YOLOv8优改系列二:YOLOv8融合ATSS标签分配策略,实现网络快速涨点

本文主要是介绍YOLOv8优改系列二:YOLOv8融合ATSS标签分配策略,实现网络快速涨点,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

在这里插入图片描述

💥 💥💥 💥💥 💥💥 💥💥神经网络专栏改进完整目录:点击
💗 只需订阅一个专栏即可享用所有网络改进内容每周定时更新
文章内容:针对YOLOv8的Neck部分融合ATSS标签分配策略,实现网络快速涨点!!!
推荐指数(满分五星):⭐️⭐️⭐️⭐️⭐️
涨点指数(满分五星):⭐️⭐️⭐️⭐️⭐️

✨目录

  • 一、ATSS介绍
  • 二、核心代码修改
    • 2.1 修改loss文件
    • 2.2 创建模块文件
    • 2.3 修改训练代码
    • 2.4 问题总结


一、ATSS介绍

🌳论文地址:点击
🌳源码地址:点击
🌳问题阐述:多年来,目标检测一直由基于锚点的检测器主导。最近,由于 FPN 和 Focal Loss 的提出,无锚检测器变得流行起来。在本文中,我们首先指出基于anchor的检测和无anchor的检测的本质区别实际上是如何定义正负训练样本,这导致了它们之间的性能差距。如果他们在训练时采用相同的正负样本定义,那么无论从一个盒子还是一个点回归,最终的性能都没有明显的差异。如何在不依赖复杂手工设计规则的情况下,利用有限的标注数据有效地进行目标分割训练。
🌳主要思想:ATSS方法首先在每个特征层找到与GT(Ground Truth) box最近的k个候选anchor boxes(非预测结果),然后计算这些候选box与GT间的IoU(Intersection over Union),并计算IoU的均值和标准差,以此确定IoU阈值,选择IoU大于该阈值的box作为最终的正样本。如果某个anchor box对应多个GT,则选择IoU最大的GT进行匹配3。
🌳思想优点:它能够根据目标的统计信息自动选择正负样本,避免了人工设定固定阈值的问题,提高了模型的性能和效率。同时,ATSS方法只需要一个超参数k,后续的使用表明ATSS的性能对k不敏感,因此可以说ATSS是一个几乎不需要超参数的方法。

🌳算法流程图
在这里插入图片描述

二、核心代码修改

2.1 修改loss文件

loss文件地址:ultralytics\utils\loss.py
修改1

            _, target_bboxes, target_scores, fg_mask, _ = self.assigner(pred_scores.detach().sigmoid(),(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),anchor_points * stride_tensor,gt_labels,gt_bboxes,mask_gt,)

修改为

            _, target_bboxes, target_scores, fg_mask = self.assigner_atss(anchors,n_anchors_list,gt_labels, gt_bboxes,mask_gt,(pred_bboxes.detach() * stride_tensor_s).type(gt_bboxes.dtype),)

修改2
初始化ATSS标签分配策略:
self.assigner_atss = ATSSAssigner(9, num_classes=self.nc)
在这里插入图片描述

2.2 创建模块文件

上面修改完之后,我们可以发现找不到ATSSAssigner类,这是因为我们还未创建此类,我们在相同的utils文件夹下,创建ATSS标签分配策略代码,命名为atss_assigner.py,内容如下:

import torch
import torch.nn as nn
import torch.nn.functional as Fimport torch
import torch.nn.functional as F
from ultralytics.utils.atss_fun import iou_calculator, select_highest_overlaps, dist_calculator, select_candidates_in_gts
from ultralytics.utils.ops import fp16_clampdef cast_tensor_type(x, scale=1., dtype=None):if dtype == 'fp16':# scale is for preventing overflowsx = (x / scale).half()return xdef iou2d_calculator(bboxes1, bboxes2, mode='iou', is_aligned=False, scale=1., dtype=None):"""2D Overlaps (e.g. IoUs, GIoUs) Calculator.""""""Calculate IoU between 2D bboxes.Args:bboxes1 (Tensor): bboxes have shape (m, 4) in <x1, y1, x2, y2>format, or shape (m, 5) in <x1, y1, x2, y2, score> format.bboxes2 (Tensor): bboxes have shape (m, 4) in <x1, y1, x2, y2>format, shape (m, 5) in <x1, y1, x2, y2, score> format, or beempty. If ``is_aligned `` is ``True``, then m and n must beequal.mode (str): "iou" (intersection over union), "iof" (intersectionover foreground), or "giou" (generalized intersection overunion).is_aligned (bool, optional): If True, then m and n must be equal.Default False.@from MangoAI &3836712GKcH2717GhcK. please see https://github.com/iscyy/ultralyticsPro Returns:Tensor: shape (m, n) if ``is_aligned `` is False else shape (m,)"""assert bboxes1.size(-1) in [0, 4, 5]assert bboxes2.size(-1) in [0, 4, 5]if bboxes2.size(-1) == 5:bboxes2 = bboxes2[..., :4]if bboxes1.size(-1) == 5:bboxes1 = bboxes1[..., :4]if dtype == 'fp16':# change tensor type to save cpu and cuda memory and keep speedbboxes1 = cast_tensor_type(bboxes1, scale, dtype)bboxes2 = cast_tensor_type(bboxes2, scale, dtype)overlaps = bbox_overlaps(bboxes1, bboxes2, mode, is_aligned)if not overlaps.is_cuda and overlaps.dtype == torch.float16:# resume cpu float32overlaps = overlaps.float()return overlapsreturn bbox_overlaps(bboxes1, bboxes2, mode, is_aligned)def bbox_overlaps(bboxes1, bboxes2, mode='iou', is_aligned=False, eps=1e-6):assert mode in ['iou', 'iof', 'giou'], f'Unsupported mode {mode}'# Either the boxes are empty or the length of boxes' last dimension is 4assert (bboxes1.size(-1) == 4 or bboxes1.size(0) == 0)assert (bboxes2.size(-1) == 4 or bboxes2.size(0) == 0)# Batch dim must be the same# Batch dim: (B1, B2, ... Bn)assert bboxes1.shape[:-2] == bboxes2.shape[:-2]batch_shape = bboxes1.shape[:-2]rows = bboxes1.size(-2)cols = bboxes2.size(-2)if is_aligned:assert rows == colsif rows * cols == 0:if is_aligned:return bboxes1.new(batch_shape + (rows, ))else:return bboxes1.new(batch_shape + (rows, cols))area1 = (bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1])area2 = (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1])if is_aligned:lt = torch.max(bboxes1[..., :2], bboxes2[..., :2])  # [B, rows, 2]rb = torch.min(bboxes1[..., 2:], bboxes2[..., 2:])  # [B, rows, 2]wh = fp16_clamp(rb - lt, min=0)overlap = wh[..., 0] * wh[..., 1]if mode in ['iou', 'giou']:union = area1 + area2 - overlapelse:union = area1if mode == 'giou':enclosed_lt = torch.min(bboxes1[..., :2], bboxes2[..., :2])enclosed_rb = torch.max(bboxes1[..., 2:], bboxes2[..., 2:])else:lt = torch.max(bboxes1[..., :, None, :2],bboxes2[..., None, :, :2])  # [B, rows, cols, 2]rb = torch.min(bboxes1[..., :, None, 2:],bboxes2[..., None, :, 2:])  # [B, rows, cols, 2]wh = fp16_clamp(rb - lt, min=0)overlap = wh[..., 0] * wh[..., 1]if mode in ['iou', 'giou']:union = area1[..., None] + area2[..., None, :] - overlapelse:union = area1[..., None]if mode == 'giou':enclosed_lt = torch.min(bboxes1[..., :, None, :2],bboxes2[..., None, :, :2])enclosed_rb = torch.max(bboxes1[..., :, None, 2:],bboxes2[..., None, :, 2:])eps = union.new_tensor([eps])union = torch.max(union, eps)ious = overlap / unionif mode in ['iou', 'iof']:return ious# calculate giousenclose_wh = fp16_clamp(enclosed_rb - enclosed_lt, min=0)enclose_area = enclose_wh[..., 0] * enclose_wh[..., 1]enclose_area = torch.max(enclose_area, eps)gious = ious - (enclose_area - union) / enclose_areareturn giousclass ATSSAssigner(nn.Module):'''Adaptive Training Sample Selection Assigner'''def __init__(self,topk=9,num_classes=80):super(ATSSAssigner, self).__init__()self.topk = topkself.num_classes = num_classesself.bg_idx = num_classes@torch.no_grad()def forward(self,anc_bboxes,n_level_bboxes,gt_labels,gt_bboxes,mask_gt,pd_bboxes):r"""This code is based onhttps://github.com/fcjian/TOOD/blob/master/mmdet/core/bbox/assigners/atss_assigner.pyArgs:anc_bboxes (Tensor): shape(num_total_anchors, 4)n_level_bboxes (List):len(3)gt_labels (Tensor): shape(bs, n_max_boxes, 1)gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)mask_gt (Tensor): shape(bs, n_max_boxes, 1)pd_bboxes (Tensor): shape(bs, n_max_boxes, 4)Returns:target_labels (Tensor): shape(bs, num_total_anchors)target_bboxes (Tensor): shape(bs, num_total_anchors, 4)target_scores (Tensor): shape(bs, num_total_anchors, num_classes)fg_mask (Tensor): shape(bs, num_total_anchors)"""self.n_anchors = anc_bboxes.size(0)self.bs = gt_bboxes.size(0)self.n_max_boxes = gt_bboxes.size(1)if self.n_max_boxes == 0:device = gt_bboxes.devicereturn torch.full( [self.bs, self.n_anchors], self.bg_idx).to(device), \torch.zeros([self.bs, self.n_anchors, 4]).to(device), \torch.zeros([self.bs, self.n_anchors, self.num_classes]).to(device), \torch.zeros([self.bs, self.n_anchors]).to(device)overlaps = iou2d_calculator(gt_bboxes.reshape([-1, 4]), anc_bboxes)overlaps = overlaps.reshape([self.bs, -1, self.n_anchors])distances, ac_points = dist_calculator(gt_bboxes.reshape([-1, 4]), anc_bboxes)distances = distances.reshape([self.bs, -1, self.n_anchors])is_in_candidate, candidate_idxs = self.select_topk_candidates(distances, n_level_bboxes, mask_gt)overlaps_thr_per_gt, iou_candidates = self.thres_calculator(is_in_candidate, candidate_idxs, overlaps)# select candidates iou >= threshold as positiveis_pos = torch.where(iou_candidates > overlaps_thr_per_gt.repeat([1, 1, self.n_anchors]),is_in_candidate, torch.zeros_like(is_in_candidate))is_in_gts = select_candidates_in_gts(ac_points, gt_bboxes)mask_pos = is_pos * is_in_gts * mask_gttarget_gt_idx, fg_mask, mask_pos = select_highest_overlaps(mask_pos, overlaps, self.n_max_boxes)# assigned targettarget_labels, target_bboxes, target_scores = self.get_targets(gt_labels, gt_bboxes, target_gt_idx, fg_mask)# soft label with iouif pd_bboxes is not None:ious = iou_calculator(gt_bboxes, pd_bboxes) * mask_posious = ious.max(axis=-2)[0].unsqueeze(-1)target_scores *= iousreturn target_labels.long(), target_bboxes, target_scores, fg_mask.bool()def select_topk_candidates(self,distances, n_level_bboxes, mask_gt):mask_gt = mask_gt.repeat(1, 1, self.topk).bool()level_distances = torch.split(distances, n_level_bboxes, dim=-1)is_in_candidate_list = []candidate_idxs = []start_idx = 0for per_level_distances, per_level_boxes in zip(level_distances, n_level_bboxes):end_idx = start_idx + per_level_boxesselected_k = min(self.topk, per_level_boxes)_, per_level_topk_idxs = per_level_distances.topk(selected_k, dim=-1, largest=False)candidate_idxs.append(per_level_topk_idxs + start_idx)per_level_topk_idxs = torch.where(mask_gt, per_level_topk_idxs, torch.zeros_like(per_level_topk_idxs))is_in_candidate = F.one_hot(per_level_topk_idxs, per_level_boxes).sum(dim=-2)is_in_candidate = torch.where(is_in_candidate > 1, torch.zeros_like(is_in_candidate), is_in_candidate)is_in_candidate_list.append(is_in_candidate.to(distances.dtype))start_idx = end_idxis_in_candidate_list = torch.cat(is_in_candidate_list, dim=-1)candidate_idxs = torch.cat(candidate_idxs, dim=-1)return is_in_candidate_list, candidate_idxsdef thres_calculator(self,is_in_candidate, candidate_idxs, overlaps):n_bs_max_boxes = self.bs * self.n_max_boxes_candidate_overlaps = torch.where(is_in_candidate > 0, overlaps, torch.zeros_like(overlaps))candidate_idxs = candidate_idxs.reshape([n_bs_max_boxes, -1])assist_idxs = self.n_anchors * torch.arange(n_bs_max_boxes, device=candidate_idxs.device)assist_idxs = assist_idxs[:,None]faltten_idxs = candidate_idxs + assist_idxscandidate_overlaps = _candidate_overlaps.reshape(-1)[faltten_idxs]candidate_overlaps = candidate_overlaps.reshape([self.bs, self.n_max_boxes, -1])overlaps_mean_per_gt = candidate_overlaps.mean(axis=-1, keepdim=True)overlaps_std_per_gt = candidate_overlaps.std(axis=-1, keepdim=True)overlaps_thr_per_gt = overlaps_mean_per_gt + overlaps_std_per_gtreturn overlaps_thr_per_gt, _candidate_overlaps'''@from MangoAI &3836712GKcH2717GhcK. please see https://github.com/iscyy/ultralyticsPro'''def get_targets(self,gt_labels, gt_bboxes, target_gt_idx, fg_mask):# assigned target labelsbatch_idx = torch.arange(self.bs, dtype=gt_labels.dtype, device=gt_labels.device)batch_idx = batch_idx[...,None]target_gt_idx = (target_gt_idx + batch_idx * self.n_max_boxes).long()target_labels = gt_labels.flatten()[target_gt_idx.flatten()]target_labels = target_labels.reshape([self.bs, self.n_anchors])target_labels = torch.where(fg_mask > 0, target_labels, torch.full_like(target_labels, self.bg_idx))# assigned target boxestarget_bboxes = gt_bboxes.reshape([-1, 4])[target_gt_idx.flatten()]target_bboxes = target_bboxes.reshape([self.bs, self.n_anchors, 4])# assigned target scorestarget_scores = F.one_hot(target_labels.long(), self.num_classes + 1).float()target_scores = target_scores[:, :, :self.num_classes]return target_labels, target_bboxes, target_scores

