pytorch nn.utils.rnn.pack_padded_sequence 分析

2023-10-17 22:50

本文主要是介绍pytorch nn.utils.rnn.pack_padded_sequence 分析,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

pack_padded_sequence

在nlp模型的forward方法中,可能有以下调用令读者疑惑

packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, text_lengths, batch_first=True, enforce_sorted=False)

为什么要使用pack_padded_sequence?

参考

  • Pytorch中的RNN之pack_padded_sequence()和pad_packed_sequence()
  • Pytorch中pack_padded_sequence和pad_packed_sequence的理解

当我们训练RNN时,如果想要进行批次化训练,由于句子的长短不一,所以需要截断和填充。

  • 为什么要截断?对于那些太长的句子,一般选择一个合适的长度来进行截断。
  • 为什么要填充?对于那些太短的句子,需要以 填充字符(比如<pad>)填充,使得该批次内所有的句子长度相同。

但是,填充会带来其它问题:

  • 增加了计算复杂度。假设一个批次内有2个句子,长度分别为5和2。我们要保证批次内所有的句子长度相同,就需要把长度为2的句子填充为5。这样喂给RNN时,需要计算 2 × 5 = 10 2 \times 5 =10 2×5=10次,而实际真正需要的是 5 + 2 = 7 5+2=7 5+2=7次。
  • 得到的结果可能不准确。我们知道RNN取的是最后一个时间步的隐藏状态做为输出,虽然在填充时,一般是以全0的词向量填充,RNN神经元的权重乘以零不会影响最终的输出,但还有偏差 b b b,如果 b ≠ 0 b \neq 0 b=0,还是会影响到最后的输出。

    当然这个问题不大,主要是第1个问题,毕竟批次大小很大的时候影响还是不小的。

我们用图解进一步说明这个问题。假设某句子“Yes”只有一个单词,但是填充了多余的pad符号,这样会导致LSTM对它的表示通过了非常多无用的字符,这样得到的句子表示就会有误差

那么我们正确的做法应该是怎么样呢?在上面这个例子,我们想要得到的仅仅是LSTM过完单词"Yes"之后的表示,而不是通过了多个无用的“Pad”得到的表示,如下图:

所以,Pytorch提供了pack_padded_sequence方法来压缩填充字符,加快RNN的计算效率。

pack_padded_sequence是如何压缩的?

那么它是如何做压缩的呢?举个例子,假如一个batch里有5个句子,长度分别是5、4、3、3、2、1。将它们按列压缩,在这个过程中删除了pad字符。所以你可以想象这样的训练过程:

  1. 第一个batch有5个单词,[I, I, This, No, Yes],它们被送入LSTM。
  2. 第二个batch有4个单词被送入LSTM。
  3. 以此类推,之后的batch长度逐渐减小,分别是3、3、2、1
    在这个过程中,pad字符被自然地忽略掉了。

pack_padded_sequence的参数含义

必备参数是句子向量embedded,以及每个句子长度的变量text_lengths。前者通常包含3个维度,即[批次大小、句子最大长度、单词向量长度](前两者顺序可换);后者通常是list类型,或者一维Tensor类型,包含了每个句子的长度。

  • batch_first表示输入的向量是batch维度优先的。
  • enforce_sorted代表输入的句子是否已经按照长度顺序排好,如果为False,那么函数估计会先按照长度排好,进行计算,再还原回原来的顺序。

这篇关于pytorch nn.utils.rnn.pack_padded_sequence 分析的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/228478

相关文章

python使用Akshare与Streamlit实现股票估值分析教程(图文代码)

《python使用Akshare与Streamlit实现股票估值分析教程(图文代码)》入职测试中的一道题,要求:从Akshare下载某一个股票近十年的财务报表包括,资产负债表,利润表,现金流量表,保存... 目录一、前言二、核心知识点梳理1、Akshare数据获取2、Pandas数据处理3、Matplotl

python panda库从基础到高级操作分析

《pythonpanda库从基础到高级操作分析》本文介绍了Pandas库的核心功能,包括处理结构化数据的Series和DataFrame数据结构,数据读取、清洗、分组聚合、合并、时间序列分析及大数据... 目录1. Pandas 概述2. 基本操作:数据读取与查看3. 索引操作:精准定位数据4. Group

MySQL中EXISTS与IN用法使用与对比分析

《MySQL中EXISTS与IN用法使用与对比分析》在MySQL中,EXISTS和IN都用于子查询中根据另一个查询的结果来过滤主查询的记录,本文将基于工作原理、效率和应用场景进行全面对比... 目录一、基本用法详解1. IN 运算符2. EXISTS 运算符二、EXISTS 与 IN 的选择策略三、性能对比

MySQL 内存使用率常用分析语句

《MySQL内存使用率常用分析语句》用户整理了MySQL内存占用过高的分析方法,涵盖操作系统层确认及数据库层bufferpool、内存模块差值、线程状态、performance_schema性能数据... 目录一、 OS层二、 DB层1. 全局情况2. 内存占js用详情最近连续遇到mysql内存占用过高导致

深度解析Nginx日志分析与499状态码问题解决

《深度解析Nginx日志分析与499状态码问题解决》在Web服务器运维和性能优化过程中,Nginx日志是排查问题的重要依据,本文将围绕Nginx日志分析、499状态码的成因、排查方法及解决方案展开讨论... 目录前言1. Nginx日志基础1.1 Nginx日志存放位置1.2 Nginx日志格式2. 499

PyTorch中的词嵌入层(nn.Embedding)详解与实战应用示例

《PyTorch中的词嵌入层(nn.Embedding)详解与实战应用示例》词嵌入解决NLP维度灾难,捕捉语义关系,PyTorch的nn.Embedding模块提供灵活实现,支持参数配置、预训练及变长... 目录一、词嵌入(Word Embedding)简介为什么需要词嵌入?二、PyTorch中的nn.Em

Olingo分析和实践之EDM 辅助序列化器详解(最佳实践)

《Olingo分析和实践之EDM辅助序列化器详解(最佳实践)》EDM辅助序列化器是ApacheOlingoOData框架中无需完整EDM模型的智能序列化工具,通过运行时类型推断实现灵活数据转换,适用... 目录概念与定义什么是 EDM 辅助序列化器?核心概念设计目标核心特点1. EDM 信息可选2. 智能类

Olingo分析和实践之OData框架核心组件初始化(关键步骤)

《Olingo分析和实践之OData框架核心组件初始化(关键步骤)》ODataSpringBootService通过初始化OData实例和服务元数据,构建框架核心能力与数据模型结构,实现序列化、URI... 目录概述第一步:OData实例创建1.1 OData.newInstance() 详细分析1.1.1

Olingo分析和实践之ODataImpl详细分析(重要方法详解)

《Olingo分析和实践之ODataImpl详细分析(重要方法详解)》ODataImpl.java是ApacheOlingoOData框架的核心工厂类,负责创建序列化器、反序列化器和处理器等组件,... 目录概述主要职责类结构与继承关系核心功能分析1. 序列化器管理2. 反序列化器管理3. 处理器管理重要方

SpringBoot中六种批量更新Mysql的方式效率对比分析

《SpringBoot中六种批量更新Mysql的方式效率对比分析》文章比较了MySQL大数据量批量更新的多种方法,指出REPLACEINTO和ONDUPLICATEKEY效率最高但存在数据风险,MyB... 目录效率比较测试结构数据库初始化测试数据批量修改方案第一种 for第二种 case when第三种