Pytorch(1.2.0+):多机单卡并行实操(MNIST识别)

2023-12-23 12:10

本文主要是介绍Pytorch(1.2.0+):多机单卡并行实操(MNIST识别),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

背景

简单实际操作一下用Pytorch(1.2.0+)进行多机单卡并行训练,可能就不太关注原理了。


参考

https://blog.csdn.net/u010557442/article/details/79431520
https://zhuanlan.zhihu.com/p/116482019
https://blog.csdn.net/gbyy42299/article/details/103673840
https://blog.csdn.net/m0_38008956/article/details/86559432


代码

https://gitee.com/KevinYan37/pytorch_ddp

流程

1. 配置环境

将多台配置一模一样的电脑(ubuntu系统,显卡版本,NVIDIA驱动,CUDA驱动,pytorch版本)置于同一网段下,例如我的两台电脑分别在192.168.10.235192.168.10.236,同时关闭防火墙等操作。

2. 确认环境
import torch
print(torch.__version__)
print(torch.cuda.is_available())
print(torch.distributed.is_available())
3. MNIST数据集代码

以下代码都是从torchvision里拷贝得到,只是修改了一下下载路径。

import warnings
from PIL import Image
import os
import os.path
import numpy as np
import torch
from torchvision import datasets
import codecs
import string
import gzip
import lzma
from typing import Any, Callable, Dict, IO, List, Optional, Tuple, Union
from torchvision.datasets.utils import download_url, download_and_extract_archive, extract_archive, \verify_str_argdef get_int(b: bytes) -> int:return int(codecs.encode(b, 'hex'), 16)def open_maybe_compressed_file(path: Union[str, IO]) -> Union[IO, gzip.GzipFile]:"""Return a file object that possibly decompresses 'path' on the fly.Decompression occurs when argument `path` is a string and ends with '.gz' or '.xz'."""if not isinstance(path, torch._six.string_classes):return pathif path.endswith('.gz'):return gzip.open(path, 'rb')if path.endswith('.xz'):return lzma.open(path, 'rb')return open(path, 'rb')SN3_PASCALVINCENT_TYPEMAP = {8: (torch.uint8, np.uint8, np.uint8),9: (torch.int8, np.int8, np.int8),11: (torch.int16, np.dtype('>i2'), 'i2'),12: (torch.int32, np.dtype('>i4'), 'i4'),13: (torch.float32, np.dtype('>f4'), 'f4'),14: (torch.float64, np.dtype('>f8'), 'f8')
}def read_sn3_pascalvincent_tensor(path: Union[str, IO], strict: bool = True) -> torch.Tensor:"""Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh').Argument may be a filename, compressed filename, or file object."""# readwith open_maybe_compressed_file(path) as f:data = f.read()# parsemagic = get_int(data[0:4])nd = magic % 256ty = magic // 256assert nd >= 1 and nd <= 3assert ty >= 8 and ty <= 14m = SN3_PASCALVINCENT_TYPEMAP[ty]s = [get_int(data[4 * (i + 1): 4 * (i + 2)]) for i in range(nd)]parsed = np.frombuffer(data, dtype=m[1], offset=(4 * (nd + 1)))assert parsed.shape[0] == np.prod(s) or not strictreturn torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)def read_label_file(path: str) -> torch.Tensor:with open(path, 'rb') as f:x = read_sn3_pascalvincent_tensor(f, strict=False)assert(x.dtype == torch.uint8)assert(x.ndimension() == 1)return x.long()def read_image_file(path: str) -> torch.Tensor:with open(path, 'rb') as f:x = read_sn3_pascalvincent_tensor(f, strict=False)assert(x.dtype == torch.uint8)assert(x.ndimension() == 3)return xclass MNIST(datasets.VisionDataset):"""`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.Args:root (string): Root directory of dataset where ``MNIST/processed/training.pt``and  ``MNIST/processed/test.pt`` exist.train (bool, optional): If True, creates dataset from ``training.pt``,otherwise from ``test.pt``.download (bool, optional): If true, downloads the dataset from the internet andputs it in root directory. If dataset is already downloaded, it is notdownloaded again.transform (callable, optional): A function/transform that  takes in an PIL imageand returns a transformed version. E.g, ``transforms.RandomCrop``target_transform (callable, optional): A function/transform that takes in thetarget and transforms it."""resources = [("file://./data/MNIST/train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"),("file://./data/MNIST/train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"),("file://./data/MNIST/t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"),("file://./data/MNIST/t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c")]training_file = 'training.pt'test_file = 'test.pt'classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four','5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']@propertydef train_labels(self):warnings.warn("train_labels has been renamed targets")return self.targets@propertydef test_labels(self):warnings.warn("test_labels has been renamed targets")return self.targets@propertydef train_data(self):warnings.warn("train_data has been renamed data")return self.data@propertydef test_data(self):warnings.warn("test_data has been renamed data")return self.datadef __init__(self,root: str,train: bool = True,transform: Optional[Callable] = None,target_transform: Optional[Callable] = None,download: bool = False,) -> None:super(MNIST, self).__init__(root, transform=transform,target_transform=target_transform)self.train = train  # training set or test setif download:self.download()if not self._check_exists():raise RuntimeError('Dataset not found.' +' You can use download=True to download it')if self.train:data_file = self.training_fileelse:data_file = self.test_fileself.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file))def __getitem__(self, index: int) -> Tuple[Any, Any]:"""Args:index (int): IndexReturns:tuple: (image, target) where target is index of the target class."""img, target = self.data[index], int(self.targets[index])# doing this so that it is consistent with all other datasets# to return a PIL Imageimg = Image.fromarray(img.numpy(), mode='L')if self.transform is not None:img = self.transform(img)if self.target_transform is not None:target = self.target_transform(target)return img, targetdef __len__(self) -> int:return len(self.data)@propertydef raw_folder(self) -> str:return os.path.join(self.root, self.__class__.__name__, 'raw')@propertydef processed_folder(self) -> str:return os.path.join(self.root, self.__class__.__name__, 'processed')@propertydef class_to_idx(self) -> Dict[str, int]:return {_class: i for i, _class in enumerate(self.classes)}def _check_exists(self) -> bool:return (os.path.exists(os.path.join(self.processed_folder,self.training_file)) andos.path.exists(os.path.join(self.processed_folder,self.test_file)))def download(self) -> None:"""Download the MNIST data if it doesn't exist in processed_folder already."""if self._check_exists():returnos.makedirs(self.raw_folder, exist_ok=True)os.makedirs(self.processed_folder, exist_ok=True)# download filesfor url, md5 in self.resources:filename = url.rpartition('/')[2]download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5)# process and save as torch filesprint('Processing...')training_set = (read_image_file(os.path.join(self.raw_folder, 'train-images-idx3-ubyte')),read_label_file(os.path.join(self.raw_folder, 'train-labels-idx1-ubyte')))test_set = (read_image_file(os.path.join(self.raw_folder, 't10k-images-idx3-ubyte')),read_label_file(os.path.join(self.raw_folder, 't10k-labels-idx1-ubyte')))with open(os.path.join(self.processed_folder, self.training_file), 'wb') as f:torch.save(training_set, f)with open(os.path.join(self.processed_folder, self.test_file), 'wb') as f:torch.save(test_set, f)print('Done!')def extra_repr(self) -> str:return "Split: {}".format("Train" if self.train is True else "Test")
4. 训练代码
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import timeimport torch.nn.parallel
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.utils.data 
import torch.utils.data.distributed
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variablefrom MNIST import MNIST# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',help='number of epochs to train (default: 10)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',help='SGD momentum (default: 0.5)')
parser.add_argument('--no-cuda', action='store_true', default=False,help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',help='how many batches to wait before logging training status')
# 必须设置的参数 
parser.add_argument('--tcp', type=str, default='tcp://192.168.10.235:23456', metavar='N',help='how many batches to wait before logging training status')
parser.add_argument('--rank', type=int, default=0, metavar='N',help='pytorch distribued rank')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()#初始化
dist.init_process_group(init_method=args.tcp,backend="nccl",rank=args.rank,world_size=2,group_name="pytorch_test")torch.manual_seed(args.seed)
if args.cuda:torch.cuda.manual_seed(args.seed)train_dataset=MNIST('./data', train=True, download=True,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))]))
# 分发数据
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(MNIST('./data', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])),batch_size=args.test_batch_size, shuffle=True, **kwargs)class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 10, kernel_size=5)self.conv2 = nn.Conv2d(10, 20, kernel_size=5)self.conv2_drop = nn.Dropout2d()self.fc1 = nn.Linear(320, 50)self.fc2 = nn.Linear(50, 10)def forward(self, x):x = F.relu(F.max_pool2d(self.conv1(x), 2))x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))x = x.view(-1, 320)x = F.relu(self.fc1(x))x = F.dropout(x, training=self.training)x = self.fc2(x)return F.log_softmax(x)model = Net()
if args.cuda:# 分发模型model.cuda()model = torch.nn.parallel.DistributedDataParallel(model)# model = torch.nn.DataParallel(model,device_ids=[0,1,2,3]).cuda()# model.cuda()optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)def train(epoch):model.train()for batch_idx, (data, target) in enumerate(train_loader):if args.cuda:data, target = data.cuda(), target.cuda()data, target = Variable(data), Variable(target)optimizer.zero_grad()output = model(data)loss = F.nll_loss(output, target)loss.backward()optimizer.step()if batch_idx % args.log_interval == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))def test():model.eval()test_loss = 0correct = 0for data, target in test_loader:if args.cuda:data, target = data.cuda(), target.cuda()data, target = Variable(data, volatile=True), Variable(target)output = model(data)test_loss += F.nll_loss(output, target, size_average=False).item() # sum up batch losspred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probabilitycorrect += pred.eq(target.data.view_as(pred)).cpu().sum()test_loss /= len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))tot_time=0;for epoch in range(1, args.epochs + 1):# 设置epoch位置,这应该是个为了同步所做的工作train_sampler.set_epoch(epoch)start_cpu_secs = time.time()#long runningtrain(epoch)end_cpu_secs = time.time()print("Epoch {} of {} took {:.3f}s".format(epoch , args.epochs , end_cpu_secs - start_cpu_secs))tot_time+=end_cpu_secs - start_cpu_secstest()print("Total time= {:.3f}s".format(tot_time))
5. 运行代码

