强化学习11——DQN算法

2024-01-16 05:36
文章标签 算法 学习 强化 dqn

本文主要是介绍强化学习11——DQN算法,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

DQN算法的全称为,Deep Q-Network,即在Q-learning算法的基础上引用深度神经网络来近似动作函数 Q ( s , a ) Q(s,a) Q(s,a) 。对于传统的Q-learning,当状态或动作数量特别大的时候,如处理一张图片,假设为 210 × 160 × 3 210×160×3 210×160×3,共有 25 6 ( 210 × 60 × 3 ) 256^{(210×60×3)} 256(210×60×3)种状态,难以存储,但可以使用参数化的函数 Q θ Q_{\theta} Qθ 来拟合这些数据,即DQN算法。同时DQN还引用了经验回放和目标网络,接下来将以此介绍。

CartPole 环境

image.png

在车杆环境中,通过移动小车,让小车上的杆保持垂直,如果杆的倾斜度数过大或者车子偏离初始位置的距离过大,或者坚持了一定的时间,则结束本轮训练。该智能体的状态是四维向量,每个状态是连续的,但其动作是离散的,动作的工作空间是2。

维度意义最小值最大值
0车的位置-2.42.4
1车的速度-InfInf
2杆的角度~ -41.8°~ 41.8°
3杆尖端的速度-InfInf
标号动作
0向左移动小车
1向右移动小车

深度网络

我们通过神经网络将输入向量 x x x映射到输出向量 y y y,通过下式表示:
y = f θ ( x ) y=f_{\theta}(x) y=fθ(x)
神经网络可以理解为是一个函数,输入输出都是向量,并且拥有可以学习的参数 θ \theta θ ,通过梯度下降等方法,使得神经网络能够逼近任意函数,当然可以用来近似动作价值函数:
y ⃗ = Q θ ( s ⃗ , a ⃗ ) \vec{y}=Q_{\theta}(\vec{s},\vec{a}) y =Qθ(s ,a )
在本环境种,由于状态的每一维度的值都是连续的,无法使用表格记录,因此可以使用一个神经网络表示函数Q。当动作是连续(无限)时,神经网络的输入是状态s和动作a,输出一个标量,表示在状态s下采取动作a能获得的价值。若动作是离散(有限)的,除了采取动作连续情况下的做法,还可以只将状态s输入到神经忘了,输出每一个动作的Q值。

假设使用神经网络拟合w,则每一个状态s下所有可能动作a的Q值为 Q w ( s , a ) Q_w(s,a) Qw(s,a),我们称为Q网络:

image.png

