tensorflow之MNIST手写字符集训练可视化

2024-06-08 22:38

本文主要是介绍tensorflow之MNIST手写字符集训练可视化,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

简介

很多人认为卷积神经是一个黑箱子,把图片输入,输出结果为有监督式的学习(supervised learning),贴标签的形式,即可达到分类的效果。那么计算机到底做了什么事情呢?训练过程结果如何可视化?下面进行简单的介绍。

模型的搭建

@author XT
#第1层convolutional
W1 = tf.Variable(tf.truncated_normal([5,5,1,K],stddev=0.1),dtype=tf.float32,name='W1') #[filterheight,filterwith,input_channel,output_channel]
b1 = tf.Variable(tf.ones([K])/10,dtype=tf.float32,name='b1')
#第2层convolutional
W2 = tf.Variable(tf.truncated_normal([4,4,K,L],stddev=0.1),dtype=tf.float32,name='W2')
b2 = tf.Variable(tf.ones([L])/10,dtype=tf.float32,name='b2')
#第3层convolutional
W3 = tf.Variable(tf.truncated_normal([4,4,L,M],stddev=0.1),dtype=tf.float32,name='W3')
b3 = tf.Variable(tf.ones([M])/10,dtype=tf.float32,name='b3')
#convolutional out fully connected layer
W4 = tf.Variable(tf.truncated_normal([7*7*M,N],stddev=0.1),dtype=tf.float32,name='W4')
b4 = tf.Variable(tf.ones([N])/10,dtype=tf.float32,name='b4')
#output
W5 = tf.Variable(tf.truncated_normal([N,n_class],stddev=0.1),dtype=tf.float32,name='W5')#要随最后层修改
b5 = tf.Variable(tf.ones([n_class]),dtype=tf.float32,name='b5')

这里搭建了较简单的卷积神经网络,使用了3层卷积的权值,后加全连接层,最后是输出。结构为:
这里写图片描述

代码

@author XT
#Model
pkeep = tf.placeholder(tf.float32)
x = tf.placeholder(tf.float32, [None,784])#!!!注意图片格式大小    
x_image = tf.reshape(x,[-1,28,28,1])stride=1 # output is 28x28
Y1 = tf.nn.relu(tf.nn.conv2d(x_image,W1,strides=[1,stride,stride,1],padding='SAME')+b1)#sigmoid很差,要用relu
YF1 = tf.nn.dropout(Y1,pkeep)
h_poolYF1 = tf.nn.max_pool(YF1,ksize=[1,2,2,1],strides=[1,1,1,1],padding='SAME')#池化步长为1stride=2 # output is 14x14
Y2 = tf.nn.relu(tf.nn.conv2d(h_poolYF1,W2,strides=[1,stride,stride,1],padding='SAME')+b2)#sigmoid
YF2 = tf.nn.dropout(Y2,pkeep)
h_poolYF2 = tf.nn.max_pool(YF2,ksize=[1,2,2,1],strides=[1,1,1,1],padding='SAME')#池化步长为1stride = 2  # output is 7x7
Y3 = tf.nn.relu(tf.nn.conv2d(h_poolYF2,W3,strides=[1,stride,stride,1],padding='SAME')+b3)#sigmoid
YF3 = tf.nn.dropout(Y3,pkeep)
h_poolYF3 = tf.nn.max_pool(YF3,ksize=[1,2,2,1],strides=[1,1,1,1],padding='SAME')#池化步长为1# reshape the output from the third convolution for the fully connected layer
YY = tf.reshape(h_poolYF3, shape=[-1, 7*7*M])Y4 = tf.nn.relu(tf.matmul(YY,W4)+b4)#sigmoid
YF4 = tf.nn.dropout(Y4,pkeep)Ylogits = tf.matmul(YF4, W5)+b5
Y = tf.nn.softmax(Ylogits)#softmax

训练结果

1、总测试

这里写图片描述

2、Test One
这里写图片描述
输出概率:
这里写图片描述
第一卷积层:
这里写图片描述

这就是图片特征被激活的结果

部分代码

def plot_images(x, labels,max_index,name):'''plot one batch sizeimages:images_batchsize,4D tensor - [batch_size, width, height, channel]label_batch: 1D tensor - [batch_size]'''i = 0for one_pic_vic in x:one_pic_arr = np.reshape(one_pic_vic,(28,28))plt.subplot(1,1,i+1)plt.axis('off')plt.title('Label: %d   Forecast: %d'%(labels[i],max_index[i]), fontsize = 14)#采用A=0标签 +'  Forecast: '+max_index[i]plt.subplots_adjust(top=0.9)plt.imshow(one_pic_arr,cmap='gray')i+=1figure_title = nameax3  = plt.subplot(1,1,1)plt.text(0.5, -0.05, figure_title,horizontalalignment='center',fontsize=20,transform = ax3.transAxes)pylab.show()def show_rich_feature(x_relu,Node):print(x_relu.shape[1],"X",x_relu.shape[2])feature_map = tf.reshape(x_relu, [x_relu.shape[1],x_relu.shape[2],Node])images = tf.image.convert_image_dtype (feature_map, dtype=tf.uint8)images = sess.run(images)plt.figure(figsize=(10, 10))#if Node > 25,plot(5,5)for i in np.arange(0, Node):plt.subplot(2, 2, i + 1)#you need to change the subplot size if you use other layerplt.axis('off')plt.imshow(images[:,:,i])plt.show()

