AIGC笔记--基于PEFT库使用LoRA

2024-05-29 04:52
文章标签 使用 lora 笔记 aigc peft

本文主要是介绍AIGC笔记--基于PEFT库使用LoRA,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1--相关讲解

LORA: LOW-RANK ADAPTATION OF LARGE LANGUAGE MODELS

LoRA 在 Stable Diffusion 中的三种应用:原理讲解与代码示例

PEFT-LoRA

2--基本原理

        固定原始层,通过添加和训练两个低秩矩阵,达到微调模型的效果;

3--简单代码

import torch
import torch.nn as nn
from peft import LoraConfig, get_peft_model, LoraModel
from peft.utils import get_peft_model_state_dict# 创建模型
class Simple_Model(nn.Module):def __init__(self):super().__init__()self.linear1 = nn.Linear(64, 128)self.linear2 = nn.Linear(128, 256)def forward(self, x: torch.Tensor):x = self.linear1(x)x = self.linear2(x)return xif __name__ == "__main__":# 初始化原始模型origin_model = Simple_Model()# 配置lora configmodel_lora_config = LoraConfig(r = 32, lora_alpha = 32, # scaling = lora_alpha / r 一般来说,lora_alpha的参数初始化为与r相同,即scale=1init_lora_weights = "gaussian", # 参数初始化方式target_modules = ["linear1", "linear2"], # 对应层添加lora层lora_dropout = 0.1)# Test datainput_data = torch.rand(2, 64)origin_output = origin_model(input_data)# 原始模型的权重参数origin_state_dict = origin_model.state_dict() # 两种方式生成对应的lora模型,调用后会更改原始的模型new_model1 = get_peft_model(origin_model, model_lora_config)new_model2 = LoraModel(origin_model, model_lora_config, "default")output1 = new_model1(input_data)output2 = new_model2(input_data)# 初始化时,lora_B矩阵会初始化为全0,因此最初 y = WX + (alpha/r) * BA * X == WX# origin_output == output1 == output2# 获取lora权重参数,两者在key_name上会有区别new_model1_lora_state_dict = get_peft_model_state_dict(new_model1)new_model2_lora_state_dict = get_peft_model_state_dict(new_model2)# origin_state_dict['linear1.weight'].shape -> [output_dim, input_dim]# new_model1_lora_state_dict['base_model.model.linear1.lora_A.weight'].shape -> [r, input_dim]# new_model1_lora_state_dict['base_model.model.linear1.lora_B.weight'].shape -> [output_dim, r]print("All Done!")

4--权重保存和合并

核心公式是:new_weights = origin_weights + alpha* (BA)

    # 借助diffuser的save_lora_weights保存模型权重from diffusers import StableDiffusionPipelinesave_path = "./"global_step = 0StableDiffusionPipeline.save_lora_weights(save_directory = save_path,unet_lora_layers = new_model1_lora_state_dict,safe_serialization = True,weight_name = f"checkpoint-{global_step}.safetensors",)# 加载lora模型权重(参考Stable Diffusion),其实可以重写一个简单的版本from safetensors import safe_openalpha = 1. # 参数融合因子lora_path = "./" + f"checkpoint-{global_step}.safetensors"state_dict = {}with safe_open(lora_path, framework="pt", device="cpu") as f:for key in f.keys():state_dict[key] = f.get_tensor(key)all_lora_weights = []for idx,key in enumerate(state_dict):# only process lora down keyif "lora_B." in key: continueup_key    = key.replace(".lora_A.", ".lora_B.") # 通过lora_A直接获取lora_B的键名model_key = key.replace("unet.", "").replace("lora_A.", "").replace("lora_B.", "")layer_infos = model_key.split(".")[:-1]curr_layer = new_model1while len(layer_infos) > 0:temp_name = layer_infos.pop(0)curr_layer = curr_layer.__getattr__(temp_name)weight_down = state_dict[key].to(curr_layer.weight.data.device)weight_up   = state_dict[up_key].to(curr_layer.weight.data.device)# 将lora参数合并到原模型参数中 -> new_W = origin_W + alpha*(BA)curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)all_lora_weights.append([model_key, torch.mm(weight_up, weight_down).t()])print('Load Lora Done')

