scikit-learn中用gridsearchcv给随机森林(RF)自动调参

2024-03-15 16:58

本文主要是介绍scikit-learn中用gridsearchcv给随机森林(RF)自动调参,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

全文参考 1:http://scikit-learn.org/stable/auto_examples/model_selection/grid_search_digits.html#parameter-estimation-using-grid-search-with-cross-validation

全文参考 2:http://scikit-learn.org/stable/modules/model_evaluation.html#the-scoring-parameter-defining-model-evaluation-rules

全文参考 3:http://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html#sklearn.metrics.roc_auc_score

全文参考 4:http://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html#sphx-glr-auto-examples-model-selection-plot-roc-py

实验重点:随机森林(RandomForest) + 5折交叉验证(Cross-Validation) + 网格参数寻优(GridSearchCV) + 二分类问题中ROC曲线的绘制。

由于原始数据本身质量很好,且正负样本基本均衡,没有做数据预处理工作。

  1. import pandas as pd
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. from sklearn.metrics import roc_curve
  5. from sklearn.metrics import roc_auc_score
  6. from sklearn.metrics import classification_report
  7. from sklearn.model_selection import GridSearchCV
  8. from sklearn.ensemble import RandomForestClassifier

  1. #导入数据,来源于:http://mldata.org/repository/tags/data/IDA_Benchmark_Repository/,见上图
  2. dataset = pd.read_csv('image_data.csv', header=None, encoding='utf-8')
  3. dataset_positive = dataset[dataset[0] == 1.0]
  4. dataset_negative = dataset[dataset[0] == -1.0]
  5. #训练集和测试集按照7:3分割,分割时兼顾正负样本所占比例
  6. #其中训练集基于5折交叉验证做网格搜索找出最优参数,应用于测试集以评价算法性能
  7. train_dataset = pd.concat([dataset_positive[0:832], dataset_negative[0:628]])
  8. train_recon = train_dataset.sort_index(axis=0, ascending=True)
  9. test_dataset = pd.concat([dataset_positive[832:1188], dataset_negative[628:898]])
  10. test_recon = test_dataset.sort_index(axis=0, ascending=True)
  11. y_train = np.array(train_recon[0])
  12. X_train = np.array(train_recon.drop([0], axis=1))
  13. y_test = np.array(test_recon[0])
  14. X_test = np.array(test_recon.drop([0], axis=1))

  1. # Set the parameters by cross-validation
  2. parameter_space = {
  3. "n_estimators": [10, 15, 20],
  4. "criterion": ["gini", "entropy"],
  5. "min_samples_leaf": [2, 4, 6],
  6. }
  7. #scores = ['precision', 'recall', 'roc_auc']
  8. scores = ['roc_auc']
  9. for score in scores:
  10. print("# Tuning hyper-parameters for %s" % score)
  11. print()
  12. clf = RandomForestClassifier(random_state=14)
  13. grid = GridSearchCV(clf, parameter_space, cv=5, scoring='%s' % score)
  14. #scoring='%s_macro' % score:precision_macro、recall_macro是用于multiclass/multilabel任务的
  15. grid.fit(X_train, y_train)
  16. print("Best parameters set found on development set:")
  17. print()
  18. print(grid.best_params_)
  19. print()
  20. print("Grid scores on development set:")
  21. print()
  22. means = grid.cv_results_['mean_test_score']
  23. stds = grid.cv_results_['std_test_score']
  24. for mean, std, params in zip(means, stds, grid.cv_results_['params']):
  25. print("%0.3f (+/-%0.03f) for %r"
  26. % (mean, std * 2, params))
  27. print()
  28. print("Detailed classification report:")
  29. print()
  30. print("The model is trained on the full development set.")
  31. print("The scores are computed on the full evaluation set.")
  32. print()
  33. bclf = grid.best_estimator_
  34. bclf.fit(X_train, y_train)
  35. y_true = y_test
  36. y_pred = bclf.predict(X_test)
  37. y_pred_pro = bclf.predict_proba(X_test)
  38. y_scores = pd.DataFrame(y_pred_pro, columns=bclf.classes_.tolist())[1].values
  39. print(classification_report(y_true, y_pred))
  40. auc_value = roc_auc_score(y_true, y_scores)

输出结果:


  1. #绘制ROC曲线
  2. fpr, tpr, thresholds = roc_curve(y_true, y_scores, pos_label=1.0)
  3. plt.figure()
  4. lw = 2
  5. plt.plot(fpr, tpr, color='darkorange', linewidth=lw, label='ROC curve (area = %0.4f)' % auc_value)
  6. plt.plot([0, 1], [0, 1], color='navy', linewidth=lw, linestyle='--')
  7. plt.xlim([0.0, 1.0])
  8. plt.ylim([0.0, 1.05])
  9. plt.xlabel('False Positive Rate')
  10. plt.ylabel('True Positive Rate')
  11. plt.title('Receiver operating characteristic example')
  12. plt.legend(loc="lower right")
  13. plt.show()




