基于LeNet5的手写数字识别神经网络

2024-01-02 03:48

本文主要是介绍基于LeNet5的手写数字识别神经网络,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

关于CNN,迄今为止已经提出了各种网络结构。在这里我们着重介绍一下在1998年首次被提出的CNN元组LeNet。LeNet子啊1998年被提出,是进行手写数字识别的网络,如下图所示,他又连续的卷积层和池化层(正确地讲,是只“抽选元素”的子采样层),最后经全连接输出结果。

在初始的LeNet中,输入时32*32的图像,经过卷积层输出channel为6,大小28*28的feature map,在经过子采样(Subsampling)池化后,将图像大小变为14*14,(stride=2)在进行卷积,output_channel变为16,大小10*10,在经过一层子采样池化,将图像最终变为5*5,传给全连接层,经过全连接层处理后输出。具体处理流程如下图:

和现在的CNN相比,LeNet有几个不同点。第一个不同点在于激活函数,LeNet中使用的是sigmoid函数 ,而现在的CNN中主要使用ReLU函数。此外,原始的LeNet中使用子采样(Subsampling)缩小中间数据大小,而现在的CNN中Max池化是主流。

下面我们完成一个基于LeNet5的网络对MNIST数据集的识别:

首先我们先建立数据集,在这里可以说利用datasets下载这样的简易数据集简直不要太好用

mnist_train = datasets.MNIST('MNIST',True,transform=transforms.Compose([transforms.Resize((28,28)),transforms.ToTensor()]),download=True)mnist_train = DataLoader(mnist_train,batch_size=batch_size,shuffle=True)mnist_test = datasets.MNIST('MNIST',False,transform=transforms.Compose([transforms.Resize((28,28)),transforms.ToTensor()]),download=True)mnist_test = DataLoader(mnist_test,batch_size=batch_size,shuffle=True)

我们对下载好的数据集进行输出,看看情况怎么样(batch_size = 32)

 x,label = iter(mnist_train).next()print('x:',x.shape,' label:',label.shape)#输出结果:x: torch.Size([32, 1, 28, 28])  label: torch.Size([32])

下面我们来建立一个LeNet网络:

class lenet5(nn.Module):"""for MNIST DATASET"""def __init__(self):super(lenet5, self).__init__()# convolutionsself.cov_unit = nn.Sequential(nn.Conv2d(1,6,kernel_size=5,stride=1,padding=1),nn.MaxPool2d(kernel_size=2,stride=2,padding=0),nn.Conv2d(6,16,kernel_size=5,stride=1,padding=1),nn.MaxPool2d(kernel_size=2,stride=2,padding=0))#flattenself.fc_unit = nn.Sequential(nn.Linear(16*5*5,120),nn.ReLU(),nn.Linear(120,84),nn.ReLU(),nn.Linear(84,10))def forward(self,x):batchsz = x.size(0)x = self.cov_unit(x)x = x.view(batchsz,16*5*5)logits = self.fc_unit(x)return logits

在这里需要借鉴一下:LeNet论文阅读:LeNet结构以及参数个数计算_silent56_th的博客-CSDN博客icon-default.png?t=L892https://blog.csdn.net/silent56_th/article/details/53456522

博主的博客内对输入数据和隐藏层的参数分析以及为何不全采用全连接做了解释:我自己对于kernel_size部分的参数选定还有理解不到位的地方,在这里借鉴一下:

S1-C2对应关系

已经搭建好了LeNet网络,下面定义优化器和损失函数已经利用GPU进行加速:

device = torch.device('cuda')
model = lenet5().to(device)
criteon = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters(),lr=1e-3)

 在这里我们可以将model在控制台打印出来,观察一下,在整体观察一下LeNet网络模型:

# 控制台打印输出
model: lenet5((cov_unit): Sequential((0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1), padding=(1, 1))(1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1), padding=(1, 1))(3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(fc_unit): Sequential((0): Linear(in_features=400, out_features=120, bias=True)(1): ReLU()(2): Linear(in_features=120, out_features=84, bias=True)(3): ReLU()(4): Linear(in_features=84, out_features=10, bias=True))
)

 从这里可以看到,基本是我们所设置的一个网络模型,现在网络已经搭建完毕,优化器和参数都以设定,下面开始进行训练:

​
for batchidx,(x,label) in enumerate(mnist_train):x,label = x.to(device),label.to(device)logits = model(x)loss = criteon(logits,label)optimizer.zero_grad()loss.backward()optimizer.step()print('epoch:',epoch,' loss:',loss.item())

这里的logits原指sigmoid函数(标准logits函数),但是在这里用来表示最终全连接层输出,而非其本意。在每个epoch结束后,将损失函数loss的值打印在控制台

下面是进行的测试:

 model.eval()with torch.no_grad():total_num = 0total_correct = 0for x,label in mnist_test:x,label = x.to(device),label.to(device)logits = model(x)pred = logits.argmax(dim=1)total_correct += torch.eq(pred,label).float().sum().item()total_num += x.size(0)acc = total_correct/total_numprint('epoch:',epoch,' accuarcy:',acc)

 这里的pred = logits.argmax(dim=1),argmax函数是返回最大值的索引,即经过训练后预测结果概率最大的索引,这里将pred和监督标签label进行比较,如果equal便加到total_correct中,最后计算acc。

在进行训练和测试之前,分别添加了model.train()以及model.eval()

(1). model.train()
启用 BatchNormalization 和 Dropout,将BatchNormalization和Dropout置为True
(2). model.eval()
不启用 BatchNormalization 和 Dropout,将BatchNormalization和Dropout置为False

