Pytorch入门(7)—— 梯度累加(Gradient Accumulation)

2024-05-24 20:52

本文主要是介绍Pytorch入门(7)—— 梯度累加(Gradient Accumulation),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1. 梯度累加

  • 在训练大模型时,batch_size 最大值往往受限于显存容量上限,当模型非常大时,这个上限可能小到不可接受。梯度累加(Gradient Accumulation)是一个解决该问题的 trick
  • 梯度累加的思想很简单,就是时间换空间。具体而言,我们不在每个 batch data 梯度计算后直接更新模型,而是多算几个 batch 后,使用这些 batch 的平均梯度更新模型,从而放大等效 batch_size。如下图所示
    在这里插入图片描述
  • 用公式表示:设 batch size 为 n n n,模型参数为 w \pmb{w} w,样本 i i i 的损失为 l i l_i li,则正常情况下 sgd 参数更新为
    w ← w + α ∑ i = 1 n 1 n ∂ l i ∂ w \pmb{w} \leftarrow \pmb{w} + \alpha \sum_{i=1}^n\frac{1}{n}\frac{\partial l_i}{\partial \pmb{w}} ww+αi=1nn1wli 使用梯度累加时,设累加步长为 m m m(即计算 m m m 个 batch 梯度后用梯度均值更新一次),sgd 更新如下
    w ← w + α 1 m ∑ b = 1 m ∑ i = 1 n 1 n ∂ l b i ∂ w = w + α ∑ i = 1 m n 1 m n ∂ l i ∂ w \begin{aligned} \pmb{w} &\leftarrow \pmb{w} + \alpha \frac{1}{m} \sum_{b=1}^m \sum_{i=1}^n\frac{1}{n}\frac{\partial l_{bi}}{\partial \pmb{w}} \\ &= \pmb{w} + \alpha \sum_{i=1}^{mn}\frac{1}{mn} \frac{\partial l_i}{\partial \pmb{w}} \end{aligned} ww+αm1b=1mi=1nn1wlbi=w+αi=1mnmn1wli 可见这等价于使用 batch_size = m n mn mn 进行训练

2. 在 pytorch 中实现梯度累加

2.1 伪代码

  • pytorch 使用和 tensor 绑定的自动微分机制。每个 tensor 对象都有 .grad 属性存储其中每个元素的梯度值,通过 .requires_grad 属性控制其是否参与梯度计算。训练模型时,一般通过对标量 loss 执行 loss.backward() 自动进行反向传播,以得到计算图中所有 tensor 的梯度。详见 PyTorch入门(2)—— 自动求梯度
  • pytorch 中梯度 tensor.grad 不会自动清零,而会在每次反向传播过程中自动累加,所以一般在反向传播前把梯度清零
    for inputs, labels in data_loader:# forward pass preds = model(inputs)loss  = criterion(preds, labels)# clear grad of last batch	optimizer.zero_grad()# backward pass, calculate grad of batch dataloss.backward()# update modeloptimizer.step()
    
    这种设计对于实现梯度累加 trick 是很方便的,我们可以在 batch 计算过程中进行计数,仅在达到计数达到更新步长时进行一次参数更新并清零梯度,即
    # batch accumulation parameter
    accum_iter = 4  # loop through enumaretad batches
    for batch_idx, (inputs, labels) in enumerate(data_loader):# forward pass preds = model(inputs)loss  = criterion(preds, labels)# scale the loss to the mean of the accumulated batch sizeloss = loss / accum_iter # backward passloss.backward()# weights updateif ((batch_idx + 1) % accum_iter == 0) or (batch_idx + 1 == len(data_loader)):optimizer.step()optimizer.zero_grad()
    

