深度学习基础 - MNIST实验(tensorflow+Softmax)

2024-01-15 11:18

本文主要是介绍深度学习基础 - MNIST实验(tensorflow+Softmax),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

基于tensorflow开发框架,搭建softmax模型完成mnist分类任务。

本文的完整代码托管在我的Github PnYuan - Practice-of-Machine-Learning - MNIST_tensorflow_demo,欢迎访问。

1.任务背景

1.1.目的

以MNIST手写数字识别为课题,研究基本深度学习方法的应用。本文先从Softmax模型切入,以熟悉tensorflow下mnist任务的开发流程。之后的文章将陆续引入MLP、CNN等模型,以达到更优异的识别效果。

1.2.数据集

MNIST任务是图像识别领域经典的“Helloworld”。在其所提供的数据集中,包含了6w个训练样本和1w个测试样本,均为黑白图片,大小28×28,以灰度矩阵的形式存放,数值取浮点数“0~1”对应“白~黑”。给出一些图片(X)及对应标注(Y)如下图所示:

display_some

2.实验过程

实验参考代码:python + tensorflow

2.1.数据预研

MNIST数据的一些基本信息如下:

输入:image - 784 的向量 --> image size [28*28]
输出:label - int(0-10)
#train:55k
#valid:5k
#test:10k

基于tensorflow对mnist数据进行加载与测试的示例代码如下:

mnist = input_data.read_data_sets('../data/mnist_data/',one_hot=True)
X_train_org, Y_train_org = mnist.train.images, mnist.train.labels
X_valid_org, Y_valid_org = mnist.validation.images, mnist.validation.labels
X_test_org,  Y_test_org  = mnist.test.images, mnist.test.labels# check the shape of dataset
print("train set shape: X-", X_train_org.shape, ", Y-", Y_train_org.shape)
print("valid set shape: X-", X_valid_org.shape, ", Y-", Y_valid_org.shape)
print("test set shape: X-", X_test_org.shape, ", Y-", Y_test_org.shape)

2.2.Softmax建模

Softmax回归可看作是Logistic回归模型向多分类任务的拓展,其模型可描述如下图:

softmax_graph

其公式表达如下:

softmax_formula_1

写成向量化形式:

softmax_formula_2

权值 W 和偏置 b 是这里需要学习的参数。

采用tensorflow可以轻松构建出Softmax模型,示例代码如下:

#========== Softmax Modeling ==========#
x = tf.placeholder("float", [None, 784])  # placeholder of inputW = tf.Variable(tf.zeros([784,10]))  # parameters (initial to 0)
b = tf.Variable(tf.zeros([10]))y = tf.nn.softmax(tf.matmul(x,W) + b)  # softmax computation graph

2.3.训练与测试

通过构建tensorflow对话(session),给定输入x,运行由x-->y的计算图(Computation Graph),后台可简单完成训练过程。这里采用的是简单的mini-batch Gradient Descent优化策略。

模型的训练样例代码如下:

y_ = tf.placeholder("float", [None, 10])  # placeholder of labelcross_entropy = -tf.reduce_sum(y_*tf.log(y))  # loss (cross-entropy)train_step = tf.train.GradientDescentOptimizer(learning_rate = 0.01).minimize(cross_entropy)  # using GD#========== Training ==========#
init = tf.global_variables_initializer()sess = tf.InteractiveSession()  # initial a session
sess.run(init)for i in range(1000):  # iterate  for 100 timesbatch_xs, batch_ys = mnist.train.next_batch(100)  # using mini-batchsess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

在验证集与测试集上评估所学模型的效果,以预测准确率(accuracy)为指标,得出结果如下:

valid accuracy 0.927
test accuracy 0.9208

可以看出,Softmax模型在经过一定时间的训练之后,达到了九成的分类准确率。与MNIST官网给出的线性分类器(单层NN)的准确级别相近。

3.实验小结

这里采用tensorflow开发框架搭建了Softmax多分类模型,实现了超过90%的测试准确率。模型的搭建以及训练测试过程十分简便。据tensorflow官网所述,使用多层神经网络等更复杂的模型还可进一步提升分类效果,接下来的文章,将对此进行跟进。

