mask2former利用不确定性采样点选择提高模型性能

2024-06-13 04:04

本文主要是介绍mask2former利用不确定性采样点选择提高模型性能,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

在机器学习和深度学习的训练过程中,不确定性高的点通常代表模型在这些点上的预测不够可靠或有较高的误差。因此,关注这些不确定性高的点,通过计算这些点的损失并进行梯度更新,可以有效地提高模型的整体性能。确定性高的点预测结果已经比较准确,相应地对模型的训练贡献较小,所以可以减少对这些点的关注或完全忽略它们的损失计算。

代码复现参考仓库:https://github.com/NielsRogge/Transformers-Tutorials

在这篇博客中,我们将详细解释 mask2former 中的一段代码,该代码通过不确定性采样点来选择重要点,并探讨其在模型训练中的重要性。mask2former原文描述比较简单,如下:
在这里插入图片描述

代码源自transformers库中的modeling_mask2former.py,主要讲解如下代码:

    def sample_points_using_uncertainty(self,logits: torch.Tensor,uncertainty_function,num_points: int,oversample_ratio: int,importance_sample_ratio: float,) -> torch.Tensor:"""This function is meant for sampling points in [0, 1] * [0, 1] coordinate space based on their uncertainty. Theuncertainty is calculated for each point using the passed `uncertainty function` that takes points logitprediction as input.Args:logits (`float`):Logit predictions for P points.uncertainty_function:A function that takes logit predictions for P points and returns their uncertainties.num_points (`int`):The number of points P to sample.oversample_ratio (`int`):Oversampling parameter.importance_sample_ratio (`float`):Ratio of points that are sampled via importance sampling.Returns:point_coordinates (`torch.Tensor`):Coordinates for P sampled points."""num_boxes = logits.shape[0]num_points_sampled = int(num_points * oversample_ratio)# Get random point coordinatespoint_coordinates = torch.rand(num_boxes, num_points_sampled, 2, device=logits.device)# Get sampled prediction value for the point coordinatespoint_logits = sample_point(logits, point_coordinates, align_corners=False)# Calculate the uncertainties based on the sampled prediction values of the pointspoint_uncertainties = uncertainty_function(point_logits)#[n1+n2, 1, 37632],理解为,值越大,不确定性越高num_uncertain_points = int(importance_sample_ratio * num_points)#9408num_random_points = num_points - num_uncertain_points#3136idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]#[n1+n2, 9408]这行代码的作用是从每个 num_boxes 的不确定性值中选择 num_uncertain_points 个最大值的索引。这些索引将用于从原始的点坐标张量 point_coordinates 中选择相应的点,这些点将被认为是基于不确定性的重要性采样点。shift = num_points_sampled * torch.arange(num_boxes, dtype=torch.long, device=logits.device)#这两行代码的主要目的是确保在从 point_coordinates 中选择点时,能够正确地访问全局索引,使得每个 box 的采样点能够准确地映射到整个张量中的位置。idx += shift[:, None]point_coordinates = point_coordinates.view(-1, 2)[idx.view(-1), :].view(num_boxes, num_uncertain_points, 2)#[n1+n2, 9408, 2]if num_random_points > 0:point_coordinates = torch.cat([point_coordinates, torch.rand(num_boxes, num_random_points, 2, device=logits.device)],dim=1,)return point_coordinates

以下是 sample_points_using_uncertainty 函数的参数解释:

  • logits (torch.Tensor): P 个点的 logit 预测值。
  • uncertainty_function: 一个函数,接受 P 个点的 logit 预测值并返回它们的不确定性。
  • num_points (int): 需要采样的点的数量 P。
  • oversample_ratio (int): 过采样参数,用于增加采样点的数量,以确保能在不确定性采样中选到合适的点。
  • importance_sample_ratio (float): 使用重要性采样选出的点的比例。

