【无标题】PyTorch 常用算子说明

2024-05-29 10:52

本文主要是介绍【无标题】PyTorch 常用算子说明,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1.增加维度

        print(a.unsqueeze(0).shape)  # 在0号维度位置插入一个维度

        print(a.unsqueeze(-1).shape)  # 在最后插入一个维度

        print(a.unsqueeze(3).shape)  # 在3号维度位置插入一个维度

2.删减维度

        a = torch.Tensor(1, 4, 1, 9)

        print(a.squeeze().shape) # 能删除的都删除掉

        print(a.squeeze(0).shape) # 尝试删除0号维度,ok

3.维度扩展(expand)

        b = torch.rand(32)

        f = torch.rand(4, 32, 14, 14)

        # 先进行维度增加

        b = b.unsqueeze(1).unsqueeze(2).unsqueeze(0)

        print(b.shape)

        # 再进行维度扩展

        b = b.expand(4, -1, 14, 14)  # -1表示这个维度保持不变,这里写32也可以

        print(b.shape)                  

         输出:

         torch.Size([1, 32, 1, 1])

        torch.Size([4, 32, 14, 14])

4.维度重复(repeat)

        print(b.shape)

        # 维度重复,32这里不想进行重复,所以就相当于"重复至1次"

        b = b.repeat(4, 1, 14, 14)

        print(b.shape)

        输出:

        torch.Size([1, 32, 1, 1])

        torch.Size([4, 32, 14, 14])

5.转置

        只适用于dim=2的Tensor。

        c = torch.Tensor(2, 4)

        print(c.t().shape)

        输出:

        torch.Size([4, 2])

  6. 维度交换

       d = torch.Tensor(6, 3, 1, 2)

        print(d.transpose(1, 3).contiguous().shape)  # 1号维度和3号维度交换

        输出:

        torch.Size([6, 2, 1, 3])

  7. permute

        h = torch.rand(4, 3, 6, 7)

        print(h.permute(0, 2, 3, 1).shape)

        输出:

        torch.Size([4, 6, 7, 3])

  8.gather

        1)input:输入

        2)dim:维度,常用的为0和1

        3)index:索引位置

        a=t.arange(0,16).view(4,4)

        print(a)

        index_1=t.LongTensor([[3,2,1,0]])

        b=a.gather(0,index_1)

        print(b)

        index_2=t.LongTensor([[0,1,2,3]]).t()#tensor转置操作:(a)T=a.t()

        c=a.gather(1,index_2)

        print(c)

        outout输出:

        tensor([[ 0,  1,  2,  3],

                         [ 4,  5,  6,  7],

                [ 8,  9, 10, 11],

                [12, 13, 14, 15]])

                tensor([[12,  9,  6,  3]])

        tensor([[ 0],

                   [ 5],

                  [10],

                  [15]])

        在gather中,我们是通过index对input进行索引把对应的数据提取出来的,而dim决定了索引的方式。

9.Chunk

             torch.chunk(tensor, chunks, dim=0)

              在给定维度(轴)上将输入张量进行分块儿

             直接用上面的数据来举个例子:

             l, m, n = x.chunk(3, 0) # 在 0 维上拆分成 3 份

             l.size(), m.size(), n.size()

              (torch.Size([1, 10, 6]), torch.Size([1, 10, 6]), torch.Size([1, 10, 6]))

                u, v = x.chunk(2, 0) # 在 0 维上拆分成 2 份

                u.size(), v.size()

        (torch.Size([2, 10, 6]), torch.Size([1, 10, 6]))

10.Stack

              合并新增(stack)

              stack需要保证两个Tensor的shape是一致的。

                c = torch.rand(4, 3, 32, 32)

                d = torch.rand(4, 3, 32, 32)

                print(torch.stack([c, d], dim=2).shape)

                print(torch.stack([c, d], dim=0).shape)

        运行结果:

                torch.Size([4, 3, 2, 32, 32])

                torch.Size([2, 4, 3, 32, 32])

11.View

        Pytorch中的view函数主要用于Tensor维度的重构,即返回一个有相同数据但不同维度的Tensor。

a3 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
                   13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24])
a4 = a3.view(4, -1)
a5 = a3.view(2, 3, -1)

输出:

#a3

tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
        19, 20, 21, 22, 23, 24])

#a4

tensor([[ 1,  2,  3,  4,  5,  6],
        [ 7,  8,  9, 10, 11, 12],
        [13, 14, 15, 16, 17, 18],
        [19, 20, 21, 22, 23, 24]])

