使用paddle2的DQN跑Mountain

2023-10-30 00:20
文章标签 使用 dqn mountain paddle2

本文主要是介绍使用paddle2的DQN跑Mountain,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1.Agent

Agent就是一个接口,sample就是实现了一个随机探索,本质还是用的self.alg.predict()函数
然后Agent.learn(self, obs, act, reward, next_obs, terminal)就是将从环境拿到的obs, act, reward, next_obs, terminal转化为tensor形式,然后送给算法中的learn,即self.alg.learn(obs, act, reward, next_obs, terminal)

import parl
import paddle
import numpy as npclass Agent(parl.Agent):def __init__(self, algorithm, act_dim, e_greed=0.1, e_greed_decrement=0):super(Agent, self).__init__(algorithm)assert isinstance(act_dim, int)self.act_dim = act_dimself.global_step = 0self.update_target_steps = 200self.e_greed = e_greedself.e_greed_decrement = e_greed_decrementdef sample(self, obs):"""Sample an action `for exploration` when given an observationArgs:obs(np.float32): shape of (obs_dim,)Returns:act(int): action"""sample = np.random.random()if sample < self.e_greed:act = np.random.randint(self.act_dim)else:if np.random.random() < 0.01:act = np.random.randint(self.act_dim)else:act = self.predict(obs)self.e_greed = max(0.01, self.e_greed - self.e_greed_decrement)return act  ##返回动作def predict(self, obs):"""Predict an action when given an observationArgs:obs(np.float32): shape of (obs_dim,)Returns:act(int): action"""obs = paddle.to_tensor(obs, dtype='float32')	##将环境obs转换为tensor形式pred_q = self.alg.predict(obs)     ##调用了算法中的predict函数act = pred_q.argmax().numpy()[0]	##找最大值,返回第一个数据即actreturn actdef learn(self, obs, act, reward, next_obs, terminal):"""Update model with an episode dataArgs:obs(np.float32): shape of (batch_size, obs_dim)act(np.int32): shape of (batch_size)reward(np.float32): shape of (batch_size)next_obs(np.float32): shape of (batch_size, obs_dim)terminal(np.float32): shape of (batch_size)Returns:loss(float)"""if self.global_step % self.update_target_steps == 0:self.alg.sync_target()self.global_step += 1##扩展维度1变为【1】act = np.expand_dims(act, axis=-1)reward = np.expand_dims(reward, axis=-1)terminal = np.expand_dims(terminal, axis=-1)##将arrary转换为tensor形式obs = paddle.to_tensor(obs, dtype='float32')act = paddle.to_tensor(act, dtype='int32')reward = paddle.to_tensor(reward, dtype='float32')next_obs = paddle.to_tensor(next_obs, dtype='float32')terminal = paddle.to_tensor(terminal, dtype='float32')##调用算法中的learn,因为self.alg引用算法中的learn了loss = self.alg.learn(obs, act, reward, next_obs, terminal)return loss.numpy()[0]

2.Model

model就是定义网络的结构,nn.Linear(输入维度,输出维度)。前向网络就是输入进入全连接层,然后relu激活函数;再经过第二层全连接层,然后relu激活函数,最后再全连接层输出。输出维度为act_dim。

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import parlclass Model(parl.Model):""" Linear network to solve Cartpole problem.Args:obs_dim (int): Dimension of observation space.act_dim (int): Dimension of action space."""def __init__(self, obs_dim, act_dim):super(Model, self).__init__()hid1_size = 128hid2_size = 128self.fc1 = nn.Linear(obs_dim, hid1_size)self.fc2 = nn.Linear(hid1_size, hid2_size)self.fc3 = nn.Linear(hid2_size, act_dim)def forward(self, obs):h1 = F.relu(self.fc1(obs))h2 = F.relu(self.fc2(h1))Q = self.fc3(h2)return Q

3.Train

import gym
import numpy as np
from parl.utils import logger, ReplayMemoryfrom Model import Model
from Agent import Agent
from parl.algorithms import DQNLEARN_FREQ = 5  # 训练频率,不需要每一个step都learn,攒一些新增经验后再learn,提高效率
MEMORY_SIZE = 20000  # replay memory的大小,越大越占用内存
MEMORY_WARMUP_SIZE = 200  # replay_memory 里需要预存一些经验数据,再从里面sample一个batch的经验让agent去learn
BATCH_SIZE = 32  # 每次给agent learn的数据数量,从replay memory随机里sample一批数据出来
LEARNING_RATE = 0.001  # 学习率
GAMMA = 0.99  # reward 的衰减因子,一般取 0.90.999 不等# train an episode
def run_train_episode(agent, env, rpm):total_reward = 0obs = env.reset()step = 0while True:step += 1action = agent.sample(obs)——训练的时候用sampl函数next_obs, reward, done, _ = env.step(action)#这里体现了Q-learningrpm.append(obs, action, reward, next_obs, done)——存储到经验池# train model——进行学习if (len(rpm) > MEMORY_WARMUP_SIZE) and (step % LEARN_FREQ == 0):# s,a,r,s',done(batch_obs, batch_action, batch_reward, batch_next_obs,batch_done) = rpm.sample_batch(BATCH_SIZE)train_loss = agent.learn(batch_obs, batch_action, batch_reward,batch_next_obs, batch_done)total_reward += rewardobs = next_obsif done:breakreturn total_reward# evaluate 5 episodes
def run_evaluate_episodes(agent, env, eval_episodes=5, render=False):eval_reward = []for i in range(eval_episodes):obs = env.reset()episode_reward = 0while True:action = agent.predict(obs)——用训练的模型与环境交互obs, reward, done, _ = env.step(action)episode_reward += reward##记录一轮的游戏得分if render:env.render()if done:breakeval_reward.append(episode_reward)##组装为数组,再进行求平均return np.mean(eval_reward)def main():env = gym.make('MountainCar-v0')obs_dim = env.observation_space.shape[0]act_dim = env.action_space.nlogger.info('obs_dim {}, act_dim {}'.format(obs_dim, act_dim))# set action_shape = 0 while in discrete control environmentrpm = ReplayMemory(MEMORY_SIZE, obs_dim, 0)# build an agentmodel = Model(obs_dim=obs_dim, act_dim=act_dim)alg = DQN(model, gamma=GAMMA, lr=LEARNING_RATE)agent = Agent(alg, act_dim=act_dim, e_greed=0.1, e_greed_decrement=1e-6)
##加载模型save_path = './model.ckpt'agent.restore(save_path)# warmup memorywhile len(rpm) < MEMORY_WARMUP_SIZE:run_train_episode(agent, env, rpm)##总训练次数max_episode = 2000# start trainingepisode = 0while episode < max_episode:# train part一轮训练50for i in range(50):total_reward = run_train_episode(agent, env, rpm)episode += 1# test parteval_reward = run_evaluate_episodes(agent, env, render=True)logger.info('episode:{}    e_greed:{}   Test reward:{}'.format(episode, agent.e_greed, eval_reward))# save the parameters to ./model.ckptsave_path = './model.ckpt'agent.save(save_path)if __name__ == '__main__':main()

