YOLOv8改进 | 注意力篇 | YOLOv8引入CBAM注意力机制

2024-08-30 08:20

本文主要是介绍YOLOv8改进 | 注意力篇 | YOLOv8引入CBAM注意力机制,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1.CBAM介绍

摘要:我们提出了卷积块注意力模块(CBAM),这是一种用于前馈卷积神经网络的简单而有效的注意力模块。 给定中间特征图,我们的模块沿着两个独立的维度(通道和空间)顺序推断注意力图,然后将注意力图乘以输入特征图以进行自适应特征细化。 由于 CBAM 是一个轻量级通用模块,因此它可以无缝集成到任何 CNN 架构中,且开销可以忽略不计,并且可以与基础 CNN 一起进行端到端训练。 我们通过在 ImageNet-1K、MS COCO 检测和 VOC 2007 检测数据集上进行大量实验来验证我们的 CBAM。 我们的实验表明各种模型的分类和检测性能得到了一致的改进,证明了 CBAM 的广泛适用性。 代码和模型将公开。

官方论文地址:CBAM论文 

官方代码地址:CBAM代码

简单介绍:CBAM的主要思想是通过关注重要的特征并抑制不必要的特征来增强网络的表示能力。模块首先应用通道注意力,关注"重要的"特征,然后应用空间注意力,关注这些特征的"重要位置"。通过这种方式,CBAM有效地帮助网络聚焦于图像中的关键信息,提高了特征的表示力度,下图为其原理结构图。

2.核心代码

import torch
import torch.nn as nnclass ChannelAttention(nn.Module):"""Channel-attention module https://github.com/open-mmlab/mmdetection/tree/v3.0.0rc1/configs/rtmdet."""def __init__(self, channels: int) -> None:"""Initializes the class and sets the basic configurations and instance variables required."""super().__init__()self.pool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)self.act = nn.Sigmoid()def forward(self, x: torch.Tensor) -> torch.Tensor:"""Applies forward pass using activation on convolutions of the input, optionally using batch normalization."""return x * self.act(self.fc(self.pool(x)))class SpatialAttention(nn.Module):"""Spatial-attention module."""def __init__(self, kernel_size=7):"""Initialize Spatial-attention module with kernel size argument."""super().__init__()assert kernel_size in (3, 7), "kernel size must be 3 or 7"padding = 3 if kernel_size == 7 else 1self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)self.act = nn.Sigmoid()def forward(self, x):"""Apply channel and spatial attention on input for feature recalibration."""return x * self.act(self.cv1(torch.cat([torch.mean(x, 1, keepdim=True), torch.max(x, 1, keepdim=True)[0]], 1)))class CBAM(nn.Module):"""Convolutional Block Attention Module."""def __init__(self, c1, kernel_size=7):"""Initialize CBAM with given input channel (c1) and kernel size."""super().__init__()self.channel_attention = ChannelAttention(c1)self.spatial_attention = SpatialAttention(kernel_size)def forward(self, x):"""Applies the forward pass through C1 module."""return self.spatial_attention(self.channel_attention(x))

3.YOLOv8中添加CBAM方式  

3.1 在ultralytics/nn下新建Extramodule

3.2 在Extramodule里创建CBAM

在CBAM.py文件里添加给出的CBAM代码

添加完CBAM代码后,在ultralytics/nn/Extramodule/__init__.py文件中引用

3.3 在task.py里引用

在ultralytics/nn/tasks.py文件里引用Extramodule

