图神经网络实战(7)——图卷积网络(Graph Convolutional Network, GCN)详解与实现

本文主要是介绍图神经网络实战(7)——图卷积网络(Graph Convolutional Network, GCN)详解与实现,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

图神经网络实战(7)——图卷积网络详解与实现

    • 前言
    • 1. 图卷积层
    • 2. 比较 GCN 和 GNN
      • 2.1 数据集分析
      • 2.2 实现 GCN 架构
    • 小结
    • 系列链接

前言

图卷积网络 (Graph Convolutional Network, GCN) 架构由 KipfWelling2017 年提出,其理念是创建一种适用于图的高效卷积神经网络 (Convolutional Neural Networks, CNN)。更准确地说,它是图信号处理中图卷积操作的近似,由于其易用性,GCN 已成为最受欢迎的图神经网络 (Graph Neural Networks, GNN) 之一,是处理图数据时创建基线模型的首选架构。
在本节中,我们将讨论 Vanilla GNN 架构的局限性,这有助于我们理解 GCN 的核心思想。并详细介绍 GCN 的工作原理,解释为什么 GCNVanilla GNN 性能更好,通过使用 PyTorch Geometric 在 Cora 和 Facebook Page-Page 数据集上实现 GCN 来验证其性能。

1. 图卷积层

与表格或图像数据不同,图数据中节点的邻居数量并不总是相同。例如,在下图中,节点 13 个邻居,而节点 2 只有 1 个:

图数据

但是,观察图神经网络 (Graph Neural Networks, GNN) 层就会发现,邻居数量的差异并不会导致计算的复杂化。GNN 层由一个简单的求和公式组成,没有任何归一化系数,计算节点 i i i 的嵌入方法如下:
h i = ∑ j ∈ N i x j W T h_i=\sum_{j\in \mathcal N_i}x_jW^T hi=jNixjWT
假设节点 11,000 个邻居,而节点 2 只有 1 个邻居,那么 h 1 h_1 h1 嵌入的值将远远大于 h 2 h_2 h2 嵌入的值。这样便会出现一个问题,当我们要对这些嵌入进行比较时,如果它们的值相差过大,如何进行有意义的比较?
一个简单的解决方案是将嵌入除以邻居数量,用 deg ⁡ ( A ) \deg(A) deg(A) 表示节点的度,因此 GNN 层公式可以更新为:
h i = 1 deg ⁡ ( i ) ∑ j ∈ N i x j W T h_i=\frac 1{\deg(i)}\sum_{j\in \mathcal N_i}x_jW^T hi=deg(i)1jNixjWT
那么如何将其转化为矩阵乘法呢?首先回顾普通 GNN 层的计算公式:
H = A ~ T X W T H=\tilde A^TXW^T H=A~TXWT
其中, A ~ = A + I \tilde A=A+I A~=A+I。公式中缺少的是一个能为我们提供归一化系数 1 deg ⁡ ( A ) \frac 1 {\deg(A)} deg(A)1 的矩阵,可以利用度矩阵 D D D 来计算每个节点的邻居数量。上示图像中的图的度矩阵如下:
D = [ 3 0 0 0 0 1 0 0 0 0 2 0 0 0 0 2 ] D=\left[\begin{array}{c} 3 & 0 & 0 & 0\\ 0 & 1 & 0 & 0\\ 0 & 0 & 2 & 0\\ 0 & 0 & 0 & 2\\ \end{array}\right] D= 3000010000200002
使用 NumPy 表示以上矩阵:

import numpy as npD = np.array([[3, 0, 0, 0],[0, 1, 0, 0],[0, 0, 2, 0],[0, 0, 0, 2]
])

根据定义, D D D 给出了每个节点的度 deg ⁡ ( i ) \deg(i) deg(i) 。因此,根据度矩阵的逆矩阵 D − 1 D^{-1} D1 可以直接得到归一化系数 1 deg ⁡ ( A ) \frac 1 {\deg(A)} deg(A)1

可以使用 numpy.linalg.inv() 函数计算矩阵的逆:

print(np.linalg.inv(D))
'''输出如下
[[0.33333333 0.         0.         0.        ][0.         1.         0.         0.        ][0.         0.         0.5        0.        ][0.         0.         0.         0.5       ]]
'''

为了更加精确,在图中添加了自循环,用 A ~ = A + I \tilde A=A+I A~=A+I 表示。同样,我们也需要在度矩阵中加入自循环,即 D ~ = D + I \tilde D= D+I D~=D+I ,因此最终所需的矩阵为 D ~ − 1 = ( D + I ) − 1 \tilde D^{-1} = (D+I)^{-1} D~1=(D+I)1

