深度强化学习系列tensorflow2.0自定义loss函数实现policy gradient策略梯度

本文主要是介绍深度强化学习系列tensorflow2.0自定义loss函数实现policy gradient策略梯度,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

本篇文章利用tensorflow2.0自定义loss函数实现policy gradient策略梯度,自定义loss=-log(prob) *Vt
现在训练最高分能到193分,但是还是不稳定,在修改中,欢迎一起探讨文章代码也有参考莫烦大佬的代码action_dim = 2 //定义动作
state_dim = 4 //定义状态
env = gym.make('CartPole-v0')
class PGModel(tf.keras.Model):def __init__(self):super().__init__()self.dense1 = layers.Dense(128,input_dim=state_dim,activation='relu')layers.Dropout(0.1)self.all_acts = layers.Dense(units=action_dim)self.x = 0def call(self,inputs):x = self.dense1(inputs)x = self.all_acts(x)self.x = xoutput = tf.nn.softmax(x)return outputclass PG():def __init__(self):self.model = PGModel()def choose_action(self, s):prob = self.model.predict(np.array([s]))[0]#print(prob)return np.random.choice(len(prob),p=prob)def discount_reward(self,rewards,gamma=0.95): #衰减reward 通过最后一步奖励反推真实奖励out = np.zeros_like(rewards)dis_reward = 0for i in reversed(range(len(rewards))):dis_reward = dis_reward + gamma * rewards[i]  # 前一步的reward等于后一步衰减reward加上即时奖励乘以衰减因子out[i] = dis_rewardreturn  out/np.std(out - np.mean(out))def all_actf(self):all_act = self.model.xprint(all_act)return all_actdef reca_batch(self,a_batch):a = a_batchreturn adef def_loss(self,label=reca_batch,logit=all_actf):  //自定义loss函数neg_log_prob = tf.nn.softmax_cross_entropy_with_logits(labels=label,logits=logit)return neg_log_probdef train(self,records): #训练s_batch = np.array([record[0] for record in records]) #取状态,每次batch个状态a_batch = np.array([[1 if record[1]==i else 0 for i in range(action_dim)]for record in records])self.reca_batch(a_batch)prob_batch = self.model.predict(s_batch) * a_batchr_batch = self.discount_reward([record[2] for record in records ])self.model.compile(loss=self.def_loss,optimizer=optimizers.Adam(0.001))self.model.fit(s_batch,prob_batch,sample_weight=r_batch,verbose=1)episodes = 2000
score_list= []
pg = PG()for i in range(episodes):score = 0records = []s = env.reset()while True:a = pg.choose_action(s)#print(a)next_s,r,done,_ = env.step(a)records.append((s, a, r))s = next_sscore += rif done:pg.train(records)score_list.append(score)print("episode:", i, "score:", score, "maxscore:", max(score_list))breakif np.mean(score_list[-10:]) > 195:pg.model.save('CarPoleModel.h5')breakenv.close()

这篇关于深度强化学习系列tensorflow2.0自定义loss函数实现policy gradient策略梯度的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/921862

相关文章

Spring Security 单点登录与自动登录机制的实现原理

《SpringSecurity单点登录与自动登录机制的实现原理》本文探讨SpringSecurity实现单点登录(SSO)与自动登录机制,涵盖JWT跨系统认证、RememberMe持久化Token... 目录一、核心概念解析1.1 单点登录(SSO)1.2 自动登录(Remember Me)二、代码分析三、

PyCharm中配置PyQt的实现步骤

《PyCharm中配置PyQt的实现步骤》PyCharm是JetBrains推出的一款强大的PythonIDE,结合PyQt可以进行pythion高效开发桌面GUI应用程序,本文就来介绍一下PyCha... 目录1. 安装China编程PyQt1.PyQt 核心组件2. 基础 PyQt 应用程序结构3. 使用 Q

springboot自定义注解RateLimiter限流注解技术文档详解

《springboot自定义注解RateLimiter限流注解技术文档详解》文章介绍了限流技术的概念、作用及实现方式,通过SpringAOP拦截方法、缓存存储计数器,结合注解、枚举、异常类等核心组件,... 目录什么是限流系统架构核心组件详解1. 限流注解 (@RateLimiter)2. 限流类型枚举 (

Python实现批量提取BLF文件时间戳

《Python实现批量提取BLF文件时间戳》BLF(BinaryLoggingFormat)作为Vector公司推出的CAN总线数据记录格式,被广泛用于存储车辆通信数据,本文将使用Python轻松提取... 目录一、为什么需要批量处理 BLF 文件二、核心代码解析:从文件遍历到数据导出1. 环境准备与依赖库

linux下shell脚本启动jar包实现过程

《linux下shell脚本启动jar包实现过程》确保APP_NAME和LOG_FILE位于目录内,首次启动前需手动创建log文件夹,否则报错,此为个人经验,供参考,欢迎支持脚本之家... 目录linux下shell脚本启动jar包样例1样例2总结linux下shell脚本启动jar包样例1#!/bin

go动态限制并发数量的实现示例

《go动态限制并发数量的实现示例》本文主要介绍了Go并发控制方法,通过带缓冲通道和第三方库实现并发数量限制,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面... 目录带有缓冲大小的通道使用第三方库其他控制并发的方法因为go从语言层面支持并发,所以面试百分百会问到

Go语言并发之通知退出机制的实现

《Go语言并发之通知退出机制的实现》本文主要介绍了Go语言并发之通知退出机制的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧... 目录1、通知退出机制1.1 进程/main函数退出1.2 通过channel退出1.3 通过cont

Python实现PDF按页分割的技术指南

《Python实现PDF按页分割的技术指南》PDF文件处理是日常工作中的常见需求,特别是当我们需要将大型PDF文档拆分为多个部分时,下面我们就来看看如何使用Python创建一个灵活的PDF分割工具吧... 目录需求分析技术方案工具选择安装依赖完整代码实现使用说明基本用法示例命令输出示例技术亮点实际应用场景扩

java如何实现高并发场景下三级缓存的数据一致性

《java如何实现高并发场景下三级缓存的数据一致性》这篇文章主要为大家详细介绍了java如何实现高并发场景下三级缓存的数据一致性,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 下面代码是一个使用Java和Redisson实现的三级缓存服务,主要功能包括:1.缓存结构:本地缓存:使

SpringBoot 异常处理/自定义格式校验的问题实例详解

《SpringBoot异常处理/自定义格式校验的问题实例详解》文章探讨SpringBoot中自定义注解校验问题,区分参数级与类级约束触发的异常类型,建议通过@RestControllerAdvice... 目录1. 问题简要描述2. 异常触发1) 参数级别约束2) 类级别约束3. 异常处理1) 字段级别约束