pytorch强化学习(2)——重写DQN

2024-03-07 01:36
文章标签 学习 重写 pytorch 强化 dqn

本文主要是介绍pytorch强化学习(2)——重写DQN,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

思路

在q-learning当中,Q函数的输入是状态state和action,输出是q-value。

而DQN就是使用神经网络来拟合Q函数,所以从直观上来说,我觉得神经网络的输入应该是状态state和action,输出应该是q-value。

但是,网上绝大多数DQN的代码实现都把state作为网络输入,把所有action的q-value的组合作为网络输出。我觉得这是不直观的、令人费解的,于是我按照自己的想法写了一份DQN代码。

在下面的代码中,神经网络的输入是state和action的连接,若干个浮点数表示state,一个整数表示action。神经网络的输出只有一个元素,代表q-value的值。

代码

env.py

import gym
from DQN_brain import DQN
import matplotlib.pyplot as plt
import numpylr = 1e-3  # 学习率
gamma = 0.9  # 折扣因子
epsilon = 0.9  # 贪心系数
n_hidden = 50  # 隐含层神经元个数env = gym.make("CartPole-v1")
n_states = env.observation_space.shape[0]  # 4
n_actions = env.action_space.n  # 2 动作的个数dqn = DQN(n_states, n_hidden, n_actions, lr, gamma, epsilon)if __name__ == '__main__':reward_list = []for i in range(100):# 获取初始环境state = env.reset()[0]  # len=4total_reward = 0done = Falsewhile True:# 获取最优动作action = dqn.optimal_action(state)# 有一定概率不采取最优动作,而是随机选择一个动作执行,这一点很重要if numpy.random.random() > epsilon:action = numpy.random.randint(n_actions)# 更新环境next_state, reward, done, _, _ = env.step(action)dqn.learning(state, next_state, action, reward, done)# 更新一些变量state = next_statetotal_reward += rewardif done:breakprint("第%d回合,total_reward=%f" % (i, total_reward))reward_list.append(total_reward)# 绘图episodes_list = list(range(len(reward_list)))plt.plot(episodes_list, reward_list)plt.xlabel('Episodes')plt.ylabel('Returns')plt.title('DQN Returns')plt.show()

DQN_brain.py

import torch
from torch import nn, Tensorclass Net(nn.Module):# 构造有2个隐含层的网络def __init__(self, input_dim: int, n_hidden: int, output_dim: int):super().__init__()self.network = nn.Sequential(torch.nn.Linear(input_dim, n_hidden, dtype=torch.float),torch.nn.ReLU(),torch.nn.Linear(n_hidden, n_hidden, dtype=torch.float),torch.nn.ReLU(),torch.nn.Linear(n_hidden, n_hidden, dtype=torch.float),torch.nn.ReLU(),torch.nn.Linear(n_hidden, output_dim, dtype=torch.float),)# 前传,直接调用Net对象,其实就是调用forward函数def forward(self, x):  # [b,n_states]return self.network(x)class DQN:def __init__(self, n_states: int, n_hidden: int, n_actions: int, lr: float, gamma: float, epsilon: float):# 属性分配self.n_states = n_states  # 状态的特征数self.n_hidden = n_hidden  # 隐含层个数self.n_actions = n_actions  # 动作数self.lr = lr  # 训练时的学习率self.gamma = gamma  # 折扣因子,对下一状态的回报的缩放self.epsilon = epsilon  # 贪婪策略,有1-epsilon的概率探索# 实例化训练网络,网络的输入是state+action,# 网络的输出是只有一个元素的一维向量,代表该动作在该状态下的q-valueself.q_net = Net(self.n_states + 1, self.n_hidden, 1)# 优化器,更新训练网络的参数self.q_optimizer = torch.optim.Adam(self.q_net.parameters(), lr=lr)self.criterion = torch.nn.MSELoss()  # 损失函数# 把状态和动作转化为tensor并连接起来def _concat_input(self, state: list[float], action: int):state_tensor = torch.tensor(state, dtype=torch.float)action_tensor = torch.tensor([action], dtype=torch.float)return torch.concat([state_tensor, action_tensor])# 获取q-value值最大的actiondef optimal_action(self, state: list[float]):q_values = torch.tensor([], dtype=torch.float)# 获取所有action的q-valuefor action in range(self.n_actions):q_values = torch.concat([q_values, self.get_q_value(state, action)])# 返回值最大的那个下标,item()函数只能对只有单个元素的tensor使用return torch.argmax(q_values).item()# 更新网络def learning(self,state: list[float],next_state: list[float],action: int,reward: float,done: bool) -> None:# 下一状态的最优动作next_optimal_action = self.optimal_action(next_state)# 当前状态q_valueq_value = self.get_q_value(state, action)# 下一状态q_valuenext_q_value = self.get_q_value(next_state, next_optimal_action)# q_target计算q_target = reward + self.gamma * next_q_value * (1. - float(done))# 计算loss,然后反向传播,然后梯度下降loss: Tensor = self.criterion(q_value, q_target)self.q_optimizer.zero_grad()loss.backward()self.q_optimizer.step()# 根据状态和动作获取q_valuedef get_q_value(self, state: list[float], action: int) -> Tensor:return self.q_net(self._concat_input(state, action))# tensor([5.5241], grad_fn=<ViewBackward0>)