#a5
tensor([[[ 1,  2,  3,  4],
         [ 5,  6,  7,  8],
         [ 9, 10, 11, 12]],
        [[13, 14, 15, 16],
         [17, 18, 19, 20],
         [21, 22, 23, 24]]])

12.reshape

        返回与 input张量数据大小一样、给定 shape的张量。如果可能,返回的是input 张量的视图,否则返回的是其拷贝。

a1 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
a2 = torch.reshape(a1, (3, 4))
print(a1.shape)
print(a1)
print(a2.shape)
print(a2)

运行结果:

torch.Size([12])
tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12])
torch.Size([3, 4])
tensor([[ 1,  2,  3,  4],
        [ 5,  6,  7,  8],
        [ 9, 10, 11, 12]])

同view函数,也可以自动推断维度:a4 = torch.reshape(a1, (-1, 6))


 

这篇关于【无标题】PyTorch 常用算子说明的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!


原文地址:
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.chinasem.cn/article/1013470

相关文章

Spring中管理bean对象的方式(专业级说明)

《Spring中管理bean对象的方式(专业级说明)》在Spring框架中,Bean的管理是核心功能,主要通过IoC(控制反转)容器实现,下面给大家介绍Spring中管理bean对象的方式,感兴趣的朋... 目录1.Bean的声明与注册1.1 基于XML配置1.2 基于注解(主流方式)1.3 基于Java

Pytorch介绍与安装过程

《Pytorch介绍与安装过程》PyTorch因其直观的设计、卓越的灵活性以及强大的动态计算图功能,迅速在学术界和工业界获得了广泛认可,成为当前深度学习研究和开发的主流工具之一,本文给大家介绍Pyto... 目录1、Pytorch介绍1.1、核心理念1.2、核心组件与功能1.3、适用场景与优势总结1.4、优

conda安装GPU版pytorch默认却是cpu版本

《conda安装GPU版pytorch默认却是cpu版本》本文主要介绍了遇到Conda安装PyTorchGPU版本却默认安装CPU的问题,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的... 目录一、问题描述二、网上解决方案罗列【此节为反面方案罗列!!!】三、发现的根本原因[独家]3.1 p

Java实现本地缓存的常用方案介绍

《Java实现本地缓存的常用方案介绍》本地缓存的代表技术主要有HashMap,GuavaCache,Caffeine和Encahche,这篇文章主要来和大家聊聊java利用这些技术分别实现本地缓存的方... 目录本地缓存实现方式HashMapConcurrentHashMapGuava CacheCaffe

PyTorch中cdist和sum函数使用示例详解

《PyTorch中cdist和sum函数使用示例详解》torch.cdist是PyTorch中用于计算**两个张量之间的成对距离(pairwisedistance)**的函数,常用于点云处理、图神经网... 目录基本语法输出示例1. 简单的 2D 欧几里得距离2. 批量形式(3D Tensor)3. 使用不

Python将字符串转换为小写字母的几种常用方法

《Python将字符串转换为小写字母的几种常用方法》:本文主要介绍Python中将字符串大写字母转小写的四种方法:lower()方法简洁高效,手动ASCII转换灵活可控,str.translate... 目录一、使用内置方法 lower()(最简单)二、手动遍历 + ASCII 码转换三、使用 str.tr

PyTorch高级特性与性能优化方式

《PyTorch高级特性与性能优化方式》:本文主要介绍PyTorch高级特性与性能优化方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、自动化机制1.自动微分机制2.动态计算图二、性能优化1.内存管理2.GPU加速3.多GPU训练三、分布式训练1.分布式数据

Spring Boot 常用注解整理(最全收藏版)

《SpringBoot常用注解整理(最全收藏版)》本文系统整理了常用的Spring/SpringBoot注解,按照功能分类进行介绍,每个注解都会涵盖其含义、提供来源、应用场景以及代码示例,帮助开发... 目录Spring & Spring Boot 常用注解整理一、Spring Boot 核心注解二、Spr

Java中的内部类和常用类用法解读

《Java中的内部类和常用类用法解读》:本文主要介绍Java中的内部类和常用类用法,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录内部类和常用类内部类成员内部类静态内部类局部内部类匿名内部类常用类Object类包装类String类StringBuffer和Stri

MySQL连接池(Pool)常用方法详解

《MySQL连接池(Pool)常用方法详解》本文详细介绍了MySQL连接池的常用方法,包括创建连接池、核心方法连接对象的方法、连接池管理方法以及事务处理,同时,还提供了最佳实践和性能提示,帮助开发者构... 目录mysql 连接池 (Pool) 常用方法详解1. 创建连接池2. 核心方法2.1 pool.q