手写数据集minist基于pytorch分类学习

2024-06-03 22:12

本文主要是介绍手写数据集minist基于pytorch分类学习,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1.Mnist数据集介绍
1.1 基本介绍
Mnist数据集可以算是学习深度学习最常用到的了。这个数据集包含70000张手写数字图片,分别是60000张训练图片和10000张测试图片,训练集由来自250个不同人手写的数字构成,一般来自高中生,一半来自工作人员,测试集(test set)也是同样比例的手写数字数据,并且保证了测试集和训练集的作者不同。每个图片都是2828个像素点,数据集会把一张图片的数据转成一个2828=784的一维向量存储起来。
里面的图片数据如下所示,每张图是0-9的手写数字黑底白字的图片,存储时,黑色用0表示,白色用0-1的浮点数表示。


1.2 数据集下载
1)官网下载
Mnist数据集的下载地址如下:http://yann.lecun.com/exdb/mnist/
打开后会有四个文件:


训练数据集:train-images-idx3-ubyte.gz
训练数据集标签:train-labels-idx1-ubyte.gz
测试数据集:t10k-images-idx3-ubyte.gz
测试数据集标签:t10k-labels-idx1-ubyte.gz
将这四个文件下载后放置到需要用的文件夹下即可不要解压!下载后是什么就怎么放!

2)代码导入
文件夹下运行下面的代码,即可自动检测数据集是否存在,若没有会自动进行下载,下载后在这一路径:

下载数据集:

# 下载数据集
from torchvision import datasets, transformstrain_set = datasets.MNIST("data",train=True,download=True, transform=transforms.ToTensor(),)
test_set = datasets.MNIST("data",train=False,download=True, transform=transforms.ToTensor(),)

参数解释:

datasets.MNIST:是Pytorch的内置函数torchvision.datasets.MNIST,可以导入数据集
train=True :读入的数据作为训练集
transform:读入我们自己定义的数据预处理操作
download=True:当我们的根目录(root)下没有数据集时,便自动下载
如果这时候我们通过联网自动下载方式download我们的数据后,它的文件路径是以下形式:原文件夹/data/MNIST/raw

14轮左右,模型识别准确率达到98%以上

 

 加载数据集

import os.path
import matplotlib.pyplot as plt
import torch
from torchvision.datasets import MNIST
from PIL import Image
from torch.utils.data import Dataset,DataLoader
from torchvision import datasets, transforms
# 下载数据集
from torchvision import datasets, transformstransform = transforms.Compose([transforms.ToTensor(), # 将灰度图片像素值(0~255)转为Tensor(0~1),方便后续处理transforms.Normalize((0.1307,),(0.3081,))# 归一化,均值0,方差1;mean:各通道的均值std:各通道的标准差inplace:是否原地操作
])train_data = MNIST(root='./minist_data',train=True,download=False,transform=transform)
train_loader = DataLoader(dataset=train_data,shuffle=True,batch_size=64)
test_data = MNIST(root='./minist_data',train=False,download=False,transform=transform)
test_loader = DataLoader(dataset=test_data,shuffle=True,batch_size=64)# train_data返回的是很多张图,每一张图是一个元组,包含图片和对应的数字
# print(test_data[0])
# print(train_data[0][0].show())train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集的长度:{}".format(train_data_size))
print("测试数据集的长度:{}".format(test_data_size))

构建模型,模型主要由两个卷积层,两个池化层,以及一个全连接层构成,激活函数使用relu. 

 

