图神经网络框架DGL实现Graph Attention Network (GAT)笔记

2024-09-08 09:18

本文主要是介绍图神经网络框架DGL实现Graph Attention Network (GAT)笔记,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

参考列表:

[1]深入理解图注意力机制
[2]DGL官方学习教程一 ——基础操作&消息传递
[3]Cora数据集介绍+python读取

一、DGL实现GAT分类机器学习论文

程序摘自[1],该程序实现了利用图神经网络框架——DGL,实现图注意网络(GAT)。应用demo为对机器学习论文数据集——Cora,对论文所属类别进行分类。(下图摘自[3])
在这里插入图片描述

1. 程序

Ubuntu:18.04
cuda:11.1
cudnn:8.0.4.30
pytorch:1.7.0
networkx:2.5

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as npclass GATLayer(nn.Module):def __init__(self, g, in_dim, out_dim):super(GATLayer, self).__init__()self.g = gself.fc = nn.Linear(in_dim, out_dim, bias=False)self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False)def edge_attention(self, edges):z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)a = self.attn_fc(z2)return {'e' : F.leaky_relu(a)}def message_func(self, edges):return {'z' : edges.src['z'], 'e' : edges.data['e']}def reduce_func(self, nodes):alpha = F.softmax(nodes.mailbox['e'], dim=1)h = torch.sum(alpha * nodes.mailbox['z'], dim=1)return {'h' : h}def forward(self, h):z = self.fc(h) # eq. 1self.g.ndata['z'] = z self.g.apply_edges(self.edge_attention) # eq. 2self.g.update_all(self.message_func, self.reduce_func) # eq. 3 and 4return self.g.ndata.pop('h')class MultiHeadGATLayer(nn.Module):def __init__(self, g, in_dim, out_dim, num_heads, merge='cat'):super(MultiHeadGATLayer, self).__init__()self.heads = nn.ModuleList()for i in range(num_heads):self.heads.append(GATLayer(g, in_dim, out_dim))self.merge = mergedef forward(self, h):head_outs = [attn_head(h) for attn_head in self.heads]if self.merge == 'cat':return torch.cat(head_outs, dim=1)else:return torch.mean(torch.stack(head_outs))class GAT(nn.Module):def __init__(self, g, in_dim, hidden_dim, out_dim, num_heads):super(GAT, self).__init__()self.layer1 = MultiHeadGATLayer(g, in_dim, hidden_dim, num_heads)self.layer2 = MultiHeadGATLayer(g, hidden_dim * num_heads, out_dim, 1)def forward(self, h):h = self.layer1(h)h = F.elu(h)h = self.layer2(h)return hfrom dgl import DGLGraph
from dgl.data import citation_graph as citegrhdef load_core_data():data = citegrh.load_cora()features = torch.FloatTensor(data.features)labels = torch.LongTensor(data.labels)mask = torch.ByteTensor(data.train_mask)g = DGLGraph(data.graph)return g, features, labels, maskimport time 
import numpy as np
g, features, labels, mask = load_core_data()net = GAT(g, in_dim = features.size()[1], hidden_dim=8, out_dim=7, num_heads=8)optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
dur = []
for epoch in range(300):if epoch >= 3:t0 = time.time()logits = net(features)logp = F.log_softmax(logits, 1)loss = F.nll_loss(logp[mask], labels[mask])optimizer.zero_grad()loss.backward()optimizer.step()if epoch >= 3:dur.append(time.time() - t0)print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f}".format(epoch, loss.item(), np.mean(dur)))
2.笔记
2.1 初始化一个graph的两种方式

对于如下图数据结构:
0->1
1->2
3->1

多称之为小括号方式

import networkx as nx
import matplotlib.pyplot as plt
import dgl
import torch
%matplotlib inline
g = dgl.graph((torch.tensor([0, 1, 3]), torch.tensor([1, 2, 1]))) # 小括号
nx.draw(g.to_networkx(), node_size=50, node_color=[[.5, .5, .5,]])  #使用nx绘制,设置节点大小及灰度值
plt.show()

在这里插入图片描述
或中括号方式:

import networkx as nx
import matplotlib.pyplot as plt
import dgl
import torch
%matplotlib inline
g = dgl.graph([torch.tensor([0, 1]), torch.tensor([1, 2]), torch.tensor([3, 1])]) # 中括号
nx.draw(g.to_networkx(), node_size=50, node_color=[[.5, .5, .5,]])  #使用nx绘制,设置节点大小及灰度值
plt.show()

在这里插入图片描述
note: 同一个graph,每次打印出来的各节点的位置是随机的。

2.2 DGL的update_all函数实际工作过程

利用如下例程说明:

import networkx as nx
import matplotlib.pyplot as plt
import torch
import dglN = 100  # number of nodes
DAMP = 0.85  # damping factor阻尼因子
K = 10  # number of iterations
g = nx.nx.erdos_renyi_graph(N, 0.1) #图随机生成器,生成nx图
g = dgl.DGLGraph(g)                 #转换成DGL图
g.ndata['pv'] = torch.ones(N) / N  #初始化PageRank值
g.ndata['deg'] = g.in_degrees(g.nodes()).float()  #初始化节点特征
print(g.ndata['deg'])
#定义message函数,它将每个节点的PageRank值除以其out-degree,并将结果作为消息传递给它的邻居:
def pagerank_message_func(edges):return {'pv' : edges.src['pv'] / edges.src['deg']}
#定义reduce函数,它从mailbox中删除并聚合message,并计算其新的PageRank值:
def pagerank_reduce_func(nodes):print("-batch size--pv size-------------")print(nodes.batch_size(), nodes.mailbox['pv'].size())msgs = torch.sum(nodes.mailbox['pv'], dim=1)pv = (1 - DAMP) / N + DAMP * msgsreturn {'pv' : pv}
g.update_all(pagerank_message_func, pagerank_reduce_func)

打印g.ndata[‘deg’]信息(也即每个节点的入度信息)如下:

tensor([ 9., 7., 17., 10., 12., 13., 13., 9., 5., 14., 7., 12., 15., 6.,
15., 7., 13., 7., 11., 9., 9., 15., 9., 12., 10., 8., 10., 9.,
15., 7., 8., 10., 10., 8., 11., 13., 6., 10., 10., 11., 5., 13.,
6., 12., 12., 8., 6., 11., 9., 10., 12., 8., 11., 5., 7., 12.,
4., 7., 8., 13., 11., 14., 9., 10., 12., 10., 10., 9., 10., 13.,
7., 15., 15., 10., 6., 11., 4., 6., 5., 10., 9., 11., 19., 9.,
12., 13., 15., 12., 12., 11., 10., 8., 11., 9., 7., 7., 11., 3.,
10., 5.])

pagerank_reduce_func函数内的打印信息如下:

-batch size–pv size-------------
1 torch.Size([1, 3])
-batch size–pv size-------------
2 torch.Size([2, 4])
-batch size–pv size-------------
5 torch.Size([5, 5])
-batch size–pv size-------------
6 torch.Size([6, 6])
-batch size–pv size-------------
10 torch.Size([10, 7])
-batch size–pv size-------------
7 torch.Size([7, 8])
-batch size–pv size-------------
12 torch.Size([12, 9])
-batch size–pv size-------------
16 torch.Size([16, 10])
-batch size–pv size-------------
11 torch.Size([11, 11])
-batch size–pv size-------------
11 torch.Size([11, 12])
-batch size–pv size-------------
8 torch.Size([8, 13])
-batch size–pv size-------------
2 torch.Size([2, 14])
-batch size–pv size-------------
7 torch.Size([7, 15])
-batch size–pv size-------------
1 torch.Size([1, 17])
-batch size–pv size-------------
1 torch.Size([1, 19])

入度为3的节点只有一个,入度为4的节点有两个,入度为5的节点五个,…

对比图的入度信息与pagerank_reduce_func函数内的打印信息,我们发现:入度为3的节点只有一个,入度为4的节点有两个,入度为5的节点五个,…因此,得出:
1)函数update_all并不是将所有节点一起更新;
2)函数update_all将具有同等个数目标节点的节点放在一起更新,形成一个batch,这也是为什么reduce_func(nodes)中的入参中的入参type为dgl.udf.NodeBatch的原因。reduce_func(nodes)中的入参nodes的不同行代表与不同节点相关的数据。

