Pytorch:torch.nn.utils.clip_grad_norm_梯度截断_解读

2023-12-22 13:36

本文主要是介绍Pytorch:torch.nn.utils.clip_grad_norm_梯度截断_解读,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

torch.nn.utils.clip_grad_norm_函数主要作用:

  神经网络深度逐渐增加,网络参数量增多的时候,容易引起梯度消失和梯度爆炸。对于梯度爆炸问题,解决方法之一便是进行梯度剪裁torch.nn.utils.clip_grad_norm_(),即设置一个梯度大小的上限

注:旧版为torch.nn.utils.clip_grad_norm()

函数参数:

官网链接:https://pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_norm_.html

torch.nn.utils.clip_grad_norm_(parameters, max_norm, norm_type=2.0, error_if_nonfinite=False, foreach=None)

“Clips gradient norm of an iterable of parameters. The norm is computed over all gradients together, as if they were concatenated into a single vector. Gradients are modified in-place.”

“对一组可迭代(网络)参数的梯度范数进行裁剪。效果如同将所有参数连接成单个向量来计算范数。梯度原位修改。”

Parameters

  • parameters (Iterable[Tensor] or Tensor) – 实施梯度裁剪的可迭代网络参数
    an iterable of Tensors or a single Tensor that will have gradients normalized(一个由张量或单个张量组成的可迭代对象(模型参数),将梯度归一化)

  • max_norm (float) – 该组网络参数梯度的范数上限
    max norm of the gradients(梯度的最大值)

  • norm_type (float) –范数类型
    type of the used p-norm. Can be ‘inf’ for infinity norm.(所使用的范数类型。默认为L2范数,可以是无穷大范数(‘inf’))

  • error_if_nonfinite (bool)
    if True, an error is thrown if the total norm of the gradients from parameters is nan, inf, or -inf. Default: False (will switch to True in the future)

  • foreach (bool)
    use the faster foreach-based implementation. If None, use the foreach implementation for CUDA and CPU native tensors and silently fall back to the slow implementation for other device types. Default: None

源码解读:

参考:https://blog.csdn.net/Mikeyboi/article/details/119522689
(建议大家看看源码,更好理解函数意义,有注释)

def clip_grad_norm_(parameters, max_norm, norm_type=2):# 处理传入的三个参数。# 首先将parameters中的非空网络参数存入一个列表,# 然后将max_norm和norm_type类型强制为浮点数。if isinstance(parameters, torch.Tensor):parameters = [parameters]parameters = list(filter(lambda p: p.grad is not None, parameters))max_norm = float(max_norm)norm_type = float(norm_type)#对无穷范数进行了单独计算,即取所有网络参数梯度范数中的最大值,定义为total_normif norm_type == inf:total_norm = max(p.grad.data.abs().max() for p in parameters)# 对于其他范数,计算所有网络参数梯度范数之和,再归一化,# 即等价于把所有网络参数放入一个向量,再对向量计算范数。将结果定义为total_normelse:total_norm = 0for p in parameters:param_norm = p.grad.data.norm(norm_type)total_norm += param_norm.item() ** norm_type # norm_type=2 求平方(二范数)total_norm = total_norm ** (1. / norm_type) # norm_type=2 等价于 开根号# 最后定义了一个“裁剪系数”变量clip_coef,为传入参数max_norm和total_norm的比值(+1e-6防止分母为0的情况)。# 如果max_norm > total_norm,即没有溢出预设上限,则不对梯度进行修改。# 反之则以clip_coef为系数对全部梯度进行惩罚,使最后的全部梯度范数归一化至max_norm的值。# 注意该方法返回了一个 total_norm,实际应用时可以通过该方法得到网络参数梯度的范数,以便确定合理的max_norm值。clip_coef = max_norm / (total_norm + 1e-6)if clip_coef < 1:for p in parameters:p.grad.data.mul_(clip_coef)return total_norm

使用方法及分析:

应用逻辑为:

  1. 先计算梯度;
  2. 裁剪梯度(在函数内部会判断是否需要裁剪,具体看源码解读);
  3. 最后更新网络参数。

因此 torch.nn.utils.clip_grad_norm_() 的使用应该在loss.backward() 之后,optimizer.step() 之前,

在U-Net中如下:

optimizer.zero_grad(set_to_none=True)
grad_scaler.scale(loss).backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)
grad_scaler.step(optimizer)
grad_scaler.update()

参考:https://blog.csdn.net/zhaohongfei_358/article/details/122820992

注意:

  • 从上面文章可以看到,clip_grad_norm 最后就是对所有的梯度乘以一个 clip_coefp.grad.data.mul_(clip_coef)),而且乘的前提是clip_coef一定是小于1的,所以,clip_grad_norm 只解决梯度爆炸问题,不解决梯度消失问题
  • clip_coef的定义**clip_coef = max_norm / (total_norm + 1e-6)** 可以知道:max_norm越大,对于梯度爆炸的解决越柔和,max_norm越小,对梯度爆炸的解决越狠

这篇关于Pytorch:torch.nn.utils.clip_grad_norm_梯度截断_解读的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!


原文地址:
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.chinasem.cn/article/524137

相关文章

Nacos注册中心和配置中心的底层原理全面解读

《Nacos注册中心和配置中心的底层原理全面解读》:本文主要介绍Nacos注册中心和配置中心的底层原理的全面解读,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录临时实例和永久实例为什么 Nacos 要将服务实例分为临时实例和永久实例?1.x 版本和2.x版本的区别

Pytorch介绍与安装过程

《Pytorch介绍与安装过程》PyTorch因其直观的设计、卓越的灵活性以及强大的动态计算图功能,迅速在学术界和工业界获得了广泛认可,成为当前深度学习研究和开发的主流工具之一,本文给大家介绍Pyto... 目录1、Pytorch介绍1.1、核心理念1.2、核心组件与功能1.3、适用场景与优势总结1.4、优

C++类和对象之默认成员函数的使用解读

《C++类和对象之默认成员函数的使用解读》:本文主要介绍C++类和对象之默认成员函数的使用方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、默认成员函数有哪些二、各默认成员函数详解默认构造函数析构函数拷贝构造函数拷贝赋值运算符三、默认成员函数的注意事项总结一

conda安装GPU版pytorch默认却是cpu版本

《conda安装GPU版pytorch默认却是cpu版本》本文主要介绍了遇到Conda安装PyTorchGPU版本却默认安装CPU的问题,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的... 目录一、问题描述二、网上解决方案罗列【此节为反面方案罗列!!!】三、发现的根本原因[独家]3.1 p

C/C++的OpenCV 进行图像梯度提取的几种实现

《C/C++的OpenCV进行图像梯度提取的几种实现》本文主要介绍了C/C++的OpenCV进行图像梯度提取的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的... 目录预www.chinasem.cn备知识1. 图像加载与预处理2. Sobel 算子计算 X 和 Y

MySQL的ALTER TABLE命令的使用解读

《MySQL的ALTERTABLE命令的使用解读》:本文主要介绍MySQL的ALTERTABLE命令的使用,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1、查看所建表的编China编程码格式2、修改表的编码格式3、修改列队数据类型4、添加列5、修改列的位置5.1、把列

Linux CPU飙升排查五步法解读

《LinuxCPU飙升排查五步法解读》:本文主要介绍LinuxCPU飙升排查五步法,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录排查思路-五步法1. top命令定位应用进程pid2.php top-Hp[pid]定位应用进程对应的线程tid3. printf"%

解读@ConfigurationProperties和@value的区别

《解读@ConfigurationProperties和@value的区别》:本文主要介绍@ConfigurationProperties和@value的区别及说明,具有很好的参考价值,希望对大家... 目录1. 功能对比2. 使用场景对比@ConfigurationProperties@Value3. 核

PyTorch中cdist和sum函数使用示例详解

《PyTorch中cdist和sum函数使用示例详解》torch.cdist是PyTorch中用于计算**两个张量之间的成对距离(pairwisedistance)**的函数,常用于点云处理、图神经网... 目录基本语法输出示例1. 简单的 2D 欧几里得距离2. 批量形式(3D Tensor)3. 使用不

Jupyter notebook安装步骤解读

《Jupyternotebook安装步骤解读》:本文主要介绍Jupyternotebook安装步骤,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、开始安装二、更改打开文件位置和快捷启动方式总结在安装Jupyter notebook 之前,确认您已安装pytho