参考

【1】Tensorflow教程-VGG论文导读+Tensorflow实现+参数微调(fine-tuning)
http://v.youku.com/v_show/id_XMjcyNzYwMjkxMg==.html?spm=a2hzp.8244740.0.0
【2】谷歌云大会教程:没有博士学位如何玩转TensorFlow和深度学习(附资源)
http://www.sohu.com/a/128686069_465975

这篇关于tensorflow之MNIST手写字符集训练可视化的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1043509

相关文章

Python数据分析与可视化的全面指南(从数据清洗到图表呈现)

《Python数据分析与可视化的全面指南(从数据清洗到图表呈现)》Python是数据分析与可视化领域中最受欢迎的编程语言之一,凭借其丰富的库和工具,Python能够帮助我们快速处理、分析数据并生成高质... 目录一、数据采集与初步探索二、数据清洗的七种武器1. 缺失值处理策略2. 异常值检测与修正3. 数据

使用Python和Matplotlib实现可视化字体轮廓(从路径数据到矢量图形)

《使用Python和Matplotlib实现可视化字体轮廓(从路径数据到矢量图形)》字体设计和矢量图形处理是编程中一个有趣且实用的领域,通过Python的matplotlib库,我们可以轻松将字体轮廓... 目录背景知识字体轮廓的表示实现步骤1. 安装依赖库2. 准备数据3. 解析路径指令4. 绘制图形关键

8种快速易用的Python Matplotlib数据可视化方法汇总(附源码)

《8种快速易用的PythonMatplotlib数据可视化方法汇总(附源码)》你是否曾经面对一堆复杂的数据,却不知道如何让它们变得直观易懂?别慌,Python的Matplotlib库是你数据可视化的... 目录引言1. 折线图(Line Plot)——趋势分析2. 柱状图(Bar Chart)——对比分析3

使用Vue-ECharts实现数据可视化图表功能

《使用Vue-ECharts实现数据可视化图表功能》在前端开发中,经常会遇到需要展示数据可视化的需求,比如柱状图、折线图、饼图等,这类需求不仅要求我们准确地将数据呈现出来,还需要兼顾美观与交互体验,所... 目录前言为什么选择 vue-ECharts?1. 基于 ECharts,功能强大2. 更符合 Vue

Git可视化管理工具(SourceTree)使用操作大全经典

《Git可视化管理工具(SourceTree)使用操作大全经典》本文详细介绍了SourceTree作为Git可视化管理工具的常用操作,包括连接远程仓库、添加SSH密钥、克隆仓库、设置默认项目目录、代码... 目录前言:连接Gitee or github,获取代码:在SourceTree中添加SSH密钥:Cl

Pandas中统计汇总可视化函数plot()的使用

《Pandas中统计汇总可视化函数plot()的使用》Pandas提供了许多强大的数据处理和分析功能,其中plot()函数就是其可视化功能的一个重要组成部分,本文主要介绍了Pandas中统计汇总可视化... 目录一、plot()函数简介二、plot()函数的基本用法三、plot()函数的参数详解四、使用pl

使用Python实现矢量路径的压缩、解压与可视化

《使用Python实现矢量路径的压缩、解压与可视化》在图形设计和Web开发中,矢量路径数据的高效存储与传输至关重要,本文将通过一个Python示例,展示如何将复杂的矢量路径命令序列压缩为JSON格式,... 目录引言核心功能概述1. 路径命令解析2. 路径数据压缩3. 路径数据解压4. 可视化代码实现详解1

Python 交互式可视化的利器Bokeh的使用

《Python交互式可视化的利器Bokeh的使用》Bokeh是一个专注于Web端交互式数据可视化的Python库,本文主要介绍了Python交互式可视化的利器Bokeh的使用,具有一定的参考价值,感... 目录1. Bokeh 简介1.1 为什么选择 Bokeh1.2 安装与环境配置2. Bokeh 基础2

基于Python打造一个可视化FTP服务器

《基于Python打造一个可视化FTP服务器》在日常办公和团队协作中,文件共享是一个不可或缺的需求,所以本文将使用Python+Tkinter+pyftpdlib开发一款可视化FTP服务器,有需要的小... 目录1. 概述2. 功能介绍3. 如何使用4. 代码解析5. 运行效果6.相关源码7. 总结与展望1

Python Dash框架在数据可视化仪表板中的应用与实践记录

《PythonDash框架在数据可视化仪表板中的应用与实践记录》Python的PlotlyDash库提供了一种简便且强大的方式来构建和展示互动式数据仪表板,本篇文章将深入探讨如何使用Dash设计一... 目录python Dash框架在数据可视化仪表板中的应用与实践1. 什么是Plotly Dash?1.1