强化学习原理python篇06——DQN

2024-01-28 06:04
文章标签 python 学习 原理 强化 06 dqn

本文主要是介绍强化学习原理python篇06——DQN,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

强化学习原理python篇05——DQN

  • DQN 算法
    • 定义DQN网络
    • 初始化环境
    • 开始训练
    • 可视化结果

本章全篇参考赵世钰老师的教材 Mathmatical-Foundation-of-Reinforcement-Learning Deep Q-learning 章节,请各位结合阅读,本合集只专注于数学概念的代码实现。

DQN 算法

1)使用随机权重 ( w ← 1.0 ) (w←1.0) w1.0初始化目标网络 Q ( s , a , w ) Q(s, a, w) Q(s,a,w)和网络 Q ^ ( s , a , w ) \hat Q(s, a, w) Q^(s,a,w) Q Q Q Q ^ \hat Q Q^相同,清空回放缓冲区。

2)以概率ε选择一个随机动作a,否则 a = a r g m a x Q ( s , a , w ) a=argmaxQ(s,a,w) a=argmaxQ(s,a,w)

3)在模拟器中执行动作a,观察奖励r和下一个状态s’。

4)将转移过程(s, a, r, s’)存储在回放缓冲区中。

5)从回放缓冲区中采样一个随机的小批量转移过程。

6)对于回放缓冲区中的每个转移过程,如果片段在此步结束,则计算目标 y = r y=r y=r,否则计算 y = r + γ m a x Q ^ ( s , a , w ) y=r+\gamma max \hat Q(s, a, w) y=r+γmaxQ^(s,a,w)

7)计算损失: L = ( Q ( s , a , w ) – y ) 2 L=(Q(s, a, w)–y)^2 L=(Q(s,a,w)y)2

8)固定网络 Q ^ ( s , a , w ) \hat Q(s, a, w) Q^(s,a,w)不变,通过最小化模型参数的损失,使用SGD算法更新 Q ( s , a ) Q(s, a) Q(s,a)

9)每N步,将权重从目标网络 Q Q Q复制到 Q ^ ( s , a , w ) \hat Q(s, a, w) Q^(s,a,w)

10)从步骤2开始重复,直到收敛为止。

定义DQN网络

