飞桨自然语言处理框架 paddlenlp的 trainer

2024-02-07 06:04

本文主要是介绍飞桨自然语言处理框架 paddlenlp的 trainer,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

飞桨(PaddlePaddle)的NLP库PaddleNLP中的Trainer类是一个用于训练和评估模型的简单但功能完整的循环。它被优化用于与PaddleNLP一起使用。Trainer类简化了训练过程,提供了自动的批处理、模型保存、日志记录等特性。
以下是Trainer类的主要参数和功能:

  1. 模型
    • model:可以是一个预训练的模型或一个自定义的paddle.nn.Layer。如果使用自定义模型,它需要与PaddleNLP提供的模型工作方式相同。
  2. criterion
    • 如果模型只输出logits,并且您想对模型的输出进行更多的计算,可以添加criterion层。
  3. args
    • args:用于调整训练的参数。如果未提供,将默认使用一个具有output_dir设置为当前目录中名为tmp_trainer的目录的基本TrainingArguments实例。
  4. 数据整理器(DataCollator)
    • data_collator:用于从train_dataseteval_dataset的列表中形成一批数据的功能。如果没有提供tokenizer,将默认使用default_data_collator;否则,将使用DataCollatorWithPadding的实例。
  5. 训练数据集和评估数据集
    • train_dataseteval_dataset:用于训练和评估的数据集。如果数据集是datasets.Dataset的实例,则不会接受model.forward()方法不接受的字段。
  6. 分词器(Tokenizer)
    • tokenizer:用于预处理数据的分词器。如果提供了,将在批量输入时自动将输入填充到最大长度,并在中断训练或重用模型时保存分词器。
  7. 计算指标(compute_metrics)
    • compute_metrics:用于在评估时计算指标的函数。它必须接受一个EvalPrediction对象并返回一个字典,字典中的字符串表示指标名称,对应的值表示指标值。
  8. 回调函数(callbacks)
    • callbacks:一个回调函数列表,用于自定义训练循环。可以将这些回调函数添加到默认回调函数列表中。如果想要移除默认使用的回调函数,可以使用Trainer.remove_callback方法。
  9. 优化器(optimizers)
    • optimizers:一个包含优化器和调度器的元组。如果没有提供,将默认使用AdamW优化器,并根据args使用get_linear_schedule_with_warmup调度器。
  10. 预处理logits用于指标(preprocess_logits_for_metrics)
    • preprocess_logits_for_metrics:一个函数,用于在每次评估步骤后预处理logits。它必须接受两个张量,即logits和labels,并返回处理后的logits。此函数的修改将在compute_metrics中反映在接收到的预测值上。
      Trainer类简化了训练流程,让用户可以更加专注于模型的设计和训练策略,而不必担心底层的训练细节。通过提供这些参数和功能,用户可以轻松地训练、评估和部署模型。
      paddlenlp/trainer/training_args.py

