利用tensorflow使用预训练神经网络(VGG16)来训练模型

2024-03-29 13:48

本文主要是介绍利用tensorflow使用预训练神经网络(VGG16)来训练模型,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

利用tensorflow使用预训练神经网络(VGG16)来训练模型

文章目录

  • 利用tensorflow使用预训练神经网络(VGG16)来训练模型
    • 1.预训练神经网络是什么
    • 2.数据集以及神经网络的选择
    • 3.VGG网络架构
    • 4.确定我们训练的步骤
    • 5.微调
    • 结语

1.预训练神经网络是什么

​ 预训练神经网络是提前在大型数据库上训练过的网络,他蕴含了在大型数据集上训练过的权重,我们可以将他迁移到小型数据集上从而得到较高的准确率,举个例子来说,原本的神经网络是对几百种分类的大型数据上进行学习的,我们得到的训练模型含有获得的权重,我们将他迁移到只有几种分类的小型数据上从而来完成分类识别任务,这种又称迁移学习(**所谓迁移学习,**或者领域适应Domain Adaptation,一般就是要将从源领域(Source Domain)学习到的东西应用到目标领域(Target Domain)上去。源领域和目标领域之间往往有gap/domain discrepancy(源领域的数据和目标领域的数据遵循不同的分布)。

迁移学习能够将适用于大数据的模型迁移到小数据上,实现个性化迁移。)。

​ 那么,这里我们有一个明显的问题,熟悉神经网络的同学肯定知道,我们训练模型其实就是训练神经网络的各项参数,让每个权重结合,最终能够完成得到正确的输出,那么举个例子,在识别椅子与桌子之上的数据集,上的网络,得到的权重为什么能够应用到其他地方上呢。其实我们换个地方想想,在分类问题上,到底几个分类是由我们最后的输出层(全连接层,分类器)决定的,而底层的卷积层只是负责提取特征,所以我们可以使用预训练神经网络的卷积来提取特征(这里是出自网络上搜索,以及询问他人得到的理解,有不对的可以评论区指正)。

2.数据集以及神经网络的选择

​ 迁移学习一般用于小型数据上的识别上,我们再把著名的猫狗大战数据集拿出来,该数据集由一堆猫和狗的图片来识别在这里插入图片描述

然后在预训练神经网络上的选择,keras其中有提供一堆预训练神经网络,比如有VGG16,VGG19,ResNet,Xception等许多预训练神经网络,我们在这里选择VGG16(其实我也有拿Xception来测试,但可能由于代码编写问题,拿到的正确率并不高,所以我并没有拿出来。。。)

3.VGG网络架构

​ VGG是一个十分经典的神经网络,网上资料很多,我找到了该网络的各项模型的具体层结构:在这里插入图片描述

我们其实可以自己来构造当前的层次,但是我们使用预训练神经网络主要是想拿他训练过的权重,

4.确定我们训练的步骤

​ 预训练神经网络步骤可以由下面这张图来确定在这里插入图片描述

我们获得预训练网络的卷积基,抛弃直接连接输出的分类器,冻结卷积基(就是不让卷积基础的权重随着后面我们新加入的分类器传播的权重来改变),然后根据具体的任务我们添加新的分类器从而来进行新的训练

所以我们直接编写代码如下:

import tensorflow as tf
from tensorflow import keras
from keras import layers
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import glob
import os
#提取新的与训练神经网络我们不要输出层,只要中间提取特征的卷积层
conv_base=keras.applications.VGG16(weights='imagenet',include_top=False)#使用在imageNet上使用的权重
#include_top表示是否需要哪些全连接层
#查看一下获取的网络结构
conv_base.summary()

在这里插入图片描述

然后我们根据具体的分类器来调整我们最后的输出层,最后是二分类问题,所以我们构造后面的分类器如下:

model=keras.Sequential()
#因为我们使用了已经提前训练好的参数,我们并不希望该权重改变,所以我们要将该权重设置为不可训练
conv_base.trainable=False
model.add(conv_base)
model.add(layers.GlobalAveragePooling2D())
model.add(layers.Dense(512,activation='relu'))
model.add(layers.Dense(1,activation='sigmoid'))
model.summary()

之后我们开始训练(一些其他数据读取与训练代码与博主上期博客一样),我们查看每次训练的结果:

------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------->

Epoch 1: loss: 0.329, accuracy: 0.855

------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------->

Epoch 2: loss: 0.250, accuracy: 0.893

可以看到仅仅只是,两次训练

准确率便能接近达到0.893,但是在接下来的训练上准确率是在90左右,无法上升,达到过拟合了,我们为了进一步提高准确率,可以使用微调方法。

5.微调

刚才我们说过底层的卷积基是负责提取特征的部分,含有的权重是可以迁移过来直接使用的,那么接近输出层的卷积层的权重,是不是意味着可以调整。使得我们增加我们的准确率,但是有一个前提是我们的分类器必须提前训练好的,不然随机初始化后的分类器可能会破坏我们调整的卷积基。所以我们接下来的微调都是在第一轮训练过的分类器

