分布式训练同步梯度出现形状不一致的解决方案

2024-09-06 19:12

本文主要是介绍分布式训练同步梯度出现形状不一致的解决方案,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1、问题描述

          为了加快大模型的训练速度,采用了分布式训练策略,基于MultiWorkerServerStrategy模式,集群之间采用Ring—Reduce的通信机制,不同节点在同步梯度会借助collective_ops.all_gather方法将梯度进行汇聚收集,汇聚过程出现了:

allreduce_1/CollectiveGather_1 Inconsitent output shapes,got[20],but expected is [22]

allreduce_1/CollectiveGather  Inconsitent output shapes,got[16,8],but expected is [20,8]

从而终止了训练继续进行。

2、原因分析

         直观看是因为不连续的输出形状,即要求的输出形状对于第一个是[22],却输出了[20],造成了不一致,查阅相关资料发现在tensorflow1.15早期的版本中,底层的源码文件tensorflow/core/kernels/collective_ops.cc

当col_params_.instance.shape.num_elements() == 0时表明是首次批来的时候,记住了output_shape,当第二批次或后面的批次再来的时候,强行判断和首次记住的形状保持一致,如果不一致就报错打印出了上面的“输出形状不连续的问题”,即errors::Internal里的内容。这也就是之所以报梯度形状不一致的根本原因。

3、问题解决

          分析清楚了原因,制定对应的解决办法。当然可以将该段代码的逻辑去掉,当后面批次再来的时候,不做判断,而是让col_params_.instance.shape=output_shape始终跟最新你的输出保持一致。如下所示:去掉了老版本里的if else的判断,直接让col_params_.instance.shape=out_shape,兼容输出可变化的动态形状

该种解决方案优点是从根本上解决,上层应用无感知,然而缺点是改完后要重新编译cc代码生成so文件,或者升级到最新的版本,对于不开放的网络环境,升级tensorflow或重新编译成本巨高。为了依旧使用老版本,尽可能不动底层,采用修改上层的方法,虽然繁琐一些,但是修改成本会低恨多。

        对该问题进行更深入的分析,到底为什么梯度输出的形状会发生改变,即out_shape和首次批的输出形状可能会不一致呢?仔细梳理了每个批次的梯度产生过程。以一个id类特征product_id为例,假定训练的batch_size=1024,product总个数30,embedding后总参数大小[30,8], 第一个批次输入的批次数据是[1024,8],即每次输入1024个样本,而梯度回传有时候是[20,8],有时候是[16,8],之所以和输入没有对齐,经分析发现在反向传播,通过collective_all_gather收集各个集群上的梯度时,是以特征变量为单位,不是以样本量来衡量的,即会收集每个训练的特征变量,在各个节点上的梯度,收回来做累加或其他聚合操作,在有些1024的批data中,product_id包含20个(不同样本的product_id会有重复),有些1024的批data中,product_id包含16个,这样反回来的梯度是这20个或16个product_id的emb的梯度,所以看到的梯度的形状是[20,8]和[16,8],这也体现了训练的过程更新的可训练特征变量,以不同特征变量的个数来组织梯度也顺利成章。

解决办法:当一个批次的数据大小是batch_size的时候,根据以上分析,某个特征变量的不同值的个数上限是batch_size个,因此把梯度的形状pad成[batch_size,dim],这样就就保证了每次进入collective_op.all_gather的形状保持了一致,另一个问题就是这个batch_size如何传入cross_device_utils.py,刚上来考虑通过获取tensorgrah中输入变量的第一维的值来作为batch_size,这样会有个问题就是,当最后batch的大小不够一个batch_size的时候,补的形状就和前面的又一样,还是会失败,训练最后一步挂掉;因此考虑传入固定的静态手工配置的batch_size,通过参数传递的方式,内部经过的链路很长,会进入不同的模块,才会传导到cross_device_utils.py,这种方式改动太大,自然而然想到共享内存,python的共享内存可能需要第三方的工具包,成本也高,进而考虑共享文件,启动的时候将静态固定的batch_size写入一个固定的目录文件,在cross_device_utils.py里用到的时候读取文件,这样改动的成本还是有些繁琐,最后考虑python夸文件的变量共享,在cross_device_utils.py,定义一个全局变量global_batch_size给定默认值256,在训练启动的python文件main方法里通过引用修改该变量,即:

from tensorflow.python.distribute import cross_device_utils

