基于muist数据集的maxout网络实现分类 ----代码分享

2023-10-11 07:38

本文主要是介绍基于muist数据集的maxout网络实现分类 ----代码分享,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

运行环境:windows,tensorflow - gpu-1.13.1

#---------------------------------理解mnist数据集
#导入mnist数据集
from tensorflow.examples.tutorials.mnist import input_data #从网上下载mnist数据集的模块
mnist = input_data.read_data_sets('MNIST_data/',one_hot = False) #从指定文件夹导入数据集的数据
##分析mnist数据集
#print('输入训练数据集数据:',mnist.train.images) #打引导如数据集的数据
#print('输入训练数据集shape:',mnist.train.images.shape) #打印训练数据集的形状
#print('输入测试数据集shape:',mnist.test.images.shape) #用于评估训练过程中的准确度
#print('输入验证数据集shape:',mnist.validation.images.shape) #用于评估最终模型的准确度
#print('输入标签的shape:',mnist.train.labels.shape)
#展示mnist数据集
#import pylab 
#im = mnist.test.images[6] #train中的第六张图
#im = im.reshape(-1,28)
#pylab.imshow(im)
#pylab.show()#-----------------------------------------------#-------------------------------正向传播结构
import tensorflow as tf
tf.reset_default_graph()
#分析图片特点定义变量
#define placeholder
x = tf.placeholder(tf.float32,[None, 784]) #mnist data have 784 value
#y = tf.placeholder(tf.float32,[None,10]) #labels have 10 value
y = tf.placeholder(tf.int32,[None]) 
#定义学习参数
W = tf.Variable(tf.random_normal([784,10])) #Normally,we set weight as random
b = tf.Variable(tf.zeros([10]))#Normally,we set base as zero
#print(b)
#with tf.Session() as sess:
#    print(sess.run(b))
#定义输出节点
#pred = tf.nn.softmax(tf.matmul(x,W) + b) #sotfmax分类
z = tf.matmul(x,W) + b
maxout = tf.reduce_max(z,axis=1,keep_dims=True)
#设置学习参数
W2 = tf.Variable(tf.truncated_normal([1,10],stddev=0.1))
b2 = tf.Variable(tf.zeros([10]))pred = tf.nn.softmax(tf.matmul(maxout,W2)+b2) # Softmax分类
#-------------------------------------------#-------------------------------------定义反向结构及传播参数
#损失函数
#cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1)) #生成的pred与样本标签y进行交叉熵运算,然后取平均值
#cost = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels = y,logits = z))
cost = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=z))
#定义参数
learning_rate = 0.3
#使用梯度下降优化器
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) 
#----------------------------------------------------------#--------------------------------训练模型并输出中间状态参数
training_epochs = 200
batch_size = 500
display_step = 1saver = tf.train.Saver()
model_path = 'log/mnist_model.ckpt'#启动session
with tf.Session() as sess:sess.run(tf.global_variables_initializer()) #初始化OP#启动循环开始训练for epoch in range(training_epochs):avg_cost = 0.total_batch = int(mnist.train.num_examples/batch_size)#循环所有数据集for i in range(total_batch):batch_xs, batch_ys = mnist.train.next_batch(batch_size)#运行优化器_,c = sess.run([optimizer, cost], feed_dict = {x:batch_xs,y:batch_ys})#计算平均loss值avg_cost += c / total_batch#显示训练中的详细信息if (epoch+1) % display_step == 0:print('Epoch:','%04d' % (epoch+1),'cost','{:.9f}'.format(avg_cost))print('Finish!')

这篇关于基于muist数据集的maxout网络实现分类 ----代码分享的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/186539

相关文章

使用Python和OpenCV库实现实时颜色识别系统

