神经网络第四篇:推理处理之手写数字识别

2024-06-24 11:18

本文主要是介绍神经网络第四篇:推理处理之手写数字识别,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

到目前为止,我们已经介绍完了神经网络的基本结构,现在用一个图像识别示例对前面的知识作整体的总结。本专题知识点如下:

  • MNIST数据集
  • 图像数据转图像
  • 神经网络的推理处理
  • 批处理

  •  MNIST数据集         
mnist数据图像

MNIST数据集由0到9的数字图像构成。像素取值在0到255之间。每个图像数据都相应地标有“7”、“2”、“1”等数字标签。MNIST数据集中,训练数据有6万张,测试图像有1万张。一般先用训练数据进行学习,再用学习到的模型(参数)对测试图像进行识别分类。MNIST数据集可以从官网下载,这里我们用python获取已经下载好并做过处理的MNIST数据集的相关信息:

#数据文件:mnist.pkl,大小约54M。
#文件读取,位置
import pandas as pd
network=pd.read_pickle('F:/deep-learning with python/dataset/mnist.pkl')
type(network) #类型:字典
network.keys() #关键字['train_label', 'train_img', 'test_img', 'test_label']
#训练数据形状,(60000, 784),6万个样本,每个样本由784个数据组成(1·28·28)
network['train_img'].shape
network['train_img'][0,:].max()  #第一个样本的最大值255
network['train_img'][0,:].min()  #第一个样本的最小值0,
network['train_label'].shape #训练标签形状(60000,),由0至9组成的6万个数据
network['train_label'].max() #最大值9
network['train_label'].min() #最小值0
network['train_label']  #训练标签:0~9
network['test_img'].shape #测试数据形状(10000, 784),1万个样本
network['test_label'].shape#测试数据标签形状(10000,)
network['test_label']  #测试标签:0~9

MNIST数据集保存在mnist.pkl,读者可点击:mnist数据集及权重参数下载 进行下载,mnist数据集下载的源码如下:

# coding: utf-8
try:import urllib.request
except ImportError:raise ImportError('You should use Python 3.x')
import os.path
import gzip
import pickle
import os
import numpy as npurl_base = 'http://yann.lecun.com/exdb/mnist/'
key_file = {'train_img':'train-images-idx3-ubyte.gz','train_label':'train-labels-idx1-ubyte.gz','test_img':'t10k-images-idx3-ubyte.gz','test_label':'t10k-labels-idx1-ubyte.gz'
}dataset_dir = os.path.dirname(os.path.abspath(__file__))
save_file = dataset_dir + "/mnist.pkl"train_num = 60000
test_num = 10000
img_dim = (1, 28, 28)
img_size = 784def _download(file_name):file_path = dataset_dir + "/" + file_nameif os.path.exists(file_path):returnprint("Downloading " + file_name + " ... ")urllib.request.urlretrieve(url_base + file_name, file_path)print("Done")def download_mnist():for v in key_file.values():_download(v)def _load_label(file_name):file_path = dataset_dir + "/" + file_nameprint("Converting " + file_name + " to NumPy Array ...")with gzip.open(file_path, 'rb') as f:labels = np.frombuffer(f.read(), np.uint8, offset=8)print("Done")return labelsdef _load_img(file_name):file_path = dataset_dir + "/" + file_nameprint("Converting " + file_name + " to NumPy Array ...")    with gzip.open(file_path, 'rb') as f:data = np.frombuffer(f.read(), np.uint8, offset=16)data = data.reshape(-1, img_size)print("Done")return datadef _convert_numpy():dataset = {}dataset['train_img'] =  _load_img(key_file['train_img'])dataset['train_label'] = _load_label(key_file['train_label'])    dataset['test_img'] = _load_img(key_file['test_img'])dataset['test_label'] = _load_label(key_file['test_label'])return datasetdef init_mnist():download_mnist()dataset = _convert_numpy()print("Creating pickle file ...")with open(save_file, 'wb') as f:pickle.dump(dataset, f, -1)print("Done!")def _change_one_hot_label(X):T = np.zeros((X.size, 10))for idx, row in enumerate(T):row[X[idx]] = 1return Tdef load_mnist(normalize=True, flatten=True, one_hot_label=False):"""读入MNIST数据集Parameters----------normalize : 将图像的像素值正规化为0.0~1.0one_hot_label : one_hot_label为True的情况下,标签作为one-hot数组返回one-hot数组是指[0,0,1,0,0,0,0,0,0,0]这样的数组flatten : 是否将图像展开为一维数组Returns-------(训练图像, 训练标签), (测试图像, 测试标签)"""if not os.path.exists(save_file):init_mnist()with open(save_file, 'rb') as f:dataset = pickle.load(f)if normalize:for key in ('train_img', 'test_img'):dataset[key] = dataset[key].astype(np.float32)dataset[key] /= 255.0if one_hot_label:dataset['train_label'] = _change_one_hot_label(dataset['train_label'])dataset['test_label'] = _change_one_hot_label(dataset['test_label'])if not flatten:for key in ('train_img', 'test_img'):dataset[key] = dataset[key].reshape(-1, 1, 28, 28)return (dataset['train_img'], dataset['train_label']), (dataset['test_img'], dataset['test_label']) if __name__ == '__main__':init_mnist()

