Pytorch Ignite 使用方法

2024-01-16 15:48
文章标签 使用 方法 pytorch ignite

本文主要是介绍Pytorch Ignite 使用方法,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

Pytorch Ignite 使用方法

下载 pip install ignite
官方网址:https://pytorch.org/ignite/concepts.html

Engine

该框架的本质是class Engine,它是一种抽象形式,它在提供的数据上循环给定的次数,执行处理函数并返回结果:

  while epoch < max_epochs:# run an epoch on datadata_iter = iter(data)while True:try:batch = next(data_iter)output = process_function(batch)iter_counter += 1except StopIteration:data_iter = iter(data)if iter_counter == epoch_length:break

因此,模型训练器只是一个引擎,它在训练数据集上循环多次并更新模型参数。同样,可以使用在验证数据集上运行一次并计算指标的引擎来完成模型评估。

例如,用于监督任务的模型训练器:

def train_step(trainer, batch):model.train()optimizer.zero_grad()x, y = prepare_batch(batch)y_pred = model(x)loss = loss_fn(y_pred, y)loss.backward()optimizer.step()return loss.item()trainer = Engine(train_step)
trainer.run(data, max_epochs=100)

训练步骤的输出类型(即在上面的示例中loss.item())不受限制。训练步骤功能可以返回用户想要的一切。输出设置为,trainer.state.output并且可以进一步用于任何类型的处理。

默认情况下,epoch_length 长度由len(data)定义。但是,用户也可以手动将epoch_length 长度定义为要循环的多次迭代。这样,输入数据可以是迭代器。

trainer.run(data, max_epochs=100, epoch_length=200)

如果data是长度未知的有限数据迭代器(对于用户),epoch_length则可以省略参数,并且在耗尽数据迭代器时将自动确定参数。

任何复杂度的训练逻辑都可以使用train_step方法进行编码,并且可以使用此方法来构造训练器。

函数batch中的train_step参数是用户定义的,可以包含单个迭代所需的任何数据。

# 定义模型参数
model_1 = ...
model_2 = ...
# 定义优化器
optimizer_1 = ...
optimizer_2 = ...
# 
criterion_1 = ...
criterion_2 = ...
# ...def train_step(trainer, batch):data_1 = batch["data_1"]data_2 = batch["data_2"]# ...model_1.train()optimizer_1.zero_grad()loss_1 = forward_pass(data_1, model_1, criterion_1)loss_1.backward()optimizer_1.step()# ...model_2.train()optimizer_2.zero_grad()loss_2 = forward_pass(data_2, model_2, criterion_2)loss_2.backward()optimizer_2.step()# ...# User can return any type of structure.# 用户可以返回任何类型的数据结构return {"loss_1": loss_1,"loss_2": loss_2,# ...}trainer = Engine(train_step)
trainer.run(data, max_epochs=100)

事件和处理程序

为了提高Engine灵活性,引入了事件系统,该系统促进了运行的每个步骤之间的交互:

  • Engine启动/完成
  • epoch开始/完成
  • 批处理迭代已开始/已完成

有关事件的完整列表,请参见Events

用户可以执行自定义代码作为事件处理程序。处理程序可以是任何函数:例如lambda,简单函数,类方法等。第一个参数可以选择是engine,但不是必须的。

让我们更详细地考虑当run()被调用时发生的情况:

fire_event(Events.STARTED)
while epoch < max_epochs:fire_event(Events.EPOCH_STARTED)# run once on datafor batch in data:fire_event(Events.ITERATION_STARTED)output = process_function(batch)fire_event(Events.ITERATION_COMPLETED)fire_event(Events.EPOCH_COMPLETED)
fire_event(Events.COMPLETED)

首先,将触发*“Engine已启动”事件并执行其所有事件处理程序(我们将在下一段中看到如何添加事件处理程序)。接下来,在启动循环并发生“epoch 开始”*事件时,等等。每次触发事件时,都会执行附加的处理程序。

使用方法add_event_handler()on()装饰器附加事件处理程序很简单:

trainer = Engine(update_model)trainer.add_event_handler(Events.STARTED, lambda _: print("Start training"))
# or
@trainer.on(Events.STARTED)
def on_training_started(engine):print("Another message of start training")
# or even simpler, use only what you need !
@trainer.on(Events.STARTED)
def on_training_started():print("Another message of start training")# attach handler with args, kwargs
mydata = [1, 2, 3, 4]def on_training_ended(data):print(f"Training is ended. mydata={data}")trainer.add_event_handler(Events.COMPLETED, on_training_ended, mydata)

可以通过remove_event_handler()或通过RemovableEventHandle 返回的引用来分离事件处理程序add_event_handler()。这可用于将已配置的引擎重用于多个循环:

model = ...
train_loader, validation_loader, test_loader = ...trainer = create_supervised_trainer(model, optimizer, loss)
evaluator = create_supervised_evaluator(model, metrics={"acc": Accuracy()})def log_metrics(engine, title):print(f"Epoch: {trainer.state.epoch} - {title} accuracy: {engine.state.metrics['acc']:.2f}")@trainer.on(Events.EPOCH_COMPLETED)
def evaluate(trainer):with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "train"):evaluator.run(train_loader)with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "validation"):evaluator.run(validation_loader)with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "test"):evaluator.run(test_loader)trainer.run(train_loader, max_epochs=100)

