PyTorch Demo-3 : 动态调整学习率

2024-09-05 01:38

本文主要是介绍PyTorch Demo-3 : 动态调整学习率,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

# 一些必要的库和参数
import torch
import torch.nn as nn
from torchvision import models
import matplotlib.pyplot as plt
import numpy as np
以SGD为例
model = models.resnet18()init_lr = 0.1
optimizer = torch.optim.SGD(model.parameters(), init_lr)
# 查看学习率
for param_group in optimizer.param_groups:print(param_group['lr'])
# 0.1
1. 官方例子里是如下的自定义函数方式,以最常用的调整策略StepLR为例,每隔一定轮数进行改变
# Reference:https://github.com/pytorch/examples/blob/master/imagenet/main.py
# 每30轮学习率乘以0.1
def adjust_learning_rate(optimizer, epoch, init_lr):"""optimizer: 优化器epoch: 训练轮数,也可以根据需要加入其它参数init_lr:初始学习率,也可以设置为全局变量"""lr = init_lr * (0.1 ** (epoch // 30))for param_group in optimizer.param_groups:param_group['lr'] = lr
total_epoch = 100
lrs = []
# 每一轮调用函数即可
for epoch in range(total_epoch):adjust_learning_rate(optimizer, epoch, init_lr)lrs.append(optimizer.param_groups[0]['lr'])plt.plot(range(total_epoch), lrs)
plt.title('adjustLR')
plt.savefig('adjustLR.jpg', bbox_inches='tight')

adjustLR

这种方法可以很方便根据自己的逻辑获得想要的学习率变化策略,可以很复杂,也可以很简单。

2. lr_scheduler

PyTorch中提供了多种预设的学习率策略,都包含在torch.optim.lr_scheduler ,详细见 Docs 。

同理,以 StepLR 为例

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
total_epoch = 100
lrs = []
for epoch in range(total_epoch):lrs.append(optimizer.param_groups[0]['lr'])# 调用step()即更新学习率scheduler.step()
plt.plot(range(total_epoch), lrs)
plt.title('StepLR')

在这里插入图片描述
可以看到,两种方式的StepLR效果是一样的。

3. 带warmup的学习率调整
3.1 自定义函数
def adjust_learning_rate(optimizer, warm_up_step, epoch, init_lr):if epoch < warm_up_step:lr = (epoch + 1) / warm_up_step * init_lrelse:lr = init_lr * (0.1 ** (epoch // 30))for param_group in optimizer.param_groups:param_group['lr'] = lr

在这里插入图片描述

3.2 LambdaLR

lr_scheduler中也有一项自定义的学习率调整方法,通过构造匿名函数来实现

lambda_ = lambda epoch: (epoch + 1) / warm_up_step if epoch < warm_up_step else 0.1 ** (epoch // 30)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda_)

在这里插入图片描述

需要注意的是,自定义函数的方式是直接对优化器中的学习率赋值,而LambdaLR是学习率的权重!
3.3 余弦变换

余弦变换也是常用的学习率调整策略之一,跟steplr可以达到差不多的效果,但是从训练图像上看会更平稳一些。

lambda_ = lambda epoch: (epoch + 1) / warm_up_step if epoch < warm_up_step else 0.5 * (np.cos((epoch - warm_up_step) / (total_epoch - warm_up_step) * np.pi) + 1)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda_)

在这里插入图片描述

4. Other

其他策略如:余弦退火,指数变换,正弦变换,学习率重启等。

这篇关于PyTorch Demo-3 : 动态调整学习率的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1137608

相关文章

Go学习记录之runtime包深入解析

《Go学习记录之runtime包深入解析》Go语言runtime包管理运行时环境,涵盖goroutine调度、内存分配、垃圾回收、类型信息等核心功能,:本文主要介绍Go学习记录之runtime包的... 目录前言:一、runtime包内容学习1、作用:① Goroutine和并发控制:② 垃圾回收:③ 栈和

Java调用C#动态库的三种方法详解

《Java调用C#动态库的三种方法详解》在这个多语言编程的时代,Java和C#就像两位才华横溢的舞者,各自在不同的舞台上展现着独特的魅力,然而,当它们携手合作时,又会碰撞出怎样绚丽的火花呢?今天,我们... 目录方法1:C++/CLI搭建桥梁——Java ↔ C# 的“翻译官”步骤1:创建C#类库(.NET

MyBatis编写嵌套子查询的动态SQL实践详解

《MyBatis编写嵌套子查询的动态SQL实践详解》在Java生态中,MyBatis作为一款优秀的ORM框架,广泛应用于数据库操作,本文将深入探讨如何在MyBatis中编写嵌套子查询的动态SQL,并结... 目录一、Myhttp://www.chinasem.cnBATis动态SQL的核心优势1. 灵活性与可

Android学习总结之Java和kotlin区别超详细分析

《Android学习总结之Java和kotlin区别超详细分析》Java和Kotlin都是用于Android开发的编程语言,它们各自具有独特的特点和优势,:本文主要介绍Android学习总结之Ja... 目录一、空安全机制真题 1:Kotlin 如何解决 Java 的 NullPointerExceptio

Pytorch介绍与安装过程

《Pytorch介绍与安装过程》PyTorch因其直观的设计、卓越的灵活性以及强大的动态计算图功能,迅速在学术界和工业界获得了广泛认可,成为当前深度学习研究和开发的主流工具之一,本文给大家介绍Pyto... 目录1、Pytorch介绍1.1、核心理念1.2、核心组件与功能1.3、适用场景与优势总结1.4、优

conda安装GPU版pytorch默认却是cpu版本

《conda安装GPU版pytorch默认却是cpu版本》本文主要介绍了遇到Conda安装PyTorchGPU版本却默认安装CPU的问题,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的... 目录一、问题描述二、网上解决方案罗列【此节为反面方案罗列!!!】三、发现的根本原因[独家]3.1 p

Mybatis嵌套子查询动态SQL编写实践

《Mybatis嵌套子查询动态SQL编写实践》:本文主要介绍Mybatis嵌套子查询动态SQL编写方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录前言一、实体类1、主类2、子类二、Mapper三、XML四、详解总结前言MyBATis的xml文件编写动态SQL

SpringBoot实现Kafka动态反序列化的完整代码

《SpringBoot实现Kafka动态反序列化的完整代码》在分布式系统中,Kafka作为高吞吐量的消息队列,常常需要处理来自不同主题(Topic)的异构数据,不同的业务场景可能要求对同一消费者组内的... 目录引言一、问题背景1.1 动态反序列化的需求1.2 常见问题二、动态反序列化的核心方案2.1 ht

golang实现动态路由的项目实践

《golang实现动态路由的项目实践》本文主要介绍了golang实现动态路由项目实践,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习... 目录一、动态路由1.结构体(数据库的定义)2.预加载preload3.添加关联的方法一、动态路由1

PyTorch中cdist和sum函数使用示例详解

《PyTorch中cdist和sum函数使用示例详解》torch.cdist是PyTorch中用于计算**两个张量之间的成对距离(pairwisedistance)**的函数,常用于点云处理、图神经网... 目录基本语法输出示例1. 简单的 2D 欧几里得距离2. 批量形式(3D Tensor)3. 使用不