【强化学习】13 —— Actor-Critic 算法

2023-11-01 11:15

本文主要是介绍【强化学习】13 —— Actor-Critic 算法,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

  • REINFORCE 存在的问题
  • Actor-Critic
  • A2C: Advantageous Actor-Critic
  • 代码实践
    • 结果
  • 参考

REINFORCE 存在的问题

  • 基于片段式数据的任务
    • 通常情况下,任务需要有终止状态,REINFORCE才能直接计算累计折扣奖励
  • 低数据利用效率
    • 实际中,REINFORCE需要大量的训练数据
  • 高训练方差(最重要的缺陷
    • 从单个或多个片段中采样到的值函数具有很高的方差

Actor-Critic

在 REINFORCE 算法中,目标函数的梯度中有一项轨迹回报,用于指导策略的更新。REINFOCE 算法用蒙特卡洛方法来估计 Q ( s , a ) Q(s,a) Q(s,a),能不能考虑拟合一个值函数来指导策略进行学习呢?这正是 Actor-Critic 算法所做的。
在这里插入图片描述
评论家Critic Q Φ ( s , a ) Q_\Phi (s,a) QΦ(s,a):

  • 学会准确估计当前演员策略(actor policy)的动作价值。通过 Actor 与环境交互收集的数据学习一个价值函数,这个价值函数会用于判断在当前状态什么动作是好的,什么动作不是好的,进而帮助 Actor 进行策略更新。 Q Φ ( s , a ) ≃ r ( s , a ) + γ E s ′ ∼ p ( s ′ ∣ s , a ) , a ′ ∼ π θ ( a ′ ∣ s ′ ) [ Q Φ ( s ′ , a ′ ) ] Q_\Phi(s,a)\simeq r(s,a)+\gamma\mathbb{E}_{s^{\prime}\thicksim p(s^{\prime}|s,a),a^{\prime}\thicksim\pi_\theta(a^{\prime}|s^{\prime})}[Q_\Phi(s^{\prime},a^{\prime})] QΦ(s,a)r(s,a)+γEsp(ss,a),aπθ(as)[QΦ(s,a)]

演员Actor π θ ( s , a ) \pi_\theta(s,a) πθ(s,a):

  • 要做的是与环境交互,并在 Critic 价值函数的指导下用策略梯度学习一个更好的策略。 J ( θ ) = E s ∼ p , π θ [ π θ ( a ∣ s ) Q Φ ( s , a ) ] ∂ f ( θ ) ∂ θ = E π θ [ ∂ log ⁡ π θ ( a ∣ s ) ∂ θ Q Φ ( s , a ) ] \begin{aligned}J(\theta)&=\mathbb{E}_{s\sim p,\pi_\theta}[\pi_\theta(a|s)Q_\Phi(s,a)]\\\\\frac{\partial f(\theta)}{\partial\theta}&=\mathbb{E}_{\pi_\theta}\left[\frac{\partial\log\pi_\theta(a|s)}{\partial\theta}Q_\Phi(s,a)\right]\end{aligned} J(θ)θf(θ)=Esp,πθ[πθ(as)QΦ(s,a)]=Eπθ[θlogπθ(as)QΦ(s,a)]

A2C: Advantageous Actor-Critic

思想:通过减去一个基线函数来标准化评论家的打分

  • 更多信息指导:降低较差动作概率,提高较优动作概率
  • 进一步降低方差

优势函数(Advantage Function) A π ( s , a ) = Q π ( s , a ) − V π ( s ) A^\pi(s,a)=Q^\pi(s,a)-V^\pi(s) Aπ(s,a)=Qπ(s,a)Vπ(s)
在这里插入图片描述
若只采用动作值的方式,虽然也会选择A2,但是方差相对会更大,同时所有的动作都是出于上升的状态,只是上升程度的问题。而采用优势函数的方式,部分动作的优势函数值是负的,可以直接降低相应动作的概率,同时方差更小。

状态-动作值和状态值函数 Q π ( s , a ) = r ( s , a ) + γ E s ′ ∼ p ( s ′ ∣ s , a ) , a ′ ∼ π θ ( a ′ ∣ s ′ ) [ Q Φ ( s ′ , a ′ ) ] = r ( s , a ) + γ E s ′ ∼ p ( s ′ ∣ s , a ) [ V π ( s ′ ) ] \begin{aligned} Q^{\pi}(s,a)& =r(s,a)+\gamma\mathbb{E}_{s^{\prime}\sim p(s^{\prime}|s,a),a^{\prime}\sim\pi_\theta(a^{\prime}|s^{\prime})}\left[Q_\Phi(s^{\prime},a^{\prime})\right] \\ &=r(s,a)+\gamma\mathbb{E}_{s^{\prime}\sim p(s^{\prime}|s,a)}[V^{\pi}(s^{\prime})] \end{aligned} Qπ(s,a)=r(s,a)+γEsp(ss,a),aπθ(as)[QΦ(s,a)]=r(s,a)+γEsp(ss,a)[Vπ(s)]

因此我们只需要拟合状态值函数来拟合优势函数 A π ( s , a ) = Q π ( s , a ) − V π ( s ) = r ( s , a ) + γ E s ′ ∼ p ( s ′ ∣ s , a ) [ V π ( s ′ ) − V π ( s ) ] ≃ r ( s , a ) + γ ( V π ( s ′ ) − V π ( s ) ) \begin{aligned} A^{\pi}(s,a)& =Q^\pi(s,a)-V^\pi(s) \\ &=r(s,a)+\gamma\mathbb{E}_{s^{\prime}\sim p(s^{\prime}|s,a)}[V^{\pi}(s^{\prime})-V^{\pi}(s)] \\ &\simeq r(s,a)+\gamma(V^{\pi}(s^{\prime})-V^{\pi}(s)) \end{aligned} Aπ(s,a)=Qπ(s,a)Vπ(s)=r(s,a)+γEsp(ss,a)[Vπ(s)Vπ(s)]r(s,a)+γ(Vπ(s)Vπ(s))


在策略梯度中,可以把梯度写成下面这个更加一般的形式: g = E [ ∑ t = 0 T ψ t ∇ θ log ⁡ π θ ( a t ∣ s t ) ] g=\mathbb{E}\left[\sum_{t=0}^T\psi_t\nabla_\theta\log\pi_\theta(a_t|s_t)\right] g=E[t=0Tψtθlogπθ(atst)]其中, ψ t \psi_t ψt可以有很多种形式: 1. ∑ t ′ = 0 T γ t ′ r t ′ : 轨迹的总回报; 2. ∑ t ′ = t T γ t ′ − t r t ′ : 动作 a t 之后的回报; 3. ∑ t ′ = t T γ t ′ − t r t ′ − b ( s t ) : 基准线版本的改进 ; 4. Q π θ ( s t , a t ) : 动作价值函数; 5. A π θ ( s t , a t ) : 优势函数; 6. r t + γ V π θ ( s t + 1 ) − V π θ ( s t ) : 时序差分残差。 \begin{aligned} &1.\sum_{t^{\prime}=0}^T\gamma^{t^{\prime}}r_{t^{\prime}}:\textit{轨迹的总回报;} \\ &2.\sum_{t^{\prime}=t}^T\gamma^{t^{\prime}-t}r_{t^{\prime}}:\textit{动作}a_t\textit{之后的回报;} \\ &\begin{aligned}3.\sum_{t^{\prime}=t}^T\gamma^{t^{\prime}-t}r_{t^{\prime}}-b(s_t):\textit{基准线版本的改进};\end{aligned} \\ &4.Q^{\pi_\theta}(s_t,a_t):\textit{动作价值函数;} \\ &5.A^{\pi_\theta}(s_t,a_t):\textit{优势函数;} \\ &6.r_t+\gamma V^{\pi_\theta}(s_{t+1})-V^{\pi_\theta}(s_t):\textit{时序差分残差。} \end{aligned} 1.t=0Tγtrt:轨迹的总回报;2.t=tTγttrt:动作at之后的回报;3.t=tTγttrtb(st):基准线版本的改进;4.Qπθ(st,at):动作价值函数;5.Aπθ(st,at):优势函数;6.rt+γVπθ(st+1)Vπθ(st):时序差分残差。

REINFORCE 通过蒙特卡洛采样的方法对策略梯度的估计是无偏的,但是方差非常大。我们可以用形式(3)引入基线函数 b ( s t ) b(s_t) b(st)(baseline function)来减小方差。此外,我们也可以采用 Actor-Critic 算法估计一个动作价值函数 Q Q Q,代替蒙特卡洛采样得到的回报,这便是形式(4)。这个时候,我们可以把状态价值函数 V V V作为基线,从 Q Q Q函数减去这个 V V V函数则得到了 A A A函数,我们称之为优势函数(advantage function),这便是形式(5)。更进一步,我们可以利用等式 Q = r + γ V Q=r+\gamma V Q=r+γV得到形式(6)。

Actor 的更新采用策略梯度的原则,那 Critic 如何更新呢?我们将 Critic 价值网络表示为 V ω V_\omega Vω,参数为 ω \omega ω。于是,我们可以采取时序差分残差的学习方式,对于单个数据定义如下价值函数的损失函数: L ( ω ) = 1 2 ( r + γ V ω ( s t + 1 ) − V ω ( s t ) ) 2 \mathcal{L}(\omega)=\frac12(r+\gamma V_\omega(s_{t+1})-V_\omega(s_t))^2 L(ω)=21(r+γVω(st+1)Vω(st))2

与 DQN 中一样,我们采取类似于目标网络的方法,将上式中 r + γ V ω ( s t + 1 ) r+\gamma V_\omega(s_{t+1}) r+γVω(st+1)作为时序差分目标,不会产生梯度来更新价值函数。因此,价值函数的梯度为:
∇ ω L ( ω ) = − ( r + γ V ω ( s t + 1 ) − V ω ( s t ) ) ∇ ω V ω ( s t ) \nabla_{\omega}\mathcal{L}(\omega)=-(r+\gamma V_{\omega}(s_{t+1})-V_{\omega}(s_{t}))\nabla_{\omega}V_{\omega}(s_{t}) ωL(ω)=(r+γVω(st+1)Vω(st))ωVω(st)

然后使用梯度下降方法来更新 Critic 价值网络参数即可。

算法伪代码:

在这里插入图片描述

代码实践

import gymnasium as gym
import numpy as np
from tqdm import tqdm
import torch
import torch.nn.functional as F
import utilclass PolicyNet(torch.nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super(PolicyNet, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, action_dim)def forward(self, x):x = F.relu(self.fc1(x))return F.softmax(self.fc2(x), dim=1)# 输入是某个状态,输出则是状态的价值。
class ValueNet(torch.nn.Module):def __init__(self, state_dim, hidden_dim):super(ValueNet, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, 1)def forward(self, x):x = F.relu(self.fc1(x))return self.fc2(x)class ActorCritic:def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr, gamma,device, numOfEpisodes, env):self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)self.critic = ValueNet(state_dim, hidden_dim).to(device)self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)self.gamma = gammaself.device = deviceself.env = envself.numOfEpisodes = numOfEpisodes# 根据动作概率分布随机采样def takeAction(self, state):state = torch.tensor(np.array([state]), dtype=torch.float).to(self.device)action_probs = self.actor(state)action_dist = torch.distributions.Categorical(action_probs)action = action_dist.sample()return action.item()def update(self, transition_dict):states = torch.tensor(np.array(transition_dict['states']), dtype=torch.float).to(self.device)actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(self.device)rewards = torch.tensor(transition_dict['rewards'], dtype=torch.float).view(-1, 1).to(self.device)next_states = torch.tensor(np.array(transition_dict['next_states']), dtype=torch.float).to(self.device)terminateds = torch.tensor(transition_dict['terminateds'], dtype=torch.float).view(-1, 1).to(self.device)truncateds = torch.tensor(transition_dict['truncateds'], dtype=torch.float).view(-1, 1).to(self.device)# 时序差分目标td_target = rewards + self.gamma * self.critic(next_states) * (1 - terminateds + truncateds)# 时序差分误差td_delta = td_target - self.critic(states)log_probs = torch.log(self.actor(states).gather(1, actions))# 均方误差损失函数actor_loss = torch.mean(-log_probs * td_delta.detach())critic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach()))self.actor_optimizer.zero_grad()self.critic_optimizer.zero_grad()actor_loss.backward()critic_loss.backward()self.actor_optimizer.step()self.critic_optimizer.step()def ACTrain(self):returnList = []for i in range(10):with tqdm(total=int(self.numOfEpisodes / 10), desc='Iteration %d' % i) as pbar:for episode in range(int(self.numOfEpisodes / 10)):# initialize statestate, info = self.env.reset()terminated = Falsetruncated = FalseepisodeReward = 0transition_dict = {'states': [],'actions': [],'next_states': [],'rewards': [],'terminateds': [],'truncateds': []}# Loop for each step of episode:while 1:action = self.takeAction(state)next_state, reward, terminated, truncated, info = self.env.step(action)transition_dict['states'].append(state)transition_dict['actions'].append(action)transition_dict['next_states'].append(next_state)transition_dict['rewards'].append(reward)transition_dict['terminateds'].append(terminated)transition_dict['truncateds'].append(truncated)state = next_stateepisodeReward += rewardif terminated or truncated:breakself.update(transition_dict)returnList.append(episodeReward)if (episode + 1) % 10 == 0:  # 每10条序列打印一下这10条序列的平均回报pbar.set_postfix({'episode':'%d' % (self.numOfEpisodes / 10 * i + episode + 1),'return':'%.3f' % np.mean(returnList[-10:])})pbar.update(1)return returnList