还可以将事件处理程序配置为以用户模式调用:每第n个事件一次,或使用自定义事件过滤功能:

model = ...
train_loader, validation_loader, test_loader = ...trainer = create_supervised_trainer(model, optimizer, loss)@trainer.on(Events.ITERATION_COMPLETED(every=50))
def log_training_loss_every_50_iterations():print(f"{trainer.state.epoch} / {trainer.state.max_epochs} : {trainer.state.iteration} - loss: {trainer.state.output:.2f}")@trainer.on(Events.EPOCH_STARTED(once=25))
def do_something_once_on_25_epoch():# do somethingdef custom_event_filter(engine, event):if event in [1, 2, 5, 10, 50, 100]:return Truereturn False@engine.on(Events.ITERATION_STARTED(event_filter=custom_event_filter))
def call_on_special_event(engine):# do something on 1, 2, 5, 10, 50, 100 iterationstrainer.run(train_loader, max_epochs=100)

自定义事件

用户还可以定义自定义事件。用户定义的事件应继承于引擎EventEnumregister_events()在引擎中注册。

from ignite.engine import EventEnumclass CustomEvents(EventEnum):"""Custom events defined by user"""CUSTOM_STARTED = 'custom_started'CUSTOM_COMPLETED = 'custom_completed'engine.register_events(*CustomEvents)

这些事件可用于附加任何处理程序,并使用触发fire_event()

@engine.on(CustomEvents.CUSTOM_STARTED)
def call_on_custom_event(engine):# do something@engine.on(Events.STARTED)
def fire_custom_events(engine):engine.fire_event(CustomEvents.CUSTOM_STARTED)

时间线和事件

在事件下方,一些典型的处理程序显示在时间轴上,以进行训练循环,并在每个时期后进行评估:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-U3DJJvxX-1618109731932)(https://pytorch.org/ignite/_images/timeline_and_events.png)]

状态

Engine引入了一个状态来存储process_function,当前时期,迭代和其他有用信息的输出。每个都Engine包含一个State,其中包括以下内容:

  • engine.state.seed:要在每个数据“ epoch”处设置的种子。
  • engine.state.epoch:引擎已完成的纪元数。初始化为0,第一个时期为1。
  • engine.state.iteration:引擎已完成的迭代次数。初始化为0,第一次迭代为1。
  • engine.state.max_epochs:要运行的时期数。初始化为1。
  • engine.state.output:为定义的process_function的输出Engine。见下文。
  • 等等

其他属性可以在的文档中找到State

在下面的代码中,engine.state.output将存储批次损失。此输出用于打印每次迭代的损耗。

def update(engine, batch):x, y = batchy_pred = model(inputs)loss = loss_fn(y_pred, y)optimizer.zero_grad()loss.backward()optimizer.step()return loss.item()def on_iteration_completed(engine):iteration = engine.state.iterationepoch = engine.state.epochloss = engine.state.outputprint(f"Epoch: {epoch}, Iteration: {iteration}, Loss: {loss}")trainer.add_event_handler(Events.ITERATION_COMPLETED, on_iteration_completed)

由于对process_function的输出没有限制,因此Ignite为其和提供了output_transform参数 。参数output_transform是用于将engine.state.output转换为预期用途的函数。在下面,我们将看到不同类型的engine.state.output以及如何对其进行转换。metrics``handlers

在下面的代码中,engine.state.output将已处理批次的loss,y_pred,y。如果要附加Accuracy到引擎,则需要output_transform来从engine.state.output获取y_pred和y 。让我们看看如何做到这一点:

def update(engine, batch):x, y = batchy_pred = model(inputs)loss = loss_fn(y_pred, y)optimizer.zero_grad()loss.backward()optimizer.step()return loss.item(), y_pred, ytrainer = Engine(update)@trainer.on(Events.EPOCH_COMPLETED)
def print_loss(engine):epoch = engine.state.epochloss = engine.state.output[0]print (f'Epoch {epoch}: train_loss = {loss}')accuracy = Accuracy(output_transform=lambda x: [x[1], x[2]])
accuracy.attach(trainer, 'acc')
trainer.run(data, max_epochs=10)

与上面类似,但是这次process_function的输出是处理后的批次的字典 loss,y_pred,y,这是用户可以使用output_transform从engine.state.output获取y_pred和y的方式。见下文:

def update(engine, batch):x, y = batchy_pred = model(inputs)loss = loss_fn(y_pred, y)optimizer.zero_grad()loss.backward()optimizer.step()return {'loss': loss.item(),'y_pred': y_pred,'y': y}trainer = Engine(update)@trainer.on(Events.EPOCH_COMPLETED)
def print_loss(engine):epoch = engine.state.epochloss = engine.state.output['loss']print (f'Epoch {epoch}: train_loss = {loss}')accuracy = Accuracy(output_transform=lambda x: [x['y_pred'], x['y']])
accuracy.attach(trainer, 'acc')
trainer.run(data, max_epochs=10)

指标

库提供了各种机器学习任务的现成指标列表。支持两种计算指标的方式:1)在线 和 2)存储整个输出历史记录。

