pytorch数组处理:排序获取前k个(torch.topk(input , k, dim=1))+ 截取Tensor的几种方法

本文主要是介绍pytorch数组处理:排序获取前k个(torch.topk(input , k, dim=1))+ 截取Tensor的几种方法,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

数组排序并返回前N值

对数组的第n个维度进项排序,并返回排序的前k个元素的values, indices

torch.topk(input, k, dim=n, largest=True, sorted=True, out=None) 
-> (Tensor, LongTensor)

例:取input的第1维

values, indices = torch.topk(input , 1, dim=1)

l a r g e s t { T r u e ,按照大到小排序 F a l s e ,按照小到大排序 largest\left\{\begin{array}{l}True,\mathrm{按照大到小排序}\\False,\mathrm{按照小到大排序}\end{array}\right. largest{True按照大到小排序False按照小到大排序

input:一个tensor数据

k:指明是得到前k个数据以及其index

dim: 指定在哪个维度上排序, 默认是最后一个维度

sorted:返回的结果按照顺序返回

out:可缺省,不要

按照索引取值:torch.gather(input,dim,index),或indicat_select

import torch
input = [[2, 3, 4, 5, 0, 0],[1, 4, 3, 0, 0, 0],[4, 2, 2, 5, 7, 0],[1, 0, 0, 0, 0, 0]
]
input = torch.tensor(input)
#注意index的类型
index = torch.LongTensor([[3],[2],[4],[0]])
#index之所以减1,是因为序列维度是从0开始计算的
out = torch.gather(input, 1, index)
————————————————
版权声明:https://blog.csdn.net/cpluss/article/details/90260550 https://www.zhihu.com/question/374472015

截取Tensor

初始方法

```c
res1 = []
for i in range(10):res1.append(i*3)
res = out[:, res1]
## [narrow](https://pytorch.org/docs/stable/generated/torch.narrow.html?highlight=narrow#torch.narrow)维度范围返回 。返回的张量和张量共享相同的底层存储。
Narrow()的工作原理类似于高级索引。例如,在一个2D张量中,使用[:,0:5]选择列0到5中的所有行。同样的,可以使用torch.narrow(1,0,5)。然而,在高维张量中,对于每个维度都使用range操作是很麻烦的。使用narrow()可以更快更方便地实现这一点。```c
>>> x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
>>> torch.narrow(x, 0, 0, 2)# 沿着x的第0维度,的第0位置开始,向下选取2个距离
tensor([[ 1,  2,  3],[ 4,  5,  6]])
>>> torch.narrow(x, 1, 1, 2)# 沿着x的第1维度,的第1位置开始,向下选取2个距离
tensor([[ 2,  3],[ 5,  6],[ 8,  9]])

mask方式选取torch.masked_select(input,mask)

>>> import torch
>>> x = torch.randn([3, 4])
>>> print(x)tensor([[ 1.2001,  1.2968, -0.6657, -0.6907],[-2.0099,  0.6249, -0.5382,  1.4458],[ 0.0684,  0.4118,  0.1011, -0.5684]])>>> # 将x中的每一个元素与0.5进行比较
>>> # 当元素大于等于0.5返回True,否则返回False
>>> mask = x.ge(0.5)
>>> print(mask)tensor([[ True,  True, False, False],[False,  True, False,  True],[False, False, False, False]])>>> print(torch.masked_select(x, mask))tensor([1.2001, 1.2968, 0.6249, 1.4458])————————————————
版权声明  https://pytorch.org/docs/stable/generated/torch.masked_select.html#torch.masked_select   https://cloud.tencent.com/developer/article/1755706

permute置换操作+res = input[:,0:N]

where()

>>> x = torch.randn(2, 2, dtype=torch.double)
>>> x
tensor([[ 1.0779,  0.0383],[-0.8785, -1.1089]], dtype=torch.float64)
>>> torch.where(x > 0, x, 0.)
tensor([[1.0779, 0.0383],[0.0000, 0.0000]], dtype=torch.float64)



这篇关于pytorch数组处理:排序获取前k个(torch.topk(input , k, dim=1))+ 截取Tensor的几种方法的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1894

相关文章

MyBatis-plus处理存储json数据过程

《MyBatis-plus处理存储json数据过程》文章介绍MyBatis-Plus3.4.21处理对象与集合的差异:对象可用内置Handler配合autoResultMap,集合需自定义处理器继承F... 目录1、如果是对象2、如果需要转换的是List集合总结对象和集合分两种情况处理,目前我用的MP的版本

JavaScript中比较两个数组是否有相同元素(交集)的三种常用方法

《JavaScript中比较两个数组是否有相同元素(交集)的三种常用方法》:本文主要介绍JavaScript中比较两个数组是否有相同元素(交集)的三种常用方法,每种方法结合实例代码给大家介绍的非常... 目录引言:为什么"相等"判断如此重要?方法1:使用some()+includes()(适合小数组)方法2

SpringBoot 获取请求参数的常用注解及用法

《SpringBoot获取请求参数的常用注解及用法》SpringBoot通过@RequestParam、@PathVariable等注解支持从HTTP请求中获取参数,涵盖查询、路径、请求体、头、C... 目录SpringBoot 提供了多种注解来方便地从 HTTP 请求中获取参数以下是主要的注解及其用法:1

504 Gateway Timeout网关超时的根源及完美解决方法

《504GatewayTimeout网关超时的根源及完美解决方法》在日常开发和运维过程中,504GatewayTimeout错误是常见的网络问题之一,尤其是在使用反向代理(如Nginx)或... 目录引言为什么会出现 504 错误?1. 探索 504 Gateway Timeout 错误的根源 1.1 后端

Python自动化处理PDF文档的操作完整指南

《Python自动化处理PDF文档的操作完整指南》在办公自动化中,PDF文档处理是一项常见需求,本文将介绍如何使用Python实现PDF文档的自动化处理,感兴趣的小伙伴可以跟随小编一起学习一下... 目录使用pymupdf读写PDF文件基本概念安装pymupdf提取文本内容提取图像添加水印使用pdfplum

C# LiteDB处理时间序列数据的高性能解决方案

《C#LiteDB处理时间序列数据的高性能解决方案》LiteDB作为.NET生态下的轻量级嵌入式NoSQL数据库,一直是时间序列处理的优选方案,本文将为大家大家简单介绍一下LiteDB处理时间序列数... 目录为什么选择LiteDB处理时间序列数据第一章:LiteDB时间序列数据模型设计1.1 核心设计原则

MySQL 表空却 ibd 文件过大的问题及解决方法

《MySQL表空却ibd文件过大的问题及解决方法》本文给大家介绍MySQL表空却ibd文件过大的问题及解决方法,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考... 目录一、问题背景:表空却 “吃满” 磁盘的怪事二、问题复现:一步步编程还原异常场景1. 准备测试源表与数据

基于Redis自动过期的流处理暂停机制

《基于Redis自动过期的流处理暂停机制》基于Redis自动过期的流处理暂停机制是一种高效、可靠且易于实现的解决方案,防止延时过大的数据影响实时处理自动恢复处理,以避免积压的数据影响实时性,下面就来详... 目录核心思路代码实现1. 初始化Redis连接和键前缀2. 接收数据时检查暂停状态3. 检测到延时过

python 线程池顺序执行的方法实现

《python线程池顺序执行的方法实现》在Python中,线程池默认是并发执行任务的,但若需要实现任务的顺序执行,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋... 目录方案一:强制单线程(伪顺序执行)方案二:按提交顺序获取结果方案三:任务间依赖控制方案四:队列顺序消

SpringBoot通过main方法启动web项目实践

《SpringBoot通过main方法启动web项目实践》SpringBoot通过SpringApplication.run()启动Web项目,自动推断应用类型,加载初始化器与监听器,配置Spring... 目录1. 启动入口:SpringApplication.run()2. SpringApplicat