如何使用自有数据微调ChatGLM-6B

2023-11-22 17:20

本文主要是介绍如何使用自有数据微调ChatGLM-6B,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

目录

构建自己的数据集

数据格式:问答对

官网例子

调整自有样本格式

数据集划分和数据量

微调方法

脚本参数如何修改

官网对于参数的解释:

如何防止过拟合和灾难遗忘

微调后效果评估

微调方法简介(理论)


构建自己的数据集

数据格式:问答对

官网例子

ADGEN 数据集任务为根据输入(content)生成一段广告词(summary)。

{  "content": "类型#上衣*版型#宽松*版型#显瘦*图案#线条*衣样式#衬衫*衣袖型#泡泡袖*衣款式#抽绳",  "summary": "这件衬衫的款式非常的宽松,利落的线条可以很好的隐藏身材上的小缺点,穿在身上有着很好的显瘦效果。领口装饰了一个可爱的抽绳,漂亮的绳结展现出了十足的个性,配合时尚的泡泡袖型,尽显女性甜美可爱的气息。"  
}

调整自有样本格式

结合您的任务场景,构建问题和答案,以提取关键字为例。

[
{  "content": "请提取下面句子的关键字:'''离岸人民币 (CNH) 兑美元北京时间04:59报7.1657元,较周二纽约尾盘上涨462点,盘中整体交投于7.2155-7.1617元区间。截至发稿,离岸人民币暂报7.1656,升值12基点。Wind数据显示,7月份以来美元指数持续下跌,12日更是大跌1.06%。与此同时,在岸、离岸人民币对美元汇率迎来反弹,在岸、离岸人民币双双收复7.2关。自上周四起,人民币对美元连续走高,截至12日收盘,离岸人民币对美元5个交易日累计涨幅达956个基点。'''",  "summary": "离岸人民币 反弹"  
},
{  "content": "请提取下面句子的关键字:'''连日来,人民币汇率强势回升,给市场留下深刻“记忆”。昨日晚间,在岸、离岸人民币汇率双双收复7.17关口,其中,离岸人民币汇率日内大涨近400点。而近一周以来,在岸、离岸人民币汇率强势回升近千点,升幅均超过1%。关于人民币汇率,中国人民银行行长易纲近期在《经济研究》发表《货币政策的自主性、有效性与经济金融稳定》一文。易纲在文中指出,“近年来人民币汇率弹性显著增强,提高了利率调控的自主性,促进了宏观经济稳定,经济基本面稳定又对汇率稳定形成支撑,外汇市场运行更有韧性,利率和汇率之间形成良性互动。”'''",  "summary": "人民币汇率 回升 强势 弹性 韧性 支撑"  
},
]

数据集划分和数据量

样本总量至少几百个,根据具体task调整,太少可能效果不好。

除了准备训练集,您还要准备一个验证集,测试与验证的比例可参考8:2、9:1等。如果您有测试集更好了,可以评测下效果。

微调方法

选用官网ptuningv2微调方法,去学习提示向量。对于 ChatGLM-6B 模型基于 P-Tuning v2 的微调。P-Tuning v2 将需要微调的参数量减少到原来的 0.1%,再通过模型量化、Gradient Checkpoint 等方法,最低只需要 7GB 显存即可运行。

脚本参数如何修改

# P-tuning v2
!PRE_SEQ_LEN=128 && LR=2e-2 && CUDA_VISIBLE_DEVICES=0 python3 main.py \--do_train \--train_file AdvertiseGen/train.json \--validation_file AdvertiseGen/dev.json \--prompt_column content \--response_column summary \--overwrite_cache \--model_name_or_path /home/mw/input/ChatGLM6B6449 \--output_dir /home/output/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR \--overwrite_output_dir \--max_source_length 64 \--max_target_length 64 \--per_device_train_batch_size 4 \--per_device_eval_batch_size 1 \--gradient_accumulation_steps 4 \--predict_with_generate \--max_steps 3000 \--logging_steps 10 \--save_steps 1000 \--learning_rate $LR \--pre_seq_len $PRE_SEQ_LEN \--quantization_bit 4
# 重要参数注释
PRE_SEQ_LEN=128 #前缀长度,前缀长度不占整个提示的输入token
LR=2e-2 # 学习率
--train_file AdvertiseGen/train.json \ # 训练集样本路径
--validation_file AdvertiseGen/dev.json \ # 验证集样本路径
--prompt_column content \ # 样本集json文件中问题/提示的key
--response_column summary \ # 样本集json文件中答案的key
--max_source_length 64 \ # 输入的token最大长度
--max_target_length 64 \ # 输出的token最大长度
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 4 \
--max_steps 3000 \
--quantization_bit 4 # 使用int4量化--model_name_or_path /home/mw/input/ChatGLM6B6449 \ # 原始预训练模型文件存放位置
--output_dir /home/output/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR \ # 微调生成的模型检查点/训练后的模型参数存放位置;后续推理预测时,需要再次加载此目录下的最后一个文件夹

官网对于参数的解释:

