知识蒸馏的蒸馏损失方法代码总结(包括:基于logits的方法:KLDiv,dist,dkd等,基于中间层提示的方法:)

本文主要是介绍知识蒸馏的蒸馏损失方法代码总结(包括:基于logits的方法:KLDiv,dist,dkd等,基于中间层提示的方法:),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

有两种知识蒸馏方法:一种利用教师模型的输出概率(基于logits的方法)[15,14,11],另一种利用教师模型的中间表示(基于提示的方法)[12,13,18,17]。基于logits的方法利用教师的输出作为辅助信号来训练一个较小的模型,即学生模型:

利用教师模型的输出概率(基于logits的方法)

该类方法损失函数为:
在这里插入图片描述

DIST

Tao Huang,Shan You,Fei Wang,Chen Qian,and Chang Xu.Knowledge distillation from a strongerteacher.In Advances in Neural Information Processing Systems,2022.

import torch.nn as nndef cosine_similarity(a, b, eps=1e-8):return (a * b).sum(1) / (a.norm(dim=1) * b.norm(dim=1) + eps)def pearson_correlation(a, b, eps=1e-8):return cosine_similarity(a - a.mean(1).unsqueeze(1),b - b.mean(1).unsqueeze(1), eps)def inter_class_relation(soft_student_outputs, soft_teacher_outputs):return 1 - pearson_correlation(soft_student_outputs, soft_teacher_outputs).mean()def intra_class_relation(soft_student_outputs, soft_teacher_outputs):return inter_class_relation(soft_student_outputs.transpose(0, 1), soft_teacher_outputs.transpose(0, 1))class DIST(nn.Module):def __init__(self, beta=1.0, gamma=1.0, temp=1.0):super(DIST, self).__init__()self.beta = betaself.gamma = gammaself.temp = tempdef forward(self, student_preds, teacher_preds, **kwargs):soft_student_outputs = (student_preds / self.temp).softmax(dim=1)soft_teacher_outputs = (teacher_preds / self.temp).softmax(dim=1)inter_loss = self.temp ** 2 * inter_class_relation(soft_student_outputs, soft_teacher_outputs)intra_loss = self.temp ** 2 * intra_class_relation(soft_student_outputs, soft_teacher_outputs)kd_loss = self.beta * inter_loss + self.gamma * intra_lossreturn kd_loss

KLDiv (2015年的原始方法)

import torch.nn as nn
import torch.nn.functional as F# loss = alpha * hard_loss + (1-alpha) * kd_loss,此处是单单的kd_loss
class KLDiv(nn.Module):def __init__(self, temp=1.0):super(KLDiv, self).__init__()self.temp = tempdef forward(self, student_preds, teacher_preds, **kwargs):soft_student_outputs = F.log_softmax(student_preds / self.temp, dim=1)soft_teacher_outputs = F.softmax(teacher_preds / self.temp, dim=1)kd_loss = F.kl_div(soft_student_outputs, soft_teacher_outputs, reduction="none").sum(1).mean()kd_loss *= self.temp ** 2return kd_loss

dkd (Decoupled KD(CVPR 2022) )

Borui Zhao,Quan Cui,Renjie Song,Yiyu Qiu,and Jiajun Liang.Decoupled knowledge distillation.InIEEE/CVF Conference on Computer Vision and Pattern Recognition,2022.

import torch
import torch.nn as nn
import torch.nn.functional as Fdef dkd_loss(logits_student, logits_teacher, target, alpha, beta, temperature):gt_mask = _get_gt_mask(logits_student, target)other_mask = _get_other_mask(logits_student, target)pred_student = F.softmax(logits_student / temperature, dim=1)pred_teacher = F.softmax(logits_teacher / temperature, dim=1)pred_student = cat_mask(pred_student, gt_mask, other_mask)pred_teacher = cat_mask(pred_teacher, gt_mask, other_mask)log_pred_student = torch.log(pred_student)tckd_loss = (F.kl_div(log_pred_student, pred_teacher, reduction='batchmean')* (temperature ** 2))pred_teacher_part2 = F.softmax(logits_teacher / temperature - 1000.0 * gt_mask, dim=1)log_pred_student_part2 = F.log_softmax(logits_student / temperature - 1000.0 * gt_mask, dim=1)nckd_loss = (F.kl_div(log_pred_student_part2, pred_teacher_part2, reduction='batchmean')* (temperature ** 2))return alpha * tckd_loss + beta * nckd_lossdef _get_gt_mask(logits, target):target = target.reshape(-1)mask = torch.zeros_like(logits).scatter_(1, target.unsqueeze(1), 1).bool()return maskdef _get_other_mask(logits, target):target = target.reshape(-1)mask = torch.ones_like(logits).scatter_(1, target.unsqueeze(1), 0).bool()return maskdef cat_mask(t, mask1, mask2):t1 = (t * mask1).sum(dim=1, keepdims=True)t2 = (t * mask2).sum(1, keepdims=True)rt = torch.cat([t1, t2], dim=1)return rtclass DKD(nn.Module):def __init__(self, alpha=1., beta=2., temperature=1.):super(DKD, self).__init__()self.alpha = alphaself.beta = betaself.temperature = temperaturedef forward(self, z_s, z_t, **kwargs):target = kwargs['target']if len(target.shape) == 2:  # mixup / smoothingtarget = target.max(1)[1]kd_loss = dkd_loss(z_s, z_t, target, self.alpha, self.beta, self.temperature)return kd_loss

利用教师模型的中间表示(基于提示的方法)

该类方法损失函数为:
[ L_{hint} = D_{hint}(T_s(F_s), T_t(F_t)) ]

