CBAM注意力机制详解(附pytorch复现)

2024-03-01 07:12

本文主要是介绍CBAM注意力机制详解(附pytorch复现),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

简介

论文原址:1807.06521.pdf (arxiv.org)

CBAM(Convolutional Block Attention Module)是一种卷积神经网络模块,旨在通过引入注意力机制来提升网络的表示能力。CBAM包含两个顺序子模块:通道注意力模块和空间注意力模块。

通过在深度网络的每个卷积块中自适应地优化中间特征图,CBAM通过强调通道和空间维度上的有意义特征,实现了对关键信息的关注和不必要信息的抑制。研究表明,CBAM在ImageNet-1K数据集上能够显著提高各种基线网络的准确性,通过grad-CAM可视化验证,CBAM增强的网络能够更准确地关注目标对象。在MS COCO和VOC 2007数据集上的目标检测任务中,CBAM也展现出显著的性能改进,而由于CBAM精心设计为轻量级模块,其在大多数情况下几乎没有参数和计算开销。CBAM注意力模块可广泛应用于提升卷积神经网络的表示能力。

Channel attention module(CAM)

通过平均池化和最大池化操作,整合输入特征图的空间信息,生成两个不同的空间上下文描述符,得到两个 1×1×C 的特征图,分别表示为 F_c_avg 和 F_c_max。将 F_c_avg 和 F_c_max 分别送入一个共享的多层感知机(MLP),该 MLP 具有一个隐藏层,其中第一层神经元个数为 C/r(r 为减少率),激活函数为 ReLU,第二层神经元个数为 C。这两层神经网络是共享的,即它们的权重相同。将两个 MLP 的输出特征进行逐元素相加,并通过 sigmoid 激活函数,生成通道注意力图 Mc。

这是对池化操作的使用进行实验比较的结果。研究者发现,采用平均池化和最大池化并行的方式能够取得更好的效果。可能是因为采用并行连接方式,相比于单一的池化,能够更有效地保留有用的信息,进而提升模型性能。

Spatial attention module(SAM)

首先,将 Channel Attention 模块输出的特征图作为 Spatial Attention 模块的输入特征图。接着,对输入特征图进行基于通道的全局最大池化和全局平均池化操作,得到两个 H×W×1 的特征图。然后,将这两个特征图在通道维度上进行拼接,经过一个 7×7 的卷积操作,将通道数降维为 1,即得到 H×W×1 的特征图。最后,经过 sigmoid 操作生成空间注意力特征,即 Ms。将该特征与输入特征图进行乘法操作,得到最终生成的特征。这一过程有助于模型关注输入特征图中的重要区域,从而增强表示能力。

CBAM的pytorch实现

