知识蒸馏dist和KLDiv代码和使用方法

2023-12-03 23:04

本文主要是介绍知识蒸馏dist和KLDiv代码和使用方法,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

DIST 实现方法

import torch.nn as nndef cosine_similarity(a, b, eps=1e-8):return (a * b).sum(1) / (a.norm(dim=1) * b.norm(dim=1) + eps)def pearson_correlation(a, b, eps=1e-8):return cosine_similarity(a - a.mean(1).unsqueeze(1),b - b.mean(1).unsqueeze(1), eps)def inter_class_relation(soft_student_outputs, soft_teacher_outputs):return 1 - pearson_correlation(soft_student_outputs, soft_teacher_outputs).mean()def intra_class_relation(soft_student_outputs, soft_teacher_outputs):return inter_class_relation(soft_student_outputs.transpose(0, 1), soft_teacher_outputs.transpose(0, 1))class DIST(nn.Module):def __init__(self, beta=1.0, gamma=1.0, temp=1.0):super(DIST, self).__init__()self.beta = betaself.gamma = gammaself.temp = tempdef forward(self, student_preds, teacher_preds, **kwargs):soft_student_outputs = (student_preds / self.temp).softmax(dim=1)soft_teacher_outputs = (teacher_preds / self.temp).softmax(dim=1)inter_loss = self.temp ** 2 * inter_class_relation(soft_student_outputs, soft_teacher_outputs)intra_loss = self.temp ** 2 * intra_class_relation(soft_student_outputs, soft_teacher_outputs)kd_loss = self.beta * inter_loss + self.gamma * intra_lossreturn kd_loss

KLDiv方法

import torch.nn as nn
import torch.nn.functional as F# loss = alpha * hard_loss + (1-alpha) * kd_loss,此处是单单的kd_loss
class KLDiv(nn.Module):def __init__(self, temp=1.0):super(KLDiv, self).__init__()self.temp = tempdef forward(self, student_preds, teacher_preds, **kwargs):soft_student_outputs = F.log_softmax(student_preds / self.temp, dim=1)soft_teacher_outputs = F.softmax(teacher_preds / self.temp, dim=1)kd_loss = F.kl_div(soft_student_outputs, soft_teacher_outputs, reduction="none").sum(1).mean()kd_loss *= self.temp ** 2return kd_loss

关于知识蒸馏的文章

FitNet(ICLR 2015)、Attention(ICLR 2017)、Relational KD(CVPR 2019)、ICKD (ICCV 2021)、Decoupled KD(CVPR 2022) 、ReviewKD(CVPR 2021)等方法的介绍:

https://zhuanlan.zhihu.com/p/603748226?utm_id=0

待更新

这篇关于知识蒸馏dist和KLDiv代码和使用方法的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/451194

相关文章

使用Python构建智能BAT文件生成器的完美解决方案

《使用Python构建智能BAT文件生成器的完美解决方案》这篇文章主要为大家详细介绍了如何使用wxPython构建一个智能的BAT文件生成器,它不仅能够为Python脚本生成启动脚本,还提供了完整的文... 目录引言运行效果图项目背景与需求分析核心需求技术选型核心功能实现1. 数据库设计2. 界面布局设计3

使用IDEA部署Docker应用指南分享

《使用IDEA部署Docker应用指南分享》本文介绍了使用IDEA部署Docker应用的四步流程:创建Dockerfile、配置IDEADocker连接、设置运行调试环境、构建运行镜像,并强调需准备本... 目录一、创建 dockerfile 配置文件二、配置 IDEA 的 Docker 连接三、配置 Do

Android Paging 分页加载库使用实践

《AndroidPaging分页加载库使用实践》AndroidPaging库是Jetpack组件的一部分,它提供了一套完整的解决方案来处理大型数据集的分页加载,本文将深入探讨Paging库... 目录前言一、Paging 库概述二、Paging 3 核心组件1. PagingSource2. Pager3.

python使用try函数详解

《python使用try函数详解》Pythontry语句用于异常处理,支持捕获特定/多种异常、else/final子句确保资源释放,结合with语句自动清理,可自定义异常及嵌套结构,灵活应对错误场景... 目录try 函数的基本语法捕获特定异常捕获多个异常使用 else 子句使用 finally 子句捕获所

C++11右值引用与Lambda表达式的使用

《C++11右值引用与Lambda表达式的使用》C++11引入右值引用,实现移动语义提升性能,支持资源转移与完美转发;同时引入Lambda表达式,简化匿名函数定义,通过捕获列表和参数列表灵活处理变量... 目录C++11新特性右值引用和移动语义左值 / 右值常见的左值和右值移动语义移动构造函数移动复制运算符

Python对接支付宝支付之使用AliPay实现的详细操作指南

《Python对接支付宝支付之使用AliPay实现的详细操作指南》支付宝没有提供PythonSDK,但是强大的github就有提供python-alipay-sdk,封装里很多复杂操作,使用这个我们就... 目录一、引言二、准备工作2.1 支付宝开放平台入驻与应用创建2.2 密钥生成与配置2.3 安装ali

C#中lock关键字的使用小结

《C#中lock关键字的使用小结》在C#中,lock关键字用于确保当一个线程位于给定实例的代码块中时,其他线程无法访问同一实例的该代码块,下面就来介绍一下lock关键字的使用... 目录使用方式工作原理注意事项示例代码为什么不能lock值类型在C#中,lock关键字用于确保当一个线程位于给定实例的代码块中时

MySQL 强制使用特定索引的操作

《MySQL强制使用特定索引的操作》MySQL可通过FORCEINDEX、USEINDEX等语法强制查询使用特定索引,但优化器可能不采纳,需结合EXPLAIN分析执行计划,避免性能下降,注意版本差异... 目录1. 使用FORCE INDEX语法2. 使用USE INDEX语法3. 使用IGNORE IND

C# $字符串插值的使用

《C#$字符串插值的使用》本文介绍了C#中的字符串插值功能,详细介绍了使用$符号的实现方式,文中通过示例代码介绍的非常详细,需要的朋友们下面随着小编来一起学习学习吧... 目录$ 字符使用方式创建内插字符串包含不同的数据类型控制内插表达式的格式控制内插表达式的对齐方式内插表达式中使用转义序列内插表达式中使用

flask库中sessions.py的使用小结

《flask库中sessions.py的使用小结》在Flask中Session是一种用于在不同请求之间存储用户数据的机制,Session默认是基于客户端Cookie的,但数据会经过加密签名,防止篡改,... 目录1. Flask Session 的基本使用(1) 启用 Session(2) 存储和读取 Se