政安晨:【Keras机器学习示例演绎】(四十二)—— 使用 KerasNLP 和 tf.distribute 进行数据并行训练

本文主要是介绍政安晨:【Keras机器学习示例演绎】(四十二)—— 使用 KerasNLP 和 tf.distribute 进行数据并行训练,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

目录

简介

导入

基本批量大小和学习率

计算按比例分配的批量大小和学习率


政安晨的个人主页政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏: TensorFlow与Keras机器学习实战

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!

本文目标:使用 KerasNLP 和 tf.distribute 进行数据并行训练。

简介


分布式训练是一种在多台设备或机器上同时训练深度学习模型的技术。它有助于缩短训练时间,并允许使用更多数据训练更大的模型。KerasNLP 是一个为自然语言处理任务(包括分布式训练)提供工具和实用程序的库。

在本文中,我们将使用 KerasNLP 在 wikitext-2 数据集(维基百科文章的 200 万字数据集)上训练基于 BERT 的屏蔽语言模型 (MLM)。MLM 任务包括预测句子中的屏蔽词,这有助于模型学习单词的上下文表征。

本指南侧重于数据并行性,尤其是同步数据并行性,即每个加速器(GPU 或 TPU)都拥有一个完整的模型副本,并查看不同批次的部分输入数据。部分梯度在每个设备上计算、汇总,并用于计算全局梯度更新。

具体来说,本文将教您如何在以下两种设置中使用 tf.distribute API 在多个 GPU 上训练 Keras 模型,只需对代码做最小的改动:

—— 在一台机器上安装多个 GPU(通常为 2 至 8 个)(单主机、多设备训练)。这是研究人员和小规模行业工作流程最常见的设置。
—— 在由多台机器组成的集群上,每台机器安装一个或多个 GPU(多设备分布式训练)。这是大规模行业工作流程的良好设置,例如在 20-100 个 GPU 上对十亿字数据集进行高分辨率文本摘要模型训练。

!pip install -q --upgrade keras-nlp
!pip install -q --upgrade keras  # Upgrade to Keras 3.

导入

import osos.environ["KERAS_BACKEND"] = "tensorflow"import tensorflow as tf
import keras
import keras_nlp

在开始任何训练之前,让我们配置一下我们的单 GPU,使其显示为两个逻辑设备。

在使用两个或更多物理 GPU 进行训练时,这完全没有必要。这只是在默认 colab GPU 运行时(只有一个 GPU 可用)上显示真实分布式训练的一个技巧。

