分布式执行引擎ray入门--(3)Ray Train

2024-03-11 03:20

本文主要是介绍分布式执行引擎ray入门--(3)Ray Train,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

Ray Train中包含4个部分

  1. Training function: 包含训练模型逻辑的函数

  2. Worker: 用来跑训练的

  3. Scaling configuration: 配置

  4. Trainer: 协调以上三个部分

Ray Train+PyTorch

这一块比较建议直接去官网看diff,官网色块标注的比较清晰,非常直观。

import os
import tempfileimport torch
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor, Normalize, Composeimport ray.train.torchdef train_func(config):# Model, Loss, Optimizermodel = resnet18(num_classes=10)model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)# model.to("cuda")  # This is done by `prepare_model`# [1] Prepare model.model = ray.train.torch.prepare_model(model)criterion = CrossEntropyLoss()optimizer = Adam(model.parameters(), lr=0.001)# Datatransform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])data_dir = os.path.join(tempfile.gettempdir(), "data")train_data = FashionMNIST(root=data_dir, train=True, download=True, transform=transform)train_loader = DataLoader(train_data, batch_size=128, shuffle=True)# [2] Prepare dataloader.train_loader = ray.train.torch.prepare_data_loader(train_loader)# Trainingfor epoch in range(10):for images, labels in train_loader:# This is done by `prepare_data_loader`!# images, labels = images.to("cuda"), labels.to("cuda")outputs = model(images)loss = criterion(outputs, labels)optimizer.zero_grad()loss.backward()optimizer.step()# [3] Report metrics and checkpoint.metrics = {"loss": loss.item(), "epoch": epoch}with tempfile.TemporaryDirectory() as temp_checkpoint_dir:torch.save(model.module.state_dict(),os.path.join(temp_checkpoint_dir, "model.pt"))ray.train.report(metrics,checkpoint=ray.train.Checkpoint.from_directory(temp_checkpoint_dir),)if ray.train.get_context().get_world_rank() == 0:print(metrics)# [4] Configure scaling and resource requirements.
scaling_config = ray.train.ScalingConfig(num_workers=2, use_gpu=True)# [5] Launch distributed training job.
trainer = ray.train.torch.TorchTrainer(train_func,scaling_config=scaling_config,# [5a] If running in a multi-node cluster, this is where you# should configure the run's persistent storage that is accessible# across all worker nodes.# run_config=ray.train.RunConfig(storage_path="s3://..."),
)
result = trainer.fit()# [6] Load the trained model.
with result.checkpoint.as_directory() as checkpoint_dir:model_state_dict = torch.load(os.path.join(checkpoint_dir, "model.pt"))model = resnet18(num_classes=10)model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)model.load_state_dict(model_state_dict)

模型 

  ray.train.torch.prepare_model() 

model = ray.train.torch.prepare_model(model)
相当于model.to(device_id or "cpu") +  DistributedDataParallel(model, device_ids=[device_id])

将model移动到合适的device上,同时实现分布式

数据

ray.train.torch.prepare_data_loader() 

报告 checkpoints 和 metrics

+import ray.train
+from ray.train import Checkpointdef train_func(config):...torch.save(model.state_dict(), f"{checkpoint_dir}/model.pth"))
+    metrics = {"loss": loss.item()} # Training/validation metrics.
+    checkpoint = Checkpoint.from_directory(checkpoint_dir) # Build a Ray Train checkpoint from a directory
+    ray.train.report(metrics=metrics, checkpoint=checkpoint)...
data_loader = ray.train.torch.prepare_data_loader(data_loader)

将batches移动到合适的device上,同时实现分布式sampler

配置 scale 和 GPUs

from ray.train import ScalingConfig
scaling_config = ScalingConfig(num_workers=2, use_gpu=True)

配置持久化存储

多节点分布式训练时必须指定,本地路径会有问题。

from ray.train import RunConfig# Local path (/some/local/path/unique_run_name)
run_config = RunConfig(storage_path="/some/local/path", name="unique_run_name")# Shared cloud storage URI (s3://bucket/unique_run_name)
run_config = RunConfig(storage_path="s3://bucket", name="unique_run_name")# Shared NFS path (/mnt/nfs/unique_run_name)
run_config = RunConfig(storage_path="/mnt/nfs", name="unique_run_name")

