PyTorch-Lightning:trining_step的自动优化

2024-04-12 09:44

本文主要是介绍PyTorch-Lightning:trining_step的自动优化,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

  • PyTorch-Lightning:trining_step的自动优化
      • 总结:
    • class _ AutomaticOptimization()
      • def run
      • def _make_closure
      • def _training_step
        • class ClosureResult():
          • def from_training_step_output
      • class Closure

PyTorch-Lightning:trining_step的自动优化

使用PyTorch-Lightning时,在trining_step定义损失,在没有定义损失,没有任何返回的情况下没有报错,在定义一个包含loss的多个元素字典返回时,也可以正常训练,那么到底lightning是怎么完成训练过程的。

总结:

在自动优化中,training_step必须返回一个tensor或者dict或者None(跳过),对于简单的使用,在training_step可以return一个tensor会作为Loss回传,也可以return一个字典,其中必须包括key"loss",字典中的"loss"会提取出来作为Loss回传,具体过程主要包含在lightning\pytorch\loop\sautomatic.py中的_ AutomaticOptimization()类。

在这里插入图片描述

class _ AutomaticOptimization()

实现自动优化(前向,梯度清零,后向,optimizer step)

在training_epoch_loop中会调用这个类的run函数。

def run

首先通过 _make_closure得到一个closure,详见def _make_closure,最后返回一个字典,如果我们在training_step只return了一个loss tensor则字典只有一个’loss’键值对,如果return了一个字典,则包含其他键值对。

可以看到调用了_ optimizer_step,_ optimizer_step经过层层调用,最后会调用torch默认的optimizer.zero_grad,而我们通过 make_closure得到的closure作为参数传入,具体而言是调用了closure类的_ call __()方法。

def run(self, optimizer: Optimizer, batch_idx: int, kwargs: OrderedDict) -> _OUTPUTS_TYPE:closure = self._make_closure(kwargs, optimizer, batch_idx)if (# when the strategy handles accumulation, we want to always call the optimizer stepnot self.trainer.strategy.handles_gradient_accumulation and self.trainer.fit_loop._should_accumulate()):# For gradient accumulation# -------------------# calculate loss (train step + train step end)# -------------------# automatic_optimization=True: perform ddp sync only when performing optimizer_stepwith _block_parallel_sync_behavior(self.trainer.strategy, block=True):closure()# ------------------------------# BACKWARD PASS# ------------------------------# gradient update with accumulated gradientselse:self._optimizer_step(batch_idx, closure)result = closure.consume_result()if result.loss is None:return {}return result.asdict()

def _make_closure

创建一个closure对象,来捕捉给定的参数并且运行’training_step’和可选的其他如backword和zero_grad函数

比较重要的是step_fn,在这里调用了_training_step,得到的是一个存储我们在定义模型时重写的training step的输出所构成ClosureResult数据类。具体见def _training_step

def _make_closure(self, kwargs: OrderedDict, optimizer: Optimizer, batch_idx: int) -> Closure:step_fn = self._make_step_fn(kwargs)backward_fn = self._make_backward_fn(optimizer)zero_grad_fn = self._make_zero_grad_fn(batch_idx, optimizer)return Closure(step_fn=step_fn, backward_fn=backward_fn, zero_grad_fn=zero_grad_fn)

def _training_step

通过hook函数实现真正的训练step,返回一个存储training step输出的ClosureResult数据类。

将我们在定义模型时定义的lightning.pytorch.core.LightningModule.training_step的输出作为参数传入存储容器class ClosureResult的from_training_step_output方法,见class Closure

class ClosureResult():

一个数据类,包含closure_loss,loss,extra

    closure_loss: Optional[Tensor]loss: Optional[Tensor] = field(init=False, default=None)extra: Dict[str, Any] = field(default_factory=dict)
def from_training_step_output

一个类方法,如果我们在training_step定义的返回是一个字典,则我们会将key值为"loss"的value赋值给closure_loss,而其余的键值对赋值给extra字典,如果返回的既不是包含"loss"的字典也不是tensor,则会报错。当我们在training_step不设定返回,则自然为None,但是不会报错。

class Closure

闭包是将外部作用域中的变量绑定到对这些变量进行计算的函数变量,而不将它们明确地作为输入。这样做的好处是可以将闭包传递给对象,之后可以像函数一样调用它,但不需要传入任何参数。

在lightning,用于自动优化的Closure类将training_step和backward, zero_grad三个基本的闭包结合在一起。

这个Closure得到training循环中的结果之后传入torch.optim.Optimizer.step。

参数:

  • step_fn: 这里是一个存储了training step输出的ClosureResult数据类,见def _training_step
  • backward_fn: 梯度回传函数
  • zero_grad_fn: 梯度清零函数

按照顺序,会先检查得到loss,之后调用梯度清零函数,最后调用梯度回传函数

class Closure(AbstractClosure[ClosureResult]):warning_cache = WarningCache()def __init__(self,step_fn: Callable[[], ClosureResult],backward_fn: Optional[Callable[[Tensor], None]] = None,zero_grad_fn: Optional[Callable[[], None]] = None,):super().__init__()self._step_fn = step_fnself._backward_fn = backward_fnself._zero_grad_fn = zero_grad_fn@override@torch.enable_grad()def closure(self, *args: Any, **kwargs: Any) -> ClosureResult:step_output = self._step_fn()if step_output.closure_loss is None:self.warning_cache.warn("`training_step` returned `None`. If this was on purpose, ignore this warning...")if self._zero_grad_fn is not None:self._zero_grad_fn()if self._backward_fn is not None and step_output.closure_loss is not None:self._backward_fn(step_output.closure_loss)return step_output@overridedef __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]:self._result = self.closure(*args, **kwargs)return self._result.loss

