pytorch nn.utils.rnn.pack_padded_sequence 分析

2023-10-17 22:50

本文主要是介绍pytorch nn.utils.rnn.pack_padded_sequence 分析,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

pack_padded_sequence

在nlp模型的forward方法中,可能有以下调用令读者疑惑

packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, text_lengths, batch_first=True, enforce_sorted=False)

为什么要使用pack_padded_sequence?

参考

  • Pytorch中的RNN之pack_padded_sequence()和pad_packed_sequence()
  • Pytorch中pack_padded_sequence和pad_packed_sequence的理解

当我们训练RNN时,如果想要进行批次化训练,由于句子的长短不一,所以需要截断和填充。

  • 为什么要截断?对于那些太长的句子,一般选择一个合适的长度来进行截断。
  • 为什么要填充?对于那些太短的句子,需要以 填充字符(比如<pad>)填充,使得该批次内所有的句子长度相同。

但是,填充会带来其它问题:

  • 增加了计算复杂度。假设一个批次内有2个句子,长度分别为5和2。我们要保证批次内所有的句子长度相同,就需要把长度为2的句子填充为5。这样喂给RNN时,需要计算 2 × 5 = 10 2 \times 5 =10 2×5=10次,而实际真正需要的是 5 + 2 = 7 5+2=7 5+2=7次。
  • 得到的结果可能不准确。我们知道RNN取的是最后一个时间步的隐藏状态做为输出,虽然在填充时,一般是以全0的词向量填充,RNN神经元的权重乘以零不会影响最终的输出,但还有偏差 b b b,如果 b ≠ 0 b \neq 0 b=0,还是会影响到最后的输出。

    当然这个问题不大,主要是第1个问题,毕竟批次大小很大的时候影响还是不小的。

我们用图解进一步说明这个问题。假设某句子“Yes”只有一个单词,但是填充了多余的pad符号,这样会导致LSTM对它的表示通过了非常多无用的字符,这样得到的句子表示就会有误差

那么我们正确的做法应该是怎么样呢?在上面这个例子,我们想要得到的仅仅是LSTM过完单词"Yes"之后的表示,而不是通过了多个无用的“Pad”得到的表示,如下图:

所以,Pytorch提供了pack_padded_sequence方法来压缩填充字符,加快RNN的计算效率。

pack_padded_sequence是如何压缩的?

那么它是如何做压缩的呢?举个例子,假如一个batch里有5个句子,长度分别是5、4、3、3、2、1。将它们按列压缩,在这个过程中删除了pad字符。所以你可以想象这样的训练过程:

  1. 第一个batch有5个单词,[I, I, This, No, Yes],它们被送入LSTM。
  2. 第二个batch有4个单词被送入LSTM。
  3. 以此类推,之后的batch长度逐渐减小,分别是3、3、2、1
    在这个过程中,pad字符被自然地忽略掉了。

pack_padded_sequence的参数含义

必备参数是句子向量embedded,以及每个句子长度的变量text_lengths。前者通常包含3个维度,即[批次大小、句子最大长度、单词向量长度](前两者顺序可换);后者通常是list类型,或者一维Tensor类型,包含了每个句子的长度。

  • batch_first表示输入的向量是batch维度优先的。
  • enforce_sorted代表输入的句子是否已经按照长度顺序排好,如果为False,那么函数估计会先按照长度排好,进行计算,再还原回原来的顺序。

这篇关于pytorch nn.utils.rnn.pack_padded_sequence 分析的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/228478

相关文章

Python中的Walrus运算符分析示例详解

《Python中的Walrus运算符分析示例详解》Python中的Walrus运算符(:=)是Python3.8引入的一个新特性,允许在表达式中同时赋值和返回值,它的核心作用是减少重复计算,提升代码简... 目录1. 在循环中避免重复计算2. 在条件判断中同时赋值变量3. 在列表推导式或字典推导式中简化逻辑

判断PyTorch是GPU版还是CPU版的方法小结

《判断PyTorch是GPU版还是CPU版的方法小结》PyTorch作为当前最流行的深度学习框架之一,支持在CPU和GPU(NVIDIACUDA)上运行,所以对于深度学习开发者来说,正确识别PyTor... 目录前言为什么需要区分GPU和CPU版本?性能差异硬件要求如何检查PyTorch版本?方法1:使用命

Java程序进程起来了但是不打印日志的原因分析

《Java程序进程起来了但是不打印日志的原因分析》:本文主要介绍Java程序进程起来了但是不打印日志的原因分析,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录Java程序进程起来了但是不打印日志的原因1、日志配置问题2、日志文件权限问题3、日志文件路径问题4、程序

Java字符串操作技巧之语法、示例与应用场景分析

《Java字符串操作技巧之语法、示例与应用场景分析》在Java算法题和日常开发中,字符串处理是必备的核心技能,本文全面梳理Java中字符串的常用操作语法,结合代码示例、应用场景和避坑指南,可快速掌握字... 目录引言1. 基础操作1.1 创建字符串1.2 获取长度1.3 访问字符2. 字符串处理2.1 子字

pytorch自动求梯度autograd的实现

《pytorch自动求梯度autograd的实现》autograd是一个自动微分引擎,它可以自动计算张量的梯度,本文主要介绍了pytorch自动求梯度autograd的实现,具有一定的参考价值,感兴趣... autograd是pytorch构建神经网络的核心。在 PyTorch 中,结合以下代码例子,当你

在PyCharm中安装PyTorch、torchvision和OpenCV详解

《在PyCharm中安装PyTorch、torchvision和OpenCV详解》:本文主要介绍在PyCharm中安装PyTorch、torchvision和OpenCV方式,具有很好的参考价值,... 目录PyCharm安装PyTorch、torchvision和OpenCV安装python安装PyTor

Python 迭代器和生成器概念及场景分析

《Python迭代器和生成器概念及场景分析》yield是Python中实现惰性计算和协程的核心工具,结合send()、throw()、close()等方法,能够构建高效、灵活的数据流和控制流模型,这... 目录迭代器的介绍自定义迭代器省略的迭代器生产器的介绍yield的普通用法yield的高级用法yidle

pytorch之torch.flatten()和torch.nn.Flatten()的用法

《pytorch之torch.flatten()和torch.nn.Flatten()的用法》:本文主要介绍pytorch之torch.flatten()和torch.nn.Flatten()的用... 目录torch.flatten()和torch.nn.Flatten()的用法下面举例说明总结torch

C++ Sort函数使用场景分析

《C++Sort函数使用场景分析》sort函数是algorithm库下的一个函数,sort函数是不稳定的,即大小相同的元素在排序后相对顺序可能发生改变,如果某些场景需要保持相同元素间的相对顺序,可使... 目录C++ Sort函数详解一、sort函数调用的两种方式二、sort函数使用场景三、sort函数排序

kotlin中const 和val的区别及使用场景分析

《kotlin中const和val的区别及使用场景分析》在Kotlin中,const和val都是用来声明常量的,但它们的使用场景和功能有所不同,下面给大家介绍kotlin中const和val的区别,... 目录kotlin中const 和val的区别1. val:2. const:二 代码示例1 Java