启动训练任务

from ray.train.torch import TorchTrainertrainer = TorchTrainer(train_func, scaling_config=scaling_config, run_config=run_config
)
result = trainer.fit()

这篇关于分布式执行引擎ray入门--(3)Ray Train的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/796446

相关文章

Java Lettuce 客户端入门到生产的实现步骤

《JavaLettuce客户端入门到生产的实现步骤》本文主要介绍了JavaLettuce客户端入门到生产的实现步骤,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要... 目录1 安装依赖MavenGradle2 最小化连接示例3 核心特性速览4 生产环境配置建议5 常见问题

Linux kill正在执行的后台任务 kill进程组使用详解

《Linuxkill正在执行的后台任务kill进程组使用详解》文章介绍了两个脚本的功能和区别,以及执行这些脚本时遇到的进程管理问题,通过查看进程树、使用`kill`命令和`lsof`命令,分析了子... 目录零. 用到的命令一. 待执行的脚本二. 执行含子进程的脚本,并kill2.1 进程查看2.2 遇到的

Nginx分布式部署流程分析

《Nginx分布式部署流程分析》文章介绍Nginx在分布式部署中的反向代理和负载均衡作用,用于分发请求、减轻服务器压力及解决session共享问题,涵盖配置方法、策略及Java项目应用,并提及分布式事... 目录分布式部署NginxJava中的代理代理分为正向代理和反向代理正向代理反向代理Nginx应用场景

Java 线程池+分布式实现代码

《Java线程池+分布式实现代码》在Java开发中,池通过预先创建并管理一定数量的资源,避免频繁创建和销毁资源带来的性能开销,从而提高系统效率,:本文主要介绍Java线程池+分布式实现代码,需要... 目录1. 线程池1.1 自定义线程池实现1.1.1 线程池核心1.1.2 代码示例1.2 总结流程2. J

java中ssh2执行多条命令的四种方法

《java中ssh2执行多条命令的四种方法》本文主要介绍了java中ssh2执行多条命令的四种方法,包括分号分隔、管道分隔、EOF块、脚本调用,可确保环境配置生效,提升操作效率,具有一定的参考价值,感... 目录1 使用分号隔开2 使用管道符号隔开3 使用写EOF的方式4 使用脚本的方式大家平时有没有遇到自

mybatis直接执行完整sql及踩坑解决

《mybatis直接执行完整sql及踩坑解决》MyBatis可通过select标签执行动态SQL,DQL用ListLinkedHashMap接收结果,DML用int处理,注意防御SQL注入,优先使用#... 目录myBATiFBNZQs直接执行完整sql及踩坑select语句采用count、insert、u

一个Java的main方法在JVM中的执行流程示例详解

《一个Java的main方法在JVM中的执行流程示例详解》main方法是Java程序的入口点,程序从这里开始执行,:本文主要介绍一个Java的main方法在JVM中执行流程的相关资料,文中通过代码... 目录第一阶段:加载 (Loading)第二阶段:链接 (Linking)第三阶段:初始化 (Initia

C++统计函数执行时间的最佳实践

《C++统计函数执行时间的最佳实践》在软件开发过程中,性能分析是优化程序的重要环节,了解函数的执行时间分布对于识别性能瓶颈至关重要,本文将分享一个C++函数执行时间统计工具,希望对大家有所帮助... 目录前言工具特性核心设计1. 数据结构设计2. 单例模式管理器3. RAII自动计时使用方法基本用法高级用法

Java实现远程执行Shell指令

《Java实现远程执行Shell指令》文章介绍使用JSch在SpringBoot项目中实现远程Shell操作,涵盖环境配置、依赖引入及工具类编写,详解分号和双与号执行多指令的区别... 目录软硬件环境说明编写执行Shell指令的工具类总结jsch(Java Secure Channel)是SSH2的一个纯J

从入门到精通详解Python虚拟环境完全指南

《从入门到精通详解Python虚拟环境完全指南》Python虚拟环境是一个独立的Python运行环境,它允许你为不同的项目创建隔离的Python环境,下面小编就来和大家详细介绍一下吧... 目录什么是python虚拟环境一、使用venv创建和管理虚拟环境1.1 创建虚拟环境1.2 激活虚拟环境1.3 验证虚