cross_device_utils.global_batch_size=Configs[‘batch_size’]

具体修改如下:

修改前:

修改后:

4、总结

        本文对Ring-AllReduce通信框架下分布式训练梯度收集形状不一致的问题进行了分析,并阐述了从最底层和偏上层的不同解决思路。对使用稍早版本的tf搭建分布式训练平台有一定的借鉴作用。

这篇关于分布式训练同步梯度出现形状不一致的解决方案的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1142878

相关文章

canal实现mysql数据同步的详细过程

《canal实现mysql数据同步的详细过程》:本文主要介绍canal实现mysql数据同步的详细过程,本文通过实例图文相结合给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的... 目录1、canal下载2、mysql同步用户创建和授权3、canal admin安装和启动4、canal

MyBatis Plus 中 update_time 字段自动填充失效的原因分析及解决方案(最新整理)

《MyBatisPlus中update_time字段自动填充失效的原因分析及解决方案(最新整理)》在使用MyBatisPlus时,通常我们会在数据库表中设置create_time和update... 目录前言一、问题现象二、原因分析三、总结:常见原因与解决方法对照表四、推荐写法前言在使用 MyBATis

Java死锁问题解决方案及示例详解

《Java死锁问题解决方案及示例详解》死锁是指两个或多个线程因争夺资源而相互等待,导致所有线程都无法继续执行的一种状态,本文给大家详细介绍了Java死锁问题解决方案详解及实践样例,需要的朋友可以参考下... 目录1、简述死锁的四个必要条件:2、死锁示例代码3、如何检测死锁?3.1 使用 jstack3.2

html 滚动条滚动过快会留下边框线的解决方案

《html滚动条滚动过快会留下边框线的解决方案》:本文主要介绍了html滚动条滚动过快会留下边框线的解决方案,解决方法很简单,详细内容请阅读本文,希望能对你有所帮助... 滚动条滚动过快时,会留下边框线但其实大部分时候是这样的,没有多出边框线的滚动条滚动过快时留下边框线的问题通常与滚动条样式和滚动行

Oracle修改端口号之后无法启动的解决方案

《Oracle修改端口号之后无法启动的解决方案》Oracle数据库更改端口后出现监听器无法启动的问题确实较为常见,但并非必然发生,这一问题通常源于​​配置错误或环境冲突​​,而非端口修改本身,以下是系... 目录一、问题根源分析​​​二、保姆级解决方案​​​​步骤1:修正监听器配置文件 (listener.

Linux实现线程同步的多种方式汇总

《Linux实现线程同步的多种方式汇总》本文详细介绍了Linux下线程同步的多种方法,包括互斥锁、自旋锁、信号量以及它们的使用示例,通过这些同步机制,可以解决线程安全问题,防止资源竞争导致的错误,示例... 目录什么是线程同步?一、互斥锁(单人洗手间规则)适用场景:特点:二、条件变量(咖啡厅取餐系统)工作流

MySQL版本问题导致项目无法启动问题的解决方案

《MySQL版本问题导致项目无法启动问题的解决方案》本文记录了一次因MySQL版本不一致导致项目启动失败的经历,详细解析了连接错误的原因,并提供了两种解决方案:调整连接字符串禁用SSL或统一MySQL... 目录本地项目启动报错报错原因:解决方案第一个:第二种:容器启动mysql的坑两种修改时区的方法:本地

安装centos8设置基础软件仓库时出错的解决方案

《安装centos8设置基础软件仓库时出错的解决方案》:本文主要介绍安装centos8设置基础软件仓库时出错的解决方案,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐... 目录安装Centos8设置基础软件仓库时出错版本 8版本 8.2.200android4版本 javas

Java中JSON格式反序列化为Map且保证存取顺序一致的问题

《Java中JSON格式反序列化为Map且保证存取顺序一致的问题》:本文主要介绍Java中JSON格式反序列化为Map且保证存取顺序一致的问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未... 目录背景问题解决方法总结背景做项目涉及两个微服务之间传数据时,需要提供方将Map类型的数据序列化为co

Mysql的主从同步/复制的原理分析

《Mysql的主从同步/复制的原理分析》:本文主要介绍Mysql的主从同步/复制的原理分析,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录为什么要主从同步?mysql主从同步架构有哪些?Mysql主从复制的原理/整体流程级联复制架构为什么好?Mysql主从复制注意