使用DGL完成节点分类任务

2024-02-20 09:30

本文主要是介绍使用DGL完成节点分类任务,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

更多图神经网络和深度学习内容请关注:
在这里插入图片描述

节点分类任务概述

节点分类(node classification)任务是在图数据处理中最流行任务之一,一个模型需要预测每个节点属于哪个类别。

在图神经网络出现之前,用于结点分类任务的方法可归为两大类:

  • 仅使用连通性(如DeepWalk或node2vec)
  • 简单地结合连通性和节点自身的特征

相比之下,GNNs是一个通过结合局部邻域(广义上的邻居,包含结点自身)的连通性及其特征来获得节点表征的方法。

Kipf等人将节点分类问题描述为一个半监督的节点分类任务。图神经网络只需要一小部分已标记的节点,即可准确地预测其他节点的类别。

本文将展示如何在Cora数据集中(即以论文为节点,以论文引用为边的引文网络)使用少量标签构建半监督节点分类任务的GNN模型。其具体任务为预测给定论文的类别。每个论文节点均包含一个单词计数向量(word count vector)作为它的特征,这些特征进行了归一化(使其总和为1),参考论文第5.2节。

使用DGL完成节点分类

导入相对应的包

import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.data
Using backend: pytorch

加载数据集

dataset = dgl.data.CoraGraphDataset()
  NumNodes: 2708NumEdges: 10556NumFeats: 1433NumClasses: 7NumTrainingSamples: 140NumValidationSamples: 500NumTestSamples: 1000
Done loading data from cached files.

DGL数据集对象可以包含一个或多个图。一般情况下,整图分类任务数据集包含多个图,边预测节点分类数据集只包含一个图,如节点分类任务中的Cora数据集只包含一个图。

g = dataset[0]

DGL图将节点特征和边特征分别存储在两个类似字典的属性ndataedata中,在Cora数据集中,图包含以下节点特征(其他数据集也类似):

  • train_mask:布尔张量,表示节点是否在训练集中。
  • val_mask:布尔张量,表示节点是否在验证集中。
  • test_mask:布尔张量,表示节点是否在测试集中。
  • label:节点类别。
  • feat:节点特征。
print("Node feature")
print(g.ndata)print("Edge feature")
print(g.edata)
Node feature
{'train_mask': tensor([ True,  True,  True,  ..., False, False, False]), 'label': tensor([3, 4, 4,  ..., 3, 3, 3]), 'val_mask': tensor([False, False, False,  ..., False, False, False]), 'test_mask': tensor([False, False, False,  ...,  True,  True,  True]), 'feat': tensor([[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],...,[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.]])}
Edge feature
{}

定义图卷积网络(GCN)

本文将构建一个两层图卷积网络(GCN)。其中每一层都通过聚合邻居信息来计算新的节点表示,若需要构建多层GCN网络,我们可简单地堆叠dgl.nn.GraphConv模块,这些都模块继承于torch.nn.Module。(假设DGL使用的后端框架为PyTorch)

from dgl.nn import GraphConvclass GCN(nn.Module):def __init__(self, in_feats, h_feats, num_class):super(GCN, self).__init__()self.conv1 = GraphConv(in_feats, h_feats)self.conv2 = GraphConv(h_feats, num_class)def forward(self, g, in_feat):h = self.conv1(g, in_feat)h = F.relu(h)h = self.conv2(g, h)return h#设置参数
in_feats = g.ndata["feat"].shape[1]
h_feats = 16
num_class = (torch.max(g.ndata["label"]) + 1).item() #或者 num_class = dataset.num_classes
# 创建模型
model = GCN(in_feats, h_feats, num_class)

DGL提供了许多流行的邻居聚合模块的实现,我们可以使用一行代码即可轻松调用它们。

训练GCN模型

GCN模型训练过程类似其他PyTorch神经网络训练过程。

