利用Eager写的自定义模型来训练神经网络(mnist示例)

2024-03-29 13:48

本文主要是介绍利用Eager写的自定义模型来训练神经网络(mnist示例),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

利用Eager写的自定义循环来训练神经网络(mnist示例)

文章目录

  • 利用Eager写的自定义循环来训练神经网络(mnist示例)
    • 1.为什么
    • 2.确认eager是否可以使用
    • 3.数据集选择
    • 4.数据预处理
    • 5.模型的建立
    • 6.对于每次训练的自定义
    • 7.模型的验证
    • 8.结语

1.为什么

tensorflow的keras提供了非常具体的舒服封装的神经网络和各项功能,但是越封装越影响我们对他的自定义的神经网络,所以tensorflow提供了一个 eager 自定义模式,可以自定义内层的网络循环,,这样可以在同时拥有高度封装的网络模型的情况下对训练的每一步进行调试,控制每一步的流程,也有利于我们搞懂神经网络的内部循环。

2.确认eager是否可以使用

我使用的版本是tensorflow2.4.0,默认支持eager模式,但是我们也可以用个功能来查看

tf.executing_eagerly()
True

这样就说明我们的电脑可以支持eager的运行模式,接下来我们就可以使用eager来支持对于训练每一步的自定义模式。

3.数据集选择

我们在这里选择深度学习的hello world测试集mnist,来测试,他是由60000张手写体图片组成的,每张图片大小为(28,28)灰度图,每张图片读取后可以这样展示出来

import matplotlib.pyplot as plt
import random 
ch=random.choice(range(len(train_image)))
%matplotlib inline
plt.imshow(train_image[ch])

在这里插入图片描述

4.数据预处理

对于该数据集,我们打算使用CNN卷积神经网络来解决该问题,在使用卷积神经网络需要注意使用对于图片的卷积神经网络的话。需要输入的具有四个维度分别是(batch,width,length,channel)但是我们输入的图片目前只有三个维度,最后一个维度需要我们去扩充,并且图片数据现在是无符号整型数,我们还要转换成浮点型,以及将数据归一化,代码如下:

import tensorflow as tf
from tensorflow import keras
(train_image,train_label),(test_image,test_label)=keras.datasets.mnist.load_data()#导入数据
train_images=tf.expand_dims(train_image,-1)#扩展维度
train_images=tf.cast(train_images/255,tf.float32)#归一化处理数据,然后转化数据格式
train_label=tf.cast(train_label,tf.int64)#处理标签数据
train_data=tf.data.Dataset.from_tensor_slices((train_images,train_label))#制作数据
BATCH_SIZE=32
train_data=train_data.shuffle(10000).batch(BATCH_SIZE)
#制作训练数据

5.模型的建立

​ 对于mnist这样的数据集,我们就按照正常的卷积神经网络的布局去解决就好了,唯一要注意的就是我们在这里最后不选择激活(最后返回一个十维的张量,我们可以根据哪个维度的值比较大来确定该图是哪个值,这样的值我们称为logits,同时我们也要注意如果不激活的话在计算损失的话也需要注意要声明我们最后没有激活)

model=tf.keras.Sequential()
model.add(keras.layers.Conv2D(16,(3,3),input_shape=(28,28,1),activation='relu'))
model.add(keras.layers.Conv2D(32,(3,3),activation='relu'))
model.add(keras.layers.GlobalAveragePooling2D())
model.add(keras.layers.Dense(10))

6.对于每次训练的自定义

​ 接下来才是我们eager的自定义循环最重要的部分,我们先缕清数据到底是怎么训练的,每一批次的数据输入模型,模型返回值,计算与真值的损失函数,从而来修改模型的可训练参数,如果不用eager自定义模式,那么其实这些每一步都由tensorflow自己帮我们完成,那么接下来这些由我们自己来完成,同时还要注意每一epoch训练完要输出准确率和loss那么接下来让我们来用代码实现

train_loss = tf.keras.metrics.Mean('train_loss')#meteics计算损失对象
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy('train_accuracy')
#计算每一步的训练准确率#在每一批次的训练中我们规定如何去训练
def train_step(model, images, labels):#利用t来记录每个变量的变化with tf.GradientTape() as t:pred = model(images)loss_step = loss_func(labels, pred)#求解损失函数对于可训练参数的微分grads = t.gradient(loss_step, model.trainable_variables)#对于优化器来应用这样求解的梯度对于可训练参数的修改呢optimizer.apply_gradients(zip(grads, model.trainable_variables))train_loss(loss_step)#返回loss的均值train_accuracy(labels, pred)#返回准确率均值
def train():for epoch in range(10):for (batch, (images, labels)) in enumerate(dataset):train_step(model, images, labels)print('Epoch{} loss is {}, accuracy is {}'.format(epoch,train_loss.result(),train_accuracy.result()))train_loss.reset_states()#重新设置训练参数的状态train_accuracy.reset_states()
train()#最后我们直接运行参数开始训练

输出以下内容

Epoch0 loss is 0.9328579902648926, accuracy is 0.7062000036239624

Epoch1 loss is 0.3824250400066376, accuracy is 0.8821166753768921

Epoch2 loss is 0.3067557215690613, accuracy is 0.90461665391922

Epoch3 loss is 0.27012768387794495, accuracy is 0.914900004863739

Epoch4 loss is 0.24741436541080475, accuracy is 0.9226499795913696

Epoch5 loss is 0.22932228446006775, accuracy is 0.9286666512489319

Epoch6 loss is 0.2168780267238617, accuracy is 0.9328500032424927

Epoch7 loss is 0.20622499287128448, accuracy is 0.9351166486740112

