MNIST手写字符分类-卷积

2024-06-13 13:52

本文主要是介绍MNIST手写字符分类-卷积,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

MNIST手写字符分类-卷积

文章目录

  • MNIST手写字符分类-卷积
    • 1 模型构造
    • 2 训练
    • 3 推理
    • 4 导出
    • 5 onnx测试
    • 6 opencv部署
    • 7 总结

  在上一篇中,我们介绍了如何在pytorch中使用线性层+ReLU非线性层堆叠的网络进行手写字符识别的网络构建、训练、模型保存、导出和推理测试。本篇文章中,我们将要使用卷积层进行网络构建,并完成后续的训练、保存、导出,并使用opencv在C++中推理我们的模型,将结果可视化。

1 模型构造

  在pytorch中,卷积层的使用比较方便,需要注意的是卷积层的输入通道数、输出通道数、卷积核的大小等参数。这里直接放出构建的网络结构:

import torch
from torch import nn
from torch.utils.data import DataLoader
class ZKNNNet_Conv(nn.Module):def __init__(self):super(ZKNNNet_Conv, self).__init__()self.conv_stack = nn.Sequential(nn.Conv2d(1, 32, kernel_size=3),nn.ReLU(),nn.Conv2d(32, 64, kernel_size=3),nn.ReLU(),nn.MaxPool2d(kernel_size=2),nn.Flatten(),nn.Linear(12*12*64, 128),nn.ReLU(),nn.Linear(128, 10))def forward(self, x):logits = self.conv_stack(x)return logits

在这里插入图片描述

从图中可以看出,该模型先堆叠了两个卷积层与ReLU单元,经过最大池化之后,展开并进行后续的全连接层训练。

