CvT(ICCV 2021)论文与代码解读

2024-06-10 22:04
文章标签 代码 解读 论文 2021 iccv cvt

本文主要是介绍CvT(ICCV 2021)论文与代码解读,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

paper:CvT: Introducing Convolutions to Vision Transformers

official implementation:https://github.com/microsoft/CvT

出发点

该论文的出发点是改进Vision Transformer (ViT) 的性能和效率。传统的ViT在处理图像分类任务时虽然表现出色,但在数据量较小的情况下,其表现不如同等规模的卷积神经网络(CNN)。研究人员认为这是因为ViT缺乏CNN固有的一些有利特性,如对局部空间信息的捕捉能力。本文提出通过在ViT结构中引入卷积操作来弥补这一不足,以获得更好的性能和鲁棒性。

创新点

本文解决了如何在保持ViT优点(如动态注意力机制、全局上下文建模和更好的泛化能力)的同时,引入卷积神经网络的优点(如局部感受野、权重共享和空间下采样)。具体来说,论文通过引入卷积的方式来增强ViT的局部信息捕捉能力和计算效率,从而在各种图像分类任务中取得更好的表现。具体如下

  1. 卷积token embedding层:在ViT的结构中引入卷积embedding层,通过卷积操作将图像转换为token,同时保留局部空间信息。这种方法使模型能够在多个阶段逐步减少令token序列长度,同时增加token特征维度,类似于CNN的设计。
  2. 卷积projection:标准Transformer模块中的线性投影替换为卷积投影。通过深度可分离卷积操作,进一步捕捉局部空间上下文,并减少注意力机制中的语义模糊性。此外,卷积投影的步幅可用于对key和value矩阵进行下采样,从而显著提高计算效率。
  3. 无需位置编码:实验表明,CvT模型可以在不使用位置编码的情况下取得良好的性能,这简化了模型设计,尤其适用于处理高分辨率图像任务。

方法介绍

CvT的整体pipeline如图2所示。作者将两种基于卷积的operation引入Vision Transformer中,即Convolutional Token Embedding和Convolutional Projection。如图2(a)所示,借鉴了CNN采用了一个多个stage的层级设计,本文一共包含三个stage。每个stage包括两部分,首先输入图片(或reshape后的二维token map)经过Convolutional Token Embedding层的处理,具体是通过一个重叠的卷积实现。这使得每个stage可以逐渐减少token的数量(即特征分辨率)并增加token的宽度(即特征的维度),从而实现空间降采样并增加特征表示的丰富性。和之前的各种视觉Transformer不同,本文在这里并没有加上一个位置编码。接下来是堆叠的多个本文提出的Convolutional Transformer Block如图2(b)所示, 其中一个深度可分离卷积作为卷积投影分别作用于query、key和value。class token只在最后一个stage添加,最后通过一个MLP head得到最终的输出预测类别。 

Convolutional Token Embedding

给定一张图片或前一个stage输出并reshape成二维的token map \(x_{i-1}\in \mathbb{R}^{H_{i-1}\times W_{i-1}\times C_{i-1}}\) 作为当前stage \(i\) 的输入,我们学习一个卷积 \(f(\cdot)\) 将 \(x_{i-1}\) 映射到新的token \(f(x_{i-1})\),卷积核大小为 \(s\times s\),步长为 \(s-o\),padding为 \(p\)。新的token map \(f(x_{i-1})\in \mathbb{R}^{H_i\times W_i\times C_i}\) 的高和宽分别为

\(f(x_{i-1})\) 然后展平成 \(H_iW_i\times C_i\) 的shape并经过layer normalization处理,然后作为输入到stage \(i\) 的后续transformer block中。

Convolution Token Embedding层使得我们可以通过调整卷积的参数来调整每个stage的token特征维度和数量。通过这种方式,每个stage我们逐渐减少token序列的长度同时增加token特征的维度,使得token能够在越来越大的空间中表示越来越复杂的视觉模式,类似于CNN的特征层。

Convolutional Projection for Attention

本文提出的卷积映射层的目的是实现对局部context的额外建模,并通过对 \(K\) 和 \(V\) 矩阵降采样来提高效率。

图3(a)展示了ViT中使用的position-wise线性投影,图3(b)展示了本文提出的 \(s\times s\) 卷积投影。如图3(b)所示,tokens首先reshape成一个2D token map,然后通过一个深度可分离卷积实现卷积投影。最后再将projected tokens展平成1D作为后续的输入,如下

其中 \(x_i^{q/k/v}\) 是 \(i\) 层 \(Q/K/V\) 矩阵的token输入,\(conv2d\) 是一个深度可分离卷积具体实现为:\(Depthwise\ Con2d\rightarrow BatchNorm2d\rightarrow Pointwise\ Conv2d\),\(s\) 表示卷积核大小。原始的position-wise线性投影可以通过1x1卷积实现,因此这里新的卷积投影可以看作是一种推广。

