大模型微调中的内存效率问题及解决方案

2024-09-02 14:04

本文主要是介绍大模型微调中的内存效率问题及解决方案,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

人工智能咨询培训老师叶梓 转载标明出处

大模型(LLMs)在大规模训练中的内存消耗问题日益凸显,传统的参数高效微调技术,如低秩适应(LoRA),虽然在一定程度上缓解了这一问题,但其性能在很多大规模微调场景下仍无法与全参数训练相媲美。

为了解决上述问题,香港科技大学以及伊利诺伊大学香槟分校的研究团队共同提出了一种新的训练策略——Layerwise Importance Sampled AdamW(LISA)。LISA策略基于对LoRA在微调任务中层级权重规范分布的观察,发现不同层的权重规范呈现出不寻常的偏斜分布。利用这一关键发现,研究者们提出了一种简单有效的训练方法,该方法在多种设置下的性能都超过了LoRA和全参数训练,同时内存成本与LoRA相当。图1为在Alpaca GPT-4数据集上,使用全参数训练(FT)、LoRA和LISA方法对LLaMA-2-7B模型进行训练时的损失变化情况。显示了LISA方法相比其他方法在训练损失上的优势。

论文链接:https://arxiv.org/abs/2403.17919

开源地址:https://github.com/OptimalScale/LMFlow

方法

为了理解LoRA如何仅用少量参数实现有效训练,研究者们对多个模型进行了实证研究,特别关注了不同层的权重规范。他们使用Alpaca-GPT4数据集进行微调,并在训练过程中详细记录了每一层ℓ在每次更新后的平均权重规范,其公式表示为:

其中表示层ℓ 的平均权重规范。

实验发现在LoRA训练中,嵌入层或语言模型(LM)头部层的权重规范显著大于中间层,有时甚至高出数百倍。然而,在全参数训练设置下,这种现象并不明显。

Figure 2 展示了GPT2和LLaMA-2-7B模型在LoRA和全参数训练期间的层级权重规范。图中x轴代表从嵌入权重到最终层的层级,y轴量化了权重规范。这一可视化揭示了一个关键趋势:嵌入层或LM头部层在LoRA中的权重规范远大于中间层。

基于上述发现,研究者们希望模拟LoRA的更新模式,通过采样不同的层进行冻结,以避免LoRA固有的低秩表示能力的限制,并模仿其快速学习过程。在全参数设置中,LoRA中权重规范较小的层也应该有较小的采样概率来解冻,以保持迭代中的预期学习率相同。这正是重要性采样的思想。

Algorithm 1 展示了LISA方法的步骤。在实践中,除了底部和顶部层外,LoRA中所有层的权重规范都较小,因此研究者们采用​=,其中γ控制优化过程中预期的解冻层数。γ作为一个补偿因子,用来桥接LoRA和全参数调优之间的差异,让LISA模拟与LoRA相似的层级更新模式。为了进一步控制实际设置中的内存消耗,研究者们每次随机采样γ层,以限制训练期间最大未冻结层数。

通过这种方法,LISA算法能够在保持内存效率的同时,提高大型语言模型微调的性能。这一创新方法为解决LoRA在大规模微调中的局限性提供了新的思路,并展示了在不同领域任务中应用的潜力。

想要掌握如何将大模型的力量发挥到极致吗?叶老师带您深入了解 Llama Factory —— 一款革命性的大模型微调工具。9月22日晚,实战专家1小时讲解让您轻松上手,学习如何使用 Llama Factory 微调模型。

加助理微信提供直播链接:amliy007,29.9元即可参加线上直播分享,叶老师亲自指导,互动沟通,全面掌握Llama Factory,关注享粉丝福利,限时免费CSDN听直播后的录播讲解。
 

LLaMA Factory 支持多种预训练模型和微调算法。它提供灵活的运算精度和优化算法选择,以及丰富的实验监控工具。开源特性和社区支持使其易于使用,适合各类用户快速提升模型性能。

实验

为了证明LISA的内存效率,研究者们进行了峰值GPU内存消耗的实验。实验设置中,他们从Alpaca数据集(Taori et al., 2023)中随机抽取提示,并限制最大输出令牌长度为1024。重点关注两个关键超参数:LoRA的秩和LISA的激活层数。在其他超参数设置中,所有模型统一使用1的mini-batch大小,并排除了其他节省GPU内存的技术,如梯度检查点(Chen et al., 2016)、卸载(Ren et al., 2021)和快速注意力(Dao et al., 2022; Dao, 2023)。

Table 1为不同模型架构和配置下的峰值GPU内存消耗。特别是,当LISA配置增强了嵌入层(E)和两个额外层(E+H+2L)时,在微调LLaMA-2-70B模型时,与LoRA方法相比,显示出了相当大的GPU内存使用减少。具体而言LISA E+H+2L配置将峰值GPU内存从LoRA Rank 128配置所需的79G降低到75G。这种效率提升不是孤立的事件;在不同模型架构上观察到系统性的内存使用减少,表明LISA激活层的方法在内存效率上具有固有优势。

