微调预训练模型方式的文本语义匹配(Further Pretraining Bert)

2023-10-23 08:59

本文主要是介绍微调预训练模型方式的文本语义匹配(Further Pretraining Bert),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

微调预训练模型方式的文本语义匹配(Further Pretraining Bert)

今年带着小伙伴参加了天池赛道三: 小布助手对话短文本语义匹配比赛,虽然最后没有杀进B榜,但也是预料之中的结果,最后成绩在110名左右,还算能接受。

言归正传,本文会解说苏剑林(苏神)的Baseline方案和代码,然后会分享我在Baseline上使用的tricks还有我们的方案和实验结果。

干货
Github:https://github.com/Ludong418/gaic-2021-task3
资料集合:个人整理的资料大全

苏神的Baseline

苏神在第二次发布数据没几天就公布了自己的Baseline,线上成绩大概是86左右,他的方案是mlm和文本语义匹配两个任务同时进行。

方案
词典不匹配问题

由于官方发布的数据是脱敏的,所以不建议大家直接使用其他已经训练好的预训练模型,当然也有人直接使用了,也有一定的效果,也有人好奇为什么脱敏数据不要用已经训练好的预训练模型呢?因为存在脱敏数据和预训练词典不匹配问题。nlp在深度学习预处理有一步很重要的过程就是 tokens 转 ids,而脱敏数据是已经转好的 ids,也就是下面形式。所以例子中的 ‘2’ 代表的意思是 ‘我’,若使用了预训练模型字典的 ‘2’,那就确实是 ‘2’ 的含义了。
在这里插入图片描述
所以我们就需要重新训练一个预训练模型。但是我们都知道一个新的预训练模型需要大量的数据和算力,如何去做呢?

如何训练脱敏数据的预训练模型

我们可以只保留模型参数部分,而tokens embeddings table可以替换掉,举个例子就是预训练模型中20000多个embedded table,随机换成自己数据词典大小的6000多个tokens,苏神的替换方式是保留了bert语料中token频数top6000+的embedded,当然我们也可以随机初始化6000多个embedded,两种方式我发现效果差不多。
在这里插入图片描述
但是别忘几个特殊的token要加入字典中,‘no’ 和 ‘yes’ 就是文本对的标签,‘相似’ 和 ‘不相似’,这两个token很关键。

0: pad, 1: unk, 2: cls, 3: sep, 4: mask, 5: no, 6: yes

最后就可以训练一个mlm(masked language model)了,但是苏神不只是就是训练一个mlm,而是在训练mlm过程中随便把文本语义匹配也做了,它使用第一个token([cls])的输出作为文本语义匹配任务的输出。

模型代码
预处理

这份数据的预处理过程比较简单,重点还是讲解mlm任务的输入格式,注意到output_ids第一个token要 +5, 目的就是用 [cls] 来预测 yes 或者 no。

def sample_convert(text1, text2, label, random=False):"""转换为MLM格式"""text1_ids = [tokens.get(t, 1) for t in text1]text2_ids = [tokens.get(t, 1) for t in text2]if random:if np.random.random() < 0.5:text1_ids, text2_ids = text2_ids, text1_idstext1_ids, out1_ids = random_mask(text1_ids)text2_ids, out2_ids = random_mask(text2_ids)else:out1_ids = [0] * len(text1_ids)out2_ids = [0] * len(text2_ids)token_ids = [2] + text1_ids + [3] + text2_ids + [3]segment_ids = [0] * len(token_ids)# +5 目的就是用 [cls] 来预测 yes 或者 nooutput_ids = [label + 5] + out1_ids + [0] + out2_ids + [0]return token_ids, segment_ids, output_ids
模型

模型是一个mlm任务,就是希望被masked的token进行预测真实的token,完成一个完形填空的任务,而我们希望输出的第一个token([cls])用来预测 ‘yes’ 或者 ‘no’ 这两个token。
在这里插入图片描述
以下就是评估模型的代码:

def evaluate(data):"""线下评测函数"""Y_true, Y_pred = [], []for x_true, y_true in data:y_pred = model.predict(x_true)[:, 0, 5:7]y_pred = y_pred[:, 1] / (y_pred.sum(axis=1) + 1e-8)y_true = y_true[:, 0] - 5Y_pred.extend(y_pred)Y_true.extend(y_true)return roc_auc_score(Y_true, Y_pred)

注意到这行代码,其实就是选择了output的no和yes所在的维度的值进行预测的。

# y_pred shape:[batch_size, max_seq_len, voc_szie] 
y_pred = model.predict(x_true)[:, 0, 5:7]

在这里插入图片描述

我的方案

我们的做法差不多,只不过是预训练模型和文本语义匹配分开做了,我们先使用了脱敏数据和nazha预训练模型微调了一个新的预训练模型,然后利用新的预训练模型完成文本语义匹配二分类任务。

Further Pretraining

预训练模型使用了中文nezha-pytorch版本训练好的模型进行微调,使用了NeZhaForMaskedLM class 做mlm任务,但是源码是对全词计算loss,所以我修改了计算loss方法,只对masked的token计算loss。修改代码如下:

if labels is not None:# 只对mask的部分进行计算masked_lm_positions = torch.where(labels.view(-1) != 0)loss_fct = CrossEntropyLoss()  masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size)[masked_lm_positions],labels.view(-1)[masked_lm_positions])outputs = (masked_lm_loss,) + outputsreturn outputs  # (ltr_lm_loss), (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
文本语义匹配

利用新的预训练模型做二分类任务,做法和常规的bert分类没有啥区别,最后线上效果能达到90左右,要提高模型的效果还得需要一些小tricks。

trick 1:数据增强

