基于python3.6+tensorflow2.2的石头剪刀布案例

2024-02-24 18:20

本文主要是介绍基于python3.6+tensorflow2.2的石头剪刀布案例,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

unzip_save.py

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # 不显示等级2以下的提示信息
import zipfile
import matplotlib.pyplot as plt
import matplotlib.image as mpimg# 解压
local_zip1 = 'E:/Python/pythonProject_1/rps/tmp/rps.zip' # 数据集压缩包路径
zip_ref1 = zipfile.ZipFile(local_zip1, 'r') # 打开压缩包,以读取方式
zip_ref1.extractall('E:/Python/pythonProject_1/rps/tmp/') # 解压到以下路径
zip_ref1.close()local_zip2 = 'E:/Python/pythonProject_1/rps/tmp/rps-test-set.zip' # 数据集压缩包路径
zip_ref2 = zipfile.ZipFile(local_zip2, 'r') # 打开压缩包,以读取方式
zip_ref2.extractall('E:/Python/pythonProject_1/rps/tmp/') # 解压到以下路径
zip_ref2.close()rock_dir = os.path.join('E:/Python/pythonProject_1/rps/tmp/rps/rock')
paper_dir = os.path.join('E:/Python/pythonProject_1/rps/tmp/rps/paper')
scissors_dir = os.path.join('E:/Python/pythonProject_1/rps/tmp/rps/scissors')rock_files = os.listdir(rock_dir)
print(rock_files[:10])paper_files = os.listdir(paper_dir)
print(paper_files[:10])scissors_files = os.listdir(scissors_dir)
print(scissors_files[:10])pic_index = 2
next_rock = [os.path.join(rock_dir, fname)for fname in rock_files[pic_index - 2:pic_index]]
next_paper = [os.path.join(paper_dir, fname)for fname in paper_files[pic_index - 2:pic_index]]
next_scissors = [os.path.join(scissors_dir, fname)for fname in scissors_files[pic_index - 2:pic_index]]for i, img_path in enumerate(next_rock+next_paper+next_scissors):img = mpimg.imread(img_path)plt.imshow(img)plt.axis('Off')plt.show()


