初涉LeNet5处理mnist (CNN卷积神经网络)

2024-06-15 22:38

本文主要是介绍初涉LeNet5处理mnist (CNN卷积神经网络),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import os
import numpy as np
#输入节点个数
INPUT_NODE = 784
#输出节点个数
OUTPUT_NODE = 10
#图片的尺寸
IMAGE_SIZE = 28
#通道数
NUM_CHANNELS = 1
#输出节点的个数 
NUM_LABELS = 10
#第一层卷积层的过滤器深度及其尺寸
CONV1_DEEP = 32
CONV1_SIZE = 5
#第二层卷积层的过滤器深度及其尺寸
CONV2_DEEP = 64
CONV2_SIZE = 5
#全连接层的节点个数
FC_SIZE = 512def inference(input_tensor, train, regularizer):#第一层卷积层输入大小是28*28*1=784=INPUT_NODE  #卷积层参数个数计算: CONV1_SIZE*CONV1_SIZE*NUM_CHANNELS*CONV1_DEEP+bias =5*5*1*32+32  过滤器的长*宽*过滤器深度*当前层深度+biases(个数为过滤器深度)#过滤器尺寸5*5深度为32     从strides=[1, 1, 1, 1]可得  步长的长宽方向分别为1  第二维度跟第三维度表示分别为长宽方向步长#输出的深度为CONV1_DEEP=32  由于SAME是全0填充,因此输出的尺寸为当前输入矩阵的长宽分别除以对应的步长 28*28   bias与输出深度个数一致with tf.variable_scope('layer1-conv1'):#weight前两个维度过滤器的尺寸  第三个维度当前层的深度 第四个是过滤器的维度conv1_weights = tf.get_variable("weight", [CONV1_SIZE, CONV1_SIZE, NUM_CHANNELS, CONV1_DEEP],initializer=tf.truncated_normal_initializer(stddev=0.1))conv1_biases = tf.get_variable("bias", [CONV1_DEEP], initializer=tf.constant_initializer(0.0))conv1 = tf.nn.conv2d(input_tensor, conv1_weights, strides=[1, 1, 1, 1], padding='SAME')relu1 = tf.nn.relu(tf.nn.bias_add(conv1, conv1_biases))#给convert卷积的这个结果加上偏置  然后利用激活函数ReLu#第二层 池化层   输入矩阵为第一层的输出  28*28*32    池化层输出与当前输入的深度一致32    池化层采用了2*2的过滤器尺寸  并且SAME方法全0填充#步长的长宽方向分别为2 所以输出尺寸为28/2=14  输出14*14*32的矩阵   池化层可以改变输入的尺寸但是不改变深度with tf.name_scope("layer2-pool1"):#其中relu1是激活函数  ksize是过滤器尺寸  strides是步长 SAME是全0填充  VALID是不适用全0   SAME方法得到的尺寸是输入的尺寸/步长#  VALID方法输出的尺寸是 ( 输入尺寸-过滤器尺寸+1)/2取得上限值pool1 = tf.nn.max_pool(relu1, ksize = [1,2,2,1],strides=[1,2,2,1],padding="SAME")#第三层 卷积层  输入矩阵为14*14*32     本层步长为1  所以输出尺寸为14/1=14    输出的矩阵为14*14*64    with tf.variable_scope("layer3-conv2"):#weight前两个维度过滤器的尺寸  第三个维度当前层的深度 第四个是过滤器的维度 :尺寸为5*5 深度为64的过滤器,  当前层深度为32  输出深度为64conv2_weights = tf.get_variable("weight", [CONV2_SIZE, CONV2_SIZE, CONV1_DEEP, CONV2_DEEP],initializer=tf.truncated_normal_initializer(stddev=0.1))conv2_biases = tf.get_variable("bias", [CONV2_DEEP], initializer=tf.constant_initializer(0.0))conv2 = tf.nn.conv2d(pool1, conv2_weights, strides=[1, 1, 1, 1], padding='SAME')relu2 = tf.nn.relu(tf.nn.bias_add(conv2, conv2_biases))#第四层  池化层输入矩阵为上一层输出 14*14*64 过滤器尺寸为2*2 深度为64  池化层的输出深度同输入深度  步长分别为2#所以输出尺寸是14/2=7   pool2的输出矩阵7*7*64with tf.name_scope("layer4-pool2"):pool2 = tf.nn.max_pool(relu2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')pool_shape = pool2.get_shape().as_list()#pool2.get_shape()获得第四层输出矩阵的维度   #每一层神经网络的输入输出都为一个batch的矩阵 所以这里面的维度也包含了一个batch中数据的个数pool_shape[0]nodes = pool_shape[1] * pool_shape[2] * pool_shape[3]reshaped = tf.reshape(pool2, [pool_shape[0], nodes])#把第四层的输出变为一个batch的向量#  第五层全连接层  输入为一组向量  向量长度为7*7*64=3136=nodes   输出一组长度为FC_SIZE=512的向量with tf.variable_scope('layer5-fc1'):fc1_weights = tf.get_variable("weight", [nodes, FC_SIZE],initializer=tf.truncated_normal_initializer(stddev=0.1))#只有全连接的权重需要加入正则化if regularizer != None: tf.add_to_collection('losses', regularizer(fc1_weights))fc1_biases = tf.get_variable("bias", [FC_SIZE], initializer=tf.constant_initializer(0.1))fc1 = tf.nn.relu(tf.matmul(reshaped, fc1_weights) + fc1_biases)#dropout在训练时会随机将部分的节点的输出改为0 避免过拟合问题  一般只在全连接层使用if train: fc1 = tf.nn.dropout(fc1, 0.5)#第六层  全连接层  也是输出层 输入为一组长度为512的向量 输出为一组长度为10的向量  这一次输出后会通过softmax得到分类结果      with tf.variable_scope('layer6-fc2'):fc2_weights = tf.get_variable("weight", [FC_SIZE, NUM_LABELS],initializer=tf.truncated_normal_initializer(stddev=0.1))if regularizer != None: tf.add_to_collection('losses', regularizer(fc2_weights))fc2_biases = tf.get_variable("bias", [NUM_LABELS], initializer=tf.constant_initializer(0.1))logit = tf.matmul(fc1, fc2_weights) + fc2_biasesreturn logitBATCH_SIZE = 100
LEARNING_RATE_BASE = 0.01
LEARNING_RATE_DECAY = 0.99
REGULARIZATION_RATE = 0.0001
TRAINING_STEPS = 6000
MOVING_AVERAGE_DECAY = 0.99#定义训练过程
def train(mnist):# 定义输出为4维矩阵的placeholderx = tf.placeholder(tf.float32, [BATCH_SIZE,IMAGE_SIZE,IMAGE_SIZE,NUM_CHANNELS],name='x-input')y_ = tf.placeholder(tf.float32, [None, LeNet5_infernece.OUTPUT_NODE], name='y-input')regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)y = inference(x,False,regularizer)global_step = tf.Variable(0, trainable=False)# 定义损失函数、学习率、滑动平均操作以及训练过程。variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)variables_averages_op = variable_averages.apply(tf.trainable_variables())cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))cross_entropy_mean = tf.reduce_mean(cross_entropy)loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses'))learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE,global_step,mnist.train.num_examples / BATCH_SIZE, LEARNING_RATE_DECAY,staircase=True)train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)with tf.control_dependencies([train_step, variables_averages_op]):train_op = tf.no_op(name='train')# 初始化TensorFlow持久化类。saver = tf.train.Saver()with tf.Session() as sess:tf.global_variables_initializer().run()for i in range(TRAINING_STEPS):xs, ys = mnist.train.next_batch(BATCH_SIZE)reshaped_xs = np.reshape(xs, (BATCH_SIZE,IMAGE_SIZE,IMAGE_SIZE,NUM_CHANNELS))_, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: reshaped_xs, y_: ys})if i % 1000 == 0:print("After %d training step(s), loss on training batch is %g." % (step, loss_value))def main(argv=None):mnist = input_data.read_data_sets("datasets/MNIST_data", one_hot=True)train(mnist)if __name__ == '__main__':main()

