Pytorch常用的函数(九)torch.gather()用法

2024-05-08 20:20

本文主要是介绍Pytorch常用的函数(九)torch.gather()用法,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

Pytorch常用的函数(九)torch.gather()用法

torch.gather() 就是在指定维度上收集value。

torch.gather() 的必填也是最常用的参数有三个,下面引用官方解释:

  • input (Tensor) – the source tensor
  • dim (int) – the axis along which to index
  • index (LongTensor) – the indices of elements to gather

一句话概括 gather 操作就是:根据 index ,在 inputdim 维度上收集 value

1、举例直观理解

# 1、我们有input_tensor如下
>>> input_tensor = torch.arange(24).reshape(2, 3, 4)
tensor([[[ 0,  1,  2,  3],[ 4,  5,  6,  7],[ 8,  9, 10, 11]],[[12, 13, 14, 15],[16, 17, 18, 19],[20, 21, 22, 23]]])# 2、我们有index_tensor如下
>>> index_tensor = torch.tensor([[[0, 0, 0, 0],[2, 2, 2, 2]],[[0, 0, 0, 0],[2, 2, 2, 2]]]
)	# 3、我们通过torch.gather()函数获取out_tensor
>>> out_tensor = torch.gather(input_tensor, dim=1, index=index_tensor)
tensor([[[ 0,  1,  2,  3],[ 8,  9, 10, 11]],[[12, 13, 14, 15],[20, 21, 22, 23]]])

我们以out_tensor中[0,1,0]=8为例,解释下如何利用dim和index,从input_tensor中获得8。

在这里插入图片描述

根据上图,我们很直观的了解根据 index ,在 inputdim 维度上收集 value的过程。

  • 假设 inputindex 均为三维数组,那么输出 tensor 每个位置的索引是列表 [i, j, k] ,正常来说我们直接取 input[i, j, k] 作为 输出 tensor 对应位置的值即可;
  • 但是由于 dim 的存在以及 input.shape 可能不等于 index.shape ,所以直接取值可能就会报错 ;
  • 所以我们是将索引列表的相应位置替换为 dim ,再去 input 取值。在上面示例中,由于dim=1,那么我们就替换索引列表第1个值,即[i,dim,k],因此由原来的[0,1,0]替换为[0,2,0]后,再去input_tensor中取值。
  • pytorch官方文档的写法如下,同一个意思。
out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

2、反推法再理解

# 1、我们有input_tensor如下
>>> input_tensor = torch.arange(24).reshape(2, 3, 4)
tensor([[[ 0,  1,  2,  3],[ 4,  5,  6,  7],[ 8,  9, 10, 11]],[[12, 13, 14, 15],[16, 17, 18, 19],[20, 21, 22, 23]]])# 2、假设我们要得到out_tensor如下
>>> out_tensor
tensor([[[ 0,  1,  2,  3],[ 8,  9, 10, 11]],[[12, 13, 14, 15],[20, 21, 22, 23]]])# 3、如何知道dim 和 index_tensor呢? 
# 首先,我们要记住:out_tensor的shape = index_tensor的shape# 从 output_tensor 的第一个位置开始:
# 此时[i, j, k]一样,看不出来 dim 应该是多少
output_tensor[0, 0, :] = input_tensor[0, 0, :] = 0
# 同理可知,此时index都为0
output_tensor[0, 0, 1] = input_tensor[0, 0, 1] = 1
output_tensor[0, 0, 2] = input_tensor[0, 0, 2] = 2
output_tensor[0, 0, 3] = input_tensor[0, 0, 3] = 3# 我们从下一行的第一个位置开始:
# 这里我们看到维度 1 发生了变化,1 变成了 2,所以 dim 应该是 1,而 index 应为 2
output_tensor[0, 1, 0] = input_tensor[0, 2, 0] = 8
# 同理可知,此时index都为2
output_tensor[0, 1, 1] = input_tensor[0, 2, 1] = 9
output_tensor[0, 1, 2] = input_tensor[0, 2, 2] = 10
output_tensor[0, 1, 3] = input_tensor[0, 2, 3] = 11# 根据上面推导我们易知dim=1,index_tensor为:
>>> index_tensor = torch.tensor([[[0, 0, 0, 0],[2, 2, 2, 2]],[[0, 0, 0, 0],[2, 2, 2, 2]]]
)	

3、实际案例

在大神何凯明MAE模型(Masked Autoencoders Are Scalable Vision Learners)源码中,多次使用了torch.gather() 函数。

  • 论文链接:https://arxiv.org/pdf/2111.06377
  • 官方源码:https://github.com/facebookresearch/mae

在MAE中根据预设的掩码比例(paper 中提倡的是 75%),使用服从均匀分布的随机采样策略采样一部分 tokens 送给 Encoder,另一部分mask 掉。采样25%作为unmasked tokens过程中,使用了torch.gather() 函数。

