知识蒸馏的蒸馏损失方法代码总结(包括:基于logits的方法:KLDiv,dist,dkd等,基于中间层提示的方法:)

本文主要是介绍知识蒸馏的蒸馏损失方法代码总结(包括:基于logits的方法:KLDiv,dist,dkd等,基于中间层提示的方法:),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

有两种知识蒸馏方法:一种利用教师模型的输出概率(基于logits的方法)[15,14,11],另一种利用教师模型的中间表示(基于提示的方法)[12,13,18,17]。基于logits的方法利用教师的输出作为辅助信号来训练一个较小的模型,即学生模型:

利用教师模型的输出概率(基于logits的方法)

该类方法损失函数为:
在这里插入图片描述

DIST

Tao Huang,Shan You,Fei Wang,Chen Qian,and Chang Xu.Knowledge distillation from a strongerteacher.In Advances in Neural Information Processing Systems,2022.

import torch.nn as nndef cosine_similarity(a, b, eps=1e-8):return (a * b).sum(1) / (a.norm(dim=1) * b.norm(dim=1) + eps)def pearson_correlation(a, b, eps=1e-8):return cosine_similarity(a - a.mean(1).unsqueeze(1),b - b.mean(1).unsqueeze(1), eps)def inter_class_relation(soft_student_outputs, soft_teacher_outputs):return 1 - pearson_correlation(soft_student_outputs, soft_teacher_outputs).mean()def intra_class_relation(soft_student_outputs, soft_teacher_outputs):return inter_class_relation(soft_student_outputs.transpose(0, 1), soft_teacher_outputs.transpose(0, 1))class DIST(nn.Module):def __init__(self, beta=1.0, gamma=1.0, temp=1.0):super(DIST, self).__init__()self.beta = betaself.gamma = gammaself.temp = tempdef forward(self, student_preds, teacher_preds, **kwargs):soft_student_outputs = (student_preds / self.temp).softmax(dim=1)soft_teacher_outputs = (teacher_preds / self.temp).softmax(dim=1)inter_loss = self.temp ** 2 * inter_class_relation(soft_student_outputs, soft_teacher_outputs)intra_loss = self.temp ** 2 * intra_class_relation(soft_student_outputs, soft_teacher_outputs)kd_loss = self.beta * inter_loss + self.gamma * intra_lossreturn kd_loss

KLDiv (2015年的原始方法)

import torch.nn as nn
import torch.nn.functional as F# loss = alpha * hard_loss + (1-alpha) * kd_loss,此处是单单的kd_loss
class KLDiv(nn.Module):def __init__(self, temp=1.0):super(KLDiv, self).__init__()self.temp = tempdef forward(self, student_preds, teacher_preds, **kwargs):soft_student_outputs = F.log_softmax(student_preds / self.temp, dim=1)soft_teacher_outputs = F.softmax(teacher_preds / self.temp, dim=1)kd_loss = F.kl_div(soft_student_outputs, soft_teacher_outputs, reduction="none").sum(1).mean()kd_loss *= self.temp ** 2return kd_loss

dkd (Decoupled KD(CVPR 2022) )

Borui Zhao,Quan Cui,Renjie Song,Yiyu Qiu,and Jiajun Liang.Decoupled knowledge distillation.InIEEE/CVF Conference on Computer Vision and Pattern Recognition,2022.

import torch
import torch.nn as nn
import torch.nn.functional as Fdef dkd_loss(logits_student, logits_teacher, target, alpha, beta, temperature):gt_mask = _get_gt_mask(logits_student, target)other_mask = _get_other_mask(logits_student, target)pred_student = F.softmax(logits_student / temperature, dim=1)pred_teacher = F.softmax(logits_teacher / temperature, dim=1)pred_student = cat_mask(pred_student, gt_mask, other_mask)pred_teacher = cat_mask(pred_teacher, gt_mask, other_mask)log_pred_student = torch.log(pred_student)tckd_loss = (F.kl_div(log_pred_student, pred_teacher, reduction='batchmean')* (temperature ** 2))pred_teacher_part2 = F.softmax(logits_teacher / temperature - 1000.0 * gt_mask, dim=1)log_pred_student_part2 = F.log_softmax(logits_student / temperature - 1000.0 * gt_mask, dim=1)nckd_loss = (F.kl_div(log_pred_student_part2, pred_teacher_part2, reduction='batchmean')* (temperature ** 2))return alpha * tckd_loss + beta * nckd_lossdef _get_gt_mask(logits, target):target = target.reshape(-1)mask = torch.zeros_like(logits).scatter_(1, target.unsqueeze(1), 1).bool()return maskdef _get_other_mask(logits, target):target = target.reshape(-1)mask = torch.ones_like(logits).scatter_(1, target.unsqueeze(1), 0).bool()return maskdef cat_mask(t, mask1, mask2):t1 = (t * mask1).sum(dim=1, keepdims=True)t2 = (t * mask2).sum(1, keepdims=True)rt = torch.cat([t1, t2], dim=1)return rtclass DKD(nn.Module):def __init__(self, alpha=1., beta=2., temperature=1.):super(DKD, self).__init__()self.alpha = alphaself.beta = betaself.temperature = temperaturedef forward(self, z_s, z_t, **kwargs):target = kwargs['target']if len(target.shape) == 2:  # mixup / smoothingtarget = target.max(1)[1]kd_loss = dkd_loss(z_s, z_t, target, self.alpha, self.beta, self.temperature)return kd_loss

利用教师模型的中间表示(基于提示的方法)

该类方法损失函数为:
[ L_{hint} = D_{hint}(T_s(F_s), T_t(F_t)) ]