TrainingArguments 类是 PaddleNLP 中用于定义与训练循环相关的命令行参数的子集。这些参数用于配置训练过程的各种方面,例如输出目录、训练和评估的批处理大小、学习率、训练周期数等。通过 PdArgumentParser,可以将这个类转换为 argparse 参数,以便在命令行上指定。
以下是 TrainingArguments 类中一些关键参数的详细介绍:

  1. output_dir:模型预测和检查点的输出目录。
  2. overwrite_output_dir:如果为 True,将覆盖输出目录中的内容。
  3. do_train:是否运行训练。
  4. do_eval:是否在验证集上评估模型。
  5. do_predict:是否在测试集上进行预测。
  6. evaluation_strategy:训练期间采用的评估策略,可以是 “no”(不评估)、“steps”(每指定步数评估一次)或 “epoch”(每轮训练结束后评估)。
  7. per_device_train_batch_size:训练时的每个 GPU 核心/CPU 的批处理大小。
  8. per_device_eval_batch_size:评估时的每个 GPU 核心/CPU 的批处理大小。
  9. learning_rate:AdamW 优化器的初始学习率。
  10. num_train_epochs:要执行的总训练周期数。
  11. max_steps:要执行的总训练步数。
  12. log_on_each_node:在多节点分布式训练中,是否每个节点都进行日志记录。
  13. logging_dir:日志目录。
  14. logging_strategy:训练期间采用的日志策略,可以是 “no”(不记录日志)、“epoch”(每轮训练结束后记录日志)或 “steps”(每指定步数记录一次日志)。
  15. save_strategy:训练期间采用的检查点保存策略,可以是 “no”(不保存检查点)、“epoch”(每轮训练结束后保存检查点)或 “steps”(每指定步数保存一次检查点)。
    这些参数可以在您的训练脚本中使用,以配置和控制训练过程。TrainingArguments 类可以被转换为命令行参数,使用户能够轻松地在运行脚本时指定这些参数。
    TrainingArguments 类中的剩余参数用于进一步控制训练过程的高级特性,如混合精度训练、并行训练策略等。以下是对这些参数的详细介绍:
  16. save_steps:如果 save_strategy="steps",则在达到指定的步数之前保存两次检查点。
  17. save_total_limit:如果指定了值,将限制保存的检查点总数,并在 output_dir 中删除较旧的检查点。
  18. save_on_each_node:在多节点分布式训练中,是否在每个节点上保存模型和检查点。如果不同节点使用相同的存储,则不应激活此选项,因为文件名将相同。
  19. no_cuda:是否即使可用也禁用 CUDA。
  20. seed:训练开始时设置的随机种子,用于确保跨运行的可重复性。
  21. fp16:是否使用 16 位(混合)精度训练而不是 32 位训练。
  22. fp16_opt_level:对于 16 位训练,选择的 AMP 优化级别。
  23. amp_custom_black_list:自定义黑名单,用于指定哪些操作不应转换为 16 位或 32 位。
  24. amp_custom_white_list:自定义白名单,用于指定哪些操作应转换为 16 位或 32 位。
  25. amp_master_grad:对于 AMP 优化级别 'O2',是否使用 float32 权重梯度进行计算。
  26. sharding:是否使用 Paddle Sharding Data Parallel 训练。
  27. sharding_parallel_degree:在特定卡组中的 Sharding 参数。
  28. tensor_parallel_degree:张量并行度,用于指定将 Transformer 层分割成多少部分。
  29. pipeline_parallel_degree:流水线并行度,用于指定将所有 Transformer 层分割成多少阶段。
  30. sep_parallel_degree:Paddle 序列并行策略,可以减少激活 GPU 内存。
  31. tensor_parallel_config:一些影响模型并行性能的额外配置。
  32. pipeline_parallel_config:一些影响流水线并行使用的额外配置。
  33. sharding_parallel_config:一些影响 Sharding 并行的额外配置。
  34. recompute:是否在训练过程中重新计算梯度。
  35. num_workers:数据加载过程中使用的线程数。
  36. max_predictions_per_batch:每个批处理中最大预测的数量。
  37. prediction_loss_only:在评估和生成预测时,是否只返回损失。
    这些参数提供了对训练过程的细粒度控制,允许用户根据他们的需求和硬件配置来优化训练。在实际应用中,这些参数可能需要根据具体情况进行调整,以达到最佳的训练效果。
    在您提供的TrainingArguments类的参数说明中,涵盖了训练循环中涉及的各种配置选项。以下是对这些参数的详细介绍:
  38. per_device_train_batch_size:指定每个GPU核心/CPU在训练时使用的批处理大小。
  39. per_device_eval_batch_size:指定每个GPU核心/CPU在评估时使用的批处理大小。
  40. gradient_accumulation_steps:在执行反向传播和更新参数之前,累积梯度的步数。使用梯度累积时,一次步数对应一次反向传播。
  41. eval_accumulation_steps:在将预测结果移动到CPU之前,累积的预测步数。如果未设置,则整个预测结果将在GPU/TPU上累积后再移动到CPU。
  42. learning_rate:指定AdamW优化器的初始学习率。
  43. weight_decay:对AdamW优化器中的所有层(除偏置和LayerNorm层)应用的权重衰减。
  44. adam_beta1:AdamW优化器的beta1超参数。
  45. adam_beta2:AdamW优化器的beta2超参数。
  46. adam_epsilon:AdamW优化器的epsilon超参数。
  47. max_grad_norm:用于梯度裁剪的最大梯度范数。
  48. num_train_epochs:要执行的总训练周期数。
  49. max_steps:要执行的总训练步数。如果设置为正数,将覆盖num_train_epochs
  50. lr_scheduler_type:指定的学习率调度器类型。
  51. warmup_ratio:用于线性预热的总训练步数的比例。
  52. warmup_steps:用于从0到学习率进行线性预热的步数。
  53. num_cycles:余弦调度器中的波数。
  54. lr_end:多项式调度器中的结束学习率。
  55. power:多项式调度器中的幂因子。
    在训练过程的其他方面,还提供了以下配置选项:
  56. log_on_each_node:在多节点分布式训练中,是否每个节点都进行日志记录。
  57. logging_dir:日志目录。
  58. logging_strategy:训练期间采用的日志策略。
  59. logging_first_step:是否记录和评估第一个全局步骤。
  60. logging_steps:如果logging_strategy="steps",则两次日志之间的更新步数。
  61. save_strategy:训练期间采用的检查点保存策略。
  62. save_steps:如果save_strategy="steps",则在两次检查点保存之间的更新步数。
  63. save_total_limit:限制保存的检查点总数。
  64. save_on_each_node:在多节点分布式训练中,是否每个节点都保存模型和检查点。
  65. no_cuda:是否禁用CUDA。
  66. seed:训练开始时设置的随机种子。
  67. fp16:是否使用16位(混合)精度训练。
  68. fp16_opt_level:16位训练的AMP优化级别。
  69. amp_custom_black_list:自定义黑名单,用于指定哪些操作不应转换为16位或32位。
  70. amp_custom_white_list:自定义白名单,用于指定哪些操作应转换为16位或32位。
  71. amp_master_grad:是否使用float32权重梯度进行计算。
  72. sharding:是否使用Paddle Sharding Data Parallel训练。
  73. sharding_parallel_degree:Sharding参数,用于指定在特定卡组中的并行度。
  74. tensor_parallel_degree:张量并行度,用于指定将Transformer层分割成多少部分。
  75. pipeline_parallel_degree:流水线并行度,用于指定将所有Transformer层分割成多少阶段。

