SDXS:知识蒸馏在高效图像生成中的应用

2024-08-22 21:44

本文主要是介绍SDXS:知识蒸馏在高效图像生成中的应用,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

人工智能咨询培训老师叶梓 转载标明出处

扩散模型虽然在图像生成方面表现出色,但其迭代采样过程导致在低功耗设备上部署面临挑战,同时在云端高性能GPU平台上的能耗也不容忽视。为了解决这一问题,小米公司的Yuda Song、Zehao Sun、Xuanwu Yin等人提出了一种新的方法——SDXS,通过知识蒸馏简化了U-Net和图像解码器架构,并引入了一种创新的一步式DM训练技术,使用特征匹配和得分蒸馏,从而在单GPU上实现了大约100 FPS(比SD v1.5快30倍)和30 FPS(比SDXL快60倍)的推理速度。

图1为在图像生成时间限制为1秒的情况下,不同模型的性能对比。SDXL模型在这种情况下只能使用16次函数评估(NFEs)来生成稍微模糊的图像,而提出的SDXS-1024模型却能够生成30张清晰的图像。这表明SDXS-1024在保持图像质量的同时显著提高了生成速度。本方法还能够训练ControlNet,这是一种能够嵌入空间引导的网络,用于图像到图像的任务,如草图到图像的转换、修复和超分辨率等。证明了SDXS方法的灵活性和应用潜力。

方法

LDM框架由三个关键要素组成:文本编码器、图像解码器以及一个需要多次迭代以生成清晰图像的去噪模型。由于文本编码器的开销相对较低,因此优化其大小并不是研究的重点。

VAE优化:LDM框架通过将样本投影到计算效率更高的低维潜在空间,显著提高了高分辨率图像扩散模型的训练效率。这一过程通过使用预训练模型,如变分自编码器(Variational AutoEncoder, VAE)或向量量化变分自编码器(Vector Quantised-Variational AutoEncoder, VQVAE)来实现高比例图像压缩。VAE包含一个将图像映射到潜在空间的编码器,以及一个重建图像的解码器。其训练通过平衡重建损失、Kullback-Leibler (KL) 散度和GAN损失来优化。然而,训练中对所有样本同等对待引入了冗余。研究者们提出了一种VAE蒸馏(VD)损失,用于训练一个小型的图像解码器G: 其中,D是GAN判别器,用于平衡两个损失项,表示在8倍下采样图像上的L1损失。图2(a)展示了蒸馏小型图像解码器的训练策略。倡使用简化的CNN架构,不包含注意力机制和归一化层等复杂组件,只关注基本的残差块和上采样层。

U-Net优化: LDMs采用U-Net架构作为核心去噪模型,该架构结合了残差块和Transformer块。为了利用预训练的U-Nets的能力,同时减少计算需求和参数数量,研究者们采用了知识蒸馏策略,这一策略受到BK-SDM的块移除训练策略启发。这涉及从U-Net中选择性地移除残差和Transformer块,目的是训练一个更紧凑的模型,该模型仍能有效复现原始模型的中间特征图和输出。图2(b)展示了蒸馏小型U-Net的训练策略。知识蒸馏通过输出知识蒸馏(OKD)和特征知识蒸馏(FKD)损失实现:总的损失函数是两者的结合: 其中,λF​平衡两个损失项。与BK-SDM不同,研究者们排除了原始的去噪损失。模型基于SD-2.1基础版和SDXL-1.0基础版进行了小型化。对于SD-2.1基础版,研究者们去除了中间阶段、下采样阶段的最后阶段和上采样阶段的第一阶段,并去除了最高分辨率阶段的Transformer块。对于SDXL-1.0基础版,研究者们去除了大部分Transformer块。

