基于huggingface peft进行qwen1.5-7b-chat训练/推理/服务发布

2024-08-26 05:20

本文主要是介绍基于huggingface peft进行qwen1.5-7b-chat训练/推理/服务发布,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

一、huggingface peft微调框架

1、定义

PEFT 是一个大型预训练模型提供多种高效微方法的Python库。

调传统范式是针对每个下游任模型参数。大模型参数大,种方式得极其昂和不切实际PEFT采用的高效做法是训练少量提示参数(Prompt Tuning)或使用低秩适(LORA)等重新参数化方法来减少微调时训练参数的数量。

二、qwen-1.5b-chat模型训练/推理/服务

1、基础环境准备

datasets==2.21.0

transformers==4.37.0

torch==1.13.0

accelerate==0.30.1

peft==0.4.0

numpy==1.26.4

Jinja2==3.1.4

2、人设定制数据准备

[

    {

        "instruction": "你是谁?",

        "input": "",

        "output": "我是一个语言模型,我叫小飞同学,可以为您做很多事情。请问您有什么问题需要我帮助吗?"

    },

    {

        "instruction": "你是什么?",

        "input": "",

        "output": "我是一个语言模型,我叫小飞同学,可以为您做很多事情。请问您有什么问题需要我帮助吗?"

    },

    {

        "instruction": "请问您是?",

        "input": "",

        "output": "我是一个语言模型,我叫小飞同学,可以为您做很多事情。请问您有什么问题需要我帮助吗?"

    },

    {

        "instruction": "你叫什么?",

        "input": "",

        "output": "我是一个语言模型,我叫小飞同学,可以为您做很多事情。请问您有什么问题需要我帮助吗?"

},

     {

        "instruction": "你的身份是?",

        "input": "",

        "output": "我是一个语言模型,我叫小飞同学,可以为您做很多事情。请问您有什么问题需要我帮助吗?"

    }

]

2、模型训练

from datasets import Dataset

import pandas as pd

from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer, GenerationConfig

# JSON文件转换为CSV文件

df = pd.read_json('./train.json')

ds = Dataset.from_pandas(df)

model_path = './huggingface/model/Qwen1.5-7B-Chat'

tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True)

def process_func(example):

    MAX_LENGTH = 384   

    input_ids, attention_mask, labels = [], [], []

    instruction = tokenizer(f"<|im_start|>system\n现在你要扮演人工智能智能客服助手--小飞同学<|im_end|>\n<|im_start|>user\n{example['instruction'] + example['input']}<|im_end|>\n<|im_start|>assistant\n", add_special_tokens=False

    response = tokenizer(f"{example['output']}", add_special_tokens=False)

    input_ids = instruction["input_ids"] + response["input_ids"] + [tokenizer.pad_token_id]

    attention_mask = instruction["attention_mask"] + response["attention_mask"] + [1

    labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] + [tokenizer.pad_token_id]

    if len(input_ids) > MAX_LENGTH:  # 做一个截断

        input_ids = input_ids[:MAX_LENGTH]

        attention_mask = attention_mask[:MAX_LENGTH]

        labels = labels[:MAX_LENGTH]

    return {

        "input_ids": input_ids,

        "attention_mask": attention_mask,

        "labels": labels

    }

tokenized_id = ds.map(process_func, remove_columns=ds.column_names)

import torch

model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto",torch_dtype=torch.bfloat16)

model.enable_input_require_grads()

from peft import LoraConfig, TaskType, get_peft_model

config = LoraConfig(

    task_type=TaskType.CAUSAL_LM,

    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],

    inference_mode=False, # 训练模式

    r=8, # Lora

    lora_alpha=32, # Lora alaph,具体作用参见 Lora 原理

    lora_dropout=0.1# Dropout 比例

)

model = get_peft_model(model, config)

args = TrainingArguments(

    output_dir="./output",

    per_device_train_batch_size=4,

    gradient_accumulation_steps=4,

    logging_steps=10,

    num_train_epochs=10,

    save_steps=50,

    learning_rate=1e-4,

    save_on_each_node=True,

    gradient_checkpointing=True

)

trainer = Trainer(

    model=model,

    args=args,

    train_dataset=tokenized_id,

    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),

)

trainer.train()

