机器学习笔记:如何用MLflow管理模型生命周期

2023-11-29 02:18

本文主要是介绍机器学习笔记:如何用MLflow管理模型生命周期,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1

MLflow介绍

MLflow是一个解决机器学习生命周期管理的平台,在上面对模型进行跟踪、重现、管理和部署。

MLFlow解决了如下几个问题:

1、算法训练实验难于追踪,所以我们需要有一个实验管理工具Tracking。这个工具能够记录算法,算法参数,模型结果,效果等数据。

2、算法脚本难于重复运行,原因很多,比如代码版本,模型参数,还有运行环境。解决办法就是所有的算法项目应该都有一套标准的Projects概念,记录下来这些东西,并且这个Projects是可以拟合所有算法框架的。

3、部署模型是一个艰难的过程,ML界目前还没有一个标准的打包和部署模型的机制。解决的办法是Models概念,Models提供了工具和标准帮助你部署各种算法框架的模型。

MLflow提供了三个主要组件:

  • MLflow 跟踪模块 ,用于记录实验数据的API和UI,其中的数据包括参数、代码版本、评价指标和使用过的输出文件。

  • MLflow 项目模块 ,用于可重现性运行的一种代码封装格式,将代码封装在MLflow项目中,可以指定其中的依赖关系,并允许任何其他用户稍后再次运行它,并可靠地对结果进行重现。

  • MLflow 模型模块 ,一种简单的模型封装格式,允许将模型部署到许多工具。例如,如果将模型封装为一个Python函数,MLflow模型就可以将其部署到Docker或Azure ML进行线上服务,部署到Apache Spark进行批量打分。

2

数据集介绍

这里使用的是Kaggle竞赛题中预测信用卡风险客户的数据集,数据链接:

https://www.kaggle.com/c/home-credit-default-risk/data。

其中有一张表中是HomeCredit_columns.csv,记录了每个栏位的含义。数据集任务是需要判断每单申请信贷资金的风险等级,这是一个比较明确的二分类任务。这里我们主要使用:

  • application_train.csv

  • application_test.csv

3

特征工程

train文件相较于test文件多了一栏目标变量TARGET,Kaggle上有一句非常经典的话,数据和特征决定了机器学习的上限,而模型和算法只是逼近这个上限而已。通过总结和归纳,特征工程的步骤主要包括:

from sklearn.preprocessing import LabelEncoder
import pickle
import pandas as pdtrain_data = pd.read.csv("./application_train.csv")
print(application_train.select_dtypes(include=['object']).head())le = LabelEncoder() 
train_data['ORGANIZATION_TYPE'] = le.fit_transform(train_data['ORGANIZATION_TYPE'].values.tolist()+["UNKNOWN"])
output = open('ORGANIZATION_TYPE.pkl', 'wb')
pickle.dump(le, output)
output.close()# test文件转换
test_data = pd.read.csv("./application_test.csv")
pkl_file = open('ORGANIZATION_TYPE.pkl', 'rb')
le_origanization = pickle.load(pkl_file) 
pkl_file.close()
# 查看transformer中类别
print(le_origanization.classes_)
# 替换test中ORGANIZATION_TYPE未出现在train文件中为"UNKNOWN"
ORGANIZATION_TYPE_diff = set(test_data["ORGANIZATION_TYPE"])-set(train_data["ORGANIZATION_TYPE"])
if len(ORGANIZATION_TYPE_diff)>0:replace_dict = dict(zip(list(ORGANIZATION_TYPE_diff),len(list(ORGANIZATION_TYPE_diff))*["UNKNOWN"]))test_data["ORGANIZATION_TYPE"].replace(replace_dict,inplace=True)test_data['ORGANIZATION_TYPE'] = le_departure.transform(test_data['ORGANIZATION_TYPE'])

实际工程实施中,离线跑批模型是将特征工程加工过程提交到spark、Hadoop等大数据平台上进行的,产出变量宽表后再进行预测。上面的LabelEncoder与StandardScaler步骤是需要在sql文件中加工完成的。

4

模型训练

当特征工程完成后,可以使用MLflow Tracking API记录模型运行参数,

import argparse
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, log_loss, roc_auc_score
import xgboost as xgb
import mlflow
import mlflow.xgboostdef parse_args():parser = argparse.ArgumentParser(description='XGBoost example')parser.add_argument('--max-depth', type=int, default=3,help='Maximum depth of a tree. Increasing this value will make the model more complex and more likely to overfit (default: 3)')parser.add_argument('--colsample-bytree', type=float, default=1.0,help='subsample ratio of columns when constructing each tree (default: 1.0)')parser.add_argument('--subsample', type=float, default=1.0,help='subsample ratio of the training instances (default: 1.0)')return parser.parse_args()def main():# parse command-line argumentsargs = parse_args()app_train_domain = pd.read_pickle("./app_train_domain.pkl")features_ids = app_train_domain['SK_ID_CURR']labels = app_train_domain['TARGET']features = app_train_domain.drop(columns = ['SK_ID_CURR', 'TARGET'])train_x, test_x, train_y, test_y = train_test_split(features, labels, random_state=11)dtrain = xgb.DMatrix(train_x, label = train_y)dtest = xgb.DMatrix(test_x, label = test_y)# enable auto loggingmlflow.xgboost.autolog()with mlflow.start_run():params={'booster':'gbtree','objective': 'binary:logistic','eval_metric': 'auc','max_depth':args.max_depth,'lambda':10,'subsample':args.subsample,'colsample_bytree':args.colsample_bytree,'min_child_weight':2,'eta': 0.025,'seed':0,'nthread':8,'silent':1}model=xgb.train(params,dtrain,num_boost_round=100,evals=[(dtrain, 'train'),(dtest, 'test')])y_proba = model.predict(dtest)auc = roc_auc_score(test_y, y_proba)# log metricsmlflow.log_metrics({'auc': auc})if __name__ == '__main__':main()

