ResNet网络(三部曲_3)

2024-08-25 20:20
文章标签 网络 resnet 三部曲

本文主要是介绍ResNet网络(三部曲_3),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

  • 1 网络介绍
  • 2 具体应用
    • 2.1 网络搭建
    • 2.2 网络训练
    • 2.3 模型测试
    • 2.4 小玩意儿

1 网络介绍

Deep Residual Learning for Image Recognition

论文地址:https://arxiv.org/abs/1512.03385

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

2 具体应用

使用resnet50,进行猫狗二分类

2.1 网络搭建

ResNet50.py

# 定义 ResNet50 用于二分类任务的网络模型
from torch import nn
from torchvision import modelsclass ResNet50ForCatDog(nn.Module):def __init__(self):super(ResNet50ForCatDog, self).__init__()# 加载预训练的 ResNet50self.resnet50 = models.resnet50(pretrained=True)# 修改最后一层全连接层以适应二分类任务# num_ftrs是全连接层的输入神经元个数num_ftrs = self.resnet50.fc.in_features# 修改最后一层的全连接层,以符合自己的二分类【猫、狗】需求self.resnet50.fc = nn.Linear(num_ftrs, 2)def forward(self, x):return self.resnet50(x)

可以通过如下代码查看resnet50的结构

from torchvision import models
resnet50 = models.resnet50(pretrained=False)
print(resnet50)

在这里插入图片描述

pretrained=True 这个参数的作用是指示加载在大型数据集(如 ImageNet)上预训练好的模型权重。
使用预训练的模型有以下好处:
节省训练时间:预训练模型已经学习到了通用的图像特征,基于这样的模型进行微调,通常可以比从头开始训练更快地收敛到较好的结果。
利用已有知识:预训练模型在大规模数据上学习到的特征具有一定的通用性和鲁棒性,可以为新的任务提供有价值的初始特征表示。
提高性能:在许多情况下,使用预训练模型并进行适当的微调,可以获得比随机初始化权重并从头训练更好的性能。

2.2 网络训练

TrainResNet.py

