NLP(六十六)使用HuggingFace中的Trainer进行BERT模型微调

2024-03-03 22:40

本文主要是介绍NLP(六十六)使用HuggingFace中的Trainer进行BERT模型微调,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

  以往,我们在使用HuggingFace在训练BERT模型时,代码写得比较复杂,涉及到数据处理、token编码、模型编码、模型训练等步骤,从事NLP领域的人都有这种切身感受。事实上,HugggingFace中提供了datasets模块(数据处理)和Trainer函数,使得我们的模型训练较为方便。关于datasets模块,可参考文章NLP(六十二)HuggingFace中的Datasets使用。
  本文将会介绍如何使用HuggingFace中的Trainer对BERT模型微调。

Trainer

  Trainer是HuggingFace中的模型训练函数,其网址为:https://huggingface.co/docs/transformers/main_classes/trainer 。
  Trainer的传入参数如下:

model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module] = None
args: TrainingArguments = None
data_collator: typing.Optional[DataCollator] = None
train_dataset: typing.Optional[torch.utils.data.dataset.Dataset] = None
eval_dataset: typing.Union[torch.utils.data.dataset.Dataset, typing.Dict[str, torch.utils.data.dataset.Dataset], NoneType] = None
tokenizer: typing.Optional[transformers.tokenization_utils_base.PreTrainedTokenizerBase] = None
model_init: typing.Union[typing.Callable[[], transformers.modeling_utils.PreTrainedModel], NoneType] = None
compute_metrics: typing.Union[typing.Callable[[transformers.trainer_utils.EvalPrediction], typing.Dict], NoneType] = None
callbacks: typing.Optional[typing.List[transformers.trainer_callback.TrainerCallback]] = None
optimizers: typing.Tuple[torch.optim.optimizer.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None)
preprocess_logits_for_metrics: typing.Union[typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor], NoneType] = None )

参数解释:

  • model为预训练模型
  • args为TrainingArguments(训练参数)类
  • data_collator会将数据集中的元素组成一个batch,默认使用default_data_collator(),如果tokenizer没有提供,则使用DataCollatorWithPadding
  • train_dataset, eval_dataset为训练集,验证集
  • tokenizer为模型训练使用的tokenizer
  • model_init为模型初始化
  • compute_metrics为验证集的评估指标计算函数
  • callbacks为训练过程中的callback列表
  • optimizers为模型训练中的优化器
  • preprocess_logits_for_metrics为模型评估阶段前对logits的预处理

  TrainingArguments为训练参数类,其网址为:https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments,传入参数非常多(transformers版本4.32.1中有98个参数!),我们在这里只介绍几个常见的:

output_dir: stroverwrite_output_dir: bool = False
evaluation_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'no'
per_gpu_train_batch_size: typing.Optional[int] = None
per_gpu_eval_batch_size: typing.Optional[int] = None
learning_rate: float = 5e-05
num_train_epochs: float = 3.0
logging_dir: typing.Optional[str] = None
logging_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps'
save_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps'save_steps: float = 500
report_to: typing.Optional[typing.List[str]] = None

参数解释:

  • output_dir为模型输出目录
  • evaluation_strategy为模型评估策略
  1. “no": 不做模型评估
  2. “steps”: 按训练步数(steps)进行评估,需指定步数
  3. “epoch”: 每个epoch训练完后进行评估
  • per_gpu_train_batch_size, per_gpu_eval_batch_size为每个GPU上训练集和测试集的batch size,也有CPU上的对应参数
  • learning_rate为学习率
  • logging_dir为日志输出目录
  • logging_strategy为日志输出策略,同样有no, steps, epoch三种,意义同上
  • save_strategy为模型保存策略,同样有no, steps, epoch三种,意义同上
  • report_to为模型训练、评估中的重要指标(如loss, accurace)输出之处,可选择azure_ml, clearml, codecarbon, comet_ml, dagshub, flyte, mlflow, neptune, tensorboard, wandb,使用all会输出到所有的地方,使用no则不会输出。

  下面我们使用Trainer进行BERT模型微调,给出英语、中文数据集上文本分类的示例代码。

