【损失函数】Quantile Loss 分位数损失

2024-01-04 07:52

本文主要是介绍【损失函数】Quantile Loss 分位数损失,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1、介绍

Quantile Loss(分位数损失)是用于回归问题的一种损失函数,它允许我们对不同分位数的预测误差赋予不同的权重。这对于处理不同置信水平的预测非常有用,例如在风险管理等领域。

当我们需要对区间预测而不单是点预测时 分位数损失函数可以发挥很大作用

2、公式

$J_{\text {quant }}=\frac{1}{N} \sum_{i=1}^N \mathbb{I}_{\hat{y}_i \geq y_i}(1-\gamma)\left|y_i-\hat{y}_i\right|+\mathbb{I}_{\hat{y}_i<y_i} \gamma\left|y_i-\hat{y}_i\right|$

其中,{y}_i是实际目标值,\hat{y}_i 是模型的预测值,\gamma 是分位数水平,通常取值在 0 和 1 之间。

        我们如何理解这个损失函数呢?这个损失函数是一个分段的函数 ,将  \hat{y}_i \geq y_i(高估) 和  \hat{y}_i<y_i(低估) 两种情况分开来,并分别给予不同的系数。当 \gamma > 0.5 时,低估的损失要比高估的损失更大,反过来当 \gamma < 0.5 时,高估的损失比低估的损失大;分位数损失实现了分别用不同的系数控制高估和低估的损失,进而实现分位数回归。特别地,当 \gamma = 0.5 时,分位数损失退化为 MAE 损失,从这里可以看出 MAE 损失实际上是分位数损失的一个特例 — 中位数回归(这也可以解释为什么 MAE 损失对 outlier 更鲁棒:MSE 回归期望值,MAE 回归中位数,通常 outlier 对中位数的影响比对期望值的影响小)。      

        简单的总结下,分位数损失通过 \gamma 的不同取值来避免过拟合和欠拟合,实现分位数回归。

        分位数值的选择基于在实际中需要误差如何发挥作用,即在过程中误差为正时发挥更多作用还是在误差为负时发挥更大作用。

3、图像

        上图是分位数损失(Quantile Loss)在分位数为 0.3、0.5、0.7 时的图像。图中显示了预测值(f)与分位数损失之间的关系,可以看到 0.3 和 0.8 在高估和低估两种情况下损失是不同的,而 0.5 实际上就是 MAE。

4、实例

假设我们有以下情况:我们正在训练一个模型来预测房价涨幅区间。我们有以下目标值(真实值)和预测值:

  • 目标(真实值): [2.0, 1.0, 4.0, 3.5, 5.0]
  • 预测: [1.8, 0.9, 3.5, 3.0, 4.8]

我们使用 Quantile Loss作为损失函数:

import torch
import torch.nn as nnclass QuantileLoss(nn.Module):def __init__(self, quantile):super(QuantileLoss, self).__init__()self.quantile = quantiledef forward(self, y, y_pred):residual = y_pred - yloss = torch.max((self.quantile - 1) * residual, self.quantile * residual)return torch.mean(loss)
# 示例数据
y_true = torch.tensor([2.0, 1.0, 4.0, 3.5, 5.0], dtype=torch.float32)
y_pred = torch.tensor([1.8, 0.9, 3.5, 3.0, 4.8], dtype=torch.float32)
# 定义分位数水平 当分位数为 0.5 时,分位数损失退化为 MAE 损失
quantile = 0.5
# 初始化损失函数
quantile_loss = QuantileLoss(quantile)
# 计算损失
loss = quantile_loss(y_true, y_pred)
# Quantile Loss: 0.14999999105930328
print(f'Quantile Loss: {loss.item()}')

       在上述示例中,我们使用了一个简单的自定义 PyTorch 模块 `QuantileLoss`,它采用分位数水平作为参数,并计算相应的 Quantile Loss。这个例子中使用的分位数是 0.5,即中位数。此时分位数损失退化为 MAE 损失,实际应用中根据不同需求设定不同的分位数水平。

