图像识别完整项目之Swin-Transformer,从获取关键词数据集到训练的完整过程

本文主要是介绍图像识别完整项目之Swin-Transformer,从获取关键词数据集到训练的完整过程,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

0. 前言

图像分类的大部分经典神经网络已经全部介绍完,并且已经作了测试

代码已经全部上传到资源,根据文章名或者关键词搜索即可

LeNet :pytorch 搭建 LeNet 网络对 CIFAR-10 图片分类

AlexNet : pytorch 搭建AlexNet 对花进行分类

Vgg : pytorch 搭建 VGG 网络

GoogLeNet : pytorch 搭建GoogLeNet

ResNet : ResNet 训练CIFAR10数据集,并做图片分类


关于轻量级网络

MobileNet 系列:

  • V1 :MobileNet V1 图像分类
  • V2 :MobileNet V2 图像分类
  • V3 :MobileNet V3 图像分类

ShuffleNet 系列:

  • V1 : ShuffleNet V1 对花数据集训练
  • V2 : ShuffleNet V2 迁移学习对花数据集训练

EfficientNet 系列:

  • V1 :EfficientNet 分类花数据集
  • V2 :EfficientNet V2 

Swin-Transformer :Swin-Transformer 在图像识别中的应用


本章将根据 Swin-Transformer 网络对图像分类ending,包括如何获取数据集,训练网络、预测图像等等。

本文从头实现对Marvel superhero 进行分类记录,项目下载在后面

代码尽量简单,小白均可运行,不需要定义复杂的变量

网络精度高,采用迁移学习

1. 项目目录

文件目录如下所示:

注:项目的文件夹和代码不可更改,要不然会报错,至于超参数的更改下面会介绍!!

inference 是预测的文件夹,将预测的图像放在该文件夹下,可以实现批预测

my_dataset_from_net 爬虫脚本,可以自动从网络上下载图片

run_results  网络训练之后生成的信息,包括类别json文件、loss和accuracy精度曲线、学习率衰减曲线、训练过程日志、已经训练集和测试集的混淆矩阵

weights 下面存放的是Swin-Transformer 的预训练权重

py 文件:

  • model Swin-Transformer 网络
  • predict 预测脚本
  • process_data 根据爬虫下载的图片,自动划分训练集和测试,并且提出损坏图像
  • train 训练部分
  • utils 工具函数

详细的可以参考README 文件

2. 获取数据集

当然最开始要配置好环境和requirements.txt 文件

获取数据集在 my_dataset_from_net 文件下,运行文件下的main.py 可以得到:

脚本会自动在该文件下生成download_images文件目录,然后会根据关键词生成子文件夹

批下载的话,可以新建txt文件,按照这样操作就行:

按照下面操作:

选中baidu API ,load file就是刚刚新建的txt文件

Max number per keywords 就是每个关键词下载的图像个数,Threads 最好设定小一点,否则可能会漏下载

下载过程如下:

下载完成如下:

3. 对下载的图像处理、划分训练集和测试集

代码是 process_data.py 文件,因为代码用中文可能报错,这里要将文件夹改成英文

该脚本会自动删除那些 PIL 打不开的文件

代码会自动将每个子文件夹下按照 0.2比例划分测试集 

运行 process_data.py 结果如下:

代码会在主目录下生成数据

4. 开始训练

训练代码是 train.py 文件

4.1 超参数设定

超参数如下:

关于--freeze-layers,设定为True,只会训练MLP权重。False会训练全部网络

    parser = argparse.ArgumentParser()parser.add_argument('--epochs', type=int, default=100)parser.add_argument('--batch-size', type=int, default=32)parser.add_argument('--lr', type=float, default=0.0001)parser.add_argument('--lrf', type=float, default=0.1)parser.add_argument('--freeze-layers', type=bool, default=False)     # 是否冻结权重

至于分类的个数啊、对应标签json文件等等,这里使用 datasets.ImageFolder,代码会自动生成,不需要设定!!

只需要更改上面超参数就行!!

4.2 训练过程

将train这部分代码放开,可以查看网络训练图像信息

如下:

训练过程:

代码会自动计算分类的类别个数

训练结果:

4.3 生成的训练日志

生成的结果全部保存在run_results目录下:

json 文件:

loss-accuracy-curve:

学习率衰减曲线:

训练集和测试集的混淆矩阵:

训练日志:

5. 预测脚本

预测脚本在 inference 中,predict.py 会预测该目录下所有图片

不需要任何更改!!

运行 predict.py结果如下:

结果展示:

6. 项目的一些问题和下载

