MPNN 模型:GNN 传递规则的实现

2023-10-16 07:30

本文主要是介绍MPNN 模型:GNN 传递规则的实现,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

首先,假如我们定义一个极简的传递规则

f(X,A) = AX

A是邻接矩阵,X是特征矩阵, 其物理意义就是 通过矩阵乘法操作,批量把图中的相邻节点汇聚到当前节点。

但是由于A的对角线都是 0.因此自身的节点特征会被过滤掉。

图神经网络的核心是 吸周围之精华,再叠加自身,因而需要改进来保留自身特征。如何做?

方法是给每个节点添加一个自环,即将邻接矩阵对角线值各加1,此时用\widetilde{A}表示,\widetilde{A}X做到了聚合邻居节点并保留自身信息。

但是当图过于复杂时,聚合邻居信息会不断执行矩阵乘法或加法,可能导致特征值太大而溢出。如何做?

方法是邻接矩阵归一化。那么如何归一化呢?我们由A可以得到图的度D,由于A变成了\widetilde{A},我们认为\widetilde{A}的度为\widetilde{D}。常用的归一化方式就是用度数矩阵的倒数\widetilde{D}^{-1}

f(X,A) = \widetilde{D}^{-1}\widetilde{A}X

但是\widetilde{D}^{-1}\widetilde{A}仅仅对矩阵A进行了列上的缩放,操作后的元素值是不对称的,某种程度破坏了图结构的对称性。(这是为什么?)那么如何修复这种对称性呢?

方法是在行的方向上也进行对等缩放,具体 做法是,让邻接矩阵\widetilde{A}右乘一个缩放因子\widetilde{D}^{-1},这样就使得缩放版本的邻接矩阵重新恢复对称性。于是信息聚合的方式为

f(X,A) = \widetilde{D}^{-1}\widetilde{A}\widetilde{D}^{-1}X

\widetilde{D}^{-1}\widetilde{A}\widetilde{D}^{-1}能够很好地缩放邻接矩阵,既然-1次幂可以完成,为什么不尝试一下(-1/2)次幂呢?

事实上,对每个矩阵元素都实施\widetilde{D}^{-\frac{1}{2}}\widetilde{D}^{-\frac{1}{2}}=\frac{1}{\sqrt{deg(v_i)\sqrt{deg(v_j)}}}

这种操作可以对邻接矩阵地每一行每一列”无偏差“地进行一次归一化,以防相邻节点间度数不匹配对归一化地影响。(why)?

于是就出现了被众多学术论文广泛采纳地邻接矩阵地缩放形式

f(X,A) = \widetilde{D}^{-\frac{1}{2}}\widetilde{A}\widetilde{D}^{-\frac{1}{2}}X

考虑权值影响的信息聚合

上述仅仅考虑到邻接矩阵对获取邻居节点信息的影响,即只考虑拓扑结构施加的影响。事实上,对于特定节点而言,不同维度的特征值对给定任务的影响程度是不同的,如果第对各个特征值进行时 打分就,就要涉及到权值矩阵W了,也就是要构造更为完整的图神经网络模型 AWX。权值矩阵W通常是通过学习得到的。

f(X,A) = \widetilde{D}^{-\frac{1}{2}}\widetilde{A}\widetilde{D}^{-\frac{1}{2}}XW

如果我们想压缩节点输出的维度,也可以缩减权值矩阵的输出维度。

在以上的分析中,没有考虑激活函数的影响,无法给予神经网络的非线性变换能力,因此通常我们需要使用sigmoid、tanh、Relu等作为激活函数,最后再用argmax函数模拟一个分类的输出。

reference:

《从深度学习到图神经网络:模型与实践》  张玉宏 等

code:

import networkx as nx
import matplotlib.pyplot as plt
import numpy as np#定义节点
N = [(f"v{i}", 0) for i in range (1,3)] + [(f"v{i}",1) for i in range (3,5)] + [(f"v{i}",2) for i in range (5,6)] #定义节点#定义边
E = [("v1","v2"),("v1","v3"),("v2","v1"),("v2","v3"),("v2","v4"),("v3","v1"),("v3","v2"),("v3","v4"),("v4","v2"),("v4","v3"),("v4","v5"),("v5","v4")] #定义边G = nx.Graph() #构造图G.add_nodes_from(list(map(lambda x: x[0],N))) #给图添加节点
G.add_edges_from(E) #给图添加边ncolor =['r']*2 + ['b']*2 +['g']*1 #设置节点颜色
nsize = [700]*2 + [700]*2 + [700]*1 #设置节点的大小#显示图
nx.draw(G, with_labels= True, font_weight ='bold', font_color = 'w', node_color =ncolor, node_size =nsize)
plt.show()#借用nx构造邻接矩阵
A = np.array(nx.adjacency_matrix(G).todense())
print(A)#构造特征矩阵X
X = np.array([[i,-i, i+2] for i in  range (A.shape[0])])
print(X)#为了不丢失自己的属性,需要修改本身的邻接矩阵,因为最初邻接矩阵的斜对角线为0
I = np.eye(A.shape[0])
A_hat = A + I
print('A_hat')
print(A_hat)#计算自环邻接矩阵的度
D_hat = np.diag(np.sum(A_hat,axis= 0 ))
print(D_hat)#获取D——hat的逆矩阵,即一个缩放因子
D_1 = np.diag(D_hat) ** (-1) *np.eye(A_hat.shape[0])
print('D_1')
print(D_1)#缩放版的邻接矩阵
A_scale = D_1 @ A_hat  #对矩阵A仅仅进行了列方向上的缩放
print('A_scale')
print(A_scale)#用A_scale来聚合邻居节点的信息
X_new = A_scale @ X
print('X_new')
print(X_new)#修复原本的缩放的不对称性
scale_factor = D_1 @ A_hat @ D_1    #scale_factor 是对称的,而 A_scale是不对称 的
print('scale_factor')
print(scale_factor)#用scale_factor来聚合邻居节点的信息
X_new1 = scale_factor  @ X
print('X_new1')
print(X_new1)D_sq_half = np.diag(D_hat) ** (-0.5) *np.eye(A_hat.shape[0])
print('D_sq_half')
print(D_sq_half)#修复原本的缩放的不对称性
scale_factor2 = D_sq_half @ A_hat @ D_sq_half    #scale_factor 是对称的,而 A_scale是不对称 的
print('scale_factor2')
print(scale_factor2)#用scale_factor2来聚合邻居节点的信息
X_new2 = scale_factor2  @ X
print('X_new2')
print(X_new2)#给出的权值矩阵
W = np.array([[0.13,0.24],[0.37,-0.32],[0.14,-0.15]])X_new3 = X_new2 @ W
print(X_new3)#也可以缩减W的尺寸压缩节点的输出维度
W1 = np.array([[0.13],[0.37],[0.14]])
#计算logits
logits = X_new2 @ W1
print(logits)#以上都没有考虑到激活函数,无法模拟神经网络的非线性变换能力,可以使用激活函数
y = logits * (logits >0)  #使用Relu函数
print(y)#为了实现分类等功能,还需要添加一层Softmax
def softmax(x):return np.exp(x) /np.sum(np.exp(x), axis = 0)prob = softmax(y)
print('y')
print(y)#模拟一个分类输出
pred = np.argmax(prob)
print(pred)

