【MetaLearning】有关Pytorch的元学习库higher的基本用法

2023-11-22 05:04

本文主要是介绍【MetaLearning】有关Pytorch的元学习库higher的基本用法,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

【MetaLearning】有关Pytorch的元学习库higher的基本用法

文章目录

  • 【MetaLearning】有关Pytorch的元学习库higher的基本用法
    • 1. 基本介绍
    • 2. Toy Example
    • Reference

1. 基本介绍

higher.innerloop_ctxhigher库的上下文管理器,用于创建内部循环(inner loop)的上下文,内部循环通常用于元学习场景,其中在模型参数更新的内部循环中进行一些额外的操作。

这个上下文管理器主要有五个参数:(详细请参考官方库说明)

higher.innerloop_ctx(model, opt, device=None, copy_initial_weights=True, override=None, track_higher_grads=True)
  • 第一个参数model是需要进行内部循环的模型,通常是你的元模型
  • 第二个参数opt是优化器,这是你用来更新模型参数的优化器
  • 第三个参数copy_initial_weights是一个布尔值,用于指定是否在每个内部循环之前复制初始权重,如果设置为True则表示在每个内部循环之前都会将模型的初始权重进行复制,以确保每个内部循环都从相同的初始权重开始。如果设置为False,则所有的内部循环共享相同的权重模型。
  • 第四个参数override是一个字典,例如override={'lr':lr_tensor, "momentum': momentum_tensor},用于指定在内部循环期间覆盖优化器的参数,比如在这里示例中,lr_tensormomentum_tensor是张量,用于指定内部循环期间覆盖的学习率和动量。
  • 第五个参数track_higher_grads是一个布尔值,用于跟踪更高阶的梯度,如果是True,则内部循环中计算的梯度将被跟踪以支持高阶的梯度计算,如果设置为False,则不会跟踪高阶梯度。

with语句块中,通过(fmodel, diffopt)获取内部循环的上下文。fmodel表示内部循环中的模型,diffopt表示内部循环中的优化器,在这个上下文中,你可以执行内部循环的计算和参数更新。

下面给出一个基本的使用示例,演示如何使用higher.innerloop_ctx,使用higher库需要习惯下列的转变

从通常使用pytorch的用法

model = MyModel()
opt = torch.optim.Adam(model.parameters())for xs, ys in data:opt.zero_grad()logits = model(xs)loss = loss_function(logits, ys)loss.backward()opt.step()

转变到

model = MyModel()
opt = torch.optim.Adam(model.parameters())with higher.innerloop_ctx(model, opt) as (fmodel, diffopt):for xs, ys in data:logits = fmodel(xs)  # modified `params` can also be passed as a kwargloss = loss_function(logits, ys)  # no need to call loss.backwards()diffopt.step(loss)  # note that `step` must take `loss` as an argument!,这一步相当于使用了loss.backward()和opt.step()# At the end of your inner loop you can obtain these e.g. ...grad_of_grads = torch.autograd.grad(meta_loss_fn(fmodel.parameters()), fmodel.parameters(time=0))

训练模型和执行diffopt.step 来更新fmodel之间的区别在于,fmodel不会像原始部分中的opt.step()那样就地更新参数。 相反,每次调用 diffopt.step时都会以这样的方式创建新版本的参数,即fmodel将在下一步中使用新的参数,但所有以前的参数仍会保留。

运行的原理是什么呢?举个例子,fmodelfmodel.parameters(time=0)开始迭代(这里的time=0表示就是第0次迭代),当我们调用diffopt.stepN次之后,我们可以使用fmodel.parameters(time=i)来访问,其中i可以取到1N,并且我们仍然可以访问fmodel.parameters(time=0),这个结果和迭代之前是一样的,这是为什么呢?

因为fmodel的创建依赖于参数copy_initial_weights,如果copy_initial_weights=True,那么fmodel.parameters(time=0)是从原模型clone’d别且是detach’ed(即是从原模型克隆过来并且进行分离计算图了),如果copy_initial_weights=False,那么只是进行了clone’d并没有detach‘ed。

放一段原文在这里方便大家理解

I.e. fmodel starts with only fmodel.parameters(time=0) available, but after you called diffopt.step N times you can ask fmodel to give you fmodel.parameters(time=i) for any i up to N inclusive. Notice that fmodel.parameters(time=0) doesn’t change in this process at all, just every time fmodel is applied to some input it will use the latest version of parameters it currently has.

Now, what exactly is fmodel.parameters(time=0)? It is created here and depends on copy_initial_weights. If copy_initial_weights==True then fmodel.parameters(time=0) are clone’d and detach’ed parameters of model. Otherwise they are only clone’d, but not detach’ed!

That means that when we do meta-optimization step, the original model’s parameters will actually accumulate gradients if and only if copy_initial_weights==False. And in MAML we want to optimize model’s starting weights so we actually do need to get gradients from meta-optimization step.

2. Toy Example

import torch
import torch.nn as nn
import torch.optim as optim
import higher
import numpy as npnp.random.seed(1)
torch.manual_seed(3)
N = 100
actual_multiplier = 3.5
meta_lr = 0.00001
loops = 5 # how many iterations in the inner loop we want to dox = torch.tensor(np.random.random((N,1)), dtype=torch.float64) # features for inner training loop
y = x * actual_multiplier # target for inner training loop
model = nn.Linear(1, 1, bias=False).double() # simplest possible model - multiple input x by weight w without bias
meta_opt = optim.SGD(model.parameters(), lr=meta_lr, momentum=0.)def run_inner_loop_once(model, verbose, copy_initial_weights):lr_tensor = torch.tensor([0.3], requires_grad=True)momentum_tensor = torch.tensor([0.5], requires_grad=True)opt = optim.SGD(model.parameters(), lr=0.3, momentum=0.5)with higher.innerloop_ctx(model, opt, copy_initial_weights=copy_initial_weights, override={'lr': lr_tensor, 'momentum': momentum_tensor}) as (fmodel, diffopt):for j in range(loops):if verbose:print('Starting inner loop step j=={0}'.format(j))print('    Representation of fmodel.parameters(time={0}): {1}'.format(j, str(list(fmodel.parameters(time=j)))))print('    Notice that fmodel.parameters() is same as fmodel.parameters(time={0}): {1}'.format(j, (list(fmodel.parameters())[0] is list(fmodel.parameters(time=j))[0])))out = fmodel(x)if verbose:print('    Notice how `out` is `x` multiplied by the latest version of weight: {0:.4} * {1:.4} == {2:.4}'.format(x[0,0].item(), list(fmodel.parameters())[0].item(), out[0].item()))loss = ((out - y)**2).mean()diffopt.step(loss)if verbose:# after all inner training let's see all steps' parameter tensorsprint()print("Let's print all intermediate parameters versions after inner loop is done:")for j in range(loops+1):print('    For j=={0} parameter is: {1}'.format(j, str(list(fmodel.parameters(time=j)))))print()# let's imagine now that our meta-learning optimization is trying to check how far we got in the end from the actual_multiplierweight_learned_after_full_inner_loop = list(fmodel.parameters())[0]meta_loss = (weight_learned_after_full_inner_loop - actual_multiplier)**2print('  Final meta-loss: {0}'.format(meta_loss.item()))meta_loss.backward() # will only propagate gradient to original model parameter's `grad` if copy_initial_weight=Falseif verbose:print('  Gradient of final loss we got for lr and momentum: {0} and {1}'.format(lr_tensor.grad, momentum_tensor.grad))print('  If you change number of iterations "loops" to much larger number final loss will be stable and the values above will be smaller')return meta_loss.item()print('=================== Run Inner Loop First Time (copy_initial_weights=True) =================\n')
meta_loss_val1 = run_inner_loop_once(model, verbose=True, copy_initial_weights=True)
print("\nLet's see if we got any gradient for initial model parameters: {0}\n".format(list(model.parameters())[0].grad))print('=================== Run Inner Loop Second Time (copy_initial_weights=False) =================\n')
meta_loss_val2 = run_inner_loop_once(model, verbose=False, copy_initial_weights=False)
print("\nLet's see if we got any gradient for initial model parameters: {0}\n".format(list(model.parameters())[0].grad))print('=================== Run Inner Loop Third Time (copy_initial_weights=False) =================\n')
final_meta_gradient = list(model.parameters())[0].grad.item()
# Now let's double-check `higher` library is actually doing what it promised to do, not just giving us
# a bunch of hand-wavy statements and difficult to read code.
# We will do a simple SGD step using meta_opt changing initial weight for the training and see how meta loss changed
meta_opt.step()
meta_opt.zero_grad()
meta_step = - meta_lr * final_meta_gradient # how much meta_opt actually shifted inital weight value
# before we run inner loop third time, we update the meta parameter firstly.
meta_loss_val3 = run_inner_loop_once(model, verbose=False, copy_initial_weights=False)meta_loss_gradient_approximation = (meta_loss_val3 - meta_loss_val2) / meta_stepprint()
print('Side-by-side meta_loss_gradient_approximation and gradient computed by `higher` lib: {0:.4} VS {1:.4}'.format(meta_loss_gradient_approximation, final_meta_gradient))

结果如下

=================== Run Inner Loop First Time (copy_initial_weights=True) =================Starting inner loop step j==0Representation of fmodel.parameters(time=0): [tensor([[-0.9915]], dtype=torch.float64, requires_grad=True)]Notice that fmodel.parameters() is same as fmodel.parameters(time=0): TrueNotice how `out` is `x` multiplied by the latest version of weight: 0.417 * -0.9915 == -0.4135
Starting inner loop step j==1Representation of fmodel.parameters(time=1): [tensor([[-0.1217]], dtype=torch.float64, grad_fn=<AddBackward0>)]Notice that fmodel.parameters() is same as fmodel.parameters(time=1): TrueNotice how `out` is `x` multiplied by the latest version of weight: 0.417 * -0.1217 == -0.05075
Starting inner loop step j==2Representation of fmodel.parameters(time=2): [tensor([[1.0145]], dtype=torch.float64, grad_fn=<AddBackward0>)]Notice that fmodel.parameters() is same as fmodel.parameters(time=2): TrueNotice how `out` is `x` multiplied by the latest version of weight: 0.417 * 1.015 == 0.4231
Starting inner loop step j==3Representation of fmodel.parameters(time=3): [tensor([[2.0640]], dtype=torch.float64, grad_fn=<AddBackward0>)]Notice that fmodel.parameters() is same as fmodel.parameters(time=3): TrueNotice how `out` is `x` multiplied by the latest version of weight: 0.417 * 2.064 == 0.8607
Starting inner loop step j==4Representation of fmodel.parameters(time=4): [tensor([[2.8668]], dtype=torch.float64, grad_fn=<AddBackward0>)]Notice that fmodel.parameters() is same as fmodel.parameters(time=4): TrueNotice how `out` is `x` multiplied by the latest version of weight: 0.417 * 2.867 == 1.196Let's print all intermediate parameters versions after inner loop is done:For j==0 parameter is: [tensor([[-0.9915]], dtype=torch.float64, requires_grad=True)]For j==1 parameter is: [tensor([[-0.1217]], dtype=torch.float64, grad_fn=<AddBackward0>)]For j==2 parameter is: [tensor([[1.0145]], dtype=torch.float64, grad_fn=<AddBackward0>)]For j==3 parameter is: [tensor([[2.0640]], dtype=torch.float64, grad_fn=<AddBackward0>)]For j==4 parameter is: [tensor([[2.8668]], dtype=torch.float64, grad_fn=<AddBackward0>)]For j==5 parameter is: [tensor([[3.3908]], dtype=torch.float64, grad_fn=<AddBackward0>)]Final meta-loss: 0.011927987982895929Gradient of final loss we got for lr and momentum: tensor([-1.6295]) and tensor([-0.9496])If you change number of iterations "loops" to much larger number final loss will be stable and the values above will be smallerLet's see if we got any gradient for initial model parameters: None=================== Run Inner Loop Second Time (copy_initial_weights=False) =================Final meta-loss: 0.011927987982895929Let's see if we got any gradient for initial model parameters: tensor([[-0.0053]], dtype=torch.float64)=================== Run Inner Loop Third Time (copy_initial_weights=False) =================Final meta-loss: 0.01192798770078706Side-by-side meta_loss_gradient_approximation and gradient computed by `higher` lib: -0.005311 VS -0.005311

Reference

Parper: Generalized Inner Loop Meta-Learning
What does the copy_initial_weights documentation mean in the higher library for Pytorch?

这篇关于【MetaLearning】有关Pytorch的元学习库higher的基本用法的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/407707

相关文章

MySQL 中的 CAST 函数详解及常见用法

《MySQL中的CAST函数详解及常见用法》CAST函数是MySQL中用于数据类型转换的重要函数,它允许你将一个值从一种数据类型转换为另一种数据类型,本文给大家介绍MySQL中的CAST... 目录mysql 中的 CAST 函数详解一、基本语法二、支持的数据类型三、常见用法示例1. 字符串转数字2. 数字

Python中你不知道的gzip高级用法分享

《Python中你不知道的gzip高级用法分享》在当今大数据时代,数据存储和传输成本已成为每个开发者必须考虑的问题,Python内置的gzip模块提供了一种简单高效的解决方案,下面小编就来和大家详细讲... 目录前言:为什么数据压缩如此重要1. gzip 模块基础介绍2. 基本压缩与解压缩操作2.1 压缩文

解读GC日志中的各项指标用法

《解读GC日志中的各项指标用法》:本文主要介绍GC日志中的各项指标用法,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、基础 GC 日志格式(以 G1 为例)1. Minor GC 日志2. Full GC 日志二、关键指标解析1. GC 类型与触发原因2. 堆

MySQL数据库中ENUM的用法是什么详解

《MySQL数据库中ENUM的用法是什么详解》ENUM是一个字符串对象,用于指定一组预定义的值,并可在创建表时使用,下面:本文主要介绍MySQL数据库中ENUM的用法是什么的相关资料,文中通过代码... 目录mysql 中 ENUM 的用法一、ENUM 的定义与语法二、ENUM 的特点三、ENUM 的用法1

JavaSE正则表达式用法总结大全

《JavaSE正则表达式用法总结大全》正则表达式就是由一些特定的字符组成,代表的是一个规则,:本文主要介绍JavaSE正则表达式用法的相关资料,文中通过代码介绍的非常详细,需要的朋友可以参考下... 目录常用的正则表达式匹配符正则表China编程达式常用的类Pattern类Matcher类PatternSynta

Go语言数据库编程GORM 的基本使用详解

《Go语言数据库编程GORM的基本使用详解》GORM是Go语言流行的ORM框架,封装database/sql,支持自动迁移、关联、事务等,提供CRUD、条件查询、钩子函数、日志等功能,简化数据库操作... 目录一、安装与初始化1. 安装 GORM 及数据库驱动2. 建立数据库连接二、定义模型结构体三、自动迁

MySQL之InnoDB存储引擎中的索引用法及说明

《MySQL之InnoDB存储引擎中的索引用法及说明》:本文主要介绍MySQL之InnoDB存储引擎中的索引用法及说明,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐... 目录1、背景2、准备3、正篇【1】存储用户记录的数据页【2】存储目录项记录的数据页【3】聚簇索引【4】二

mysql中的数据目录用法及说明

《mysql中的数据目录用法及说明》:本文主要介绍mysql中的数据目录用法及说明,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1、背景2、版本3、数据目录4、总结1、背景安装mysql之后,在安装目录下会有一个data目录,我们创建的数据库、创建的表、插入的

ModelMapper基本使用和常见场景示例详解

《ModelMapper基本使用和常见场景示例详解》ModelMapper是Java对象映射库,支持自动映射、自定义规则、集合转换及高级配置(如匹配策略、转换器),可集成SpringBoot,减少样板... 目录1. 添加依赖2. 基本用法示例:简单对象映射3. 自定义映射规则4. 集合映射5. 高级配置匹

深度解析Python装饰器常见用法与进阶技巧

《深度解析Python装饰器常见用法与进阶技巧》Python装饰器(Decorator)是提升代码可读性与复用性的强大工具,本文将深入解析Python装饰器的原理,常见用法,进阶技巧与最佳实践,希望可... 目录装饰器的基本原理函数装饰器的常见用法带参数的装饰器类装饰器与方法装饰器装饰器的嵌套与组合进阶技巧