ControlNet优化: ControlNet通过嵌入空间引导来增强扩散模型,使图像到图像的任务如草图到图像的转换、修复和超分辨率成为可能。它复制了U-Net的编码器架构和参数,并增加了额外的卷积层以纳入空间控制。尽管ControlNet继承了U-Net的参数并采用零卷积来提高训练稳定性,但其训练过程仍然成本高昂且显著受数据集质量影响。为了解决这些挑战,研究者们提出了一种蒸馏方法,将原始U-Net中的ControlNet蒸馏到小型U-Net中的相应ControlNet。图2(b)展示了这一过程,不是直接蒸馏ControlNet零卷积的输出,而是将ControlNet与U-Net结合,然后蒸馏U-Net的中间特征图和输出,这使得蒸馏后的ControlNet和小型U-Net能够更好地协同工作。考虑到ControlNet不影响U-Net编码器的特征图,特征蒸馏仅应用于U-Net的解码器。

尽管扩散模型(DMs)在图像生成方面表现出色,但它们依赖于多个采样步骤,即使采用先进的采样器,这也引入了显著的推理延迟。为了解决这个问题,先前的研究引入了知识蒸馏技术,例如渐进式蒸馏(progressive distillation)和一致性蒸馏(consistency distillation),旨在减少采样步骤并加速推理。然而,这些方法通常只能在4到8个采样步骤中产生清晰的图像,这与在生成对抗网络(GANs)中看到的一步式生成过程形成了鲜明对比。

直接训练一步式模型的方法包括初始化噪声ϵ,并使用常微分方程(ODE)采样器ψ进行采样以获得生成的图像,从而构建噪声-图像对。这些对在训练期间作为学生模型的输入和真实情况。然而,这种方法通常导致生成质量低下的图像。根本问题是使用预训练的DM生成的噪声-图像对的采样轨迹交叉,导致不适定问题。Rectified Flow通过拉直采样轨迹来解决这一挑战。它替换了训练目标,并提出了一种“重流”策略来优化配对,从而最小化轨迹交叉。

采样轨迹的交叉可能导致一个噪声输入对应多个真实图像,导致训练模型生成的图像是多个可行输出的加权和。为了解决这个问题,研究者们探索了改变权重方案以优先考虑更清晰图像的替代损失函数。在大多数情况下,可以使用L1损失、感知损失和LPIPS损失来改变权重形式。研究者们基于特征匹配的方法,计算由编码器模型生成的中间特征图上损失。具体来说,他们从DISTS损失中汲取灵感,对这些特征图应用结构相似性指数(SSIM)以获得更精细的特征匹配损失: 其中 是由编码器 编码的第 个中间特征图上计算的SSIM损失的权重,是由小型U-Net 生成的图像,是由原始U-Net xϕ​ 使用ODE采样器ψ生成的图像。在实践中,使用预训练的CNN骨干、ViT骨干和DM U-Net的编码器都能产生有利的结果,与MSE损失的比较在图6中展示。

尽管特征匹配损失可以产生几乎清晰的图像,但它未能实现真正的分布匹配,因此训练的模型只能作为正式训练的初始化。为了解决这一差距,Diff-Instruct中使用的训练策略,该策略旨在通过在时间步上匹配边际得分函数,使模型的输出分布与预训练模型的分布更紧密地对齐。然而,因为它需要在 t→T 时添加高水平的噪声以使目标得分可计算,此时估计的得分函数是不准确的。研究者们指出,扩散模型的采样轨迹从粗糙到精细,这意味着 t→T 时,得分函数提供了低频信息的梯度,而 t→0 时,它提供了高频信息的梯度。因此,研究者们将时间步分为两段:,后者被LFM替换,因为它可以提供足够的低频梯度。这种策略可以正式表示为: 其中 是在时间 t 和状态 下的函数,用于平衡两段的梯度,。研究者们有意将 α 设置接近1,并将 设置在高值,以确保模型的输出分布与预训练得分函数预测的分布平滑对齐。在概率密度显著重叠后,逐渐降低 α 和 。图3描述了训练策略,其中离线DM表示预训练DM的U-Net,在线DM是从离线DM初始化并在生成的图像上通过等式(1)微调得到的。在实践中,在线DM和学生DM交替训练,如算法1所示。

 一旦一步式DM训练完成,就可以像其他DM一样进行微调,以调整生成图像的风格。研究者们结合使用LoRA和提出的分段得分蒸馏来微调一步式DM,如图4所示。具体为将预训练的LoRA插入离线DM中,如果它也与教师DM兼容,也会插入到那里。要注意,不将LoRA插入在线DM中,因为它对应于一步式DM的输出分布。然后,使用与一步式训练相同的训练程序,但跳过特征匹配预热,因为LoRA微调比完全微调更稳定。另外当教师DM不能纳入预训练的LoRA时,使用降低的 。通过这种方式,可以将预训练的LoRA蒸馏到SDXS的LoRA中。

