浅谈知识蒸馏(Knowledge Distillation)

2024-02-09 03:58

本文主要是介绍浅谈知识蒸馏(Knowledge Distillation),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

浅谈知识蒸馏(Knowledge Distillation)

前言:

在实验室做算法研究时,我们最看重的一般是模型精度,因为精度是我们模型有效性的最直接证明。而在公司做研发时,除了算法精度,我们还很关注模型的大小和内存占用。因为实验室模型一般运行在服务器上,很少有运算资源不足的情况,但是公司研发的算法功能最终都是要部署到实际的产品上的,像手机或者小型计算平台,其运算资源是很有限的。所以算法工程师在公司做预研时,算法建模一般都分两部分:先根据需求建模,并尽可能提高模型精度;然后进行模型压缩,在保证算法精度的情况下尽可能减少其参数量。

常用的模型压缩方法: 知识蒸馏、权重共享、模型剪枝、网络量化以及低秩分解。本文我们主要介绍知识蒸馏。
模型剪枝示意图:左边为训练好的大模型,通过剪枝删除掉一切权值趋于零的节点,达到缩减模型参数的效果上图为模型剪枝示意图:左边为训练好的大模型,通过剪枝删除掉一切权值趋于零的节点,达到缩减模型参数的效果

知识蒸馏(Knowledge Distillation,KD):

我一直觉得训练神经网络的过程很像求解线性方程组,用已知的数据及标签(对应求解方程中的xy点对)来拟合一批模型参数(对应方程组的系数矩阵)。一般来讲,在数据量有限的情况下,如果我们的模型过大,就很容易出现过拟合现象,此时我们需要缩减模型参数,或者添加正则项。

但在数据量足够的情况下,网络模型越复杂、参数量越大,训练出的模型性能会越好,而较小的网络却很难达到大网络那么好的效果。要让一个小网络达到和大模型相近的性能,我们就需要换一个思路,让大模型在训练过程中学习到的知识迁移到小模型中,而这个迁移的过程就叫做知识蒸馏(Knowledge Distillation,KD)

知识蒸馏的开山之作为大佬Hinton发表在NIPS2014文章《Distilling the Knowledge in a Neural Network》。其主要思想是:在给定输入的情况下训练迁移模型(Student Network),让其输出与原模型(Teacher Network)的输出一致,从而达到将原模型学习到的知识迁移至小网络的目标。
在这里插入图片描述
上图为知识蒸馏模型训练示意图:左侧为大参数量的原模型(Teacher Network),右侧为小网络(Student Network)

训练过程中,原模型(Teacher)输出 vi 与小模型(Student)输出 zi 之间的一致性约束是知识蒸馏的关键所在,即最小化下式:
在这里插入图片描述
对于输出一致性约束,常用的一般为各种距离度量、或者K-L散度等。在神经网络模型中,训练模型就是让模型的softmax输出与Ground Truth匹配;而知识蒸馏任务中,我们需要让Student网络与Teacher网络的的softmax输出尽可能匹配。

下式定义为普通的Softmax函数:
在这里插入图片描述
从上面softmax函数的定义式中我们不难看出,它先通过指数函数拉大输出节点之间的差异,然后通过归一化输出一个接近one-hot的向量(其中一个值很大,其他值接近于0)。对于普通的分类等任务,这样的操作没什么问题,但在知识蒸馏中,这种one-hot形式的输出对于知识的体现很有限,并不利于Student网络的学习(容易放大错误分类的概率,引入不必要的噪声)。所以我们通过引入一个温度参数T来将softmax输出的hard分布转化为soft。

加温度参数T后的softmax定义如下:
在这里插入图片描述
上述公式可以理解为:将网络的输出除以温度参数T后再做softmax,这样可以获得比较soft的输出向量:向量中每个值介于0~1之间,各个值之间的差异没有one-hot那么大。并且T的数值越大,分布越缓和。

训练过程中,模型总体的损失函数由两部分组成如下所示:
在这里插入图片描述
其中,Alpha和Beta为权重参数,Lsoft 为Distill Loss,保证Student网络输出与Teacher网络输出保持一致性,其定义如下:
在这里插入图片描述
其中,pj 和 qj 分别为Teacher网络和Student网络在温度T下的softmax输出向量的第j个值。

因为Teacher网络虽然已经经过了预训练,但其输出也会有一定的误差,为了降低将这些误差迁移到Student网络的可能性,在训练时还添加了Lhard :通过Ground-truth对Student网络的约束损失,定义如下:
在这里插入图片描述