import os
import torch
from torch import nn, optim
from Utils import write_to_txt
from ResNet50 import ResNet50ForCatDog
from DrawCurves import plot_metrics
from datetime import datetime
from torchvision import transforms, models
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader# 定义数据集
ROOT_TRAIN = r'train数据集_path'
ROOT_TEST = r'val数据集_path'
# 文档写入路径
WRITER_PATH="./train_process"
# 定义超参数
batch_size = 64
learning_rate = 0.001
num_epochs = 30# 将图像RGB三个通道的像素值分别减去0.5,再除以0.5.从而将所有的像素值固定在[-1,1]范围内
normalize = transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
train_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.RandomVerticalFlip(),  # 随机垂直旋转transforms.ToTensor(),  # 将0-255范围内的像素转为0-1范围内的tensornormalize])val_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),normalize])train_dataset = ImageFolder(ROOT_TRAIN, transform=train_transform)
val_dataset = ImageFolder(ROOT_TEST, transform=val_transform)train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)# 定义训练函数
def train(model, train_loader, criterion, optimizer):model.train()running_loss = 0.0running_corrects = 0.0for inputs, labels in train_loader:inputs = inputs.to(device)labels = labels.to(device)optimizer.zero_grad()outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels)epoch_loss = running_loss / len(train_loader.dataset)epoch_acc = running_corrects / len(train_loader.dataset)print(f"Train Loss: {epoch_loss:.4f}, Train Acc: {epoch_acc:.4f}")write_to_txt(WRITER_PATH,f"Train Loss: {epoch_loss:.4f}, Train Acc: {epoch_acc:.4f}")return epoch_loss, epoch_acc.item()#  定义测试函数
def val(model, test_loader, criterion):model.eval()running_loss = 0.0running_corrects = 0.0with torch.no_grad():for inputs, labels in test_loader:inputs = inputs.to(device)labels = labels.to(device)outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels)epoch_loss = running_loss / len(test_loader.dataset)epoch_acc = running_corrects / len(test_loader.dataset)print(f"Test Loss: {epoch_loss:.4f}, Test Acc: {epoch_acc:.4f}")write_to_txt(WRITER_PATH,f"Test Loss: {epoch_loss:.4f}, Test Acc: {epoch_acc:.4f}")return epoch_loss, epoch_acc.item()"""
只保测试结果最好的那一个模型
"""
# 开始训练模型
# 如果显卡可用,则用显卡进行训练
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("当前设备{}".format(device))
write_to_txt(WRITER_PATH,f"当前设备{device}")
model = ResNet50ForCatDog()
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)max_acc = 0.0
save_epoch=0
Loss_train = []
Acc_train = []
Loss_val = []
Acc_val = []
# 计时
start_time = datetime.now()
print("当前训练模型是ResNet50,猫狗二分类,预定训练轮次-{}".format(num_epochs))
write_to_txt(WRITER_PATH,f"当前训练模型是ResNet50,猫狗二分类,预定训练轮次-{num_epochs}")
for t in range(num_epochs):print("-----第{}轮训练开始-----".format(t + 1))write_to_txt(WRITER_PATH, f"-----第{t + 1}轮训练开始-----")train_loss, train_acc = train(model, train_dataloader, criterion, optimizer)val_loss, val_acc = val(model, val_dataloader, criterion)# 将损失值和正确率写入列表Loss_train.append(train_loss)Acc_train.append(train_acc)Loss_val.append(val_loss)Acc_val.append(val_acc)# 保存最好的模型权重文件if val_acc > max_acc:folder = '../save_models'if not os.path.exists(folder):os.mkdir('../save_models')max_acc = val_accsave_epoch=tprint(f'save best model,第{t + 1}轮')write_to_txt(WRITER_PATH,f'save best model,第{t + 1}轮')torch.save(model.state_dict(), '../save_models/best_model.pth')
end_time = datetime.now()
print("start_time:{}".format(start_time))
print("end_time:{}".format(end_time))
print("{}训练总用时:{}".format(device, end_time - start_time))
plot_metrics(Loss_train, Acc_train, Loss_val, Acc_val, save_path_loss='loss.png', save_path_accuracy='acc.png')
print("Done!")
write_to_txt(WRITER_PATH,f'=========Done!=========')
write_to_txt(WRITER_PATH,f'start_time:{start_time}')
write_to_txt(WRITER_PATH,f'end_time:{end_time}')
write_to_txt(WRITER_PATH,f'{device}训练总用时:{end_time - start_time}')
write_to_txt(WRITER_PATH,f'保存最好模型为{save_epoch+1}轮次')
write_to_txt(WRITER_PATH,f"Train Loss: {Loss_train[save_epoch]:.4f}, Train Acc: {Acc_train[save_epoch]:.4f}")
write_to_txt(WRITER_PATH,f"Test Loss: {Loss_val[save_epoch]:.4f}, Test Acc: {Acc_val[save_epoch]:.4f}")
write_to_txt(WRITER_PATH,f'=========Done!=========')

from DrawCurves import plot_metrics
from Utils import write_to_txt
from ResNet50 import ResNet50ForCatDog
为自定义py文件

DrawCurves.py

绘制模型训练曲线

import matplotlib.pyplot as pltdef plot_metrics(train_losses, train_accuracies, test_losses, test_accuracies,save_path_loss=None, save_path_accuracy=None):"""绘制训练集和测试集的损失值及正确率曲线,并保存为图片文件(可选)。参数:- train_losses: 训练集损失值的列表或数组- train_accuracies: 训练集正确率的列表或数组- test_losses: 测试集损失值的列表或数组- test_accuracies: 测试集正确率的列表或数组- save_path_loss: 保存损失值图形的路径。如果为 None,则不保存图片- save_path_accuracy: 保存正确率图形的路径。如果为 None,则不保存图片"""epochs = range(1, len(train_losses) + 1)# 绘制训练集和测试集损失值plt.figure(figsize=(10, 5))plt.plot(epochs, train_losses, 'b', label='Train Loss')plt.plot(epochs, test_losses, 'r', label='Test Loss')plt.xlabel('Epochs')plt.ylabel('Loss')plt.title('Train and Test Loss')plt.legend(loc='upper right')# 如果提供了保存路径,则保存图片if save_path_loss:plt.savefig(save_path_loss)plt.show()# 绘制训练集和测试集正确率plt.figure(figsize=(10, 5))plt.plot(epochs, train_accuracies, 'b', label='Train Accuracy')plt.plot(epochs, test_accuracies, 'r', label='Test Accuracy')plt.xlabel('Epochs')plt.ylabel('Accuracy')plt.title('Train and Test Accuracy')plt.legend(loc='upper right')# 如果提供了保存路径,则保存图片if save_path_accuracy:plt.savefig(save_path_accuracy)plt.show()

Utils.py

一个用于书写训练数据,一个用于生成指定范围内的随机数