实验结果

作者设计三种不同size的模型如表2所示,其中CvT-X中的X表示模型总共的transformer block的数量。CvT-224中的W表示Wide。

表3是在ImageNet数据集上和其它SOTA模型的对比。

代码解析

这里的代码是官方实现,convolutional token embedding的代码如下,在每个stage的开始都会首先经过ConvEmbed,以cvt-13为例,一共3个stage,patch_size=[7, 3, 3],patch_stride=[4, 2, 2],patch_padding=[2, 1, 1]。

class ConvEmbed(nn.Module):""" Image to Conv Embedding"""def __init__(self,patch_size=7,in_chans=3,embed_dim=64,stride=4,padding=2,norm_layer=None):super().__init__()patch_size = to_2tuple(patch_size)self.patch_size = patch_sizeself.proj = nn.Conv2d(in_chans, embed_dim,kernel_size=patch_size,stride=stride,padding=padding)self.norm = norm_layer(embed_dim) if norm_layer else Nonedef forward(self, x):x = self.proj(x)B, C, H, W = x.shapex = rearrange(x, 'b c h w -> b (h w) c')if self.norm:x = self.norm(x)x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W)return x

Attention的代码如下,在forward函数中会首先调用forward_conv得到q、k、v,这里的forward_conv就是本文提出的conv projection,在函数_build_projection中method='dw_bn',因此三个投影都是通过深度可分离卷积实现的。在self.forward_conv后就是普通的计算attention的过程了。

class Attention(nn.Module):def __init__(self,dim_in,dim_out,num_heads,qkv_bias=False,attn_drop=0.,proj_drop=0.,method='dw_bn',kernel_size=3,stride_kv=1,stride_q=1,padding_kv=1,padding_q=1,with_cls_token=True,**kwargs):super().__init__()self.stride_kv = stride_kvself.stride_q = stride_qself.dim = dim_outself.num_heads = num_heads# head_dim = self.qkv_dim // num_headsself.scale = dim_out ** -0.5self.with_cls_token = with_cls_tokenself.conv_proj_q = self._build_projection(dim_in, dim_out, kernel_size, padding_q,stride_q, 'linear' if method == 'avg' else method)self.conv_proj_k = self._build_projection(dim_in, dim_out, kernel_size, padding_kv,stride_kv, method)self.conv_proj_v = self._build_projection(dim_in, dim_out, kernel_size, padding_kv,stride_kv, method)self.proj_q = nn.Linear(dim_in, dim_out, bias=qkv_bias)self.proj_k = nn.Linear(dim_in, dim_out, bias=qkv_bias)self.proj_v = nn.Linear(dim_in, dim_out, bias=qkv_bias)self.attn_drop = nn.Dropout(attn_drop)self.proj = nn.Linear(dim_out, dim_out)self.proj_drop = nn.Dropout(proj_drop)def _build_projection(self,dim_in,dim_out,kernel_size,padding,stride,method):if method == 'dw_bn':proj = nn.Sequential(OrderedDict([('conv', nn.Conv2d(dim_in,dim_in,kernel_size=kernel_size,padding=padding,stride=stride,bias=False,groups=dim_in)),('bn', nn.BatchNorm2d(dim_in)),('rearrage', Rearrange('b c h w -> b (h w) c')),]))elif method == 'avg':proj = nn.Sequential(OrderedDict([('avg', nn.AvgPool2d(kernel_size=kernel_size,padding=padding,stride=stride,ceil_mode=True)),('rearrage', Rearrange('b c h w -> b (h w) c')),]))elif method == 'linear':proj = Noneelse:raise ValueError('Unknown method ({})'.format(method))return projdef forward_conv(self, x, h, w):if self.with_cls_token:cls_token, x = torch.split(x, [1, h*w], 1)x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)if self.conv_proj_q is not None:q = self.conv_proj_q(x)else:q = rearrange(x, 'b c h w -> b (h w) c')if self.conv_proj_k is not None:k = self.conv_proj_k(x)else:k = rearrange(x, 'b c h w -> b (h w) c')if self.conv_proj_v is not None:v = self.conv_proj_v(x)else:v = rearrange(x, 'b c h w -> b (h w) c')if self.with_cls_token:q = torch.cat((cls_token, q), dim=1)k = torch.cat((cls_token, k), dim=1)v = torch.cat((cls_token, v), dim=1)return q, k, vdef forward(self, x, h, w):if (self.conv_proj_q is not Noneor self.conv_proj_k is not Noneor self.conv_proj_v is not None):q, k, v = self.forward_conv(x, h, w)q = rearrange(self.proj_q(q), 'b t (h d) -> b h t d', h=self.num_heads)k = rearrange(self.proj_k(k), 'b t (h d) -> b h t d', h=self.num_heads)v = rearrange(self.proj_v(v), 'b t (h d) -> b h t d', h=self.num_heads)attn_score = torch.einsum('bhlk,bhtk->bhlt', [q, k]) * self.scaleattn = F.softmax(attn_score, dim=-1)attn = self.attn_drop(attn)x = torch.einsum('bhlt,bhtv->bhlv', [attn, v])x = rearrange(x, 'b h t d -> b t (h d)')x = self.proj(x)x = self.proj_drop(x)return x

