pmf-automl源码分析

2024-04-20 23:38
文章标签 分析 源码 automl pmf

本文主要是介绍pmf-automl源码分析,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

  • arxiv论文(有附录,但是字小)
    Probabilistic Matrix Factorization for Automated Machine Learning
  • NIPS2018论文(字大但是没有附录)
    Probabilistic Matrix Factorization for Automated Machine Learning
  • 代码
    https://github.com/rsheth80/pmf-automl

文章目录

  • 初窥项目文件
  • PMF模型训练
    • 数据切分
    • 初始隐变量
    • 模型的定义与训练
    • D个高斯过程的定义
    • 后验分布协方差矩阵的求解
      • transform_forward与transform_backward函数
      • get_cov函数的顶层设计
      • kernel的RBF
      • kernel的White
      • 求协方差矩阵复盘
    • GP前向函数的返回值的含义

初窥项目文件

用jupyter lab打开all_normalized_accuracy_with_pipelineID.csv
在这里插入图片描述

all_normalized_accuracy_with_pipelineID.zip contains the performance observations from running 42K pipelines on 553 OpenML datasets. The task was classification and the performance metric was balanced accuracy. Unzip prior to running code.

行表示pipeline id,列表示dataset id,元素表示balanced accuracy

在这里插入图片描述
简单查阅了一下pipelines.json,基本只有pcapolynomial两种preprocessor。

PMF模型训练

数据切分

Ytrain, Ytest, Ftrain, Ftest = get_data()
>>> Ytrain.shape
Out[2]: (42000, 464)
>>> Ytest.shape
Out[3]: (42000, 89)
>>> Ftrain.shape
Out[4]: (464, 46)
>>> Ftest.shape
Out[5]: (89, 46)

训练测试集切分,89个数据集作为测试集,464个训练集

初始隐变量

    imp = sklearn.impute.SimpleImputer(missing_values=np.nan, strategy='mean')X = sklearn.decomposition.PCA(Q).fit_transform(imp.fit(Ytrain).transform(Ytrain))
>>> X.shape
Out[7]: (42000, 20)

根据目前的理解,整个训练过程就是根据GP来训练X的隐变量。这个隐变量是用PCA初始化的。

处理训练集的缺失值,并降维为20维(42K个pipelines,数据集从553降为20个隐变量)

论文:the elements of Y Y Y are given by as nonlinear function of the latent variables, y n , d = f d ( x n ) + ϵ y_{n,d}=f_d(x_n)+\epsilon yn,d=fd(xn)+ϵ, where ϵ \epsilon ϵ is independent Gaussian noise.

这里的 Y Y Y指的是整个 42000 × 464 42000\times464 42000×464矩阵,那么 X X X就是pipeline空间的隐变量,这里隐变量维度 Q = 20 Q=20 Q=20 X X X的shape为 42000 × 20 42000\times20 42000×20

模型的定义与训练

模型的顶层定义:

    kernel = kernels.Add(kernels.RBF(Q, lengthscale=None), kernels.White(Q))m = gplvm.GPLVM(Q, X, Ytrain, kernel, N_max=N_max, D_max=batch_size)optimizer = torch.optim.SGD(m.parameters(), lr=lr)m = train(m, optimizer, f_callback=f_callback, f_stop=f_stop)

f_callbackf_stop都是两个local函数

    def f_callback(m, v, it, t):varn_list.append(transform_forward(m.variance).item())logpr_list.append(m().item()/m.D)if it == 1:t_list.append(t)else:t_list.append(t_list[-1] + t)if save_checkpoint and not (it % checkpoint_period):torch.save(m.state_dict(), fn_checkpoint + '_it%d.pt' % it)print('it=%d, f=%g, varn=%g, t: %g'% (it, logpr_list[-1], transform_forward(m.variance), t_list[-1]))
    def f_stop(m, v, it, t):if it >= maxiter-1:print('maxiter (%d) reached' % maxiter)return Truereturn False

看到训练函数train

def train(m, optimizer, f_callback=None, f_stop=None):it = 0while True:try:t = time.time()optimizer.zero_grad()nll = m()nll.backward()optimizer.step()it += 1t = time.time() - tif f_callback is not None:f_callback(m, nll, it, t)# f_stop should not be a substantial portion of total iteration timeif f_stop is not None and f_stop(m, nll, it, t):breakexcept KeyboardInterrupt:breakreturn m

