基于VGG16的猫狗数据集分类

2024-06-04 00:44
文章标签 数据 分类 vgg16

本文主要是介绍基于VGG16的猫狗数据集分类,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

目录

  • 1. 作者介绍
  • 2. VGG16介绍
    • 2.1 背景介绍
    • 2.2 VGG16 结构
  • 3. Cat VS Dog数据集介绍
  • 4. 实验过程
    • 4.1 数据集处理
    • 4.2 训练部分设置
    • 4.3 训练结果
    • 4.4 问题分析
    • 4.5 单张图片测试
  • 5.完整训练代码与权重
  • 参考文献


1. 作者介绍

孙思伟,男,西安工程大学电子信息学院,2023级研究生
研究方向:深度强化学习与人工智能
电子邮件:sunsiwei0109@163.com

2. VGG16介绍

2.1 背景介绍

VGG16 是由牛津大学的K. Simonyan 和A. Zisserman 在论文“Very Deep Convolutional Networks for Large-Scale Image Recognition”中提出的卷积神经网络模型。该模型在 ImageNet 中实现了 92.7% 的前 5 名测试准确率,ImageNet 是一个包含 1000 个类别的 1400 多万张图像的数据集。它是提交给ILSVRC-2014的著名模型之一。它通过将大型内核大小的滤波器(第一层和第二卷积层分别为 11 个和 5 个)替换为多个 3×3 个内核大小的滤波器来改进 AlexNet。
论文地址: Very Deep Convolutional Networks for Large-Scale Image Recognition

2.2 VGG16 结构

VGG16 是一个流行的卷积神经网络(CNN)架构,主要用于图像识别和处理。它由多个卷积层、激活层、池化层和全连接层组成,图1为VGG网络结构图,具体如下:
在这里插入图片描述

  1. 输入层:接受大小为 224 x 224 x 3 的图像,这意味着每个图像有 224 x 224 像素,每个像素有 3 个颜色通道(红、绿、蓝)。
  2. 卷积层和ReLU激活层
    • 前两个卷积层各有 64 个过滤器,大小均为 3x3,步长为 1,通过ReLU激活函数进行非线性处理。
    • 接下来是两个卷积层,每层 128 个过滤器,过滤器大小和步长不变,同样使用ReLU激活。
    • 然后是三个卷积层,每层有 256 个过滤器。
    • 最后是三组卷积层,每组包含三个卷积层,每层有 512 个过滤器。
  3. 最大池化层(Max Pooling)
    • 每几个卷积层后会跟一个最大池化层,用于降低空间尺寸(特征图的高度和宽度),池化窗口大小通常为 2x2,步长为 2。
  4. 全连接层
    • 网络末端有三个全连接层。前两个全连接层各有 4096 个神经元,并使用ReLU激活函数。
    • 最后一个全连接层有 1000 个输出单元,对应于 1000 个类别的分类任务。
  5. Softmax层
    • 最终输出通过softmax层,该层将网络输出转换为概率分布,用于多类别分类。

VGG16 的特点是使用了很多堆叠的小卷积核(3x3),这样设计的好处是可以在保持感受野的情况下减少参数数量,提高网络的深度以学习更复杂的特征。整个网络架构简单但效果显著,适用于各种图像识别任务。

3. Cat VS Dog数据集介绍

猫狗大战(Cats vs. Dogs)数据集是一个用于二分类的图像识别任务的公开数据集,其目的是让机器学习模型能够区分图像是猫还是狗。
数据集最初由Kaggle在2013年为一个竞赛提供,现在已经成为计算机视觉和深度学习领域中一个非常流行的入门级数据集。图2是数据集内容。
数据集下载地址:Cats-vs-Dogs

Cats vs. Dogs数据集

数据集特点

  1. Number of data sets

    • 数据集包含25,000张彩色图像,其中猫和狗各占一半,即12,500张猫的图片和12,500张狗的图片。
  2. Data set size

    • 图片的大小不一,格式为.JPG。这为处理图像时的预处理步骤(如缩放、裁剪)提供了实际的挑战。
  3. Data set label

    • 每张图片都明确标记为“猫”或“狗”,这简化了监督学习模型的训练过程。

应用方向

  • 训练深度学习模型:这个数据集广泛用于训练各种深度学习模型,如卷积神经网络(CNN),来识别和分类图像内容。
  • 模型评估和比较:数据集也常用于评估不同模型和算法的性能,以及比较各种优化技术的效果。