def train(g, model, learning_rate=0.01, num_epoch=100):optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)best_val_acc = 0best_test_acc = 0features = g.ndata["feat"]labels = g.ndata["label"]train_mask = g.ndata["train_mask"]test_mask = g.ndata["test_mask"]val_mask = g.ndata["val_mask"]for epoch in range(num_epoch):result = model(g, features)pred = result.argmax(1)loss = F.cross_entropy(result[train_mask], labels[train_mask])train_acc = (pred[train_mask]==labels[train_mask]).float().mean()val_acc  = (pred[val_mask]==labels[val_mask]).float().mean()        test_acc  = (pred[test_mask]==labels[test_mask]).float().mean()if best_val_acc < val_acc:best_val_acc, best_test_acc = val_acc, test_accoptimizer.zero_grad()loss.backward()optimizer.step()if epoch % 5 == 0:print('In epoch {}, loss: {}, val acc: {:.3f} (best {:.3f}), test acc: {:.3f} (best {:.3f})'.format(epoch, loss, val_acc, best_val_acc, test_acc, best_test_acc))if __name__ == "__main__":train(g, model, num_epoch=200, learning_rate=0.002)
In epoch 0, loss: 1.0601081612549024e-06, val acc: 0.764 (best 0.764), test acc: 0.764 (best 0.764)
In epoch 5, loss: 9.979492006095825e-07, val acc: 0.760 (best 0.764), test acc: 0.764 (best 0.764)
In epoch 10, loss: 9.494142432231456e-07, val acc: 0.762 (best 0.764), test acc: 0.764 (best 0.764)
In epoch 15, loss: 9.017308570946625e-07, val acc: 0.764 (best 0.764), test acc: 0.764 (best 0.764)
In epoch 20, loss: 8.557504429518303e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 25, loss: 8.157304023370671e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 30, loss: 7.71452903336467e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 35, loss: 7.322842634494009e-07, val acc: 0.764 (best 0.764), test acc: 0.764 (best 0.764)
In epoch 40, loss: 6.948185955479858e-07, val acc: 0.764 (best 0.764), test acc: 0.764 (best 0.764)
In epoch 45, loss: 6.624618436035234e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 50, loss: 6.292536340879451e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 55, loss: 6.028573125149705e-07, val acc: 0.764 (best 0.764), test acc: 0.764 (best 0.764)
In epoch 60, loss: 5.807185630146705e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 65, loss: 5.534708407139988e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 70, loss: 5.381440359997214e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 75, loss: 5.117477144267468e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 80, loss: 4.913119937555166e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 85, loss: 4.759851037761109e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 90, loss: 4.5640075541086844e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 95, loss: 4.368164354673354e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 100, loss: 4.2319251747358066e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 105, loss: 4.07865627494175e-07, val acc: 0.764 (best 0.764), test acc: 0.766 (best 0.764)
In epoch 110, loss: 3.993507107225014e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 115, loss: 3.840238207430957e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 120, loss: 3.755089039714221e-07, val acc: 0.762 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 125, loss: 3.6358795796331833e-07, val acc: 0.762 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 130, loss: 3.5081561122751737e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 135, loss: 3.414492084630183e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 140, loss: 3.363402356626466e-07, val acc: 0.762 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 145, loss: 3.218648316760664e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 150, loss: 3.159043444611598e-07, val acc: 0.762 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 155, loss: 3.0568645570383524e-07, val acc: 0.762 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 160, loss: 2.988745109178126e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 165, loss: 2.895080797316041e-07, val acc: 0.766 (best 0.766), test acc: 0.765 (best 0.765)
In epoch 170, loss: 2.792901625525701e-07, val acc: 0.766 (best 0.766), test acc: 0.765 (best 0.765)
In epoch 175, loss: 2.733296753376635e-07, val acc: 0.766 (best 0.766), test acc: 0.765 (best 0.765)
In epoch 180, loss: 2.673692165444663e-07, val acc: 0.764 (best 0.766), test acc: 0.765 (best 0.765)
In epoch 185, loss: 2.614087861729786e-07, val acc: 0.762 (best 0.766), test acc: 0.765 (best 0.765)
In epoch 190, loss: 2.53745326972421e-07, val acc: 0.762 (best 0.766), test acc: 0.765 (best 0.765)
In epoch 195, loss: 2.486363541720493e-07, val acc: 0.762 (best 0.766), test acc: 0.765 (best 0.765)

完整代码为

