稀疏变分高斯过程【超简单,全流程解析,案例应用,简单代码】

2024-05-06 17:52

本文主要是介绍稀疏变分高斯过程【超简单,全流程解析,案例应用,简单代码】,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

  • 简介
      • 1. 定义和目标
      • 2. 协方差函数与引入点
      • 3. 变分分布
      • 4. 近似后验参数的计算
      • 5. 计算具体步骤
      • 6. 优势与应用(时间复杂度)
  • 应用案例
      • 1. 初始化参数
      • 2. 计算核矩阵
      • 3. 计算优化变分参数
      • 4. 预测新数据点
      • 5. 结果展示
    • 符号和参数说明
  • python代码

简介

稀疏变分高斯过程(Sparse Variational Gaussian Processes, SVGP)是一种高效的高斯过程(GP)近似方法,它使用一组称为引入点的固定数据点来近似整个数据集。这种方法大大减少了高斯过程模型的计算复杂度,使其能够适用于大数据集。下面是SVGP的详细数学过程。

1. 定义和目标

在标准高斯过程中,给定数据集 { ( x i , y i ) } i = 1 N \{(\mathbf{x}_i, y_i)\}_{i=1}^N {(xi,yi)}i=1N,目标是学习一个映射 f f f ,其中 f ∼ G P ( m , k ) f \sim \mathcal{GP}(m, k) fGP(m,k) m m m 是均值函数, k k k 是协方差函数。SVGP的目标是使用一组较小的引入点 Z = { z i } i = 1 M \mathbf{Z} = \{\mathbf{z}_i\}_{i=1}^M Z={zi}i=1M (其中 M ≪ N M \ll N MN)来近似这个映射。

2. 协方差函数与引入点

引入点 Z \mathbf{Z} Z 被用于构建一个近似的协方差矩阵 K M M \mathbf{K}_{MM} KMM,其中包含引入点之间的协方差。实际的观测点 X \mathbf{X} X 与引入点之间的协方差表示为 K N M \mathbf{K}_{NM} KNM

3. 变分分布

在SVGP中,我们设定变分分布 q ( f ) q(f) q(f) 来近似真实的后验分布。变分分布假设形式为:
q ( f ) = ∫ p ( f ∣ u ) q ( u ) d u q(\mathbf{f}) = \int p(\mathbf{f} | \mathbf{u}) q(\mathbf{u}) \, d\mathbf{u} q(f)=p(fu)q(u)du
其中 u \mathbf{u} u 是在引入点上的函数值, q ( u ) = N ( u ∣ m , S ) q(\mathbf{u}) = \mathcal{N}(\mathbf{u} | \mathbf{m}, \mathbf{S}) q(u)=N(um,S) 是定义在引入点上的高斯分布,具有均值 m \mathbf{m} m 和协方差矩阵 S \mathbf{S} S

4. 近似后验参数的计算

变分参数 m \mathbf{m} m S \mathbf{S} S 通过最小化KL散度(Kullback-Leibler divergence)来学习,这是变分推断中的常用方法。这要求我们计算如下的期望对数似然和KL散度:

ELBO = E q ( f ) [ log ⁡ p ( y ∣ f ) ] − KL ( q ( u ) ∥ p ( u ) ) \text{ELBO} = \mathbb{E}_{q(\mathbf{f})}[\log p(\mathbf{y}|\mathbf{f})] - \text{KL}(q(\mathbf{u}) \| p(\mathbf{u})) ELBO=Eq(f)[logp(yf)]KL(q(u)p(u))

其中,第一项是在变分分布下数据的对数似然的期望,第二项是变分分布和先验分布之间的KL散度。

5. 计算具体步骤

  • 计算协方差矩阵 K M M \mathbf{K}_{MM} KMM, K N M \mathbf{K}_{NM} KNM K N N \mathbf{K}_{NN} KNN
  • 变分分布更新:通过优化ELBO来学习变分参数 m \mathbf{m} m S \mathbf{S} S
  • 后验均值和协方差的更新:在测试点 X ∗ \mathbf{X}_* X 上的后验均值和方差可以通过变分参数和核矩阵计算得到。

6. 优势与应用(时间复杂度)

SVGP减少了与 N N N 成二次方的计算复杂度,变为与 M M M 成二次方的计算复杂度,其中 M M M 通常远小于 N N N。这使得SVGP可以应用于大规模数据集的概率建模和推断。

