Google版EfficientDet训练自己的数据集

2024-03-18 12:38

本文主要是介绍Google版EfficientDet训练自己的数据集,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

Ubuntu 20.04,cuda10.1,TensorFlow2.1.0,python3.6.9环境,使用Google版efficientDet训练自己的数据集,预测图片,并将结果写入txt


0.项目地址

https://github.com/google/automl/tree/master/efficientdet


1.项目部署

使用docker部署环境

Ubuntu 20.04,使用TensorFlow官方提供的docker镜像部署

拉取镜像

docker pull tensorflow/tensorflow:2.1.0-gpu-py3

包含cuda10.1,TensorFlow2.1.0,python3.6.9

拉取镜像后建立容器,设置使用宿主机的gpu,并挂载目录

docker run -it  --gpus all  -v 宿主机目录:容器中的目录 tensorflow/tensorflow:2.1.0-gpu-py3  /bin/bash

下载代码

git clone https://github.com/google/automl.git

安装依赖

项目github页面中写通过pip install -r requirements.txt安装依赖,但博主安装似乎出了点问题,于是手动pip install其中的python依赖包:

absl-py>=0.7.1
matplotlib>=3.0.3
numpy>=1.16.4
Pillow>=6.0.0
PyYAML>=5.1
six>=1.12.0
tensorflow>=2.1.0
tensorflow-addons>=0.9.1
tensorflow-probability>=0.9.0

