结合代码详细讲解DDPM的训练和采样过程

2024-08-30 23:12

本文主要是介绍结合代码详细讲解DDPM的训练和采样过程,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

本篇文章结合代码讲解Denoising Diffusion Probabilistic Models(DDPM),首先我们先不关注推导过程,而是结合代码来看一下训练和推理过程是如何实现的,推导过程会在别的文章中讲解;首先我们来看一下论文中的算法描述。DDPM分为扩散过程和反向扩散过程,也就是训练过程和采样过程;
代码来自https://github.com/zoubohao/DenoisingDiffusionProbabilityModel-ddpm-

请添加图片描述

1. 训练(扩散)过程

首先我们来逐个看一下训练过程中的所有符号的含义:

x 0 x_0 x0是真实图像;

t 是扩散的步数,取值范围从1到T;

ϵ \epsilon ϵ是从标准正态分布中采样的噪声;

ϵ θ \epsilon_\theta ϵθ是模型,用于预测噪声,其输入是 x t x_t xt和 t;

x t x_t xt的表达式如下:

在这里插入图片描述

x t x_t xt x 0 x_0 x0加噪获得,其中 α t ‾ \overline{\alpha_{t}} αt是常数
因此训练过程总结成一句话就是,向真实图像 x 0 x_0 x0中加噪,获得加噪后的图像 x t x_t xt;然后将 x t x_t xt和t输入到网络中,得到预测的噪声,通过使得网络预测的噪声和真实加入的噪声更接近,完成网络的训练。
从另一个角度,我们也可以这么理解:向 x 0 x_0 x0中加噪的过程,可以理解成是编码的过程,加噪之后获取到了图像的中间表示 x t x_t xt;而预测噪声的过程则是从 x t x_t xt解码的过程,只是并没有选择直接解码出 x 0 x_0 x0,而是解码出加入的噪声,也就是残差。请添加图片描述

下面来看一下代码,跟上面讲解的过程是一一对应的,首先在初始化函数中我们需要准备好每个时刻t所需要的常数量 α t ‾ \sqrt{\overline{\alpha_{t}}} αt 1 − α t ‾ \sqrt{1-\overline{\alpha_{t}}} 1αt 。这些参数最原始来源于一个超参数 β t \beta_t βt,这个参数为加入噪声的方差。他们的关系如下:

[图片]

所以很容易理解代码中的sqrt_alphas_bar就是 α t ‾ \sqrt{\overline{\alpha_{t}}} αt ,sqrt_one_minus_alphas_bar 就是 1 − α t ‾ \sqrt{1-\overline{\alpha_{t}}} 1αt
接着在forward函数中,首先从[0,T]中随机选取一个时刻t,然后从标准正态分布中采样一个噪声,shape和 x 0 x_0 x0一致,接着获取 x t x_t xt

x_t = (
extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 +
extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise)

然后将然后将 x t x_t xt和t输入到网络中,得到预测的噪声:

self.model(x_t, t)

计算Loss函数:

loss = F.mse_loss(self.model(x_t, t), noise, reduction='none')

训练过程的完整代码:

class GaussianDiffusionTrainer(nn.Module):def __init__(self, model, beta_1, beta_T, T):super().__init__()self.model = modelself.T = Tself.register_buffer('betas', torch.linspace(beta_1, beta_T, T).double())alphas = 1. - self.betasalphas_bar = torch.cumprod(alphas, dim=0)# calculations for diffusion q(x_t | x_{t-1}) and othersself.register_buffer('sqrt_alphas_bar', torch.sqrt(alphas_bar))self.register_buffer('sqrt_one_minus_alphas_bar', torch.sqrt(1. - alphas_bar))# 每次forward时,给每个样本随机取一个t,并采样一个高斯噪声,然后根据t从sqrt_alphas_bar和sqrt_one_minus_alphas_bar中取出对应的系数,然后根据x_0和采样的高斯噪声生成x_t。然后将x_t和t输入到噪声预测网络中,得到预测的噪声。预测出的噪声输入到网络中,计算loss,从而实现model的训练。def forward(self, x_0):"""Algorithm 1."""t = torch.randint(self.T, size=(x_0.shape[0], ), device=x_0.device) # 给batch中每个样本取一个t,取值范围是[0, 1000]noise = torch.randn_like(x_0) # 采样高斯噪声,shape与x_0一致x_t = (extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 +extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise)loss = F.mse_loss(self.model(x_t, t), noise, reduction='none')return loss

