每日Attention学习15——Cross-Model Grafting Module

2024-08-26 02:20

本文主要是介绍每日Attention学习15——Cross-Model Grafting Module,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

模块出处

[CVPR 22] [link] [code] Pyramid Grafting Network for One-Stage High Resolution Saliency Detection


模块名称

Cross-Model Grafting Module (CMGM)


模块作用

Transformer与CNN之间的特征融合


模块结构

在这里插入图片描述


模块思想

Transformer在全局特征上更优,CNN在局部特征上更优,对这两者进行进行融合的最简单做法是直接相加或相乘。但是,相加或相乘本质上属于"局部"操作,如果某片区域两个特征的不确定性都较高,则会带来许多噪声。为此,本文提出了CMGM模块,通过交叉注意力的形式引入更为广泛的信息来增强融合效果。


模块代码
import torch.nn.functional as F
import torch.nn as nn
import torchclass CMGM(nn.Module):def __init__(self, dim, num_heads=8, qkv_bias=True, qk_scale=None):super().__init__()self.num_heads = num_headshead_dim = dim // num_headsself.scale = qk_scale or head_dim ** -0.5self.k = nn.Linear(dim, dim , bias=qkv_bias)self.qv = nn.Linear(dim, dim * 2, bias=qkv_bias)self.proj = nn.Linear(dim, dim)self.act = nn.ReLU(inplace=True)self.conv = nn.Conv2d(8,8,kernel_size=3, stride=1, padding=1)self.lnx = nn.LayerNorm(64)self.lny = nn.LayerNorm(64)self.bn = nn.BatchNorm2d(8)self.conv2 = nn.Sequential(nn.Conv2d(64,64,kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(64),nn.ReLU(inplace=True),nn.Conv2d(64,64,kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(64),nn.ReLU(inplace=True))def forward(self, x, y):batch_size = x.shape[0]chanel     = x.shape[1]sc = xx = x.view(batch_size, chanel, -1).permute(0, 2, 1)sc1 = xx = self.lnx(x)y = y.view(batch_size, chanel, -1).permute(0, 2, 1)y = self.lny(y)B, N, C = x.shapey_k = self.k(y).reshape(B, N, 1, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)x_qv= self.qv(x).reshape(B,N,2,self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)x_q, x_v = x_qv[0], x_qv[1] y_k = y_k[0]attn = (x_q @ y_k.transpose(-2, -1)) * self.scaleattn = attn.softmax(dim=-1)x = (attn @ x_v).transpose(1, 2).reshape(B, N, C)x = self.proj(x)x = (x+sc1)x = x.permute(0,2,1)x = x.view(batch_size,chanel,*sc.size()[2:])x = self.conv2(x)+xreturn x, self.act(self.bn(self.conv(attn+attn.transpose(-1,-2))))if __name__ == '__main__':x = torch.randn([1, 64, 11, 11])y = torch.randn([1, 64, 11, 11])cmgm = CMGM(dim=64)out1, out2 = cmgm(x, y)print(out1.shape)  # out feature 1, 64, 11, 11print(out2.shape)  # cross attention matrix 1, 8, 121, 121

这篇关于每日Attention学习15——Cross-Model Grafting Module的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!


原文地址:
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.chinasem.cn/article/1107275

相关文章

Android学习总结之Java和kotlin区别超详细分析

《Android学习总结之Java和kotlin区别超详细分析》Java和Kotlin都是用于Android开发的编程语言,它们各自具有独特的特点和优势,:本文主要介绍Android学习总结之Ja... 目录一、空安全机制真题 1:Kotlin 如何解决 Java 的 NullPointerExceptio

重新对Java的类加载器的学习方式

《重新对Java的类加载器的学习方式》:本文主要介绍重新对Java的类加载器的学习方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1、介绍1.1、简介1.2、符号引用和直接引用1、符号引用2、直接引用3、符号转直接的过程2、加载流程3、类加载的分类3.1、显示

Java学习手册之Filter和Listener使用方法

《Java学习手册之Filter和Listener使用方法》:本文主要介绍Java学习手册之Filter和Listener使用方法的相关资料,Filter是一种拦截器,可以在请求到达Servl... 目录一、Filter(过滤器)1. Filter 的工作原理2. Filter 的配置与使用二、Listen

如何解决idea的Module:‘:app‘platform‘android-32‘not found.问题

《如何解决idea的Module:‘:app‘platform‘android-32‘notfound.问题》:本文主要介绍如何解决idea的Module:‘:app‘platform‘andr... 目录idea的Module:‘:app‘pwww.chinasem.cnlatform‘android-32

Pydantic中model_validator的实现

《Pydantic中model_validator的实现》本文主要介绍了Pydantic中model_validator的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价... 目录引言基础知识创建 Pydantic 模型使用 model_validator 装饰器高级用法mo

GORM中Model和Table的区别及使用

《GORM中Model和Table的区别及使用》Model和Table是两种与数据库表交互的核心方法,但它们的用途和行为存在著差异,本文主要介绍了GORM中Model和Table的区别及使用,具有一... 目录1. Model 的作用与特点1.1 核心用途1.2 行为特点1.3 示例China编程代码2. Tab

Python中ModuleNotFoundError: No module named ‘timm’的错误解决

《Python中ModuleNotFoundError:Nomodulenamed‘timm’的错误解决》本文主要介绍了Python中ModuleNotFoundError:Nomodulen... 目录一、引言二、错误原因分析三、解决办法1.安装timm模块2. 检查python环境3. 解决安装路径问题

Java进阶学习之如何开启远程调式

《Java进阶学习之如何开启远程调式》Java开发中的远程调试是一项至关重要的技能,特别是在处理生产环境的问题或者协作开发时,:本文主要介绍Java进阶学习之如何开启远程调式的相关资料,需要的朋友... 目录概述Java远程调试的开启与底层原理开启Java远程调试底层原理JVM参数总结&nbsMbKKXJx

Java深度学习库DJL实现Python的NumPy方式

《Java深度学习库DJL实现Python的NumPy方式》本文介绍了DJL库的背景和基本功能,包括NDArray的创建、数学运算、数据获取和设置等,同时,还展示了如何使用NDArray进行数据预处理... 目录1 NDArray 的背景介绍1.1 架构2 JavaDJL使用2.1 安装DJL2.2 基本操

nginx-rtmp-module模块实现视频点播的示例代码

《nginx-rtmp-module模块实现视频点播的示例代码》本文主要介绍了nginx-rtmp-module模块实现视频点播,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习... 目录预置条件Nginx点播基本配置点播远程文件指定多个播放位置参考预置条件配置点播服务器 192.