import randomdef write_to_txt(file_path, content):"""将指定内容写入到指定的txt文件中,并确保每次写入的内容新起一行。参数:file_path (str): 文件的路径。content (str): 要写入文件的内容。"""with open(file_path, "a") as file:  # 以追加模式打开文件file.write(content + "\n")  # 写入内容并换行def generate_random_numbers(count, min_value, max_value):"""生成指定数量的随机数,并在给定范围内。参数:count (int): 要生成的随机数个数。min_value (int): 随机数的最小值(包括)。max_value (int): 随机数的最大值(包括)。返回:list: 生成的随机数列表。"""random_numbers = [random.randint(min_value, max_value) for _ in range(count)]return random_numbers

模型训练曲线图:
在这里插入图片描述
在这里插入图片描述
保留的最好模型信息
在这里插入图片描述

2.3 模型测试

TestModel.py

import torch
from ResNet50 import ResNet50ForCatDog
from torch.autograd import Variable
from torchvision import datasets, transforms,models
from torchvision.transforms import ToPILImage
from torchvision.datasets import ImageFolder
import matplotlib.pyplot as plt
from Utils import generate_random_numbersROOT_TEST = r'待测试数据集_path'val_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])val_dataset = ImageFolder(ROOT_TEST, transform=val_transform)
print(len(val_dataset))# 如果显卡可用,则用显卡进行训练
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("当前设备{}".format(device))# 调用net里面的定义的网络模型, 如果GPU可用则将模型转到GPU
# 加载预训练的 ResNet-50 模型
model=ResNet50ForCatDog()
model.to(device)
# 加载模型train.py里面训练的模型
model.load_state_dict(torch.load('../save_models/best_model.pth', weights_only=True))# 获取预测结果
classes = ['cat', 'dog']# 把tensor转成Image,方便可视化
show = ToPILImage()
# 进入验证阶段
model.eval()
# 随机30个图片,用于测试
arr=generate_random_numbers(30,0,len(val_dataset))
# 对val_dataset里面的照片进行推理验证
for i in arr:x, y = val_dataset[i][0], val_dataset[i][1]# show(x).show()img = show(x)# 使用 matplotlib 显示图像并设置标题plt.imshow(img)plt.title(i)plt.axis('off')  # 不显示坐标轴# 保存图像到本地文件# plt.savefig(f'../imgs/image_{i}.png', bbox_inches='tight', pad_inches=0)plt.show()x = Variable(torch.unsqueeze(x, dim=0).float(), requires_grad=False).to(device)with torch.no_grad():pred = model(x)# print(pred)predicted, actual = classes[torch.argmax(pred[0])], classes[y]print(f'index: {i} ,Predicted: "{predicted}" ,Actual: "{actual}"')

测试结果:
在这里插入图片描述

2.4 小玩意儿

图形化界面+已训练好的模型=猫狗分类器【仅限于猫狗图片,其他图片无预测能力】
在这里插入图片描述
ClassifyGui.py

import tkinter as tk
from tkinter import filedialog
from PIL import Image, ImageTk
import torch
from torchvision import transforms
from ResNet50 import ResNet50ForCatDog# 定义一个简单的猫狗识别模型
class_names = ['Cat', 'Dog']
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("当前设备{}".format(device))
# 加载预训练的 ResNet-50 模型
model=ResNet50ForCatDog()
model.to(device)
# 加载模型里面训练的模型
model.load_state_dict(torch.load('./best_model.pth', weights_only=True,map_location=device))
model.eval()# 图像预处理
normalize = transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
preprocess = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),  # 将0-255范围内的像素转为0-1范围内的tensornormalize])def predict_image(image_path):"""加载图像并进行预测"""image = Image.open(image_path)image = preprocess(image).unsqueeze(0)with torch.no_grad():output = model(image)_, predicted = torch.max(output, 1)return class_names[predicted.item()]def open_image():"""打开文件对话框选择图片"""file_path = filedialog.askopenfilename()if file_path:# 显示选中的图像img = Image.open(file_path)img.thumbnail((300, 150))  # 调整图像大小以适应展示区域img = ImageTk.PhotoImage(img)image_label.config(image=img)image_label.image = img# 预测图像类别result = predict_image(file_path)result_label.config(text=f'Prediction: {result}')# 初始化图形界面
root = tk.Tk()
root.title("Cat vs Dog Classifier")
root.geometry("500x300")  # 设置窗口大小# 创建界面组件
frame_top = tk.Frame(root, width=300, height=150)
frame_top.pack_propagate(False)  # 防止 Frame 自适应内容大小
frame_top.pack(pady=10)image_label = tk.Label(frame_top, bg='gray')
image_label.pack(expand=True)btn = tk.Button(root, text="Choose Image", command=open_image)
btn.pack(pady=10)result_label = tk.Label(root, text="Prediction: ", font=("Arial", 14))
result_label.pack(pady=10)# 运行图形界面
root.mainloop()