结果

在这里插入图片描述
在这里插入图片描述
可以发现,Actor-Critic 算法很快便能收敛到最优策略,并且训练过程非常稳定,抖动情况相比 REINFORCE 算法有了明显的改进,这说明价值函数的引入减小了方差。

参考

[1] 伯禹AI
[2] https://www.davidsilver.uk/teaching/
[3] 动手学强化学习
[4] Reinforcement Learning

这篇关于【强化学习】13 —— Actor-Critic 算法的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/322571

相关文章

Go学习记录之runtime包深入解析

《Go学习记录之runtime包深入解析》Go语言runtime包管理运行时环境,涵盖goroutine调度、内存分配、垃圾回收、类型信息等核心功能,:本文主要介绍Go学习记录之runtime包的... 目录前言:一、runtime包内容学习1、作用:① Goroutine和并发控制:② 垃圾回收:③ 栈和

Android学习总结之Java和kotlin区别超详细分析

《Android学习总结之Java和kotlin区别超详细分析》Java和Kotlin都是用于Android开发的编程语言,它们各自具有独特的特点和优势,:本文主要介绍Android学习总结之Ja... 目录一、空安全机制真题 1:Kotlin 如何解决 Java 的 NullPointerExceptio

使用雪花算法产生id导致前端精度缺失问题解决方案