在您提供的TrainingArguments类的参数说明中,涵盖了训练循环中涉及的各种配置选项。以下是对这些参数的详细介绍:

  1. tensor_parallel_config:影响模型并行性能的一些额外配置,例如:
    • enable_mp_async_allreduce:支持在列并行线性反向传播期间的all_reduce(dx)与matmul(dw)重叠,可以加速模型并行性能。
    • enable_mp_skip_c_identity:支持在列并行线性和行并行线性中跳过c_identity,当与mp_async_allreduce一起设置时,可以进一步加速模型并行。
    • enable_mp_fused_linear_param_grad_add:支持在列并行线性中融合线性参数梯度添加,当与mp_async_allreduce一起设置时,可以进一步加速模型并行。
    • enable_delay_scale_loss:在优化器步骤累积梯度,所有梯度除以累积步数,而不是直接在损失上除以累积步数。
  2. pipeline_parallel_config:影响流水线并行使用的额外配置,例如:
    • disable_p2p_cache_shape:如果您使用的最大序列长度变化,请设置此选项。
    • disable_partial_send_recv:优化tensor并行的发送速度。
    • enable_delay_scale_loss:在优化器步骤累积梯度,所有梯度除以内部流水线累积步数,而不是直接在损失上除以累积步数。
    • enable_dp_comm_overlap:融合数据并行梯度通信。
    • enable_sharding_comm_overlap:融合sharding stage 1并行梯度通信。
    • enable_release_grads:在每次迭代后释放梯度,以减少峰值内存使用。梯度的创建将推迟到下一迭代的反向传播。
  3. sharding_parallel_config:影响Sharding并行的额外配置,例如:
    • enable_stage1_tensor_fusion:将小张量融合成大的张量块来加速通信,可能会增加内存占用。
    • enable_stage1_overlap:在回传计算之前,将小张量融合成大的张量块来加速通信,可能会损害回传速度。
    • enable_stage2_overlap:重叠stage2 NCCL通信与计算。重叠有一些约束,例如,对于广播重叠,logging_step应该大于1,在训练期间不应调用其他同步操作。
  4. recompute:是否重新计算前向传播以计算梯度。用于节省内存。仅支持具有transformer块的网络。
  5. scale_loss:fp16的初始scale_loss值。
  6. local_rank:分布式训练过程中的进程排名。
  7. dataloader_drop_last:是否丢弃最后一个不完整的批处理(如果数据集的长度不能被批处理大小整除)。
  8. eval_steps:如果evaluation_strategy="steps",则两次评估之间的更新步数。
  9. max_evaluate_steps:要执行的总评估步数。
  10. dataloader_num_workers:数据加载过程中使用的子进程数。
  11. past_index:一些模型如TransformerXL或XLNet可以使用过去的隐状态为其预测。如果此参数设置为正整数,Trainer将使用相应的输出(通常是索引2)作为过去状态,并在下一次训练步骤中将其提供给模型,作为关键字参数mems
  12. run_name:运行的描述符。通常用于日志记录。
  13. disable_tqdm:是否禁用tqdm进度条和指标表。
  14. remove_unused_columns:如果使用datasets.Dataset数据集,是否自动删除模型前向方法未使用的列。
  15. label_names:输入字典中对应于标签的键的列表。
  16. load_best_model_at_end:是否在训练结束时加载找到的最佳模型。
    在您提供的TrainingArguments类的参数说明中,涵盖了训练循环中涉及的各种配置选项。以下是对这些参数的详细介绍:
  17. metric_for_best_model:与load_best_model_at_end配合使用,指定用于比较两个不同模型的度量。必须是评估返回的度量名称,可以是带前缀"eval_"或不带前缀。如果未指定且load_best_model_at_end=True,则默认为"loss"(使用评估损失)。
  18. greater_is_better:与load_best_model_at_endmetric_for_best_model配合使用,指定更好的模型是否应该有更大的度量。默认为:
    • 如果metric_for_best_model设置为不是"loss""eval_loss"的值,则默认为True
    • 如果metric_for_best_model未设置,或设置为"loss""eval_loss",则默认为False
  19. ignore_data_skip:在继续训练时,是否跳过某些 epochs 和 batches 以确保数据加载与之前训练的数据加载阶段相同。如果设置为True,训练将更快开始,但结果可能与中断的训练结果不同。
  20. optim:要使用的优化器:adamwadafactor
  21. length_column_name:预计算长度的列名。如果该列存在,则按长度分组将使用这些值,而不是在训练启动时计算它们。除非group_by_lengthTrue且数据集是Dataset的实例。
  22. report_to:报告结果和日志的集成列表。支持的平台是"visualdl"/"wandb"/"tensorboard""none"表示不使用任何集成。
  23. wandb_api_key:Weights & Biases (WandB) API 密钥,用于与 WandB 服务进行身份验证。
  24. resume_from_checkpoint:要从中恢复的模型检查点的路径。此参数不直接由Trainer使用,而是由您的训练/评估脚本使用。
  25. flatten_param_grads:是否在优化器中使用flatten_param_grads方法,仅用于 NPU 设备。默认为False
  26. skip_profile_timer:是否跳过分析时间计时器,计时器将记录前向/反向/步等的时间使用情况。
  27. distributed_dataloader:是否使用分布式数据加载器。默认为False
    这些参数提供了对训练过程的细粒度控制,允许用户根据他们的需求和硬件配置来优化训练。在实际应用中,这些参数可能需要根据具体情况进行调整,以达到最佳的训练效果。

