强化学习原理python篇06——DQN

2024-01-28 06:04
文章标签 python 学习 原理 强化 06 dqn

本文主要是介绍强化学习原理python篇06——DQN,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

强化学习原理python篇05——DQN

  • DQN 算法
    • 定义DQN网络
    • 初始化环境
    • 开始训练
    • 可视化结果

本章全篇参考赵世钰老师的教材 Mathmatical-Foundation-of-Reinforcement-Learning Deep Q-learning 章节,请各位结合阅读,本合集只专注于数学概念的代码实现。

DQN 算法

1)使用随机权重 ( w ← 1.0 ) (w←1.0) w1.0初始化目标网络 Q ( s , a , w ) Q(s, a, w) Q(s,a,w)和网络 Q ^ ( s , a , w ) \hat Q(s, a, w) Q^(s,a,w) Q Q Q Q ^ \hat Q Q^相同,清空回放缓冲区。

2)以概率ε选择一个随机动作a,否则 a = a r g m a x Q ( s , a , w ) a=argmaxQ(s,a,w) a=argmaxQ(s,a,w)

3)在模拟器中执行动作a,观察奖励r和下一个状态s’。

4)将转移过程(s, a, r, s’)存储在回放缓冲区中。

5)从回放缓冲区中采样一个随机的小批量转移过程。

6)对于回放缓冲区中的每个转移过程,如果片段在此步结束,则计算目标 y = r y=r y=r,否则计算 y = r + γ m a x Q ^ ( s , a , w ) y=r+\gamma max \hat Q(s, a, w) y=r+γmaxQ^(s,a,w)

7)计算损失: L = ( Q ( s , a , w ) – y ) 2 L=(Q(s, a, w)–y)^2 L=(Q(s,a,w)y)2

8)固定网络 Q ^ ( s , a , w ) \hat Q(s, a, w) Q^(s,a,w)不变,通过最小化模型参数的损失,使用SGD算法更新 Q ( s , a ) Q(s, a) Q(s,a)

9)每N步,将权重从目标网络 Q Q Q复制到 Q ^ ( s , a , w ) \hat Q(s, a, w) Q^(s,a,w)

10)从步骤2开始重复,直到收敛为止。

定义DQN网络

