基于muist数据集的maxout网络实现分类 ----代码分享

2023-10-11 07:38

本文主要是介绍基于muist数据集的maxout网络实现分类 ----代码分享,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

运行环境:windows,tensorflow - gpu-1.13.1

#---------------------------------理解mnist数据集
#导入mnist数据集
from tensorflow.examples.tutorials.mnist import input_data #从网上下载mnist数据集的模块
mnist = input_data.read_data_sets('MNIST_data/',one_hot = False) #从指定文件夹导入数据集的数据
##分析mnist数据集
#print('输入训练数据集数据:',mnist.train.images) #打引导如数据集的数据
#print('输入训练数据集shape:',mnist.train.images.shape) #打印训练数据集的形状
#print('输入测试数据集shape:',mnist.test.images.shape) #用于评估训练过程中的准确度
#print('输入验证数据集shape:',mnist.validation.images.shape) #用于评估最终模型的准确度
#print('输入标签的shape:',mnist.train.labels.shape)
#展示mnist数据集
#import pylab 
#im = mnist.test.images[6] #train中的第六张图
#im = im.reshape(-1,28)
#pylab.imshow(im)
#pylab.show()#-----------------------------------------------#-------------------------------正向传播结构
import tensorflow as tf
tf.reset_default_graph()
#分析图片特点定义变量
#define placeholder
x = tf.placeholder(tf.float32,[None, 784]) #mnist data have 784 value
#y = tf.placeholder(tf.float32,[None,10]) #labels have 10 value
y = tf.placeholder(tf.int32,[None]) 
#定义学习参数
W = tf.Variable(tf.random_normal([784,10])) #Normally,we set weight as random
b = tf.Variable(tf.zeros([10]))#Normally,we set base as zero
#print(b)
#with tf.Session() as sess:
#    print(sess.run(b))
#定义输出节点
#pred = tf.nn.softmax(tf.matmul(x,W) + b) #sotfmax分类
z = tf.matmul(x,W) + b
maxout = tf.reduce_max(z,axis=1,keep_dims=True)
#设置学习参数
W2 = tf.Variable(tf.truncated_normal([1,10],stddev=0.1))
b2 = tf.Variable(tf.zeros([10]))pred = tf.nn.softmax(tf.matmul(maxout,W2)+b2) # Softmax分类
#-------------------------------------------#-------------------------------------定义反向结构及传播参数
#损失函数
#cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1)) #生成的pred与样本标签y进行交叉熵运算,然后取平均值
#cost = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels = y,logits = z))
cost = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=z))
#定义参数
learning_rate = 0.3
#使用梯度下降优化器
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) 
#----------------------------------------------------------#--------------------------------训练模型并输出中间状态参数
training_epochs = 200
batch_size = 500
display_step = 1saver = tf.train.Saver()
model_path = 'log/mnist_model.ckpt'#启动session
with tf.Session() as sess:sess.run(tf.global_variables_initializer()) #初始化OP#启动循环开始训练for epoch in range(training_epochs):avg_cost = 0.total_batch = int(mnist.train.num_examples/batch_size)#循环所有数据集for i in range(total_batch):batch_xs, batch_ys = mnist.train.next_batch(batch_size)#运行优化器_,c = sess.run([optimizer, cost], feed_dict = {x:batch_xs,y:batch_ys})#计算平均loss值avg_cost += c / total_batch#显示训练中的详细信息if (epoch+1) % display_step == 0:print('Epoch:','%04d' % (epoch+1),'cost','{:.9f}'.format(avg_cost))print('Finish!')

这篇关于基于muist数据集的maxout网络实现分类 ----代码分享的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!


原文地址:
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.chinasem.cn/article/186539

相关文章

C#监听txt文档获取新数据方式

