稀疏变分高斯过程【超简单,全流程解析,案例应用,简单代码】

2024-05-06 17:52

本文主要是介绍稀疏变分高斯过程【超简单,全流程解析,案例应用,简单代码】,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

  • 简介
      • 1. 定义和目标
      • 2. 协方差函数与引入点
      • 3. 变分分布
      • 4. 近似后验参数的计算
      • 5. 计算具体步骤
      • 6. 优势与应用(时间复杂度)
  • 应用案例
      • 1. 初始化参数
      • 2. 计算核矩阵
      • 3. 计算优化变分参数
      • 4. 预测新数据点
      • 5. 结果展示
    • 符号和参数说明
  • python代码

简介

稀疏变分高斯过程(Sparse Variational Gaussian Processes, SVGP)是一种高效的高斯过程(GP)近似方法,它使用一组称为引入点的固定数据点来近似整个数据集。这种方法大大减少了高斯过程模型的计算复杂度,使其能够适用于大数据集。下面是SVGP的详细数学过程。

1. 定义和目标

在标准高斯过程中,给定数据集 { ( x i , y i ) } i = 1 N \{(\mathbf{x}_i, y_i)\}_{i=1}^N {(xi,yi)}i=1N,目标是学习一个映射 f f f ,其中 f ∼ G P ( m , k ) f \sim \mathcal{GP}(m, k) fGP(m,k) m m m 是均值函数, k k k 是协方差函数。SVGP的目标是使用一组较小的引入点 Z = { z i } i = 1 M \mathbf{Z} = \{\mathbf{z}_i\}_{i=1}^M Z={zi}i=1M (其中 M ≪ N M \ll N MN)来近似这个映射。

2. 协方差函数与引入点

引入点 Z \mathbf{Z} Z 被用于构建一个近似的协方差矩阵 K M M \mathbf{K}_{MM} KMM,其中包含引入点之间的协方差。实际的观测点 X \mathbf{X} X 与引入点之间的协方差表示为 K N M \mathbf{K}_{NM} KNM

3. 变分分布

在SVGP中,我们设定变分分布 q ( f ) q(f) q(f) 来近似真实的后验分布。变分分布假设形式为:
q ( f ) = ∫ p ( f ∣ u ) q ( u ) d u q(\mathbf{f}) = \int p(\mathbf{f} | \mathbf{u}) q(\mathbf{u}) \, d\mathbf{u} q(f)=p(fu)q(u)du
其中 u \mathbf{u} u 是在引入点上的函数值, q ( u ) = N ( u ∣ m , S ) q(\mathbf{u}) = \mathcal{N}(\mathbf{u} | \mathbf{m}, \mathbf{S}) q(u)=N(um,S) 是定义在引入点上的高斯分布,具有均值 m \mathbf{m} m 和协方差矩阵 S \mathbf{S} S

4. 近似后验参数的计算

变分参数 m \mathbf{m} m S \mathbf{S} S 通过最小化KL散度(Kullback-Leibler divergence)来学习,这是变分推断中的常用方法。这要求我们计算如下的期望对数似然和KL散度:

ELBO = E q ( f ) [ log ⁡ p ( y ∣ f ) ] − KL ( q ( u ) ∥ p ( u ) ) \text{ELBO} = \mathbb{E}_{q(\mathbf{f})}[\log p(\mathbf{y}|\mathbf{f})] - \text{KL}(q(\mathbf{u}) \| p(\mathbf{u})) ELBO=Eq(f)[logp(yf)]KL(q(u)p(u))

其中,第一项是在变分分布下数据的对数似然的期望,第二项是变分分布和先验分布之间的KL散度。

5. 计算具体步骤

  • 计算协方差矩阵 K M M \mathbf{K}_{MM} KMM, K N M \mathbf{K}_{NM} KNM K N N \mathbf{K}_{NN} KNN
  • 变分分布更新:通过优化ELBO来学习变分参数 m \mathbf{m} m S \mathbf{S} S
  • 后验均值和协方差的更新:在测试点 X ∗ \mathbf{X}_* X 上的后验均值和方差可以通过变分参数和核矩阵计算得到。

6. 优势与应用(时间复杂度)

SVGP减少了与 N N N 成二次方的计算复杂度,变为与 M M M 成二次方的计算复杂度,其中 M M M 通常远小于 N N N。这使得SVGP可以应用于大规模数据集的概率建模和推断。

应用案例

