python中pytorch的广播机制——Broadcasting

2023-10-11 02:36

本文主要是介绍python中pytorch的广播机制——Broadcasting,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

广播机制

numpy 在算术运算期间采用“广播”来处理具有不同形状的 array ,即将较小的阵列在较大的阵列上“广播”,以便它们具有兼容的形状。Broadcasting是一种没有copy数据的expand

  • 不过两个维度不相同,在前面插入维度1
  • 扩张维度1到相同的维度

例如:Feature maps:[4,32,14,14]
Bias:[32,1,1]=>[1,32,1,1]=>[4,32,14,14]

A:[32,1,1]=>[1,32,1,1]=>[4,32,14,14]
B:[4,32,14,14]
这里就可以进行相同维度的相加

image


比如说一个[4,1]+[1,2]
那么这个[4,1]可以再复制列变为[4,2]
[1,2]可以再复制4行变为[4,2]

首先用1将那个小的维度的tensor扩展成大的维度相同的维度,然后将1扩张成两者的相同维度,如果有两个维度不相同,并且都不是1的话,则不能broadcasting

 

广播规则

当对两个 array 进行操作时,numpy 会逐元素比较它们的形状。从尾(即最右边)维度开始,然后向左逐渐比较。只有当两个维度 1)相等 or 2)其中一个维度是1 时,这两个维度才会被认为是兼容。

如果不满足这些条件,则会抛出 ValueError:operands could not be broadcast together 异常,表明 array 的形状不兼容。最终结果 array 的每个维度尽可能不为 1 ,是两个操作数各个维度中较大的值 。

例如,有一个 256x256x3 的 RGB 值图片 array ,需要将图像中的每种颜色缩放不同的值,此时可以将图像乘以具有 3 个值的一维 array 。根据广播规则排列这两个 array 的尾维度大小,是兼容的:

 图片(3d array): 256 x 256 x 3
缩放(1d array):             3
结果(3d array): 256 x 256 x 3

当比较的任一维度是 1 时,使用另一个,也就是说,大小为 1 的维度被拉伸或“复制”以匹配另一个维度。
在以下示例中,A 和 B 数组都有长度为 1 的维度,在广播操作期间扩展为更大的大小:

A      (4d array):  8 x 1 x 6 x 1
B      (3d array):      7 x 1 x 5
result (4d array):  8 x 7 x 6 x 5

以二维为例,更加方便的解释“广播”:
已知 a.shape 是(5,1),b.shape 是(1,6),c.shape 是(6,),d.shape 是(), d 是一个标量, a, b, c,和 d 都可以“广播”到维度 (5,6);

a “广播”为一个 (5,6) array ,其中 a[:,0] 被“广播”到其他列,
b “广播”为一个 (5,6) array ,其中 b[0,:] 被广播到其他行,
c 类似于 (1,6) array ,其中 c[:] 广播到每一行,
d 是标量,“广播”为 (5,6) array ,其中每个元素都一样,重复d值。
 

A      (2d array):      2 x 1
B      (3d array):  8 x 4 x 3 # 倒数第二个维度不兼容
>>> a = np.array([[ 0.0,  0.0,  0.0],
...               [10.0, 10.0, 10.0],
...               [20.0, 20.0, 20.0],
...               [30.0, 30.0, 30.0]])
>>> b = np.array([1.0, 2.0, 3.0])
>>> a + b
array([[  1.,   2.,   3.],[11.,  12.,  13.],[21.,  22.,  23.],[31.,  32.,  33.]])
>>> b = np.array([1.0, 2.0, 3.0, 4.0])
>>> a + b
Traceback (most recent call last):
ValueError: operands could not be broadcast together with shapes (4,3) (4,)

 

 

在某些情况下,广播会拉伸两个 array 以形成一个大于任何一个初始 array 的结果 array 。 

>>> a = np.array([0.0, 10.0, 20.0, 30.0])
>>> b = np.array([1.0, 2.0, 3.0])
>>> a[:, np.newaxis] + b
array([[ 1.,   2.,   3.],[11.,  12.,  13.],[21.,  22.,  23.],[31.,  32.,  33.]])

 

newaxis 运算符将新轴插入到 a 中,使其成为二维 4x1 array 。将 4x1 array 与形状为 (3,) 的 b 组合,产生一个 4x3 array 。 

这里注意要都从右端进行匹配:
A:[                     ]
B:          [           ]
就是这样补充
我们看个例子吧:

a=torch.randn(2,3,4)
b=torch.randn(2,3)
a+b
#The size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 2

image


但是这样是可以的

image


也就是(2,3,4)+(2,3)是不可以的,(2,3,4)+(3,4)是可以的,因为他们是右看齐的。

