NLP实践——文本生成中停不下来的问题

2023-10-23 08:59

本文主要是介绍NLP实践——文本生成中停不下来的问题,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

NLP实践——文本生成中停不下来的问题

  • 1. 问题概述
  • 2. 造成的原因
  • 3. 解决的方法
  • 4. 效果

1. 问题概述

对于NLG任务,在推理阶段可能经常会遇到“停不下来”的问题,即重复的token被反复预测出来。
例如,输入“Google”,翻译模型可能会翻译为“谷歌谷歌”。

这个问题已经有很多人研究很久了,在模型侧提出的应对方案也有很多,本文介绍最简便的一种处理方法,只需要添加一行代码,就可以有效地改善。

2. 造成的原因

对于这种现象出现的原因,有很多相关的分析和介绍,其中苏神的这篇文章让我感到受益匪浅,从数学的角度分析了为什么会重复,非常建议大家读一下这篇文章。

3. 解决的方法

其实在transformers的源码中,以及预置了一个参数,用来控制对重复出现token的惩罚,思想非常朴素,最早应该是出现在CTRL的论文中:
https://arxiv.org/pdf/1909.05858.pdf

我们来看一下论文里是怎么描述的:
ctrl
在生成的时候,就是在计算词表中词汇的概率嘛,如果我们不希望之前出现的token连续出现,那只要把出现过的token对应的得分,人为地降低就好了,也就是给它一个惩罚的力度,让它变小一点。

反应在代码中,就是transformers/generation_utils.py中的GenerationMixin.generate方法,其中的repetition_penalty参数,就是用来控制这个惩罚的,也就是论文中的theta。

这个参数必须为大于0的浮点数,当取值为1.0时,相当于什么也没有做。如果在调用generate的时候给了这个参数,则会创建一个RepetitionPenaltyLogitsProcessor,简单看一下这个Processor是如何运作的:

class RepetitionPenaltyLogitsProcessor(LogitsProcessor):r""":class:`transformers.LogitsProcessor` enforcing an exponential penalty on repeated sequences.Args:repetition_penalty (:obj:`float`):The parameter for repetition penalty. 1.0 means no penalty. See `this paper<https://arxiv.org/pdf/1909.05858.pdf>`__ for more details."""def __init__(self, penalty: float):if not isinstance(penalty, float) or not (penalty > 0):raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")self.penalty = penaltydef __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:score = torch.gather(scores, 1, input_ids)# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probabilityscore = torch.where(score < 0, score * self.penalty, score / self.penalty)scores.scatter_(1, input_ids, score)return scores

其中input_ids就是generate时,输入的input_ids, scores是每一步推理计算出来的为下一步提供的得分。简单来说,这个类就是根据输入序列的token id,把score里边对应位置的得分取出来,然后惩罚一下这些位置的得分,让它的得分变小,然后把惩罚过的分数,替换掉原来计算出来的得分。

4. 效果

还是以翻译模型为例,采用的模型是opus-mt-en-zh,实例化这个模型:

from transformers import AutoModelWithLMHead,AutoTokenizer
mode_name = 'liam168/trans-opus-mt-en-zh'
model = AutoModelWithLMHead.from_pretrained(mode_name)
tokenizer = AutoTokenizer.from_pretrained(mode_name)

翻译一个词:

text = 'Google'
batch = tokenizer.prepare_seq2seq_batch(src_texts=[text], return_tensors='pt', max_length=512)
translation = model.generate(**batch)
res = tokenizer.batch_decode(translation, skip_special_tokens=True)

翻译结果为“谷歌谷歌”。可以看到,当输入文本很短时,很容易就出现了重复。

而如果在generate的时候,增加一个参数:

text = 'Google'
batch = tokenizer.prepare_seq2seq_batch(src_texts=[text], return_tensors='pt', max_length=512)
batch['repetition_penalty'] = 1.2   # 论文中默认的参数1.2
translation = model.generate(**batch)
res = tokenizer.batch_decode(translation, skip_special_tokens=True)

翻译结果就变成了只有一个"谷歌"。

