XGB-24:使用Scikit-Learn估计器接口

2024-03-27 16:52

本文主要是介绍XGB-24:使用Scikit-Learn估计器接口,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

概览

除了原生接口之外,XGBoost还提供了一个符合sklearn估计器指南的sklearn估计器接口。它支持回归、分类和学习排名。sklearn估计器接口的生存训练仍在进行中。
你可以在使用sklearn接口的示例集合中找到一些快速入门示例。使用sklearn接口的主要优势在于,它可以与sklearn提供的大多数实用程序一起工作,例如sklearn.model_selection.cross_validate()。此外,由于其流行度,许多其他库也认识sklearn估计器接口。
使用sklearn估计器接口,我们只需要几行Python代码就可以训练一个分类模型。下面是训练一个分类模型的示例:

from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_splitimport xgboost as xgbX, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=94)# Use "hist" for constructing the trees, with early stopping enabled.
clf = xgb.XGBClassifier(tree_method="hist", early_stopping_rounds=2)# Fit the model, test sets are used for early stopping.
clf.fit(X_train, y_train, eval_set=[(X_test, y_test)])# Save model into JSON format.
clf.save_model("clf.json")

tree_method参数指定了构建树的方法,而early_stopping_rounds参数则启用了提前停止。提前停止可以帮助防止过拟合,并在训练过程中节省时间。

提前停止Early Stopping

可以通过参数early_stopping_rounds启用提前停止。另外,还可以使用回调函数xgboost.callback.EarlyStopping来指定有关提前停止行为的更多细节,包括XGBoost是否应返回最佳模型而不是完整的树栈:

early_stop = xgb.callback.EarlyStopping(rounds=2, metric_name='logloss', data_name='Validation_0', save_best=True
)
clf = xgb.XGBClassifier(tree_method="hist", callbacks=[early_stop])
clf.fit(X_train, y_train, eval_set=[(X_test, y_test)])

目前,XGBoost在估计器中没有实现数据拆分逻辑,而是依赖于xgboost.XGBModel.fit()方法的eval_set参数。如果要使用提前停止来防止过拟合,需要使用sklearn库中的sklearn.model_selection.train_test_split()函数手动将数据拆分为训练集和测试集。一些其他的机器学习算法,比如sklearn中的算法,将提前停止作为估计器的一部分,并且可以与交叉验证一起使用。然而,在交叉验证过程中使用提前停止可能并不是一个完美的方法,因为它会改变每个验证折叠的模型树的数量,导致不同的模型。一个更好的方法是在交叉验证后使用最佳的超参数以及提前停止重新训练模型。如果想尝试使用提前停止进行交叉验证的想法,这是一个开始的代码片段:

from sklearn.base import clone
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import StratifiedKFold, cross_validateimport xgboost as xgbX, y = load_breast_cancer(return_X_y=True)def fit_and_score(estimator, X_train, X_test, y_train, y_test):"""Fit the estimator on the train set and score it on both sets"""estimator.fit(X_train, y_train, eval_set=[(X_test, y_test)])train_score = estimator.score(X_train, y_train)test_score = estimator.score(X_test, y_test)return estimator, train_score, test_scorecv = StratifiedKFold(n_splits=5, shuffle=True, random_state=94)clf = xgb.XGBClassifier(tree_method="hist", early_stopping_rounds=3)results = {}for train, test in cv.split(X, y):X_train = X[train]X_test = X[test]y_train = y[train]y_test = y[test]est, train_score, test_score = fit_and_score(clone(clf), X_train, X_test, y_train, y_test)results[est] = (train_score, test_score)

获取原生 booster 对象

**Sklearn 估计器接口主要用于训练,并没有实现 XGBoost 中所有可用的功能。**例如,为了获得缓存的预测结果,需要使用 xgboost.Booster.predict() 方法配合 xgboost.DMatrix。可以通过 xgboost.XGBModel.get_booster() 方法从 sklearn 接口中获取 booster 对象。

booster = clf.get_booster()
print(booster.num_boosted_rounds())

预测

当启用提前停止时,包括xgboost.XGBModel.predict()xgboost.XGBModel.score()xgboost.XGBModel.apply()在内的预测函数将自动使用最佳模型。这意味着xgboost.XGBModel.best_iteration用于指定在预测中使用的树的范围。
为了获得增量预测的缓存结果,可以使用xgboost.Booster.predict()方法。

并行线程数

