【域适应】基于深度域适应MMD损失的典型四分类任务实现

2024-04-12 07:28

本文主要是介绍【域适应】基于深度域适应MMD损失的典型四分类任务实现,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

关于

MMD (maximum mean discrepancy)是用来衡量两组数据分布之间相似度的度量。一般地,如果两组数据分布相似,那么MMD 损失就相对较小,说明两组数据/特征处于相似的特征空间中。基于这个想法,对于源域和目标域数据,在使用深度学习进行特征提取中,使用MMD损失,可以让模型提取两个域的共有特征/空间,从而实现源域到目标域的迁移。

参考论文:https://arxiv.org/abs/1409.6041

工具

Python

 

方法实现

定义mmd函数
#!/usr/bin/env python
# encoding: utf-8import torch# Consider linear time MMD with a linear kernel:
# K(f(x), f(y)) = f(x)^Tf(y)
# h(z_i, z_j) = k(x_i, x_j) + k(y_i, y_j) - k(x_i, y_j) - k(x_j, y_i)
#             = [f(x_i) - f(y_i)]^T[f(x_j) - f(y_j)]
#
# f_of_X: batch_size * k
# f_of_Y: batch_size * k
def mmd_linear(f_of_X, f_of_Y):delta = f_of_X - f_of_Yloss = torch.mean(torch.mm(delta, torch.transpose(delta, 0, 1)))return lossdef guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):n_samples = int(source.size()[0])+int(target.size()[0])total = torch.cat([source, target], dim=0)total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))L2_distance = ((total0-total1)**2).sum(2)if fix_sigma:bandwidth = fix_sigmaelse:bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)bandwidth /= kernel_mul ** (kernel_num // 2)bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]return sum(kernel_val)#/len(kernel_val)def mmd_rbf_accelerate(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):batch_size = int(source.size()[0])kernels = guassian_kernel(source, target,kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)loss = 0for i in range(batch_size):s1, s2 = i, (i+1)%batch_sizet1, t2 = s1+batch_size, s2+batch_sizeloss += kernels[s1, s2] + kernels[t1, t2]loss -= kernels[s1, t2] + kernels[s2, t1]return loss / float(batch_size)def mmd_rbf_noaccelerate(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):batch_size = int(source.size()[0])kernels = guassian_kernel(source, target,kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)XX = kernels[:batch_size, :batch_size]YY = kernels[batch_size:, batch_size:]XY = kernels[:batch_size, batch_size:]YX = kernels[batch_size:, :batch_size]loss = torch.mean(XX + YY - XY -YX)return loss
定义基于mmd特征对齐CNN模型
# encoding=utf-8import torch.nn as nn
import torch.nn.functional as Fclass Network(nn.Module):def __init__(self):super(Network, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(1, 3)),nn.ReLU())self.conv2 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(1, 3)),nn.ReLU(),nn.Dropout(0.4),nn.MaxPool2d(kernel_size=(1, 2), stride=2))self.fc1 = nn.Sequential(nn.Linear(in_features=64 * 98, out_features=100),nn.ReLU())self.fc2 = nn.Sequential(nn.Linear(in_features=100, out_features=2))def forward(self, src, tar):x_src = self.conv1(src)x_tar = self.conv1(tar)x_src = self.conv2(x_src)x_tar = self.conv2(x_tar)#print(x_src.shape)x_src = x_src.reshape(-1, 64 * 98)x_tar = x_tar.reshape(-1, 64 * 98)x_src_mmd = self.fc1(x_src)x_tar_mmd = self.fc1(x_tar)#x_src = self.fc1(x_src)#x_tar = self.fc1(x_tar)#x_src_mmd = self.fc2(x_src)#x_tar_mmd = self.fc2(x_tar)y_src = self.fc2(x_src_mmd)return y_src, x_src_mmd, x_tar_mmd

代码获取

后台私信;

其他相关域适应问题和代码开发,欢迎沟通和交流。

这篇关于【域适应】基于深度域适应MMD损失的典型四分类任务实现的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/896486

相关文章

C#代码实现解析WTGPS和BD数据