2.2 线性回归案例

  • 下面使用来自 经典机器学习方法(1)—— 线性回归 的简单线性回归任务说明梯度累加的具体实现方法

    本节代码直接从 jupyter notebook 复制而来,可能无法直接运行!

  • 首先生成随机数据构造 dataset
    import torch
    from IPython import display
    from matplotlib import pyplot as plt
    import numpy as np
    import random
    import torch.utils.data as Data
    import torch.nn as nn
    import torch.optim as optim# 生成样本
    num_inputs = 2
    num_examples = 1000
    true_w = torch.Tensor([-2,3.4]).view(2,1)
    true_b = 4.2
    batch_size = 10# 1000 个2特征样本,每个特征都服从 N(0,1)
    features = torch.randn(num_examples, num_inputs, dtype=torch.float32) # 生成真实标记
    labels = torch.mm(features,true_w) + true_b
    labels += torch.tensor(np.random.normal(0, 0.01, size=labels.size()), dtype=torch.float32)# 包装数据集,将训练数据的特征和标签组合
    dataset = Data.TensorDataset(features, labels)
    
    1. 不使用梯度累加技巧,batch size 设置为 40
      # 构造 DataLoader
      batch_size = 40
      data_iter = Data.DataLoader(dataset, batch_size, shuffle=False)	# shuffle=False 保证实验可比# 定义模型
      net = nn.Sequential(nn.Linear(num_inputs, 1))# 初始化模型参数
      nn.init.normal_(net[0].weight, mean=0, std=0)
      nn.init.constant_(net[0].bias, val=0)# 均方差损失函数
      criterion = nn.MSELoss()# SGD优化器
      optimizer = optim.SGD(net.parameters(), lr=0.01)# 模型训练
      num_epochs = 3
      for epoch in range(1, num_epochs + 1):epoch_loss = []for X, y in data_iter:# 正向传播,计算损失output = net(X) loss = criterion(output, y.view(-1, 1))# 梯度清零optimizer.zero_grad()            # 计算各参数梯度loss.backward()#print('backward: ', net[0].weight.grad)# 更新模型optimizer.step()epoch_loss.append(loss.item()/batch_size)print(f'epoch {epoch}, loss: {np.mean(epoch_loss)}')'''
      epoch 1, loss: 0.5434057731628418
      epoch 2, loss: 0.1914414196014404
      epoch 3, loss: 0.06752514398097992
      '''
      
    2. 使用梯度累加,batch size 设置为 10,步长设为 4,等效 batch size 为 40
      # 构造 DataLoader
      batch_size = 10
      accum_iter = 4
      data_iter = Data.DataLoader(dataset, batch_size, shuffle=False)	# shuffle=False 保证实验可比# 定义模型
      net = nn.Sequential(nn.Linear(num_inputs, 1))# 初始化模型参数
      nn.init.normal_(net[0].weight, mean=0, std=0)
      nn.init.constant_(net[0].bias, val=0)# 均方差损失
      criterion = nn.MSELoss()# SGD优化器对象
      optimizer = optim.SGD(net.parameters(), lr=0.01)# 模型训练
      num_epochs = 3
      for epoch in range(1, num_epochs + 1):epoch_loss = []for batch_idx, (X, y) in enumerate(data_iter):# 正向传播,计算损失output = net(X) loss = criterion(output, y.view(-1, 1))  loss = loss / accum_iter	# 取各个累计batch的平均损失,从而在.backward()时得到平均梯度# 反向传播,梯度累计loss.backward()if ((batch_idx + 1) % accum_iter == 0) or (batch_idx + 1 == len(data_iter)):#print('backward: ', net[0].weight.grad)# 更新模型optimizer.step()              # 梯度清零optimizer.zero_grad()epoch_loss.append(loss.item()/batch_size)print(f'epoch {epoch}, loss: {np.mean(epoch_loss)}')
      '''
      epoch 1, loss: 0.5434057596921921
      epoch 2, loss: 0.19144139245152472
      epoch 3, loss: 0.06752512042224407
      '''
      
  • 可以观察到无论 epoch loss 还是 net[0].weight.grad 都完全相同,说明梯度累加不影响计算结果

这篇关于Pytorch入门(7)—— 梯度累加(Gradient Accumulation)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/999503

相关文章

从入门到精通C++11 <chrono> 库特性

《从入门到精通C++11<chrono>库特性》chrono库是C++11中一个非常强大和实用的库,它为时间处理提供了丰富的功能和类型安全的接口,通过本文的介绍,我们了解了chrono库的基本概念... 目录一、引言1.1 为什么需要<chrono>库1.2<chrono>库的基本概念二、时间段(Durat

解析C++11 static_assert及与Boost库的关联从入门到精通

《解析C++11static_assert及与Boost库的关联从入门到精通》static_assert是C++中强大的编译时验证工具,它能够在编译阶段拦截不符合预期的类型或值,增强代码的健壮性,通... 目录一、背景知识:传统断言方法的局限性1.1 assert宏1.2 #error指令1.3 第三方解决

从入门到精通MySQL 数据库索引(实战案例)

《从入门到精通MySQL数据库索引(实战案例)》索引是数据库的目录,提升查询速度,主要类型包括BTree、Hash、全文、空间索引,需根据场景选择,建议用于高频查询、关联字段、排序等,避免重复率高或... 目录一、索引是什么?能干嘛?核心作用:二、索引的 4 种主要类型(附通俗例子)1. BTree 索引(

Redis 配置文件使用建议redis.conf 从入门到实战

《Redis配置文件使用建议redis.conf从入门到实战》Redis配置方式包括配置文件、命令行参数、运行时CONFIG命令,支持动态修改参数及持久化,常用项涉及端口、绑定、内存策略等,版本8... 目录一、Redis.conf 是什么?二、命令行方式传参(适用于测试)三、运行时动态修改配置(不重启服务

MySQL DQL从入门到精通

《MySQLDQL从入门到精通》通过DQL,我们可以从数据库中检索出所需的数据,进行各种复杂的数据分析和处理,本文将深入探讨MySQLDQL的各个方面,帮助你全面掌握这一重要技能,感兴趣的朋友跟随小... 目录一、DQL 基础:SELECT 语句入门二、数据过滤:WHERE 子句的使用三、结果排序:ORDE

Pytorch介绍与安装过程

《Pytorch介绍与安装过程》PyTorch因其直观的设计、卓越的灵活性以及强大的动态计算图功能,迅速在学术界和工业界获得了广泛认可,成为当前深度学习研究和开发的主流工具之一,本文给大家介绍Pyto... 目录1、Pytorch介绍1.1、核心理念1.2、核心组件与功能1.3、适用场景与优势总结1.4、优

conda安装GPU版pytorch默认却是cpu版本

《conda安装GPU版pytorch默认却是cpu版本》本文主要介绍了遇到Conda安装PyTorchGPU版本却默认安装CPU的问题,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的... 目录一、问题描述二、网上解决方案罗列【此节为反面方案罗列!!!】三、发现的根本原因[独家]3.1 p

Python中OpenCV与Matplotlib的图像操作入门指南

《Python中OpenCV与Matplotlib的图像操作入门指南》:本文主要介绍Python中OpenCV与Matplotlib的图像操作指南,本文通过实例代码给大家介绍的非常详细,对大家的学... 目录一、环境准备二、图像的基本操作1. 图像读取、显示与保存 使用OpenCV操作2. 像素级操作3.

C/C++的OpenCV 进行图像梯度提取的几种实现

《C/C++的OpenCV进行图像梯度提取的几种实现》本文主要介绍了C/C++的OpenCV进行图像梯度提取的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的... 目录预www.chinasem.cn备知识1. 图像加载与预处理2. Sobel 算子计算 X 和 Y

PyTorch中cdist和sum函数使用示例详解

《PyTorch中cdist和sum函数使用示例详解》torch.cdist是PyTorch中用于计算**两个张量之间的成对距离(pairwisedistance)**的函数,常用于点云处理、图神经网... 目录基本语法输出示例1. 简单的 2D 欧几里得距离2. 批量形式(3D Tensor)3. 使用不