detectron2 DiffusionDet 训练自己的数据集

2024-03-07 04:52

本文主要是介绍detectron2 DiffusionDet 训练自己的数据集,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

配环境
git clone https://github.com/ShoufaChen/DiffusionDet# 创建环境
conda create -n diffusion python=3.9
conda activate diffusion
conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch
pip install opencv-python# 安装detectron2
cd /data2/zy/DiffusionDet/
git clone https://github.com/facebookresearch/detectron2.git
python -m pip install -e detectron2pip install timm # 不装就会报错 No module named 'timm' (diffusion) 
prepare datasets
mkdir -p datasets/coco
mkdir -p datasets/lvisln -s /path_to_coco_dataset/annotations datasets/coco/annotations
ln -s /path_to_coco_dataset/train2017 datasets/coco/train2017
ln -s /path_to_coco_dataset/val2017 datasets/coco/val2017
修改配置文件等

复制一份train_net.py,命名为train.py,在其中添加下列代码注册数据集

#引入以下注释
from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.data.datasets.coco import load_coco_json
import pycocotools
#声明类别,尽量保持
CLASS_NAMES =["__background__","Inlet","Slightshort","Generalshort","Severeshort","Outlet"]
# 数据集路径
DATASET_ROOT = '/data2/zy/DiffusionDet/datasets/coco/'
ANN_ROOT = os.path.join(DATASET_ROOT, 'annotations')TRAIN_PATH = os.path.join(DATASET_ROOT, 'train2017')
VAL_PATH = os.path.join(DATASET_ROOT, 'val2017')
TEST_PATH = os.path.join(DATASET_ROOT, 'test2017')TRAIN_JSON = os.path.join(ANN_ROOT, 'instances_train2017.json')
VAL_JSON = os.path.join(ANN_ROOT, 'instances_val2017.json')
TEST_JSON = os.path.join(ANN_ROOT, 'instances_test2017.json')# 声明数据集的子集
PREDEFINED_SPLITS_DATASET = {"coco_my_train": (TRAIN_PATH, TRAIN_JSON),"coco_my_val": (VAL_PATH, VAL_JSON),
}
#===========以下有两种注册数据集的方法,本人直接用的第二个plain_register_dataset的方式 也可以用register_dataset的形式==================
#注册数据集(这一步就是将自定义数据集注册进Detectron2)
def register_dataset():"""purpose: register all splits of dataset with PREDEFINED_SPLITS_DATASET"""for key, (image_root, json_file) in PREDEFINED_SPLITS_DATASET.items():register_dataset_instances(name=key,json_file=json_file,image_root=image_root)#注册数据集实例,加载数据集中的对象实例
def register_dataset_instances(name, json_file, image_root):"""purpose: register dataset to DatasetCatalog,register metadata to MetadataCatalog and set attribute"""DatasetCatalog.register(name, lambda: load_coco_json(json_file, image_root, name))MetadataCatalog.get(name).set(json_file=json_file,image_root=image_root,evaluator_type="coco")#=============================
# 注册数据集和元数据
def plain_register_dataset():#训练集DatasetCatalog.register("coco_my_train", lambda: load_coco_json(TRAIN_JSON, TRAIN_PATH))MetadataCatalog.get("coco_my_train").set(thing_classes=CLASS_NAMES,  # 可以选择开启,但是不能显示中文,这里需要注意,中文的话最好关闭evaluator_type='coco', # 指定评估方式json_file=TRAIN_JSON,image_root=TRAIN_PATH)#DatasetCatalog.register("coco_my_val", lambda: load_coco_json(VAL_JSON, VAL_PATH, "coco_2017_val"))#验证/测试集DatasetCatalog.register("coco_my_val", lambda: load_coco_json(VAL_JSON, VAL_PATH))MetadataCatalog.get("coco_my_val").set(thing_classes=CLASS_NAMES, # 可以选择开启,但是不能显示中文,这里需要注意,中文的话最好关闭evaluator_type='coco', # 指定评估方式json_file=VAL_JSON,image_root=VAL_PATH)
# 查看数据集标注,可视化检查数据集标注是否正确,
#这个也可以自己写脚本判断,其实就是判断标注框是否超越图像边界
#可选择使用此方法
def checkout_dataset_annotation(name="coco_my_val"):#dataset_dicts = load_coco_json(TRAIN_JSON, TRAIN_PATH, name)dataset_dicts = load_coco_json(TRAIN_JSON, TRAIN_PATH)print(len(dataset_dicts))for i, d in enumerate(dataset_dicts,0):#print(d)img = cv2.imread(d["file_name"])visualizer = Visualizer(img[:, :, ::-1], metadata=MetadataCatalog.get(name), scale=1.5)vis = visualizer.draw_dataset_dict(d)#cv2.imshow('show', vis.get_image()[:, :, ::-1])cv2.imwrite('out/'+str(i) + '.jpg',vis.get_image()[:, :, ::-1])#cv2.waitKey(0)if i == 200:break

main中调用注册函数

def main(args):cfg = setup(args)register_dataset() # here to registerif args.eval_only:model = Trainer.build_model(cfg)kwargs = may_get_ema_checkpointer(cfg, model)if cfg.MODEL_EMA.ENABLED:EMADetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR, **kwargs).resume_or_load(cfg.MODEL.WEIGHTS,resume=args.resume)else:DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR, **kwargs).resume_or_load(cfg.MODEL.WEIGHTS,resume=args.resume)res = Trainer.ema_test(cfg, model)if cfg.TEST.AUG.ENABLED:res.update(Trainer.test_with_TTA(cfg, model))if comm.is_main_process():verify_results(cfg, res)return restrainer = Trainer(cfg)trainer.resume_or_load(resume=args.resume)return trainer.train()

