深度学习之基于VGG16与ResNet50实现鸟类识别

2024-02-01 13:32

本文主要是介绍深度学习之基于VGG16与ResNet50实现鸟类识别,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

鸟类识别在之前做过,但是效果特别差。而且ResNet50的效果直接差到爆炸,这次利用VGG16与ResNet50的官方模型进行鸟类识别。

1.导入库

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os,pathlib,PIL
from tensorflow.keras import layers,models,Sequential,Input,Model
from tensorflow.keras.layers import Conv2D,MaxPooling2D,Flatten,Dense,BatchNormalization,ZeroPadding2D,Activation,AveragePooling2D# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号

2.数据准备

数据所在文件夹

data_dir = "E:/tmp/.keras/datasets/Birds_photos"
data_dir = pathlib.Path(data_dir)
img_count = len(list(data_dir.glob('*/*')))
print(img_count)#共565张图片

labels:一共是4类

all_images_paths = list(data_dir.glob('*'))##”*”匹配0个或多个字符
all_images_paths = [str(path) for path in all_images_paths]
all_label_names = [path.split("\\")[5].split(".")[0] for path in all_images_paths]

超参数的设置

height = 227
width = 227
batch_size = 8
epochs = 20

按照8:2的比例划分训练集与测试集

train_data_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255,validation_split=0.2
)
train_ds = train_data_gen.flow_from_directory(directory=data_dir,target_size=(height,width),batch_size=batch_size,shuffle=True,class_mode='categorical',subset='training'
)
test_ds = train_data_gen.flow_from_directory(directory=data_dir,target_size=(height,width),batch_size=batch_size,shuffle=True,class_mode='categorical',subset='validation'
)

查看数据

plt.figure(figsize=(15,10))for images,labels in train_ds:for i in range(32):ax = plt.subplot(4,8,i+1)plt.imshow(images[i])plt.title(all_label_names[np.argmax(labels[i])])plt.axis("off")break
plt.show()

在这里插入图片描述

3.VGG16网络

迁移学习调用VGG16的官方模型

conv_base = tf.keras.applications.VGG16(weights='imagenet',include_top=False)
#设置为不可训练
conv_base.trainable = False
#模型搭建
model = tf.keras.Sequential()
model.add(conv_base)
model.add(tf.keras.layers.GlobalAveragePooling2D())
model.add(tf.keras.layers.Dense(512,activation='relu'))
model.add(tf.keras.layers.Dense(4,activation='sigmoid'))

模型编译&&训练

model.compile(optimizer="adam",loss='categorical_crossentropy',metrics=['accuracy']
)
history = model.fit(train_ds,validation_data=test_ds,epochs=epochs
)

在这里插入图片描述
模型的准确率很高,在博主实验的几个模型中,VGG16的模型准确率是最高的。

保存网络:

model.save("E:/Users/yqx/PycharmProjects/BirdsRec/model.h5")

利用网络模型进行预测:

new_model = tf.keras.models.load_model("E:/Users/yqx/PycharmProjects/BirdsRec/model.h5")
plt.figure(figsize=(18,18))
plt.suptitle("预测结果展示")
for images,labels in test_ds:for i in range(8):ax = plt.subplot(2,4,i+1)plt.imshow(images[i])img_array = tf.expand_dims(images[i],0)#增加一个维度pre = new_model.predict(img_array)plt.title(all_label_names[np.argmax(pre)])plt.axis("off")break
plt.show()

在这里插入图片描述
绘制混淆矩阵

from sklearn.metrics import confusion_matrix
import seaborn as sns
import pandas as pd#绘制混淆矩阵
def plot_cm(labels,pre):conf_numpy = confusion_matrix(labels,pre)#根据实际值和预测值绘制混淆矩阵conf_df = pd.DataFrame(conf_numpy,index=all_label_names,columns=all_label_names)#将data和all_label_names制成DataFrameplt.figure(figsize=(8,8))sns.heatmap(conf_df,annot=True,fmt="d",cmap="BuPu")#将data绘制为混淆矩阵plt.title('混淆矩阵',fontsize = 15)plt.ylabel('真实值',fontsize = 14)plt.xlabel('预测值',fontsize = 14)plt.show()
test_pre = []
test_label = []
for images,labels in test_ds:for image,label in zip(images,labels):img_array = tf.expand_dims(image,0)#增加一个维度pre = model.predict(img_array)#预测结果test_pre.append(all_label_names[np.argmax(pre)])#将预测结果传入列表test_label.append(all_label_names[np.argmax(label)])#将真实结果传入列表break#由于硬件问题。这里我只用了一个batch,一共8张图片。
plot_cm(test_label,test_pre)#绘制混淆矩阵

在这里插入图片描述

4.ResNet50网络

