Tensorflow实现的MNIST数据集的2层卷积2层全连接网络

2024-04-27 02:48

本文主要是介绍Tensorflow实现的MNIST数据集的2层卷积2层全连接网络,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

import tensorflow as tf
"""
h=w 图片尺寸
f=卷积核
p=padding 边界填补 ‘SAME’补充
s=strides 每一次走的步长
(h-f+2*p)/s + 1
"""
# 10 分类,输入图片尺寸 784*784
n_input=784
n_output=10
# 获取数据MNIST
mnist=('data/',one_hot = True)weights={# [3,3,1,64] 3*3 = h*w  卷积核, 1 channel, 64个特征图'wc1':tf.Variable(tf.random_normal([3,3,1,64],stddev=0.1)),'wc2':tf.Variable(tf.random_normal([3, 3, 64, 128], stddev = 0.1)),'wd1':tf.Variable(tf.random_normal([7*7*128,1024],stddev=0.1)),'wd2':tf.Variable(tf.random_normal([1024,n_output],stddev=0.1))
}
biases={'bc1':tf.Variable(tf.random_normal([64],stddev=0.1)),'bc2':tf.Variable(tf.random_normal([128],stddev=0.1)),'bd1':tf.Variable(tf.random_normal([1024],stddev=0.1)),'bd2':tf.Variable(tf.random_normal([n_output],stddev=0.1))
}def conv_basic(input, w, b, keepratio):input_r=tf.reshape(input,shape=[-1, 28,28,1])conv1 = tf.nn.conv2d(input_r,w['wc1'],strides=[1,1,1,1],padding='SAME')conv1 = tf.nn.relu(tf.nn.bias_add(conv1,b['bc1']))pool1 = tf.nn.max_pool(conv1,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')pool_dr1 = tf.nn.dropout(pool1,keepratio)conv2 = tf.nn.conv2d(pool_dr1,w['wc2'],strides=[1,1,1,1],padding='SAME')conv2 = tf.nn.relu(tf.nn.bias_add(conv2,b['bc2']))pool2 = tf.nn.max_pool(conv2,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')pool_dr2 = tf.nn.dropout(pool2,keepratio)# 全连接层dense1 = tf.reshape(pool_dr2,[-1,w['wd1'].get_shape().as_list()[0]])fc1 = tf.nn.relu(tf.add(tf.matmul(dense1,w['wd1']),b['bd1']))fc_dr1 = tf.nn.dropout(fc1,keepratio)_out = tf.add(tf.matmul(fc_dr1,w['wd2']),b['bd2'])out ={'input_r':input_r,'conv1':conv1,'pool1':pool1, 'pool_dr1': pool_dr1,'conv2': conv2,'pool2': pool2, 'pool_dr2': pool_dr2,'dense1':dense1,'fc1':fc1,  'fc_dr1':fc_dr1,'out': _out}return outx = tf.placeholder(tf.float32,[None,n_input])
y = tf.placeholder(tf.float32,[None,n_output])
keepratio = tf.placeholder(tf.float32)_pred = conv_basic(x, weights, biases, keepratio)
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(_pred,y))
optm = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)_corr = tf.equal(tf.argmax(_pred,1),tf.argmax(y,1))
accr = tf.redece_mean(tf.cast(_corr,tf.float32))init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)training_epochs = 15
batch_size = 16
display_step=1
for epoch in range(training_epochs):avg_cost=0.total_batch =10for i in range(total_batch):# 以 batch_size 大小来依次的获取数据batch_xs, batch_ys = mnist.train.next_batch(batch_size)sess.run(optm,feed_dict={x:batch_xs,y:batch_ys,keepratio:0.7})avg_cost += sess.run(loss,feed_dict={x:batch_xs,y:batch_ys,keepratio:1.})/total_batchif epoch % display_step==0:print('Epoch: %03d/%03d loss: %9f'%(epoch,training_epochs, avg_cost))train_acc = sess.run(accr, feed_dict={x:batch_xs,y:batch_ys,keepratio:0.7})

 

