【Pytorch】10.CIFAR10模型搭建

2024-05-15 22:04
文章标签 模型 搭建 pytorch cifar10

本文主要是介绍【Pytorch】10.CIFAR10模型搭建,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

CIFAR10模型

torchvision中官方给出的一个数据集,可以通过

dataset = torchvision.datasets.CIFAR10('./data', train=False, download=True, transform=torchvision.transforms.ToTensor())

来下载到指定文件夹

搭建CIFAR10模型

首先我们先去搜一下CIFAR10 model structure
在这里插入图片描述
可以看到,模型的训练步骤为

  • 输入为3通道 32*32像素,通过5*5的卷积核进行卷积操作,得到32通道32*32像素
  • 进行2*2卷积核的最大池化操作变为32通道16*16像素
  • 进行5*5卷积核的卷积操作变为32通道16*16像素
  • 进行2*2卷积核的最大池化操作变为32通道8*8像素
  • 进行5*5卷积核的卷积操作变为64通道8*8像素
  • 进行2*2卷积核的最大池化操作变为64通道4*4像素
  • 进行Flatten全链接操作展开为1024长度
  • 通过线性激活变为64长度
  • 通过线性激活变为10长度
    然后我们就可以进行搭建了

首层卷积层

输入为3通道 32*32像素,通过5*5的卷积核进行卷积操作,得到32通道32*32像素
因为输入输出都是32*32像素,所以我们就需要根据官方给出的公式来计算一下padding为多少
在这里插入图片描述
其中padding为未知变量,dilation为默认值1,stride为默认值1,kernel_size为5
根据输入输出都为32可以求出,padding为2

所以我们的首层卷积为输入3通道,输出32通道,卷积核为5,padding为2

self.conv1 = nn.Conv2d(3, 32, kernel_size=5, padding=2)

其它层的推导同理,这里就不过多赘述

最终结果

import torch
import torchvision
from torch import nn
from torch.utils.tensorboard import SummaryWriterdataset = torchvision.datasets.CIFAR10('./dataset', train=False, download=True,transform=torchvision.transforms.ToTensor())dataLoader = torch.utils.data.DataLoader(dataset, batch_size=64)class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 32, kernel_size=5, padding=2)self.pool1 = nn.MaxPool2d(2)self.conv2 = nn.Conv2d(32, 32, kernel_size=5, padding=2)self.pool2 = nn.MaxPool2d(2)self.conv3 = nn.Conv2d(32, 64, kernel_size=5, padding=2)self.pool3 = nn.MaxPool2d(2)self.flatten = nn.Flatten()self.fc1 = nn.Linear(1024, 64)self.fc2 = nn.Linear(64, 10)def forward(self, x):x = self.conv1(x)x = self.pool1(x)x = self.conv2(x)x = self.pool2(x)x = self.conv3(x)x = self.pool3(x)x = self.flatten(x)x = self.fc1(x)x = self.fc2(x)return xnet = Net()
# print(net)
input_test = torch.ones((64, 3, 32, 32))  # torch.ones用于模拟数据,用于检验定义的每层操作是否有错误
output_test = net(input_test)
# print(output.shape)
writer = SummaryWriter('./logs')
writer.add_graph(net, input_test)       # 给定网络的类和输入的 input

这里新使用了torch.oneswriter.add_graph,根据注释再自己查看一下
在这里插入图片描述
add_graph生成的图像

这篇关于【Pytorch】10.CIFAR10模型搭建的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/993036

相关文章

PyTorch中cdist和sum函数使用示例详解

《PyTorch中cdist和sum函数使用示例详解》torch.cdist是PyTorch中用于计算**两个张量之间的成对距离(pairwisedistance)**的函数,常用于点云处理、图神经网... 目录基本语法输出示例1. 简单的 2D 欧几里得距离2. 批量形式(3D Tensor)3. 使用不

Spring Cloud GateWay搭建全过程

《SpringCloudGateWay搭建全过程》:本文主要介绍SpringCloudGateWay搭建全过程,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐... 目录Spring Cloud GateWay搭建1.搭建注册中心1.1添加依赖1.2 配置文件及启动类1.3 测

PyTorch高级特性与性能优化方式

《PyTorch高级特性与性能优化方式》:本文主要介绍PyTorch高级特性与性能优化方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、自动化机制1.自动微分机制2.动态计算图二、性能优化1.内存管理2.GPU加速3.多GPU训练三、分布式训练1.分布式数据

SpringBoot快速搭建TCP服务端和客户端全过程

《SpringBoot快速搭建TCP服务端和客户端全过程》:本文主要介绍SpringBoot快速搭建TCP服务端和客户端全过程,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,... 目录TCPServerTCPClient总结由于工作需要,研究了SpringBoot搭建TCP通信的过程

Gradle下如何搭建SpringCloud分布式环境

《Gradle下如何搭建SpringCloud分布式环境》:本文主要介绍Gradle下如何搭建SpringCloud分布式环境问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地... 目录Gradle下搭建SpringCloud分布式环境1.idea配置好gradle2.创建一个空的gr

Linux搭建单机MySQL8.0.26版本的操作方法

《Linux搭建单机MySQL8.0.26版本的操作方法》:本文主要介绍Linux搭建单机MySQL8.0.26版本的操作方法,本文通过图文并茂的形式给大家讲解的非常详细,感兴趣的朋友一起看看吧... 目录概述环境信息数据库服务安装步骤下载前置依赖服务下载方式一:进入官网下载,并上传到宿主机中,适合离线环境

判断PyTorch是GPU版还是CPU版的方法小结

《判断PyTorch是GPU版还是CPU版的方法小结》PyTorch作为当前最流行的深度学习框架之一,支持在CPU和GPU(NVIDIACUDA)上运行,所以对于深度学习开发者来说,正确识别PyTor... 目录前言为什么需要区分GPU和CPU版本?性能差异硬件要求如何检查PyTorch版本?方法1:使用命

pytorch自动求梯度autograd的实现

《pytorch自动求梯度autograd的实现》autograd是一个自动微分引擎,它可以自动计算张量的梯度,本文主要介绍了pytorch自动求梯度autograd的实现,具有一定的参考价值,感兴趣... autograd是pytorch构建神经网络的核心。在 PyTorch 中,结合以下代码例子,当你

在PyCharm中安装PyTorch、torchvision和OpenCV详解

《在PyCharm中安装PyTorch、torchvision和OpenCV详解》:本文主要介绍在PyCharm中安装PyTorch、torchvision和OpenCV方式,具有很好的参考价值,... 目录PyCharm安装PyTorch、torchvision和OpenCV安装python安装PyTor

pytorch之torch.flatten()和torch.nn.Flatten()的用法

《pytorch之torch.flatten()和torch.nn.Flatten()的用法》:本文主要介绍pytorch之torch.flatten()和torch.nn.Flatten()的用... 目录torch.flatten()和torch.nn.Flatten()的用法下面举例说明总结torch