pytorch数组处理:排序获取前k个(torch.topk(input , k, dim=1))+ 截取Tensor的几种方法

本文主要是介绍pytorch数组处理:排序获取前k个(torch.topk(input , k, dim=1))+ 截取Tensor的几种方法,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

数组排序并返回前N值

对数组的第n个维度进项排序,并返回排序的前k个元素的values, indices

torch.topk(input, k, dim=n, largest=True, sorted=True, out=None) 
-> (Tensor, LongTensor)

例:取input的第1维

values, indices = torch.topk(input , 1, dim=1)

l a r g e s t { T r u e ,按照大到小排序 F a l s e ,按照小到大排序 largest\left\{\begin{array}{l}True,\mathrm{按照大到小排序}\\False,\mathrm{按照小到大排序}\end{array}\right. largest{True按照大到小排序False按照小到大排序

input:一个tensor数据

k:指明是得到前k个数据以及其index

dim: 指定在哪个维度上排序, 默认是最后一个维度

sorted:返回的结果按照顺序返回

out:可缺省,不要

按照索引取值:torch.gather(input,dim,index),或indicat_select

import torch
input = [[2, 3, 4, 5, 0, 0],[1, 4, 3, 0, 0, 0],[4, 2, 2, 5, 7, 0],[1, 0, 0, 0, 0, 0]
]
input = torch.tensor(input)
#注意index的类型
index = torch.LongTensor([[3],[2],[4],[0]])
#index之所以减1,是因为序列维度是从0开始计算的
out = torch.gather(input, 1, index)
————————————————
版权声明:https://blog.csdn.net/cpluss/article/details/90260550 https://www.zhihu.com/question/374472015

截取Tensor

初始方法

```c
res1 = []
for i in range(10):res1.append(i*3)
res = out[:, res1]
## [narrow](https://pytorch.org/docs/stable/generated/torch.narrow.html?highlight=narrow#torch.narrow)维度范围返回 。返回的张量和张量共享相同的底层存储。
Narrow()的工作原理类似于高级索引。例如,在一个2D张量中,使用[:,0:5]选择列0到5中的所有行。同样的,可以使用torch.narrow(1,0,5)。然而,在高维张量中,对于每个维度都使用range操作是很麻烦的。使用narrow()可以更快更方便地实现这一点。```c
>>> x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
>>> torch.narrow(x, 0, 0, 2)# 沿着x的第0维度,的第0位置开始,向下选取2个距离
tensor([[ 1,  2,  3],[ 4,  5,  6]])
>>> torch.narrow(x, 1, 1, 2)# 沿着x的第1维度,的第1位置开始,向下选取2个距离
tensor([[ 2,  3],[ 5,  6],[ 8,  9]])

mask方式选取torch.masked_select(input,mask)

>>> import torch
>>> x = torch.randn([3, 4])
>>> print(x)tensor([[ 1.2001,  1.2968, -0.6657, -0.6907],[-2.0099,  0.6249, -0.5382,  1.4458],[ 0.0684,  0.4118,  0.1011, -0.5684]])>>> # 将x中的每一个元素与0.5进行比较
>>> # 当元素大于等于0.5返回True,否则返回False
>>> mask = x.ge(0.5)
>>> print(mask)tensor([[ True,  True, False, False],[False,  True, False,  True],[False, False, False, False]])>>> print(torch.masked_select(x, mask))tensor([1.2001, 1.2968, 0.6249, 1.4458])————————————————
版权声明  https://pytorch.org/docs/stable/generated/torch.masked_select.html#torch.masked_select   https://cloud.tencent.com/developer/article/1755706

permute置换操作+res = input[:,0:N]

where()

>>> x = torch.randn(2, 2, dtype=torch.double)
>>> x
tensor([[ 1.0779,  0.0383],[-0.8785, -1.1089]], dtype=torch.float64)
>>> torch.where(x > 0, x, 0.)
tensor([[1.0779, 0.0383],[0.0000, 0.0000]], dtype=torch.float64)



这篇关于pytorch数组处理:排序获取前k个(torch.topk(input , k, dim=1))+ 截取Tensor的几种方法的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1894

相关文章

MySQL启动报错:InnoDB表空间丢失问题及解决方法

《MySQL启动报错:InnoDB表空间丢失问题及解决方法》在启动MySQL时,遇到了InnoDB:Tablespace5975wasnotfound,该错误表明MySQL在启动过程中无法找到指定的s... 目录mysql 启动报错:InnoDB 表空间丢失问题及解决方法错误分析解决方案1. 启用 inno

Python函数返回多个值的多种方法小结

《Python函数返回多个值的多种方法小结》在Python中,函数通常用于封装一段代码,使其可以重复调用,有时,我们希望一个函数能够返回多个值,Python提供了几种不同的方法来实现这一点,需要的朋友... 目录一、使用元组(Tuple):二、使用列表(list)三、使用字典(Dictionary)四、 使

Linux查看系统盘和SSD盘的容量、型号及挂载信息的方法

《Linux查看系统盘和SSD盘的容量、型号及挂载信息的方法》在Linux系统中,管理磁盘设备和分区是日常运维工作的重要部分,而lsblk命令是一个强大的工具,它用于列出系统中的块设备(blockde... 目录1. 查看所有磁盘的物理信息方法 1:使用 lsblk(推荐)方法 2:使用 fdisk -l(

python web 开发之Flask中间件与请求处理钩子的最佳实践

《pythonweb开发之Flask中间件与请求处理钩子的最佳实践》Flask作为轻量级Web框架,提供了灵活的请求处理机制,中间件和请求钩子允许开发者在请求处理的不同阶段插入自定义逻辑,实现诸如... 目录Flask中间件与请求处理钩子完全指南1. 引言2. 请求处理生命周期概述3. 请求钩子详解3.1

使用Python获取JS加载的数据的多种实现方法

《使用Python获取JS加载的数据的多种实现方法》在当今的互联网时代,网页数据的动态加载已经成为一种常见的技术手段,许多现代网站通过JavaScript(JS)动态加载内容,这使得传统的静态网页爬取... 目录引言一、动态 网页与js加载数据的原理二、python爬取JS加载数据的方法(一)分析网络请求1

MySQL查看表的最后一个ID的常见方法

《MySQL查看表的最后一个ID的常见方法》在使用MySQL数据库时,我们经常会遇到需要查看表中最后一个id值的场景,无论是为了调试、数据分析还是其他用途,了解如何快速获取最后一个id都是非常实用的技... 目录背景介绍方法一:使用MAX()函数示例代码解释适用场景方法二:按id降序排序并取第一条示例代码解

Python中合并列表(list)的六种方法小结

《Python中合并列表(list)的六种方法小结》本文主要介绍了Python中合并列表(list)的六种方法小结,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋... 目录一、直接用 + 合并列表二、用 extend() js方法三、用 zip() 函数交叉合并四、用

Java 中的跨域问题解决方法

《Java中的跨域问题解决方法》跨域问题本质上是浏览器的一种安全机制,与Java本身无关,但Java后端开发者需要理解其来源以便正确解决,下面给大家介绍Java中的跨域问题解决方法,感兴趣的朋友一起... 目录1、Java 中跨域问题的来源1.1. 浏览器同源策略(Same-Origin Policy)1.

Python处理大量Excel文件的十个技巧分享

《Python处理大量Excel文件的十个技巧分享》每天被大量Excel文件折磨的你看过来!这是一份Python程序员整理的实用技巧,不说废话,直接上干货,文章通过代码示例讲解的非常详细,需要的朋友可... 目录一、批量读取多个Excel文件二、选择性读取工作表和列三、自动调整格式和样式四、智能数据清洗五、

通过cmd获取网卡速率的代码

《通过cmd获取网卡速率的代码》今天从群里看到通过bat获取网卡速率两段代码,感觉还不错,学习bat的朋友可以参考一下... 1、本机有线网卡支持的最高速度:%v%@echo off & setlocal enabledelayedexpansionecho 代码开始echo 65001编码获取: >