利用PyTorch构建三层线性网络完成对MNIST数据集识别

2024-01-02 03:48

本文主要是介绍利用PyTorch构建三层线性网络完成对MNIST数据集识别,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

在这里首先简单介绍一些MNIST数据集:

MNIST数据集内共包含70000张手写数字图像,数字范围0~9,大小为28*28,其中60000张用于训练学习,10000张用于数据测试,图像为灰度图像,数字位置居中,可以减少预处理和加快运行

在学习编程入门时,无论哪个语言,hello world往往是第一步,再进行深度学习入门时,MNIST数据集研究透了,基本就可以入门了。

下面向大家展示一下MNIST数据集内图像:

--------------------------------------------------------------------------------------------------------------------------------

下面进入今天的正题:利用PyTorch搭建一个三层线性网络,完成对MNIST数据集的训练并且进行测试 :

本次demo包含两个目录文件,一个是utils.py,另一个是mnist_train.py,在utils.py内我们放置了三个函数,分别是plot_curve,plot_image和one_hot。

plot_curve:用于绘制对MNIST数据集进行训练时损失函数曲线,方便观察

def plot_curve(data):fig = plt.figure()plt.plot(range(len(data)),data,color = 'blue')plt.legend(['value'],loc = 'upper right')plt.xlabel('step')plt.ylabel('value')plt.show()

plot_image:对于训练和识别过程,可以很方便的将训练结果可视化

def plot_image(img,label,name):fig = plt.figure()for i in range(6):plt.subplot(2,3,i+1)plt.tight_layout()plt.imshow(img[i][0]*0.3081+0.1307,cmap='gray',interpolation='none')plt.title("{}:{}".format(name,label[i].item()))plt.xticks([])plt.yticks([])plt.show()

one_hot:PyTorch内还没有对one-hot函数的实现,在这用scatter完成简单的一个编码

注:one_hot编码:

One-Hot编码,又称为一位有效编码,主要是采用N位状态寄存器来对N个状态进行编码,每个状态都由他独立的寄存器位,并且在任意时候只有一位有效。

One-Hot编码是分类变量作为二进制向量的表示。这首先要求将分类值映射到整数值。然后,每个整数值被表示为二进制向量,除了整数的索引之外,它都是零值,它被标记为1。

def one_hot(label,depth=10):out = torch.zeros(label.size(0),depth)idx = torch.LongTensor(label).view(-1,1)out.scatter_(dim = 1,index = idx,value = 1)return out

这样,我们的utils.py里面的三个工具函数就已经编码完毕,这三个函数只是达到辅助可视化的作用,不会对训练和测试产生任何影响,所以大家如果图省事,放在一个函数里也可以

下面我们实现MNIST的train:

1.导入相关包:

import torch
from torch import nn
from torch.nn import functional as F
from torch import  optim
import torchvision
from utils import plot_image,plot_curve,one_hot

2.准备数据集: 

在这儿我们使用DataLoader完成数据集的下载,是一种十分方便的方式,这个地方是通过torch vision实现的。

torchvision是PyTorch的一个图形库,服务于PyTorch深度学习框架,构建计算机视觉模型

torchvision.transforms:常用的图像预处理方法,利用Compose将对图片的操作整合起来

torchvision.datasets:常用的datasets数据集实现,如MNIST,CIFAR10等

torchvision.model:常用的模型预训练,如LeNet,ResNet,VGG等

解释一下这里的ToTensor:数据归一化到均值为0,方差为1(是将数据除以255),即图像进来以后,先进行通道转换,然后判断图像类型,若是uint8类型,就除以255;否则返回原图。

而这里的Normalize是对数据按通道进行标准化,即减去均值,再除以方差

其中,0.1307和0.3081是mnist数据集的均值和标准差,因为mnist数据值都是灰度图,所以图像的通道数只有一个,因此均值和标准差各一个。要是imagenet数据集的话,由于它的图像都是RGB图像,因此他们的均值和标准差各3个,分别对应其R,G,B值。例如([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])就是Imagenet dataset的标准化系数(RGB三个通道对应三组系数)。数据集给出的均值和标准差系数,每个数据集都不同的,都是数据集提供方给出的。

transforms.Normalize(mean,std)的计算公式是:input=\frac{input - mean}{std}

train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data',train=True,download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,),(0.3081,))])),batch_size = batch_size,shuffle = True)
test_load = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data/',train=False,download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,),(0.3081,))])),batch_size = batch_size,shuffle = False)

因为对RGB图片而言,数据范围是[0-255]的,需要先经过ToTensor除以255归一化到[0,1]之后,再通过Normalize计算过后,将数据归一化到[-1,1]。

那transform.Normalize()是怎么工作的呢?以上面代码为例,ToTensor()能够把灰度范围从0-255变换到0-1之间,而后面的transform.Normalize()则把0-1变换到(-1,1)

3.创建网络

