Tensorflow2.0笔记 - 使用卷积神经网络层做CIFA100数据集训练(类VGG13)

本文主要是介绍Tensorflow2.0笔记 - 使用卷积神经网络层做CIFA100数据集训练(类VGG13),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

        本笔记记录CNN做CIFAR100数据集的训练相关内容,代码中使用了类似VGG13的网络结构,做了两个Sequetial(CNN和全连接层),没有用Flatten层而是用reshape操作做CNN和全连接层的中转操作。由于网络层次较深,参数量相比之前的网络多了不少,因此只做了10次epoch(RTX4090),没有继续跑了,最终准确率大概在33.8%左右。

import os
import time
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics, Inputos.environ['TF_CPP_MIN_LOG_LEVEL']='2'
#tf.random.set_seed(12345)
tf.__version__#如果下载很慢,可以使用迅雷下载到本地,迅雷的链接也可以直接用官网URL:
#      https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz
#下载好后,将cifar-100.python.tar.gz放到 .keras\datasets 目录下(我的环境是C:\Users\Administrator\.keras\datasets)
# 参考:https://blog.csdn.net/zy_like_study/article/details/104219259
(x_train,y_train), (x_test, y_test) = datasets.cifar100.load_data()
print("Train data shape:", x_train.shape)
print("Train label shape:", y_train.shape)
print("Test data shape:", x_test.shape)
print("Test label shape:", y_test.shape)def preprocess(x, y):x = tf.cast(x, dtype=tf.float32) / 255.y = tf.cast(y, dtype=tf.int32)return x,yy_train = tf.squeeze(y_train, axis=1)
y_test = tf.squeeze(y_test, axis=1)batch_size = 128
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_db = train_db.shuffle(1000).map(preprocess).batch(batch_size)test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_db = test_db.map(preprocess).batch(batch_size)sample = next(iter(train_db))
print("Train data sample:", sample[0].shape, sample[1].shape, tf.reduce_min(sample[0]), tf.reduce_max(sample[0]))#创建CNN网络,总共4个unit,每个unit主要是两个卷积层和Max Pooling池化层
cnn_layers = [#unit 1layers.Conv2D(64, kernel_size=[3,3], padding='same', activation='relu'),layers.Conv2D(64, kernel_size=[3,3], padding='same', activation='relu'),#layers.MaxPool2D(pool_size=[2,2], strides=2, padding='same'),layers.MaxPool2D(pool_size=[2,2], strides=2),#unit 2layers.Conv2D(128, kernel_size=[3,3], padding='same', activation='relu'),layers.Conv2D(128, kernel_size=[3,3], padding='same', activation='relu'),#layers.MaxPool2D(pool_size=[2,2], strides=2, padding='same'),layers.MaxPool2D(pool_size=[2,2], strides=2),#unit 3layers.Conv2D(256, kernel_size=[3,3], padding='same', activation='relu'),layers.Conv2D(256, kernel_size=[3,3], padding='same', activation='relu'),#layers.MaxPool2D(pool_size=[2,2], strides=2, padding='same'),layers.MaxPool2D(pool_size=[2,2], strides=2),#unit 4layers.Conv2D(512, kernel_size=[3,3], padding='same', activation='relu'),layers.Conv2D(512, kernel_size=[3,3], padding='same', activation='relu'),#layers.MaxPool2D(pool_size=[2,2], strides=2, padding='same'),layers.MaxPool2D(pool_size=[2,2], strides=2),#unit 5layers.Conv2D(512, kernel_size=[3,3], padding='same', activation='relu'),layers.Conv2D(512, kernel_size=[3,3], padding='same', activation='relu'),#layers.MaxPool2D(pool_size=[2,2], strides=2, padding='same'),layers.MaxPool2D(pool_size=[2,2], strides=2),
]def main():#[b, 32, 32, 3] => [b, 1, 1, 512]cnn_net = Sequential(cnn_layers)cnn_net.build(input_shape=[None, 32, 32, 3])#测试一下卷积层的输出#x = tf.random.normal([4, 32, 32, 3])#out = cnn_net(x)#print(out.shape)#创建全连接层, 输出为100分类fc_net = Sequential([layers.Dense(256, activation='relu'),layers.Dense(128, activation='relu'),layers.Dense(100, activation=None),])fc_net.build(input_shape=[None, 512])#设置优化器optimizer = optimizers.Adam(learning_rate=1e-4)#记录cnn层和全连接层所有可训练参数, 实现的效果类似list拼接,比如# [1, 2] + [3, 4] => [1, 2, 3, 4]variables = cnn_net.trainable_variables + fc_net.trainable_variables#进行训练num_epoches = 10for epoch in range(num_epoches):for step, (x,y) in enumerate(train_db):with tf.GradientTape() as tape:#[b, 32, 32, 3] => [b, 1, 1, 512]out = cnn_net(x)#flatten打平 => [b, 512]out = tf.reshape(out, [-1, 512])#使用全连接层做100分类logits输出#[b, 512] => [b, 100]logits = fc_net(out)#标签做one_hot encodingy_onehot = tf.one_hot(y, depth=100)#计算损失loss = tf.losses.categorical_crossentropy(y_onehot, logits, from_logits=True)loss = tf.reduce_mean(loss)#计算梯度grads = tape.gradient(loss, variables)#更新参数optimizer.apply_gradients(zip(grads, variables))if (step % 100 == 0):print("Epoch[", epoch + 1, "/", num_epoches, "]: step-", step, " loss:", float(loss))#进行验证total_samples = 0total_correct = 0for x,y in test_db:out = cnn_net(x)out = tf.reshape(out, [-1, 512])logits = fc_net(out)prob = tf.nn.softmax(logits, axis=1)pred = tf.argmax(prob, axis=1)pred = tf.cast(pred, dtype=tf.int32)correct = tf.cast(tf.equal(pred, y), dtype=tf.int32)correct = tf.reduce_sum(correct)total_samples += x.shape[0]total_correct += int(correct)#统计准确率acc = total_correct / total_samplesprint("Epoch[", epoch + 1, "/", num_epoches, "]: accuracy:", acc)
if __name__ == '__main__':main()

