CVPR 2023 | 主干网络FasterNet 核心解读 代码分析

2023-11-01 04:36

本文主要是介绍CVPR 2023 | 主干网络FasterNet 核心解读 代码分析,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

本文分享来自CVPR 2023的论文,提出了一种快速的主干网络,名为FasterNet

论文提出了一种新的卷积算子,partial convolution,部分卷积(PConv),通过减少冗余计算内存访问来更有效地提取空间特征。

创新在于部分卷积(PConv),它选择一部分通道的特性进行常规卷积剩余部分通道的特性保持不变,降低了计算复杂度,从而实现了快速高效的神经网络。

区别于常规卷积:PConv只对输入通道的一部分应用卷积,而保留其余部分不变。

论文地址:Run, Don’t Walk: Chasing Higher FLOPS for Faster Neural Networks

代码地址:https://github.com/JierunChen/FasterNet/tree/master

目录

一、PConv算子设计原理

二、PConv算子的代码解析 

三、FasterNet模型原理

四、FasterNet模型测试

五、实验分析


背景:

  • MobileNet、ShuffleNet和GhostNet等利用深度卷积(DWConv)或 组卷积(GConv)来提取空间特征。
  • 然而,在减少FLOPs的过程中,算子经常会受到内存访问增加的副作用的影响
  • MicroNet进一步分解和稀疏网络,将其FLOPs推至极低水平。尽管这种方法在FLOPs方面有所改进,但其碎片计算效率很低。
  • 上述网络通常伴随着额外的数据操作,如级联、Shuffle和池化这些操作的运行时间对于小型模型来说往往很重要

一、PConv算子设计原理

 1、这种部分卷积的核心思想对输入特征图的部分通道应用卷积操作而保留其他通道不变。这种操作可以有效地减少计算冗余,提高计算效率。

对于连续或规则的内存访问,将第一个或最后一个连续的通道视为整个特征图的代表进行计算。

在不丧失一般性的情况下认为输入和输出特征图具有相同数量的通道

设计原因

通过利用特征图的冗余度可以进一步优化成本。

如下图所示,特征图在不同通道之间具有高度相似性。许多其他著作也涵盖了这种冗余,但很少有人以简单而有效的方式充分利用它。

于是出了PConv,对输入特征图的部分通道应用卷积操作而保留其他通道不变,同时减少计算冗余和内存访问。

2、为了充分有效地利用来自所有通道的信息,进一步将逐点卷积(PWConv)附加到PConv

它们在输入特征图上的有效感受野看起来像一个T形Conv,与均匀处理补丁的常规Conv相比,它更专注于中心位置。

通过实验表明:中心位置是卷积操作中最常见的突出位置,即中心位置的权重比周围的更重。这与集中于中心位置的T形计算一致。

虽然T形卷积可以直接用于高效计算,但作者表明,将T形卷积分解为PConv和PWConv更好,因为该分解利用了卷积操作间冗余并进一步节省了FLOPs。

二、PConv算子的代码解析 

PConv算子的代码:

'''
输入三个参数:dim(输入特征图的通道数),n_div(分割的组数)和forward(前向传播的方法)
输出:卷积后的特征图
'''
class Partial_conv3(nn.Module):def __init__(self, dim, n_div, forward):super().__init__()self.dim_conv3 = dim // n_div # 计算出卷积部分的通道数self.dim_untouched = dim - self.dim_conv3 # 计算出不需要卷积部分的通道数# 定义一个3*3卷积,输入通道数为self.dim_conv3,输出通道数也为self.dim_conv3,步长为1,填充为1,且不使用bias。self.partial_conv3 = nn.Conv2d(self.dim_conv3, self.dim_conv3, 3, 1, 1, bias=False)if forward == 'slicing':self.forward = self.forward_slicingelif forward == 'split_cat':self.forward = self.forward_split_catelse:raise NotImplementedError# 只适合推理def forward_slicing(self, x: Tensor) -> Tensor:# 对输入x进行深拷贝,以保持原始输入的完整性。后面的操作不会改变原始输入x。x = x.clone()   # 对输入x中前self.dim_conv3个通道应用卷积操作,并将结果保存回x中对应的位置。x[:, :self.dim_conv3, :, :] = self.partial_conv3(x[:, :self.dim_conv3, :, :])return x# 适合训练/推理def forward_split_cat(self, x: Tensor) -> Tensor:# 使用torch.split函数将输入x沿着通道维度(即第1维,索引从0开始)分割成两个部分,# 分别为x1和x2。分割的长度为[self.dim_conv3, self.dim_untouched],# 即分割后的x1的通道数为self.dim_conv3,x2的通道数为self.dim_untouched。x1, x2 = torch.split(x, [self.dim_conv3, self.dim_untouched], dim=1)x1 = self.partial_conv3(x1)x = torch.cat((x1, x2), 1)return x