5、参考

损失函数 Loss Function 之 分位数损失 Quantile Loss - 知乎 (zhihu.com)

深度学习常用损失函数总览:基本形式、原理、特点 (qq.com)

这篇关于【损失函数】Quantile Loss 分位数损失的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/568649

相关文章

Python中help()和dir()函数的使用

《Python中help()和dir()函数的使用》我们经常需要查看某个对象(如模块、类、函数等)的属性和方法,Python提供了两个内置函数help()和dir(),它们可以帮助我们快速了解代... 目录1. 引言2. help() 函数2.1 作用2.2 使用方法2.3 示例(1) 查看内置函数的帮助(

C++ 函数 strftime 和时间格式示例详解

《C++函数strftime和时间格式示例详解》strftime是C/C++标准库中用于格式化日期和时间的函数,定义在ctime头文件中,它将tm结构体中的时间信息转换为指定格式的字符串,是处理... 目录C++ 函数 strftipythonme 详解一、函数原型二、功能描述三、格式字符串说明四、返回值五

Python中bisect_left 函数实现高效插入与有序列表管理

《Python中bisect_left函数实现高效插入与有序列表管理》Python的bisect_left函数通过二分查找高效定位有序列表插入位置,与bisect_right的区别在于处理重复元素时... 目录一、bisect_left 基本介绍1.1 函数定义1.2 核心功能二、bisect_left 与

java中BigDecimal里面的subtract函数介绍及实现方法

《java中BigDecimal里面的subtract函数介绍及实现方法》在Java中实现减法操作需要根据数据类型选择不同方法,主要分为数值型减法和字符串减法两种场景,本文给大家介绍java中BigD... 目录Java中BigDecimal里面的subtract函数的意思?一、数值型减法(高精度计算)1.

C++/类与对象/默认成员函数@构造函数的用法

《C++/类与对象/默认成员函数@构造函数的用法》:本文主要介绍C++/类与对象/默认成员函数@构造函数的用法,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录名词概念默认成员函数构造函数概念函数特征显示构造函数隐式构造函数总结名词概念默认构造函数:不用传参就可以

C++类和对象之默认成员函数的使用解读

《C++类和对象之默认成员函数的使用解读》:本文主要介绍C++类和对象之默认成员函数的使用方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、默认成员函数有哪些二、各默认成员函数详解默认构造函数析构函数拷贝构造函数拷贝赋值运算符三、默认成员函数的注意事项总结一

Python函数返回多个值的多种方法小结

《Python函数返回多个值的多种方法小结》在Python中,函数通常用于封装一段代码,使其可以重复调用,有时,我们希望一个函数能够返回多个值,Python提供了几种不同的方法来实现这一点,需要的朋友... 目录一、使用元组(Tuple):二、使用列表(list)三、使用字典(Dictionary)四、 使

PyTorch中cdist和sum函数使用示例详解

《PyTorch中cdist和sum函数使用示例详解》torch.cdist是PyTorch中用于计算**两个张量之间的成对距离(pairwisedistance)**的函数,常用于点云处理、图神经网... 目录基本语法输出示例1. 简单的 2D 欧几里得距离2. 批量形式(3D Tensor)3. 使用不

MySQL 字符串截取函数及用法详解

《MySQL字符串截取函数及用法详解》在MySQL中,字符串截取是常见的操作,主要用于从字符串中提取特定部分,MySQL提供了多种函数来实现这一功能,包括LEFT()、RIGHT()、SUBST... 目录mysql 字符串截取函数详解RIGHT(str, length):从右侧截取指定长度的字符SUBST

Kotlin运算符重载函数及作用场景

《Kotlin运算符重载函数及作用场景》在Kotlin里,运算符重载函数允许为自定义类型重新定义现有的运算符(如+-…)行为,从而让自定义类型能像内置类型那样使用运算符,本文给大家介绍Kotlin运算... 目录基本语法作用场景类对象数据类型接口注意事项在 Kotlin 里,运算符重载函数允许为自定义类型重