class Model(torch.nn.Module):def __init__(self):super(Model,self).__init__()self.conv1 = torch.nn.Conv2d(in_channels=1,out_channels=10,stride=1,kernel_size=5,padding=0)self.maxpool1 = torch.nn.MaxPool2d(2)self.conv2 = torch.nn.Conv2d(in_channels=10,out_channels=20,kernel_size=5,stride=1,padding=0)self.maxpool2 = torch.nn.MaxPool2d(2)self.linear = torch.nn.Linear(320,10)def forward(self,x):x = torch.relu(self.conv1(x))x = self.maxpool1(x)x = torch.relu(self.conv2(x))x = self.maxpool2(x)x = x.view(x.size(0),-1)x = self.linear(x)return x
model = Model()criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(),lr=0.14)# 交叉熵损失,相当于Softmax+Log+NllLoss
# 线性多分类模型Softmax,给出最终预测值对于10个类别出现的概率,Log:将乘法转换为加法,减少计算量,保证函数的单调性
# NLLLoss:计算损失,此过程不需要手动one-hot编码,NLLLoss会自动完成
# SGD,优化器,梯度下降算法e

模型训练
每次训练完成后会自动保存参数到pkl模型中,如果路径中有Pkl文件,下次运行会自动加载上一次的模型参数,在这个基础上继续训练,第一次运行时没有模型参数,结束后会自动生成。

# 模型训练
def train():# index = 0for index, data in enumerate(train_loader):  # 获取训练数据以及对应标签# for data in train_loader:input, target = data  # input为输入数据,target为标签y_predict = model(input)  # 模型预测loss = criterion(y_predict, target)optimizer.zero_grad()  # 梯度清零loss.backward()  # loss值反向传播optimizer.step()  # 更新参数# index += 1if index % 100 == 0:  # 每一百次保存一次模型,打印损失torch.save(model.state_dict(), "model.pkl")  # 保存模型torch.save(optimizer.state_dict(), "optimizer.pkl")print("训练次数为:{},损失值为:{}".format(index, loss.item()))

加载模型
第一次运行这里需要一个空的model文件夹

if os.path.exists('model.pkl'):model.load_state_dict(torch.load("model.pkl"))

模型测试

def test():correct = 0total = 0with torch.no_grad():for index,data in enumerate(test_loader):inputs,target = dataoutput = model(inputs)probability,predict = torch.max(input=output.data, dim=1)total += target.size(0)  # target是形状为(batch_size,1)的矩阵,使用size(0)取出该批的大小correct += (predict == target).sum().item()  # predict 和target均为(batch_size,1)的矩阵,sum求出相等的个数print("测试准确率为:%.6f" % (correct / total))

自己手写数字图片识别函数(可选用)
这部分主要是加载训练好的pkl模型测试自己的数据,因此在进行自己手写图的测试时,需要有训练好的pkl文件,并且就不要调用train()函数和test()函数啦注意:这个图片像素也要说黑底白字,28*28像素,否则无法识别

def test_mydata():image = Image.open('5fd4e4c2c99a24e3e27eb9b2ee3b053c.jpg')  # 读取自定义手写图片image = image.resize((28, 28))  # 裁剪尺寸为28*28image = image.convert('L')  # 转换为灰度图像transform = transforms.ToTensor()image = transform(image)image = image.resize(1, 1, 28, 28)output = model(image)probability, predict = torch.max(output.data, dim=1)print("此手写图片值为:%d,其最大概率为:%.2f " % (predict[0], probability))plt.title("此手写图片值为:{}".format((int(predict))), fontname='SimHei')plt.imshow(image.squeeze())plt.show()

MNIST中的数据识别测试数据
训练过程中的打印信息我进行了修改,这里设置的训练轮数是15轮,每次训练生成的pkl模型参数也是会更新的,想要更多训练信息可以查看对应的教程哦~

if __name__ == '__main__':# 训练与测试for i in range(15):  # 训练和测试进行5轮print({"————————第{}轮测试开始——————".format(i + 1)})train()test()test_mydata()

完整代码:

import os.path
import matplotlib.pyplot as plt
import torch
from torchvision.datasets import MNIST
from PIL import Image
from torch.utils.data import Dataset,DataLoader
from torchvision import datasets, transforms
# 下载数据集
from torchvision import datasets, transformstransform = transforms.Compose([transforms.ToTensor(), # 将灰度图片像素值(0~255)转为Tensor(0~1),方便后续处理transforms.Normalize((0.1307,),(0.3081,))# 归一化,均值0,方差1;mean:各通道的均值std:各通道的标准差inplace:是否原地操作
])train_data = MNIST(root='./minist_data',train=True,download=False,transform=transform)
train_loader = DataLoader(dataset=train_data,shuffle=True,batch_size=64)
test_data = MNIST(root='./minist_data',train=False,download=False,transform=transform)
test_loader = DataLoader(dataset=test_data,shuffle=True,batch_size=64)# train_data返回的是很多张图,每一张图是一个元组,包含图片和对应的数字
# print(test_data[0])
# print(train_data[0][0].show())train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集的长度:{}".format(train_data_size))
print("测试数据集的长度:{}".format(test_data_size))class Model(torch.nn.Module):def __init__(self):super(Model,self).__init__()self.conv1 = torch.nn.Conv2d(in_channels=1,out_channels=10,stride=1,kernel_size=5,padding=0)self.maxpool1 = torch.nn.MaxPool2d(2)self.conv2 = torch.nn.Conv2d(in_channels=10,out_channels=20,kernel_size=5,stride=1,padding=0)self.maxpool2 = torch.nn.MaxPool2d(2)self.linear = torch.nn.Linear(320,10)def forward(self,x):x = torch.relu(self.conv1(x))x = self.maxpool1(x)x = torch.relu(self.conv2(x))x = self.maxpool2(x)x = x.view(x.size(0),-1)x = self.linear(x)return x
model = Model()criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(),lr=0.14)# 交叉熵损失,相当于Softmax+Log+NllLoss
# 线性多分类模型Softmax,给出最终预测值对于10个类别出现的概率,Log:将乘法转换为加法,减少计算量,保证函数的单调性
# NLLLoss:计算损失,此过程不需要手动one-hot编码,NLLLoss会自动完成
# SGD,优化器,梯度下降算法e# 模型训练
def train():# index = 0for index, data in enumerate(train_loader):  # 获取训练数据以及对应标签# for data in train_loader:input, target = data  # input为输入数据,target为标签y_predict = model(input)  # 模型预测loss = criterion(y_predict, target)optimizer.zero_grad()  # 梯度清零loss.backward()  # loss值反向传播optimizer.step()  # 更新参数# index += 1if index % 100 == 0:  # 每一百次保存一次模型,打印损失torch.save(model.state_dict(), "model.pkl")  # 保存模型torch.save(optimizer.state_dict(), "optimizer.pkl")print("训练次数为:{},损失值为:{}".format(index, loss.item()))if os.path.exists('model.pkl'):model.load_state_dict(torch.load("model.pkl"))def test():correct = 0total = 0with torch.no_grad():for index,data in enumerate(test_loader):inputs,target = dataoutput = model(inputs)probability,predict = torch.max(input=output.data, dim=1)total += target.size(0)  # target是形状为(batch_size,1)的矩阵,使用size(0)取出该批的大小correct += (predict == target).sum().item()  # predict 和target均为(batch_size,1)的矩阵,sum求出相等的个数print("测试准确率为:%.6f" % (correct / total))def test_mydata():image = Image.open('5fd4e4c2c99a24e3e27eb9b2ee3b053c.jpg')  # 读取自定义手写图片image = image.resize((28, 28))  # 裁剪尺寸为28*28image = image.convert('L')  # 转换为灰度图像transform = transforms.ToTensor()image = transform(image)image = image.resize(1, 1, 28, 28)output = model(image)probability, predict = torch.max(output.data, dim=1)print("此手写图片值为:%d,其最大概率为:%.2f " % (predict[0], probability))plt.title("此手写图片值为:{}".format((int(predict))), fontname='SimHei')plt.imshow(image.squeeze())plt.show()if __name__ == '__main__':# 训练与测试for i in range(15):  # 训练和测试进行5轮print({"————————第{}轮测试开始——————".format(i + 1)})train()test()test_mydata()

 

 

 

 

这篇关于手写数据集minist基于pytorch分类学习的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1028251

相关文章

Java注解之超越Javadoc的元数据利器详解