这段代码定义了一个名为 Partial_conv3 的 PyTorch 模块,它是nn.Module的子类。这个模块主要实现了一种部分卷积(Partial Convolution); 

这种部分卷积的核心思想对输入特征图的部分通道应用卷积操作而保留其他通道不变。这种操作可以有效地减少计算冗余,提高计算效率。

方式1:slicing

 # 只适合推理def forward_slicing(self, x: Tensor) -> Tensor:# 对输入x进行深拷贝,以保持原始输入的完整性。后面的操作不会改变原始输入x。x = x.clone()   # 对输入x中前self.dim_conv3个通道应用卷积操作,并将结果保存回x中对应的位置。x[:, :self.dim_conv3, :, :] = self.partial_conv3(x[:, :self.dim_conv3, :, :])return x

方式2:split_cat

    # 适合训练/推理def forward_split_cat(self, x: Tensor) -> Tensor:# 使用torch.split函数将输入x沿着通道维度(即第1维,索引从0开始)分割成两个部分,# 分别为x1和x2。分割的长度为[self.dim_conv3, self.dim_untouched],# 即分割后的x1的通道数为self.dim_conv3,x2的通道数为self.dim_untouched。x1, x2 = torch.split(x, [self.dim_conv3, self.dim_untouched], dim=1)x1 = self.partial_conv3(x1)x = torch.cat((x1, x2), 1)return x

三、FasterNet模型原理

基于部分卷积算子PConv逐点卷积PWConv,作为主要的算子,进一步提出FasterNet。

这是一个新的神经网络家族,运行速度非常快,对许多视觉任务有效。模型架构如下:

它有4个层次级,每个层次级前面都有一个嵌入层(步长为4的常规4×4卷积)或一个合并层(步长为2的常规2×2卷积),用于空间下采样和通道数量扩展。每个阶段都有一堆FasterNet块。

每个FasterNet块有一个PConv层,后跟2个PWConv(或Conv 1×1)层。它们一起显示为倒置残差块,其中中间层具有扩展的通道数量,并且放置了Shorcut以重用输入特征。

最后两个阶段中的块消耗更少的内存访问,并且倾向于具有更高的FLOPS,因此,放置了更多FasterNet块,并相应地将更多计算分配给最后两个阶段。

补充一下标准化和激活层

标准化和激活层对于高性能神经网络也是不可或缺的。

然而,许多先前的工作在整个网络中过度使用这些层,这可能会限制特征多样性,从而损害性能。它还可以降低整体计算速度。

相比之下,只将它们放在每个中间PWConv之后,以保持特征多样性并实现较低的延迟。

四、FasterNet模型测试

使用默认的参数构建FasterNet

        mlp_ratio=2.0,

        embed_dim=96,

        depths=(1, 2, 8, 2),

        drop_path_rate=0.10,

看一下的模型参数 :

感觉模型也不小的。。。。。。。

测试代码分享给大家(代码存放路径:models/model_summary.py)

import torch.nn as nn
from fasternet import FasterNet
from torchsummary import summary# 默认参数
def fasternet(**kwargs):model = FasterNet(**kwargs)return model# S
def fasternet_s(**kwargs):model = FasterNet(mlp_ratio=2.0,embed_dim=128,depths=(1, 2, 13, 2),drop_path_rate=0.15,act_layer='RELU',fork_feat=True,**kwargs)return model# M
def fasternet_m(**kwargs):model = FasterNet(mlp_ratio=2.0,embed_dim=144,depths=(3, 4, 18, 3),drop_path_rate=0.2,act_layer='RELU',fork_feat=True,**kwargs)return model# L
def fasternet_l(**kwargs):model = FasterNet(mlp_ratio=2.0,embed_dim=192,depths=(3, 4, 18, 3),drop_path_rate=0.3,act_layer='RELU',fork_feat=True,**kwargs)return modelprint("fasternet:", fasternet)
model = fasternet()
summary(model, input_size=(3, 224, 224))print("fasternet_s:", fasternet_s)
model = fasternet_s()
summary(model, input_size=(3, 224, 224))print("fasternet_m:", fasternet_m)
model = fasternet_m()
summary(model, input_size=(3, 224, 224))print("fasternet_l:", fasternet_l)
model = fasternet_l()
summary(model, input_size=(3, 224, 224))

github有各个版本的预训练模型,大家可以测试一下。

nameresolutionacc#paramsFLOPsmodel
FasterNet-T0224x22471.93.9M0.34Gmodel
FasterNet-T1224x22476.27.6M0.85Gmodel
FasterNet-T2224x22478.915.0M1.90Gmodel
FasterNet-S224x22481.331.1M4.55Gmodel
FasterNet-M224x22483.053.5M8.72Gmodel
FasterNet-L224x22483.593.4M15.49Gmodel