4.结果

在这里插入图片描述

这篇关于使用paddle2的DQN跑Mountain的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/304365

相关文章

使用Python实现IP地址和端口状态检测与监控

《使用Python实现IP地址和端口状态检测与监控》在网络运维和服务器管理中,IP地址和端口的可用性监控是保障业务连续性的基础需求,本文将带你用Python从零打造一个高可用IP监控系统,感兴趣的小伙... 目录概述:为什么需要IP监控系统使用步骤说明1. 环境准备2. 系统部署3. 核心功能配置系统效果展

使用Java将各种数据写入Excel表格的操作示例

《使用Java将各种数据写入Excel表格的操作示例》在数据处理与管理领域,Excel凭借其强大的功能和广泛的应用,成为了数据存储与展示的重要工具,在Java开发过程中,常常需要将不同类型的数据,本文... 目录前言安装免费Java库1. 写入文本、或数值到 Excel单元格2. 写入数组到 Excel表格

redis中使用lua脚本的原理与基本使用详解

《redis中使用lua脚本的原理与基本使用详解》在Redis中使用Lua脚本可以实现原子性操作、减少网络开销以及提高执行效率,下面小编就来和大家详细介绍一下在redis中使用lua脚本的原理... 目录Redis 执行 Lua 脚本的原理基本使用方法使用EVAL命令执行 Lua 脚本使用EVALSHA命令

Java 中的 @SneakyThrows 注解使用方法(简化异常处理的利与弊)

《Java中的@SneakyThrows注解使用方法(简化异常处理的利与弊)》为了简化异常处理,Lombok提供了一个强大的注解@SneakyThrows,本文将详细介绍@SneakyThro... 目录1. @SneakyThrows 简介 1.1 什么是 Lombok?2. @SneakyThrows

使用Python和Pyecharts创建交互式地图

《使用Python和Pyecharts创建交互式地图》在数据可视化领域,创建交互式地图是一种强大的方式,可以使受众能够以引人入胜且信息丰富的方式探索地理数据,下面我们看看如何使用Python和Pyec... 目录简介Pyecharts 简介创建上海地图代码说明运行结果总结简介在数据可视化领域,创建交互式地

Java Stream流使用案例深入详解

《JavaStream流使用案例深入详解》:本文主要介绍JavaStream流使用案例详解,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录前言1. Lambda1.1 语法1.2 没参数只有一条语句或者多条语句1.3 一个参数只有一条语句或者多

Java Spring 中 @PostConstruct 注解使用原理及常见场景

《JavaSpring中@PostConstruct注解使用原理及常见场景》在JavaSpring中,@PostConstruct注解是一个非常实用的功能,它允许开发者在Spring容器完全初... 目录一、@PostConstruct 注解概述二、@PostConstruct 注解的基本使用2.1 基本代

C#使用StackExchange.Redis实现分布式锁的两种方式介绍

《C#使用StackExchange.Redis实现分布式锁的两种方式介绍》分布式锁在集群的架构中发挥着重要的作用,:本文主要介绍C#使用StackExchange.Redis实现分布式锁的... 目录自定义分布式锁获取锁释放锁自动续期StackExchange.Redis分布式锁获取锁释放锁自动续期分布式

springboot使用Scheduling实现动态增删启停定时任务教程

《springboot使用Scheduling实现动态增删启停定时任务教程》:本文主要介绍springboot使用Scheduling实现动态增删启停定时任务教程,具有很好的参考价值,希望对大家有... 目录1、配置定时任务需要的线程池2、创建ScheduledFuture的包装类3、注册定时任务,增加、删

使用Python实现矢量路径的压缩、解压与可视化

《使用Python实现矢量路径的压缩、解压与可视化》在图形设计和Web开发中,矢量路径数据的高效存储与传输至关重要,本文将通过一个Python示例,展示如何将复杂的矢量路径命令序列压缩为JSON格式,... 目录引言核心功能概述1. 路径命令解析2. 路径数据压缩3. 路径数据解压4. 可视化代码实现详解1