Pytorch函数——torch.gather详解

2024-01-15 21:20

本文主要是介绍Pytorch函数——torch.gather详解,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

在学习强化学习时,顺便复习复习pytorch的基本内容,遇到了 torch.gather()函数,参考图解PyTorch中的torch.gather函数 - 知乎 (zhihu.com)进行解释。

pytorch官网对函数给出的解释:

image.png

即input是一个矩阵,根据dim的值,将index的值替换到不同的维度的索引,当dim为0时,index替代i的值,成为第0维度的索引。

输入和输出的矩阵形式相同。

例子:首先我们生成3×3的矩阵,明确行索引的概念,第0行指的是[3,4,5],第0列指的是[[3] [6] [9]]

import torch
tensor_0 = torch.arange(3, 12).view(3, 3)
print(tensor_0)
tensor([[ 3,  4,  5],[ 6,  7,  8],[ 9, 10, 11]])

index为行向量且dim=0时

index = torch.tensor([[2, 1, 0]])
tensor_1 = tensor_0.gather(0, index)
print(tensor_1)
tensor([[9, 7, 5]])

当dim=0时,替换第0维度。由于input为二维列表,因此第0维度指的是选择第几行的维度,即行索引所在的维度,替换了i的索引,为input[index[i][j]] [j]

那么我们会输出tensor([[ input[2][j] input[1][j] input[0][j] ]]),那么j如何获得呢?从index of index中拿到,index每一个元素的索引为(0,0) (0,1) (0,2),取j,则为0,1,2,那么输出则为tensor([[ input[2][0] input[1][1] input[0][2] ]]),即

tensor([[9, 7, 5]])

输入行向量index,且dim=1

index = torch.tensor([[2, 1, 0]])
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)
tensor([[5, 4, 3]])

维度为1,则替换列索引的值,那么输出为tensor([[ input[i][2] input[i][1] input[i][0] ]]),index每一个元素的索引为(0,0) (0,1) (0,2),i均为1,那么tensor([[ input[0][2] input[0][1] input[0][0] ]])

输入为行向量,dim=0

index = torch.tensor([[2, 1, 0]]).t()
tensor_1 = tensor_0.gather(0, index)
print(tensor_1)
tensor([[5],[7],[9]])

维度为0,则替换行索引,且输出与输入的格式相同,为

tensor([input[2][j],input[1][j],input[0][j]])

index每一个元素的索引为(0,0) (1,0) (2,0),j对应的值为0,0,0,则

tensor([input[2][0],input[1][0],input[0][0]])

输入为行向量,dim=1

index = torch.tensor([[2, 1, 0]]).t()
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)
tensor([[5],[7],[9]])

维度为1,则替换列索引,且输出与输入的格式相同,为

tensor([input[i][2],input[i][1],input[i][0]])

index每一个元素的索引为(0,0) (1,0) (2,0),i对应的值为0,1,2,则

tensor([input[0][2],input[1][1],input[2][0]])

输入为二维矩阵,且dim=1

index = torch.tensor([[0, 2],[1, 2]])
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)
tensor([[3, 5],[7, 8]])

维度为1,则替换列索引,且输出与输入的格式相同,为

tensor([[input[i][0], input[i][2]],[input[i][1], input[i][2]]])

替换为行索引后,可得:

tensor([[input[0][0], input[0][2]],[input[1][1], input[1][2]]])

在强化学习中的应用

在PyTorch官网DQN页面的代码中,i是state,j是a

# Compute Q(s_t, a) - the model computes Q(s_t), then we select the
# columns of actions taken. These are the actions which would've been taken
# for each batch state according to policy_net
state_action_values = policy_net(state_batch).gather(1, action_batch)

我们使用dim=1action_batch 将获得的动作列表替换为列索引,即可获得每个state下该动作的动作价值。

这篇关于Pytorch函数——torch.gather详解的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/610246

相关文章

MySQL数据库双机热备的配置方法详解

