强化学习笔记1——ppo算法

2023-10-12 19:59
文章标签 算法 学习 笔记 强化 ppo

本文主要是介绍强化学习笔记1——ppo算法,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

参考莫烦Python的学习视频链接: 莫烦Python的学习视频.

  1. why PPO?
    根据 OpenAI 的官方博客, PPO 已经成为他们在强化学习上的默认算法. 如果一句话概括 PPO: OpenAI 提出的一种解决 Policy Gradient 不好确定 Learning rate (或者 Step size) 的问题. 因为如果 step size 过大, 学出来的 Policy 会一直乱动, 不会收敛, 但如果 Step Size 太小, 对于完成训练, 我们会等到绝望. PPO 利用 New Policy 和 Old Policy 的比例, 限制了 New Policy 的更新幅度, 让 Policy Gradient 对稍微大点的 Step size 不那么敏感.
  2. 一些讲解,来源: link.
    在这里插入图片描述on-policy和off-policy的含义
    在这里插入图片描述
  • 1可以用迁移思想来理解这个。t1为要训练的agent, t2 为fixed的agent,用于sample
  • 2在t2上反复采样,可以理解t2是个高手。t2是固定的,也不会导致(s,a)中某个a的概率发生变化
  • 3即可以理解为t2中采的样本一定都是正确的!。把t1模型,训练成t2模型的水平。
  • 4所以alpha go 的前期训练好像是借助的棋谱。棋谱相当于t2, 刚开始的go相当于t1

简单来说,因为难以在 p ( x ) p(x) p(x)中采样,所以曲线救国,从 q ( x ) q(x) q(x)采样求期望,再乘以一个weight,即 p ( x ) / q ( x ) p(x)/q(x) p(x)/q(x)
在这里插入图片描述
如果 p ( x ) p(x) p(x) q ( x ) q(x) q(x)差距很大,要多采样才行,采样数少会错误
在这里插入图片描述
替换原理
在这里插入图片描述
原问题中的替换原理。when to stop? 引入PPO解决原网络与现在网络不能差太多的问题,即两个分布不可以差太多
在这里插入图片描述
在这里插入图片描述
4. 算法伪代码
PPO-Penalty:近似地解决了TRPO之类的受KL约束的更新,但对目标函数中的KL偏离进行了惩罚而不是使其成为硬约束,并在训练过程中自动调整惩罚系数,以便对其进行适当缩放。
PPO-Clip:在目标中没有KL散度项,也完全没有约束。取而代之的是依靠对目标函数的专门裁剪来减小新老策略的差异。
在这里插入图片描述
KL散度用来限制新策略的更新幅度(重要)
在这里插入图片描述
在PPO clip中去掉了KL散度的计算,只限制了比例。效果更好。
多线程将加快学习进程。
5. 算法结构
在这里插入图片描述
在这里插入图片描述

  1. 代码

