pmf-automl源码分析

2024-04-20 23:38
文章标签 分析 源码 automl pmf

本文主要是介绍pmf-automl源码分析,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

  • arxiv论文(有附录,但是字小)
    Probabilistic Matrix Factorization for Automated Machine Learning
  • NIPS2018论文(字大但是没有附录)
    Probabilistic Matrix Factorization for Automated Machine Learning
  • 代码
    https://github.com/rsheth80/pmf-automl

文章目录

  • 初窥项目文件
  • PMF模型训练
    • 数据切分
    • 初始隐变量
    • 模型的定义与训练
    • D个高斯过程的定义
    • 后验分布协方差矩阵的求解
      • transform_forward与transform_backward函数
      • get_cov函数的顶层设计
      • kernel的RBF
      • kernel的White
      • 求协方差矩阵复盘
    • GP前向函数的返回值的含义

初窥项目文件

用jupyter lab打开all_normalized_accuracy_with_pipelineID.csv
在这里插入图片描述

all_normalized_accuracy_with_pipelineID.zip contains the performance observations from running 42K pipelines on 553 OpenML datasets. The task was classification and the performance metric was balanced accuracy. Unzip prior to running code.

行表示pipeline id,列表示dataset id,元素表示balanced accuracy

在这里插入图片描述
简单查阅了一下pipelines.json,基本只有pcapolynomial两种preprocessor。

PMF模型训练

数据切分

Ytrain, Ytest, Ftrain, Ftest = get_data()
>>> Ytrain.shape
Out[2]: (42000, 464)
>>> Ytest.shape
Out[3]: (42000, 89)
>>> Ftrain.shape
Out[4]: (464, 46)
>>> Ftest.shape
Out[5]: (89, 46)

训练测试集切分,89个数据集作为测试集,464个训练集

初始隐变量

    imp = sklearn.impute.SimpleImputer(missing_values=np.nan, strategy='mean')X = sklearn.decomposition.PCA(Q).fit_transform(imp.fit(Ytrain).transform(Ytrain))
>>> X.shape
Out[7]: (42000, 20)

根据目前的理解,整个训练过程就是根据GP来训练X的隐变量。这个隐变量是用PCA初始化的。

处理训练集的缺失值,并降维为20维(42K个pipelines,数据集从553降为20个隐变量)

论文:the elements of Y Y Y are given by as nonlinear function of the latent variables, y n , d = f d ( x n ) + ϵ y_{n,d}=f_d(x_n)+\epsilon yn,d=fd(xn)+ϵ, where ϵ \epsilon ϵ is independent Gaussian noise.

这里的 Y Y Y指的是整个 42000 × 464 42000\times464 42000×464矩阵,那么 X X X就是pipeline空间的隐变量,这里隐变量维度 Q = 20 Q=20 Q=20 X X X的shape为 42000 × 20 42000\times20 42000×20

模型的定义与训练

模型的顶层定义:

    kernel = kernels.Add(kernels.RBF(Q, lengthscale=None), kernels.White(Q))m = gplvm.GPLVM(Q, X, Ytrain, kernel, N_max=N_max, D_max=batch_size)optimizer = torch.optim.SGD(m.parameters(), lr=lr)m = train(m, optimizer, f_callback=f_callback, f_stop=f_stop)

f_callbackf_stop都是两个local函数

    def f_callback(m, v, it, t):varn_list.append(transform_forward(m.variance).item())logpr_list.append(m().item()/m.D)if it == 1:t_list.append(t)else:t_list.append(t_list[-1] + t)if save_checkpoint and not (it % checkpoint_period):torch.save(m.state_dict(), fn_checkpoint + '_it%d.pt' % it)print('it=%d, f=%g, varn=%g, t: %g'% (it, logpr_list[-1], transform_forward(m.variance), t_list[-1]))
    def f_stop(m, v, it, t):if it >= maxiter-1:print('maxiter (%d) reached' % maxiter)return Truereturn False

看到训练函数train

def train(m, optimizer, f_callback=None, f_stop=None):it = 0while True:try:t = time.time()optimizer.zero_grad()nll = m()nll.backward()optimizer.step()it += 1t = time.time() - tif f_callback is not None:f_callback(m, nll, it, t)# f_stop should not be a substantial portion of total iteration timeif f_stop is not None and f_stop(m, nll, it, t):breakexcept KeyboardInterrupt:breakreturn m