在task.py找到parse_model(ctrl+f可以直接搜索parse_model位置

添加如下代码:

        elif m in {CBAM}:c2 = ch[f]args = [c2, *args]

4.新建一个yolov8CBAM.yaml文件

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect# Parameters
nc: 2 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'# [depth, width, max_channels]n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPss: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPsm: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPsl: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPsx: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs# YOLOv8.0n backbone
backbone:# [from, repeats, module, args]- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4- [-1, 3, C2f, [128, True]]- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8- [-1, 6, C2f, [256, True]]- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16- [-1, 6, C2f, [512, True]]- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32- [-1, 3, C2f, [1024, True]]- [-1, 1, SPPF, [1024, 5]] # 9# YOLOv8.0n head
head:- [-1, 1, nn.Upsample, [None, 2, "nearest"]]- [[-1, 6], 1, Concat, [1]] # cat backbone P4- [-1, 3, C2f, [512]] # 12- [-1, 1, nn.Upsample, [None, 2, "nearest"]]- [[-1, 4], 1, Concat, [1]] # cat backbone P3- [-1, 3, C2f, [256]] # 15 (P3/8-small)- [-1, 1, CBAM, []]- [-1, 1, Conv, [256, 3, 2]]- [[-1, 12], 1, Concat, [1]] # cat head P4- [-1, 3, C2f, [512]] # 18 (P4/16-medium)- [-1, 1, CBAM, []]- [-1, 1, Conv, [512, 3, 2]]- [[-1, 9], 1, Concat, [1]] # cat head P5- [-1, 3, C2f, [1024]] # 21 (P5/32-large)- [-1, 1, CBAM, []]- [[15, 19, 24], 1, Detect, [nc]] # Detect(P3, P4, P5)

大家根据自己的数据集实际情况,修改nc大小。

5.模型训练

import warnings
warnings.filterwarnings('ignore')
from ultralytics import YOLOif __name__ == '__main__':model = YOLO(r'E:\csdn\ultralytics-main\datasets\yolov8CBAM.yaml')model.train(data=r'E:\csdn\ultralytics-main\datasets\data.yaml',cache=False,imgsz=640,epochs=100,single_cls=False,  # 是否是单类别检测batch=16,close_mosaic=10,workers=0,device='0',optimizer='SGD',amp=True,project='runs/train',name='exp',)

模型结构打印,成功运行 :

6.本文总结

到此本文的正式分享内容就结束了,在这里给大家推荐我的YOLOv8改进有效涨点专栏,后期我会根据各种最新的前沿顶会进行论文复现,也会对一些老的改进机制进行补充,如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~

YOLOv8有效涨点专栏

这篇关于YOLOv8改进 | 注意力篇 | YOLOv8引入CBAM注意力机制的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1120288

相关文章

Maven 配置中的 <mirror>绕过 HTTP 阻断机制的方法

《Maven配置中的<mirror>绕过HTTP阻断机制的方法》:本文主要介绍Maven配置中的<mirror>绕过HTTP阻断机制的方法,本文给大家分享问题原因及解决方案,感兴趣的朋友一... 目录一、问题场景:升级 Maven 后构建失败二、解决方案:通过 <mirror> 配置覆盖默认行为1. 配置示

Redis过期删除机制与内存淘汰策略的解析指南

《Redis过期删除机制与内存淘汰策略的解析指南》在使用Redis构建缓存系统时,很多开发者只设置了EXPIRE但却忽略了背后Redis的过期删除机制与内存淘汰策略,下面小编就来和大家详细介绍一下... 目录1、简述2、Redis http://www.chinasem.cn的过期删除策略(Key Expir

Go语言中Recover机制的使用

《Go语言中Recover机制的使用》Go语言的recover机制通过defer函数捕获panic,实现异常恢复与程序稳定性,具有一定的参考价值,感兴趣的可以了解一下... 目录引言Recover 的基本概念基本代码示例简单的 Recover 示例嵌套函数中的 Recover项目场景中的应用Web 服务器中

Jvm sandbox mock机制的实践过程

《Jvmsandboxmock机制的实践过程》:本文主要介绍Jvmsandboxmock机制的实践过程,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、背景二、定义一个损坏的钟1、 Springboot工程中创建一个Clock类2、 添加一个Controller

Dubbo之SPI机制的实现原理和优势分析

《Dubbo之SPI机制的实现原理和优势分析》:本文主要介绍Dubbo之SPI机制的实现原理和优势,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录Dubbo中SPI机制的实现原理和优势JDK 中的 SPI 机制解析Dubbo 中的 SPI 机制解析总结Dubbo中

Java 的 Condition 接口与等待通知机制详解

《Java的Condition接口与等待通知机制详解》在Java并发编程里,实现线程间的协作与同步是极为关键的任务,本文将深入探究Condition接口及其背后的等待通知机制,感兴趣的朋友一起看... 目录一、引言二、Condition 接口概述2.1 基本概念2.2 与 Object 类等待通知方法的区别

CSS引入方式和选择符的讲解和运用小结

《CSS引入方式和选择符的讲解和运用小结》CSS即层叠样式表,是一种用于描述网页文档(如HTML或XML)外观和格式的样式表语言,它主要用于将网页内容的呈现(外观)和结构(内容)分离,从而实现... 目录一、前言二、css 是什么三、CSS 引入方式1、行内样式2、内部样式表3、链入外部样式表四、CSS 选

macOS Sequoia 15.5 发布: 改进邮件和屏幕使用时间功能

《macOSSequoia15.5发布:改进邮件和屏幕使用时间功能》经过常规Beta测试后,新的macOSSequoia15.5现已公开发布,但重要的新功能将被保留到WWDC和... MACOS Sequoia 15.5 正式发布!本次更新为 Mac 用户带来了一系列功能强化、错误修复和安全性提升,进一步增

嵌入式Linux驱动中的异步通知机制详解

《嵌入式Linux驱动中的异步通知机制详解》:本文主要介绍嵌入式Linux驱动中的异步通知机制,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录前言一、异步通知的核心概念1. 什么是异步通知2. 异步通知的关键组件二、异步通知的实现原理三、代码示例分析1. 设备结构

JVM垃圾回收机制之GC解读

《JVM垃圾回收机制之GC解读》:本文主要介绍JVM垃圾回收机制之GC,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、死亡对象的判断算法1.1 引用计数算法1.2 可达性分析算法二、垃圾回收算法2.1 标记-清除算法2.2 复制算法2.3 标记-整理算法2.4