12层的bert参数量_EMNLP 2019 | BERTPKD:一种基于PKD方法的BERT模型压缩

2023-12-08 07:50

本文主要是介绍12层的bert参数量_EMNLP 2019 | BERTPKD:一种基于PKD方法的BERT模型压缩,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

过去一年里,语言模型的研究有了许多突破性的进展,BERT、XLNet、RoBERTa等预训练语言模型作为特征提取器横扫各大NLP榜单。但这些模型的参数量也相当惊人,比如BERT-base有一亿零九百万参数,BERT-large的参数量则高达三亿三千万,从而导致模型的训练及推理速度过慢。本文提出了一种“耐心的知识蒸馏” (Patient Knowledge Distillation) 方法对模型进行压缩,PKD方法重新定义损失函数,使student模型的隐藏层表示更接近teacher模型的隐藏层表示,从而让student模型的泛化能力更强,在不损失太多精度的情况下,减小模型规模及推理时间。

59a42c2fee039ac5839697139fc78de6.png

研究者将提出的BERT-PKD模型与BERT模型微调 (fine-tuning) 和基线知识蒸馏模型在7个句子分类的基准数据集上比较,在12层teacher模型蒸馏到6层或者3层student模型的时候,绝大部分情况下PKD模型表现都优于两种基线模型。并且在五个数据集上SST-2 (相比于teacher模型-2.3%准确率),QQP (-0.1%),MNLI-m (-2.2%),MNLI-mm (-1.8%),and QNLI (-1.4%) 的表现接近于teacher模型。论文地址:https://arxiv.org/abs/1908.09355

引言

过去一年里,语言模型的研究有了许多突破性的进展,BERT、XLNet、RoBERTa等预训练语言模型作为特征提取器横扫各大NLP榜单。但这些模型的参数量也相当惊人,比如BERT-base有一亿零九百万参数,BERT-large的参数量则高达三亿三千万,从而导致模型的训练及推理速度过慢。本文提出了一种“耐心的知识蒸馏” (Patient Knowledge Distillation) 方法对模型进行压缩,PKD方法重新定义损失函数,使student模型的隐藏层表示更接近teacher模型的隐藏层表示,从而让student模型的泛化能力更强,在不损失太多精度的情况下,减小模型规模及推理时间。
具体来说,对于句子分类类型的任务,当普通的知识蒸馏模型用来对模型进行压缩的时候,通常带来较大精度损失。究其原因,student模型 (student model) 在学习的时候只是学到了teacher模型 (teacher model) 最终预测的概率分布,而忽略了中间隐藏层表示的学习。就像老师在教学生的时候,学生只记住了最终的答案,但是对于中间的过程却完全没有学习。这样在遇到新问题的时候,student模型犯错误的概率更高。BERT-PKD模型重新定义损失函数,使student模型的隐藏层表示接近teacher模型的隐藏层表示,从而让student模型的泛化能力更强。

模型

论文所提出的多层蒸馏,即student模型除了学习teacher模型的概率输出之外,还要学习一些中间层的输出。论文提出了两种方法,第一种是skip模式,即每隔几层去学习一个中间层,第二种是last模式,即学习teacher模型的最后几层。

823dcb3e7cac42de3d6d5a99c78cc545.png

对于最初的知识蒸馏方法,student模型使用交叉熵函数尽量拟合teacher模型的概率输出,同时student模型还需要学习ground truth,然后再让两个损失去做加权平均。

4150da028692c0ed9643e46b8e0f5cab.png

而PKD模型增加模型中间层的学习,同时为避免中间层学习计算量过大,让student模型仅学习[CLS]字符的中间层输出,使得模型能够同时学到[CLS]字符的各层的特征表示,对于中间层的学习,使用的损失函数是均方差函数。

8bb6182b2120e19d478ea80433f6e5ca.png

数据集

为了验证BERT-PKD的效果,本文在多个任务上将其与其他模型进行了比较。

实验结果

BERT-PKD模型、BERT模型微调 (fine-tuning) 和基线知识蒸馏模型在7个句子分类的基准数据集上比较如下表所示,在12层teacher模型蒸馏到6层或者3层student模型的时候,大部分情况下PKD的表现都优于同等规模的基线模型。并且在五个数据集上SST-2 (相比于teacher模型-2.3%准确率),QQP (-0.1%),MNLI-m (-2.2%),MNLI-mm (-1.8%),and QNLI (-1.4%) 的表现接近于teacher模型。这进一步验证了研究者的假设,学习了隐藏层表示的student模型优于只学模型预测概率的student模型。同时,student模型在MRPC任务上的表现较差,究其原因,可能是因为MRPC的数据较少,从而导致了过拟合。

b61eb564779c01c3e5b1d2ded08ec14f.png

BERT-PKD模型Last模式和Skip模式的对比如下表所示,Skip模式一般优于last模式。Skip模式下,层次之间的距离较远,从而让student学习到各种层次的信息。

