XGB-24:使用Scikit-Learn估计器接口

2024-03-27 16:52

本文主要是介绍XGB-24:使用Scikit-Learn估计器接口,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

概览

除了原生接口之外,XGBoost还提供了一个符合sklearn估计器指南的sklearn估计器接口。它支持回归、分类和学习排名。sklearn估计器接口的生存训练仍在进行中。
你可以在使用sklearn接口的示例集合中找到一些快速入门示例。使用sklearn接口的主要优势在于,它可以与sklearn提供的大多数实用程序一起工作,例如sklearn.model_selection.cross_validate()。此外,由于其流行度,许多其他库也认识sklearn估计器接口。
使用sklearn估计器接口,我们只需要几行Python代码就可以训练一个分类模型。下面是训练一个分类模型的示例:

from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_splitimport xgboost as xgbX, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=94)# Use "hist" for constructing the trees, with early stopping enabled.
clf = xgb.XGBClassifier(tree_method="hist", early_stopping_rounds=2)# Fit the model, test sets are used for early stopping.
clf.fit(X_train, y_train, eval_set=[(X_test, y_test)])# Save model into JSON format.
clf.save_model("clf.json")

tree_method参数指定了构建树的方法,而early_stopping_rounds参数则启用了提前停止。提前停止可以帮助防止过拟合,并在训练过程中节省时间。

提前停止Early Stopping

可以通过参数early_stopping_rounds启用提前停止。另外,还可以使用回调函数xgboost.callback.EarlyStopping来指定有关提前停止行为的更多细节,包括XGBoost是否应返回最佳模型而不是完整的树栈:

early_stop = xgb.callback.EarlyStopping(rounds=2, metric_name='logloss', data_name='Validation_0', save_best=True
)
clf = xgb.XGBClassifier(tree_method="hist", callbacks=[early_stop])
clf.fit(X_train, y_train, eval_set=[(X_test, y_test)])

目前,XGBoost在估计器中没有实现数据拆分逻辑,而是依赖于xgboost.XGBModel.fit()方法的eval_set参数。如果要使用提前停止来防止过拟合,需要使用sklearn库中的sklearn.model_selection.train_test_split()函数手动将数据拆分为训练集和测试集。一些其他的机器学习算法,比如sklearn中的算法,将提前停止作为估计器的一部分,并且可以与交叉验证一起使用。然而,在交叉验证过程中使用提前停止可能并不是一个完美的方法,因为它会改变每个验证折叠的模型树的数量,导致不同的模型。一个更好的方法是在交叉验证后使用最佳的超参数以及提前停止重新训练模型。如果想尝试使用提前停止进行交叉验证的想法,这是一个开始的代码片段:

from sklearn.base import clone
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import StratifiedKFold, cross_validateimport xgboost as xgbX, y = load_breast_cancer(return_X_y=True)def fit_and_score(estimator, X_train, X_test, y_train, y_test):"""Fit the estimator on the train set and score it on both sets"""estimator.fit(X_train, y_train, eval_set=[(X_test, y_test)])train_score = estimator.score(X_train, y_train)test_score = estimator.score(X_test, y_test)return estimator, train_score, test_scorecv = StratifiedKFold(n_splits=5, shuffle=True, random_state=94)clf = xgb.XGBClassifier(tree_method="hist", early_stopping_rounds=3)results = {}for train, test in cv.split(X, y):X_train = X[train]X_test = X[test]y_train = y[train]y_test = y[test]est, train_score, test_score = fit_and_score(clone(clf), X_train, X_test, y_train, y_test)results[est] = (train_score, test_score)

获取原生 booster 对象

**Sklearn 估计器接口主要用于训练,并没有实现 XGBoost 中所有可用的功能。**例如,为了获得缓存的预测结果,需要使用 xgboost.Booster.predict() 方法配合 xgboost.DMatrix。可以通过 xgboost.XGBModel.get_booster() 方法从 sklearn 接口中获取 booster 对象。

booster = clf.get_booster()
print(booster.num_boosted_rounds())

预测

当启用提前停止时,包括xgboost.XGBModel.predict()xgboost.XGBModel.score()xgboost.XGBModel.apply()在内的预测函数将自动使用最佳模型。这意味着xgboost.XGBModel.best_iteration用于指定在预测中使用的树的范围。
为了获得增量预测的缓存结果,可以使用xgboost.Booster.predict()方法。

并行线程数