目的是为了增加训练数据数量,效果能略有提升零点几个点。
句子等价替换:如果 句子A = 句子B, 句子B = 句子C, 则 句子A = 句子C
句子对调:把所有的句子对调

trick 2:对抗式学习

我们测试了FGM和VAT两种对抗学习,效果都有提升1个点左右,VAT也可以使用在mlm任务中,但是发现效果并不是很好。

trick 3:伪标签

利用训练好的模型对test数据集预测,把输出概率大于阈值的数据加入到训练集中训练,效果提升不明显。

trick 4:半监督学习(没有实现)

可以考虑使用半监督学习,例如mixtext、mixup等模型,目的也是为了增大训练模型样本,但这次比赛中没来得及去实验,但是在工作中的一个项目中使用了mixtext模型,效果有很大的提升。

trick 4:置信学习(没有实现)

数据集中总会或多或少出现错误标签,若能剔除掉这些错误标签在进行训练,效果会有一定的提升,所以可以考虑使用置信学习的方法去发现错误标签。

结论

本次比赛没能杀入B榜主要还是因为身为打工人,我们只能在下班和周末搞搞,很多想法并不能有足够的时间去实验,不像高校里的学生论文看的多,时间也足够。其次就是算力没跟上,显卡不够多,一个mlm模型要花一天多才能训练完成,然后加入对抗式学习速度更慢了。

整个任务的难度不大,一开始碰到脱敏数据我也懵了一阵,我也想到自己要训练一个预训练模型,但是不敢确定效果,但最后看来效果确实不错,所以深度学习这门技术还是得靠动手验证的科学,最后推荐一篇不错的论文Don’t Stop Pretraining: Adapt Language Models to Domains and Tasks。

这篇关于微调预训练模型方式的文本语义匹配(Further Pretraining Bert)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!


原文地址:https://blog.csdn.net/weixin_40570579/article/details/116101615
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.chinasem.cn/article/266848

相关文章

Python获取浏览器Cookies的四种方式小结

《Python获取浏览器Cookies的四种方式小结》在进行Web应用程序测试和开发时,获取浏览器Cookies是一项重要任务,本文我们介绍四种用Python获取浏览器Cookies的方式,具有一定的... 目录什么是 Cookie?1.使用Selenium库获取浏览器Cookies2.使用浏览器开发者工具

Java获取当前时间String类型和Date类型方式

《Java获取当前时间String类型和Date类型方式》:本文主要介绍Java获取当前时间String类型和Date类型方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,... 目录Java获取当前时间String和Date类型String类型和Date类型输出结果总结Java获取

C#监听txt文档获取新数据方式

《C#监听txt文档获取新数据方式》文章介绍通过监听txt文件获取最新数据,并实现开机自启动、禁用窗口关闭按钮、阻止Ctrl+C中断及防止程序退出等功能,代码整合于主函数中,供参考学习... 目录前言一、监听txt文档增加数据二、其他功能1. 设置开机自启动2. 禁止控制台窗口关闭按钮3. 阻止Ctrl +

linux批量替换文件内容的实现方式

《linux批量替换文件内容的实现方式》本文总结了Linux中批量替换文件内容的几种方法,包括使用sed替换文件夹内所有文件、单个文件内容及逐行字符串,强调使用反引号和绝对路径,并分享个人经验供参考... 目录一、linux批量替换文件内容 二、替换文件内所有匹配的字符串 三、替换每一行中全部str1为st

Python实现终端清屏的几种方式详解

《Python实现终端清屏的几种方式详解》在使用Python进行终端交互式编程时,我们经常需要清空当前终端屏幕的内容,本文为大家整理了几种常见的实现方法,有需要的小伙伴可以参考下... 目录方法一:使用 `os` 模块调用系统命令方法二:使用 `subprocess` 模块执行命令方法三:打印多个换行符模拟

RabbitMQ消息总线方式刷新配置服务全过程

《RabbitMQ消息总线方式刷新配置服务全过程》SpringCloudBus通过消息总线与MQ实现微服务配置统一刷新,结合GitWebhooks自动触发更新,避免手动重启,提升效率与可靠性,适用于配... 目录前言介绍环境准备代码示例测试验证总结前言介绍在微服务架构中,为了更方便的向微服务实例广播消息,

SpringBoot中六种批量更新Mysql的方式效率对比分析

《SpringBoot中六种批量更新Mysql的方式效率对比分析》文章比较了MySQL大数据量批量更新的多种方法,指出REPLACEINTO和ONDUPLICATEKEY效率最高但存在数据风险,MyB... 目录效率比较测试结构数据库初始化测试数据批量修改方案第一种 for第二种 case when第三种

Linux线程之线程的创建、属性、回收、退出、取消方式

《Linux线程之线程的创建、属性、回收、退出、取消方式》文章总结了线程管理核心知识:线程号唯一、创建方式、属性设置(如分离状态与栈大小)、回收机制(join/detach)、退出方法(返回/pthr... 目录1. 线程号2. 线程的创建3. 线程属性4. 线程的回收5. 线程的退出6. 线程的取消7.

golang程序打包成脚本部署到Linux系统方式

《golang程序打包成脚本部署到Linux系统方式》Golang程序通过本地编译(设置GOOS为linux生成无后缀二进制文件),上传至Linux服务器后赋权执行,使用nohup命令实现后台运行,完... 目录本地编译golang程序上传Golang二进制文件到linux服务器总结本地编译Golang程序

Linux下删除乱码文件和目录的实现方式

《Linux下删除乱码文件和目录的实现方式》:本文主要介绍Linux下删除乱码文件和目录的实现方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录linux下删除乱码文件和目录方法1方法2总结Linux下删除乱码文件和目录方法1使用ls -i命令找到文件或目录