在 DiffisionDet/configs 下新建demo.yaml,主要是修改batchsize和max_iter

_BASE_: "Base-DiffusionDet.yaml"
MODEL:WEIGHTS: "detectron2://ImageNetPretrained/torchvision/R-50.pkl"RESNETS:DEPTH: 50STRIDE_IN_1X1: FalseDiffusionDet:NUM_PROPOSALS: 100NUM_CLASSES: 5
DATASETS:TRAIN: ("coco_my_train",)TEST:  ("coco_my_val",)
SOLVER:IMS_PER_BATCH: 16BASE_LR: 0.000025STEPS: (5850, 7000)MAX_ITER: 7500# TOTAL_NUM_IMAGES / (IMS_PER_BATCH * NUM_GPUS) * num_epochs = MAX_ITER# 2000/(16*1)*60=7500 
INPUT:MIN_SIZE_TRAIN: (800,)CROP:ENABLED: FalseFORMAT: "RGB"
OUTPUT_DIR: ./OUTPUT/bs16
训练
 python train.py --num-gpus 1     --config-file configs/diffdet.coco.res50.yaml

这篇关于detectron2 DiffusionDet 训练自己的数据集的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/782414

相关文章

SQL Server修改数据库名及物理数据文件名操作步骤

《SQLServer修改数据库名及物理数据文件名操作步骤》在SQLServer中重命名数据库是一个常见的操作,但需要确保用户具有足够的权限来执行此操作,:本文主要介绍SQLServer修改数据... 目录一、背景介绍二、操作步骤2.1 设置为单用户模式(断开连接)2.2 修改数据库名称2.3 查找逻辑文件名

canal实现mysql数据同步的详细过程

《canal实现mysql数据同步的详细过程》:本文主要介绍canal实现mysql数据同步的详细过程,本文通过实例图文相结合给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的... 目录1、canal下载2、mysql同步用户创建和授权3、canal admin安装和启动4、canal

使用SpringBoot整合Sharding Sphere实现数据脱敏的示例

《使用SpringBoot整合ShardingSphere实现数据脱敏的示例》ApacheShardingSphere数据脱敏模块,通过SQL拦截与改写实现敏感信息加密存储,解决手动处理繁琐及系统改... 目录痛点一:痛点二:脱敏配置Quick Start——Spring 显示配置:1.引入依赖2.创建脱敏

详解如何使用Python构建从数据到文档的自动化工作流

《详解如何使用Python构建从数据到文档的自动化工作流》这篇文章将通过真实工作场景拆解,为大家展示如何用Python构建自动化工作流,让工具代替人力完成这些数字苦力活,感兴趣的小伙伴可以跟随小编一起... 目录一、Excel处理:从数据搬运工到智能分析师二、PDF处理:文档工厂的智能生产线三、邮件自动化:

Python数据分析与可视化的全面指南(从数据清洗到图表呈现)

《Python数据分析与可视化的全面指南(从数据清洗到图表呈现)》Python是数据分析与可视化领域中最受欢迎的编程语言之一,凭借其丰富的库和工具,Python能够帮助我们快速处理、分析数据并生成高质... 目录一、数据采集与初步探索二、数据清洗的七种武器1. 缺失值处理策略2. 异常值检测与修正3. 数据

pandas实现数据concat拼接的示例代码

《pandas实现数据concat拼接的示例代码》pandas.concat用于合并DataFrame或Series,本文主要介绍了pandas实现数据concat拼接的示例代码,具有一定的参考价值,... 目录语法示例:使用pandas.concat合并数据默认的concat:参数axis=0,join=

C#代码实现解析WTGPS和BD数据

《C#代码实现解析WTGPS和BD数据》在现代的导航与定位应用中,准确解析GPS和北斗(BD)等卫星定位数据至关重要,本文将使用C#语言实现解析WTGPS和BD数据,需要的可以了解下... 目录一、代码结构概览1. 核心解析方法2. 位置信息解析3. 经纬度转换方法4. 日期和时间戳解析5. 辅助方法二、L

使用Python和Matplotlib实现可视化字体轮廓(从路径数据到矢量图形)

《使用Python和Matplotlib实现可视化字体轮廓(从路径数据到矢量图形)》字体设计和矢量图形处理是编程中一个有趣且实用的领域,通过Python的matplotlib库,我们可以轻松将字体轮廓... 目录背景知识字体轮廓的表示实现步骤1. 安装依赖库2. 准备数据3. 解析路径指令4. 绘制图形关键

解决mysql插入数据锁等待超时报错:Lock wait timeout exceeded;try restarting transaction

《解决mysql插入数据锁等待超时报错:Lockwaittimeoutexceeded;tryrestartingtransaction》:本文主要介绍解决mysql插入数据锁等待超时报... 目录报错信息解决办法1、数据库中执行如下sql2、再到 INNODB_TRX 事务表中查看总结报错信息Lock

使用C#删除Excel表格中的重复行数据的代码详解

《使用C#删除Excel表格中的重复行数据的代码详解》重复行是指在Excel表格中完全相同的多行数据,删除这些重复行至关重要,因为它们不仅会干扰数据分析,还可能导致错误的决策和结论,所以本文给大家介绍... 目录简介使用工具C# 删除Excel工作表中的重复行语法工作原理实现代码C# 删除指定Excel单元