[cleanrl] ppo_continuous_action源码解析

2023-12-12 05:44

本文主要是介绍[cleanrl] ppo_continuous_action源码解析,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1 import库(略)

import os
import random
import time
from dataclasses import dataclassimport gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import tyro
from torch.distributions.normal import Normal
from torch.utils.tensorboard import SummaryWriter

2 Args类(略)

定义了所有有关模型的参数,参数含义见英文注释。

@dataclass
class Args:exp_name: str = os.path.basename(__file__)[: -len(".py")]"""the name of this experiment"""seed: int = 1"""seed of the experiment"""torch_deterministic: bool = True"""if toggled, `torch.backends.cudnn.deterministic=False`"""cuda: bool = True"""if toggled, cuda will be enabled by default"""track: bool = False"""if toggled, this experiment will be tracked with Weights and Biases"""wandb_project_name: str = "cleanRL""""the wandb's project name"""wandb_entity: str = None"""the entity (team) of wandb's project"""capture_video: bool = False"""whether to capture videos of the agent performances (check out `videos` folder)"""save_model: bool = False"""whether to save model into the `runs/{run_name}` folder"""upload_model: bool = False"""whether to upload the saved model to huggingface"""hf_entity: str = """""the user or org name of the model repository from the Hugging Face Hub"""# Algorithm specific argumentsenv_id: str = "HalfCheetah-v4""""the id of the environment"""total_timesteps: int = 1000000"""total timesteps of the experiments"""learning_rate: float = 3e-4"""the learning rate of the optimizer"""num_envs: int = 1"""the number of parallel game environments"""num_steps: int = 2048"""the number of steps to run in each environment per policy rollout"""anneal_lr: bool = True"""Toggle learning rate annealing for policy and value networks"""gamma: float = 0.99"""the discount factor gamma"""gae_lambda: float = 0.95"""the lambda for the general advantage estimation"""num_minibatches: int = 32"""the number of mini-batches"""update_epochs: int = 10"""the K epochs to update the policy"""norm_adv: bool = True"""Toggles advantages normalization"""clip_coef: float = 0.2"""the surrogate clipping coefficient"""clip_vloss: bool = True"""Toggles whether or not to use a clipped loss for the value function, as per the paper."""ent_coef: float = 0.0"""coefficient of the entropy"""vf_coef: float = 0.5"""coefficient of the value function"""max_grad_norm: float = 0.5"""the maximum norm for the gradient clipping"""target_kl: float = None"""the target KL divergence threshold"""# to be filled in runtimebatch_size: int = 0"""the batch size (computed in runtime)"""minibatch_size: int = 0"""the mini-batch size (computed in runtime)"""num_iterations: int = 0"""the number of iterations (computed in runtime)"""

3 定义Agent

使用gym.wrappers对原始gym环境进行修改:

  • FlattenObservation:将obs矩阵展平为1维向量
  • RecordEpisodeStatistics:记录episode的统计数据
  • ClipAction:剪裁action以满足action_space的要求
  • NormalizeObservation:对obs矩阵进行归一化
  • TransformObservation:对obs矩阵进行变换
  • NormalizeReward:对reward进行归一化
  • TransformReward:对reward进行变换
def make_env(env_id, idx, capture_video, run_name, gamma):def thunk():if capture_video and idx == 0:env = gym.make(env_id, render_mode="rgb_array")env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")else:env = gym.make(env_id)env = gym.wrappers.FlattenObservation(env)  # deal with dm_control's Dict observation spaceenv = gym.wrappers.RecordEpisodeStatistics(env)env = gym.wrappers.ClipAction(env)env = gym.wrappers.NormalizeObservation(env)env = gym.wrappers.TransformObservation(env, lambda obs: np.clip(obs, -10, 10))env = gym.wrappers.NormalizeReward(env, gamma=gamma)env = gym.wrappers.TransformReward(env, lambda reward: np.clip(reward, -10, 10))return envreturn thunk

