最大均值差异(Maximum Mean Discrepancy, MMD)复现教程

2023-10-10 19:59

本文主要是介绍最大均值差异(Maximum Mean Discrepancy, MMD)复现教程,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

本文章主要为了复现这个MMD教程中的代码。

pytorch环境安装

下面参考pytorch的官方教程。

这是安装pytorch的先决条件,如果需要用到GPU加速的话还需要下载CUDA驱动。(不过这个小项目就不用啦)

先决条件
首先需要一个Anaconda做为package manager,为项目建立虚拟环境(因为不同项目对pytorch或者其他包的版本要求不同,不能兼容哦)。

之后要下载项目所需要的pytorch版本。如果项目中有说明具体的pytorch版本,最好下载对应的版本,会省很多问题。在这个MMD项目中没有明确说明版本,那么我们就选择pytorch1.1.0这个版本吧(1.1、1.2、1.4、1.5这几个版本的区别不太清楚,但是小版本改动不大。0.4的版本相比1.x的版本差别会大很多。1.6是最新的版本,一般新版本不太稳定不建议使用)。

然后点击跳转到先前版本

选择Windows->CPU only的命令行,复制下来,你可以直接在终端进入虚拟环境安装,也可以在后面打开VSCode,进项目再安装。(记得打开VPN哦,不然下载速度会很慢)


建立MMD项目

好的文件管理可以让你的电脑更加有序,不然项目一多就乱套了。(或者说我有整理洁癖也行哈哈哈哈)目录最好不要用中文,不然有些项目可能会出现乱七八糟的报错,还要改很长时间。打开一个盘建立一个pythonProjects文件夹,以后专门用来放python的项目,然后再创建一个MMD_test文件夹,用来放本次MMD项目的代码。

然后打开VSCode,可能会自动打开上次的项目,那么我们需要点击最上方“文件->新建窗口”,然后选择打开文件夹,选中之前创建的MMD_test。之后在最上方选择“终端->新终端”,在VSCode中打开一个终端,用conda activate 激活到目标虚拟环境中。

现在我们来粗略看一下MMD教程中的代码吧。

  1. 第一段代码
    定义了两个函数,具体下面都有说明,看起来像是为之后的测试提供封装好的函数工具,那么我们就新建一个.py文件,把这段代码复制进去,命名为mmd_tool.py。
import torchdef guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):'''将源域数据和目标域数据转化为核矩阵,即上文中的KParams: source: 源域数据(n * len(x))target: 目标域数据(m * len(y))kernel_mul: kernel_num: 取不同高斯核的数量fix_sigma: 不同高斯核的sigma值Return:sum(kernel_val): 多个核矩阵之和'''n_samples = int(source.size()[0])+int(target.size()[0])# 求矩阵的行数,一般source和target的尺度是一样的,这样便于计算total = torch.cat([source, target], dim=0)#将source,target按列方向合并#将total复制(n+m)份total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))#将total的每一行都复制成(n+m)行,即每个数据都扩展成(n+m)份total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))#求任意两个数据之间的和,得到的矩阵中坐标(i,j)代表total中第i行数据和第j行数据之间的l2 distance(i==j时为0)L2_distance = ((total0-total1)**2).sum(2) #调整高斯核函数的sigma值if fix_sigma:bandwidth = fix_sigmaelse:bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)#以fix_sigma为中值,以kernel_mul为倍数取kernel_num个bandwidth值(比如fix_sigma为1时,得到[0.25,0.5,1,2,4]bandwidth /= kernel_mul ** (kernel_num // 2)bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]#高斯核函数的数学表达式kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]#得到最终的核矩阵return sum(kernel_val)#/len(kernel_val)def mmd_rbf(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):'''计算源域数据和目标域数据的MMD距离Params: source: 源域数据(n * len(x))target: 目标域数据(m * len(y))kernel_mul: kernel_num: 取不同高斯核的数量fix_sigma: 不同高斯核的sigma值Return:loss: MMD loss'''batch_size = int(source.size()[0])#一般默认为源域和目标域的batchsize相同kernels = guassian_kernel(source, target,kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)#根据式(3)将核矩阵分成4部分XX = kernels[:batch_size, :batch_size]YY = kernels[batch_size:, batch_size:]XY = kernels[:batch_size, batch_size:]YX = kernels[batch_size:, :batch_size]loss = torch.mean(XX + YY - XY -YX)return loss#因为一般都是n==m,所以L矩阵一般不加入计算
  1. 第二段代码
    用来生成后面测试用的两种不同分布下的数据,那我们就命名为data_generate.py。这里需要安装matplotlib库,直接pip install matplotlib就可以啦。
