【深度学习笔记】7_3 小批量随机梯度下降

2024-03-12 22:04

本文主要是介绍【深度学习笔记】7_3 小批量随机梯度下降,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

注:本文为《动手学深度学习》开源内容,部分标注了个人理解,仅为个人学习记录,无抄袭搬运意图

7.3 小批量随机梯度下降

在每一次迭代中,梯度下降使用整个训练数据集来计算梯度,因此它有时也被称为批量梯度下降(batch gradient descent)。而随机梯度下降在每次迭代中只随机采样一个样本来计算梯度。正如我们在前几章中所看到的,我们还可以在每轮迭代中随机均匀采样多个样本来组成一个小批量,然后使用这个小批量来计算梯度。下面就来描述小批量随机梯度下降。

设目标函数 f ( x ) : R d → R f(\boldsymbol{x}): \mathbb{R}^d \rightarrow \mathbb{R} f(x):RdR。在迭代开始前的时间步设为0。该时间步的自变量记为 x 0 ∈ R d \boldsymbol{x}_0\in \mathbb{R}^d x0Rd,通常由随机初始化得到。在接下来的每一个时间步 t > 0 t>0 t>0中,小批量随机梯度下降随机均匀采样一个由训练数据样本索引组成的小批量 B t \mathcal{B}_t Bt。我们可以通过重复采样(sampling with replacement)或者不重复采样(sampling without replacement)得到一个小批量中的各个样本。前者允许同一个小批量中出现重复的样本,后者则不允许如此,且更常见。对于这两者间的任一种方式,都可以使用

g t ← ∇ f B t ( x t − 1 ) = 1 ∣ B ∣ ∑ i ∈ B t ∇ f i ( x t − 1 ) \boldsymbol{g}_t \leftarrow \nabla f_{\mathcal{B}_t}(\boldsymbol{x}_{t-1}) = \frac{1}{|\mathcal{B}|} \sum_{i \in \mathcal{B}_t}\nabla f_i(\boldsymbol{x}_{t-1}) gtfBt(xt1)=B1iBtfi(xt1)

来计算时间步 t t t的小批量 B t \mathcal{B}_t Bt上目标函数位于 x t − 1 \boldsymbol{x}_{t-1} xt1处的梯度 g t \boldsymbol{g}_t gt。这里 ∣ B ∣ |\mathcal{B}| B代表批量大小,即小批量中样本的个数,是一个超参数。同随机梯度一样,重复采样所得的小批量随机梯度 g t \boldsymbol{g}_t gt也是对梯度 ∇ f ( x t − 1 ) \nabla f(\boldsymbol{x}_{t-1}) f(xt1)的无偏估计。给定学习率 η t \eta_t ηt(取正数),小批量随机梯度下降对自变量的迭代如下:

x t ← x t − 1 − η t g t . \boldsymbol{x}_t \leftarrow \boldsymbol{x}_{t-1} - \eta_t \boldsymbol{g}_t. xtxt1ηtgt.

基于随机采样得到的梯度的方差在迭代过程中无法减小,因此在实际中,(小批量)随机梯度下降的学习率可以在迭代过程中自我衰减,例如 η t = η t α \eta_t=\eta t^\alpha ηt=ηtα(通常 α = − 1 \alpha=-1 α=1或者 − 0.5 -0.5 0.5)、 η t = η α t \eta_t = \eta \alpha^t ηt=ηαt(如 α = 0.95 \alpha=0.95 α=0.95)或者每迭代若干次后将学习率衰减一次。如此一来,学习率和(小批量)随机梯度乘积的方差会减小。而梯度下降在迭代过程中一直使用目标函数的真实梯度,无须自我衰减学习率。

小批量随机梯度下降中每次迭代的计算开销为 O ( ∣ B ∣ ) \mathcal{O}(|\mathcal{B}|) O(B)。当批量大小为1时,该算法即为随机梯度下降;当批量大小等于训练数据样本数时,该算法即为梯度下降。当批量较小时,每次迭代中使用的样本少,这会导致并行处理和内存使用效率变低。这使得在计算同样数目样本的情况下比使用更大批量时所花时间更多。当批量较大时,每个小批量梯度里可能含有更多的冗余信息。为了得到较好的解,批量较大时比批量较小时需要计算的样本数目可能更多,例如增大迭代周期数。

7.3.1 读取数据

本章里我们将使用一个来自NASA的测试不同飞机机翼噪音的数据集来比较各个优化算法 [1]。我们使用该数据集的前1,500个样本和5个特征,并使用标准化对数据进行预处理。

%matplotlib inline
import numpy as np
import time
import torch
from torch import nn, optim
import sys
sys.path.append("..") 
import d2lzh_pytorch as d2ldef get_data_ch7():  # 本函数已保存在d2lzh_pytorch包中方便以后使用data = np.genfromtxt('../../data/airfoil_self_noise.dat', delimiter='\t')data = (data - data.mean(axis=0)) / data.std(axis=0)return torch.tensor(data[:1500, :-1], dtype=torch.float32), \torch.tensor(data[:1500, -1], dtype=torch.float32) # 前1500个样本(每个样本5个特征)features, labels = get_data_ch7()
features.shape # torch.Size([1500, 5])