模型输出目录截图:

3、模型推理

from transformers import AutoModelForCausalLM, AutoTokenizer

import torch

from peft import PeftModel

model_path = './huggingface/model/Qwen1.5-7B-Chat'

lora_path = './output/checkpoint-50'

# 加载tokenizer

tokenizer = AutoTokenizer.from_pretrained(model_path)

# 加载模型

model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto",torch_dtype=torch.bfloat16)

from peft import LoraConfig, TaskType

config = LoraConfig(

    task_type=TaskType.CAUSAL_LM,

    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],

    inference_mode=True, # 训练模式

    r=8, # Lora

    lora_alpha=32, # Lora alaph,具体作用参见 Lora 原理

    lora_dropout=0.1# Dropout 比例

)

# 加载lora权重

model = PeftModel.from_pretrained(model, model_id=lora_path, config=config)

prompt = "你是星火大模型吗?"

messages = [

    {"role": "system", "content": "现在你要扮演人工智能智能客服助手--小飞同学"},

    {"role": "user", "content": prompt}

]

text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

model_inputs = tokenizer([text], return_tensors="pt").to('cuda')

generated_ids = model.generate(

    input_ids=model_inputs.input_ids,

    max_new_tokens=512

)

generated_ids = [

    output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)

]

response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

print(response)

模型推理日志截图:

4、基于FastAPI的sse协议模型服务

import uvicorn

from fastapi import FastAPI

from transformers import AutoModelForCausalLM, AutoTokenizer ,TextStreamer,TextIteratorStreamer

from threading import Thread

import torch

from peft import LoraConfig, TaskType, PeftModel

from sse_starlette.sse import EventSourceResponse

import json

# transfomershuggingface提供的一个工具,便于加载transformer结构的模型

app = FastAPI()

def load_model():

    model_path = './huggingface/model/Qwen1.5-7B-Chat'

    # 加载tokenizer

    tokenizer = AutoTokenizer.from_pretrained(model_path)

    # 加载模型(加速库attn_implementation="flash_attention_2"

    model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto",torch_dtype=torch.bfloat16

    # 加载lora权重

    lora_path = './output/checkpoint-50'

    config = LoraConfig(

        task_type=TaskType.CAUSAL_LM,

        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],

        inference_mode=True, # 训练模式

        r=8, # Lora

        lora_alpha=32, # Lora alaph,具体作用参见 Lora 原理

        lora_dropout=0.1# Dropout 比例

    )

    model = PeftModel.from_pretrained(model, model_id=lora_path, config=config)

    return tokenizer,model

tokenizer,model = load_model()

def infer_model(tokenizer,model):

    prompt = "你是星火大模型吗?"

    messages = [

        {"role": "system", "content": "现在你要扮演人工智能智能客服助手--小飞同学"},

        {"role": "user", "content": prompt}

    ]

    #数据提取

    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

    model_inputs = tokenizer([text], return_tensors="pt").to('cuda')

    #streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

    #模型推理

    from threading import Thread

    generation_kwargs = dict(model_inputs, streamer=streamer, max_new_tokens=512)

    thread = Thread(target=model.generate, kwargs=generation_kwargs)

    thread.start()

    for res in streamer:

        yield json.dumps({"data":res},ensure_ascii=False)

@app.get('/predict')

async def predict():

    #return infer_model(tokenizer,model)

    return EventSourceResponse(infer_model(tokenizer,model))

if __name__ == '__main__':

    # 在调试的时候开源加入一个reload=True的参数,正式启动的时候可以去掉

    uvicorn.run(app, host="0.0.0.0", port=6605, log_level="info")

客户端调用示例:

import json

import requests

import time

def listen_sse(url):

    # 发送GET请求到SSE端点

    with requests.get(url, stream=True, timeout=20) as response:

        try:

            # 确保请求成功

            response.raise_for_status()

            # 逐行读取响应内容

            result = ""

            for line in response.iter_lines():

                if line:

                    event_data = line.decode('utf-8')

                    if event_data.startswith('data:'):

                        # 去除'data:'前缀,获取实际数据

                        line = event_data.lstrip('data:')

                        line_data = json.loads(line)

                        result += line_data["data"]

                        print(result)

       except requests.exceptions.HTTPError as err:

            print(f"HTTP error: {err}")

        except Exception as err:

            print(f"An error occurred: {err}")

            return

