PyTorch nn.MSELoss() 均方误差损失函数详解和要点提醒

2024-06-24 01:44

本文主要是介绍PyTorch nn.MSELoss() 均方误差损失函数详解和要点提醒,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

  • nn.MSELoss() 均方误差损失函数
    • 参数
    • 数学公式
      • 元素版本
    • 要点
    • 附录
  • 参考链接

nn.MSELoss() 均方误差损失函数

torch.nn.MSELoss(size_average=None, reduce=None, reduction='mean')

Creates a criterion that measures the mean squared error (squared L2 norm) between each element in the input x x x and target y y y.

计算输入和目标之间每个元素的均方误差(平方 L2 范数)。

参数

  • size_average (bool, 可选):
    • 已弃用。请参阅 reduction 参数。
    • 默认情况下,损失在批次中的每个损失元素上取平均(True);否则(False),在每个小批次中对损失求和。
    • reduceFalse 时忽略该参数。
    • 默认值是 True
  • reduce (bool, 可选):
    • 已弃用。请参阅 reduction 参数。
    • 默认情况下,损失根据 size_average 参数进行平均或求和。
    • reduceFalse 时,返回每个批次元素的损失,并忽略 size_average 参数。
    • 默认值是 True
  • reduction (str, 可选):
    • 指定应用于输出的归约方式。
    • 可选值为 'none''mean''sum'
      • 'none':不进行归约。
      • 'mean':输出的和除以输出的元素总数。
      • 'sum':输出的元素求和。
    • 注意:size_averagereduce 参数正在被弃用,同时指定这些参数中的任何一个都会覆盖 reduction 参数。
    • 默认值是 'mean'

数学公式

附录部分会验证下述公式和代码的一致性。

假设有 N N N 个样本,每个样本的输入为 x n x_n xn,目标为 y n y_n yn。均方误差损失的计算步骤如下:

  1. 单个样本的损失
    计算每个样本的均方误差:
    l n = ( x n − y n ) 2 l_n = (x_n - y_n)^2 ln=(xnyn)2
    其中 l n l_n ln 是第 n n n 个样本的损失。
  2. 总损失
    计算所有样本的平均损失(reduction 参数默认为 'mean'):
    L = 1 N ∑ n = 1 N l n = 1 N ∑ n = 1 N ( x n − y n ) 2 \mathcal{L} = \frac{1}{N} \sum_{n=1}^{N} l_n = \frac{1}{N} \sum_{n=1}^{N} (x_n - y_n)^2 L=N1n=1Nln=N1n=1N(xnyn)2
    如果 reduction 参数为 'sum',总损失为所有样本损失的和:
    L = ∑ n = 1 N l n = ∑ n = 1 N ( x n − y n ) 2 \mathcal{L} = \sum_{n=1}^{N} l_n = \sum_{n=1}^{N} (x_n - y_n)^2 L=n=1Nln=n=1N(xnyn)2
    如果 reduction 参数为 'none',则返回每个样本的损失 l n l_n ln 组成的张量:
    L = [ l 1 , l 2 , … , l N ] = [ ( x 1 − y 1 ) 2 , ( x 2 − y 2 ) 2 , … , ( x N − y N ) 2 ] \mathcal{L} = [l_1, l_2, \ldots, l_N] = [(x_1 - y_1)^2, (x_2 - y_2)^2, \ldots, (x_N - y_N)^2] L=[l1,l2,,lN]=[(x1y1)2,(x2y2)2,,(xNyN)2]

元素版本