model_training_fit.py

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # 不显示等级2以下的提示信息import tensorflow as tf
# from tensorflow import keras
# from tensorflow.keras.optimizers import RMSprop
# from tensorflow.keras.preprocessing import image
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plttraining_datagen = ImageDataGenerator(# 数据增强rescale=1. / 255,rotation_range=40, # 旋转范围width_shift_range=0.2, # 宽平移height_shift_range=0.2,# 高平移shear_range=0.2, # 剪切zoom_range=0.2, # 缩放horizontal_flip=True,fill_mode='nearest'
)validation_datagen = ImageDataGenerator(rescale=1. / 255
)TRAINING_DIR = 'E:/Python/pythonProject_1/rps/tmp/rps/'
training_generator = training_datagen.flow_from_directory(TRAINING_DIR,target_size = (150, 150),class_mode = 'categorical'
)VALIDATION_DIR = 'E:/Python/pythonProject_1/rps/tmp/rps-test-set/'
validation_generator = validation_datagen.flow_from_directory(VALIDATION_DIR,target_size = (150, 150),class_mode = 'categorical'
)#======== 模型构建 =========
model = tf.keras.models.Sequential([tf.keras.layers.Conv2D(64, (3, 3), activation = 'relu', input_shape = (150, 150, 3)), # 输入参数:过滤器数量,过滤器尺寸,激活函数:relu, 输入图像尺寸tf.keras.layers.MaxPooling2D(2, 2), # 池化:增强特征tf.keras.layers.Conv2D(64, (3, 3), activation = 'relu'), # 输入参数:过滤器数量、过滤器尺寸、激活函数:relutf.keras.layers.MaxPooling2D(2, 2),tf.keras.layers.Conv2D(128, (3, 3), activation='relu'),  # 输入参数:过滤器数量、过滤器尺寸、激活函数:relutf.keras.layers.MaxPooling2D(2, 2),tf.keras.layers.Conv2D(128, (3, 3), activation='relu'),  # 输入参数:过滤器数量、过滤器尺寸、激活函数:relutf.keras.layers.MaxPooling2D(2, 2),tf.keras.layers.Flatten(), # 输入层tf.keras.layers.Dense(512, activation = 'relu'), # 全连接隐层 神经元数量:128 ,激活函数:relutf.keras.layers.Dense(3, activation = 'softmax') # 英文字母分类 26 ,阿拉伯数字分类 10  输出用的是softmax 概率化函数 使得所有输出加起来为1 0-1之间
])model.summary()#======== 模型参数编译 =========
model.compile(optimizer = 'rmsprop',loss = 'categorical_crossentropy', # 损失函数: 稀疏的交叉熵 binary_crossentropymetrics = ['accuracy']
)#======== 模型训练 =========
# Note that this may take some time.
history = model.fit_generator(training_generator,epochs = 25,validation_data = validation_generator,verbose = 1
)model.save('E:/Python/pythonProject_1/rps/model.h5') # model 保存#-----------------------------------------------------------
# Retrieve a list of list result on training and test data
# set for each training epoch
#-----------------------------------------------------------
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']epochs = range(len(acc)) # Get number of epochs#-----------------------------------------------------------
# Plot training and validation accuracy per epoch
#-----------------------------------------------------------
plt.plot(epochs, acc, 'r', label = "tra_acc")
plt.plot(epochs ,val_acc, 'b', label = "val_acc")
plt.title("training and validation accuracy")
plt.legend(loc=0)
plt.grid(ls='--')  # 生成网格
plt.show()
# 曲线呈直线是因为epochs/轮次太少
#-----------------------------------------------------------
# Plot training and validation loss per epoch
#-----------------------------------------------------------
plt.plot(epochs, loss, 'r', label = "train_loss")
plt.plot(epochs ,val_loss, 'b', label = "val_loss")
plt.title("training and validation loss")
plt.legend(loc=0)
plt.grid(ls='--')  # 生成网格
plt.show()
# 曲线呈直线是因为epochs/轮次太少


predict.py

import numpy as np
from tensorflow.keras.preprocessing import image
from tensorflow import keras
model = keras.models.load_model('E:/Python/pythonProject_1/rps/model.h5')# predicting images
path = 'E:/Python/pythonProject_1/rps/scissor.png'
img = image.load_img(path, target_size=(150, 150))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)images = np.vstack([x])
classes = model.predict(images, batch_size=10)
print(classes)


model.summary

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              (None, 148, 148, 64)      1792      
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 74, 74, 64)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 72, 72, 64)        36928     
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 36, 36, 64)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 34, 34, 128)       73856     
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 17, 17, 128)       0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 15, 15, 128)       147584    
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 7, 7, 128)         0         
_________________________________________________________________
flatten (Flatten)            (None, 6272)              0         
_________________________________________________________________
dense (Dense)                (None, 512)               3211776   
_________________________________________________________________
dense_1 (Dense)              (None, 3)                 1539      
=================================================================
Total params: 3,473,475
Trainable params: 3,473,475
Non-trainable params: 0
_________________________________________________________________


测试结果:

Epoch 23/25
79/79 [==============================] - 83s 1s/step - loss: 0.0410 - accuracy: 0.9905 - val_loss: 0.0064 - val_accuracy: 1.0000
Epoch 24/25
79/79 [==============================] - 82s 1s/step - loss: 0.0621 - accuracy: 0.9798 - val_loss: 0.1802 - val_accuracy: 0.9382
Epoch 25/25
79/79 [==============================] - 82s 1s/step - loss: 0.0704 - accuracy: 0.9821 - val_loss: 0.0640 - val_accuracy: 0.9704


 预测结果:

scissor.png

>>> print(classes)[[0. 0. 1.]]


数据来源地址: https://laurencemoroney.com/datasets.html

这篇关于基于python3.6+tensorflow2.2的石头剪刀布案例的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/742957