class PPO:def __init__(self):# 建 Actor Critic 网络# 搭计算图纸 graphself.sess = tf.Session()self.tfs = tf.placeholder(tf.float32, [None, S_DIM], 'state')  # 状态空间[None, S_DIM]self._build_anet('Critic')  # 建立critic网络,更新self.v# 得到self.v之后计算损失函数with tf.variable_scope('closs'):self.tfdc_r = tf.placeholder(tf.float32, [None, 1], name='discounted_r')  # 折扣奖励self.adv = self.tfdc_r - self.v  # ?这个可以理解TD error吗?closs = tf.reduce_mean(tf.square(self.adv))  # critic的损失函数self.ctrain = tf.train.AdamOptimizer(C_LR).minimize(closs)  # 接着训练critic# 建立pi网络和old_pi网络,获得相应参数pi, pi_params = self._build_anet('pi', trainable=True)oldpi, oldpi_params = self._build_anet('oldpi', trainable=False)# ??这是什么with tf.variable_scope('sample_action'):self.sample_op = tf.squeeze(pi.sample(1), axis=0)# 将新pi参数赋给old_piwith tf.variable_scope('update_oldpi'):# 此时还没有赋值,要sess.run才行self.update_oldpi_op = [oldp.assign(p) for p, oldp in zip(pi_params, oldpi_params)]with tf.variable_scope('aloss'):self.tfa = tf.placeholder(dtype=tf.float32, shape=[None, A_DIM], name='action')  # 动作空间self.tfadv = tf.placeholder(tf.float32, [None, 1], 'advantage')  # 优势函数with tf.variable_scope('surrogate'):ratio = pi.prob(self.tfa) / oldpi.prob(self.tfa)  # 概率密度surr = ratio * self.tfadv  # 差异大,奖励大惊讶度高if METHOD['name'] == 'kl_pen':self.tflam = tf.placeholder(tf.float32, None, 'lambda')kl = tf.distributions.kl_divergence(oldpi, pi)self.kl_mean = tf.reduce_mean(kl)self.aloss = -(tf.reduce_mean(surr - self.tflam * kl))else:  # clipping method, find this is better  限制了surrogate的变化幅度self.aloss = -tf.reduce_mean(tf.minimum(surr,tf.clip_by_value(ratio, 1. - METHOD['epsilon'], 1. + METHOD['epsilon']) * self.tfadv))  # 限定ratio的范围,我也不懂这个参数是怎么调的self.atrain = tf.train.AdamOptimizer(A_LR).minimize(self.aloss) # A_LR学习率,损失函数aloss# 写日志文件tf.summary.FileWriter('log/', self.sess.graph)self.sess.run(tf.global_variables_initializer())# 搭建网络函数def _build_anet(self, name, trainable=True):# Critic网络部分if name == 'Critic':with tf.variable_scope(name):# self.s_Critic = tf.placeholder(tf.float32, [None, S_DIM], 'state')# 两层神经网络,输出是self.v,即估计state valuel1_Critic = tf.layers.dense(self.tfs, 100, tf.nn.relu, trainable=trainable, name='l1')self.v = tf.layers.dense(l1_Critic, 1, trainable=trainable, name='value_predict')# Actor部分,分为‘pi’和‘oldpi’两个神经网络# 返回动作分布以及网络参数列表else:with tf.variable_scope(name):# self.s_Actor = tf.placeholder(tf.float32, [None, S_DIM], 'state')# ??这部分l1_Actor = tf.layers.dense(self.tfs, 100, tf.nn.relu, trainable=trainable, name='l1')mu = 2 * tf.layers.dense(l1_Actor, A_DIM, tf.nn.tanh, trainable=trainable, name='mu')sigma = tf.layers.dense(l1_Actor, A_DIM, tf.nn.softplus, trainable=trainable, name='sigma')norm_list = tf.distributions.Normal(loc=mu, scale=sigma)    # 正态分布params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=name)   #提取网络参数列表return norm_list, paramsdef update(self, s, a, r):# 将值赋给old_pi网络self.sess.run(self.update_oldpi_op)#为了取回self.adv,adv = self.sess.run(self.adv, {self.tfdc_r: r, self.tfs: s})    # 后面那个字典是什么意思?if METHOD['name'] == 'kl_pen':  # 选择kl-penalty方式for _ in range(A_UPDATE_STEPS):_, kl = self.sess.run([self.atrain, self.kl_mean],{self.tfa: a, self.tfadv: adv, self.tfs: s, self.tflam: METHOD['lam']})if kl > 4 * METHOD['kl_target']:  # this in in google's paperbreakif kl < METHOD['kl_target'] / 1.5:  # adaptive lambda, this is in OpenAI's paperMETHOD['lam'] /= 2elif kl > METHOD['kl_target'] * 1.5:METHOD['lam'] *= 2METHOD['lam'] = np.clip(METHOD['lam'], 1e-4, 10)  # sometimes explode, this clipping is my solutionelse:# 训练actor网络[self.sess.run(self.atrain, {self.tfs: s, self.tfa: a, self.tfadv: adv}) for _ in range(A_UPDATE_STEPS)]# 训练critic网络[self.sess.run(self.ctrain, {self.tfs: s, self.tfdc_r: r}) for _ in range(C_UPDATE_STEPS)]def choose_action(self, s):# 选动作s = s[np.newaxis, :]a = self.sess.run(self.sample_op, {self.tfs: s})[0]return np.clip(a, -2, 2)def get_v(self, s):# 算 state valueif s.ndim < 2:s = s[np.newaxis, :]return self.sess.run(self.v, {self.tfs: s})env = gym.make('Pendulum-v0').unwrapped
S_DIM = env.observation_space.shape[0]
A_DIM = env.action_space.shape[0]
ppo = PPO()
all_ep_r = []# ppo和环境的互动
# 达到最大回合数退出
for ep in range(EP_MAX):s = env.reset()buffer_s, buffer_a, buffer_r = [], [], []ep_r = 0for t in range(EP_LEN):env.render()a = ppo.choose_action(s)s_, r, done, _ = env.step(a)# 存储在buffer当中buffer_s.append(s)buffer_a.append(a)buffer_r.append((r + 8) / 8)s = s_ep_r += r# 如果buffer收集一个batch了或者episode结束了if (t + 1) % BATCH == 0 or t == EP_LEN - 1:# 计算discounted rewardv_s_ = ppo.get_v(s_)discounted_r = []for r in buffer_r[::-1]:v_s_ = r + GAMMA * v_s_discounted_r.append(v_s_)discounted_r.reverse()bs, ba, br = np.vstack(buffer_s), np.vstack(buffer_a), np.vstack(discounted_r)# 清空bufferbuffer_s, buffer_a, buffer_r = [], [], []ppo.update(bs, ba, br)  # 更新PPOif ep == 0:all_ep_r.append(ep_r)else:all_ep_r.append(all_ep_r[-1] * 0.9 + ep_r * 0.1)print('Ep:%d | Ep_r:%f' % (ep, ep_r))plt.plot(np.arange(len(all_ep_r)), all_ep_r)
plt.xlabel('Episode')
plt.ylabel('Moving averaged episode reward')
plt.show()
  1. 多线程DPPO

