pytorch中backward函数的参数gradient作用的数学过程

2023-10-28 15:20

本文主要是介绍pytorch中backward函数的参数gradient作用的数学过程,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

pytorch中backward函数的参数gradient作用的数学过程

    • 问题描述
    • 实例分析

本机器学习小白最近在学习pytorch,在学习.backward()函数的过程中一直不能理解参数gradient的作用,感觉相关资料中对它的解释过于简单,几乎忽略了相关数学过程。这里分享一下我的理解,希望能对有需要的同学有些帮助。

.backward()函数是pytorch用来实现反向传播计算的关键,不了解反向传播或者.backward()函数的可以看看下面的文章:

  • pytorch中backward()函数详解
  • pytorch的计算图
  • PyTorch的自动求导机制详细解析,PyTorch的核心魔法
  • 反向传播——通俗易懂
  • 深度学习——以图读懂反向传播

问题描述

官方文档对.backward()的说明是这样的(看英文头痛的同学可以忽略解释内容):
在这里插入图片描述
gradient的解释是这样的:
在这里插入图片描述
可以粗浅的认为backward()就是在根据计算图计算tensor的“梯度”,但是这个“梯度”只有在tensor是标量(scalar)时才是真正意义上的梯度。

gradient在实际使用的过程中,是一个与使用.backward()的tensor维度一致的tensor。官方文档中只提到如果使用.backward()的tensor是一个标量,就可以省略gradient参数,而如果这个tensor是向量,则必须指明gradient参数,但这个参数对计算的作用是什么,文档中说的很模糊(好像就没提到)。

看了一些文章之后发现,“如果是(向量)矩阵对(向量)矩阵求导(tensor对tensor求导),实际上是先求出Jacobian矩阵中每一个元素的梯度值(每一个元素的梯度值的求解过程对应上面的计算图的求解方法),然后将这个Jacobian矩阵与grad_tensors参数对应的矩阵进行对应的点乘,得到最终的结果。”(来源:pytorch中backward()函数详解)
这里提到的grad_tensors参数就是现在的gradient参数。

所以本质上,gradient参数在向量与向量的求导中起作用,而backward()在这种情况下求得的各个元素的梯度实际上并不是Jacobian,而是Jacobian与gradient的乘积。

以下结合一些例子说明backward()函数的计算结果。

实例分析

来源:PyTorch的自动求导机制详细解析,PyTorch的核心魔法

import torchx = torch.tensor([0.0, 2.0, 8.0], requires_grad = True)y = torch.tensor([5.0, 1.0, 7.0], requires_grad = True)z = x * yz.backward(torch.FloatTensor([1.0, 1.0, 1.0]))

运行完之后查看z分别关于x和y的梯度可以发现:

>>>x.grad.data
tensor([5., 1., 7.])
>>>y.grad.data
tensor([0., 2., 8.])

实际上上述代码的计算结果可以这么理解:
x = ( x 1 x 2 x 3 ) = ( 0.0 2.0 8.0 ) y = ( y 1 y 2 y 3 ) = ( 5.0 1.0 7.0 ) x = \begin{pmatrix} x_1 & x_2 & x_3\end{pmatrix} = \begin{pmatrix} 0.0 & 2.0 & 8.0\end{pmatrix}\\ y = \begin{pmatrix} y_1 & y_2 & y_3\end{pmatrix} = \begin{pmatrix} 5.0 & 1.0 & 7.0\end{pmatrix} x=(x1x2x3)=(0.02.08.0)y=(y1y2y3)=(5.01.07.0) z z z向量则是 x x x y y y每项相乘得到的向量(不是点乘也不是叉乘)
z = ( x 1 y 1 x 2 y 2 x 3 y 3 ) z = \begin{pmatrix} x_1y_1 & x_2y_2 & x_3y_3\end{pmatrix} z=(x1y1x2y2x3y3)那么 z z z关于 x x x的Jacobian就是
J = ( ∂ z ∂ x 1 ∂ z ∂ x 2 ∂ z ∂ x 3 ) = ( ∂ z 1 ∂ x 1 ∂ z 1 ∂ x 2 ∂ z 1 ∂ x 3 ∂ z 2 ∂ x 1 ∂ z 2 ∂ x 2 ∂ z 2 ∂ x 3 ∂ z 3 ∂ x 1 ∂ z 3 ∂ x 2 ∂ z 3 ∂ x 3 ) = ( y 1 0 0 0 y 2 0 0 0 y 3 ) J=\begin{pmatrix} \frac{\partial z}{\partial x_1} & \frac{\partial z}{\partial x_2} & \frac{\partial z}{\partial x_3} \end{pmatrix} = \begin{pmatrix} \frac{\partial z_1}{\partial x_1} & \frac{\partial z_1}{\partial x_2} & \frac{\partial z_1}{\partial x_3} \\ \frac{\partial z_2}{\partial x_1} & \frac{\partial z_2}{\partial x_2} & \frac{\partial z_2}{\partial x_3} \\ \frac{\partial z_3}{\partial x_1} & \frac{\partial z_3}{\partial x_2} & \frac{\partial z_3}{\partial x_3} \end{pmatrix} =\begin{pmatrix} y_1 & 0 & 0\\ 0 & y_2 & 0\\ 0 & 0 & y_3 \end{pmatrix} J=(x1zx2zx3z)=x1z1x1z2x1z3x2z1x2z2x2z3x3z1x3z2x3z3=y1000y2000y3而我们引入的gradient参数是一个向量(这里用 v v v表示)
v = ( 1.0 1.0 1.0 ) T v = \begin{pmatrix} 1.0 & 1.0 &1.0 \end{pmatrix}^{T} v=(1.01.01.0)T然后将 J J J v v v相乘,我们就得到了我们看到的x.grad.data的结果:
J v = ( y 1 0 0 0 y 2 0 0 0 y 3 ) ( 1.0 1.0 1.0 ) = ( y 1 y 2 y 3 ) = ( 5.0 1.0 7.0 ) Jv =\begin{pmatrix} y_1 & 0 & 0\\ 0 & y_2 & 0\\ 0 & 0 & y_3 \end{pmatrix} \begin{pmatrix} 1.0 \\ 1.0 \\ 1.0 \end{pmatrix} = \begin{pmatrix} y_1 \\ y_2 \\ y_3 \end{pmatrix} = \begin{pmatrix} 5.0 \\ 1.0 \\ 7.0 \end{pmatrix} Jv=y1000y2000y31.01.01.0=y1y2y3=5.01.07.0y.grad.data的结果这里就不演示了,因为这个例子里 x x x y y y具有对称性,只需要把结果的 y y y换成 x x x就得到了y.grad.data的结果,所以我们可以看到x.grad.data和y.grad.data在数值上就是 y y y x x x的值。

