使用pytorch构建ResNet50模型训练猫狗数据集

2024-05-30 06:04

本文主要是介绍使用pytorch构建ResNet50模型训练猫狗数据集,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

数据集

1.导包

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm.auto import tqdm  # 引入tqdm库以显示进度条

2.数据预处理

ResNet50模型适合的图片大小为224x244

# 定义数据转换
data_transforms = {'train': transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'test': transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
}

3.加载数据集和模型构建

# 加载数据集
data_dir = 'catdog_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),data_transforms[x])for x in ['train', 'test']}
dataloaders = {x: DataLoader(image_datasets[x], batch_size=4,shuffle=True, num_workers=4)for x in ['train', 'test']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'test']}
class_names = image_datasets['train'].classes# 加载ResNet-50模型
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)# 替换最后的全连接层以适配我们的分类问题
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(class_names))# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

4.训练

# 训练次数
num_epochs = 10# 初始化训练次数计数器
train_count = 0
for epoch in range(num_epochs):  # num_epochs 是你希望训练的轮数for phase in ['train', 'test']:if phase == 'train':model.train()else:model.eval()running_loss = 0.0running_corrects = 0# 使用tqdm显示进度条with tqdm(total=len(dataloaders[phase]), desc=f'Epoch {epoch+1}/{num_epochs}', leave=False) as progress_bar:for inputs, labels in dataloaders[phase]:optimizer.zero_grad()with torch.set_grad_enabled(phase == 'train'):outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)if phase == 'train':loss.backward()optimizer.step()running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)epoch_loss = running_loss / dataset_sizes[phase]epoch_acc = running_corrects.double() / dataset_sizes[phase]progress_bar.set_postfix(loss=epoch_loss, acc=epoch_acc)progress_bar.update(1)print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')# 更新训练次数计数器train_count += 1print(f'Training Count: {train_count}')

训练过程

5.预测

import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt# 定义模型的类别数量
num_classes = 2# 加载模型
model = torchvision.models.resnet50(pretrained=False)
# 修改模型的fc层以匹配训练时的结构
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
# 加载保存的权重
model.load_state_dict(torch.load('mg_ResNet50model.pth'))
model.eval()# 图像预处理
preprocess = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])# 测试图片
img_path = 'mao_1.jpg'  # 替换为你的图片路径
img = Image.open(img_path)
img_t = preprocess(img)# 扩展维度,因为模型需要4维输入(Batch, Channels, Height, Width)
batch_t = torch.unsqueeze(img_t, 0)# 预测
with torch.no_grad():out = model(batch_t)# 获取最高分数的类别
_, index = torch.max(out, 1)# 可视化结果
plt.imshow(img)
plt.title(f'Predicted: {index.item()}')
plt.show()

预测效果

0就是猫咪,1就是小狗

全部代码

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm.auto import tqdm  # 引入tqdm库以显示进度条# 定义数据转换
data_transforms = {'train': transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'test': transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
}# 加载数据集
data_dir = 'catdog_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),data_transforms[x])for x in ['train', 'test']}
dataloaders = {x: DataLoader(image_datasets[x], batch_size=4,shuffle=True, num_workers=4)for x in ['train', 'test']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'test']}
class_names = image_datasets['train'].classes# 加载ResNet-50模型
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)# 替换最后的全连接层以适配我们的分类问题
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(class_names))# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)# 训练次数
num_epochs = 10# 初始化训练次数计数器
train_count = 0
for epoch in range(num_epochs):  # num_epochs 是你希望训练的轮数for phase in ['train', 'test']:if phase == 'train':model.train()else:model.eval()running_loss = 0.0running_corrects = 0# 使用tqdm显示进度条with tqdm(total=len(dataloaders[phase]), desc=f'Epoch {epoch+1}/{num_epochs}', leave=False) as progress_bar:for inputs, labels in dataloaders[phase]:optimizer.zero_grad()with torch.set_grad_enabled(phase == 'train'):outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)if phase == 'train':loss.backward()optimizer.step()running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)epoch_loss = running_loss / dataset_sizes[phase]epoch_acc = running_corrects.double() / dataset_sizes[phase]progress_bar.set_postfix(loss=epoch_loss, acc=epoch_acc)progress_bar.update(1)print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')# 更新训练次数计数器train_count += 1print(f'Training Count: {train_count}')

