大模型增量预训练新技巧:解决灾难性遗忘

2024-02-04 09:36

本文主要是介绍大模型增量预训练新技巧:解决灾难性遗忘,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

大家好,目前不少开源模型在通用领域具有不错的效果,但由于缺乏领域数据,往往在一些垂直领域中表现不理想,这时就需要增量预训练和微调等方法来提高模型的领域能力。

但在领域数据增量预训练或微调时,很容易出现灾难性遗忘现象,也就是学会了垂直领域知识,但忘记了通用领域知识,之前介绍过增量预训练以及领域大模型训练技巧。

今天给大家带来一篇增量预训练方法-Llama-Pro,对LLMs进行Transformer块扩展后,增量预训练过程中仅对新增块进行训练,有效地进行模型知识注入,并且极大程度地避免灾难性遗忘。

图片

LLaMA Pro: Progressive LLaMA with Block Expansion

LLaMA Pro: Progressive LLaMA with Block Expansion
Paper: https://arxiv.org/abs/2401.02415
Github: https://github.com/TencentARC/LLaMA-Pro

文章目录

    • 技术交流群
    • 用通俗易懂方式讲解系列
    • 块扩展方法
    • 实验细节
    • 讨论分析
    • 写在最后

技术交流群

前沿技术资讯、算法交流、求职内推、算法竞赛、面试交流(校招、社招、实习)等、与 10000+来自港科大、北大、清华、中科院、CMU、腾讯、百度等名校名企开发者互动交流~

我们建了大模型面试与技术交流群, 想要进交流群、需要源码&资料、提升技术的同学,可以直接加微信号:mlc2060。加的时候备注一下:研究方向 +学校/公司+CSDN,即可。然后就可以拉你进群了。

方式①、微信搜索公众号:机器学习社区,后台回复:加群
方式②、添加微信号:mlc2060,备注:技术交流

资料1
在这里插入图片描述

用通俗易懂方式讲解系列

  • 用通俗易懂的方式讲解:自然语言处理初学者指南(附1000页的PPT讲解)
  • 用通俗易懂的方式讲解:1.6万字全面掌握 BERT
  • 用通俗易懂的方式讲解:NLP 这样学习才是正确路线
  • 用通俗易懂的方式讲解:28张图全解深度学习知识!
  • 用通俗易懂的方式讲解:不用再找了,这就是 NLP 方向最全面试题库
  • 用通俗易懂的方式讲解:实体关系抽取入门教程
  • 用通俗易懂的方式讲解:灵魂 20 问帮你彻底搞定Transformer
  • 用通俗易懂的方式讲解:图解 Transformer 架构
  • 用通俗易懂的方式讲解:大模型算法面经指南(附答案)
  • 用通俗易懂的方式讲解:十分钟部署清华 ChatGLM-6B,实测效果超预期
  • 用通俗易懂的方式讲解:内容讲解+代码案例,轻松掌握大模型应用框架 LangChain
  • 用通俗易懂的方式讲解:如何用大语言模型构建一个知识问答系统
  • 用通俗易懂的方式讲解:最全的大模型 RAG 技术概览
  • 用通俗易懂的方式讲解:利用 LangChain 和 Neo4j 向量索引,构建一个RAG应用程序
  • 用通俗易懂的方式讲解:使用 Neo4j 和 LangChain 集成非结构化知识图增强 QA
  • 用通俗易懂的方式讲解:面了 5 家知名企业的NLP算法岗(大模型方向),被考倒了。。。。。
  • 用通俗易懂的方式讲解:NLP 算法实习岗,对我后续找工作太重要了!。
  • 用通俗易懂的方式讲解:理想汽车大模型算法工程师面试,被问的瑟瑟发抖。。。。
  • 用通俗易懂的方式讲解:基于 Langchain-Chatchat,我搭建了一个本地知识库问答系统
  • 面试了字节大模型算法岗(实习),快被问哭了。。。。

块扩展方法

