pytorch Finetune和各层定制学习率

2024-06-17 15:48

本文主要是介绍pytorch Finetune和各层定制学习率,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

一、Finetune之权值初始化
第一步:保存模型参数
第二步:加载模型
第三步:初始化
二、不同层设置不同的学习率
补充:
我们知道一个良好的权值初始化,可以使收敛速度加快,甚至可以获得更好的精度。而在实际应用中,我们通常采用一个已经训练模型的模型的权值参数作为我们模型的初始化参数,也称之为Finetune,更宽泛的称之为迁移学习。迁移学习中的Finetune技术,本质上就是让我们新构建的模型,拥有一个较好的权值初始值。

finetune权值初始化三步曲,finetune就相当于给模型进行初始化,其流程共用三步:

第一步:保存模型,拥有一个预训练模型;
第二步:加载模型,把预训练模型中的权值取出来;
第三步:初始化,将权值对应的“放”到新模型中

一、Finetune之权值初始化

在进行finetune之前我们需要拥有一个模型或者是模型参数,因此需要了解如何保存模型。官方文档中介绍了两种保存模型的方法,一种是保存整个模型,另外一种是仅保存模型参数(官方推荐用这种方法),这里采用官方推荐的方法。

第一步:保存模型参数

若拥有模型参数,可跳过这一步。
假设创建了一个net = Net(),并且经过训练,通过以下方式保存:
torch.save(net.state_dict(), 'net_params.pkl')


第二步:加载模型

进行三步曲中的第二步,加载模型,这里只是加载模型的参数:
pretrained_dict = torch.load('net_params.pkl')
第三步:初始化

进行三步曲中的第三步,将取到的权值,对应的放到新模型中:
首先我们创建新模型,并且获取新模型的参数字典net_state_dict:
net = Net() # 创建net
net_state_dict = net.state_dict() # 获取已创建net的state_dict

接着将pretrained_dict里不属于net_state_dict的键剔除掉:
pretrained_dict_1 =  {k: v for k, v in pretrained_dict.items() if k in net_state_dict}  

然后,用预训练模型的参数字典 对 新模型的参数字典net_state_dict 进行更新:
net_state_dict.update(pretrained_dict_1)

最后,将更新了参数的字典 “放”回到网络中:
net.load_state_dict(net_state_dict)

这样,利用预训练模型参数对新模型的权值进行初始化过程就做完了。

采用finetune的训练过程中,有时候希望前面层的学习率低一些,改变不要太大,而后面的全连接层的学习率相对大一些。这时就需要对不同的层设置不同的学习率,下面就介绍如何为不同层配置不同的学习率。

二、不同层设置不同的学习率

在利用pre-trained model的参数做初始化之后,我们可能想让fc层更新相对快一些,而希望前面的权值更新小一些,这就可以通过为不同的层设置不同的学习率来达到此目的。

为不同层设置不同的学习率,主要通过优化器对多个参数组进行设置不同的参数。所以,只需要将原始的参数组,划分成两个,甚至更多的参数组,然后分别进行设置学习率。
这里将原始参数“切分”成fc3层参数和其余参数,为fc3层设置更大的学习率。

请看代码:

ignored_params = list(map(id, net.fc3.parameters())) # 返回的是parameters的 内存地址
base_params = filter(lambda p: id(p) not in ignored_params, net.parameters()) 
optimizer = optim.SGD([
{'params': base_params},
{'params': net.fc3.parameters(), 'lr': 0.001*10}], 0.001, momentum=0.9, weight_decay=1e-4)

第一行+ 第二行的意思就是,将fc3层的参数net.fc3.parameters()从原始参数net.parameters()中剥离出来
base_params就是剥离了fc3层的参数的其余参数,然后在优化器中为fc3层的参数单独设定学习率。

optimizer = optim.SGD(…)这里的意思就是 base_params中的层,用 0.001, momentum=0.9, weight_decay=1e-4
fc3层设定学习率为: 0.001*10

完整代码位于 https://github.com/tensor-yu/PyTorch_Tutorial/blob/master/Code/2_model/2_finetune.py

补充:

挑选出特定的层的机制是利用内存地址作为过滤条件,将需要单独设定的那部分参数,从总的参数中剔除。
base_params 是一个list,每个元素是一个Parameter 类
net.fc3.parameters() 是一个