这篇关于使用pytorch构建ResNet50模型训练猫狗数据集的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1015768

相关文章

SpringBoot使用ffmpeg实现视频压缩

《SpringBoot使用ffmpeg实现视频压缩》FFmpeg是一个开源的跨平台多媒体处理工具集,用于录制,转换,编辑和流式传输音频和视频,本文将使用ffmpeg实现视频压缩功能,有需要的可以参考... 目录核心功能1.格式转换2.编解码3.音视频处理4.流媒体支持5.滤镜(Filter)安装配置linu

Redis中的Lettuce使用详解

《Redis中的Lettuce使用详解》Lettuce是一个高级的、线程安全的Redis客户端,用于与Redis数据库交互,Lettuce是一个功能强大、使用方便的Redis客户端,适用于各种规模的J... 目录简介特点连接池连接池特点连接池管理连接池优势连接池配置参数监控常用监控工具通过JMX监控通过Pr

解决mysql插入数据锁等待超时报错:Lock wait timeout exceeded;try restarting transaction

《解决mysql插入数据锁等待超时报错:Lockwaittimeoutexceeded;tryrestartingtransaction》:本文主要介绍解决mysql插入数据锁等待超时报... 目录报错信息解决办法1、数据库中执行如下sql2、再到 INNODB_TRX 事务表中查看总结报错信息Lock

apache的commons-pool2原理与使用实践记录

《apache的commons-pool2原理与使用实践记录》ApacheCommonsPool2是一个高效的对象池化框架,通过复用昂贵资源(如数据库连接、线程、网络连接)优化系统性能,这篇文章主... 目录一、核心原理与组件二、使用步骤详解(以数据库连接池为例)三、高级配置与优化四、典型应用场景五、注意事

使用Python实现Windows系统垃圾清理

《使用Python实现Windows系统垃圾清理》Windows自带的磁盘清理工具功能有限,无法深度清理各类垃圾文件,所以本文为大家介绍了如何使用Python+PyQt5开发一个Windows系统垃圾... 目录一、开发背景与工具概述1.1 为什么需要专业清理工具1.2 工具设计理念二、工具核心功能解析2.

Linux系统之stress-ng测压工具的使用

《Linux系统之stress-ng测压工具的使用》:本文主要介绍Linux系统之stress-ng测压工具的使用,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、理论1.stress工具简介与安装2.语法及参数3.具体安装二、实验1.运行8 cpu, 4 fo

Java使用MethodHandle来替代反射,提高性能问题

《Java使用MethodHandle来替代反射,提高性能问题》:本文主要介绍Java使用MethodHandle来替代反射,提高性能问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑... 目录一、认识MethodHandle1、简介2、使用方式3、与反射的区别二、示例1、基本使用2、(重要)

使用C#删除Excel表格中的重复行数据的代码详解

《使用C#删除Excel表格中的重复行数据的代码详解》重复行是指在Excel表格中完全相同的多行数据,删除这些重复行至关重要,因为它们不仅会干扰数据分析,还可能导致错误的决策和结论,所以本文给大家介绍... 目录简介使用工具C# 删除Excel工作表中的重复行语法工作原理实现代码C# 删除指定Excel单元

Linux lvm实例之如何创建一个专用于MySQL数据存储的LVM卷组

《Linuxlvm实例之如何创建一个专用于MySQL数据存储的LVM卷组》:本文主要介绍使用Linux创建一个专用于MySQL数据存储的LVM卷组的实例,具有很好的参考价值,希望对大家有所帮助,... 目录在Centos 7上创建卷China编程组并配置mysql数据目录1. 检查现有磁盘2. 创建物理卷3. 创

SpringBoot整合Sa-Token实现RBAC权限模型的过程解析

《SpringBoot整合Sa-Token实现RBAC权限模型的过程解析》:本文主要介绍SpringBoot整合Sa-Token实现RBAC权限模型的过程解析,本文给大家介绍的非常详细,对大家的学... 目录前言一、基础概念1.1 RBAC模型核心概念1.2 Sa-Token核心功能1.3 环境准备二、表结