PyTorch Demo-5 : 多GPU训练踩坑

2024-09-05 01:38
文章标签 训练 gpu pytorch demo

本文主要是介绍PyTorch Demo-5 : 多GPU训练踩坑,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

当数据量或者模型很大的时候往往单GPU已经无法满足我们的需求了,为了能够跑更大型的数据,多GPU训练是必要的。

PyTorch多卡训练的文章已经有很多,也写的很详细,比如:
https://zhuanlan.zhihu.com/p/98535650
https://zhuanlan.zhihu.com/p/74792767
不过写法各异,整合到自己的祖传代码里又有一些问题,在此记录一下踩坑。


DataParallel (DP)

最简单的是DP,只需要对model直接调用就可以了,更多细节可以参考前面的链接

gpus = [0, 1]
model = model.cuda(gpus)
model = nn.DataParallel(model, device_ids=gpus, output_device=gpus[0])

训练过程中需要把data设置 non_blocking=True ,参考non_blocking:

for idx, (data, target) in enumerate(train_loader):images = images.cuda(non_blocking=True)target = target.cuda(non_blocking=True)

DP只能用于单机多卡,由主卡分发再在主卡统筹,所以负载不均衡的问题比较严重,通常主卡会多占用1-2G显存,而且效率没有DDP高。


DistributedDataParallel (DDP)

采用all-reduce算法,适用于多机多卡,也适用于单机多卡。关于DDP的细节还是参考链接写的更清楚。
主要步骤:

  • 在argparser里面定义一个local_rank, 用于确定当前进程所在的GPU
parser.add_argument('--local_rank', default=-1, type=int,help='node rank for distributed training')
  • 初始化通信方式和端口,设定当前的GPU号
torch.distributed.init_process_group(backend='nccl')
torch.cuda.set_device(args.local_rank)
  • 分发训练数据
trainset = ...
train_sampler = None
# 设定一下参数,调用多卡才用
if use_multi_gpus:train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_loader = DataLoader(trainset, batch_size=...,shuffle=(train_sampler is None),sampler=train_sampler)
  • 分配模型
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])
  • 训练时数据
for idx, (data, target) in enumerate(train_loader):images = images.cuda(non_blocking=True)target = target.cuda(non_blocking=True)

使用DDP需要在终端启动

CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 torch_ddp.py

加入祖传代码

祖传代码Git
由于DDP模式是开启了多个进程来执行,因此在打印log和存储的时候可能会冲突导致打印的内容混乱,可以指定某一个rank打印,或者分批,把local_rank的默认值设为0的话,单卡的时候也能通用了:

# 指定rank打印
if args.local_rank == 0:print(f'loss:{loss:.4f}, acc:{acc:.4f} ...')# 打印出rank
print(f'rank:{args.local_rank} loss:{loss:.4f}, acc:{acc:.4f} ...')

参数存储通常是在测试时进行的,一方面可以指定 local_rank == 0 才存储,但是实际上,测试的时候每个GPU上的模型都是一样的,因此可以只测试一次,在循环交替的时候直接指定:

for epoch in range(total_epoch):train()scheduler.step()if args.local_rank == 0:test()

存储参数,在DDP模式下,直接存储参数key会变成model.module,可以在存之前先改成正常的state_dict:

if isinstance(model, nn.parallel.distributed.DistributedDataParallel):state = {'net': model.module.state_dict(),'acc': acc,'epoch': epoch}

DDP启动时可能会遇到地址冲突的情况,在启动命令中加入地址和端口 参考

CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 --master_addr 127.0.0.3 --master_port 23456 torch_ddp.py

find_unused_parameters=True错误,参考
参考1 参考2

这篇关于PyTorch Demo-5 : 多GPU训练踩坑的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1137611

相关文章

判断PyTorch是GPU版还是CPU版的方法小结

《判断PyTorch是GPU版还是CPU版的方法小结》PyTorch作为当前最流行的深度学习框架之一,支持在CPU和GPU(NVIDIACUDA)上运行,所以对于深度学习开发者来说,正确识别PyTor... 目录前言为什么需要区分GPU和CPU版本?性能差异硬件要求如何检查PyTorch版本?方法1:使用命

pytorch自动求梯度autograd的实现

《pytorch自动求梯度autograd的实现》autograd是一个自动微分引擎,它可以自动计算张量的梯度,本文主要介绍了pytorch自动求梯度autograd的实现,具有一定的参考价值,感兴趣... autograd是pytorch构建神经网络的核心。在 PyTorch 中,结合以下代码例子,当你

在PyCharm中安装PyTorch、torchvision和OpenCV详解

《在PyCharm中安装PyTorch、torchvision和OpenCV详解》:本文主要介绍在PyCharm中安装PyTorch、torchvision和OpenCV方式,具有很好的参考价值,... 目录PyCharm安装PyTorch、torchvision和OpenCV安装python安装PyTor

pytorch之torch.flatten()和torch.nn.Flatten()的用法

《pytorch之torch.flatten()和torch.nn.Flatten()的用法》:本文主要介绍pytorch之torch.flatten()和torch.nn.Flatten()的用... 目录torch.flatten()和torch.nn.Flatten()的用法下面举例说明总结torch

Springboot处理跨域的实现方式(附Demo)

《Springboot处理跨域的实现方式(附Demo)》:本文主要介绍Springboot处理跨域的实现方式(附Demo),具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不... 目录Springboot处理跨域的方式1. 基本知识2. @CrossOrigin3. 全局跨域设置4.

使用PyTorch实现手写数字识别功能

《使用PyTorch实现手写数字识别功能》在人工智能的世界里,计算机视觉是最具魅力的领域之一,通过PyTorch这一强大的深度学习框架,我们将在经典的MNIST数据集上,见证一个神经网络从零开始学会识... 目录当计算机学会“看”数字搭建开发环境MNIST数据集解析1. 认识手写数字数据库2. 数据预处理的

Pytorch微调BERT实现命名实体识别

《Pytorch微调BERT实现命名实体识别》命名实体识别(NER)是自然语言处理(NLP)中的一项关键任务,它涉及识别和分类文本中的关键实体,BERT是一种强大的语言表示模型,在各种NLP任务中显著... 目录环境准备加载预训练BERT模型准备数据集标记与对齐微调 BERT最后总结环境准备在继续之前,确

pytorch+torchvision+python版本对应及环境安装

《pytorch+torchvision+python版本对应及环境安装》本文主要介绍了pytorch+torchvision+python版本对应及环境安装,安装过程中需要注意Numpy版本的降级,... 目录一、版本对应二、安装命令(pip)1. 版本2. 安装全过程3. 命令相关解释参考文章一、版本对

从零教你安装pytorch并在pycharm中使用

《从零教你安装pytorch并在pycharm中使用》本文详细介绍了如何使用Anaconda包管理工具创建虚拟环境,并安装CUDA加速平台和PyTorch库,同时在PyCharm中配置和使用PyTor... 目录背景介绍安装Anaconda安装CUDA安装pytorch报错解决——fbgemm.dll连接p

pycharm远程连接服务器运行pytorch的过程详解

《pycharm远程连接服务器运行pytorch的过程详解》:本文主要介绍在Linux环境下使用Anaconda管理不同版本的Python环境,并通过PyCharm远程连接服务器来运行PyTorc... 目录linux部署pytorch背景介绍Anaconda安装Linux安装pytorch虚拟环境安装cu