2. 推理(反向)过程

首先我们来明确一下,反向过程的目标是什么。反向过程的目标是逐步从一张噪声图像 x T x_T xT中恢复出一张图像,表示成 p θ ( x t − 1 ∣ x t ) p_{\theta}(x_{t-1}|x_t) pθ(xt1xt),我们没法推导出 p ( x t − 1 ∣ x t ) p(x_{t-1}|x_t) p(xt1xt),但是 p ( x t − 1 ∣ x t , x 0 ) p(x_{t-1}|x_t, x_0) p(xt1xt,x0)是可以用贝叶斯公式推导出来的,其也是一个高斯分布,并且可以把 x 0 x_0 x0化简掉。最终 p θ ( x t − 1 ∣ x t ) p_{\theta}(x_{t-1}|x_t) pθ(xt1xt)分布的均值为:
请添加图片描述

方差为 β t \beta_t βt
因此我们可以从 p θ ( x t − 1 ∣ x t ) p_{\theta}(x_{t-1}|x_t) pθ(xt1xt)分布中采样出一个 x t − 1 x_{t-1} xt1
请添加图片描述
这种采样方式叫做重参数技巧,如果不了解可以看如下介绍:
在这里插入图片描述
注意:是标准差与标准正态分布相乘,而不是方差;

因为DDPM的方差固定为 β t \beta_t βt,所以反向过程的重点就是学习出这个分布的方差,从上面的表达式可以看出分布的均值与 x t x_t xt和当前时刻加入的噪声 ϵ t \epsilon_t ϵt有关,而我们的模型可以完成对 ϵ t \epsilon_t ϵt的预测,只要将 x t x_t xt和 t 输入进去模型中即可。代码中描述的过程与此一一对应。

注意代码中存在三个噪声,其中eps是模型预测出来的,其和分布的均值计算相关;forward函数中的noise也是噪声,但是它是从标准正态分布中采样的,用于从 p θ ( x t − 1 ∣ x t ) p_{\theta}(x_{t-1}|x_t) pθ(xt1xt)采样;forward函数中的 x T x_T xT是整个反向过程的输入,也是从标准正态分布中采样的。

# 反向过程是从纯噪声x_T开始逐步去噪以生成样本,此过程也是一个高斯分布,均值和x_t以及预测出的噪声相关,方差在ddpm中没有进行学习,直接使用的是后验分布q(x_t-1|x_t,x_0)的方差。
class GaussianDiffusionSampler(nn.Module):def __init__(self, model, beta_1, beta_T, T):super().__init__()self.model = modelself.T = Tself.register_buffer('betas', torch.linspace(beta_1, beta_T, T).double())alphas = 1. - self.betasalphas_bar = torch.cumprod(alphas, dim=0)alphas_bar_prev = F.pad(alphas_bar, [1, 0], value=1)[:T]self.register_buffer('coeff1', torch.sqrt(1. / alphas))self.register_buffer('coeff2', self.coeff1 * (1. - alphas) / torch.sqrt(1. - alphas_bar))self.register_buffer('posterior_var', self.betas * (1. - alphas_bar_prev) / (1. - alphas_bar))def predict_xt_prev_mean_from_eps(self, x_t, t, eps):assert x_t.shape == eps.shapereturn (extract(self.coeff1, t, x_t.shape) * x_t -extract(self.coeff2, t, x_t.shape) * eps)def p_mean_variance(self, x_t, t):# below: only log_variance is used in the KL computationsvar = torch.cat([self.posterior_var[1:2], self.betas[1:]])var = extract(var, t, x_t.shape)eps = self.model(x_t, t)xt_prev_mean = self.predict_xt_prev_mean_from_eps(x_t, t, eps=eps)return xt_prev_mean, vardef forward(self, x_T):"""Algorithm 2."""x_t = x_T # 输入是一个标准正态分布噪声# 从T到1进行reverse过程for time_step in reversed(range(self.T)):print(time_step)t = x_t.new_ones([x_T.shape[0], ], dtype=torch.long) * time_stepmean, var= self.p_mean_variance(x_t=x_t, t=t) # no noise when t == 0if time_step > 0:noise = torch.randn_like(x_t)else:noise = 0x_t = mean + torch.sqrt(var) * noise # 从q(x_t-1|x_t)中采样assert torch.isnan(x_t).int().sum() == 0, "nan in tensor."x_0 = x_treturn torch.clip(x_0, -1, 1)