import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.data import CoraGraphDataset
from dgl.nn import GraphConvclass GCN(nn.Module):"""GCN network"""def __init__(self, in_feats, h_feats, num_classes):super(GCN, self).__init__()self.conv1 = GraphConv(in_feats, h_feats)self.conv2 = GraphConv(h_feats, num_classes)def forward(self, g, in_feat):h = self.conv1(g, in_feat)h = F.relu(h)h = self.conv2(g, h)return hdef train(g, model, num_epoch = 100, learning_rate =  0.001):"""train function"""optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)best_val_accurate = 0best_test_accurate = 0features = g.ndata["feat"]labels = g.ndata["label"]train_mask = g.ndata["train_mask"]test_mask = g.ndata["test_mask"]val_mask = g.ndata["val_mask"]for e in range(num_epoch):#forwardresult = model(g, features)#predictionpred = result.argmax(dim=1)#Lossloss = F.cross_entropy(result[train_mask], labels[train_mask])#compute accuratetrain_accurate = (pred[train_mask]==labels[train_mask]).float().mean()test_accurate = (pred[test_mask]==labels[test_mask]).float().mean()val_accurate = (pred[val_mask]==labels[val_mask]).float().mean()if best_val_accurate < val_accurate:best_val_accurate, best_test_accurate = val_accurate, test_accurate#backwardoptimizer.zero_grad()loss.backward()optimizer.step()if e % 5 == 0:print('In epoch {}, loss: {:.3f}, val acc: {:.3f} (best {:.3f}), test acc: {:.3f} (best {:.3f})'.format(e, loss, val_accurate, best_val_accurate, test_accurate, best_test_accurate))def main():dataset = CoraGraphDataset()g = dataset[0]in_feats = g.ndata["feat"].shape[1]h_feats = 16num_classes = dataset.num_classesmodel = GCN(in_feats, h_feats, num_classes)train(g, model)if __name__ == "__main__":main()
  NumNodes: 2708NumEdges: 10556NumFeats: 1433NumClasses: 7NumTrainingSamples: 140NumValidationSamples: 500NumTestSamples: 1000
Done loading data from cached files.
In epoch 0, loss: 1.946, val acc: 0.104 (best 0.104), test acc: 0.114 (best 0.114)
In epoch 5, loss: 1.942, val acc: 0.276 (best 0.276), test acc: 0.314 (best 0.314)
In epoch 10, loss: 1.936, val acc: 0.452 (best 0.452), test acc: 0.452 (best 0.452)
In epoch 15, loss: 1.929, val acc: 0.546 (best 0.546), test acc: 0.549 (best 0.549)
In epoch 20, loss: 1.921, val acc: 0.612 (best 0.612), test acc: 0.631 (best 0.631)
In epoch 25, loss: 1.913, val acc: 0.640 (best 0.640), test acc: 0.647 (best 0.647)
In epoch 30, loss: 1.904, val acc: 0.654 (best 0.654), test acc: 0.670 (best 0.670)
In epoch 35, loss: 1.895, val acc: 0.684 (best 0.684), test acc: 0.692 (best 0.692)
In epoch 40, loss: 1.886, val acc: 0.690 (best 0.692), test acc: 0.695 (best 0.693)
In epoch 45, loss: 1.876, val acc: 0.700 (best 0.700), test acc: 0.694 (best 0.694)
In epoch 50, loss: 1.866, val acc: 0.706 (best 0.708), test acc: 0.701 (best 0.699)
In epoch 55, loss: 1.855, val acc: 0.710 (best 0.710), test acc: 0.698 (best 0.698)
In epoch 60, loss: 1.844, val acc: 0.708 (best 0.712), test acc: 0.702 (best 0.699)
In epoch 65, loss: 1.833, val acc: 0.704 (best 0.712), test acc: 0.702 (best 0.699)
In epoch 70, loss: 1.821, val acc: 0.702 (best 0.712), test acc: 0.704 (best 0.699)
In epoch 75, loss: 1.809, val acc: 0.704 (best 0.712), test acc: 0.705 (best 0.699)
In epoch 80, loss: 1.796, val acc: 0.706 (best 0.712), test acc: 0.704 (best 0.699)
In epoch 85, loss: 1.783, val acc: 0.702 (best 0.712), test acc: 0.706 (best 0.699)
In epoch 90, loss: 1.769, val acc: 0.694 (best 0.712), test acc: 0.703 (best 0.699)
In epoch 95, loss: 1.755, val acc: 0.692 (best 0.712), test acc: 0.706 (best 0.699)