这篇关于pytorch强化学习(2)——重写DQN的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/782042

相关文章

Go学习记录之runtime包深入解析

《Go学习记录之runtime包深入解析》Go语言runtime包管理运行时环境,涵盖goroutine调度、内存分配、垃圾回收、类型信息等核心功能,:本文主要介绍Go学习记录之runtime包的... 目录前言:一、runtime包内容学习1、作用:① Goroutine和并发控制:② 垃圾回收:③ 栈和

Android学习总结之Java和kotlin区别超详细分析

《Android学习总结之Java和kotlin区别超详细分析》Java和Kotlin都是用于Android开发的编程语言,它们各自具有独特的特点和优势,:本文主要介绍Android学习总结之Ja... 目录一、空安全机制真题 1:Kotlin 如何解决 Java 的 NullPointerExceptio

Pytorch介绍与安装过程

《Pytorch介绍与安装过程》PyTorch因其直观的设计、卓越的灵活性以及强大的动态计算图功能,迅速在学术界和工业界获得了广泛认可,成为当前深度学习研究和开发的主流工具之一,本文给大家介绍Pyto... 目录1、Pytorch介绍1.1、核心理念1.2、核心组件与功能1.3、适用场景与优势总结1.4、优

conda安装GPU版pytorch默认却是cpu版本

《conda安装GPU版pytorch默认却是cpu版本》本文主要介绍了遇到Conda安装PyTorchGPU版本却默认安装CPU的问题,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的... 目录一、问题描述二、网上解决方案罗列【此节为反面方案罗列!!!】三、发现的根本原因[独家]3.1 p

PyTorch中cdist和sum函数使用示例详解

《PyTorch中cdist和sum函数使用示例详解》torch.cdist是PyTorch中用于计算**两个张量之间的成对距离(pairwisedistance)**的函数,常用于点云处理、图神经网... 目录基本语法输出示例1. 简单的 2D 欧几里得距离2. 批量形式(3D Tensor)3. 使用不

PyTorch高级特性与性能优化方式

《PyTorch高级特性与性能优化方式》:本文主要介绍PyTorch高级特性与性能优化方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、自动化机制1.自动微分机制2.动态计算图二、性能优化1.内存管理2.GPU加速3.多GPU训练三、分布式训练1.分布式数据

重新对Java的类加载器的学习方式

《重新对Java的类加载器的学习方式》:本文主要介绍重新对Java的类加载器的学习方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1、介绍1.1、简介1.2、符号引用和直接引用1、符号引用2、直接引用3、符号转直接的过程2、加载流程3、类加载的分类3.1、显示

判断PyTorch是GPU版还是CPU版的方法小结

《判断PyTorch是GPU版还是CPU版的方法小结》PyTorch作为当前最流行的深度学习框架之一,支持在CPU和GPU(NVIDIACUDA)上运行,所以对于深度学习开发者来说,正确识别PyTor... 目录前言为什么需要区分GPU和CPU版本?性能差异硬件要求如何检查PyTorch版本?方法1:使用命

Java学习手册之Filter和Listener使用方法

《Java学习手册之Filter和Listener使用方法》:本文主要介绍Java学习手册之Filter和Listener使用方法的相关资料,Filter是一种拦截器,可以在请求到达Servl... 目录一、Filter(过滤器)1. Filter 的工作原理2. Filter 的配置与使用二、Listen

pytorch自动求梯度autograd的实现

《pytorch自动求梯度autograd的实现》autograd是一个自动微分引擎,它可以自动计算张量的梯度,本文主要介绍了pytorch自动求梯度autograd的实现,具有一定的参考价值,感兴趣... autograd是pytorch构建神经网络的核心。在 PyTorch 中,结合以下代码例子,当你