NumPy 中,可以使用函数 numpy.identity(n) 快速创建指定维度 n 的单位矩阵 I I I

print(np.linalg.inv(D + np.identity(4)))
'''输出如下
[[0.25       0.         0.         0.        ][0.         0.5        0.         0.        ][0.         0.         0.33333333 0.        ][0.         0.         0.         0.33333333]]
'''

得到归一化系数矩阵后,有两种应用方式:

  • D ~ − 1 A ~ X W T \tilde D^{-1}\tilde AXW^T D~1A~XWT 会对每一行特征进行归一化处理。
  • A ~ D ~ − 1 X W T \tilde A \tilde D^{-1}XW^T A~D~1XWT 会对每一列特征进行归一化处理。

接下来,通过计算 D ~ − 1 A ~ \tilde D^{-1}\tilde A D~1A~ A ~ D ~ − 1 \tilde A \tilde D^{-1} A~D~1 进行验证:

D ~ − 1 A ~ = [ 1 4 0 0 0 0 1 2 0 0 0 0 1 3 0 0 0 0 1 3 ] ⋅ [ 1 1 1 1 1 1 0 0 1 0 1 1 1 0 1 1 ] = [ 1 4 1 4 1 4 1 4 1 2 1 2 0 0 1 3 0 1 3 1 3 1 3 0 1 3 1 3 ] A ~ D ~ − 1 = [ 1 1 1 1 1 1 0 0 1 0 1 1 1 0 1 1 ] ⋅ [ 1 4 0 0 0 0 1 2 0 0 0 0 1 3 0 0 0 0 1 3 ] = [ 1 4 1 2 1 3 1 3 1 4 1 2 0 0 1 4 0 1 3 1 3 1 4 0 1 3 1 3 ] \tilde D^{-1}\tilde A=\left[\begin{array}{c} \frac 14 & 0 & 0 & 0\\ 0 & \frac 12 & 0 & 0\\ 0 & 0 & \frac 13 & 0\\ 0 & 0 & 0 & \frac 13\\ \end{array}\right] \cdot \left[\begin{array}{c} 1 & 1 & 1 & 1\\ 1 & 1 & 0 & 0\\ 1 & 0 & 1 & 1\\ 1 & 0 & 1 &1\\ \end{array}\right]=\left[\begin{array}{c} \frac 14 & \frac 14 & \frac 14 & \frac 14\\ \frac 12 & \frac 12 & 0 & 0\\ \frac 13 & 0 & \frac 13 & \frac 13\\ \frac 13 & 0 & \frac 13 & \frac 13\\ \end{array}\right]\\ \tilde A \tilde D^{-1}=\left[\begin{array}{c} 1 & 1 & 1 & 1\\ 1 & 1 & 0 & 0\\ 1 & 0 & 1 & 1\\ 1 & 0 & 1 &1\\ \end{array}\right] \cdot \left[\begin{array}{c} \frac 14 & 0 & 0 & 0\\ 0 & \frac 12 & 0 & 0\\ 0 & 0 & \frac 13 & 0\\ 0 & 0 & 0 & \frac 13\\ \end{array}\right]=\left[\begin{array}{c} \frac 14 & \frac 12 & \frac 13 & \frac 13\\ \frac 14 & \frac 12 & 0 & 0\\ \frac 14 & 0 & \frac 13 & \frac 13\\ \frac 14 & 0 & \frac 13 & \frac 13\\ \end{array}\right] D~1A~= 41000021000031000031 1111110010111011 = 4121313141210041031314103131 A~D~1= 1111110010111011 41000021000031000031 = 4141414121210031031313103131

在第一种情况下,每一行的和都等于 1;在第二种情况下,每一列的和都等于 1。矩阵乘法可以使用 numpy.matmul() 函数执行,或使用 Python 内置的矩阵乘法运算符 @。定义邻接矩阵并使用 @ 操作符计算矩阵乘法:

A = np.array([[1, 1, 1, 1],[1, 1, 0, 0],[1, 0, 1, 1],[1, 0, 1, 1]
])
print(np.linalg.inv(D + np.identity(4)) @ A)
print('------------------------------')
print(A @ np.linalg.inv(D + np.identity(4)))
'''输出如下
[[0.25       0.25       0.25       0.25      ][0.5        0.5        0.         0.        ][0.33333333 0.         0.33333333 0.33333333][0.33333333 0.         0.33333333 0.33333333]]
------------------------------
[[0.25       0.5        0.33333333 0.33333333][0.25       0.5        0.         0.        ][0.25       0.         0.33333333 0.33333333][0.25       0.         0.33333333 0.33333333]]
'''