应用案例

让我们通过一个具体的数值例子来解释稀疏变分高斯过程(SVGP)的操作和计算。假设我们有一个简单的一维回归任务,数据集由以下观测点构成:

  • 训练数据点 X X X X = [ 0.5 , 1.5 , 2.5 , 3.5 , 4.5 ] \mathbf{X} = [0.5, 1.5, 2.5, 3.5, 4.5] X=[0.5,1.5,2.5,3.5,4.5]
  • 对应的目标值 y y y y = [ 1.0 , 2.0 , 3.0 , 2.5 , 1.5 ] \mathbf{y} = [1.0, 2.0, 3.0, 2.5, 1.5] y=[1.0,2.0,3.0,2.5,1.5]

我们的目标是使用SVGP来拟合这些数据。我们设定使用两个引入点( M = 2 M = 2 M=2),位置分布在输入空间内:

  • 引入点 Z Z Z Z = [ 1.0 , 4.0 ] \mathbf{Z} = [1.0, 4.0] Z=[1.0,4.0]

假设使用的核函数是平方指数核,其长度尺度设定为 1.0,输出方差也设定为 1.0。

1. 初始化参数

引入点的均值 m \mathbf{m} m 和协方差 S \mathbf{S} S 参数随机初始化:
m = [ 0.0 , 0.0 ] \mathbf{m} = [0.0, 0.0] m=[0.0,0.0]
S = [ 1.0 0.0 0.0 1.0 ] \mathbf{S} = \begin{bmatrix} 1.0 & 0.0 \\ 0.0 & 1.0 \end{bmatrix} S=[1.00.00.01.0]

2. 计算核矩阵

设定噪声水平 σ 2 = 0.1 \sigma^2 = 0.1 σ2=0.1

  • K M M \mathbf{K}_{MM} KMM:核矩阵在引入点之间:
    K M M = [ 1.0 e − 4.5 / 2 e − 4.5 / 2 1.0 ] \mathbf{K}_{MM} = \begin{bmatrix} 1.0 & e^{-4.5/2} \\ e^{-4.5/2} & 1.0 \end{bmatrix} KMM=[1.0e4.5/2e4.5/21.0]
  • K N M \mathbf{K}_{NM} KNM:核矩阵在观测点和引入点之间:
    K N M = [ e − 0.25 / 2 e − 12.25 / 2 e − 0.25 / 2 e − 6.25 / 2 e − 1.0 / 2 e − 2.25 / 2 e − 4.5 / 2 e − 0.25 / 2 e − 9.0 / 2 e − 0.25 / 2 ] \mathbf{K}_{NM} = \begin{bmatrix} e^{-0.25/2} & e^{-12.25/2} \\ e^{-0.25/2} & e^{-6.25/2} \\ e^{-1.0/2} & e^{-2.25/2} \\ e^{-4.5/2} & e^{-0.25/2} \\ e^{-9.0/2} & e^{-0.25/2} \end{bmatrix} KNM= e0.25/2e0.25/2e1.0/2e4.5/2e9.0/2e12.25/2e6.25/2e2.25/2e0.25/2e0.25/2

3. 计算优化变分参数

在这一步,我们利用变分推断来优化引入点的均值 m \mathbf{m} m 和协方差 S \mathbf{S} S 参数。假设我们使用期望传播(EP)或者自然梯度下降来优化。

(1). 计算引入点的后验分布参数:

  • 使用的核矩阵 K M M \mathbf{K}_{MM} KMM K N M \mathbf{K}_{NM} KNM 已经给出。
  • 计算精度矩阵(逆协方差矩阵) Λ \Lambda Λ
    Λ = K M M − 1 + K N M ⊤ diag ( 1 σ 2 + Var [ f n ] ) K N M \Lambda = \mathbf{K}_{MM}^{-1} + \mathbf{K}_{NM}^\top \text{diag}(\frac{1}{\sigma^2 + \text{Var}[\mathbf{f}_n]}) \mathbf{K}_{NM} Λ=KMM1+KNMdiag(σ2+Var[fn]1)KNM
  • 其中, Var [ f n ] \text{Var}[\mathbf{f}_n] Var[fn] 是每个数据点的方差,可以假设初始为零。
  • 更新均值 m \mathbf{m} m
    m = Λ − 1 K N M ⊤ diag ( 1 σ 2 + Var [ f n ] ) y \mathbf{m} = \Lambda^{-1} \mathbf{K}_{NM}^\top \text{diag}(\frac{1}{\sigma^2 + \text{Var}[\mathbf{f}_n]}) \mathbf{y} m=Λ1KNMdiag(σ2+Var[fn]1)y