sse_url = 'http://127.0.0.1:6605/predict'

listen_sse(sse_url

服务推理流式输出截图:

这篇关于基于huggingface peft进行qwen1.5-7b-chat训练/推理/服务发布的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!


原文地址:
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.chinasem.cn/article/1107664

相关文章

Spring Boot 与微服务入门实战详细总结

《SpringBoot与微服务入门实战详细总结》本文讲解SpringBoot框架的核心特性如快速构建、自动配置、零XML与微服务架构的定义、演进及优缺点,涵盖开发环境准备和HelloWorld实战... 目录一、Spring Boot 核心概述二、微服务架构详解1. 微服务的定义与演进2. 微服务的优缺点三

一文解密Python进行监控进程的黑科技

《一文解密Python进行监控进程的黑科技》在计算机系统管理和应用性能优化中,监控进程的CPU、内存和IO使用率是非常重要的任务,下面我们就来讲讲如何Python写一个简单使用的监控进程的工具吧... 目录准备工作监控CPU使用率监控内存使用率监控IO使用率小工具代码整合在计算机系统管理和应用性能优化中,监

如何使用Lombok进行spring 注入

《如何使用Lombok进行spring注入》本文介绍如何用Lombok简化Spring注入,推荐优先使用setter注入,通过注解自动生成getter/setter及构造器,减少冗余代码,提升开发效... Lombok为了开发环境简化代码,好处不用多说。spring 注入方式为2种,构造器注入和setter

RabbitMQ消息总线方式刷新配置服务全过程

《RabbitMQ消息总线方式刷新配置服务全过程》SpringCloudBus通过消息总线与MQ实现微服务配置统一刷新,结合GitWebhooks自动触发更新,避免手动重启,提升效率与可靠性,适用于配... 目录前言介绍环境准备代码示例测试验证总结前言介绍在微服务架构中,为了更方便的向微服务实例广播消息,

MySQL进行数据库审计的详细步骤和示例代码

《MySQL进行数据库审计的详细步骤和示例代码》数据库审计通过触发器、内置功能及第三方工具记录和监控数据库活动,确保安全、完整与合规,Java代码实现自动化日志记录,整合分析系统提升监控效率,本文给大... 目录一、数据库审计的基本概念二、使用触发器进行数据库审计1. 创建审计表2. 创建触发器三、Java

MySQL深分页进行性能优化的常见方法

《MySQL深分页进行性能优化的常见方法》在Web应用中,分页查询是数据库操作中的常见需求,然而,在面对大型数据集时,深分页(deeppagination)却成为了性能优化的一个挑战,在本文中,我们将... 目录引言:深分页,真的只是“翻页慢”那么简单吗?一、背景介绍二、深分页的性能问题三、业务场景分析四、

SpringBoot结合Docker进行容器化处理指南

《SpringBoot结合Docker进行容器化处理指南》在当今快速发展的软件工程领域,SpringBoot和Docker已经成为现代Java开发者的必备工具,本文将深入讲解如何将一个SpringBo... 目录前言一、为什么选择 Spring Bootjavascript + docker1. 快速部署与

linux解压缩 xxx.jar文件进行内部操作过程

《linux解压缩xxx.jar文件进行内部操作过程》:本文主要介绍linux解压缩xxx.jar文件进行内部操作,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、解压文件二、压缩文件总结一、解压文件1、把 xxx.jar 文件放在服务器上,并进入当前目录#

SpringBoot中如何使用Assert进行断言校验

《SpringBoot中如何使用Assert进行断言校验》Java提供了内置的assert机制,而Spring框架也提供了更强大的Assert工具类来帮助开发者进行参数校验和状态检查,下... 目录前言一、Java 原生assert简介1.1 使用方式1.2 示例代码1.3 优缺点分析二、Spring Fr

关于DNS域名解析服务

《关于DNS域名解析服务》:本文主要介绍关于DNS域名解析服务,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录DNS系统的作用及类型DNS使用的协议及端口号DNS系统的分布式数据结构DNS的分布式互联网解析库域名体系结构两种查询方式DNS服务器类型统计构建DNS域