图像分类实战:深度学习在CIFAR-10数据集上的应用

2024-03-30 06:28

本文主要是介绍图像分类实战:深度学习在CIFAR-10数据集上的应用,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1.前言

        图像分类是计算机视觉领域的一个核心任务,算法能够自动识别图像中的物体或场景,并将其归类到预定义的类别中。近年来,深度学习技术的发展极大地推动了图像分类领域的进步。CIFAR-10数据集作为计算机视觉领域的一个经典小型数据集,为研究者提供了一个理想的实验平台,用于验证和比较不同的图像分类算法。本文将介绍CIFAR-10数据集的基本情况和加载方法,并展示如何构建与训练一个卷积神经网络(CNN)模型来进行图像分类,最后对模型的性能进行评估与可视化。

2.数据集介绍与加载

        CIFAR-10数据集由加拿大高等研究院(Canadian Institute for Advanced Research, CIFAR)发布,是计算机视觉领域广泛使用的基准数据集之一。它包含了10个类别(飞机、汽车、鸟类、猫、鹿、狗、青蛙、船、卡车、马)的彩色图像,每类有6,000张图像,共计60,000张。所有图像尺寸统一为32x32像素,且已进行标准化处理,其色彩模式为RGB。数据集被划分为50,000张训练图像和10,000张测试图像,保证了训练集与测试集的均衡分布。

        数据加载

        使用Python的tensorflow.keras.datasets模块加载CIFAR-10数据集,同时进行必要的预处理,如归一化和标签转换。

import tensorflow as tf# 加载CIFAR-10数据集
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()# 数据归一化
x_train, x_test = x_train / 255.0, x_test / 255.0# 将标签转换为one-hot编码
y_train = tf.keras.utils.to_categorical(y_train, num_classes=10)
y_test = tf.keras.utils.to_categorical(y_test, num_classes=10)

3.构建与训练CNN模型

        ResNet(Residual Neural Network)是一种深度残差学习网络,通过引入残差块解决了深度神经网络训练过程中的梯度消失和爆炸问题,从而能够构建和训练极深的模型,显著提升模型的性能和泛化能力。

        关于CNN模型的更多介绍,请看这篇文章:

卷积神经网络(CNN):图像识别的强大工具-CSDN博客文章浏览阅读795次,点赞9次,收藏18次。卷积神经网络是一种强大的图像识别工具,它能够自动学习图像的特征,并在各种图像识别任务中取得出色的效果。通过使用深度学习框架和大量的训练数据,我们可以构建出高效准确的卷积神经网络模型,实现对图像的分类、识别等任务。希望这篇文章能够帮助你更好地理解卷积神经网络在图像识别中的应用。如果你有任何问题或需要进一步的帮助,请随时提问。https://blog.csdn.net/meijinbo/article/details/137015665

3.1.构建模型

        使用Keras构建一个适用于CIFAR-10数据集的小型ResNet模型。

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, Add, MaxPooling2D, GlobalAveragePooling2D, Densedef residual_block(input_tensor, filters, strides=1, use_projection=False):shortcut = input_tensorif use_projection:shortcut = Conv2D(filters, kernel_size=1, strides=strides, padding='valid')(shortcut)shortcut = BatchNormalization()(shortcut)x = Conv2D(filters, kernel_size=3, strides=strides, padding='same')(input_tensor)x = BatchNormalization()(x)x = Activation('relu')(x)x = Conv2D(filters, kernel_size=3, strides=1, padding='same')(x)x = BatchNormalization()(x)if strides != 1 or input_tensor.shape[-1] != filters:shortcut = Conv2D(filters, kernel_size=1, strides=strides, padding='valid')(shortcut)shortcut = BatchNormalization()(shortcut)x = Add()([shortcut, x])x = Activation('relu')(x)return xdef build_resnet():model = Sequential()model.add(Conv2D(16, kernel_size=3, padding='same', input_shape=(32, 32, 3)))model.add(BatchNormalization())model.add(Activation('relu'))for _ in range(2):model.add(residual_block(model.output, 16))model.add(MaxPooling2D(pool_size=(2, 2)))model.add(residual_block(model.output, 32, strides=2, use_projection=True))for _ in range(2):model.add(residual_block(model.output, 32))model.add(GlobalAveragePooling2D())model.add(Dense(10, activation='softmax'))return modelresnet_model = build_resnet()
resnet_model.summary()

3.2.模型训练

        配置模型训练参数,启动训练过程,并监控训练进度。

resnet_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])history = resnet_model.fit(x_train, y_train,batch_size=128,epochs=100,validation_data=(x_test, y_test),verbose=1)

4.模型性能评估与可视化

4.1.性能评估

        评估模型在测试集上的最终性能指标。

test_loss, test_acc = resnet_model.evaluate(x_test, y_test, verbose=2)
print(f'Test accuracy: {test_acc:.4f}')

 4.2.可视化

        绘制训练过程中损失和准确率曲线,以直观了解模型收敛情况与过拟合风险。

import matplotlib.pyplot as pltdef plot_history(history):plt.figure(figsize=(12, 6))plt.subplot(1, 2, 1)plt.plot(history.history['accuracy'], label='Training Accuracy')plt.plot(history.history['val_accuracy'], label='Validation Accuracy')plt.xlabel('Epoch')plt.ylabel('Accuracy')plt.legend()plt.subplot(1, 2, 2)plt.plot(history.history['loss'], label='Training Loss')plt.plot(history.history['val_loss'], label='Validation Loss')plt.xlabel('Epoch')plt.ylabel('Loss')plt.legend()plt.show()plot_history(history)  # 显示训练过程中的准确率与损失曲线

        以下是基于PyTorch的实现:

import torch.nn as nn  
import torch.nn.functional as F  class SimpleCNN(nn.Module):  def __init__(self):  super(SimpleCNN, self).__init__()  self.conv1 = nn.Conv2d(3, 6, 5)  self.pool = nn.MaxPool2d(2, 2)  self.conv2 = nn.Conv2d(6, 16, 5)  self.fc1 = nn.Linear(16 * 5 * 5, 120)  self.fc2 = nn.Linear(120, 84)  self.fc3 = nn.Linear(84, 10)  def forward(self, x):  x = self.pool(F.relu(self.conv1(x)))  x = self.pool(F.relu(self.conv2(x)))  x = x.view(-1, 16 * 5 * 5)  x = F.relu(self.fc1(x))  x = F.relu(self.fc2(x))  x = self.fc3(x)  return x  # 实例化模型、定义损失函数和优化器  
model = SimpleCNN()  
criterion = nn.CrossEntropyLoss()  
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)  # 训练模型  
for epoch in range(2):  # 假设我们训练两个epoch  running_loss = 0.0  for i, data in enumerate(trainloader, 0):  inputs, labels = data  optimizer.zero_grad()  outputs = model(inputs)  loss = criterion(outputs, labels)  loss.backward()  optimizer.step()  running_loss += loss.item()  if i % 2000 == 1999:  # 每2

 5.总结

        通过以上步骤,我们已经完成了在CIFAR-10数据集上使用深度学习进行图像分类的全过程。从数据集的介绍与加载,到构建并训练ResNet模型,再到模型性能的评估与可视化,这一系列操作展示了如何将理论知识应用于实际问题,揭示了深度学习在图像分类任务中的强大能力。实践中,可根据具体需求调整模型结构、优化策略等参数以进一步提升模型性能。

这篇关于图像分类实战:深度学习在CIFAR-10数据集上的应用的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/860725

相关文章

SpringBoot多环境配置数据读取方式

《SpringBoot多环境配置数据读取方式》SpringBoot通过环境隔离机制,支持properties/yaml/yml多格式配置,结合@Value、Environment和@Configura... 目录一、多环境配置的核心思路二、3种配置文件格式详解2.1 properties格式(传统格式)1.

Python标准库之数据压缩和存档的应用详解

《Python标准库之数据压缩和存档的应用详解》在数据处理与存储领域,压缩和存档是提升效率的关键技术,Python标准库提供了一套完整的工具链,下面小编就来和大家简单介绍一下吧... 目录一、核心模块架构与设计哲学二、关键模块深度解析1.tarfile:专业级归档工具2.zipfile:跨平台归档首选3.

SQL Server跟踪自动统计信息更新实战指南

《SQLServer跟踪自动统计信息更新实战指南》本文详解SQLServer自动统计信息更新的跟踪方法,推荐使用扩展事件实时捕获更新操作及详细信息,同时结合系统视图快速检查统计信息状态,重点强调修... 目录SQL Server 如何跟踪自动统计信息更新:深入解析与实战指南 核心跟踪方法1️⃣ 利用系统目录

使用IDEA部署Docker应用指南分享

《使用IDEA部署Docker应用指南分享》本文介绍了使用IDEA部署Docker应用的四步流程:创建Dockerfile、配置IDEADocker连接、设置运行调试环境、构建运行镜像,并强调需准备本... 目录一、创建 dockerfile 配置文件二、配置 IDEA 的 Docker 连接三、配置 Do

解决pandas无法读取csv文件数据的问题

《解决pandas无法读取csv文件数据的问题》本文讲述作者用Pandas读取CSV文件时因参数设置不当导致数据错位,通过调整delimiter和on_bad_lines参数最终解决问题,并强调正确参... 目录一、前言二、问题复现1. 问题2. 通过 on_bad_lines=‘warn’ 跳过异常数据3

深入浅出SpringBoot WebSocket构建实时应用全面指南

《深入浅出SpringBootWebSocket构建实时应用全面指南》WebSocket是一种在单个TCP连接上进行全双工通信的协议,这篇文章主要为大家详细介绍了SpringBoot如何集成WebS... 目录前言为什么需要 WebSocketWebSocket 是什么Spring Boot 如何简化 We

java中pdf模版填充表单踩坑实战记录(itextPdf、openPdf、pdfbox)

《java中pdf模版填充表单踩坑实战记录(itextPdf、openPdf、pdfbox)》:本文主要介绍java中pdf模版填充表单踩坑的相关资料,OpenPDF、iText、PDFBox是三... 目录准备Pdf模版方法1:itextpdf7填充表单(1)加入依赖(2)代码(3)遇到的问题方法2:pd

Java Stream流之GroupBy的用法及应用场景

《JavaStream流之GroupBy的用法及应用场景》本教程将详细介绍如何在Java中使用Stream流的groupby方法,包括基本用法和一些常见的实际应用场景,感兴趣的朋友一起看看吧... 目录Java Stream流之GroupBy的用法1. 前言2. 基础概念什么是 GroupBy?Stream

python中列表应用和扩展性实用详解

《python中列表应用和扩展性实用详解》文章介绍了Python列表的核心特性:有序数据集合,用[]定义,元素类型可不同,支持迭代、循环、切片,可执行增删改查、排序、推导式及嵌套操作,是常用的数据处理... 目录1、列表定义2、格式3、列表是可迭代对象4、列表的常见操作总结1、列表定义是处理一组有序项目的

深度解析Nginx日志分析与499状态码问题解决

《深度解析Nginx日志分析与499状态码问题解决》在Web服务器运维和性能优化过程中,Nginx日志是排查问题的重要依据,本文将围绕Nginx日志分析、499状态码的成因、排查方法及解决方案展开讨论... 目录前言1. Nginx日志基础1.1 Nginx日志存放位置1.2 Nginx日志格式2. 499