《MySQL数据库双机热备的配置方法详解》在企业级应用中,数据库的高可用性和数据的安全性是至关重要的,MySQL作为最流行的开源关系型数据库管理系统之一,提供了多种方式来实现高可用性,其中双机热备(M... 目录1. 环境准备1.1 安装mysql1.2 配置MySQL1.2.1 主服务器配置1.2.2 从

Linux kill正在执行的后台任务 kill进程组使用详解

《Linuxkill正在执行的后台任务kill进程组使用详解》文章介绍了两个脚本的功能和区别,以及执行这些脚本时遇到的进程管理问题,通过查看进程树、使用`kill`命令和`lsof`命令,分析了子... 目录零. 用到的命令一. 待执行的脚本二. 执行含子进程的脚本,并kill2.1 进程查看2.2 遇到的

MyBatis常用XML语法详解

《MyBatis常用XML语法详解》文章介绍了MyBatis常用XML语法,包括结果映射、查询语句、插入语句、更新语句、删除语句、动态SQL标签以及ehcache.xml文件的使用,感兴趣的朋友跟随小... 目录1、定义结果映射2、查询语句3、插入语句4、更新语句5、删除语句6、动态 SQL 标签7、ehc

详解SpringBoot+Ehcache使用示例

《详解SpringBoot+Ehcache使用示例》本文介绍了SpringBoot中配置Ehcache、自定义get/set方式,并实际使用缓存的过程,文中通过示例代码介绍的非常详细,对大家的学习或者... 目录摘要概念内存与磁盘持久化存储:配置灵活性:编码示例引入依赖:配置ehcache.XML文件:配置

从基础到高级详解Go语言中错误处理的实践指南

《从基础到高级详解Go语言中错误处理的实践指南》Go语言采用了一种独特而明确的错误处理哲学,与其他主流编程语言形成鲜明对比,本文将为大家详细介绍Go语言中错误处理详细方法,希望对大家有所帮助... 目录1 Go 错误处理哲学与核心机制1.1 错误接口设计1.2 错误与异常的区别2 错误创建与检查2.1 基础

k8s按需创建PV和使用PVC详解

《k8s按需创建PV和使用PVC详解》Kubernetes中,PV和PVC用于管理持久存储,StorageClass实现动态PV分配,PVC声明存储需求并绑定PV,通过kubectl验证状态,注意回收... 目录1.按需创建 PV(使用 StorageClass)创建 StorageClass2.创建 PV

Python版本信息获取方法详解与实战

《Python版本信息获取方法详解与实战》在Python开发中,获取Python版本号是调试、兼容性检查和版本控制的重要基础操作,本文详细介绍了如何使用sys和platform模块获取Python的主... 目录1. python版本号获取基础2. 使用sys模块获取版本信息2.1 sys模块概述2.1.1

一文详解Python如何开发游戏

《一文详解Python如何开发游戏》Python是一种非常流行的编程语言,也可以用来开发游戏模组,:本文主要介绍Python如何开发游戏的相关资料,文中通过代码介绍的非常详细,需要的朋友可以参考下... 目录一、python简介二、Python 开发 2D 游戏的优劣势优势缺点三、Python 开发 3D

Python函数作用域与闭包举例深度解析

《Python函数作用域与闭包举例深度解析》Python函数的作用域规则和闭包是编程中的关键概念,它们决定了变量的访问和生命周期,:本文主要介绍Python函数作用域与闭包的相关资料,文中通过代码... 目录1. 基础作用域访问示例1:访问全局变量示例2:访问外层函数变量2. 闭包基础示例3:简单闭包示例4

Redis 基本数据类型和使用详解

《Redis基本数据类型和使用详解》String是Redis最基本的数据类型,一个键对应一个值,它的功能十分强大,可以存储字符串、整数、浮点数等多种数据格式,本文给大家介绍Redis基本数据类型和... 目录一、Redis 入门介绍二、Redis 的五大基本数据类型2.1 String 类型2.2 Hash