attention机制SENET、CBAM模块原理总结

2024-01-09 10:32

本文主要是介绍attention机制SENET、CBAM模块原理总结,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

参考博客:https://blog.csdn.net/weixin_33602281/article/details/85223216

def cbam_module(inputs,reduction_ratio=0.5,name=""):with tf.variable_scope("cbam_"+name, reuse=tf.AUTO_REUSE):#假如输入是[batsize,h,w,channel],#channel attension 因为要得到batsize * 1 * 1 * channel,它的全连接层第一层#隐藏层单元个数是channel / r, 第二层是channel,所以这里把channel赋值给hidden_numbatch_size,hidden_num=inputs.get_shape().as_list()[0],inputs.get_shape().as_list()[3]#通道attension#全局最大池化,窗口大小为h * w,所以对于这个数据[batsize,h,w,channel],他其实是求每个h * w面积的最大值#这里实现是先对h这个维度求最大值,然后对w这个维度求最大值,平均池化也一样maxpool_channel=tf.reduce_max(tf.reduce_max(inputs,axis=1,keepdims=True),axis=2,keepdims=True)avgpool_channel=tf.reduce_mean(tf.reduce_mean(inputs,axis=1,keepdims=True),axis=2,keepdims=True)#上面全局池化结果为batsize * 1 * 1 * channel,它这个拉平输入到全连接层#这个拉平,它会保留batsize,所以结果是[batsize,channel]maxpool_channel = tf.layers.Flatten()(maxpool_channel)avgpool_channel = tf.layers.Flatten()(avgpool_channel)#将上面拉平后结果输入到全连接层,第一个全连接层hiddensize = channel/r = channel * reduction_ratio,#第二哥全连接层hiddensize = channelmlp_1_max=tf.layers.dense(inputs=maxpool_channel,units=int(hidden_num*reduction_ratio),name="mlp_1",reuse=None,activation=tf.nn.relu)mlp_2_max=tf.layers.dense(inputs=mlp_1_max,units=hidden_num,name="mlp_2",reuse=None)#全连接层输出结果为[batsize,channel],这里又降它转回到原来维度batsize * 1 * 1 * channel,mlp_2_max=tf.reshape(mlp_2_max,[batch_size,1,1,hidden_num])mlp_1_avg=tf.layers.dense(inputs=avgpool_channel,units=int(hidden_num*reduction_ratio),name="mlp_1",reuse=True,activation=tf.nn.relu)mlp_2_avg=tf.layers.dense(inputs=mlp_1_avg,units=hidden_num,name="mlp_2",reuse=True)mlp_2_avg=tf.reshape(mlp_2_avg,[batch_size,1,1,hidden_num])#将平均和最大池化的结果维度都是[batch_size,1,1,channel]相加,然后进行sigmod,维度不变channel_attention=tf.nn.sigmoid(mlp_2_max+mlp_2_avg)#和最开始的inputs相乘,相当于[batch_size,1,1,channel] * [batch_size,h,w,channel]#只有维度一样才能相乘,这里相乘相当于给每个通道作用了不同的权重channel_refined_feature=inputs*channel_attention#空间attension#上面得到的结果维度依然是[batch_size,h,w,channel],#下面要进行全局通道池化,其实就是一条通道里面那个通道的值最大,其实就是对channel这个维度求最大值#每个通道池化相当于将通道压缩到了1维,有两个池化,结果为两个[batch_size,h,w,1]feature mapmaxpool_spatial=tf.reduce_max(inputs,axis=3,keepdims=True)avgpool_spatial=tf.reduce_mean(inputs,axis=3,keepdims=True)#将两个[batch_size,h,w,1]的feature map进行通道合并得到[batch_size,h,w,2]的feature mapmax_avg_pool_spatial=tf.concat([maxpool_spatial,avgpool_spatial],axis=3)#然后对上面的feature map用1个7*7的卷积核进行卷积得到[batch_size,h,w,1]的feature map,因为是用一个卷积核卷的#所以将2个输入通道压缩到了1个输出通道conv_layer=tf.layers.conv2d(inputs=max_avg_pool_spatial, filters=1, kernel_size=(7, 7), padding="same", activation=None)#然后再对上面得到的[batch_size,h,w,1]feature map进行sigmod,这里为什么要用一个卷积核压缩到1个通道,相当于只得到了一个面积的值#然后进行sigmod,因为我们要求的就是feature map面积上不同位置像素的中重要性,所以它压缩到了一个通道,然后求sigmodspatial_attention=tf.nn.sigmoid(conv_layer)#上面得到了空间attension feature map [batch_size,h,w,1],然后再用这个和经过空间attension作用的结果相乘得到最终的结果#这个结果就是经过通道和空间attension共同作用的结果refined_feature=channel_refined_feature*spatial_attentionreturn refined_feature