《使用雪花算法产生id导致前端精度缺失问题解决方案》雪花算法由Twitter提出,设计目的是生成唯一的、递增的ID,下面:本文主要介绍使用雪花算法产生id导致前端精度缺失问题的解决方案,文中通过代... 目录一、问题根源二、解决方案1. 全局配置Jackson序列化规则2. 实体类必须使用Long封装类3.

重新对Java的类加载器的学习方式

《重新对Java的类加载器的学习方式》:本文主要介绍重新对Java的类加载器的学习方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1、介绍1.1、简介1.2、符号引用和直接引用1、符号引用2、直接引用3、符号转直接的过程2、加载流程3、类加载的分类3.1、显示

Springboot实现推荐系统的协同过滤算法

《Springboot实现推荐系统的协同过滤算法》协同过滤算法是一种在推荐系统中广泛使用的算法,用于预测用户对物品(如商品、电影、音乐等)的偏好,从而实现个性化推荐,下面给大家介绍Springboot... 目录前言基本原理 算法分类 计算方法应用场景 代码实现 前言协同过滤算法(Collaborativ

Java学习手册之Filter和Listener使用方法

《Java学习手册之Filter和Listener使用方法》:本文主要介绍Java学习手册之Filter和Listener使用方法的相关资料,Filter是一种拦截器,可以在请求到达Servl... 目录一、Filter(过滤器)1. Filter 的工作原理2. Filter 的配置与使用二、Listen

openCV中KNN算法的实现

《openCV中KNN算法的实现》KNN算法是一种简单且常用的分类算法,本文主要介绍了openCV中KNN算法的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的... 目录KNN算法流程使用OpenCV实现KNNOpenCV 是一个开源的跨平台计算机视觉库,它提供了各

springboot+dubbo实现时间轮算法

《springboot+dubbo实现时间轮算法》时间轮是一种高效利用线程资源进行批量化调度的算法,本文主要介绍了springboot+dubbo实现时间轮算法,文中通过示例代码介绍的非常详细,对大家... 目录前言一、参数说明二、具体实现1、HashedwheelTimer2、createWheel3、n

SpringBoot实现MD5加盐算法的示例代码

《SpringBoot实现MD5加盐算法的示例代码》加盐算法是一种用于增强密码安全性的技术,本文主要介绍了SpringBoot实现MD5加盐算法的示例代码,文中通过示例代码介绍的非常详细,对大家的学习... 目录一、什么是加盐算法二、如何实现加盐算法2.1 加盐算法代码实现2.2 注册页面中进行密码加盐2.

Java时间轮调度算法的代码实现

《Java时间轮调度算法的代码实现》时间轮是一种高效的定时调度算法,主要用于管理延时任务或周期性任务,它通过一个环形数组(时间轮)和指针来实现,将大量定时任务分摊到固定的时间槽中,极大地降低了时间复杂... 目录1、简述2、时间轮的原理3. 时间轮的实现步骤3.1 定义时间槽3.2 定义时间轮3.3 使用时