【深度学习:大模型微调】如何微调SAM

2024-02-20 09:04
文章标签 学习 深度 模型 微调 sam

本文主要是介绍【深度学习:大模型微调】如何微调SAM,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

在这里插入图片描述

【深度学习:大模型微调】如何微调SAM

    • 什么是 Segment Anything 模型 (SAM)?
    • 什么是模型微调?
    • 为什么要微调模型?
    • 如何微调分段任何模型[使用代码]
      • 背景与架构
      • 创建自定义数据集
      • 输入数据预处理
      • 训练设置
      • 训练循环
      • 保存检查点并从中启动模型
    • 针对下游应用程序的微调
    • 结论

随着 Meta 上周发布 Segment Anything Model (SAM),计算机视觉正在经历 ChatGPT 时刻。SAM 训练了超过 110 亿个分割掩码,是预测性 AI 用例的基础模型,而不是生成式 AI。虽然它在分割各种图像模态和问题空间的能力方面表现出了令人难以置信的灵活性,但它在没有“微调”功能的情况下发布。

本教程将概述使用掩码解码器微调 SAM 的一些关键步骤,特别是描述 SAM 中的哪些函数用于预处理/后处理数据,以便其处于良好的状态以进行微调。

什么是 Segment Anything 模型 (SAM)?

Segment Anything 模型 (SAM) 是由 Meta AI 开发的细分模型。它被认为是计算机视觉的第一个基础模型。SAM在包含数百万张图像和数十亿个掩码的庞大数据语料库上进行了训练,使其非常强大。顾名思义,SAM 能够为各种图像生成准确的分割掩码。SAM 的设计允许它考虑人类提示,使其在 Human In The Loop 注释中特别强大。这些提示可以是多模式的:它们可以是要分割的区域上的点、要分割的对象周围的边界框,也可以是有关应分割的内容的文本提示。

该模型分为 3 个组件:图像编码器、提示编码器和掩码解码器。

在这里插入图片描述
图像编码器为被分割的图像生成嵌入,而提示编码器为提示生成嵌入。图像编码器是模型中一个特别大的组件。这与轻量级掩码解码器形成鲜明对比,后者根据嵌入预测分割掩码。Meta AI 已将 Segment Anything 10 亿掩码 (SA-1B) 数据集上训练的模型的权重和偏差作为模型检查点。

什么是模型微调?

公开可用的先进模型具有自定义架构,通常提供预训练的模型权重。如果这些架构没有权重,那么用户需要从头开始训练模型,他们需要使用海量数据集来获得最先进的性能。

模型微调是采用预训练模型(架构+权重)并向其显示特定用例数据的过程。这通常是模型以前从未见过的数据,或者在其原始训练数据集中代表性不足的数据。

微调模型和从头开始之间的区别在于权重和偏差的起始值。如果我们从头开始训练,这些将根据某种策略随机初始化。在这样的起始配置中,模型对手头的任务“一无所知”并且表现不佳。通过使用预先存在的权重和偏差作为起点,我们可以“微调”权重和偏差,以便我们的模型在自定义数据集上更好地工作。例如,学习到识别猫的信息(边缘检测、计算爪子)对于识别狗很有用。

为什么要微调模型?

微调模型的目的是为了在预训练模型从未见过的数据上获得更高的性能。例如,在从手机摄像头收集的广泛数据集上训练的图像分割模型将主要从水平角度看到图像。

如果我们尝试将此模型用于从垂直角度拍摄的卫星图像,它的性能可能不会那么好。如果我们尝试分割屋顶,该模型可能不会产生最佳结果。预训练很有用,因为模型通常已经学会了如何分割对象,因此我们希望利用这个起点来构建一个可以准确分割屋顶的模型。此外,我们的自定义数据集可能不会有数百万个示例,因此我们希望进行微调,而不是从头开始训练模型。

微调是可取的,这样我们就可以在特定用例上获得更好的性能,而不必承担从头开始训练模型的计算成本。

如何微调分段任何模型[使用代码]

背景与架构

我们在简介部分概述了 SAM 架构。图像编码器具有复杂的架构和许多参数。为了微调模型,我们有必要关注掩模解码器,因为它是轻量级的,因此微调起来更容易、更快、内存效率更高。

为了微调 SAM,我们需要提取其架构的底层部分(图像和提示编码器、掩码解码器)。我们不能使用 SamPredictor.predict(链接)有两个原因:

  • 我们只想微调掩码解码器
  • 该函数调用 SamPredictor.predict_torch,它具有 @torch.no_grad() 装饰器(链接),这会阻止我们计算梯度

