pytorch中backward函数的参数gradient作用的数学过程

2023-10-28 15:20

本文主要是介绍pytorch中backward函数的参数gradient作用的数学过程,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

pytorch中backward函数的参数gradient作用的数学过程

    • 问题描述
    • 实例分析

本机器学习小白最近在学习pytorch,在学习.backward()函数的过程中一直不能理解参数gradient的作用,感觉相关资料中对它的解释过于简单,几乎忽略了相关数学过程。这里分享一下我的理解,希望能对有需要的同学有些帮助。

.backward()函数是pytorch用来实现反向传播计算的关键,不了解反向传播或者.backward()函数的可以看看下面的文章:

  • pytorch中backward()函数详解
  • pytorch的计算图
  • PyTorch的自动求导机制详细解析,PyTorch的核心魔法
  • 反向传播——通俗易懂
  • 深度学习——以图读懂反向传播

问题描述

官方文档对.backward()的说明是这样的(看英文头痛的同学可以忽略解释内容):
在这里插入图片描述
gradient的解释是这样的:
在这里插入图片描述
可以粗浅的认为backward()就是在根据计算图计算tensor的“梯度”,但是这个“梯度”只有在tensor是标量(scalar)时才是真正意义上的梯度。

gradient在实际使用的过程中,是一个与使用.backward()的tensor维度一致的tensor。官方文档中只提到如果使用.backward()的tensor是一个标量,就可以省略gradient参数,而如果这个tensor是向量,则必须指明gradient参数,但这个参数对计算的作用是什么,文档中说的很模糊(好像就没提到)。

看了一些文章之后发现,“如果是(向量)矩阵对(向量)矩阵求导(tensor对tensor求导),实际上是先求出Jacobian矩阵中每一个元素的梯度值(每一个元素的梯度值的求解过程对应上面的计算图的求解方法),然后将这个Jacobian矩阵与grad_tensors参数对应的矩阵进行对应的点乘,得到最终的结果。”(来源:pytorch中backward()函数详解)
这里提到的grad_tensors参数就是现在的gradient参数。

所以本质上,gradient参数在向量与向量的求导中起作用,而backward()在这种情况下求得的各个元素的梯度实际上并不是Jacobian,而是Jacobian与gradient的乘积。

以下结合一些例子说明backward()函数的计算结果。

实例分析

来源:PyTorch的自动求导机制详细解析,PyTorch的核心魔法

import torchx = torch.tensor([0.0, 2.0, 8.0], requires_grad = True)y = torch.tensor([5.0, 1.0, 7.0], requires_grad = True)z = x * yz.backward(torch.FloatTensor([1.0, 1.0, 1.0]))

运行完之后查看z分别关于x和y的梯度可以发现:

>>>x.grad.data
tensor([5., 1., 7.])
>>>y.grad.data
tensor([0., 2., 8.])

实际上上述代码的计算结果可以这么理解:
x = ( x 1 x 2 x 3 ) = ( 0.0 2.0 8.0 ) y = ( y 1 y 2 y 3 ) = ( 5.0 1.0 7.0 ) x = \begin{pmatrix} x_1 & x_2 & x_3\end{pmatrix} = \begin{pmatrix} 0.0 & 2.0 & 8.0\end{pmatrix}\\ y = \begin{pmatrix} y_1 & y_2 & y_3\end{pmatrix} = \begin{pmatrix} 5.0 & 1.0 & 7.0\end{pmatrix} x=(x1x2x3)=(0.02.08.0)y=(y1y2y3)=(5.01.07.0) z z z向量则是 x x x y y y每项相乘得到的向量(不是点乘也不是叉乘)
z = ( x 1 y 1 x 2 y 2 x 3 y 3 ) z = \begin{pmatrix} x_1y_1 & x_2y_2 & x_3y_3\end{pmatrix} z=(x1y1x2y2x3y3)那么 z z z关于 x x x的Jacobian就是
J = ( ∂ z ∂ x 1 ∂ z ∂ x 2 ∂ z ∂ x 3 ) = ( ∂ z 1 ∂ x 1 ∂ z 1 ∂ x 2 ∂ z 1 ∂ x 3 ∂ z 2 ∂ x 1 ∂ z 2 ∂ x 2 ∂ z 2 ∂ x 3 ∂ z 3 ∂ x 1 ∂ z 3 ∂ x 2 ∂ z 3 ∂ x 3 ) = ( y 1 0 0 0 y 2 0 0 0 y 3 ) J=\begin{pmatrix} \frac{\partial z}{\partial x_1} & \frac{\partial z}{\partial x_2} & \frac{\partial z}{\partial x_3} \end{pmatrix} = \begin{pmatrix} \frac{\partial z_1}{\partial x_1} & \frac{\partial z_1}{\partial x_2} & \frac{\partial z_1}{\partial x_3} \\ \frac{\partial z_2}{\partial x_1} & \frac{\partial z_2}{\partial x_2} & \frac{\partial z_2}{\partial x_3} \\ \frac{\partial z_3}{\partial x_1} & \frac{\partial z_3}{\partial x_2} & \frac{\partial z_3}{\partial x_3} \end{pmatrix} =\begin{pmatrix} y_1 & 0 & 0\\ 0 & y_2 & 0\\ 0 & 0 & y_3 \end{pmatrix} J=(x1zx2zx3z)=x1z1x1z2x1z3x2z1x2z2x2z3x3z1x3z2x3z3=y1000y2000y3而我们引入的gradient参数是一个向量(这里用 v v v表示)
v = ( 1.0 1.0 1.0 ) T v = \begin{pmatrix} 1.0 & 1.0 &1.0 \end{pmatrix}^{T} v=(1.01.01.0)T然后将 J J J v v v相乘,我们就得到了我们看到的x.grad.data的结果:
J v = ( y 1 0 0 0 y 2 0 0 0 y 3 ) ( 1.0 1.0 1.0 ) = ( y 1 y 2 y 3 ) = ( 5.0 1.0 7.0 ) Jv =\begin{pmatrix} y_1 & 0 & 0\\ 0 & y_2 & 0\\ 0 & 0 & y_3 \end{pmatrix} \begin{pmatrix} 1.0 \\ 1.0 \\ 1.0 \end{pmatrix} = \begin{pmatrix} y_1 \\ y_2 \\ y_3 \end{pmatrix} = \begin{pmatrix} 5.0 \\ 1.0 \\ 7.0 \end{pmatrix} Jv=y1000y2000y31.01.01.0=y1y2y3=5.01.07.0y.grad.data的结果这里就不演示了,因为这个例子里 x x x y y y具有对称性,只需要把结果的 y y y换成 x x x就得到了y.grad.data的结果,所以我们可以看到x.grad.data和y.grad.data在数值上就是 y y y x x x的值。