初始化神经网络中的每层的参数。

def layer_init(layer, std=np.sqrt(2), bias_const=0.0):torch.nn.init.orthogonal_(layer.weight, std)torch.nn.init.constant_(layer.bias, bias_const)return layer

PPO(连续动作)的Agent类,Actor-Critic结构,其中Actor网络和Critic网络均基于MLP构建,激活函数使用Tanh

Critic网络的输入尺寸为(batch_size, obs_dim, 64),输出尺寸为(batch_size, 1),作用是形成obs到value的映射。向外暴露get_value函数以计算状态价值。

Actor网络包含两部分:

  • self.action_mean将obs映射到动作均值,输入尺寸为(batch_size, obs_dim, 64),输出尺寸为(batch_size, action_dim)
  • self.actor_logstd是一个(1, action_dim)大小的Parameter,用于形成动作方差的对数(后面需要对其使用torch.exp保证其为正数)

在cleanrl的实现里,Actor网络使用对角高斯分布来生成连续动作的分布,即根据Normal(action_mean, actor_std)对动作进行抽样。

get_action_and_value函数中计算了:

  • 动作分布probs
  • 动作采样probs.sample()
  • 对数似然probs.log_prob(action).sum(1)
  • probs.entropy().sum(1)
  • 状态价值self.critic(x)

在对数似然和熵的计算中,sum(1)用于计算多个相互独立动作的联合概率。

class Agent(nn.Module):def __init__(self, envs):super().__init__()self.critic = nn.Sequential(layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),nn.Tanh(),layer_init(nn.Linear(64, 64)),nn.Tanh(),layer_init(nn.Linear(64, 1), std=1.0),)self.actor_mean = nn.Sequential(layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),nn.Tanh(),layer_init(nn.Linear(64, 64)),nn.Tanh(),layer_init(nn.Linear(64, np.prod(envs.single_action_space.shape)), std=0.01),)self.actor_logstd = nn.Parameter(torch.zeros(1, np.prod(envs.single_action_space.shape)))def get_value(self, x):return self.critic(x)def get_action_and_value(self, x, action=None):action_mean = self.actor_mean(x)action_logstd = self.actor_logstd.expand_as(action_mean)action_std = torch.exp(action_logstd)probs = Normal(action_mean, action_std)if action is None:action = probs.sample()return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x)

4 训练Agent

设置一些参数,稍微解释一下几个参数的含义:

  • batch_sizenum_envsnum_steps的乘积,表示跑一次迭代能收集到多少样本
  • minibatch_size:每次训练都从大的batch中抽取小的minibatch进行训练
  • num_iterations:整个训练过程跑几轮迭代
args = tyro.cli(Args)
args.batch_size = int(args.num_envs * args.num_steps)
args.minibatch_size = int(args.batch_size // args.num_minibatches)
args.num_iterations = args.total_timesteps // args.batch_size
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
if args.track:import wandbwandb.init(project=args.wandb_project_name,entity=args.wandb_entity,sync_tensorboard=True,config=vars(args),name=run_name,monitor_gym=True,save_code=True,)
writer = SummaryWriter(f"runs/{run_name}")
writer.add_text("hyperparameters","|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)# TRY NOT TO MODIFY: seeding
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministicdevice = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

实例化envs、agent以及optim。

# env setup
envs = gym.vector.SyncVectorEnv([make_env(args.env_id, i, args.capture_video, run_name, args.gamma) for i in range(args.num_envs)]
)
assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported"agent = Agent(envs).to(device)
optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)

定义需要收集的信息

  • obs:观测到的环境状态
  • actions:动作采样值
  • logprobs:动作采样的对数似然
  • rewards:即时奖励
  • dones:episode是否结束
  • values:状态价值
# ALGO Logic: Storage setup
obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device)
actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device)
rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
values = torch.zeros((args.num_steps, args.num_envs)).to(device)

