使用 PAI-Blade 加速 StableDiffusion Fine-Tuning

2023-12-15 18:45

本文主要是介绍使用 PAI-Blade 加速 StableDiffusion Fine-Tuning,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

01

背景

Stable Diffusion 模型自从发布以来在互联网上发展迅猛,它可以根据用户输入的文本描述信息生成相关图片,用户也可以提供自己喜爱的风格的照片,来对模型进行微调。例如当我们输入 "A photo of sks dog in a bucket" ,StableDiffusion 模型会生成类似下面的图片:

02

PAI-Blade 加速 PyTorch 训练

PAI-Blade 使用编译优化技术提高 PyTorch 程序的执行效率,其代码已经开源在

Github: https://github.com/alibaba/BladeDISC.

PAI-Blade API

使用 PAI-Blade 对 PyTorch 程序进行加速非常简单,只需要在原有程序上插入两行代码即可:

# 1. import PAI-Blade Python package import torch_blade# 2. compile and accelerate 'model' performancemodel = torch.compile(backend='aot_disc')(model)for batch, label in data_loader(): output = model(**batch) loss = compute_loss(output, label) loss.backward() optimizer.step()

torch.compile(backend='aot_disc')(model) 使用 BladeDISC 作为 TorchDynamo 的编译器后端,加速 PyTorch 模型的的前向和反向计算,其中 model也可以是一段 PyTorch 实现的 Python 函数。

PAI-Blade 编译 Pipeline

TorchDynamo 将 PyTorch 程序记录到一个或多个 FX Graph 上,PAI-Blade 通过一系列 Pass 优化计算图的执行效率。

https://pytorch.org/docs/2.1/torch.compiler_deepdive.html

MHLO Conversion PAI-Blade 引入了 Torch-MLIR Project 将 PyTorch IR 转换为 MLIR 世界中的 MHLO Dialect,以便进一步使用 BladeDISC 编译器进行性能优化,同时 PAI-Blade 开发团队也将 MHLO 转换相关代码贡献给了社区。

https://github.com/llvm/torch-mlir

BlaDNN Library 提供了高性能计算密集型算子库,PAI-Blade 会根据计算图上的一些典型 Pattern,自动的将一部分子图替换为等价的,有极致性能的 BlaDNN 算子。

Memory Intensive Kernel Fusion

算子融合是图层面编译优化最重要的收益来源,一个典型的 workload 上,可能会包含 element-wise 算子,动态 shape 的 broadcast/reshape/reduce 算子以及计算密集型算子,例如 GEMM 等。在 PyTorch 中,每一个算子都是一个独立的 kernel,而过多的 kernel 会导致 Tensor 在 Cache 中频繁的交换,导致显存带宽的浪费,而频繁的发射 kernel 也会造成一定的额外的开销。

对于如上图的一个典型的访存类算子 workload ,类似 XLA 做法会将 schedule 相同的算子合并在一起,从而将 7 个 kernel 合并为 3 个 kernel。BladeDISC 会采用更为激进的 fusion 策略,从而进一步提高 workload 性能:

  • 每个 fusion block 表示为独立的 ww 结构,使用 shared-memory 进行粘连,从而将 kernel 数量由 3 减少到 1
  • 使用 AStitch 技术,将不同的 loop 结构黏贴在一起,通过 index 推导生成一个 loop 结构,同时引入了 index_cache, value_cache 消除冗余的 index 计算。

在上面 workload 中,BladeDISC 的 fusion 策略可以将 kernel 数量从 7 减少到 1,并且在 kernel 内部使用 index 推导和 cache 来减少冗余的计算,从而逼近硬件的理论峰值。

Inplace Mutation 优化

