12层的bert参数量_EMNLP 2019 | BERTPKD:一种基于PKD方法的BERT模型压缩

2023-12-08 07:50

本文主要是介绍12层的bert参数量_EMNLP 2019 | BERTPKD:一种基于PKD方法的BERT模型压缩,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

过去一年里,语言模型的研究有了许多突破性的进展,BERT、XLNet、RoBERTa等预训练语言模型作为特征提取器横扫各大NLP榜单。但这些模型的参数量也相当惊人,比如BERT-base有一亿零九百万参数,BERT-large的参数量则高达三亿三千万,从而导致模型的训练及推理速度过慢。本文提出了一种“耐心的知识蒸馏” (Patient Knowledge Distillation) 方法对模型进行压缩,PKD方法重新定义损失函数,使student模型的隐藏层表示更接近teacher模型的隐藏层表示,从而让student模型的泛化能力更强,在不损失太多精度的情况下,减小模型规模及推理时间。

59a42c2fee039ac5839697139fc78de6.png

研究者将提出的BERT-PKD模型与BERT模型微调 (fine-tuning) 和基线知识蒸馏模型在7个句子分类的基准数据集上比较,在12层teacher模型蒸馏到6层或者3层student模型的时候,绝大部分情况下PKD模型表现都优于两种基线模型。并且在五个数据集上SST-2 (相比于teacher模型-2.3%准确率),QQP (-0.1%),MNLI-m (-2.2%),MNLI-mm (-1.8%),and QNLI (-1.4%) 的表现接近于teacher模型。论文地址:https://arxiv.org/abs/1908.09355

引言

过去一年里,语言模型的研究有了许多突破性的进展,BERT、XLNet、RoBERTa等预训练语言模型作为特征提取器横扫各大NLP榜单。但这些模型的参数量也相当惊人,比如BERT-base有一亿零九百万参数,BERT-large的参数量则高达三亿三千万,从而导致模型的训练及推理速度过慢。本文提出了一种“耐心的知识蒸馏” (Patient Knowledge Distillation) 方法对模型进行压缩,PKD方法重新定义损失函数,使student模型的隐藏层表示更接近teacher模型的隐藏层表示,从而让student模型的泛化能力更强,在不损失太多精度的情况下,减小模型规模及推理时间。
具体来说,对于句子分类类型的任务,当普通的知识蒸馏模型用来对模型进行压缩的时候,通常带来较大精度损失。究其原因,student模型 (student model) 在学习的时候只是学到了teacher模型 (teacher model) 最终预测的概率分布,而忽略了中间隐藏层表示的学习。就像老师在教学生的时候,学生只记住了最终的答案,但是对于中间的过程却完全没有学习。这样在遇到新问题的时候,student模型犯错误的概率更高。BERT-PKD模型重新定义损失函数,使student模型的隐藏层表示接近teacher模型的隐藏层表示,从而让student模型的泛化能力更强。

模型

论文所提出的多层蒸馏,即student模型除了学习teacher模型的概率输出之外,还要学习一些中间层的输出。论文提出了两种方法,第一种是skip模式,即每隔几层去学习一个中间层,第二种是last模式,即学习teacher模型的最后几层。

823dcb3e7cac42de3d6d5a99c78cc545.png

对于最初的知识蒸馏方法,student模型使用交叉熵函数尽量拟合teacher模型的概率输出,同时student模型还需要学习ground truth,然后再让两个损失去做加权平均。

4150da028692c0ed9643e46b8e0f5cab.png

而PKD模型增加模型中间层的学习,同时为避免中间层学习计算量过大,让student模型仅学习[CLS]字符的中间层输出,使得模型能够同时学到[CLS]字符的各层的特征表示,对于中间层的学习,使用的损失函数是均方差函数。

8bb6182b2120e19d478ea80433f6e5ca.png

数据集

为了验证BERT-PKD的效果,本文在多个任务上将其与其他模型进行了比较。

实验结果

BERT-PKD模型、BERT模型微调 (fine-tuning) 和基线知识蒸馏模型在7个句子分类的基准数据集上比较如下表所示,在12层teacher模型蒸馏到6层或者3层student模型的时候,大部分情况下PKD的表现都优于同等规模的基线模型。并且在五个数据集上SST-2 (相比于teacher模型-2.3%准确率),QQP (-0.1%),MNLI-m (-2.2%),MNLI-mm (-1.8%),and QNLI (-1.4%) 的表现接近于teacher模型。这进一步验证了研究者的假设,学习了隐藏层表示的student模型优于只学模型预测概率的student模型。同时,student模型在MRPC任务上的表现较差,究其原因,可能是因为MRPC的数据较少,从而导致了过拟合。

b61eb564779c01c3e5b1d2ded08ec14f.png

BERT-PKD模型Last模式和Skip模式的对比如下表所示,Skip模式一般优于last模式。Skip模式下,层次之间的距离较远,从而让student学习到各种层次的信息。

d5b6695b362c10633eab8e3d133fa866.png