假设输入张量 x \mathbf{x} x 和目标张量 y \mathbf{y} y 具有相同的形状,每个张量包含 N N N 个元素。均方误差损失的计算步骤如下:

  1. 单个元素的损失
    计算每个元素的均方误差:
    l i j = ( x i j − y i j ) 2 l_{ij} = (x_{ij} - y_{ij})^2 lij=(xijyij)2
    其中 l i j l_{ij} lij 是输入张量和目标张量在位置 ( i , j ) (i, j) (i,j) 的元素损失。
  2. 总损失
    计算所有元素的平均损失(reduction 参数默认为 'mean'):
    L = 1 N ∑ i , j l i j = 1 N ∑ i , j ( x i j − y i j ) 2 \mathcal{L} = \frac{1}{N} \sum_{i,j} l_{ij} = \frac{1}{N} \sum_{i,j} (x_{ij} - y_{ij})^2 L=N1i,jlij=N1i,j(xijyij)2
    如果 reduction 参数为 'sum',总损失为所有元素损失的和:
    L = ∑ i , j l i j = ∑ i , j ( x i j − y i j ) 2 \mathcal{L} = \sum_{i,j} l_{ij} = \sum_{i,j} (x_{ij} - y_{ij})^2 L=i,jlij=i,j(xijyij)2
    如果 reduction 参数为 'none',则返回每个元素的损失 l i j l_{ij} lij 组成的张量:
    L = { l i j } = { ( x i j − y i j ) 2 } \mathcal{L} = \{l_{ij}\} = \{(x_{ij} - y_{ij})^2 \} L={lij}={(xijyij)2}

要点

  1. nn.MSELoss() 接受的输入和目标应具有相同的形状和类型。
    使用示例
    import torch
    import torch.nn as nn# 定义输入和目标张量
    input = torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True)
    target = torch.tensor([[1.5, 2.5], [3.5, 4.5]])# 使用 nn.MSELoss 计算损失
    criterion = nn.MSELoss()
    loss = criterion(input, target)print(f"Loss using nn.MSELoss: {loss.item()}")
    
    >>> Loss using nn.MSELoss: 0.25
    
  2. nn.MSELoss()reduction 参数指定了如何归约输出损失。默认值是 'mean',计算的是所有样本的平均损失。
    • 如果 reduction 参数为 'mean',损失是所有样本损失的平均值。
    • 如果 reduction 参数为 'sum',损失是所有样本损失的和。
    • 如果 reduction 参数为 'none',则返回每个样本的损失组成的张量。
      代码示例
    import torch
    import torch.nn as nn# 定义输入和目标张量
    input = torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True)
    target = torch.tensor([[1.5, 2.5], [3.5, 4.5]])# 使用 nn.MSELoss 计算损失(reduction='mean')
    criterion_mean = nn.MSELoss(reduction='mean')
    loss_mean = criterion_mean(input, target)
    print(f"Loss with reduction='mean': {loss_mean.item()}")# 使用 nn.MSELoss 计算损失(reduction='sum')
    criterion_sum = nn.MSELoss(reduction='sum')
    loss_sum = criterion_sum(input, target)
    print(f"Loss with reduction='sum': {loss_sum.item()}")# 使用 nn.MSELoss 计算损失(reduction='none')
    criterion_none = nn.MSELoss(reduction='none')
    loss_none = criterion_none(input, target)
    print(f"Loss with reduction='none': {loss_none}")
    
    >>> Loss with reduction='mean': 0.25
    >>> Loss with reduction='sum': 1.0
    >>> Loss with reduction='none': tensor([[0.2500, 0.2500],[0.2500, 0.2500]], grad_fn=<MseLossBackward0>)
    

附录

用于验证数学公式和函数实际运行的一致性

import torch
import torch.nn.functional as F# 假设有两个样本,每个样本有两个维度
input = torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True)
target = torch.tensor([[1.5, 2.5], [3.5, 4.5]])# 根据公式实现均方误差损失
def mse_loss(input, target):return ((input - target) ** 2).mean()# 使用 nn.MSELoss 计算损失
criterion = torch.nn.MSELoss(reduction='mean')
loss_torch = criterion(input, target)# 使用根据公式实现的均方误差损失
loss_custom = mse_loss(input, target)# 打印结果
print("PyTorch 计算的均方误差损失:", loss_torch.item())
print("根据公式实现的均方误差损失:", loss_custom.item())# 验证结果是否相等
assert torch.isclose(loss_torch, loss_custom), "数学公式验证失败"
>>> PyTorch 计算的均方误差损失: 0.25
>>> 根据公式实现的均方误差损失: 0.25

输出没有抛出 AssertionError,验证通过。

参考链接

MSELoss - Docs

