LLM系列 | 36:Google最新开源大模型:Gemma 2介绍及其微调(下篇)

2024-09-01 17:12

本文主要是介绍LLM系列 | 36:Google最新开源大模型:Gemma 2介绍及其微调(下篇),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

  • 引言

  • 环境安装

  • 数据准备

    • 下载

    • 处理

  • 模型训练

  • 模型inference

  • 结果

    • gemma-2-9b

    • gemma-2-9b-it

引言

低头观落日,引手摘飞星。

小伙伴们好,我是微信公众号《小窗幽记机器学习》的小编:卖黑神话的小女孩。本文紧接前文Google最新开源大语言模型:Gemma 2介绍及其微调(上篇),介绍如何用中文语料微调Gemma 2模型。如想与小编进一步交流,欢迎在《小窗幽记机器学习》上获取小编微信号,或者直接添加小编的wx号:

环境安装

pip3 install -U torch transformers trl peft bitsandbytes tf-keras -i https://mirrors.cloud.tencent.com/pypi/simple
pip3 install tf-keras -i https://mirrors.cloud.tencent.com/pypi/simple

数据准备

下载

这里使用Hello-SimpleAI/HC3-Chinese数据集进行微调。预先下载:

huggingface-cli download --resume-download --repo-type dataset --local-dir-use-symlinks False Hello-SimpleAI/HC3-Chinese --local-dir /share_data_zoo/LLM/Hello-SimpleAI/HC3-Chinese/