#使得一些高层的卷积基可以训练
conv_base.trainable=True
for layers in conv_base.layers[:-3]:layers.trainable=False
optimizer=keras.optimizers.Adam(0.00005)#同时需要调整更低的学习速率

然后我们开始我们第二波的训练

------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------->

Epoch 1: loss: 0.196, accuracy: 0.915

------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------->

Epoch 2: loss: 0.136, accuracy: 0.944

可以看到,我们的正确率开始迅速上升达到95左右,说明我们的微调还是十分有用的。

结语

对于本次使用预训练神经网络中,我省略了一些代码,重点介绍了如何使用预训练神经网络以及微调上,有什么问题和独到的见解可以评论区指正,谢谢。

这篇关于利用tensorflow使用预训练神经网络(VGG16)来训练模型的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/858759

相关文章

使用Python实现IP地址和端口状态检测与监控

《使用Python实现IP地址和端口状态检测与监控》在网络运维和服务器管理中,IP地址和端口的可用性监控是保障业务连续性的基础需求,本文将带你用Python从零打造一个高可用IP监控系统,感兴趣的小伙... 目录概述:为什么需要IP监控系统使用步骤说明1. 环境准备2. 系统部署3. 核心功能配置系统效果展

使用Java将各种数据写入Excel表格的操作示例

《使用Java将各种数据写入Excel表格的操作示例》在数据处理与管理领域,Excel凭借其强大的功能和广泛的应用,成为了数据存储与展示的重要工具,在Java开发过程中,常常需要将不同类型的数据,本文... 目录前言安装免费Java库1. 写入文本、或数值到 Excel单元格2. 写入数组到 Excel表格

redis中使用lua脚本的原理与基本使用详解

《redis中使用lua脚本的原理与基本使用详解》在Redis中使用Lua脚本可以实现原子性操作、减少网络开销以及提高执行效率,下面小编就来和大家详细介绍一下在redis中使用lua脚本的原理... 目录Redis 执行 Lua 脚本的原理基本使用方法使用EVAL命令执行 Lua 脚本使用EVALSHA命令

Java 中的 @SneakyThrows 注解使用方法(简化异常处理的利与弊)

《Java中的@SneakyThrows注解使用方法(简化异常处理的利与弊)》为了简化异常处理,Lombok提供了一个强大的注解@SneakyThrows,本文将详细介绍@SneakyThro... 目录1. @SneakyThrows 简介 1.1 什么是 Lombok?2. @SneakyThrows

使用Python和Pyecharts创建交互式地图

《使用Python和Pyecharts创建交互式地图》在数据可视化领域,创建交互式地图是一种强大的方式,可以使受众能够以引人入胜且信息丰富的方式探索地理数据,下面我们看看如何使用Python和Pyec... 目录简介Pyecharts 简介创建上海地图代码说明运行结果总结简介在数据可视化领域,创建交互式地

Java Stream流使用案例深入详解

《JavaStream流使用案例深入详解》:本文主要介绍JavaStream流使用案例详解,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录前言1. Lambda1.1 语法1.2 没参数只有一条语句或者多条语句1.3 一个参数只有一条语句或者多

Java Spring 中 @PostConstruct 注解使用原理及常见场景

《JavaSpring中@PostConstruct注解使用原理及常见场景》在JavaSpring中,@PostConstruct注解是一个非常实用的功能,它允许开发者在Spring容器完全初... 目录一、@PostConstruct 注解概述二、@PostConstruct 注解的基本使用2.1 基本代

C#使用StackExchange.Redis实现分布式锁的两种方式介绍

《C#使用StackExchange.Redis实现分布式锁的两种方式介绍》分布式锁在集群的架构中发挥着重要的作用,:本文主要介绍C#使用StackExchange.Redis实现分布式锁的... 目录自定义分布式锁获取锁释放锁自动续期StackExchange.Redis分布式锁获取锁释放锁自动续期分布式

springboot使用Scheduling实现动态增删启停定时任务教程

《springboot使用Scheduling实现动态增删启停定时任务教程》:本文主要介绍springboot使用Scheduling实现动态增删启停定时任务教程,具有很好的参考价值,希望对大家有... 目录1、配置定时任务需要的线程池2、创建ScheduledFuture的包装类3、注册定时任务,增加、删

使用Python实现矢量路径的压缩、解压与可视化

《使用Python实现矢量路径的压缩、解压与可视化》在图形设计和Web开发中,矢量路径数据的高效存储与传输至关重要,本文将通过一个Python示例,展示如何将复杂的矢量路径命令序列压缩为JSON格式,... 目录引言核心功能概述1. 路径命令解析2. 路径数据压缩3. 路径数据解压4. 可视化代码实现详解1