神经网络第四篇:推理处理之手写数字识别

2024-06-24 11:18

本文主要是介绍神经网络第四篇:推理处理之手写数字识别,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

到目前为止,我们已经介绍完了神经网络的基本结构,现在用一个图像识别示例对前面的知识作整体的总结。本专题知识点如下:

  • MNIST数据集
  • 图像数据转图像
  • 神经网络的推理处理
  • 批处理

  •  MNIST数据集         
mnist数据图像

MNIST数据集由0到9的数字图像构成。像素取值在0到255之间。每个图像数据都相应地标有“7”、“2”、“1”等数字标签。MNIST数据集中,训练数据有6万张,测试图像有1万张。一般先用训练数据进行学习,再用学习到的模型(参数)对测试图像进行识别分类。MNIST数据集可以从官网下载,这里我们用python获取已经下载好并做过处理的MNIST数据集的相关信息:

#数据文件:mnist.pkl,大小约54M。
#文件读取,位置
import pandas as pd
network=pd.read_pickle('F:/deep-learning with python/dataset/mnist.pkl')
type(network) #类型:字典
network.keys() #关键字['train_label', 'train_img', 'test_img', 'test_label']
#训练数据形状,(60000, 784),6万个样本,每个样本由784个数据组成(1·28·28)
network['train_img'].shape
network['train_img'][0,:].max()  #第一个样本的最大值255
network['train_img'][0,:].min()  #第一个样本的最小值0,
network['train_label'].shape #训练标签形状(60000,),由0至9组成的6万个数据
network['train_label'].max() #最大值9
network['train_label'].min() #最小值0
network['train_label']  #训练标签:0~9
network['test_img'].shape #测试数据形状(10000, 784),1万个样本
network['test_label'].shape#测试数据标签形状(10000,)
network['test_label']  #测试标签:0~9

MNIST数据集保存在mnist.pkl,读者可点击:mnist数据集及权重参数下载 进行下载,mnist数据集下载的源码如下:

# coding: utf-8
try:import urllib.request
except ImportError:raise ImportError('You should use Python 3.x')
import os.path
import gzip
import pickle
import os
import numpy as npurl_base = 'http://yann.lecun.com/exdb/mnist/'
key_file = {'train_img':'train-images-idx3-ubyte.gz','train_label':'train-labels-idx1-ubyte.gz','test_img':'t10k-images-idx3-ubyte.gz','test_label':'t10k-labels-idx1-ubyte.gz'
}dataset_dir = os.path.dirname(os.path.abspath(__file__))
save_file = dataset_dir + "/mnist.pkl"train_num = 60000
test_num = 10000
img_dim = (1, 28, 28)
img_size = 784def _download(file_name):file_path = dataset_dir + "/" + file_nameif os.path.exists(file_path):returnprint("Downloading " + file_name + " ... ")urllib.request.urlretrieve(url_base + file_name, file_path)print("Done")def download_mnist():for v in key_file.values():_download(v)def _load_label(file_name):file_path = dataset_dir + "/" + file_nameprint("Converting " + file_name + " to NumPy Array ...")with gzip.open(file_path, 'rb') as f:labels = np.frombuffer(f.read(), np.uint8, offset=8)print("Done")return labelsdef _load_img(file_name):file_path = dataset_dir + "/" + file_nameprint("Converting " + file_name + " to NumPy Array ...")    with gzip.open(file_path, 'rb') as f:data = np.frombuffer(f.read(), np.uint8, offset=16)data = data.reshape(-1, img_size)print("Done")return datadef _convert_numpy():dataset = {}dataset['train_img'] =  _load_img(key_file['train_img'])dataset['train_label'] = _load_label(key_file['train_label'])    dataset['test_img'] = _load_img(key_file['test_img'])dataset['test_label'] = _load_label(key_file['test_label'])return datasetdef init_mnist():download_mnist()dataset = _convert_numpy()print("Creating pickle file ...")with open(save_file, 'wb') as f:pickle.dump(dataset, f, -1)print("Done!")def _change_one_hot_label(X):T = np.zeros((X.size, 10))for idx, row in enumerate(T):row[X[idx]] = 1return Tdef load_mnist(normalize=True, flatten=True, one_hot_label=False):"""读入MNIST数据集Parameters----------normalize : 将图像的像素值正规化为0.0~1.0one_hot_label : one_hot_label为True的情况下,标签作为one-hot数组返回one-hot数组是指[0,0,1,0,0,0,0,0,0,0]这样的数组flatten : 是否将图像展开为一维数组Returns-------(训练图像, 训练标签), (测试图像, 测试标签)"""if not os.path.exists(save_file):init_mnist()with open(save_file, 'rb') as f:dataset = pickle.load(f)if normalize:for key in ('train_img', 'test_img'):dataset[key] = dataset[key].astype(np.float32)dataset[key] /= 255.0if one_hot_label:dataset['train_label'] = _change_one_hot_label(dataset['train_label'])dataset['test_label'] = _change_one_hot_label(dataset['test_label'])if not flatten:for key in ('train_img', 'test_img'):dataset[key] = dataset[key].reshape(-1, 1, 28, 28)return (dataset['train_img'], dataset['train_label']), (dataset['test_img'], dataset['test_label']) if __name__ == '__main__':init_mnist()

