【从零开始学习深度学习】13. 防止过拟合方法:权重衰减(L2惩罚项)介绍及示例演示

本文主要是介绍【从零开始学习深度学习】13. 防止过拟合方法:权重衰减(L2惩罚项)介绍及示例演示,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

上一篇文章中介绍了过拟合现象,即模型的训练误差远小于它在测试集上的误差。虽然增大训练数据集可能会减轻过拟合,但是获取额外的训练数据往往代价高昂。本节介绍应对过拟合问题的常用方法:权重衰减(weight decay)

目录

  • 1 权重衰减
    • 1.1 方法
    • 1.2 高维线性回归示例
    • 1.3 从零开始实现带权重衰减的回归模型
    • 1.4 权重衰减的Pytorch简洁实现
    • 小结

1 权重衰减

1.1 方法

权重衰减等价于 L 2 L_2 L2 范数正则化(regularization)。正则化通过为模型损失函数添加惩罚项使学出的模型参数值较小,是应对过拟合的常用手段。我们先描述 L 2 L_2 L2范数正则化,再解释它为何又称权重衰减。

L 2 L_2 L2范数正则化在模型原损失函数基础上添加 L 2 L_2 L2范数惩罚项,从而得到训练所需要最小化的函数。 L 2 L_2 L2范数惩罚项指的是模型权重参数每个元素的平方和与一个正的常数的乘积。以线性回归损失函数
ℓ ( w 1 , w 2 , b ) = 1 n ∑ i = 1 n 1 2 ( x 1 ( i ) w 1 + x 2 ( i ) w 2 + b − y ( i ) ) 2 \ell(w_1, w_2, b) = \frac{1}{n} \sum_{i=1}^n \frac{1}{2}\left(x_1^{(i)} w_1 + x_2^{(i)} w_2 + b - y^{(i)}\right)^2 (w1,w2,b)=n1i=1n21(x1(i)w1+x2(i)w2+by(i))2

为例,其中 w 1 , w 2 w_1, w_2 w1,w2是权重参数, b b b是偏差参数,样本 i i i的输入为 x 1 ( i ) , x 2 ( i ) x_1^{(i)}, x_2^{(i)} x1(i),x2(i),标签为 y ( i ) y^{(i)} y(i),样本数为 n n n。将权重参数用向量 w = [ w 1 , w 2 ] \boldsymbol{w} = [w_1, w_2] w=[w1,w2]表示,带有 L 2 L_2 L2范数惩罚项的新损失函数为

ℓ ( w 1 , w 2 , b ) + λ 2 n ∥ w ∥ 2 , \ell(w_1, w_2, b) + \frac{\lambda}{2n} \|\boldsymbol{w}\|^2, (w1,w2,b)+2nλw2,

其中超参数 λ > 0 \lambda > 0 λ>0。当权重参数均为0时,惩罚项最小。当 λ \lambda λ较大时,惩罚项在损失函数中的比重较大,这通常会使学到的权重参数的元素较接近0。当 λ \lambda λ设为0时,惩罚项完全不起作用。上式中 L 2 L_2 L2范数平方 ∥ w ∥ 2 \|\boldsymbol{w}\|^2 w2展开后得到 w 1 2 + w 2 2 w_1^2 + w_2^2 w12+w22。有了 L 2 L_2 L2范数惩罚项后,在小批量随机梯度下降中,我们将线性回归中权重 w 1 w_1 w1 w 2 w_2 w2的迭代方式更改为