相关文章

Java中的分布式系统开发基于 Zookeeper 与 Dubbo 的应用案例解析

《Java中的分布式系统开发基于Zookeeper与Dubbo的应用案例解析》本文将通过实际案例,带你走进基于Zookeeper与Dubbo的分布式系统开发,本文通过实例代码给大家介绍的非常详... 目录Java 中的分布式系统开发基于 Zookeeper 与 Dubbo 的应用案例一、分布式系统中的挑战二

Java 中的 equals 和 hashCode 方法关系与正确重写实践案例

《Java中的equals和hashCode方法关系与正确重写实践案例》在Java中,equals和hashCode方法是Object类的核心方法,广泛用于对象比较和哈希集合(如HashMa... 目录一、背景与需求分析1.1 equals 和 hashCode 的背景1.2 需求分析1.3 技术挑战1.4

Java中实现对象的拷贝案例讲解

《Java中实现对象的拷贝案例讲解》Java对象拷贝分为浅拷贝(复制值及引用地址)和深拷贝(递归复制所有引用对象),常用方法包括Object.clone()、序列化及JSON转换,需处理循环引用问题,... 目录对象的拷贝简介浅拷贝和深拷贝浅拷贝深拷贝深拷贝和循环引用总结对象的拷贝简介对象的拷贝,把一个

Java中最全最基础的IO流概述和简介案例分析

《Java中最全最基础的IO流概述和简介案例分析》JavaIO流用于程序与外部设备的数据交互,分为字节流(InputStream/OutputStream)和字符流(Reader/Writer),处理... 目录IO流简介IO是什么应用场景IO流的分类流的超类类型字节文件流应用简介核心API文件输出流应用文

MyBatis分页查询实战案例完整流程

《MyBatis分页查询实战案例完整流程》MyBatis是一个强大的Java持久层框架,支持自定义SQL和高级映射,本案例以员工工资信息管理为例,详细讲解如何在IDEA中使用MyBatis结合Page... 目录1. MyBATis框架简介2. 分页查询原理与应用场景2.1 分页查询的基本原理2.1.1 分

深度解析Java @Serial 注解及常见错误案例

《深度解析Java@Serial注解及常见错误案例》Java14引入@Serial注解,用于编译时校验序列化成员,替代传统方式解决运行时错误,适用于Serializable类的方法/字段,需注意签... 目录Java @Serial 注解深度解析1. 注解本质2. 核心作用(1) 主要用途(2) 适用位置3

Java 正则表达式的使用实战案例

《Java正则表达式的使用实战案例》本文详细介绍了Java正则表达式的使用方法,涵盖语法细节、核心类方法、高级特性及实战案例,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要... 目录一、正则表达式语法详解1. 基础字符匹配2. 字符类([]定义)3. 量词(控制匹配次数)4. 边

Python Counter 函数使用案例

《PythonCounter函数使用案例》Counter是collections模块中的一个类,专门用于对可迭代对象中的元素进行计数,接下来通过本文给大家介绍PythonCounter函数使用案例... 目录一、Counter函数概述二、基本使用案例(一)列表元素计数(二)字符串字符计数(三)元组计数三、C

Spring Boot 整合 SSE(Server-Sent Events)实战案例(全网最全)

《SpringBoot整合SSE(Server-SentEvents)实战案例(全网最全)》本文通过实战案例讲解SpringBoot整合SSE技术,涵盖实现原理、代码配置、异常处理及前端交互,... 目录Spring Boot 整合 SSE(Server-Sent Events)1、简述SSE与其他技术的对

MySQL 临时表与复制表操作全流程案例

《MySQL临时表与复制表操作全流程案例》本文介绍MySQL临时表与复制表的区别与使用,涵盖生命周期、存储机制、操作限制、创建方法及常见问题,本文结合实例代码给大家介绍的非常详细,感兴趣的朋友跟随小... 目录一、mysql 临时表(一)核心特性拓展(二)操作全流程案例1. 复杂查询中的临时表应用2. 临时