Pytorch基础:torch.expand() 和 torch.repeat()

2024-05-09 00:28

本文主要是介绍Pytorch基础:torch.expand() 和 torch.repeat(),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

在torch中,如果要改变某一个tensor的维度,可以利用viewexpandrepeattransposepermute等方法,这里对这些方法的一些容易混淆的地方做个总结。

expand和repeat函数是pytorch中常用于进行张量数据复制维度扩展的函数,但其工作机制差别很大,本文对这两个函数进行对比。

1. torch.expand()

  • 作用: expand()函数可以将张量广播到新的形状。
  • 注意: 只能对维度值为1的维度进行扩展无需扩展的维度,维度值不变,对应位置可写上原始维度大小或直接写作-1;且扩展的Tensor不会分配新的内存,只是原来的基础上创建新的视图并返回,返回的张量内存不连续的。类似于numpy中的broadcast_to函数的作用。如果希望张量内存连续,可以调用contiguous函数。

expand函数用于将张量中单数维的数据扩展到指定的size。

首先解释下什么叫单数维(singleton dimensions),张量在某个维度上的size为1,则称为单数维。比如zeros(2,3,4)不存在单数维,而zeros(2,1,4)在第二个维度(即维度1)上为单数维。expand函数仅仅能作用于这些单数维的维度上

参数*sizes用于逐个指定各个维度扩展后的大小(也可以理解为拓展的次数),对于不需要或者无法(即非单数维)进行扩展的维度,对应位置可写上原始维度大小或直接写作-1

expand函数可能导致原始张量的升维,其作用在张量前面的维度上(在tensor的低维增加更多维度),因此通过expand函数可将张量数据复制多份(可理解为沿着第一个batch的维度上)。

import torcha = torch.tensor([1, 0, 2])     # a -> torch.Size([3])
b1 = a.expand(2, -1)            # 第一个维度为升维,第二个维度保持原样
'''
b1为 -> torch.Size([3, 2])
tensor([[1, 0, 2],[1, 0, 2]])
'''a = torch.tensor([[1], [0], [2]])   # a -> torch.Size([3, 1])
b2 = a.expand(-1, 2)                 # 保持第一个维度,第二个维度只有一个元素,可扩展
'''
b2 -> torch.Size([3, 2])
b2为  tensor([[1, 1],[0, 0],[2, 2]])
'''a = torch.Tensor([[1, 2, 3]])   # a -> torch.Size([1, 3])
b3 = a.expand(4, 3)              # 也可写为a.expand(4, -1)  对于某一个维度上的值为1的维度,# 可以在该维度上进行tensor的复制,若大于1则不行
'''
b3 -> torch.Size([4, 3])
tensor([[1.,2.,3.],[1.,2.,3.],[1.,2.,3.],[1.,2.,3.]]
)
'''a = torch.Tensor([[1, 2, 3], [4, 5, 6]])  # a -> torch.Size([2, 3])
b4 = a.expand(4, 6)  # 最高几个维度的参数必须和原始shape保持一致,否则报错
'''
RuntimeError: The expanded size of the tensor (6) must match 
the existing size (3) at non-singleton dimension 1.
'''b5 = a.expand(1, 2, 3)  # 可以在tensor的低维增加更多维度
'''
b5 -> torch.Size([1,2, 3])
tensor([[[1.,2.,3.],[4.,5.,6.]]]
)
'''
b6 = a.expand(2, 2, 3)  # 可以在tensor的低维增加更多维度,同时在新增加的低维度上进行tensor的复制
'''
b5 -> torch.Size([2,2, 3])
tensor([[[1.,2.,3.],[4.,5.,6.]],[[1.,2.,3.],[4.,5.,6.]]]
)
'''b7 = a.expand(2, 3, 2)  # 不可在更高维增加维度,否则报错
'''
RuntimeError: The expanded size of the tensor (2) must match the 
existing size (3) at non-singleton dimension 2.
'''b8 = a.expand(2, -1, -1)  # 最高几个维度的参数可以用-1,表示和原始维度一致
'''
b8 -> torch.Size([2,2, 3])
tensor([[[1.,2.,3.],[4.,5.,6.]],[[1.,2.,3.],[4.,5.,6.]]]
)
'''# expand返回的张量与原版张量具有相同内存地址
print(b8.storage())  # 存储区的数据,说明expand后的a,aa,aaa,aaaa是共享storage的,
# 只是tensor的头信息区设置了不同的数据展示格式,从而使得a,aa,aaa,aaaa呈现不同的tensor形式
'''
1.0
2.0
3.0
4.0
5.0
6.0
'''

1.1 expand_as

可视为expand的另一种表达,其size通过函数传递的目标张量的size来定义。