得到的结果与手动计算的矩阵乘法相同。那么,在实践中我们应该使用哪种应用方式?第一种方案似乎看起来合理,因为它能很好地对相邻节点特征进行归一化处理。
KipfWelling 提出,具有多个邻居的节点的特征很容易传播,而与之相反,孤立节点的特征不容易传播。在 GCN 论文中,作者提出了一种混合归一化方法来平衡这种影响。在实践中,使用以下公式为邻居较少的节点分配更高的权重:
H = D ~ − 1 2 A ~ T D ~ − 1 2 X W T H=\tilde D^{-\frac 12}\tilde A^T\tilde D^{-\frac 12}XW^T H=D~21A~TD~21XWT
就单个嵌入而言,上式可以写为:
h i = ∑ j ∈ N i 1 deg ⁡ ( i ) deg ⁡ ( j ) x j W T h_i=\sum_{j\in \mathcal N_i}\frac 1{\sqrt {\deg(i)}\sqrt {\deg(j)}}x_jW^T hi=jNideg(i) deg(j) 1xjWT
这就是实现原始图卷积层的数学公式。与普通的 GNN 层一样,我们可以通过堆叠图卷积层创建 GCN。接下来,使用 PyTorch Geometric 实现一个 GCN 模型,并验证其性能是否优于原始图神经网络模型。

2. 比较 GCN 和 GNN

我们已经证明了 vanilla GNN 性能优于 Node2Vec 模型,接下来,我们将其与 GCN 进行比较,比较它们在 Cora 和 Facebook Page-Page 数据集上的表现。
与普通 GNN 相比,GCN 的主要特点是通过考虑节点度来权衡其特征。在构建模型之前,我们首先计算这两个数据集中的节点度,这与 GCN 的性能直接相关。
根据我们对 GCN 架构的了解,可以猜测当节点度差异较大时,它的性能会更好。如果每个节点都有相同数量的邻居,那么无论使用哪种归一化方式,架构之间都是等价的: deg ⁡ ( i ) deg ⁡ ( i ) = deg ⁡ ( i ) \sqrt {\deg(i)} \sqrt {\deg(i)}= \deg (i) deg(i) deg(i) =deg(i)

2.1 数据集分析

(1)PyTorch Geometric 中导入 Planetoid 类,为了可视化节点度,同时导入两个附加类( degree 用于获取每个节点的邻居数,Counter 用于计算每个度数的节点数)和 matplotlib 库:

import torch
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import degree
from collections import Counter
import matplotlib.pyplot as plt

(2) 导入 Cora 数据集,并将图存储在 data 中:

dataset = Planetoid(root=".", name="Cora")
data = dataset[0]

(3) 计算图中每个节点的邻居数:

degrees = degree(data.edge_index[0]).numpy()

(4) 为了生成更自然的可视化效果,统计具有相同度的节点数量:

numbers = Counter(degrees)

(5) 使用条形图来绘制统计结果:

fig, ax = plt.subplots()
ax.set_xlabel('Node degree')
ax.set_ylabel('Number of nodes')
plt.bar(numbers.keys(), numbers.values())
plt.show()

节点度分布

从上图中可以看出,图中的度分布近似指数分布,从 1 个邻居( 485 个节点)到 168 个邻居( 1 个节点)不等,这种不平衡的数据集正是归一化处理的用武之地。

(6)Facebook Page-Page 数据集上重复同样的过程:

from torch_geometric.datasets import FacebookPagePage# Import dataset from PyTorch Geometric
dataset = FacebookPagePage(root=".")
data = dataset[0]# Create masks
data.train_mask = range(18000)
data.val_mask = range(18001, 20000)
data.test_mask = range(20001, 22470)# Get list of degrees for each node
degrees = degree(data.edge_index[0]).numpy()# Count the number of nodes for each degree
numbers = Counter(degrees)# Bar plot
fig, ax = plt.subplots()
ax.set_xlabel('Node degree')
ax.set_ylabel('Number of nodes')
plt.bar(numbers.keys(), numbers.values())
plt.show()

节点度分布

Facebook Page-Page 数据集的图的节点度分布看起来更加失衡,邻居数量从 1709 不等。出于同样的原因,Facebook Page-Page 数据集也是应用 GCN 的合适实例。

2.2 实现 GCN 架构

我们可以从零开始实现 GCN 层,但这里我们无需再从头造轮子,PyTorch Geometric 已经内置了 GCN 层,首先在 Cora 数据集上实现 GCN 架构。