这篇关于pytorch中backward函数的参数gradient作用的数学过程的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/294116

相关文章

Python函数作用域与闭包举例深度解析

《Python函数作用域与闭包举例深度解析》Python函数的作用域规则和闭包是编程中的关键概念,它们决定了变量的访问和生命周期,:本文主要介绍Python函数作用域与闭包的相关资料,文中通过代码... 目录1. 基础作用域访问示例1:访问全局变量示例2:访问外层函数变量2. 闭包基础示例3:简单闭包示例4

Redis中Hash从使用过程到原理说明

《Redis中Hash从使用过程到原理说明》RedisHash结构用于存储字段-值对,适合对象数据,支持HSET、HGET等命令,采用ziplist或hashtable编码,通过渐进式rehash优化... 目录一、开篇:Hash就像超市的货架二、Hash的基本使用1. 常用命令示例2. Java操作示例三

Redis中Set结构使用过程与原理说明

《Redis中Set结构使用过程与原理说明》本文解析了RedisSet数据结构,涵盖其基本操作(如添加、查找)、集合运算(交并差)、底层实现(intset与hashtable自动切换机制)、典型应用场... 目录开篇:从购物车到Redis Set一、Redis Set的基本操作1.1 编程常用命令1.2 集

Linux下利用select实现串口数据读取过程

《Linux下利用select实现串口数据读取过程》文章介绍Linux中使用select、poll或epoll实现串口数据读取,通过I/O多路复用机制在数据到达时触发读取,避免持续轮询,示例代码展示设... 目录示例代码(使用select实现)代码解释总结在 linux 系统里,我们可以借助 select、

k8s中实现mysql主备过程详解

《k8s中实现mysql主备过程详解》文章讲解了在K8s中使用StatefulSet部署MySQL主备架构,包含NFS安装、storageClass配置、MySQL部署及同步检查步骤,确保主备数据一致... 目录一、k8s中实现mysql主备1.1 环境信息1.2 部署nfs-provisioner1.2.

Python中isinstance()函数原理解释及详细用法示例

《Python中isinstance()函数原理解释及详细用法示例》isinstance()是Python内置的一个非常有用的函数,用于检查一个对象是否属于指定的类型或类型元组中的某一个类型,它是Py... 目录python中isinstance()函数原理解释及详细用法指南一、isinstance()函数

python中的高阶函数示例详解

《python中的高阶函数示例详解》在Python中,高阶函数是指接受函数作为参数或返回函数作为结果的函数,下面:本文主要介绍python中高阶函数的相关资料,文中通过代码介绍的非常详细,需要的朋... 目录1.定义2.map函数3.filter函数4.reduce函数5.sorted函数6.自定义高阶函数

Python中的sort方法、sorted函数与lambda表达式及用法详解

《Python中的sort方法、sorted函数与lambda表达式及用法详解》文章对比了Python中list.sort()与sorted()函数的区别,指出sort()原地排序返回None,sor... 目录1. sort()方法1.1 sort()方法1.2 基本语法和参数A. reverse参数B.

Spring的基础事务注解@Transactional作用解读

《Spring的基础事务注解@Transactional作用解读》文章介绍了Spring框架中的事务管理,核心注解@Transactional用于声明事务,支持传播机制、隔离级别等配置,结合@Tran... 目录一、事务管理基础1.1 Spring事务的核心注解1.2 注解属性详解1.3 实现原理二、事务事

C#中通过Response.Headers设置自定义参数的代码示例

《C#中通过Response.Headers设置自定义参数的代码示例》:本文主要介绍C#中通过Response.Headers设置自定义响应头的方法,涵盖基础添加、安全校验、生产实践及调试技巧,强... 目录一、基础设置方法1. 直接添加自定义头2. 批量设置模式二、高级配置技巧1. 安全校验机制2. 类型