Figure 3 展示了不同方法和批量大小为1的LLaMA2-7B的GPU内存消耗。注意,LISA的内存减少允许LLaMA-2-7B在单个RTX4090(24GB)GPU上进行训练,这使得即使在笔记本电脑上也能负担得起高质量的微调。特别是由于LISA不引入适配器带来的额外参数,因此其激活内存消耗比LoRA少得多。由于pytorch(Paszke et al., 2019)与deepspeed(Rasley et al., 2020)允许在反向传播前删除冗余激活,LISA的激活内存甚至略低于全参数训练。

Figure 4 展示了不同方法和批量大小为1的LLaMA-2-7B的单次迭代时间成本。LISA由于减少了内存占用,还带来了加速效果。如图4所示,与全参数训练相比,LISA提供了大约2.9倍的加速,与LoRA相比大约有1.5倍的加速,这部分是由于去除了适配器结构。值得注意的是,LoRA和LISA的内存占用减少都显著加快了前向传播的速度,强调了内存高效训练的重要性。

LISA在保持显著内存节省的同时,还能在微调设置中获得有竞争力的性能。为了证明LISA优于LoRA,研究者们在Alpaca GPT-4数据集(Taori et al., 2023)上评估了它们的性能,该数据集包含由GPT-4(OpenAI et al., 2023)生成的52k对对话。微调的有效性在多个基准上进行评估:MT-Bench(Zheng et al., 2023)包含80个高质量的多轮问题,旨在从多个方面评估LLMs;MMLU(Hendrycks et al., 2020)总共包含57个任务,14,079个问题,涵盖广泛的世界知识;AGIEval(Zhong et al., 2023)作为以人为本的通用能力基准,包含9,316个实例;WinoGrande(Sakaguchi et al., 2021)是大规模常识推理数据集,包含44,000个实例,旨在挑战模型对上下文和常识知识的了解。

Table 2 和 Table 3 展示了中等规模LLMs的详细比较。基线包括全参数训练(FT)、低秩适应(LoRA)(Hu et al., 2022)和梯度低秩投影(GaLore)(Zhao et al., 2024)。结果表明,LISA在大多数评估轨道上一致性地优于其他微调方法,表明其在多样化任务和模型架构中的鲁棒性和有效性。LISA特别适用于指令跟随任务,在与其他基线方法相比时观察到较大的差距。LISA甚至超越了全参数训练,这表明当限制未冻结层数时,存在隐式正则化效果,类似于dropout(Srivastava et al., 2014)。

持续预训练对于使模型适应新数据和领域至关重要。为了评估LISA在持续预训练场景中的有效性,研究者们在数学领域与全参数训练进行了比较。

研究者们采用数学语料库OpenWebMath(Paster et al., 2023)构建持续预训练数据集。具体来说,他们从中提取了一个包含15亿令牌的高质量子集。详细情况在附录B.2中解释。在持续预训练后,然后对GSM8K(Cobbe et al., 2021)训练集进行相同的微调程序,该训练集包含7473个实例。

Table 4 显示,LISA能够实现与全参数训练相当甚至更好的性能,同时内存消耗要少得多。具体来说,与全参数训练相比,LISA只需要一半的内存成本。这表明LISA在计算效率和模型性能之间实现了更好的平衡。根据研究者的经验,将未冻结层数减少到原始大小的一半,在持续预训练期间不会变差甚至表现更好,同时内存消耗要少得多。

为了进一步证明LISA在大规模LLMs上的可扩展性,研究者们在LLaMA-2-70B(Touvron et al., 2023b)上进行了额外的微调实验。除了前面提到的指令跟随任务外,研究者们还使用了额外的特定领域微调任务,包括数学和医学QA基准。GSM8K数据集(Cobbe et al., 2021)包含7473个训练实例和1319个测试实例,用于数学领域。对于医学领域,研究者们选择了PubMedQA数据集(Jin et al., 2019),该数据集包含211.3K个人工生成的QA训练实例和1K个测试实例。

Table 5 显示,LISA在与LoRA相比时一致性地产生更好或相当的性能。此外,LISA在指令调整任务中再次超越了全参数训练,为LISA在大规模训练场景下的可扩展性提供了有力证据。

LISA的两个关键超参数是采样层数γ和采样周期K。为了直观和实证地指导这些超参数的选择,研究者们使用TinyLlama(Zhang et al., 2024)和LLaMA-2-7B(Touvron et al., 2023b)模型,在Alpaca-GPT4数据集上进行了消融研究。γ的配置,如E+H+2L、E+H+8L,分别表示为γ = 2和γ = 8。至于采样周期K = T /n,T = 122代表实验框架内的最大训练步骤。Table 6 中的发现表明,γ和K都显著影响LISA算法的性能。具体为较高的γ值增加了可训练参数的数量,尽管内存成本更高。另一方面,最优的K值促进了更频繁的层切换,从而在一定阈值内提高了性能,超出该阈值后性能可能会恶化。通常的经验法则是:更多的采样层和更高的采样周期会带来更好的性能。