使用GPU进行训练

在GPU上进行训练需要使用to方法将模型和图都放到GPU上,PyTorch训练其他神经网络模型类似。

g = g.to('cuda')
model = GCN(g.ndata['feat'].shape[1], 16, dataset.num_classes).to('cuda')
train(g, model)

参考

翻译整理自Node Classification with DGL

这篇关于使用DGL完成节点分类任务的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/727761

相关文章

Java使用Javassist动态生成HelloWorld类

《Java使用Javassist动态生成HelloWorld类》Javassist是一个非常强大的字节码操作和定义库,它允许开发者在运行时创建新的类或者修改现有的类,本文将简单介绍如何使用Javass... 目录1. Javassist简介2. 环境准备3. 动态生成HelloWorld类3.1 创建CtC

使用Python批量将.ncm格式的音频文件转换为.mp3格式的实战详解

《使用Python批量将.ncm格式的音频文件转换为.mp3格式的实战详解》本文详细介绍了如何使用Python通过ncmdump工具批量将.ncm音频转换为.mp3的步骤,包括安装、配置ffmpeg环... 目录1. 前言2. 安装 ncmdump3. 实现 .ncm 转 .mp34. 执行过程5. 执行结

Java使用jar命令配置服务器端口的完整指南

《Java使用jar命令配置服务器端口的完整指南》本文将详细介绍如何使用java-jar命令启动应用,并重点讲解如何配置服务器端口,同时提供一个实用的Web工具来简化这一过程,希望对大家有所帮助... 目录1. Java Jar文件简介1.1 什么是Jar文件1.2 创建可执行Jar文件2. 使用java

C#使用Spire.Doc for .NET实现HTML转Word的高效方案

《C#使用Spire.Docfor.NET实现HTML转Word的高效方案》在Web开发中,HTML内容的生成与处理是高频需求,然而,当用户需要将HTML页面或动态生成的HTML字符串转换为Wor... 目录引言一、html转Word的典型场景与挑战二、用 Spire.Doc 实现 HTML 转 Word1

Java中的抽象类与abstract 关键字使用详解

《Java中的抽象类与abstract关键字使用详解》:本文主要介绍Java中的抽象类与abstract关键字使用详解,本文通过实例代码给大家介绍的非常详细,感兴趣的朋友跟随小编一起看看吧... 目录一、抽象类的概念二、使用 abstract2.1 修饰类 => 抽象类2.2 修饰方法 => 抽象方法,没有

MyBatis ParameterHandler的具体使用

《MyBatisParameterHandler的具体使用》本文主要介绍了MyBatisParameterHandler的具体使用,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参... 目录一、概述二、源码1 关键属性2.setParameters3.TypeHandler1.TypeHa

Spring 中的切面与事务结合使用完整示例

《Spring中的切面与事务结合使用完整示例》本文给大家介绍Spring中的切面与事务结合使用完整示例,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考... 目录 一、前置知识:Spring AOP 与 事务的关系 事务本质上就是一个“切面”二、核心组件三、完

使用docker搭建嵌入式Linux开发环境

《使用docker搭建嵌入式Linux开发环境》本文主要介绍了使用docker搭建嵌入式Linux开发环境,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面... 目录1、前言2、安装docker3、编写容器管理脚本4、创建容器1、前言在日常开发全志、rk等不同

使用Python实现Word文档的自动化对比方案

《使用Python实现Word文档的自动化对比方案》我们经常需要比较两个Word文档的版本差异,无论是合同修订、论文修改还是代码文档更新,人工比对不仅效率低下,还容易遗漏关键改动,下面通过一个实际案例... 目录引言一、使用python-docx库解析文档结构二、使用difflib进行差异比对三、高级对比方

sky-take-out项目中Redis的使用示例详解

《sky-take-out项目中Redis的使用示例详解》SpringCache是Spring的缓存抽象层,通过注解简化缓存管理,支持Redis等提供者,适用于方法结果缓存、更新和删除操作,但无法实现... 目录Spring Cache主要特性核心注解1.@Cacheable2.@CachePut3.@Ca