在这里插入图片描述
在这里插入图片描述

这篇关于ResNet网络(三部曲_3)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1106559

相关文章

Linux网络配置之网桥和虚拟网络的配置指南

《Linux网络配置之网桥和虚拟网络的配置指南》这篇文章主要为大家详细介绍了Linux中配置网桥和虚拟网络的相关方法,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 一、网桥的配置在linux系统中配置一个新的网桥主要涉及以下几个步骤:1.为yum仓库做准备,安装组件epel-re

python如何下载网络文件到本地指定文件夹

《python如何下载网络文件到本地指定文件夹》这篇文章主要为大家详细介绍了python如何实现下载网络文件到本地指定文件夹,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下...  在python中下载文件到本地指定文件夹可以通过以下步骤实现,使用requests库处理HTTP请求,并结合o

Linux高并发场景下的网络参数调优实战指南

《Linux高并发场景下的网络参数调优实战指南》在高并发网络服务场景中,Linux内核的默认网络参数往往无法满足需求,导致性能瓶颈、连接超时甚至服务崩溃,本文基于真实案例分析,从参数解读、问题诊断到优... 目录一、问题背景:当并发连接遇上性能瓶颈1.1 案例环境1.2 初始参数分析二、深度诊断:连接状态与

Qt实现网络数据解析的方法总结

《Qt实现网络数据解析的方法总结》在Qt中解析网络数据通常涉及接收原始字节流,并将其转换为有意义的应用层数据,这篇文章为大家介绍了详细步骤和示例,感兴趣的小伙伴可以了解下... 目录1. 网络数据接收2. 缓冲区管理(处理粘包/拆包)3. 常见数据格式解析3.1 jsON解析3.2 XML解析3.3 自定义

Linux系统配置NAT网络模式的详细步骤(附图文)

《Linux系统配置NAT网络模式的详细步骤(附图文)》本文详细指导如何在VMware环境下配置NAT网络模式,包括设置主机和虚拟机的IP地址、网关,以及针对Linux和Windows系统的具体步骤,... 目录一、配置NAT网络模式二、设置虚拟机交换机网关2.1 打开虚拟机2.2 管理员授权2.3 设置子

揭秘Python Socket网络编程的7种硬核用法

《揭秘PythonSocket网络编程的7种硬核用法》Socket不仅能做聊天室,还能干一大堆硬核操作,这篇文章就带大家看看Python网络编程的7种超实用玩法,感兴趣的小伙伴可以跟随小编一起... 目录1.端口扫描器:探测开放端口2.简易 HTTP 服务器:10 秒搭个网页3.局域网游戏:多人联机对战4.

SpringBoot使用OkHttp完成高效网络请求详解

《SpringBoot使用OkHttp完成高效网络请求详解》OkHttp是一个高效的HTTP客户端,支持同步和异步请求,且具备自动处理cookie、缓存和连接池等高级功能,下面我们来看看SpringB... 目录一、OkHttp 简介二、在 Spring Boot 中集成 OkHttp三、封装 OkHttp

Linux系统之主机网络配置方式

《Linux系统之主机网络配置方式》:本文主要介绍Linux系统之主机网络配置方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、查看主机的网络参数1、查看主机名2、查看IP地址3、查看网关4、查看DNS二、配置网卡1、修改网卡配置文件2、nmcli工具【通用

使用Python高效获取网络数据的操作指南

《使用Python高效获取网络数据的操作指南》网络爬虫是一种自动化程序,用于访问和提取网站上的数据,Python是进行网络爬虫开发的理想语言,拥有丰富的库和工具,使得编写和维护爬虫变得简单高效,本文将... 目录网络爬虫的基本概念常用库介绍安装库Requests和BeautifulSoup爬虫开发发送请求解

如何通过海康威视设备网络SDK进行Java二次开发摄像头车牌识别详解

《如何通过海康威视设备网络SDK进行Java二次开发摄像头车牌识别详解》:本文主要介绍如何通过海康威视设备网络SDK进行Java二次开发摄像头车牌识别的相关资料,描述了如何使用海康威视设备网络SD... 目录前言开发流程问题和解决方案dll库加载不到的问题老旧版本sdk不兼容的问题关键实现流程总结前言作为