(1)PyTorch Geometric 中导入 GCN 层,并导入 PyTorch

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConvdataset = Planetoid(root=".", name="Cora")
data = dataset[0]

(2) 创建函数 accuracy() 计算模型准确率:

def accuracy(y_pred, y_true):"""Calculate accuracy."""return torch.sum(y_pred == y_true) / len(y_true)

(3) 创建 GCN 类,其中 __init__() 函数接受三个参数作为输入:输入维度 dim_in、隐藏维度 dim_h 和输出维度 dim_out

class GCN(torch.nn.Module):"""Graph Convolutional Network"""def __init__(self, dim_in, dim_h, dim_out):super().__init__()self.gcn1 = GCNConv(dim_in, dim_h)self.gcn2 = GCNConv(dim_h, dim_out)

(4) forward() 方法使用两个 GCN 层,并对分类结果应用 log_softmax 函数:

    def forward(self, x, edge_index):h = self.gcn1(x, edge_index)h = torch.relu(h)h = self.gcn2(h, edge_index)return F.log_softmax(h, dim=1)

(5) fit() 方法与 Vanilla GNN 相同,为了更好的比较,使用具有相同参数的 Adam 优化器,其中学习率 lr0.1L2 正则化 weight_decay0.0005

    def fit(self, data, epochs):criterion = torch.nn.CrossEntropyLoss()optimizer = torch.optim.Adam(self.parameters(),lr=0.01,weight_decay=5e-4)self.train()for epoch in range(epochs+1):optimizer.zero_grad()out = self(data.x, data.edge_index)loss = criterion(out[data.train_mask], data.y[data.train_mask])acc = accuracy(out[data.train_mask].argmax(dim=1),data.y[data.train_mask])loss.backward()optimizer.step()if(epoch % 20 == 0):val_loss = criterion(out[data.val_mask], data.y[data.val_mask])val_acc = accuracy(out[data.val_mask].argmax(dim=1),data.y[data.val_mask])print(f'Epoch {epoch:>3} | Train Loss: {loss:.3f} | Train Acc:'f' {acc*100:>5.2f}% | Val Loss: {val_loss:.2f} | 'f'Val Acc: {val_acc*100:.2f}%')

(6) 编写 test() 方法:

    @torch.no_grad()def test(self, data):self.eval()out = self(data.x, data.edge_index)acc = accuracy(out.argmax(dim=1)[data.test_mask], data.y[data.test_mask])return acc

(7) 实例化模型并训练 100epoch

# Create the Vanilla GNN model
gcn = GCN(dataset.num_features, 16, dataset.num_classes)
print(gcn)# Train
gcn.fit(data, epochs=100)

训练过程中的输出结果如下:

训练过程

(8) 最后,在测试集上对模型进行评估:

acc = gcn.test(data)
print(f'\nGCN test accuracy: {acc*100:.2f}%\n')# GCN test accuracy: 80.30%

重复此实验 100 次,模型的平均准确率为 80.26%(±0.59%),明显 vanilla GNN 模型的平均准确率 74.99%(±1.60%)

(9) 将同样的模型应用于 Facebook Page-Page 数据集,其平均准确率可以达到 91.78%(±0.31%),同样比 vanilla GNN 的结果( 84.91%(±1.88%) )高出很多:

# Load Facebook Page-Page
dataset = FacebookPagePage(root=".")
data = dataset[0]
data.train_mask = range(18000)
data.val_mask = range(18001, 20000)
data.test_mask = range(20001, 22470)# Train GCN
gcn = GCN(dataset.num_features, 16, dataset.num_classes)
print(gcn)
gcn.fit(data, epochs=100)
acc = gcn.test(data)
print(f'\nGCN test accuracy: {acc*100:.2f}%\n')

模型训练过程

下表总结了不同模型在不同数据集上的准确率和标准差:

MLPGNNGCN
Cora53.47%(±1.95%)74.99%(±1.60%)80.26%(±0.59%)
Facebook75.22%(±0.39%)84.91%(±1.88%)91.78%(±0.31%)

我们可以将这些性能提升归因于这两个数据集中节点度的分布的不平衡性。通过对特征进行归一化处理,并考虑中心节点及其邻居的数量,GCN 的灵活性得到了极大的提升,可以很好地处理各种类型的图。但节点分类远不是 GCN 的唯一应用,在之后的学习中,我们将看到 GCN 模型的更多新颖应用。

小结