在处理XGBoost和其他sklearn工具时,可以通过使用n_jobs参数来指定想要使用的线程数。默认情况下,XGBoost会使用计算机上所有可用的线程,这可能会在与sklearn的其他功能(如sklearn.model_selection.cross_validate())结合使用时产生一些有趣的结果。**如果XGBoost和sklearn都设置为使用所有线程,计算机可能会因为所谓的“线程颠簸”而显著变慢。**为了避免这种情况,只需将XGBoost的n_jobs参数设置为None(这使用了所有线程),并将sklearn的n_jobs参数设置为1。这样,这两个程序就能够顺畅地一起工作,而不会给计算机造成任何不必要的负担。

参考

  • https://scikit-learn.org/stable/developers/develop.html#rolling-your-own-estimator
  • https://xgboost.readthedocs.io/en/latest/python/examples/sklearn_examples.html#sphx-glr-python-examples-sklearn-examples-py
  • https://xgboost.readthedocs.io/en/latest/python/sklearn_estimator.html

这篇关于XGB-24:使用Scikit-Learn估计器接口的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/852808

相关文章

Java使用Javassist动态生成HelloWorld类

《Java使用Javassist动态生成HelloWorld类》Javassist是一个非常强大的字节码操作和定义库,它允许开发者在运行时创建新的类或者修改现有的类,本文将简单介绍如何使用Javass... 目录1. Javassist简介2. 环境准备3. 动态生成HelloWorld类3.1 创建CtC

使用Python批量将.ncm格式的音频文件转换为.mp3格式的实战详解

《使用Python批量将.ncm格式的音频文件转换为.mp3格式的实战详解》本文详细介绍了如何使用Python通过ncmdump工具批量将.ncm音频转换为.mp3的步骤,包括安装、配置ffmpeg环... 目录1. 前言2. 安装 ncmdump3. 实现 .ncm 转 .mp34. 执行过程5. 执行结

Java使用jar命令配置服务器端口的完整指南

《Java使用jar命令配置服务器端口的完整指南》本文将详细介绍如何使用java-jar命令启动应用,并重点讲解如何配置服务器端口,同时提供一个实用的Web工具来简化这一过程,希望对大家有所帮助... 目录1. Java Jar文件简介1.1 什么是Jar文件1.2 创建可执行Jar文件2. 使用java

C#使用Spire.Doc for .NET实现HTML转Word的高效方案

《C#使用Spire.Docfor.NET实现HTML转Word的高效方案》在Web开发中,HTML内容的生成与处理是高频需求,然而,当用户需要将HTML页面或动态生成的HTML字符串转换为Wor... 目录引言一、html转Word的典型场景与挑战二、用 Spire.Doc 实现 HTML 转 Word1

SpringBoot实现不同接口指定上传文件大小的具体步骤

《SpringBoot实现不同接口指定上传文件大小的具体步骤》:本文主要介绍在SpringBoot中通过自定义注解、AOP拦截和配置文件实现不同接口上传文件大小限制的方法,强调需设置全局阈值远大于... 目录一  springboot实现不同接口指定文件大小1.1 思路说明1.2 工程启动说明二 具体实施2

Java中的抽象类与abstract 关键字使用详解

《Java中的抽象类与abstract关键字使用详解》:本文主要介绍Java中的抽象类与abstract关键字使用详解,本文通过实例代码给大家介绍的非常详细,感兴趣的朋友跟随小编一起看看吧... 目录一、抽象类的概念二、使用 abstract2.1 修饰类 => 抽象类2.2 修饰方法 => 抽象方法,没有

MyBatis ParameterHandler的具体使用

《MyBatisParameterHandler的具体使用》本文主要介绍了MyBatisParameterHandler的具体使用,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参... 目录一、概述二、源码1 关键属性2.setParameters3.TypeHandler1.TypeHa

Spring 中的切面与事务结合使用完整示例

《Spring中的切面与事务结合使用完整示例》本文给大家介绍Spring中的切面与事务结合使用完整示例,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考... 目录 一、前置知识:Spring AOP 与 事务的关系 事务本质上就是一个“切面”二、核心组件三、完

使用docker搭建嵌入式Linux开发环境

《使用docker搭建嵌入式Linux开发环境》本文主要介绍了使用docker搭建嵌入式Linux开发环境,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面... 目录1、前言2、安装docker3、编写容器管理脚本4、创建容器1、前言在日常开发全志、rk等不同

使用Python实现Word文档的自动化对比方案

《使用Python实现Word文档的自动化对比方案》我们经常需要比较两个Word文档的版本差异,无论是合同修订、论文修改还是代码文档更新,人工比对不仅效率低下,还容易遗漏关键改动,下面通过一个实际案例... 目录引言一、使用python-docx库解析文档结构二、使用difflib进行差异比对三、高级对比方