PyTorch nn.MSELoss() 均方误差损失函数详解和要点提醒

2024-06-24 01:44

本文主要是介绍PyTorch nn.MSELoss() 均方误差损失函数详解和要点提醒,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

  • nn.MSELoss() 均方误差损失函数
    • 参数
    • 数学公式
      • 元素版本
    • 要点
    • 附录
  • 参考链接

nn.MSELoss() 均方误差损失函数

torch.nn.MSELoss(size_average=None, reduce=None, reduction='mean')

Creates a criterion that measures the mean squared error (squared L2 norm) between each element in the input x x x and target y y y.

计算输入和目标之间每个元素的均方误差(平方 L2 范数)。

参数

  • size_average (bool, 可选):
    • 已弃用。请参阅 reduction 参数。
    • 默认情况下,损失在批次中的每个损失元素上取平均(True);否则(False),在每个小批次中对损失求和。
    • reduceFalse 时忽略该参数。
    • 默认值是 True
  • reduce (bool, 可选):
    • 已弃用。请参阅 reduction 参数。
    • 默认情况下,损失根据 size_average 参数进行平均或求和。
    • reduceFalse 时,返回每个批次元素的损失,并忽略 size_average 参数。
    • 默认值是 True
  • reduction (str, 可选):
    • 指定应用于输出的归约方式。
    • 可选值为 'none''mean''sum'
      • 'none':不进行归约。
      • 'mean':输出的和除以输出的元素总数。
      • 'sum':输出的元素求和。
    • 注意:size_averagereduce 参数正在被弃用,同时指定这些参数中的任何一个都会覆盖 reduction 参数。
    • 默认值是 'mean'

数学公式

附录部分会验证下述公式和代码的一致性。

假设有 N N N 个样本,每个样本的输入为 x n x_n xn,目标为 y n y_n yn。均方误差损失的计算步骤如下:

  1. 单个样本的损失
    计算每个样本的均方误差:
    l n = ( x n − y n ) 2 l_n = (x_n - y_n)^2 ln=(xnyn)2
    其中 l n l_n ln 是第 n n n 个样本的损失。
  2. 总损失
    计算所有样本的平均损失(reduction 参数默认为 'mean'):
    L = 1 N ∑ n = 1 N l n = 1 N ∑ n = 1 N ( x n − y n ) 2 \mathcal{L} = \frac{1}{N} \sum_{n=1}^{N} l_n = \frac{1}{N} \sum_{n=1}^{N} (x_n - y_n)^2 L=N1n=1Nln=N1n=1N(xnyn)2
    如果 reduction 参数为 'sum',总损失为所有样本损失的和:
    L = ∑ n = 1 N l n = ∑ n = 1 N ( x n − y n ) 2 \mathcal{L} = \sum_{n=1}^{N} l_n = \sum_{n=1}^{N} (x_n - y_n)^2 L=n=1Nln=n=1N(xnyn)2
    如果 reduction 参数为 'none',则返回每个样本的损失 l n l_n ln 组成的张量:
    L = [ l 1 , l 2 , … , l N ] = [ ( x 1 − y 1 ) 2 , ( x 2 − y 2 ) 2 , … , ( x N − y N ) 2 ] \mathcal{L} = [l_1, l_2, \ldots, l_N] = [(x_1 - y_1)^2, (x_2 - y_2)^2, \ldots, (x_N - y_N)^2] L=[l1,l2,,lN]=[(x1y1)2,(x2y2)2,,(xNyN)2]

元素版本