BERT微调

  使用datasets模块导入imdb数据集(英语影评数据集,常用于文本分类),加载预训练模型bert-base-cased的tokenizer。

import numpy as np
from transformers import AutoTokenizer, DataCollatorWithPadding
import datasetscheckpoint = 'bert-base-cased'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
raw_datasets = datasets.load_dataset('imdb')

  查看数据集,有train(训练集)、test(测试集)、unsupervised(非监督)三部分,我们这里使用训练集和测试集,各自有25000个样本。

raw_datasets
DatasetDict({train: Dataset({features: ['text', 'label'],num_rows: 25000})test: Dataset({features: ['text', 'label'],num_rows: 25000})unsupervised: Dataset({features: ['text', 'label'],num_rows: 50000})
})

  创建数据tokenize函数,对文本进行tokenize,最大长度设置为300,同时使用data_collector为DataCollatorWithPadding。

def tokenize_function(sample):return tokenizer(sample['text'], max_length=300, truncation=True)
tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

  加载分类模型,输出类别为2.

from transformers import AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)

  设置compute_metrics函数,在评估过程中输出accuracy, f1, precision, recall四个指标。设置训练参数TrainingArguments类,设置Trainer。

from transformers import Trainer, TrainingArguments
from sklearn.metrics import accuracy_score, precision_recall_fscore_supportdef compute_metrics(pred):labels = pred.label_idspreds = pred.predictions.argmax(-1)precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')acc = accuracy_score(labels, preds)return {'accuracy': acc,'f1': f1,'precision': precision,'recall': recall}training_args = TrainingArguments(output_dir='imdb_test_trainer', # 指定输出文件夹,没有会自动创建evaluation_strategy="epoch",per_device_train_batch_size=32,per_device_eval_batch_size=32,learning_rate=5e-5,num_train_epochs=3,warmup_ratio=0.2,logging_dir='./imdb_train_logs',logging_strategy="epoch",save_strategy="epoch",report_to="tensorboard") trainer = Trainer(model,training_args,train_dataset=tokenized_datasets["train"],eval_dataset=tokenized_datasets["test"],data_collator=data_collator,  # 在定义了tokenizer之后,其实这里的data_collator就不用再写了,会自动根据tokenizer创建tokenizer=tokenizer,compute_metrics=compute_metrics
)

  开启模型训练。

trainer.train()
EpochTraining LossValidation LossAccuracyF1PrecisionRecall
10.3643000.2232230.9106000.9105090.9122760.910600
20.1648000.2044200.9239600.9239410.9243750.923960
30.0710000.2413500.9255200.9255100.9257590.925520
TrainOutput(global_step=588, training_loss=0.20003824169132986, metrics={'train_runtime': 1539.8692, 'train_samples_per_second': 48.705, 'train_steps_per_second': 0.382, 'total_flos': 1.156249755e+16, 'train_loss': 0.20003824169132986, 'epoch': 3.0})

  以上为英语数据集的文本分类模型微调。
  中文数据集使用sougou-mini数据集(训练集4000个样本,测试集495个样本,共5个输出类别),预训练模型采用bert-base-chinese。代码基本与英语数据集差不多,只要修改 预训练模型,数据集加载 和 最大长度为128,输出类别。以下是不同的代码之处:

import numpy as np
from transformers import AutoTokenizer, DataCollatorWithPadding
import datasetscheckpoint = 'bert-base-chinese'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)data_files = {"train": "./data/sougou/train.csv", "test": "./data/sougou/test.csv"}
raw_datasets = datasets.load_dataset("csv", data_files=data_files, delimiter=",")
...
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=5)
...

输出结果如下:

