nnUNet 更改学习率和衰减优化器的方法

2023-11-06 11:36

本文主要是介绍nnUNet 更改学习率和衰减优化器的方法,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

此为记录贴,逻辑混乱 仅供参考:
勿喷
nnUNet默认的学习率衰减方法为线性衰减,优化器为SGD,在.\nnUNet\nnunetv2\training\nnUNetTrainer\nnUNetTrainer.py文件中nnUNetTrainer基类中定义 如下:

    def configure_optimizers(self):optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,momentum=0.99, nesterov=True)lr_scheduler = PolyLRScheduler(optimizer, self.initial_lr, self.num_epochs)return optimizer, lr_scheduler

为了改变优化器和学习率衰减方法:
我们可以继承nnUNetTrainer类重写一个 nnUNetTrainerCosAnneal类,当然nnUnet已经贴心的为我们写好了 在.\nnUNet\nnunetv2\training\nnUNetTrainer\variants\optimizer\nnUNetTrainerAdam
原始代码如下:

import torch
from torch.optim import Adam, AdamWfrom nnunetv2.training.lr_scheduler.polylr import PolyLRScheduler
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainerclass nnUNetTrainerAdam(nnUNetTrainer):def configure_optimizers(self):optimizer = AdamW(self.network.parameters(),lr=self.initial_lr,weight_decay=self.weight_decay,amsgrad=True)# optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,#                             momentum=0.99, nesterov=True)lr_scheduler = PolyLRScheduler(optimizer, self.initial_lr, self.num_epochs)return optimizer, lr_scheduler

如果按照上一篇博客的方法直接更改训练方法为nnUNetTrainerAdam的话,会弹出如下警告:

 UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1
.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`.  Failure to do this will result in PyTorch skipping the first 
value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-ratewarnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`.

警告已经说的很明白了,就不翻译了,为了避免不能在训练的时候调整学习率,我们需要去改变lr_scheduler.step()optimizer.step() 调用顺序。就要在重写on_train_epoch_starttrain_step函数
下列文件可以作为参考:
要修改优化器也可以直接在
optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay, momentum=0.99, nesterov=True)
更改即可

from torch.optim.lr_scheduler import CosineAnnealingLR
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import *class nnUNetTrainerCosAnneal(nnUNetTrainer):def configure_optimizers(self):optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,momentum=0.99, nesterov=True)lr_scheduler = CosineAnnealingLR(optimizer, T_max=self.num_epochs,eta_min=1e-4)return optimizer, lr_schedulerdef on_train_epoch_start(self):self.network.train()# self.lr_scheduler.step() #don't need call lr_scheduler.step() in this functionself.print_to_log_file('')self.print_to_log_file(f'Epoch {self.current_epoch}')self.print_to_log_file(f"Current learning rate: {np.round(self.optimizer.param_groups[0]['lr'], decimals=5)}")# lrs are the same for all workers so we don't need to gather them in case of DDP trainingself.logger.log('lrs', self.optimizer.param_groups[0]['lr'], self.current_epoch)def train_step(self, batch: dict) -> dict:data = batch['data']target = batch['target']data = data.to(self.device, non_blocking=True)if isinstance(target, list):target = [i.to(self.device, non_blocking=True) for i in target]else:target = target.to(self.device, non_blocking=True)self.optimizer.zero_grad(set_to_none=True)# Autocast is a little bitch.# If the device_type is 'cpu' then it's slow as heck and needs to be disabled.# If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False is set. Whyyyyyyy. (this is why we don't make use of enabled=False)# So autocast will only be active if we have a cuda device.with autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context():output = self.network(data)# del datal = self.loss(output, target)if self.grad_scaler is not None:self.grad_scaler.scale(l).backward()self.grad_scaler.unscale_(self.optimizer)torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)self.grad_scaler.step(self.optimizer)self.grad_scaler.update()else:l.backward()torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)self.optimizer.step()self.lr_scheduler.step()## add lr_scheduler.step() after optimizer.step()return {'loss': l.detach().cpu().numpy()}