《使用Python和OpenCV库实现实时颜色识别系统》:本文主要介绍使用Python和OpenCV库实现的实时颜色识别系统,这个系统能够通过摄像头捕捉视频流,并在视频中指定区域内识别主要颜色(红... 目录一、引言二、系统概述三、代码解析1. 导入库2. 颜色识别函数3. 主程序循环四、HSV色彩空间详解

PostgreSQL中MVCC 机制的实现

《PostgreSQL中MVCC机制的实现》本文主要介绍了PostgreSQL中MVCC机制的实现,通过多版本数据存储、快照隔离和事务ID管理实现高并发读写,具有一定的参考价值,感兴趣的可以了解一下... 目录一 MVCC 基本原理python1.1 MVCC 核心概念1.2 与传统锁机制对比二 Postg

SpringBoot整合Flowable实现工作流的详细流程

《SpringBoot整合Flowable实现工作流的详细流程》Flowable是一个使用Java编写的轻量级业务流程引擎,Flowable流程引擎可用于部署BPMN2.0流程定义,创建这些流程定义的... 目录1、流程引擎介绍2、创建项目3、画流程图4、开发接口4.1 Java 类梳理4.2 查看流程图4

SQL Server修改数据库名及物理数据文件名操作步骤

《SQLServer修改数据库名及物理数据文件名操作步骤》在SQLServer中重命名数据库是一个常见的操作,但需要确保用户具有足够的权限来执行此操作,:本文主要介绍SQLServer修改数据... 目录一、背景介绍二、操作步骤2.1 设置为单用户模式(断开连接)2.2 修改数据库名称2.3 查找逻辑文件名

C++中零拷贝的多种实现方式

《C++中零拷贝的多种实现方式》本文主要介绍了C++中零拷贝的实现示例,旨在在减少数据在内存中的不必要复制,从而提高程序性能、降低内存使用并减少CPU消耗,零拷贝技术通过多种方式实现,下面就来了解一下... 目录一、C++中零拷贝技术的核心概念二、std::string_view 简介三、std::stri

C++高效内存池实现减少动态分配开销的解决方案

《C++高效内存池实现减少动态分配开销的解决方案》C++动态内存分配存在系统调用开销、碎片化和锁竞争等性能问题,内存池通过预分配、分块管理和缓存复用解决这些问题,下面就来了解一下... 目录一、C++内存分配的性能挑战二、内存池技术的核心原理三、主流内存池实现:TCMalloc与Jemalloc1. TCM

OpenCV实现实时颜色检测的示例

《OpenCV实现实时颜色检测的示例》本文主要介绍了OpenCV实现实时颜色检测的示例,通过HSV色彩空间转换和色调范围判断实现红黄绿蓝颜色检测,包含视频捕捉、区域标记、颜色分析等功能,具有一定的参考... 目录一、引言二、系统概述三、代码解析1. 导入库2. 颜色识别函数3. 主程序循环四、HSV色彩空间

Python虚拟环境与Conda使用指南分享

《Python虚拟环境与Conda使用指南分享》:本文主要介绍Python虚拟环境与Conda使用指南,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、python 虚拟环境概述1.1 什么是虚拟环境1.2 为什么需要虚拟环境二、Python 内置的虚拟环境工具

Python实例题之pygame开发打飞机游戏实例代码

《Python实例题之pygame开发打飞机游戏实例代码》对于python的学习者,能够写出一个飞机大战的程序代码,是不是感觉到非常的开心,:本文主要介绍Python实例题之pygame开发打飞机... 目录题目pygame-aircraft-game使用 Pygame 开发的打飞机游戏脚本代码解释初始化部

Python实现精准提取 PDF中的文本,表格与图片

《Python实现精准提取PDF中的文本,表格与图片》在实际的系统开发中,处理PDF文件不仅限于读取整页文本,还有提取文档中的表格数据,图片或特定区域的内容,下面我们来看看如何使用Python实... 目录安装 python 库提取 PDF 文本内容:获取整页文本与指定区域内容获取页面上的所有文本内容获取