使用pytorch构建ResNet50模型训练猫狗数据集

2024-05-30 06:04

本文主要是介绍使用pytorch构建ResNet50模型训练猫狗数据集,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

数据集

1.导包

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm.auto import tqdm  # 引入tqdm库以显示进度条

2.数据预处理

ResNet50模型适合的图片大小为224x244

# 定义数据转换
data_transforms = {'train': transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'test': transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
}

3.加载数据集和模型构建

# 加载数据集
data_dir = 'catdog_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),data_transforms[x])for x in ['train', 'test']}
dataloaders = {x: DataLoader(image_datasets[x], batch_size=4,shuffle=True, num_workers=4)for x in ['train', 'test']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'test']}
class_names = image_datasets['train'].classes# 加载ResNet-50模型
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)# 替换最后的全连接层以适配我们的分类问题
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(class_names))# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

4.训练

# 训练次数
num_epochs = 10# 初始化训练次数计数器
train_count = 0
for epoch in range(num_epochs):  # num_epochs 是你希望训练的轮数for phase in ['train', 'test']:if phase == 'train':model.train()else:model.eval()running_loss = 0.0running_corrects = 0# 使用tqdm显示进度条with tqdm(total=len(dataloaders[phase]), desc=f'Epoch {epoch+1}/{num_epochs}', leave=False) as progress_bar:for inputs, labels in dataloaders[phase]:optimizer.zero_grad()with torch.set_grad_enabled(phase == 'train'):outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)if phase == 'train':loss.backward()optimizer.step()running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)epoch_loss = running_loss / dataset_sizes[phase]epoch_acc = running_corrects.double() / dataset_sizes[phase]progress_bar.set_postfix(loss=epoch_loss, acc=epoch_acc)progress_bar.update(1)print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')# 更新训练次数计数器train_count += 1print(f'Training Count: {train_count}')

训练过程

5.预测

import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt# 定义模型的类别数量
num_classes = 2# 加载模型
model = torchvision.models.resnet50(pretrained=False)
# 修改模型的fc层以匹配训练时的结构
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
# 加载保存的权重
model.load_state_dict(torch.load('mg_ResNet50model.pth'))
model.eval()# 图像预处理
preprocess = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])# 测试图片
img_path = 'mao_1.jpg'  # 替换为你的图片路径
img = Image.open(img_path)
img_t = preprocess(img)# 扩展维度,因为模型需要4维输入(Batch, Channels, Height, Width)
batch_t = torch.unsqueeze(img_t, 0)# 预测
with torch.no_grad():out = model(batch_t)# 获取最高分数的类别
_, index = torch.max(out, 1)# 可视化结果
plt.imshow(img)
plt.title(f'Predicted: {index.item()}')
plt.show()

预测效果

0就是猫咪,1就是小狗

全部代码

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm.auto import tqdm  # 引入tqdm库以显示进度条# 定义数据转换
data_transforms = {'train': transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'test': transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
}# 加载数据集
data_dir = 'catdog_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),data_transforms[x])for x in ['train', 'test']}
dataloaders = {x: DataLoader(image_datasets[x], batch_size=4,shuffle=True, num_workers=4)for x in ['train', 'test']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'test']}
class_names = image_datasets['train'].classes# 加载ResNet-50模型
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)# 替换最后的全连接层以适配我们的分类问题
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(class_names))# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)# 训练次数
num_epochs = 10# 初始化训练次数计数器
train_count = 0
for epoch in range(num_epochs):  # num_epochs 是你希望训练的轮数for phase in ['train', 'test']:if phase == 'train':model.train()else:model.eval()running_loss = 0.0running_corrects = 0# 使用tqdm显示进度条with tqdm(total=len(dataloaders[phase]), desc=f'Epoch {epoch+1}/{num_epochs}', leave=False) as progress_bar:for inputs, labels in dataloaders[phase]:optimizer.zero_grad()with torch.set_grad_enabled(phase == 'train'):outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)if phase == 'train':loss.backward()optimizer.step()running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)epoch_loss = running_loss / dataset_sizes[phase]epoch_acc = running_corrects.double() / dataset_sizes[phase]progress_bar.set_postfix(loss=epoch_loss, acc=epoch_acc)progress_bar.update(1)print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')# 更新训练次数计数器train_count += 1print(f'Training Count: {train_count}')