在两台电脑上分别运行代码即可

# 主机,rank为0
python test.py --tcp '192.168.10.235:23456' --rank 0

在另外一台电脑上运行

python test.py --tcp '192.168.10.235:23456' --rank 1

在这里插入图片描述

总结

本次就是一个简单的操作,具体细节原理就不讨论了,以后继续学习。

这篇关于Pytorch(1.2.0+):多机单卡并行实操(MNIST识别)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/527985

相关文章

Python中图片与PDF识别文本(OCR)的全面指南

《Python中图片与PDF识别文本(OCR)的全面指南》在数据爆炸时代,80%的企业数据以非结构化形式存在,其中PDF和图像是最主要的载体,本文将深入探索Python中OCR技术如何将这些数字纸张转... 目录一、OCR技术核心原理二、python图像识别四大工具库1. Pytesseract - 经典O

Python基于微信OCR引擎实现高效图片文字识别

《Python基于微信OCR引擎实现高效图片文字识别》这篇文章主要为大家详细介绍了一款基于微信OCR引擎的图片文字识别桌面应用开发全过程,可以实现从图片拖拽识别到文字提取,感兴趣的小伙伴可以跟随小编一... 目录一、项目概述1.1 开发背景1.2 技术选型1.3 核心优势二、功能详解2.1 核心功能模块2.

