利用MMDetection进行模型微调和权重初始化

2024-05-14 14:36

本文主要是介绍利用MMDetection进行模型微调和权重初始化,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

目录

  • 模型微调
    • 修改第一处:更少的训练回合Epoch
    • 修改第二处:更小的学习率Learning Rate
    • 修改第三处:使用预训练模型
  • 权重初始化
    • init_cfg 的使用
    • 配置初始化器

本文基于 MMDetection官方文档,对模型微调和权重初始化进行第三方讲解。

在这里插入图片描述

模型微调

在 COCO 数据集上预训练的检测器可以作为其他数据集优质的预训练模型。
微调超参数与默认的训练策略不同。它通常需要更小的学习率和更少的训练回合。根据继承文件_base_的位置找到优化相关配置和训练和测试的配置的文件位置,我选择的Faster R-CNN相关配置文件位于mmdetection/configs/common/ms_3x_coco.py,为了不修改官方已经继承的配置文件,我们可以选择新建一个文件进行,例如mmdetection/configs/common/ms_3x_coco_finetuning.py。在进行下列步骤之前,请确保数据集与配置文件相匹配,并且检测头roi_headnum_classes与数据集类别数相匹配,参考利用MMDetection在自定义数据集上进行训练。

修改第一处:更少的训练回合Epoch

train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=8, val_interval=1)
# max_epochs = 12 → 8

修改第二处:更小的学习率Learning Rate

optim_wrapper = dict(type='OptimWrapper',optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001))
# lr = 0.002 → 0.001

修改第三处:使用预训练模型

load_from = '/home/miqi/mmdetection/checkpoints/faster_rcnn_r50_fpn_mstrain_3x_coco_20210524_110822-e10bd31c.pth'

权重初始化

在训练过程中,适当的初始化策略有利于加快训练速度或获得更⾼的性能。 MMCV 提供了一些常⽤的初始化模块的⽅法,如 nn.Conv2d。 MMdetection 中的模型初始化主要使⽤ init_cfg

例如在mmdetection/mmdet/models/necks/fpn.py

@MODELS.register_module()
class FPN(BaseModule):def __init__(self,in_channels: List[int],out_channels: int,num_outs: int,start_level: int = 0,end_level: int = -1,add_extra_convs: Union[bool, str] = False,relu_before_extra_convs: bool = False,no_norm_on_lateral: bool = False,conv_cfg: OptConfigType = None,norm_cfg: OptConfigType = None,act_cfg: OptConfigType = None,upsample_cfg: ConfigType = dict(mode='nearest'),init_cfg: MultiConfig = dict(type='Xavier', layer='Conv2d', distribution='uniform')) -> None:super().__init__(init_cfg=init_cfg)

我们可以对init_cfg部分进行修改

init_cfg 的使用

  1. layer 键初始化模型

    如果我们只定义了 layer, 它只会在 layer 键中初始化网络层。

    注意: layer 键对应的值是 Pytorch 的带有 weights 和 bias 属性的类名(因此不⽀持 MultiheadAttention 层)。

  • 定义⽤于初始化具有相同配置的模块的 layer 键。

    init_cfg = dict(type='Constant', layer=['Conv1d', 'Conv2d', 'Linear'], val=1)
    # ⽤相同的配置初始化整个模块
    
  • 定义⽤于初始化具有不同配置的层的 layer 键。

    init_cfg = [dict(type='Constant', layer='Conv1d', val=1),dict(type='Constant', layer='Conv2d', val=2),dict(type='Constant', layer='Linear', val=3)]
    # nn.Conv1d 将被初始化为 dict(type='Constant', val=1)
    # nn.Conv2d 将被初始化为 dict(type='Constant', val=2)
    # nn.Linear 将被初始化为 dict(type='Constant', val=3)
    
  1. 使⽤ override 键初始化模型
  • 当使⽤属性名初始化某些特定部分时,我们可以使⽤ override 键, override 中的值将忽略 init_cfg 中的值。

    # layers:
    # self.feat = nn.Conv1d(3, 1, 3)
    # self.reg = nn.Conv2d(3, 3, 3)
    # self.cls = nn.Linear(1,2)init_cfg = dict(type='Constant',layer=['Conv1d','Conv2d'], val=1, bias=2,override=dict(type='Constant', name='reg', val=3, bias=4))
    # self.feat and self.cls 将被初始化为 dict(type='Constant', val=1, bias=2)
    # 叫 'reg' 的模块将被初始化为 dict(type='Constant', val=3, bias=4)
    
  • 如果 init_cfg 中的 layer 为 None,则只会初始化 override 中有 name 的⼦模块,⽽ override 中的 type 和其他参数可以省略。

    # layers:
    # self.feat = nn.Conv1d(3, 1, 3)
    # self.reg = nn.Conv2d(3, 3, 3)
    # self.cls = nn.Linear(1,2)init_cfg = dict(type='Constant', val=1, bias=2, 	override=dict(name='reg'))# self.feat and self.cls 将被 Pytorch 初始化
    # 叫 'reg' 的模块将被 dict(type='Constant', val=1, bias=2) 初始化
    
  • 如果我们不定义 layeroverride 键,它不会初始化任何东西。

  • 无效的使用

    # override 没有 name 键的话是无效的
    init_cfg = dict(type='Constant', layer=['Conv1d','Conv2d'], val=1, bias=2,override=dict(type='Constant', val=3, bias=4))# override 有 name 键和其他参数但是没有 type 键也是无效的
    init_cfg = dict(type='Constant', layer=['Conv1d','Conv2d'], val=1, bias=2,override=dict(name='reg', val=3, bias=4))
    
  1. 使⽤预训练模型初始化模型

    init_cfg = dict(type='Pretrained',checkpoint='torchvision://resnet50')
    