为方便训练和预测,一般以(训练数据,训练标签),(测试数据,测试标签)的形式修改数据格式。


  • 图像数据转图像                                                                                                                                       

​​​​​MNIST数据集是图像数据,每一张图的大小为28像素×28像素,像素值在0至255之间。在MNIST数据集中的形状是以一列数组的形式保存的(784个像素),因此要显示为图像时需要对数据进行相关转换。PIL是Python的图像库,可用于显示图像,这里我们用它来将MNIST图像数据转换为图像进行显示,方便大家对本主题知识的理解 。

import numpy as np
from PIL import Image
img_data=network['train_img'][1]     #训练数据中第二个样本的图像数据,(784,)
img_label=network['train_label'][1] #训练数据中第2个样本的标签为
print(img_label)                     #  0
img=img_data.reshape(28,28)#转为图像尺寸
print(img.shape)           #(28, 28)
show_img=Image.fromarray(np.uint8(img)) #转换为PIL显示图像的数据格式
show_img.show()                         #图像显示,0
第二张训练图像数据的图像显示

  • 神经网络的推理处理 

前面介绍的PIL库显示的训练数据中的第二个样本的图形是“0”,实际标签也是“0”。所谓神经网络的推理,即利用神经网络对训练数据进行学习(这里我们先直接使用学到的参数,保存在sample_weight.pkl文件中),利用学到的参数(w,b)对测试数据(test_img)进行识别,然后将识别结果与实际标签(test_label)进行比较,判断推理是否正确。原理如下图:

具体的推理处理过程为:

图像数据有784个像素(28×28),即输入层有784个神经元,推理的结果是识别数字0到9,因此输出层有10个神经元。此外,为和前面知识点相连,我们在这个神经网络中添加了2个隐藏层,第一个隐藏层有50个神经元,第二个隐藏层有100个神经元,隐藏层中神经元的个数可自己设置。由于本主题不涉及参数的学习,因此我们直接使用已学习到的参数,它保存在文件sample_weight.pkl中,根据我们设计的神经网络的结构,我们应该知道参数的结构。下面是神经网络推理的代码:

parms=pd.read_pickle('F:/deep-learning with python/ch03/sample_weight.pkl')
type(parms)  #字典
parms.keys() #['b1', 'W1', 'b2', 'W2', 'b3', 'W3']
parms['b1'].shape #(50,)
parms['W1'].shape #(784,50)
parms['b2'].shape #(100,)
parms['W2'].shape #(50,100)
parms['b3'].shape #(10,)
parms['W3'].shape #(100,10)     
def predict(parms,x):"""代码同三层神经网络的实现一样,只是将随机参数改为实际学到的参数 6激活函数在前面专题已给出"""W1,W2,W3=parms['W1'],parms['W2'],parms['W3']b1,b2,b3=parms['b1'],parms['b2'],parms['b3']a1=np.dot(x,W1)+b1z1=sigmoid(a1)  #激活函数可在以前的文章中找到a2=np.dot(z1,W2)+b2z2=sigmoid(a2)a3=np.dot(z2,W3)+b3y=softmax(a3)   #激活函数可在以前的文章中找到return ytest_img=network['test_img']  #测试数据
test_label=network['test_label'] #测试标签
accuracy_cnt=0
for i in range(len(test_img)):y=predict(parms,test_img[i])p=np.argmax(y)#获取概率最高的元素的索引if p==test_label[i]:accuracy_cnt+=1
print("识别精度:"+str(float(accuracy_cnt)/len(test_img)))  #0.9352

下面我们对代码做简单介绍,首先提取测试数据和测试标签。接着用for循环逐一取出测试数据中的图像数据,然后用predict()函数进行分类,该函数输出各个标签对应的概率,比如输出[0.1,0.2,0.4…,0.03],表示“0”的概率为0.1,1的概率为0.2,9的概率为0.03。然后我们取出这个概率列表中的最大值的索引即为分类结果。最后比较神经网络预测的结果和正确解标签(test_label),对1万张图预测正确的概率作为识别精度(93.52%)。         

在机器学习领域中,一般需要考虑数据预处理,这里我们可将像素0至255可缩小到0至1的范围内(即对所有数据均除以255),然后再输入至神经网络中,这种将数据限制在某个范围内的处理称为正则化,一般情况下,预处理会改善机器学习模型。读者可比较一下图像数据正则化后神经网络的识别精度。


  • 批处理

 上面只介绍了输入一张图像数据时的处理流程。即每次向神经网络中输入一个由784个元素(原本是一个28·28的二维数组)构成的一维数组后,输出一个有10个元素的一维数组。其数据形状如下:

现在我们想predict()函数一次性打包处理100张图像。为此可把X的形状改为100×784,将100张图片打包作为输入数据。数据形状如下:

批处理数据形状

可见,输入数据的形状为100×784,输出数据的形状为100×100,这说明了输入的100张图像的推理结果被一次性输出了。例如x[0]、x[1]....x[99]和y[0]、y[1]....y[99]保存了第1、2....到100张图像的图像数据及其推理结果。这种被打包的输入数据被称为(batch),批处理主要集中在数据计算上,而不是数据读入,因此批处理可缩短时间开销。下面用代码实现如下:

test_img=network['test_img']  #测试数据
test_label=network['test_label'] #测试标签batch_size=100 #批数量
accuracy_cnt=0 #初始识别精度for i in range(0,len(test_img),batch_size):x_batch=test_img[i:i+batch_size]y_batch=predict(parms,x_batch)p=np.argmax(y_batch,axis=1)#取每列最大值accuracy_cnt+=np.sum(p==t[i:i+batch_size])
print("识别精度:"+str(float(accuracy_cnt)/len(test_img)))

批处理代码核心在于for循环语句添加了步数batch_size,输入predict()函数的参数x由以前的单条数据变为x_batch表示的100条数据,寻找二维数组中每行的最大值所在的列位置使用了参数axis=1。

至此,神经网络的基本知识已经讲解完了,后面的内容主要讲解权重参数的学习!欢迎关注微信公众号“Python生态智联”,学知识,享生活!

这篇关于神经网络第四篇:推理处理之手写数字识别的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1089999

相关文章

Java 中的 @SneakyThrows 注解使用方法(简化异常处理的利与弊)