因此,我们需要检查 SamPredictor.predict 函数并调用适当的函数,并在我们想要微调的部分(掩码解码器)上启用梯度计算。这样做也是了解 SAM 工作原理的好方法。

创建自定义数据集

我们需要三件事来微调我们的模型:

  • 用于绘制分割的图像
  • 分割地面真值掩模
  • 提示输入模型

我们选择印章验证数据集(链接),因为它包含 SAM 在训练中可能未见过的数据(即文档上的印章)。我们可以通过使用预先训练的权重进行推理来验证它在此数据集上的表现是否良好,但并不完美。地面真实掩模也非常精确,这将使我们能够计算准确的损失。最后,该数据集包含分割掩码周围的边界框,我们可以将其用作 SAM 的提示。示例图像如下所示。这些边界框与人工注释者在生成分段时所经历的工作流程非常吻合。

在这里插入图片描述

输入数据预处理

我们需要预处理从 numpy 数组到 pytorch 张量的扫描。为此,我们可以跟踪 SamPredictor.set_image(链接)和 SamPredictor.set_torch_image(链接)内部发生的情况(对图像进行预处理)。首先,我们可以使用 utils.transform.ResizeLongestSide 来调整图像大小,因为这是预测器内部使用的转换器(链接)。然后我们可以将图像转换为pytorch张量并使用SAM预处理方法(链接)来完成预处理。

训练设置

我们下载 vit_b 模型的模型检查点,并将其加载进去:

sam_model = sam_model_registry['vit_b'](checkpoint='sam_vit_b_01ec64.pth')

我们可以使用默认值设置 Adam 优化器,并指定要调整的参数是掩码解码器的参数:

optimizer = torch.optim.Adam(sam_model.mask_decoder.parameters()) 

同时,我们可以设置我们的损失函数,例如均方误差

loss_fn = torch.nn.MSELoss()

训练循环

在主训练循环中,我们将迭代数据项,生成掩模,并将它们与地面实况掩模进行比较,以便我们可以根据损失函数优化模型参数。

在此示例中,我们使用 GPU 进行训练,因为它比使用 CPU 快得多。在适当的张量上使用 .to(device) 非常重要,以确保我们不会在 CPU 上使用某些张量而在 GPU 上使用其他张量。

我们希望通过将编码器包装在 torch.no_grad() 上下文管理器中来嵌入图像,否则我们将遇到内存问题,并且我们不希望微调图像编码器。

with torch.no_grad():image_embedding = sam_model.image_encoder(input_image)

我们还可以在 no_grad 上下文管理器中生成提示嵌入。我们使用边界框坐标,转换为 pytorch 张量。

with torch.no_grad():sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(points=None,boxes=box_torch,masks=None,)

最后,我们可以生成蒙版。请注意,这里我们处于单掩码生成模式(与通常输出的 3 个掩码相反)。

low_res_masks, iou_predictions = sam_model.mask_decoder(image_embeddings=image_embedding,image_pe=sam_model.prompt_encoder.get_dense_pe(),sparse_prompt_embeddings=sparse_embeddings,dense_prompt_embeddings=dense_embeddings,multimask_output=False,
)

这里的最后一步是将蒙版放大回原始图像大小,因为它们的分辨率较低。我们可以使用 Sam.postprocess_masks 来实现这一点。我们还希望从预测的掩码生成二进制掩码,以便我们可以将它们与我们的基本事实进行比较。为了不破坏反向传播,使用火炬泛函非常重要。

upscaled_masks = sam_model.postprocess_masks(low_res_masks, input_size, original_image_size).to(device)from torch.nn.functional import threshold, normalizebinary_mask = normalize(threshold(upscaled_masks, 0.0, 0)).to(device)

最后,我们可以计算损失并运行优化步骤:

loss = loss_fn(binary_mask, gt_binary_mask)
optimizer.zero_grad()
loss.backward()
optimizer.step()

通过在多个 epoch 和批次中重复此操作,我们可以微调 SAM 解码器。

保存检查点并从中启动模型

一旦我们完成训练并对性能提升感到满意,我们可以使用以下方法保存调整模型的状态字典:

torch.save(model.state_dict(), PATH)

然后,当我们想要对与我们用于微调模型的数据类似的数据进行推理时,我们可以加载此状态字典。

针对下游应用程序的微调

虽然 SAM 目前不提供开箱即用的微调功能,但我们正在构建与 Encord 平台集成的自定义微调器。如本文所示,我们对解码器进行微调以实现这一目标。这可以作为 Web 应用程序中开箱即用的一键式过程使用,其中超参数是自动设置的。