官方给的数据:

五、实验分析

FasterNet在不同设备(CPU、GPU、ARM),精度-吞吐量和精度-延迟权衡方面具有最高的效率。

图像分类中,比较ImageNet-1k基准。具有类似TOP-1精度的模型被组合在一起。除MobileViT和EdgeNeXt的分辨率为256×256外,所有型号的分辨率均为224×224。OOM是内存不足的缩写。

关于COCO目标检测实例分割基准的结果,Flop是根据图像大小(1280,800)计算的。

分享完成~

这篇关于CVPR 2023 | 主干网络FasterNet 核心解读 代码分析的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/320530

相关文章

SpringBoot中四种AOP实战应用场景及代码实现

《SpringBoot中四种AOP实战应用场景及代码实现》面向切面编程(AOP)是Spring框架的核心功能之一,它通过预编译和运行期动态代理实现程序功能的统一维护,在SpringBoot应用中,AO... 目录引言场景一:日志记录与性能监控业务需求实现方案使用示例扩展:MDC实现请求跟踪场景二:权限控制与

慢sql提前分析预警和动态sql替换-Mybatis-SQL

《慢sql提前分析预警和动态sql替换-Mybatis-SQL》为防止慢SQL问题而开发的MyBatis组件,该组件能够在开发、测试阶段自动分析SQL语句,并在出现慢SQL问题时通过Ducc配置实现动... 目录背景解决思路开源方案调研设计方案详细设计使用方法1、引入依赖jar包2、配置组件XML3、核心配

Java NoClassDefFoundError运行时错误分析解决

《JavaNoClassDefFoundError运行时错误分析解决》在Java开发中,NoClassDefFoundError是一种常见的运行时错误,它通常表明Java虚拟机在尝试加载一个类时未能... 目录前言一、问题分析二、报错原因三、解决思路检查类路径配置检查依赖库检查类文件调试类加载器问题四、常见

Python中的Walrus运算符分析示例详解

《Python中的Walrus运算符分析示例详解》Python中的Walrus运算符(:=)是Python3.8引入的一个新特性,允许在表达式中同时赋值和返回值,它的核心作用是减少重复计算,提升代码简... 目录1. 在循环中避免重复计算2. 在条件判断中同时赋值变量3. 在列表推导式或字典推导式中简化逻辑

Qt实现网络数据解析的方法总结

《Qt实现网络数据解析的方法总结》在Qt中解析网络数据通常涉及接收原始字节流,并将其转换为有意义的应用层数据,这篇文章为大家介绍了详细步骤和示例,感兴趣的小伙伴可以了解下... 目录1. 网络数据接收2. 缓冲区管理(处理粘包/拆包)3. 常见数据格式解析3.1 jsON解析3.2 XML解析3.3 自定义

SpringBoot中配置文件的加载顺序解读

《SpringBoot中配置文件的加载顺序解读》:本文主要介绍SpringBoot中配置文件的加载顺序,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录SpringBoot配置文件的加载顺序1、命令⾏参数2、Java系统属性3、操作系统环境变量5、项目【外部】的ap

利用Python调试串口的示例代码

《利用Python调试串口的示例代码》在嵌入式开发、物联网设备调试过程中,串口通信是最基础的调试手段本文将带你用Python+ttkbootstrap打造一款高颜值、多功能的串口调试助手,需要的可以了... 目录概述:为什么需要专业的串口调试工具项目架构设计1.1 技术栈选型1.2 关键类说明1.3 线程模

Python Transformers库(NLP处理库)案例代码讲解

《PythonTransformers库(NLP处理库)案例代码讲解》本文介绍transformers库的全面讲解,包含基础知识、高级用法、案例代码及学习路径,内容经过组织,适合不同阶段的学习者,对... 目录一、基础知识1. Transformers 库简介2. 安装与环境配置3. 快速上手示例二、核心模

Java的栈与队列实现代码解析

《Java的栈与队列实现代码解析》栈是常见的线性数据结构,栈的特点是以先进后出的形式,后进先出,先进后出,分为栈底和栈顶,栈应用于内存的分配,表达式求值,存储临时的数据和方法的调用等,本文给大家介绍J... 目录栈的概念(Stack)栈的实现代码队列(Queue)模拟实现队列(双链表实现)循环队列(循环数组

Mysql用户授权(GRANT)语法及示例解读

《Mysql用户授权(GRANT)语法及示例解读》:本文主要介绍Mysql用户授权(GRANT)语法及示例,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录mysql用户授权(GRANT)语法授予用户权限语法GRANT语句中的<权限类型>的使用WITH GRANT