5--完整代码

PEFT_LoRA

这篇关于AIGC笔记--基于PEFT库使用LoRA的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1012703

相关文章

Conda与Python venv虚拟环境的区别与使用方法详解

《Conda与Pythonvenv虚拟环境的区别与使用方法详解》随着Python社区的成长,虚拟环境的概念和技术也在不断发展,:本文主要介绍Conda与Pythonvenv虚拟环境的区别与使用... 目录前言一、Conda 与 python venv 的核心区别1. Conda 的特点2. Python v

Spring Boot中WebSocket常用使用方法详解

《SpringBoot中WebSocket常用使用方法详解》本文从WebSocket的基础概念出发,详细介绍了SpringBoot集成WebSocket的步骤,并重点讲解了常用的使用方法,包括简单消... 目录一、WebSocket基础概念1.1 什么是WebSocket1.2 WebSocket与HTTP

C#中Guid类使用小结

《C#中Guid类使用小结》本文主要介绍了C#中Guid类用于生成和操作128位的唯一标识符,用于数据库主键及分布式系统,支持通过NewGuid、Parse等方法生成,感兴趣的可以了解一下... 目录前言一、什么是 Guid二、生成 Guid1. 使用 Guid.NewGuid() 方法2. 从字符串创建

Python使用python-can实现合并BLF文件

《Python使用python-can实现合并BLF文件》python-can库是Python生态中专注于CAN总线通信与数据处理的强大工具,本文将使用python-can为BLF文件合并提供高效灵活... 目录一、python-can 库:CAN 数据处理的利器二、BLF 文件合并核心代码解析1. 基础合

Python使用OpenCV实现获取视频时长的小工具

《Python使用OpenCV实现获取视频时长的小工具》在处理视频数据时,获取视频的时长是一项常见且基础的需求,本文将详细介绍如何使用Python和OpenCV获取视频时长,并对每一行代码进行深入解析... 目录一、代码实现二、代码解析1. 导入 OpenCV 库2. 定义获取视频时长的函数3. 打开视频文

Spring IoC 容器的使用详解(最新整理)

《SpringIoC容器的使用详解(最新整理)》文章介绍了Spring框架中的应用分层思想与IoC容器原理,通过分层解耦业务逻辑、数据访问等模块,IoC容器利用@Component注解管理Bean... 目录1. 应用分层2. IoC 的介绍3. IoC 容器的使用3.1. bean 的存储3.2. 方法注

Python内置函数之classmethod函数使用详解

《Python内置函数之classmethod函数使用详解》:本文主要介绍Python内置函数之classmethod函数使用方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地... 目录1. 类方法定义与基本语法2. 类方法 vs 实例方法 vs 静态方法3. 核心特性与用法(1编程客

Linux中压缩、网络传输与系统监控工具的使用完整指南

《Linux中压缩、网络传输与系统监控工具的使用完整指南》在Linux系统管理中,压缩与传输工具是数据备份和远程协作的桥梁,而系统监控工具则是保障服务器稳定运行的眼睛,下面小编就来和大家详细介绍一下它... 目录引言一、压缩与解压:数据存储与传输的优化核心1. zip/unzip:通用压缩格式的便捷操作2.

使用Python实现可恢复式多线程下载器

《使用Python实现可恢复式多线程下载器》在数字时代,大文件下载已成为日常操作,本文将手把手教你用Python打造专业级下载器,实现断点续传,多线程加速,速度限制等功能,感兴趣的小伙伴可以了解下... 目录一、智能续传:从崩溃边缘抢救进度二、多线程加速:榨干网络带宽三、速度控制:做网络的好邻居四、终端交互

Python中注释使用方法举例详解

《Python中注释使用方法举例详解》在Python编程语言中注释是必不可少的一部分,它有助于提高代码的可读性和维护性,:本文主要介绍Python中注释使用方法的相关资料,需要的朋友可以参考下... 目录一、前言二、什么是注释?示例:三、单行注释语法:以 China编程# 开头,后面的内容为注释内容示例:示例:四