这篇关于飞桨自然语言处理框架 paddlenlp的 trainer的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/686725

相关文章

Java使用Thumbnailator库实现图片处理与压缩功能

《Java使用Thumbnailator库实现图片处理与压缩功能》Thumbnailator是高性能Java图像处理库,支持缩放、旋转、水印添加、裁剪及格式转换,提供易用API和性能优化,适合Web应... 目录1. 图片处理库Thumbnailator介绍2. 基本和指定大小图片缩放功能2.1 图片缩放的

Python进行JSON和Excel文件转换处理指南

《Python进行JSON和Excel文件转换处理指南》在数据交换与系统集成中,JSON与Excel是两种极为常见的数据格式,本文将介绍如何使用Python实现将JSON转换为格式化的Excel文件,... 目录将 jsON 导入为格式化 Excel将 Excel 导出为结构化 JSON处理嵌套 JSON:

Python Web框架Flask、Streamlit、FastAPI示例详解

《PythonWeb框架Flask、Streamlit、FastAPI示例详解》本文对比分析了Flask、Streamlit和FastAPI三大PythonWeb框架:Flask轻量灵活适合传统应用... 目录概述Flask详解Flask简介安装和基础配置核心概念路由和视图模板系统数据库集成实际示例Stre

Spring Boot 中的默认异常处理机制及执行流程