要使用这个类进行训练,运行以下命令即可:

nnUNetV2_train 002 2d 0 -tr nnUNetTrainerCosAnneal

记录完毕,继续炼丹

这篇关于nnUNet 更改学习率和衰减优化器的方法的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/356375

相关文章

python获取指定名字的程序的文件路径的两种方法

《python获取指定名字的程序的文件路径的两种方法》本文主要介绍了python获取指定名字的程序的文件路径的两种方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要... 最近在做项目,需要用到给定一个程序名字就可以自动获取到这个程序在Windows系统下的绝对路径,以下

JavaScript中的高级调试方法全攻略指南

《JavaScript中的高级调试方法全攻略指南》什么是高级JavaScript调试技巧,它比console.log有何优势,如何使用断点调试定位问题,通过本文,我们将深入解答这些问题,带您从理论到实... 目录观点与案例结合观点1观点2观点3观点4观点5高级调试技巧详解实战案例断点调试:定位变量错误性能分

Python中 try / except / else / finally 异常处理方法详解

《Python中try/except/else/finally异常处理方法详解》:本文主要介绍Python中try/except/else/finally异常处理方法的相关资料,涵... 目录1. 基本结构2. 各部分的作用tryexceptelsefinally3. 执行流程总结4. 常见用法(1)多个e

JavaScript中比较两个数组是否有相同元素(交集)的三种常用方法

《JavaScript中比较两个数组是否有相同元素(交集)的三种常用方法》:本文主要介绍JavaScript中比较两个数组是否有相同元素(交集)的三种常用方法,每种方法结合实例代码给大家介绍的非常... 目录引言:为什么"相等"判断如此重要?方法1:使用some()+includes()(适合小数组)方法2

504 Gateway Timeout网关超时的根源及完美解决方法

《504GatewayTimeout网关超时的根源及完美解决方法》在日常开发和运维过程中,504GatewayTimeout错误是常见的网络问题之一,尤其是在使用反向代理(如Nginx)或... 目录引言为什么会出现 504 错误?1. 探索 504 Gateway Timeout 错误的根源 1.1 后端

从原理到实战解析Java Stream 的并行流性能优化

《从原理到实战解析JavaStream的并行流性能优化》本文给大家介绍JavaStream的并行流性能优化:从原理到实战的全攻略,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的... 目录一、并行流的核心原理与适用场景二、性能优化的核心策略1. 合理设置并行度:打破默认阈值2. 避免装箱

Python实战之SEO优化自动化工具开发指南

《Python实战之SEO优化自动化工具开发指南》在数字化营销时代,搜索引擎优化(SEO)已成为网站获取流量的重要手段,本文将带您使用Python开发一套完整的SEO自动化工具,需要的可以了解下... 目录前言项目概述技术栈选择核心模块实现1. 关键词研究模块2. 网站技术seo检测模块3. 内容优化分析模

Java实现复杂查询优化的7个技巧小结

《Java实现复杂查询优化的7个技巧小结》在Java项目中,复杂查询是开发者面临的“硬骨头”,本文将通过7个实战技巧,结合代码示例和性能对比,手把手教你如何让复杂查询变得优雅,大家可以根据需求进行选择... 目录一、复杂查询的痛点:为何你的代码“又臭又长”1.1冗余变量与中间状态1.2重复查询与性能陷阱1.

Python内存优化的实战技巧分享

《Python内存优化的实战技巧分享》Python作为一门解释型语言,虽然在开发效率上有着显著优势,但在执行效率方面往往被诟病,然而,通过合理的内存优化策略,我们可以让Python程序的运行速度提升3... 目录前言python内存管理机制引用计数机制垃圾回收机制内存泄漏的常见原因1. 循环引用2. 全局变

MySQL 表空却 ibd 文件过大的问题及解决方法

《MySQL表空却ibd文件过大的问题及解决方法》本文给大家介绍MySQL表空却ibd文件过大的问题及解决方法,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考... 目录一、问题背景:表空却 “吃满” 磁盘的怪事二、问题复现:一步步编程还原异常场景1. 准备测试源表与数据