如何只用bert夺冠之对比学习代码解读

2024-03-19 17:50

本文主要是介绍如何只用bert夺冠之对比学习代码解读,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

有监督对比学习:Supervised Contrastive Learning:
https://zhuanlan.zhihu.com/p/136332151

1. 自监督对比学习

一句话总结:不使用label数据,通过数据增强构造样本,使特征提取器提取的特征在增强样本和原始样本的距离更近,其他数据特征与原始样本的更远来训练特征提取器的方法。
关键思想:一个batch=n的数据,增强一次变成2n的数据,loss函数如下:
在这里插入图片描述

2. 监督对比学习

把标签数据加进来,但是计算loss还是以一个batch=n,自监督计算loss的思想来做。
在这里插入图片描述

问:真的有用吗?

3. 代码实现

  • 核心代码总结
    • 将label作为输入,当然结合了bert的原始输入,另外预测的时候,还是通过嫁接一个最基础的分类model来进行预测和输出。
      训练的时候,没有指定loss函数,只是在train_model中以监督对比loss和交叉熵loss最为输出,并继承与Loss类并指定了add_metric(loss),这样可能就可以了,反正也是给定一个数据,使loss最小
    • 优缺点分析:
      一个batch里面,正负样本就那么点,并且正样本之间的相似度不一定要高啊,在这个任务场景下,所以对比训练这个方式提升不大吧。
      可能在语义相似度计算任务上,真正的语义相似度计算而不是问答任务上,效果可能会好点吧。
    • 代码链接:https://github.com/xv44586/ccf_2020_qa_match/blob/main/pair-supervised-contrastive-learning.py