4.参考资料

  • 基础教程:tensorflow官网 - MNIST机器学习入门
  • 辅助资料:TensorFlow中的Nan值的陷阱

这篇关于深度学习基础 - MNIST实验(tensorflow+Softmax)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/608710

相关文章

python panda库从基础到高级操作分析

《pythonpanda库从基础到高级操作分析》本文介绍了Pandas库的核心功能,包括处理结构化数据的Series和DataFrame数据结构,数据读取、清洗、分组聚合、合并、时间序列分析及大数据... 目录1. Pandas 概述2. 基本操作:数据读取与查看3. 索引操作:精准定位数据4. Group

深度解析Spring Security 中的 SecurityFilterChain核心功能

《深度解析SpringSecurity中的SecurityFilterChain核心功能》SecurityFilterChain通过组件化配置、类型安全路径匹配、多链协同三大特性,重构了Spri... 目录Spring Security 中的SecurityFilterChain深度解析一、Security

深度解析Nginx日志分析与499状态码问题解决

《深度解析Nginx日志分析与499状态码问题解决》在Web服务器运维和性能优化过程中,Nginx日志是排查问题的重要依据,本文将围绕Nginx日志分析、499状态码的成因、排查方法及解决方案展开讨论... 目录前言1. Nginx日志基础1.1 Nginx日志存放位置1.2 Nginx日志格式2. 499

深度解析Java DTO(最新推荐)

《深度解析JavaDTO(最新推荐)》DTO(DataTransferObject)是一种用于在不同层(如Controller层、Service层)之间传输数据的对象设计模式,其核心目的是封装数据,... 目录一、什么是DTO?DTO的核心特点:二、为什么需要DTO?(对比Entity)三、实际应用场景解析

深度解析Java项目中包和包之间的联系

《深度解析Java项目中包和包之间的联系》文章浏览阅读850次,点赞13次,收藏8次。本文详细介绍了Java分层架构中的几个关键包:DTO、Controller、Service和Mapper。_jav... 目录前言一、各大包1.DTO1.1、DTO的核心用途1.2. DTO与实体类(Entity)的区别1

深度解析Python装饰器常见用法与进阶技巧

《深度解析Python装饰器常见用法与进阶技巧》Python装饰器(Decorator)是提升代码可读性与复用性的强大工具,本文将深入解析Python装饰器的原理,常见用法,进阶技巧与最佳实践,希望可... 目录装饰器的基本原理函数装饰器的常见用法带参数的装饰器类装饰器与方法装饰器装饰器的嵌套与组合进阶技巧

深度解析Spring Boot拦截器Interceptor与过滤器Filter的区别与实战指南

《深度解析SpringBoot拦截器Interceptor与过滤器Filter的区别与实战指南》本文深度解析SpringBoot中拦截器与过滤器的区别,涵盖执行顺序、依赖关系、异常处理等核心差异,并... 目录Spring Boot拦截器(Interceptor)与过滤器(Filter)深度解析:区别、实现

Python中Tensorflow无法调用GPU问题的解决方法

《Python中Tensorflow无法调用GPU问题的解决方法》文章详解如何解决TensorFlow在Windows无法识别GPU的问题,需降级至2.10版本,安装匹配CUDA11.2和cuDNN... 当用以下代码查看GPU数量时,gpuspython返回的是一个空列表,说明tensorflow没有找到

深度解析Spring AOP @Aspect 原理、实战与最佳实践教程

《深度解析SpringAOP@Aspect原理、实战与最佳实践教程》文章系统讲解了SpringAOP核心概念、实现方式及原理,涵盖横切关注点分离、代理机制(JDK/CGLIB)、切入点类型、性能... 目录1. @ASPect 核心概念1.1 AOP 编程范式1.2 @Aspect 关键特性2. 完整代码实

SpringBoot开发中十大常见陷阱深度解析与避坑指南

《SpringBoot开发中十大常见陷阱深度解析与避坑指南》在SpringBoot的开发过程中,即使是经验丰富的开发者也难免会遇到各种棘手的问题,本文将针对SpringBoot开发中十大常见的“坑... 目录引言一、配置总出错?是不是同时用了.properties和.yml?二、换个位置配置就失效?搞清楚加