torch.gather——沿特定维度收集数值

2024-01-28 06:30

本文主要是介绍torch.gather——沿特定维度收集数值,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

PyTorch学习笔记:torch.gather——沿特定维度收集数值

torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor

功能:从输入的数组中,沿指定的dim维度,利用索引变量index,将数据索引出来,并且堆叠成一个数组。直观可能不好理解,具体可以见代码案例。

输入:

input:输入的数组

dim:指定的维度

index:索引变量,数据类型需是长整型(int64)

注意:

  • inputindex具有相同的维数

  • outindex具有相同的形状

  • 除了dim维度,在每个维度上,索引在该维度上的大小要小于等于输入在该维度上的大小,即:
    i n d e x . s i z e ( d ) ≤ i n p u t . s i z e ( d ) , d ! = d i m index.size(d)≤input.size(d),\quad d!=dim index.size(d)input.size(d),d!=dim

代码案例

一般用法,当在一个维度上进行索引时,以第一维度为例

import torch
a=torch.arange(20).reshape(4,5)
index=torch.tensor([[2,3,1,3]])
b=torch.gather(a,dim=0,index=index)
print(a)
print(b)

输出

在这里插入图片描述

以第二维度为例

import torch
a=torch.arange(20).reshape(4,5)
index=torch.tensor([[2],[3],[1],[3]])
b=torch.gather(a,dim=1,index=index)
print(a)
print(b)

输出

在这里插入图片描述

当同时在两个维度上进行索引时,以第一维度为例

import torch
a=torch.arange(20).reshape(4,5)
index=torch.tensor([[1,2,3],[2,3,0],[3,0,1]])
b=torch.gather(a,dim=0,index=index)
print(a)
print(b)

输出

tensor([[ 0,  1,  2,  3,  4],[ 5,  6,  7,  8,  9],[10, 11, 12, 13, 14],[15, 16, 17, 18, 19]])
tensor([[ 5, 11, 17],[10, 16,  2],[15,  1,  7]])

以第二维度为例

import torch
a=torch.arange(20).reshape(4,5)
index=torch.tensor([[1,2],[2,3],[3,4]])
b=torch.gather(a,dim=1,index=index)
print(a)
print(b)

输出

tensor([[ 0,  1,  2,  3,  4],[ 5,  6,  7,  8,  9],[10, 11, 12, 13, 14],[15, 16, 17, 18, 19]])
tensor([[ 1,  2],[ 7,  8],[13, 14]])

官方文档

torch.gather:https://pytorch.org/docs/stable/generated/torch.gather.html?highlight=torch%20gather#torch.gather

这篇关于torch.gather——沿特定维度收集数值的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/652751

相关文章

SpringBoot项目配置logback-spring.xml屏蔽特定路径的日志

《SpringBoot项目配置logback-spring.xml屏蔽特定路径的日志》在SpringBoot项目中,使用logback-spring.xml配置屏蔽特定路径的日志有两种常用方式,文中的... 目录方案一:基础配置(直接关闭目标路径日志)方案二:结合 Spring Profile 按环境屏蔽关

Spring Boot中JSON数值溢出问题从报错到优雅解决办法

《SpringBoot中JSON数值溢出问题从报错到优雅解决办法》:本文主要介绍SpringBoot中JSON数值溢出问题从报错到优雅的解决办法,通过修改字段类型为Long、添加全局异常处理和... 目录一、问题背景:为什么我的接口突然报错了?二、为什么会发生这个错误?1. Java 数据类型的“容量”限制

如何在pycharm安装torch包

《如何在pycharm安装torch包》:本文主要介绍如何在pycharm安装torch包方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录在pycharm安装torch包适http://www.chinasem.cn配于我电脑的指令为适用的torch包为总结在p

pytorch之torch.flatten()和torch.nn.Flatten()的用法

《pytorch之torch.flatten()和torch.nn.Flatten()的用法》:本文主要介绍pytorch之torch.flatten()和torch.nn.Flatten()的用... 目录torch.flatten()和torch.nn.Flatten()的用法下面举例说明总结torch

Python异步编程中asyncio.gather的并发控制详解

《Python异步编程中asyncio.gather的并发控制详解》在Python异步编程生态中,asyncio.gather是并发任务调度的核心工具,本文将通过实际场景和代码示例,展示如何结合信号量... 目录一、asyncio.gather的原始行为解析二、信号量控制法:给并发装上"节流阀"三、进阶控制

使用Java解析JSON数据并提取特定字段的实现步骤(以提取mailNo为例)

《使用Java解析JSON数据并提取特定字段的实现步骤(以提取mailNo为例)》在现代软件开发中,处理JSON数据是一项非常常见的任务,无论是从API接口获取数据,还是将数据存储为JSON格式,解析... 目录1. 背景介绍1.1 jsON简介1.2 实际案例2. 准备工作2.1 环境搭建2.1.1 添加

JS常用组件收集

收集了一些平时遇到的前端比较优秀的组件,方便以后开发的时候查找!!! 函数工具: Lodash 页面固定: stickUp、jQuery.Pin 轮播: unslider、swiper 开关: switch 复选框: icheck 气泡: grumble 隐藏元素: Headroom

【编程底层思考】垃圾收集机制,GC算法,垃圾收集器类型概述

Java的垃圾收集(Garbage Collection,GC)机制是Java语言的一大特色,它负责自动管理内存的回收,释放不再使用的对象所占用的内存。以下是对Java垃圾收集机制的详细介绍: 一、垃圾收集机制概述: 对象存活判断:垃圾收集器定期检查堆内存中的对象,判断哪些对象是“垃圾”,即不再被任何引用链直接或间接引用的对象。内存回收:将判断为垃圾的对象占用的内存进行回收,以便重新使用。

理解java虚拟机内存收集

学习《深入理解Java虚拟机》时个人的理解笔记 1、为什么要去了解垃圾收集和内存回收技术? 当需要排查各种内存溢出、内存泄漏问题时,当垃圾收集成为系统达到更高并发量的瓶颈时,我们就必须对这些“自动化”的技术实施必要的监控和调节。 2、“哲学三问”内存收集 what?when?how? 那些内存需要回收?什么时候回收?如何回收? 这是一个整体的问题,确定了什么状态的内存可以

Android中如何实现adb向应用发送特定指令并接收返回

1 ADB发送命令给应用 1.1 发送自定义广播给系统或应用 adb shell am broadcast 是 Android Debug Bridge (ADB) 中用于向 Android 系统发送广播的命令。通过这个命令,开发者可以发送自定义广播给系统或应用,触发应用中的广播接收器(BroadcastReceiver)。广播机制是 Android 的一种组件通信方式,应用可以监听广播来执行