其中,cj 为第j个类别的Ground-truth,qj 为Student网络softmax输出向量的第j个值。

这篇关于浅谈知识蒸馏(Knowledge Distillation)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/693050

相关文章

Unity新手入门学习殿堂级知识详细讲解(图文)

《Unity新手入门学习殿堂级知识详细讲解(图文)》Unity是一款跨平台游戏引擎,支持2D/3D及VR/AR开发,核心功能模块包括图形、音频、物理等,通过可视化编辑器与脚本扩展实现开发,项目结构含A... 目录入门概述什么是 UnityUnity引擎基础认知编辑器核心操作Unity 编辑器项目模式分类工程

浅谈MySQL的容量规划

《浅谈MySQL的容量规划》进行MySQL的容量规划是确保数据库能够在当前和未来的负载下顺利运行的重要步骤,容量规划包括评估当前资源使用情况、预测未来增长、调整配置和硬件资源等,感兴趣的可以了解一下... 目录一、评估当前资源使用情况1.1 磁盘空间使用1.2 内存使用1.3 CPU使用1.4 网络带宽二、

浅谈mysql的not exists走不走索引

《浅谈mysql的notexists走不走索引》在MySQL中,​NOTEXISTS子句是否使用索引取决于子查询中关联字段是否建立了合适的索引,下面就来介绍一下mysql的notexists走不走索... 在mysql中,​NOT EXISTS子句是否使用索引取决于子查询中关联字段是否建立了合适的索引。以下

浅谈Redis Key 命名规范文档

《浅谈RedisKey命名规范文档》本文介绍了Redis键名命名规范,包括命名格式、具体规范、数据类型扩展命名、时间敏感型键名、规范总结以及实际应用示例,感兴趣的可以了解一下... 目录1. 命名格式格式模板:示例:2. 具体规范2.1 小写命名2.2 使用冒号分隔层级2.3 标识符命名3. 数据类型扩展命

一文详解Java异常处理你都了解哪些知识

《一文详解Java异常处理你都了解哪些知识》:本文主要介绍Java异常处理的相关资料,包括异常的分类、捕获和处理异常的语法、常见的异常类型以及自定义异常的实现,文中通过代码介绍的非常详细,需要的朋... 目录前言一、什么是异常二、异常的分类2.1 受检异常2.2 非受检异常三、异常处理的语法3.1 try-

浅谈配置MMCV环境,解决报错,版本不匹配问题

《浅谈配置MMCV环境,解决报错,版本不匹配问题》:本文主要介绍浅谈配置MMCV环境,解决报错,版本不匹配问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录配置MMCV环境,解决报错,版本不匹配错误示例正确示例总结配置MMCV环境,解决报错,版本不匹配在col

浅谈mysql的sql_mode可能会限制你的查询

《浅谈mysql的sql_mode可能会限制你的查询》本文主要介绍了浅谈mysql的sql_mode可能会限制你的查询,这个问题主要说明的是,我们写的sql查询语句违背了聚合函数groupby的规则... 目录场景:问题描述原因分析:解决方案:第一种:修改后,只有当前生效,若是mysql服务重启,就会失效;

国内环境搭建私有知识问答库踩坑记录(ollama+deepseek+ragflow)

《国内环境搭建私有知识问答库踩坑记录(ollama+deepseek+ragflow)》本文给大家利用deepseek模型搭建私有知识问答库的详细步骤和遇到的问题及解决办法,感兴趣的朋友一起看看吧... 目录1. 第1步大家在安装完ollama后,需要到系统环境变量中添加两个变量2. 第3步 “在cmd中

Spring核心思想之浅谈IoC容器与依赖倒置(DI)

《Spring核心思想之浅谈IoC容器与依赖倒置(DI)》文章介绍了Spring的IoC和DI机制,以及MyBatis的动态代理,通过注解和反射,Spring能够自动管理对象的创建和依赖注入,而MyB... 目录一、控制反转 IoC二、依赖倒置 DI1. 详细概念2. Spring 中 DI 的实现原理三、

Java架构师知识体认识

源码分析 常用设计模式 Proxy代理模式Factory工厂模式Singleton单例模式Delegate委派模式Strategy策略模式Prototype原型模式Template模板模式 Spring5 beans 接口实例化代理Bean操作 Context Ioc容器设计原理及高级特性Aop设计原理Factorybean与Beanfactory Transaction 声明式事物