pytorch 参数冻结 parameter-efficient fine-tuning

2024-08-27 08:12

本文主要是介绍pytorch 参数冻结 parameter-efficient fine-tuning,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

目标:在网络中冻结部分参数进行高效训练

框架:pytorch (version 1.11.0)

基本实现

  1. 需要学习的参数requires_grad设置为True,冻结的设置为False
  2. 需要学习的参数要加到 optimizer的List中;对于冻结的参数,可以直接不加进去,(应该也可以加进去,但是requires_grad=False)

注意事项
3. 如果不传递参数的层,记得前向操作是要设置 with torch.no_grad,否则即便没有需要更新的参数,其layer的梯度也回传,效率低。

  1. 要保证所有参与前向的操作,都被用于计算loss。例如,a=self.layer(b),只要前向里出现了这个操作,就要保证a(或a的后续输出)要参与loss的计算。如果a算完了不用,是不可以的。(不论self.layer里是否有需要更新的参数)。ps:这点和不冻结设置下的要求不一样,如果所有参数都学,即便中间有一些变量操作是冗余的,也不会报错,只是增加计算代价而已。(比如,在clip框架里,如果不用text prompt, 就不要提取该特征)
  2. 要保证,所有需要更新的参数,都用于前向计算了。如何比较二者的参数,见下:

a. 记录需要梯度回传的参数:

grad_params = set()
for name, param in model.named_parameters():if param.requires_grad:grad_params.add(name)

b. 记录前向中使用的参数:

used_params = set()
def forward(self, x):for name, param in self.named_parameters():if param.requires_grad:param.register_hook(lambda grad, name=name: used_params.add(name))return self.model(x)

c. 比较二者差异

unused_params = grad_params - used_params
if unused_params:print("以下参数未在 forward 函数中使用:", unused_params)
else:print("所有需要计算梯度的参数都在 forward 函数中使用了。")

ps. 好像也可以通过在nn.parallel.DistributedDataParallel中设置find_unused_parameters=True来找到未使用的变量。(不过我没试过

这篇关于pytorch 参数冻结 parameter-efficient fine-tuning的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1111115

相关文章

判断PyTorch是GPU版还是CPU版的方法小结

《判断PyTorch是GPU版还是CPU版的方法小结》PyTorch作为当前最流行的深度学习框架之一,支持在CPU和GPU(NVIDIACUDA)上运行,所以对于深度学习开发者来说,正确识别PyTor... 目录前言为什么需要区分GPU和CPU版本?性能差异硬件要求如何检查PyTorch版本?方法1:使用命

SpringBoot请求参数接收控制指南分享

《SpringBoot请求参数接收控制指南分享》:本文主要介绍SpringBoot请求参数接收控制指南,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录Spring Boot 请求参数接收控制指南1. 概述2. 有注解时参数接收方式对比3. 无注解时接收参数默认位置

Python使用getopt处理命令行参数示例解析(最佳实践)

《Python使用getopt处理命令行参数示例解析(最佳实践)》getopt模块是Python标准库中一个简单但强大的命令行参数处理工具,它特别适合那些需要快速实现基本命令行参数解析的场景,或者需要... 目录为什么需要处理命令行参数?getopt模块基础实际应用示例与其他参数处理方式的比较常见问http

Linux内核参数配置与验证详细指南

《Linux内核参数配置与验证详细指南》在Linux系统运维和性能优化中,内核参数(sysctl)的配置至关重要,本文主要来聊聊如何配置与验证这些Linux内核参数,希望对大家有一定的帮助... 目录1. 引言2. 内核参数的作用3. 如何设置内核参数3.1 临时设置(重启失效)3.2 永久设置(重启仍生效

pytorch自动求梯度autograd的实现

《pytorch自动求梯度autograd的实现》autograd是一个自动微分引擎,它可以自动计算张量的梯度,本文主要介绍了pytorch自动求梯度autograd的实现,具有一定的参考价值,感兴趣... autograd是pytorch构建神经网络的核心。在 PyTorch 中,结合以下代码例子,当你

SpringMVC获取请求参数的方法

《SpringMVC获取请求参数的方法》:本文主要介绍SpringMVC获取请求参数的方法,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下... 目录1、通过ServletAPI获取2、通过控制器方法的形参获取请求参数3、@RequestParam4、@

在PyCharm中安装PyTorch、torchvision和OpenCV详解

《在PyCharm中安装PyTorch、torchvision和OpenCV详解》:本文主要介绍在PyCharm中安装PyTorch、torchvision和OpenCV方式,具有很好的参考价值,... 目录PyCharm安装PyTorch、torchvision和OpenCV安装python安装PyTor

Spring Boot项目部署命令java -jar的各种参数及作用详解

《SpringBoot项目部署命令java-jar的各种参数及作用详解》:本文主要介绍SpringBoot项目部署命令java-jar的各种参数及作用的相关资料,包括设置内存大小、垃圾回收... 目录前言一、基础命令结构二、常见的 Java 命令参数1. 设置内存大小2. 配置垃圾回收器3. 配置线程栈大小

pytorch之torch.flatten()和torch.nn.Flatten()的用法

《pytorch之torch.flatten()和torch.nn.Flatten()的用法》:本文主要介绍pytorch之torch.flatten()和torch.nn.Flatten()的用... 目录torch.flatten()和torch.nn.Flatten()的用法下面举例说明总结torch

SpringBoot利用@Validated注解优雅实现参数校验

《SpringBoot利用@Validated注解优雅实现参数校验》在开发Web应用时,用户输入的合法性校验是保障系统稳定性的基础,​SpringBoot的@Validated注解提供了一种更优雅的解... 目录​一、为什么需要参数校验二、Validated 的核心用法​1. 基础校验2. php分组校验3