这篇关于MPNN 模型:GNN 传递规则的实现的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/219974

相关文章

利用Python实现可回滚方案的示例代码

《利用Python实现可回滚方案的示例代码》很多项目翻车不是因为不会做,而是走错了方向却没法回头,技术选型失败的风险我们都清楚,但真正能提前规划“回滚方案”的人不多,本文从实际项目出发,教你如何用Py... 目录描述题解答案(核心思路)题解代码分析第一步:抽象缓存接口第二步:实现两个版本第三步:根据 Fea

Go语言使用slices包轻松实现排序功能

《Go语言使用slices包轻松实现排序功能》在Go语言开发中,对数据进行排序是常见的需求,Go1.18版本引入的slices包提供了简洁高效的排序解决方案,支持内置类型和用户自定义类型的排序操作,本... 目录一、内置类型排序:字符串与整数的应用1. 字符串切片排序2. 整数切片排序二、检查切片排序状态:

python利用backoff实现异常自动重试详解

《python利用backoff实现异常自动重试详解》backoff是一个用于实现重试机制的Python库,通过指数退避或其他策略自动重试失败的操作,下面小编就来和大家详细讲讲如何利用backoff实... 目录1. backoff 库简介2. on_exception 装饰器的原理2.1 核心逻辑2.2

Java实现视频格式转换的完整指南

《Java实现视频格式转换的完整指南》在Java中实现视频格式的转换,通常需要借助第三方工具或库,因为视频的编解码操作复杂且性能需求较高,以下是实现视频格式转换的常用方法和步骤,需要的朋友可以参考下... 目录核心思路方法一:通过调用 FFmpeg 命令步骤示例代码说明优点方法二:使用 Jaffree(FF

基于C#实现MQTT通信实战

《基于C#实现MQTT通信实战》MQTT消息队列遥测传输,在物联网领域应用的很广泛,它是基于Publish/Subscribe模式,具有简单易用,支持QoS,传输效率高的特点,下面我们就来看看C#实现... 目录1、连接主机2、订阅消息3、发布消息MQTT(Message Queueing Telemetr

Java实现图片淡入淡出效果

《Java实现图片淡入淡出效果》在现代图形用户界面和游戏开发中,**图片淡入淡出(FadeIn/Out)**是一种常见且实用的视觉过渡效果,它可以用于启动画面、场景切换、轮播图、提示框弹出等场景,通过... 目录1. 项目背景详细介绍2. 项目需求详细介绍2.1 功能需求2.2 非功能需求3. 相关技术详细

Python实现获取带合并单元格的表格数据

《Python实现获取带合并单元格的表格数据》由于在日常运维中经常出现一些合并单元格的表格,如果要获取数据比较麻烦,所以本文我们就来聊聊如何使用Python实现获取带合并单元格的表格数据吧... 由于在日常运维中经常出现一些合并单元格的表格,如果要获取数据比较麻烦,现将将封装成类,并通过调用list_exc

Nginx路由匹配规则及优先级详解

《Nginx路由匹配规则及优先级详解》Nginx作为一个高性能的Web服务器和反向代理服务器,广泛用于负载均衡、请求转发等场景,在配置Nginx时,路由匹配规则是非常重要的概念,本文将详细介绍Ngin... 目录引言一、 Nginx的路由匹配规则概述二、 Nginx的路由匹配规则类型2.1 精确匹配(=)2

使用animation.css库快速实现CSS3旋转动画效果

《使用animation.css库快速实现CSS3旋转动画效果》随着Web技术的不断发展,动画效果已经成为了网页设计中不可或缺的一部分,本文将深入探讨animation.css的工作原理,如何使用以及... 目录1. css3动画技术简介2. animation.css库介绍2.1 animation.cs

Java进行日期解析与格式化的实现代码

《Java进行日期解析与格式化的实现代码》使用Java搭配ApacheCommonsLang3和Natty库,可以实现灵活高效的日期解析与格式化,本文将通过相关示例为大家讲讲具体的实践操作,需要的可以... 目录一、背景二、依赖介绍1. Apache Commons Lang32. Natty三、核心实现代