《Java中的@SneakyThrows注解使用方法(简化异常处理的利与弊)》为了简化异常处理,Lombok提供了一个强大的注解@SneakyThrows,本文将详细介绍@SneakyThro... 目录1. @SneakyThrows 简介 1.1 什么是 Lombok?2. @SneakyThrows

在 Spring Boot 中实现异常处理最佳实践

《在SpringBoot中实现异常处理最佳实践》本文介绍如何在SpringBoot中实现异常处理,涵盖核心概念、实现方法、与先前查询的集成、性能分析、常见问题和最佳实践,感兴趣的朋友一起看看吧... 目录一、Spring Boot 异常处理的背景与核心概念1.1 为什么需要异常处理?1.2 Spring B

python处理带有时区的日期和时间数据

《python处理带有时区的日期和时间数据》这篇文章主要为大家详细介绍了如何在Python中使用pytz库处理时区信息,包括获取当前UTC时间,转换为特定时区等,有需要的小伙伴可以参考一下... 目录时区基本信息python datetime使用timezonepandas处理时区数据知识延展时区基本信息

Python Transformers库(NLP处理库)案例代码讲解

《PythonTransformers库(NLP处理库)案例代码讲解》本文介绍transformers库的全面讲解,包含基础知识、高级用法、案例代码及学习路径,内容经过组织,适合不同阶段的学习者,对... 目录一、基础知识1. Transformers 库简介2. 安装与环境配置3. 快速上手示例二、核心模

一文详解Java异常处理你都了解哪些知识

《一文详解Java异常处理你都了解哪些知识》:本文主要介绍Java异常处理的相关资料,包括异常的分类、捕获和处理异常的语法、常见的异常类型以及自定义异常的实现,文中通过代码介绍的非常详细,需要的朋... 目录前言一、什么是异常二、异常的分类2.1 受检异常2.2 非受检异常三、异常处理的语法3.1 try-

Python使用getopt处理命令行参数示例解析(最佳实践)

《Python使用getopt处理命令行参数示例解析(最佳实践)》getopt模块是Python标准库中一个简单但强大的命令行参数处理工具,它特别适合那些需要快速实现基本命令行参数解析的场景,或者需要... 目录为什么需要处理命令行参数?getopt模块基础实际应用示例与其他参数处理方式的比较常见问http

Java Response返回值的最佳处理方案

《JavaResponse返回值的最佳处理方案》在开发Web应用程序时,我们经常需要通过HTTP请求从服务器获取响应数据,这些数据可以是JSON、XML、甚至是文件,本篇文章将详细解析Java中处理... 目录摘要概述核心问题:关键技术点:源码解析示例 1:使用HttpURLConnection获取Resp

Java中Switch Case多个条件处理方法举例

《Java中SwitchCase多个条件处理方法举例》Java中switch语句用于根据变量值执行不同代码块,适用于多个条件的处理,:本文主要介绍Java中SwitchCase多个条件处理的相... 目录前言基本语法处理多个条件示例1:合并相同代码的多个case示例2:通过字符串合并多个case进阶用法使用

Java实现优雅日期处理的方案详解

《Java实现优雅日期处理的方案详解》在我们的日常工作中,需要经常处理各种格式,各种类似的的日期或者时间,下面我们就来看看如何使用java处理这样的日期问题吧,感兴趣的小伙伴可以跟随小编一起学习一下... 目录前言一、日期的坑1.1 日期格式化陷阱1.2 时区转换二、优雅方案的进阶之路2.1 线程安全重构2

Python实现特殊字符判断并去掉非字母和数字的特殊字符

《Python实现特殊字符判断并去掉非字母和数字的特殊字符》在Python中,可以通过多种方法来判断字符串中是否包含非字母、数字的特殊字符,并将这些特殊字符去掉,本文为大家整理了一些常用的,希望对大家... 目录1. 使用正则表达式判断字符串中是否包含特殊字符去掉字符串中的特殊字符2. 使用 str.isa