由于LISA在算法上依赖于层的采样序列,研究者们进一步研究了LISA在三个不同运行中性能的变化,每个运行都使用不同的随机种子进行层选择。研究者们采用TinyLlama、LLaMA2-7B和Mistral-7B模型与Alpaca-GPT4数据集,同时保持所有其他超参数与前面指令跟随实验中使用的一致。Table 7 显示,LISA对不同的随机种子相当有韧性,三次运行之间的性能差距在0.13以内,与超过基线方法的性能增益相比,这是一个小值。

实验结果显示,LISA在保持相似或更低的GPU内存消耗的同时,在下游微调任务中的性能超越了LoRA,甚至在某些情况下还超越了全参数训练。

这篇关于大模型微调中的内存效率问题及解决方案的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1130164

相关文章

IDEA和GIT关于文件中LF和CRLF问题及解决

《IDEA和GIT关于文件中LF和CRLF问题及解决》文章总结:因IDEA默认使用CRLF换行符导致Shell脚本在Linux运行报错,需在编辑器和Git中统一为LF,通过调整Git的core.aut... 目录问题描述问题思考解决过程总结问题描述项目软件安装shell脚本上git仓库管理,但拉取后,上l

idea npm install很慢问题及解决(nodejs)

《ideanpminstall很慢问题及解决(nodejs)》npm安装速度慢可通过配置国内镜像源(如淘宝)、清理缓存及切换工具解决,建议设置全局镜像(npmconfigsetregistryht... 目录idea npm install很慢(nodejs)配置国内镜像源清理缓存总结idea npm in

pycharm跑python项目易出错的问题总结

《pycharm跑python项目易出错的问题总结》:本文主要介绍pycharm跑python项目易出错问题的相关资料,当你在PyCharm中运行Python程序时遇到报错,可以按照以下步骤进行排... 1. 一定不要在pycharm终端里面创建环境安装别人的项目子模块等,有可能出现的问题就是你不报错都安装

idea突然报错Malformed \uxxxx encoding问题及解决

《idea突然报错Malformeduxxxxencoding问题及解决》Maven项目在切换Git分支时报错,提示project元素为描述符根元素,解决方法:删除Maven仓库中的resolv... 目www.chinasem.cn录问题解决方式总结问题idea 上的 maven China编程项目突然报错,是

Python爬虫HTTPS使用requests,httpx,aiohttp实战中的证书异步等问题

《Python爬虫HTTPS使用requests,httpx,aiohttp实战中的证书异步等问题》在爬虫工程里,“HTTPS”是绕不开的话题,HTTPS为传输加密提供保护,同时也给爬虫带来证书校验、... 目录一、核心问题与优先级检查(先问三件事)二、基础示例:requests 与证书处理三、高并发选型:

Python内存管理机制之垃圾回收与引用计数操作全过程

《Python内存管理机制之垃圾回收与引用计数操作全过程》SQLAlchemy是Python中最流行的ORM(对象关系映射)框架之一,它提供了高效且灵活的数据库操作方式,本文将介绍如何使用SQLAlc... 目录安装核心概念连接数据库定义数据模型创建数据库表基本CRUD操作创建数据读取数据更新数据删除数据查

前端导出Excel文件出现乱码或文件损坏问题的解决办法

《前端导出Excel文件出现乱码或文件损坏问题的解决办法》在现代网页应用程序中,前端有时需要与后端进行数据交互,包括下载文件,:本文主要介绍前端导出Excel文件出现乱码或文件损坏问题的解决办法,... 目录1. 检查后端返回的数据格式2. 前端正确处理二进制数据方案 1:直接下载(推荐)方案 2:手动构造

Python绘制TSP、VRP问题求解结果图全过程

《Python绘制TSP、VRP问题求解结果图全过程》本文介绍用Python绘制TSP和VRP问题的静态与动态结果图,静态图展示路径,动态图通过matplotlib.animation模块实现动画效果... 目录一、静态图二、动态图总结【代码】python绘制TSP、VRP问题求解结果图(包含静态图与动态图

MyBatis/MyBatis-Plus同事务循环调用存储过程获取主键重复问题分析及解决

《MyBatis/MyBatis-Plus同事务循环调用存储过程获取主键重复问题分析及解决》MyBatis默认开启一级缓存,同一事务中循环调用查询方法时会重复使用缓存数据,导致获取的序列主键值均为1,... 目录问题原因解决办法如果是存储过程总结问题myBATis有如下代码获取序列作为主键IdMappe

Linux五种IO模型的使用解读

《Linux五种IO模型的使用解读》文章系统解析了Linux的五种IO模型(阻塞、非阻塞、IO复用、信号驱动、异步),重点区分同步与异步IO的本质差异,强调同步由用户发起,异步由内核触发,通过对比各模... 目录1.IO模型简介2.五种IO模型2.1 IO模型分析方法2.2 阻塞IO2.3 非阻塞IO2.4