【扩散模型(十)】IP-Adapter 源码详解 4 - 训练细节、具体训了哪些层?

2024-09-06 21:44

本文主要是介绍【扩散模型(十)】IP-Adapter 源码详解 4 - 训练细节、具体训了哪些层?,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

系列文章目录

  • 【扩散模型(一)】中介绍了 Stable Diffusion 可以被理解为重建分支(reconstruction branch)和条件分支(condition branch)
  • 【扩散模型(二)】IP-Adapter 从条件分支的视角,快速理解相关的可控生成研究
  • 【扩散模型(三)】IP-Adapter 源码详解1-训练输入 介绍了训练代码中的 image prompt 的输入部分,即 img projection 模块。
  • 【扩散模型(四)】IP-Adapter 源码详解2-训练核心(cross-attention)详细介绍 IP-Adapter 训练代码的核心部分,即插入 Unet 中的、针对 Image prompt 的 cross-attention 模块。
  • 【扩散模型(五)】IP-Adapter 源码详解3-推理代码 详细介绍 IP-Adapter 推理过程代码。
  • 【可控图像生成系列论文(四)】IP-Adapter 具体是如何训练的?1公式篇
  • 【扩散模型(六)】IP-Adapter 是如何训练的?2 源码篇(IP-Adapter Plus)
  • 【扩散模型(九)】IP-Adapter 与 IP-Adapter Plus 的具体区别是什么?

文章目录

  • 系列文章目录
    • adapter_modules 分为两类
  • 总结


通过前面的系列文章,很清楚要训练的就是 image_proj_model(或者对于 plus 来说是 resampler) 和 adapter_modules 两块。

而 image_proj_model 这块比较简单,原码如下所示

    # freeze parameters of models to save more memoryunet.requires_grad_(False)vae.requires_grad_(False)text_encoder.requires_grad_(False)image_encoder.requires_grad_(False)#ip-adapterimage_proj_model = ImageProjModel(cross_attention_dim=unet.config.cross_attention_dim,clip_embeddings_dim=image_encoder.config.projection_dim,clip_extra_context_tokens=4,)

adapter_modules 分为两类

  1. AttnProcessor 对应 self attention
  2. IPAttnProcessor 对应 cross attention

按理说 self attention 对应的 AttnProcessor 应该不会被训练,但是 training = True,便让人非常费解。
在这里插入图片描述
进一步查看 AttnProcessor2_0 和 IPAttnProcessor2_0 后,就清楚了,因为从 AttnProcessor2_0 的构造函数(init)中并没有参数,就算是 trianing = True 也并不影响训练,实际训练的模块还是 IPAttnProcessor2_0 构造函数中的 to_k_ip 和 to_v_ip 两层 linear!

class AttnProcessor2_0(torch.nn.Module):r"""Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0)."""def __init__(self,hidden_size=None,cross_attention_dim=None,):super().__init__()if not hasattr(F, "scaled_dot_product_attention"):raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")def __call__(
...class IPAttnProcessor2_0(torch.nn.Module):r"""Attention processor for IP-Adapater for PyTorch 2.0.Args:hidden_size (`int`):The hidden size of the attention layer.cross_attention_dim (`int`):The number of channels in the `encoder_hidden_states`.scale (`float`, defaults to 1.0):the weight scale of image prompt.num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):The context length of the image features."""def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):super().__init__()if not hasattr(F, "scaled_dot_product_attention"):raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")self.hidden_size = hidden_sizeself.cross_attention_dim = cross_attention_dimself.scale = scaleself.num_tokens = num_tokensself.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)def __call__(

总结

  1. IP-Adapter 训的就是 image_proj_model(或者对于 plus 来说是 resampler) 和 adapter_modules 两块。
  2. 在 adapter_modules 中,实际只训了 IPAttnProcessor2_0 的 to_k_ip 和 to_v_ip。
  3. adapter_modules 是在每个有含有 cross attention 的 unet block 里进行的替换,如下图所示。

在这里插入图片描述