让我们通过一个具体的数值例子来解释稀疏变分高斯过程(SVGP)的操作和计算。假设我们有一个简单的一维回归任务,数据集由以下观测点构成:

  • 训练数据点 X X X X = [ 0.5 , 1.5 , 2.5 , 3.5 , 4.5 ] \mathbf{X} = [0.5, 1.5, 2.5, 3.5, 4.5] X=[0.5,1.5,2.5,3.5,4.5]
  • 对应的目标值 y y y y = [ 1.0 , 2.0 , 3.0 , 2.5 , 1.5 ] \mathbf{y} = [1.0, 2.0, 3.0, 2.5, 1.5] y=[1.0,2.0,3.0,2.5,1.5]

我们的目标是使用SVGP来拟合这些数据。我们设定使用两个引入点( M = 2 M = 2 M=2),位置分布在输入空间内:

  • 引入点 Z Z Z Z = [ 1.0 , 4.0 ] \mathbf{Z} = [1.0, 4.0] Z=[1.0,4.0]

假设使用的核函数是平方指数核,其长度尺度设定为 1.0,输出方差也设定为 1.0。

1. 初始化参数

引入点的均值 m \mathbf{m} m 和协方差 S \mathbf{S} S 参数随机初始化:
m = [ 0.0 , 0.0 ] \mathbf{m} = [0.0, 0.0] m=[0.0,0.0]
S = [ 1.0 0.0 0.0 1.0 ] \mathbf{S} = \begin{bmatrix} 1.0 & 0.0 \\ 0.0 & 1.0 \end{bmatrix} S=[1.00.00.01.0]

2. 计算核矩阵

设定噪声水平 σ 2 = 0.1 \sigma^2 = 0.1 σ2=0.1

  • K M M \mathbf{K}_{MM} KMM:核矩阵在引入点之间:
    K M M = [ 1.0 e − 4.5 / 2 e − 4.5 / 2 1.0 ] \mathbf{K}_{MM} = \begin{bmatrix} 1.0 & e^{-4.5/2} \\ e^{-4.5/2} & 1.0 \end{bmatrix} KMM=[1.0e4.5/2e4.5/21.0]
  • K N M \mathbf{K}_{NM} KNM:核矩阵在观测点和引入点之间:
    K N M = [ e − 0.25 / 2 e − 12.25 / 2 e − 0.25 / 2 e − 6.25 / 2 e − 1.0 / 2 e − 2.25 / 2 e − 4.5 / 2 e − 0.25 / 2 e − 9.0 / 2 e − 0.25 / 2 ] \mathbf{K}_{NM} = \begin{bmatrix} e^{-0.25/2} & e^{-12.25/2} \\ e^{-0.25/2} & e^{-6.25/2} \\ e^{-1.0/2} & e^{-2.25/2} \\ e^{-4.5/2} & e^{-0.25/2} \\ e^{-9.0/2} & e^{-0.25/2} \end{bmatrix} KNM= e0.25/2e0.25/2e1.0/2e4.5/2e9.0/2e12.25/2e6.25/2e2.25/2e0.25/2e0.25/2

3. 计算优化变分参数

在这一步,我们利用变分推断来优化引入点的均值 m \mathbf{m} m 和协方差 S \mathbf{S} S 参数。假设我们使用期望传播(EP)或者自然梯度下降来优化。

(1). 计算引入点的后验分布参数:

  • 使用的核矩阵 K M M \mathbf{K}_{MM} KMM K N M \mathbf{K}_{NM} KNM 已经给出。
  • 计算精度矩阵(逆协方差矩阵) Λ \Lambda Λ
    Λ = K M M − 1 + K N M ⊤ diag ( 1 σ 2 + Var [ f n ] ) K N M \Lambda = \mathbf{K}_{MM}^{-1} + \mathbf{K}_{NM}^\top \text{diag}(\frac{1}{\sigma^2 + \text{Var}[\mathbf{f}_n]}) \mathbf{K}_{NM} Λ=KMM1+KNMdiag(σ2+Var[fn]1)KNM
  • 其中, Var [ f n ] \text{Var}[\mathbf{f}_n] Var[fn] 是每个数据点的方差,可以假设初始为零。
  • 更新均值 m \mathbf{m} m
    m = Λ − 1 K N M ⊤ diag ( 1 σ 2 + Var [ f n ] ) y \mathbf{m} = \Lambda^{-1} \mathbf{K}_{NM}^\top \text{diag}(\frac{1}{\sigma^2 + \text{Var}[\mathbf{f}_n]}) \mathbf{y} m=Λ1KNMdiag(σ2+Var[fn]1)y