2.3 修改训练代码

我们复制yolov8配置文件,命名为ultralytics\cfg\models\v8\YOLOv8-ATSS.yaml, 配置内容无需修改

import sys
import argparse
from ultralytics import YOLO
import os
sys.path.append(r'F:\python\company_code\Algorithm_architecture\ultralyticsPro0425-YOLOv8') # Pathdef main(opt):yaml = opt.cfgweights = opt.weightsmodel = YOLO(yaml).load(weights)model.info()results = model.train(data='ultralytics\cfg\datasets\coco128.yaml', epochs=10,imgsz=416, workers=0,batch=4,)def parse_opt(known=False):parser = argparse.ArgumentParser()parser.add_argument('--cfg', type=str, default= r'ultralytics\cfg\models\cfg2024\YOLOv8-标签分配策略\YOLOv8-ATSS.yaml', help='initial weights path')parser.add_argument('--weights', type=str, default='weights\yolov8n.pt', help='')opt = parser.parse_known_args()[0] if known else parser.parse_args()return optif __name__ == "__main__":opt = parse_opt()main(opt)

运行此代码即可将ATSS结合YOLOv8进行训练。python train_v8.py --cfg ultralytics\cfg\models\v8\YOLOv8-ATSS.yaml

2.4 问题总结

  1. 如果遇到v8在文件里修改了模型,但是训练时调用总是调用虚拟环境中的库
    • 是这种情况是没有成功载入你的模块,可以将所有的ultralytics复制到你的虚拟环境,或者卸载了ultralytics环境,只能载入你的文件。
  2. ModuleNotFoundError: No module named ‘timm’:
    • pip install timm -i https://pypi.tuna.tsinghua.edu.cn/simple/(高环境问题可以安装pip install timm==0.6.13)
  3. ModuleNotFoundError: No module named ‘einops’
    • pip install einops -i https://pypi.tuna.tsinghua.edu.cn/simple
  4. ModuleNotFoundError: No module named ‘hub_sdk’:
    • pip install hub_sdk -i https://pypi.tuna.tsinghua.edu.cn/simple/

