Google版EfficientDet训练自己的数据集

2024-03-18 12:38

本文主要是介绍Google版EfficientDet训练自己的数据集,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

Ubuntu 20.04,cuda10.1,TensorFlow2.1.0,python3.6.9环境,使用Google版efficientDet训练自己的数据集,预测图片,并将结果写入txt


0.项目地址

https://github.com/google/automl/tree/master/efficientdet


1.项目部署

使用docker部署环境

Ubuntu 20.04,使用TensorFlow官方提供的docker镜像部署

拉取镜像

docker pull tensorflow/tensorflow:2.1.0-gpu-py3

包含cuda10.1,TensorFlow2.1.0,python3.6.9

拉取镜像后建立容器,设置使用宿主机的gpu,并挂载目录

docker run -it  --gpus all  -v 宿主机目录:容器中的目录 tensorflow/tensorflow:2.1.0-gpu-py3  /bin/bash

下载代码

git clone https://github.com/google/automl.git

安装依赖

项目github页面中写通过pip install -r requirements.txt安装依赖,但博主安装似乎出了点问题,于是手动pip install其中的python依赖包:

absl-py>=0.7.1
matplotlib>=3.0.3
numpy>=1.16.4
Pillow>=6.0.0
PyYAML>=5.1
six>=1.12.0
tensorflow>=2.1.0
tensorflow-addons>=0.9.1
tensorflow-probability>=0.9.0