这篇关于使用pytorch构建ResNet50模型训练猫狗数据集的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1015768

相关文章

MyBatis-Plus通用中等、大量数据分批查询和处理方法

《MyBatis-Plus通用中等、大量数据分批查询和处理方法》文章介绍MyBatis-Plus分页查询处理,通过函数式接口与Lambda表达式实现通用逻辑,方法抽象但功能强大,建议扩展分批处理及流式... 目录函数式接口获取分页数据接口数据处理接口通用逻辑工具类使用方法简单查询自定义查询方法总结函数式接口

C++中assign函数的使用

《C++中assign函数的使用》在C++标准模板库中,std::list等容器都提供了assign成员函数,它比操作符更灵活,支持多种初始化方式,下面就来介绍一下assign的用法,具有一定的参考价... 目录​1.assign的基本功能​​语法​2. 具体用法示例​​​(1) 填充n个相同值​​(2)

Spring StateMachine实现状态机使用示例详解

《SpringStateMachine实现状态机使用示例详解》本文介绍SpringStateMachine实现状态机的步骤,包括依赖导入、枚举定义、状态转移规则配置、上下文管理及服务调用示例,重点解... 目录什么是状态机使用示例什么是状态机状态机是计算机科学中的​​核心建模工具​​,用于描述对象在其生命

使用Python删除Excel中的行列和单元格示例详解

《使用Python删除Excel中的行列和单元格示例详解》在处理Excel数据时,删除不需要的行、列或单元格是一项常见且必要的操作,本文将使用Python脚本实现对Excel表格的高效自动化处理,感兴... 目录开发环境准备使用 python 删除 Excphpel 表格中的行删除特定行删除空白行删除含指定

深入理解Go语言中二维切片的使用

《深入理解Go语言中二维切片的使用》本文深入讲解了Go语言中二维切片的概念与应用,用于表示矩阵、表格等二维数据结构,文中通过示例代码介绍的非常详细,需要的朋友们下面随着小编来一起学习学习吧... 目录引言二维切片的基本概念定义创建二维切片二维切片的操作访问元素修改元素遍历二维切片二维切片的动态调整追加行动态

prometheus如何使用pushgateway监控网路丢包

《prometheus如何使用pushgateway监控网路丢包》:本文主要介绍prometheus如何使用pushgateway监控网路丢包问题,具有很好的参考价值,希望对大家有所帮助,如有错误... 目录监控网路丢包脚本数据图表总结监控网路丢包脚本[root@gtcq-gt-monitor-prome

Python通用唯一标识符模块uuid使用案例详解

《Python通用唯一标识符模块uuid使用案例详解》Pythonuuid模块用于生成128位全局唯一标识符,支持UUID1-5版本,适用于分布式系统、数据库主键等场景,需注意隐私、碰撞概率及存储优... 目录简介核心功能1. UUID版本2. UUID属性3. 命名空间使用场景1. 生成唯一标识符2. 数

SpringBoot中如何使用Assert进行断言校验

《SpringBoot中如何使用Assert进行断言校验》Java提供了内置的assert机制,而Spring框架也提供了更强大的Assert工具类来帮助开发者进行参数校验和状态检查,下... 目录前言一、Java 原生assert简介1.1 使用方式1.2 示例代码1.3 优缺点分析二、Spring Fr

Android kotlin中 Channel 和 Flow 的区别和选择使用场景分析

《Androidkotlin中Channel和Flow的区别和选择使用场景分析》Kotlin协程中,Flow是冷数据流,按需触发,适合响应式数据处理;Channel是热数据流,持续发送,支持... 目录一、基本概念界定FlowChannel二、核心特性对比数据生产触发条件生产与消费的关系背压处理机制生命周期

java使用protobuf-maven-plugin的插件编译proto文件详解

《java使用protobuf-maven-plugin的插件编译proto文件详解》:本文主要介绍java使用protobuf-maven-plugin的插件编译proto文件,具有很好的参考价... 目录protobuf文件作为数据传输和存储的协议主要介绍在Java使用maven编译proto文件的插件