微调预训练模型方式的文本语义匹配(Further Pretraining Bert)

2023-10-23 08:59

本文主要是介绍微调预训练模型方式的文本语义匹配(Further Pretraining Bert),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

微调预训练模型方式的文本语义匹配(Further Pretraining Bert)

今年带着小伙伴参加了天池赛道三: 小布助手对话短文本语义匹配比赛,虽然最后没有杀进B榜,但也是预料之中的结果,最后成绩在110名左右,还算能接受。

言归正传,本文会解说苏剑林(苏神)的Baseline方案和代码,然后会分享我在Baseline上使用的tricks还有我们的方案和实验结果。

干货
Github:https://github.com/Ludong418/gaic-2021-task3
资料集合:个人整理的资料大全

苏神的Baseline

苏神在第二次发布数据没几天就公布了自己的Baseline,线上成绩大概是86左右,他的方案是mlm和文本语义匹配两个任务同时进行。

方案
词典不匹配问题

由于官方发布的数据是脱敏的,所以不建议大家直接使用其他已经训练好的预训练模型,当然也有人直接使用了,也有一定的效果,也有人好奇为什么脱敏数据不要用已经训练好的预训练模型呢?因为存在脱敏数据和预训练词典不匹配问题。nlp在深度学习预处理有一步很重要的过程就是 tokens 转 ids,而脱敏数据是已经转好的 ids,也就是下面形式。所以例子中的 ‘2’ 代表的意思是 ‘我’,若使用了预训练模型字典的 ‘2’,那就确实是 ‘2’ 的含义了。
在这里插入图片描述
所以我们就需要重新训练一个预训练模型。但是我们都知道一个新的预训练模型需要大量的数据和算力,如何去做呢?

如何训练脱敏数据的预训练模型

我们可以只保留模型参数部分,而tokens embeddings table可以替换掉,举个例子就是预训练模型中20000多个embedded table,随机换成自己数据词典大小的6000多个tokens,苏神的替换方式是保留了bert语料中token频数top6000+的embedded,当然我们也可以随机初始化6000多个embedded,两种方式我发现效果差不多。
在这里插入图片描述
但是别忘几个特殊的token要加入字典中,‘no’ 和 ‘yes’ 就是文本对的标签,‘相似’ 和 ‘不相似’,这两个token很关键。

0: pad, 1: unk, 2: cls, 3: sep, 4: mask, 5: no, 6: yes

最后就可以训练一个mlm(masked language model)了,但是苏神不只是就是训练一个mlm,而是在训练mlm过程中随便把文本语义匹配也做了,它使用第一个token([cls])的输出作为文本语义匹配任务的输出。

模型代码
预处理

这份数据的预处理过程比较简单,重点还是讲解mlm任务的输入格式,注意到output_ids第一个token要 +5, 目的就是用 [cls] 来预测 yes 或者 no。

def sample_convert(text1, text2, label, random=False):"""转换为MLM格式"""text1_ids = [tokens.get(t, 1) for t in text1]text2_ids = [tokens.get(t, 1) for t in text2]if random:if np.random.random() < 0.5:text1_ids, text2_ids = text2_ids, text1_idstext1_ids, out1_ids = random_mask(text1_ids)text2_ids, out2_ids = random_mask(text2_ids)else:out1_ids = [0] * len(text1_ids)out2_ids = [0] * len(text2_ids)token_ids = [2] + text1_ids + [3] + text2_ids + [3]segment_ids = [0] * len(token_ids)# +5 目的就是用 [cls] 来预测 yes 或者 nooutput_ids = [label + 5] + out1_ids + [0] + out2_ids + [0]return token_ids, segment_ids, output_ids
模型

模型是一个mlm任务,就是希望被masked的token进行预测真实的token,完成一个完形填空的任务,而我们希望输出的第一个token([cls])用来预测 ‘yes’ 或者 ‘no’ 这两个token。
在这里插入图片描述
以下就是评估模型的代码:

def evaluate(data):"""线下评测函数"""Y_true, Y_pred = [], []for x_true, y_true in data:y_pred = model.predict(x_true)[:, 0, 5:7]y_pred = y_pred[:, 1] / (y_pred.sum(axis=1) + 1e-8)y_true = y_true[:, 0] - 5Y_pred.extend(y_pred)Y_true.extend(y_true)return roc_auc_score(Y_true, Y_pred)

注意到这行代码,其实就是选择了output的no和yes所在的维度的值进行预测的。

# y_pred shape:[batch_size, max_seq_len, voc_szie] 
y_pred = model.predict(x_true)[:, 0, 5:7]

在这里插入图片描述

我的方案

我们的做法差不多,只不过是预训练模型和文本语义匹配分开做了,我们先使用了脱敏数据和nazha预训练模型微调了一个新的预训练模型,然后利用新的预训练模型完成文本语义匹配二分类任务。

Further Pretraining

预训练模型使用了中文nezha-pytorch版本训练好的模型进行微调,使用了NeZhaForMaskedLM class 做mlm任务,但是源码是对全词计算loss,所以我修改了计算loss方法,只对masked的token计算loss。修改代码如下:

if labels is not None:# 只对mask的部分进行计算masked_lm_positions = torch.where(labels.view(-1) != 0)loss_fct = CrossEntropyLoss()  masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size)[masked_lm_positions],labels.view(-1)[masked_lm_positions])outputs = (masked_lm_loss,) + outputsreturn outputs  # (ltr_lm_loss), (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
文本语义匹配

利用新的预训练模型做二分类任务,做法和常规的bert分类没有啥区别,最后线上效果能达到90左右,要提高模型的效果还得需要一些小tricks。

