[跑代码]BK-SDM: A Lightweight, Fast, and Cheap Version of Stable Diffusion

2023-11-30 22:12

本文主要是介绍[跑代码]BK-SDM: A Lightweight, Fast, and Cheap Version of Stable Diffusion,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

Installation(下载代码-装环境)

conda create -n bk-sdm python=3.8
conda activate bk-sdm
git clone https://github.com/Nota-NetsPresso/BK-SDM.git
cd BK-SDM
pip install -r requirements.txt
Note on the torch versions we've used
  • torch 1.13.1 for MS-COCO evaluation & DreamBooth finetuning on a single 24GB RTX3090
     

  • torch 2.0.1 for KD pretraining on a single 80GB A10
    火炬2.0.1在单个80GB A100上进行KD预训练

    • 如果A100上总批大小为256的预训练导致gpu内存不足,请检查torch版本并考虑升级到torch>2.0.0。
      我的版本也是torch2.0.1 单个A100(80G)理论上吃的下256batch

小的例子

PNDM采样器 50步去噪声

等效代码(仅修改SD-v1.4的U-Net,同时保留其文本编码器和图像解码器):

Distillation Pretraining

Our code was based on train_text_to_image.py of Diffusers 0.15.0.dev0. To access the latest version, use this link.
BK-SDM的diffusers版本0.15
我的diffusers版本比较高0.24.0

检测是否能够训练(先下载数据集get_laion_data.sh再运行代码kd_train_toy.sh)

1 一个玩具数据集(11K的img-txt对)下载到。

bash scripts/get_laion_data.sh preprocessed_11k

/data/laion_aes/preprocessed_11k (1.7GB in tar.gz;1.8GB数据文件夹)。
get_laion_data.sh

需要修改,实际就是下载这三个数据集,我自行下载

# https://netspresso-research-code-release.s3.us-east-2.amazonaws.com/data/improved_aesthetics_6.5plus/preprocessed_11k.tar.gz
# https://netspresso-research-code-release.s3.us-east-2.amazonaws.com/data/improved_aesthetics_6.5plus/preprocessed_212k.tar.gz
# https://netspresso-research-code-release.s3.us-east-2.amazonaws.com/data/improved_aesthetics_6.5plus/preprocessed_2256k.tar.gz

我修改后下载文件名 https://... .../preprocessed_11k.tar.gz直接粘贴到网址里面也可以下载
wget $S3_URL -0 $FILe_PATH
$S3_URL 就是这个网址
$FILe_PATH 就是下载路径./data/laion_aes/preprocessed_11k

DATA_TYPE=$"preprocessed_11k"  # {preprocessed_11k, preprocessed_212k, preprocessed_2256k}
FILE_NAME="${DATA_TYPE}.tar.gz"DATA_DIR="./data/laion_aes/"
FILE_UNZIP_DIR="${DATA_DIR}${DATA_TYPE}"
FILE_PATH="${DATA_DIR}${FILE_NAME}"if [ "$DATA_TYPE" = "preprocessed_11k" ] || [ "$DATA_TYPE" = "preprocessed_212k" ]; thenecho "-> preprocessed_11k or 212k"S3_URL="https://netspresso-research-code-release.s3.us-east-2.amazonaws.com/data/improved_aesthetics_6.5plus/${FILE_NAME}"
elif [ "$DATA_TYPE" = "preprocessed_2256k" ]; thenS3_URL="https://netspresso-research-code-release.s3.us-east-2.amazonaws.com/data/improved_aesthetics_6.25plus/${FILE_NAME}"
elseecho "Something wrong in data folder name"exit
fiwget $S3_URL -O $FILE_PATH
tar -xvzf $FILE_PATH -C $DATA_DIR
echo "downloaded to ${FILE_UNZIP_DIR}"

2 一个小脚本可以用来验证代码的可执行性,并找到与你的GPU匹配的批处理大小。
批量大小为8 (=4×2),训练BK-SDM-Base 20次迭代大约需要5分钟和22GB的GPU内存。

bash scripts/kd_train_toy.sh
MODEL_NAME="CompVis/stable-diffusion-v1-4"
TRAIN_DATA_DIR="./data/laion_aes/preprocessed_11k" # please adjust it if needed
UNET_CONFIG_PATH="./src/unet_config"UNET_NAME="bk_small" # option: ["bk_base", "bk_small", "bk_tiny"]
OUTPUT_DIR="./results/toy_"$UNET_NAME # please adjust it if neededBATCH_SIZE=2
GRAD_ACCUMULATION=4StartTime=$(date +%s)CUDA_VISIBLE_DEVICES=1 accelerate launch src/kd_train_text_to_image.py \--pretrained_model_name_or_path $MODEL_NAME \--train_data_dir $TRAIN_DATA_DIR\--use_ema \--resolution 512 --center_crop --random_flip \--train_batch_size $BATCH_SIZE \--gradient_checkpointing \--mixed_precision="fp16" \--learning_rate 5e-05 \--max_grad_norm 1 \--lr_scheduler="constant" --lr_warmup_steps=0 \--report_to="all" \--max_train_steps=20 \--seed 1234 \--gradient_accumulation_steps $GRAD_ACCUMULATION \--checkpointing_steps 5 \--valid_steps 5 \--lambda_sd 1.0 --lambda_kd_output 1.0 --lambda_kd_feat 1.0 \--use_copy_weight_from_teacher \--unet_config_path $UNET_CONFIG_PATH --unet_config_name $UNET_NAME \--output_dir $OUTPUT_DIREndTime=$(date +%s)
echo "** KD training takes $(($EndTime - $StartTime)) seconds."