假设输入张量 x \mathbf{x} x 和目标张量 y \mathbf{y} y 具有相同的形状,每个张量包含 N N N 个元素。均方误差损失的计算步骤如下:

  1. 单个元素的损失
    计算每个元素的均方误差:
    l i j = ( x i j − y i j ) 2 l_{ij} = (x_{ij} - y_{ij})^2 lij=(xijyij)2
    其中 l i j l_{ij} lij 是输入张量和目标张量在位置 ( i , j ) (i, j) (i,j) 的元素损失。
  2. 总损失
    计算所有元素的平均损失(reduction 参数默认为 'mean'):
    L = 1 N ∑ i , j l i j = 1 N ∑ i , j ( x i j − y i j ) 2 \mathcal{L} = \frac{1}{N} \sum_{i,j} l_{ij} = \frac{1}{N} \sum_{i,j} (x_{ij} - y_{ij})^2 L=N1i,jlij=N1i,j(xijyij)2
    如果 reduction 参数为 'sum',总损失为所有元素损失的和:
    L = ∑ i , j l i j = ∑ i , j ( x i j − y i j ) 2 \mathcal{L} = \sum_{i,j} l_{ij} = \sum_{i,j} (x_{ij} - y_{ij})^2 L=i,jlij=i,j(xijyij)2
    如果 reduction 参数为 'none',则返回每个元素的损失 l i j l_{ij} lij 组成的张量:
    L = { l i j } = { ( x i j − y i j ) 2 } \mathcal{L} = \{l_{ij}\} = \{(x_{ij} - y_{ij})^2 \} L={lij}={(xijyij)2}

要点

  1. nn.MSELoss() 接受的输入和目标应具有相同的形状和类型。
    使用示例
    import torch
    import torch.nn as nn# 定义输入和目标张量
    input = torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True)
    target = torch.tensor([[1.5, 2.5], [3.5, 4.5]])# 使用 nn.MSELoss 计算损失
    criterion = nn.MSELoss()
    loss = criterion(input, target)print(f"Loss using nn.MSELoss: {loss.item()}")
    
    >>> Loss using nn.MSELoss: 0.25
    
  2. nn.MSELoss()reduction 参数指定了如何归约输出损失。默认值是 'mean',计算的是所有样本的平均损失。
    • 如果 reduction 参数为 'mean',损失是所有样本损失的平均值。
    • 如果 reduction 参数为 'sum',损失是所有样本损失的和。
    • 如果 reduction 参数为 'none',则返回每个样本的损失组成的张量。
      代码示例
    import torch
    import torch.nn as nn# 定义输入和目标张量
    input = torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True)
    target = torch.tensor([[1.5, 2.5], [3.5, 4.5]])# 使用 nn.MSELoss 计算损失(reduction='mean')
    criterion_mean = nn.MSELoss(reduction='mean')
    loss_mean = criterion_mean(input, target)
    print(f"Loss with reduction='mean': {loss_mean.item()}")# 使用 nn.MSELoss 计算损失(reduction='sum')
    criterion_sum = nn.MSELoss(reduction='sum')
    loss_sum = criterion_sum(input, target)
    print(f"Loss with reduction='sum': {loss_sum.item()}")# 使用 nn.MSELoss 计算损失(reduction='none')
    criterion_none = nn.MSELoss(reduction='none')
    loss_none = criterion_none(input, target)
    print(f"Loss with reduction='none': {loss_none}")
    
    >>> Loss with reduction='mean': 0.25
    >>> Loss with reduction='sum': 1.0
    >>> Loss with reduction='none': tensor([[0.2500, 0.2500],[0.2500, 0.2500]], grad_fn=<MseLossBackward0>)
    

附录

用于验证数学公式和函数实际运行的一致性

import torch
import torch.nn.functional as F# 假设有两个样本,每个样本有两个维度
input = torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True)
target = torch.tensor([[1.5, 2.5], [3.5, 4.5]])# 根据公式实现均方误差损失
def mse_loss(input, target):return ((input - target) ** 2).mean()# 使用 nn.MSELoss 计算损失
criterion = torch.nn.MSELoss(reduction='mean')
loss_torch = criterion(input, target)# 使用根据公式实现的均方误差损失
loss_custom = mse_loss(input, target)# 打印结果
print("PyTorch 计算的均方误差损失:", loss_torch.item())
print("根据公式实现的均方误差损失:", loss_custom.item())# 验证结果是否相等
assert torch.isclose(loss_torch, loss_custom), "数学公式验证失败"
>>> PyTorch 计算的均方误差损失: 0.25
>>> 根据公式实现的均方误差损失: 0.25