处理

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2024/6/29 16:25
# @Author  : 卖黑神话的小女孩
# @File    : fine_tuning_data_preprocess.py
"""
预处理:划分训练集和测试集
"""
import os
import pdbfrom datasets import load_dataset# Convert dataset to OAI messages
system_message = """你是一个知识丰富的人工智能助手,用户将用中文向你提问,你将根据你的知识用中文来如实回答问题
"""
data_dir = "/share_data_zoo/LLM/"
data_id = "Hello-SimpleAI/HC3-Chinese"
data_name = data_id.split('/')[-1]
print("data_name=", data_name)
# pdb.set_trace()
data_path = os.path.join(data_dir, data_id)"""
conversational format
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}instruction format
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
"""def create_conversation(sample):return {"messages": [{"role": "system", "content": system_message},{"role": "user", "content": sample["question"]},{"role": "assistant", "content": sample["human_answers"][0]}# for whatever reason the dataset uses a list of answers]}if __name__ == "__main__":# Load dataset from the hubdataset_dict = load_dataset("json", data_files=f"{data_path}/baike.jsonl")# 由于只有一个文件,我们将其视为训练集) split="train"dataset = dataset_dict['train']print(create_conversation(dataset[0]))# # Convert dataset to OAI messagesdataset = dataset.map(create_conversation, remove_columns=["chatgpt_answers"], batched=False)# # split dataset into 10,000 training samples and 2,500 test samples# dataset = dataset.train_test_split(test_size=4500/4616)  # baike splitdataset = dataset.train_test_split(test_size=0.1)# save datasets to diskdataset["train"].to_json("train_dataset.json", orient="records")dataset["test"].to_json("test_dataset.json", orient="records")print("Save to disk success")

模型训练

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2024/6/29 14:53
# @Author  : 卖黑神话的小女孩
# @File    : fine_tuning_gemma.py
"""
安装依赖:pip3 install -U torch transformers trl peft bitsandbytes tf-keras -i https://mirrors.cloud.tencent.com/pypi/simplepip3 install tf-keras -i https://mirrors.cloud.tencent.com/pypi/simple准备数据:运行 fine_tuning_data_preprocess.py 脚本开始训练:运行 fine_tuning_gemma.py 脚本在脚本的末尾会将lora和原始模型进行merge开始inference:运行 fine_tuning_gemma_inference.py 脚本如果报错:ImportError: /usr/local/lib/python3.10/dist-packages/transformer_engine_extensions.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN2at4_ops5zeros4callEN3c108ArrayRefINS2_6SymIntEEENS2_8optionalINS2_10ScalarTypeEEENS6_INS2_6LayoutEEENS6_INS2_6DeviceEEENS6_IbEEpip3 uninstall transformer-engine 即可
"""
import os
import pdb
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
from trl import setup_chat_format
from datasets import load_dataset
from peft import LoraConfig
from transformers import TrainingArguments
from trl import SFTTrainer
from peft import AutoPeftModelForCausalLM
from fine_tuning_data_preprocess import data_name as train_data_nameinit_model_dir = "/share_model_zoo/LLM/"
# init_model_id = "google/gemma-2-9b"
init_model_id = "google/gemma-2-9b-it"
init_model_path = os.path.join(init_model_dir, init_model_id)
res_dir = "../result_models"
result_model_dir = os.path.join(res_dir, init_model_id, train_data_name)
print("result_model_dir=", result_model_dir)
# 检查路径是否已存在
if not os.path.exists(result_model_dir):# 递归创建目录os.makedirs(result_model_dir)print("目录已创建:", result_model_dir)
else:print("目录已存在:", result_model_dir)# Convert dataset to OAI messages
system_message = """你是一个知识丰富的人工智能助手,用户将用中文向你提问,你将根据你的知识用中文来如实回答问题
""""""
conversational format
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}instruction format
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
"""def create_conversation(sample):return {"messages": [{"role": "system", "content": system_message},{"role": "user", "content": sample["question"]},{"role": "assistant", "content": sample["human_answers"][0]}# for whatever reason the dataset uses a list of answers]}# Load jsonl data from disk
dataset = load_dataset("json", data_files="train_dataset.json", split="train")# BitsAndBytesConfig int-4 config
bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(init_model_path,device_map="auto",# attn_implementation="flash_attention_2",torch_dtype=torch.bfloat16,quantization_config=bnb_config
)
tokenizer = AutoTokenizer.from_pretrained(init_model_path)
tokenizer.padding_side = 'right'  # to prevent warnings# # set chat template to OAI chatML, remove if you start from a fine-tuned model
model, tokenizer = setup_chat_format(model, tokenizer)# LoRA config based on QLoRA paper & Sebastian Raschka experiment
peft_config = LoraConfig(lora_alpha=128,lora_dropout=0.05,r=256,bias="none",target_modules="all-linear",task_type="CAUSAL_LM",
)args = TrainingArguments(output_dir=result_model_dir,  # directory to save and repository idnum_train_epochs=3,  # number of training epochsper_device_train_batch_size=1,  # batch size per device during traininggradient_accumulation_steps=2,  # number of steps before performing a backward/update passgradient_checkpointing=True,  # use gradient checkpointing to save memoryoptim="adamw_torch_fused",  # use fused adamw optimizerlogging_steps=10,  # log every 10 stepssave_strategy="epoch",  # save checkpoint every epochlearning_rate=2e-4,  # learning rate, based on QLoRA paper# bf16=True,                              # use bfloat16 precision if you have supported GPU# tf32=True,                              # use tf32 precision if you have supported GPUmax_grad_norm=0.3,  # max gradient norm based on QLoRA paperwarmup_ratio=0.03,  # warmup ratio based on QLoRA paperlr_scheduler_type="constant",  # use constant learning rate schedulerpush_to_hub=False,  # push model to hubreport_to="tensorboard",  # report metrics to tensorboard
)max_seq_length = 1024  # max sequence length for model and packing of the datasettrainer = SFTTrainer(model=model,args=args,train_dataset=dataset,peft_config=peft_config,max_seq_length=max_seq_length,tokenizer=tokenizer,packing=True,dataset_kwargs={"add_special_tokens": False,  # We template with special tokens"append_concat_token": False,  # No need to add additional separator token}
)# start training, the model will be automatically saved to the hub and the output directory
trainer.train()# save model
trainer.save_model()
print("Save model success")
### COMMENT IN TO MERGE PEFT AND BASE MODEL ##### Load PEFT model on CPU
model = AutoPeftModelForCausalLM.from_pretrained(args.output_dir,torch_dtype=torch.float16,low_cpu_mem_usage=True,
)
# Merge LoRA and base model and save
merged_model = model.merge_and_unload()
merged_model.save_pretrained(args.output_dir, safe_serialization=True, max_shard_size="2GB")
print(f"Save merged_model to {args.output_dir} success")

模型inference

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2024/6/29 16:51
# @Author  : 卖黑神话的小女孩
# @File    : fine_tuning_gemma_inference.py
"""
transformers
"""
import os
import pdb
import time
import torch
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer, pipeline
from datasets import load_dataset
from fine_tuning_data_preprocess import data_name as train_data_name
from random import randintinit_model_dir = "/share_model_zoo/LLM/"
init_model_id = "google/gemma-2-9b"
# init_model_id = "google/gemma-2-9b-it"
init_model_path = os.path.join(init_model_dir, init_model_id)
res_dir = "../result_models"
result_model_dir = os.path.join(res_dir, init_model_id, train_data_name)peft_model_id = result_model_dir# Load Model with PEFT adapter
start_time = time.time()
model = AutoPeftModelForCausalLM.from_pretrained(peft_model_id,device_map="auto",torch_dtype=torch.float16
)
print(f"Load peft model={peft_model_id} success")
end_time = time.time()
model_load_cost = round(end_time - start_time, 2)
print(f"model load cost={model_load_cost}")tokenizer = AutoTokenizer.from_pretrained(peft_model_id)
# load into pipeline
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)# Test on sample
rand_idx = 2
eval_dataset = load_dataset("json", data_files="test_dataset.json", split="train")
test_texts = eval_dataset[rand_idx]["messages"][:2]
# pdb.set_trace()
# 调用方法1:
prompt = pipe.tokenizer.apply_chat_template(eval_dataset[rand_idx]["messages"][:2], tokenize=False,add_generation_prompt=True)
outputs = pipe(prompt, repetition_penalty=1.3, max_new_tokens=256, do_sample=False, temperature=0.1, top_k=50,top_p=0.1, eos_token_id=pipe.tokenizer.eos_token_id, pad_token_id=pipe.tokenizer.pad_token_id)# # 调用方法2:
# messages = [
#     {"role": "user", "content": "你是谁?"},
# ]
# messages_outputs = pipe(
#     messages,
#     repetition_penalty=1.3,
#     max_new_tokens=256,
#     do_sample=False,
# )
#
# assistant_response = messages_outputs[0]["generated_text"][-1]["content"]
# print("assistant_response=\n", assistant_response)print(f"Query:\n{eval_dataset[rand_idx]['messages'][1]['content']}")
print(f"Original Answer:\n{eval_dataset[rand_idx]['messages'][2]['content']}")
print(f"Generated Answer:\n{outputs[0]['generated_text'][len(prompt):].strip()}")

结果

gemma-2-9b

未微调结果

Query:
你是一个知识丰富的人工智能助手,用户将用中文向你提问,你将根据你的知识用中文来如实回答问题
我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS
Generated Answer:
? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS? 我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS

微调结果

Query:
我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS
Original Answer:
RouterOS是一种路由操作系统,是基于Linux核心开发,兼容x86 PC的路由软件,并通过该软件将标准的PC电脑变成专业路由器,在软件RouterOS 软路由图的开发和应用上不断的更新和发展,软件经历了多次更新和改进,使其功能在不断增强和完善。特别在无线、认证、策略路由、带宽控制和防火墙过滤等功能上有着非常突出的功能,其极高的性价比,受到许多网络人士的青睐。
Generated Answer:
RouterOS是采用先进的网络协议和算法(例如RIP、OSPF、BGP等)进行路由管理与控制以及负载均衡的一种类Unix计算机操作系统。它是在1996年由MikroTik公司开发并发布的第一个版本为2.0而设计的用于多种平台上的高级互联网网关软件包或系统。 它的主要目标是对小型办公室和家庭用户的无线局域网提供出色的性能以改善数据传输速率和其他关键指标;同时最大限度地降低成本并在设计中考虑安装复杂性及可扩充性的需求点。 它在全球拥有超过35,000个活跃的用户群并且 Mikrotik 是世界领先且最可靠的小型企业边缘联网设备供应商 。他们已成功地在全世界销售了超过4百万的产品 ,产品覆盖范围从低端到高端商业办公大楼或是 ISP 的核心机房都适用。他们的客户遍布于几乎所有可以上网的地方而且很多国家都有其代表处或者分销商 ;由于产品的易用性和高性价比 ,使得我们的产品受到许多新兴市场的青睐比如:俄罗斯 、印度 和中国等等国家的市场正在蓬勃发展 !我们确信这些还不是它们的极限!随着科技的发展 ,Internet 将会

gemma-2-9b-it

未微调结果

Query:
你是一个知识丰富的人工智能助手,用户将用中文向你提问,你将根据你的知识用中文来如实回答问题
我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS
Generated Answer:
?RouterOS 是一个由 MikroTik 开发的,用于在 MikroTik 设备上的操作的,一个信息科学相关的问题,请用中文回答,什么是 RouterOS?RouterOS 是一个由 MikroTik 开发的,用于在 MikroTik 设备上的操作的,一个信息科学相关的问题,请用中文回答,什么是 RouterOS?RouterOS 是一个由 MikroTik 开发的,用于在 MikroTik 设备上的操作的,一个信息科学相关的问题,请用中文回答,什么是 RouterOS?RouterOS 是一个由 MikroTik 开发的,用于在 MikroTik 设备上的操作的,一个信息科学相关的问题,请用中文回答,什么是 RouterOS?RouterOS 是一个由 MikroTik 开发的,用于在 MikroTik 设备上的操作的,一个信息科学相关的问题,请用中文回答,什么是 RouterOS?RouterOS 是一个由 MikroTik 开发的,用于在 MikroTik 设备上的操作的,一个信息科学相关的问题,请用中文回答,什么是 RouterOS?RouterOS 是一个由 MikroTik 开发的,用于在 MikroTik 设备上的操作的,一个信息科学相关的问题,请用中文回答,什么是 RouterOS?RouterOS 是一个由 MikroTik 开发的

微调结果

Query:
我有一个信息科学相关的问题,请用中文回答,什么是 RouterOS
Original Answer:
RouterOS是一种路由操作系统,是基于Linux核心开发,兼容x86 PC的路由软件,并通过该软件将标准的PC电脑变成专业路由器,在软件RouterOS 软路由图的开发和应用上不断的更新和发展,软件经历了多次更新和改进,使其功能在不断增强和完善。特别在无线、认证、策略路由、带宽控制和防火墙过滤等功能上有着非常突出的功能,其极高的性价比,受到许多网络人士的青睐。
Generated Answer:
RouterOS是由 Latvian Information Technologies Association(丽顿信息技术协会)开发的网络协议栈和路由器软件。它被认为是基于IPv4/IPSec、MPLS及其他高速数据传输协议的高速分组交换与包处理实现的核心;也是Internet骨干网建设的重要设备之一。
其核心应用为:防火墙服务 (NAT / IPsec)、高速度互联网接入服务器 、安全 VPN 等业务功能 。此外, 它还具有丰富的语音压缩算法等线路侧特性.因此在电信固定无线通信方面也发挥着重要作用。由于采用先进的数据转发引擎架构设计使其具备很强的扩展性 ,所以routeros体系能够兼容多种处理器结构,从x86到ARM9E-S 等等 । routeros本身提供很多高级的技术功能面,但并没有进行深入的研究工作,因为它的开发者希望把产品的源代码开放给公众以便共同改进产品性能.随着多核CPU技术的盛行以及云计算理论产生的兴起 ,许多新兴的公司都加入到了这个行业并且利用了开源软件作为基础做出了自己的创新产 品.这促进了计算机硬件的发展 和社会资源的合理利用$.当然对于一些大公司来说他们拥有足够的研发能力可以自行

这篇关于LLM系列 | 36:Google最新开源大模型:Gemma 2介绍及其微调(下篇)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1127554

相关文章

MySQL常用字符串函数示例和场景介绍

《MySQL常用字符串函数示例和场景介绍》MySQL提供了丰富的字符串函数帮助我们高效地对字符串进行处理、转换和分析,本文我将全面且深入地介绍MySQL常用的字符串函数,并结合具体示例和场景,帮你熟练... 目录一、字符串函数概述1.1 字符串函数的作用1.2 字符串函数分类二、字符串长度与统计函数2.1

最新Spring Security的基于内存用户认证方式

《最新SpringSecurity的基于内存用户认证方式》本文讲解SpringSecurity内存认证配置,适用于开发、测试等场景,通过代码创建用户及权限管理,支持密码加密,虽简单但不持久化,生产环... 目录1. 前言2. 因何选择内存认证?3. 基础配置实战❶ 创建Spring Security配置文件

MySQL 迁移至 Doris 最佳实践方案(最新整理)

《MySQL迁移至Doris最佳实践方案(最新整理)》本文将深入剖析三种经过实践验证的MySQL迁移至Doris的最佳方案,涵盖全量迁移、增量同步、混合迁移以及基于CDC(ChangeData... 目录一、China编程JDBC Catalog 联邦查询方案(适合跨库实时查询)1. 方案概述2. 环境要求3.

SpringSecurity整合redission序列化问题小结(最新整理)

《SpringSecurity整合redission序列化问题小结(最新整理)》文章详解SpringSecurity整合Redisson时的序列化问题,指出需排除官方Jackson依赖,通过自定义反序... 目录1. 前言2. Redission配置2.1 RedissonProperties2.2 Red

MySQL 多列 IN 查询之语法、性能与实战技巧(最新整理)

《MySQL多列IN查询之语法、性能与实战技巧(最新整理)》本文详解MySQL多列IN查询,对比传统OR写法,强调其简洁高效,适合批量匹配复合键,通过联合索引、分批次优化提升性能,兼容多种数据库... 目录一、基础语法:多列 IN 的两种写法1. 直接值列表2. 子查询二、对比传统 OR 的写法三、性能分析

Spring Boot spring-boot-maven-plugin 参数配置详解(最新推荐)

《SpringBootspring-boot-maven-plugin参数配置详解(最新推荐)》文章介绍了SpringBootMaven插件的5个核心目标(repackage、run、start... 目录一 spring-boot-maven-plugin 插件的5个Goals二 应用场景1 重新打包应用

zookeeper端口说明及介绍

《zookeeper端口说明及介绍》:本文主要介绍zookeeper端口说明,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、zookeeper有三个端口(可以修改)aVNMqvZ二、3个端口的作用三、部署时注意总China编程结一、zookeeper有三个端口(可以

Javaee多线程之进程和线程之间的区别和联系(最新整理)

《Javaee多线程之进程和线程之间的区别和联系(最新整理)》进程是资源分配单位,线程是调度执行单位,共享资源更高效,创建线程五种方式:继承Thread、Runnable接口、匿名类、lambda,r... 目录进程和线程进程线程进程和线程的区别创建线程的五种写法继承Thread,重写run实现Runnab

Knife4j+Axios+Redis前后端分离架构下的 API 管理与会话方案(最新推荐)

《Knife4j+Axios+Redis前后端分离架构下的API管理与会话方案(最新推荐)》本文主要介绍了Swagger与Knife4j的配置要点、前后端对接方法以及分布式Session实现原理,... 目录一、Swagger 与 Knife4j 的深度理解及配置要点Knife4j 配置关键要点1.Spri

Spring IoC 容器的使用详解(最新整理)

《SpringIoC容器的使用详解(最新整理)》文章介绍了Spring框架中的应用分层思想与IoC容器原理,通过分层解耦业务逻辑、数据访问等模块,IoC容器利用@Component注解管理Bean... 目录1. 应用分层2. IoC 的介绍3. IoC 容器的使用3.1. bean 的存储3.2. 方法注