next_obs存储每步的观测结果,next_done存储每步是否导致episode结束。这两个变量用于计算由最后一个动作导致的下一个状态的价值。

# TRY NOT TO MODIFY: start the game
global_step = 0
start_time = time.time()
next_obs, _ = envs.reset(seed=args.seed)
next_obs = torch.Tensor(next_obs).to(device)
next_done = torch.zeros(args.num_envs).to(device)

step的for循环里,Actor网络和Critic网络基于当前策略(旧策略)收集样本。因为旧策略不作为参数参与到梯度下降过程,因此需要torch.no_grad()包围相关数值的计算过程。

for iteration in range(1, args.num_iterations + 1):# Annealing the rate if instructed to do so.if args.anneal_lr:frac = 1.0 - (iteration - 1.0) / args.num_iterationslrnow = frac * args.learning_rateoptimizer.param_groups[0]["lr"] = lrnowfor step in range(0, args.num_steps):global_step += args.num_envsobs[step] = next_obsdones[step] = next_done# ALGO LOGIC: action logicwith torch.no_grad():action, logprob, _, value = agent.get_action_and_value(next_obs)values[step] = value.flatten()actions[step] = actionlogprobs[step] = logprob# TRY NOT TO MODIFY: execute the game and log data.next_obs, reward, terminations, truncations, infos = envs.step(action.cpu().numpy())next_done = np.logical_or(terminations, truncations)rewards[step] = torch.tensor(reward).to(device).view(-1)next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(next_done).to(device)if "final_info" in infos:for info in infos["final_info"]:if info and "episode" in info:print(f"global_step={global_step}, episodic_return={info['episode']['r']}")writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)

这部分基于value、reward计算GAE(广义优势估计)。从最后一个reward开始,通过迭代计算:

  • δ t = r t + γ ∗ V ( s t + 1 ) − V ( s t ) \delta_t = r_t+\gamma * V(s_{t+1})-V(s_t) δt=rt+γV(st+1)V(st)
  • a t = δ t + γ ∗ λ ∗ a t + 1 a_t = \delta_t + \gamma * \lambda * a_{t+1} at=δt+γλat+1
###############################################
for iteration in range(1, args.num_iterations + 1):【在iteration的for循环中拼接上一段代码】
################################################ bootstrap value if not donewith torch.no_grad():next_value = agent.get_value(next_obs).reshape(1, -1)advantages = torch.zeros_like(rewards).to(device)lastgaelam = 0for t in reversed(range(args.num_steps)):if t == args.num_steps - 1:nextnonterminal = 1.0 - next_donenextvalues = next_valueelse:nextnonterminal = 1.0 - dones[t + 1]nextvalues = values[t + 1]delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelamreturns = advantages + values

原先的矩阵都是(num_envs, num_steps, XX_dim)的形状,现在转换成(batch_size, XX_dim)的形状,后面要基于batch划分minibatch进行训练。

###############################################
for iteration in range(1, args.num_iterations + 1):【在iteration的for循环中拼接上一段代码】
################################################ flatten the batchb_obs = obs.reshape((-1,) + envs.single_observation_space.shape)b_logprobs = logprobs.reshape(-1)b_actions = actions.reshape((-1,) + envs.single_action_space.shape)b_advantages = advantages.reshape(-1)b_returns = returns.reshape(-1)b_values = values.reshape(-1)# Optimizing the policy and value networkb_inds = np.arange(args.batch_size)clipfracs = []

minibatch的划分是基于b_inds进行的,所以先使用shuffle进行打乱,然后在start的for循环里每次抽取minibatch,计算新的newlogprobentropynewvalue。根据新的和旧的logprob计算ratio,用于后面的PPO截断。