单GPU训练BK-SDM{Base, Small, Tiny}-0.22M数据训练
 

bash scripts/get_laion_data.sh preprocessed_212k
bash scripts/kd_train.sh

1 下载数据集preprocessed_212k
2 训练kd_train.sh
(256batch 训练BD-SM-Base 50K轮次需要300hours/53G单卡)
(64batch 训练BD-SM-Base 50K轮次需要60hours/28G单卡) 不理解?
 

单GPU训练BK-SDM{Base, Small, Tiny}-2.3M数据训练

这篇关于[跑代码]BK-SDM: A Lightweight, Fast, and Cheap Version of Stable Diffusion的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/438740

相关文章

Python实例题之pygame开发打飞机游戏实例代码

《Python实例题之pygame开发打飞机游戏实例代码》对于python的学习者,能够写出一个飞机大战的程序代码,是不是感觉到非常的开心,:本文主要介绍Python实例题之pygame开发打飞机... 目录题目pygame-aircraft-game使用 Pygame 开发的打飞机游戏脚本代码解释初始化部

Java中Map.Entry()含义及方法使用代码

《Java中Map.Entry()含义及方法使用代码》:本文主要介绍Java中Map.Entry()含义及方法使用的相关资料,Map.Entry是Java中Map的静态内部接口,用于表示键值对,其... 目录前言 Map.Entry作用核心方法常见使用场景1. 遍历 Map 的所有键值对2. 直接修改 Ma

深入解析 Java Future 类及代码示例

《深入解析JavaFuture类及代码示例》JavaFuture是java.util.concurrent包中用于表示异步计算结果的核心接口,下面给大家介绍JavaFuture类及实例代码,感兴... 目录一、Future 类概述二、核心工作机制代码示例执行流程2. 状态机模型3. 核心方法解析行为总结:三

python获取cmd环境变量值的实现代码

《python获取cmd环境变量值的实现代码》:本文主要介绍在Python中获取命令行(cmd)环境变量的值,可以使用标准库中的os模块,需要的朋友可以参考下... 前言全局说明在执行py过程中,总要使用到系统环境变量一、说明1.1 环境:Windows 11 家庭版 24H2 26100.4061

pandas实现数据concat拼接的示例代码

《pandas实现数据concat拼接的示例代码》pandas.concat用于合并DataFrame或Series,本文主要介绍了pandas实现数据concat拼接的示例代码,具有一定的参考价值,... 目录语法示例:使用pandas.concat合并数据默认的concat:参数axis=0,join=

C#代码实现解析WTGPS和BD数据

《C#代码实现解析WTGPS和BD数据》在现代的导航与定位应用中,准确解析GPS和北斗(BD)等卫星定位数据至关重要,本文将使用C#语言实现解析WTGPS和BD数据,需要的可以了解下... 目录一、代码结构概览1. 核心解析方法2. 位置信息解析3. 经纬度转换方法4. 日期和时间戳解析5. 辅助方法二、L

Python使用Code2flow将代码转化为流程图的操作教程

《Python使用Code2flow将代码转化为流程图的操作教程》Code2flow是一款开源工具,能够将代码自动转换为流程图,该工具对于代码审查、调试和理解大型代码库非常有用,在这篇博客中,我们将深... 目录引言1nVflRA、为什么选择 Code2flow?2、安装 Code2flow3、基本功能演示

IIS 7.0 及更高版本中的 FTP 状态代码

《IIS7.0及更高版本中的FTP状态代码》本文介绍IIS7.0中的FTP状态代码,方便大家在使用iis中发现ftp的问题... 简介尝试使用 FTP 访问运行 Internet Information Services (IIS) 7.0 或更高版本的服务器上的内容时,IIS 将返回指示响应状态的数字代

MySQL 添加索引5种方式示例详解(实用sql代码)

《MySQL添加索引5种方式示例详解(实用sql代码)》在MySQL数据库中添加索引可以帮助提高查询性能,尤其是在数据量大的表中,下面给大家分享MySQL添加索引5种方式示例详解(实用sql代码),... 在mysql数据库中添加索引可以帮助提高查询性能,尤其是在数据量大的表中。索引可以在创建表时定义,也可

使用C#删除Excel表格中的重复行数据的代码详解

《使用C#删除Excel表格中的重复行数据的代码详解》重复行是指在Excel表格中完全相同的多行数据,删除这些重复行至关重要,因为它们不仅会干扰数据分析,还可能导致错误的决策和结论,所以本文给大家介绍... 目录简介使用工具C# 删除Excel工作表中的重复行语法工作原理实现代码C# 删除指定Excel单元