Situation 1:
▪ [4, 32, 14, 14]
▪ [1, 32, 1, 1] => [4, 32, 14, 14]

Situation 2
▪ [4, 32, 14, 14]
▪ [14, 14] => [1, 1, 14, 14] => [4, 32, 14, 14]

Situation 3
▪ [4, 32, 14, 14]
▪ [2, 32, 14, 14]
▪ Dim 0 has dim, can NOT insert and expand to same
▪ Dim 0 has distinct dim, NOT size 1
▪ NOT broadcasting-able

Situation 4
▪ [4, 32, 14, 14]
▪ [4, 32, 14]
这样是不行的,因为我们要右看齐,match from
last dim

Situation 5
▪ [4, 3, 32, 32]
▪ + [32, 32]
▪ + [3, 1, 1]
▪ + [1, 1, 1, 1]
这都是可以的

这篇关于python中pytorch的广播机制——Broadcasting的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!


原文地址:
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.chinasem.cn/article/184911

相关文章

利用Python实现可回滚方案的示例代码

《利用Python实现可回滚方案的示例代码》很多项目翻车不是因为不会做,而是走错了方向却没法回头,技术选型失败的风险我们都清楚,但真正能提前规划“回滚方案”的人不多,本文从实际项目出发,教你如何用Py... 目录描述题解答案(核心思路)题解代码分析第一步:抽象缓存接口第二步:实现两个版本第三步:根据 Fea

Python中CSV文件处理全攻略

《Python中CSV文件处理全攻略》在数据处理和存储领域,CSV格式凭借其简单高效的特性,成为了电子表格和数据库中常用的文件格式,Python的csv模块为操作CSV文件提供了强大的支持,本文将深入... 目录一、CSV 格式简介二、csv模块核心内容(一)模块函数(二)模块类(三)模块常量(四)模块异常

Python报错ModuleNotFoundError的10种解决方案

《Python报错ModuleNotFoundError的10种解决方案》在Python开发中,ModuleNotFoundError是最常见的运行时错误之一,通常由模块路径配置错误、依赖缺失或命名冲... 目录一、常见错误场景与原因分析二、10种解决方案与代码示例1. 检查并安装缺失模块2. 动态添加模块

python利用backoff实现异常自动重试详解

《python利用backoff实现异常自动重试详解》backoff是一个用于实现重试机制的Python库,通过指数退避或其他策略自动重试失败的操作,下面小编就来和大家详细讲讲如何利用backoff实... 目录1. backoff 库简介2. on_exception 装饰器的原理2.1 核心逻辑2.2

python如何下载网络文件到本地指定文件夹

《python如何下载网络文件到本地指定文件夹》这篇文章主要为大家详细介绍了python如何实现下载网络文件到本地指定文件夹,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下...  在python中下载文件到本地指定文件夹可以通过以下步骤实现,使用requests库处理HTTP请求,并结合o

Python实现获取带合并单元格的表格数据

《Python实现获取带合并单元格的表格数据》由于在日常运维中经常出现一些合并单元格的表格,如果要获取数据比较麻烦,所以本文我们就来聊聊如何使用Python实现获取带合并单元格的表格数据吧... 由于在日常运维中经常出现一些合并单元格的表格,如果要获取数据比较麻烦,现将将封装成类,并通过调用list_exc

Python logging模块使用示例详解

《Pythonlogging模块使用示例详解》Python的logging模块是一个灵活且强大的日志记录工具,广泛应用于应用程序的调试、运行监控和问题排查,下面给大家介绍Pythonlogging模... 目录一、为什么使用 logging 模块?二、核心组件三、日志级别四、基本使用步骤五、快速配置(bas

Python日期和时间完全指南与实战

《Python日期和时间完全指南与实战》在软件开发领域,‌日期时间处理‌是贯穿系统设计全生命周期的重要基础能力,本文将深入解析Python日期时间的‌七大核心模块‌,通过‌企业级代码案例‌揭示最佳实践... 目录一、背景与核心价值二、核心模块详解与实战2.1 datetime模块四剑客2.2 时区处理黄金法

PyTorch高级特性与性能优化方式

《PyTorch高级特性与性能优化方式》:本文主要介绍PyTorch高级特性与性能优化方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、自动化机制1.自动微分机制2.动态计算图二、性能优化1.内存管理2.GPU加速3.多GPU训练三、分布式训练1.分布式数据

Python文件操作与IO流的使用方式

《Python文件操作与IO流的使用方式》:本文主要介绍Python文件操作与IO流的使用方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、python文件操作基础1. 打开文件2. 关闭文件二、文件读写操作1.www.chinasem.cn 读取文件2. 写