d5b6695b362c10633eab8e3d133fa866.png

student模型的计算量和参数数目如下表所示。在速度方面,6层BERT-PKD模型可将推理 (inference) 速度提高两倍,总参数量减少1.64倍;而三层BERT-PKD模型可以提速3.73倍,总参数量减少2.4倍。PKD模型由于需每层计算损失,student模型和teacher模型隐藏层的宽度(隐藏层的向量维度)相同,这其实对于模型是一个限制,student模型的参数量主要在层数的减少。

cc8abe24a0e02e2a5eeb57ef24be215d.png

结论

本文提出一种基于BERT的知识蒸馏模型BERT-PKD,模型在GLUE基准大部分情况下的表现都优于同等规模的基线模型。并且在五个数据集上SST-2 (相比于teacher模型-2.3%准确率),QQP (-0.1%),MNLI-m (-2.2%),MNLI-mm (-1.8%),and QNLI (-1.4%) 的表现接近于teacher模型。同时,在速度方面,6层BERT-PKD模型可将推理 (inference) 速度提高两倍,总参数量减少1.64倍;而三层BERT-PKD模型可提速3.73倍,总参数量减少2.4倍。

这篇关于12层的bert参数量_EMNLP 2019 | BERTPKD:一种基于PKD方法的BERT模型压缩的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/469101

相关文章

Python安装Pandas库的两种方法

《Python安装Pandas库的两种方法》本文介绍了三种安装PythonPandas库的方法,通过cmd命令行安装并解决版本冲突,手动下载whl文件安装,更换国内镜像源加速下载,最后建议用pipli... 目录方法一:cmd命令行执行pip install pandas方法二:找到pandas下载库,然后

Linux系统中查询JDK安装目录的几种常用方法

《Linux系统中查询JDK安装目录的几种常用方法》:本文主要介绍Linux系统中查询JDK安装目录的几种常用方法,方法分别是通过update-alternatives、Java命令、环境变量及目... 目录方法 1:通过update-alternatives查询(推荐)方法 2:检查所有已安装的 JDK方

SQL Server安装时候没有中文选项的解决方法

《SQLServer安装时候没有中文选项的解决方法》用户安装SQLServer时界面全英文,无中文选项,通过修改安装设置中的国家或地区为中文中国,重启安装程序后界面恢复中文,解决了问题,对SQLSe... 你是不是在安装SQL Server时候发现安装界面和别人不同,并且无论如何都没有中文选项?这个问题也

Java Thread中join方法使用举例详解

《JavaThread中join方法使用举例详解》JavaThread中join()方法主要是让调用改方法的thread完成run方法里面的东西后,在执行join()方法后面的代码,这篇文章主要介绍... 目录前言1.join()方法的定义和作用2.join()方法的三个重载版本3.join()方法的工作原

go动态限制并发数量的实现示例

《go动态限制并发数量的实现示例》本文主要介绍了Go并发控制方法,通过带缓冲通道和第三方库实现并发数量限制,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面... 目录带有缓冲大小的通道使用第三方库其他控制并发的方法因为go从语言层面支持并发,所以面试百分百会问到

在MySQL中实现冷热数据分离的方法及使用场景底层原理解析

《在MySQL中实现冷热数据分离的方法及使用场景底层原理解析》MySQL冷热数据分离通过分表/分区策略、数据归档和索引优化,将频繁访问的热数据与冷数据分开存储,提升查询效率并降低存储成本,适用于高并发... 目录实现冷热数据分离1. 分表策略2. 使用分区表3. 数据归档与迁移在mysql中实现冷热数据分

Spring Boot从main方法到内嵌Tomcat的全过程(自动化流程)

《SpringBoot从main方法到内嵌Tomcat的全过程(自动化流程)》SpringBoot启动始于main方法,创建SpringApplication实例,初始化上下文,准备环境,刷新容器并... 目录1. 入口:main方法2. SpringApplication初始化2.1 构造阶段3. 运行阶

Olingo分析和实践之ODataImpl详细分析(重要方法详解)

《Olingo分析和实践之ODataImpl详细分析(重要方法详解)》ODataImpl.java是ApacheOlingoOData框架的核心工厂类,负责创建序列化器、反序列化器和处理器等组件,... 目录概述主要职责类结构与继承关系核心功能分析1. 序列化器管理2. 反序列化器管理3. 处理器管理重要方

Python错误AttributeError: 'NoneType' object has no attribute问题的彻底解决方法

《Python错误AttributeError:NoneTypeobjecthasnoattribute问题的彻底解决方法》在Python项目开发和调试过程中,经常会碰到这样一个异常信息... 目录问题背景与概述错误解读:AttributeError: 'NoneType' object has no at

postgresql使用UUID函数的方法

《postgresql使用UUID函数的方法》本文给大家介绍postgresql使用UUID函数的方法,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录PostgreSQL有两种生成uuid的方法。可以先通过sql查看是否已安装扩展函数,和可以安装的扩展函数