pytorch中backward函数的参数gradient作用的数学过程

2023-10-28 15:20

本文主要是介绍pytorch中backward函数的参数gradient作用的数学过程,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

pytorch中backward函数的参数gradient作用的数学过程

    • 问题描述
    • 实例分析

本机器学习小白最近在学习pytorch,在学习.backward()函数的过程中一直不能理解参数gradient的作用,感觉相关资料中对它的解释过于简单,几乎忽略了相关数学过程。这里分享一下我的理解,希望能对有需要的同学有些帮助。

.backward()函数是pytorch用来实现反向传播计算的关键,不了解反向传播或者.backward()函数的可以看看下面的文章:

  • pytorch中backward()函数详解
  • pytorch的计算图
  • PyTorch的自动求导机制详细解析,PyTorch的核心魔法
  • 反向传播——通俗易懂
  • 深度学习——以图读懂反向传播

问题描述

官方文档对.backward()的说明是这样的(看英文头痛的同学可以忽略解释内容):
在这里插入图片描述
gradient的解释是这样的:
在这里插入图片描述
可以粗浅的认为backward()就是在根据计算图计算tensor的“梯度”,但是这个“梯度”只有在tensor是标量(scalar)时才是真正意义上的梯度。

gradient在实际使用的过程中,是一个与使用.backward()的tensor维度一致的tensor。官方文档中只提到如果使用.backward()的tensor是一个标量,就可以省略gradient参数,而如果这个tensor是向量,则必须指明gradient参数,但这个参数对计算的作用是什么,文档中说的很模糊(好像就没提到)。

看了一些文章之后发现,“如果是(向量)矩阵对(向量)矩阵求导(tensor对tensor求导),实际上是先求出Jacobian矩阵中每一个元素的梯度值(每一个元素的梯度值的求解过程对应上面的计算图的求解方法),然后将这个Jacobian矩阵与grad_tensors参数对应的矩阵进行对应的点乘,得到最终的结果。”(来源:pytorch中backward()函数详解)
这里提到的grad_tensors参数就是现在的gradient参数。

所以本质上,gradient参数在向量与向量的求导中起作用,而backward()在这种情况下求得的各个元素的梯度实际上并不是Jacobian,而是Jacobian与gradient的乘积。

以下结合一些例子说明backward()函数的计算结果。

实例分析

来源:PyTorch的自动求导机制详细解析,PyTorch的核心魔法

import torchx = torch.tensor([0.0, 2.0, 8.0], requires_grad = True)y = torch.tensor([5.0, 1.0, 7.0], requires_grad = True)z = x * yz.backward(torch.FloatTensor([1.0, 1.0, 1.0]))

运行完之后查看z分别关于x和y的梯度可以发现:

>>>x.grad.data
tensor([5., 1., 7.])
>>>y.grad.data
tensor([0., 2., 8.])

实际上上述代码的计算结果可以这么理解:
x = ( x 1 x 2 x 3 ) = ( 0.0 2.0 8.0 ) y = ( y 1 y 2 y 3 ) = ( 5.0 1.0 7.0 ) x = \begin{pmatrix} x_1 & x_2 & x_3\end{pmatrix} = \begin{pmatrix} 0.0 & 2.0 & 8.0\end{pmatrix}\\ y = \begin{pmatrix} y_1 & y_2 & y_3\end{pmatrix} = \begin{pmatrix} 5.0 & 1.0 & 7.0\end{pmatrix} x=(x1x2x3)=(0.02.08.0)y=(y1y2y3)=(5.01.07.0) z z z向量则是 x x x y y y每项相乘得到的向量(不是点乘也不是叉乘)
z = ( x 1 y 1 x 2 y 2 x 3 y 3 ) z = \begin{pmatrix} x_1y_1 & x_2y_2 & x_3y_3\end{pmatrix} z=(x1y1x2y2x3y3)那么 z z z关于 x x x的Jacobian就是
J = ( ∂ z ∂ x 1 ∂ z ∂ x 2 ∂ z ∂ x 3 ) = ( ∂ z 1 ∂ x 1 ∂ z 1 ∂ x 2 ∂ z 1 ∂ x 3 ∂ z 2 ∂ x 1 ∂ z 2 ∂ x 2 ∂ z 2 ∂ x 3 ∂ z 3 ∂ x 1 ∂ z 3 ∂ x 2 ∂ z 3 ∂ x 3 ) = ( y 1 0 0 0 y 2 0 0 0 y 3 ) J=\begin{pmatrix} \frac{\partial z}{\partial x_1} & \frac{\partial z}{\partial x_2} & \frac{\partial z}{\partial x_3} \end{pmatrix} = \begin{pmatrix} \frac{\partial z_1}{\partial x_1} & \frac{\partial z_1}{\partial x_2} & \frac{\partial z_1}{\partial x_3} \\ \frac{\partial z_2}{\partial x_1} & \frac{\partial z_2}{\partial x_2} & \frac{\partial z_2}{\partial x_3} \\ \frac{\partial z_3}{\partial x_1} & \frac{\partial z_3}{\partial x_2} & \frac{\partial z_3}{\partial x_3} \end{pmatrix} =\begin{pmatrix} y_1 & 0 & 0\\ 0 & y_2 & 0\\ 0 & 0 & y_3 \end{pmatrix} J=(x1zx2zx3z)=x1z1x1z2x1z3x2z1x2z2x2z3x3z1x3z2x3z3=y1000y2000y3而我们引入的gradient参数是一个向量(这里用 v v v表示)
v = ( 1.0 1.0 1.0 ) T v = \begin{pmatrix} 1.0 & 1.0 &1.0 \end{pmatrix}^{T} v=(1.01.01.0)T然后将 J J J v v v相乘,我们就得到了我们看到的x.grad.data的结果:
J v = ( y 1 0 0 0 y 2 0 0 0 y 3 ) ( 1.0 1.0 1.0 ) = ( y 1 y 2 y 3 ) = ( 5.0 1.0 7.0 ) Jv =\begin{pmatrix} y_1 & 0 & 0\\ 0 & y_2 & 0\\ 0 & 0 & y_3 \end{pmatrix} \begin{pmatrix} 1.0 \\ 1.0 \\ 1.0 \end{pmatrix} = \begin{pmatrix} y_1 \\ y_2 \\ y_3 \end{pmatrix} = \begin{pmatrix} 5.0 \\ 1.0 \\ 7.0 \end{pmatrix} Jv=y1000y2000y31.01.01.0=y1y2y3=5.01.07.0y.grad.data的结果这里就不演示了,因为这个例子里 x x x y y y具有对称性,只需要把结果的 y y y换成 x x x就得到了y.grad.data的结果,所以我们可以看到x.grad.data和y.grad.data在数值上就是 y y y x x x的值。