w 1 ← ( 1 − η λ ∣ B ∣ ) w 1 − η ∣ B ∣ ∑ i ∈ B x 1 ( i ) ( x 1 ( i ) w 1 + x 2 ( i ) w 2 + b − y ( i ) ) , w 2 ← ( 1 − η λ ∣ B ∣ ) w 2 − η ∣ B ∣ ∑ i ∈ B x 2 ( i ) ( x 1 ( i ) w 1 + x 2 ( i ) w 2 + b − y ( i ) ) . \begin{aligned} w_1 &\leftarrow \left(1- \frac{\eta\lambda}{|\mathcal{B}|} \right)w_1 - \frac{\eta}{|\mathcal{B}|} \sum_{i \in \mathcal{B}}x_1^{(i)} \left(x_1^{(i)} w_1 + x_2^{(i)} w_2 + b - y^{(i)}\right),\\ w_2 &\leftarrow \left(1- \frac{\eta\lambda}{|\mathcal{B}|} \right)w_2 - \frac{\eta}{|\mathcal{B}|} \sum_{i \in \mathcal{B}}x_2^{(i)} \left(x_1^{(i)} w_1 + x_2^{(i)} w_2 + b - y^{(i)}\right). \end{aligned} w1w2(1Bηλ)w1BηiBx1(i)(x1(i)w1+x2(i)w2+by(i)),(1Bηλ)w2BηiBx2(i)(x1(i)w1+x2(i)w2+by(i)).

可见, L 2 L_2 L2范数正则化令权重 w 1 w_1 w1 w 2 w_2 w2先自乘小于1的数,再减去不含惩罚项的梯度。因此, L 2 L_2 L2范数正则化又叫权重衰减。权重衰减通过惩罚绝对值较大的模型参数为需要学习的模型增加了限制,这可能对过拟合有效

1.2 高维线性回归示例

下面,我们以高维线性回归为例来引入一个过拟合问题,并使用权重衰减来应对过拟合。设数据样本特征的维度为 p p p。对于训练数据集和测试数据集中特征为 x 1 , x 2 , … , x p x_1, x_2, \ldots, x_p x1,x2,,xp的任一样本,我们使用如下的线性函数来生成该样本的标签:

y = 0.05 + ∑ i = 1 p 0.01 x i + ϵ y = 0.05 + \sum_{i = 1}^p 0.01x_i + \epsilon y=0.05+i=1p0.01xi+ϵ

其中噪声项 ϵ \epsilon ϵ服从均值为0、标准差为0.01的正态分布。为了较容易地观察过拟合,我们考虑高维线性回归问题,如设维度 p = 200 p=200 p=200;同时,我们特意把训练数据集的样本数设低,如20。

%matplotlib inline
import torch
import torch.nn as nn
import numpy as np
import sys
import d2lzh_pytorch as d2ln_train, n_test, num_inputs = 20, 100, 200
true_w, true_b = torch.ones(num_inputs, 1) * 0.01, 0.05features = torch.randn((n_train + n_test, num_inputs))
labels = torch.matmul(features, true_w) + true_b
labels += torch.tensor(np.random.normal(0, 0.01, size=labels.size()), dtype=torch.float)
train_features, test_features = features[:n_train, :], features[n_train:, :]
train_labels, test_labels = labels[:n_train], labels[n_train:]

1.3 从零开始实现带权重衰减的回归模型

下面先介绍从零开始实现权重衰减的方法:通过在目标函数后添加 L 2 L_2 L2范数惩罚项来实现权重衰减。

1.3.1 初始化模型参数

首先,定义随机初始化模型参数的函数。该函数为每个参数都附上梯度。

def init_params():w = torch.randn((num_inputs, 1), requires_grad=True)b = torch.zeros(1, requires_grad=True)return [w, b]

1.3.2 定义 L 2 L_2 L2范数惩罚项

下面定义 L 2 L_2 L2范数惩罚项:这里只惩罚模型的权重参数。

def l2_penalty(w):return (w**2).sum() / 2

1.3.3 定义训练和测试

下面定义如何在训练数据集和测试数据集上分别训练和测试模型:在计算最终的损失函数时添加 L 2 L_2 L2范数惩罚项。

