知识蒸馏dist和KLDiv代码和使用方法

2023-12-03 23:04

本文主要是介绍知识蒸馏dist和KLDiv代码和使用方法,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

DIST 实现方法

import torch.nn as nndef cosine_similarity(a, b, eps=1e-8):return (a * b).sum(1) / (a.norm(dim=1) * b.norm(dim=1) + eps)def pearson_correlation(a, b, eps=1e-8):return cosine_similarity(a - a.mean(1).unsqueeze(1),b - b.mean(1).unsqueeze(1), eps)def inter_class_relation(soft_student_outputs, soft_teacher_outputs):return 1 - pearson_correlation(soft_student_outputs, soft_teacher_outputs).mean()def intra_class_relation(soft_student_outputs, soft_teacher_outputs):return inter_class_relation(soft_student_outputs.transpose(0, 1), soft_teacher_outputs.transpose(0, 1))class DIST(nn.Module):def __init__(self, beta=1.0, gamma=1.0, temp=1.0):super(DIST, self).__init__()self.beta = betaself.gamma = gammaself.temp = tempdef forward(self, student_preds, teacher_preds, **kwargs):soft_student_outputs = (student_preds / self.temp).softmax(dim=1)soft_teacher_outputs = (teacher_preds / self.temp).softmax(dim=1)inter_loss = self.temp ** 2 * inter_class_relation(soft_student_outputs, soft_teacher_outputs)intra_loss = self.temp ** 2 * intra_class_relation(soft_student_outputs, soft_teacher_outputs)kd_loss = self.beta * inter_loss + self.gamma * intra_lossreturn kd_loss

KLDiv方法

import torch.nn as nn
import torch.nn.functional as F# loss = alpha * hard_loss + (1-alpha) * kd_loss,此处是单单的kd_loss
class KLDiv(nn.Module):def __init__(self, temp=1.0):super(KLDiv, self).__init__()self.temp = tempdef forward(self, student_preds, teacher_preds, **kwargs):soft_student_outputs = F.log_softmax(student_preds / self.temp, dim=1)soft_teacher_outputs = F.softmax(teacher_preds / self.temp, dim=1)kd_loss = F.kl_div(soft_student_outputs, soft_teacher_outputs, reduction="none").sum(1).mean()kd_loss *= self.temp ** 2return kd_loss

关于知识蒸馏的文章

FitNet(ICLR 2015)、Attention(ICLR 2017)、Relational KD(CVPR 2019)、ICKD (ICCV 2021)、Decoupled KD(CVPR 2022) 、ReviewKD(CVPR 2021)等方法的介绍:

https://zhuanlan.zhihu.com/p/603748226?utm_id=0

待更新

这篇关于知识蒸馏dist和KLDiv代码和使用方法的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/451194

相关文章

PyTorch核心方法之state_dict()、parameters()参数打印与应用案例

《PyTorch核心方法之state_dict()、parameters()参数打印与应用案例》PyTorch是一个流行的开源深度学习框架,提供了灵活且高效的方式来训练和部署神经网络,这篇文章主要介绍... 目录前言模型案例A. state_dict()方法验证B. parameters()C. 模型结构冻

Java中的ConcurrentBitSet使用小结

《Java中的ConcurrentBitSet使用小结》本文主要介绍了Java中的ConcurrentBitSet使用小结,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,... 目录一、核心澄清:Java标准库无内置ConcurrentBitSet二、推荐方案:Eclipse

Go语言结构体标签(Tag)的使用小结

《Go语言结构体标签(Tag)的使用小结》结构体标签Tag是Go语言中附加在结构体字段后的元数据字符串,用于提供额外的属性信息,这些信息可以通过反射在运行时读取和解析,下面就来详细的介绍一下Tag的使... 目录什么是结构体标签?基本语法常见的标签用途1.jsON 序列化/反序列化(最常用)2.数据库操作(

Java中ScopeValue的使用小结

《Java中ScopeValue的使用小结》Java21引入的ScopedValue是一种作用域内共享不可变数据的预览API,本文就来详细介绍一下Java中ScopeValue的使用小结,感兴趣的可以... 目录一、Java ScopedValue(作用域值)详解1. 定义与背景2. 核心特性3. 使用方法

spring中Interceptor的使用小结

《spring中Interceptor的使用小结》SpringInterceptor是SpringMVC提供的一种机制,用于在请求处理的不同阶段插入自定义逻辑,通过实现HandlerIntercept... 目录一、Interceptor 的核心概念二、Interceptor 的创建与配置三、拦截器的执行顺

利用c++判断水仙花数并输出示例代码

《利用c++判断水仙花数并输出示例代码》水仙花数是指一个三位数,其各位数字的立方和恰好等于该数本身,:本文主要介绍利用c++判断水仙花数并输出的相关资料,文中通过代码介绍的非常详细,需要的朋友可以... 以下是使用C++实现的相同逻辑代码:#include <IOStream>#include <vec

Python字符串处理方法超全攻略

《Python字符串处理方法超全攻略》字符串可以看作多个字符的按照先后顺序组合,相当于就是序列结构,意味着可以对它进行遍历、切片,:本文主要介绍Python字符串处理方法的相关资料,文中通过代码介... 目录一、基础知识:字符串的“不可变”特性与创建方式二、常用操作:80%场景的“万能工具箱”三、格式化方法

springboot+redis实现订单过期(超时取消)功能的方法详解

《springboot+redis实现订单过期(超时取消)功能的方法详解》在SpringBoot中使用Redis实现订单过期(超时取消)功能,有多种成熟方案,本文为大家整理了几个详细方法,文中的示例代... 目录一、Redis键过期回调方案(推荐)1. 配置Redis监听器2. 监听键过期事件3. Redi

C#中checked关键字的使用小结

《C#中checked关键字的使用小结》本文主要介绍了C#中checked关键字的使用,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学... 目录✅ 为什么需要checked? 问题:整数溢出是“静默China编程”的(默认)checked的三种用

C#中预处理器指令的使用小结

《C#中预处理器指令的使用小结》本文主要介绍了C#中预处理器指令的使用,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧... 目录 第 1 名:#if/#else/#elif/#endif✅用途:条件编译(绝对最常用!) 典型场景: 示例