以及安装coco api的python分支:
安装coco api(https://github.com/cocodataset/cocoapi),在当前环境下编译,并将编译后PythonAPI文件夹中的内容(pycocotools等)拷贝至efficientdet目录下
目录结构为automl-master/efficientdet/pycocotools


2.数据集准备

VOC2007格式数据集

数据集制作成VOC2007格式,目录结构如下
放在efficientdet目录下,目录结构为automl-master/efficientdet/VOCdevkit2007
在这里插入图片描述

数据集转换为tfrecord格式

*制作tfrecord格式的数据集前,先在dataset/create_pascal_tfrecord.py中,将pascal_label_map_dict修改为自己的类名

制作tfrecord格式数据集

PYTHONPATH=".:$PYTHONPATH"  python dataset/create_pascal_tfrecord.py \
--data_dir=VOCdevkit2007 \
--year=VOC2007  \
--output_path=mytfrecord/pascal  \
--set=trainval

制作完成后,在文件夹下生成100个tfrecord文件和1个json文件


3.预训练模型准备

下载预训练模型,解压到efficientdet目录下
博主使用efficientdet-d2模型,目录结构形如automl-master/efficientdet/efficientdet-d2


4.训练

训练

python main.py \
--mode=train_and_eval \
--training_file_pattern=mytfrecord/*.tfrecord \
--validation_file_pattern=mytfrecord/*.tfrecord \
--val_json_file=mytfrecord/json_pascal.json \
--model_name=efficientdet-d2 \
--model_dir=tmp/efficientdet-d2  \
--ckpt=efficientdet-d2  \
--train_batch_size=4 \
--eval_batch_size=1 \
--eval_samples=512 \
--hparams="num_classes=3 " 

其中num_classes为自己数据集的类数+1,model_name等根据自己使用的模型修改
训练生成的模型存放在automl-master/efficientdet/tmp/efficientdet-d2


5.预测

计算AP

制作test集的tfrecord文件
生成的文件在mytfrecord_test文件夹下

PYTHONPATH=".:$PYTHONPATH"  python dataset/create_pascal_tfrecord.py \
--data_dir=VOCdevkit2007 \
--year=VOC2007  \
--output_path=mytfrecord_test/pascal  \
--set=test 

计算AP

python main.py --mode=eval  \
--model_name=efficientdet-d2   --model_dir=tmp/efficientdet-d2  \
--validation_file_pattern=mytfrecord_test/pascal*  \
--testdev_dir='testdev_output'  \
--hparams="num_classes=3 " 

num_classes为自己数据集类数+1

预测自己的图片

*预测前,先将inference.py中的coco_id_mapping修改为自己数据集的类名
预测

python model_inspect.py \
--runmode=infer \
--model_name=efficientdet-d2 \
--max_boxes_to_draw=100  \
--min_score_thresh=0.7  \
--ckpt_path=tmp/efficientdet-d2  \
--input_image=VOCdevkit2007/demo/*.jpg \
--output_image_dir=res_img \
--hparams="num_classes=3 "

其中input_image为待检测图片存放路径,output_image_dir为检测结果图片输出路径,num_classes为自己数据集类数+1

远程查看tensorboard

cd至tmp/efficientdet-d2目录下
打开tensorboard

tensorboard  --logdir=. --port=6006

由于博主是在远程服务器上训练,要在本地查看远程服务器上的tensorboard,可以在Xshell中进行如下的ssh隧道设置
在这里插入图片描述
设置后,重新连接并启动tensorboard,即可在本地浏览器用http://127.0.1.1:6006/访问服务器上的tensorboard

预测结果写入txt

百度网盘链接:https://pan.baidu.com/s/1NjjIr9n63n65enF5rTWhZA
提取码:icgg
↑提供博主修改的一个可以将检测结果按类别写入txt的inference.py文件,写入的格式为每一行:文件名 置信度 xmin ymin xmax ymax
如:000174 0.967 278 337 327 358
使用时直接替换原inference.py,并将coco_id_mapping修改为自己数据集的类名,使用和上面相同的预测命令可以同时输出预测后的图片和txt
输出的txt存放在efficientdet目录下,文件名为类别


6.可能出现的问题

1)预测时,出现形如这样的错误↓

tensorflow.python.framework.errors_impl.InvalidArgumentError: 2 root error(s) found.(0) Invalid argument: Input to reshape is a tensor with 884736 values, but the requested shape requires a multiple of 90[[{{node Reshape}}]][[cond_88/while/Identity/_7860]](1) Invalid argument: Input to reshape is a tensor with 884736 values, but the requested shape requires a multiple of 90[[{{node Reshape}}]]
0 successful operations.
0 derived errors ignored.

在issues218中翻到类似问题,说是在预测时加上参数--enable_ema=False解决,但尝试后发现目前版本(2020.08.04)的代码中没有这个参数,于是加上--hparams="num_classes=自己数据集的类别数+1 ",发现可以解决问题

2)计算ap时,出现形如这样的错误(issues552)↓

 Invalid argument: Incompatible shapes: [32,810,128,128] vs. [32,9,128,128] 

也可以尝试通过指定num_classes解决


参考链接
https://blog.csdn.net/jy1023408440/article/details/105638482

这篇关于Google版EfficientDet训练自己的数据集的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/822354

相关文章

使用Python开发一个Ditto剪贴板数据导出工具

《使用Python开发一个Ditto剪贴板数据导出工具》在日常工作中,我们经常需要处理大量的剪贴板数据,下面将介绍如何使用Python的wxPython库开发一个图形化工具,实现从Ditto数据库中读... 目录前言运行结果项目需求分析技术选型核心功能实现1. Ditto数据库结构分析2. 数据库自动定位3

pandas数据的合并concat()和merge()方式

《pandas数据的合并concat()和merge()方式》Pandas中concat沿轴合并数据框(行或列),merge基于键连接(内/外/左/右),concat用于纵向或横向拼接,merge用于... 目录concat() 轴向连接合并(1) join='outer',axis=0(2)join='o

批量导入txt数据到的redis过程

《批量导入txt数据到的redis过程》用户通过将Redis命令逐行写入txt文件,利用管道模式运行客户端,成功执行批量删除以Product*匹配的Key操作,提高了数据清理效率... 目录批量导入txt数据到Redisjs把redis命令按一条 一行写到txt中管道命令运行redis客户端成功了批量删除k

SpringBoot多环境配置数据读取方式

《SpringBoot多环境配置数据读取方式》SpringBoot通过环境隔离机制,支持properties/yaml/yml多格式配置,结合@Value、Environment和@Configura... 目录一、多环境配置的核心思路二、3种配置文件格式详解2.1 properties格式(传统格式)1.

解决pandas无法读取csv文件数据的问题

《解决pandas无法读取csv文件数据的问题》本文讲述作者用Pandas读取CSV文件时因参数设置不当导致数据错位,通过调整delimiter和on_bad_lines参数最终解决问题,并强调正确参... 目录一、前言二、问题复现1. 问题2. 通过 on_bad_lines=‘warn’ 跳过异常数据3

C#监听txt文档获取新数据方式

《C#监听txt文档获取新数据方式》文章介绍通过监听txt文件获取最新数据,并实现开机自启动、禁用窗口关闭按钮、阻止Ctrl+C中断及防止程序退出等功能,代码整合于主函数中,供参考学习... 目录前言一、监听txt文档增加数据二、其他功能1. 设置开机自启动2. 禁止控制台窗口关闭按钮3. 阻止Ctrl +

java如何实现高并发场景下三级缓存的数据一致性

《java如何实现高并发场景下三级缓存的数据一致性》这篇文章主要为大家详细介绍了java如何实现高并发场景下三级缓存的数据一致性,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 下面代码是一个使用Java和Redisson实现的三级缓存服务,主要功能包括:1.缓存结构:本地缓存:使

在MySQL中实现冷热数据分离的方法及使用场景底层原理解析

《在MySQL中实现冷热数据分离的方法及使用场景底层原理解析》MySQL冷热数据分离通过分表/分区策略、数据归档和索引优化,将频繁访问的热数据与冷数据分开存储,提升查询效率并降低存储成本,适用于高并发... 目录实现冷热数据分离1. 分表策略2. 使用分区表3. 数据归档与迁移在mysql中实现冷热数据分

C#解析JSON数据全攻略指南

《C#解析JSON数据全攻略指南》这篇文章主要为大家详细介绍了使用C#解析JSON数据全攻略指南,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录一、为什么jsON是C#开发必修课?二、四步搞定网络JSON数据1. 获取数据 - HttpClient最佳实践2. 动态解析 - 快速

MyBatis-Plus通用中等、大量数据分批查询和处理方法

《MyBatis-Plus通用中等、大量数据分批查询和处理方法》文章介绍MyBatis-Plus分页查询处理,通过函数式接口与Lambda表达式实现通用逻辑,方法抽象但功能强大,建议扩展分批处理及流式... 目录函数式接口获取分页数据接口数据处理接口通用逻辑工具类使用方法简单查询自定义查询方法总结函数式接口