Pytorch中的Exponential Moving Average(EMA)

2023-10-14 16:20

本文主要是介绍Pytorch中的Exponential Moving Average(EMA),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

EMA介绍

EMA,指数移动平均,常用于更新模型参数、梯度等。

EMA的优点是能提升模型的鲁棒性(融合了之前的模型权重信息)

代码示例

下面以yolov7/utils/torch_utils.py代码为例:

class ModelEMA:""" Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-modelsKeep a moving average of everything in the model state_dict (parameters and buffers).This is intended to allow functionality likehttps://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverageA smoothed version of the weights is necessary for some training schemes to perform well.This class is sensitive where it is initialized in the sequence of model init,GPU assignment and distributed training wrappers."""def __init__(self, model, decay=0.9999, updates=0):# Create EMAself.ema = deepcopy(model.module if is_parallel(model) else model).eval()self.updates = updates  # number of EMA updatesself.decay = lambda x: decay * (1 - math.exp(-x / 2000))for p in self.ema.parameters():p.requires_grad_(False)def update(self, model):# Update EMA parameterswith torch.no_grad():self.updates += 1d = self.decay(self.updates)msd = model.module.state_dict() if is_parallel(model) else model.state_dict()  for k, v in self.ema.state_dict().items():if v.dtype.is_floating_point:v *= dv += (1. - d) * msd[k].detach()def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):# Update EMA attributescopy_attr(self.ema, model, include, exclude)

ModelEMA类的__init__ 函数介绍

__init__ 函数的输入参数介绍

  • model:需要使用EMA策略更新参数的模型
  • decay:加权权重,默认为0.9999
  • updates:模型参数更新/迭代次数

__init__ 函数的初始化介绍

首先深拷贝一份模型

"""
创建EMA模型model.eval()的作用:
1. 保证BN层使用的是训练数据的均值(running_mean)和方差(running_val), 否则一旦test的batch_size过小, 很容易就会被BN层影响结果
2. 保证Dropout不随机舍弃神经元
3. 模型不会计算梯度,从而减少内存消耗和计算时间is_parallel()的作用:
如果模型是并行训练(DP/DDP)的, 则深拷贝model.module,否则就深拷贝model"""
self.ema = deepcopy(model.module if is_parallel(model) else model).eval()

接着,初始化updates次数,若是从头开始训练,则该参数为0

self.updates = updates

最后,定义加权权重decay的计算公式(这里呈指数型变化),

self.decay = lambda x: decay * (1 - math.exp(-x / 2000))

ModelEMA类的update()函数介绍

如果调用该函数,则更新updates以及decay,

self.updates += 1
## d随着updates的增加而逐渐增大, 意味着随着模型迭代次数的增加, EMA模型的权重会越来越偏向于之前的权重
d = self.decay(self.updates)

取出当前模型的参数,为更新EMA模型的参数做准备,

msd = model.module.state_dict() if is_parallel(model) else model.state_dict()

对EMA模型参数以及当前模型参数进行加权求和,作为EMA模型的新参数,

for k, v in self.ema.state_dict().items():if v.dtype.is_floating_point:v *= dv += (1. - d) * msd[k].detach()

参考文章

【代码解读】在pytorch中使用EMA - 知乎

【炼丹技巧】指数移动平均(EMA)的原理及PyTorch实现 - 知乎

以史为鉴!EMA在机器学习中的应用 - 知乎

这篇关于Pytorch中的Exponential Moving Average(EMA)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/211675

相关文章

Pytorch介绍与安装过程

《Pytorch介绍与安装过程》PyTorch因其直观的设计、卓越的灵活性以及强大的动态计算图功能,迅速在学术界和工业界获得了广泛认可,成为当前深度学习研究和开发的主流工具之一,本文给大家介绍Pyto... 目录1、Pytorch介绍1.1、核心理念1.2、核心组件与功能1.3、适用场景与优势总结1.4、优

conda安装GPU版pytorch默认却是cpu版本

《conda安装GPU版pytorch默认却是cpu版本》本文主要介绍了遇到Conda安装PyTorchGPU版本却默认安装CPU的问题,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的... 目录一、问题描述二、网上解决方案罗列【此节为反面方案罗列!!!】三、发现的根本原因[独家]3.1 p

PyTorch中cdist和sum函数使用示例详解

《PyTorch中cdist和sum函数使用示例详解》torch.cdist是PyTorch中用于计算**两个张量之间的成对距离(pairwisedistance)**的函数,常用于点云处理、图神经网... 目录基本语法输出示例1. 简单的 2D 欧几里得距离2. 批量形式(3D Tensor)3. 使用不

PyTorch高级特性与性能优化方式

《PyTorch高级特性与性能优化方式》:本文主要介绍PyTorch高级特性与性能优化方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、自动化机制1.自动微分机制2.动态计算图二、性能优化1.内存管理2.GPU加速3.多GPU训练三、分布式训练1.分布式数据

判断PyTorch是GPU版还是CPU版的方法小结

《判断PyTorch是GPU版还是CPU版的方法小结》PyTorch作为当前最流行的深度学习框架之一,支持在CPU和GPU(NVIDIACUDA)上运行,所以对于深度学习开发者来说,正确识别PyTor... 目录前言为什么需要区分GPU和CPU版本?性能差异硬件要求如何检查PyTorch版本?方法1:使用命

pytorch自动求梯度autograd的实现

《pytorch自动求梯度autograd的实现》autograd是一个自动微分引擎,它可以自动计算张量的梯度,本文主要介绍了pytorch自动求梯度autograd的实现,具有一定的参考价值,感兴趣... autograd是pytorch构建神经网络的核心。在 PyTorch 中,结合以下代码例子,当你

在PyCharm中安装PyTorch、torchvision和OpenCV详解

《在PyCharm中安装PyTorch、torchvision和OpenCV详解》:本文主要介绍在PyCharm中安装PyTorch、torchvision和OpenCV方式,具有很好的参考价值,... 目录PyCharm安装PyTorch、torchvision和OpenCV安装python安装PyTor

pytorch之torch.flatten()和torch.nn.Flatten()的用法

《pytorch之torch.flatten()和torch.nn.Flatten()的用法》:本文主要介绍pytorch之torch.flatten()和torch.nn.Flatten()的用... 目录torch.flatten()和torch.nn.Flatten()的用法下面举例说明总结torch

使用PyTorch实现手写数字识别功能

《使用PyTorch实现手写数字识别功能》在人工智能的世界里,计算机视觉是最具魅力的领域之一,通过PyTorch这一强大的深度学习框架,我们将在经典的MNIST数据集上,见证一个神经网络从零开始学会识... 目录当计算机学会“看”数字搭建开发环境MNIST数据集解析1. 认识手写数字数据库2. 数据预处理的

Pytorch微调BERT实现命名实体识别

《Pytorch微调BERT实现命名实体识别》命名实体识别(NER)是自然语言处理(NLP)中的一项关键任务,它涉及识别和分类文本中的关键实体,BERT是一种强大的语言表示模型,在各种NLP任务中显著... 目录环境准备加载预训练BERT模型准备数据集标记与对齐微调 BERT最后总结环境准备在继续之前,确