Python验证码识别方式(使用pytesseract库)

《Python验证码识别方式(使用pytesseract库)》:本文主要介绍Python验证码识别方式(使用pytesseract库),具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全... 目录1、安装Tesseract-OCR2、在python中使用3、本地图片识别4、结合playwrigh

Pytorch介绍与安装过程

《Pytorch介绍与安装过程》PyTorch因其直观的设计、卓越的灵活性以及强大的动态计算图功能,迅速在学术界和工业界获得了广泛认可,成为当前深度学习研究和开发的主流工具之一,本文给大家介绍Pyto... 目录1、Pytorch介绍1.1、核心理念1.2、核心组件与功能1.3、适用场景与优势总结1.4、优

conda安装GPU版pytorch默认却是cpu版本

《conda安装GPU版pytorch默认却是cpu版本》本文主要介绍了遇到Conda安装PyTorchGPU版本却默认安装CPU的问题,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的... 目录一、问题描述二、网上解决方案罗列【此节为反面方案罗列!!!】三、发现的根本原因[独家]3.1 p

PyTorch中cdist和sum函数使用示例详解

《PyTorch中cdist和sum函数使用示例详解》torch.cdist是PyTorch中用于计算**两个张量之间的成对距离(pairwisedistance)**的函数,常用于点云处理、图神经网... 目录基本语法输出示例1. 简单的 2D 欧几里得距离2. 批量形式(3D Tensor)3. 使用不

PyTorch高级特性与性能优化方式

《PyTorch高级特性与性能优化方式》:本文主要介绍PyTorch高级特性与性能优化方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、自动化机制1.自动微分机制2.动态计算图二、性能优化1.内存管理2.GPU加速3.多GPU训练三、分布式训练1.分布式数据

使用Python和PaddleOCR实现图文识别的代码和步骤

《使用Python和PaddleOCR实现图文识别的代码和步骤》在当今数字化时代,图文识别技术的应用越来越广泛,如文档数字化、信息提取等,PaddleOCR是百度开源的一款强大的OCR工具包,它集成了... 目录一、引言二、环境准备2.1 安装 python2.2 安装 PaddlePaddle2.3 安装

判断PyTorch是GPU版还是CPU版的方法小结

《判断PyTorch是GPU版还是CPU版的方法小结》PyTorch作为当前最流行的深度学习框架之一,支持在CPU和GPU(NVIDIACUDA)上运行,所以对于深度学习开发者来说,正确识别PyTor... 目录前言为什么需要区分GPU和CPU版本?性能差异硬件要求如何检查PyTorch版本?方法1:使用命

pytorch自动求梯度autograd的实现

《pytorch自动求梯度autograd的实现》autograd是一个自动微分引擎,它可以自动计算张量的梯度,本文主要介绍了pytorch自动求梯度autograd的实现,具有一定的参考价值,感兴趣... autograd是pytorch构建神经网络的核心。在 PyTorch 中,结合以下代码例子,当你