挑战

  • 图像多样性:由于数据集中的图像在姿势、大小、背景和光照条件等方面都有很大的不同,这给模型的泛化能力带来了挑战。
  • 不均匀的图像质量:部分图像可能质量较低,或者包含干扰的背景元素,这需要更复杂的数据预处理或更强大的模型来克服。

4. 实验过程

4.1 数据集处理

  1. 安装并导入需要的模块:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader
import os
import shutil
import numpy as np
from concurrent.futures import ThreadPoolExecutor
import matplotlib.pyplot as plt
  1. [2] 数据集划分
    我这里下载的数据集路径格式为:
    Dog_Cat
    ├─Dog
    └─Cat
    需要划分并将数据集设置为:
    Dog_Cat
    ├─train
    │ ├─cat
    │ └─dog
    └─valid
    ├─cat
    └─dog
#数据集划分
def create_train_valid_dirs(base_dir):#创建训练集和验证集的目录结构classes = ['dog', 'cat']for cls in classes:os.makedirs(os.path.join(base_dir, 'train', cls), exist_ok=True)os.makedirs(os.path.join(base_dir, 'valid', cls), exist_ok=True)def copy_files(src_files, target_dir):#将文件列表复制到目标目录for src in src_files:shutil.copy(src, target_dir)def split_data(source_dir, train_dir, valid_dir, valid_ratio=0.2):#划分数据到训练集和验证集classes = ['dog', 'cat']with ThreadPoolExecutor() as executor:for cls in classes:class_dir = os.path.join(source_dir, cls)files = [entry.path for entry in os.scandir(class_dir) if entry.is_file()]np.random.shuffle(files)split_point = int(len(files) * (1 - valid_ratio))train_files = files[:split_point]valid_files = files[split_point:]# 并行复制文件到训练和验证目录executor.submit(copy_files, train_files, os.path.join(train_dir, cls))executor.submit(copy_files, valid_files, os.path.join(valid_dir, cls))
# 主目录
base_dir = 'D:\\Desktop\\Datasets\\Dog_Cat'
source_dir = base_dir
train_dir = os.path.join(base_dir, 'train')
valid_dir = os.path.join(base_dir, 'valid')
# 创建目录结构并划分数据
create_train_valid_dirs(base_dir)
split_data(source_dir, train_dir, valid_dir)

4.2 训练部分设置

  1. 检查设备是否有CUDA以及导入数据集
# 检查CUDA是否可用,否则使用CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)  # 打印出使用的设备,GPU还是CPU# 定义训练数据的变换流程
train_transform = transforms.Compose([transforms.RandomResizedCrop(224),  # 随机大小、比例裁剪图像到224x224transforms.RandomHorizontalFlip(),  # 随机水平翻转图像transforms.ToTensor(),              # 将图片转换为Tensortransforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 标准化处理,使用ImageNet的均值和标准差
])# 定义验证数据的变换流程
valid_transform = transforms.Compose([transforms.Resize(256),             # 将图片大小调整为256x256transforms.CenterCrop(224),         # 中心裁剪到224x224transforms.ToTensor(),              # 将图片转换为Tensortransforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 与训练集相同的标准化处理
])# 加载训练数据集,并应用预定义的变换
train_dataset = datasets.ImageFolder(root=r'D:\Desktop\Datasets\Dog_Cat\train', transform=train_transform)# 加载验证数据集,并应用预定义的变换
valid_dataset = datasets.ImageFolder(root=r'D:\Desktop\Datasets\Dog_Cat\valid', transform=valid_transform)# 定义训练数据加载器,设置批处理大小为32,并启用随机洗牌
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)# 定义验证数据加载器,设置批处理大小为32,不进行洗牌
valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False)
  1. 网络结构

第一部分卷积:
在这里插入图片描述