这篇关于初涉LeNet5处理mnist (CNN卷积神经网络)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1064784

相关文章

电脑提示xlstat4.dll丢失怎么修复? xlstat4.dll文件丢失处理办法

《电脑提示xlstat4.dll丢失怎么修复?xlstat4.dll文件丢失处理办法》长时间使用电脑,大家多少都会遇到类似dll文件丢失的情况,不过,解决这一问题其实并不复杂,下面我们就来看看xls... 在Windows操作系统中,xlstat4.dll是一个重要的动态链接库文件,通常用于支持各种应用程序

SQL Server数据库死锁处理超详细攻略

《SQLServer数据库死锁处理超详细攻略》SQLServer作为主流数据库管理系统,在高并发场景下可能面临死锁问题,影响系统性能和稳定性,这篇文章主要给大家介绍了关于SQLServer数据库死... 目录一、引言二、查询 Sqlserver 中造成死锁的 SPID三、用内置函数查询执行信息1. sp_w

Java对异常的认识与异常的处理小结

《Java对异常的认识与异常的处理小结》Java程序在运行时可能出现的错误或非正常情况称为异常,下面给大家介绍Java对异常的认识与异常的处理,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参... 目录一、认识异常与异常类型。二、异常的处理三、总结 一、认识异常与异常类型。(1)简单定义-什么是