2 训练

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, Compose
from ZKNNNet import ZKNNNet_Conv
import os
# Download training data from open datasets.
training_data = datasets.MNIST(root="data",train=True,download=True,transform=ToTensor(),
)# Download test data from open datasets.
test_data = datasets.MNIST(root="data",train=False,download=True,transform=ToTensor(),
)# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))model = ZKNNNet_Conv()
if os.path.exists("./model/model_conv.pth"):model.load_state_dict(torch.load("./model/model_conv.pth"))
model = model.to(device)
print(model)# Optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)# Loss function
loss_fn = nn.CrossEntropyLoss()# Train
def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)model.train()for batch, (X, y) in enumerate(dataloader):X, y = X.to(device), y.to(device)# Compute prediction errorpred = model(X)loss = loss_fn(pred, y)# Backpropagationoptimizer.zero_grad()loss.backward()optimizer.step()if batch % 100 == 0:loss, current = loss.item(), batch * len(X)print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")# Test
def test(dataloader, model):size = len(dataloader.dataset)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= sizecorrect /= sizeprint(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")return correctepochs = 200
maxAcc = 0
for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train(train_dataloader, model, loss_fn, optimizer)currentAcc = test(test_dataloader, model)if maxAcc < currentAcc:maxAcc = currentAcctorch.save(model.state_dict(), "./model/model_conv.pth")
print("Done!")

模型的训练代码与上一篇中的线性连接训练代码是一样的。
训练过程来看,使用卷积层,在相同数据集上训练,模型收敛速度比用线性层快很多。最终精度达到97.8%。

3 推理

模型训练完成之后,推理过程与上一篇一致,这里简单放一下推理代码。

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torchvision import datasets
from ZKNNNet import ZKNNNet_Convimport matplotlib.pyplot as plt# Get cpu or gpu device for inference.
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device for inference".format(device))# Load the trained model
model = ZKNNNet_Conv()
model.load_state_dict(torch.load("./model/model_conv.pth"))
model.to(device)
model.eval()# Download test data from open datasets.
test_data = datasets.MNIST(root="data",train=False,download=True,transform=ToTensor(),
)# Create data loader.
test_dataloader = DataLoader(test_data, batch_size=64)# Perform inference
with torch.no_grad():correct = 0total = 0for images, labels in test_dataloader:images = images.to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()# Visualize the image and its predicted resultfor i in range(len(images)):image = images[i].cpu()label = labels[i].cpu()prediction = predicted[i].cpu()plt.imshow(image.squeeze(), cmap='gray')plt.title(f"Label: {label}, Predicted: {prediction}")plt.show()accuracy = 100 * correct / totalprint("Accuracy on test set: {:.2f}%".format(accuracy))

4 导出

模型导出方式与上一篇一致。

import torch
import torch.utils
import os
from ZKNNNet import ZKNNNet_3Layer,ZKNNNet_5Layer,ZKNNNet_Conv
model_conv = ZKNNNet_Conv()
if os.path.exists('./model/model_conv.pth'):model_conv.load_state_dict(torch.load('./model/model_conv.pth'))
model_conv = model_conv.to(device)
model_conv.eval()
torch.onnx.export(model_conv,torch.randn(1,1,28,28),'./model/model_conv.onnx',verbose=True)

5 onnx测试

import onnxruntime as rt
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torchvision import datasetsimport matplotlib.pyplot as pltfrom PIL import Imagesess = rt.InferenceSession("model/model_conv.onnx")
input_name = sess.get_inputs()[0].name
print(input_name)image = Image.open('./data/test/2.png')
image_data = np.array(image)
image_data = image_data.astype(np.float32)/255.0
image_data = image_data[None,None,:,:]
print(image_data.shape)outputs = sess.run(None,{input_name:image_data})
outputs = np.array(outputs).flatten()prediction = np.argmax(outputs)
plt.imshow(image, cmap='gray')
plt.title(f"Predicted: {prediction}")
plt.show()# Download test data from open datasets.
test_data = datasets.MNIST(root="data",train=False,download=True,transform=ToTensor(),
)# Create data loader.
test_dataloader = DataLoader(test_data, batch_size=1)with torch.no_grad():correct = 0total = 0for images, labels in test_dataloader:images = images.numpy()labels = labels.numpy()outputs = sess.run(None,{input_name:images})[0]outputs = np.array(outputs).flatten()prediction = np.argmax(outputs)# Visualize the image and its predicted resultfor i in range(len(images)):image = images[i]label = labels[i]plt.imshow(image.squeeze(), cmap='gray')plt.title(f"Label: {label}, Predicted: {prediction}")plt.show()

至此,模型已经成功的转换成onnx模型,可以用于后续各种部署环境的部署。

6 opencv部署

本例中,使用C++/opencv来尝试部署刚才训练的模型。输入为在之前的博文中提到的将MNIST测试集导出成png图片保存。

#include "opencv2/opencv.hpp"#include <iostream>
#include <filesystem>
#include <string>
#include <vector>int main(int argc, char** argv)
{if (argc != 3){std::cerr << "Usage: MNISTClassifier_onnx_opencv <onnx_model_path> <image_path>" << std::endl;return 1;}cv::dnn::Net net = cv::dnn::readNetFromONNX(argv[1]);if (net.empty()){std::cout << "Error: Failed to load ONNX file." << std::endl;return 1;}std::filesystem::path srcPath(argv[2]);for (auto& imgPath : std::filesystem::recursive_directory_iterator(srcPath)){if(!std::filesystem::is_regular_file(imgPath))continue;const cv::Mat image = cv::imread(imgPath.path().string(), cv::IMREAD_GRAYSCALE);if (image.empty()){std::cerr << "Error: Failed to read image file." << std::endl;continue;}const cv::Size size(28, 28);cv::Mat resized_image;cv::resize(image, resized_image, size);cv::Mat float_image;resized_image.convertTo(float_image, CV_32F, 1.0 / 255.0);cv::Mat input_blob = cv::dnn::blobFromImage(float_image);net.setInput(input_blob);cv::Mat output = net.forward();cv::Point classIdPoint;double confidence;cv::minMaxLoc(output.reshape(1, 1), nullptr, &confidence, nullptr, &classIdPoint);const int class_id = classIdPoint.x;std::cout << "Class ID: " << class_id << std::endl;std::cout << "Confidence: " << confidence << std::endl;cv::Mat bigImg;cv::resize(image,bigImg,cv::Size(128,128));auto parentPath = imgPath.path().parent_path();auto label = parentPath.filename().string()+std::string("<->")+std::to_string(class_id);cv::putText(bigImg, label, cv::Point(10, 20), cv::FONT_HERSHEY_SIMPLEX, 0.5, cv::Scalar(255, 255, 255), 1);cv::imshow("img",bigImg);cv::waitKey();}return 0;
}

本部署方式需要依赖opencv dnn模块。试验中使用的是opencv4.8版本。

7 总结

使用卷积神经网络进行MNIST手写字符识别,在模型结构无明显复杂的情况下,模型收敛速度较全连接层构建的网络收敛速度快。

按照相同的套路导出成onnx模型之后,直接通过opencv可以部署,简化深度学习算法部署的难度。

本部署方式需要依赖opencv dnn模块。试验中使用的是opencv4.8版本。

这篇关于MNIST手写字符分类-卷积的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1057493

相关文章

MySQL中的索引结构和分类实战案例详解

《MySQL中的索引结构和分类实战案例详解》本文详解MySQL索引结构与分类,涵盖B树、B+树、哈希及全文索引,分析其原理与优劣势,并结合实战案例探讨创建、管理及优化技巧,助力提升查询性能,感兴趣的朋... 目录一、索引概述1.1 索引的定义与作用1.2 索引的基本原理二、索引结构详解2.1 B树索引2.2

C#如何去掉文件夹或文件名非法字符

《C#如何去掉文件夹或文件名非法字符》:本文主要介绍C#如何去掉文件夹或文件名非法字符的问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录C#去掉文件夹或文件名非法字符net类库提供了非法字符的数组这里还有个小窍门总结C#去掉文件夹或文件名非法字符实现有输入字

idea报错java: 非法字符: ‘\ufeff‘的解决步骤以及说明

《idea报错java:非法字符:‘ufeff‘的解决步骤以及说明》:本文主要介绍idea报错java:非法字符:ufeff的解决步骤以及说明,文章详细解释了为什么在Java中会出现uf... 目录BOM是什么?1. BOM的作用2. 为什么会出现 \ufeff 错误?3. 如何解决 \ufeff 问题?最

使用Java编写一个字符脱敏工具类

《使用Java编写一个字符脱敏工具类》这篇文章主要为大家详细介绍了如何使用Java编写一个字符脱敏工具类,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录1、字符脱敏工具类2、测试工具类3、测试结果1、字符脱敏工具类import lombok.extern.slf4j.Slf4j

解决IDEA报错:编码GBK的不可映射字符问题

《解决IDEA报错:编码GBK的不可映射字符问题》:本文主要介绍解决IDEA报错:编码GBK的不可映射字符问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录IDEA报错:编码GBK的不可映射字符终端软件问题描述原因分析解决方案方法1:将命令改为方法2:右下jav

Pandas使用AdaBoost进行分类的实现

《Pandas使用AdaBoost进行分类的实现》Pandas和AdaBoost分类算法,可以高效地进行数据预处理和分类任务,本文主要介绍了Pandas使用AdaBoost进行分类的实现,具有一定的参... 目录什么是 AdaBoost?使用 AdaBoost 的步骤安装必要的库步骤一:数据准备步骤二:模型

使用PyTorch实现手写数字识别功能

《使用PyTorch实现手写数字识别功能》在人工智能的世界里,计算机视觉是最具魅力的领域之一,通过PyTorch这一强大的深度学习框架,我们将在经典的MNIST数据集上,见证一个神经网络从零开始学会识... 目录当计算机学会“看”数字搭建开发环境MNIST数据集解析1. 认识手写数字数据库2. 数据预处理的

C语言字符函数和字符串函数示例详解

《C语言字符函数和字符串函数示例详解》本文详细介绍了C语言中字符分类函数、字符转换函数及字符串操作函数的使用方法,并通过示例代码展示了如何实现这些功能,通过这些内容,读者可以深入理解并掌握C语言中的字... 目录一、字符分类函数二、字符转换函数三、strlen的使用和模拟实现3.1strlen函数3.2st

C# string转unicode字符的实现

《C#string转unicode字符的实现》本文主要介绍了C#string转unicode字符的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随... 目录1. 获取字符串中每个字符的 Unicode 值示例代码:输出:2. 将 Unicode 值格式化

C#使用DeepSeek API实现自然语言处理,文本分类和情感分析

《C#使用DeepSeekAPI实现自然语言处理,文本分类和情感分析》在C#中使用DeepSeekAPI可以实现多种功能,例如自然语言处理、文本分类、情感分析等,本文主要为大家介绍了具体实现步骤,... 目录准备工作文本生成文本分类问答系统代码生成翻译功能文本摘要文本校对图像描述生成总结在C#中使用Deep