与VGG16不同的是,ResNet50的网络参数设置的是可以训练,经过多次实验,这样ResNet50的实验效果是最好的。

conv_base = tf.keras.applications.ResNet50(weights='imagenet',include_top=False)
#设置为可以训练
conv_base.trainable = True
#模型搭建
model = tf.keras.Sequential()
model.add(conv_base)
model.add(tf.keras.layers.GlobalAveragePooling2D())
model.add(tf.keras.layers.Dense(512,activation='relu'))
model.add(tf.keras.layers.Dense(4,activation='sigmoid'))

在这里插入图片描述
虽然准确率在来回波动,但是整体的准确率是比较高的,比VGG16的准确率还是差一些的。博主关于ResNet50的了解还比较少,等到了解深刻了再回来更新。

这篇关于深度学习之基于VGG16与ResNet50实现鸟类识别的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/667426

相关文章

SpringBoot全局域名替换的实现

《SpringBoot全局域名替换的实现》本文主要介绍了SpringBoot全局域名替换的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一... 目录 项目结构⚙️ 配置文件application.yml️ 配置类AppProperties.Ja

Python实现批量CSV转Excel的高性能处理方案

《Python实现批量CSV转Excel的高性能处理方案》在日常办公中,我们经常需要将CSV格式的数据转换为Excel文件,本文将介绍一个基于Python的高性能解决方案,感兴趣的小伙伴可以跟随小编一... 目录一、场景需求二、技术方案三、核心代码四、批量处理方案五、性能优化六、使用示例完整代码七、小结一、

Java实现将HTML文件与字符串转换为图片

《Java实现将HTML文件与字符串转换为图片》在Java开发中,我们经常会遇到将HTML内容转换为图片的需求,本文小编就来和大家详细讲讲如何使用FreeSpire.DocforJava库来实现这一功... 目录前言核心实现:html 转图片完整代码场景 1:转换本地 HTML 文件为图片场景 2:转换 H

C#使用Spire.Doc for .NET实现HTML转Word的高效方案

《C#使用Spire.Docfor.NET实现HTML转Word的高效方案》在Web开发中,HTML内容的生成与处理是高频需求,然而,当用户需要将HTML页面或动态生成的HTML字符串转换为Wor... 目录引言一、html转Word的典型场景与挑战二、用 Spire.Doc 实现 HTML 转 Word1

C#实现一键批量合并PDF文档

《C#实现一键批量合并PDF文档》这篇文章主要为大家详细介绍了如何使用C#实现一键批量合并PDF文档功能,文中的示例代码简洁易懂,感兴趣的小伙伴可以跟随小编一起学习一下... 目录前言效果展示功能实现1、添加文件2、文件分组(书签)3、定义页码范围4、自定义显示5、定义页面尺寸6、PDF批量合并7、其他方法

SpringBoot实现不同接口指定上传文件大小的具体步骤

《SpringBoot实现不同接口指定上传文件大小的具体步骤》:本文主要介绍在SpringBoot中通过自定义注解、AOP拦截和配置文件实现不同接口上传文件大小限制的方法,强调需设置全局阈值远大于... 目录一  springboot实现不同接口指定文件大小1.1 思路说明1.2 工程启动说明二 具体实施2

Python实现精确小数计算的完全指南

《Python实现精确小数计算的完全指南》在金融计算、科学实验和工程领域,浮点数精度问题一直是开发者面临的重大挑战,本文将深入解析Python精确小数计算技术体系,感兴趣的小伙伴可以了解一下... 目录引言:小数精度问题的核心挑战一、浮点数精度问题分析1.1 浮点数精度陷阱1.2 浮点数误差来源二、基础解决

Java实现在Word文档中添加文本水印和图片水印的操作指南

《Java实现在Word文档中添加文本水印和图片水印的操作指南》在当今数字时代,文档的自动化处理与安全防护变得尤为重要,无论是为了保护版权、推广品牌,还是为了在文档中加入特定的标识,为Word文档添加... 目录引言Spire.Doc for Java:高效Word文档处理的利器代码实战:使用Java为Wo

Java实现远程执行Shell指令

《Java实现远程执行Shell指令》文章介绍使用JSch在SpringBoot项目中实现远程Shell操作,涵盖环境配置、依赖引入及工具类编写,详解分号和双与号执行多指令的区别... 目录软硬件环境说明编写执行Shell指令的工具类总结jsch(Java Secure Channel)是SSH2的一个纯J

使用Python实现Word文档的自动化对比方案

《使用Python实现Word文档的自动化对比方案》我们经常需要比较两个Word文档的版本差异,无论是合同修订、论文修改还是代码文档更新,人工比对不仅效率低下,还容易遗漏关键改动,下面通过一个实际案例... 目录引言一、使用python-docx库解析文档结构二、使用difflib进行差异比对三、高级对比方