研究者们的方法也可以适应于ControlNet的训练,使微小的一步式模型能够在其图像生成过程中纳入图像条件,如图5所示。与用于文本到图像生成的基础模型相比,这里训练的模型是伴随前面提到的小型U-Net的蒸馏ControlNet,并且在训练期间U-Net的参数是固定的。重点是需要从教师模型采样的图像中提取控制图像,而不是从数据集图像中提取,以确保噪声、目标图像和控制图像形成一个配对三元组。此外,原始多步U-Net的伴随预训练ControlNet与在线U-Net和离线U-Net集成,但不参与训练。与文本编码器类似,其功能限于作为预训练的特征提取器。通过这种方式,为了进一步减少损失L,训练的ControlNet学习利用从目标图像中提取的控制图像。同时,得分蒸馏鼓励模型匹配边际分布,增强生成图像的上下文相关性。值得注意的是,研究发现用新初始化的噪声替换U-Net噪声输入的一部分可以增强控制能力。图5展示了基于特征匹配和得分蒸馏提出的一步式ControlNet训练策略。虚线表示梯度反向传播。

实验

研究者的代码是基于diffusers库开发的。由于他们无法访问SD v2.1基础版和SDXL的训练数据集,整个训练过程几乎是无数据的,完全依赖于公开可访问数据集中提供的提示。他们使用开源的预训练模型与这些提示结合,生成相应的图像。为了训练模型,他们将训练小批量大小配置在1,024到2,048之间。为了在现有硬件上适应这个批量大小,必要时他们有策略地实施了梯度累积。他们发现所提出训练策略导致模型生成的图像纹理较少。因此,在训练后,他们使用GAN损失结合极低秩的LoRA进行了短暂的微调。当需要GAN损失时,他们使用了StyleGAN-T中的Projected GAN损失,基本设置与ADD一致。对于SDXS-1024的训练,他们使用Vega,SDXL的紧凑版本,作为在线DM和离线DM的初始化,以减少训练开销。

表3为在MS-COCO 2017验证集上的定量结果,即FID和CLIP分数。由于FID对高斯分布的强烈假设,它不是衡量图像质量的一个好的指标,因为它受到生成样本多样性的显著影响。表3显示了MS-COCO 2017 5K子集上的性能比较,图7显示了一些示例。尽管模型大小和所需的采样步骤数量都有明显减少,但SDXS-512的提示跟随能力仍然优于SD v1.5。与Tiny SD(另一个为效率而设计的模型)相比,SDXS-512的优越性更加明显。这一观察结果也在SDXS-1024的性能中得到了一致的验证。使用所提方法训练LoRA的样本如图9所示。显然,模型生成的图像风格可以有效地转移到与离线DM集成的风格导向LoRA匹配的风格,同时通常保持场景布局的一致性。

研究者引入的一步式训练方法是足够通用的,可以应用于图像条件生成。他们展示了其在促进图像到图像转换方面的有效性,特别是利用ControlNet进行涉及canny边缘和深度图的转换。图8展示了两个不同任务的代表性示例,突出了生成图像紧密遵循控制图像提供的指导的能力。然而,这也揭示了在图像多样性方面的显著局限性。如图1所示,虽然问题可以通过替换提示来缓解,但它仍然是后续研究工作中加强的领域。

实验证明将高效的图像条件生成部署在边缘设备上是一个充满前景的研究方向,研究者计划在未来探索包括修复和超分辨率在内的更多应用。通过不断的技术创新和优化,人工智能在图像生成领域的应用将更加广泛和深入。

论文链接:https://arxiv.org/abs/2403.16627