在这里插入图片描述

这篇关于YOLOv8优改系列二:YOLOv8融合ATSS标签分配策略,实现网络快速涨点的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/944185

相关文章

使用Xterm实现终端构建

————html篇———— // 需要使用Xterm Xterm的官网: Xterm.js 新建项目 增加基本文件 下载 框架 npm init -y Xterm依赖 npm install @xterm/xterm 参考文档写的代码 贴入代码 <html><head><link rel="stylesheet" href="n

三丰云搭建QQ-bot的服务器-代码实现(3)

网址:https://www.sanfengyun.com >> 三丰云免费云服务器 代码实现 书接上回装饰器,显而易见,只有装饰器还不完善,所以我们接着来补充代码 首先定义一个 MyClient 类 class MyClient(botpy.Client):async def on_ready(self):_log.info(f"robot 「{self.robot.name}」 on_

[muduo网络库]——muduo库InetAddress类(剖析muduo网络库核心部分、设计思想)

接着之前我们[muduo网络库]——muduo库EventLoopThreadPool类(剖析muduo网络库核心部分、设计思想),我们接着看完除去TcpServer的最后一个InetAddress类。InetAddress 类是 muduo 网络库中的一个重要类,用于表示网络中的 IP 地址和端口号。 InetAddress类 用于表示网络中的通信实体的地址信息,例如服务器地址、客户端地址等