class Net(nn.Module):def __init__(self):super(Net, self).__init__()#xw+bself.fc1 = nn.Linear(28*28,256)self.fc2 = nn.Linear(256,64)self.fc3 = nn.Linear(64,10)def forward(self,x):# x:[b,1,28,28]# h1:relu(xw1+b1)x = F.relu(self.fc1(x))# h2:relu(h1w2+b2)x = F.relu(self.fc2(x))# h3 = h2w3+b3x = self.fc3(x)return x

这里采用的是最最最基本的线性层连接,共有三层,激活函数采用的relu函数

在这里Linear会把28*28的图像铺平展开为784个元素的一维数组后进行处理,在forward内不断传给下一层。

这个地方只写了前向传播的forward(),并没有写反向传播的backward()是因为在pytorch的求导过程中,有以下两种情况:

如果是标量对向量求导(scalar对tensor求导),那么就可以保证上面的计算图的根节点只有一个,此时不用引入grad_tensors参数,直接调用backward函数即可
如果是(向量)矩阵对(向量)矩阵求导(tensor对tensor求导),实际上是先求出Jacobian矩阵中每一个元素的梯度值(每一个元素的梯度值的求解过程对应上面的计算图的求解方法),然后将这个Jacobian矩阵与grad_tensors参数对应的矩阵进行对应的点乘,得到最终的结果。

4.Train

for epoch in range(5):for batch_idx,(x,y) in enumerate(train_loader):# x:[b,1,28,28] , y:[512]# [b,1,28,28] => [b,784]x = x.view(x.size(0),28*28)# => [b,10]out = net(x)y_onehot = one_hot(y)#loss = mse(out,y_onehot)loss = F.mse_loss(out,y_onehot)   #均方差#清零梯度optimizer.zero_grad()loss.backward()# w' = w -lr * gradoptimizer.step()train_loss.append(loss.item())if batch_idx % 10 ==0:print(epoch,batch_idx,loss.item())
plot_curve(train_loss)

经过五轮peoch完成对MNIST60000张图片的训练,并将每轮结果打印出来,将损失函数记录,并调用plot_curve展现train_loss下降折线图

5.Test

total_correct = 0
for x,y in test_load:x = x.view(x.size(0),28*28)out = net(x)#out :[b,10] => pred:[b]pred = out.argmax(dim=1)correct = pred.eq(y).sum().float().item()total_correct += correcttotal_num = len(test_load.dataset)
acc = total_correct / total_num
print('test acc:',acc)x,y = next(iter(test_load))
out = net(x.view(x.size(0),28*28))
pred = out.argmax(dim=1)
plot_image(x,pred,'test')

 预测值pred取结果中间概率最大的index值作为他的label,可以用过argmax返回最大值的索引,上述代码即可以理解为在dim=1处,取最大值索引

而正确值correct是将y和pred之间做一个比较,利用sum()可得到当前batch中预测结果正确的一个总个数,最终是Tensor类型,再将其转换为数据类型,加上item(),最后total_correct累加

之后就是调用工具函数将数据可视化

mnist_train.py完整代码如下:

import torch
from torch import nn
from torch.nn import functional as F
from torch import  optim
import torchvision
from utils import plot_image,plot_curve,one_hotbatch_size = 512   #GPU单次运行处理图片的数量         批处理大小
#step1.load mnist
train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data',train=True,download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,),(0.3081,))])),batch_size = batch_size,shuffle = True)
test_load = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data/',train=False,download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,),(0.3081,))])),batch_size = batch_size,shuffle = False)x,y = next(iter((train_loader)))
print(x.shape,y.shape,x.min(),x.max())
plot_image(x,y,'image sample')#step2.create network
class Net(nn.Module):def __init__(self):super(Net, self).__init__()#xw+bself.fc1 = nn.Linear(28*28,256)self.fc2 = nn.Linear(256,64)self.fc3 = nn.Linear(64,10)def forward(self,x):# x:[b,1,28,28]# h1:relu(xw1+b1)x = F.relu(self.fc1(x))# h2:relu(h1w2+b2)x = F.relu(self.fc2(x))# h3 = h2w3+b3x = self.fc3(x)return xnet = Net()
#[w1,w2,w3,b1,b2,b3]
optimizer = optim.SGD(net.parameters(),lr=0.01,momentum=0.9)train_loss = []for epoch in range(3):for batch_idx,(x,y) in enumerate(train_loader):# x:[b,1,28,28] , y:[512]# [b,1,28,28] => [b,784]x = x.view(x.size(0),28*28)# => [b,10]out = net(x)y_onehot = one_hot(y)#loss = mse(out,y_onehot)loss = F.mse_loss(out,y_onehot)   #均方差#清零梯度optimizer.zero_grad()loss.backward()# w' = w -lr * gradoptimizer.step()train_loss.append(loss.item())if batch_idx % 10 ==0:print(epoch,batch_idx,loss.item())
plot_curve(train_loss)
# we get optimal [w1,b1,w2,b2,w3,b3]total_correct = 0
for x,y in test_load:x = x.view(x.size(0),28*28)out = net(x)#out :[b,10] => pred:[b]pred = out.argmax(dim=1)correct = pred.eq(y).sum().float().item()total_correct += correcttotal_num = len(test_load.dataset)
acc = total_correct / total_num
print('test acc:',acc)x,y = next(iter(test_load))
out = net(x.view(x.size(0),28*28))
pred = out.argmax(dim=1)
plot_image(x,pred,'test')

