卷积神经网络(CNN)使用PyTorch实现卷积神经网络对CIFAR-10数据集进行图片分类(代码➕注释)

本文主要是介绍卷积神经网络(CNN)使用PyTorch实现卷积神经网络对CIFAR-10数据集进行图片分类(代码➕注释),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

目录

一、CNN概述

二、CNN网络结构

三、CNN常见名词

四、使用PyTorch实现卷积神经网络对CIFAR-10数据集进行图片分类


一、CNN概述

        卷积神经网络 ( Convolutional Neural NetworkCNN) 作为人工神经网络中一种常见的深度学习架构,该网络是受到生物自然视觉认知机制启发而来,是一种特殊的多层前馈神经网络, CNN 是由简单的神经网络改进而来,使用卷积层和池化层替代全连接层结构,卷积层能够有效地将图像中的各种特征提取出并生成特征图。广泛应用于图像识别图像分类等领域 ,具有良好的扩展性和鲁棒性,截至目前,CNN 的深度呈不断增加的趋势

        CNN在图像分类识别中要做的事情是:给定一张图片,图片中是牛还是马不知道,是什么牛也不知道,现在需要模型判断这张图片里具体是一个什么东西,总之输出一个结果:如果是牛的话,那是什么牛?

【1】鲁棒性也称作健壮性(英语:Robustness一个系统或组织有抵御或克服不利条件的能力。鲁棒性则常被用来描述可以面对复杂适应系统的能力,需要更全面的对系统进行考虑。

二、CNN网络结构

1)输入层(Input layer),众多神经元(Neuron)接受大量非线形输入讯息。输入的讯息称为输入向量。

2)卷积层:是一块一块地来进行比对。它拿来比对的这个“小块”我们称之为Features,每一个feature就像是一个小图,对图像和滤波矩阵做内积(逐个元素相乘再求和)的操作就是所谓的卷积”操作,也是卷积神经网络的名字来源。

【1】卷积:滤波器filter与数据窗口做内积(在CNN中,滤波器filter带着一组固定权重的神经元)对局部输入数据进行卷积计算。每计算完一个数据窗口内的局部数据后,数据窗口不断平移滑动,直到计算完所有数据

3)池化pool层:保留主要的特征进一步删减冗余参数,提高特征提取效率。池化,简言之,即取区域平均或最大。

5)全连接层:就是把特征整合到一起(高度提纯特征),方便交给最后的分类器或者回归。

三、CNN常见名词

1感受野:某一个输出层的一个元素对应输入层的区域大小,被称为感受野,即输出层的一个元素在输入层上的映射区域。

2激活函数:常用的非线性激活函数有sigmoidtanhrelu等等,前两者sigmoid/tanh比较常见于全连接层,后者relu常见于卷积层。

四、使用PyTorch实现卷积神经网络对CIFAR-10数据集进行图片分类

主要步骤是:

1. 加载和预处理CIFAR-10数据集
2. 定义卷积神经网络 ConvNet 模型
3. 定义交叉熵损失函数和SGD优化器
4. 训练模型50个epoch
5. 打印训练损失并完成训练

import torch 
import torch.nn as nn 
import torch.nn.functional as F 
import torchvision 
import torchvision.transforms as transforms
import matplotlib.pyplot as plt# 训练数据
transform = transforms.Compose([transforms.ToTensor(),     # 转为tensortransforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])   # 归一化trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)    # 测试数据    
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False)  classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')# 卷积神经网络定义
class ConvNet(nn.Module):def __init__(self):super(ConvNet, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))    # 2层卷积池化x = self.pool(F.relu(self.conv2(x)))    # 2层卷积池化x = x.view(-1, 16 * 5 * 5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xmodel = ConvNet()
criterion = nn.CrossEntropyLoss()       # 损失函数定义
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)   # 优化器定义# 训练网络
for epoch in range(50):   # 50个epochrunning_loss = 0.0for i, data in enumerate(trainloader, 0):   # 遍历训练集inputs, labels = dataoptimizer.zero_grad()    # 梯度清零outputs = model(inputs)  # 神经网络前向传播loss = criterion(outputs, labels)    # 计算损失loss.backward()         # 反向传播optimizer.step()        # 更新参数running_loss += loss.item() # 累加损失loss = running_loss/len(trainset) # 打印Lossprint(f'Epoch {epoch+1}, Loss: {loss}') print('Finished Training')