class SupervisedContrastiveLearning(Loss):"""https://arxiv.org/pdf/2011.01403.pdf"""def __init__(self, alpha=1., T=1., **kwargs):super(SupervisedContrastiveLearning, self).__init__(**kwargs)self.alpha = alpha  # loss weightself.T = T  # Temperaturedef compute_loss(self, inputs, mask=None):loss = self.compute_loss_of_scl(inputs)loss = loss * self.alphaself.add_metric(loss, name='scl_loss')return lossdef get_label_mask(self, y_true):"""获取batch内相同label样本"""label = K.cast(y_true, 'int32') # 转换数据类型label_2 = K.reshape(label, (1, -1)) # reshape成一行mask = K.equal(label_2, label) # 这两个shape都不一样,出来的是啥?知道了,应该原来是一列,现在换成一行,所以可以比较每个位置的label是否一样了mask = K.cast(mask, K.floatx()) # 又把它转成float类型,这样lable相等的位置为1.0,不相等的时候为0.0mask = mask * (1 - K.eye(K.shape(y_true)[0]))  # 排除对角线,即 i == j,对角线的位置的值全设置为0return maskdef compute_loss_of_scl(self, inputs, mask=None):y_pred, y_true = inputslabel_mask = self.get_label_mask(y_true) # mask是个二维矩阵,告诉i,j位置的lable是否一样y_pred = K.l2_normalize(y_pred, axis=1)  # 特征向量归一化similarities = K.dot(y_pred, K.transpose(y_pred))  # 相似矩阵,相当于是点乘similarities = similarities - K.eye(K.shape(y_pred)[0]) * 1e12  # 排除对角线,即 i == j,点乘然后排出对角线的位置similarities = similarities / self.T  # Temperature scalesimilarities = K.exp(similarities)  # expsum_similarities = K.sum(similarities, axis=-1, keepdims=True)  # sum i != k, 求和得到的应该是分母scl = similarities / sum_similarities # 这里算的还是全部的作为分子,但是我们只要把i和j位置label相同的作为分子,所以还要乘以maskscl = K.log((scl + K.epsilon()))  # sum log,取logscl = -K.sum(scl * label_mask, axis=1, keepdims=True) / (K.sum(label_mask, axis=1, keepdims=True) + K.epsilon()) # 乘以maskreturn K.mean(scl)class CrossEntropy(Loss):def compute_loss(self, inputs, mask=None):pred, ytrue = inputsytrue = K.cast(ytrue, K.floatx())loss = K.binary_crossentropy(ytrue, pred)loss = K.mean(loss)self.add_metric(loss, name='clf_loss')return loss# 加载预训练模型
bert = build_transformer_model(config_path=config_path,checkpoint_path=checkpoint_path,model='nezha',keep_tokens=keep_tokens,num_hidden_layers=12,
)# 将label作为输入
y_in = Input(shape=(None,))output = Lambda(lambda x: x[:, 0])(bert.output)
# output相当于是特征,监督对比函数相当于是利用了label类别信息求了个loss
scl_output = SupervisedContrastiveLearning(alpha=0.1, T=0.2, output_idx=0)([output, y_in])
# scl_output是监督对比函数的loss
output = Dropout(0.1)(output)
# 这个是分类的概率
clf_output = Dense(1, activation='sigmoid')(output)
# 这个是分类的loss
clf = CrossEntropy(0)([clf_output, y_in])
# clf是分类的loss
# model模型还是以bert作为输入,分类的概率clf_output作为输出
model = keras.models.Model(bert.input, clf_output)
model.summary()
# train_model模型是用于训练的模型,bert+label作为输入,scl_output是监督对比函数的loss与clf是分类的loss作为输出
# 将loss函数作为输出,后面complile就不用指定loss函数了吗?还有这种操作,那为啥要两个loss函数啊?
train_model = keras.models.Model(bert.input + [y_in], [scl_output, clf])optimizer = extend_with_weight_decay(Adam)
optimizer = extend_with_piecewise_linear_lr(optimizer)
opt = optimizer(learning_rate=1e-5, weight_decay_rate=0.1, exclude_from_weight_decay=['Norm', 'bias'],lr_schedule={int(len(train_generator) * 0.1 * epochs): 1, len(train_generator) * epochs: 0})train_model.compile(optimizer=opt,# 但是这里没有指定loss啊,所以loss是啥,所以
)# 正常的分类模型是这么做的,在compile里面指定了loss函数
"""
model = keras.models.Model(bert.input, output)
model.summary()model.compile(# 指定了loss函数loss=K.binary_crossentropy,optimizer=Adam(2e-5),  # 用足够小的学习率metrics=['accuracy'],
)class Evaluator(keras.callbacks.Callback):"""评估与保存"""def __init__(self):self.best_val_f1 = 0.def on_epoch_end(self, epoch, logs=None):val_f1 = evaluate(valid_generator)if val_f1 > self.best_val_f1:self.best_val_f1 = val_f1model.save_weights('best_parimatch_model.weights')print(u'val_f1: %.5f, best_val_f1: %.5f\n' %(val_f1, self.best_val_f1))evaluator = Evaluator()
model.fit_generator(train_generator.generator(),steps_per_epoch=len(train_generator),epochs=5,callbacks=[evaluator],)"""def evaluate(data):P, R, TP = 0., 0., 0.for x, _ in tqdm(data):x_true = x[:2]y_true = x[-1]y_pred = model.predict(x_true)[:, 0]y_pred = np.round(y_pred)y_true = y_true[:, 0]R += y_pred.sum()P += y_true.sum()TP += ((y_pred + y_true) > 1).sum()print(P, R, TP)pre = TP / Rrec = TP / Preturn 2 * (pre * rec) / (pre + rec)class Evaluator(keras.callbacks.Callback):"""评估与保存"""def __init__(self, save_path):self.best_val_f1 = 0.self.save_path = save_pathdef on_epoch_end(self, epoch, logs=None):val_f1 = evaluate(valid_generator)if val_f1 > self.best_val_f1:self.best_val_f1 = val_f1model.save_weights(self.save_path)print(u'val_f1: %.5f, best_val_f1: %.5f\n' %(val_f1, self.best_val_f1))def predict_to_file(path='pair_submission.tsv'):preds = []for x, _ in tqdm(test_generator):x = x[:2]pred = model.predict(x).argmax(axis=1)#         pred = np.round(pred)pred = pred.astype(int)preds.append(pred)preds = np.concatenate(preds)ret = []for d, p in zip(test_data, preds):q_id, _, r_id, _, _ = dret.append([str(q_id), str(r_id), str(p)])with open(path, 'w', encoding='utf8') as f:for l in ret:f.write('\t'.join(l) + '\n')if __name__ == '__main__':save_path = 'best_pair_scl_model.weights'evaluator = Evaluator(save_path)train_model.fit_generator(train_generator.generator(),steps_per_epoch=len(train_generator),epochs=epochs,callbacks=[evaluator],)model.load_weights(save_path)predict_to_file('pair_scl.tsv')