下面将结果展示一下

上图为读取mnist_train中的数据,展示图片以及sample

 

 这是损失函数train_loss的图像,可以看到,随着学习的深入,损失值是在不断下降的,最后逐渐趋于稳定,但是本次demo只是利用最简单的三层结构,且利用的SGD,在其它方法的训练下,可能会获得更好的训练效果

识别测试结果,可以看到,准确率还是比较高的,最后的acc大概达到了接近90%

这就是本次的小demo,有很多理解自己也不算吃的很透,随着学习的深入会理解的更加透彻

最后,感谢各位看官,欢迎批评,互相学习进步! 

 

这篇关于利用PyTorch构建三层线性网络完成对MNIST数据集识别的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/561262

相关文章

使用Python和PaddleOCR实现图文识别的代码和步骤

《使用Python和PaddleOCR实现图文识别的代码和步骤》在当今数字化时代,图文识别技术的应用越来越广泛,如文档数字化、信息提取等,PaddleOCR是百度开源的一款强大的OCR工具包,它集成了... 目录一、引言二、环境准备2.1 安装 python2.2 安装 PaddlePaddle2.3 安装

Java注解之超越Javadoc的元数据利器详解

《Java注解之超越Javadoc的元数据利器详解》本文将深入探讨Java注解的定义、类型、内置注解、自定义注解、保留策略、实际应用场景及最佳实践,无论是初学者还是资深开发者,都能通过本文了解如何利用... 目录什么是注解?注解的类型内置注编程解自定义注解注解的保留策略实际用例最佳实践总结在 Java 编程

一文教你Python如何快速精准抓取网页数据

《一文教你Python如何快速精准抓取网页数据》这篇文章主要为大家详细介绍了如何利用Python实现快速精准抓取网页数据,文中的示例代码简洁易懂,具有一定的借鉴价值,有需要的小伙伴可以了解下... 目录1. 准备工作2. 基础爬虫实现3. 高级功能扩展3.1 抓取文章详情3.2 保存数据到文件4. 完整示例

Java中的StringBuilder之如何高效构建字符串

《Java中的StringBuilder之如何高效构建字符串》本文将深入浅出地介绍StringBuilder的使用方法、性能优势以及相关字符串处理技术,结合代码示例帮助读者更好地理解和应用,希望对大家... 目录关键点什么是 StringBuilder?为什么需要 StringBuilder?如何使用 St

使用Java将各种数据写入Excel表格的操作示例

《使用Java将各种数据写入Excel表格的操作示例》在数据处理与管理领域,Excel凭借其强大的功能和广泛的应用,成为了数据存储与展示的重要工具,在Java开发过程中,常常需要将不同类型的数据,本文... 目录前言安装免费Java库1. 写入文本、或数值到 Excel单元格2. 写入数组到 Excel表格

判断PyTorch是GPU版还是CPU版的方法小结

《判断PyTorch是GPU版还是CPU版的方法小结》PyTorch作为当前最流行的深度学习框架之一,支持在CPU和GPU(NVIDIACUDA)上运行,所以对于深度学习开发者来说,正确识别PyTor... 目录前言为什么需要区分GPU和CPU版本?性能差异硬件要求如何检查PyTorch版本?方法1:使用命

python处理带有时区的日期和时间数据

《python处理带有时区的日期和时间数据》这篇文章主要为大家详细介绍了如何在Python中使用pytz库处理时区信息,包括获取当前UTC时间,转换为特定时区等,有需要的小伙伴可以参考一下... 目录时区基本信息python datetime使用timezonepandas处理时区数据知识延展时区基本信息

Qt实现网络数据解析的方法总结

《Qt实现网络数据解析的方法总结》在Qt中解析网络数据通常涉及接收原始字节流,并将其转换为有意义的应用层数据,这篇文章为大家介绍了详细步骤和示例,感兴趣的小伙伴可以了解下... 目录1. 网络数据接收2. 缓冲区管理(处理粘包/拆包)3. 常见数据格式解析3.1 jsON解析3.2 XML解析3.3 自定义

SpringMVC 通过ajax 前后端数据交互的实现方法

《SpringMVC通过ajax前后端数据交互的实现方法》:本文主要介绍SpringMVC通过ajax前后端数据交互的实现方法,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价... 在前端的开发过程中,经常在html页面通过AJAX进行前后端数据的交互,SpringMVC的controll

Pandas统计每行数据中的空值的方法示例

《Pandas统计每行数据中的空值的方法示例》处理缺失数据(NaN值)是一个非常常见的问题,本文主要介绍了Pandas统计每行数据中的空值的方法示例,具有一定的参考价值,感兴趣的可以了解一下... 目录什么是空值?为什么要统计空值?准备工作创建示例数据统计每行空值数量进一步分析www.chinasem.cn处