P-tuning v2
PRE_SEQ_LEN 和 LR 分别是 soft prompt 长度和训练的学习率,可以进行调节以取得最佳的效果。P-Tuning-v2 方法会冻结全部的模型参数,可通过调整 quantization_bit 来被原始模型的量化等级,不加此选项则为 FP16 精度加载。

在默认配置 quantization_bit=4、per_device_train_batch_size=1、gradient_accumulation_steps=16 下,INT4 的模型参数被冻结,一次训练迭代会以 1 的批处理大小进行 16 次累加的前后向传播,等效为 16 的总批处理大小,此时最低只需 6.7G 显存。若想在同等批处理大小下提升训练效率,可在二者乘积不变的情况下,加大 per_device_train_batch_size 的值,但也会带来更多的显存消耗,请根据实际情况酌情调整。

上文以 P-tuning v2 方法采取的参数 quantization_bit=4、per_device_train_batch_size=4、gradient_accumulation_steps=4

如何防止过拟合和灾难遗忘

训练的epoch(轮数)不要太多。具体多少是个玄学,如果训练样本有很少,建议先试试1轮,模型效果不收敛再尝试加大epoch数,如果过拟合了,请降低epoch数;如果样本很多,1w+,建议先试试0.5轮。

chatglm-6b的epoch数的计算方法:
per_device_train_batch_size*gradient_accumulation_steps*max_steps/样本总数

训练日志里会输出epoch的数据,大家可以仔细观察。

微调后效果评估