(2). 优化变分参数:

  • 根据变分推断框架,我们最小化KL散度。这通常涉及到迭代更新 m \mathbf{m} m S \mathbf{S} S 直到收敛:
    S = Λ − 1 \mathbf{S} = \Lambda^{-1} S=Λ1

4. 预测新数据点

给定新的输入位置 x ∗ = [ 2.0 , 3.0 ] x_* = [2.0, 3.0] x=[2.0,3.0],我们使用更新后的变分参数进行预测:

(1). 计算新数据点与引入点之间的核矩阵 K ∗ M \mathbf{K}_{*M} KM
K ∗ M = [ e − 1.0 / 2 e − 9.0 / 2 e − 4.0 / 2 e − 1.0 / 2 ] \mathbf{K}_{*M} = \begin{bmatrix} e^{-1.0/2} & e^{-9.0/2} \\ e^{-4.0/2} & e^{-1.0/2} \end{bmatrix} KM=[e1.0/2e4.0/2e9.0/2e1.0/2]

(2). 计算新数据点自身的核矩阵 K ∗ ∗ \mathbf{K}_{**} K∗∗
K ∗ ∗ = [ 1.0 e − 1.0 / 2 e − 1.0 / 2 1.0 ] \mathbf{K}_{**} = \begin{bmatrix} 1.0 & e^{-1.0/2} \\ e^{-1.0/2} & 1.0 \end{bmatrix} K∗∗=[1.0e1.0/2e1.0/21.0]

(3). 使用变分后验公式计算预测均值和方差:

  • 均值:
    μ ∗ = K ∗ M K M M − 1 m \mu_* = \mathbf{K}_{*M} \mathbf{K}_{MM}^{-1} \mathbf{m} μ=KMKMM1m
  • 方差:
    Σ ∗ = K ∗ ∗ − K ∗ M K M M − 1 ( K M M − S ) K M M − 1 K M ∗ \Sigma_* = \mathbf{K}_{**} - \mathbf{K}_{*M} \mathbf{K}_{MM}^{-1} (\mathbf{K}_{MM} - \mathbf{S}) \mathbf{K}_{MM}^{-1} \mathbf{K}_{M*} Σ=K∗∗KMKMM1(KMMS)KMM1KM
  • 这里 K M ∗ \mathbf{K}_{M*} KM K ∗ M \mathbf{K}_{*M} KM 的转置。

5. 结果展示

这种方法给出了在新数据点 x ∗ = [ 2.0 , 3.0 ] x_* = [2.0, 3.0] x=[2.0,3.0] 处的预测分布,包括均值和方差,这些预测可以用于后续的分析或决策制定。

假设优化后,参数更新为:
m = [ 1.5 , 2.0 ] \mathbf{m} = [1.5, 2.0] m=[1.5,2.0]
S = [ 0.5 0.1 0.1 0.5 ] \mathbf{S} = \begin{bmatrix} 0.5 & 0.1 \\ 0.1 & 0.5 \end{bmatrix} S=[0.50.10.10.5]

  • 注意:在实际应用中,这些计算会通过自动化软件工具包如GPflow或GPyTorch来完成。

让我们详细解释上述例子中使用的每个符号和参数:

符号和参数说明

  1. X X X(训练数据点):

    • X = [ 0.5 , 1.5 , 2.5 , 3.5 , 4.5 ] \mathbf{X} = [0.5, 1.5, 2.5, 3.5, 4.5] X=[0.5,1.5,2.5,3.5,4.5]:一维输入空间中的训练数据点。
  2. y y y(目标值):

    • y = [ 1.0 , 2.0 , 3.0 , 2.5 , 1.5 ] \mathbf{y} = [1.0, 2.0, 3.0, 2.5, 1.5] y=[1.0,2.0,3.0,2.5,1.5]:对应于输入点 X X X 的目标输出。
  3. Z Z Z(引入点):

    • Z = [ 1.0 , 4.0 ] \mathbf{Z} = [1.0, 4.0] Z=[1.0,4.0]:在输入空间中人为设置的参考点,用于近似全空间的高斯过程。
  4. M M M(引入点数量):

    • M = 2 M = 2 M=2:引入点的总数。
  5. 核函数:

    • 平方指数核(Squared Exponential Kernel),核函数的选择直接影响模型的平滑性和灵活性。
  6. 长度尺度(Length Scale):

    • 控制核函数的宽度,这里假设为 1.0。
  7. 输出方差(Output Variance):

    • 核函数的高度,这里假设为 1.0。
  8. σ 2 \sigma^2 σ2(观测噪声方差):

    • σ 2 = 0.1 \sigma^2 = 0.1 σ2=0.1:数据的噪声水平,影响模型对数据的敏感程度。
  9. m \mathbf{m} m(引入点的均值参数):

    • 初始设置为 m = [ 0.0 , 0.0 ] \mathbf{m} = [0.0, 0.0] m=[0.0,0.0]
  10. S \mathbf{S} S(引入点的协方差参数):

  • 初始设置为 S = [ 1.0 0.0 0.0 1.0 ] \mathbf{S} = \begin{bmatrix} 1.0 & 0.0 \\ 0.0 & 1.0 \end{bmatrix} S=[1.00.00.01.0]
  1. K M M \mathbf{K}_{MM} KMM(引入点间的核矩阵):
  • 计算所有引入点之间的核值。
  1. K N M \mathbf{K}_{NM} KNM(训练点与引入点间的核矩阵):
  • 计算训练点与引入点之间的核值。
  1. K ∗ M \mathbf{K}_{*M} KM(新数据点与引入点间的核矩阵):
  • 计算新数据点与引入点之间的核值。
  1. K ∗ ∗ \mathbf{K}_{**} K∗∗(新数据点自身的核矩阵):
  • 计算新数据点之间的核值。
  1. 变分参数(Variational Parameters):
  • m \mathbf{m} m S \mathbf{S} S 是在变分推断中优化的参数,用于近似后验分布。

python代码

本代码包含参数更新和预测:

import torch
import torch.nn as nn
import numpy as np
import mathclass SVGP(nn.Module):def __init__(self, inducing_points, kernel_scale=1.0, jitter=1e-6):super(SVGP, self).__init__()self.inducing_points = nn.Parameter(torch.tensor(inducing_points, dtype=torch.float32))self.kernel_scale = kernel_scaleself.jitter = jitterself.variational_mean = nn.Parameter(torch.zeros(self.inducing_points.shape[0]))self.variational_cov = nn.Parameter(torch.eye(self.inducing_points.shape[0]))def rbf_kernel(self, X, Y):dist = torch.cdist(X, Y)**2return torch.exp(-0.5 / self.kernel_scale * dist)def forward(self, X, y=None):# Calculate kernel matricesK_mm = self.rbf_kernel(self.inducing_points, self.inducing_points) + self.jitter * torch.eye(self.inducing_points.shape[0])K_nm = self.rbf_kernel(X, self.inducing_points)K_mn = K_nm.TK_nn = self.rbf_kernel(X, X)# Compute the inverse of K_mmK_mm_inv = torch.inverse(K_mm)# If training mode, optimize the variational parametersif y is not None:noise = 0.1  # Fixed noise for simplicityA = torch.mm(torch.mm(K_nm, K_mm_inv), K_mn) + torch.eye(X.shape[0]) * noiseB = torch.mm(torch.mm(K_nm, K_mm_inv), self.variational_mean.unsqueeze(1)).squeeze()# Compute the variational lower bound and gradients for optimization# Placeholder for actual ELBO calculationloss = torch.mean((y - B)**2) + torch.trace(A)return loss# If not training mode, do the predictionelse:# Predictive mean and variancemu_star = torch.mm(torch.mm(K_nm, K_mm_inv), self.variational_mean.unsqueeze(1)).squeeze()v_star = K_nn - torch.mm(torch.mm(K_nm, K_mm_inv), K_mn)return mu_star, v_stardef train_model(self, X_train, y_train, learning_rate=0.01, epochs=100):optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)self.train()for epoch in range(epochs):optimizer.zero_grad()loss = self.forward(X_train, y_train)loss.backward()optimizer.step()print(f'Epoch {epoch+1}, Loss: {loss.item()}')def predict(self, X_test):self.eval()mu_star, v_star = self.forward(X_test)return mu_star, v_star.diag()# Example usage
X_train = torch.tensor([[0.5], [1.5], [2.5], [3.5], [4.5]])
y_train = torch.tensor([1.0, 2.0, 3.0, 2.5, 1.5])
inducing_points = torch.tensor([[1.0], [4.0]])model = SVGP(inducing_points=inducing_points)
model.train_model(X_train, y_train)
mu_star, v_star = model.predict(torch.tensor([[2.0], [3.0]]))
print(f'Mean predictions: {mu_star}, Variances: {v_star}')

