attention机制SENET、CBAM模块原理总结

2024-01-09 10:32

本文主要是介绍attention机制SENET、CBAM模块原理总结,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

参考博客:https://blog.csdn.net/weixin_33602281/article/details/85223216

def cbam_module(inputs,reduction_ratio=0.5,name=""):with tf.variable_scope("cbam_"+name, reuse=tf.AUTO_REUSE):#假如输入是[batsize,h,w,channel],#channel attension 因为要得到batsize * 1 * 1 * channel,它的全连接层第一层#隐藏层单元个数是channel / r, 第二层是channel,所以这里把channel赋值给hidden_numbatch_size,hidden_num=inputs.get_shape().as_list()[0],inputs.get_shape().as_list()[3]#通道attension#全局最大池化,窗口大小为h * w,所以对于这个数据[batsize,h,w,channel],他其实是求每个h * w面积的最大值#这里实现是先对h这个维度求最大值,然后对w这个维度求最大值,平均池化也一样maxpool_channel=tf.reduce_max(tf.reduce_max(inputs,axis=1,keepdims=True),axis=2,keepdims=True)avgpool_channel=tf.reduce_mean(tf.reduce_mean(inputs,axis=1,keepdims=True),axis=2,keepdims=True)#上面全局池化结果为batsize * 1 * 1 * channel,它这个拉平输入到全连接层#这个拉平,它会保留batsize,所以结果是[batsize,channel]maxpool_channel = tf.layers.Flatten()(maxpool_channel)avgpool_channel = tf.layers.Flatten()(avgpool_channel)#将上面拉平后结果输入到全连接层,第一个全连接层hiddensize = channel/r = channel * reduction_ratio,#第二哥全连接层hiddensize = channelmlp_1_max=tf.layers.dense(inputs=maxpool_channel,units=int(hidden_num*reduction_ratio),name="mlp_1",reuse=None,activation=tf.nn.relu)mlp_2_max=tf.layers.dense(inputs=mlp_1_max,units=hidden_num,name="mlp_2",reuse=None)#全连接层输出结果为[batsize,channel],这里又降它转回到原来维度batsize * 1 * 1 * channel,mlp_2_max=tf.reshape(mlp_2_max,[batch_size,1,1,hidden_num])mlp_1_avg=tf.layers.dense(inputs=avgpool_channel,units=int(hidden_num*reduction_ratio),name="mlp_1",reuse=True,activation=tf.nn.relu)mlp_2_avg=tf.layers.dense(inputs=mlp_1_avg,units=hidden_num,name="mlp_2",reuse=True)mlp_2_avg=tf.reshape(mlp_2_avg,[batch_size,1,1,hidden_num])#将平均和最大池化的结果维度都是[batch_size,1,1,channel]相加,然后进行sigmod,维度不变channel_attention=tf.nn.sigmoid(mlp_2_max+mlp_2_avg)#和最开始的inputs相乘,相当于[batch_size,1,1,channel] * [batch_size,h,w,channel]#只有维度一样才能相乘,这里相乘相当于给每个通道作用了不同的权重channel_refined_feature=inputs*channel_attention#空间attension#上面得到的结果维度依然是[batch_size,h,w,channel],#下面要进行全局通道池化,其实就是一条通道里面那个通道的值最大,其实就是对channel这个维度求最大值#每个通道池化相当于将通道压缩到了1维,有两个池化,结果为两个[batch_size,h,w,1]feature mapmaxpool_spatial=tf.reduce_max(inputs,axis=3,keepdims=True)avgpool_spatial=tf.reduce_mean(inputs,axis=3,keepdims=True)#将两个[batch_size,h,w,1]的feature map进行通道合并得到[batch_size,h,w,2]的feature mapmax_avg_pool_spatial=tf.concat([maxpool_spatial,avgpool_spatial],axis=3)#然后对上面的feature map用1个7*7的卷积核进行卷积得到[batch_size,h,w,1]的feature map,因为是用一个卷积核卷的#所以将2个输入通道压缩到了1个输出通道conv_layer=tf.layers.conv2d(inputs=max_avg_pool_spatial, filters=1, kernel_size=(7, 7), padding="same", activation=None)#然后再对上面得到的[batch_size,h,w,1]feature map进行sigmod,这里为什么要用一个卷积核压缩到1个通道,相当于只得到了一个面积的值#然后进行sigmod,因为我们要求的就是feature map面积上不同位置像素的中重要性,所以它压缩到了一个通道,然后求sigmodspatial_attention=tf.nn.sigmoid(conv_layer)#上面得到了空间attension feature map [batch_size,h,w,1],然后再用这个和经过空间attension作用的结果相乘得到最终的结果#这个结果就是经过通道和空间attension共同作用的结果refined_feature=channel_refined_feature*spatial_attentionreturn refined_feature