import os
import torch
from transformers import AutoConfig, AutoModel, AutoTokenizer
# 假如大约12G左右的预训练模型文件放在如下目录
model_path = "/home/mw/input/ChatGLM6B6449"
# 载入Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# Fine-tuning 后的表现测试
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True, pre_seq_len=128)
model = AutoModel.from_pretrained(model_path, config=config, trust_remote_code=True)
# 此处使用你的微调后得到的模型检查点目录,
#例如上文我们设置检查点目录为/home/output/adgen-chatglm-6b-pt-128-2e-2/,还要指定最终的检查点目录
#因为前面脚本设置了每1000step生成一个临时检查点,请使用最终那个检查点目录,即max-steps对应的目录。
prefix_state_dict = torch.load(os.path.join("/home/output/adgen-chatglm-6b-pt-128-2e-2/checkpoint-3000", "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)#V100 机型上可以不进行量化
#print(f"Quantized to 4 bit")
#model = model.quantize(4)
model = model.half().cuda()
model.transformer.prefix_encoder.float()
model = model.eval()response, history = model.chat(tokenizer, "类型#上衣\*材质#牛仔布\*颜色#白色\*风格#简约\*图案#刺绣\*衣样式#外套\*衣款式#破洞", history=[])
response
# 以下是微调前的效果评估,即没有加载微调后的检查点文件。
model = AutoModel.from_pretrained(model_path, trust_remote_code=True).half().cuda()
model = model.eval()response, history = model.chat(tokenizer, "类型#上衣\*材质#牛仔布\*颜色#白色\*风格#简约\*图案#刺绣\*衣样式#外套\*衣款式#破洞", history=[])
response

以上如有问题,欢迎评论区回复和交流~

微调方法简介(理论)

p*tuning 论文综述https://arxiv.org/pdf/2107.13586.pdf
p-tuning v1 论文https://arxiv.org/pdf/2103.10385.pdf
p-tuningv2 论文https://arxiv.org/pdf/2110.07602.pdf
prefix-tuning 论文https://aclanthology.org/2021.acl-long.353.pdf
Prompt Tuning 论文https://arxiv.org/pdf/2104.08691.pdf​

 以上方法均属于软提示(连续提示),区别硬提示(提示工程,或者理解为用自然语言去提示,也叫离散提示)。

现在大模型相关的良心论文里都开始写“直观感受”来解释他们的微调方法了,真是方便大家理解。下面用简单的直观语言来解释一下:

软提示微调的目标:去自动学习一些参数/向量,来模拟人工提示工程;即让模型在嵌入式空间自己学习一个提示向量,加到原来的输入之前,再去激活大模型的“潜力”。

谷歌的prompt tuning,这个名字有争议,因为目前还没有明确的分类和命名方法。学习一个嵌入向量,拼到原有输入的向量表示之前,然后喂给后续的预训练语言模型。如下如所示。

P-tuning v2:在v1的基础上,每一层transformer都加上一个要学习的提示向量。 如下图。

prefix-tuning:全参微调(下图顶部)会更新所有LM参数(红色的Transformer框),并要求为每个任务存储一个完整的模型副本。前缀调整(下图底部),它冻结LM参数并仅优化前缀(红色前缀块)。因此,只需要存储每个任务的前缀,使前缀调整模块化和空间高效。注意,每个垂直块表示一个时间步长的转换器活动。它也是在每个transformer前添加前缀序列来完成目标的。

目前看调优方法并无明显优劣之分,具体下游任务具体分析,目前细微差别只体现在论文的“相近工作比较中”谷歌的prompt tuning吐槽前缀调整包括编码器和解码器网络上的前缀,而提示调整只需要编码器上的提示。P-tuning v2吐槽P-tuning v1针对对复杂任务效果不佳,因为v1只调了input层(但谷歌说的prompt tuning也是只调了input层,但是效果也很显著)。

结论P-tuning v1和谷歌的prompt tuning比较类似:只修改了input层(即给transformer的输入层嵌入了一个可学习的提示向量);prefix-tuning和P-tuning v2比较类似:修改了每一个transformer的input层。

专家的回复

1、前缀长度是否占输入token,不占。大家知道chatglm-6b的输入token是有长度限制的,最初是2048个token,ptuningv2这种加前缀的微调方法,并不会挤占输入token,加入前缀并不会导致可用的输入token长度变短。

2、一个token大概对应1.8个汉字。

这篇关于如何使用自有数据微调ChatGLM-6B的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!


原文地址:https://blog.csdn.net/haohaizijhz/article/details/131709160
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.chinasem.cn/article/411697

相关文章

Python使用FFmpeg实现高效音频格式转换工具

《Python使用FFmpeg实现高效音频格式转换工具》在数字音频处理领域,音频格式转换是一项基础但至关重要的功能,本文主要为大家介绍了Python如何使用FFmpeg实现强大功能的图形化音频转换工具... 目录概述功能详解软件效果展示主界面布局转换过程截图完成提示开发步骤详解1. 环境准备2. 项目功能结

SpringBoot使用ffmpeg实现视频压缩

《SpringBoot使用ffmpeg实现视频压缩》FFmpeg是一个开源的跨平台多媒体处理工具集,用于录制,转换,编辑和流式传输音频和视频,本文将使用ffmpeg实现视频压缩功能,有需要的可以参考... 目录核心功能1.格式转换2.编解码3.音视频处理4.流媒体支持5.滤镜(Filter)安装配置linu

Redis中的Lettuce使用详解

《Redis中的Lettuce使用详解》Lettuce是一个高级的、线程安全的Redis客户端,用于与Redis数据库交互,Lettuce是一个功能强大、使用方便的Redis客户端,适用于各种规模的J... 目录简介特点连接池连接池特点连接池管理连接池优势连接池配置参数监控常用监控工具通过JMX监控通过Pr

解决mysql插入数据锁等待超时报错:Lock wait timeout exceeded;try restarting transaction

《解决mysql插入数据锁等待超时报错:Lockwaittimeoutexceeded;tryrestartingtransaction》:本文主要介绍解决mysql插入数据锁等待超时报... 目录报错信息解决办法1、数据库中执行如下sql2、再到 INNODB_TRX 事务表中查看总结报错信息Lock

apache的commons-pool2原理与使用实践记录

《apache的commons-pool2原理与使用实践记录》ApacheCommonsPool2是一个高效的对象池化框架,通过复用昂贵资源(如数据库连接、线程、网络连接)优化系统性能,这篇文章主... 目录一、核心原理与组件二、使用步骤详解(以数据库连接池为例)三、高级配置与优化四、典型应用场景五、注意事

使用Python实现Windows系统垃圾清理

《使用Python实现Windows系统垃圾清理》Windows自带的磁盘清理工具功能有限,无法深度清理各类垃圾文件,所以本文为大家介绍了如何使用Python+PyQt5开发一个Windows系统垃圾... 目录一、开发背景与工具概述1.1 为什么需要专业清理工具1.2 工具设计理念二、工具核心功能解析2.

Linux系统之stress-ng测压工具的使用

《Linux系统之stress-ng测压工具的使用》:本文主要介绍Linux系统之stress-ng测压工具的使用,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、理论1.stress工具简介与安装2.语法及参数3.具体安装二、实验1.运行8 cpu, 4 fo

Java使用MethodHandle来替代反射,提高性能问题

《Java使用MethodHandle来替代反射,提高性能问题》:本文主要介绍Java使用MethodHandle来替代反射,提高性能问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑... 目录一、认识MethodHandle1、简介2、使用方式3、与反射的区别二、示例1、基本使用2、(重要)

使用C#删除Excel表格中的重复行数据的代码详解

《使用C#删除Excel表格中的重复行数据的代码详解》重复行是指在Excel表格中完全相同的多行数据,删除这些重复行至关重要,因为它们不仅会干扰数据分析,还可能导致错误的决策和结论,所以本文给大家介绍... 目录简介使用工具C# 删除Excel工作表中的重复行语法工作原理实现代码C# 删除指定Excel单元

Linux lvm实例之如何创建一个专用于MySQL数据存储的LVM卷组

《Linuxlvm实例之如何创建一个专用于MySQL数据存储的LVM卷组》:本文主要介绍使用Linux创建一个专用于MySQL数据存储的LVM卷组的实例,具有很好的参考价值,希望对大家有所帮助,... 目录在Centos 7上创建卷China编程组并配置mysql数据目录1. 检查现有磁盘2. 创建物理卷3. 创