完整项目下载:图像识别完整项目之Swin-Transformer,从获取关键词数据集到训练的完整过程

爬虫下载图片的时候,下载的数目往往和设定的不一致,这个只需要将数目调大就行。事实上,本项目每个类别仅有200多张图片仍能有不错的表现

爬虫下载的图片有时候会出现不能打开的错误,但是在process_data脚本处理的时候,是没有报错的。

训练过程也没有出现错误,可能是process_data脚本的问题

如果不放心,可以手动删除,

预测的时候,因为预处理train mean和train std的原因,会计算的很慢,如果将项目部署的话,可以手动设定

这篇关于图像识别完整项目之Swin-Transformer,从获取关键词数据集到训练的完整过程的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/504636

相关文章

springboot项目中整合高德地图的实践

《springboot项目中整合高德地图的实践》:本文主要介绍springboot项目中整合高德地图的实践,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一:高德开放平台的使用二:创建数据库(我是用的是mysql)三:Springboot所需的依赖(根据你的需求再

一文详解如何在idea中快速搭建一个Spring Boot项目

《一文详解如何在idea中快速搭建一个SpringBoot项目》IntelliJIDEA作为Java开发者的‌首选IDE‌,深度集成SpringBoot支持,可一键生成项目骨架、智能配置依赖,这篇文... 目录前言1、创建项目名称2、勾选需要的依赖3、在setting中检查maven4、编写数据源5、开启热

SQL Server修改数据库名及物理数据文件名操作步骤

《SQLServer修改数据库名及物理数据文件名操作步骤》在SQLServer中重命名数据库是一个常见的操作,但需要确保用户具有足够的权限来执行此操作,:本文主要介绍SQLServer修改数据... 目录一、背景介绍二、操作步骤2.1 设置为单用户模式(断开连接)2.2 修改数据库名称2.3 查找逻辑文件名

C++中RAII资源获取即初始化

《C++中RAII资源获取即初始化》RAII通过构造/析构自动管理资源生命周期,确保安全释放,本文就来介绍一下C++中的RAII技术及其应用,具有一定的参考价值,感兴趣的可以了解一下... 目录一、核心原理与机制二、标准库中的RAII实现三、自定义RAII类设计原则四、常见应用场景1. 内存管理2. 文件操

SpringBoot项目配置logback-spring.xml屏蔽特定路径的日志

《SpringBoot项目配置logback-spring.xml屏蔽特定路径的日志》在SpringBoot项目中,使用logback-spring.xml配置屏蔽特定路径的日志有两种常用方式,文中的... 目录方案一:基础配置(直接关闭目标路径日志)方案二:结合 Spring Profile 按环境屏蔽关

canal实现mysql数据同步的详细过程

《canal实现mysql数据同步的详细过程》:本文主要介绍canal实现mysql数据同步的详细过程,本文通过实例图文相结合给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的... 目录1、canal下载2、mysql同步用户创建和授权3、canal admin安装和启动4、canal

MySQL存储过程之循环遍历查询的结果集详解

《MySQL存储过程之循环遍历查询的结果集详解》:本文主要介绍MySQL存储过程之循环遍历查询的结果集,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录前言1. 表结构2. 存储过程3. 关于存储过程的SQL补充总结前言近来碰到这样一个问题:在生产上导入的数据发现

SpringBoot集成LiteFlow实现轻量级工作流引擎的详细过程

《SpringBoot集成LiteFlow实现轻量级工作流引擎的详细过程》LiteFlow是一款专注于逻辑驱动流程编排的轻量级框架,它以组件化方式快速构建和执行业务流程,有效解耦复杂业务逻辑,下面给大... 目录一、基础概念1.1 组件(Component)1.2 规则(Rule)1.3 上下文(Conte

SpringBoot服务获取Pod当前IP的两种方案

《SpringBoot服务获取Pod当前IP的两种方案》在Kubernetes集群中,SpringBoot服务获取Pod当前IP的方案主要有两种,通过环境变量注入或通过Java代码动态获取网络接口IP... 目录方案一:通过 Kubernetes Downward API 注入环境变量原理步骤方案二:通过

使用SpringBoot整合Sharding Sphere实现数据脱敏的示例

《使用SpringBoot整合ShardingSphere实现数据脱敏的示例》ApacheShardingSphere数据脱敏模块,通过SQL拦截与改写实现敏感信息加密存储,解决手动处理繁琐及系统改... 目录痛点一:痛点二:脱敏配置Quick Start——Spring 显示配置:1.引入依赖2.创建脱敏