这篇关于【扩散模型(十)】IP-Adapter 源码详解 4 - 训练细节、具体训了哪些层?的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1143206

相关文章

一文详解如何在idea中快速搭建一个Spring Boot项目

《一文详解如何在idea中快速搭建一个SpringBoot项目》IntelliJIDEA作为Java开发者的‌首选IDE‌,深度集成SpringBoot支持,可一键生成项目骨架、智能配置依赖,这篇文... 目录前言1、创建项目名称2、勾选需要的依赖3、在setting中检查maven4、编写数据源5、开启热

Python常用命令提示符使用方法详解

《Python常用命令提示符使用方法详解》在学习python的过程中,我们需要用到命令提示符(CMD)进行环境的配置,:本文主要介绍Python常用命令提示符使用方法的相关资料,文中通过代码介绍的... 目录一、python环境基础命令【Windows】1、检查Python是否安装2、 查看Python的安

HTML5 搜索框Search Box详解

《HTML5搜索框SearchBox详解》HTML5的搜索框是一个强大的工具,能够有效提升用户体验,通过结合自动补全功能和适当的样式,可以创建出既美观又实用的搜索界面,这篇文章给大家介绍HTML5... html5 搜索框(Search Box)详解搜索框是一个用于输入查询内容的控件,通常用于网站或应用程

Python中使用uv创建环境及原理举例详解

《Python中使用uv创建环境及原理举例详解》uv是Astral团队开发的高性能Python工具,整合包管理、虚拟环境、Python版本控制等功能,:本文主要介绍Python中使用uv创建环境及... 目录一、uv工具简介核心特点:二、安装uv1. 通过pip安装2. 通过脚本安装验证安装:配置镜像源(可

C++ 函数 strftime 和时间格式示例详解

《C++函数strftime和时间格式示例详解》strftime是C/C++标准库中用于格式化日期和时间的函数,定义在ctime头文件中,它将tm结构体中的时间信息转换为指定格式的字符串,是处理... 目录C++ 函数 strftipythonme 详解一、函数原型二、功能描述三、格式字符串说明四、返回值五

LiteFlow轻量级工作流引擎使用示例详解

《LiteFlow轻量级工作流引擎使用示例详解》:本文主要介绍LiteFlow是一个灵活、简洁且轻量的工作流引擎,适合用于中小型项目和微服务架构中的流程编排,本文给大家介绍LiteFlow轻量级工... 目录1. LiteFlow 主要特点2. 工作流定义方式3. LiteFlow 流程示例4. LiteF

CSS3中的字体及相关属性详解

《CSS3中的字体及相关属性详解》:本文主要介绍了CSS3中的字体及相关属性,详细内容请阅读本文,希望能对你有所帮助... 字体网页字体的三个来源:用户机器上安装的字体,放心使用。保存在第三方网站上的字体,例如Typekit和Google,可以link标签链接到你的页面上。保存在你自己Web服务器上的字

MySQL存储过程之循环遍历查询的结果集详解

《MySQL存储过程之循环遍历查询的结果集详解》:本文主要介绍MySQL存储过程之循环遍历查询的结果集,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录前言1. 表结构2. 存储过程3. 关于存储过程的SQL补充总结前言近来碰到这样一个问题:在生产上导入的数据发现

SpringBoot服务获取Pod当前IP的两种方案

《SpringBoot服务获取Pod当前IP的两种方案》在Kubernetes集群中,SpringBoot服务获取Pod当前IP的方案主要有两种,通过环境变量注入或通过Java代码动态获取网络接口IP... 目录方案一:通过 Kubernetes Downward API 注入环境变量原理步骤方案二:通过

MyBatis ResultMap 的基本用法示例详解

《MyBatisResultMap的基本用法示例详解》在MyBatis中,resultMap用于定义数据库查询结果到Java对象属性的映射关系,本文给大家介绍MyBatisResultMap的基本... 目录MyBATis 中的 resultMap1. resultMap 的基本语法2. 简单的 resul