这篇关于稀疏变分高斯过程【超简单,全流程解析,案例应用,简单代码】的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/964965

相关文章

oracle 11g导入\导出(expdp impdp)之导入过程

《oracle11g导入导出(expdpimpdp)之导入过程》导出需使用SEC.DMP格式,无分号;建立expdir目录(E:/exp)并确保存在;导入在cmd下执行,需sys用户权限;若需修... 目录准备文件导入(impdp)1、建立directory2、导入语句 3、更改密码总结上一个环节,我们讲了

PHP应用中处理限流和API节流的最佳实践

《PHP应用中处理限流和API节流的最佳实践》限流和API节流对于确保Web应用程序的可靠性、安全性和可扩展性至关重要,本文将详细介绍PHP应用中处理限流和API节流的最佳实践,下面就来和小编一起学习... 目录限流的重要性在 php 中实施限流的最佳实践使用集中式存储进行状态管理(如 Redis)采用滑动

ShardingProxy读写分离之原理、配置与实践过程

《ShardingProxy读写分离之原理、配置与实践过程》ShardingProxy是ApacheShardingSphere的数据库中间件,通过三层架构实现读写分离,解决高并发场景下数据库性能瓶... 目录一、ShardingProxy技术定位与读写分离核心价值1.1 技术定位1.2 读写分离核心价值二

深度解析Python中递归下降解析器的原理与实现

《深度解析Python中递归下降解析器的原理与实现》在编译器设计、配置文件处理和数据转换领域,递归下降解析器是最常用且最直观的解析技术,本文将详细介绍递归下降解析器的原理与实现,感兴趣的小伙伴可以跟随... 目录引言:解析器的核心价值一、递归下降解析器基础1.1 核心概念解析1.2 基本架构二、简单算术表达

MyBatis-plus处理存储json数据过程

《MyBatis-plus处理存储json数据过程》文章介绍MyBatis-Plus3.4.21处理对象与集合的差异:对象可用内置Handler配合autoResultMap,集合需自定义处理器继承F... 目录1、如果是对象2、如果需要转换的是List集合总结对象和集合分两种情况处理,目前我用的MP的版本

深度解析Java @Serial 注解及常见错误案例

《深度解析Java@Serial注解及常见错误案例》Java14引入@Serial注解,用于编译时校验序列化成员,替代传统方式解决运行时错误,适用于Serializable类的方法/字段,需注意签... 目录Java @Serial 注解深度解析1. 注解本质2. 核心作用(1) 主要用途(2) 适用位置3

深入浅出Spring中的@Autowired自动注入的工作原理及实践应用

《深入浅出Spring中的@Autowired自动注入的工作原理及实践应用》在Spring框架的学习旅程中,@Autowired无疑是一个高频出现却又让初学者头疼的注解,它看似简单,却蕴含着Sprin... 目录深入浅出Spring中的@Autowired:自动注入的奥秘什么是依赖注入?@Autowired

Java MCP 的鉴权深度解析

《JavaMCP的鉴权深度解析》文章介绍JavaMCP鉴权的实现方式,指出客户端可通过queryString、header或env传递鉴权信息,服务器端支持工具单独鉴权、过滤器集中鉴权及启动时鉴权... 目录一、MCP Client 侧(负责传递,比较简单)(1)常见的 mcpServers json 配置

Redis实现高效内存管理的示例代码

《Redis实现高效内存管理的示例代码》Redis内存管理是其核心功能之一,为了高效地利用内存,Redis采用了多种技术和策略,如优化的数据结构、内存分配策略、内存回收、数据压缩等,下面就来详细的介绍... 目录1. 内存分配策略jemalloc 的使用2. 数据压缩和编码ziplist示例代码3. 优化的

redis-sentinel基础概念及部署流程

《redis-sentinel基础概念及部署流程》RedisSentinel是Redis的高可用解决方案,通过监控主从节点、自动故障转移、通知机制及配置提供,实现集群故障恢复与服务持续可用,核心组件包... 目录一. 引言二. 核心功能三. 核心组件四. 故障转移流程五. 服务部署六. sentinel部署