深度强化学习系列tensorflow2.0自定义loss函数实现policy gradient策略梯度

本文主要是介绍深度强化学习系列tensorflow2.0自定义loss函数实现policy gradient策略梯度,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

本篇文章利用tensorflow2.0自定义loss函数实现policy gradient策略梯度,自定义loss=-log(prob) *Vt
现在训练最高分能到193分,但是还是不稳定,在修改中,欢迎一起探讨文章代码也有参考莫烦大佬的代码action_dim = 2 //定义动作
state_dim = 4 //定义状态
env = gym.make('CartPole-v0')
class PGModel(tf.keras.Model):def __init__(self):super().__init__()self.dense1 = layers.Dense(128,input_dim=state_dim,activation='relu')layers.Dropout(0.1)self.all_acts = layers.Dense(units=action_dim)self.x = 0def call(self,inputs):x = self.dense1(inputs)x = self.all_acts(x)self.x = xoutput = tf.nn.softmax(x)return outputclass PG():def __init__(self):self.model = PGModel()def choose_action(self, s):prob = self.model.predict(np.array([s]))[0]#print(prob)return np.random.choice(len(prob),p=prob)def discount_reward(self,rewards,gamma=0.95): #衰减reward 通过最后一步奖励反推真实奖励out = np.zeros_like(rewards)dis_reward = 0for i in reversed(range(len(rewards))):dis_reward = dis_reward + gamma * rewards[i]  # 前一步的reward等于后一步衰减reward加上即时奖励乘以衰减因子out[i] = dis_rewardreturn  out/np.std(out - np.mean(out))def all_actf(self):all_act = self.model.xprint(all_act)return all_actdef reca_batch(self,a_batch):a = a_batchreturn adef def_loss(self,label=reca_batch,logit=all_actf):  //自定义loss函数neg_log_prob = tf.nn.softmax_cross_entropy_with_logits(labels=label,logits=logit)return neg_log_probdef train(self,records): #训练s_batch = np.array([record[0] for record in records]) #取状态,每次batch个状态a_batch = np.array([[1 if record[1]==i else 0 for i in range(action_dim)]for record in records])self.reca_batch(a_batch)prob_batch = self.model.predict(s_batch) * a_batchr_batch = self.discount_reward([record[2] for record in records ])self.model.compile(loss=self.def_loss,optimizer=optimizers.Adam(0.001))self.model.fit(s_batch,prob_batch,sample_weight=r_batch,verbose=1)episodes = 2000
score_list= []
pg = PG()for i in range(episodes):score = 0records = []s = env.reset()while True:a = pg.choose_action(s)#print(a)next_s,r,done,_ = env.step(a)records.append((s, a, r))s = next_sscore += rif done:pg.train(records)score_list.append(score)print("episode:", i, "score:", score, "maxscore:", max(score_list))breakif np.mean(score_list[-10:]) > 195:pg.model.save('CarPoleModel.h5')breakenv.close()

这篇关于深度强化学习系列tensorflow2.0自定义loss函数实现policy gradient策略梯度的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!


原文地址:
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.chinasem.cn/article/921862

相关文章

Python实现自动化Word文档样式复制与内容生成

《Python实现自动化Word文档样式复制与内容生成》在办公自动化领域,高效处理Word文档的样式和内容复制是一个常见需求,本文将展示如何利用Python的python-docx库实现... 目录一、为什么需要自动化 Word 文档处理二、核心功能实现:样式与表格的深度复制1. 表格复制(含样式与内容)2

python获取cmd环境变量值的实现代码

《python获取cmd环境变量值的实现代码》:本文主要介绍在Python中获取命令行(cmd)环境变量的值,可以使用标准库中的os模块,需要的朋友可以参考下... 前言全局说明在执行py过程中,总要使用到系统环境变量一、说明1.1 环境:Windows 11 家庭版 24H2 26100.4061

Python中文件读取操作漏洞深度解析与防护指南

《Python中文件读取操作漏洞深度解析与防护指南》在Web应用开发中,文件操作是最基础也最危险的功能之一,这篇文章将全面剖析Python环境中常见的文件读取漏洞类型,成因及防护方案,感兴趣的小伙伴可... 目录引言一、静态资源处理中的路径穿越漏洞1.1 典型漏洞场景1.2 os.path.join()的陷

Python中bisect_left 函数实现高效插入与有序列表管理

《Python中bisect_left函数实现高效插入与有序列表管理》Python的bisect_left函数通过二分查找高效定位有序列表插入位置,与bisect_right的区别在于处理重复元素时... 目录一、bisect_left 基本介绍1.1 函数定义1.2 核心功能二、bisect_left 与

VSCode设置python SDK路径的实现步骤

《VSCode设置pythonSDK路径的实现步骤》本文主要介绍了VSCode设置pythonSDK路径的实现步骤,包括命令面板切换、settings.json配置、环境变量及虚拟环境处理,具有一定... 目录一、通过命令面板快速切换(推荐方法)二、通过 settings.json 配置(项目级/全局)三、

pandas实现数据concat拼接的示例代码

《pandas实现数据concat拼接的示例代码》pandas.concat用于合并DataFrame或Series,本文主要介绍了pandas实现数据concat拼接的示例代码,具有一定的参考价值,... 目录语法示例:使用pandas.concat合并数据默认的concat:参数axis=0,join=

Android学习总结之Java和kotlin区别超详细分析

《Android学习总结之Java和kotlin区别超详细分析》Java和Kotlin都是用于Android开发的编程语言,它们各自具有独特的特点和优势,:本文主要介绍Android学习总结之Ja... 目录一、空安全机制真题 1:Kotlin 如何解决 Java 的 NullPointerExceptio

java中BigDecimal里面的subtract函数介绍及实现方法

《java中BigDecimal里面的subtract函数介绍及实现方法》在Java中实现减法操作需要根据数据类型选择不同方法,主要分为数值型减法和字符串减法两种场景,本文给大家介绍java中BigD... 目录Java中BigDecimal里面的subtract函数的意思?一、数值型减法(高精度计算)1.

C#代码实现解析WTGPS和BD数据

《C#代码实现解析WTGPS和BD数据》在现代的导航与定位应用中,准确解析GPS和北斗(BD)等卫星定位数据至关重要,本文将使用C#语言实现解析WTGPS和BD数据,需要的可以了解下... 目录一、代码结构概览1. 核心解析方法2. 位置信息解析3. 经纬度转换方法4. 日期和时间戳解析5. 辅助方法二、L

C++/类与对象/默认成员函数@构造函数的用法

《C++/类与对象/默认成员函数@构造函数的用法》:本文主要介绍C++/类与对象/默认成员函数@构造函数的用法,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录名词概念默认成员函数构造函数概念函数特征显示构造函数隐式构造函数总结名词概念默认构造函数:不用传参就可以