GumbleSoftmax感性理解--可导式输出随机类别

2023-12-12 23:04

本文主要是介绍GumbleSoftmax感性理解--可导式输出随机类别,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

GumbleSoftmax

本文不涉及GumbleSoftmax的具体证明和推导,有需要请参见1,只是从感性角度来直观讲解为何要引入GumbleSoftmax,同时又为什么不用Gumblemax。

 GumbleSoftmax提出是为了应对分布采样不可导的问题。举例而言,我们从网络经Softmax层输出了类别概率向量 p 1 = [ 0.9 , 0.1 , 0.1 ] p_1=[0.9,0.1,0.1] p1=[0.9,0.1,0.1] p 2 = [ 0.5 , 0.2 , 0.3 ] p_2=[0.5,0.2,0.3] p2=[0.5,0.2,0.3],那么如果我们训练网络最终的输出需求只是从中得到对应的类别结果(分类任务),那么 p 1 p_1 p1 p 2 p_2 p2其实都是合理的,因为我们我们最终得到的都只会是 a r g m a x ( p ) = 0 argmax(p)=0 argmax(p)=0。但如果我们正在进行生成任务,这一类别结果只是一个中间值,而我们希望这一类别概率向量真正体现出了概率的含义,那么 p 1 , p 2 p_1,p_2 p1p2就会有着显著的差异,后者采样出第1、2类的的结果要明显高于前者。
 因此为了突出网络输出的概率属性,我们可以简单的依照这一概率向量进行采样即可,定一个均匀分布 U ( 0 , 1 ) U(0,1) U(0,1),落在哪个概率区间就认为输出哪一个类别,但这一采样操作是不可导的,也就无法使网络端到端训练。GumbleSoftmax的提出就是为了解决这一问题,它让网络输出类别随机的同时,又使得这一采样过程可导。一句话总结:GumbleSoftmaxd代替了网络中的 a r g m a x argmax argmax,引入了:

  1. 随机性:网络的输出真的变成了由最终概率向量决定的随机变量,即logit输出 [ 0.9 , 0.1 , 0.1 ] [0.9,0.1,0.1] [0.9,0.1,0.1]真的可能因抽样而判定为第2类;
  2. 可导性:这一抽样过程可导,可以融入到网络端到端训练过程中。(伪)

GumbleMax

 为了让网络的输出类别真正的随机,我们需要先将对 a r g m a x argmax argmax进行替换,既然网络输出随机的就不可导的话,我们就利用重参数技巧将这一随机性放到另一个随机变量上,也就得到了GumbleMax,公式如下:
x = a r g m a x ( l o g ( x ) + G ) , \bold{x}=argmax(log(\bold{x})+\bold{G}), x=argmax(log(x)+G),
其中 x , G \bold{x},\bold{G} x,G分别是网络输出的概率向量、符合Gumble分布的噪声向量, G i = − l o g ( − l o g ( U i ) ) , U i U ( 0 , 1 ) G_i=-log(-log(U_i)),U_i~U(0,1) Gi=log(log(Ui)),Ui U(0,1)。这一噪声向量的引入就会使得argmax的输出结果发生扰动,变成一个随机变量。同样是之前的例子, l o g ( p 1 ) + G log(p_1)+\bold{G} log(p1)+G就有可能变为 [ 0.5 , 0.6 , 0.5 ] [0.5,0.6,0.5] [0.5,0.6,0.5]而使得最终输出类别为第1类,而 a r g m a x ( l o g ( x ) + G ) argmax(log(\bold{x})+\bold{G}) argmax(log(x)+G)服从这一随机变量服从 x x x的离散分布列证明见附1
 通过引入GumbleMax,我们成功的为网络的类别输出引入了随机性。但可导性的问题并没有解决,因为这里仍然是存在了argmax。

GumbleSoftMax

 GumbleSoftMax对GumbleMax的解决也很简单,它又把argmax替换成为了softmax,得到如下计算:
x = s o f t m a x ( ( l o g ( x ) + G ) / τ ) , \bold{x}=softmax((log(\bold{x})+\bold{G})/\tau), x=softmax((log(x)+G)/τ),
其中 τ \tau τ为为温度参数,这一算式中通过对argmax的软化实现了可导操作。至此,也就完成了为了网络输出引入可导随机性的目标。