运行结果:

这篇关于Tensorflow2.0笔记 - 使用卷积神经网络层做CIFA100数据集训练(类VGG13)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/918731

相关文章

SpringBoot集成EasyExcel实现百万级别的数据导入导出实践指南

《SpringBoot集成EasyExcel实现百万级别的数据导入导出实践指南》本文将基于开源项目springboot-easyexcel-batch进行解析与扩展,手把手教大家如何在SpringBo... 目录项目结构概览核心依赖百万级导出实战场景核心代码效果百万级导入实战场景监听器和Service(核心

setsid 命令工作原理和使用案例介绍

《setsid命令工作原理和使用案例介绍》setsid命令在Linux中创建独立会话,使进程脱离终端运行,适用于守护进程和后台任务,通过重定向输出和确保权限,可有效管理长时间运行的进程,本文给大家介... 目录setsid 命令介绍和使用案例基本介绍基本语法主要特点命令参数使用案例1. 在后台运行命令2.

使用Redis快速实现共享Session登录的详细步骤

《使用Redis快速实现共享Session登录的详细步骤》在Web开发中,Session通常用于存储用户的会话信息,允许用户在多个页面之间保持登录状态,Redis是一个开源的高性能键值数据库,广泛用于... 目录前言实现原理:步骤:使用Redis实现共享Session登录1. 引入Redis依赖2. 配置R

使用Python的requests库调用API接口的详细步骤

《使用Python的requests库调用API接口的详细步骤》使用Python的requests库调用API接口是开发中最常用的方式之一,它简化了HTTP请求的处理流程,以下是详细步骤和实战示例,涵... 目录一、准备工作:安装 requests 库二、基本调用流程(以 RESTful API 为例)1.

使用Python开发一个Ditto剪贴板数据导出工具

《使用Python开发一个Ditto剪贴板数据导出工具》在日常工作中,我们经常需要处理大量的剪贴板数据,下面将介绍如何使用Python的wxPython库开发一个图形化工具,实现从Ditto数据库中读... 目录前言运行结果项目需求分析技术选型核心功能实现1. Ditto数据库结构分析2. 数据库自动定位3

Python yield与yield from的简单使用方式

《Pythonyield与yieldfrom的简单使用方式》生成器通过yield定义,可在处理I/O时暂停执行并返回部分结果,待其他任务完成后继续,yieldfrom用于将一个生成器的值传递给另一... 目录python yield与yield from的使用代码结构总结Python yield与yield

Go语言使用select监听多个channel的示例详解

《Go语言使用select监听多个channel的示例详解》本文将聚焦Go并发中的一个强力工具,select,这篇文章将通过实际案例学习如何优雅地监听多个Channel,实现多任务处理、超时控制和非阻... 目录一、前言:为什么要使用select二、实战目标三、案例代码:监听两个任务结果和超时四、运行示例五

python使用Akshare与Streamlit实现股票估值分析教程(图文代码)

《python使用Akshare与Streamlit实现股票估值分析教程(图文代码)》入职测试中的一道题,要求:从Akshare下载某一个股票近十年的财务报表包括,资产负债表,利润表,现金流量表,保存... 目录一、前言二、核心知识点梳理1、Akshare数据获取2、Pandas数据处理3、Matplotl

pandas数据的合并concat()和merge()方式

《pandas数据的合并concat()和merge()方式》Pandas中concat沿轴合并数据框(行或列),merge基于键连接(内/外/左/右),concat用于纵向或横向拼接,merge用于... 目录concat() 轴向连接合并(1) join='outer',axis=0(2)join='o

批量导入txt数据到的redis过程

《批量导入txt数据到的redis过程》用户通过将Redis命令逐行写入txt文件,利用管道模式运行客户端,成功执行批量删除以Product*匹配的Key操作,提高了数据清理效率... 目录批量导入txt数据到Redisjs把redis命令按一条 一行写到txt中管道命令运行redis客户端成功了批量删除k