import collections
import copy
import random
from collections import defaultdict
import math
import gym
import gym.spaces
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from gym.envs.toy_text import frozen_lake
from torch.utils.tensorboard import SummaryWriterclass Net(nn.Module):def __init__(self, obs_size, hidden_size, q_table_size):super(Net, self).__init__()self.net = nn.Sequential(# 输入为状态,样本为(1*n)nn.Linear(obs_size, hidden_size),nn.ReLU(),# nn.Linear(hidden_size, hidden_size),# nn.ReLU(),nn.Linear(hidden_size, q_table_size),)def forward(self, state):return self.net(state)class DQN:def __init__(self, env, tgt_net, net):self.env = envself.tgt_net = tgt_netself.net = netdef generate_train_data(self, batch_size, epsilon):state, _ = env.reset()train_data = []while len(train_data)<batch_size*2:q_table_tgt = self.tgt_net(torch.Tensor(state)).detach()if np.random.uniform(0, 1, 1) > epsilon:action = self.env.action_space.sample()else:action = int(torch.argmax(q_table_tgt))new_state, reward,terminated, truncted, info = env.step(action)train_data.append([state, action, reward, new_state, terminated])state = new_stateif terminated:state, _ = env.reset()continuerandom.shuffle(train_data)                return train_data[:batch_size]def calculate_y_hat_and_y(self, batch):# 6)对于回放缓冲区中的每个转移过程,如果片段在此步结束,则计算目标$y=r$,否则计算$y=r+\gamma max \hat Q(s, a, w)$ 。y = []state_space = []action_space = []for state, action, reward, new_state, terminated in batch:# y值if terminated:y.append(reward)else:# 下一步的 qtable 的最大值q_table_net = self.net(torch.Tensor(np.array([new_state]))).detach()y.append(reward + gamma * float(torch.max(q_table_net)))# y hat的值state_space.append(state)action_space.append(action)idx = [list(range(len(action_space))), action_space]y_hat = self.tgt_net(torch.Tensor(np.array(state_space)))[idx]return y_hat, torch.tensor(y)def update_net_parameters(self, update=True):self.net.load_state_dict(self.tgt_net.state_dict())

初始化环境

   # 初始化环境
env = gym.make("CartPole-v1")
# env = DiscreteOneHotWrapper(env)hidden_num = 64
# 定义网络
net = Net(env.observation_space.shape[0],hidden_num, env.action_space.n)
tgt_net = Net(env.observation_space.shape[0],hidden_num, env.action_space.n)
dqn = DQN(env=env, net=net, tgt_net=tgt_net)# 初始化参数
# dqn.init_net_and_target_net_weight()# 定义优化器
opt = optim.Adam(tgt_net.parameters(), lr=0.001)# 定义损失函数
loss = nn.MSELoss()# 记录训练过程
# writer = SummaryWriter(log_dir="logs/DQN", comment="DQN")

开始训练

gamma = 0.8
for i in range(10000):batch = dqn.generate_train_data(256, 0.8)y_hat, y = dqn.calculate_y_hat_and_y(batch)opt.zero_grad()l = loss(y_hat, y)l.backward()opt.step()print("MSE: {}".format(l.item()))if i % 5 == 0:dqn.update_net_parameters(update=True)

输出:

MSE: 0.027348674833774567
MSE: 0.1803671419620514
MSE: 0.06523636728525162
MSE: 0.08363766968250275
MSE: 0.062360599637031555
MSE: 0.004909628536552191
MSE: 0.05730309337377548
MSE: 0.03543371334671974
MSE: 0.08458714932203293

可视化结果

env = gym.make("CartPole-v1", render_mode = "human")
env = gym.wrappers.RecordVideo(env, video_folder="video")state, info = env.reset()
total_rewards = 0while True:q_table_state = dqn.tgt_net(torch.Tensor(state)).detach()# if np.random.uniform(0, 1, 1) > 0.9:#     action = env.action_space.sample()# else:action = int(torch.argmax(q_table_state))state, reward, terminated, truncted, info = env.step(action)if terminated:break

这篇关于强化学习原理python篇06——DQN的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/652678

相关文章

Java中流式并行操作parallelStream的原理和使用方法

《Java中流式并行操作parallelStream的原理和使用方法》本文详细介绍了Java中的并行流(parallelStream)的原理、正确使用方法以及在实际业务中的应用案例,并指出在使用并行流... 目录Java中流式并行操作parallelStream0. 问题的产生1. 什么是parallelS

Java中Redisson 的原理深度解析

《Java中Redisson的原理深度解析》Redisson是一个高性能的Redis客户端,它通过将Redis数据结构映射为Java对象和分布式对象,实现了在Java应用中方便地使用Redis,本文... 目录前言一、核心设计理念二、核心架构与通信层1. 基于 Netty 的异步非阻塞通信2. 编解码器三、

Java HashMap的底层实现原理深度解析

《JavaHashMap的底层实现原理深度解析》HashMap基于数组+链表+红黑树结构,通过哈希算法和扩容机制优化性能,负载因子与树化阈值平衡效率,是Java开发必备的高效数据结构,本文给大家介绍... 目录一、概述:HashMap的宏观结构二、核心数据结构解析1. 数组(桶数组)2. 链表节点(Node

Python版本信息获取方法详解与实战

《Python版本信息获取方法详解与实战》在Python开发中,获取Python版本号是调试、兼容性检查和版本控制的重要基础操作,本文详细介绍了如何使用sys和platform模块获取Python的主... 目录1. python版本号获取基础2. 使用sys模块获取版本信息2.1 sys模块概述2.1.1

一文详解Python如何开发游戏

《一文详解Python如何开发游戏》Python是一种非常流行的编程语言,也可以用来开发游戏模组,:本文主要介绍Python如何开发游戏的相关资料,文中通过代码介绍的非常详细,需要的朋友可以参考下... 目录一、python简介二、Python 开发 2D 游戏的优劣势优势缺点三、Python 开发 3D

Python函数作用域与闭包举例深度解析

《Python函数作用域与闭包举例深度解析》Python函数的作用域规则和闭包是编程中的关键概念,它们决定了变量的访问和生命周期,:本文主要介绍Python函数作用域与闭包的相关资料,文中通过代码... 目录1. 基础作用域访问示例1:访问全局变量示例2:访问外层函数变量2. 闭包基础示例3:简单闭包示例4

Python实现字典转字符串的五种方法

《Python实现字典转字符串的五种方法》本文介绍了在Python中如何将字典数据结构转换为字符串格式的多种方法,首先可以通过内置的str()函数进行简单转换;其次利用ison.dumps()函数能够... 目录1、使用json模块的dumps方法:2、使用str方法:3、使用循环和字符串拼接:4、使用字符

Python版本与package版本兼容性检查方法总结

《Python版本与package版本兼容性检查方法总结》:本文主要介绍Python版本与package版本兼容性检查方法的相关资料,文中提供四种检查方法,分别是pip查询、conda管理、PyP... 目录引言为什么会出现兼容性问题方法一:用 pip 官方命令查询可用版本方法二:conda 管理包环境方法

Redis中Hash从使用过程到原理说明

《Redis中Hash从使用过程到原理说明》RedisHash结构用于存储字段-值对,适合对象数据,支持HSET、HGET等命令,采用ziplist或hashtable编码,通过渐进式rehash优化... 目录一、开篇:Hash就像超市的货架二、Hash的基本使用1. 常用命令示例2. Java操作示例三

Redis中Set结构使用过程与原理说明

《Redis中Set结构使用过程与原理说明》本文解析了RedisSet数据结构,涵盖其基本操作(如添加、查找)、集合运算(交并差)、底层实现(intset与hashtable自动切换机制)、典型应用场... 目录开篇:从购物车到Redis Set一、Redis Set的基本操作1.1 编程常用命令1.2 集