指标可以附加到Engine

from ignite.metrics import Accuracyaccuracy = Accuracy()accuracy.attach(evaluator, "accuracy")state = evaluator.run(validation_data)print("Result:", state.metrics)
# > {"accuracy": 0.12345}

或可用作独立对象:

from ignite.metrics import Accuracyaccuracy = Accuracy()accuracy.reset()for y_pred, y in get_prediction_target():accuracy.update((y_pred, y))print("Result:", accuracy.compute())

或可用作独立对象:

from ignite.metrics import Accuracyaccuracy = Accuracy()accuracy.reset()for y_pred, y in get_prediction_target():accuracy.update((y_pred, y))print("Result:", accuracy.compute())

完整的指标列表和API可以在ignite.metrics模块中找到。

这篇关于Pytorch Ignite 使用方法的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/613089

相关文章

Python中反转字符串的常见方法小结

《Python中反转字符串的常见方法小结》在Python中,字符串对象没有内置的反转方法,然而,在实际开发中,我们经常会遇到需要反转字符串的场景,比如处理回文字符串、文本加密等,因此,掌握如何在Pyt... 目录python中反转字符串的方法技术背景实现步骤1. 使用切片2. 使用 reversed() 函

Python中将嵌套列表扁平化的多种实现方法

《Python中将嵌套列表扁平化的多种实现方法》在Python编程中,我们常常会遇到需要将嵌套列表(即列表中包含列表)转换为一个一维的扁平列表的需求,本文将给大家介绍了多种实现这一目标的方法,需要的朋... 目录python中将嵌套列表扁平化的方法技术背景实现步骤1. 使用嵌套列表推导式2. 使用itert

使用Docker构建Python Flask程序的详细教程

《使用Docker构建PythonFlask程序的详细教程》在当今的软件开发领域,容器化技术正变得越来越流行,而Docker无疑是其中的佼佼者,本文我们就来聊聊如何使用Docker构建一个简单的Py... 目录引言一、准备工作二、创建 Flask 应用程序三、创建 dockerfile四、构建 Docker

Python使用vllm处理多模态数据的预处理技巧

《Python使用vllm处理多模态数据的预处理技巧》本文深入探讨了在Python环境下使用vLLM处理多模态数据的预处理技巧,我们将从基础概念出发,详细讲解文本、图像、音频等多模态数据的预处理方法,... 目录1. 背景介绍1.1 目的和范围1.2 预期读者1.3 文档结构概述1.4 术语表1.4.1 核

Python使用pip工具实现包自动更新的多种方法

《Python使用pip工具实现包自动更新的多种方法》本文深入探讨了使用Python的pip工具实现包自动更新的各种方法和技术,我们将从基础概念开始,逐步介绍手动更新方法、自动化脚本编写、结合CI/C... 目录1. 背景介绍1.1 目的和范围1.2 预期读者1.3 文档结构概述1.4 术语表1.4.1 核

在Linux中改变echo输出颜色的实现方法

《在Linux中改变echo输出颜色的实现方法》在Linux系统的命令行环境下,为了使输出信息更加清晰、突出,便于用户快速识别和区分不同类型的信息,常常需要改变echo命令的输出颜色,所以本文给大家介... 目python录在linux中改变echo输出颜色的方法技术背景实现步骤使用ANSI转义码使用tpu

Conda与Python venv虚拟环境的区别与使用方法详解

《Conda与Pythonvenv虚拟环境的区别与使用方法详解》随着Python社区的成长,虚拟环境的概念和技术也在不断发展,:本文主要介绍Conda与Pythonvenv虚拟环境的区别与使用... 目录前言一、Conda 与 python venv 的核心区别1. Conda 的特点2. Python v

Spring Boot中WebSocket常用使用方法详解

《SpringBoot中WebSocket常用使用方法详解》本文从WebSocket的基础概念出发,详细介绍了SpringBoot集成WebSocket的步骤,并重点讲解了常用的使用方法,包括简单消... 目录一、WebSocket基础概念1.1 什么是WebSocket1.2 WebSocket与HTTP

C#中Guid类使用小结

《C#中Guid类使用小结》本文主要介绍了C#中Guid类用于生成和操作128位的唯一标识符,用于数据库主键及分布式系统,支持通过NewGuid、Parse等方法生成,感兴趣的可以了解一下... 目录前言一、什么是 Guid二、生成 Guid1. 使用 Guid.NewGuid() 方法2. 从字符串创建

Python使用python-can实现合并BLF文件

《Python使用python-can实现合并BLF文件》python-can库是Python生态中专注于CAN总线通信与数据处理的强大工具,本文将使用python-can为BLF文件合并提供高效灵活... 目录一、python-can 库:CAN 数据处理的利器二、BLF 文件合并核心代码解析1. 基础合