微调实操四:直接偏好优化方法-DPO

2024-02-27 05:20

本文主要是介绍微调实操四:直接偏好优化方法-DPO,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

在《微调实操三:人类反馈对语言模型进行强化学习(RLHF)》中提到过第三阶段有2个方法,一种是是RLHF, 另外一种就是今天的DPO方法, DPO通过直接优化语言模型来实现对其行为的精确控制,而无需使用复杂的强化学习,也可以有效学习到人类偏好,DPO相较于RLHF更容易实现且易于训练,效果更好.

1、DPO VS RLHF

在这里插入图片描述

DPO 是一种自动微调方法,它通过最大化预训练模型在特定任务上的奖励来优化模型参数。与传统的微调方法相比,DPO 绕过了建模奖励函数这一步,而是通过直接在偏好数据上优化模型来提高性能。相对RLHF两阶段而言具有多项优越性:

(1)简单性:DPO更容易实施和培训,使其更易于使用。

(2)稳定性:不易陷入局部最优,保证训练过程更加可靠。

(3)效率:与RLHF 相比, DPO 需要更少的计算资源和数据,使其计算量轻。

(4)有效性:实验结果表明,DPO在情感控制、摘要和对话生成等任务中可以优于 RLHF 。

并不是说DPO没有奖励模型, 而是利用同个阶段训练建立模型和强化学习, 在 DPO 中,目标函数是优化模型参数以最大化奖励的函数。除了奖励最大化目标外,还需要添加一个相对于参考模型的 KL 惩罚项,以防止模型学习作弊或钻营奖励模型。
在这里插入图片描述

2、trl库

TRL(Transformer Reinforcement Learning)是一个全面的库,专为使用强化学习训练变换器语言模型而设计。它包含多种工具,可以支持从监督式微调(SFT)开始,通过奖励建模(RM)阶段,最终达到近端策略优化(PPO)阶段和DPO。此库 transformers框架无缝集成。所以未来在人工智能领域transformers必学.

在这里插入图片描述
在这里插入图片描述

3、实操

3.1 数据集
在这里插入图片描述

采用《微调实操三:人类反馈对语言模型进行强化学习(RLHF)》 阶段中训练奖励模型的数据集

3.2 合并指令微调的模型

!python /kaggle/working/MedicalGPT/merge_peft_adapter.py --model_type bloom \
--base_model merged-pt --lora_model outputs-sft-v1 --output_dir merged-sft/

3.3 DPO训练脚本


# dpo training
%cd /kaggle/working/autoorder
!ls 
!git pull
!pip install -r algorithm/llm/requirements.txt
!pip install Logbook
import os
os.environ['RUN_PACKAGE'] = 'algorithm.llm.train.dpo_training'
os.environ['RUN_CLASS'] = 'DPOTraining'
print(os.getenv("RUN_PACKAGE"))
!python main.py \--model_type bloom \--model_name_or_path ./merged-sft \--train_file_dir /kaggle/working/MedicalGPT/data/reward \--validation_file_dir /kaggle/working/MedicalGPT/data/reward \--per_device_train_batch_size 3 \--per_device_eval_batch_size 1 \--do_train \--do_eval \--use_peft True \--max_train_samples 1000 \--max_eval_samples 10 \--max_steps 100 \--eval_steps 10 \--save_steps 50 \--max_source_length 128 \--max_target_length 128 \--output_dir outputs-dpo-v1 \--target_modules all \--lora_rank 8 \--lora_alpha 16 \--lora_dropout 0.05 \--torch_dtype float16 \--fp16 True \--device_map auto \--report_to tensorboard \--remove_unused_columns False \--gradient_checkpointing True \--cache_dir ./cache \--use_fast_tokenizer

这篇关于微调实操四:直接偏好优化方法-DPO的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/751335

相关文章

Python中提取文件名扩展名的多种方法实现

《Python中提取文件名扩展名的多种方法实现》在Python编程中,经常会遇到需要从文件名中提取扩展名的场景,Python提供了多种方法来实现这一功能,不同方法适用于不同的场景和需求,包括os.pa... 目录技术背景实现步骤方法一:使用os.path.splitext方法二:使用pathlib模块方法三

Python打印对象所有属性和值的方法小结

《Python打印对象所有属性和值的方法小结》在Python开发过程中,调试代码时经常需要查看对象的当前状态,也就是对象的所有属性和对应的值,然而,Python并没有像PHP的print_r那样直接提... 目录python中打印对象所有属性和值的方法实现步骤1. 使用vars()和pprint()2. 使

CSS实现元素撑满剩余空间的五种方法

《CSS实现元素撑满剩余空间的五种方法》在日常开发中,我们经常需要让某个元素占据容器的剩余空间,本文将介绍5种不同的方法来实现这个需求,并分析各种方法的优缺点,感兴趣的朋友一起看看吧... css实现元素撑满剩余空间的5种方法 在日常开发中,我们经常需要让某个元素占据容器的剩余空间。这是一个常见的布局需求

Python常用命令提示符使用方法详解

《Python常用命令提示符使用方法详解》在学习python的过程中,我们需要用到命令提示符(CMD)进行环境的配置,:本文主要介绍Python常用命令提示符使用方法的相关资料,文中通过代码介绍的... 目录一、python环境基础命令【Windows】1、检查Python是否安装2、 查看Python的安

Maven 配置中的 <mirror>绕过 HTTP 阻断机制的方法

《Maven配置中的<mirror>绕过HTTP阻断机制的方法》:本文主要介绍Maven配置中的<mirror>绕过HTTP阻断机制的方法,本文给大家分享问题原因及解决方案,感兴趣的朋友一... 目录一、问题场景:升级 Maven 后构建失败二、解决方案:通过 <mirror> 配置覆盖默认行为1. 配置示

SpringBoot排查和解决JSON解析错误(400 Bad Request)的方法

《SpringBoot排查和解决JSON解析错误(400BadRequest)的方法》在开发SpringBootRESTfulAPI时,客户端与服务端的数据交互通常使用JSON格式,然而,JSON... 目录问题背景1. 问题描述2. 错误分析解决方案1. 手动重新输入jsON2. 使用工具清理JSON3.

使用jenv工具管理多个JDK版本的方法步骤

《使用jenv工具管理多个JDK版本的方法步骤》jenv是一个开源的Java环境管理工具,旨在帮助开发者在同一台机器上轻松管理和切换多个Java版本,:本文主要介绍使用jenv工具管理多个JD... 目录一、jenv到底是干啥的?二、jenv的核心功能(一)管理多个Java版本(二)支持插件扩展(三)环境隔

Java中Map.Entry()含义及方法使用代码

《Java中Map.Entry()含义及方法使用代码》:本文主要介绍Java中Map.Entry()含义及方法使用的相关资料,Map.Entry是Java中Map的静态内部接口,用于表示键值对,其... 目录前言 Map.Entry作用核心方法常见使用场景1. 遍历 Map 的所有键值对2. 直接修改 Ma

Mybatis Plus Join使用方法示例详解

《MybatisPlusJoin使用方法示例详解》:本文主要介绍MybatisPlusJoin使用方法示例详解,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,... 目录1、pom文件2、yaml配置文件3、分页插件4、示例代码:5、测试代码6、和PageHelper结合6

Java中实现线程的创建和启动的方法

《Java中实现线程的创建和启动的方法》在Java中,实现线程的创建和启动是两个不同但紧密相关的概念,理解为什么要启动线程(调用start()方法)而非直接调用run()方法,是掌握多线程编程的关键,... 目录1. 线程的生命周期2. start() vs run() 的本质区别3. 为什么必须通过 st