student模型的计算量和参数数目如下表所示。在速度方面,6层BERT-PKD模型可将推理 (inference) 速度提高两倍,总参数量减少1.64倍;而三层BERT-PKD模型可以提速3.73倍,总参数量减少2.4倍。PKD模型由于需每层计算损失,student模型和teacher模型隐藏层的宽度(隐藏层的向量维度)相同,这其实对于模型是一个限制,student模型的参数量主要在层数的减少。

cc8abe24a0e02e2a5eeb57ef24be215d.png

结论

本文提出一种基于BERT的知识蒸馏模型BERT-PKD,模型在GLUE基准大部分情况下的表现都优于同等规模的基线模型。并且在五个数据集上SST-2 (相比于teacher模型-2.3%准确率),QQP (-0.1%),MNLI-m (-2.2%),MNLI-mm (-1.8%),and QNLI (-1.4%) 的表现接近于teacher模型。同时,在速度方面,6层BERT-PKD模型可将推理 (inference) 速度提高两倍,总参数量减少1.64倍;而三层BERT-PKD模型可提速3.73倍,总参数量减少2.4倍。

这篇关于12层的bert参数量_EMNLP 2019 | BERTPKD:一种基于PKD方法的BERT模型压缩的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!


原文地址:https://blog.csdn.net/weixin_39735509/article/details/110127168
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.chinasem.cn/article/469101

相关文章

SpringBoot实现文件记录日志及日志文件自动归档和压缩

《SpringBoot实现文件记录日志及日志文件自动归档和压缩》Logback是Java日志框架,通过Logger收集日志并经Appender输出至控制台、文件等,SpringBoot配置logbac... 目录1、什么是Logback2、SpringBoot实现文件记录日志,日志文件自动归档和压缩2.1、

使用Python获取JS加载的数据的多种实现方法

《使用Python获取JS加载的数据的多种实现方法》在当今的互联网时代,网页数据的动态加载已经成为一种常见的技术手段,许多现代网站通过JavaScript(JS)动态加载内容,这使得传统的静态网页爬取... 目录引言一、动态 网页与js加载数据的原理二、python爬取JS加载数据的方法(一)分析网络请求1

MySQL查看表的最后一个ID的常见方法

《MySQL查看表的最后一个ID的常见方法》在使用MySQL数据库时,我们经常会遇到需要查看表中最后一个id值的场景,无论是为了调试、数据分析还是其他用途,了解如何快速获取最后一个id都是非常实用的技... 目录背景介绍方法一:使用MAX()函数示例代码解释适用场景方法二:按id降序排序并取第一条示例代码解

Python中合并列表(list)的六种方法小结

《Python中合并列表(list)的六种方法小结》本文主要介绍了Python中合并列表(list)的六种方法小结,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋... 目录一、直接用 + 合并列表二、用 extend() js方法三、用 zip() 函数交叉合并四、用

Java 中的跨域问题解决方法

《Java中的跨域问题解决方法》跨域问题本质上是浏览器的一种安全机制,与Java本身无关,但Java后端开发者需要理解其来源以便正确解决,下面给大家介绍Java中的跨域问题解决方法,感兴趣的朋友一起... 目录1、Java 中跨域问题的来源1.1. 浏览器同源策略(Same-Origin Policy)1.

Java Stream.reduce()方法操作实际案例讲解

《JavaStream.reduce()方法操作实际案例讲解》reduce是JavaStreamAPI中的一个核心操作,用于将流中的元素组合起来产生单个结果,:本文主要介绍JavaStream.... 目录一、reduce的基本概念1. 什么是reduce操作2. reduce方法的三种形式二、reduce

MybatisX快速生成增删改查的方法示例

《MybatisX快速生成增删改查的方法示例》MybatisX是基于IDEA的MyBatis/MyBatis-Plus开发插件,本文主要介绍了MybatisX快速生成增删改查的方法示例,文中通过示例代... 目录1 安装2 基本功能2.1 XML跳转2.2 代码生成2.2.1 生成.xml中的sql语句头2

python3 pip终端出现错误解决的方法详解

《python3pip终端出现错误解决的方法详解》这篇文章主要为大家详细介绍了python3pip如果在终端出现错误该如何解决,文中的示例方法讲解详细,感兴趣的小伙伴可以跟随小编一起了解一下... 目录前言一、查看是否已安装pip二、查看是否添加至环境变量1.查看环境变量是http://www.cppcns

Linux给磁盘扩容(LVM方式)的方法实现

《Linux给磁盘扩容(LVM方式)的方法实现》本文主要介绍了Linux给磁盘扩容(LVM方式)的方法实现,涵盖PV/VG/LV概念及操作步骤,具有一定的参考价值,感兴趣的可以了解一下... 目录1 概念2 实战2.1 相关基础命令2.2 开始给LVM扩容2.3 总结最近测试性能,在本地打数据时,发现磁盘空

使用Python实现调用API获取图片存储到本地的方法

《使用Python实现调用API获取图片存储到本地的方法》开发一个自动化工具,用于从JSON数据源中提取图像ID,通过调用指定API获取未经压缩的原始图像文件,并确保下载结果与Postman等工具直接... 目录使用python实现调用API获取图片存储到本地1、项目概述2、核心功能3、环境准备4、代码实现