为方便训练和预测,一般以(训练数据,训练标签),(测试数据,测试标签)的形式修改数据格式。


  • 图像数据转图像                                                                                                                                       

​​​​​MNIST数据集是图像数据,每一张图的大小为28像素×28像素,像素值在0至255之间。在MNIST数据集中的形状是以一列数组的形式保存的(784个像素),因此要显示为图像时需要对数据进行相关转换。PIL是Python的图像库,可用于显示图像,这里我们用它来将MNIST图像数据转换为图像进行显示,方便大家对本主题知识的理解 。

import numpy as np
from PIL import Image
img_data=network['train_img'][1]     #训练数据中第二个样本的图像数据,(784,)
img_label=network['train_label'][1] #训练数据中第2个样本的标签为
print(img_label)                     #  0
img=img_data.reshape(28,28)#转为图像尺寸
print(img.shape)           #(28, 28)
show_img=Image.fromarray(np.uint8(img)) #转换为PIL显示图像的数据格式
show_img.show()                         #图像显示,0
第二张训练图像数据的图像显示

  • 神经网络的推理处理 

前面介绍的PIL库显示的训练数据中的第二个样本的图形是“0”,实际标签也是“0”。所谓神经网络的推理,即利用神经网络对训练数据进行学习(这里我们先直接使用学到的参数,保存在sample_weight.pkl文件中),利用学到的参数(w,b)对测试数据(test_img)进行识别,然后将识别结果与实际标签(test_label)进行比较,判断推理是否正确。原理如下图:

具体的推理处理过程为:

图像数据有784个像素(28×28),即输入层有784个神经元,推理的结果是识别数字0到9,因此输出层有10个神经元。此外,为和前面知识点相连,我们在这个神经网络中添加了2个隐藏层,第一个隐藏层有50个神经元,第二个隐藏层有100个神经元,隐藏层中神经元的个数可自己设置。由于本主题不涉及参数的学习,因此我们直接使用已学习到的参数,它保存在文件sample_weight.pkl中,根据我们设计的神经网络的结构,我们应该知道参数的结构。下面是神经网络推理的代码:

parms=pd.read_pickle('F:/deep-learning with python/ch03/sample_weight.pkl')
type(parms)  #字典
parms.keys() #['b1', 'W1', 'b2', 'W2', 'b3', 'W3']
parms['b1'].shape #(50,)
parms['W1'].shape #(784,50)
parms['b2'].shape #(100,)
parms['W2'].shape #(50,100)
parms['b3'].shape #(10,)
parms['W3'].shape #(100,10)     
def predict(parms,x):"""代码同三层神经网络的实现一样,只是将随机参数改为实际学到的参数 6激活函数在前面专题已给出"""W1,W2,W3=parms['W1'],parms['W2'],parms['W3']b1,b2,b3=parms['b1'],parms['b2'],parms['b3']a1=np.dot(x,W1)+b1z1=sigmoid(a1)  #激活函数可在以前的文章中找到a2=np.dot(z1,W2)+b2z2=sigmoid(a2)a3=np.dot(z2,W3)+b3y=softmax(a3)   #激活函数可在以前的文章中找到return ytest_img=network['test_img']  #测试数据
test_label=network['test_label'] #测试标签
accuracy_cnt=0
for i in range(len(test_img)):y=predict(parms,test_img[i])p=np.argmax(y)#获取概率最高的元素的索引if p==test_label[i]:accuracy_cnt+=1
print("识别精度:"+str(float(accuracy_cnt)/len(test_img)))  #0.9352

下面我们对代码做简单介绍,首先提取测试数据和测试标签。接着用for循环逐一取出测试数据中的图像数据,然后用predict()函数进行分类,该函数输出各个标签对应的概率,比如输出[0.1,0.2,0.4…,0.03],表示“0”的概率为0.1,1的概率为0.2,9的概率为0.03。然后我们取出这个概率列表中的最大值的索引即为分类结果。最后比较神经网络预测的结果和正确解标签(test_label),对1万张图预测正确的概率作为识别精度(93.52%)。         

在机器学习领域中,一般需要考虑数据预处理,这里我们可将像素0至255可缩小到0至1的范围内(即对所有数据均除以255),然后再输入至神经网络中,这种将数据限制在某个范围内的处理称为正则化,一般情况下,预处理会改善机器学习模型。读者可比较一下图像数据正则化后神经网络的识别精度。


  • 批处理

 上面只介绍了输入一张图像数据时的处理流程。即每次向神经网络中输入一个由784个元素(原本是一个28·28的二维数组)构成的一维数组后,输出一个有10个元素的一维数组。其数据形状如下:

现在我们想predict()函数一次性打包处理100张图像。为此可把X的形状改为100×784,将100张图片打包作为输入数据。数据形状如下:

批处理数据形状

可见,输入数据的形状为100×784,输出数据的形状为100×100,这说明了输入的100张图像的推理结果被一次性输出了。例如x[0]、x[1]....x[99]和y[0]、y[1]....y[99]保存了第1、2....到100张图像的图像数据及其推理结果。这种被打包的输入数据被称为(batch),批处理主要集中在数据计算上,而不是数据读入,因此批处理可缩短时间开销。下面用代码实现如下:

test_img=network['test_img']  #测试数据
test_label=network['test_label'] #测试标签batch_size=100 #批数量
accuracy_cnt=0 #初始识别精度for i in range(0,len(test_img),batch_size):x_batch=test_img[i:i+batch_size]y_batch=predict(parms,x_batch)p=np.argmax(y_batch,axis=1)#取每列最大值accuracy_cnt+=np.sum(p==t[i:i+batch_size])
print("识别精度:"+str(float(accuracy_cnt)/len(test_img)))

批处理代码核心在于for循环语句添加了步数batch_size,输入predict()函数的参数x由以前的单条数据变为x_batch表示的100条数据,寻找二维数组中每行的最大值所在的列位置使用了参数axis=1。

至此,神经网络的基本知识已经讲解完了,后面的内容主要讲解权重参数的学习!欢迎关注微信公众号“Python生态智联”,学知识,享生活!

这篇关于神经网络第四篇:推理处理之手写数字识别的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1089999

相关文章

SpringBoot分段处理List集合多线程批量插入数据方式