EpochTraining LossValidation LossAccuracyF1PrecisionRecall
10.8492000.1151890.9696970.9694490.9700730.969697
20.1069000.0939870.9737370.9737700.9753720.973737
30.0478000.0788610.9737370.9737400.9741170.973737

模型评估

  在上述模型评估过程中,已经有了模型评估的各项指标。
  本文也给出单独做模型评估的代码,方便后续对模型做量化时(后续介绍BERT模型的动态量化)获取量化前后模型推理的各项指标。
  中文数据集文本分类模型评估代码如下:

import torch
from transformers import AutoModelForSequenceClassificationMAX_LENGTH = 128
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
checkpoint = f"./sougou_test_trainer_{MAX_LENGTH}/checkpoint-96"
model = AutoModelForSequenceClassification.from_pretrained(checkpoint).to(device)from transformers import AutoTokenizer, DataCollatorWithPaddingtokenizer = AutoTokenizer.from_pretrained(checkpoint)import pandas as pdtest_df = pd.read_csv("./data/sougou/test.csv")
test_df.head()
textlabel
0届数比赛时间比赛地点参加国家和地区冠军亚军决赛成绩第一届1956-1957英国11美国丹麦6...0
1商品属性材质软橡胶带加浮雕工艺+合金彩色队徽吊牌规格162mm数量这一系列产品不限量发行图案...0
2今天下午,沈阳金德和长春亚泰队将在五里河相遇。在这两支球队中沈阳籍球员居多,因此这场比赛实际...0
3本报讯中国足协准备好了与特鲁西埃谈判的合同文本,也在北京给他预订好了房间,但特鲁西埃爽约了!...0
4网友点击发表评论祝贺中国队夺得五连冠搜狐体育讯北京时间5月6日,2006年尤伯杯羽毛球赛在日...0
import numpy as np
import times_time = time.time()
true_labels, pred_labels = [], [] 
for i, row in test_df.iterrows():row_s_time = time.time()true_labels.append(row["label"])encoded_text = tokenizer(row['text'], max_length=MAX_LENGTH, truncation=True, padding=True, return_tensors='pt').to(device)# print(encoded_text)logits = model(**encoded_text)label_id = np.argmax(logits[0].detach().cpu().numpy(), axis=1)[0]pred_labels.append(label_id)if i % 100 == 0:print(i, (time.time() - row_s_time)*1000, label_id)print("avg time: ", (time.time() - s_time) * 1000 / test_df.shape[0])
0 229.3872833251953 0
100 362.0314598083496 1
200 311.16747856140137 2
300 324.13792610168457 3
400 406.9099426269531 4
avg time:  352.44047810332944
true_labels[:10]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
pred_labels[:10]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
from sklearn.metrics import classification_reportprint(classification_report(true_labels, pred_labels, digits=4))
              precision    recall  f1-score   support0     0.9900    1.0000    0.9950        991     0.9691    0.9495    0.9592        992     0.9900    1.0000    0.9950        993     0.9320    0.9697    0.9505        994     0.9895    0.9495    0.9691        99accuracy                         0.9737       495macro avg     0.9741    0.9737    0.9737       495
weighted avg     0.9741    0.9737    0.9737       495

总结

  本文介绍了如何使用HuggingFace中的Trainer对BERT模型微调。可以看到,使用Trainer进行模型微调,代码较为简洁,且支持功能丰富,是理想的模型训练方式。
  本文项目代码已开源至Github,网址为:https://github.com/percent4/PyTorch_Learning/tree/master/huggingface_learning 。
  本人已开通个人博客网站,网址为:https://percent4.github.io/ ,欢迎大家访问~

  欢迎关注我的公众号NLP奇幻之旅,原创技术文章第一时间推送。

  欢迎关注我的知识星球“自然语言处理奇幻之旅”,笔者正在努力构建自己的技术社区。

这篇关于NLP(六十六)使用HuggingFace中的Trainer进行BERT模型微调的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/771150

