MNIST数据集下的Softmax回归模型实验—我的第一个机器学习程序(包括MNIST数据集下载,Softmax介绍和源码)

本文主要是介绍MNIST数据集下的Softmax回归模型实验—我的第一个机器学习程序(包括MNIST数据集下载,Softmax介绍和源码),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

MNIST数据集是一个入门级的计算机视觉数据集,它包含各种手写数字照片,它也包含每一张图片对应的标签,告诉我们这是数字几。

例如这幅图的标签分别是5、0、4、1。

数据集被分成两部分:60000 行的训练数据集mnist.train和10000行的测试数据集(mnist.test)。其中:60000 行的训练

部分拆为 55000 行的训练集和 5000 行的验证集。

接下来我将介绍一个简单的机器学习模型—CNN,来预测图片里面的数字。

首先介绍一下如何下载MNIST数据集

Tensorflow里面可以用如下代码导入MNIST数据集:

from  tensorflow.examples.tutorials.mnist   import   input_data

mnist  =  input_data.read_data_sets ( "MNIST_data/",  one_hot=True )

成功获取MNIST数据集后,发现本地已经下载了4个压缩文件:

#训练集的压缩文件, 9912422  bytes

Extracting MNIST_data / train-images-idx3-ubyte.gz

#训练集标签的压缩文件28881 bytes
Extracting MNIST_data / train-labels-idx1-ubyte.gz

#测试集的压缩文件1648877 bytes
Extracting MNIST_data / t10k-images-idx3-ubyte.gz

#测试集的压缩文件4542 bytes
Extracting MNIST_data / t10k-labels-idx1-ubyte.gz

我们可以在终端打印数据集的张量情况:

print ( mnist.train.images.shape )   #训练集的张量  

print ( mnist.train.labels.shape )   #训练集标签的张量

print ( mnist.validation.images.shape )  #验证集的张量

print ( mnist.validation.labels.shape )   #验证集标签的张量

print ( mnist.test.images.shape ) #测试集的张量

print ( mnist.test.labels.shape ) #测试集标签的张量 



我们发现:

1、MNIST数据集包含 55000 行训练集、5000 行验证集和10000 行测试集

2、每一张图片展开成一个 28 28 = 784 维的向量,展开的顺序可以随意的,只要保证每张图片的展开顺序一致即可

3、每一张图片的标签被初始化成 一个 10 维的“one-hot”向量


数据集下载完之后就要进入构建模型阶段了

Softmax回归模型

当我们在数据集上进行训练的时候,模型判定一张图片是几往往依靠一个概率比如80%的概率为5,20%的概率为2,那么最终将会判定这个数字为5。

为了得到这个概率,我们使用Softmax模型,这个模型分为两步

第一步

为了得到一张给定图片属于某个特定数字类的证据(evidence),我们对图片像素值进行加权求和。如果这个像素具有很强的证据证明这张图片不属于该类,那么相应的权值为负数,相反如果这个像素拥有有利的证据支持这张图片属于这个类,那么权值是正数。我们也需要加入一个额外的偏置量(bias),因为输入往往会带有一些无关的干扰量。因此对于给定的输入图片x 它代表的是数字 i 的证据可以表示为


其中bi代表数字 i 类的偏置量,j代表给定图片 x 的像素索引用于像素求和。然后用softmax函数可以把这些证据转换成概率 y :

这里的softmax可以看成是一个激励(activation)函数或者链接(link)函数,把我们定义的线性函数的输出转换成我们想要的格式,也就是关于10个数字类的概率分布。因此,给定一张图片,它对于每一个数字的吻合度可以被softmax函数转换成为一个概率值。

softmax函数可以定义为:

展开后可得

但是更多的时候把softmax模型函数定义为前一种形式:把输入值当成幂指数求值,再正则化这些结果值。这个幂运算表示,更大的证据对应更大的假设模型(hypothesis)里面的乘数权重值。反之,拥有更少的证据意味着在假设模型里面拥有更小的乘数系数。假设模型里的权值不可以是0值或者负值。Softmax然后会正则化这些权重值,使它们的总和等于1,以此构造一个有效的概率分布。

更进一步我们把Softmax模型写为

接下来我们来实现模型

首先我们在使用Tensorflow时需要先导入:

import tensorflow as tf

然后设置一个占位符x,作为Tensorflow计算时的输入,允许输入任意数量图像,并且将其张开为784维向量我们用2维的浮点数张量来表示这些图,这个张量的形状是 [None,784 ] 。(这里的 None 表示此张量的第一个维度可以是任何长度的。)