这篇关于图神经网络框架DGL实现Graph Attention Network (GAT)笔记的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1147712

相关文章

C++中零拷贝的多种实现方式

《C++中零拷贝的多种实现方式》本文主要介绍了C++中零拷贝的实现示例,旨在在减少数据在内存中的不必要复制,从而提高程序性能、降低内存使用并减少CPU消耗,零拷贝技术通过多种方式实现,下面就来了解一下... 目录一、C++中零拷贝技术的核心概念二、std::string_view 简介三、std::stri

C++高效内存池实现减少动态分配开销的解决方案

《C++高效内存池实现减少动态分配开销的解决方案》C++动态内存分配存在系统调用开销、碎片化和锁竞争等性能问题,内存池通过预分配、分块管理和缓存复用解决这些问题,下面就来了解一下... 目录一、C++内存分配的性能挑战二、内存池技术的核心原理三、主流内存池实现:TCMalloc与Jemalloc1. TCM

OpenCV实现实时颜色检测的示例

《OpenCV实现实时颜色检测的示例》本文主要介绍了OpenCV实现实时颜色检测的示例,通过HSV色彩空间转换和色调范围判断实现红黄绿蓝颜色检测,包含视频捕捉、区域标记、颜色分析等功能,具有一定的参考... 目录一、引言二、系统概述三、代码解析1. 导入库2. 颜色识别函数3. 主程序循环四、HSV色彩空间

Python实现精准提取 PDF中的文本,表格与图片

《Python实现精准提取PDF中的文本,表格与图片》在实际的系统开发中,处理PDF文件不仅限于读取整页文本,还有提取文档中的表格数据,图片或特定区域的内容,下面我们来看看如何使用Python实... 目录安装 python 库提取 PDF 文本内容:获取整页文本与指定区域内容获取页面上的所有文本内容获取

基于Python实现一个Windows Tree命令工具

《基于Python实现一个WindowsTree命令工具》今天想要在Windows平台的CMD命令终端窗口中使用像Linux下的tree命令,打印一下目录结构层级树,然而还真有tree命令,但是发现... 目录引言实现代码使用说明可用选项示例用法功能特点添加到环境变量方法一:创建批处理文件并添加到PATH1

Java使用HttpClient实现图片下载与本地保存功能

《Java使用HttpClient实现图片下载与本地保存功能》在当今数字化时代,网络资源的获取与处理已成为软件开发中的常见需求,其中,图片作为网络上最常见的资源之一,其下载与保存功能在许多应用场景中都... 目录引言一、Apache HttpClient简介二、技术栈与环境准备三、实现图片下载与保存功能1.

canal实现mysql数据同步的详细过程

《canal实现mysql数据同步的详细过程》:本文主要介绍canal实现mysql数据同步的详细过程,本文通过实例图文相结合给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的... 目录1、canal下载2、mysql同步用户创建和授权3、canal admin安装和启动4、canal

Nexus安装和启动的实现教程

《Nexus安装和启动的实现教程》:本文主要介绍Nexus安装和启动的实现教程,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、Nexus下载二、Nexus安装和启动三、关闭Nexus总结一、Nexus下载官方下载链接:DownloadWindows系统根

SpringBoot集成LiteFlow实现轻量级工作流引擎的详细过程

《SpringBoot集成LiteFlow实现轻量级工作流引擎的详细过程》LiteFlow是一款专注于逻辑驱动流程编排的轻量级框架,它以组件化方式快速构建和执行业务流程,有效解耦复杂业务逻辑,下面给大... 目录一、基础概念1.1 组件(Component)1.2 规则(Rule)1.3 上下文(Conte

MySQL 横向衍生表(Lateral Derived Tables)的实现

《MySQL横向衍生表(LateralDerivedTables)的实现》横向衍生表适用于在需要通过子查询获取中间结果集的场景,相对于普通衍生表,横向衍生表可以引用在其之前出现过的表名,本文就来... 目录一、横向衍生表用法示例1.1 用法示例1.2 使用建议前面我们介绍过mysql中的衍生表(From子句