这篇关于pytorch中backward函数的参数gradient作用的数学过程的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/294116

相关文章

oracle 11g导入\导出(expdp impdp)之导入过程

《oracle11g导入导出(expdpimpdp)之导入过程》导出需使用SEC.DMP格式,无分号;建立expdir目录(E:/exp)并确保存在;导入在cmd下执行,需sys用户权限;若需修... 目录准备文件导入(impdp)1、建立directory2、导入语句 3、更改密码总结上一个环节,我们讲了

C++统计函数执行时间的最佳实践

《C++统计函数执行时间的最佳实践》在软件开发过程中,性能分析是优化程序的重要环节,了解函数的执行时间分布对于识别性能瓶颈至关重要,本文将分享一个C++函数执行时间统计工具,希望对大家有所帮助... 目录前言工具特性核心设计1. 数据结构设计2. 单例模式管理器3. RAII自动计时使用方法基本用法高级用法

ShardingProxy读写分离之原理、配置与实践过程

《ShardingProxy读写分离之原理、配置与实践过程》ShardingProxy是ApacheShardingSphere的数据库中间件,通过三层架构实现读写分离,解决高并发场景下数据库性能瓶... 目录一、ShardingProxy技术定位与读写分离核心价值1.1 技术定位1.2 读写分离核心价值二

MyBatis-plus处理存储json数据过程

《MyBatis-plus处理存储json数据过程》文章介绍MyBatis-Plus3.4.21处理对象与集合的差异:对象可用内置Handler配合autoResultMap,集合需自定义处理器继承F... 目录1、如果是对象2、如果需要转换的是List集合总结对象和集合分两种情况处理,目前我用的MP的版本

SpringBoot 获取请求参数的常用注解及用法

《SpringBoot获取请求参数的常用注解及用法》SpringBoot通过@RequestParam、@PathVariable等注解支持从HTTP请求中获取参数,涵盖查询、路径、请求体、头、C... 目录SpringBoot 提供了多种注解来方便地从 HTTP 请求中获取参数以下是主要的注解及其用法:1

HTTP 与 SpringBoot 参数提交与接收协议方式

《HTTP与SpringBoot参数提交与接收协议方式》HTTP参数提交方式包括URL查询、表单、JSON/XML、路径变量、头部、Cookie、GraphQL、WebSocket和SSE,依据... 目录HTTP 协议支持多种参数提交方式,主要取决于请求方法(Method)和内容类型(Content-Ty

Java Kafka消费者实现过程

《JavaKafka消费者实现过程》Kafka消费者通过KafkaConsumer类实现,核心机制包括偏移量管理、消费者组协调、批量拉取消息及多线程处理,手动提交offset确保数据可靠性,自动提交... 目录基础KafkaConsumer类分析关键代码与核心算法2.1 订阅与分区分配2.2 拉取消息2.3

GO语言中函数命名返回值的使用

《GO语言中函数命名返回值的使用》在Go语言中,函数可以为其返回值指定名称,这被称为命名返回值或命名返回参数,这种特性可以使代码更清晰,特别是在返回多个值时,感兴趣的可以了解一下... 目录基本语法函数命名返回特点代码示例命名特点基本语法func functionName(parameters) (nam

Python Counter 函数使用案例

《PythonCounter函数使用案例》Counter是collections模块中的一个类,专门用于对可迭代对象中的元素进行计数,接下来通过本文给大家介绍PythonCounter函数使用案例... 目录一、Counter函数概述二、基本使用案例(一)列表元素计数(二)字符串字符计数(三)元组计数三、C

python中的显式声明类型参数使用方式

《python中的显式声明类型参数使用方式》文章探讨了Python3.10+版本中类型注解的使用,指出FastAPI官方示例强调显式声明参数类型,通过|操作符替代Union/Optional,可提升代... 目录背景python函数显式声明的类型汇总基本类型集合类型Optional and Union(py