"""
Original paper addresshttps: https://arxiv.org/pdf/1807.06521.pdf
Time: 2024-02-28
"""
import torch
from torch import nnclass ChannelAttention(nn.Module):def __init__(self, in_planes, reduction=16):super(ChannelAttention, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.max_pool = nn.AdaptiveMaxPool2d(1)# shared MLPself.mlp = nn.Sequential(nn.Conv2d(in_planes, in_planes // reduction, 1, bias=False),nn.ReLU(inplace=True),nn.Conv2d(in_planes // reduction, in_planes, 1, bias=False))self.sigmoid = nn.Sigmoid()def forward(self, x):avg_out = self.mlp(self.avg_pool(x))max_out = self.mlp(self.max_pool(x))out = avg_out + max_outreturn self.sigmoid(out)class SpatialAttention(nn.Module):def __init__(self, kernel_size=7, padding=3):super(SpatialAttention, self).__init__()self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):avg_out = torch.mean(x, dim=1, keepdim=True)max_out, _ = torch.max(x, dim=1, keepdim=True)x = torch.cat([avg_out, max_out], dim=1)x = self.conv1(x)return self.sigmoid(x)class CBAM(nn.Module):def __init__(self, in_planes, reduction=16, kernel_size=7):super(CBAM, self).__init__()self.ca = ChannelAttention(in_planes, reduction)self.sa = SpatialAttention(kernel_size)def forward(self, x):out = x * self.ca(x)result = out * self.sa(out)return resultif __name__ == '__main__':block = CBAM(16)input = torch.rand(1, 16, 8, 8)output = block(input)print(output.shape)

参考文章

CBAM——即插即用的注意力模块(附代码)_cbam模块-CSDN博客

[ 注意力机制 ] 经典网络模型2——CBAM 详解与复现_cbam代码复现-CSDN博客

这篇关于CBAM注意力机制详解(附pytorch复现)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!


原文地址:
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.chinasem.cn/article/761602

相关文章

HTML5 搜索框Search Box详解

《HTML5搜索框SearchBox详解》HTML5的搜索框是一个强大的工具,能够有效提升用户体验,通过结合自动补全功能和适当的样式,可以创建出既美观又实用的搜索界面,这篇文章给大家介绍HTML5... html5 搜索框(Search Box)详解搜索框是一个用于输入查询内容的控件,通常用于网站或应用程

Python中使用uv创建环境及原理举例详解

《Python中使用uv创建环境及原理举例详解》uv是Astral团队开发的高性能Python工具,整合包管理、虚拟环境、Python版本控制等功能,:本文主要介绍Python中使用uv创建环境及... 目录一、uv工具简介核心特点:二、安装uv1. 通过pip安装2. 通过脚本安装验证安装:配置镜像源(可

C++ 函数 strftime 和时间格式示例详解

《C++函数strftime和时间格式示例详解》strftime是C/C++标准库中用于格式化日期和时间的函数,定义在ctime头文件中,它将tm结构体中的时间信息转换为指定格式的字符串,是处理... 目录C++ 函数 strftipythonme 详解一、函数原型二、功能描述三、格式字符串说明四、返回值五

LiteFlow轻量级工作流引擎使用示例详解

《LiteFlow轻量级工作流引擎使用示例详解》:本文主要介绍LiteFlow是一个灵活、简洁且轻量的工作流引擎,适合用于中小型项目和微服务架构中的流程编排,本文给大家介绍LiteFlow轻量级工... 目录1. LiteFlow 主要特点2. 工作流定义方式3. LiteFlow 流程示例4. LiteF

Maven 配置中的 <mirror>绕过 HTTP 阻断机制的方法

《Maven配置中的<mirror>绕过HTTP阻断机制的方法》:本文主要介绍Maven配置中的<mirror>绕过HTTP阻断机制的方法,本文给大家分享问题原因及解决方案,感兴趣的朋友一... 目录一、问题场景:升级 Maven 后构建失败二、解决方案:通过 <mirror> 配置覆盖默认行为1. 配置示

CSS3中的字体及相关属性详解

《CSS3中的字体及相关属性详解》:本文主要介绍了CSS3中的字体及相关属性,详细内容请阅读本文,希望能对你有所帮助... 字体网页字体的三个来源:用户机器上安装的字体,放心使用。保存在第三方网站上的字体,例如Typekit和Google,可以link标签链接到你的页面上。保存在你自己Web服务器上的字

MySQL存储过程之循环遍历查询的结果集详解

《MySQL存储过程之循环遍历查询的结果集详解》:本文主要介绍MySQL存储过程之循环遍历查询的结果集,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录前言1. 表结构2. 存储过程3. 关于存储过程的SQL补充总结前言近来碰到这样一个问题:在生产上导入的数据发现

MyBatis ResultMap 的基本用法示例详解

《MyBatisResultMap的基本用法示例详解》在MyBatis中,resultMap用于定义数据库查询结果到Java对象属性的映射关系,本文给大家介绍MyBatisResultMap的基本... 目录MyBATis 中的 resultMap1. resultMap 的基本语法2. 简单的 resul

从基础到进阶详解Pandas时间数据处理指南

《从基础到进阶详解Pandas时间数据处理指南》Pandas构建了完整的时间数据处理生态,核心由四个基础类构成,Timestamp,DatetimeIndex,Period和Timedelta,下面我... 目录1. 时间数据类型与基础操作1.1 核心时间对象体系1.2 时间数据生成技巧2. 时间索引与数据

Mybatis Plus Join使用方法示例详解

《MybatisPlusJoin使用方法示例详解》:本文主要介绍MybatisPlusJoin使用方法示例详解,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,... 目录1、pom文件2、yaml配置文件3、分页插件4、示例代码:5、测试代码6、和PageHelper结合6