!nvidia-smi --query-gpu=memory.total --format=csv,noheader
physical_devices = tf.config.list_physical_devices("GPU")
tf.config.set_logical_device_configuration(physical_devices[0],[tf.config.LogicalDeviceConfiguration(memory_limit=15360 // 2),tf.config.LogicalDeviceConfiguration(memory_limit=15360 // 2),],
)logical_devices = tf.config.list_logical_devices("GPU")
logical_devicesEPOCHS = 3
24576 MiB

要使用 Keras 模型进行单主机、多设备同步训练,您需要使用 tf.distribute.MirroredStrategy API。下面是其工作原理:

—— 实例化 MirroredStrategy,可选择配置要使用的特定设备(默认情况下,该策略将使用所有可用的 GPU)。
—— 使用该策略对象打开一个作用域,并在该作用域中创建所需的包含变量的所有 Keras 对象。通常情况下,这意味着在分发作用域内创建和编译模型。
—— 像往常一样通过 fit() 训练模型。

strategy = tf.distribute.MirroredStrategy()
print(f"Number of devices: {strategy.num_replicas_in_sync}")
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1')
Number of devices: 2

基本批量大小和学习率

base_batch_size = 32
base_learning_rate = 1e-4

计算按比例分配的批量大小和学习率

scaled_batch_size = base_batch_size * strategy.num_replicas_in_sync
scaled_learning_rate = base_learning_rate * strategy.num_replicas_in_sync

现在,我们需要下载并预处理 wikitext-2 数据集。该数据集将用于预训练 BERT 模型。我们将过滤掉短行,以确保数据有足够的语境用于训练。

keras.utils.get_file(origin="https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip",extract=True,
)
wiki_dir = os.path.expanduser("~/.keras/datasets/wikitext-2/")# Load wikitext-103 and filter out short lines.
wiki_train_ds = (tf.data.TextLineDataset(wiki_dir + "wiki.train.tokens",).filter(lambda x: tf.strings.length(x) > 100).shuffle(buffer_size=500).batch(scaled_batch_size).cache().prefetch(tf.data.AUTOTUNE)
)
wiki_val_ds = (tf.data.TextLineDataset(wiki_dir + "wiki.valid.tokens").filter(lambda x: tf.strings.length(x) > 100).shuffle(buffer_size=500).batch(scaled_batch_size).cache().prefetch(tf.data.AUTOTUNE)
)
wiki_test_ds = (tf.data.TextLineDataset(wiki_dir + "wiki.test.tokens").filter(lambda x: tf.strings.length(x) > 100).shuffle(buffer_size=500).batch(scaled_batch_size).cache().prefetch(tf.data.AUTOTUNE)
)

在上述代码中,我们下载并提取了 wikitext-2 数据集。然后,我们定义了三个数据集:wiki_train_ds、wiki_val_ds 和 wiki_test_ds。我们对这些数据集进行了过滤,以去除短行,并对其进行批处理,以提高训练效率。

在 NLP 训练/调整中,使用衰减学习率是一种常见的做法。在这里,我们将使用多项式衰减时间表(PolynomialDecay schedule)。

total_training_steps = sum(1 for _ in wiki_train_ds.as_numpy_iterator()) * EPOCHS
lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(initial_learning_rate=scaled_learning_rate,decay_steps=total_training_steps,end_learning_rate=0.0,
)class PrintLR(tf.keras.callbacks.Callback):def on_epoch_end(self, epoch, logs=None):print(f"\nLearning rate for epoch {epoch + 1} is {model_dist.optimizer.learning_rate.numpy()}")

我们还要回调 TensorBoard,这样就能在本教程后半部分训练模型时可视化不同的指标。我们将所有回调放在一起,如下所示:

callbacks = [tf.keras.callbacks.TensorBoard(log_dir="./logs"),PrintLR(),
]print(tf.config.list_physical_devices("GPU"))
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

准备好数据集后,我们现在要在 strategy.scope() 中初始化并编译模型和优化器:

with strategy.scope():# Everything that creates variables should be under the strategy scope.# In general this is only model construction & `compile()`.model_dist = keras_nlp.models.BertMaskedLM.from_preset("bert_tiny_en_uncased")# This line just sets pooled_dense layer as non-trainiable, we do this to avoid# warnings of this layer being unusedmodel_dist.get_layer("bert_backbone").get_layer("pooled_dense").trainable = Falsemodel_dist.compile(loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),optimizer=tf.keras.optimizers.AdamW(learning_rate=scaled_learning_rate),weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],jit_compile=False,)model_dist.fit(wiki_train_ds, validation_data=wiki_val_ds, epochs=EPOCHS, callbacks=callbacks)
Epoch 1/3
Learning rate for epoch 1 is 0.00019999999494757503239/239 ━━━━━━━━━━━━━━━━━━━━ 43s 136ms/step - loss: 3.7009 - sparse_categorical_accuracy: 0.1499 - val_loss: 1.1509 - val_sparse_categorical_accuracy: 0.3485
Epoch 2/3239/239 ━━━━━━━━━━━━━━━━━━━━ 0s 122ms/step - loss: 2.6094 - sparse_categorical_accuracy: 0.5284
Learning rate for epoch 2 is 0.00019999999494757503239/239 ━━━━━━━━━━━━━━━━━━━━ 32s 133ms/step - loss: 2.6038 - sparse_categorical_accuracy: 0.5274 - val_loss: 0.9812 - val_sparse_categorical_accuracy: 0.4006
Epoch 3/3239/239 ━━━━━━━━━━━━━━━━━━━━ 0s 123ms/step - loss: 2.3564 - sparse_categorical_accuracy: 0.6053
Learning rate for epoch 3 is 0.00019999999494757503239/239 ━━━━━━━━━━━━━━━━━━━━ 32s 134ms/step - loss: 2.3514 - sparse_categorical_accuracy: 0.6040 - val_loss: 0.9213 - val_sparse_categorical_accuracy: 0.4230

根据范围拟合模型后,我们对其进行正常评估!

model_dist.evaluate(wiki_test_ds)
 29/29 ━━━━━━━━━━━━━━━━━━━━ 3s 60ms/step - loss: 1.9197 - sparse_categorical_accuracy: 0.8527[0.9470901489257812, 0.4373602867126465]

对于跨多台计算机的分布式训练(而不是只利用单台计算机上的多个设备进行训练),您可以使用两种分布式策略:MultiWorkerMirroredStrategy 和 ParameterServerStrategy:

—— tf.distribution.MultiWorkerMirroredStrategy(多工作站策略)实现了一种 CPU/GPU 多工作站同步解决方案,可与 Keras 风格的模型构建和训练循环配合使用,并使用跨副本的梯度同步还原。
—— tf.distribution.experimental.ParameterServerStrategy(参数服务器策略)实现了一种异步 CPU/GPU 多工作站解决方案,其中参数存储在参数服务器上,工作站异步更新梯度到参数服务器。


这篇关于政安晨:【Keras机器学习示例演绎】(四十二)—— 使用 KerasNLP 和 tf.distribute 进行数据并行训练的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/982333

相关文章

Java中流式并行操作parallelStream的原理和使用方法

《Java中流式并行操作parallelStream的原理和使用方法》本文详细介绍了Java中的并行流(parallelStream)的原理、正确使用方法以及在实际业务中的应用案例,并指出在使用并行流... 目录Java中流式并行操作parallelStream0. 问题的产生1. 什么是parallelS

Linux join命令的使用及说明

《Linuxjoin命令的使用及说明》`join`命令用于在Linux中按字段将两个文件进行连接,类似于SQL的JOIN,它需要两个文件按用于匹配的字段排序,并且第一个文件的换行符必须是LF,`jo... 目录一. 基本语法二. 数据准备三. 指定文件的连接key四.-a输出指定文件的所有行五.-o指定输出

Linux jq命令的使用解读

《Linuxjq命令的使用解读》jq是一个强大的命令行工具,用于处理JSON数据,它可以用来查看、过滤、修改、格式化JSON数据,通过使用各种选项和过滤器,可以实现复杂的JSON处理任务... 目录一. 简介二. 选项2.1.2.2-c2.3-r2.4-R三. 字段提取3.1 普通字段3.2 数组字段四.

Linux kill正在执行的后台任务 kill进程组使用详解

《Linuxkill正在执行的后台任务kill进程组使用详解》文章介绍了两个脚本的功能和区别,以及执行这些脚本时遇到的进程管理问题,通过查看进程树、使用`kill`命令和`lsof`命令,分析了子... 目录零. 用到的命令一. 待执行的脚本二. 执行含子进程的脚本,并kill2.1 进程查看2.2 遇到的

详解SpringBoot+Ehcache使用示例

《详解SpringBoot+Ehcache使用示例》本文介绍了SpringBoot中配置Ehcache、自定义get/set方式,并实际使用缓存的过程,文中通过示例代码介绍的非常详细,对大家的学习或者... 目录摘要概念内存与磁盘持久化存储:配置灵活性:编码示例引入依赖:配置ehcache.XML文件:配置

Java 虚拟线程的创建与使用深度解析

《Java虚拟线程的创建与使用深度解析》虚拟线程是Java19中以预览特性形式引入,Java21起正式发布的轻量级线程,本文给大家介绍Java虚拟线程的创建与使用,感兴趣的朋友一起看看吧... 目录一、虚拟线程简介1.1 什么是虚拟线程?1.2 为什么需要虚拟线程?二、虚拟线程与平台线程对比代码对比示例:三

k8s按需创建PV和使用PVC详解

《k8s按需创建PV和使用PVC详解》Kubernetes中,PV和PVC用于管理持久存储,StorageClass实现动态PV分配,PVC声明存储需求并绑定PV,通过kubectl验证状态,注意回收... 目录1.按需创建 PV(使用 StorageClass)创建 StorageClass2.创建 PV

Redis 基本数据类型和使用详解

《Redis基本数据类型和使用详解》String是Redis最基本的数据类型,一个键对应一个值,它的功能十分强大,可以存储字符串、整数、浮点数等多种数据格式,本文给大家介绍Redis基本数据类型和... 目录一、Redis 入门介绍二、Redis 的五大基本数据类型2.1 String 类型2.2 Hash

Redis中Hash从使用过程到原理说明

《Redis中Hash从使用过程到原理说明》RedisHash结构用于存储字段-值对,适合对象数据,支持HSET、HGET等命令,采用ziplist或hashtable编码,通过渐进式rehash优化... 目录一、开篇:Hash就像超市的货架二、Hash的基本使用1. 常用命令示例2. Java操作示例三

Linux创建服务使用systemctl管理详解

《Linux创建服务使用systemctl管理详解》文章指导在Linux中创建systemd服务,设置文件权限为所有者读写、其他只读,重新加载配置,启动服务并检查状态,确保服务正常运行,关键步骤包括权... 目录创建服务 /usr/lib/systemd/system/设置服务文件权限:所有者读写js,其他