# 定义VGG16模型结构
class Vgg16_net(nn.Module):def __init__(self,num_classes= 2):super(Vgg16_net, self).__init__()# 第一层卷积层self.layer1 = nn.Sequential(# 输入3通道图像,输出64通道特征图,卷积核大小3x3,步长1,填充1nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),# 对64通道特征图进行Batch Normalizationnn.BatchNorm2d(64),# 对64通道特征图进行ReLU激活函数nn.ReLU(inplace=True),# 输入64通道特征图,输出64通道特征图,卷积核大小3x3,步长1,填充1nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),# 对64通道特征图进行Batch Normalizationnn.BatchNorm2d(64),# 对64通道特征图进行ReLU激活函数nn.ReLU(inplace=True),# 进行2x2的最大池化操作,步长为2nn.MaxPool2d(kernel_size=2, stride=2))

第二部分卷积
在这里插入图片描述

        # 第二层卷积层self.layer2 = nn.Sequential(# 输入64通道特征图,输出128通道特征图,卷积核大小3x3,步长1,填充1nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),# 对128通道特征图进行Batch Normalizationnn.BatchNorm2d(128),# 对128通道特征图进行ReLU激活函数nn.ReLU(inplace=True),# 输入128通道特征图,输出128通道特征图,卷积核大小3x3,步长1,填充1nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),# 对128通道特征图进行Batch Normalizationnn.BatchNorm2d(128),nn.ReLU(inplace=True),# 进行2x2的最大池化操作,步长为2nn.MaxPool2d(2, 2))

第三部分卷积
在这里插入图片描述

        # 第三层卷积层self.layer3 = nn.Sequential(# 输入为128通道,输出为256通道,卷积核大小为33,步长为1,填充大小为1nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),# 批归一化nn.BatchNorm2d(256),nn.ReLU(inplace=True),nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(256),nn.ReLU(inplace=True),nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(256),nn.ReLU(inplace=True),nn.MaxPool2d(2, 2))

第四部分卷积
在这里插入图片描述

        self.layer4 = nn.Sequential(nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(512),nn.ReLU(inplace=True),nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(512),nn.ReLU(inplace=True),nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(512),nn.ReLU(inplace=True),nn.MaxPool2d(2, 2))

第五部分卷积
在这里插入图片描述

        self.layer5 = nn.Sequential(nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(512),nn.ReLU(inplace=True),nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(512),nn.ReLU(inplace=True),nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(512),nn.ReLU(inplace=True),nn.MaxPool2d(2, 2))

第六部分 全连接层
在这里插入图片描述

        self.fc = nn.Sequential(nn.Linear(512*7*7, 512),nn.ReLU(inplace=True),nn.Dropout(0.5),nn.Linear(512, 256),nn.ReLU(inplace=True),nn.Dropout(0.5),nn.Linear(256, num_classes))

forward定义

        self.conv = nn.Sequential(self.layer1,self.layer2,self.layer3,self.layer4,self.layer5)def forward(self, x):x = self.conv(x)# 对张量的拉平(flatten)操作,即将卷积层输出的张量转化为二维,全连接的输入尺寸为512x = x.view(x.size(0), -1)x = self.fc(x)return x
  1. 调用构建的网络,定义损失函数与优化器
model = Vgg16_net().to(device)
# 损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
  1. 训练部分,在服务器上训练,最后将图保存,需要直接显示就取消注释最后一行代码
def train_model(num_epochs,initial_weights = None):if initial_weights is not None:model.load_state_dict(torch.load(initial_weights))print(f"Loaded weights from {initial_weights}")train_losses = []validation_accuracies = []for epoch in range(num_epochs):model.train()train_loss = 0for images, labels in train_loader:images, labels = images.to(device), labels.to(device)optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()train_loss += loss.item()epoch_loss = train_loss / len(train_loader)train_losses.append(epoch_loss)print(f'Epoch {epoch+1}, Train Loss: {epoch_loss}')torch.save(model.state_dict(), f'model_epoch_{epoch+1}.pth')# 验证模型model.eval()correct = 0total = 0with torch.no_grad():for images, labels in valid_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = 100 * correct / totalvalidation_accuracies.append(accuracy)print(f'Validation Accuracy: {accuracy}%')# 绘图plt.figure(figsize=(10, 5))plt.subplot(1, 2, 1)plt.plot(range(1, num_epochs + 1), train_losses, label='Train Loss')plt.xlabel('Epochs')plt.ylabel('Loss')plt.title('Training Loss Over Epochs')plt.legend()plt.subplot(1, 2, 2)plt.plot(range(1, num_epochs + 1), validation_accuracies, label='Validation Accuracy', color='r')plt.xlabel('Epochs')plt.ylabel('Accuracy (%)')plt.title('Validation Accuracy Over Epochs')plt.legend()plt.savefig('training_validation_metrics.png')# plt.show()  # 这行被注释掉,以避免在无GUI环境中尝试显示图像
# 开始训练
train_model(100)