《SpringBoot中的默认异常处理机制及执行流程》SpringBoot内置BasicErrorController,自动处理异常并生成HTML/JSON响应,支持自定义错误路径、配置及扩展,如... 目录Spring Boot 异常处理机制详解默认错误页面功能自动异常转换机制错误属性配置选项默认错误处理

SpringBoot 异常处理/自定义格式校验的问题实例详解

《SpringBoot异常处理/自定义格式校验的问题实例详解》文章探讨SpringBoot中自定义注解校验问题,区分参数级与类级约束触发的异常类型,建议通过@RestControllerAdvice... 目录1. 问题简要描述2. 异常触发1) 参数级别约束2) 类级别约束3. 异常处理1) 字段级别约束

Olingo分析和实践之OData框架核心组件初始化(关键步骤)

《Olingo分析和实践之OData框架核心组件初始化(关键步骤)》ODataSpringBootService通过初始化OData实例和服务元数据,构建框架核心能力与数据模型结构,实现序列化、URI... 目录概述第一步:OData实例创建1.1 OData.newInstance() 详细分析1.1.1

Java堆转储文件之1.6G大文件处理完整指南

《Java堆转储文件之1.6G大文件处理完整指南》堆转储文件是优化、分析内存消耗的重要工具,:本文主要介绍Java堆转储文件之1.6G大文件处理的相关资料,文中通过代码介绍的非常详细,需要的朋友可... 目录前言文件为什么这么大?如何处理这个文件?分析文件内容(推荐)删除文件(如果不需要)查看错误来源如何避

使用Python构建一个高效的日志处理系统

《使用Python构建一个高效的日志处理系统》这篇文章主要为大家详细讲解了如何使用Python开发一个专业的日志分析工具,能够自动化处理、分析和可视化各类日志文件,大幅提升运维效率,需要的可以了解下... 目录环境准备工具功能概述完整代码实现代码深度解析1. 类设计与初始化2. 日志解析核心逻辑3. 文件处

Java docx4j高效处理Word文档的实战指南

《Javadocx4j高效处理Word文档的实战指南》对于需要在Java应用程序中生成、修改或处理Word文档的开发者来说,docx4j是一个强大而专业的选择,下面我们就来看看docx4j的具体使用... 目录引言一、环境准备与基础配置1.1 Maven依赖配置1.2 初始化测试类二、增强版文档操作示例2.

MyBatis-Plus通用中等、大量数据分批查询和处理方法

《MyBatis-Plus通用中等、大量数据分批查询和处理方法》文章介绍MyBatis-Plus分页查询处理,通过函数式接口与Lambda表达式实现通用逻辑,方法抽象但功能强大,建议扩展分批处理及流式... 目录函数式接口获取分页数据接口数据处理接口通用逻辑工具类使用方法简单查询自定义查询方法总结函数式接口