函数步骤解释

  1. 计算总采样点数

    num_boxes = logits.shape[0]
    num_points_sampled = int(num_points * oversample_ratio)
    

    num_boxes 是指预测的盒子数量,num_points_sampled 是经过过采样之后的总采样点数。

  2. 生成随机点的坐标

    point_coordinates = torch.rand(num_boxes, num_points_sampled, 2, device=logits.device)
    

    在 [0, 1] * [0, 1] 空间内生成随机点的坐标。

  3. 获取这些随机点的预测值

    point_logits = sample_point(logits, point_coordinates, align_corners=False)
    

    对随机点的坐标进行采样,获取它们的预测 logit 值。

  4. 计算这些点的不确定性

    point_uncertainties = uncertainty_function(point_logits)
    

    使用 uncertainty_function 计算这些点的不确定性。

  5. 确定不确定性采样和随机采样的点数

    num_uncertain_points = int(importance_sample_ratio * num_points)
    num_random_points = num_points - num_uncertain_points
    

    根据 importance_sample_ratio 确定通过不确定性采样的点数 num_uncertain_points,以及剩余的随机采样点数 num_random_points

  6. 选择不确定性最高的点

    idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
    shift = num_points_sampled * torch.arange(num_boxes, dtype=torch.long, device=logits.device)
    idx += shift[:, None]
    point_coordinates = point_coordinates.view(-1, 2)[idx.view(-1), :].view(num_boxes, num_uncertain_points, 2)
    

    使用 torch.topk 函数选择每个盒子中不确定性最高的 num_uncertain_points 个点,并获取它们的坐标。

  7. 添加随机点

    if num_random_points > 0:point_coordinates = torch.cat([point_coordinates, torch.rand(num_boxes, num_random_points, 2, device=logits.device)],dim=1,)
    

    如果需要添加随机采样点,将它们与不确定性采样点合并。

  8. 返回采样点的坐标

    return point_coordinates
    

    最终返回所有采样点的坐标。

关键代码解读

1. 偏移量的生成
 shift = num_points_sampled * torch.arange(num_boxes, dtype=torch.long, device=logits.device)

这行代码的目的是为每个 box 生成一个偏移量(shift),用于转换局部索引为全局索引。

  • torch.arange(num_boxes, dtype=torch.long, device=logits.device) 生成一个从 0 到 num_boxes-1 的张量。
  • num_points_sampled 是每个 box 中采样的点的数量。
  • 乘法操作 num_points_sampled * torch.arange(num_boxes, dtype=torch.long, device=logits.device) 为每个 box 生成一个偏移量。例如,假设 num_points_sampled 为 100,那么生成的偏移量张量为 [0, 100, 200, 300, ...]

这些偏移量将用于将局部索引(即每个 box 内的索引)转换为全局索引(即在整个 point_coordinates 中的索引)。

2. 局部索引转换为全局索引
  idx += shift[:, None]

这行代码将局部索引转换为全局索引。

  • idxtorch.topk 返回的不确定性最高的点的局部索引,形状为 [num_boxes, num_uncertain_points]
  • shift[:, None] 的形状是 [num_boxes, 1],通过这种方式将每个 box 的偏移量广播到与 idx 的形状匹配。

通过将 shift 加到 idx 上,每个 box 的局部索引将变成全局索引。例如,如果第一个 box 的偏移量为 100,那么第一个 box 内的局部索引 [0, 1, 2, ...] 将变为 [100, 101, 102, ...]

总结

通过 sample_points_using_uncertainty 函数,我们可以有效地选择不确定性高的点进行训练,提高模型在这些关键点上的表现,同时减少确定性高的点的计算开销。这种不确定性采样方法结合了重要性采样和随机采样,确保了模型训练的高效性和鲁棒性。

这篇关于mask2former利用不确定性采样点选择提高模型性能的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1056225

相关文章

从原理到实战解析Java Stream 的并行流性能优化

《从原理到实战解析JavaStream的并行流性能优化》本文给大家介绍JavaStream的并行流性能优化:从原理到实战的全攻略,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的... 目录一、并行流的核心原理与适用场景二、性能优化的核心策略1. 合理设置并行度:打破默认阈值2. 避免装箱

