内涵:算法学习之gumbel softmax

2024-06-05 23:32

本文主要是介绍内涵:算法学习之gumbel softmax,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1. gumbel_softmax有什么用呢?

假设如下场景:
模型训练过程中, 网络的输出为p = [0.1, 0.7, 0.2], 三个数值分别为"向左", “向上”, "向右"的概率。 我们的决策可能是y = argmax§, 也即选择"向上"这条决策。
但是,这样做会有两个问题:

  1. argmax()函数是不可导的。这样网络就无法通过反向传播进行学习。
  2. argmax()的选择不具有随机性。同样的输出p选择100次,每次的结果都为"向上"。而按照概率为0.7的含义,100次应该有70次左右的决策结果是选择"向上".

而gumbel_softmax的作用就是解决上述这两个子问题.。

2.argmax(x)是什么?为什么不可导?

为了更直观,这里使用两维的vector
y = argmax(x); x = (x1, x2)

%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.patches import FancyArrowPatch# https://www.itdaan.com/blog/2014/04/04/32dfc1abfd5a635469b7762c516a37b3.html
class Arrow3D(FancyArrowPatch):def __init__(self, xs, ys, zs, *args, **kwargs):FancyArrowPatchcyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs)self._verts3d = xs, ys, zsdef draw(self, renderer):xs3d, ys3d, zs3d = self._verts3dxs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, renderer.M)self.set_positions((xs[0],ys[0]),(xs[1],ys[1]))FancyArrowPatch.draw(self, renderer)# 绘制argmax()的第一段
xs = [0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
ys = [0.5, 0.4, 0.3, 0.2 , 0.1, 0.0]
zs = [0, 0, 0, 0, 0, 0]
fig = plt.figure()
ax = axisartist.Subplot(fig, 111)
ax = fig.add_axes((0.1,0.1,0.8,0.8), projection='3d')
ax.plot3D(xs, ys, zs, c='red', marker='o')
ys = [0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
xs = [0.5, 0.4, 0.3, 0.2 , 0.1, 0.0]
zs = [1, 1, 1, 1, 1, 1]
ax.plot3D(xs, ys, zs, c='blue', marker='o')
plt.xlim(0, 1)
plt.ylim(0, 1)
ax.view_init(azim=30, elev=30)
plt.show()

在这里插入图片描述
多元函数可微分的充分条件是函数连续且具有偏导数. 从argmax的三维图可以看出, argmax(x), 首先在x1 = x2处不连续,因此在该点处必定是不可导的. 在红线处, 保持x1不变, 求 y相对于x2的偏微分,发现是不存在的.因为x1不变的情况下,x2也是无法有一个微小的变动. 故, argmax()函数不可微分.

3. 引入随机性:gumbel分布

为了在y=argmax§中引入随机性, 将其修改为y = argmax(log§ + G).G称之为gumbel分布, 它的数学表达式为G=-log(-log( ξ \xi ξ)))。引入该分布的作用是引入了随机性,且该随机性保证了该分布输出i的概率等于pi。下面是科学空间上的证明,比较容易理解。
在这里插入图片描述

4. 解决不可导:gumbel_softmax

  1. 解决不可导的方法可以用gumbel_softmax来处理。也即forward阶段,使用argmax操作,暂时不用管后面反向操作;但在反向阶段则使用gumbel_softmax来做bp计算,可以通过看pytorch中相关代码块有一个很清晰的认知。
def gumbel_softmax(logits: Tensor, tau: float = 1, hard: bool = False, eps: float = 1e-10, dim: int = -1) -> Tensor:。。。gumbels = (-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log())  # ~Gumbel(0,1)gumbels = (logits + gumbels) / tau  # ~Gumbel(logits,tau)y_soft = gumbels.softmax(dim)if hard:# Straight through.index = y_soft.max(dim, keepdim=True)[1]y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)ret = y_hard - y_soft.detach() + y_softelse:# Reparametrization trick.ret = y_softreturn ret
  1. 这样做是没有问题的,但是前向的y_hard,与y_softmax我们还是要尽可能缩小它们之间的“误差”,因此gumbel_softmax中引入了温度t, t越小,softmax就越接近One-hot。为了训练稳定性,一般t会取一个比较大的数字,然后逐步缩小。

这篇关于内涵:算法学习之gumbel softmax的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1034482

相关文章

Java学习手册之Filter和Listener使用方法

《Java学习手册之Filter和Listener使用方法》:本文主要介绍Java学习手册之Filter和Listener使用方法的相关资料,Filter是一种拦截器,可以在请求到达Servl... 目录一、Filter(过滤器)1. Filter 的工作原理2. Filter 的配置与使用二、Listen

openCV中KNN算法的实现

《openCV中KNN算法的实现》KNN算法是一种简单且常用的分类算法,本文主要介绍了openCV中KNN算法的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的... 目录KNN算法流程使用OpenCV实现KNNOpenCV 是一个开源的跨平台计算机视觉库,它提供了各

springboot+dubbo实现时间轮算法

《springboot+dubbo实现时间轮算法》时间轮是一种高效利用线程资源进行批量化调度的算法,本文主要介绍了springboot+dubbo实现时间轮算法,文中通过示例代码介绍的非常详细,对大家... 目录前言一、参数说明二、具体实现1、HashedwheelTimer2、createWheel3、n

SpringBoot实现MD5加盐算法的示例代码

《SpringBoot实现MD5加盐算法的示例代码》加盐算法是一种用于增强密码安全性的技术,本文主要介绍了SpringBoot实现MD5加盐算法的示例代码,文中通过示例代码介绍的非常详细,对大家的学习... 目录一、什么是加盐算法二、如何实现加盐算法2.1 加盐算法代码实现2.2 注册页面中进行密码加盐2.

Java时间轮调度算法的代码实现

《Java时间轮调度算法的代码实现》时间轮是一种高效的定时调度算法,主要用于管理延时任务或周期性任务,它通过一个环形数组(时间轮)和指针来实现,将大量定时任务分摊到固定的时间槽中,极大地降低了时间复杂... 目录1、简述2、时间轮的原理3. 时间轮的实现步骤3.1 定义时间槽3.2 定义时间轮3.3 使用时

Java进阶学习之如何开启远程调式

《Java进阶学习之如何开启远程调式》Java开发中的远程调试是一项至关重要的技能,特别是在处理生产环境的问题或者协作开发时,:本文主要介绍Java进阶学习之如何开启远程调式的相关资料,需要的朋友... 目录概述Java远程调试的开启与底层原理开启Java远程调试底层原理JVM参数总结&nbsMbKKXJx

如何通过Golang的container/list实现LRU缓存算法

《如何通过Golang的container/list实现LRU缓存算法》文章介绍了Go语言中container/list包实现的双向链表,并探讨了如何使用链表实现LRU缓存,LRU缓存通过维护一个双向... 目录力扣:146. LRU 缓存主要结构 List 和 Element常用方法1. 初始化链表2.

golang字符串匹配算法解读

《golang字符串匹配算法解读》文章介绍了字符串匹配算法的原理,特别是Knuth-Morris-Pratt(KMP)算法,该算法通过构建模式串的前缀表来减少匹配时的不必要的字符比较,从而提高效率,在... 目录简介KMP实现代码总结简介字符串匹配算法主要用于在一个较长的文本串中查找一个较短的字符串(称为

通俗易懂的Java常见限流算法具体实现

《通俗易懂的Java常见限流算法具体实现》:本文主要介绍Java常见限流算法具体实现的相关资料,包括漏桶算法、令牌桶算法、Nginx限流和Redis+Lua限流的实现原理和具体步骤,并比较了它们的... 目录一、漏桶算法1.漏桶算法的思想和原理2.具体实现二、令牌桶算法1.令牌桶算法流程:2.具体实现2.1

Java深度学习库DJL实现Python的NumPy方式

《Java深度学习库DJL实现Python的NumPy方式》本文介绍了DJL库的背景和基本功能,包括NDArray的创建、数学运算、数据获取和设置等,同时,还展示了如何使用NDArray进行数据预处理... 目录1 NDArray 的背景介绍1.1 架构2 JavaDJL使用2.1 安装DJL2.2 基本操