DI-engine强化学习入门(十又二分之一)如何使用RNN——数据处理、隐藏状态、Burn-in

本文主要是介绍DI-engine强化学习入门(十又二分之一)如何使用RNN——数据处理、隐藏状态、Burn-in,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

一、数据处理
用于训练 RNN 的 mini-batch 数据不同于通常的数据。 这些数据通常应按时间序列排列。 对于 DI-engine, 这个处理是在 collector 阶段完成的。 用户需要在配置文件中指定 learn_unroll_len 以确保序列数据的长度与算法匹配。 对于大多数情况, learn_unroll_len 应该等于 RNN 的历史长度(a.k.a 时间序列长度),但在某些情况下并非如此。比如,在 r2d2 中, 我们使用burn-in操作, 序列长度等于 learn_unroll_len + burnin_step 。 这里将在下一节中具体解释。

什么是数据处理?
数据处理指的是为循环神经网络(RNN)训练准备时间序列数据的过程。这个过程包括将收集到的数据组织成适当格式的小批量(mini-batches),这些批量数据将用于网络的训练。这一步骤通常发生在DI-engine的collector阶段,也就是数据收集和预处理发生的地方。用户需要在配置文件中指定 learn_unroll_len 以确保序列数据的长度与算法匹配。 对于大多数情况, learn_unroll_len 应该等于 RNN 的历史长度(a.k.a 时间序列长度),但在某些情况下并非如此。比如,在 r2d2 中, 我们使用burn-in操作, 序列长度等于 learn_unroll_len + burnin_step 。例如,如果你设置 learn_unroll_len = 10 和 burnin_step = 5,那么 RNN 实际接收的输入序列长度将是 15:前 5 步为 burn-in(用于预热隐藏状态),接下来的 10 步作为学习的一部分。这样设置可以帮助 RNN 在计算梯度和进行权重更新时,有一个更加准确的隐藏状态作为起点。
部分名词解释

  • mini-batches:在机器学习中,特别是在训练神经网络时,数据一般被分成小的批次进行处理,这些批次被称为 “mini-batch”。一个 mini-batch 包含了一组样本,这组样本用于执行单次迭代的前向传播和反向传播,以更新网络的权重。使用 mini-batches 而不是单个样本或整个数据集(后者称为 “batch” 或 “full-batch”)可以平衡计算效率和内存限制,有助于提高学习的稳定性和收敛速度。
  • collector阶段:在 DI-engine中,collector 阶段是指环境与智能体交互并收集经验数据的过程。在这个阶段,智能体根据其当前的策略执行操作,环境则返回新的状态、奖励和其他可能的信息,如是否达到终止状态。收集到的数据(经常被称为经验或转换)随后被用于训练智能体的模型,例如对策略或价值函数进行更新。

为什么要进行数据处理:

  1. 保持时间依赖性:RNN的核心优势是处理具有时间序列依赖性的数据,比如语言、视频帧、股票价格等。正确的数据处理确保了这些时间依赖性在训练数据中得以保留,使得模型能够学习到数据中的序列特征。
  2. 提高学习效率:通过将数据划分为与模型期望的序列长度匹配的批次,可以提高模型学习的效率。这样做可以确保网络在每次更新时都接收到足够的上下文信息。
  3. 适配算法要求:不同的RNN算法可能需要不同形式的输入数据。例如,标准的RNN只需要过去的信息,而一些变体如LSTM或GRU可能会处理更长的序列。特定的算法,如R2D2,还可能需要额外的步骤(如burn-in),以便更好地初始化网络状态。
  4. 处理不规则长度:在现实世界的数据集中,序列长度往往是不规则的。数据处理确保了每个mini-batch都有统一的序列长度,这通常通过截断过长的序列或填充过短的序列来实现。
  5. 优化内存和计算资源:通过将数据组织成具有固定时间步长的批次,可以更有效地利用GPU等计算资源,因为这些资源在处理固定大小的数据时通常更高效。
  6. 稳定学习过程:特别是在强化学习中,使用如n-step返回或经验回放的技术,可以帮助模型从环境反馈中学习,并减少方差,从而稳定学习过程。

如何进行数据处理

def _get_train_sample(self, data: list) -> Union[None, List[Any]]:    data = get_nstep_return_data(data, self._nstep, gamma=self._gamma)    return get_train_sample(data, self._sequence_len)

 代码段 def _get_train_sample(self, data: list) 是一个方法,它的作用是从收集到的数据中提取用于训练 RNN 的样本。这个方法会在两个步骤中处理数据:

  • N步返回计算(get_nstep_return_data): 这个函数接受原始的经验数据,然后计算所谓的 N 步返回值。N 步返回是一个在强化学习中用于临时差分(Temporal Difference, TD)学习的概念,它考虑了从当前状态开始的未来 N 步的累积奖励。计算这个值需要使用折现因子 gamma。这个步骤的目的是为了让智能体学习如何根据当前的行动预测未来的奖励,这是强化学习中价值函数估计的重要部分。
  • 训练样本获取(get_train_sample): 在得到 N 步返回值之后,这个函数进一步处理数据以生成训练样本。具体地,它会根据 self._sequence_len(即时间序列长度或者 RNN 的历史长度)来选择数据序列。这意味着每个训练样本将是一个具有 self._sequence_len 长度的数据序列,这对于训练 RNN 来说是必要的,因为 RNN 需要一定长度的历史来维护其内部状态(或记忆)。

