分布式训练同步梯度出现形状不一致的解决方案

2024-09-06 19:12

本文主要是介绍分布式训练同步梯度出现形状不一致的解决方案,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1、问题描述

          为了加快大模型的训练速度,采用了分布式训练策略,基于MultiWorkerServerStrategy模式,集群之间采用Ring—Reduce的通信机制,不同节点在同步梯度会借助collective_ops.all_gather方法将梯度进行汇聚收集,汇聚过程出现了:

allreduce_1/CollectiveGather_1 Inconsitent output shapes,got[20],but expected is [22]

allreduce_1/CollectiveGather  Inconsitent output shapes,got[16,8],but expected is [20,8]

从而终止了训练继续进行。

2、原因分析

         直观看是因为不连续的输出形状,即要求的输出形状对于第一个是[22],却输出了[20],造成了不一致,查阅相关资料发现在tensorflow1.15早期的版本中,底层的源码文件tensorflow/core/kernels/collective_ops.cc

当col_params_.instance.shape.num_elements() == 0时表明是首次批来的时候,记住了output_shape,当第二批次或后面的批次再来的时候,强行判断和首次记住的形状保持一致,如果不一致就报错打印出了上面的“输出形状不连续的问题”,即errors::Internal里的内容。这也就是之所以报梯度形状不一致的根本原因。

3、问题解决

          分析清楚了原因,制定对应的解决办法。当然可以将该段代码的逻辑去掉,当后面批次再来的时候,不做判断,而是让col_params_.instance.shape=output_shape始终跟最新你的输出保持一致。如下所示:去掉了老版本里的if else的判断,直接让col_params_.instance.shape=out_shape,兼容输出可变化的动态形状

该种解决方案优点是从根本上解决,上层应用无感知,然而缺点是改完后要重新编译cc代码生成so文件,或者升级到最新的版本,对于不开放的网络环境,升级tensorflow或重新编译成本巨高。为了依旧使用老版本,尽可能不动底层,采用修改上层的方法,虽然繁琐一些,但是修改成本会低恨多。

        对该问题进行更深入的分析,到底为什么梯度输出的形状会发生改变,即out_shape和首次批的输出形状可能会不一致呢?仔细梳理了每个批次的梯度产生过程。以一个id类特征product_id为例,假定训练的batch_size=1024,product总个数30,embedding后总参数大小[30,8], 第一个批次输入的批次数据是[1024,8],即每次输入1024个样本,而梯度回传有时候是[20,8],有时候是[16,8],之所以和输入没有对齐,经分析发现在反向传播,通过collective_all_gather收集各个集群上的梯度时,是以特征变量为单位,不是以样本量来衡量的,即会收集每个训练的特征变量,在各个节点上的梯度,收回来做累加或其他聚合操作,在有些1024的批data中,product_id包含20个(不同样本的product_id会有重复),有些1024的批data中,product_id包含16个,这样反回来的梯度是这20个或16个product_id的emb的梯度,所以看到的梯度的形状是[20,8]和[16,8],这也体现了训练的过程更新的可训练特征变量,以不同特征变量的个数来组织梯度也顺利成章。

解决办法:当一个批次的数据大小是batch_size的时候,根据以上分析,某个特征变量的不同值的个数上限是batch_size个,因此把梯度的形状pad成[batch_size,dim],这样就就保证了每次进入collective_op.all_gather的形状保持了一致,另一个问题就是这个batch_size如何传入cross_device_utils.py,刚上来考虑通过获取tensorgrah中输入变量的第一维的值来作为batch_size,这样会有个问题就是,当最后batch的大小不够一个batch_size的时候,补的形状就和前面的又一样,还是会失败,训练最后一步挂掉;因此考虑传入固定的静态手工配置的batch_size,通过参数传递的方式,内部经过的链路很长,会进入不同的模块,才会传导到cross_device_utils.py,这种方式改动太大,自然而然想到共享内存,python的共享内存可能需要第三方的工具包,成本也高,进而考虑共享文件,启动的时候将静态固定的batch_size写入一个固定的目录文件,在cross_device_utils.py里用到的时候读取文件,这样改动的成本还是有些繁琐,最后考虑python夸文件的变量共享,在cross_device_utils.py,定义一个全局变量global_batch_size给定默认值256,在训练启动的python文件main方法里通过引用修改该变量,即:

from tensorflow.python.distribute import cross_device_utils

cross_device_utils.global_batch_size=Configs[‘batch_size’]

具体修改如下:

修改前:

修改后:

4、总结

        本文对Ring-AllReduce通信框架下分布式训练梯度收集形状不一致的问题进行了分析,并阐述了从最底层和偏上层的不同解决思路。对使用稍早版本的tf搭建分布式训练平台有一定的借鉴作用。

这篇关于分布式训练同步梯度出现形状不一致的解决方案的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1142878

相关文章

C++ vector越界问题的完整解决方案

《C++vector越界问题的完整解决方案》在C++开发中,std::vector作为最常用的动态数组容器,其便捷性与性能优势使其成为处理可变长度数据的首选,然而,数组越界访问始终是威胁程序稳定性的... 目录引言一、vector越界的底层原理与危害1.1 越界访问的本质原因1.2 越界访问的实际危害二、基

Python 字符串裁切与提取全面且实用的解决方案

《Python字符串裁切与提取全面且实用的解决方案》本文梳理了Python字符串处理方法,涵盖基础切片、split/partition分割、正则匹配及结构化数据解析(如BeautifulSoup、j... 目录python 字符串裁切与提取的完整指南 基础切片方法1. 使用切片操作符[start:end]2

C#控制台程序同步调用WebApi实现方式

《C#控制台程序同步调用WebApi实现方式》控制台程序作为Job时,需同步调用WebApi以确保获取返回结果后执行后续操作,否则会引发TaskCanceledException异常,同步处理可避免异... 目录同步调用WebApi方法Cls001类里面的写法总结控制台程序一般当作Job使用,有时候需要控制

Redis分布式锁中Redission底层实现方式

《Redis分布式锁中Redission底层实现方式》Redission基于Redis原子操作和Lua脚本实现分布式锁,通过SETNX命令、看门狗续期、可重入机制及异常处理,确保锁的可靠性和一致性,是... 目录Redis分布式锁中Redission底层实现一、Redission分布式锁的基本使用二、Red

redis和redission分布式锁原理及区别说明

《redis和redission分布式锁原理及区别说明》文章对比了synchronized、乐观锁、Redis分布式锁及Redission锁的原理与区别,指出在集群环境下synchronized失效,... 目录Redis和redission分布式锁原理及区别1、有的同伴想到了synchronized关键字

Linux部署中的文件大小写问题的解决方案

《Linux部署中的文件大小写问题的解决方案》在本地开发环境(Windows/macOS)一切正常,但部署到Linux服务器后出现模块加载错误,核心原因是Linux文件系统严格区分大小写,所以本文给大... 目录问题背景解决方案配置要求问题背景在本地开发环境(Windows/MACOS)一切正常,但部署到

Java中InputStream重复使用问题的几种解决方案

《Java中InputStream重复使用问题的几种解决方案》在Java开发中,InputStream是用于读取字节流的类,在许多场景下,我们可能需要重复读取InputStream中的数据,这篇文章主... 目录前言1. 使用mark()和reset()方法(适用于支持标记的流)2. 将流内容缓存到字节数组

MybatisPlus中removeById删除数据库未变解决方案

《MybatisPlus中removeById删除数据库未变解决方案》MyBatisPlus中,removeById需实体类标注@TableId注解以识别数据库主键,若字段名不一致,应通过value属... 目录MyBATisPlus中removeBypythonId删除数据库未变removeById(Se

创建springBoot模块没有目录结构的解决方案

《创建springBoot模块没有目录结构的解决方案》2023版IntelliJIDEA创建模块时可能出现目录结构识别错误,导致文件显示异常,解决方法为选择模块后点击确认,重新校准项目结构设置,确保源... 目录创建spChina编程ringBoot模块没有目录结构解决方案总结创建springBoot模块没有目录

idea Maven Springboot多模块项目打包时90%的问题及解决方案

《ideaMavenSpringboot多模块项目打包时90%的问题及解决方案》:本文主要介绍ideaMavenSpringboot多模块项目打包时90%的问题及解决方案,具有很好的参考价值,... 目录1. 前言2. 问题3. 解决办法4. jar 包冲突总结1. 前言之所以写这篇文章是因为在使用Mav