以及安装coco api的python分支:
安装coco api(https://github.com/cocodataset/cocoapi),在当前环境下编译,并将编译后PythonAPI文件夹中的内容(pycocotools等)拷贝至efficientdet目录下
目录结构为automl-master/efficientdet/pycocotools


2.数据集准备

VOC2007格式数据集

数据集制作成VOC2007格式,目录结构如下
放在efficientdet目录下,目录结构为automl-master/efficientdet/VOCdevkit2007
在这里插入图片描述

数据集转换为tfrecord格式

*制作tfrecord格式的数据集前,先在dataset/create_pascal_tfrecord.py中,将pascal_label_map_dict修改为自己的类名

制作tfrecord格式数据集

PYTHONPATH=".:$PYTHONPATH"  python dataset/create_pascal_tfrecord.py \
--data_dir=VOCdevkit2007 \
--year=VOC2007  \
--output_path=mytfrecord/pascal  \
--set=trainval

制作完成后,在文件夹下生成100个tfrecord文件和1个json文件


3.预训练模型准备

下载预训练模型,解压到efficientdet目录下
博主使用efficientdet-d2模型,目录结构形如automl-master/efficientdet/efficientdet-d2


4.训练

训练

python main.py \
--mode=train_and_eval \
--training_file_pattern=mytfrecord/*.tfrecord \
--validation_file_pattern=mytfrecord/*.tfrecord \
--val_json_file=mytfrecord/json_pascal.json \
--model_name=efficientdet-d2 \
--model_dir=tmp/efficientdet-d2  \
--ckpt=efficientdet-d2  \
--train_batch_size=4 \
--eval_batch_size=1 \
--eval_samples=512 \
--hparams="num_classes=3 " 

其中num_classes为自己数据集的类数+1,model_name等根据自己使用的模型修改
训练生成的模型存放在automl-master/efficientdet/tmp/efficientdet-d2


5.预测

计算AP

制作test集的tfrecord文件
生成的文件在mytfrecord_test文件夹下

PYTHONPATH=".:$PYTHONPATH"  python dataset/create_pascal_tfrecord.py \
--data_dir=VOCdevkit2007 \
--year=VOC2007  \
--output_path=mytfrecord_test/pascal  \
--set=test 

计算AP

python main.py --mode=eval  \
--model_name=efficientdet-d2   --model_dir=tmp/efficientdet-d2  \
--validation_file_pattern=mytfrecord_test/pascal*  \
--testdev_dir='testdev_output'  \
--hparams="num_classes=3 " 

num_classes为自己数据集类数+1

预测自己的图片

*预测前,先将inference.py中的coco_id_mapping修改为自己数据集的类名
预测

python model_inspect.py \
--runmode=infer \
--model_name=efficientdet-d2 \
--max_boxes_to_draw=100  \
--min_score_thresh=0.7  \
--ckpt_path=tmp/efficientdet-d2  \
--input_image=VOCdevkit2007/demo/*.jpg \
--output_image_dir=res_img \
--hparams="num_classes=3 "

其中input_image为待检测图片存放路径,output_image_dir为检测结果图片输出路径,num_classes为自己数据集类数+1

远程查看tensorboard

cd至tmp/efficientdet-d2目录下
打开tensorboard

tensorboard  --logdir=. --port=6006

由于博主是在远程服务器上训练,要在本地查看远程服务器上的tensorboard,可以在Xshell中进行如下的ssh隧道设置
在这里插入图片描述
设置后,重新连接并启动tensorboard,即可在本地浏览器用http://127.0.1.1:6006/访问服务器上的tensorboard

预测结果写入txt

百度网盘链接:https://pan.baidu.com/s/1NjjIr9n63n65enF5rTWhZA
提取码:icgg
↑提供博主修改的一个可以将检测结果按类别写入txt的inference.py文件,写入的格式为每一行:文件名 置信度 xmin ymin xmax ymax
如:000174 0.967 278 337 327 358
使用时直接替换原inference.py,并将coco_id_mapping修改为自己数据集的类名,使用和上面相同的预测命令可以同时输出预测后的图片和txt
输出的txt存放在efficientdet目录下,文件名为类别


6.可能出现的问题

1)预测时,出现形如这样的错误↓

tensorflow.python.framework.errors_impl.InvalidArgumentError: 2 root error(s) found.(0) Invalid argument: Input to reshape is a tensor with 884736 values, but the requested shape requires a multiple of 90[[{{node Reshape}}]][[cond_88/while/Identity/_7860]](1) Invalid argument: Input to reshape is a tensor with 884736 values, but the requested shape requires a multiple of 90[[{{node Reshape}}]]
0 successful operations.
0 derived errors ignored.

在issues218中翻到类似问题,说是在预测时加上参数--enable_ema=False解决,但尝试后发现目前版本(2020.08.04)的代码中没有这个参数,于是加上--hparams="num_classes=自己数据集的类别数+1 ",发现可以解决问题

2)计算ap时,出现形如这样的错误(issues552)↓

 Invalid argument: Incompatible shapes: [32,810,128,128] vs. [32,9,128,128] 

也可以尝试通过指定num_classes解决


参考链接
https://blog.csdn.net/jy1023408440/article/details/105638482

这篇关于Google版EfficientDet训练自己的数据集的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/822354

相关文章

SpringBoot分段处理List集合多线程批量插入数据方式

《SpringBoot分段处理List集合多线程批量插入数据方式》文章介绍如何处理大数据量List批量插入数据库的优化方案:通过拆分List并分配独立线程处理,结合Spring线程池与异步方法提升效率... 目录项目场景解决方案1.实体类2.Mapper3.spring容器注入线程池bejsan对象4.创建

PHP轻松处理千万行数据的方法详解

《PHP轻松处理千万行数据的方法详解》说到处理大数据集,PHP通常不是第一个想到的语言,但如果你曾经需要处理数百万行数据而不让服务器崩溃或内存耗尽,你就会知道PHP用对了工具有多强大,下面小编就... 目录问题的本质php 中的数据流处理:为什么必不可少生成器:内存高效的迭代方式流量控制:避免系统过载一次性

C#实现千万数据秒级导入的代码

《C#实现千万数据秒级导入的代码》在实际开发中excel导入很常见,现代社会中很容易遇到大数据处理业务,所以本文我就给大家分享一下千万数据秒级导入怎么实现,文中有详细的代码示例供大家参考,需要的朋友可... 目录前言一、数据存储二、处理逻辑优化前代码处理逻辑优化后的代码总结前言在实际开发中excel导入很

MyBatis-plus处理存储json数据过程

《MyBatis-plus处理存储json数据过程》文章介绍MyBatis-Plus3.4.21处理对象与集合的差异:对象可用内置Handler配合autoResultMap,集合需自定义处理器继承F... 目录1、如果是对象2、如果需要转换的是List集合总结对象和集合分两种情况处理,目前我用的MP的版本

GSON框架下将百度天气JSON数据转JavaBean

《GSON框架下将百度天气JSON数据转JavaBean》这篇文章主要为大家详细介绍了如何在GSON框架下实现将百度天气JSON数据转JavaBean,文中的示例代码讲解详细,感兴趣的小伙伴可以了解下... 目录前言一、百度天气jsON1、请求参数2、返回参数3、属性映射二、GSON属性映射实战1、类对象映

C# LiteDB处理时间序列数据的高性能解决方案

《C#LiteDB处理时间序列数据的高性能解决方案》LiteDB作为.NET生态下的轻量级嵌入式NoSQL数据库,一直是时间序列处理的优选方案,本文将为大家大家简单介绍一下LiteDB处理时间序列数... 目录为什么选择LiteDB处理时间序列数据第一章:LiteDB时间序列数据模型设计1.1 核心设计原则

Java+AI驱动实现PDF文件数据提取与解析

《Java+AI驱动实现PDF文件数据提取与解析》本文将和大家分享一套基于AI的体检报告智能评估方案,详细介绍从PDF上传、内容提取到AI分析、数据存储的全流程自动化实现方法,感兴趣的可以了解下... 目录一、核心流程:从上传到评估的完整链路二、第一步:解析 PDF,提取体检报告内容1. 引入依赖2. 封装

MySQL中查询和展示LONGBLOB类型数据的技巧总结

《MySQL中查询和展示LONGBLOB类型数据的技巧总结》在MySQL中LONGBLOB是一种二进制大对象(BLOB)数据类型,用于存储大量的二进制数据,:本文主要介绍MySQL中查询和展示LO... 目录前言1. 查询 LONGBLOB 数据的大小2. 查询并展示 LONGBLOB 数据2.1 转换为十

使用SpringBoot+InfluxDB实现高效数据存储与查询

《使用SpringBoot+InfluxDB实现高效数据存储与查询》InfluxDB是一个开源的时间序列数据库,特别适合处理带有时间戳的监控数据、指标数据等,下面详细介绍如何在SpringBoot项目... 目录1、项目介绍2、 InfluxDB 介绍3、Spring Boot 配置 InfluxDB4、I

Java整合Protocol Buffers实现高效数据序列化实践

《Java整合ProtocolBuffers实现高效数据序列化实践》ProtocolBuffers是Google开发的一种语言中立、平台中立、可扩展的结构化数据序列化机制,类似于XML但更小、更快... 目录一、Protocol Buffers简介1.1 什么是Protocol Buffers1.2 Pro