自然语言处理(扩展学习1):Scheduled Sampling(计划采样)与Teacher forcing(教师强制)

本文主要是介绍自然语言处理(扩展学习1):Scheduled Sampling(计划采样)与Teacher forcing(教师强制),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

自然语言处理(扩展学习1):Scheduled Sampling(计划采样)与2. Teacher forcing(教师强制)


作者:安静到无声 个人主页

作者简介:人工智能和硬件设计博士生、CSDN与阿里云开发者博客专家,多项比赛获奖者,发表SCI论文多篇。

Thanks♪(・ω・)ノ 如果觉得文章不错或能帮助到你学习,可以点赞👍收藏📁评论📒+关注哦! o( ̄▽ ̄)d

欢迎大家来到安静到无声的 《基于pytorch的自然语言处理入门与实践》,如果对所写内容感兴趣请看《基于pytorch的自然语言处理入门与实践》系列讲解 - 总目录,同时这也可以作为大家学习的参考。欢迎订阅,请多多支持!

目录标题

  • 自然语言处理(扩展学习1):Scheduled Sampling(计划采样)与2. Teacher forcing(教师强制)
  • 1. Scheduled Sampling(计划采样)
    • 1.1 概念解释
    • 1.2 代码实现
  • 2. Teacher forcing(教师强制)
    • 2.1 概念解释
    • 2.1 代码实现
  • 参考

1. Scheduled Sampling(计划采样)

1.1 概念解释

Scheduled Sampling是一种用于训练序列生成模型的策略,旨在缓解曝光偏差(Exposure Bias)问题。曝光偏差是指模型在训练时接触到的数据分布与测试时的数据分布不一致,导致性能下降。

在Scheduled Sampling中,模型在每个时间步骤都有一定的概率选择使用真实目标序列中的单词作为输入,而不是使用前一个时间步骤生成的单词。这样可以使模型更好地适应真实数据分布,减少曝光偏差问题。

具体来说,Scheduled Sampling使用以下公式计算每个时间步骤生成当前单词的概率:

P ( y t ∣ y 1 , . . . , y t − 1 ) = ( 1 − ϵ ) ∗ P model ( y t ∣ y 1 , . . . , y t − 1 ) + ϵ ∗ P data ( y t ∣ y 1 , . . . , y t − 1 ) P(y_t|y_1, ..., y_{t-1}) = (1 - \epsilon) * P_{\text{model}}(y_t|y_1, ..., y_{t-1}) + \epsilon * P_{\text{data}}(y_t|y_1, ..., y_{t-1}) P(yty1,...,yt1)=(1ϵ)Pmodel(yty1,...,yt1)+ϵPdata(yty1,...,yt1)其中, P ( y t ∣ y 1 , . . . , y t − 1 ) P(y_t|y_1, ..., y_{t-1}) P(yty1,...,yt1)表示在给定前面的生成序列条件下生成当前单词 y t y_t yt的概率, P model ( y t ∣ y 1 , . . . , y t − 1 ) P_{\text{model}}(y_t|y_1, ..., y_{t-1}) Pmodel(yty1,...,yt1)表示模型生成该单词的概率, P data ( y t ∣ y 1 , . . . , y t − 1 ) P_{\text{data}}(y_t|y_1, ..., y_{t-1}) Pdata(yty1,...,yt1)表示真实目标序列中该单词的概率。参数 ϵ \epsilon ϵ用于控制采样策略,可以随着训练的进行而逐渐增加。

1.2 代码实现

下面是一个使用Python实现Scheduled Sampling的示例代码:

在这里插入图片描述

其中, P ( y t ∣ y < t , x ) P(y_t | y_{<t}, x) P(yty<t,x)表示在给定前文和输入的条件下,生成当前时间步的输出的概率。 P model P_{\text{model}} Pmodel表示由模型生成的概率分布, P prev P_{\text{prev}} Pprev表示根据上一个时间步的真实输出计算得到的概率分布。sample是从均匀分布中采样得到的一个随机数,threshold是一个控制Scheduled Sampling引入程度的超参数。

2. Teacher forcing(教师强制)

2.1 概念解释

Teacher forcing(教师强制)是一种在序列生成模型中使用的训练技术。具体来说,当使用RNN(循环神经网络)或类似架构的模型进行序列生成时,每个时间步都会根据前一个时间步的输入和隐藏状态生成输出。在训练期间,如果使用teacher forcing,那么每个时间步的输入将是真实的目标序列(而不是模型自身生成的序列)。这意味着模型在每个时间步都能够观察到正确的答案,从而更容易地学习到正确的模式和规律。

2.1 代码实现

import torch
import torch.nn as nn# 定义序列到序列模型
class Seq2SeqModel(nn.Module):def __init__(self, input_dim, hidden_dim, output_dim):super(Seq2SeqModel, self).__init__()self.hidden_dim = hidden_dim# 定义编码器self.encoder = nn.RNN(input_dim, hidden_dim)# 定义解码器self.decoder = nn.RNN(output_dim, hidden_dim)# 定义全连接层,将解码器的输出映射为目标序列self.fc = nn.Linear(hidden_dim, output_dim)def forward(self, input_seq, target_seq):# 编码器计算输入序列的隐藏状态_, hidden_state = self.encoder(input_seq)# 解码器初始化隐藏状态decoder_hidden_state = hidden_state# 用真实目标序列作为输入来指导解码器的生成过程decoder_outputs, _ = self.decoder(target_seq, decoder_hidden_state)# 对解码器的输出应用全连接层进行映射output_seq = self.fc(decoder_outputs)return output_seq# 创建模型实例
input_dim = 10
hidden_dim = 20
output_dim = 10
model = Seq2SeqModel(input_dim, hidden_dim, output_dim)# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 训练模型
num_epochs = 10
for epoch in range(num_epochs):# 步骤1:将模型设为训练模式model.train()# 步骤2:清零梯度optimizer.zero_grad()# 步骤3:前向传播input_seq = torch.randn(5, 3, input_dim)  # 输入序列target_seq = torch.randn(5, 3, output_dim)  # 目标序列output_seq = model(input_seq, target_seq)# 步骤4:计算损失loss = criterion(output_seq, target_seq)# 步骤5:反向传播和优化loss.backward()optimizer.step()print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))

