torch.gather——沿特定维度收集数值

2024-01-28 06:30

本文主要是介绍torch.gather——沿特定维度收集数值,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

PyTorch学习笔记:torch.gather——沿特定维度收集数值

torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor

功能:从输入的数组中,沿指定的dim维度,利用索引变量index,将数据索引出来,并且堆叠成一个数组。直观可能不好理解,具体可以见代码案例。

输入:

input:输入的数组

dim:指定的维度

index:索引变量,数据类型需是长整型(int64)

注意:

  • inputindex具有相同的维数

  • outindex具有相同的形状

  • 除了dim维度,在每个维度上,索引在该维度上的大小要小于等于输入在该维度上的大小,即:
    i n d e x . s i z e ( d ) ≤ i n p u t . s i z e ( d ) , d ! = d i m index.size(d)≤input.size(d),\quad d!=dim index.size(d)input.size(d),d!=dim

代码案例

一般用法,当在一个维度上进行索引时,以第一维度为例

import torch
a=torch.arange(20).reshape(4,5)
index=torch.tensor([[2,3,1,3]])
b=torch.gather(a,dim=0,index=index)
print(a)
print(b)

输出

在这里插入图片描述

以第二维度为例

import torch
a=torch.arange(20).reshape(4,5)
index=torch.tensor([[2],[3],[1],[3]])
b=torch.gather(a,dim=1,index=index)
print(a)
print(b)

输出

在这里插入图片描述

当同时在两个维度上进行索引时,以第一维度为例

import torch
a=torch.arange(20).reshape(4,5)
index=torch.tensor([[1,2,3],[2,3,0],[3,0,1]])
b=torch.gather(a,dim=0,index=index)
print(a)
print(b)

输出

tensor([[ 0,  1,  2,  3,  4],[ 5,  6,  7,  8,  9],[10, 11, 12, 13, 14],[15, 16, 17, 18, 19]])
tensor([[ 5, 11, 17],[10, 16,  2],[15,  1,  7]])

以第二维度为例

import torch
a=torch.arange(20).reshape(4,5)
index=torch.tensor([[1,2],[2,3],[3,4]])
b=torch.gather(a,dim=1,index=index)
print(a)
print(b)

输出

tensor([[ 0,  1,  2,  3,  4],[ 5,  6,  7,  8,  9],[10, 11, 12, 13, 14],[15, 16, 17, 18, 19]])
tensor([[ 1,  2],[ 7,  8],[13, 14]])

官方文档

torch.gather:https://pytorch.org/docs/stable/generated/torch.gather.html?highlight=torch%20gather#torch.gather

这篇关于torch.gather——沿特定维度收集数值的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/652751

相关文章

从基础到高级详解Python数值格式化输出的完全指南

《从基础到高级详解Python数值格式化输出的完全指南》在数据分析、金融计算和科学报告领域,数值格式化是提升可读性和专业性的关键技术,本文将深入解析Python中数值格式化输出的相关方法,感兴趣的小伙... 目录引言:数值格式化的核心价值一、基础格式化方法1.1 三种核心格式化方式对比1.2 基础格式化示例

MySQL按时间维度对亿级数据表进行平滑分表

《MySQL按时间维度对亿级数据表进行平滑分表》本文将以一个真实的4亿数据表分表案例为基础,详细介绍如何在不影响线上业务的情况下,完成按时间维度分表的完整过程,感兴趣的小伙伴可以了解一下... 目录引言一、为什么我们需要分表1.1 单表数据量过大的问题1.2 分表方案选型二、分表前的准备工作2.1 数据评估

Linux从文件中提取特定内容的实用技巧分享

《Linux从文件中提取特定内容的实用技巧分享》在日常数据处理和配置文件管理中,我们经常需要从大型文件中提取特定内容,本文介绍的提取特定行技术正是这些高级操作的基础,以提取含有1的简单需求为例,我们可... 目录引言1、方法一:使用 grep 命令1.1 grep 命令基础1.2 命令详解1.3 高级用法2

MySQL 强制使用特定索引的操作

《MySQL强制使用特定索引的操作》MySQL可通过FORCEINDEX、USEINDEX等语法强制查询使用特定索引,但优化器可能不采纳,需结合EXPLAIN分析执行计划,避免性能下降,注意版本差异... 目录1. 使用FORCE INDEX语法2. 使用USE INDEX语法3. 使用IGNORE IND

MySQL查询JSON数组字段包含特定字符串的方法

《MySQL查询JSON数组字段包含特定字符串的方法》在MySQL数据库中,当某个字段存储的是JSON数组,需要查询数组中包含特定字符串的记录时传统的LIKE语句无法直接使用,下面小编就为大家介绍两种... 目录问题背景解决方案对比1. 精确匹配方案(推荐)2. 模糊匹配方案参数化查询示例使用场景建议性能优

SpringBoot项目配置logback-spring.xml屏蔽特定路径的日志

《SpringBoot项目配置logback-spring.xml屏蔽特定路径的日志》在SpringBoot项目中,使用logback-spring.xml配置屏蔽特定路径的日志有两种常用方式,文中的... 目录方案一:基础配置(直接关闭目标路径日志)方案二:结合 Spring Profile 按环境屏蔽关

Spring Boot中JSON数值溢出问题从报错到优雅解决办法

《SpringBoot中JSON数值溢出问题从报错到优雅解决办法》:本文主要介绍SpringBoot中JSON数值溢出问题从报错到优雅的解决办法,通过修改字段类型为Long、添加全局异常处理和... 目录一、问题背景:为什么我的接口突然报错了?二、为什么会发生这个错误?1. Java 数据类型的“容量”限制

如何在pycharm安装torch包

《如何在pycharm安装torch包》:本文主要介绍如何在pycharm安装torch包方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录在pycharm安装torch包适http://www.chinasem.cn配于我电脑的指令为适用的torch包为总结在p

pytorch之torch.flatten()和torch.nn.Flatten()的用法

《pytorch之torch.flatten()和torch.nn.Flatten()的用法》:本文主要介绍pytorch之torch.flatten()和torch.nn.Flatten()的用... 目录torch.flatten()和torch.nn.Flatten()的用法下面举例说明总结torch

Python异步编程中asyncio.gather的并发控制详解

《Python异步编程中asyncio.gather的并发控制详解》在Python异步编程生态中,asyncio.gather是并发任务调度的核心工具,本文将通过实际场景和代码示例,展示如何结合信号量... 目录一、asyncio.gather的原始行为解析二、信号量控制法:给并发装上"节流阀"三、进阶控制