这篇关于PyTorch-Lightning:trining_step的自动优化的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/896785

相关文章

JAVA实现Token自动续期机制的示例代码

《JAVA实现Token自动续期机制的示例代码》本文主要介绍了JAVA实现Token自动续期机制的示例代码,通过动态调整会话生命周期平衡安全性与用户体验,解决固定有效期Token带来的风险与不便,感兴... 目录1. 固定有效期Token的内在局限性2. 自动续期机制:兼顾安全与体验的解决方案3. 总结PS

linux部署NFS和autofs自动挂载实现过程

《linux部署NFS和autofs自动挂载实现过程》文章介绍了NFS(网络文件系统)和Autofs的原理与配置,NFS通过RPC实现跨系统文件共享,需配置/etc/exports和nfs.conf,... 目录(一)NFS1. 什么是NFS2.NFS守护进程3.RPC服务4. 原理5. 部署5.1安装NF

Docker多阶段镜像构建与缓存利用性能优化实践指南

《Docker多阶段镜像构建与缓存利用性能优化实践指南》这篇文章将从原理层面深入解析Docker多阶段构建与缓存机制,结合实际项目示例,说明如何有效利用构建缓存,组织镜像层次,最大化提升构建速度并减少... 目录一、技术背景与应用场景二、核心原理深入分析三、关键 dockerfile 解读3.1 Docke

MyBatis Plus实现时间字段自动填充的完整方案

《MyBatisPlus实现时间字段自动填充的完整方案》在日常开发中,我们经常需要记录数据的创建时间和更新时间,传统的做法是在每次插入或更新操作时手动设置这些时间字段,这种方式不仅繁琐,还容易遗漏,... 目录前言解决目标技术栈实现步骤1. 实体类注解配置2. 创建元数据处理器3. 服务层代码优化填充机制详

深入浅出Spring中的@Autowired自动注入的工作原理及实践应用

《深入浅出Spring中的@Autowired自动注入的工作原理及实践应用》在Spring框架的学习旅程中,@Autowired无疑是一个高频出现却又让初学者头疼的注解,它看似简单,却蕴含着Sprin... 目录深入浅出Spring中的@Autowired:自动注入的奥秘什么是依赖注入?@Autowired

从原理到实战解析Java Stream 的并行流性能优化

《从原理到实战解析JavaStream的并行流性能优化》本文给大家介绍JavaStream的并行流性能优化:从原理到实战的全攻略,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的... 目录一、并行流的核心原理与适用场景二、性能优化的核心策略1. 合理设置并行度:打破默认阈值2. 避免装箱

Python实战之SEO优化自动化工具开发指南

《Python实战之SEO优化自动化工具开发指南》在数字化营销时代,搜索引擎优化(SEO)已成为网站获取流量的重要手段,本文将带您使用Python开发一套完整的SEO自动化工具,需要的可以了解下... 目录前言项目概述技术栈选择核心模块实现1. 关键词研究模块2. 网站技术seo检测模块3. 内容优化分析模

Java实现复杂查询优化的7个技巧小结

《Java实现复杂查询优化的7个技巧小结》在Java项目中,复杂查询是开发者面临的“硬骨头”,本文将通过7个实战技巧,结合代码示例和性能对比,手把手教你如何让复杂查询变得优雅,大家可以根据需求进行选择... 目录一、复杂查询的痛点:为何你的代码“又臭又长”1.1冗余变量与中间状态1.2重复查询与性能陷阱1.

Python内存优化的实战技巧分享

《Python内存优化的实战技巧分享》Python作为一门解释型语言,虽然在开发效率上有着显著优势,但在执行效率方面往往被诟病,然而,通过合理的内存优化策略,我们可以让Python程序的运行速度提升3... 目录前言python内存管理机制引用计数机制垃圾回收机制内存泄漏的常见原因1. 循环引用2. 全局变

基于Redis自动过期的流处理暂停机制

《基于Redis自动过期的流处理暂停机制》基于Redis自动过期的流处理暂停机制是一种高效、可靠且易于实现的解决方案,防止延时过大的数据影响实时处理自动恢复处理,以避免积压的数据影响实时性,下面就来详... 目录核心思路代码实现1. 初始化Redis连接和键前缀2. 接收数据时检查暂停状态3. 检测到延时过