x = tf.placeholder("float", shape=[None, 784])

我们的模型也需要权重值和偏置量,当然我们可以把它们当做是另外的输入(使用占位符),但TensorFlow有一个更好的方法来表示它们: Variable 。它们可以用于计算输入值,也可以在计算中被修改。对于各种机器学习应用,一般都会有模型参数,可以用 Variable 表示。

W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))

我们赋予 tf.Variable 不同的初值来创建不同的 Variable :在这里,我们都用全为零的张量来初始化 W 和b 。因为我们要学习 W 和b 的值,它们的初值可以随意设置。

注意, W 的维度是[784,10],因为我们想要用784维的图片向量乘以它以得到一个10维的证据值向量,每一位对应不同数字类。 b 的形状是[10],所以我们可以直接把它加到输出上面。

现在,我们可以实现我们的模型

y = tf.nn.softmax(tf.matmul(x,W) + b)

模型至此已经写完了,接下来就要训练我们的模型

在机器学习中衡量一个模型好坏的因素,往往要通过损失(loss)或者是成本(cost),然后尽量最小化这两种指标来获取更好的模型,但从本质上来说两种方法取得的是一样的结果。

一个非常常见的成本函数是“交叉熵”(cross-entropy),今天我们也采用这个函数来训练我们的模型,它的定义如下:

y 是我们预测的概率分布, y' 是实际的分布(我们输入的one-hot vector)。比较粗糙的理解是,交叉熵是用来衡量我们的预测用于描述真相的低效性。

为了计算交叉熵,我们首先需要添加一个新的占位符用于输入正确值:

y_ = tf.placeholder("float", shape=[None, 10])

然后我们可以用

来计算交叉熵

cross_entropy = -tf.reduce_sum(y_*tf.log(y))

TensorFlow拥有一张描述你各个计算单元的图,它可以自动地使用反向传播算法(backpropagation algorithm)来有效地确定你的变量是如何影响你想要最小化的那个成本值的。然后,TensorFlow会用你选择的优化算法来不断地修改变量以降低成本。

在这里,我们要求TensorFlow用梯度下降算法(gradient descent algorithm)以0.01的学习速率最小化交叉熵。梯度下降算法(gradient descent algorithm)是一个简单的学习过程,TensorFlow只需将每个变量一点点地往使成本不断降低的方向移动在这里,我们要求TensorFlow用梯度下降算法(gradient descent algorithm)以0.01的学习速率最小化交叉熵。梯度下降算法(gradient descent algorithm)是一个简单的学习过程,TensorFlow只需将每个变量一点点地往使成本不断降低的方向移动。

train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)

现在,我们已经设置好了我们的模型。

Tensorflow依赖于一个高效的C++后端来进行计算。与后端的这个连接叫做session。一般而言,使用TensorFlow程序的流程是先创建一个图,然后在session中启动它。

这里,我们使用更加方便的 InteractiveSession 类。通过它,你可以更加灵活地构建你的代码。它能让你在运行图的时候,插入一些计算图,这些计算图是由某些操作(operations)构成的。这对于工作在交互式环境中的人们来说非常便利,比如使用IPython。如果你没有使用 InteractiveSession ,那么你需要在启动session之前构建整个计算图,然后启动该计算图。

sess = tf.InteractiveSession()

变量 需要通过seesion初始化后,才能在session中使用。这一初始化步骤为,为初始值指定具体值(本例当中是全为零),并将其分配给每个 变量 ,可以一次性为所有 变量 完成此操作。

sess.run(tf.initialize_all_variables())
然后开始训练模型,这里我们让模型循环训练1000次!
for i in range(1000):
    batch = mnist.train.next_batch(50)
    train_step.run(feed_dict={x: batch[0], y_: batch[1]})
该循环的每个步骤中,我们都会随机抓取训练数据中的100个批处理数据点,然后我们用这些数据点作为参数替换

之前的占位符来运行 train_step 。

使用一小部分的随机数据来进行训练被称为随机训练(stochastic training)- 在这里更确切的说是随机梯度下降训练。在理想情况

下,我们希望用我们所有的数据来进行每一步的训练,因为这能给我们更好的训练结果,但显然这需要很大的计算开销。所以,每一次

训练我们可以使用不同的数据子集,这样做既可以减少计算开销,又可以最大化地学习到数据集的总体特性。

最后就是评估我们的模型

