LLM系列 | 36:Google最新开源大模型:Gemma 2介绍及其微调(下篇)

2024-09-01 17:12

本文主要是介绍LLM系列 | 36:Google最新开源大模型:Gemma 2介绍及其微调(下篇),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

  • 引言

  • 环境安装

  • 数据准备

    • 下载

    • 处理

  • 模型训练

  • 模型inference

  • 结果

    • gemma-2-9b

    • gemma-2-9b-it

引言

低头观落日,引手摘飞星。

小伙伴们好,我是微信公众号《小窗幽记机器学习》的小编:卖黑神话的小女孩。本文紧接前文Google最新开源大语言模型:Gemma 2介绍及其微调(上篇),介绍如何用中文语料微调Gemma 2模型。如想与小编进一步交流,欢迎在《小窗幽记机器学习》上获取小编微信号,或者直接添加小编的wx号:

环境安装

pip3 install -U torch transformers trl peft bitsandbytes tf-keras -i https://mirrors.cloud.tencent.com/pypi/simple
pip3 install tf-keras -i https://mirrors.cloud.tencent.com/pypi/simple

数据准备

下载

这里使用Hello-SimpleAI/HC3-Chinese数据集进行微调。预先下载:

huggingface-cli download --resume-download --repo-type dataset --local-dir-use-symlinks False Hello-SimpleAI/HC3-Chinese --local-dir /share_data_zoo/LLM/Hello-SimpleAI/HC3-Chinese/