Epoch8 loss is 0.197651669383049, accuracy is 0.9382666945457458

Epoch9 loss is 0.19150032103061676, accuracy is 0.940500020980835

这样我们就完成了每一步的训练(但我在实际跑的时候,由于我不了解对于缓存的处理,我把CPU跑爆了,即使我将batch,和shuffle的数据降低,还是爆了,应该不是计算资源吧,我调用GPU跑也还是这样,希望有知道的人可以在评论区指点迷津),我们来测试一下预测值

7.模型的验证

features,label=next(iter(train_data))#获取一个批次的数据
features.shape#每一批次的数据为三十二个
prediction=model(features)
prediction.shape
tf.argmax(prediction, axis=1)
<tf.Tensor: shape=(32,), dtype=int64, numpy=
array([7, 6, 8, 8, 3, 6, 3, 7, 3, 9, 2, 1, 6, 0, 8, 3, 3, 4, 3, 8, 6, 5,2, 4, 1, 1, 0, 4, 8, 5, 2, 9], dtype=int64)>

那么我们接下来查看真值

print(label)
<tf.Tensor: shape=(32,), dtype=int64, numpy=
array([7, 6, 8, 2, 3, 6, 3, 7, 3, 9, 2, 1, 6, 0, 8, 3, 3, 4, 3, 8, 6, 5,2, 4, 1, 1, 0, 4, 8, 5, 2, 9], dtype=int64)>

我们可以看到我们获得的模型虽然有误差但是整体来说准确率还是挺高的。

8.结语

​ 在本例中,我们完成了对于模型每次训练的自定义完成了在我们自己定义下的循环中去不断训练模型的实例,但在本次训练中对于CPU的使用还需优化,暂时没找到解决的办法希望有懂得人,或者对博客中指出的错误都可以在博客中讨论。

这篇关于利用Eager写的自定义模型来训练神经网络(mnist示例)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/858756

相关文章

详解SpringBoot+Ehcache使用示例

《详解SpringBoot+Ehcache使用示例》本文介绍了SpringBoot中配置Ehcache、自定义get/set方式,并实际使用缓存的过程,文中通过示例代码介绍的非常详细,对大家的学习或者... 目录摘要概念内存与磁盘持久化存储:配置灵活性:编码示例引入依赖:配置ehcache.XML文件:配置

Java高效实现PowerPoint转PDF的示例详解

《Java高效实现PowerPoint转PDF的示例详解》在日常开发或办公场景中,经常需要将PowerPoint演示文稿(PPT/PPTX)转换为PDF,本文将介绍从基础转换到高级设置的多种用法,大家... 目录为什么要将 PowerPoint 转换为 PDF安装 Spire.Presentation fo

Python中isinstance()函数原理解释及详细用法示例

《Python中isinstance()函数原理解释及详细用法示例》isinstance()是Python内置的一个非常有用的函数,用于检查一个对象是否属于指定的类型或类型元组中的某一个类型,它是Py... 目录python中isinstance()函数原理解释及详细用法指南一、isinstance()函数

python中的高阶函数示例详解

《python中的高阶函数示例详解》在Python中,高阶函数是指接受函数作为参数或返回函数作为结果的函数,下面:本文主要介绍python中高阶函数的相关资料,文中通过代码介绍的非常详细,需要的朋... 目录1.定义2.map函数3.filter函数4.reduce函数5.sorted函数6.自定义高阶函数

Vue实现路由守卫的示例代码

《Vue实现路由守卫的示例代码》Vue路由守卫是控制页面导航的钩子函数,主要用于鉴权、数据预加载等场景,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着... 目录一、概念二、类型三、实战一、概念路由守卫(Navigation Guards)本质上就是 在路

JAVA实现Token自动续期机制的示例代码

《JAVA实现Token自动续期机制的示例代码》本文主要介绍了JAVA实现Token自动续期机制的示例代码,通过动态调整会话生命周期平衡安全性与用户体验,解决固定有效期Token带来的风险与不便,感兴... 目录1. 固定有效期Token的内在局限性2. 自动续期机制:兼顾安全与体验的解决方案3. 总结PS

C#中通过Response.Headers设置自定义参数的代码示例

《C#中通过Response.Headers设置自定义参数的代码示例》:本文主要介绍C#中通过Response.Headers设置自定义响应头的方法,涵盖基础添加、安全校验、生产实践及调试技巧,强... 目录一、基础设置方法1. 直接添加自定义头2. 批量设置模式二、高级配置技巧1. 安全校验机制2. 类型

Python屏幕抓取和录制的详细代码示例

《Python屏幕抓取和录制的详细代码示例》随着现代计算机性能的提高和网络速度的加快,越来越多的用户需要对他们的屏幕进行录制,:本文主要介绍Python屏幕抓取和录制的相关资料,需要的朋友可以参考... 目录一、常用 python 屏幕抓取库二、pyautogui 截屏示例三、mss 高性能截图四、Pill

Java中的Schema校验技术与实践示例详解

《Java中的Schema校验技术与实践示例详解》本主题详细介绍了在Java环境下进行XMLSchema和JSONSchema校验的方法,包括使用JAXP、JAXB以及专门的JSON校验库等技术,本文... 目录1. XML和jsON的Schema校验概念1.1 XML和JSON校验的必要性1.2 Sche

使用MapStruct实现Java对象映射的示例代码

《使用MapStruct实现Java对象映射的示例代码》本文主要介绍了使用MapStruct实现Java对象映射的示例代码,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,... 目录一、什么是 MapStruct?二、实战演练:三步集成 MapStruct第一步:添加 Mave