块扩展,顾名思义,就是在原始模型中每个Transformer块或者某几个Transformer块后增加一个Transformer块,但为了保持扩展后的模型输出保持不变,需要增加的块为恒等块(输入输出相同),如下图所示。

图片

在构建恒等块过程中,主要是将多头注意力层和FFN层中的最后一个线性层(Linear)权重置为0变成Zero-Linear,即可保持经过该块的输入输出一致。

PS:论文附录A中写了大段的推导公式来证明,在此不做过多介绍。

块的增加方式是,对原始模型的 个Transformer块分成 组,每组中包含 个Transformer块,对于每组后添加 个恒等块。代码实现具体如下:

model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16)
ckpt = model.state_dict()# original_layers是模型原始层数,layers是模型最后达到层数
split = int(original_layers / (layers - original_layers))layer_cnt = 0output = {}
for i in range(original_layers):for k in ckpt:if ('layers.' + str(i) + '.') in k:output[k.replace(('layers.' + str(i) + '.'), ('layers.' + str(layer_cnt) + '.'))] = ckpt[k]layer_cnt += 1if (i+1) % split == 0:for k in ckpt:if ('layers.' + str(i) + '.') in k:if 'down_proj' in k or 'o_proj' in k:output[k.replace(('layers.' + str(i) + '.'), ('layers.' + str(layer_cnt) + '.'))] = torch.zeros_like(ckpt[k])else:output[k.replace(('layers.' + str(i) + '.'), ('layers.' + str(layer_cnt) + '.'))] = ckpt[k]layer_cnt += 1assert layer_cnt==layers
for k in ckpt:if not 'layers' in k:output[k] = ckpt[k]torch.save(output, output_path)

实验细节

数据由代码和数学组成,其中代码数据采用The-Stack-Dedup数据集中Python语言部分共22B Token,数学数据采用Proof-Pile-2数据集中AlgebraicStack、OpenWebMath和ArXiv部分共55B,详细如下表所示。

图片

数据分布

基础模型为LLaMA2-7B模型,通过块扩展方法将32层模型扩展到40层,其中 、 、 ,每个组从4个Transformer块扩展到5个Transformer块。

对于代码和数学数据进行增量预训练,批量大小为1024,序列最大长度为4096,预热比率为6%,学习率为2e-4,采用余弦学习率调度器,BF16混合精度训练,权重衰减为0.1。使用16个NVIDIA H800 GPU进行了15900个步骤的训练,大约耗费2830个GPU/小时。

在ARC、HellaSwag、MMLU、TruthfulQA、Winogrande、GSM8K、GSM8K-PoT、HumanEval、MBPP等多个评测数据集中进行评测,可以看出,在保持通用任务能力不下降的情况下,数学和代码能力较原始LLaMA2-7B模型有很大提升。

图片

图片

讨论分析

对比块扩展方法与正常训练和Lora方法之间的区别,采用TRACE基准利用总体性能(OP)和逆向转移(BWT)指标进行评估。,如下表所示,块扩展方法整体提升较大。

图片

对比块个数对块扩展方法的影响,进行了不同个数块的实验,并且对比了MoE的方法,训练损失如下,MoE方法的损失下降程度与添加四个块相当。

图片

在代码和法律(16.7B)领域数据下进行增量预训练,在通用任务以及领域任务上比较不同个数块之间的差异,同时比较扩展块全部添加到模型底部或顶部之间的差别,如下所示。可以发现块个数为8时效果最佳,并且不能直接将扩展块全部堆积在头部或尾部,需要分开插入。

图片

写在最后

该方法主要通过增加恒定块扩展模型层数,使模型在增量训练过程中仅训练新增层、冻结原始层,保持模型原有能力,防止模型出现灾难性遗忘现象。

但有两点存疑:

  • 目前来说mistral要好于llama,为啥不用mistral进行实验

  • 不用恒定块,性能会差多少

这篇关于大模型增量预训练新技巧:解决灾难性遗忘的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/677032

相关文章

解决pandas无法读取csv文件数据的问题