相关文章

使用Python实现IP地址和端口状态检测与监控

《使用Python实现IP地址和端口状态检测与监控》在网络运维和服务器管理中,IP地址和端口的可用性监控是保障业务连续性的基础需求,本文将带你用Python从零打造一个高可用IP监控系统,感兴趣的小伙... 目录概述:为什么需要IP监控系统使用步骤说明1. 环境准备2. 系统部署3. 核心功能配置系统效果展

使用Java将各种数据写入Excel表格的操作示例

《使用Java将各种数据写入Excel表格的操作示例》在数据处理与管理领域,Excel凭借其强大的功能和广泛的应用,成为了数据存储与展示的重要工具,在Java开发过程中,常常需要将不同类型的数据,本文... 目录前言安装免费Java库1. 写入文本、或数值到 Excel单元格2. 写入数组到 Excel表格

redis中使用lua脚本的原理与基本使用详解

《redis中使用lua脚本的原理与基本使用详解》在Redis中使用Lua脚本可以实现原子性操作、减少网络开销以及提高执行效率,下面小编就来和大家详细介绍一下在redis中使用lua脚本的原理... 目录Redis 执行 Lua 脚本的原理基本使用方法使用EVAL命令执行 Lua 脚本使用EVALSHA命令

Java 中的 @SneakyThrows 注解使用方法(简化异常处理的利与弊)

《Java中的@SneakyThrows注解使用方法(简化异常处理的利与弊)》为了简化异常处理,Lombok提供了一个强大的注解@SneakyThrows,本文将详细介绍@SneakyThro... 目录1. @SneakyThrows 简介 1.1 什么是 Lombok?2. @SneakyThrows

使用Python和Pyecharts创建交互式地图

《使用Python和Pyecharts创建交互式地图》在数据可视化领域,创建交互式地图是一种强大的方式,可以使受众能够以引人入胜且信息丰富的方式探索地理数据,下面我们看看如何使用Python和Pyec... 目录简介Pyecharts 简介创建上海地图代码说明运行结果总结简介在数据可视化领域,创建交互式地

Java Stream流使用案例深入详解

《JavaStream流使用案例深入详解》:本文主要介绍JavaStream流使用案例详解,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录前言1. Lambda1.1 语法1.2 没参数只有一条语句或者多条语句1.3 一个参数只有一条语句或者多

利用python实现对excel文件进行加密

《利用python实现对excel文件进行加密》由于文件内容的私密性,需要对Excel文件进行加密,保护文件以免给第三方看到,本文将以Python语言为例,和大家讲讲如何对Excel文件进行加密,感兴... 目录前言方法一:使用pywin32库(仅限Windows)方法二:使用msoffcrypto-too

Java Spring 中 @PostConstruct 注解使用原理及常见场景

《JavaSpring中@PostConstruct注解使用原理及常见场景》在JavaSpring中,@PostConstruct注解是一个非常实用的功能,它允许开发者在Spring容器完全初... 目录一、@PostConstruct 注解概述二、@PostConstruct 注解的基本使用2.1 基本代

C#使用StackExchange.Redis实现分布式锁的两种方式介绍

《C#使用StackExchange.Redis实现分布式锁的两种方式介绍》分布式锁在集群的架构中发挥着重要的作用,:本文主要介绍C#使用StackExchange.Redis实现分布式锁的... 目录自定义分布式锁获取锁释放锁自动续期StackExchange.Redis分布式锁获取锁释放锁自动续期分布式

springboot使用Scheduling实现动态增删启停定时任务教程

《springboot使用Scheduling实现动态增删启停定时任务教程》:本文主要介绍springboot使用Scheduling实现动态增删启停定时任务教程,具有很好的参考价值,希望对大家有... 目录1、配置定时任务需要的线程池2、创建ScheduledFuture的包装类3、注册定时任务,增加、删