import collections
import copy
import random
from collections import defaultdict
import math
import gym
import gym.spaces
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from gym.envs.toy_text import frozen_lake
from torch.utils.tensorboard import SummaryWriterclass Net(nn.Module):def __init__(self, obs_size, hidden_size, q_table_size):super(Net, self).__init__()self.net = nn.Sequential(# 输入为状态,样本为(1*n)nn.Linear(obs_size, hidden_size),nn.ReLU(),# nn.Linear(hidden_size, hidden_size),# nn.ReLU(),nn.Linear(hidden_size, q_table_size),)def forward(self, state):return self.net(state)class DQN:def __init__(self, env, tgt_net, net):self.env = envself.tgt_net = tgt_netself.net = netdef generate_train_data(self, batch_size, epsilon):state, _ = env.reset()train_data = []while len(train_data)<batch_size*2:q_table_tgt = self.tgt_net(torch.Tensor(state)).detach()if np.random.uniform(0, 1, 1) > epsilon:action = self.env.action_space.sample()else:action = int(torch.argmax(q_table_tgt))new_state, reward,terminated, truncted, info = env.step(action)train_data.append([state, action, reward, new_state, terminated])state = new_stateif terminated:state, _ = env.reset()continuerandom.shuffle(train_data)                return train_data[:batch_size]def calculate_y_hat_and_y(self, batch):# 6)对于回放缓冲区中的每个转移过程,如果片段在此步结束,则计算目标$y=r$,否则计算$y=r+\gamma max \hat Q(s, a, w)$ 。y = []state_space = []action_space = []for state, action, reward, new_state, terminated in batch:# y值if terminated:y.append(reward)else:# 下一步的 qtable 的最大值q_table_net = self.net(torch.Tensor(np.array([new_state]))).detach()y.append(reward + gamma * float(torch.max(q_table_net)))# y hat的值state_space.append(state)action_space.append(action)idx = [list(range(len(action_space))), action_space]y_hat = self.tgt_net(torch.Tensor(np.array(state_space)))[idx]return y_hat, torch.tensor(y)def update_net_parameters(self, update=True):self.net.load_state_dict(self.tgt_net.state_dict())

初始化环境

   # 初始化环境
env = gym.make("CartPole-v1")
# env = DiscreteOneHotWrapper(env)hidden_num = 64
# 定义网络
net = Net(env.observation_space.shape[0],hidden_num, env.action_space.n)
tgt_net = Net(env.observation_space.shape[0],hidden_num, env.action_space.n)
dqn = DQN(env=env, net=net, tgt_net=tgt_net)# 初始化参数
# dqn.init_net_and_target_net_weight()# 定义优化器
opt = optim.Adam(tgt_net.parameters(), lr=0.001)# 定义损失函数
loss = nn.MSELoss()# 记录训练过程
# writer = SummaryWriter(log_dir="logs/DQN", comment="DQN")

开始训练

gamma = 0.8
for i in range(10000):batch = dqn.generate_train_data(256, 0.8)y_hat, y = dqn.calculate_y_hat_and_y(batch)opt.zero_grad()l = loss(y_hat, y)l.backward()opt.step()print("MSE: {}".format(l.item()))if i % 5 == 0:dqn.update_net_parameters(update=True)

输出:

MSE: 0.027348674833774567
MSE: 0.1803671419620514
MSE: 0.06523636728525162
MSE: 0.08363766968250275
MSE: 0.062360599637031555
MSE: 0.004909628536552191
MSE: 0.05730309337377548
MSE: 0.03543371334671974
MSE: 0.08458714932203293

可视化结果

env = gym.make("CartPole-v1", render_mode = "human")
env = gym.wrappers.RecordVideo(env, video_folder="video")state, info = env.reset()
total_rewards = 0while True:q_table_state = dqn.tgt_net(torch.Tensor(state)).detach()# if np.random.uniform(0, 1, 1) > 0.9:#     action = env.action_space.sample()# else:action = int(torch.argmax(q_table_state))state, reward, terminated, truncted, info = env.step(action)if terminated:break

这篇关于强化学习原理python篇06——DQN的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/652678

相关文章

使用Python的requests库调用API接口的详细步骤

《使用Python的requests库调用API接口的详细步骤》使用Python的requests库调用API接口是开发中最常用的方式之一,它简化了HTTP请求的处理流程,以下是详细步骤和实战示例,涵... 目录一、准备工作:安装 requests 库二、基本调用流程(以 RESTful API 为例)1.

Python清空Word段落样式的三种方法

《Python清空Word段落样式的三种方法》:本文主要介绍如何用python-docx库清空Word段落样式,提供三种方法:设置为Normal样式、清除直接格式、创建新Normal样式,注意需重... 目录方法一:直接设置段落样式为"Normal"方法二:清除所有直接格式设置方法三:创建新的Normal样

Python调用LibreOffice处理自动化文档的完整指南

《Python调用LibreOffice处理自动化文档的完整指南》在数字化转型的浪潮中,文档处理自动化已成为提升效率的关键,LibreOffice作为开源办公软件的佼佼者,其命令行功能结合Python... 目录引言一、环境搭建:三步构建自动化基石1. 安装LibreOffice与python2. 验证安装

把Python列表中的元素移动到开头的三种方法

《把Python列表中的元素移动到开头的三种方法》在Python编程中,我们经常需要对列表(list)进行操作,有时,我们希望将列表中的某个元素移动到最前面,使其成为第一项,本文给大家介绍了把Pyth... 目录一、查找删除插入法1. 找到元素的索引2. 移除元素3. 插入到列表开头二、使用列表切片(Lis

Python按照24个实用大方向精选的上千种工具库汇总整理

《Python按照24个实用大方向精选的上千种工具库汇总整理》本文整理了Python生态中近千个库,涵盖数据处理、图像处理、网络开发、Web框架、人工智能、科学计算、GUI工具、测试框架、环境管理等多... 目录1、数据处理文本处理特殊文本处理html/XML 解析文件处理配置文件处理文档相关日志管理日期和

Python标准库datetime模块日期和时间数据类型解读

《Python标准库datetime模块日期和时间数据类型解读》文章介绍Python中datetime模块的date、time、datetime类,用于处理日期、时间及日期时间结合体,通过属性获取时间... 目录Datetime常用类日期date类型使用时间 time 类型使用日期和时间的结合体–日期时间(

使用Python开发一个Ditto剪贴板数据导出工具

《使用Python开发一个Ditto剪贴板数据导出工具》在日常工作中,我们经常需要处理大量的剪贴板数据,下面将介绍如何使用Python的wxPython库开发一个图形化工具,实现从Ditto数据库中读... 目录前言运行结果项目需求分析技术选型核心功能实现1. Ditto数据库结构分析2. 数据库自动定位3

Python yield与yield from的简单使用方式

《Pythonyield与yieldfrom的简单使用方式》生成器通过yield定义,可在处理I/O时暂停执行并返回部分结果,待其他任务完成后继续,yieldfrom用于将一个生成器的值传递给另一... 目录python yield与yield from的使用代码结构总结Python yield与yield

python使用Akshare与Streamlit实现股票估值分析教程(图文代码)

《python使用Akshare与Streamlit实现股票估值分析教程(图文代码)》入职测试中的一道题,要求:从Akshare下载某一个股票近十年的财务报表包括,资产负债表,利润表,现金流量表,保存... 目录一、前言二、核心知识点梳理1、Akshare数据获取2、Pandas数据处理3、Matplotl

Django开发时如何避免频繁发送短信验证码(python图文代码)

《Django开发时如何避免频繁发送短信验证码(python图文代码)》Django开发时,为防止频繁发送验证码,后端需用Redis限制请求频率,结合管道技术提升效率,通过生产者消费者模式解耦业务逻辑... 目录避免频繁发送 验证码1. www.chinasem.cn避免频繁发送 验证码逻辑分析2. 避免频繁