ReviewKD (CVPR2021)

论文:

Pengguang Chen,Shu Liu,Hengshuang Zhao,and Jiaya Jia.Distilling knowledge via knowledge review.In IEEE/CVF Conference on Computer Vision and Pattern Recognition,2021.

代码:

https://github.com/dvlab-research/ReviewKD

Adriana Romero,Nicolas Ballas,Samira Ebrahimi Kahou,Antoine Chassang,Carlo Gatta,and YoshuaBengio.Fitnets:Hints for thin deep nets.arXiv preprint arXiv:1412.6550,2014.

Yonglong Tian,Dilip Krishnan,and Phillip Isola.Contrastive representation distillation.In IEEE/CVFInternational Conference on Learning Representations,2020.

Baoyun Peng,Xiao Jin,Jiaheng Liu,Dongsheng Li,Yichao Wu,Yu Liu,Shunfeng Zhou,and ZhaoningZhang.Correlation congruence for knowledge distillation.In International Conference on ComputerVision,2019.

关于知识蒸馏损失函数的文章

FitNet(ICLR 2015)、Attention(ICLR 2017)、Relational KD(CVPR 2019)、ICKD (ICCV 2021)、Decoupled KD(CVPR 2022) 、ReviewKD(CVPR 2021)等方法的介绍:

https://zhuanlan.zhihu.com/p/603748226?utm_id=0

待更新

这篇关于知识蒸馏的蒸馏损失方法代码总结(包括:基于logits的方法:KLDiv,dist,dkd等,基于中间层提示的方法:)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/456622

相关文章

Django开发时如何避免频繁发送短信验证码(python图文代码)

《Django开发时如何避免频繁发送短信验证码(python图文代码)》Django开发时,为防止频繁发送验证码,后端需用Redis限制请求频率,结合管道技术提升效率,通过生产者消费者模式解耦业务逻辑... 目录避免频繁发送 验证码1. www.chinasem.cn避免频繁发送 验证码逻辑分析2. 避免频繁

精选20个好玩又实用的的Python实战项目(有图文代码)

《精选20个好玩又实用的的Python实战项目(有图文代码)》文章介绍了20个实用Python项目,涵盖游戏开发、工具应用、图像处理、机器学习等,使用Tkinter、PIL、OpenCV、Kivy等库... 目录① 猜字游戏② 闹钟③ 骰子模拟器④ 二维码⑤ 语言检测⑥ 加密和解密⑦ URL缩短⑧ 音乐播放

Python使用Tenacity一行代码实现自动重试详解

《Python使用Tenacity一行代码实现自动重试详解》tenacity是一个专为Python设计的通用重试库,它的核心理念就是用简单、清晰的方式,为任何可能失败的操作添加重试能力,下面我们就来看... 目录一切始于一个简单的 API 调用Tenacity 入门:一行代码实现优雅重试精细控制:让重试按我

Springboot项目启动失败提示找不到dao类的解决

《Springboot项目启动失败提示找不到dao类的解决》SpringBoot启动失败,因ProductServiceImpl未正确注入ProductDao,原因:Dao未注册为Bean,解决:在启... 目录错误描述原因解决方法总结***************************APPLICA编

Python安装Pandas库的两种方法

《Python安装Pandas库的两种方法》本文介绍了三种安装PythonPandas库的方法,通过cmd命令行安装并解决版本冲突,手动下载whl文件安装,更换国内镜像源加速下载,最后建议用pipli... 目录方法一:cmd命令行执行pip install pandas方法二:找到pandas下载库,然后

Linux系统中查询JDK安装目录的几种常用方法

《Linux系统中查询JDK安装目录的几种常用方法》:本文主要介绍Linux系统中查询JDK安装目录的几种常用方法,方法分别是通过update-alternatives、Java命令、环境变量及目... 目录方法 1:通过update-alternatives查询(推荐)方法 2:检查所有已安装的 JDK方

SQL Server安装时候没有中文选项的解决方法

《SQLServer安装时候没有中文选项的解决方法》用户安装SQLServer时界面全英文,无中文选项,通过修改安装设置中的国家或地区为中文中国,重启安装程序后界面恢复中文,解决了问题,对SQLSe... 你是不是在安装SQL Server时候发现安装界面和别人不同,并且无论如何都没有中文选项?这个问题也

Java Thread中join方法使用举例详解

《JavaThread中join方法使用举例详解》JavaThread中join()方法主要是让调用改方法的thread完成run方法里面的东西后,在执行join()方法后面的代码,这篇文章主要介绍... 目录前言1.join()方法的定义和作用2.join()方法的三个重载版本3.join()方法的工作原

电脑提示d3dx11_43.dll缺失怎么办? DLL文件丢失的多种修复教程

《电脑提示d3dx11_43.dll缺失怎么办?DLL文件丢失的多种修复教程》在使用电脑玩游戏或运行某些图形处理软件时,有时会遇到系统提示“d3dx11_43.dll缺失”的错误,下面我们就来分享超... 在计算机使用过程中,我们可能会遇到一些错误提示,其中之一就是缺失某个dll文件。其中,d3dx11_4

游戏闪退弹窗提示找不到storm.dll文件怎么办? Stormdll文件损坏修复技巧

《游戏闪退弹窗提示找不到storm.dll文件怎么办?Stormdll文件损坏修复技巧》DLL文件丢失或损坏会导致软件无法正常运行,例如我们在电脑上运行软件或游戏时会得到以下提示:storm.dll... 很多玩家在打开游戏时,突然弹出“找不到storm.dll文件”的提示框,随后游戏直接闪退,这通常是由于