上述代码中,我们定义了一个简单的序列到序列模型Seq2SeqModel,其中包括一个RNN编码器、一个RNN解码器和一个全连接层。在forward方法中,我们首先使用编码器计算输入序列的隐藏状态,然后将隐藏状态作为解码器的初始隐藏状态。接下来,我们使用真实目标序列来指导解码器的生成过程,并将解码器的输出映射为目标序列。在训练阶段,我们使用真实目标序列作为输入来指导模型的生成过程。最后,我们定义了损失函数和优化器,并进行训练。

需要注意的是,在实际应用中,模型的推理阶段并不会使用真实目标序列来指导生成过程。在推理阶段,可以将前一个时间步的模型输出作为下一个时间步的输入,从而进行序列的自我生成。

--------推荐专栏--------
🔥 手把手实现Image captioning
💯CNN模型压缩
💖模式识别与人工智能(程序与算法)
🔥FPGA—Verilog与Hls学习与实践
💯基于Pytorch的自然语言处理入门与实践

参考

Scheduled Sampling的搜索结果_百度图片搜索 (baidu.com)
Teacher forcing RNN的搜索结果_百度图片搜索 (baidu.com)

在这里插入图片描述

这篇关于自然语言处理(扩展学习1):Scheduled Sampling(计划采样)与Teacher forcing(教师强制)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/235192

相关文章

Java对异常的认识与异常的处理小结

《Java对异常的认识与异常的处理小结》Java程序在运行时可能出现的错误或非正常情况称为异常,下面给大家介绍Java对异常的认识与异常的处理,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参... 目录一、认识异常与异常类型。二、异常的处理三、总结 一、认识异常与异常类型。(1)简单定义-什么是

Go学习记录之runtime包深入解析

《Go学习记录之runtime包深入解析》Go语言runtime包管理运行时环境,涵盖goroutine调度、内存分配、垃圾回收、类型信息等核心功能,:本文主要介绍Go学习记录之runtime包的... 目录前言:一、runtime包内容学习1、作用:① Goroutine和并发控制:② 垃圾回收:③ 栈和

Golang 日志处理和正则处理的操作方法

《Golang日志处理和正则处理的操作方法》:本文主要介绍Golang日志处理和正则处理的操作方法,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考... 目录1、logx日志处理1.1、logx简介1.2、日志初始化与配置1.3、常用方法1.4、配合defer

springboot加载不到nacos配置中心的配置问题处理

《springboot加载不到nacos配置中心的配置问题处理》:本文主要介绍springboot加载不到nacos配置中心的配置问题处理,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑... 目录springboot加载不到nacos配置中心的配置两种可能Spring Boot 版本Nacos

Spring组件实例化扩展点之InstantiationAwareBeanPostProcessor使用场景解析

《Spring组件实例化扩展点之InstantiationAwareBeanPostProcessor使用场景解析》InstantiationAwareBeanPostProcessor是Spring... 目录一、什么是InstantiationAwareBeanPostProcessor?二、核心方法解

Android学习总结之Java和kotlin区别超详细分析

《Android学习总结之Java和kotlin区别超详细分析》Java和Kotlin都是用于Android开发的编程语言,它们各自具有独特的特点和优势,:本文主要介绍Android学习总结之Ja... 目录一、空安全机制真题 1:Kotlin 如何解决 Java 的 NullPointerExceptio

python web 开发之Flask中间件与请求处理钩子的最佳实践

《pythonweb开发之Flask中间件与请求处理钩子的最佳实践》Flask作为轻量级Web框架,提供了灵活的请求处理机制,中间件和请求钩子允许开发者在请求处理的不同阶段插入自定义逻辑,实现诸如... 目录Flask中间件与请求处理钩子完全指南1. 引言2. 请求处理生命周期概述3. 请求钩子详解3.1

Python处理大量Excel文件的十个技巧分享

《Python处理大量Excel文件的十个技巧分享》每天被大量Excel文件折磨的你看过来!这是一份Python程序员整理的实用技巧,不说废话,直接上干货,文章通过代码示例讲解的非常详细,需要的朋友可... 目录一、批量读取多个Excel文件二、选择性读取工作表和列三、自动调整格式和样式四、智能数据清洗五、

SpringBoot如何对密码等敏感信息进行脱敏处理

《SpringBoot如何对密码等敏感信息进行脱敏处理》这篇文章主要为大家详细介绍了SpringBoot对密码等敏感信息进行脱敏处理的几个常用方法,文中的示例代码讲解详细,感兴趣的小伙伴可以了解下... 目录​1. 配置文件敏感信息脱敏​​2. 日志脱敏​​3. API响应脱敏​​4. 其他注意事项​​总结

Python使用python-docx实现自动化处理Word文档

《Python使用python-docx实现自动化处理Word文档》这篇文章主要为大家展示了Python如何通过代码实现段落样式复制,HTML表格转Word表格以及动态生成可定制化模板的功能,感兴趣的... 目录一、引言二、核心功能模块解析1. 段落样式与图片复制2. html表格转Word表格3. 模板生