这篇关于卷积神经网络(CNN)使用PyTorch实现卷积神经网络对CIFAR-10数据集进行图片分类(代码➕注释)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/679583

相关文章

SpringBoot中SM2公钥加密、私钥解密的实现示例详解

《SpringBoot中SM2公钥加密、私钥解密的实现示例详解》本文介绍了如何在SpringBoot项目中实现SM2公钥加密和私钥解密的功能,通过使用Hutool库和BouncyCastle依赖,简化... 目录一、前言1、加密信息(示例)2、加密结果(示例)二、实现代码1、yml文件配置2、创建SM2工具

Mysql实现范围分区表(新增、删除、重组、查看)

《Mysql实现范围分区表(新增、删除、重组、查看)》MySQL分区表的四种类型(范围、哈希、列表、键值),主要介绍了范围分区的创建、查询、添加、删除及重组织操作,具有一定的参考价值,感兴趣的可以了解... 目录一、mysql分区表分类二、范围分区(Range Partitioning1、新建分区表:2、分

MySQL 定时新增分区的实现示例

《MySQL定时新增分区的实现示例》本文主要介绍了通过存储过程和定时任务实现MySQL分区的自动创建,解决大数据量下手动维护的繁琐问题,具有一定的参考价值,感兴趣的可以了解一下... mysql创建好分区之后,有时候会需要自动创建分区。比如,一些表数据量非常大,有些数据是热点数据,按照日期分区MululbU

Spring IoC 容器的使用详解(最新整理)

《SpringIoC容器的使用详解(最新整理)》文章介绍了Spring框架中的应用分层思想与IoC容器原理,通过分层解耦业务逻辑、数据访问等模块,IoC容器利用@Component注解管理Bean... 目录1. 应用分层2. IoC 的介绍3. IoC 容器的使用3.1. bean 的存储3.2. 方法注

MySQL 删除数据详解(最新整理)

《MySQL删除数据详解(最新整理)》:本文主要介绍MySQL删除数据的相关知识,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录一、前言二、mysql 中的三种删除方式1.DELETE语句✅ 基本语法: 示例:2.TRUNCATE语句✅ 基本语

MySQL中查找重复值的实现

《MySQL中查找重复值的实现》查找重复值是一项常见需求,比如在数据清理、数据分析、数据质量检查等场景下,我们常常需要找出表中某列或多列的重复值,具有一定的参考价值,感兴趣的可以了解一下... 目录技术背景实现步骤方法一:使用GROUP BY和HAVING子句方法二:仅返回重复值方法三:返回完整记录方法四:

Python内置函数之classmethod函数使用详解

《Python内置函数之classmethod函数使用详解》:本文主要介绍Python内置函数之classmethod函数使用方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地... 目录1. 类方法定义与基本语法2. 类方法 vs 实例方法 vs 静态方法3. 核心特性与用法(1编程客

IDEA中新建/切换Git分支的实现步骤

《IDEA中新建/切换Git分支的实现步骤》本文主要介绍了IDEA中新建/切换Git分支的实现步骤,通过菜单创建新分支并选择是否切换,创建后在Git详情或右键Checkout中切换分支,感兴趣的可以了... 前提:项目已被Git托管1、点击上方栏Git->NewBrancjsh...2、输入新的分支的

Linux中压缩、网络传输与系统监控工具的使用完整指南

《Linux中压缩、网络传输与系统监控工具的使用完整指南》在Linux系统管理中,压缩与传输工具是数据备份和远程协作的桥梁,而系统监控工具则是保障服务器稳定运行的眼睛,下面小编就来和大家详细介绍一下它... 目录引言一、压缩与解压:数据存储与传输的优化核心1. zip/unzip:通用压缩格式的便捷操作2.

Python实现对阿里云OSS对象存储的操作详解

《Python实现对阿里云OSS对象存储的操作详解》这篇文章主要为大家详细介绍了Python实现对阿里云OSS对象存储的操作相关知识,包括连接,上传,下载,列举等功能,感兴趣的小伙伴可以了解下... 目录一、直接使用代码二、详细使用1. 环境准备2. 初始化配置3. bucket配置创建4. 文件上传到os