# models_mae.pyimport torchdef random_masking(x, mask_ratio=0.75):"""Perform per-sample random masking by per-sample shuffling.Per-sample shuffling is done by argsort random noise.x: [N, L, D], sequence"""N, L, D = x.shape  # batch, length, dimlen_keep = int(L * (1 - mask_ratio))  # 计算unmasked的片数# 利用0-1均匀分布进行采样,避免潜在的【中心归纳偏好】noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]# sort noise for each sample【核心代码】ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is removeids_restore = torch.argsort(ids_shuffle, dim=1)# keep the first subsetids_keep = ids_shuffle[:, :len_keep]# 利用torch.gather()从源tensor中获取25%的unmasked tokensx_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))# generate the binary mask: 0 is keep, 1 is removemask = torch.ones([N, L], device=x.device)mask[:, :len_keep] = 0# unshuffle to get the binary maskmask = torch.gather(mask, dim=1, index=ids_restore)return x_masked, mask, ids_restoreif __name__ == '__main__':x = torch.arange(64).reshape(1, 16, 4)random_masking(x)
# x模拟一张图片经过patch_embedding后的序列
# x相当于input_tensor
# 16是patch数量,实际上一般为(img_size/patch_size)^2 = (224 / 16)^2 = 14*14=196
# 4是一个patch中像素个数,这里只是模拟,实际上一般为(in_chans * patch_size * patch_size = 3*16*16 = 768)
>>> x = torch.arange(64).reshape(1, 16, 4) 
tensor([[[ 0,  1,  2,  3],[ 4,  5,  6,  7],[ 8,  9, 10, 11],[12, 13, 14, 15],[16, 17, 18, 19], # 4[20, 21, 22, 23],[24, 25, 26, 27],[28, 29, 30, 31],[32, 33, 34, 35],[36, 37, 38, 39],[40, 41, 42, 43], # 10[44, 45, 46, 47],[48, 49, 50, 51], # 12[52, 53, 54, 55], # 13[56, 57, 58, 59],[60, 61, 62, 63]]])
# dim=1, index相当于index_tensor
>>> index
tensor([[[10, 10, 10, 10],[12, 12, 12, 12],[ 4,  4,  4,  4],[13, 13, 13, 13]]])# x_masked(从源tensor即x中,随机获取25%(4个patch)的unmasked tokens)     
>>> x_masked相当于out_tensor
tensor([[[40, 41, 42, 43],[48, 49, 50, 51],[16, 17, 18, 19],[52, 53, 54, 55]]])

这篇关于Pytorch常用的函数(九)torch.gather()用法的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/971338

相关文章

MyBatis常用XML语法详解

《MyBatis常用XML语法详解》文章介绍了MyBatis常用XML语法,包括结果映射、查询语句、插入语句、更新语句、删除语句、动态SQL标签以及ehcache.xml文件的使用,感兴趣的朋友跟随小... 目录1、定义结果映射2、查询语句3、插入语句4、更新语句5、删除语句6、动态 SQL 标签7、ehc

JDK21对虚拟线程的几种用法实践指南

《JDK21对虚拟线程的几种用法实践指南》虚拟线程是Java中的一种轻量级线程,由JVM管理,特别适合于I/O密集型任务,:本文主要介绍JDK21对虚拟线程的几种用法,文中通过代码介绍的非常详细,... 目录一、参考官方文档二、什么是虚拟线程三、几种用法1、Thread.ofVirtual().start(

Python函数作用域与闭包举例深度解析

《Python函数作用域与闭包举例深度解析》Python函数的作用域规则和闭包是编程中的关键概念,它们决定了变量的访问和生命周期,:本文主要介绍Python函数作用域与闭包的相关资料,文中通过代码... 目录1. 基础作用域访问示例1:访问全局变量示例2:访问外层函数变量2. 闭包基础示例3:简单闭包示例4

Java8 Collectors.toMap() 的两种用法

《Java8Collectors.toMap()的两种用法》Collectors.toMap():JDK8中提供,用于将Stream流转换为Map,本文给大家介绍Java8Collector... 目录一、简单介绍用法1:根据某一属性,对对象的实例或属性做映射用法2:根据某一属性,对对象集合进行去重二、Du

Python打包成exe常用的四种方法小结

《Python打包成exe常用的四种方法小结》本文主要介绍了Python打包成exe常用的四种方法,包括PyInstaller、cx_Freeze、Py2exe、Nuitka,文中通过示例代码介绍的非... 目录一.PyInstaller11.安装:2. PyInstaller常用参数下面是pyinstal

Python中isinstance()函数原理解释及详细用法示例

《Python中isinstance()函数原理解释及详细用法示例》isinstance()是Python内置的一个非常有用的函数,用于检查一个对象是否属于指定的类型或类型元组中的某一个类型,它是Py... 目录python中isinstance()函数原理解释及详细用法指南一、isinstance()函数

python中的高阶函数示例详解

《python中的高阶函数示例详解》在Python中,高阶函数是指接受函数作为参数或返回函数作为结果的函数,下面:本文主要介绍python中高阶函数的相关资料,文中通过代码介绍的非常详细,需要的朋... 目录1.定义2.map函数3.filter函数4.reduce函数5.sorted函数6.自定义高阶函数

Python 常用数据类型详解之字符串、列表、字典操作方法

《Python常用数据类型详解之字符串、列表、字典操作方法》在Python中,字符串、列表和字典是最常用的数据类型,它们在数据处理、程序设计和算法实现中扮演着重要角色,接下来通过本文给大家介绍这三种... 目录一、字符串(String)(一)创建字符串(二)字符串操作1. 字符串连接2. 字符串重复3. 字

Python中的sort方法、sorted函数与lambda表达式及用法详解

《Python中的sort方法、sorted函数与lambda表达式及用法详解》文章对比了Python中list.sort()与sorted()函数的区别,指出sort()原地排序返回None,sor... 目录1. sort()方法1.1 sort()方法1.2 基本语法和参数A. reverse参数B.

vue监听属性watch的用法及使用场景详解

《vue监听属性watch的用法及使用场景详解》watch是vue中常用的监听器,它主要用于侦听数据的变化,在数据发生变化的时候执行一些操作,:本文主要介绍vue监听属性watch的用法及使用场景... 目录1. 监听属性 watch2. 常规用法3. 监听对象和route变化4. 使用场景附Watch 的