【YOLO改进】主干插入ShuffleAttention模块(基于MMYOLO)

2024-04-26 10:28

本文主要是介绍【YOLO改进】主干插入ShuffleAttention模块(基于MMYOLO),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

ShuffleAttention模块

论文链接:https://arxiv.org/abs/2102.00240

将ShuffleAttention模块添加到MMYOLO中

  1. 将开源代码ShuffleAttention.py文件复制到mmyolo/models/plugins目录下

  2. 导入MMYOLO用于注册模块的包: from mmyolo.registry import MODELS

  3. 确保 class ShuffleAttention中的输入维度为in_channels(因为MMYOLO会提前传入输入维度参数,所以要保持参数名的一致)

  4. 利用@MODELS.register_module()将“class ShuffleAttention(nn.Module)”注册:

  5. 修改mmyolo/models/plugins/__init__.py文件

  6. 在终端运行:

    python setup.py install
  7. 修改对应的配置文件,并且将plugins的参数“type”设置为“ShuffleAttention”,可参考【YOLO改进】主干插入注意力机制模块CBAM(基于MMYOLO)-CSDN博客

修改后的ShuffleAttention.py

import torch
from torch import nn
from torch.nn import init
from torch.nn.parameter import Parameter
from mmyolo.registry import MODELS@MODELS.register_module()
class ShuffleAttention(nn.Module):def __init__(self, in_channels=512, reduction=16, G=8):super().__init__()self.G = Gself.channel = in_channelsself.avg_pool = nn.AdaptiveAvgPool2d(1)self.gn = nn.GroupNorm(in_channels // (2 * G), in_channels // (2 * G))self.cweight = Parameter(torch.zeros(1, in_channels // (2 * G), 1, 1))self.cbias = Parameter(torch.ones(1, in_channels // (2 * G), 1, 1))self.sweight = Parameter(torch.zeros(1, in_channels // (2 * G), 1, 1))self.sbias = Parameter(torch.ones(1, in_channels // (2 * G), 1, 1))self.sigmoid = nn.Sigmoid()def init_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):init.kaiming_normal_(m.weight, mode='fan_out')if m.bias is not None:init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):init.constant_(m.weight, 1)init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):init.normal_(m.weight, std=0.001)if m.bias is not None:init.constant_(m.bias, 0)@staticmethoddef channel_shuffle(x, groups):b, c, h, w = x.shapex = x.reshape(b, groups, -1, h, w)x = x.permute(0, 2, 1, 3, 4)# flattenx = x.reshape(b, -1, h, w)return xdef forward(self, x):b, c, h, w = x.size()# group into subfeaturesx = x.view(b * self.G, -1, h, w)  # bs*G,c//G,h,w# channel_splitx_0, x_1 = x.chunk(2, dim=1)  # bs*G,c//(2*G),h,w# channel attentionx_channel = self.avg_pool(x_0)  # bs*G,c//(2*G),1,1x_channel = self.cweight * x_channel + self.cbias  # bs*G,c//(2*G),1,1x_channel = x_0 * self.sigmoid(x_channel)# spatial attentionx_spatial = self.gn(x_1)  # bs*G,c//(2*G),h,wx_spatial = self.sweight * x_spatial + self.sbias  # bs*G,c//(2*G),h,wx_spatial = x_1 * self.sigmoid(x_spatial)  # bs*G,c//(2*G),h,w# concatenate along channel axisout = torch.cat([x_channel, x_spatial], dim=1)  # bs*G,c//G,h,wout = out.contiguous().view(b, -1, h, w)# channel shuffleout = self.channel_shuffle(out, 2)return outif __name__ == '__main__':input = torch.randn(50, 512, 7, 7)se = ShuffleAttention(channel=512, G=8)output = se(input)print(output.shape)

修改后的__init__.py

# Copyright (c) OpenMMLab. All rights reserved.
from .cbam import CBAM
from .Biformer import BiLevelRoutingAttention
from .A2Attention import DoubleAttention
from .CoordAttention import CoordAtt
from .CoTAttention import CoTAttention
from .ECA import ECAAttention
from .EffectiveSE import EffectiveSEModule
from .EMA import EMA
from .GC import GlobalContext
from .GE import GatherExcite
from .MHSA import MHSA
from .ParNetAttention import ParNetAttention
from .PolarizedSelfAttention import ParallelPolarizedSelfAttention
from .S2Attention import S2Attention
from .SE import SEAttention
from .SequentialSelfAttention import SequentialPolarizedSelfAttention
from .SGE import SpatialGroupEnhance
from .ShuffleAttention import ShuffleAttention
__all__ = ['CBAM', 'BiLevelRoutingAttention', 'DoubleAttention', 'CoordAtt','CoTAttention','ECAAttention', 'EffectiveSEModule', 'EMA','GlobalContext', 'GatherExcite', 'MHSA', 'ParNetAttention','ParallelPolarizedSelfAttention','S2Attention','SEAttention','SequentialPolarizedSelfAttention','SpatialGroupEnhance','ShuffleAttention']

修改后的配置文件(以configs/yolov5/yolov5_s-v61_syncbn_8xb16-300e_coco.py为例)

_base_ = ['../_base_/default_runtime.py', '../_base_/det_p5_tta.py']# ========================Frequently modified parameters======================
# -----data related-----
data_root = 'data/coco/'  # Root path of data
# Path of train annotation file
train_ann_file = 'annotations/instances_train2017.json'
train_data_prefix = 'train2017/'  # Prefix of train image path
# Path of val annotation file
val_ann_file = 'annotations/instances_val2017.json'
val_data_prefix = 'val2017/'  # Prefix of val image pathnum_classes = 80  # Number of classes for classification
# Batch size of a single GPU during training
train_batch_size_per_gpu = 16
# Worker to pre-fetch data for each single GPU during training
train_num_workers = 8
# persistent_workers must be False if num_workers is 0
persistent_workers = True# -----model related-----
# Basic size of multi-scale prior box
anchors = [[(10, 13), (16, 30), (33, 23)],  # P3/8[(30, 61), (62, 45), (59, 119)],  # P4/16[(116, 90), (156, 198), (373, 326)]  # P5/32
]# -----train val related-----
# Base learning rate for optim_wrapper. Corresponding to 8xb16=128 bs
base_lr = 0.01
max_epochs = 300  # Maximum training epochsmodel_test_cfg = dict(# The config of multi-label for multi-class prediction.multi_label=True,# The number of boxes before NMSnms_pre=30000,score_thr=0.001,  # Threshold to filter out boxes.nms=dict(type='nms', iou_threshold=0.65),  # NMS type and thresholdmax_per_img=300)  # Max number of detections of each image# ========================Possible modified parameters========================
# -----data related-----
img_scale = (640, 640)  # width, height
# Dataset type, this will be used to define the dataset
dataset_type = 'YOLOv5CocoDataset'
# Batch size of a single GPU during validation
val_batch_size_per_gpu = 1
# Worker to pre-fetch data for each single GPU during validation
val_num_workers = 2# Config of batch shapes. Only on val.
# It means not used if batch_shapes_cfg is None.
batch_shapes_cfg = dict(type='BatchShapePolicy',batch_size=val_batch_size_per_gpu,img_size=img_scale[0],# The image scale of padding should be divided by pad_size_divisorsize_divisor=32,# Additional paddings for pixel scaleextra_pad_ratio=0.5)# -----model related-----
# The scaling factor that controls the depth of the network structure
deepen_factor = 0.33
# The scaling factor that controls the width of the network structure
widen_factor = 0.5
# Strides of multi-scale prior box
strides = [8, 16, 32]
num_det_layers = 3  # The number of model output scales
norm_cfg = dict(type='BN', momentum=0.03, eps=0.001)  # Normalization config# -----train val related-----
affine_scale = 0.5  # YOLOv5RandomAffine scaling ratio
loss_cls_weight = 0.5
loss_bbox_weight = 0.05
loss_obj_weight = 1.0
prior_match_thr = 4.  # Priori box matching threshold
# The obj loss weights of the three output layers
obj_level_weights = [4., 1., 0.4]
lr_factor = 0.01  # Learning rate scaling factor
weight_decay = 0.0005
# Save model checkpoint and validation intervals
save_checkpoint_intervals = 10
# The maximum checkpoints to keep.
max_keep_ckpts = 3
# Single-scale training is recommended to
# be turned on, which can speed up training.
env_cfg = dict(cudnn_benchmark=True)# ===============================Unmodified in most cases====================
model = dict(type='YOLODetector',data_preprocessor=dict(type='mmdet.DetDataPreprocessor',mean=[0., 0., 0.],std=[255., 255., 255.],bgr_to_rgb=True),backbone=dict(##修改部分plugins=[dict(cfg=dict(type='ShuffleAttention'),stages=(False, False, False, True))],type='YOLOv5CSPDarknet',deepen_factor=deepen_factor,widen_factor=widen_factor,norm_cfg=norm_cfg,act_cfg=dict(type='SiLU', inplace=True)),neck=dict(type='YOLOv5PAFPN',deepen_factor=deepen_factor,widen_factor=widen_factor,in_channels=[256, 512, 1024],out_channels=[256, 512, 1024],num_csp_blocks=3,norm_cfg=norm_cfg,act_cfg=dict(type='SiLU', inplace=True)),bbox_head=dict(type='YOLOv5Head',head_module=dict(type='YOLOv5HeadModule',num_classes=num_classes,in_channels=[256, 512, 1024],widen_factor=widen_factor,featmap_strides=strides,num_base_priors=3),prior_generator=dict(type='mmdet.YOLOAnchorGenerator',base_sizes=anchors,strides=strides),# scaled based on number of detection layersloss_cls=dict(type='mmdet.CrossEntropyLoss',use_sigmoid=True,reduction='mean',loss_weight=loss_cls_weight *(num_classes / 80 * 3 / num_det_layers)),loss_bbox=dict(type='IoULoss',iou_mode='ciou',bbox_format='xywh',eps=1e-7,reduction='mean',loss_weight=loss_bbox_weight * (3 / num_det_layers),return_iou=True),loss_obj=dict(type='mmdet.CrossEntropyLoss',use_sigmoid=True,reduction='mean',loss_weight=loss_obj_weight *((img_scale[0] / 640)**2 * 3 / num_det_layers)),prior_match_thr=prior_match_thr,obj_level_weights=obj_level_weights),test_cfg=model_test_cfg)albu_train_transforms = [dict(type='Blur', p=0.01),dict(type='MedianBlur', p=0.01),dict(type='ToGray', p=0.01),dict(type='CLAHE', p=0.01)
]pre_transform = [dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),dict(type='LoadAnnotations', with_bbox=True)
]train_pipeline = [*pre_transform,dict(type='Mosaic',img_scale=img_scale,pad_val=114.0,pre_transform=pre_transform),dict(type='YOLOv5RandomAffine',max_rotate_degree=0.0,max_shear_degree=0.0,scaling_ratio_range=(1 - affine_scale, 1 + affine_scale),# img_scale is (width, height)border=(-img_scale[0] // 2, -img_scale[1] // 2),border_val=(114, 114, 114)),dict(type='mmdet.Albu',transforms=albu_train_transforms,bbox_params=dict(type='BboxParams',format='pascal_voc',label_fields=['gt_bboxes_labels', 'gt_ignore_flags']),keymap={'img': 'image','gt_bboxes': 'bboxes'}),dict(type='YOLOv5HSVRandomAug'),dict(type='mmdet.RandomFlip', prob=0.5),dict(type='mmdet.PackDetInputs',meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'flip','flip_direction'))
]train_dataloader = dict(batch_size=train_batch_size_per_gpu,num_workers=train_num_workers,persistent_workers=persistent_workers,pin_memory=True,sampler=dict(type='DefaultSampler', shuffle=True),dataset=dict(type=dataset_type,data_root=data_root,ann_file=train_ann_file,data_prefix=dict(img=train_data_prefix),filter_cfg=dict(filter_empty_gt=False, min_size=32),pipeline=train_pipeline))test_pipeline = [dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),dict(type='YOLOv5KeepRatioResize', scale=img_scale),dict(type='LetterResize',scale=img_scale,allow_scale_up=False,pad_val=dict(img=114)),dict(type='LoadAnnotations', with_bbox=True, _scope_='mmdet'),dict(type='mmdet.PackDetInputs',meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape','scale_factor', 'pad_param'))
]val_dataloader = dict(batch_size=val_batch_size_per_gpu,num_workers=val_num_workers,persistent_workers=persistent_workers,pin_memory=True,drop_last=False,sampler=dict(type='DefaultSampler', shuffle=False),dataset=dict(type=dataset_type,data_root=data_root,test_mode=True,data_prefix=dict(img=val_data_prefix),ann_file=val_ann_file,pipeline=test_pipeline,batch_shapes_cfg=batch_shapes_cfg))test_dataloader = val_dataloaderparam_scheduler = None
optim_wrapper = dict(type='OptimWrapper',optimizer=dict(type='SGD',lr=base_lr,momentum=0.937,weight_decay=weight_decay,nesterov=True,batch_size_per_gpu=train_batch_size_per_gpu),constructor='YOLOv5OptimizerConstructor')default_hooks = dict(param_scheduler=dict(type='YOLOv5ParamSchedulerHook',scheduler_type='linear',lr_factor=lr_factor,max_epochs=max_epochs),checkpoint=dict(type='CheckpointHook',interval=save_checkpoint_intervals,save_best='auto',max_keep_ckpts=max_keep_ckpts))custom_hooks = [dict(type='EMAHook',ema_type='ExpMomentumEMA',momentum=0.0001,update_buffers=True,strict_load=False,priority=49)
]val_evaluator = dict(type='mmdet.CocoMetric',proposal_nums=(100, 1, 10),ann_file=data_root + val_ann_file,metric='bbox')
test_evaluator = val_evaluatortrain_cfg = dict(type='EpochBasedTrainLoop',max_epochs=max_epochs,val_interval=save_checkpoint_intervals)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

这篇关于【YOLO改进】主干插入ShuffleAttention模块(基于MMYOLO)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/937327

相关文章

Python中logging模块用法示例总结

《Python中logging模块用法示例总结》在Python中logging模块是一个强大的日志记录工具,它允许用户将程序运行期间产生的日志信息输出到控制台或者写入到文件中,:本文主要介绍Pyt... 目录前言一. 基本使用1. 五种日志等级2.  设置报告等级3. 自定义格式4. C语言风格的格式化方法

Python 基于http.server模块实现简单http服务的代码举例

《Python基于http.server模块实现简单http服务的代码举例》Pythonhttp.server模块通过继承BaseHTTPRequestHandler处理HTTP请求,使用Threa... 目录测试环境代码实现相关介绍模块简介类及相关函数简介参考链接测试环境win11专业版python

Nginx添加内置模块过程

《Nginx添加内置模块过程》文章指导如何检查并添加Nginx的with-http_gzip_static模块:确认该模块未默认安装后,需下载同版本源码重新编译,备份替换原有二进制文件,最后重启服务验... 目录1、查看Nginx已编辑的模块2、Nginx官网查看内置模块3、停止Nginx服务4、Nginx

python urllib模块使用操作方法

《pythonurllib模块使用操作方法》Python提供了多个库用于处理URL,常用的有urllib、requests和urlparse(Python3中为urllib.parse),下面是这些... 目录URL 处理库urllib 模块requests 库urlparse 和 urljoin编码和解码

创建springBoot模块没有目录结构的解决方案

《创建springBoot模块没有目录结构的解决方案》2023版IntelliJIDEA创建模块时可能出现目录结构识别错误,导致文件显示异常,解决方法为选择模块后点击确认,重新校准项目结构设置,确保源... 目录创建spChina编程ringBoot模块没有目录结构解决方案总结创建springBoot模块没有目录

idea Maven Springboot多模块项目打包时90%的问题及解决方案

《ideaMavenSpringboot多模块项目打包时90%的问题及解决方案》:本文主要介绍ideaMavenSpringboot多模块项目打包时90%的问题及解决方案,具有很好的参考价值,... 目录1. 前言2. 问题3. 解决办法4. jar 包冲突总结1. 前言之所以写这篇文章是因为在使用Mav

Python标准库datetime模块日期和时间数据类型解读

《Python标准库datetime模块日期和时间数据类型解读》文章介绍Python中datetime模块的date、time、datetime类,用于处理日期、时间及日期时间结合体,通过属性获取时间... 目录Datetime常用类日期date类型使用时间 time 类型使用日期和时间的结合体–日期时间(

Python通用唯一标识符模块uuid使用案例详解

《Python通用唯一标识符模块uuid使用案例详解》Pythonuuid模块用于生成128位全局唯一标识符,支持UUID1-5版本,适用于分布式系统、数据库主键等场景,需注意隐私、碰撞概率及存储优... 目录简介核心功能1. UUID版本2. UUID属性3. 命名空间使用场景1. 生成唯一标识符2. 数

Python中re模块结合正则表达式的实际应用案例

《Python中re模块结合正则表达式的实际应用案例》Python中的re模块是用于处理正则表达式的强大工具,正则表达式是一种用来匹配字符串的模式,它可以在文本中搜索和匹配特定的字符串模式,这篇文章主... 目录前言re模块常用函数一、查看文本中是否包含 A 或 B 字符串二、替换多个关键词为统一格式三、提

一文深入详解Python的secrets模块

《一文深入详解Python的secrets模块》在构建涉及用户身份认证、权限管理、加密通信等系统时,开发者最不能忽视的一个问题就是“安全性”,Python在3.6版本中引入了专门面向安全用途的secr... 目录引言一、背景与动机:为什么需要 secrets 模块?二、secrets 模块的核心功能1. 基