【强化学习】DPO(Direct Preference Optimization)算法学习笔记

2024-05-31 11:44

本文主要是介绍【强化学习】DPO(Direct Preference Optimization)算法学习笔记,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

【强化学习】DPO(Direct Preference Optimization)算法学习笔记

  • RLHF与DPO的关系
  • KL散度
  • Bradley-Terry模型
  • DPO算法流程
  • 参考文献

RLHF与DPO的关系

  • DPO(Direct Preference Optimization)和RLHF(Reinforcement Learning from Human Feedback)都是用于训练和优化人工智能模型的方法,特别是在大型语言模型的训练中
  • DPO和RLHF都旨在通过人类的反馈来优化模型的表现,它们都试图让模型学习到更符合人类偏好的行为或输出
  • RLHF通常涉及三个阶段:全监督微调(Supervised Fine-Tuning)、奖励模型(Reward Model)的训练,以及强化学习(Reinforcement Learning)的微调
  • DPO是一种直接优化模型偏好的方法,不需要显式地定义奖励函数,而是通过比较不同模型输出的结果,选择更符合人类偏好的结果作为训练目标,主要是通过直接最小化或最大化目标函数来实现优化,利用偏好直接指导优化过程,而不依赖于强化学习框架
    在这里插入图片描述

KL散度

  • KL散度(Kullback-Leibler divergence),也被称为相对熵,是衡量两个概率分布P和Q差异的一种方法
  • 公式: K L ( P ∣ ∣ Q ) = ∑ x P ( x ) log ⁡ ( P ( x ) Q ( x ) ) \mathrm{KL}(P||Q)=\sum_xP(x)\log\left(\frac{P(x)}{Q(x)}\right) KL(P∣∣Q)=xP(x)log(Q(x)P(x))
  • KL散度是不对称的, K L ( P ∣ ∣ Q ) ! = K L ( Q ∣ ∣ P ) KL(P||Q)!=KL(Q||P) KL(P∣∣Q)!=KL(Q∣∣P)

在这里插入图片描述

Bradley-Terry模型

  • Bradley-Terry模型是一种用于比较成对对象并确定相对偏好或能力的方法。这种模型特别适用于对成对比较数据进行分析,从而对一组对象进行排序

  • P ( i > j ) = α i α i + α j P(i{>}j)=\frac{\alpha_i}{\alpha_i{+}\alpha_j} P(i>j)=αi+αjαi

  • α i \alpha_i αi表示第 i i i个元素的能力参数,且大于0。 P ( i > j ) P(i>j) P(i>j)表示第 i i i个元素战胜第 j j j个元素的概率

  • Bradley-Terry模型的参数通常通过最大似然估计(MLE)来确定
    在这里插入图片描述

  • sigmoid函数: σ ( x ) = 1 1 + e − x \sigma(x) = \frac{1}{1 + e^{-x}} σ(x)=1+ex1

  • loss函数的化简
    L o s s = − E ( x , y w , y l ) ∼ D [ ln ⁡ e x p ( r ( x , y w ) ) e x p ( r ( x , y w ) ) + e x p ( r ( x , y l ) ) ] = − E ( x , y w , y l ) ∼ D [ ln ⁡ 1 1 + e x p ( r ( x , y l ) − r ( x , y w ) ) ] = − E ( x , y w , y l ) ∼ D [ ln ⁡ σ ( r ( x , y w ) − r ( x , y l ) ) ] \begin{aligned}Loss &=-\mathbb{E}_{(x,y_{w},y_{l})\sim D}[\ln\frac{exp(r(x,y_{w}))}{exp(r(x,y_{w}))+exp(r(x,y_{l}))}] \\ &= -\mathbb{E}_{(x,y_{w},y_{l})\sim D}[\ln\frac{1}{1 + exp(r(x,y_{l})- r(x,y_{w}))}] \\ &= -\mathbb{E}_{(x,y_{w},y_{l})\sim D}[\ln \sigma(r(x,y_{w})-r(x,y_{l}))] \end{aligned} Loss=E(x,yw,yl)D[lnexp(r(x,yw))+exp(r(x,yl))exp(r(x,yw))]=E(x,yw,yl)D[ln1+exp(r(x,yl)r(x,yw))1]=E(x,yw,yl)D[lnσ(r(x,yw)r(x,yl))]

  • loss函数的目标是优化LLM输出的 y w y_w yw,经过reward计算的得分尽可能的大于 y w y_w yw经过reward计算的得分

在这里插入图片描述

DPO算法流程

  • DPO通过比较不同输出的偏好,构建一个目标函数,该函数直接反映人类的偏好,通常使用排序损失函数(例如Pairwise Ranking Loss),该函数用来衡量模型在用户偏好上的表现
  • DPO优化过程:使用梯度下降等优化算法,直接最小化或最大化目标函数。通过不断调整模型参数,使得模型生成的输出更加符合用户的偏好
    在这里插入图片描述
  • 基准模型一般指经过SFT有监督微调后的模型
  • DPO的目标是尽可能得到多的奖励,同时使得新训练的 模型尽可能与基准模型分布一致