在本节中,我们改进了 vanilla GNN 层,使其能够正确归一化节点特征,这一改进引入了图卷积网络 (Graph Convolutional Network, GCN) 层和混合归一化。在 CoraFacebook Page-Page 数据集上,我们对比了 GCN 架构与 Node2Vecvanilla GNN 之间的性能差异。由于采用了归一化处理,GCN 在这两个数据集中都具有较高的准确率。

系列链接

图神经网络实战(1)——图神经网络(Graph Neural Networks, GNN)基础
图神经网络实战(2)——图论基础
图神经网络实战(3)——基于DeepWalk创建节点表示
图神经网络实战(4)——基于Node2Vec改进嵌入质量
图神经网络实战(5)——常用图数据集
图神经网络实战(6)——使用PyTorch构建图神经网络

这篇关于图神经网络实战(7)——图卷积网络(Graph Convolutional Network, GCN)详解与实现的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/872455

相关文章

C#借助Spire.XLS for .NET实现在Excel中添加文档属性

《C#借助Spire.XLSfor.NET实现在Excel中添加文档属性》在日常的数据处理和项目管理中,Excel文档扮演着举足轻重的角色,本文将深入探讨如何在C#中借助强大的第三方库Spire.... 目录为什么需要程序化添加Excel文档属性使用Spire.XLS for .NET库实现文档属性管理Sp

C++ move 的作用详解及陷阱最佳实践

《C++move的作用详解及陷阱最佳实践》文章详细介绍了C++中的`std::move`函数的作用,包括为什么需要它、它的本质、典型使用场景、以及一些常见陷阱和最佳实践,感兴趣的朋友跟随小编一起看... 目录C++ move 的作用详解一、一句话总结二、为什么需要 move?C++98/03 的痛点⚡C++

Python+FFmpeg实现视频自动化处理的完整指南

《Python+FFmpeg实现视频自动化处理的完整指南》本文总结了一套在Python中使用subprocess.run调用FFmpeg进行视频自动化处理的解决方案,涵盖了跨平台硬件加速、中间素材处理... 目录一、 跨平台硬件加速:统一接口设计1. 核心映射逻辑2. python 实现代码二、 中间素材处

MySQL中between and的基本用法、范围查询示例详解

《MySQL中betweenand的基本用法、范围查询示例详解》BETWEENAND操作符在MySQL中用于选择在两个值之间的数据,包括边界值,它支持数值和日期类型,示例展示了如何使用BETWEEN... 目录一、between and语法二、使用示例2.1、betwphpeen and数值查询2.2、be

python中的flask_sqlalchemy的使用及示例详解

《python中的flask_sqlalchemy的使用及示例详解》文章主要介绍了在使用SQLAlchemy创建模型实例时,通过元类动态创建实例的方式,并说明了如何在实例化时执行__init__方法,... 目录@orm.reconstructorSQLAlchemy的回滚关联其他模型数据库基本操作将数据添

Java数组动态扩容的实现示例

《Java数组动态扩容的实现示例》本文主要介绍了Java数组动态扩容的实现示例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧... 目录1 问题2 方法3 结语1 问题实现动态的给数组添加元素效果,实现对数组扩容,原始数组使用静态分配

Java中ArrayList与顺序表示例详解

《Java中ArrayList与顺序表示例详解》顺序表是在计算机内存中以数组的形式保存的线性表,是指用一组地址连续的存储单元依次存储数据元素的线性结构,:本文主要介绍Java中ArrayList与... 目录前言一、Java集合框架核心接口与分类ArrayList二、顺序表数据结构中的顺序表三、常用代码手动

Python实现快速扫描目标主机的开放端口和服务

《Python实现快速扫描目标主机的开放端口和服务》这篇文章主要为大家详细介绍了如何使用Python编写一个功能强大的端口扫描器脚本,实现快速扫描目标主机的开放端口和服务,感兴趣的小伙伴可以了解下... 目录功能介绍场景应用1. 网络安全审计2. 系统管理维护3. 网络故障排查4. 合规性检查报错处理1.

JAVA线程的周期及调度机制详解

《JAVA线程的周期及调度机制详解》Java线程的生命周期包括NEW、RUNNABLE、BLOCKED、WAITING、TIMED_WAITING和TERMINATED,线程调度依赖操作系统,采用抢占... 目录Java线程的生命周期线程状态转换示例代码JAVA线程调度机制优先级设置示例注意事项JAVA线程

Python轻松实现Word到Markdown的转换

《Python轻松实现Word到Markdown的转换》在文档管理、内容发布等场景中,将Word转换为Markdown格式是常见需求,本文将介绍如何使用FreeSpire.DocforPython实现... 目录一、工具简介二、核心转换实现1. 基础单文件转换2. 批量转换Word文件三、工具特性分析优点局