ReviewKD (CVPR2021)

论文:

Pengguang Chen,Shu Liu,Hengshuang Zhao,and Jiaya Jia.Distilling knowledge via knowledge review.In IEEE/CVF Conference on Computer Vision and Pattern Recognition,2021.

代码:

https://github.com/dvlab-research/ReviewKD

Adriana Romero,Nicolas Ballas,Samira Ebrahimi Kahou,Antoine Chassang,Carlo Gatta,and YoshuaBengio.Fitnets:Hints for thin deep nets.arXiv preprint arXiv:1412.6550,2014.

Yonglong Tian,Dilip Krishnan,and Phillip Isola.Contrastive representation distillation.In IEEE/CVFInternational Conference on Learning Representations,2020.

Baoyun Peng,Xiao Jin,Jiaheng Liu,Dongsheng Li,Yichao Wu,Yu Liu,Shunfeng Zhou,and ZhaoningZhang.Correlation congruence for knowledge distillation.In International Conference on ComputerVision,2019.

关于知识蒸馏损失函数的文章

FitNet(ICLR 2015)、Attention(ICLR 2017)、Relational KD(CVPR 2019)、ICKD (ICCV 2021)、Decoupled KD(CVPR 2022) 、ReviewKD(CVPR 2021)等方法的介绍:

https://zhuanlan.zhihu.com/p/603748226?utm_id=0

待更新

这篇关于知识蒸馏的蒸馏损失方法代码总结(包括:基于logits的方法:KLDiv,dist,dkd等,基于中间层提示的方法:)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/456622

相关文章

电脑提示xlstat4.dll丢失怎么修复? xlstat4.dll文件丢失处理办法

《电脑提示xlstat4.dll丢失怎么修复?xlstat4.dll文件丢失处理办法》长时间使用电脑,大家多少都会遇到类似dll文件丢失的情况,不过,解决这一问题其实并不复杂,下面我们就来看看xls... 在Windows操作系统中,xlstat4.dll是一个重要的动态链接库文件,通常用于支持各种应用程序

Python常用命令提示符使用方法详解

《Python常用命令提示符使用方法详解》在学习python的过程中,我们需要用到命令提示符(CMD)进行环境的配置,:本文主要介绍Python常用命令提示符使用方法的相关资料,文中通过代码介绍的... 目录一、python环境基础命令【Windows】1、检查Python是否安装2、 查看Python的安

Python实例题之pygame开发打飞机游戏实例代码

《Python实例题之pygame开发打飞机游戏实例代码》对于python的学习者,能够写出一个飞机大战的程序代码,是不是感觉到非常的开心,:本文主要介绍Python实例题之pygame开发打飞机... 目录题目pygame-aircraft-game使用 Pygame 开发的打飞机游戏脚本代码解释初始化部

Maven 配置中的 <mirror>绕过 HTTP 阻断机制的方法

《Maven配置中的<mirror>绕过HTTP阻断机制的方法》:本文主要介绍Maven配置中的<mirror>绕过HTTP阻断机制的方法,本文给大家分享问题原因及解决方案,感兴趣的朋友一... 目录一、问题场景:升级 Maven 后构建失败二、解决方案:通过 <mirror> 配置覆盖默认行为1. 配置示

SpringBoot排查和解决JSON解析错误(400 Bad Request)的方法

《SpringBoot排查和解决JSON解析错误(400BadRequest)的方法》在开发SpringBootRESTfulAPI时,客户端与服务端的数据交互通常使用JSON格式,然而,JSON... 目录问题背景1. 问题描述2. 错误分析解决方案1. 手动重新输入jsON2. 使用工具清理JSON3.

使用jenv工具管理多个JDK版本的方法步骤

《使用jenv工具管理多个JDK版本的方法步骤》jenv是一个开源的Java环境管理工具,旨在帮助开发者在同一台机器上轻松管理和切换多个Java版本,:本文主要介绍使用jenv工具管理多个JD... 目录一、jenv到底是干啥的?二、jenv的核心功能(一)管理多个Java版本(二)支持插件扩展(三)环境隔

SQL中JOIN操作的条件使用总结与实践

《SQL中JOIN操作的条件使用总结与实践》在SQL查询中,JOIN操作是多表关联的核心工具,本文将从原理,场景和最佳实践三个方面总结JOIN条件的使用规则,希望可以帮助开发者精准控制查询逻辑... 目录一、ON与WHERE的本质区别二、场景化条件使用规则三、最佳实践建议1.优先使用ON条件2.WHERE用

Java中Map.Entry()含义及方法使用代码

《Java中Map.Entry()含义及方法使用代码》:本文主要介绍Java中Map.Entry()含义及方法使用的相关资料,Map.Entry是Java中Map的静态内部接口,用于表示键值对,其... 目录前言 Map.Entry作用核心方法常见使用场景1. 遍历 Map 的所有键值对2. 直接修改 Ma

Mybatis Plus Join使用方法示例详解

《MybatisPlusJoin使用方法示例详解》:本文主要介绍MybatisPlusJoin使用方法示例详解,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,... 目录1、pom文件2、yaml配置文件3、分页插件4、示例代码:5、测试代码6、和PageHelper结合6

Java中实现线程的创建和启动的方法

《Java中实现线程的创建和启动的方法》在Java中,实现线程的创建和启动是两个不同但紧密相关的概念,理解为什么要启动线程(调用start()方法)而非直接调用run()方法,是掌握多线程编程的关键,... 目录1. 线程的生命周期2. start() vs run() 的本质区别3. 为什么必须通过 st