Golang 日志处理和正则处理的操作方法

《Golang日志处理和正则处理的操作方法》:本文主要介绍Golang日志处理和正则处理的操作方法,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考... 目录1、logx日志处理1.1、logx简介1.2、日志初始化与配置1.3、常用方法1.4、配合defer

springboot加载不到nacos配置中心的配置问题处理

《springboot加载不到nacos配置中心的配置问题处理》:本文主要介绍springboot加载不到nacos配置中心的配置问题处理,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑... 目录springboot加载不到nacos配置中心的配置两种可能Spring Boot 版本Nacos

python web 开发之Flask中间件与请求处理钩子的最佳实践

《pythonweb开发之Flask中间件与请求处理钩子的最佳实践》Flask作为轻量级Web框架,提供了灵活的请求处理机制,中间件和请求钩子允许开发者在请求处理的不同阶段插入自定义逻辑,实现诸如... 目录Flask中间件与请求处理钩子完全指南1. 引言2. 请求处理生命周期概述3. 请求钩子详解3.1

Python处理大量Excel文件的十个技巧分享

《Python处理大量Excel文件的十个技巧分享》每天被大量Excel文件折磨的你看过来!这是一份Python程序员整理的实用技巧,不说废话,直接上干货,文章通过代码示例讲解的非常详细,需要的朋友可... 目录一、批量读取多个Excel文件二、选择性读取工作表和列三、自动调整格式和样式四、智能数据清洗五、

SpringBoot如何对密码等敏感信息进行脱敏处理

《SpringBoot如何对密码等敏感信息进行脱敏处理》这篇文章主要为大家详细介绍了SpringBoot对密码等敏感信息进行脱敏处理的几个常用方法,文中的示例代码讲解详细,感兴趣的小伙伴可以了解下... 目录​1. 配置文件敏感信息脱敏​​2. 日志脱敏​​3. API响应脱敏​​4. 其他注意事项​​总结

Python使用python-docx实现自动化处理Word文档

《Python使用python-docx实现自动化处理Word文档》这篇文章主要为大家展示了Python如何通过代码实现段落样式复制,HTML表格转Word表格以及动态生成可定制化模板的功能,感兴趣的... 目录一、引言二、核心功能模块解析1. 段落样式与图片复制2. html表格转Word表格3. 模板生

Python Pandas高效处理Excel数据完整指南

《PythonPandas高效处理Excel数据完整指南》在数据驱动的时代,Excel仍是大量企业存储核心数据的工具,Python的Pandas库凭借其向量化计算、内存优化和丰富的数据处理接口,成为... 目录一、环境搭建与数据读取1.1 基础环境配置1.2 数据高效载入技巧二、数据清洗核心战术2.1 缺失