【动手学强化学习】SAC算法

2023-11-23 14:31
文章标签 算法 学习 强化 动手 sac

本文主要是介绍【动手学强化学习】SAC算法,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

知乎上一篇对于SAC算法讲解十分优质的博客:https://zhuanlan.zhihu.com/p/85003758

1.slide

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
α越高分布越平缓在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

2.算法伪码

在这里插入图片描述
在这里插入图片描述

3.代码

应用SAC算法实现倒立摆智能体的训练

import random
import gym
import numpy as np
from tqdm import tqdm
import torch
import torch.nn.functional as F
from torch.distributions import Normal
import matplotlib.pyplot as plt
import rl_utilsclass PolicyNetContinuous(torch.nn.Module):def __init__(self, state_dim, hidden_dim, action_dim, action_bound):super(PolicyNetContinuous, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc_mu = torch.nn.Linear(hidden_dim, action_dim)self.fc_std = torch.nn.Linear(hidden_dim, action_dim)self.action_bound = action_bounddef forward(self, x):x = F.relu(self.fc1(x))mu = self.fc_mu(x)std = F.softplus(self.fc_std(x))dist = Normal(mu, std)normal_sample = dist.rsample()  # rsample()是重参数化采样log_prob = dist.log_prob(normal_sample)action = torch.tanh(normal_sample)# 计算tanh_normal分布的对数概率密度log_prob = log_prob - torch.log(1 - torch.tanh(action).pow(2) + 1e-7)action = action * self.action_boundreturn action, log_probclass QValueNetContinuous(torch.nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super(QValueNetContinuous, self).__init__()self.fc1 = torch.nn.Linear(state_dim + action_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim)self.fc_out = torch.nn.Linear(hidden_dim, 1)def forward(self, x, a):cat = torch.cat([x, a], dim=1)x = F.relu(self.fc1(cat))x = F.relu(self.fc2(x))return self.fc_out(x)class SACContinuous:''' 处理连续动作的SAC算法 '''def __init__(self, state_dim, hidden_dim, action_dim, action_bound,actor_lr, critic_lr, alpha_lr, target_entropy, tau, gamma,device):self.actor = PolicyNetContinuous(state_dim, hidden_dim, action_dim,action_bound).to(device)  # 策略网络self.critic_1 = QValueNetContinuous(state_dim, hidden_dim,action_dim).to(device)  # 第一个Q网络self.critic_2 = QValueNetContinuous(state_dim, hidden_dim,action_dim).to(device)  # 第二个Q网络self.target_critic_1 = QValueNetContinuous(state_dim,hidden_dim, action_dim).to(device)  # 第一个目标Q网络self.target_critic_2 = QValueNetContinuous(state_dim,hidden_dim, action_dim).to(device)  # 第二个目标Q网络# 令目标Q网络的初始参数和Q网络一样self.target_critic_1.load_state_dict(self.critic_1.state_dict())self.target_critic_2.load_state_dict(self.critic_2.state_dict())self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),lr=actor_lr)self.critic_1_optimizer = torch.optim.Adam(self.critic_1.parameters(),lr=critic_lr)self.critic_2_optimizer = torch.optim.Adam(self.critic_2.parameters(),lr=critic_lr)# 使用alpha的log值,可以使训练结果比较稳定self.log_alpha = torch.tensor(np.log(0.01), dtype=torch.float)self.log_alpha.requires_grad = True  # 可以对alpha求梯度self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha],lr=alpha_lr)self.target_entropy = target_entropy  # 目标熵的大小self.gamma = gammaself.tau = tauself.device = devicedef take_action(self, state):state = torch.tensor([state], dtype=torch.float).to(self.device)action = self.actor(state)[0]return [action.item()]def calc_target(self, rewards, next_states, dones):  # 计算目标Q值next_actions, log_prob = self.actor(next_states)entropy = -log_probq1_value = self.target_critic_1(next_states, next_actions)q2_value = self.target_critic_2(next_states, next_actions)next_value = torch.min(q1_value,q2_value) + self.log_alpha.exp() * entropytd_target = rewards + self.gamma * next_value * (1 - dones)return td_targetdef soft_update(self, net, target_net):for param_target, param in zip(target_net.parameters(),net.parameters()):param_target.data.copy_(param_target.data * (1.0 - self.tau) +param.data * self.tau)def update(self, transition_dict):states = torch.tensor(transition_dict['states'],dtype=torch.float).to(self.device)actions = torch.tensor(transition_dict['actions'],dtype=torch.float).view(-1, 1).to(self.device)rewards = torch.tensor(transition_dict['rewards'],dtype=torch.float).view(-1, 1).to(self.device)next_states = torch.tensor(transition_dict['next_states'],dtype=torch.float).to(self.device)dones = torch.tensor(transition_dict['dones'],dtype=torch.float).view(-1, 1).to(self.device)# 和之前章节一样,对倒立摆环境的奖励进行重塑以便训练rewards = (rewards + 8.0) / 8.0# 更新两个Q网络td_target = self.calc_target(rewards, next_states, dones)critic_1_loss = torch.mean(F.mse_loss(self.critic_1(states, actions), td_target.detach()))critic_2_loss = torch.mean(F.mse_loss(self.critic_2(states, actions), td_target.detach()))self.critic_1_optimizer.zero_grad()critic_1_loss.backward()self.critic_1_optimizer.step()self.critic_2_optimizer.zero_grad()critic_2_loss.backward()self.critic_2_optimizer.step()# 更新策略网络new_actions, log_prob = self.actor(states)entropy = -log_probq1_value = self.critic_1(states, new_actions)q2_value = self.critic_2(states, new_actions)actor_loss = torch.mean(-self.log_alpha.exp() * entropy -torch.min(q1_value, q2_value))self.actor_optimizer.zero_grad()actor_loss.backward()self.actor_optimizer.step()# 更新alpha值alpha_loss = torch.mean((entropy - self.target_entropy).detach() * self.log_alpha.exp())self.log_alpha_optimizer.zero_grad()alpha_loss.backward()self.log_alpha_optimizer.step()self.soft_update(self.critic_1, self.target_critic_1)self.soft_update(self.critic_2, self.target_critic_2)env_name = 'Pendulum-v0'
env = gym.make(env_name)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
action_bound = env.action_space.high[0]  # 动作最大值
random.seed(0)
np.random.seed(0)
env.seed(0)
torch.manual_seed(0)actor_lr = 3e-4
critic_lr = 3e-3
alpha_lr = 3e-4
num_episodes = 100
hidden_dim = 128
gamma = 0.99
tau = 0.005  # 软更新参数
buffer_size = 100000
minimal_size = 1000
batch_size = 64
target_entropy = -env.action_space.shape[0]
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")replay_buffer = rl_utils.ReplayBuffer(buffer_size)
agent = SACContinuous(state_dim, hidden_dim, action_dim, action_bound,actor_lr, critic_lr, alpha_lr, target_entropy, tau,gamma, device)return_list = rl_utils.train_off_policy_agent(env, agent, num_episodes,replay_buffer, minimal_size,batch_size)episodes_list = list(range(len(return_list)))
plt.plot(episodes_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('SAC on {}'.format(env_name))
plt.show()mv_return = rl_utils.moving_average(return_list, 9)
plt.plot(episodes_list, mv_return)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('SAC on {}'.format(env_name))
plt.show()class PolicyNet(torch.nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super(PolicyNet, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, action_dim)def forward(self, x):x = F.relu(self.fc1(x))return F.softmax(self.fc2(x), dim=1)class QValueNet(torch.nn.Module):''' 只有一层隐藏层的Q网络 '''def __init__(self, state_dim, hidden_dim, action_dim):super(QValueNet, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, action_dim)def forward(self, x):x = F.relu(self.fc1(x))return self.fc2(x)class SAC:''' 处理离散动作的SAC算法 '''def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr,alpha_lr, target_entropy, tau, gamma, device):# 策略网络self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)# 第一个Q网络self.critic_1 = QValueNet(state_dim, hidden_dim, action_dim).to(device)# 第二个Q网络self.critic_2 = QValueNet(state_dim, hidden_dim, action_dim).to(device)self.target_critic_1 = QValueNet(state_dim, hidden_dim,action_dim).to(device)  # 第一个目标Q网络self.target_critic_2 = QValueNet(state_dim, hidden_dim,action_dim).to(device)  # 第二个目标Q网络# 令目标Q网络的初始参数和Q网络一样self.target_critic_1.load_state_dict(self.critic_1.state_dict())self.target_critic_2.load_state_dict(self.critic_2.state_dict())self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),lr=actor_lr)self.critic_1_optimizer = torch.optim.Adam(self.critic_1.parameters(),lr=critic_lr)self.critic_2_optimizer = torch.optim.Adam(self.critic_2.parameters(),lr=critic_lr)# 使用alpha的log值,可以使训练结果比较稳定self.log_alpha = torch.tensor(np.log(0.01), dtype=torch.float)self.log_alpha.requires_grad = True  # 可以对alpha求梯度self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha],lr=alpha_lr)self.target_entropy = target_entropy  # 目标熵的大小self.gamma = gammaself.tau = tauself.device = devicedef take_action(self, state):state = torch.tensor([state], dtype=torch.float).to(self.device)probs = self.actor(state)action_dist = torch.distributions.Categorical(probs)action = action_dist.sample()return action.item()# 计算目标Q值,直接用策略网络的输出概率进行期望计算def calc_target(self, rewards, next_states, dones):next_probs = self.actor(next_states)next_log_probs = torch.log(next_probs + 1e-8)entropy = -torch.sum(next_probs * next_log_probs, dim=1, keepdim=True)q1_value = self.target_critic_1(next_states)q2_value = self.target_critic_2(next_states)min_qvalue = torch.sum(next_probs * torch.min(q1_value, q2_value),dim=1,keepdim=True)next_value = min_qvalue + self.log_alpha.exp() * entropytd_target = rewards + self.gamma * next_value * (1 - dones)return td_targetdef soft_update(self, net, target_net):for param_target, param in zip(target_net.parameters(),net.parameters()):param_target.data.copy_(param_target.data * (1.0 - self.tau) +param.data * self.tau)def update(self, transition_dict):states = torch.tensor(transition_dict['states'],dtype=torch.float).to(self.device)actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(self.device)  # 动作不再是float类型rewards = torch.tensor(transition_dict['rewards'],dtype=torch.float).view(-1, 1).to(self.device)next_states = torch.tensor(transition_dict['next_states'],dtype=torch.float).to(self.device)dones = torch.tensor(transition_dict['dones'],dtype=torch.float).view(-1, 1).to(self.device)# 更新两个Q网络td_target = self.calc_target(rewards, next_states, dones)critic_1_q_values = self.critic_1(states).gather(1, actions)critic_1_loss = torch.mean(F.mse_loss(critic_1_q_values, td_target.detach()))critic_2_q_values = self.critic_2(states).gather(1, actions)critic_2_loss = torch.mean(F.mse_loss(critic_2_q_values, td_target.detach()))self.critic_1_optimizer.zero_grad()critic_1_loss.backward()self.critic_1_optimizer.step()self.critic_2_optimizer.zero_grad()critic_2_loss.backward()self.critic_2_optimizer.step()# 更新策略网络probs = self.actor(states)log_probs = torch.log(probs + 1e-8)# 直接根据概率计算熵entropy = -torch.sum(probs * log_probs, dim=1, keepdim=True)  #q1_value = self.critic_1(states)q2_value = self.critic_2(states)min_qvalue = torch.sum(probs * torch.min(q1_value, q2_value),dim=1,keepdim=True)  # 直接根据概率计算期望actor_loss = torch.mean(-self.log_alpha.exp() * entropy - min_qvalue)self.actor_optimizer.zero_grad()actor_loss.backward()self.actor_optimizer.step()# 更新alpha值alpha_loss = torch.mean((entropy - target_entropy).detach() * self.log_alpha.exp())self.log_alpha_optimizer.zero_grad()alpha_loss.backward()self.log_alpha_optimizer.step()self.soft_update(self.critic_1, self.target_critic_1)self.soft_update(self.critic_2, self.target_critic_2)actor_lr = 1e-3
critic_lr = 1e-2
alpha_lr = 1e-2
num_episodes = 200
hidden_dim = 128
gamma = 0.98
tau = 0.005  # 软更新参数
buffer_size = 10000
minimal_size = 500
batch_size = 64
target_entropy = -1
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")env_name = 'CartPole-v0'
env = gym.make(env_name)
random.seed(0)
np.random.seed(0)
env.seed(0)
torch.manual_seed(0)
replay_buffer = rl_utils.ReplayBuffer(buffer_size)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = SAC(state_dim, hidden_dim, action_dim, actor_lr, critic_lr, alpha_lr,target_entropy, tau, gamma, device)return_list = rl_utils.train_off_policy_agent(env, agent, num_episodes,replay_buffer, minimal_size,batch_size)episodes_list = list(range(len(return_list)))
plt.plot(episodes_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('SAC on {}'.format(env_name))
plt.show()mv_return = rl_utils.moving_average(return_list, 9)
plt.plot(episodes_list, mv_return)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('SAC on {}'.format(env_name))
plt.show()

这篇关于【动手学强化学习】SAC算法的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/418586

相关文章

Java中的雪花算法Snowflake解析与实践技巧

《Java中的雪花算法Snowflake解析与实践技巧》本文解析了雪花算法的原理、Java实现及生产实践,涵盖ID结构、位运算技巧、时钟回拨处理、WorkerId分配等关键点,并探讨了百度UidGen... 目录一、雪花算法核心原理1.1 算法起源1.2 ID结构详解1.3 核心特性二、Java实现解析2.

Go学习记录之runtime包深入解析

《Go学习记录之runtime包深入解析》Go语言runtime包管理运行时环境,涵盖goroutine调度、内存分配、垃圾回收、类型信息等核心功能,:本文主要介绍Go学习记录之runtime包的... 目录前言:一、runtime包内容学习1、作用:① Goroutine和并发控制:② 垃圾回收:③ 栈和

Android学习总结之Java和kotlin区别超详细分析

《Android学习总结之Java和kotlin区别超详细分析》Java和Kotlin都是用于Android开发的编程语言,它们各自具有独特的特点和优势,:本文主要介绍Android学习总结之Ja... 目录一、空安全机制真题 1:Kotlin 如何解决 Java 的 NullPointerExceptio

使用雪花算法产生id导致前端精度缺失问题解决方案

《使用雪花算法产生id导致前端精度缺失问题解决方案》雪花算法由Twitter提出,设计目的是生成唯一的、递增的ID,下面:本文主要介绍使用雪花算法产生id导致前端精度缺失问题的解决方案,文中通过代... 目录一、问题根源二、解决方案1. 全局配置Jackson序列化规则2. 实体类必须使用Long封装类3.

重新对Java的类加载器的学习方式

《重新对Java的类加载器的学习方式》:本文主要介绍重新对Java的类加载器的学习方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1、介绍1.1、简介1.2、符号引用和直接引用1、符号引用2、直接引用3、符号转直接的过程2、加载流程3、类加载的分类3.1、显示

Springboot实现推荐系统的协同过滤算法

《Springboot实现推荐系统的协同过滤算法》协同过滤算法是一种在推荐系统中广泛使用的算法,用于预测用户对物品(如商品、电影、音乐等)的偏好,从而实现个性化推荐,下面给大家介绍Springboot... 目录前言基本原理 算法分类 计算方法应用场景 代码实现 前言协同过滤算法(Collaborativ

Java学习手册之Filter和Listener使用方法

《Java学习手册之Filter和Listener使用方法》:本文主要介绍Java学习手册之Filter和Listener使用方法的相关资料,Filter是一种拦截器,可以在请求到达Servl... 目录一、Filter(过滤器)1. Filter 的工作原理2. Filter 的配置与使用二、Listen

openCV中KNN算法的实现

《openCV中KNN算法的实现》KNN算法是一种简单且常用的分类算法,本文主要介绍了openCV中KNN算法的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的... 目录KNN算法流程使用OpenCV实现KNNOpenCV 是一个开源的跨平台计算机视觉库,它提供了各

springboot+dubbo实现时间轮算法

《springboot+dubbo实现时间轮算法》时间轮是一种高效利用线程资源进行批量化调度的算法,本文主要介绍了springboot+dubbo实现时间轮算法,文中通过示例代码介绍的非常详细,对大家... 目录前言一、参数说明二、具体实现1、HashedwheelTimer2、createWheel3、n

SpringBoot实现MD5加盐算法的示例代码

《SpringBoot实现MD5加盐算法的示例代码》加盐算法是一种用于增强密码安全性的技术,本文主要介绍了SpringBoot实现MD5加盐算法的示例代码,文中通过示例代码介绍的非常详细,对大家的学习... 目录一、什么是加盐算法二、如何实现加盐算法2.1 加盐算法代码实现2.2 注册页面中进行密码加盐2.