论文公式(5):

N L L d = 1 2 ( N d l o g ( 2 π ) + l o g ∣ C d ∣ + Y c ( d )

这篇关于pmf-automl源码分析的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/921629

相关文章

MySQL 内存使用率常用分析语句

《MySQL内存使用率常用分析语句》用户整理了MySQL内存占用过高的分析方法,涵盖操作系统层确认及数据库层bufferpool、内存模块差值、线程状态、performance_schema性能数据... 目录一、 OS层二、 DB层1. 全局情况2. 内存占js用详情最近连续遇到mysql内存占用过高导致

深度解析Nginx日志分析与499状态码问题解决

《深度解析Nginx日志分析与499状态码问题解决》在Web服务器运维和性能优化过程中,Nginx日志是排查问题的重要依据,本文将围绕Nginx日志分析、499状态码的成因、排查方法及解决方案展开讨论... 目录前言1. Nginx日志基础1.1 Nginx日志存放位置1.2 Nginx日志格式2. 499

Olingo分析和实践之EDM 辅助序列化器详解(最佳实践)

《Olingo分析和实践之EDM辅助序列化器详解(最佳实践)》EDM辅助序列化器是ApacheOlingoOData框架中无需完整EDM模型的智能序列化工具,通过运行时类型推断实现灵活数据转换,适用... 目录概念与定义什么是 EDM 辅助序列化器?核心概念设计目标核心特点1. EDM 信息可选2. 智能类

Olingo分析和实践之OData框架核心组件初始化(关键步骤)

《Olingo分析和实践之OData框架核心组件初始化(关键步骤)》ODataSpringBootService通过初始化OData实例和服务元数据,构建框架核心能力与数据模型结构,实现序列化、URI... 目录概述第一步:OData实例创建1.1 OData.newInstance() 详细分析1.1.1

Olingo分析和实践之ODataImpl详细分析(重要方法详解)

《Olingo分析和实践之ODataImpl详细分析(重要方法详解)》ODataImpl.java是ApacheOlingoOData框架的核心工厂类,负责创建序列化器、反序列化器和处理器等组件,... 目录概述主要职责类结构与继承关系核心功能分析1. 序列化器管理2. 反序列化器管理3. 处理器管理重要方

SpringBoot中六种批量更新Mysql的方式效率对比分析

《SpringBoot中六种批量更新Mysql的方式效率对比分析》文章比较了MySQL大数据量批量更新的多种方法,指出REPLACEINTO和ONDUPLICATEKEY效率最高但存在数据风险,MyB... 目录效率比较测试结构数据库初始化测试数据批量修改方案第一种 for第二种 case when第三种

解决1093 - You can‘t specify target table报错问题及原因分析

《解决1093-Youcan‘tspecifytargettable报错问题及原因分析》MySQL1093错误因UPDATE/DELETE语句的FROM子句直接引用目标表或嵌套子查询导致,... 目录报js错原因分析具体原因解决办法方法一:使用临时表方法二:使用JOIN方法三:使用EXISTS示例总结报错原

MySQL中的LENGTH()函数用法详解与实例分析

《MySQL中的LENGTH()函数用法详解与实例分析》MySQLLENGTH()函数用于计算字符串的字节长度,区别于CHAR_LENGTH()的字符长度,适用于多字节字符集(如UTF-8)的数据验证... 目录1. LENGTH()函数的基本语法2. LENGTH()函数的返回值2.1 示例1:计算字符串

Android kotlin中 Channel 和 Flow 的区别和选择使用场景分析

《Androidkotlin中Channel和Flow的区别和选择使用场景分析》Kotlin协程中,Flow是冷数据流,按需触发,适合响应式数据处理;Channel是热数据流,持续发送,支持... 目录一、基本概念界定FlowChannel二、核心特性对比数据生产触发条件生产与消费的关系背压处理机制生命周期

怎样通过分析GC日志来定位Java进程的内存问题

《怎样通过分析GC日志来定位Java进程的内存问题》:本文主要介绍怎样通过分析GC日志来定位Java进程的内存问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、GC 日志基础配置1. 启用详细 GC 日志2. 不同收集器的日志格式二、关键指标与分析维度1.