【深度学习笔记】7_3 小批量随机梯度下降

2024-03-12 22:04

本文主要是介绍【深度学习笔记】7_3 小批量随机梯度下降,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

注:本文为《动手学深度学习》开源内容,部分标注了个人理解,仅为个人学习记录,无抄袭搬运意图

7.3 小批量随机梯度下降

在每一次迭代中,梯度下降使用整个训练数据集来计算梯度,因此它有时也被称为批量梯度下降(batch gradient descent)。而随机梯度下降在每次迭代中只随机采样一个样本来计算梯度。正如我们在前几章中所看到的,我们还可以在每轮迭代中随机均匀采样多个样本来组成一个小批量,然后使用这个小批量来计算梯度。下面就来描述小批量随机梯度下降。

设目标函数 f ( x ) : R d → R f(\boldsymbol{x}): \mathbb{R}^d \rightarrow \mathbb{R} f(x):RdR。在迭代开始前的时间步设为0。该时间步的自变量记为 x 0 ∈ R d \boldsymbol{x}_0\in \mathbb{R}^d x0Rd,通常由随机初始化得到。在接下来的每一个时间步 t > 0 t>0 t>0中,小批量随机梯度下降随机均匀采样一个由训练数据样本索引组成的小批量 B t \mathcal{B}_t Bt。我们可以通过重复采样(sampling with replacement)或者不重复采样(sampling without replacement)得到一个小批量中的各个样本。前者允许同一个小批量中出现重复的样本,后者则不允许如此,且更常见。对于这两者间的任一种方式,都可以使用

g t ← ∇ f B t ( x t − 1 ) = 1 ∣ B ∣ ∑ i ∈ B t ∇ f i ( x t − 1 ) \boldsymbol{g}_t \leftarrow \nabla f_{\mathcal{B}_t}(\boldsymbol{x}_{t-1}) = \frac{1}{|\mathcal{B}|} \sum_{i \in \mathcal{B}_t}\nabla f_i(\boldsymbol{x}_{t-1}) gtfBt(xt1)=B1iBtfi(xt1)

来计算时间步 t t t的小批量 B t \mathcal{B}_t Bt上目标函数位于 x t − 1 \boldsymbol{x}_{t-1} xt1处的梯度 g t \boldsymbol{g}_t gt。这里 ∣ B ∣ |\mathcal{B}| B代表批量大小,即小批量中样本的个数,是一个超参数。同随机梯度一样,重复采样所得的小批量随机梯度 g t \boldsymbol{g}_t gt也是对梯度 ∇ f ( x t − 1 ) \nabla f(\boldsymbol{x}_{t-1}) f(xt1)的无偏估计。给定学习率 η t \eta_t ηt(取正数),小批量随机梯度下降对自变量的迭代如下:

x t ← x t − 1 − η t g t . \boldsymbol{x}_t \leftarrow \boldsymbol{x}_{t-1} - \eta_t \boldsymbol{g}_t. xtxt1ηtgt.

基于随机采样得到的梯度的方差在迭代过程中无法减小,因此在实际中,(小批量)随机梯度下降的学习率可以在迭代过程中自我衰减,例如 η t = η t α \eta_t=\eta t^\alpha ηt=ηtα(通常 α = − 1 \alpha=-1 α=1或者 − 0.5 -0.5 0.5)、 η t = η α t \eta_t = \eta \alpha^t ηt=ηαt(如 α = 0.95 \alpha=0.95 α=0.95)或者每迭代若干次后将学习率衰减一次。如此一来,学习率和(小批量)随机梯度乘积的方差会减小。而梯度下降在迭代过程中一直使用目标函数的真实梯度,无须自我衰减学习率。

小批量随机梯度下降中每次迭代的计算开销为 O ( ∣ B ∣ ) \mathcal{O}(|\mathcal{B}|) O(B)。当批量大小为1时,该算法即为随机梯度下降;当批量大小等于训练数据样本数时,该算法即为梯度下降。当批量较小时,每次迭代中使用的样本少,这会导致并行处理和内存使用效率变低。这使得在计算同样数目样本的情况下比使用更大批量时所花时间更多。当批量较大时,每个小批量梯度里可能含有更多的冗余信息。为了得到较好的解,批量较大时比批量较小时需要计算的样本数目可能更多,例如增大迭代周期数。

7.3.1 读取数据

本章里我们将使用一个来自NASA的测试不同飞机机翼噪音的数据集来比较各个优化算法 [1]。我们使用该数据集的前1,500个样本和5个特征,并使用标准化对数据进行预处理。

%matplotlib inline
import numpy as np
import time
import torch
from torch import nn, optim
import sys
sys.path.append("..") 
import d2lzh_pytorch as d2ldef get_data_ch7():  # 本函数已保存在d2lzh_pytorch包中方便以后使用data = np.genfromtxt('../../data/airfoil_self_noise.dat', delimiter='\t')data = (data - data.mean(axis=0)) / data.std(axis=0)return torch.tensor(data[:1500, :-1], dtype=torch.float32), \torch.tensor(data[:1500, -1], dtype=torch.float32) # 前1500个样本(每个样本5个特征)features, labels = get_data_ch7()
features.shape # torch.Size([1500, 5])