这篇关于Tensorflow实现的MNIST数据集的2层卷积2层全连接网络的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/939332

相关文章

Linux下删除乱码文件和目录的实现方式

《Linux下删除乱码文件和目录的实现方式》:本文主要介绍Linux下删除乱码文件和目录的实现方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录linux下删除乱码文件和目录方法1方法2总结Linux下删除乱码文件和目录方法1使用ls -i命令找到文件或目录

SpringBoot+EasyExcel实现自定义复杂样式导入导出

《SpringBoot+EasyExcel实现自定义复杂样式导入导出》这篇文章主要为大家详细介绍了SpringBoot如何结果EasyExcel实现自定义复杂样式导入导出功能,文中的示例代码讲解详细,... 目录安装处理自定义导出复杂场景1、列不固定,动态列2、动态下拉3、自定义锁定行/列,添加密码4、合并

mybatis执行insert返回id实现详解

《mybatis执行insert返回id实现详解》MyBatis插入操作默认返回受影响行数,需通过useGeneratedKeys+keyProperty或selectKey获取主键ID,确保主键为自... 目录 两种方式获取自增 ID:1. ​​useGeneratedKeys+keyProperty(推

Spring Boot集成Druid实现数据源管理与监控的详细步骤

《SpringBoot集成Druid实现数据源管理与监控的详细步骤》本文介绍如何在SpringBoot项目中集成Druid数据库连接池,包括环境搭建、Maven依赖配置、SpringBoot配置文件... 目录1. 引言1.1 环境准备1.2 Druid介绍2. 配置Druid连接池3. 查看Druid监控

Linux在线解压jar包的实现方式

《Linux在线解压jar包的实现方式》:本文主要介绍Linux在线解压jar包的实现方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录linux在线解压jar包解压 jar包的步骤总结Linux在线解压jar包在 Centos 中解压 jar 包可以使用 u

c++ 类成员变量默认初始值的实现

《c++类成员变量默认初始值的实现》本文主要介绍了c++类成员变量默认初始值,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧... 目录C++类成员变量初始化c++类的变量的初始化在C++中,如果使用类成员变量时未给定其初始值,那么它将被

Java通过驱动包(jar包)连接MySQL数据库的步骤总结及验证方式

《Java通过驱动包(jar包)连接MySQL数据库的步骤总结及验证方式》本文详细介绍如何使用Java通过JDBC连接MySQL数据库,包括下载驱动、配置Eclipse环境、检测数据库连接等关键步骤,... 目录一、下载驱动包二、放jar包三、检测数据库连接JavaJava 如何使用 JDBC 连接 mys

SQL中如何添加数据(常见方法及示例)

《SQL中如何添加数据(常见方法及示例)》SQL全称为StructuredQueryLanguage,是一种用于管理关系数据库的标准编程语言,下面给大家介绍SQL中如何添加数据,感兴趣的朋友一起看看吧... 目录在mysql中,有多种方法可以添加数据。以下是一些常见的方法及其示例。1. 使用INSERT I

Qt使用QSqlDatabase连接MySQL实现增删改查功能

《Qt使用QSqlDatabase连接MySQL实现增删改查功能》这篇文章主要为大家详细介绍了Qt如何使用QSqlDatabase连接MySQL实现增删改查功能,文中的示例代码讲解详细,感兴趣的小伙伴... 目录一、创建数据表二、连接mysql数据库三、封装成一个完整的轻量级 ORM 风格类3.1 表结构

基于Python实现一个图片拆分工具

《基于Python实现一个图片拆分工具》这篇文章主要为大家详细介绍了如何基于Python实现一个图片拆分工具,可以根据需要的行数和列数进行拆分,感兴趣的小伙伴可以跟随小编一起学习一下... 简单介绍先自己选择输入的图片,默认是输出到项目文件夹中,可以自己选择其他的文件夹,选择需要拆分的行数和列数,可以通过