转自: https://blog.csdn.net/lixiaowang_327/article/details/53434744

这篇关于scikit-learn中用gridsearchcv给随机森林(RF)自动调参的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/812643

相关文章

一文详解MySQL如何设置自动备份任务

《一文详解MySQL如何设置自动备份任务》设置自动备份任务可以确保你的数据库定期备份,防止数据丢失,下面我们就来详细介绍一下如何使用Bash脚本和Cron任务在Linux系统上设置MySQL数据库的自... 目录1. 编写备份脚本1.1 创建并编辑备份脚本1.2 给予脚本执行权限2. 设置 Cron 任务2

MyBatis Plus 中 update_time 字段自动填充失效的原因分析及解决方案(最新整理)

《MyBatisPlus中update_time字段自动填充失效的原因分析及解决方案(最新整理)》在使用MyBatisPlus时,通常我们会在数据库表中设置create_time和update... 目录前言一、问题现象二、原因分析三、总结:常见原因与解决方法对照表四、推荐写法前言在使用 MyBATis

Python使用smtplib库开发一个邮件自动发送工具

《Python使用smtplib库开发一个邮件自动发送工具》在现代软件开发中,自动化邮件发送是一个非常实用的功能,无论是系统通知、营销邮件、还是日常工作报告,Python的smtplib库都能帮助我们... 目录代码实现与知识点解析1. 导入必要的库2. 配置邮件服务器参数3. 创建邮件发送类4. 实现邮件

Python使用pynput模拟实现键盘自动输入工具

《Python使用pynput模拟实现键盘自动输入工具》在日常办公和软件开发中,我们经常需要处理大量重复的文本输入工作,所以本文就来和大家介绍一款使用Python的PyQt5库结合pynput键盘控制... 目录概述:当自动化遇上可视化功能全景图核心功能矩阵技术栈深度效果展示使用教程四步操作指南核心代码解析

SpringBoot实现文件记录日志及日志文件自动归档和压缩

《SpringBoot实现文件记录日志及日志文件自动归档和压缩》Logback是Java日志框架,通过Logger收集日志并经Appender输出至控制台、文件等,SpringBoot配置logbac... 目录1、什么是Logback2、SpringBoot实现文件记录日志,日志文件自动归档和压缩2.1、

SpringCloud使用Nacos 配置中心实现配置自动刷新功能使用

《SpringCloud使用Nacos配置中心实现配置自动刷新功能使用》SpringCloud项目中使用Nacos作为配置中心可以方便开发及运维人员随时查看配置信息,及配置共享,并且Nacos支持配... 目录前言一、Nacos中集中配置方式?二、使用步骤1.使用$Value 注解2.使用@Configur

Golang实现Redis分布式锁(Lua脚本+可重入+自动续期)

《Golang实现Redis分布式锁(Lua脚本+可重入+自动续期)》本文主要介绍了Golang分布式锁实现,采用Redis+Lua脚本确保原子性,持可重入和自动续期,用于防止超卖及重复下单,具有一定... 目录1 概念应用场景分布式锁必备特性2 思路分析宕机与过期防止误删keyLua保证原子性可重入锁自动

python利用backoff实现异常自动重试详解

《python利用backoff实现异常自动重试详解》backoff是一个用于实现重试机制的Python库,通过指数退避或其他策略自动重试失败的操作,下面小编就来和大家详细讲讲如何利用backoff实... 目录1. backoff 库简介2. on_exception 装饰器的原理2.1 核心逻辑2.2

Java如何根据文件名前缀自动分组图片文件

《Java如何根据文件名前缀自动分组图片文件》一大堆文件(比如图片)堆在一个目录下,它们的命名规则遵循一定的格式,混在一起很难管理,所以本文小编就和大家介绍一下如何使用Java根据文件名前缀自动分组图... 目录需求背景分析思路实现代码输出结果知识扩展需求一大堆文件(比如图片)堆在一个目录下,它们的命名规

使用Python实现实时金价监控并自动提醒功能

《使用Python实现实时金价监控并自动提醒功能》在日常投资中,很多朋友喜欢在一些平台买点黄金,低买高卖赚点小差价,但黄金价格实时波动频繁,总是盯着手机太累了,于是我用Python写了一个实时金价监控... 目录工具能干啥?手把手教你用1、先装好这些"食材"2、代码实现讲解1. 用户输入参数2. 设置无头浏