矛盾

 讨论至此,有个非常反直觉的考量,那就是相比于GumbleMax的硬输出onehot向量,GumbleSoftMax的输出似乎又变成了概率向量,我们想要得到的具体的类别输出,还要继续再取argmax也就是 a r g m a x ( s o f t m a x ( ( l o g ( x ) + G ) ) / τ ) argmax(softmax((log(\bold{x})+\bold{G}))/\tau) argmax(softmax((log(x)+G))/τ)。那么这不是仍然不可导,仍然返回了GumbleMax的窘境?因此这里依据个人理解要做出以下的澄清:

  1. 确实不可导,如果我们希望从GumbleSoftMax输出一个类别值,那么就必然引入argmax,也就必然不可导。而在实际过程中,我们则是回避了对argmax求导的问题,直接对 s o f t m a x ( ( l o g ( x ) + G ) ) / τ softmax((log(\bold{x})+\bold{G}))/\tau softmax((log(x)+G))/τ进行求导,具体可以参见pytorch中Gumblesoftmax的实现2
  2. 既然如此,那为什么不照猫画虎在使用Gumblemax的时候就忽略argmax的存在,直接对 ( l o g ( x ) + G ) (log(\bold{x})+\bold{G}) (log(x)+G)求导?这是因为 a r g m a x ( l o g ( x ) + G ) argmax(log(\bold{x})+\bold{G}) argmax(log(x)+G)本身才是我们想要求导的对象,而因为argmax本身不可导,所以引入了softmax来替代,也即我们相对 [ 1 , 0 , 0 ] [1,0,0] [1,0,0]求导,迫不得已对 [ 0.8 , 0.1 , 0.1 ] [0.8,0.1,0.1] [0.8,0.1,0.1]求导,算是某种程度上的导数近似。而在1中的argmax本身也不是我们求导的对象,只是由于这一近似带来的补偿。而更进一步的,假设我们直接对 ( l o g ( x ) + G ) (log(\bold{x})+\bold{G}) (log(x)+G)进行求导,那么这一近似带来的误差只会更大,也让随机噪声的引入失去了意义,等价于对 l o g ( x ) log(x) log(x)求导。这也就是为什么开头的可导加了,因为我们是在对softmax求导,而不是argmax。

总结

 整体而言,GumbleSoftmax通过引入了Gumble随机噪声使得输出的类别真正具有随机性,而将argmax软化为softmax则使得这一随机过程可导。

参考文献


  1. Gumbel-Softmax Trick和Gumbel分布 ↩︎ ↩︎

  2. 请问用Gumbel-softmax的时候,怎么让softmax输出的概率分布转化成one-hot向量? ↩︎

这篇关于GumbleSoftmax感性理解--可导式输出随机类别的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/486249

相关文章

GO语言zap日志库理解和使用方法示例

《GO语言zap日志库理解和使用方法示例》Zap是一个高性能、结构化日志库,专为Go语言设计,它由Uber开源,并且在Go社区中非常受欢迎,:本文主要介绍GO语言zap日志库理解和使用方法的相关资... 目录1. zap日志库介绍2.安装zap库3.配置日志记录器3.1 Logger3.2 Sugared

深入理解Redis线程模型的原理及使用

《深入理解Redis线程模型的原理及使用》Redis的线程模型整体还是多线程的,只是后台执行指令的核心线程是单线程的,整个线程模型可以理解为还是以单线程为主,基于这种单线程为主的线程模型,不同客户端的... 目录1 Redis是单线程www.chinasem.cn还是多线程2 Redis如何保证指令原子性2.

深入理解MySQL流模式

《深入理解MySQL流模式》MySQL的Binlog流模式是一种实时读取二进制日志的技术,允许下游系统几乎无延迟地获取数据库变更事件,适用于需要极低延迟复制的场景,感兴趣的可以了解一下... 目录核心概念一句话总结1. 背景知识:什么是 Binlog?2. 传统方式 vs. 流模式传统文件方式 (非流式)流

深入理解Go之==的使用

《深入理解Go之==的使用》本文主要介绍了深入理解Go之==的使用,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧... 目录概述类型基本类型复合类型引用类型接口类型使用type定义的类型不可比较性谈谈map总结概述相信==判等操作,大

深入理解Mysql OnlineDDL的算法

《深入理解MysqlOnlineDDL的算法》本文主要介绍了讲解MysqlOnlineDDL的算法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小... 目录一、Online DDL 是什么?二、Online DDL 的三种主要算法2.1COPY(复制法)

从基础到高级详解Python数值格式化输出的完全指南

《从基础到高级详解Python数值格式化输出的完全指南》在数据分析、金融计算和科学报告领域,数值格式化是提升可读性和专业性的关键技术,本文将深入解析Python中数值格式化输出的相关方法,感兴趣的小伙... 目录引言:数值格式化的核心价值一、基础格式化方法1.1 三种核心格式化方式对比1.2 基础格式化示例

java -jar example.jar 产生的日志输出到指定文件的方法

《java-jarexample.jar产生的日志输出到指定文件的方法》这篇文章给大家介绍java-jarexample.jar产生的日志输出到指定文件的方法,本文给大家介绍的非常详细,对大家的... 目录怎么让 Java -jar example.jar 产生的日志输出到指定文件一、方法1:使用重定向1、

深入理解go中interface机制

《深入理解go中interface机制》本文主要介绍了深入理解go中interface机制,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学... 目录前言interface使用类型判断总结前言go的interface是一组method的集合,不

Spring Boot集成/输出/日志级别控制/持久化开发实践

《SpringBoot集成/输出/日志级别控制/持久化开发实践》SpringBoot默认集成Logback,支持灵活日志级别配置(INFO/DEBUG等),输出包含时间戳、级别、类名等信息,并可通过... 目录一、日志概述1.1、Spring Boot日志简介1.2、日志框架与默认配置1.3、日志的核心作用

Java Spring的依赖注入理解及@Autowired用法示例详解

《JavaSpring的依赖注入理解及@Autowired用法示例详解》文章介绍了Spring依赖注入(DI)的概念、三种实现方式(构造器、Setter、字段注入),区分了@Autowired(注入... 目录一、什么是依赖注入(DI)?1. 定义2. 举个例子二、依赖注入的几种方式1. 构造器注入(Con