深度剖析SpringBoot日志性能提升的原因与解决

《深度剖析SpringBoot日志性能提升的原因与解决》日志记录本该是辅助工具,却为何成了性能瓶颈,SpringBoot如何用代码彻底破解日志导致的高延迟问题,感兴趣的小伙伴可以跟随小编一起学习一下... 目录前言第一章:日志性能陷阱的底层原理1.1 日志级别的“双刃剑”效应1.2 同步日志的“吞吐量杀手”

Java慢查询排查与性能调优完整实战指南

《Java慢查询排查与性能调优完整实战指南》Java调优是一个广泛的话题,它涵盖了代码优化、内存管理、并发处理等多个方面,:本文主要介绍Java慢查询排查与性能调优的相关资料,文中通过代码介绍的非... 目录1. 事故全景:从告警到定位1.1 事故时间线1.2 关键指标异常1.3 排查工具链2. 深度剖析:

深入解析Java NIO在高并发场景下的性能优化实践指南

《深入解析JavaNIO在高并发场景下的性能优化实践指南》随着互联网业务不断演进,对高并发、低延时网络服务的需求日益增长,本文将深入解析JavaNIO在高并发场景下的性能优化方法,希望对大家有所帮助... 目录简介一、技术背景与应用场景二、核心原理深入分析2.1 Selector多路复用2.2 Buffer

基于Python Playwright进行前端性能测试的脚本实现

《基于PythonPlaywright进行前端性能测试的脚本实现》在当今Web应用开发中,性能优化是提升用户体验的关键因素之一,本文将介绍如何使用Playwright构建一个自动化性能测试工具,希望... 目录引言工具概述整体架构核心实现解析1. 浏览器初始化2. 性能数据收集3. 资源分析4. 关键性能指

Zabbix在MySQL性能监控方面的运用及最佳实践记录

《Zabbix在MySQL性能监控方面的运用及最佳实践记录》Zabbix通过自定义脚本和内置模板监控MySQL核心指标(连接、查询、资源、复制),支持自动发现多实例及告警通知,结合可视化仪表盘,可有效... 目录一、核心监控指标及配置1. 关键监控指标示例2. 配置方法二、自动发现与多实例管理1. 实践步骤

MySQL深分页进行性能优化的常见方法

《MySQL深分页进行性能优化的常见方法》在Web应用中,分页查询是数据库操作中的常见需求,然而,在面对大型数据集时,深分页(deeppagination)却成为了性能优化的一个挑战,在本文中,我们将... 目录引言:深分页,真的只是“翻页慢”那么简单吗?一、背景介绍二、深分页的性能问题三、业务场景分析四、

MySQL 多列 IN 查询之语法、性能与实战技巧(最新整理)

《MySQL多列IN查询之语法、性能与实战技巧(最新整理)》本文详解MySQL多列IN查询,对比传统OR写法,强调其简洁高效,适合批量匹配复合键,通过联合索引、分批次优化提升性能,兼容多种数据库... 目录一、基础语法:多列 IN 的两种写法1. 直接值列表2. 子查询二、对比传统 OR 的写法三、性能分析

Linux系统性能检测命令详解

《Linux系统性能检测命令详解》本文介绍了Linux系统常用的监控命令(如top、vmstat、iostat、htop等)及其参数功能,涵盖进程状态、内存使用、磁盘I/O、系统负载等多维度资源监控,... 目录toppsuptimevmstatIOStatiotopslabtophtopdstatnmon

Android kotlin中 Channel 和 Flow 的区别和选择使用场景分析

《Androidkotlin中Channel和Flow的区别和选择使用场景分析》Kotlin协程中,Flow是冷数据流,按需触发,适合响应式数据处理;Channel是热数据流,持续发送,支持... 目录一、基本概念界定FlowChannel二、核心特性对比数据生产触发条件生产与消费的关系背压处理机制生命周期