《C#监听txt文档获取新数据方式》文章介绍通过监听txt文件获取最新数据,并实现开机自启动、禁用窗口关闭按钮、阻止Ctrl+C中断及防止程序退出等功能,代码整合于主函数中,供参考学习... 目录前言一、监听txt文档增加数据二、其他功能1. 设置开机自启动2. 禁止控制台窗口关闭按钮3. 阻止Ctrl +

java如何实现高并发场景下三级缓存的数据一致性

《java如何实现高并发场景下三级缓存的数据一致性》这篇文章主要为大家详细介绍了java如何实现高并发场景下三级缓存的数据一致性,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 下面代码是一个使用Java和Redisson实现的三级缓存服务,主要功能包括:1.缓存结构:本地缓存:使

如何在Java Spring实现异步执行(详细篇)

《如何在JavaSpring实现异步执行(详细篇)》Spring框架通过@Async、Executor等实现异步执行,提升系统性能与响应速度,支持自定义线程池管理并发,本文给大家介绍如何在Sprin... 目录前言1. 使用 @Async 实现异步执行1.1 启用异步执行支持1.2 创建异步方法1.3 调用

Spring Boot配置和使用两个数据源的实现步骤

《SpringBoot配置和使用两个数据源的实现步骤》本文详解SpringBoot配置双数据源方法,包含配置文件设置、Bean创建、事务管理器配置及@Qualifier注解使用,强调主数据源标记、代... 目录Spring Boot配置和使用两个数据源技术背景实现步骤1. 配置数据源信息2. 创建数据源Be

在MySQL中实现冷热数据分离的方法及使用场景底层原理解析

《在MySQL中实现冷热数据分离的方法及使用场景底层原理解析》MySQL冷热数据分离通过分表/分区策略、数据归档和索引优化,将频繁访问的热数据与冷数据分开存储,提升查询效率并降低存储成本,适用于高并发... 目录实现冷热数据分离1. 分表策略2. 使用分区表3. 数据归档与迁移在mysql中实现冷热数据分

C#解析JSON数据全攻略指南

《C#解析JSON数据全攻略指南》这篇文章主要为大家详细介绍了使用C#解析JSON数据全攻略指南,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录一、为什么jsON是C#开发必修课?二、四步搞定网络JSON数据1. 获取数据 - HttpClient最佳实践2. 动态解析 - 快速

linux批量替换文件内容的实现方式

《linux批量替换文件内容的实现方式》本文总结了Linux中批量替换文件内容的几种方法,包括使用sed替换文件夹内所有文件、单个文件内容及逐行字符串,强调使用反引号和绝对路径,并分享个人经验供参考... 目录一、linux批量替换文件内容 二、替换文件内所有匹配的字符串 三、替换每一行中全部str1为st

SpringBoot集成MyBatis实现SQL拦截器的实战指南

《SpringBoot集成MyBatis实现SQL拦截器的实战指南》这篇文章主要为大家详细介绍了SpringBoot集成MyBatis实现SQL拦截器的相关知识,文中的示例代码讲解详细,有需要的小伙伴... 目录一、为什么需要SQL拦截器?二、MyBATis拦截器基础2.1 核心接口:Interceptor

SpringBoot集成EasyPoi实现Excel模板导出成PDF文件

《SpringBoot集成EasyPoi实现Excel模板导出成PDF文件》在日常工作中,我们经常需要将数据导出成Excel表格或PDF文件,本文将介绍如何在SpringBoot项目中集成EasyPo... 目录前言摘要简介源代码解析应用场景案例优缺点分析类代码方法介绍测试用例小结前言在日常工作中,我们经

基于Python实现简易视频剪辑工具

《基于Python实现简易视频剪辑工具》这篇文章主要为大家详细介绍了如何用Python打造一个功能完备的简易视频剪辑工具,包括视频文件导入与格式转换,基础剪辑操作,音频处理等功能,感兴趣的小伙伴可以了... 目录一、技术选型与环境搭建二、核心功能模块实现1. 视频基础操作2. 音频处理3. 特效与转场三、高