再大胆一点,如果把惩罚力度设置为无穷大,也会出问题。当设置惩罚为float('inf')时,在翻译句子“Google has Google translate”的时候,就会变成“谷歌有Google翻译”,第二个Google就因为被惩罚了而没有翻译成谷歌,而如果惩罚为1.2,则翻译结果为“谷歌有谷歌翻译”。所以惩罚力度设置为多大,还需要自己把握一下。

这篇关于NLP实践——文本生成中停不下来的问题的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/266847

相关文章

springboot项目中整合高德地图的实践

《springboot项目中整合高德地图的实践》:本文主要介绍springboot项目中整合高德地图的实践,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一:高德开放平台的使用二:创建数据库(我是用的是mysql)三:Springboot所需的依赖(根据你的需求再

SpringBoot3应用中集成和使用Spring Retry的实践记录

《SpringBoot3应用中集成和使用SpringRetry的实践记录》SpringRetry为SpringBoot3提供重试机制,支持注解和编程式两种方式,可配置重试策略与监听器,适用于临时性故... 目录1. 简介2. 环境准备3. 使用方式3.1 注解方式 基础使用自定义重试策略失败恢复机制注意事项

MySQL MCP 服务器安装配置最佳实践

《MySQLMCP服务器安装配置最佳实践》本文介绍MySQLMCP服务器的安装配置方法,本文结合实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下... 目录mysql MCP 服务器安装配置指南简介功能特点安装方法数据库配置使用MCP Inspector进行调试开发指

SQLite3命令行工具最佳实践指南

《SQLite3命令行工具最佳实践指南》SQLite3是轻量级嵌入式数据库,无需服务器支持,具备ACID事务与跨平台特性,适用于小型项目和学习,sqlite3.exe作为命令行工具,支持SQL执行、数... 目录1. SQLite3简介和特点2. sqlite3.exe使用概述2.1 sqlite3.exe

苹果macOS 26 Tahoe主题功能大升级:可定制图标/高亮文本/文件夹颜色

《苹果macOS26Tahoe主题功能大升级:可定制图标/高亮文本/文件夹颜色》在整体系统设计方面,macOS26采用了全新的玻璃质感视觉风格,应用于Dock栏、应用图标以及桌面小部件等多个界面... 科技媒体 MACRumors 昨日(6 月 13 日)发布博文,报道称在 macOS 26 Tahoe 中

Python实现精准提取 PDF中的文本,表格与图片

《Python实现精准提取PDF中的文本,表格与图片》在实际的系统开发中,处理PDF文件不仅限于读取整页文本,还有提取文档中的表格数据,图片或特定区域的内容,下面我们来看看如何使用Python实... 目录安装 python 库提取 PDF 文本内容:获取整页文本与指定区域内容获取页面上的所有文本内容获取

SQL中JOIN操作的条件使用总结与实践

《SQL中JOIN操作的条件使用总结与实践》在SQL查询中,JOIN操作是多表关联的核心工具,本文将从原理,场景和最佳实践三个方面总结JOIN条件的使用规则,希望可以帮助开发者精准控制查询逻辑... 目录一、ON与WHERE的本质区别二、场景化条件使用规则三、最佳实践建议1.优先使用ON条件2.WHERE用

Springboot整合Redis主从实践

《Springboot整合Redis主从实践》:本文主要介绍Springboot整合Redis主从的实例,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录前言原配置现配置测试LettuceConnectionFactory.setShareNativeConnect

MySQL 设置AUTO_INCREMENT 无效的问题解决

《MySQL设置AUTO_INCREMENT无效的问题解决》本文主要介绍了MySQL设置AUTO_INCREMENT无效的问题解决,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参... 目录快速设置mysql的auto_increment参数一、修改 AUTO_INCREMENT 的值。

关于跨域无效的问题及解决(java后端方案)

《关于跨域无效的问题及解决(java后端方案)》:本文主要介绍关于跨域无效的问题及解决(java后端方案),具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录通用后端跨域方法1、@CrossOrigin 注解2、springboot2.0 实现WebMvcConfig