4.3 训练结果

在这里插入图片描述

请添加图片描述结果展示了模型在100个epochs(训练周期)中的训练损失(Training Loss)和验证准确率(Validation Accuracy)的变化情况。

左图:Training Loss Over Epochs
横轴(X轴)表示训练的epochs(训练周期数),从0到100。
纵轴(Y轴)表示训练损失(Loss)。
曲线展示了训练损失随着训练周期数增加而逐渐减少的趋势。最初的训练损失大约在0.8左右,随着训练的进行,损失持续下降,到训练结束时降至接近0.2左右。
说明模型在训练过程中逐渐学习和优化,其性能在不断提高。

右图:Validation Accuracy Over Epochs
横轴(X轴)表示训练的epochs(训练周期数),从0到100。
纵轴(Y轴)表示验证准确率(Accuracy),百分比表示。
曲线展示了验证准确率随着训练周期数增加而逐渐提高的趋势。最初的验证准确率大约在50%左右,随着训练的进行,准确率不断上升,到训练结束时达到接近90%左右。
验证准确率曲线在训练初期上升较快,随后逐渐趋于平稳,最终在高水平上稍微波动。这表明模型在验证集上的表现也在逐渐提高,达到一个较高的准确率水平。

总体而言,这个训练结果表明模型的训练过程是成功的。训练损失持续下降,验证准确率持续上升,最终模型在验证集上的准确率接近90%,说明模型在图像分类任务上有很好的表现

4.4 问题分析

训练过程中损失居然上升了,可能是过拟合,或者是陷入了局部最优,并且net已经加了Dorpout。
所以,这里加上一个学习率调度器试试
在这里插入图片描述

在这里插入图片描述
可以看到,曲线平缓了很多,且损失没有大幅上升

4.5 单张图片测试

  1. 调用前面的网络构建模型,再将训练好的权重load进来,用Jupyter会显示网络架构,其他不显示也没关系
#测试模型
model_test = Vgg16_net()
model_test.load_state_dict(torch.load('model_epoch_100.pth'))
model_test.eval() 