batch_size, num_epochs, lr = 1, 100, 0.003
net, loss = d2l.linreg, d2l.squared_lossdataset = torch.utils.data.TensorDataset(train_features, train_labels)
train_iter = torch.utils.data.DataLoader(dataset, batch_size, shuffle=True)def fit_and_plot(lambd):w, b = init_params()train_ls, test_ls = [], []for _ in range(num_epochs):for X, y in train_iter:# 添加了L2范数惩罚项l = loss(net(X, w, b), y) + lambd * l2_penalty(w)l = l.sum()if w.grad is not None:w.grad.data.zero_()b.grad.data.zero_()l.backward()d2l.sgd([w, b], lr, batch_size)train_ls.append(loss(net(train_features, w, b), train_labels).mean().item())test_ls.append(loss(net(test_features, w, b), test_labels).mean().item())d2l.semilogy(range(1, num_epochs + 1), train_ls, 'epochs', 'loss',range(1, num_epochs + 1), test_ls, ['train', 'test'])print('L2 norm of w:', w.norm().item())

1.3.4 观察过拟合

训练并测试高维线性回归模型。当lambd设为0时,我们没有使用权重衰减。结果训练误差远小于测试集上的误差。这是典型的过拟合现象。

fit_and_plot(lambd=0)

输出:

L2 norm of w: 15.114808082580566

在这里插入图片描述

1.3.5 使用权重衰减

下面我们使用权重衰减。可以看出,训练误差虽然有所提高,但测试集上的误差有所下降。过拟合现象得到一定程度的缓解。另外,权重参数的 L 2 L_2 L2范数比不使用权重衰减时的更小,此时的权重参数更接近0。

fit_and_plot(lambd=3)

输出:

L2 norm of w: 0.035220853984355927

在这里插入图片描述

1.4 权重衰减的Pytorch简洁实现

这里我们直接在构造优化器实例时通过weight_decay参数来指定权重衰减超参数。默认下,PyTorch会对权重和偏差同时衰减。我们可以分别对权重和偏差构造优化器实例,从而只对权重衰减。

def fit_and_plot_pytorch(wd):# 对权重参数衰减。权重名称一般是以weight结尾net = nn.Linear(num_inputs, 1)nn.init.normal_(net.weight, mean=0, std=1)nn.init.normal_(net.bias, mean=0, std=1)optimizer_w = torch.optim.SGD(params=[net.weight], lr=lr, weight_decay=wd) # 对权重参数衰减optimizer_b = torch.optim.SGD(params=[net.bias], lr=lr)  # 不对偏差参数衰减train_ls, test_ls = [], []for _ in range(num_epochs):for X, y in train_iter:l = loss(net(X), y).mean()optimizer_w.zero_grad()optimizer_b.zero_grad()l.backward()# 对两个optimizer实例分别调用step函数,从而分别更新权重和偏差optimizer_w.step()optimizer_b.step()train_ls.append(loss(net(train_features), train_labels).mean().item())test_ls.append(loss(net(test_features), test_labels).mean().item())d2l.semilogy(range(1, num_epochs + 1), train_ls, 'epochs', 'loss',range(1, num_epochs + 1), test_ls, ['train', 'test'])print('L2 norm of w:', net.weight.data.norm().item())

与从零开始实现权重衰减的实验现象类似,使用权重衰减可以在一定程度上缓解过拟合问题。

fit_and_plot_pytorch(0)   # 不使用衰减

输出:

L2 norm of w: 12.86785888671875

在这里插入图片描述

fit_and_plot_pytorch(3)  # 使用衰减

输出:

L2 norm of w: 0.09631537646055222

在这里插入图片描述

小结

  • 正则化通过为模型损失函数添加惩罚项使学出的模型参数值较小,是应对过拟合的常用手段。
  • 权重衰减等价于 L 2 L_2 L2范数正则化,通常会使学到的权重参数较接近0。
  • 权重衰减可以通过优化器中的weight_decay超参数来指定。
  • 可以定义多个优化器实例对不同的模型参数使用不同的迭代方法。

如果内容对你有帮助,感谢点赞+关注哦!

关注下方GZH:阿旭算法与机器学习,可获取更多干货内容~欢迎共同学习交流