项目地址:https://idkiro.github.io/sdxs/

这篇关于SDXS:知识蒸馏在高效图像生成中的应用的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1097483

相关文章

C++高效内存池实现减少动态分配开销的解决方案

《C++高效内存池实现减少动态分配开销的解决方案》C++动态内存分配存在系统调用开销、碎片化和锁竞争等性能问题,内存池通过预分配、分块管理和缓存复用解决这些问题,下面就来了解一下... 目录一、C++内存分配的性能挑战二、内存池技术的核心原理三、主流内存池实现:TCMalloc与Jemalloc1. TCM

Python基于微信OCR引擎实现高效图片文字识别

《Python基于微信OCR引擎实现高效图片文字识别》这篇文章主要为大家详细介绍了一款基于微信OCR引擎的图片文字识别桌面应用开发全过程,可以实现从图片拖拽识别到文字提取,感兴趣的小伙伴可以跟随小编一... 目录一、项目概述1.1 开发背景1.2 技术选型1.3 核心优势二、功能详解2.1 核心功能模块2.

基于Python构建一个高效词汇表

《基于Python构建一个高效词汇表》在自然语言处理(NLP)领域,构建高效的词汇表是文本预处理的关键步骤,本文将解析一个使用Python实现的n-gram词频统计工具,感兴趣的可以了解下... 目录一、项目背景与目标1.1 技术需求1.2 核心技术栈二、核心代码解析2.1 数据处理函数2.2 数据处理流程

Python实现自动化Word文档样式复制与内容生成

《Python实现自动化Word文档样式复制与内容生成》在办公自动化领域,高效处理Word文档的样式和内容复制是一个常见需求,本文将展示如何利用Python的python-docx库实现... 目录一、为什么需要自动化 Word 文档处理二、核心功能实现:样式与表格的深度复制1. 表格复制(含样式与内容)2

Python中bisect_left 函数实现高效插入与有序列表管理

《Python中bisect_left函数实现高效插入与有序列表管理》Python的bisect_left函数通过二分查找高效定位有序列表插入位置,与bisect_right的区别在于处理重复元素时... 目录一、bisect_left 基本介绍1.1 函数定义1.2 核心功能二、bisect_left 与

Python使用Tkinter打造一个完整的桌面应用

《Python使用Tkinter打造一个完整的桌面应用》在Python生态中,Tkinter就像一把瑞士军刀,它没有花哨的特效,却能快速搭建出实用的图形界面,作为Python自带的标准库,无需安装即可... 目录一、界面搭建:像搭积木一样组合控件二、菜单系统:给应用装上“控制中枢”三、事件驱动:让界面“活”

python如何生成指定文件大小

《python如何生成指定文件大小》:本文主要介绍python如何生成指定文件大小的实现方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录python生成指定文件大小方法一(速度最快)方法二(中等速度)方法三(生成可读文本文件–较慢)方法四(使用内存映射高效生成

如何确定哪些软件是Mac系统自带的? Mac系统内置应用查看技巧

《如何确定哪些软件是Mac系统自带的?Mac系统内置应用查看技巧》如何确定哪些软件是Mac系统自带的?mac系统中有很多自带的应用,想要看看哪些是系统自带,该怎么查看呢?下面我们就来看看Mac系统内... 在MAC电脑上,可以使用以下方法来确定哪些软件是系统自带的:1.应用程序文件夹打开应用程序文件夹

Python中OpenCV与Matplotlib的图像操作入门指南

《Python中OpenCV与Matplotlib的图像操作入门指南》:本文主要介绍Python中OpenCV与Matplotlib的图像操作指南,本文通过实例代码给大家介绍的非常详细,对大家的学... 目录一、环境准备二、图像的基本操作1. 图像读取、显示与保存 使用OpenCV操作2. 像素级操作3.

C/C++的OpenCV 进行图像梯度提取的几种实现

《C/C++的OpenCV进行图像梯度提取的几种实现》本文主要介绍了C/C++的OpenCV进行图像梯度提取的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的... 目录预www.chinasem.cn备知识1. 图像加载与预处理2. Sobel 算子计算 X 和 Y