Vgg16_net(
(layer1): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
(6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(layer2): Sequential(
(0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
(6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)

  1. 图像处理
from PIL import Imagetransform = transforms.Compose([transforms.Resize((224, 224)),  # 假设使用224x224输入transforms.ToTensor(),          # 将图片转换为Tensortransforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # 归一化
])
  1. 设置图片路径,将图片读取后增加一个batch维度,因为DataLoader在获取数据时会有batch维度,之后将数据传入GPU,用模型预测类别
# 加载图片
img1 = Image.open(r'C:\Users\Administrator\Pictures\cat1.png')
img2 = Image.open(r'C:\Users\Administrator\Pictures\dog1.png')#img2 = img2.convert('RGB')  # 确保图片是三通道的
#img.show()
# 应用预处理
img_tensor1 = transform(img1).unsqueeze(0)  # 增加批次维度
img_tensor2 = transform(img2).unsqueeze(0)  # 增加批次维度
#img_tensor.show()
# 确保使用与训练相同的设备
model_test = model_test.to(device)
img_tensor1 = img_tensor1.to(device)
img_tensor2 = img_tensor2.to(device)# 前向传播获取输出
with torch.no_grad():outputs1 = model_test(img_tensor1)outputs2 = model_test(img_tensor2)
classes = ('cat', 'dog')# 获取预测结果
_, predicted1 = torch.max(outputs1, 1)
_, predicted2 = torch.max(outputs2, 1)
print("Predicted class index:", classes[predicted1.item()])
print("Predicted class index:", classes[predicted2.item()])
display(img2)
  1. 测试结果
    在这里插入图片描述

5.完整训练代码与权重

在这里插入图片描述

链接:https://pan.baidu.com/s/1UW2aWyF8cRrf_tbaslJihw?pwd=18uk
提取码:18uk
–来自百度网盘超级会员V7的分享

Tips 数据集过大,这里是在3090上训练得到的权重,共训练了一天,如有需要,可以直接下载训练好的权重

参考文献

[1]深度学习12. CNN经典网络 VGG16 - 知乎
[2]VGGNet-16 Architecture: A Complete Guide
[3]pytorch-vgg16

这篇关于基于VGG16的猫狗数据集分类的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1028580

相关文章

SQL Server修改数据库名及物理数据文件名操作步骤

《SQLServer修改数据库名及物理数据文件名操作步骤》在SQLServer中重命名数据库是一个常见的操作,但需要确保用户具有足够的权限来执行此操作,:本文主要介绍SQLServer修改数据... 目录一、背景介绍二、操作步骤2.1 设置为单用户模式(断开连接)2.2 修改数据库名称2.3 查找逻辑文件名

canal实现mysql数据同步的详细过程

《canal实现mysql数据同步的详细过程》:本文主要介绍canal实现mysql数据同步的详细过程,本文通过实例图文相结合给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的... 目录1、canal下载2、mysql同步用户创建和授权3、canal admin安装和启动4、canal

使用SpringBoot整合Sharding Sphere实现数据脱敏的示例

《使用SpringBoot整合ShardingSphere实现数据脱敏的示例》ApacheShardingSphere数据脱敏模块,通过SQL拦截与改写实现敏感信息加密存储,解决手动处理繁琐及系统改... 目录痛点一:痛点二:脱敏配置Quick Start——Spring 显示配置:1.引入依赖2.创建脱敏

详解如何使用Python构建从数据到文档的自动化工作流

《详解如何使用Python构建从数据到文档的自动化工作流》这篇文章将通过真实工作场景拆解,为大家展示如何用Python构建自动化工作流,让工具代替人力完成这些数字苦力活,感兴趣的小伙伴可以跟随小编一起... 目录一、Excel处理:从数据搬运工到智能分析师二、PDF处理:文档工厂的智能生产线三、邮件自动化:

Python数据分析与可视化的全面指南(从数据清洗到图表呈现)

《Python数据分析与可视化的全面指南(从数据清洗到图表呈现)》Python是数据分析与可视化领域中最受欢迎的编程语言之一,凭借其丰富的库和工具,Python能够帮助我们快速处理、分析数据并生成高质... 目录一、数据采集与初步探索二、数据清洗的七种武器1. 缺失值处理策略2. 异常值检测与修正3. 数据

pandas实现数据concat拼接的示例代码

《pandas实现数据concat拼接的示例代码》pandas.concat用于合并DataFrame或Series,本文主要介绍了pandas实现数据concat拼接的示例代码,具有一定的参考价值,... 目录语法示例:使用pandas.concat合并数据默认的concat:参数axis=0,join=

C#代码实现解析WTGPS和BD数据

《C#代码实现解析WTGPS和BD数据》在现代的导航与定位应用中,准确解析GPS和北斗(BD)等卫星定位数据至关重要,本文将使用C#语言实现解析WTGPS和BD数据,需要的可以了解下... 目录一、代码结构概览1. 核心解析方法2. 位置信息解析3. 经纬度转换方法4. 日期和时间戳解析5. 辅助方法二、L

使用Python和Matplotlib实现可视化字体轮廓(从路径数据到矢量图形)

《使用Python和Matplotlib实现可视化字体轮廓(从路径数据到矢量图形)》字体设计和矢量图形处理是编程中一个有趣且实用的领域,通过Python的matplotlib库,我们可以轻松将字体轮廓... 目录背景知识字体轮廓的表示实现步骤1. 安装依赖库2. 准备数据3. 解析路径指令4. 绘制图形关键

解决mysql插入数据锁等待超时报错:Lock wait timeout exceeded;try restarting transaction

《解决mysql插入数据锁等待超时报错:Lockwaittimeoutexceeded;tryrestartingtransaction》:本文主要介绍解决mysql插入数据锁等待超时报... 目录报错信息解决办法1、数据库中执行如下sql2、再到 INNODB_TRX 事务表中查看总结报错信息Lock

使用C#删除Excel表格中的重复行数据的代码详解

《使用C#删除Excel表格中的重复行数据的代码详解》重复行是指在Excel表格中完全相同的多行数据,删除这些重复行至关重要,因为它们不仅会干扰数据分析,还可能导致错误的决策和结论,所以本文给大家介绍... 目录简介使用工具C# 删除Excel工作表中的重复行语法工作原理实现代码C# 删除指定Excel单元