有关这两个数据处理功能的工作流程见下图:

二、初始化隐藏状态 (Hidden State)
RNN用于处理具有时间依赖性的信息。RNN的隐藏状态(Hidden State)是其记忆的一部分,它能够捕捉到前一时间步长的信息。这些信息对于预测下一个动作或状态非常关键。在此上下文中,初始化RNN的隐藏状态是一个重要的步骤,它确保了RNN在开始新的数据批次处理时具有正确的起始状态。
策略的 _learn_model 需要初始化 RNN。这些隐藏状态来自 _collect_model 保存的 prev_state。 用户需要通过 _process_transition 函数将这些状态添加到 _learn_model 输入数据字典中。 

def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict:    transition = {        'obs': obs,        'action': model_output['action'],        'prev_state': model_output['prev_state'], # add ``prev_state`` key here        'reward': timestep.reward,        'done': timestep.done,    }    return transition

点击DI-engine强化学习入门(十又二分之一)如何使用RNN——数据处理、隐藏状态、Burn-in - 古月居 可查看全文

这篇关于DI-engine强化学习入门(十又二分之一)如何使用RNN——数据处理、隐藏状态、Burn-in的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/969655

相关文章

Java中流式并行操作parallelStream的原理和使用方法

《Java中流式并行操作parallelStream的原理和使用方法》本文详细介绍了Java中的并行流(parallelStream)的原理、正确使用方法以及在实际业务中的应用案例,并指出在使用并行流... 目录Java中流式并行操作parallelStream0. 问题的产生1. 什么是parallelS

Linux join命令的使用及说明

《Linuxjoin命令的使用及说明》`join`命令用于在Linux中按字段将两个文件进行连接,类似于SQL的JOIN,它需要两个文件按用于匹配的字段排序,并且第一个文件的换行符必须是LF,`jo... 目录一. 基本语法二. 数据准备三. 指定文件的连接key四.-a输出指定文件的所有行五.-o指定输出

Linux jq命令的使用解读

《Linuxjq命令的使用解读》jq是一个强大的命令行工具,用于处理JSON数据,它可以用来查看、过滤、修改、格式化JSON数据,通过使用各种选项和过滤器,可以实现复杂的JSON处理任务... 目录一. 简介二. 选项2.1.2.2-c2.3-r2.4-R三. 字段提取3.1 普通字段3.2 数组字段四.

Linux kill正在执行的后台任务 kill进程组使用详解

《Linuxkill正在执行的后台任务kill进程组使用详解》文章介绍了两个脚本的功能和区别,以及执行这些脚本时遇到的进程管理问题,通过查看进程树、使用`kill`命令和`lsof`命令,分析了子... 目录零. 用到的命令一. 待执行的脚本二. 执行含子进程的脚本,并kill2.1 进程查看2.2 遇到的

详解SpringBoot+Ehcache使用示例

《详解SpringBoot+Ehcache使用示例》本文介绍了SpringBoot中配置Ehcache、自定义get/set方式,并实际使用缓存的过程,文中通过示例代码介绍的非常详细,对大家的学习或者... 目录摘要概念内存与磁盘持久化存储:配置灵活性:编码示例引入依赖:配置ehcache.XML文件:配置

Java 虚拟线程的创建与使用深度解析

《Java虚拟线程的创建与使用深度解析》虚拟线程是Java19中以预览特性形式引入,Java21起正式发布的轻量级线程,本文给大家介绍Java虚拟线程的创建与使用,感兴趣的朋友一起看看吧... 目录一、虚拟线程简介1.1 什么是虚拟线程?1.2 为什么需要虚拟线程?二、虚拟线程与平台线程对比代码对比示例:三

k8s按需创建PV和使用PVC详解

《k8s按需创建PV和使用PVC详解》Kubernetes中,PV和PVC用于管理持久存储,StorageClass实现动态PV分配,PVC声明存储需求并绑定PV,通过kubectl验证状态,注意回收... 目录1.按需创建 PV(使用 StorageClass)创建 StorageClass2.创建 PV

Redis 基本数据类型和使用详解

《Redis基本数据类型和使用详解》String是Redis最基本的数据类型,一个键对应一个值,它的功能十分强大,可以存储字符串、整数、浮点数等多种数据格式,本文给大家介绍Redis基本数据类型和... 目录一、Redis 入门介绍二、Redis 的五大基本数据类型2.1 String 类型2.2 Hash

Redis中Hash从使用过程到原理说明

《Redis中Hash从使用过程到原理说明》RedisHash结构用于存储字段-值对,适合对象数据,支持HSET、HGET等命令,采用ziplist或hashtable编码,通过渐进式rehash优化... 目录一、开篇:Hash就像超市的货架二、Hash的基本使用1. 常用命令示例2. Java操作示例三

Linux创建服务使用systemctl管理详解

《Linux创建服务使用systemctl管理详解》文章指导在Linux中创建systemd服务,设置文件权限为所有者读写、其他只读,重新加载配置,启动服务并检查状态,确保服务正常运行,关键步骤包括权... 目录创建服务 /usr/lib/systemd/system/设置服务文件权限:所有者读写js,其他