torch.gather用法详解

2024-04-20 20:52
文章标签 详解 用法 torch gather

本文主要是介绍torch.gather用法详解,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

torch.gather是PyTorch中的一个函数,用于从源张量中按照指定的索引张量来收集数据。

基本语法如下,

torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
  • input:输入源张量
  • dim:要收集数据的维度
  • index:索引
  • sparse_grad:如果为True,则gather()在反向传播时会返回稀疏梯度
  • out:输出张量,形状与index相同

用法讲解

假设有以下输入张量x,

x = torch.tensor([[[ 1,  2],[ 3,  4]],[[ 5,  6],[ 7,  8]],[[ 9, 10],[11, 12]]
])

假设有以下索引index,

index = torch.tensor([[[0, 1],[1, 0]],[[1, 0],[0, 1]],[[0, 1],[1, 0]]
])

index的索引及里面的元素的对应关系如下,

index[0, 0, 0] = 0
index[0, 0, 1] = 1
index[0, 1, 0] = 1
index[0, 1, 1] = 0
index[1, 0, 0] = 1
index[1, 0, 1] = 0
index[1, 1, 0] = 0
index[1, 1, 1] = 1
index[2, 0, 0] = 0
index[2, 0, 1] = 1
index[2, 1, 0] = 1
index[2, 1, 1] = 0

接下来,有3种情况出现,分别是dim=0、dim=1、dim=2 

dim=0

拿index里的元素值去替换对应索引中第1个维度的数值,得到新索引,

(索引, 索引对应的元素) -> 新索引
[0, 0, 0], 0 -> [0, 0, 0]
[0, 0, 1], 1 -> [1, 0, 1]
[0, 1, 0], 1 -> [1, 1, 0]
[0, 1, 1], 0 -> [0, 1, 1]
[1, 0, 0], 1 -> [1, 0, 0]
[1, 0, 1], 0 -> [0, 0, 1]
[1, 1, 0], 0 -> [0, 1, 0]
[1, 1, 1], 1 -> [1, 1, 1]
[2, 0, 0], 0 -> [0, 0, 0]
[2, 0, 1], 1 -> [1, 0, 1]
[2, 1, 0], 1 -> [1, 1, 0]
[2, 1, 1], 0 -> [0, 1, 1]

有了新索引后,便可根据新索引从输入张量中获取输出张量,