trick 1:数据增强

目的是为了增加训练数据数量,效果能略有提升零点几个点。
句子等价替换:如果 句子A = 句子B, 句子B = 句子C, 则 句子A = 句子C
句子对调:把所有的句子对调

trick 2:对抗式学习

我们测试了FGM和VAT两种对抗学习,效果都有提升1个点左右,VAT也可以使用在mlm任务中,但是发现效果并不是很好。

trick 3:伪标签

利用训练好的模型对test数据集预测,把输出概率大于阈值的数据加入到训练集中训练,效果提升不明显。

trick 4:半监督学习(没有实现)

可以考虑使用半监督学习,例如mixtext、mixup等模型,目的也是为了增大训练模型样本,但这次比赛中没来得及去实验,但是在工作中的一个项目中使用了mixtext模型,效果有很大的提升。

trick 4:置信学习(没有实现)

数据集中总会或多或少出现错误标签,若能剔除掉这些错误标签在进行训练,效果会有一定的提升,所以可以考虑使用置信学习的方法去发现错误标签。

结论

本次比赛没能杀入B榜主要还是因为身为打工人,我们只能在下班和周末搞搞,很多想法并不能有足够的时间去实验,不像高校里的学生论文看的多,时间也足够。其次就是算力没跟上,显卡不够多,一个mlm模型要花一天多才能训练完成,然后加入对抗式学习速度更慢了。

整个任务的难度不大,一开始碰到脱敏数据我也懵了一阵,我也想到自己要训练一个预训练模型,但是不敢确定效果,但最后看来效果确实不错,所以深度学习这门技术还是得靠动手验证的科学,最后推荐一篇不错的论文Don’t Stop Pretraining: Adapt Language Models to Domains and Tasks。

这篇关于微调预训练模型方式的文本语义匹配(Further Pretraining Bert)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/266848

相关文章

gitlab安装及邮箱配置和常用使用方式

《gitlab安装及邮箱配置和常用使用方式》:本文主要介绍gitlab安装及邮箱配置和常用使用方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1.安装GitLab2.配置GitLab邮件服务3.GitLab的账号注册邮箱验证及其分组4.gitlab分支和标签的

C++中零拷贝的多种实现方式

《C++中零拷贝的多种实现方式》本文主要介绍了C++中零拷贝的实现示例,旨在在减少数据在内存中的不必要复制,从而提高程序性能、降低内存使用并减少CPU消耗,零拷贝技术通过多种方式实现,下面就来了解一下... 目录一、C++中零拷贝技术的核心概念二、std::string_view 简介三、std::stri

苹果macOS 26 Tahoe主题功能大升级:可定制图标/高亮文本/文件夹颜色

《苹果macOS26Tahoe主题功能大升级:可定制图标/高亮文本/文件夹颜色》在整体系统设计方面,macOS26采用了全新的玻璃质感视觉风格,应用于Dock栏、应用图标以及桌面小部件等多个界面... 科技媒体 MACRumors 昨日(6 月 13 日)发布博文,报道称在 macOS 26 Tahoe 中

Linux脚本(shell)的使用方式

《Linux脚本(shell)的使用方式》:本文主要介绍Linux脚本(shell)的使用方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录概述语法详解数学运算表达式Shell变量变量分类环境变量Shell内部变量自定义变量:定义、赋值自定义变量:引用、修改、删

Python实现精准提取 PDF中的文本,表格与图片

《Python实现精准提取PDF中的文本,表格与图片》在实际的系统开发中,处理PDF文件不仅限于读取整页文本,还有提取文档中的表格数据,图片或特定区域的内容,下面我们来看看如何使用Python实... 目录安装 python 库提取 PDF 文本内容:获取整页文本与指定区域内容获取页面上的所有文本内容获取

python判断文件是否存在常用的几种方式

《python判断文件是否存在常用的几种方式》在Python中我们在读写文件之前,首先要做的事情就是判断文件是否存在,否则很容易发生错误的情况,:本文主要介绍python判断文件是否存在常用的几种... 目录1. 使用 os.path.exists()2. 使用 os.path.isfile()3. 使用

Mybatis的分页实现方式

《Mybatis的分页实现方式》MyBatis的分页实现方式主要有以下几种,每种方式适用于不同的场景,且在性能、灵活性和代码侵入性上有所差异,对Mybatis的分页实现方式感兴趣的朋友一起看看吧... 目录​1. 原生 SQL 分页(物理分页)​​2. RowBounds 分页(逻辑分页)​​3. Page

Linux链表操作方式

《Linux链表操作方式》:本文主要介绍Linux链表操作方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、链表基础概念与内核链表优势二、内核链表结构与宏解析三、内核链表的优点四、用户态链表示例五、双向循环链表在内核中的实现优势六、典型应用场景七、调试技巧与

Linux实现线程同步的多种方式汇总

《Linux实现线程同步的多种方式汇总》本文详细介绍了Linux下线程同步的多种方法,包括互斥锁、自旋锁、信号量以及它们的使用示例,通过这些同步机制,可以解决线程安全问题,防止资源竞争导致的错误,示例... 目录什么是线程同步?一、互斥锁(单人洗手间规则)适用场景:特点:二、条件变量(咖啡厅取餐系统)工作流

RedisTemplate默认序列化方式显示中文乱码的解决

《RedisTemplate默认序列化方式显示中文乱码的解决》本文主要介绍了SpringDataRedis默认使用JdkSerializationRedisSerializer导致数据乱码,文中通过示... 目录1. 问题原因2. 解决方案3. 配置类示例4. 配置说明5. 使用示例6. 验证存储结果7.