在这里插入图片描述
Original vanilla SAM mask:

在这里插入图片描述
由模型的微调版本生成的掩模:

在这里插入图片描述
我们可以看到这个面罩比原来的面罩更紧。这是对印章验证数据集中的一小部分图像进行微调,然后在以前未见过的示例上运行调整后的模型的结果。通过进一步的训练和更多的例子,我们可以获得更好的结果。

结论

就这样,伙计们!

您现在已经学会了如何微调分段任意模型 (SAM)。如果您希望对 SAM 进行开箱即用的微调,您可能还会有兴趣了解我们最近在 Encord 中发布了分段任意模型,使您无需编写任何代码即可微调模型。

在这里插入图片描述

这篇关于【深度学习:大模型微调】如何微调SAM的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/727691

相关文章

Python中文件读取操作漏洞深度解析与防护指南

《Python中文件读取操作漏洞深度解析与防护指南》在Web应用开发中,文件操作是最基础也最危险的功能之一,这篇文章将全面剖析Python环境中常见的文件读取漏洞类型,成因及防护方案,感兴趣的小伙伴可... 目录引言一、静态资源处理中的路径穿越漏洞1.1 典型漏洞场景1.2 os.path.join()的陷

Android学习总结之Java和kotlin区别超详细分析

《Android学习总结之Java和kotlin区别超详细分析》Java和Kotlin都是用于Android开发的编程语言,它们各自具有独特的特点和优势,:本文主要介绍Android学习总结之Ja... 目录一、空安全机制真题 1:Kotlin 如何解决 Java 的 NullPointerExceptio

详解如何使用Python从零开始构建文本统计模型

《详解如何使用Python从零开始构建文本统计模型》在自然语言处理领域,词汇表构建是文本预处理的关键环节,本文通过Python代码实践,演示如何从原始文本中提取多尺度特征,并通过动态调整机制构建更精确... 目录一、项目背景与核心思想二、核心代码解析1. 数据加载与预处理2. 多尺度字符统计3. 统计结果可

SpringBoot整合Sa-Token实现RBAC权限模型的过程解析

《SpringBoot整合Sa-Token实现RBAC权限模型的过程解析》:本文主要介绍SpringBoot整合Sa-Token实现RBAC权限模型的过程解析,本文给大家介绍的非常详细,对大家的学... 目录前言一、基础概念1.1 RBAC模型核心概念1.2 Sa-Token核心功能1.3 环境准备二、表结

Spring Boot拦截器Interceptor与过滤器Filter深度解析(区别、实现与实战指南)

《SpringBoot拦截器Interceptor与过滤器Filter深度解析(区别、实现与实战指南)》:本文主要介绍SpringBoot拦截器Interceptor与过滤器Filter深度解析... 目录Spring Boot拦截器(Interceptor)与过滤器(Filter)深度解析:区别、实现与实

MyBatis分页插件PageHelper深度解析与实践指南

《MyBatis分页插件PageHelper深度解析与实践指南》在数据库操作中,分页查询是最常见的需求之一,传统的分页方式通常有两种内存分页和SQL分页,MyBatis作为优秀的ORM框架,本身并未提... 目录1. 为什么需要分页插件?2. PageHelper简介3. PageHelper集成与配置3.

Maven 插件配置分层架构深度解析

《Maven插件配置分层架构深度解析》:本文主要介绍Maven插件配置分层架构深度解析,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录Maven 插件配置分层架构深度解析引言:当构建逻辑遇上复杂配置第一章 Maven插件配置的三重境界1.1 插件配置的拓扑

重新对Java的类加载器的学习方式

《重新对Java的类加载器的学习方式》:本文主要介绍重新对Java的类加载器的学习方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1、介绍1.1、简介1.2、符号引用和直接引用1、符号引用2、直接引用3、符号转直接的过程2、加载流程3、类加载的分类3.1、显示

Java学习手册之Filter和Listener使用方法

《Java学习手册之Filter和Listener使用方法》:本文主要介绍Java学习手册之Filter和Listener使用方法的相关资料,Filter是一种拦截器,可以在请求到达Servl... 目录一、Filter(过滤器)1. Filter 的工作原理2. Filter 的配置与使用二、Listen

Python中__init__方法使用的深度解析

《Python中__init__方法使用的深度解析》在Python的面向对象编程(OOP)体系中,__init__方法如同建造房屋时的奠基仪式——它定义了对象诞生时的初始状态,下面我们就来深入了解下_... 目录一、__init__的基因图谱二、初始化过程的魔法时刻继承链中的初始化顺序self参数的奥秘默认