result = torch.gather(x, 0, index)
“”“
预测值:
result = [[[x[0, 0, 0], x[1, 0, 1]],[x[1, 1, 0], x[0, 1, 1]],[[x[1, 0, 0], x[0, 0, 1],[x[0, 1, 0], x[1, 1, 1]],[[x[0, 0, 0], x[1, 0, 1], [x[1, 1, 0], x[0, 1, 1]]]]=[[[1, 6],[7, 4]],[[5, 2],[3, 8]],[[1, 6],[7, 4]]]
”“”

打印输出张量, 

print(result)
"""
实际值:
tensor([[[1, 6],[7, 4]],[[5, 2],[3, 8]],[[1, 6],[7, 4]]])
"""

dim=1

拿index里的元素值去替换对应索引中第2个维度的数值,得到新索引,

(索引, 索引对应的元素) -> 新索引
[0, 0, 0], 0 -> [0, 0, 0]
[0, 0, 1], 1 -> [0, 1, 1]
[0, 1, 0], 1 -> [0, 1, 0]
[0, 1, 1], 0 -> [0, 0, 1]
[1, 0, 0], 1 -> [1, 1, 0]
[1, 0, 1], 0 -> [1, 0, 1]
[1, 1, 0], 0 -> [1, 0, 0]
[1, 1, 1], 1 -> [1, 1, 1]
[2, 0, 0], 0 -> [2, 0, 0]
[2, 0, 1], 1 -> [2, 1, 1]
[2, 1, 0], 1 -> [2, 1, 0]
[2, 1, 1], 0 -> [2, 0, 1]

有了新索引后,便可根据新索引从输入张量中获取输出张量,

result = torch.gather(x, 0, index)
“”“
预测值:
result = [[[x[0, 0, 0], x[0, 1, 1]],[x[0, 1, 0], x[0, 0, 1]],[[x[1, 1, 0], x[1, 0, 1],[x[1, 0, 0], x[1, 1, 1]],[[x[2, 0, 0], x[2, 1, 1], [x[2, 1, 0], x[2, 0, 1]]]]=[[[1, 4],[3, 2]],[[7, 6],[5, 8]],[[9, 12],[11, 10]]]
”“”

打印输出张量, 

print(result)
"""
实际值:
tensor([[[ 1,  4],[ 3,  2]],[[ 7,  6],[ 5,  8]],[[ 9, 12],[11, 10]]])
"""

dim=3

拿index里的元素值去替换对应索引中第3个维度的数值,得到新索引,

(索引, 索引对应的元素) -> 新索引
[0, 0, 0], 0 -> [0, 0, 0]
[0, 0, 1], 1 -> [0, 0, 1]
[0, 1, 0], 1 -> [0, 1, 1]
[0, 1, 1], 0 -> [0, 1, 0]
[1, 0, 0], 1 -> [1, 0, 1]
[1, 0, 1], 0 -> [1, 0, 0]
[1, 1, 0], 0 -> [1, 1, 0]
[1, 1, 1], 1 -> [1, 1, 1]
[2, 0, 0], 0 -> [2, 0, 0]
[2, 0, 1], 1 -> [2, 0, 1]
[2, 1, 0], 1 -> [2, 1, 1]
[2, 1, 1], 0 -> [2, 1, 0]

有了新索引后,便可根据新索引从输入张量中获取输出张量,

result = torch.gather(x, 0, index)
“”“
预测值:
result = [[[x[0, 0, 0], x[0, 0, 1]],[x[0, 1, 1], x[0, 1, 0]],[[x[1, 0, 1], x[1, 0, 0],[x[1, 1, 0], x[1, 1, 1]],[[x[2, 0, 0], x[2, 0, 1], [x[2, 1, 1], x[2, 1, 0]]]]=[[[1, 2],[4, 3]],[[6, 5],[7, 8]],[[9, 10],[12, 11]]]
”“”

打印输出张量, 

print(result)
"""
实际值:
tensor([[[ 1,  2],[ 4,  3]],[[ 6,  5],[ 7,  8]],[[ 9, 10],[12, 11]]])
"""

这篇关于torch.gather用法详解的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/921297

相关文章

HTML5 搜索框Search Box详解

《HTML5搜索框SearchBox详解》HTML5的搜索框是一个强大的工具,能够有效提升用户体验,通过结合自动补全功能和适当的样式,可以创建出既美观又实用的搜索界面,这篇文章给大家介绍HTML5... html5 搜索框(Search Box)详解搜索框是一个用于输入查询内容的控件,通常用于网站或应用程

mapstruct中的@Mapper注解的基本用法

《mapstruct中的@Mapper注解的基本用法》在MapStruct中,@Mapper注解是核心注解之一,用于标记一个接口或抽象类为MapStruct的映射器(Mapper),本文给大家介绍ma... 目录1. 基本用法2. 常用属性3. 高级用法4. 注意事项5. 总结6. 编译异常处理在MapSt

Python中使用uv创建环境及原理举例详解

《Python中使用uv创建环境及原理举例详解》uv是Astral团队开发的高性能Python工具,整合包管理、虚拟环境、Python版本控制等功能,:本文主要介绍Python中使用uv创建环境及... 目录一、uv工具简介核心特点:二、安装uv1. 通过pip安装2. 通过脚本安装验证安装:配置镜像源(可

C++ 函数 strftime 和时间格式示例详解

《C++函数strftime和时间格式示例详解》strftime是C/C++标准库中用于格式化日期和时间的函数,定义在ctime头文件中,它将tm结构体中的时间信息转换为指定格式的字符串,是处理... 目录C++ 函数 strftipythonme 详解一、函数原型二、功能描述三、格式字符串说明四、返回值五

LiteFlow轻量级工作流引擎使用示例详解

《LiteFlow轻量级工作流引擎使用示例详解》:本文主要介绍LiteFlow是一个灵活、简洁且轻量的工作流引擎,适合用于中小型项目和微服务架构中的流程编排,本文给大家介绍LiteFlow轻量级工... 目录1. LiteFlow 主要特点2. 工作流定义方式3. LiteFlow 流程示例4. LiteF

CSS3中的字体及相关属性详解

《CSS3中的字体及相关属性详解》:本文主要介绍了CSS3中的字体及相关属性,详细内容请阅读本文,希望能对你有所帮助... 字体网页字体的三个来源:用户机器上安装的字体,放心使用。保存在第三方网站上的字体,例如Typekit和Google,可以link标签链接到你的页面上。保存在你自己Web服务器上的字

java中long的一些常见用法

《java中long的一些常见用法》在Java中,long是一种基本数据类型,用于表示长整型数值,接下来通过本文给大家介绍java中long的一些常见用法,感兴趣的朋友一起看看吧... 在Java中,long是一种基本数据类型,用于表示长整型数值。它的取值范围比int更大,从-922337203685477

MySQL存储过程之循环遍历查询的结果集详解

《MySQL存储过程之循环遍历查询的结果集详解》:本文主要介绍MySQL存储过程之循环遍历查询的结果集,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录前言1. 表结构2. 存储过程3. 关于存储过程的SQL补充总结前言近来碰到这样一个问题:在生产上导入的数据发现

MyBatis ResultMap 的基本用法示例详解

《MyBatisResultMap的基本用法示例详解》在MyBatis中,resultMap用于定义数据库查询结果到Java对象属性的映射关系,本文给大家介绍MyBatisResultMap的基本... 目录MyBATis 中的 resultMap1. resultMap 的基本语法2. 简单的 resul

从基础到进阶详解Pandas时间数据处理指南

《从基础到进阶详解Pandas时间数据处理指南》Pandas构建了完整的时间数据处理生态,核心由四个基础类构成,Timestamp,DatetimeIndex,Period和Timedelta,下面我... 目录1. 时间数据类型与基础操作1.1 核心时间对象体系1.2 时间数据生成技巧2. 时间索引与数据