pytorch-YOLOv4训练自己的数据集

2024-02-19 06:38
文章标签 数据 训练 pytorch yolov4

本文主要是介绍pytorch-YOLOv4训练自己的数据集,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

题记:之前用YOLOv3训练了自己的数据集,详见该博客,Darknet--Yolov3训练自己的数据。准备试试YOLOv4,试试看校测效果是否再提高,因需要,用的pytorch版本。

一、下载代码

1、下载项目代码

git clone https://github.com/Tianxiaomo/pytorch-YOLOv4.git
cd pytorch-YOLOv4

2、下载预训练模型

在该目录下新建文件夹weight,用于存放权重文件,下载链接如下,下载后存放在weight文件夹下

  • baidu
    • yolov4.pth(https://pan.baidu.com/s/1ZroDvoGScDgtE1ja_QqJVw Extraction code:xrq9)
    • yolov4.conv.137.pth(https://pan.baidu.com/s/1ovBie4YyVQQoUrC3AY0joA Extraction code:kcel)
  • google
    • yolov4.pth(https://drive.google.com/open?id=1wv_LiFeCRYwtpkqREPeI13-gPELBDwuJ)
    • yolov4.conv.137.pth(https://drive.google.com/open?id=1fcbR0bWzYfIEdLJPzOsn4R5mlvR6IQyA)

二、配置环境

1、考虑到不同项目要求环境不同,直接在anaconda下重新创建一个虚拟环境,该环境起名为yolov4

conda create -n yolov4 python=3.6

激活该环境:

conda activate yolo_env

如需关闭环境,可用如下命令:

conda deactivate

2、安装pytorch

进入pytorch官网:https://pytorch.org/

下拉到这里:

根据自己的需求选择对应的选项,然后复制下面的命令在虚拟机上下载即可,我根据自己的服务器环境,选择了下面的命令:

conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch

安装完成,如下:

3、安装其他包

一开始没看到requirements.txt,挨个pip,缺啥pip啥,我装的如下:

pip install opencv-python
pip install scikit-image
pip3 install tqdm
pip install tensorboardX
pip install easydict
pip install pycocotools

后来发现可以一键安装所需的环境,需要用到代码里给的requirements.txt

命令如下:

pip3 install -r requirements.txt

三、测试模型

先测试一下demo,跑一下看看,命令如下:

python demo.py -cfgfile ./cfg/yolov4.cfg -weightfile ./weight/yolov4.weights -imgfile ./data/dog.jpg

这里的demo.py所用的模型是yolov4.weights而不是yolov4.pth,其中yolov4.weights的下载链接:

  • baidu(https://pan.baidu.com/s/1dAGEW8cm-dqK14TbhhVetA Extraction code:dm5b)
  • google(https://drive.google.com/open?id=1cewMfusmPjYWbrnuJRuKhPMwRe_b9PaT)

PS:经博友提醒,发现之前命令不小心写错了,写成了yolov4.pth,已更正。因此提醒我试了下,demo.py用yolov4.pth跑出的来的测试图,没有框框,故猜测demo.py只能用yolov4.weights(仅个人猜测,未求证)。若要使用yolov4.pth测试,可先跳到本文第六步测试,要使用models.py来测试,具体步骤后面都有写,这里不做过多赘述。

输入命令行显示如下:

文件夹下会生成一个predictions-dog.jpg文件(原来的名字时候predictions.jpg,为了好区分我自己改了):

四、数据集

1、数据集准备

关于准备数据集,和YOLOv3一样,之前有写过,可参考上一篇博客Darknet--Yolov3训练自己的数据中的第三部分(三、数据集准备),准备好所需的图片、XML文件、训练和验证的.txt.文件。

2、数据转换

准备train.txt,内容是图片名和box,格式如下:

image_path1 x1,y1,x2,y2,id x1,y1,x2,y2,id x1,y1,x2,y2,id ...
image_path2 x1,y1,x2,y2,id x1,y1,x2,y2,id x1,y1,x2,y2,id ...
...
  • image_path : 图片名
  • x1,y1 : 左上角坐标
  • x2,y2 : 右下角坐标
  • id : 物体类别

处理代码,官方给的好像是处理json格式的,我的是xml格式,重新找了一个。在pytorch-YOLOv4文件夹下,新建voc_annotation.py文件,具体代码如下:

import xml.etree.ElementTree as ET
from os import getcwdsets=[('myData', 'train'), ('myData', 'val'), ('myData', 'test')]classes = ["car", "truck", "bus", "moto", "bike", "tricycle", "pedestrian", "plate", "driver", "codriver", "tissue", "mark", "decorate"]def convert_annotation(year, image_id, list_file):##in_file = open('VOCdevkit/VOC%s/Annotations/%s.xml'%(year, image_id))in_file = open('myData/Annotations/%s.xml'%(image_id))tree=ET.parse(in_file)root = tree.getroot()for obj in root.iter('object'):difficult = 0 if obj.find('difficult')!=None:difficult = obj.find('difficult').textcls = obj.find('name').textif cls not in classes or int(difficult)==1:continuecls_id = classes.index(cls)xmlbox = obj.find('bndbox')b = (int(xmlbox.find('xmin').text), int(xmlbox.find('ymin').text), int(xmlbox.find('xmax').text), int(xmlbox.find('ymax').text))list_file.write(" " + ",".join([str(a) for a in b]) + ',' + str(cls_id))wd = getcwd()for year, image_set in sets:image_ids = open('myData/ImageSets/Main/%s.txt'%(image_set)).read().strip().split()list_file = open('%s_%s.txt'%(year, image_set), 'w')for image_id in image_ids:list_file.write('%s/myData/JPEGImages/%s.jpg'%(wd, image_id))convert_annotation(year, image_id, list_file)list_file.write('\n')list_file.close()

运行该代码,会在pytorch-YOLOv4下生成三个文件:

我将这三个文件拷贝到myData文件夹下,为了方便(偷懒不想改训练代码了),将名字改成如下:

至此,所需要的数据就做好了。

五、训练模型

1、参数设置

cfg.py中根据自己需求修改如下项:

2、训练

python train.py -l 0.001 -g 2 -pretrained ./weight/yolov4.conv.137.pth -classes 13 -dir ./myData/JPEGImages/ -train_label_path ./myData/train.txt#-l             学习率
#-g             gpu id
#-pretrained    预训练权值
#-classes       类别种类
#-dir           图片所在文件夹

训练生成的模型保存在checkpoints文件夹中:

六、测试

测试模型需要用到models.py,输入命令:

python models.py <num_classes> <weightfile> <imgfile> <IN_IMAGE_H> <IN_IMAGE_W> <namefile(optional)># <num_classes>           类别
# <weightfile>            模型
# <imgfile>               要检测的图片
# <IN_IMAGE_H>            图片的高
# <IN_IMAGE_W>            图片的宽
# <namefile(optional)>    类别标签

 我自己的命令如下:

python models.py 13 /checkpoints/Yolov4_epoch14.pth data/a.jpg 608 608 data/mydata.names

结果报错:

不知道为啥(原谅我又懒又菜),于是改了下代码,直接在代码里给<namefile>这个参数:

因为代码里给过<namefile>的参数了,命令行里就不给了:

python models.py 13 /checkpoints/Yolov4_epoch14.pth data/a.jpg 608 608

测试结果保存在根目录下的predictions.jpg,如下:

PS:框框和字体有点细,看不清,有空再改一下。图片中的框不准,应该是我才训练了14轮就拿来测了,后面到几万应该会好很多。训练结束再来更新。

七、评估

后续再补充。。。

八、一些问题

1、train的时候遇到一个问题,说是需要创建自己的'get_image_id':

打开dataset.py,定位到出错的地方,推测是图片名字格式不对,于是改了图片名为纯数字的,但是发现并不行,看到作者注释写的:

去看了GitHub上的global image id方式,并不会用。最后直接修改dataset.py的“get_image_id”,并将自己的名字格式改为“前缀-数字.jpg”的格式,对应的xml也改为“前缀-数字.xml”:

改完可以跑了:

大家也可以根据自己的名字改写“get_image_id”,欢迎交流。

2、不知道为啥,训练速度比之前darknet-yolov3慢好多。我的26000多张图片,训练一轮要两小时,不知道是哪里出问题,有木有大佬能给说道一二。

后记:YOLOV4有darknet和pytorch等多个版本,选这个版本是因工作需要,后续要转成TensorRT,而它刚好提供了一整套转换工具,后续会继续更新模型如何转ONNX以及TensorRT。写博客为记录踩坑过程,顺便备忘,如有错误,还望大佬指教!

 

 

 

这篇关于pytorch-YOLOv4训练自己的数据集的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/723807

相关文章

批量导入txt数据到的redis过程

《批量导入txt数据到的redis过程》用户通过将Redis命令逐行写入txt文件,利用管道模式运行客户端,成功执行批量删除以Product*匹配的Key操作,提高了数据清理效率... 目录批量导入txt数据到Redisjs把redis命令按一条 一行写到txt中管道命令运行redis客户端成功了批量删除k

SpringBoot多环境配置数据读取方式

《SpringBoot多环境配置数据读取方式》SpringBoot通过环境隔离机制,支持properties/yaml/yml多格式配置,结合@Value、Environment和@Configura... 目录一、多环境配置的核心思路二、3种配置文件格式详解2.1 properties格式(传统格式)1.

解决pandas无法读取csv文件数据的问题

《解决pandas无法读取csv文件数据的问题》本文讲述作者用Pandas读取CSV文件时因参数设置不当导致数据错位,通过调整delimiter和on_bad_lines参数最终解决问题,并强调正确参... 目录一、前言二、问题复现1. 问题2. 通过 on_bad_lines=‘warn’ 跳过异常数据3

PyTorch中的词嵌入层(nn.Embedding)详解与实战应用示例

《PyTorch中的词嵌入层(nn.Embedding)详解与实战应用示例》词嵌入解决NLP维度灾难,捕捉语义关系,PyTorch的nn.Embedding模块提供灵活实现,支持参数配置、预训练及变长... 目录一、词嵌入(Word Embedding)简介为什么需要词嵌入?二、PyTorch中的nn.Em

C#监听txt文档获取新数据方式

《C#监听txt文档获取新数据方式》文章介绍通过监听txt文件获取最新数据,并实现开机自启动、禁用窗口关闭按钮、阻止Ctrl+C中断及防止程序退出等功能,代码整合于主函数中,供参考学习... 目录前言一、监听txt文档增加数据二、其他功能1. 设置开机自启动2. 禁止控制台窗口关闭按钮3. 阻止Ctrl +

java如何实现高并发场景下三级缓存的数据一致性

《java如何实现高并发场景下三级缓存的数据一致性》这篇文章主要为大家详细介绍了java如何实现高并发场景下三级缓存的数据一致性,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 下面代码是一个使用Java和Redisson实现的三级缓存服务,主要功能包括:1.缓存结构:本地缓存:使

在MySQL中实现冷热数据分离的方法及使用场景底层原理解析

《在MySQL中实现冷热数据分离的方法及使用场景底层原理解析》MySQL冷热数据分离通过分表/分区策略、数据归档和索引优化,将频繁访问的热数据与冷数据分开存储,提升查询效率并降低存储成本,适用于高并发... 目录实现冷热数据分离1. 分表策略2. 使用分区表3. 数据归档与迁移在mysql中实现冷热数据分

C#解析JSON数据全攻略指南

《C#解析JSON数据全攻略指南》这篇文章主要为大家详细介绍了使用C#解析JSON数据全攻略指南,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录一、为什么jsON是C#开发必修课?二、四步搞定网络JSON数据1. 获取数据 - HttpClient最佳实践2. 动态解析 - 快速

MyBatis-Plus通用中等、大量数据分批查询和处理方法

《MyBatis-Plus通用中等、大量数据分批查询和处理方法》文章介绍MyBatis-Plus分页查询处理,通过函数式接口与Lambda表达式实现通用逻辑,方法抽象但功能强大,建议扩展分批处理及流式... 目录函数式接口获取分页数据接口数据处理接口通用逻辑工具类使用方法简单查询自定义查询方法总结函数式接口

SQL中如何添加数据(常见方法及示例)

《SQL中如何添加数据(常见方法及示例)》SQL全称为StructuredQueryLanguage,是一种用于管理关系数据库的标准编程语言,下面给大家介绍SQL中如何添加数据,感兴趣的朋友一起看看吧... 目录在mysql中,有多种方法可以添加数据。以下是一些常见的方法及其示例。1. 使用INSERT I