import random
import matplotlib
import matplotlib.pyplot as pltSAMPLE_SIZE = 500
buckets = 50#第一种分布:对数正态分布,得到一个中值为mu,标准差为sigma的正态分布。mu可以取任何值,sigma必须大于零。
plt.subplot(1,2,1)
plt.xlabel("random.lognormalvariate")
mu = -0.6
sigma = 0.15#将输出数据限制到0-1之间
res1 = [random.lognormvariate(mu, sigma) for _ in xrange(1, SAMPLE_SIZE)]
plt.hist(res1, buckets)#第二种分布:beta分布。参数的条件是alpha 和 beta 都要大于0, 返回值在0~1之间。
plt.subplot(1,2,2)
plt.xlabel("random.betavariate")
alpha = 1
beta = 10
res2 = [random.betavariate(alpha, beta) for _ in xrange(1, SAMPLE_SIZE)]
plt.hist(res2, buckets)plt.savefig('data.jpg)
plt.show()

我们在终端输入python data_generate.py直接来运行一下,看看有没有什么问题。
报错啦,这里漏了一个 ’ 。

加上后保存再运行。(注意中英文输入法的切换哦!)又报错啦!

我们把报错信息NameError后面的内容复制到百度查一下(程序员的日常,现学现卖!)
在这里插入图片描述
哈哈哈是python版本的问题,从这就可以看出,原来的代码使用python2写的,但是我们装的是python3,不过没关系,python2到python3没有特别大的改动,做一些小修改就行啦。
再次运行结果图,而且它还在你的当前文件夹下保存了这个图片。

  1. 第三段代码
    总共有两种情况,第一种情况是取不同分布数据,第二种情况,取相同分布数据,看MMD的效果。那我们把这两段代码合并以下,并修改之前的xrange错误。
from torch.autograd import Variable#参数值见上段代码
#分别从对数正态分布和beta分布取两组数据
diff_1 = []
for i in range(10):diff_1.append([random.lognormvariate(mu, sigma) for _ in range(1, SAMPLE_SIZE)])diff_2 = []
for i in range(10):diff_2.append([random.betavariate(alpha, beta) for _ in range(1, SAMPLE_SIZE)])X = torch.Tensor(diff_1)
Y = torch.Tensor(diff_2)
X,Y = Variable(X), Variable(Y)
print mmd_rbf(X,Y)#参数值见以上代码
#从对数正态分布取两组数据
same_1 = []
for i in range(10):same_1.append([random.lognormvariate(mu, sigma) for _ in range(1, SAMPLE_SIZE)])same_2 = []
for i in range(10):same_2.append([random.lognormvariate(mu, sigma) for _ in range(1, SAMPLE_SIZE)])X = torch.Tensor(same_1)
Y = torch.Tensor(same_2)
X,Y = Variable(X), Variable(Y)
print mmd_rbf(X,Y)

运行一下吧:

报错啦,这个问题也是python2到python3版本变换的一个经典问题,python3中的print需要加括号。

print(mmd_rbf(X,Y))

我们还注意到这里用了mmd_rbf函数,但是我们把这个函数定义在了mmd_tool.py文件里面,所以运行mmd_test.py文件时,文件应该不知道这个函数的意义,那怎么解决呢?我们直接在顶部加一个声明,类似于import包一样:

from mmd_test import *	# 这里的*代表导入mmd_test里面所有定义的函数,你也可也指定单独的函数导入

你也可以直接把这两个文件合并起来,就用一个文件,虽然这样很方便,但是如果项目很大的话我们还是需要把不同作用的代码分开,方便管理,养成良好的编程习惯。

之后我们再次运行。

又报错啦,SAMPLE_SIZE没有定义!!是因为我们生成数据的代码在data_generate.py文件里,还包括mu、sigma等等变量,这个问题是不是和上面的问题一样,那我们可以import来解决!(import就完事了)

from data_generate import *

然后我们再运行试试!

结果出来了!恭喜你哈哈哈,接下来就可以仔细看一看这个代码是如何运行的,原理是什么啦。


注意测试代码里面有这些代码:

from torch.autograd import Variable
...
X,Y = Variable(X), Variable(Y)
...

这里的Variable其实是pytorch很早以前版本的一个类,在所有张量定义使用的时候都要加一下,但是现在为了简洁已经删除了,只不过有些时候不会报错。为了更规范我们还是把包含Variable的代码都修改一下。

最后再补充一个关于if name == ‘main’: 的小知识点。

这篇关于最大均值差异(Maximum Mean Discrepancy, MMD)复现教程的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/182728

相关文章

基于C#实现PDF转图片的详细教程

《基于C#实现PDF转图片的详细教程》在数字化办公场景中,PDF文件的可视化处理需求日益增长,本文将围绕Spire.PDFfor.NET这一工具,详解如何通过C#将PDF转换为JPG、PNG等主流图片... 目录引言一、组件部署二、快速入门:PDF 转图片的核心 C# 代码三、分辨率设置 - 清晰度的决定因

Java Scanner类解析与实战教程

《JavaScanner类解析与实战教程》JavaScanner类(java.util包)是文本输入解析工具,支持基本类型和字符串读取,基于Readable接口与正则分隔符实现,适用于控制台、文件输... 目录一、核心设计与工作原理1.底层依赖2.解析机制A.核心逻辑基于分隔符(delimiter)和模式匹

spring AMQP代码生成rabbitmq的exchange and queue教程

《springAMQP代码生成rabbitmq的exchangeandqueue教程》使用SpringAMQP代码直接创建RabbitMQexchange和queue,并确保绑定关系自动成立,简... 目录spring AMQP代码生成rabbitmq的exchange and 编程queue执行结果总结s

python使用Akshare与Streamlit实现股票估值分析教程(图文代码)

《python使用Akshare与Streamlit实现股票估值分析教程(图文代码)》入职测试中的一道题,要求:从Akshare下载某一个股票近十年的财务报表包括,资产负债表,利润表,现金流量表,保存... 目录一、前言二、核心知识点梳理1、Akshare数据获取2、Pandas数据处理3、Matplotl

Python pandas库自学超详细教程

《Pythonpandas库自学超详细教程》文章介绍了Pandas库的基本功能、安装方法及核心操作,涵盖数据导入(CSV/Excel等)、数据结构(Series、DataFrame)、数据清洗、转换... 目录一、什么是Pandas库(1)、Pandas 应用(2)、Pandas 功能(3)、数据结构二、安

2025版mysql8.0.41 winx64 手动安装详细教程

《2025版mysql8.0.41winx64手动安装详细教程》本文指导Windows系统下MySQL安装配置,包含解压、设置环境变量、my.ini配置、初始化密码获取、服务安装与手动启动等步骤,... 目录一、下载安装包二、配置环境变量三、安装配置四、启动 mysql 服务,修改密码一、下载安装包安装地

电脑提示d3dx11_43.dll缺失怎么办? DLL文件丢失的多种修复教程

《电脑提示d3dx11_43.dll缺失怎么办?DLL文件丢失的多种修复教程》在使用电脑玩游戏或运行某些图形处理软件时,有时会遇到系统提示“d3dx11_43.dll缺失”的错误,下面我们就来分享超... 在计算机使用过程中,我们可能会遇到一些错误提示,其中之一就是缺失某个dll文件。其中,d3dx11_4

Linux下在线安装启动VNC教程

《Linux下在线安装启动VNC教程》本文指导在CentOS7上在线安装VNC,包含安装、配置密码、启动/停止、清理重启步骤及注意事项,强调需安装VNC桌面以避免黑屏,并解决端口冲突和目录权限问题... 目录描述安装VNC安装 VNC 桌面可能遇到的问题总结描js述linux中的VNC就类似于Window

Go语言编译环境设置教程

《Go语言编译环境设置教程》Go语言支持高并发(goroutine)、自动垃圾回收,编译为跨平台二进制文件,云原生兼容且社区活跃,开发便捷,内置测试与vet工具辅助检测错误,依赖模块化管理,提升开发效... 目录Go语言优势下载 Go  配置编译环境配置 GOPROXYIDE 设置(VS Code)一些基本

Windows环境下解决Matplotlib中文字体显示问题的详细教程

《Windows环境下解决Matplotlib中文字体显示问题的详细教程》本文详细介绍了在Windows下解决Matplotlib中文显示问题的方法,包括安装字体、更新缓存、配置文件设置及编码調整,并... 目录引言问题分析解决方案详解1. 检查系统已安装字体2. 手动添加中文字体(以SimHei为例)步骤