DPO训练目标的化简

在这里插入图片描述
上图中第一步利用的是KL散度的定义,之所以式子中没有KL散度中的 P ( π ( y ∣ x ) ) P(\pi(y|x)) P(π(yx)),是因为KL散度可以理解成是一个概率比值的log的期望,在这里这个概率以期望的形式放到式子左边的期望中了

  • 求最大值 通过在式中加上负号转化为求最小值,并同时除以 β \beta β
  • DPO原论文中的推导过程

在这里插入图片描述

  • 继续推导

在这里插入图片描述
在这里插入图片描述

  • 求解reward函数的表达式,将reward函数的表达式代入loss函数中

在这里插入图片描述

  • DPO loss损失函数的表达形式

在这里插入图片描述

  • logZ(x)项被抵消,于是可以转而用最大似然估计MLE直接在这个概率模型上直接优化LM,去得到希望的最优的π*
    个人理解的一知半解 有时间还是得去看看原论文

参考文献

  1. DPO (Direct Preference Optimization) 算法讲解
  2. Direct Preference Optimization(DPO)学习笔记
  3. DPO原论文 Direct Preference Optimization: Your Language Model is Secretly a Reward Model

这篇关于【强化学习】DPO(Direct Preference Optimization)算法学习笔记的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!


原文地址:
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.chinasem.cn/article/1017836

相关文章

Java中的雪花算法Snowflake解析与实践技巧

《Java中的雪花算法Snowflake解析与实践技巧》本文解析了雪花算法的原理、Java实现及生产实践,涵盖ID结构、位运算技巧、时钟回拨处理、WorkerId分配等关键点,并探讨了百度UidGen... 目录一、雪花算法核心原理1.1 算法起源1.2 ID结构详解1.3 核心特性二、Java实现解析2.

Go学习记录之runtime包深入解析

《Go学习记录之runtime包深入解析》Go语言runtime包管理运行时环境,涵盖goroutine调度、内存分配、垃圾回收、类型信息等核心功能,:本文主要介绍Go学习记录之runtime包的... 目录前言:一、runtime包内容学习1、作用:① Goroutine和并发控制:② 垃圾回收:③ 栈和

Android学习总结之Java和kotlin区别超详细分析

《Android学习总结之Java和kotlin区别超详细分析》Java和Kotlin都是用于Android开发的编程语言,它们各自具有独特的特点和优势,:本文主要介绍Android学习总结之Ja... 目录一、空安全机制真题 1:Kotlin 如何解决 Java 的 NullPointerExceptio

使用雪花算法产生id导致前端精度缺失问题解决方案

《使用雪花算法产生id导致前端精度缺失问题解决方案》雪花算法由Twitter提出,设计目的是生成唯一的、递增的ID,下面:本文主要介绍使用雪花算法产生id导致前端精度缺失问题的解决方案,文中通过代... 目录一、问题根源二、解决方案1. 全局配置Jackson序列化规则2. 实体类必须使用Long封装类3.

重新对Java的类加载器的学习方式

《重新对Java的类加载器的学习方式》:本文主要介绍重新对Java的类加载器的学习方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1、介绍1.1、简介1.2、符号引用和直接引用1、符号引用2、直接引用3、符号转直接的过程2、加载流程3、类加载的分类3.1、显示

Springboot实现推荐系统的协同过滤算法

《Springboot实现推荐系统的协同过滤算法》协同过滤算法是一种在推荐系统中广泛使用的算法,用于预测用户对物品(如商品、电影、音乐等)的偏好,从而实现个性化推荐,下面给大家介绍Springboot... 目录前言基本原理 算法分类 计算方法应用场景 代码实现 前言协同过滤算法(Collaborativ

Java学习手册之Filter和Listener使用方法

《Java学习手册之Filter和Listener使用方法》:本文主要介绍Java学习手册之Filter和Listener使用方法的相关资料,Filter是一种拦截器,可以在请求到达Servl... 目录一、Filter(过滤器)1. Filter 的工作原理2. Filter 的配置与使用二、Listen

openCV中KNN算法的实现

《openCV中KNN算法的实现》KNN算法是一种简单且常用的分类算法,本文主要介绍了openCV中KNN算法的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的... 目录KNN算法流程使用OpenCV实现KNNOpenCV 是一个开源的跨平台计算机视觉库,它提供了各

利用Python快速搭建Markdown笔记发布系统

《利用Python快速搭建Markdown笔记发布系统》这篇文章主要为大家详细介绍了使用Python生态的成熟工具,在30分钟内搭建一个支持Markdown渲染、分类标签、全文搜索的私有化知识发布系统... 目录引言:为什么要自建知识博客一、技术选型:极简主义开发栈二、系统架构设计三、核心代码实现(分步解析

springboot+dubbo实现时间轮算法

《springboot+dubbo实现时间轮算法》时间轮是一种高效利用线程资源进行批量化调度的算法,本文主要介绍了springboot+dubbo实现时间轮算法,文中通过示例代码介绍的非常详细,对大家... 目录前言一、参数说明二、具体实现1、HashedwheelTimer2、createWheel3、n