这篇关于如何只用bert夺冠之对比学习代码解读的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/826787

相关文章

Django开发时如何避免频繁发送短信验证码(python图文代码)

《Django开发时如何避免频繁发送短信验证码(python图文代码)》Django开发时,为防止频繁发送验证码,后端需用Redis限制请求频率,结合管道技术提升效率,通过生产者消费者模式解耦业务逻辑... 目录避免频繁发送 验证码1. www.chinasem.cn避免频繁发送 验证码逻辑分析2. 避免频繁

精选20个好玩又实用的的Python实战项目(有图文代码)

《精选20个好玩又实用的的Python实战项目(有图文代码)》文章介绍了20个实用Python项目,涵盖游戏开发、工具应用、图像处理、机器学习等,使用Tkinter、PIL、OpenCV、Kivy等库... 目录① 猜字游戏② 闹钟③ 骰子模拟器④ 二维码⑤ 语言检测⑥ 加密和解密⑦ URL缩短⑧ 音乐播放

Python使用Tenacity一行代码实现自动重试详解

《Python使用Tenacity一行代码实现自动重试详解》tenacity是一个专为Python设计的通用重试库,它的核心理念就是用简单、清晰的方式,为任何可能失败的操作添加重试能力,下面我们就来看... 目录一切始于一个简单的 API 调用Tenacity 入门:一行代码实现优雅重试精细控制:让重试按我

C语言中%zu的用法解读

《C语言中%zu的用法解读》size_t是无符号整数类型,用于表示对象大小或内存操作结果,%zu是C99标准中专为size_t设计的printf占位符,避免因类型不匹配导致错误,使用%u或%d可能引发... 目录size_t 类型与 %zu 占位符%zu 的用途替代占位符的风险兼容性说明其他相关占位符验证示

MySQL中EXISTS与IN用法使用与对比分析

《MySQL中EXISTS与IN用法使用与对比分析》在MySQL中,EXISTS和IN都用于子查询中根据另一个查询的结果来过滤主查询的记录,本文将基于工作原理、效率和应用场景进行全面对比... 目录一、基本用法详解1. IN 运算符2. EXISTS 运算符二、EXISTS 与 IN 的选择策略三、性能对比

Linux系统之lvcreate命令使用解读

《Linux系统之lvcreate命令使用解读》lvcreate是LVM中创建逻辑卷的核心命令,支持线性、条带化、RAID、镜像、快照、瘦池和缓存池等多种类型,实现灵活存储资源管理,需注意空间分配、R... 目录lvcreate命令详解一、命令概述二、语法格式三、核心功能四、选项详解五、使用示例1. 创建逻

详解MySQL中JSON数据类型用法及与传统JSON字符串对比

《详解MySQL中JSON数据类型用法及与传统JSON字符串对比》MySQL从5.7版本开始引入了JSON数据类型,专门用于存储JSON格式的数据,本文将为大家简单介绍一下MySQL中JSON数据类型... 目录前言基本用法jsON数据类型 vs 传统JSON字符串1. 存储方式2. 查询方式对比3. 索引

Python实现MQTT通信的示例代码

《Python实现MQTT通信的示例代码》本文主要介绍了Python实现MQTT通信的示例代码,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一... 目录1. 安装paho-mqtt库‌2. 搭建MQTT代理服务器(Broker)‌‌3. pytho

MySQL进行数据库审计的详细步骤和示例代码

《MySQL进行数据库审计的详细步骤和示例代码》数据库审计通过触发器、内置功能及第三方工具记录和监控数据库活动,确保安全、完整与合规,Java代码实现自动化日志记录,整合分析系统提升监控效率,本文给大... 目录一、数据库审计的基本概念二、使用触发器进行数据库审计1. 创建审计表2. 创建触发器三、Java

SpringBoot中六种批量更新Mysql的方式效率对比分析

《SpringBoot中六种批量更新Mysql的方式效率对比分析》文章比较了MySQL大数据量批量更新的多种方法,指出REPLACEINTO和ONDUPLICATEKEY效率最高但存在数据风险,MyB... 目录效率比较测试结构数据库初始化测试数据批量修改方案第一种 for第二种 case when第三种