(2). 优化变分参数:

  • 根据变分推断框架,我们最小化KL散度。这通常涉及到迭代更新 m \mathbf{m} m S \mathbf{S} S 直到收敛:
    S = Λ − 1 \mathbf{S} = \Lambda^{-1} S=Λ1

4. 预测新数据点

给定新的输入位置 x ∗ = [ 2.0 , 3.0 ] x_* = [2.0, 3.0] x=[2.0,3.0],我们使用更新后的变分参数进行预测:

(1). 计算新数据点与引入点之间的核矩阵 K ∗ M \mathbf{K}_{*M} KM
K ∗ M = [ e − 1.0 / 2 e − 9.0 / 2 e − 4.0 / 2 e − 1.0 / 2 ] \mathbf{K}_{*M} = \begin{bmatrix} e^{-1.0/2} & e^{-9.0/2} \\ e^{-4.0/2} & e^{-1.0/2} \end{bmatrix} KM=[e1.0/2e4.0/2e9.0/2e1.0/2]

(2). 计算新数据点自身的核矩阵 K ∗ ∗ \mathbf{K}_{**} K∗∗
K ∗ ∗ = [ 1.0 e − 1.0 / 2 e − 1.0 / 2 1.0 ] \mathbf{K}_{**} = \begin{bmatrix} 1.0 & e^{-1.0/2} \\ e^{-1.0/2} & 1.0 \end{bmatrix} K∗∗=[1.0e1.0/2e1.0/21.0]

(3). 使用变分后验公式计算预测均值和方差:

  • 均值:
    μ ∗ = K ∗ M K M M − 1 m \mu_* = \mathbf{K}_{*M} \mathbf{K}_{MM}^{-1} \mathbf{m} μ=KMKMM1m
  • 方差:
    Σ ∗ = K ∗ ∗ − K ∗ M K M M − 1 ( K M M − S ) K M M − 1 K M ∗ \Sigma_* = \mathbf{K}_{**} - \mathbf{K}_{*M} \mathbf{K}_{MM}^{-1} (\mathbf{K}_{MM} - \mathbf{S}) \mathbf{K}_{MM}^{-1} \mathbf{K}_{M*} Σ=K∗∗KMKMM1(KMMS)KMM1KM
  • 这里 K M ∗ \mathbf{K}_{M*} KM K ∗ M \mathbf{K}_{*M} KM 的转置。

5. 结果展示

这种方法给出了在新数据点 x ∗ = [ 2.0 , 3.0 ] x_* = [2.0, 3.0] x=[2.0,3.0] 处的预测分布,包括均值和方差,这些预测可以用于后续的分析或决策制定。

假设优化后,参数更新为:
m = [ 1.5 , 2.0 ] \mathbf{m} = [1.5, 2.0] m=[1.5,2.0]
S = [ 0.5 0.1 0.1 0.5 ] \mathbf{S} = \begin{bmatrix} 0.5 & 0.1 \\ 0.1 & 0.5 \end{bmatrix} S=[0.50.10.10.5]

  • 注意:在实际应用中,这些计算会通过自动化软件工具包如GPflow或GPyTorch来完成。

让我们详细解释上述例子中使用的每个符号和参数:

符号和参数说明

  1. X X X(训练数据点):

    • X = [ 0.5 , 1.5 , 2.5 , 3.5 , 4.5 ] \mathbf{X} = [0.5, 1.5, 2.5, 3.5, 4.5] X=[0.5,1.5,2.5,3.5,4.5]:一维输入空间中的训练数据点。
  2. y y y(目标值):

    • y = [ 1.0 , 2.0 , 3.0 , 2.5 , 1.5 ] \mathbf{y} = [1.0, 2.0, 3.0, 2.5, 1.5] y=[1.0,2.0,3.0,2.5,1.5]:对应于输入点 X X X 的目标输出。
  3. Z Z Z(引入点):

    • Z = [ 1.0 , 4.0 ] \mathbf{Z} = [1.0, 4.0] Z=[1.0,4.0]:在输入空间中人为设置的参考点,用于近似全空间的高斯过程。
  4. M M M(引入点数量):

    • M = 2 M = 2 M=2:引入点的总数。
  5. 核函数:

    • 平方指数核(Squared Exponential Kernel),核函数的选择直接影响模型的平滑性和灵活性。
  6. 长度尺度(Length Scale):

    • 控制核函数的宽度,这里假设为 1.0。
  7. 输出方差(Output Variance):

    • 核函数的高度,这里假设为 1.0。
  8. σ 2 \sigma^2 σ2(观测噪声方差):

    • σ 2 = 0.1 \sigma^2 = 0.1 σ2=0.1:数据的噪声水平,影响模型对数据的敏感程度。
  9. m \mathbf{m} m(引入点的均值参数):

    • 初始设置为 m = [ 0.0 , 0.0 ] \mathbf{m} = [0.0, 0.0] m=[0.0,0.0]
  10. S \mathbf{S} S(引入点的协方差参数):

  • 初始设置为 S = [ 1.0 0.0 0.0 1.0 ] \mathbf{S} = \begin{bmatrix} 1.0 & 0.0 \\ 0.0 & 1.0 \end{bmatrix} S=[1.00.00.01.0]
  1. K M M \mathbf{K}_{MM} KMM(引入点间的核矩阵):
  • 计算所有引入点之间的核值。
  1. K N M \mathbf{K}_{NM} KNM(训练点与引入点间的核矩阵):
  • 计算训练点与引入点之间的核值。
  1. K ∗ M \mathbf{K}_{*M} KM(新数据点与引入点间的核矩阵):
  • 计算新数据点与引入点之间的核值。
  1. K ∗ ∗ \mathbf{K}_{**} K∗∗(新数据点自身的核矩阵):
  • 计算新数据点之间的核值。
  1. 变分参数(Variational Parameters):
  • m \mathbf{m} m S \mathbf{S} S 是在变分推断中优化的参数,用于近似后验分布。

python代码

本代码包含参数更新和预测:

import torch
import torch.nn as nn
import numpy as np
import mathclass SVGP(nn.Module):def __init__(self, inducing_points, kernel_scale=1.0, jitter=1e-6):super(SVGP, self).__init__()self.inducing_points = nn.Parameter(torch.tensor(inducing_points, dtype=torch.float32))self.kernel_scale = kernel_scaleself.jitter = jitterself.variational_mean = nn.Parameter(torch.zeros(self.inducing_points.shape[0]))self.variational_cov = nn.Parameter(torch.eye(self.inducing_points.shape[0]))def rbf_kernel(self, X, Y):dist = torch.cdist(X, Y)**2return torch.exp(-0.5 / self.kernel_scale * dist)def forward(self, X, y=None):# Calculate kernel matricesK_mm = self.rbf_kernel(self.inducing_points, self.inducing_points) + self.jitter * torch.eye(self.inducing_points.shape[0])K_nm = self.rbf_kernel(X, self.inducing_points)K_mn = K_nm.TK_nn = self.rbf_kernel(X, X)# Compute the inverse of K_mmK_mm_inv = torch.inverse(K_mm)# If training mode, optimize the variational parametersif y is not None:noise = 0.1  # Fixed noise for simplicityA = torch.mm(torch.mm(K_nm, K_mm_inv), K_mn) + torch.eye(X.shape[0]) * noiseB = torch.mm(torch.mm(K_nm, K_mm_inv), self.variational_mean.unsqueeze(1)).squeeze()# Compute the variational lower bound and gradients for optimization# Placeholder for actual ELBO calculationloss = torch.mean((y - B)**2) + torch.trace(A)return loss# If not training mode, do the predictionelse:# Predictive mean and variancemu_star = torch.mm(torch.mm(K_nm, K_mm_inv), self.variational_mean.unsqueeze(1)).squeeze()v_star = K_nn - torch.mm(torch.mm(K_nm, K_mm_inv), K_mn)return mu_star, v_stardef train_model(self, X_train, y_train, learning_rate=0.01, epochs=100):optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)self.train()for epoch in range(epochs):optimizer.zero_grad()loss = self.forward(X_train, y_train)loss.backward()optimizer.step()print(f'Epoch {epoch+1}, Loss: {loss.item()}')def predict(self, X_test):self.eval()mu_star, v_star = self.forward(X_test)return mu_star, v_star.diag()# Example usage
X_train = torch.tensor([[0.5], [1.5], [2.5], [3.5], [4.5]])
y_train = torch.tensor([1.0, 2.0, 3.0, 2.5, 1.5])
inducing_points = torch.tensor([[1.0], [4.0]])model = SVGP(inducing_points=inducing_points)
model.train_model(X_train, y_train)
mu_star, v_star = model.predict(torch.tensor([[2.0], [3.0]]))
print(f'Mean predictions: {mu_star}, Variances: {v_star}')

