罗斯基白话:TensorFlow + 实战系列(五)实战MNIST

2023-10-15 12:20

本文主要是介绍罗斯基白话:TensorFlow + 实战系列(五)实战MNIST,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

 

白话TensorFlow +实战系列(五)
实战MNIST

 

       这篇文章主要用全连接神经网络来实现MNIST手写数字识别的问题。首先介绍下MNIST数据集。

       1)MNIST数据集

       MNIST数据集是一个非常有名的手写数字识别数据集,它包含了60000张图片作为训练集,10000张图片为测试集,每张图为一个手写的0~9数字。如图:




其中每张图的大小均为28*28,这里大小指的的是像素。例如数字1所对应的像素矩阵为:




而我们要做的就是教会电脑识别每个手写数字。这个数据集非常经典,常作为学习神经网络的入门教材,一如每个程序员的第一个程序都是“helloword!”一样。

 

       2)数据处理

       数据集下载下来后有四个文件,分别为训练集图片,训练集答案,测试集图片,测试集答案。TensorFlow提供了一个类来处理MNIST数据,这个类会自动的将MNIST数据分为训练集,验证集与测试集,并且这些数据都是可以直接喂给神经网络作为输入用的。示例代码如下:



      

 其中input_data.read_data_sets会自动将数据集进行处理,one_hot = True用独热方式表示,意思是每个数字由one_hot方式表,例如数字0 = [1,0,0,0,0,0,0,0,0,0],1 = [0,1,0,0,0,0,0,0,0,0]。运行结果如下:




接下来就用一个全连接神经网络来识别数字。

 

       3)全连接神经网络

       首先定义超参数与参数,没啥好解释的,代码如下:


import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_databatch_size = 100
learning_rate = 0.8
trainig_step = 30000n_input = 784
n_hidden = 500
n_labels = 10

 接着定义网络的结构,构建的网络只有一个隐藏层,隐藏层节点为500。代码如下:


def inference(x_input):with tf.variable_scope("hidden"):weights = tf.get_variable("weights", [n_input, n_hidden], initializer = tf.random_normal_initializer(stddev = 0.1))biases = tf.get_variable("biases", [n_hidden], initializer = tf.constant_initializer(0.0))hidden = tf.nn.relu(tf.matmul(x_input, weights) + biases)with tf.variable_scope("out"):weights  = tf.get_variable("weights", [n_hidden, n_labels], initializer = tf.random_normal_initializer(stddev = 0.1))biases = tf.get_variable("biases", [n_labels], initializer = tf.constant_initializer(0.0))output = tf.matmul(hidden, weights) + biasesreturn output

在输出层中,output并没有用到relu函数,因为在之后的softmax层中也是非线性激励,所以可以不用。

 

接着定义训练过程,代码如下:


def train(mnist):x = tf.placeholder("float", [None, n_input])y = tf.placeholder("float", [None, n_labels])pred = inference(x)#计算损失函数cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = pred, labels = y))#定义优化器optimizer = tf.train.GradientDescentOptimizer(learning_rate = learning_rate).minimize(cross_entropy)#定义准确率计算correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))init = tf.global_variables_initializer()with tf.Session() as sess:sess.run(init)#定义验证集与测试集validate_data = {x: mnist.validation.images, y: mnist.validation.labels}test_data = {x: mnist.test.images, y: mnist.test.labels}for i in range(trainig_step):#xs,ys为每个batch_size的训练数据与对应的标签xs, ys = mnist.train.next_batch(batch_size)_, loss = sess.run([optimizer, cross_entropy], feed_dict={x: xs, y:ys})#每1000次训练打印一次损失值与验证准确率if i % 1000 == 0:validate_accuracy = sess.run(accuracy, feed_dict=validate_data)print("after %d training steps, the loss is %g, the validation accuracy is %g" % (i, loss, validate_accuracy))print("the training is finish!")#最终的测试准确率acc = sess.run(accuracy, feed_dict=test_data)print("the test accuarcy is:", acc)


其中每一步的函数作用可以参考我的第二篇博客: 罗斯基白话:TensorFlow+实战系列(二)从零构建传统神经网络

里面有详细的解释。


完整代码如下:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_databatch_size = 100
learning_rate = 0.8
trainig_step = 30000n_input = 784
n_hidden = 500
n_labels = 10def inference(x_input):with tf.variable_scope("hidden"):weights = tf.get_variable("weights", [n_input, n_hidden], initializer = tf.random_normal_initializer(stddev = 0.1))biases = tf.get_variable("biases", [n_hidden], initializer = tf.constant_initializer(0.0))hidden = tf.nn.relu(tf.matmul(x_input, weights) + biases)with tf.variable_scope("out"):weights  = tf.get_variable("weights", [n_hidden, n_labels], initializer = tf.random_normal_initializer(stddev = 0.1))biases = tf.get_variable("biases", [n_labels], initializer = tf.constant_initializer(0.0))output = tf.matmul(hidden, weights) + biasesreturn outputdef train(mnist):x = tf.placeholder("float", [None, n_input])y = tf.placeholder("float", [None, n_labels])pred = inference(x)#计算损失函数cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = pred, labels = y))#定义优化器optimizer = tf.train.GradientDescentOptimizer(learning_rate = learning_rate).minimize(cross_entropy)#定义准确率计算correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))init = tf.global_variables_initializer()with tf.Session() as sess:sess.run(init)#定义验证集与测试集validate_data = {x: mnist.validation.images, y: mnist.validation.labels}test_data = {x: mnist.test.images, y: mnist.test.labels}for i in range(trainig_step):#xs,ys为每个batch_size的训练数据与对应的标签xs, ys = mnist.train.next_batch(batch_size)_, loss = sess.run([optimizer, cross_entropy], feed_dict={x: xs, y:ys})#每1000次训练打印一次损失值与验证准确率if i % 1000 == 0:validate_accuracy = sess.run(accuracy, feed_dict=validate_data)print("after %d training steps, the loss is %g, the validation accuracy is %g" % (i, loss, validate_accuracy))print("the training is finish!")#最终的测试准确率acc = sess.run(accuracy, feed_dict=test_data)print("the test accuarcy is:", acc)def main(argv = None):mnist = input_data.read_data_sets("/tensorflow/mnst_data", one_hot=True)train(mnist)if __name__ == "__main__":tf.app.run()

 