7.3.2 从零开始实现

3.2节(线性回归的从零开始实现)中已经实现过小批量随机梯度下降算法。我们在这里将它的输入参数变得更加通用,主要是为了方便本章后面介绍的其他优化算法也可以使用同样的输入。具体来说,我们添加了一个状态输入states并将超参数放在字典hyperparams里。此外,我们将在训练函数里对各个小批量样本的损失求平均,因此优化算法里的梯度不需要除以批量大小。

def sgd(params, states, hyperparams):for p in params:p.data -= hyperparams['lr'] * p.grad.data

下面实现一个通用的训练函数,以方便本章后面介绍的其他优化算法使用。它初始化一个线性回归模型,然后可以使用小批量随机梯度下降以及后续小节介绍的其他算法来训练模型。

# 本函数已保存在d2lzh_pytorch包中方便以后使用
def train_ch7(optimizer_fn, states, hyperparams, features, labels,batch_size=10, num_epochs=2):# 初始化模型net, loss = d2l.linreg, d2l.squared_lossw = torch.nn.Parameter(torch.tensor(np.random.normal(0, 0.01, size=(features.shape[1], 1)), dtype=torch.float32),requires_grad=True)b = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32), requires_grad=True)def eval_loss():return loss(net(features, w, b), labels).mean().item()ls = [eval_loss()]data_iter = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(features, labels), batch_size, shuffle=True)for _ in range(num_epochs):start = time.time()for batch_i, (X, y) in enumerate(data_iter):l = loss(net(X, w, b), y).mean()  # 使用平均损失# 梯度清零if w.grad is not None:w.grad.data.zero_()b.grad.data.zero_()l.backward()optimizer_fn([w, b], states, hyperparams)  # 迭代模型参数if (batch_i + 1) * batch_size % 100 == 0:ls.append(eval_loss())  # 每100个样本记录下当前训练误差# 打印结果和作图print('loss: %f, %f sec per epoch' % (ls[-1], time.time() - start))d2l.set_figsize()d2l.plt.plot(np.linspace(0, num_epochs, len(ls)), ls)d2l.plt.xlabel('epoch')d2l.plt.ylabel('loss')

当批量大小为样本总数1,500时,优化使用的是梯度下降。梯度下降的1个迭代周期对模型参数只迭代1次。可以看到6次迭代后目标函数值(训练损失)的下降趋向了平稳。

def train_sgd(lr, batch_size, num_epochs=2):train_ch7(sgd, None, {'lr': lr}, features, labels, batch_size, num_epochs)train_sgd(1, 1500, 6)

输出:

loss: 0.243605, 0.014335 sec per epoch

在这里插入图片描述

当批量大小为1时,优化使用的是随机梯度下降。为了简化实现,有关(小批量)随机梯度下降的实验中,我们未对学习率进行自我衰减,而是直接采用较小的常数学习率。随机梯度下降中,每处理一个样本会更新一次自变量(模型参数),一个迭代周期里会对自变量进行1,500次更新。可以看到,目标函数值的下降在1个迭代周期后就变得较为平缓。

train_sgd(0.005, 1)

输出:

loss: 0.243433, 0.270011 sec per epoch

在这里插入图片描述

虽然随机梯度下降和梯度下降在一个迭代周期里都处理了1,500个样本,但实验中随机梯度下降的一个迭代周期耗时更多。这是因为随机梯度下降在一个迭代周期里做了更多次的自变量迭代,而且单样本的梯度计算难以有效利用矢量计算。

当批量大小为10时,优化使用的是小批量随机梯度下降。它在每个迭代周期的耗时介于梯度下降和随机梯度下降的耗时之间。

train_sgd(0.05, 10)

输出:

loss: 0.242805, 0.078792 sec per epoch

在这里插入图片描述

7.3.3 简洁实现

在PyTorch里可以通过创建optimizer实例来调用优化算法。这能让实现更简洁。下面实现一个通用的训练函数,它通过优化算法的函数optimizer_fn和超参数optimizer_hyperparams来创建optimizer实例。

# 本函数与原书不同的是这里第一个参数优化器函数而不是优化器的名字
# 例如: optimizer_fn=torch.optim.SGD, optimizer_hyperparams={"lr": 0.05}
def train_pytorch_ch7(optimizer_fn, optimizer_hyperparams, features, labels,batch_size=10, num_epochs=2):# 初始化模型net = nn.Sequential(nn.Linear(features.shape[-1], 1))loss = nn.MSELoss()optimizer = optimizer_fn(net.parameters(), **optimizer_hyperparams)def eval_loss():return loss(net(features).view(-1), labels).item() / 2ls = [eval_loss()]data_iter = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(features, labels), batch_size, shuffle=True)for _ in range(num_epochs):start = time.time()for batch_i, (X, y) in enumerate(data_iter):# 除以2是为了和train_ch7保持一致, 因为squared_loss中除了2l = loss(net(X).view(-1), y) / 2 optimizer.zero_grad()l.backward()optimizer.step()if (batch_i + 1) * batch_size % 100 == 0:ls.append(eval_loss())# 打印结果和作图print('loss: %f, %f sec per epoch' % (ls[-1], time.time() - start))d2l.set_figsize()d2l.plt.plot(np.linspace(0, num_epochs, len(ls)), ls)d2l.plt.xlabel('epoch')d2l.plt.ylabel('loss')