ignored_params = list(map(id, net.fc3.parameters()))
net.fc3.parameters() 是一个<generator object parameters at 0x11b63bf00>
所以迭代的返回其中的parameter,这里有weight 和 bias
最终返回weight和bias所在内存的地址
 

 

https://blog.csdn.net/u011995719/article/details/85107310

https://blog.csdn.net/u011995719/article/details/85107310

https://blog.csdn.net/u011995719/article/details/85107310

 

这篇关于pytorch Finetune和各层定制学习率的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1069866

相关文章

苹果macOS 26 Tahoe主题功能大升级:可定制图标/高亮文本/文件夹颜色

《苹果macOS26Tahoe主题功能大升级:可定制图标/高亮文本/文件夹颜色》在整体系统设计方面,macOS26采用了全新的玻璃质感视觉风格,应用于Dock栏、应用图标以及桌面小部件等多个界面... 科技媒体 MACRumors 昨日(6 月 13 日)发布博文,报道称在 macOS 26 Tahoe 中

Go学习记录之runtime包深入解析

《Go学习记录之runtime包深入解析》Go语言runtime包管理运行时环境,涵盖goroutine调度、内存分配、垃圾回收、类型信息等核心功能,:本文主要介绍Go学习记录之runtime包的... 目录前言:一、runtime包内容学习1、作用:① Goroutine和并发控制:② 垃圾回收:③ 栈和

Android学习总结之Java和kotlin区别超详细分析

《Android学习总结之Java和kotlin区别超详细分析》Java和Kotlin都是用于Android开发的编程语言,它们各自具有独特的特点和优势,:本文主要介绍Android学习总结之Ja... 目录一、空安全机制真题 1:Kotlin 如何解决 Java 的 NullPointerExceptio

Pytorch介绍与安装过程

《Pytorch介绍与安装过程》PyTorch因其直观的设计、卓越的灵活性以及强大的动态计算图功能,迅速在学术界和工业界获得了广泛认可,成为当前深度学习研究和开发的主流工具之一,本文给大家介绍Pyto... 目录1、Pytorch介绍1.1、核心理念1.2、核心组件与功能1.3、适用场景与优势总结1.4、优

conda安装GPU版pytorch默认却是cpu版本

《conda安装GPU版pytorch默认却是cpu版本》本文主要介绍了遇到Conda安装PyTorchGPU版本却默认安装CPU的问题,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的... 目录一、问题描述二、网上解决方案罗列【此节为反面方案罗列!!!】三、发现的根本原因[独家]3.1 p

PyTorch中cdist和sum函数使用示例详解

《PyTorch中cdist和sum函数使用示例详解》torch.cdist是PyTorch中用于计算**两个张量之间的成对距离(pairwisedistance)**的函数,常用于点云处理、图神经网... 目录基本语法输出示例1. 简单的 2D 欧几里得距离2. 批量形式(3D Tensor)3. 使用不

PyTorch高级特性与性能优化方式

《PyTorch高级特性与性能优化方式》:本文主要介绍PyTorch高级特性与性能优化方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、自动化机制1.自动微分机制2.动态计算图二、性能优化1.内存管理2.GPU加速3.多GPU训练三、分布式训练1.分布式数据

重新对Java的类加载器的学习方式

《重新对Java的类加载器的学习方式》:本文主要介绍重新对Java的类加载器的学习方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1、介绍1.1、简介1.2、符号引用和直接引用1、符号引用2、直接引用3、符号转直接的过程2、加载流程3、类加载的分类3.1、显示

判断PyTorch是GPU版还是CPU版的方法小结

《判断PyTorch是GPU版还是CPU版的方法小结》PyTorch作为当前最流行的深度学习框架之一,支持在CPU和GPU(NVIDIACUDA)上运行,所以对于深度学习开发者来说,正确识别PyTor... 目录前言为什么需要区分GPU和CPU版本?性能差异硬件要求如何检查PyTorch版本?方法1:使用命

Java学习手册之Filter和Listener使用方法

《Java学习手册之Filter和Listener使用方法》:本文主要介绍Java学习手册之Filter和Listener使用方法的相关资料,Filter是一种拦截器,可以在请求到达Servl... 目录一、Filter(过滤器)1. Filter 的工作原理2. Filter 的配置与使用二、Listen