保存后运行代码,

python train.py --colsample-bytree 0.8 --subsample 0.9

或者以工程的方式运行,

mlflow run . -P colsample-bytree=0.8 -P subsample=0.85

输入

mlflow ui

即可在 http://127.0.0.1:5000 中看到模型运行结果,并进行多次结果的对比。

mlflow与众多机器学习框架都有相应的接口,例如通过:

mlflow.xgboost.save_model(model, "my_model")

将刚刚运行得到的模型进行保存,保存后得到的目录中包含:

  • conda.yaml 

  • MLmodel

  • model.xgb

最后利用model register功能将模型保存后,可以管理上线模型的版本:

from pprint import pprintclient = MlflowClient()
for rm in client.list_registered_models():pprint(dict(rm), indent=4)

这样就可以查看到所有Registered模型。

各位小伙伴如有兴趣交流,欢迎关注公众号留言勾搭

代码猴

在看点这里

这篇关于机器学习笔记:如何用MLflow管理模型生命周期的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/431084

相关文章

SpringBoot 多环境开发实战(从配置、管理与控制)

《SpringBoot多环境开发实战(从配置、管理与控制)》本文详解SpringBoot多环境配置,涵盖单文件YAML、多文件模式、MavenProfile分组及激活策略,通过优先级控制灵活切换环境... 目录一、多环境开发基础(单文件 YAML 版)(一)配置原理与优势(二)实操示例二、多环境开发多文件版

Redis实现高效内存管理的示例代码

《Redis实现高效内存管理的示例代码》Redis内存管理是其核心功能之一,为了高效地利用内存,Redis采用了多种技术和策略,如优化的数据结构、内存分配策略、内存回收、数据压缩等,下面就来详细的介绍... 目录1. 内存分配策略jemalloc 的使用2. 数据压缩和编码ziplist示例代码3. 优化的

Maven中生命周期深度解析与实战指南

《Maven中生命周期深度解析与实战指南》这篇文章主要为大家详细介绍了Maven生命周期实战指南,包含核心概念、阶段详解、SpringBoot特化场景及企业级实践建议,希望对大家有一定的帮助... 目录一、Maven 生命周期哲学二、default生命周期核心阶段详解(高频使用)三、clean生命周期核心阶

SpringBoot集成XXL-JOB实现任务管理全流程

《SpringBoot集成XXL-JOB实现任务管理全流程》XXL-JOB是一款轻量级分布式任务调度平台,功能丰富、界面简洁、易于扩展,本文介绍如何通过SpringBoot项目,使用RestTempl... 目录一、前言二、项目结构简述三、Maven 依赖四、Controller 代码详解五、Service

深入解析C++ 中std::map内存管理

《深入解析C++中std::map内存管理》文章详解C++std::map内存管理,指出clear()仅删除元素可能不释放底层内存,建议用swap()与空map交换以彻底释放,针对指针类型需手动de... 目录1️、基本清空std::map2️、使用 swap 彻底释放内存3️、map 中存储指针类型的对象

Linux系统管理与进程任务管理方式

《Linux系统管理与进程任务管理方式》本文系统讲解Linux管理核心技能,涵盖引导流程、服务控制(Systemd与GRUB2)、进程管理(前台/后台运行、工具使用)、计划任务(at/cron)及常用... 目录引言一、linux系统引导过程与服务控制1.1 系统引导的五个关键阶段1.2 GRUB2的进化优

Unity新手入门学习殿堂级知识详细讲解(图文)

《Unity新手入门学习殿堂级知识详细讲解(图文)》Unity是一款跨平台游戏引擎,支持2D/3D及VR/AR开发,核心功能模块包括图形、音频、物理等,通过可视化编辑器与脚本扩展实现开发,项目结构含A... 目录入门概述什么是 UnityUnity引擎基础认知编辑器核心操作Unity 编辑器项目模式分类工程

Spring Security 前后端分离场景下的会话并发管理

《SpringSecurity前后端分离场景下的会话并发管理》本文介绍了在前后端分离架构下实现SpringSecurity会话并发管理的问题,传统Web开发中只需简单配置sessionManage... 目录背景分析传统 web 开发中的 sessionManagement 入口ConcurrentSess

Python学习笔记之getattr和hasattr用法示例详解

《Python学习笔记之getattr和hasattr用法示例详解》在Python中,hasattr()、getattr()和setattr()是一组内置函数,用于对对象的属性进行操作和查询,这篇文章... 目录1.getattr用法详解1.1 基本作用1.2 示例1.3 原理2.hasattr用法详解2.

Linux之UDP和TCP报头管理方式

《Linux之UDP和TCP报头管理方式》文章系统讲解了传输层协议UDP与TCP的核心区别:UDP无连接、不可靠,适合实时传输(如视频),通过端口号标识应用;TCP有连接、可靠,通过确认应答、序号、窗... 目录一、关于端口号1.1 端口号的理解1.2 端口号范围的划分1.3 认识知名端口号1.4 一个进程