在处理XGBoost和其他sklearn工具时,可以通过使用n_jobs参数来指定想要使用的线程数。默认情况下,XGBoost会使用计算机上所有可用的线程,这可能会在与sklearn的其他功能(如sklearn.model_selection.cross_validate())结合使用时产生一些有趣的结果。**如果XGBoost和sklearn都设置为使用所有线程,计算机可能会因为所谓的“线程颠簸”而显著变慢。**为了避免这种情况,只需将XGBoost的n_jobs参数设置为None(这使用了所有线程),并将sklearn的n_jobs参数设置为1。这样,这两个程序就能够顺畅地一起工作,而不会给计算机造成任何不必要的负担。

参考

  • https://scikit-learn.org/stable/developers/develop.html#rolling-your-own-estimator
  • https://xgboost.readthedocs.io/en/latest/python/examples/sklearn_examples.html#sphx-glr-python-examples-sklearn-examples-py
  • https://xgboost.readthedocs.io/en/latest/python/sklearn_estimator.html

这篇关于XGB-24:使用Scikit-Learn估计器接口的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/852808

相关文章

python使用库爬取m3u8文件的示例

《python使用库爬取m3u8文件的示例》本文主要介绍了python使用库爬取m3u8文件的示例,可以使用requests、m3u8、ffmpeg等库,实现获取、解析、下载视频片段并合并等步骤,具有... 目录一、准备工作二、获取m3u8文件内容三、解析m3u8文件四、下载视频片段五、合并视频片段六、错误

gitlab安装及邮箱配置和常用使用方式

《gitlab安装及邮箱配置和常用使用方式》:本文主要介绍gitlab安装及邮箱配置和常用使用方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1.安装GitLab2.配置GitLab邮件服务3.GitLab的账号注册邮箱验证及其分组4.gitlab分支和标签的

spring中的ImportSelector接口示例详解

《spring中的ImportSelector接口示例详解》Spring的ImportSelector接口用于动态选择配置类,实现条件化和模块化配置,关键方法selectImports根据注解信息返回... 目录一、核心作用二、关键方法三、扩展功能四、使用示例五、工作原理六、应用场景七、自定义实现Impor

SpringBoot3应用中集成和使用Spring Retry的实践记录

《SpringBoot3应用中集成和使用SpringRetry的实践记录》SpringRetry为SpringBoot3提供重试机制,支持注解和编程式两种方式,可配置重试策略与监听器,适用于临时性故... 目录1. 简介2. 环境准备3. 使用方式3.1 注解方式 基础使用自定义重试策略失败恢复机制注意事项

nginx启动命令和默认配置文件的使用

《nginx启动命令和默认配置文件的使用》:本文主要介绍nginx启动命令和默认配置文件的使用,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录常见命令nginx.conf配置文件location匹配规则图片服务器总结常见命令# 默认配置文件启动./nginx

在Windows上使用qemu安装ubuntu24.04服务器的详细指南

《在Windows上使用qemu安装ubuntu24.04服务器的详细指南》本文介绍了在Windows上使用QEMU安装Ubuntu24.04的全流程:安装QEMU、准备ISO镜像、创建虚拟磁盘、配置... 目录1. 安装QEMU环境2. 准备Ubuntu 24.04镜像3. 启动QEMU安装Ubuntu4

使用Python和OpenCV库实现实时颜色识别系统

《使用Python和OpenCV库实现实时颜色识别系统》:本文主要介绍使用Python和OpenCV库实现的实时颜色识别系统,这个系统能够通过摄像头捕捉视频流,并在视频中指定区域内识别主要颜色(红... 目录一、引言二、系统概述三、代码解析1. 导入库2. 颜色识别函数3. 主程序循环四、HSV色彩空间详解

Windows下C++使用SQLitede的操作过程

《Windows下C++使用SQLitede的操作过程》本文介绍了Windows下C++使用SQLite的安装配置、CppSQLite库封装优势、核心功能(如数据库连接、事务管理)、跨平台支持及性能优... 目录Windows下C++使用SQLite1、安装2、代码示例CppSQLite:C++轻松操作SQ

Python常用命令提示符使用方法详解

《Python常用命令提示符使用方法详解》在学习python的过程中,我们需要用到命令提示符(CMD)进行环境的配置,:本文主要介绍Python常用命令提示符使用方法的相关资料,文中通过代码介绍的... 目录一、python环境基础命令【Windows】1、检查Python是否安装2、 查看Python的安

Python并行处理实战之如何使用ProcessPoolExecutor加速计算

《Python并行处理实战之如何使用ProcessPoolExecutor加速计算》Python提供了多种并行处理的方式,其中concurrent.futures模块的ProcessPoolExecu... 目录简介完整代码示例代码解释1. 导入必要的模块2. 定义处理函数3. 主函数4. 生成数字列表5.