配置初始化器

我们可以通过配置 init_cfg 为模型中任意组件灵活地选择初始化方式。目前我们可以在 init_cfg 中配置以下初始化器:

InitializerRegistered nameFunction
ConstantInitConstant将 weight 和 bias 初始化为指定常量,通常用于初始化卷积
XavierInitXavier将 weight Xavier 方式初始化,将 bias 初始化成指定常量,通常用于初始化卷积
NormalInitNormal将 weight 以正态分布的方式初始化,将 bias 初始化成指定常量,通常用于初始化卷积
TruncNormalInitTruncNormal将 weight 以被截断的正态分布的方式初始化,参数 a 和 b 为正态分布的有效区域;将 bias 初始化成指定常量,通常用于初始化 Transformer
UniformInitUniform将 weight 以均匀分布的方式初始化,参数 a 和 b 为均匀分布的范围;将 bias 初始化为指定常量,通常用于初始化卷积
KaimingInitKaiming将 weight 以 Kaiming 的方式初始化,将 bias 初始化成指定常量,通常用于初始化卷积
Caffe2XavierInitCaffe2XavierCaffe2 中 Xavier 初始化方式,在 Pytorch 中对应 “fan_in”, “normal” 模式的 Kaiming 初始化,,通常用于初始化卷
PretrainedPretrainedInit加载预训练权重

本贴后续会利用Faster R-CNN对预训练权重初始化和常用初始化进行实验,详情教程请见MMEgine权重初始化。

这篇关于利用MMDetection进行模型微调和权重初始化的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/989009

相关文章

C++中RAII资源获取即初始化

《C++中RAII资源获取即初始化》RAII通过构造/析构自动管理资源生命周期,确保安全释放,本文就来介绍一下C++中的RAII技术及其应用,具有一定的参考价值,感兴趣的可以了解一下... 目录一、核心原理与机制二、标准库中的RAII实现三、自定义RAII类设计原则四、常见应用场景1. 内存管理2. 文件操

Linux使用scp进行远程目录文件复制的详细步骤和示例

《Linux使用scp进行远程目录文件复制的详细步骤和示例》在Linux系统中,scp(安全复制协议)是一个使用SSH(安全外壳协议)进行文件和目录安全传输的命令,它允许在远程主机之间复制文件和目录,... 目录1. 什么是scp?2. 语法3. 示例示例 1: 复制本地目录到远程主机示例 2: 复制远程主

详解如何使用Python从零开始构建文本统计模型

《详解如何使用Python从零开始构建文本统计模型》在自然语言处理领域,词汇表构建是文本预处理的关键环节,本文通过Python代码实践,演示如何从原始文本中提取多尺度特征,并通过动态调整机制构建更精确... 目录一、项目背景与核心思想二、核心代码解析1. 数据加载与预处理2. 多尺度字符统计3. 统计结果可

windows系统上如何进行maven安装和配置方式

《windows系统上如何进行maven安装和配置方式》:本文主要介绍windows系统上如何进行maven安装和配置方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不... 目录1. Maven 简介2. maven的下载与安装2.1 下载 Maven2.2 Maven安装2.

C/C++的OpenCV 进行图像梯度提取的几种实现

《C/C++的OpenCV进行图像梯度提取的几种实现》本文主要介绍了C/C++的OpenCV进行图像梯度提取的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的... 目录预www.chinasem.cn备知识1. 图像加载与预处理2. Sobel 算子计算 X 和 Y

SpringBoot整合Sa-Token实现RBAC权限模型的过程解析

《SpringBoot整合Sa-Token实现RBAC权限模型的过程解析》:本文主要介绍SpringBoot整合Sa-Token实现RBAC权限模型的过程解析,本文给大家介绍的非常详细,对大家的学... 目录前言一、基础概念1.1 RBAC模型核心概念1.2 Sa-Token核心功能1.3 环境准备二、表结

Go语言中使用JWT进行身份验证的几种方式

《Go语言中使用JWT进行身份验证的几种方式》本文主要介绍了Go语言中使用JWT进行身份验证的几种方式,包括dgrijalva/jwt-go、golang-jwt/jwt、lestrrat-go/jw... 目录简介1. github.com/dgrijalva/jwt-go安装:使用示例:解释:2. gi

SpringBoot如何对密码等敏感信息进行脱敏处理

《SpringBoot如何对密码等敏感信息进行脱敏处理》这篇文章主要为大家详细介绍了SpringBoot对密码等敏感信息进行脱敏处理的几个常用方法,文中的示例代码讲解详细,感兴趣的小伙伴可以了解下... 目录​1. 配置文件敏感信息脱敏​​2. 日志脱敏​​3. API响应脱敏​​4. 其他注意事项​​总结

python进行while遍历的常见错误解析

《python进行while遍历的常见错误解析》在Python中选择合适的遍历方式需要综合考虑可读性、性能和具体需求,本文就来和大家讲解一下python中while遍历常见错误以及所有遍历方法的优缺点... 目录一、超出数组范围问题分析错误复现解决方法关键区别二、continue使用问题分析正确写法关键点三

Python对PDF书签进行添加,修改提取和删除操作

《Python对PDF书签进行添加,修改提取和删除操作》PDF书签是PDF文件中的导航工具,通常包含一个标题和一个跳转位置,本教程将详细介绍如何使用Python对PDF文件中的书签进行操作... 目录简介使用工具python 向 PDF 添加书签添加书签添加嵌套书签Python 修改 PDF 书签Pytho