不破坏预训练模型结构且与Lora微调后的模型等价

2024-06-14 13:36

本文主要是介绍不破坏预训练模型结构且与Lora微调后的模型等价,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

不破坏预训练模型结构且与Lora微调后的模型等价

  • 一.原理
  • 二.loss曲线
  • 三.代码

背景: Lora模块的引入破坏了图优化逻辑,是否能在不破坏原始的图的情况下,通过修改权值等价实现呢
方案: 将Lora的结果做为Ground True,去训练原始的Linear
小结: 方案虽然可行,但计算成本太高,Lora的初衷是减少微调的计算量

一.原理

在这里插入图片描述

二.loss曲线

在这里插入图片描述

三.代码

import torch
import torch.nn as nn
import torch.optim as optim
import os
import numpy as np
from torch.utils.tensorboard import SummaryWriterclass PreTrainedModel(nn.Module):def __init__(self, input_dim, output_dim):super(PreTrainedModel, self).__init__()self.input_dim=input_dimself.output_dim=output_dimself.fc = nn.Linear(input_dim, output_dim)nn.init.normal_(self.fc.weight.data)nn.init.normal_(self.fc.bias.data)def forward(self, x):return self.fc(x)def clone(self):cloned_model = PreTrainedModel(self.input_dim,self.output_dim)cloned_model.load_state_dict(self.state_dict())return cloned_modelclass LoRALayer(nn.Module):def __init__(self, input_dim, low_rank_dim,lora_alpha=4.0):super(LoRALayer, self).__init__()self.U = nn.Linear(input_dim, low_rank_dim, bias=False)self.B = nn.Linear(low_rank_dim, input_dim, bias=False)self.lora_alpha=lora_alphann.init.normal_(self.U.weight.data)nn.init.normal_(self.B.weight.data)def forward(self, x):return x + self.B(self.U(x))*self.lora_alphaclass LoRAAdaptedModel(nn.Module):def __init__(self,input_dim,output_dim,low_rank_dim):super(LoRAAdaptedModel, self).__init__()self.pretrained_model = PreTrainedModel(input_dim,output_dim)self.lora = LoRALayer(output_dim, low_rank_dim)def clone_pretrained_model(self):return self.pretrained_model.clone()def forward(self, x):x = self.pretrained_model(x)x = self.lora(x)return xdef train():writer = SummaryWriter('runs/lora')input_dim = 128low_rank_dim=16output_dim=256torch.manual_seed(1)lora_adapted_model = LoRAAdaptedModel(input_dim,output_dim, low_rank_dim).cuda().eval()pretrained_model = lora_adapted_model.clone_pretrained_model().cuda()#criterion = nn.MSELoss()criterion = nn.L1Loss()optimizer = optim.Adam(pretrained_model.parameters(), lr=0.01)writer.add_graph(lora_adapted_model,torch.rand(32, input_dim).cuda())for epoch in range(10000000):running_loss = 0.0for i in range(100):input_data =torch.rand(8,input_dim,device="cuda")with torch.no_grad():gt=lora_adapted_model(input_data).detach()pred = pretrained_model(input_data)loss = criterion(pred,gt)optimizer.zero_grad()loss.backward()optimizer.step()running_loss += loss.item()avg_loss=running_loss / 100print('[%d] loss: %f' % (epoch + 1,avg_loss ))writer.add_scalar('training loss', avg_loss, epoch)running_loss = 0.0
train()

这篇关于不破坏预训练模型结构且与Lora微调后的模型等价的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!


原文地址:
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.chinasem.cn/article/1060551

相关文章

MySQL中的索引结构和分类实战案例详解

《MySQL中的索引结构和分类实战案例详解》本文详解MySQL索引结构与分类,涵盖B树、B+树、哈希及全文索引,分析其原理与优劣势,并结合实战案例探讨创建、管理及优化技巧,助力提升查询性能,感兴趣的朋... 目录一、索引概述1.1 索引的定义与作用1.2 索引的基本原理二、索引结构详解2.1 B树索引2.2

如何使用Maven创建web目录结构

《如何使用Maven创建web目录结构》:本文主要介绍如何使用Maven创建web目录结构的问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录创建web工程第一步第二步第三步第四步第五步第六步第七步总结创建web工程第一步js通过Maven骨架创pytho

Python循环结构全面解析

《Python循环结构全面解析》循环中的代码会执行特定的次数,或者是执行到特定条件成立时结束循环,或者是针对某一集合中的所有项目都执行一次,这篇文章给大家介绍Python循环结构解析,感兴趣的朋友跟随... 目录for-in循环while循环循环控制语句break语句continue语句else子句嵌套的循

详解如何使用Python从零开始构建文本统计模型

《详解如何使用Python从零开始构建文本统计模型》在自然语言处理领域,词汇表构建是文本预处理的关键环节,本文通过Python代码实践,演示如何从原始文本中提取多尺度特征,并通过动态调整机制构建更精确... 目录一、项目背景与核心思想二、核心代码解析1. 数据加载与预处理2. 多尺度字符统计3. 统计结果可

SpringBoot整合Sa-Token实现RBAC权限模型的过程解析

《SpringBoot整合Sa-Token实现RBAC权限模型的过程解析》:本文主要介绍SpringBoot整合Sa-Token实现RBAC权限模型的过程解析,本文给大家介绍的非常详细,对大家的学... 目录前言一、基础概念1.1 RBAC模型核心概念1.2 Sa-Token核心功能1.3 环境准备二、表结

Python+PyQt5实现文件夹结构映射工具

《Python+PyQt5实现文件夹结构映射工具》在日常工作中,我们经常需要对文件夹结构进行复制和备份,本文将带来一款基于PyQt5开发的文件夹结构映射工具,感兴趣的小伙伴可以跟随小编一起学习一下... 目录概述功能亮点展示效果软件使用步骤代码解析1. 主窗口设计(FolderCopyApp)2. 拖拽路径

Spring Security基于数据库的ABAC属性权限模型实战开发教程

《SpringSecurity基于数据库的ABAC属性权限模型实战开发教程》:本文主要介绍SpringSecurity基于数据库的ABAC属性权限模型实战开发教程,本文给大家介绍的非常详细,对大... 目录1. 前言2. 权限决策依据RBACABAC综合对比3. 数据库表结构说明4. 实战开始5. MyBA

Java的IO模型、Netty原理解析

《Java的IO模型、Netty原理解析》Java的I/O是以流的方式进行数据输入输出的,Java的类库涉及很多领域的IO内容:标准的输入输出,文件的操作、网络上的数据传输流、字符串流、对象流等,这篇... 目录1.什么是IO2.同步与异步、阻塞与非阻塞3.三种IO模型BIO(blocking I/O)NI

使用Java实现通用树形结构构建工具类

《使用Java实现通用树形结构构建工具类》这篇文章主要为大家详细介绍了如何使用Java实现通用树形结构构建工具类,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录完整代码一、设计思想与核心功能二、核心实现原理1. 数据结构准备阶段2. 循环依赖检测算法3. 树形结构构建4. 搜索子

基于Flask框架添加多个AI模型的API并进行交互

《基于Flask框架添加多个AI模型的API并进行交互》:本文主要介绍如何基于Flask框架开发AI模型API管理系统,允许用户添加、删除不同AI模型的API密钥,感兴趣的可以了解下... 目录1. 概述2. 后端代码说明2.1 依赖库导入2.2 应用初始化2.3 API 存储字典2.4 路由函数2.5 应