输出没有抛出 AssertionError,验证通过。

参考链接

MSELoss - Docs

这篇关于PyTorch nn.MSELoss() 均方误差损失函数详解和要点提醒的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1088849

相关文章

RabbitMQ 延时队列插件安装与使用示例详解(基于 Delayed Message Plugin)

《RabbitMQ延时队列插件安装与使用示例详解(基于DelayedMessagePlugin)》本文详解RabbitMQ通过安装rabbitmq_delayed_message_exchan... 目录 一、什么是 RabbitMQ 延时队列? 二、安装前准备✅ RabbitMQ 环境要求 三、安装延时队

从基础到高级详解Python数值格式化输出的完全指南

《从基础到高级详解Python数值格式化输出的完全指南》在数据分析、金融计算和科学报告领域,数值格式化是提升可读性和专业性的关键技术,本文将深入解析Python中数值格式化输出的相关方法,感兴趣的小伙... 目录引言:数值格式化的核心价值一、基础格式化方法1.1 三种核心格式化方式对比1.2 基础格式化示例

GO语言中函数命名返回值的使用

《GO语言中函数命名返回值的使用》在Go语言中,函数可以为其返回值指定名称,这被称为命名返回值或命名返回参数,这种特性可以使代码更清晰,特别是在返回多个值时,感兴趣的可以了解一下... 目录基本语法函数命名返回特点代码示例命名特点基本语法func functionName(parameters) (nam

Python Counter 函数使用案例

《PythonCounter函数使用案例》Counter是collections模块中的一个类,专门用于对可迭代对象中的元素进行计数,接下来通过本文给大家介绍PythonCounter函数使用案例... 目录一、Counter函数概述二、基本使用案例(一)列表元素计数(二)字符串字符计数(三)元组计数三、C

Java中的stream流分组示例详解

《Java中的stream流分组示例详解》Java8StreamAPI以函数式风格处理集合数据,支持分组、统计等操作,可按单/多字段分组,使用String、Map.Entry或Java16record... 目录什么是stream流1、根据某个字段分组2、按多个字段分组(组合分组)1、方法一:使用 Stri

Spring创建Bean的八种主要方式详解

《Spring创建Bean的八种主要方式详解》Spring(尤其是SpringBoot)提供了多种方式来让容器创建和管理Bean,@Component、@Configuration+@Bean、@En... 目录引言一、Spring 创建 Bean 的 8 种主要方式1. @Component 及其衍生注解

Python异步编程之await与asyncio基本用法详解

《Python异步编程之await与asyncio基本用法详解》在Python中,await和asyncio是异步编程的核心工具,用于高效处理I/O密集型任务(如网络请求、文件读写、数据库操作等),接... 目录一、核心概念二、使用场景三、基本用法1. 定义协程2. 运行协程3. 并发执行多个任务四、关键

从基础到进阶详解Python条件判断的实用指南

《从基础到进阶详解Python条件判断的实用指南》本文将通过15个实战案例,带你大家掌握条件判断的核心技巧,并从基础语法到高级应用一网打尽,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一... 目录​引言:条件判断为何如此重要一、基础语法:三行代码构建决策系统二、多条件分支:elif的魔法三、

Java利用@SneakyThrows注解提升异常处理效率详解

《Java利用@SneakyThrows注解提升异常处理效率详解》这篇文章将深度剖析@SneakyThrows的原理,用法,适用场景以及隐藏的陷阱,看看它如何让Java异常处理效率飙升50%,感兴趣的... 目录前言一、检查型异常的“诅咒”:为什么Java开发者讨厌它1.1 检查型异常的痛点1.2 为什么说

MySQL的配置文件详解及实例代码

《MySQL的配置文件详解及实例代码》MySQL的配置文件是服务器运行的重要组成部分,用于设置服务器操作的各种参数,下面:本文主要介绍MySQL配置文件的相关资料,文中通过代码介绍的非常详细,需要... 目录前言一、配置文件结构1.[mysqld]2.[client]3.[mysql]4.[mysqldum