tf.argmax 是一个非常有用的函数,它能给出某个tensor对象在某一维上的其数据最大值所在的索引值。由于标签向量是由0,1组成,因此最大值1所在的索引位置就是类别标签,比如 tf.argmax(y,1) 返回的是模型对于任一输入x预测到的标签值,而 tf.argmax(y_,1) 代表正确的标签,我们可以用 tf.equal 来检测我们的预测是否真实标签匹配(索引位置一样表示匹配)。

correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))

最后,我们计算所学习到的模型在测试数据集上面的正确率

accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
print accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels})
到这里Softmax模型就全部完成了,正确率在91%左右。


这篇关于MNIST数据集下的Softmax回归模型实验—我的第一个机器学习程序(包括MNIST数据集下载,Softmax介绍和源码)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!


原文地址:
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.chinasem.cn/article/761140

相关文章

C#监听txt文档获取新数据方式

《C#监听txt文档获取新数据方式》文章介绍通过监听txt文件获取最新数据,并实现开机自启动、禁用窗口关闭按钮、阻止Ctrl+C中断及防止程序退出等功能,代码整合于主函数中,供参考学习... 目录前言一、监听txt文档增加数据二、其他功能1. 设置开机自启动2. 禁止控制台窗口关闭按钮3. 阻止Ctrl +

java如何实现高并发场景下三级缓存的数据一致性

《java如何实现高并发场景下三级缓存的数据一致性》这篇文章主要为大家详细介绍了java如何实现高并发场景下三级缓存的数据一致性,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 下面代码是一个使用Java和Redisson实现的三级缓存服务,主要功能包括:1.缓存结构:本地缓存:使

在MySQL中实现冷热数据分离的方法及使用场景底层原理解析

《在MySQL中实现冷热数据分离的方法及使用场景底层原理解析》MySQL冷热数据分离通过分表/分区策略、数据归档和索引优化,将频繁访问的热数据与冷数据分开存储,提升查询效率并降低存储成本,适用于高并发... 目录实现冷热数据分离1. 分表策略2. 使用分区表3. 数据归档与迁移在mysql中实现冷热数据分

C#解析JSON数据全攻略指南

《C#解析JSON数据全攻略指南》这篇文章主要为大家详细介绍了使用C#解析JSON数据全攻略指南,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录一、为什么jsON是C#开发必修课?二、四步搞定网络JSON数据1. 获取数据 - HttpClient最佳实践2. 动态解析 - 快速

MyBatis-Plus通用中等、大量数据分批查询和处理方法

《MyBatis-Plus通用中等、大量数据分批查询和处理方法》文章介绍MyBatis-Plus分页查询处理,通过函数式接口与Lambda表达式实现通用逻辑,方法抽象但功能强大,建议扩展分批处理及流式... 目录函数式接口获取分页数据接口数据处理接口通用逻辑工具类使用方法简单查询自定义查询方法总结函数式接口

SQL server数据库如何下载和安装

《SQLserver数据库如何下载和安装》本文指导如何下载安装SQLServer2022评估版及SSMS工具,涵盖安装配置、连接字符串设置、C#连接数据库方法和安全注意事项,如混合验证、参数化查... 目录第一步:打开官网下载对应文件第二步:程序安装配置第三部:安装工具SQL Server Manageme

golang程序打包成脚本部署到Linux系统方式

《golang程序打包成脚本部署到Linux系统方式》Golang程序通过本地编译(设置GOOS为linux生成无后缀二进制文件),上传至Linux服务器后赋权执行,使用nohup命令实现后台运行,完... 目录本地编译golang程序上传Golang二进制文件到linux服务器总结本地编译Golang程序

zookeeper端口说明及介绍

《zookeeper端口说明及介绍》:本文主要介绍zookeeper端口说明,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、zookeeper有三个端口(可以修改)aVNMqvZ二、3个端口的作用三、部署时注意总China编程结一、zookeeper有三个端口(可以

SQL中如何添加数据(常见方法及示例)

《SQL中如何添加数据(常见方法及示例)》SQL全称为StructuredQueryLanguage,是一种用于管理关系数据库的标准编程语言,下面给大家介绍SQL中如何添加数据,感兴趣的朋友一起看看吧... 目录在mysql中,有多种方法可以添加数据。以下是一些常见的方法及其示例。1. 使用INSERT I

使用Docker构建Python Flask程序的详细教程

《使用Docker构建PythonFlask程序的详细教程》在当今的软件开发领域,容器化技术正变得越来越流行,而Docker无疑是其中的佼佼者,本文我们就来聊聊如何使用Docker构建一个简单的Py... 目录引言一、准备工作二、创建 Flask 应用程序三、创建 dockerfile四、构建 Docker