这篇关于【从零开始学习深度学习】13. 防止过拟合方法:权重衰减(L2惩罚项)介绍及示例演示的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/359564

相关文章

Android 12解决push framework.jar无法开机的方法小结

《Android12解决pushframework.jar无法开机的方法小结》:本文主要介绍在Android12中解决pushframework.jar无法开机的方法,包括编译指令、框架层和s... 目录1. android 编译指令1.1 framework层的编译指令1.2 替换framework.ja

在.NET平台使用C#为PDF添加各种类型的表单域的方法

《在.NET平台使用C#为PDF添加各种类型的表单域的方法》在日常办公系统开发中,涉及PDF处理相关的开发时,生成可填写的PDF表单是一种常见需求,与静态PDF不同,带有**表单域的文档支持用户直接在... 目录引言使用 PdfTextBoxField 添加文本输入域使用 PdfComboBoxField

SQLyog中DELIMITER执行存储过程时出现前置缩进问题的解决方法

《SQLyog中DELIMITER执行存储过程时出现前置缩进问题的解决方法》在SQLyog中执行存储过程时出现的前置缩进问题,实际上反映了SQLyog对SQL语句解析的一个特殊行为,本文给大家介绍了详... 目录问题根源正确写法示例永久解决方案为什么命令行不受影响?最佳实践建议问题根源SQLyog的语句分

使用Java将各种数据写入Excel表格的操作示例

《使用Java将各种数据写入Excel表格的操作示例》在数据处理与管理领域,Excel凭借其强大的功能和广泛的应用,成为了数据存储与展示的重要工具,在Java开发过程中,常常需要将不同类型的数据,本文... 目录前言安装免费Java库1. 写入文本、或数值到 Excel单元格2. 写入数组到 Excel表格

Java 中的 @SneakyThrows 注解使用方法(简化异常处理的利与弊)

《Java中的@SneakyThrows注解使用方法(简化异常处理的利与弊)》为了简化异常处理,Lombok提供了一个强大的注解@SneakyThrows,本文将详细介绍@SneakyThro... 目录1. @SneakyThrows 简介 1.1 什么是 Lombok?2. @SneakyThrows

Python中的Walrus运算符分析示例详解

《Python中的Walrus运算符分析示例详解》Python中的Walrus运算符(:=)是Python3.8引入的一个新特性,允许在表达式中同时赋值和返回值,它的核心作用是减少重复计算,提升代码简... 目录1. 在循环中避免重复计算2. 在条件判断中同时赋值变量3. 在列表推导式或字典推导式中简化逻辑

判断PyTorch是GPU版还是CPU版的方法小结

《判断PyTorch是GPU版还是CPU版的方法小结》PyTorch作为当前最流行的深度学习框架之一,支持在CPU和GPU(NVIDIACUDA)上运行,所以对于深度学习开发者来说,正确识别PyTor... 目录前言为什么需要区分GPU和CPU版本?性能差异硬件要求如何检查PyTorch版本?方法1:使用命

Python位移操作和位运算的实现示例

《Python位移操作和位运算的实现示例》本文主要介绍了Python位移操作和位运算的实现示例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一... 目录1. 位移操作1.1 左移操作 (<<)1.2 右移操作 (>>)注意事项:2. 位运算2.1

Qt实现网络数据解析的方法总结

《Qt实现网络数据解析的方法总结》在Qt中解析网络数据通常涉及接收原始字节流,并将其转换为有意义的应用层数据,这篇文章为大家介绍了详细步骤和示例,感兴趣的小伙伴可以了解下... 目录1. 网络数据接收2. 缓冲区管理(处理粘包/拆包)3. 常见数据格式解析3.1 jsON解析3.2 XML解析3.3 自定义

SpringMVC 通过ajax 前后端数据交互的实现方法

《SpringMVC通过ajax前后端数据交互的实现方法》:本文主要介绍SpringMVC通过ajax前后端数据交互的实现方法,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价... 在前端的开发过程中,经常在html页面通过AJAX进行前后端数据的交互,SpringMVC的controll