《SpringBoot分段处理List集合多线程批量插入数据方式》文章介绍如何处理大数据量List批量插入数据库的优化方案:通过拆分List并分配独立线程处理,结合Spring线程池与异步方法提升效率... 目录项目场景解决方案1.实体类2.Mapper3.spring容器注入线程池bejsan对象4.创建

PHP轻松处理千万行数据的方法详解

《PHP轻松处理千万行数据的方法详解》说到处理大数据集,PHP通常不是第一个想到的语言,但如果你曾经需要处理数百万行数据而不让服务器崩溃或内存耗尽,你就会知道PHP用对了工具有多强大,下面小编就... 目录问题的本质php 中的数据流处理:为什么必不可少生成器:内存高效的迭代方式流量控制:避免系统过载一次性

Python实现批量CSV转Excel的高性能处理方案

《Python实现批量CSV转Excel的高性能处理方案》在日常办公中,我们经常需要将CSV格式的数据转换为Excel文件,本文将介绍一个基于Python的高性能解决方案,感兴趣的小伙伴可以跟随小编一... 目录一、场景需求二、技术方案三、核心代码四、批量处理方案五、性能优化六、使用示例完整代码七、小结一、

Python中 try / except / else / finally 异常处理方法详解

《Python中try/except/else/finally异常处理方法详解》:本文主要介绍Python中try/except/else/finally异常处理方法的相关资料,涵... 目录1. 基本结构2. 各部分的作用tryexceptelsefinally3. 执行流程总结4. 常见用法(1)多个e

PHP应用中处理限流和API节流的最佳实践

《PHP应用中处理限流和API节流的最佳实践》限流和API节流对于确保Web应用程序的可靠性、安全性和可扩展性至关重要,本文将详细介绍PHP应用中处理限流和API节流的最佳实践,下面就来和小编一起学习... 目录限流的重要性在 php 中实施限流的最佳实践使用集中式存储进行状态管理(如 Redis)采用滑动

MyBatis-plus处理存储json数据过程

《MyBatis-plus处理存储json数据过程》文章介绍MyBatis-Plus3.4.21处理对象与集合的差异:对象可用内置Handler配合autoResultMap,集合需自定义处理器继承F... 目录1、如果是对象2、如果需要转换的是List集合总结对象和集合分两种情况处理,目前我用的MP的版本

Python自动化处理PDF文档的操作完整指南

《Python自动化处理PDF文档的操作完整指南》在办公自动化中,PDF文档处理是一项常见需求,本文将介绍如何使用Python实现PDF文档的自动化处理,感兴趣的小伙伴可以跟随小编一起学习一下... 目录使用pymupdf读写PDF文件基本概念安装pymupdf提取文本内容提取图像添加水印使用pdfplum

C# LiteDB处理时间序列数据的高性能解决方案

《C#LiteDB处理时间序列数据的高性能解决方案》LiteDB作为.NET生态下的轻量级嵌入式NoSQL数据库,一直是时间序列处理的优选方案,本文将为大家大家简单介绍一下LiteDB处理时间序列数... 目录为什么选择LiteDB处理时间序列数据第一章:LiteDB时间序列数据模型设计1.1 核心设计原则

基于Redis自动过期的流处理暂停机制

《基于Redis自动过期的流处理暂停机制》基于Redis自动过期的流处理暂停机制是一种高效、可靠且易于实现的解决方案,防止延时过大的数据影响实时处理自动恢复处理,以避免积压的数据影响实时性,下面就来详... 目录核心思路代码实现1. 初始化Redis连接和键前缀2. 接收数据时检查暂停状态3. 检测到延时过

Java利用@SneakyThrows注解提升异常处理效率详解

《Java利用@SneakyThrows注解提升异常处理效率详解》这篇文章将深度剖析@SneakyThrows的原理,用法,适用场景以及隐藏的陷阱,看看它如何让Java异常处理效率飙升50%,感兴趣的... 目录前言一、检查型异常的“诅咒”:为什么Java开发者讨厌它1.1 检查型异常的痛点1.2 为什么说