论文公式(5):

N L L d = 1 2 ( N d l o g ( 2 π ) + l o g ∣ C d ∣ + Y c ( d )

这篇关于pmf-automl源码分析的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/921629

相关文章

Android 缓存日志Logcat导出与分析最佳实践

《Android缓存日志Logcat导出与分析最佳实践》本文全面介绍AndroidLogcat缓存日志的导出与分析方法,涵盖按进程、缓冲区类型及日志级别过滤,自动化工具使用,常见问题解决方案和最佳实... 目录android 缓存日志(Logcat)导出与分析全攻略为什么要导出缓存日志?按需过滤导出1. 按

Linux中的HTTPS协议原理分析

《Linux中的HTTPS协议原理分析》文章解释了HTTPS的必要性:HTTP明文传输易被篡改和劫持,HTTPS通过非对称加密协商对称密钥、CA证书认证和混合加密机制,有效防范中间人攻击,保障通信安全... 目录一、什么是加密和解密?二、为什么需要加密?三、常见的加密方式3.1 对称加密3.2非对称加密四、

MySQL中读写分离方案对比分析与选型建议

《MySQL中读写分离方案对比分析与选型建议》MySQL读写分离是提升数据库可用性和性能的常见手段,本文将围绕现实生产环境中常见的几种读写分离模式进行系统对比,希望对大家有所帮助... 目录一、问题背景介绍二、多种解决方案对比2.1 原生mysql主从复制2.2 Proxy层中间件:ProxySQL2.3

python使用Akshare与Streamlit实现股票估值分析教程(图文代码)

《python使用Akshare与Streamlit实现股票估值分析教程(图文代码)》入职测试中的一道题,要求:从Akshare下载某一个股票近十年的财务报表包括,资产负债表,利润表,现金流量表,保存... 目录一、前言二、核心知识点梳理1、Akshare数据获取2、Pandas数据处理3、Matplotl

python panda库从基础到高级操作分析

《pythonpanda库从基础到高级操作分析》本文介绍了Pandas库的核心功能,包括处理结构化数据的Series和DataFrame数据结构,数据读取、清洗、分组聚合、合并、时间序列分析及大数据... 目录1. Pandas 概述2. 基本操作:数据读取与查看3. 索引操作:精准定位数据4. Group

MySQL中EXISTS与IN用法使用与对比分析

《MySQL中EXISTS与IN用法使用与对比分析》在MySQL中,EXISTS和IN都用于子查询中根据另一个查询的结果来过滤主查询的记录,本文将基于工作原理、效率和应用场景进行全面对比... 目录一、基本用法详解1. IN 运算符2. EXISTS 运算符二、EXISTS 与 IN 的选择策略三、性能对比

MySQL 内存使用率常用分析语句

《MySQL内存使用率常用分析语句》用户整理了MySQL内存占用过高的分析方法,涵盖操作系统层确认及数据库层bufferpool、内存模块差值、线程状态、performance_schema性能数据... 目录一、 OS层二、 DB层1. 全局情况2. 内存占js用详情最近连续遇到mysql内存占用过高导致

深度解析Nginx日志分析与499状态码问题解决

《深度解析Nginx日志分析与499状态码问题解决》在Web服务器运维和性能优化过程中,Nginx日志是排查问题的重要依据,本文将围绕Nginx日志分析、499状态码的成因、排查方法及解决方案展开讨论... 目录前言1. Nginx日志基础1.1 Nginx日志存放位置1.2 Nginx日志格式2. 499

Olingo分析和实践之EDM 辅助序列化器详解(最佳实践)

《Olingo分析和实践之EDM辅助序列化器详解(最佳实践)》EDM辅助序列化器是ApacheOlingoOData框架中无需完整EDM模型的智能序列化工具,通过运行时类型推断实现灵活数据转换,适用... 目录概念与定义什么是 EDM 辅助序列化器?核心概念设计目标核心特点1. EDM 信息可选2. 智能类

Olingo分析和实践之OData框架核心组件初始化(关键步骤)

《Olingo分析和实践之OData框架核心组件初始化(关键步骤)》ODataSpringBootService通过初始化OData实例和服务元数据,构建框架核心能力与数据模型结构,实现序列化、URI... 目录概述第一步:OData实例创建1.1 OData.newInstance() 详细分析1.1.1