《Java注解之超越Javadoc的元数据利器详解》本文将深入探讨Java注解的定义、类型、内置注解、自定义注解、保留策略、实际应用场景及最佳实践,无论是初学者还是资深开发者,都能通过本文了解如何利用... 目录什么是注解?注解的类型内置注编程解自定义注解注解的保留策略实际用例最佳实践总结在 Java 编程

一文教你Python如何快速精准抓取网页数据

《一文教你Python如何快速精准抓取网页数据》这篇文章主要为大家详细介绍了如何利用Python实现快速精准抓取网页数据,文中的示例代码简洁易懂,具有一定的借鉴价值,有需要的小伙伴可以了解下... 目录1. 准备工作2. 基础爬虫实现3. 高级功能扩展3.1 抓取文章详情3.2 保存数据到文件4. 完整示例

使用Java将各种数据写入Excel表格的操作示例

《使用Java将各种数据写入Excel表格的操作示例》在数据处理与管理领域,Excel凭借其强大的功能和广泛的应用,成为了数据存储与展示的重要工具,在Java开发过程中,常常需要将不同类型的数据,本文... 目录前言安装免费Java库1. 写入文本、或数值到 Excel单元格2. 写入数组到 Excel表格

判断PyTorch是GPU版还是CPU版的方法小结

《判断PyTorch是GPU版还是CPU版的方法小结》PyTorch作为当前最流行的深度学习框架之一,支持在CPU和GPU(NVIDIACUDA)上运行,所以对于深度学习开发者来说,正确识别PyTor... 目录前言为什么需要区分GPU和CPU版本?性能差异硬件要求如何检查PyTorch版本?方法1:使用命

python处理带有时区的日期和时间数据

《python处理带有时区的日期和时间数据》这篇文章主要为大家详细介绍了如何在Python中使用pytz库处理时区信息,包括获取当前UTC时间,转换为特定时区等,有需要的小伙伴可以参考一下... 目录时区基本信息python datetime使用timezonepandas处理时区数据知识延展时区基本信息

Qt实现网络数据解析的方法总结

《Qt实现网络数据解析的方法总结》在Qt中解析网络数据通常涉及接收原始字节流,并将其转换为有意义的应用层数据,这篇文章为大家介绍了详细步骤和示例,感兴趣的小伙伴可以了解下... 目录1. 网络数据接收2. 缓冲区管理(处理粘包/拆包)3. 常见数据格式解析3.1 jsON解析3.2 XML解析3.3 自定义

SpringMVC 通过ajax 前后端数据交互的实现方法

《SpringMVC通过ajax前后端数据交互的实现方法》:本文主要介绍SpringMVC通过ajax前后端数据交互的实现方法,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价... 在前端的开发过程中,经常在html页面通过AJAX进行前后端数据的交互,SpringMVC的controll

Java学习手册之Filter和Listener使用方法

《Java学习手册之Filter和Listener使用方法》:本文主要介绍Java学习手册之Filter和Listener使用方法的相关资料,Filter是一种拦截器,可以在请求到达Servl... 目录一、Filter(过滤器)1. Filter 的工作原理2. Filter 的配置与使用二、Listen

Pandas使用AdaBoost进行分类的实现

《Pandas使用AdaBoost进行分类的实现》Pandas和AdaBoost分类算法,可以高效地进行数据预处理和分类任务,本文主要介绍了Pandas使用AdaBoost进行分类的实现,具有一定的参... 目录什么是 AdaBoost?使用 AdaBoost 的步骤安装必要的库步骤一:数据准备步骤二:模型

Pandas统计每行数据中的空值的方法示例

《Pandas统计每行数据中的空值的方法示例》处理缺失数据(NaN值)是一个非常常见的问题,本文主要介绍了Pandas统计每行数据中的空值的方法示例,具有一定的参考价值,感兴趣的可以了解一下... 目录什么是空值?为什么要统计空值?准备工作创建示例数据统计每行空值数量进一步分析www.chinasem.cn处