《C#代码实现解析WTGPS和BD数据》在现代的导航与定位应用中,准确解析GPS和北斗(BD)等卫星定位数据至关重要,本文将使用C#语言实现解析WTGPS和BD数据,需要的可以了解下... 目录一、代码结构概览1. 核心解析方法2. 位置信息解析3. 经纬度转换方法4. 日期和时间戳解析5. 辅助方法二、L

使用Python和Matplotlib实现可视化字体轮廓(从路径数据到矢量图形)

《使用Python和Matplotlib实现可视化字体轮廓(从路径数据到矢量图形)》字体设计和矢量图形处理是编程中一个有趣且实用的领域,通过Python的matplotlib库,我们可以轻松将字体轮廓... 目录背景知识字体轮廓的表示实现步骤1. 安装依赖库2. 准备数据3. 解析路径指令4. 绘制图形关键

C/C++中OpenCV 矩阵运算的实现

《C/C++中OpenCV矩阵运算的实现》本文主要介绍了C/C++中OpenCV矩阵运算的实现,包括基本算术运算(标量与矩阵)、矩阵乘法、转置、逆矩阵、行列式、迹、范数等操作,感兴趣的可以了解一下... 目录矩阵的创建与初始化创建矩阵访问矩阵元素基本的算术运算 ➕➖✖️➗矩阵与标量运算矩阵与矩阵运算 (逐元

C/C++的OpenCV 进行图像梯度提取的几种实现

《C/C++的OpenCV进行图像梯度提取的几种实现》本文主要介绍了C/C++的OpenCV进行图像梯度提取的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的... 目录预www.chinasem.cn备知识1. 图像加载与预处理2. Sobel 算子计算 X 和 Y

C/C++和OpenCV实现调用摄像头

《C/C++和OpenCV实现调用摄像头》本文主要介绍了C/C++和OpenCV实现调用摄像头,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一... 目录准备工作1. 打开摄像头2. 读取视频帧3. 显示视频帧4. 释放资源5. 获取和设置摄像头属性

c/c++的opencv图像金字塔缩放实现

《c/c++的opencv图像金字塔缩放实现》本文主要介绍了c/c++的opencv图像金字塔缩放实现,通过对原始图像进行连续的下采样或上采样操作,生成一系列不同分辨率的图像,具有一定的参考价值,感兴... 目录图像金字塔简介图像下采样 (cv::pyrDown)图像上采样 (cv::pyrUp)C++ O

c/c++的opencv实现图片膨胀

《c/c++的opencv实现图片膨胀》图像膨胀是形态学操作,通过结构元素扩张亮区填充孔洞、连接断开部分、加粗物体,OpenCV的cv::dilate函数实现该操作,本文就来介绍一下opencv图片... 目录什么是图像膨胀?结构元素 (KerChina编程nel)OpenCV 中的 cv::dilate() 函

Python使用FFmpeg实现高效音频格式转换工具

《Python使用FFmpeg实现高效音频格式转换工具》在数字音频处理领域,音频格式转换是一项基础但至关重要的功能,本文主要为大家介绍了Python如何使用FFmpeg实现强大功能的图形化音频转换工具... 目录概述功能详解软件效果展示主界面布局转换过程截图完成提示开发步骤详解1. 环境准备2. 项目功能结

SpringBoot使用ffmpeg实现视频压缩

《SpringBoot使用ffmpeg实现视频压缩》FFmpeg是一个开源的跨平台多媒体处理工具集,用于录制,转换,编辑和流式传输音频和视频,本文将使用ffmpeg实现视频压缩功能,有需要的可以参考... 目录核心功能1.格式转换2.编解码3.音视频处理4.流媒体支持5.滤镜(Filter)安装配置linu

在Spring Boot中实现HTTPS加密通信及常见问题排查

《在SpringBoot中实现HTTPS加密通信及常见问题排查》HTTPS是HTTP的安全版本,通过SSL/TLS协议为通讯提供加密、身份验证和数据完整性保护,下面通过本文给大家介绍在SpringB... 目录一、HTTPS核心原理1.加密流程概述2.加密技术组合二、证书体系详解1、证书类型对比2. 证书获