我们在Q-learning种使用下面的方式更新:
Q ( s , a ) ← Q ( s , a ) + α [ r + γ max ⁡ a ′ ∈ A Q ( s ′ , a ′ ) − Q ( s , a ) ] Q(s,a)\leftarrow Q(s,a)+\alpha\left[r+\gamma\max_{a'\in\mathcal{A}}Q(s',a')-Q(s,a)\right] Q(s,a)Q(s,a)+α[r+γaAmaxQ(s,a)Q(s,a)]
即让 Q ( s , a ) Q(s,a) Q(s,a) r + γ max ⁡ a ′ ∈ A Q ( s ′ , a ′ ) r+\gamma\max_{a'\in\mathcal{A}}Q(s',a') r+γmaxaAQ(s,a)靠近,那么Q网络的损失函数为均方误差的形式:
ω ∗ = arg ⁡ min ⁡ ω 1 2 N ∑ i = 1 N [ Q ω ( s i , a i ) − ( r i + γ max ⁡ a ′ Q ω ( s i ′ , a ′ ) ) ] 2 \omega^*=\arg\min_{\omega}\frac{1}{2N}\sum_{i=1}^{N}\left[Q_\omega\left(s_i,a_i\right)-\left(r_i+\gamma\max_{a'}Q_\omega\left(s_i',a'\right)\right)\right]^2 ω=argωmin2N1i=1N[Qω(si,ai)(ri+γamaxQω(si,a))]2

经验回访

将Q-learning过程中,每次从环境中采样得到的四元组数据(状态、动作、奖励、下一状态)存储到回放缓冲区中,之后在训练Q网络时,再从回访缓冲区中,随机采样若干数据进行训练。

image.png

在一般的监督学习中,都是假定训练数据是独立同分布的,而在强化学习中,连续的采样、交互所得到的数据有很强的相关性,这一时刻的状态和上一时刻的状态有关,不满足独立假设。通过在回访缓冲区采样,可以打破样本之间的相关性。另外每一个样本可以使用多次,也适合深度学习。

目标网络

构建两个网络,一个是目标网络,一个是当前网络,二者结构相同,都用于近似Q值。在实践中每隔若干步才把每步更新的当前网络参数复制给目标网络,这样做的好处是保证训练的稳定,当训练的结果不好时,可以不同步当前网络的值,避免Q值的估计发散。

image.png

在计算期望时,使用目标网络来计算:
Q 期望 = [ r t + γ max ⁡ a ′ Q ω ˉ ( s ′ , a ′ ) ] Q_\text{期望}=[r_t+\gamma\max_{a^{\prime}}Q_{\bar{\omega}}(s^{\prime},a^{\prime})] Q期望=[rt+γamaxQωˉ(s,a)]
具体流程如下所示:

  • 使用随机的网络参数 ω \omega ω初始化初始化当前网络 Q ω ( s , a ) Q_{\omega}(s,a) Qω(s,a)
  • 复制相同的参数初始化目标网络 ω ˉ ← ω \bar{\omega}\gets \omega ωˉω
  • 初始化经验回访池R
  • for 序列 e = 1 → E e=1\to E e=1E do
    • 获取环境初始状态 s 1 s_1 s1
    • for 时间步 t = 1 → T 时间步t=1\to T 时间步t=1T do
      • 根据当前网络 Q ω ( s , a ) Q_{\omega}(s,a) Qω(s,a) ϵ − g r e e d y \epsilon -greedy ϵgreedy策略选择动作 a t a_t at
      • 执行动作 a t a_t at,获得回报 r t r_t rt,环境状态变为 s t + 1 s_{t+1} st+1
      • ( s t , a t , r t , s t + 1 ) (s_t,a_t,r_t,s_{t+1}) (st,at,rt,st+1)存储进回池R
      • 若R中数据足够,则从R中采样N个数据 { ( s i , a i , r i , s i + 1 ) } i = 1 , … , N \{(s_i,a_i,r_i,s_{i+1})\}_{i=1,\ldots,N} {(si,ai,ri,si+1)}i=1,,N
      • 对每个数据,用目标网络计算 y = r i + γ max ⁡ a Q ω ˉ ( s i + 1 , a ) y=r_i+\gamma\max_aQ_{\bar{\omega}}(s_{i+1},a) y=ri+γmaxaQωˉ(si+1,a)
      • 最小化目标损失 L = 1 N ∑ i ( y i − Q ω ( s i , a i ) ) 2 L=\frac{1}{N}\sum_{i}(y_{i}-Q_{\omega}(s_{i},a_{i}))^{2} L=N1i(yiQω(si,ai))2,以更新当前网络 Q ω Q_{\omega} Qω
      • 更新目标网络
    • end for
  • end for
import random
from typing import Any
import gymnasium as gym
import numpy as np
import collections
from tqdm import tqdm
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import rl_utils# 首先定义经验回收池的类,包括加入数据、采样数据
class ReplayBuffer:def __init__(self, capacity):# 创建一个队列,先进先出self.buffer=collections.deque(maxlen=capacity)def add(self,state,action,reward,next_state,done):# 加入数据self.buffer.append((state,action,reward,next_state,done))def sample(self,batch_size):# 随机采样数据mini_batch=random.sample(self.buffer,batch_size)# zip(*)取mini_batch中的每个元素(即取列),并返回一个元组state,action,reward,next_state,done=zip(*mini_batch)return np.array(state), action, reward, np.array(next_state), donedef size(self):return len(self.buffer)# 定义一个只有一层隐藏层的Q网络
class Qnet(torch.nn.Module):def __init__(self,state_dim,hidden_dim,action_dim):super(Qnet,self).__init__()# 定义一个全连接层,输入为state_dim维向量,输出为hidden_dim维向量self.fc1=torch.nn.Linear(state_dim,hidden_dim)# 定义一个全连接层,输入为hidden_dim维向量,输出为action_dim维向量self.fc2=torch.nn.Linear(hidden_dim,action_dim)def forward(self,state):x = F.relu(self.fc1(state))return self.fc2(x)class DQN:def __init__(self,state_dim,hidden_dim,action_dim,learning_rate,gamma,epsilon,target_update,device):self.action_dim=action_dimself.q_net=Qnet(state_dim,hidden_dim,action_dim).to(device)# 目标网络self.target_q_net=Qnet(state_dim,hidden_dim,action_dim).to(device)# 使用Adam优化器self.optimizer=torch.optim.Adam(self.q_net.parameters(),lr=learning_rate)# 折扣因子self.gamma=gamma# 贪婪策略self.epsilon=epsilon# 目标网络更新频率self.target_update=target_update# 计数器self.count=0self.device=devicedef take_action(self,state):# 判断是否需要贪婪策略if np.random.random()<self.epsilon:action=np.random.randint(self.action_dim)else:state=torch.tensor([state],dtype=torch.float).to(self.device)action=self.q_net(state).argmax().item()return actiondef update(self,transition_dict):states = torch.tensor(transition_dict['states'],dtype=torch.float).to(self.device)actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(self.device)rewards = torch.tensor(transition_dict['rewards'],dtype=torch.float).view(-1, 1).to(self.device)next_states = torch.tensor(transition_dict['next_states'],dtype=torch.float).to(self.device)dones = torch.tensor(transition_dict['dones'],dtype=torch.float).view(-1, 1).to(self.device)# Q值q_values=self.q_net(states).gather(1,actions)# 下一个状态的最大Q值max_next_q_values=self.target_q_net(next_states).max(1)[0].view(-1, 1)q_targets=rewards+self.gamma*max_next_q_values*(1-dones)# 反向传播更新参数dqn_loss=torch.mean(F.mse_loss(q_values, q_targets)) # 均方误差损失函数self.optimizer.zero_grad()dqn_loss.backward()self.optimizer.step()if self.count % self.target_update == 0:self.target_q_net.load_state_dict(self.q_net.state_dict())  # 更新目标网络self.count += 1lr = 2e-3
num_episodes = 500
hidden_dim = 128
gamma = 0.98
epsilon = 0.01
target_update = 10
buffer_size = 10000
minimal_size = 500
batch_size = 64
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")env_name = 'CartPole-v0'
env = gym.make(env_name)
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
replay_buffer = ReplayBuffer(buffer_size)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = DQN(state_dim, hidden_dim, action_dim, lr, gamma, epsilon,target_update, device)return_list = []
for i in range(10):with tqdm(total=int(num_episodes / 10), desc='Iteration %d' % i) as pbar:for i_episode in range(int(num_episodes / 10)):episode_return = 0state = env.reset()[0]aa=state[0]print(state)done = Falsewhile not done:action = agent.take_action(state)next_state, reward, done,info, _ = env.step(action)replay_buffer.add(state, action, reward, next_state, done)state = next_stateepisode_return += reward# 当buffer数据的数量超过一定值后,才进行Q网络训练if replay_buffer.size() > minimal_size:b_s, b_a, b_r, b_ns, b_d = replay_buffer.sample(batch_size)transition_dict = {'states': b_s,'actions': b_a,'next_states': b_ns,'rewards': b_r,'dones': b_d}agent.update(transition_dict)return_list.append(episode_return)if (i_episode + 1) % 10 == 0:pbar.set_postfix({'episode':'%d' % (num_episodes / 10 * i + i_episode + 1),'return':'%.3f' % np.mean(return_list[-10:])})pbar.update(1)

image.png

这篇关于强化学习11——DQN算法的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/611475

相关文章

Go学习记录之runtime包深入解析

《Go学习记录之runtime包深入解析》Go语言runtime包管理运行时环境,涵盖goroutine调度、内存分配、垃圾回收、类型信息等核心功能,:本文主要介绍Go学习记录之runtime包的... 目录前言:一、runtime包内容学习1、作用:① Goroutine和并发控制:② 垃圾回收:③ 栈和

Android学习总结之Java和kotlin区别超详细分析

《Android学习总结之Java和kotlin区别超详细分析》Java和Kotlin都是用于Android开发的编程语言,它们各自具有独特的特点和优势,:本文主要介绍Android学习总结之Ja... 目录一、空安全机制真题 1:Kotlin 如何解决 Java 的 NullPointerExceptio

使用雪花算法产生id导致前端精度缺失问题解决方案

《使用雪花算法产生id导致前端精度缺失问题解决方案》雪花算法由Twitter提出,设计目的是生成唯一的、递增的ID,下面:本文主要介绍使用雪花算法产生id导致前端精度缺失问题的解决方案,文中通过代... 目录一、问题根源二、解决方案1. 全局配置Jackson序列化规则2. 实体类必须使用Long封装类3.

重新对Java的类加载器的学习方式

《重新对Java的类加载器的学习方式》:本文主要介绍重新对Java的类加载器的学习方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1、介绍1.1、简介1.2、符号引用和直接引用1、符号引用2、直接引用3、符号转直接的过程2、加载流程3、类加载的分类3.1、显示

Springboot实现推荐系统的协同过滤算法

《Springboot实现推荐系统的协同过滤算法》协同过滤算法是一种在推荐系统中广泛使用的算法,用于预测用户对物品(如商品、电影、音乐等)的偏好,从而实现个性化推荐,下面给大家介绍Springboot... 目录前言基本原理 算法分类 计算方法应用场景 代码实现 前言协同过滤算法(Collaborativ

Java学习手册之Filter和Listener使用方法

《Java学习手册之Filter和Listener使用方法》:本文主要介绍Java学习手册之Filter和Listener使用方法的相关资料,Filter是一种拦截器,可以在请求到达Servl... 目录一、Filter(过滤器)1. Filter 的工作原理2. Filter 的配置与使用二、Listen

openCV中KNN算法的实现

《openCV中KNN算法的实现》KNN算法是一种简单且常用的分类算法,本文主要介绍了openCV中KNN算法的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的... 目录KNN算法流程使用OpenCV实现KNNOpenCV 是一个开源的跨平台计算机视觉库,它提供了各

springboot+dubbo实现时间轮算法

《springboot+dubbo实现时间轮算法》时间轮是一种高效利用线程资源进行批量化调度的算法,本文主要介绍了springboot+dubbo实现时间轮算法,文中通过示例代码介绍的非常详细,对大家... 目录前言一、参数说明二、具体实现1、HashedwheelTimer2、createWheel3、n

SpringBoot实现MD5加盐算法的示例代码

《SpringBoot实现MD5加盐算法的示例代码》加盐算法是一种用于增强密码安全性的技术,本文主要介绍了SpringBoot实现MD5加盐算法的示例代码,文中通过示例代码介绍的非常详细,对大家的学习... 目录一、什么是加盐算法二、如何实现加盐算法2.1 加盐算法代码实现2.2 注册页面中进行密码加盐2.

Java时间轮调度算法的代码实现

《Java时间轮调度算法的代码实现》时间轮是一种高效的定时调度算法,主要用于管理延时任务或周期性任务,它通过一个环形数组(时间轮)和指针来实现,将大量定时任务分摊到固定的时间槽中,极大地降低了时间复杂... 目录1、简述2、时间轮的原理3. 时间轮的实现步骤3.1 定义时间槽3.2 定义时间轮3.3 使用时