这篇关于结合代码详细讲解DDPM的训练和采样过程的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1122208

相关文章

Redis中Hash从使用过程到原理说明

《Redis中Hash从使用过程到原理说明》RedisHash结构用于存储字段-值对,适合对象数据,支持HSET、HGET等命令,采用ziplist或hashtable编码,通过渐进式rehash优化... 目录一、开篇:Hash就像超市的货架二、Hash的基本使用1. 常用命令示例2. Java操作示例三

Redis中Set结构使用过程与原理说明

《Redis中Set结构使用过程与原理说明》本文解析了RedisSet数据结构,涵盖其基本操作(如添加、查找)、集合运算(交并差)、底层实现(intset与hashtable自动切换机制)、典型应用场... 目录开篇:从购物车到Redis Set一、Redis Set的基本操作1.1 编程常用命令1.2 集

Linux下利用select实现串口数据读取过程

《Linux下利用select实现串口数据读取过程》文章介绍Linux中使用select、poll或epoll实现串口数据读取,通过I/O多路复用机制在数据到达时触发读取,避免持续轮询,示例代码展示设... 目录示例代码(使用select实现)代码解释总结在 linux 系统里,我们可以借助 select、

k8s中实现mysql主备过程详解

《k8s中实现mysql主备过程详解》文章讲解了在K8s中使用StatefulSet部署MySQL主备架构,包含NFS安装、storageClass配置、MySQL部署及同步检查步骤,确保主备数据一致... 目录一、k8s中实现mysql主备1.1 环境信息1.2 部署nfs-provisioner1.2.

Java集合之Iterator迭代器实现代码解析

《Java集合之Iterator迭代器实现代码解析》迭代器Iterator是Java集合框架中的一个核心接口,位于java.util包下,它定义了一种标准的元素访问机制,为各种集合类型提供了一种统一的... 目录一、什么是Iterator二、Iterator的核心方法三、基本使用示例四、Iterator的工

Java 线程池+分布式实现代码

《Java线程池+分布式实现代码》在Java开发中,池通过预先创建并管理一定数量的资源,避免频繁创建和销毁资源带来的性能开销,从而提高系统效率,:本文主要介绍Java线程池+分布式实现代码,需要... 目录1. 线程池1.1 自定义线程池实现1.1.1 线程池核心1.1.2 代码示例1.2 总结流程2. J

Python中isinstance()函数原理解释及详细用法示例

《Python中isinstance()函数原理解释及详细用法示例》isinstance()是Python内置的一个非常有用的函数,用于检查一个对象是否属于指定的类型或类型元组中的某一个类型,它是Py... 目录python中isinstance()函数原理解释及详细用法指南一、isinstance()函数

Python的pandas库基础知识超详细教程

《Python的pandas库基础知识超详细教程》Pandas是Python数据处理核心库,提供Series和DataFrame结构,支持CSV/Excel/SQL等数据源导入及清洗、合并、统计等功能... 目录一、配置环境二、序列和数据表2.1 初始化2.2  获取数值2.3 获取索引2.4 索引取内容2

JS纯前端实现浏览器语音播报、朗读功能的完整代码

《JS纯前端实现浏览器语音播报、朗读功能的完整代码》在现代互联网的发展中,语音技术正逐渐成为改变用户体验的重要一环,下面:本文主要介绍JS纯前端实现浏览器语音播报、朗读功能的相关资料,文中通过代码... 目录一、朗读单条文本:① 语音自选参数,按钮控制语音:② 效果图:二、朗读多条文本:① 语音有默认值:②

Vue实现路由守卫的示例代码

《Vue实现路由守卫的示例代码》Vue路由守卫是控制页面导航的钩子函数,主要用于鉴权、数据预加载等场景,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着... 目录一、概念二、类型三、实战一、概念路由守卫(Navigation Guards)本质上就是 在路