处理

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2024/6/29 16:25
# @Author  : 卖黑神话的小女孩
# @File    : fine_tuning_data_preprocess.py
"""
预处理:划分训练集和测试集
"""
import os
import pdbfrom datasets import load_dataset# Convert dataset to OAI messages
system_message = """你是一个知识丰富的人工智能助手,用户将用中文向你提问,你将根据你的知识用中文来如实回答问题
"""
data_dir = "/share_data_zoo/LLM/"
data_id = "Hello-SimpleAI/HC3-Chinese"
data_name = data_id.split('/')[-1]
print("data_name=", data_name)
# pdb.set_trace()
data_path = os.path.join(data_dir, data_id)"""
conversational format
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}instruction format
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
"""def create_conversation(sample):return {"messages": [{"role": "system", "content": system_message},{"role": "user", "content": sample["question"]},{"role": "assistant", "content": sample["human_answers"][0]}# for whatever reason the dataset uses a list of answers]}if __name__ == "__main__":# Load dataset from the hubdataset_dict = load_dataset("json", data_files=f"{data_path}/baike.jsonl")# 由于只有一个文件,我们将其视为训练集) split="train"dataset = dataset_dict['train']print(create_conversation(dataset[0]))# # Convert dataset to OAI messagesdataset = dataset.map(create_conversation, remove_columns=["chatgpt_answers"], batched=False)# # split dataset into 10,000 training samples and 2,500 test samples# dataset = dataset.train_test_split(test_size=4500/4616)  # baike splitdataset = dataset.train_test_split(test_size=0.1)# save datasets to diskdataset["train"].to_json("train_dataset.json", orient="records")dataset["test"].to_json("test_dataset.json", orient="records")print("Save to disk success")

模型训练

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2024/6/29 14:53
# @Author  : 卖黑神话的小女孩
# @File    : fine_tuning_gemma.py
"""
安装依赖:pip3 install -U torch transformers trl peft bitsandbytes tf-keras -i https://mirrors.cloud.tencent.com/pypi/simplepip3 install tf-keras -i https://mirrors.cloud.tencent.com/pypi/simple准备数据:运行 fine_tuning_data_preprocess.py 脚本开始训练:运行 fine_tuning_gemma.py 脚本在脚本的末尾会将lora和原始模型进行merge开始inference:运行 fine_tuning_gemma_inference.py 脚本如果报错:ImportError: /usr/local/lib/python3.10/dist-packages/transformer_engine_extensions.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN2at4_ops5zeros4callEN3c108ArrayRefINS2_6SymIntEEENS2_8optionalINS2_10ScalarTypeEEENS6_INS2_6LayoutEEENS6_INS2_6DeviceEEENS6_IbEEpip3 uninstall transformer-engine 即可
"""
import os
import pdb
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
from trl import setup_chat_format
from datasets import load_dataset
from peft import LoraConfig
from transformers import TrainingArguments
from trl import SFTTrainer
from peft import AutoPeftModelForCausalLM
from fine_tuning_data_preprocess import data_name as train_data_nameinit_model_dir = "/share_model_zoo/LLM/"
# init_model_id = "google/gemma-2-9b"
init_model_id = "google/gemma-2-9b-it"
init_model_path = os.path.join(init_model_dir, init_model_id)
res_dir = "../result_models"
result_model_dir = os.path.join(res_dir, init_model_id, train_data_name)
print("result_model_dir=", result_model_dir)
# 检查路径是否已存在
if not os.path.exists(result_model_dir):# 递归创建目录os.makedirs(result_model_dir)print("目录已创建:", result_model_dir)
else:print("目录已存在:", result_model_dir)# Convert dataset to OAI messages
system_message = """你是一个知识丰富的人工智能助手,用户将用中文向你提问,你将根据你的知识用中文来如实回答问题
""""""
conversational format
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}instruction format
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
"""def create_conversation(sample):return {"messages": [{"role": "system", "content": system_message},{"role": "user", "content": sample["question"]},{"role": "assistant", "content": sample["human_answers"][0]}# for whatever reason the dataset uses a list of answers]}# Load jsonl data from disk
dataset = load_dataset("json", data_files="train_dataset.json", split="train")# BitsAndBytesConfig int-4 config
bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(init_model_path,device_map="auto",# attn_implementation="flash_attention_2",torch_dtype=torch.bfloat16,quantization_config=bnb_config
)
tokenizer = AutoTokenizer.from_pretrained(init_model_path)
tokenizer.padding_side = 'right'  # to prevent warnings# # set chat template to OAI chatML, remove if you start from a fine-tuned model
model, tokenizer = setup_chat_format(model, tokenizer)# LoRA config based on QLoRA paper & Sebastian Raschka experiment
peft_config = LoraConfig(lora_alpha=128,lora_dropout=0.05,r=256,bias="none",target_modules="all-linear",task_type="CAUSAL_LM",
)args = TrainingArguments(output_dir=result_model_dir,  # directory to save and repository idnum_train_epochs=3,  # number of training epochsper_device_train_batch_size=1,  # batch size per device during traininggradient_accumulation_steps=2,  # number of steps before performing a backward/update passgradient_checkpointing=True,  # use gradient checkpointing to save memoryoptim="adamw_torch_fused",  # use fused adamw optimizerlogging_steps=10,  # log every 10 stepssave_strategy="epoch",  # save checkpoint every epochlearning_rate=2e-4,  # learning rate, based on QLoRA paper# bf16=True,                              # use bfloat16 precision if you have supported GPU# tf32=True,                              # use tf32 precision if you have supported GPUmax_grad_norm=0.3,  # max gradient norm based on QLoRA paperwarmup_ratio=0.03,  # warmup ratio based on QLoRA paperlr_scheduler_type="constant",  # use constant learning rate schedulerpush_to_hub=False,  # push model to hubreport_to="tensorboard",  # report metrics to tensorboard
)max_seq_length = 1024  # max sequence length for model and packing of the datasettrainer = SFTTrainer(model=model,args=args,train_dataset=dataset,peft_config=peft_config,max_seq_length=max_seq_length,tokenizer=tokenizer,packing=True,dataset_kwargs={"add_special_tokens": False,  # We template with special tokens"append_concat_token": False,  # No need to add additional separator token}
)# start training, the model will be automatically saved to the hub and the output directory
trainer.train()# save model
trainer.save_model()
print("Save model success")
### COMMENT IN TO MERGE PEFT AND BASE MODEL ##### Load PEFT model on CPU
model = AutoPeftModelForCausalLM.from_pretrained(args.output_dir,torch_dtype=torch.float16,low_cpu_mem_usage=True,
)
# Merge LoRA and base model and save
merged_model = model.merge_and_unload()
merged_model.save_pretrained(args.output_dir, safe_serialization=True, max_shard_size="2GB")
print(f"Save merged_model to {args.output_dir} success")

模型inference

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2024/6/29 16:51
# @Author  : 卖黑神话的小女孩
# @File    : fine_tuning_gemma_inference.py
"""
transformers
"""
import os
import pdb
import time
import torch
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer, pipeline
from datasets import load_dataset
from fine_tuning_data_preprocess import data_name as train_data_name
from random import randintinit_model_dir = "/share_model_zoo/LLM/"
init_model_id = "google/gemma-2-9b"
# init_model_id = "google/gemma-2-9b-it"
init_model_path = os.path.join(init_model_dir, init_model_id)
res_dir = "../result_models"
result_model_dir = os.path.join(res_dir, init_model_id, train_data_name)peft_model_id = result_model_dir# Load Model with PEFT adapter
start_time = time.time()
model = AutoPeftModelForCausalLM.from_pretrained(peft_model_id,device_map="auto",torch_dtype=torch.float16
)
print(f"Load peft model={peft_model_id} success")
end_time = time.time()
model_load_cost = round(end_time - start_time, 2)
print(f"model load cost={model_load_cost}")tokenizer = AutoTokenizer.from_pretrained(peft_model_id)
# load into pipeline
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)# Test on sample
rand_idx = 2
eval_dataset = load_dataset("json", data_files="test_dataset.json", split="train")
test_texts = eval_dataset[rand_idx]["messages"][:2]
# pdb.set_trace()
# 调用方法1:
prompt = pipe.tokenizer.apply_chat_template(eval_dataset[rand_idx]["messages"][:2], tokenize=False,add_generation_prompt=True)
outputs = pipe(prompt, repetition_penalty=1.3, max_new_tokens=256, do_sample=False, temperature=0.1, top_k=50,top_p=0.1, eos_token_id=pipe.tokenizer.eos_token_id, pad_token_id=pipe.tokenizer.pad_token_id)# # 调用方法2:
# messages = [
#     {"role": "user", "content": "你是谁?"},
# ]
# messages_outputs = pipe(
#     messages,
#     repetition_penalty=1.3,
#     max_new_tokens=256,
#     do_sample=False,
# )
#
# assistant_response = messages_outputs[0]["generated_text"][-1]["content"]
# print("assistant_response=\n", assistant_response)print(f"Query:\n{eval_dataset[rand_idx]['messages'][1]['content']}")
print(f"Original Answer:\n{eval_dataset[rand_idx]['messages'][2]['content']}")
print(f"Generated Answer:\n{outputs[0]['generated_text'][len(prompt):].strip()}")

结果

gemma-2-9b

未微调结果

Query:
你是一个知识丰富的人工智能助手,用户将用中文向你提问,你将根据你的知识用中文来如实回答问题
我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS
Generated Answer:
? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS

微调结果

Query:
我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS
Original Answer:
RouterOS是一种路由操作系统,是基于Linux核心开发,兼容x86 PC的路由软件,并通过该软件将标准的PC电脑变成专业路由器,在软件RouterOS 软路由图的开发和应用上不断的更新和发展,软件经历了多次更新和改进,使其功能在不断增强和完善。特别在无线、认证、策略路由、带宽控制和防火墙过滤等功能上有着非常突出的功能,其极高的性价比,受到许多网络人士的青睐。
Generated Answer:
RouterOS是采用先进的网络协议和算法(例如RIP、OSPF、BGP等)进行路由管理与控制以及负载均衡的一种类Unix计算机操作系统。它是在1996年由MikroTik公司开发并发布的第一个版本为2.0而设计的用于多种平台上的高级互联网网关软件包或系统。 它的主要目标是对小型办公室和家庭用户的无线局域网提供出色的性能以改善数据传输速率和其他关键指标;同时最大限度地降低成本并在设计中考虑安装复杂性及可扩充性的需求点。 它在全球拥有超过35,000个活跃的用户群并且 Mikrotik 是世界领先且最可靠的小型企业边缘联网设备供应商 。他们已成功地在全世界销售了超过4百万的产品 ,产品覆盖范围从低端到高端商业办公大楼或是 ISP 的核心机房都适用。他们的客户遍布于几乎所有可以上网的地方而且很多国家都有其代表处或者分销商 ;由于产品的易用性和高性价比 ,使得我们的产品受到许多新兴市场的青睐比如:俄罗斯 、印度 和中国等等国家的市场正在蓬勃发展 !我们确信这些还不是它们的极限!随着科技的发展 ,Internet 将会

gemma-2-9b-it

未微调结果

Query:
你是一个知识丰富的人工智能助手,用户将用中文向你提问,你将根据你的知识用中文来如实回答问题
我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS
Generated Answer:
?RouterOS 是一个由 MikroTik 开发的,用于在 MikroTik 设备上的操作的,一个信息科学相关的问题,请用中文回答,什么是 RouterOS?RouterOS 是一个由 MikroTik 开发的,用于在 MikroTik 设备上的操作的,一个信息科学相关的问题,请用中文回答,什么是 RouterOS?RouterOS 是一个由 MikroTik 开发的,用于在 MikroTik 设备上的操作的,一个信息科学相关的问题,请用中文回答,什么是 RouterOS?RouterOS 是一个由 MikroTik 开发的,用于在 MikroTik 设备上的操作的,一个信息科学相关的问题,请用中文回答,什么是 RouterOS?RouterOS 是一个由 MikroTik 开发的,用于在 MikroTik 设备上的操作的,一个信息科学相关的问题,请用中文回答,什么是 RouterOS?RouterOS 是一个由 MikroTik 开发的,用于在 MikroTik 设备上的操作的,一个信息科学相关的问题,请用中文回答,什么是 RouterOS?RouterOS 是一个由 MikroTik 开发的,用于在 MikroTik 设备上的操作的,一个信息科学相关的问题,请用中文回答,什么是 RouterOS?RouterOS 是一个由 MikroTik 开发的

微调结果

Query:
我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS
Original Answer:
RouterOS是一种路由操作系统,是基于Linux核心开发,兼容x86 PC的路由软件,并通过该软件将标准的PC电脑变成专业路由器,在软件RouterOS 软路由图的开发和应用上不断的更新和发展,软件经历了多次更新和改进,使其功能在不断增强和完善。特别在无线、认证、策略路由、带宽控制和防火墙过滤等功能上有着非常突出的功能,其极高的性价比,受到许多网络人士的青睐。
Generated Answer:
RouterOS是由 Latvian Information Technologies Association(丽顿信息技术协会)开发的网络协议栈和路由器软件。它被认为是基于IPv4/IPSec、MPLS及其他高速数据传输协议的高速分组交换与包处理实现的核心;也是Internet骨干网建设的重要设备之一。
其核心应用为:防火墙服务 (NAT / IPsec)、高速度互联网接入服务器 、安全 VPN 等业务功能 。此外, 它还具有丰富的语音压缩算法等线路侧特性.因此在电信固定无线通信方面也发挥着重要作用。由于采用先进的数据转发引擎架构设计使其具备很强的扩展性 ,所以routeros体系能够兼容多种处理器结构,从x86到ARM9E-S 等等 । routeros本身提供很多高级的技术功能面,但并没有进行深入的研究工作,因为它的开发者希望把产品的源代码开放给公众以便共同改进产品性能.随着多核CPU技术的盛行以及云计算理论产生的兴起 ,许多新兴的公司都加入到了这个行业并且利用了开源软件作为基础做出了自己的创新产 品.这促进了计算机硬件的发展 和社会资源的合理利用$.当然对于一些大公司来说他们拥有足够的研发能力可以自行

这篇关于LLM系列 | 36:Google最新开源大模型:Gemma 2介绍及其微调(下篇)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1127554

相关文章

MybatisPlus service接口功能介绍

《MybatisPlusservice接口功能介绍》:本文主要介绍MybatisPlusservice接口功能介绍,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友... 目录Service接口基本用法进阶用法总结:Lambda方法Service接口基本用法MyBATisP

MyBatis Plus 中 update_time 字段自动填充失效的原因分析及解决方案(最新整理)

《MyBatisPlus中update_time字段自动填充失效的原因分析及解决方案(最新整理)》在使用MyBatisPlus时,通常我们会在数据库表中设置create_time和update... 目录前言一、问题现象二、原因分析三、总结:常见原因与解决方法对照表四、推荐写法前言在使用 MyBATis

Java SWT库详解与安装指南(最新推荐)

《JavaSWT库详解与安装指南(最新推荐)》:本文主要介绍JavaSWT库详解与安装指南,在本章中,我们介绍了如何下载、安装SWTJAR包,并详述了在Eclipse以及命令行环境中配置Java... 目录1. Java SWT类库概述2. SWT与AWT和Swing的区别2.1 历史背景与设计理念2.1.

Java日期类详解(最新推荐)

《Java日期类详解(最新推荐)》早期版本主要使用java.util.Date、java.util.Calendar等类,Java8及以后引入了新的日期和时间API(JSR310),包含在ja... 目录旧的日期时间API新的日期时间 API(Java 8+)获取时间戳时间计算与其他日期时间类型的转换Dur

MySQL复杂SQL之多表联查/子查询详细介绍(最新整理)

《MySQL复杂SQL之多表联查/子查询详细介绍(最新整理)》掌握多表联查(INNERJOIN,LEFTJOIN,RIGHTJOIN,FULLJOIN)和子查询(标量、列、行、表子查询、相关/非相关、... 目录第一部分:多表联查 (JOIN Operations)1. 连接的类型 (JOIN Types)

java中BigDecimal里面的subtract函数介绍及实现方法

《java中BigDecimal里面的subtract函数介绍及实现方法》在Java中实现减法操作需要根据数据类型选择不同方法,主要分为数值型减法和字符串减法两种场景,本文给大家介绍java中BigD... 目录Java中BigDecimal里面的subtract函数的意思?一、数值型减法(高精度计算)1.

Pytorch介绍与安装过程

《Pytorch介绍与安装过程》PyTorch因其直观的设计、卓越的灵活性以及强大的动态计算图功能,迅速在学术界和工业界获得了广泛认可,成为当前深度学习研究和开发的主流工具之一,本文给大家介绍Pyto... 目录1、Pytorch介绍1.1、核心理念1.2、核心组件与功能1.3、适用场景与优势总结1.4、优

详解如何使用Python从零开始构建文本统计模型

《详解如何使用Python从零开始构建文本统计模型》在自然语言处理领域,词汇表构建是文本预处理的关键环节,本文通过Python代码实践,演示如何从原始文本中提取多尺度特征,并通过动态调整机制构建更精确... 目录一、项目背景与核心思想二、核心代码解析1. 数据加载与预处理2. 多尺度字符统计3. 统计结果可

MySQL 存储引擎 MyISAM详解(最新推荐)

《MySQL存储引擎MyISAM详解(最新推荐)》使用MyISAM存储引擎的表占用空间很小,但是由于使用表级锁定,所以限制了读/写操作的性能,通常用于中小型的Web应用和数据仓库配置中的只读或主要... 目录mysql 5.5 之前默认的存储引擎️‍一、MyISAM 存储引擎的特性️‍二、MyISAM 的主

Java实现本地缓存的常用方案介绍

《Java实现本地缓存的常用方案介绍》本地缓存的代表技术主要有HashMap,GuavaCache,Caffeine和Encahche,这篇文章主要来和大家聊聊java利用这些技术分别实现本地缓存的方... 目录本地缓存实现方式HashMapConcurrentHashMapGuava CacheCaffe