使用PyTorch重复上一个实验。

train_pytorch_ch7(optim.SGD, {"lr": 0.05}, features, labels, 10)

输出:

loss: 0.245491, 0.044150 sec per epoch

在这里插入图片描述

小结

  • 小批量随机梯度每次随机均匀采样一个小批量的训练样本来计算梯度。
  • 在实际中,(小批量)随机梯度下降的学习率可以在迭代过程中自我衰减。
  • 通常,小批量随机梯度在每个迭代周期的耗时介于梯度下降和随机梯度下降的耗时之间。

参考文献

[1] 飞机机翼噪音数据集。https://archive.ics.uci.edu/ml/datasets/Airfoil+Self-Noise


注:除代码外本节与原书此节基本相同,原书传送门

这篇关于【深度学习笔记】7_3 小批量随机梯度下降的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/802719

相关文章

深度解析Spring Security 中的 SecurityFilterChain核心功能

《深度解析SpringSecurity中的SecurityFilterChain核心功能》SecurityFilterChain通过组件化配置、类型安全路径匹配、多链协同三大特性,重构了Spri... 目录Spring Security 中的SecurityFilterChain深度解析一、Security

深度解析Nginx日志分析与499状态码问题解决

《深度解析Nginx日志分析与499状态码问题解决》在Web服务器运维和性能优化过程中,Nginx日志是排查问题的重要依据,本文将围绕Nginx日志分析、499状态码的成因、排查方法及解决方案展开讨论... 目录前言1. Nginx日志基础1.1 Nginx日志存放位置1.2 Nginx日志格式2. 499

python生成随机唯一id的几种实现方法

《python生成随机唯一id的几种实现方法》在Python中生成随机唯一ID有多种方法,根据不同的需求场景可以选择最适合的方案,文中通过示例代码介绍的非常详细,需要的朋友们下面随着小编来一起学习学习... 目录方法 1:使用 UUID 模块(推荐)方法 2:使用 Secrets 模块(安全敏感场景)方法

深度解析Java DTO(最新推荐)

《深度解析JavaDTO(最新推荐)》DTO(DataTransferObject)是一种用于在不同层(如Controller层、Service层)之间传输数据的对象设计模式,其核心目的是封装数据,... 目录一、什么是DTO?DTO的核心特点:二、为什么需要DTO?(对比Entity)三、实际应用场景解析

深度解析Java项目中包和包之间的联系

《深度解析Java项目中包和包之间的联系》文章浏览阅读850次,点赞13次,收藏8次。本文详细介绍了Java分层架构中的几个关键包:DTO、Controller、Service和Mapper。_jav... 目录前言一、各大包1.DTO1.1、DTO的核心用途1.2. DTO与实体类(Entity)的区别1

深度解析Python装饰器常见用法与进阶技巧

《深度解析Python装饰器常见用法与进阶技巧》Python装饰器(Decorator)是提升代码可读性与复用性的强大工具,本文将深入解析Python装饰器的原理,常见用法,进阶技巧与最佳实践,希望可... 目录装饰器的基本原理函数装饰器的常见用法带参数的装饰器类装饰器与方法装饰器装饰器的嵌套与组合进阶技巧

深度解析Spring Boot拦截器Interceptor与过滤器Filter的区别与实战指南

《深度解析SpringBoot拦截器Interceptor与过滤器Filter的区别与实战指南》本文深度解析SpringBoot中拦截器与过滤器的区别,涵盖执行顺序、依赖关系、异常处理等核心差异,并... 目录Spring Boot拦截器(Interceptor)与过滤器(Filter)深度解析:区别、实现

深度解析Spring AOP @Aspect 原理、实战与最佳实践教程

《深度解析SpringAOP@Aspect原理、实战与最佳实践教程》文章系统讲解了SpringAOP核心概念、实现方式及原理,涵盖横切关注点分离、代理机制(JDK/CGLIB)、切入点类型、性能... 目录1. @ASPect 核心概念1.1 AOP 编程范式1.2 @Aspect 关键特性2. 完整代码实

SpringBoot开发中十大常见陷阱深度解析与避坑指南

《SpringBoot开发中十大常见陷阱深度解析与避坑指南》在SpringBoot的开发过程中,即使是经验丰富的开发者也难免会遇到各种棘手的问题,本文将针对SpringBoot开发中十大常见的“坑... 目录引言一、配置总出错?是不是同时用了.properties和.yml?二、换个位置配置就失效?搞清楚加

Go学习记录之runtime包深入解析

《Go学习记录之runtime包深入解析》Go语言runtime包管理运行时环境,涵盖goroutine调度、内存分配、垃圾回收、类型信息等核心功能,:本文主要介绍Go学习记录之runtime包的... 目录前言:一、runtime包内容学习1、作用:① Goroutine和并发控制:② 垃圾回收:③ 栈和