7.3.2 从零开始实现

3.2节(线性回归的从零开始实现)中已经实现过小批量随机梯度下降算法。我们在这里将它的输入参数变得更加通用,主要是为了方便本章后面介绍的其他优化算法也可以使用同样的输入。具体来说,我们添加了一个状态输入states并将超参数放在字典hyperparams里。此外,我们将在训练函数里对各个小批量样本的损失求平均,因此优化算法里的梯度不需要除以批量大小。

def sgd(params, states, hyperparams):for p in params:p.data -= hyperparams['lr'] * p.grad.data

下面实现一个通用的训练函数,以方便本章后面介绍的其他优化算法使用。它初始化一个线性回归模型,然后可以使用小批量随机梯度下降以及后续小节介绍的其他算法来训练模型。

# 本函数已保存在d2lzh_pytorch包中方便以后使用
def train_ch7(optimizer_fn, states, hyperparams, features, labels,batch_size=10, num_epochs=2):# 初始化模型net, loss = d2l.linreg, d2l.squared_lossw = torch.nn.Parameter(torch.tensor(np.random.normal(0, 0.01, size=(features.shape[1], 1)), dtype=torch.float32),requires_grad=True)b = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32), requires_grad=True)def eval_loss():return loss(net(features, w, b), labels).mean().item()ls = [eval_loss()]data_iter = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(features, labels), batch_size, shuffle=True)for _ in range(num_epochs):start = time.time()for batch_i, (X, y) in enumerate(data_iter):l = loss(net(X, w, b), y).mean()  # 使用平均损失# 梯度清零if w.grad is not None:w.grad.data.zero_()b.grad.data.zero_()l.backward()optimizer_fn([w, b], states, hyperparams)  # 迭代模型参数if (batch_i + 1) * batch_size % 100 == 0:ls.append(eval_loss())  # 每100个样本记录下当前训练误差# 打印结果和作图print('loss: %f, %f sec per epoch' % (ls[-1], time.time() - start))d2l.set_figsize()d2l.plt.plot(np.linspace(0, num_epochs, len(ls)), ls)d2l.plt.xlabel('epoch')d2l.plt.ylabel('loss')

当批量大小为样本总数1,500时,优化使用的是梯度下降。梯度下降的1个迭代周期对模型参数只迭代1次。可以看到6次迭代后目标函数值(训练损失)的下降趋向了平稳。

def train_sgd(lr, batch_size, num_epochs=2):train_ch7(sgd, None, {'lr': lr}, features, labels, batch_size, num_epochs)train_sgd(1, 1500, 6)

输出:

loss: 0.243605, 0.014335 sec per epoch

在这里插入图片描述

当批量大小为1时,优化使用的是随机梯度下降。为了简化实现,有关(小批量)随机梯度下降的实验中,我们未对学习率进行自我衰减,而是直接采用较小的常数学习率。随机梯度下降中,每处理一个样本会更新一次自变量(模型参数),一个迭代周期里会对自变量进行1,500次更新。可以看到,目标函数值的下降在1个迭代周期后就变得较为平缓。

train_sgd(0.005, 1)

输出:

loss: 0.243433, 0.270011 sec per epoch

在这里插入图片描述

虽然随机梯度下降和梯度下降在一个迭代周期里都处理了1,500个样本,但实验中随机梯度下降的一个迭代周期耗时更多。这是因为随机梯度下降在一个迭代周期里做了更多次的自变量迭代,而且单样本的梯度计算难以有效利用矢量计算。

当批量大小为10时,优化使用的是小批量随机梯度下降。它在每个迭代周期的耗时介于梯度下降和随机梯度下降的耗时之间。

train_sgd(0.05, 10)

输出:

loss: 0.242805, 0.078792 sec per epoch

在这里插入图片描述

7.3.3 简洁实现

在PyTorch里可以通过创建optimizer实例来调用优化算法。这能让实现更简洁。下面实现一个通用的训练函数,它通过优化算法的函数optimizer_fn和超参数optimizer_hyperparams来创建optimizer实例。

# 本函数与原书不同的是这里第一个参数优化器函数而不是优化器的名字
# 例如: optimizer_fn=torch.optim.SGD, optimizer_hyperparams={"lr": 0.05}
def train_pytorch_ch7(optimizer_fn, optimizer_hyperparams, features, labels,batch_size=10, num_epochs=2):# 初始化模型net = nn.Sequential(nn.Linear(features.shape[-1], 1))loss = nn.MSELoss()optimizer = optimizer_fn(net.parameters(), **optimizer_hyperparams)def eval_loss():return loss(net(features).view(-1), labels).item() / 2ls = [eval_loss()]data_iter = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(features, labels), batch_size, shuffle=True)for _ in range(num_epochs):start = time.time()for batch_i, (X, y) in enumerate(data_iter):# 除以2是为了和train_ch7保持一致, 因为squared_loss中除了2l = loss(net(X).view(-1), y) / 2 optimizer.zero_grad()l.backward()optimizer.step()if (batch_i + 1) * batch_size % 100 == 0:ls.append(eval_loss())# 打印结果和作图print('loss: %f, %f sec per epoch' % (ls[-1], time.time() - start))d2l.set_figsize()d2l.plt.plot(np.linspace(0, num_epochs, len(ls)), ls)d2l.plt.xlabel('epoch')d2l.plt.ylabel('loss')