JavaSE——集合框架一(3/7)-List系列集合:特点、方法、遍历方式、ArrayList集合的底层原理

目录 List集合 特点、特有方法 实例演示 List集合支持的遍历方式 ArrayList集合的底层原理 List集合 我们要了解List集合三点: 有什么特点?是否有特有功能?适合什么业务场景? 特点、特有方法 List系列集合特点:有序、可重复、有索引 ArrayList:有序,可重复,有索引LinkedList:有序,可重复,有索引 它们的底层实现不同,

QT creator修改快捷键--实现类似vs2019 F12快速导航

QT creator修改快捷键--实现类似vs2019 F12快速导航

YOLOv8改进 | 图像修复 | 适用多种复杂场景的全能图像修复网络AirNet助力YOLOv8检测(全网独家首发)

一、本文介绍 本文给大家带来的改进机制是一种适用多种复杂场景的全能图像修复网络AirNet,其由对比基降解编码器(CBDE)和降解引导修复网络(DGRN)两个神经模块组成,能够在未知损坏类型和程度的情况下恢复受损图像。这两部分共同工作,能够处理多种类型的图像退化,而无需预先知道损坏的具体信息。本文的内容为专栏读者指定发布。 专栏目录:

天画项目-低代码平台-总体设计与实现

一、背景&需求 1.1 总体背景 我在业余时间构建了一个租房平台,采用springboot微服务的架构模式,同时结合服务化思想进行代码实践,但是由于存在很多重复且低效的变更导致业余时间对这个租房平台的进度产生影响。作为开发人员长期维护一个业余项目产生了很多进度延迟和效率问题会导致项目项目被遗弃的概率增大,对开发者来说也是一个非常沮丧的事情。为了避免这种事情发生,我一方面鼓励自己尽量每天贡献代码

SerDes系列之电路技术概述

现在的高速电路设计中,SerDes的应用几乎无处不在,如下图所示的一款SoC,其外设接口除了少量普通的IO,几乎都是SerDes专用接口,因此,电路设计中对于SerDes接口电路的熟知程度,几乎就决定了设计的成败。         本文以概述的形式,陈述了SerDes电路设计中的关键技术元素,让读者了解这些基本概念的同时,也为后续的系列文章进行铺垫。         S

HTML常用标签及属性

一、简单标签 标签作用div 自带换行,块级显示span 语义化标签,无任何修饰效果br 换行,单标签hr 水平分割线,单标签h1 ~ h6 标题标签,有加粗效果,h1最明显,往后依次减弱,独占一行,与上下内容有间隙p 段落标签,独占一行,与上下内容有间隙strong 文字加粗em 文字倾斜ins 下划线del 删除线 二、常用属性的标签 标签作用属性a超链接标签 ① href:指定链接地址

电力系统潮流计算的计算机算法(四)——PQ快速解耦潮流算法

本篇为本科课程《电力系统稳态分析》的笔记。 本篇为这一章的第四篇笔记。上一篇传送门。 潮流计算的快速解耦法 牛顿-拉夫逊法潮流计算,主要的工作量在于形成雅可比矩阵和求解修正方程。由于雅可比矩阵的阶数为n+m-1,约为节点总数的两倍,非对称矩阵,且在迭代过程中需要不断的变化,所以大规模的电力系统中应用该算法很费时费力。 PQ分解法潮流计算的修正方程 有一种快速解耦法,是最有效且应用最广。电