import torch
a = torch.tensor([1, 0, 2])
b = torch.zeros(2, 3)
c = a.expand_as(b)  # a照着b的维度大小进行拓展
# c为 tensor([[1, 0, 2],
#        [1, 0, 2]])

2 tensor.repeat()

沿着特定维度扩展张量,并返回扩展后的张量

  • 作用:和expand()作用类似,均是将tensor广播到新的形状。
  • 注意:不允许使用维度-1,1即为不变
import torchif __name__ == '__main__':x = torch.rand(2, 3)y1 = x.repeat(4, 2)print(y1.shape)  # torch.Size([8, 6])

3. 两者内存占用的区别

  • torch.expand 不会占用额外空间,只是在存在的张量上创建一个新的视图

  • torch.repeat 和 torch.expand 不同,它是拷贝了数据,会占用额外的空间

示例如下:

import torchif __name__ == '__main__':x = torch.rand(1, 3)y1 = x.expand(4, 3)y2 = x.repeat(2, 3)print(x.storage().data_ptr(), y1.storage().data_ptr())  # 52364352 52364352print(x.storage().data_ptr(), y2.storage().data_ptr())  # 52364352 8852096

这篇关于Pytorch基础:torch.expand() 和 torch.repeat()的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/971869

相关文章

从基础到进阶详解Pandas时间数据处理指南

《从基础到进阶详解Pandas时间数据处理指南》Pandas构建了完整的时间数据处理生态,核心由四个基础类构成,Timestamp,DatetimeIndex,Period和Timedelta,下面我... 目录1. 时间数据类型与基础操作1.1 核心时间对象体系1.2 时间数据生成技巧2. 时间索引与数据

安装centos8设置基础软件仓库时出错的解决方案

《安装centos8设置基础软件仓库时出错的解决方案》:本文主要介绍安装centos8设置基础软件仓库时出错的解决方案,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐... 目录安装Centos8设置基础软件仓库时出错版本 8版本 8.2.200android4版本 javas

Pytorch介绍与安装过程

《Pytorch介绍与安装过程》PyTorch因其直观的设计、卓越的灵活性以及强大的动态计算图功能,迅速在学术界和工业界获得了广泛认可,成为当前深度学习研究和开发的主流工具之一,本文给大家介绍Pyto... 目录1、Pytorch介绍1.1、核心理念1.2、核心组件与功能1.3、适用场景与优势总结1.4、优

conda安装GPU版pytorch默认却是cpu版本

《conda安装GPU版pytorch默认却是cpu版本》本文主要介绍了遇到Conda安装PyTorchGPU版本却默认安装CPU的问题,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的... 目录一、问题描述二、网上解决方案罗列【此节为反面方案罗列!!!】三、发现的根本原因[独家]3.1 p

Linux基础命令@grep、wc、管道符的使用详解

《Linux基础命令@grep、wc、管道符的使用详解》:本文主要介绍Linux基础命令@grep、wc、管道符的使用,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐... 目录grep概念语法作用演示一演示二演示三,带选项 -nwc概念语法作用wc,不带选项-c,统计字节数-

python操作redis基础

《python操作redis基础》Redis(RemoteDictionaryServer)是一个开源的、基于内存的键值对(Key-Value)存储系统,它通常用作数据库、缓存和消息代理,这篇文章... 目录1. Redis 简介2. 前提条件3. 安装 python Redis 客户端库4. 连接到 Re

PyTorch中cdist和sum函数使用示例详解

《PyTorch中cdist和sum函数使用示例详解》torch.cdist是PyTorch中用于计算**两个张量之间的成对距离(pairwisedistance)**的函数,常用于点云处理、图神经网... 目录基本语法输出示例1. 简单的 2D 欧几里得距离2. 批量形式(3D Tensor)3. 使用不

SpringBoot基础框架详解

《SpringBoot基础框架详解》SpringBoot开发目的是为了简化Spring应用的创建、运行、调试和部署等,使用SpringBoot可以不用或者只需要很少的Spring配置就可以让企业项目快... 目录SpringBoot基础 – 框架介绍1.SpringBoot介绍1.1 概述1.2 核心功能2

PyTorch高级特性与性能优化方式

《PyTorch高级特性与性能优化方式》:本文主要介绍PyTorch高级特性与性能优化方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、自动化机制1.自动微分机制2.动态计算图二、性能优化1.内存管理2.GPU加速3.多GPU训练三、分布式训练1.分布式数据

Spring Boot集成SLF4j从基础到高级实践(最新推荐)

《SpringBoot集成SLF4j从基础到高级实践(最新推荐)》SLF4j(SimpleLoggingFacadeforJava)是一个日志门面(Facade),不是具体的日志实现,这篇文章主要介... 目录一、日志框架概述与SLF4j简介1.1 为什么需要日志框架1.2 主流日志框架对比1.3 SLF4