使用PyTorch重复上一个实验。

train_pytorch_ch7(optim.SGD, {"lr": 0.05}, features, labels, 10)

输出:

loss: 0.245491, 0.044150 sec per epoch

在这里插入图片描述

小结

  • 小批量随机梯度每次随机均匀采样一个小批量的训练样本来计算梯度。
  • 在实际中,(小批量)随机梯度下降的学习率可以在迭代过程中自我衰减。
  • 通常,小批量随机梯度在每个迭代周期的耗时介于梯度下降和随机梯度下降的耗时之间。

参考文献

[1] 飞机机翼噪音数据集。https://archive.ics.uci.edu/ml/datasets/Airfoil+Self-Noise


注:除代码外本节与原书此节基本相同,原书传送门

这篇关于【深度学习笔记】7_3 小批量随机梯度下降的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/802719

相关文章

Go学习记录之runtime包深入解析

《Go学习记录之runtime包深入解析》Go语言runtime包管理运行时环境,涵盖goroutine调度、内存分配、垃圾回收、类型信息等核心功能,:本文主要介绍Go学习记录之runtime包的... 目录前言:一、runtime包内容学习1、作用:① Goroutine和并发控制:② 垃圾回收:③ 栈和

Python中文件读取操作漏洞深度解析与防护指南

《Python中文件读取操作漏洞深度解析与防护指南》在Web应用开发中,文件操作是最基础也最危险的功能之一,这篇文章将全面剖析Python环境中常见的文件读取漏洞类型,成因及防护方案,感兴趣的小伙伴可... 目录引言一、静态资源处理中的路径穿越漏洞1.1 典型漏洞场景1.2 os.path.join()的陷

Android学习总结之Java和kotlin区别超详细分析

《Android学习总结之Java和kotlin区别超详细分析》Java和Kotlin都是用于Android开发的编程语言,它们各自具有独特的特点和优势,:本文主要介绍Android学习总结之Ja... 目录一、空安全机制真题 1:Kotlin 如何解决 Java 的 NullPointerExceptio

C/C++的OpenCV 进行图像梯度提取的几种实现

《C/C++的OpenCV进行图像梯度提取的几种实现》本文主要介绍了C/C++的OpenCV进行图像梯度提取的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的... 目录预www.chinasem.cn备知识1. 图像加载与预处理2. Sobel 算子计算 X 和 Y

Spring Boot拦截器Interceptor与过滤器Filter深度解析(区别、实现与实战指南)

《SpringBoot拦截器Interceptor与过滤器Filter深度解析(区别、实现与实战指南)》:本文主要介绍SpringBoot拦截器Interceptor与过滤器Filter深度解析... 目录Spring Boot拦截器(Interceptor)与过滤器(Filter)深度解析:区别、实现与实

MyBatis分页插件PageHelper深度解析与实践指南

《MyBatis分页插件PageHelper深度解析与实践指南》在数据库操作中,分页查询是最常见的需求之一,传统的分页方式通常有两种内存分页和SQL分页,MyBatis作为优秀的ORM框架,本身并未提... 目录1. 为什么需要分页插件?2. PageHelper简介3. PageHelper集成与配置3.

Maven 插件配置分层架构深度解析

《Maven插件配置分层架构深度解析》:本文主要介绍Maven插件配置分层架构深度解析,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录Maven 插件配置分层架构深度解析引言:当构建逻辑遇上复杂配置第一章 Maven插件配置的三重境界1.1 插件配置的拓扑

重新对Java的类加载器的学习方式

《重新对Java的类加载器的学习方式》:本文主要介绍重新对Java的类加载器的学习方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1、介绍1.1、简介1.2、符号引用和直接引用1、符号引用2、直接引用3、符号转直接的过程2、加载流程3、类加载的分类3.1、显示

Python开发文字版随机事件游戏的项目实例

《Python开发文字版随机事件游戏的项目实例》随机事件游戏是一种通过生成不可预测的事件来增强游戏体验的类型,在这篇博文中,我们将使用Python开发一款文字版随机事件游戏,通过这个项目,读者不仅能够... 目录项目概述2.1 游戏概念2.2 游戏特色2.3 目标玩家群体技术选择与环境准备3.1 开发环境3

Java学习手册之Filter和Listener使用方法

《Java学习手册之Filter和Listener使用方法》:本文主要介绍Java学习手册之Filter和Listener使用方法的相关资料,Filter是一种拦截器,可以在请求到达Servl... 目录一、Filter(过滤器)1. Filter 的工作原理2. Filter 的配置与使用二、Listen