这篇关于PyTorch nn.MSELoss() 均方误差损失函数详解和要点提醒的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1088849

相关文章

MySQL数据库双机热备的配置方法详解

《MySQL数据库双机热备的配置方法详解》在企业级应用中,数据库的高可用性和数据的安全性是至关重要的,MySQL作为最流行的开源关系型数据库管理系统之一,提供了多种方式来实现高可用性,其中双机热备(M... 目录1. 环境准备1.1 安装mysql1.2 配置MySQL1.2.1 主服务器配置1.2.2 从

Linux kill正在执行的后台任务 kill进程组使用详解

《Linuxkill正在执行的后台任务kill进程组使用详解》文章介绍了两个脚本的功能和区别,以及执行这些脚本时遇到的进程管理问题,通过查看进程树、使用`kill`命令和`lsof`命令,分析了子... 目录零. 用到的命令一. 待执行的脚本二. 执行含子进程的脚本,并kill2.1 进程查看2.2 遇到的

MyBatis常用XML语法详解

《MyBatis常用XML语法详解》文章介绍了MyBatis常用XML语法,包括结果映射、查询语句、插入语句、更新语句、删除语句、动态SQL标签以及ehcache.xml文件的使用,感兴趣的朋友跟随小... 目录1、定义结果映射2、查询语句3、插入语句4、更新语句5、删除语句6、动态 SQL 标签7、ehc

详解SpringBoot+Ehcache使用示例

《详解SpringBoot+Ehcache使用示例》本文介绍了SpringBoot中配置Ehcache、自定义get/set方式,并实际使用缓存的过程,文中通过示例代码介绍的非常详细,对大家的学习或者... 目录摘要概念内存与磁盘持久化存储:配置灵活性:编码示例引入依赖:配置ehcache.XML文件:配置

从基础到高级详解Go语言中错误处理的实践指南

《从基础到高级详解Go语言中错误处理的实践指南》Go语言采用了一种独特而明确的错误处理哲学,与其他主流编程语言形成鲜明对比,本文将为大家详细介绍Go语言中错误处理详细方法,希望对大家有所帮助... 目录1 Go 错误处理哲学与核心机制1.1 错误接口设计1.2 错误与异常的区别2 错误创建与检查2.1 基础

k8s按需创建PV和使用PVC详解

《k8s按需创建PV和使用PVC详解》Kubernetes中,PV和PVC用于管理持久存储,StorageClass实现动态PV分配,PVC声明存储需求并绑定PV,通过kubectl验证状态,注意回收... 目录1.按需创建 PV(使用 StorageClass)创建 StorageClass2.创建 PV

Python版本信息获取方法详解与实战

《Python版本信息获取方法详解与实战》在Python开发中,获取Python版本号是调试、兼容性检查和版本控制的重要基础操作,本文详细介绍了如何使用sys和platform模块获取Python的主... 目录1. python版本号获取基础2. 使用sys模块获取版本信息2.1 sys模块概述2.1.1

一文详解Python如何开发游戏

《一文详解Python如何开发游戏》Python是一种非常流行的编程语言,也可以用来开发游戏模组,:本文主要介绍Python如何开发游戏的相关资料,文中通过代码介绍的非常详细,需要的朋友可以参考下... 目录一、python简介二、Python 开发 2D 游戏的优劣势优势缺点三、Python 开发 3D

Python函数作用域与闭包举例深度解析

《Python函数作用域与闭包举例深度解析》Python函数的作用域规则和闭包是编程中的关键概念,它们决定了变量的访问和生命周期,:本文主要介绍Python函数作用域与闭包的相关资料,文中通过代码... 目录1. 基础作用域访问示例1:访问全局变量示例2:访问外层函数变量2. 闭包基础示例3:简单闭包示例4

Redis 基本数据类型和使用详解

《Redis基本数据类型和使用详解》String是Redis最基本的数据类型,一个键对应一个值,它的功能十分强大,可以存储字符串、整数、浮点数等多种数据格式,本文给大家介绍Redis基本数据类型和... 目录一、Redis 入门介绍二、Redis 的五大基本数据类型2.1 String 类型2.2 Hash