这篇关于pytorch中backward函数的参数gradient作用的数学过程的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/294116

相关文章

Pandas中统计汇总可视化函数plot()的使用

《Pandas中统计汇总可视化函数plot()的使用》Pandas提供了许多强大的数据处理和分析功能,其中plot()函数就是其可视化功能的一个重要组成部分,本文主要介绍了Pandas中统计汇总可视化... 目录一、plot()函数简介二、plot()函数的基本用法三、plot()函数的参数详解四、使用pl

判断PyTorch是GPU版还是CPU版的方法小结

《判断PyTorch是GPU版还是CPU版的方法小结》PyTorch作为当前最流行的深度学习框架之一,支持在CPU和GPU(NVIDIACUDA)上运行,所以对于深度学习开发者来说,正确识别PyTor... 目录前言为什么需要区分GPU和CPU版本?性能差异硬件要求如何检查PyTorch版本?方法1:使用命

SpringBoot请求参数接收控制指南分享

《SpringBoot请求参数接收控制指南分享》:本文主要介绍SpringBoot请求参数接收控制指南,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录Spring Boot 请求参数接收控制指南1. 概述2. 有注解时参数接收方式对比3. 无注解时接收参数默认位置

Python的time模块一些常用功能(各种与时间相关的函数)

《Python的time模块一些常用功能(各种与时间相关的函数)》Python的time模块提供了各种与时间相关的函数,包括获取当前时间、处理时间间隔、执行时间测量等,:本文主要介绍Python的... 目录1. 获取当前时间2. 时间格式化3. 延时执行4. 时间戳运算5. 计算代码执行时间6. 转换为指

Python正则表达式语法及re模块中的常用函数详解

《Python正则表达式语法及re模块中的常用函数详解》这篇文章主要给大家介绍了关于Python正则表达式语法及re模块中常用函数的相关资料,正则表达式是一种强大的字符串处理工具,可以用于匹配、切分、... 目录概念、作用和步骤语法re模块中的常用函数总结 概念、作用和步骤概念: 本身也是一个字符串,其中

Python使用getopt处理命令行参数示例解析(最佳实践)

《Python使用getopt处理命令行参数示例解析(最佳实践)》getopt模块是Python标准库中一个简单但强大的命令行参数处理工具,它特别适合那些需要快速实现基本命令行参数解析的场景,或者需要... 目录为什么需要处理命令行参数?getopt模块基础实际应用示例与其他参数处理方式的比较常见问http

Linux内核参数配置与验证详细指南

《Linux内核参数配置与验证详细指南》在Linux系统运维和性能优化中,内核参数(sysctl)的配置至关重要,本文主要来聊聊如何配置与验证这些Linux内核参数,希望对大家有一定的帮助... 目录1. 引言2. 内核参数的作用3. 如何设置内核参数3.1 临时设置(重启失效)3.2 永久设置(重启仍生效

pytorch自动求梯度autograd的实现

《pytorch自动求梯度autograd的实现》autograd是一个自动微分引擎,它可以自动计算张量的梯度,本文主要介绍了pytorch自动求梯度autograd的实现,具有一定的参考价值,感兴趣... autograd是pytorch构建神经网络的核心。在 PyTorch 中,结合以下代码例子,当你

SpringMVC获取请求参数的方法

《SpringMVC获取请求参数的方法》:本文主要介绍SpringMVC获取请求参数的方法,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下... 目录1、通过ServletAPI获取2、通过控制器方法的形参获取请求参数3、@RequestParam4、@

shell编程之函数与数组的使用详解

《shell编程之函数与数组的使用详解》:本文主要介绍shell编程之函数与数组的使用,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录shell函数函数的用法俩个数求和系统资源监控并报警函数函数变量的作用范围函数的参数递归函数shell数组获取数组的长度读取某下的