最后执行的结果如图:




可以看到最终的准确率能达到98.19%,看来效果还是很不错的。嘿嘿。

       

这篇关于罗斯基白话:TensorFlow + 实战系列(五)实战MNIST的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/217665

相关文章

MySQL 多列 IN 查询之语法、性能与实战技巧(最新整理)

《MySQL多列IN查询之语法、性能与实战技巧(最新整理)》本文详解MySQL多列IN查询,对比传统OR写法,强调其简洁高效,适合批量匹配复合键,通过联合索引、分批次优化提升性能,兼容多种数据库... 目录一、基础语法:多列 IN 的两种写法1. 直接值列表2. 子查询二、对比传统 OR 的写法三、性能分析

Python办公自动化实战之打造智能邮件发送工具

《Python办公自动化实战之打造智能邮件发送工具》在数字化办公场景中,邮件自动化是提升工作效率的关键技能,本文将演示如何使用Python的smtplib和email库构建一个支持图文混排,多附件,多... 目录前言一、基础配置:搭建邮件发送框架1.1 邮箱服务准备1.2 核心库导入1.3 基础发送函数二、

PowerShell中15个提升运维效率关键命令实战指南

《PowerShell中15个提升运维效率关键命令实战指南》作为网络安全专业人员的必备技能,PowerShell在系统管理、日志分析、威胁检测和自动化响应方面展现出强大能力,下面我们就来看看15个提升... 目录一、PowerShell在网络安全中的战略价值二、网络安全关键场景命令实战1. 系统安全基线核查

从原理到实战深入理解Java 断言assert

《从原理到实战深入理解Java断言assert》本文深入解析Java断言机制,涵盖语法、工作原理、启用方式及与异常的区别,推荐用于开发阶段的条件检查与状态验证,并强调生产环境应使用参数验证工具类替代... 目录深入理解 Java 断言(assert):从原理到实战引言:为什么需要断言?一、断言基础1.1 语

Java MQTT实战应用

《JavaMQTT实战应用》本文详解MQTT协议,涵盖其发布/订阅机制、低功耗高效特性、三种服务质量等级(QoS0/1/2),以及客户端、代理、主题的核心概念,最后提供Linux部署教程、Sprin... 目录一、MQTT协议二、MQTT优点三、三种服务质量等级四、客户端、代理、主题1. 客户端(Clien

在Spring Boot中集成RabbitMQ的实战记录

《在SpringBoot中集成RabbitMQ的实战记录》本文介绍SpringBoot集成RabbitMQ的步骤,涵盖配置连接、消息发送与接收,并对比两种定义Exchange与队列的方式:手动声明(... 目录前言准备工作1. 安装 RabbitMQ2. 消息发送者(Producer)配置1. 创建 Spr

深度解析Spring Boot拦截器Interceptor与过滤器Filter的区别与实战指南

《深度解析SpringBoot拦截器Interceptor与过滤器Filter的区别与实战指南》本文深度解析SpringBoot中拦截器与过滤器的区别,涵盖执行顺序、依赖关系、异常处理等核心差异,并... 目录Spring Boot拦截器(Interceptor)与过滤器(Filter)深度解析:区别、实现

Python中Tensorflow无法调用GPU问题的解决方法

《Python中Tensorflow无法调用GPU问题的解决方法》文章详解如何解决TensorFlow在Windows无法识别GPU的问题,需降级至2.10版本,安装匹配CUDA11.2和cuDNN... 当用以下代码查看GPU数量时,gpuspython返回的是一个空列表,说明tensorflow没有找到

深度解析Spring AOP @Aspect 原理、实战与最佳实践教程

《深度解析SpringAOP@Aspect原理、实战与最佳实践教程》文章系统讲解了SpringAOP核心概念、实现方式及原理,涵盖横切关注点分离、代理机制(JDK/CGLIB)、切入点类型、性能... 目录1. @ASPect 核心概念1.1 AOP 编程范式1.2 @Aspect 关键特性2. 完整代码实

MySQL中的索引结构和分类实战案例详解

《MySQL中的索引结构和分类实战案例详解》本文详解MySQL索引结构与分类,涵盖B树、B+树、哈希及全文索引,分析其原理与优劣势,并结合实战案例探讨创建、管理及优化技巧,助力提升查询性能,感兴趣的朋... 目录一、索引概述1.1 索引的定义与作用1.2 索引的基本原理二、索引结构详解2.1 B树索引2.2