这篇关于稀疏变分高斯过程【超简单,全流程解析,案例应用,简单代码】的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/964965

相关文章

MySQL字符串转数值的方法全解析

《MySQL字符串转数值的方法全解析》在MySQL开发中,字符串与数值的转换是高频操作,本文从隐式转换原理、显式转换方法、典型场景案例、风险防控四个维度系统梳理,助您精准掌握这一核心技能,需要的朋友可... 目录一、隐式转换:自动但需警惕的&ld编程quo;双刃剑”二、显式转换:三大核心方法详解三、典型场景

JAVA项目swing转javafx语法规则以及示例代码

《JAVA项目swing转javafx语法规则以及示例代码》:本文主要介绍JAVA项目swing转javafx语法规则以及示例代码的相关资料,文中详细讲解了主类继承、窗口创建、布局管理、控件替换、... 目录最常用的“一行换一行”速查表(直接全局替换)实际转换示例(JFramejs → JavaFX)迁移建

Go异常处理、泛型和文件操作实例代码

《Go异常处理、泛型和文件操作实例代码》Go语言的异常处理机制与传统的面向对象语言(如Java、C#)所使用的try-catch结构有所不同,它采用了自己独特的设计理念和方法,:本文主要介绍Go异... 目录一:异常处理常见的异常处理向上抛中断程序恢复程序二:泛型泛型函数泛型结构体泛型切片泛型 map三:文

Springboot3 ResponseEntity 完全使用案例

《Springboot3ResponseEntity完全使用案例》ResponseEntity是SpringBoot中控制HTTP响应的核心工具——它能让你精准定义响应状态码、响应头、响应体,相比... 目录Spring Boot 3 ResponseEntity 完全使用教程前置准备1. 项目基础依赖(M

MyBatis中的两种参数传递类型详解(示例代码)

《MyBatis中的两种参数传递类型详解(示例代码)》文章介绍了MyBatis中传递多个参数的两种方式,使用Map和使用@Param注解或封装POJO,Map方式适用于动态、不固定的参数,但可读性和安... 目录✅ android方式一:使用Map<String, Object>✅ 方式二:使用@Param

SpringBoot实现图形验证码的示例代码

《SpringBoot实现图形验证码的示例代码》验证码的实现方式有很多,可以由前端实现,也可以由后端进行实现,也有很多的插件和工具包可以使用,在这里,我们使用Hutool提供的小工具实现,本文介绍Sp... 目录项目创建前端代码实现约定前后端交互接口需求分析接口定义Hutool工具实现服务器端代码引入依赖获

GO语言实现串口简单通讯

《GO语言实现串口简单通讯》本文分享了使用Go语言进行串口通讯的实践过程,详细介绍了串口配置、数据发送与接收的代码实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要... 目录背景串口通讯代码代码块分解解析完整代码运行结果背景最近再学习 go 语言,在某宝用5块钱买了个

利用Python在万圣节实现比心弹窗告白代码

《利用Python在万圣节实现比心弹窗告白代码》:本文主要介绍关于利用Python在万圣节实现比心弹窗告白代码的相关资料,每个弹窗会显示一条温馨提示,程序通过参数方程绘制爱心形状,并使用多线程技术... 目录前言效果预览要点1. 爱心曲线方程2. 显示温馨弹窗函数(详细拆解)2.1 函数定义和延迟机制2.2

C++11中的包装器实战案例

《C++11中的包装器实战案例》本文给大家介绍C++11中的包装器实战案例,本文结合实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录引言1.std::function1.1.什么是std::function1.2.核心用法1.2.1.包装普通函数1.2.

SQL 注入攻击(SQL Injection)原理、利用方式与防御策略深度解析

《SQL注入攻击(SQLInjection)原理、利用方式与防御策略深度解析》本文将从SQL注入的基本原理、攻击方式、常见利用手法,到企业级防御方案进行全面讲解,以帮助开发者和安全人员更系统地理解... 目录一、前言二、SQL 注入攻击的基本概念三、SQL 注入常见类型分析1. 基于错误回显的注入(Erro