这篇关于CvT(ICCV 2021)论文与代码解读的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1049388

相关文章

Linux jq命令的使用解读

《Linuxjq命令的使用解读》jq是一个强大的命令行工具,用于处理JSON数据,它可以用来查看、过滤、修改、格式化JSON数据,通过使用各种选项和过滤器,可以实现复杂的JSON处理任务... 目录一. 简介二. 选项2.1.2.2-c2.3-r2.4-R三. 字段提取3.1 普通字段3.2 数组字段四.

Java集合之Iterator迭代器实现代码解析

《Java集合之Iterator迭代器实现代码解析》迭代器Iterator是Java集合框架中的一个核心接口,位于java.util包下,它定义了一种标准的元素访问机制,为各种集合类型提供了一种统一的... 目录一、什么是Iterator二、Iterator的核心方法三、基本使用示例四、Iterator的工

Java 线程池+分布式实现代码

《Java线程池+分布式实现代码》在Java开发中,池通过预先创建并管理一定数量的资源,避免频繁创建和销毁资源带来的性能开销,从而提高系统效率,:本文主要介绍Java线程池+分布式实现代码,需要... 目录1. 线程池1.1 自定义线程池实现1.1.1 线程池核心1.1.2 代码示例1.2 总结流程2. J

MySQL之搜索引擎使用解读

《MySQL之搜索引擎使用解读》MySQL存储引擎是数据存储和管理的核心组件,不同引擎(如InnoDB、MyISAM)采用不同机制,InnoDB支持事务与行锁,适合高并发场景;MyISAM不支持事务,... 目录mysql的存储引擎是什么MySQL存储引擎的功能MySQL的存储引擎的分类查看存储引擎1.命令

Spring的基础事务注解@Transactional作用解读

《Spring的基础事务注解@Transactional作用解读》文章介绍了Spring框架中的事务管理,核心注解@Transactional用于声明事务,支持传播机制、隔离级别等配置,结合@Tran... 目录一、事务管理基础1.1 Spring事务的核心注解1.2 注解属性详解1.3 实现原理二、事务事

JS纯前端实现浏览器语音播报、朗读功能的完整代码

《JS纯前端实现浏览器语音播报、朗读功能的完整代码》在现代互联网的发展中,语音技术正逐渐成为改变用户体验的重要一环,下面:本文主要介绍JS纯前端实现浏览器语音播报、朗读功能的相关资料,文中通过代码... 目录一、朗读单条文本:① 语音自选参数,按钮控制语音:② 效果图:二、朗读多条文本:① 语音有默认值:②

Vue实现路由守卫的示例代码

《Vue实现路由守卫的示例代码》Vue路由守卫是控制页面导航的钩子函数,主要用于鉴权、数据预加载等场景,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着... 目录一、概念二、类型三、实战一、概念路由守卫(Navigation Guards)本质上就是 在路

uni-app小程序项目中实现前端图片压缩实现方式(附详细代码)

《uni-app小程序项目中实现前端图片压缩实现方式(附详细代码)》在uni-app开发中,文件上传和图片处理是很常见的需求,但也经常会遇到各种问题,下面:本文主要介绍uni-app小程序项目中实... 目录方式一:使用<canvas>实现图片压缩(推荐,兼容性好)示例代码(小程序平台):方式二:使用uni

JAVA实现Token自动续期机制的示例代码

《JAVA实现Token自动续期机制的示例代码》本文主要介绍了JAVA实现Token自动续期机制的示例代码,通过动态调整会话生命周期平衡安全性与用户体验,解决固定有效期Token带来的风险与不便,感兴... 目录1. 固定有效期Token的内在局限性2. 自动续期机制:兼顾安全与体验的解决方案3. 总结PS

C#中通过Response.Headers设置自定义参数的代码示例

《C#中通过Response.Headers设置自定义参数的代码示例》:本文主要介绍C#中通过Response.Headers设置自定义响应头的方法,涵盖基础添加、安全校验、生产实践及调试技巧,强... 目录一、基础设置方法1. 直接添加自定义头2. 批量设置模式二、高级配置技巧1. 安全校验机制2. 类型