这篇关于attention机制SENET、CBAM模块原理总结的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/586816

相关文章

Java HashMap的底层实现原理深度解析

《JavaHashMap的底层实现原理深度解析》HashMap基于数组+链表+红黑树结构,通过哈希算法和扩容机制优化性能,负载因子与树化阈值平衡效率,是Java开发必备的高效数据结构,本文给大家介绍... 目录一、概述:HashMap的宏观结构二、核心数据结构解析1. 数组(桶数组)2. 链表节点(Node

Python版本与package版本兼容性检查方法总结

《Python版本与package版本兼容性检查方法总结》:本文主要介绍Python版本与package版本兼容性检查方法的相关资料,文中提供四种检查方法,分别是pip查询、conda管理、PyP... 目录引言为什么会出现兼容性问题方法一:用 pip 官方命令查询可用版本方法二:conda 管理包环境方法

Redis中Hash从使用过程到原理说明

《Redis中Hash从使用过程到原理说明》RedisHash结构用于存储字段-值对,适合对象数据,支持HSET、HGET等命令,采用ziplist或hashtable编码,通过渐进式rehash优化... 目录一、开篇:Hash就像超市的货架二、Hash的基本使用1. 常用命令示例2. Java操作示例三

Redis中Set结构使用过程与原理说明

《Redis中Set结构使用过程与原理说明》本文解析了RedisSet数据结构,涵盖其基本操作(如添加、查找)、集合运算(交并差)、底层实现(intset与hashtable自动切换机制)、典型应用场... 目录开篇:从购物车到Redis Set一、Redis Set的基本操作1.1 编程常用命令1.2 集

Redis中的有序集合zset从使用到原理分析

《Redis中的有序集合zset从使用到原理分析》Redis有序集合(zset)是字符串与分值的有序映射,通过跳跃表和哈希表结合实现高效有序性管理,适用于排行榜、延迟队列等场景,其时间复杂度低,内存占... 目录开篇:排行榜背后的秘密一、zset的基本使用1.1 常用命令1.2 Java客户端示例二、zse

Redis中的AOF原理及分析

《Redis中的AOF原理及分析》Redis的AOF通过记录所有写操作命令实现持久化,支持always/everysec/no三种同步策略,重写机制优化文件体积,与RDB结合可平衡数据安全与恢复效率... 目录开篇:从日记本到AOF一、AOF的基本执行流程1. 命令执行与记录2. AOF重写机制二、AOF的

pycharm跑python项目易出错的问题总结

《pycharm跑python项目易出错的问题总结》:本文主要介绍pycharm跑python项目易出错问题的相关资料,当你在PyCharm中运行Python程序时遇到报错,可以按照以下步骤进行排... 1. 一定不要在pycharm终端里面创建环境安装别人的项目子模块等,有可能出现的问题就是你不报错都安装

java程序远程debug原理与配置全过程

《java程序远程debug原理与配置全过程》文章介绍了Java远程调试的JPDA体系,包含JVMTI监控JVM、JDWP传输调试命令、JDI提供调试接口,通过-Xdebug、-Xrunjdwp参数配... 目录背景组成模块间联系IBM对三个模块的详细介绍编程使用总结背景日常工作中,每个程序员都会遇到bu

Python中isinstance()函数原理解释及详细用法示例

《Python中isinstance()函数原理解释及详细用法示例》isinstance()是Python内置的一个非常有用的函数,用于检查一个对象是否属于指定的类型或类型元组中的某一个类型,它是Py... 目录python中isinstance()函数原理解释及详细用法指南一、isinstance()函数

Python sys模块的使用及说明

《Pythonsys模块的使用及说明》Pythonsys模块是核心工具,用于解释器交互与运行时控制,涵盖命令行参数处理、路径修改、强制退出、I/O重定向、系统信息获取等功能,适用于脚本开发与调试,需... 目录python sys 模块详解常用功能与代码示例获取命令行参数修改模块搜索路径强制退出程序标准输入