这篇关于强化学习笔记1——ppo算法的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/198101

相关文章

Go学习记录之runtime包深入解析

《Go学习记录之runtime包深入解析》Go语言runtime包管理运行时环境,涵盖goroutine调度、内存分配、垃圾回收、类型信息等核心功能,:本文主要介绍Go学习记录之runtime包的... 目录前言:一、runtime包内容学习1、作用:① Goroutine和并发控制:② 垃圾回收:③ 栈和

Android学习总结之Java和kotlin区别超详细分析

《Android学习总结之Java和kotlin区别超详细分析》Java和Kotlin都是用于Android开发的编程语言,它们各自具有独特的特点和优势,:本文主要介绍Android学习总结之Ja... 目录一、空安全机制真题 1:Kotlin 如何解决 Java 的 NullPointerExceptio

使用雪花算法产生id导致前端精度缺失问题解决方案

《使用雪花算法产生id导致前端精度缺失问题解决方案》雪花算法由Twitter提出,设计目的是生成唯一的、递增的ID,下面:本文主要介绍使用雪花算法产生id导致前端精度缺失问题的解决方案,文中通过代... 目录一、问题根源二、解决方案1. 全局配置Jackson序列化规则2. 实体类必须使用Long封装类3.

重新对Java的类加载器的学习方式

《重新对Java的类加载器的学习方式》:本文主要介绍重新对Java的类加载器的学习方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1、介绍1.1、简介1.2、符号引用和直接引用1、符号引用2、直接引用3、符号转直接的过程2、加载流程3、类加载的分类3.1、显示

Springboot实现推荐系统的协同过滤算法

《Springboot实现推荐系统的协同过滤算法》协同过滤算法是一种在推荐系统中广泛使用的算法,用于预测用户对物品(如商品、电影、音乐等)的偏好,从而实现个性化推荐,下面给大家介绍Springboot... 目录前言基本原理 算法分类 计算方法应用场景 代码实现 前言协同过滤算法(Collaborativ

Java学习手册之Filter和Listener使用方法

《Java学习手册之Filter和Listener使用方法》:本文主要介绍Java学习手册之Filter和Listener使用方法的相关资料,Filter是一种拦截器,可以在请求到达Servl... 目录一、Filter(过滤器)1. Filter 的工作原理2. Filter 的配置与使用二、Listen

openCV中KNN算法的实现

《openCV中KNN算法的实现》KNN算法是一种简单且常用的分类算法,本文主要介绍了openCV中KNN算法的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的... 目录KNN算法流程使用OpenCV实现KNNOpenCV 是一个开源的跨平台计算机视觉库,它提供了各

利用Python快速搭建Markdown笔记发布系统

《利用Python快速搭建Markdown笔记发布系统》这篇文章主要为大家详细介绍了使用Python生态的成熟工具,在30分钟内搭建一个支持Markdown渲染、分类标签、全文搜索的私有化知识发布系统... 目录引言:为什么要自建知识博客一、技术选型:极简主义开发栈二、系统架构设计三、核心代码实现(分步解析

springboot+dubbo实现时间轮算法

《springboot+dubbo实现时间轮算法》时间轮是一种高效利用线程资源进行批量化调度的算法,本文主要介绍了springboot+dubbo实现时间轮算法,文中通过示例代码介绍的非常详细,对大家... 目录前言一、参数说明二、具体实现1、HashedwheelTimer2、createWheel3、n

SpringBoot实现MD5加盐算法的示例代码

《SpringBoot实现MD5加盐算法的示例代码》加盐算法是一种用于增强密码安全性的技术,本文主要介绍了SpringBoot实现MD5加盐算法的示例代码,文中通过示例代码介绍的非常详细,对大家的学习... 目录一、什么是加盐算法二、如何实现加盐算法2.1 加盐算法代码实现2.2 注册页面中进行密码加盐2.