pytorch之torch.flatten()和torch.nn.Flatten()的用法

2025-04-11 03:50
文章标签 用法 pytorch torch nn flatten

本文主要是介绍pytorch之torch.flatten()和torch.nn.Flatten()的用法,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

《pytorch之torch.flatten()和torch.nn.Flatten()的用法》:本文主要介绍pytorch之torch.flatten()和torch.nn.Flatten()的用...

torch.flatten()和torch.nn.Flatten()的用法

flatten()函数的作用是将tensor铺平成一维

torch.flatten(input, start_dim=0, end_dim=- 1) → Tensor
  • input (Tensor) – the input tensor.
  • start_dim (int) – the first dim to flatten
  • end_dim (int) – the last dim to flatten

start_dim和end_dim构成了整个你要选择铺平的维度范围

下面举例说明

x = torch.tensor([[1,2], [3,4], [5,6]])
x = x.flatten(0)
x
------------------------
tensor([1, 2, 3, 4, 5, 6])

对于图片数据,我们往往期望进入fc层的维度为(channels, N)这样

x = torch.tensor([[[1,2],[3,4]], [[5,6],[7,8]]])
x = x.flatten(1)
x
-------------------------
tensor([[1, 2],
        [3, 4],
        [5, 6]])

注:

torch.nn.Flatten(start_dim=1, end_dim=- 1)

start_dim 默认为 1

所以在构建网络时,下面两种是等价的

class Classifier(nn.Module):
    def __ipythonnit__(self):
        super(Classifier, self).__init__()
        # The arguments for commonly used modules:
        # torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0)
        # torch.nn.MaxPool2d(kernel_size, stride=None, padding=0)

        # input image size: [3, 128, 128]
        self.cnn_layers = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.BATchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),

            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
China编程            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),

            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=4, stride=4, padding=0),
        )
        self.fc_layers = nn.Sequential(
         编程   nn.Linear(256 * 8 * 8, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            phpnn.Linear(256, 11)
        )

    def forward(self, x):
        # input (x): [batch_size, 3, 128, 128]
        # output: [batch_size, 11]

        # Extract features by convolutional layers.
        x = self.cnn_layers(x)

 编程       # The extracted feature map must be flatten before going to fully-connected layers.
        x = x.flatten(1)

        # The features are transformed by fully-connected layers to obtain the final logits.
        x = self.fc_layers(x)
        return x
class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()

        self.layers = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),

            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),

            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=4, stride=4, padding=0),

            nn.Flatten(),

            nn.Linear(256 * 8 * 8, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 11)
        )

    def forward(self, x):
       
        x = self.layers(x)

        return x

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持编程China编程(www.chinasem.cn)。

这篇关于pytorch之torch.flatten()和torch.nn.Flatten()的用法的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1154173

相关文章

MySQL主从复制与读写分离的用法解读

《MySQL主从复制与读写分离的用法解读》:本文主要介绍MySQL主从复制与读写分离的用法,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、主从复制mysql主从复制原理实验案例二、读写分离实验案例安装并配置mycat 软件设置mycat读写分离验证mycat读

HTML5 中的<button>标签用法和特征

《HTML5中的<button>标签用法和特征》在HTML5中,button标签用于定义一个可点击的按钮,它是创建交互式网页的重要元素之一,本文将深入解析HTML5中的button标签,详细介绍其属... 目录引言<button> 标签的基本用法<button> 标签的属性typevaluedisabled

SQL BETWEEN 语句的基本用法详解

《SQLBETWEEN语句的基本用法详解》SQLBETWEEN语句是一个用于在SQL查询中指定查询条件的重要工具,它允许用户指定一个范围,用于筛选符合特定条件的记录,本文将详细介绍BETWEEN语... 目录概述BETWEEN 语句的基本用法BETWEEN 语句的示例示例 1:查询年龄在 20 到 30 岁

CSS place-items: center解析与用法详解

《CSSplace-items:center解析与用法详解》place-items:center;是一个强大的CSS简写属性,用于同时控制网格(Grid)和弹性盒(Flexbox)... place-items: center; 是一个强大的 css 简写属性,用于同时控制 网格(Grid) 和 弹性盒(F

mysql中insert into的基本用法和一些示例

《mysql中insertinto的基本用法和一些示例》INSERTINTO用于向MySQL表插入新行,支持单行/多行及部分列插入,下面给大家介绍mysql中insertinto的基本用法和一些示例... 目录基本语法插入单行数据插入多行数据插入部分列的数据插入默认值注意事项在mysql中,INSERT I

mapstruct中的@Mapper注解的基本用法

《mapstruct中的@Mapper注解的基本用法》在MapStruct中,@Mapper注解是核心注解之一,用于标记一个接口或抽象类为MapStruct的映射器(Mapper),本文给大家介绍ma... 目录1. 基本用法2. 常用属性3. 高级用法4. 注意事项5. 总结6. 编译异常处理在MapSt

java中long的一些常见用法

《java中long的一些常见用法》在Java中,long是一种基本数据类型,用于表示长整型数值,接下来通过本文给大家介绍java中long的一些常见用法,感兴趣的朋友一起看看吧... 在Java中,long是一种基本数据类型,用于表示长整型数值。它的取值范围比int更大,从-922337203685477

MyBatis ResultMap 的基本用法示例详解

《MyBatisResultMap的基本用法示例详解》在MyBatis中,resultMap用于定义数据库查询结果到Java对象属性的映射关系,本文给大家介绍MyBatisResultMap的基本... 目录MyBATis 中的 resultMap1. resultMap 的基本语法2. 简单的 resul

Python主动抛出异常的各种用法和场景分析

《Python主动抛出异常的各种用法和场景分析》在Python中,我们不仅可以捕获和处理异常,还可以主动抛出异常,也就是以类的方式自定义错误的类型和提示信息,这在编程中非常有用,下面我将详细解释主动抛... 目录一、为什么要主动抛出异常?二、基本语法:raise关键字基本示例三、raise的多种用法1. 抛

java中Optional的核心用法和最佳实践

《java中Optional的核心用法和最佳实践》Java8中Optional用于处理可能为null的值,减少空指针异常,:本文主要介绍java中Optional核心用法和最佳实践的相关资料,文中... 目录前言1. 创建 Optional 对象1.1 常规创建方式2. 访问 Optional 中的值2.1