《解决pandas无法读取csv文件数据的问题》本文讲述作者用Pandas读取CSV文件时因参数设置不当导致数据错位,通过调整delimiter和on_bad_lines参数最终解决问题,并强调正确参... 目录一、前言二、问题复现1. 问题2. 通过 on_bad_lines=‘warn’ 跳过异常数据3

解决RocketMQ的幂等性问题

《解决RocketMQ的幂等性问题》重复消费因调用链路长、消息发送超时或消费者故障导致,通过生产者消息查询、Redis缓存及消费者唯一主键可以确保幂等性,避免重复处理,本文主要介绍了解决RocketM... 目录造成重复消费的原因解决方法生产者端消费者端代码实现造成重复消费的原因当系统的调用链路比较长的时

深度解析Nginx日志分析与499状态码问题解决

《深度解析Nginx日志分析与499状态码问题解决》在Web服务器运维和性能优化过程中,Nginx日志是排查问题的重要依据,本文将围绕Nginx日志分析、499状态码的成因、排查方法及解决方案展开讨论... 目录前言1. Nginx日志基础1.1 Nginx日志存放位置1.2 Nginx日志格式2. 499

SpringBoot监控API请求耗时的6中解决解决方案

《SpringBoot监控API请求耗时的6中解决解决方案》本文介绍SpringBoot中记录API请求耗时的6种方案,包括手动埋点、AOP切面、拦截器、Filter、事件监听、Micrometer+... 目录1. 简介2.实战案例2.1 手动记录2.2 自定义AOP记录2.3 拦截器技术2.4 使用Fi

kkFileView启动报错:报错2003端口占用的问题及解决

《kkFileView启动报错:报错2003端口占用的问题及解决》kkFileView启动报错因office组件2003端口未关闭,解决:查杀占用端口的进程,终止Java进程,使用shutdown.s... 目录原因解决总结kkFileViewjavascript启动报错启动office组件失败,请检查of

SQL Server安装时候没有中文选项的解决方法

《SQLServer安装时候没有中文选项的解决方法》用户安装SQLServer时界面全英文,无中文选项,通过修改安装设置中的国家或地区为中文中国,重启安装程序后界面恢复中文,解决了问题,对SQLSe... 你是不是在安装SQL Server时候发现安装界面和别人不同,并且无论如何都没有中文选项?这个问题也

游戏闪退弹窗提示找不到storm.dll文件怎么办? Stormdll文件损坏修复技巧

《游戏闪退弹窗提示找不到storm.dll文件怎么办?Stormdll文件损坏修复技巧》DLL文件丢失或损坏会导致软件无法正常运行,例如我们在电脑上运行软件或游戏时会得到以下提示:storm.dll... 很多玩家在打开游戏时,突然弹出“找不到storm.dll文件”的提示框,随后游戏直接闪退,这通常是由于

java内存泄漏排查过程及解决

《java内存泄漏排查过程及解决》公司某服务内存持续增长,疑似内存泄漏,未触发OOM,排查方法包括检查JVM配置、分析GC执行状态、导出堆内存快照并用IDEAProfiler工具定位大对象及代码... 目录内存泄漏内存问题排查1.查看JVM内存配置2.分析gc是否正常执行3.导出 dump 各种工具分析4.

Spring的RedisTemplate的json反序列泛型丢失问题解决

《Spring的RedisTemplate的json反序列泛型丢失问题解决》本文主要介绍了SpringRedisTemplate中使用JSON序列化时泛型信息丢失的问题及其提出三种解决方案,可以根据性... 目录背景解决方案方案一方案二方案三总结背景在使用RedisTemplate操作redis时我们针对

SpringBoot整合Dubbo+ZK注册失败的坑及解决

《SpringBoot整合Dubbo+ZK注册失败的坑及解决》使用Dubbo框架时,需在公共pom添加依赖,启动类加@EnableDubbo,实现类用@DubboService替代@Service,配... 目录1.先看下公共的pom(maven创建的pom工程)2.启动类上加@EnableDubbo3.实