这篇关于attention机制SENET、CBAM模块原理总结的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/586816

相关文章

Spring Boot Interceptor的原理、配置、顺序控制及与Filter的关键区别对比分析

《SpringBootInterceptor的原理、配置、顺序控制及与Filter的关键区别对比分析》本文主要介绍了SpringBoot中的拦截器(Interceptor)及其与过滤器(Filt... 目录前言一、核心功能二、拦截器的实现2.1 定义自定义拦截器2.2 注册拦截器三、多拦截器的执行顺序四、过

JAVA线程的周期及调度机制详解

《JAVA线程的周期及调度机制详解》Java线程的生命周期包括NEW、RUNNABLE、BLOCKED、WAITING、TIMED_WAITING和TERMINATED,线程调度依赖操作系统,采用抢占... 目录Java线程的生命周期线程状态转换示例代码JAVA线程调度机制优先级设置示例注意事项JAVA线程

C# List.Sort四种重载总结

《C#List.Sort四种重载总结》本文详细分析了C#中List.Sort()方法的四种重载形式及其实现原理,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友... 目录1. Sort方法的四种重载2. 具体使用- List.Sort();- IComparable

SpringBoot项目整合Netty启动失败的常见错误总结

《SpringBoot项目整合Netty启动失败的常见错误总结》本文总结了SpringBoot集成Netty时常见的8类问题及解决方案,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参... 目录一、端口冲突问题1. Tomcat与Netty端口冲突二、主线程被阻塞问题1. Netty启动阻

Java中自旋锁与CAS机制的深层关系与区别

《Java中自旋锁与CAS机制的深层关系与区别》CAS算法即比较并替换,是一种实现并发编程时常用到的算法,Java并发包中的很多类都使用了CAS算法,:本文主要介绍Java中自旋锁与CAS机制深层... 目录1. 引言2. 比较并交换 (Compare-and-Swap, CAS) 核心原理2.1 CAS

Java 队列Queue从原理到实战指南

《Java队列Queue从原理到实战指南》本文介绍了Java中队列(Queue)的底层实现、常见方法及其区别,通过LinkedList和ArrayDeque的实现,以及循环队列的概念,展示了如何高效... 目录一、队列的认识队列的底层与集合框架常见的队列方法插入元素方法对比(add和offer)移除元素方法

SpringBoot整合Kafka启动失败的常见错误问题总结(推荐)

《SpringBoot整合Kafka启动失败的常见错误问题总结(推荐)》本文总结了SpringBoot项目整合Kafka启动失败的常见错误,包括Kafka服务器连接问题、序列化配置错误、依赖配置问题、... 目录一、Kafka服务器连接问题1. Kafka服务器无法连接2. 开发环境与生产环境网络不通二、序

SQL 注入攻击(SQL Injection)原理、利用方式与防御策略深度解析

《SQL注入攻击(SQLInjection)原理、利用方式与防御策略深度解析》本文将从SQL注入的基本原理、攻击方式、常见利用手法,到企业级防御方案进行全面讲解,以帮助开发者和安全人员更系统地理解... 目录一、前言二、SQL 注入攻击的基本概念三、SQL 注入常见类型分析1. 基于错误回显的注入(Erro

Spring Boot 集成 mybatis核心机制

《SpringBoot集成mybatis核心机制》这篇文章给大家介绍SpringBoot集成mybatis核心机制,本文结合实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值... 目录Spring Boot浅析1.依赖管理(Starter POMs)2.自动配置(AutoConfigu

Spring IOC核心原理详解与运用实战教程

《SpringIOC核心原理详解与运用实战教程》本文详细解析了SpringIOC容器的核心原理,包括BeanFactory体系、依赖注入机制、循环依赖解决和三级缓存机制,同时,介绍了SpringBo... 目录1. Spring IOC核心原理深度解析1.1 BeanFactory体系与内部结构1.1.1