###############################################
for iteration in range(1, args.num_iterations + 1):......
###############################################for epoch in range(args.update_epochs):np.random.shuffle(b_inds)for start in range(0, args.batch_size, args.minibatch_size):end = start + args.minibatch_sizemb_inds = b_inds[start:end]_, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions[mb_inds])logratio = newlogprob - b_logprobs[mb_inds]ratio = logratio.exp()

首先采用kl-approx使用蒙特卡洛近似KL散度approx_kl,然后获取minibatch的advantage,按需归一化。最后进行PPO截断,计算policy loss。

###############################################
for iteration in range(1, args.num_iterations + 1):......for epoch in range(args.update_epochs):......for start in range(0, args.batch_size, args.minibatch_size):【在start的for循环中拼接上一段代码】
###############################################with torch.no_grad():# calculate approx_kl http://joschu.net/blog/kl-approx.htmlold_approx_kl = (-logratio).mean()approx_kl = ((ratio - 1) - logratio).mean()clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()]mb_advantages = b_advantages[mb_inds]if args.norm_adv:mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)# Policy losspg_loss1 = -mb_advantages * ratiopg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)pg_loss = torch.max(pg_loss1, pg_loss2).mean()

根据旧的b_returns和新的newvalue计算value loss。当然这里也提供了value loss clip。

###############################################
for iteration in range(1, args.num_iterations + 1):......for epoch in range(args.update_epochs):......for start in range(0, args.batch_size, args.minibatch_size):【在start的for循环中拼接上一段代码】
################################################ Value lossnewvalue = newvalue.view(-1)if args.clip_vloss:v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2v_clipped = b_values[mb_inds] + torch.clamp(newvalue - b_values[mb_inds],-args.clip_coef,args.clip_coef,)v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)v_loss = 0.5 * v_loss_max.mean()else:v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean()

根据policy loss、value loss和entropy加权求和得到总的loss,然后反向传播优化参数。在之前计算了新旧策略之间的KL散度,这里可以利用KL散度实现early stopping,即KL散度大于阈值则停止当前batch的训练。(当然也可以停止掉当前minibatch的训练)

###############################################
for iteration in range(1, args.num_iterations + 1):......for epoch in range(args.update_epochs):......for start in range(0, args.batch_size, args.minibatch_size):【在start的for循环中拼接上一段代码】
###############################################entropy_loss = entropy.mean()loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coefoptimizer.zero_grad()loss.backward()nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)optimizer.step()if args.target_kl is not None and approx_kl > args.target_kl:break

tensorboard记录数据,没什么好说的。

###############################################
for iteration in range(1, args.num_iterations + 1):【在iteration的for循环中拼接上一段代码】
###############################################y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()var_y = np.var(y_true)explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y# TRY NOT TO MODIFY: record rewards for plotting purposeswriter.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)writer.add_scalar("losses/value_loss", v_loss.item(), global_step)writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)writer.add_scalar("losses/entropy", entropy_loss.item(), global_step)writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step)writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step)writer.add_scalar("losses/explained_variance", explained_var, global_step)print("SPS:", int(global_step / (time.time() - start_time)))writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

模型保存的一些操作,也没什么好说的。

if args.save_model:model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"torch.save(agent.state_dict(), model_path)print(f"model saved to {model_path}")from cleanrl_utils.evals.ppo_eval import evaluateepisodic_returns = evaluate(model_path,make_env,args.env_id,eval_episodes=10,run_name=f"{run_name}-eval",Model=Agent,device=device,gamma=args.gamma,)for idx, episodic_return in enumerate(episodic_returns):writer.add_scalar("eval/episodic_return", episodic_return, idx)if args.upload_model:from cleanrl_utils.huggingface import push_to_hubrepo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}"repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_namepush_to_hub(args, episodic_returns, repo_id, "PPO", f"runs/{run_name}", f"videos/{run_name}-eval")envs.close()
writer.close()

这篇关于[cleanrl] ppo_continuous_action源码解析的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/483445

相关文章

全面解析Golang 中的 Gorilla CORS 中间件正确用法