在 PyTorch Eager 模式下,通过 inplace 算子 (aten.add_) 可以实现对输入的 tensor (w) 进行更新,而不需要一个额外的输出 Tensor。但是在 MLIR 世界里,IR 必须是符合 SSA 形式的,所以没有办法直接表示 inplace 语义,通常的做法是增加一个 D2D memcpy 算子来将输出的 buffer (w') 覆盖输入 buffer (w)。但这样做会造成额外的一次显存拷贝。

BladeDISC 的做法是找到需要 inplace 更新的两个buffer,在 MHLO IR 上进行标记,将 w和 w' 标记为相同的 buffer,在生成 gpu.store指时,将输出直接写回 wbuffer,从而节省一次显存拷贝所造成的额外开销。

03

Benchmark

PAI-Blade 在 A10 和 A100 上最大可获得 41.6 % 和 28.4% 的性能收益(batchsize=1)。

04

在 DSW 上使用 PAI-Blade

  1. 在 PAI 平台中创建 DSW 实例,并使用如下自定义 Docker 镜像,具体步骤可以参考文档

https://help.aliyun.com/zh/pai/user-guide/overview-5

pai-blade-registry.cn-hangzhou.cr.aliyuncs.com/pai-blade/aicompiler:latest-stablediffusion-torch-2.0.1-cu118
  1. 创建 Jupyter Notebook,启动 fine-tuning 任务
!cd /opt/StableDiffusion && bash launch_dreambooth_train.sh

在看到如下日志时,表示微调任务执行完成:

  1. 启动推理任务,并在 Jupyter Notebook 中查看生成的图片
!cd /opt/StableDiffusion && python inference.py && cp dog-bucket.png /mnt/workspace

参考文档:

  • BladeDISC:

https://github.com/alibaba/BladeDISC

  • TorchDynamo:

https://pytorch.org/docs/2.1/torch.compiler_deepdive.html

  • Torch-MLIR Project:

https://github.com/llvm/torch-mlir

  • 文档:

https://help.aliyun.com/zh/pai/user-guide/overview-5


 

这篇关于使用 PAI-Blade 加速 StableDiffusion Fine-Tuning的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/497469

相关文章

使用Java将各种数据写入Excel表格的操作示例

《使用Java将各种数据写入Excel表格的操作示例》在数据处理与管理领域,Excel凭借其强大的功能和广泛的应用,成为了数据存储与展示的重要工具,在Java开发过程中,常常需要将不同类型的数据,本文... 目录前言安装免费Java库1. 写入文本、或数值到 Excel单元格2. 写入数组到 Excel表格

redis中使用lua脚本的原理与基本使用详解

《redis中使用lua脚本的原理与基本使用详解》在Redis中使用Lua脚本可以实现原子性操作、减少网络开销以及提高执行效率,下面小编就来和大家详细介绍一下在redis中使用lua脚本的原理... 目录Redis 执行 Lua 脚本的原理基本使用方法使用EVAL命令执行 Lua 脚本使用EVALSHA命令

Java 中的 @SneakyThrows 注解使用方法(简化异常处理的利与弊)

《Java中的@SneakyThrows注解使用方法(简化异常处理的利与弊)》为了简化异常处理,Lombok提供了一个强大的注解@SneakyThrows,本文将详细介绍@SneakyThro... 目录1. @SneakyThrows 简介 1.1 什么是 Lombok?2. @SneakyThrows

使用Python和Pyecharts创建交互式地图

《使用Python和Pyecharts创建交互式地图》在数据可视化领域,创建交互式地图是一种强大的方式,可以使受众能够以引人入胜且信息丰富的方式探索地理数据,下面我们看看如何使用Python和Pyec... 目录简介Pyecharts 简介创建上海地图代码说明运行结果总结简介在数据可视化领域,创建交互式地

Java Stream流使用案例深入详解

《JavaStream流使用案例深入详解》:本文主要介绍JavaStream流使用案例详解,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录前言1. Lambda1.1 语法1.2 没参数只有一条语句或者多条语句1.3 一个参数只有一条语句或者多

Java Spring 中 @PostConstruct 注解使用原理及常见场景

《JavaSpring中@PostConstruct注解使用原理及常见场景》在JavaSpring中,@PostConstruct注解是一个非常实用的功能,它允许开发者在Spring容器完全初... 目录一、@PostConstruct 注解概述二、@PostConstruct 注解的基本使用2.1 基本代

C#使用StackExchange.Redis实现分布式锁的两种方式介绍

《C#使用StackExchange.Redis实现分布式锁的两种方式介绍》分布式锁在集群的架构中发挥着重要的作用,:本文主要介绍C#使用StackExchange.Redis实现分布式锁的... 目录自定义分布式锁获取锁释放锁自动续期StackExchange.Redis分布式锁获取锁释放锁自动续期分布式

springboot使用Scheduling实现动态增删启停定时任务教程

《springboot使用Scheduling实现动态增删启停定时任务教程》:本文主要介绍springboot使用Scheduling实现动态增删启停定时任务教程,具有很好的参考价值,希望对大家有... 目录1、配置定时任务需要的线程池2、创建ScheduledFuture的包装类3、注册定时任务,增加、删

使用Python实现矢量路径的压缩、解压与可视化

《使用Python实现矢量路径的压缩、解压与可视化》在图形设计和Web开发中,矢量路径数据的高效存储与传输至关重要,本文将通过一个Python示例,展示如何将复杂的矢量路径命令序列压缩为JSON格式,... 目录引言核心功能概述1. 路径命令解析2. 路径数据压缩3. 路径数据解压4. 可视化代码实现详解1

Pandas透视表(Pivot Table)的具体使用

《Pandas透视表(PivotTable)的具体使用》透视表用于在数据分析和处理过程中进行数据重塑和汇总,本文就来介绍一下Pandas透视表(PivotTable)的具体使用,感兴趣的可以了解一下... 目录前言什么是透视表?使用步骤1. 引入必要的库2. 读取数据3. 创建透视表4. 查看透视表总结前言