BatchNormalization的思路是调整各层的激活值分布使其拥有适当的广度,要向神经网络中插入数据分布进行正规化的层,可以使学习快速进行、不那么依赖初始值同时还可以一定程度抑制过拟合

Dropout是一种在学习过程中随机删除神经元的方法,通过随机选择并删除神经元,停止向前传递信号,使用Dropout可以使训练数据和测试数据的识别精度的差距变小了,即使是表现力很强的网络,也可以抑制过拟合。

最后,我们为了使数据可以更好的展现和反馈,我们利用visdom进行可视化

viz = Visdom()
viz.line([0.], [0.], win='train_loss', opts=dict(title='train_loss'))
global_step = 0

对于train_loss,从[0,0]坐标开始,每一个epoch执行完,global_step += 1 

  global_step += 1viz.line([loss.item()],[global_step],win='train_loss', update='append')

 将本次epoch内计算的loss以折线图的方式绘制

viz.images(x.view(-1, 1, 28, 28), win='x')
viz.text(str(pred.detach().cpu().numpy()), win='pred',opts=dict(title='pred'))

此时x.shape为[16,1,28,28],str(pred.detach().cpu.numpy())是将预测值变为数据类型打印出来

经过15个epoch我们可以看到,这个识别的精确度已经很高了,我们在看一下visdom可视化的结果:

 也是可以看到的,虽然略有起伏,但是train_loss还是在逐步下降的,我们抽取了10个数据进行展示,可以看到, 预测的结果也是十分准确的。

Conclusion:LeNet只是在1998年最早提出来的CNN,与现在的CNN虽然有些许不同,但是差别也不是很大,考虑到提出的时间很早,所以LeNet还是十分令人称奇的

这篇关于基于LeNet5的手写数字识别神经网络的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/561264

相关文章

Python中图片与PDF识别文本(OCR)的全面指南

《Python中图片与PDF识别文本(OCR)的全面指南》在数据爆炸时代,80%的企业数据以非结构化形式存在,其中PDF和图像是最主要的载体,本文将深入探索Python中OCR技术如何将这些数字纸张转... 目录一、OCR技术核心原理二、python图像识别四大工具库1. Pytesseract - 经典O

Python基于微信OCR引擎实现高效图片文字识别

《Python基于微信OCR引擎实现高效图片文字识别》这篇文章主要为大家详细介绍了一款基于微信OCR引擎的图片文字识别桌面应用开发全过程,可以实现从图片拖拽识别到文字提取,感兴趣的小伙伴可以跟随小编一... 目录一、项目概述1.1 开发背景1.2 技术选型1.3 核心优势二、功能详解2.1 核心功能模块2.

Python验证码识别方式(使用pytesseract库)

《Python验证码识别方式(使用pytesseract库)》:本文主要介绍Python验证码识别方式(使用pytesseract库),具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全... 目录1、安装Tesseract-OCR2、在python中使用3、本地图片识别4、结合playwrigh

使用Python和PaddleOCR实现图文识别的代码和步骤

《使用Python和PaddleOCR实现图文识别的代码和步骤》在当今数字化时代,图文识别技术的应用越来越广泛,如文档数字化、信息提取等,PaddleOCR是百度开源的一款强大的OCR工具包,它集成了... 目录一、引言二、环境准备2.1 安装 python2.2 安装 PaddlePaddle2.3 安装

Python实现特殊字符判断并去掉非字母和数字的特殊字符

《Python实现特殊字符判断并去掉非字母和数字的特殊字符》在Python中,可以通过多种方法来判断字符串中是否包含非字母、数字的特殊字符,并将这些特殊字符去掉,本文为大家整理了一些常用的,希望对大家... 目录1. 使用正则表达式判断字符串中是否包含特殊字符去掉字符串中的特殊字符2. 使用 str.isa

使用PyTorch实现手写数字识别功能

《使用PyTorch实现手写数字识别功能》在人工智能的世界里,计算机视觉是最具魅力的领域之一,通过PyTorch这一强大的深度学习框架,我们将在经典的MNIST数据集上,见证一个神经网络从零开始学会识... 目录当计算机学会“看”数字搭建开发环境MNIST数据集解析1. 认识手写数字数据库2. 数据预处理的

java字符串数字补齐位数详解

《java字符串数字补齐位数详解》:本文主要介绍java字符串数字补齐位数,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录Java字符串数字补齐位数一、使用String.format()方法二、Apache Commons Lang库方法三、Java 11+的St

Pytorch微调BERT实现命名实体识别

《Pytorch微调BERT实现命名实体识别》命名实体识别(NER)是自然语言处理(NLP)中的一项关键任务,它涉及识别和分类文本中的关键实体,BERT是一种强大的语言表示模型,在各种NLP任务中显著... 目录环境准备加载预训练BERT模型准备数据集标记与对齐微调 BERT最后总结环境准备在继续之前,确

讯飞webapi语音识别接口调用示例代码(python)

《讯飞webapi语音识别接口调用示例代码(python)》:本文主要介绍如何使用Python3调用讯飞WebAPI语音识别接口,重点解决了在处理语音识别结果时判断是否为最后一帧的问题,通过运行代... 目录前言一、环境二、引入库三、代码实例四、运行结果五、总结前言基于python3 讯飞webAPI语音

使用Python开发一个图像标注与OCR识别工具

《使用Python开发一个图像标注与OCR识别工具》:本文主要介绍一个使用Python开发的工具,允许用户在图像上进行矩形标注,使用OCR对标注区域进行文本识别,并将结果保存为Excel文件,感兴... 目录项目简介1. 图像加载与显示2. 矩形标注3. OCR识别4. 标注的保存与加载5. 裁剪与重置图像