《全面解析Golang中的GorillaCORS中间件正确用法》Golang中使用gorilla/mux路由器配合rs/cors中间件库可以优雅地解决这个问题,然而,很多人刚开始使用时会遇到配... 目录如何让 golang 中的 Gorilla CORS 中间件正确工作一、基础依赖二、错误用法(很多人一开

Mysql中设计数据表的过程解析

《Mysql中设计数据表的过程解析》数据库约束通过NOTNULL、UNIQUE、DEFAULT、主键和外键等规则保障数据完整性,自动校验数据,减少人工错误,提升数据一致性和业务逻辑严谨性,本文介绍My... 目录1.引言2.NOT NULL——制定某列不可以存储NULL值2.UNIQUE——保证某一列的每一

深度解析Nginx日志分析与499状态码问题解决

《深度解析Nginx日志分析与499状态码问题解决》在Web服务器运维和性能优化过程中,Nginx日志是排查问题的重要依据,本文将围绕Nginx日志分析、499状态码的成因、排查方法及解决方案展开讨论... 目录前言1. Nginx日志基础1.1 Nginx日志存放位置1.2 Nginx日志格式2. 499

MySQL CTE (Common Table Expressions)示例全解析

《MySQLCTE(CommonTableExpressions)示例全解析》MySQL8.0引入CTE,支持递归查询,可创建临时命名结果集,提升复杂查询的可读性与维护性,适用于层次结构数据处... 目录基本语法CTE 主要特点非递归 CTE简单 CTE 示例多 CTE 示例递归 CTE基本递归 CTE 结

Spring Boot 3.x 中 WebClient 示例详解析

《SpringBoot3.x中WebClient示例详解析》SpringBoot3.x中WebClient是响应式HTTP客户端,替代RestTemplate,支持异步非阻塞请求,涵盖GET... 目录Spring Boot 3.x 中 WebClient 全面详解及示例1. WebClient 简介2.

在MySQL中实现冷热数据分离的方法及使用场景底层原理解析

《在MySQL中实现冷热数据分离的方法及使用场景底层原理解析》MySQL冷热数据分离通过分表/分区策略、数据归档和索引优化,将频繁访问的热数据与冷数据分开存储,提升查询效率并降低存储成本,适用于高并发... 目录实现冷热数据分离1. 分表策略2. 使用分区表3. 数据归档与迁移在mysql中实现冷热数据分

C#解析JSON数据全攻略指南

《C#解析JSON数据全攻略指南》这篇文章主要为大家详细介绍了使用C#解析JSON数据全攻略指南,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录一、为什么jsON是C#开发必修课?二、四步搞定网络JSON数据1. 获取数据 - HttpClient最佳实践2. 动态解析 - 快速

Spring Boot3.0新特性全面解析与应用实战

《SpringBoot3.0新特性全面解析与应用实战》SpringBoot3.0作为Spring生态系统的一个重要里程碑,带来了众多令人兴奋的新特性和改进,本文将深入解析SpringBoot3.0的... 目录核心变化概览Java版本要求提升迁移至Jakarta EE重要新特性详解1. Native Ima

spring中的@MapperScan注解属性解析

《spring中的@MapperScan注解属性解析》@MapperScan是Spring集成MyBatis时自动扫描Mapper接口的注解,简化配置并支持多数据源,通过属性控制扫描路径和过滤条件,利... 目录一、核心功能与作用二、注解属性解析三、底层实现原理四、使用场景与最佳实践五、注意事项与常见问题六

nginx -t、nginx -s stop 和 nginx -s reload 命令的详细解析(结合应用场景)

《nginx-t、nginx-sstop和nginx-sreload命令的详细解析(结合应用场景)》本文解析Nginx的-t、-sstop、-sreload命令,分别用于配置语法检... 以下是关于 nginx -t、nginx -s stop 和 nginx -s reload 命令的详细解析,结合实际应