使用DGL完成节点分类任务

2024-02-20 09:30

本文主要是介绍使用DGL完成节点分类任务,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

更多图神经网络和深度学习内容请关注:
在这里插入图片描述

节点分类任务概述

节点分类(node classification)任务是在图数据处理中最流行任务之一,一个模型需要预测每个节点属于哪个类别。

在图神经网络出现之前,用于结点分类任务的方法可归为两大类:

  • 仅使用连通性(如DeepWalk或node2vec)
  • 简单地结合连通性和节点自身的特征

相比之下,GNNs是一个通过结合局部邻域(广义上的邻居,包含结点自身)的连通性及其特征来获得节点表征的方法。

Kipf等人将节点分类问题描述为一个半监督的节点分类任务。图神经网络只需要一小部分已标记的节点,即可准确地预测其他节点的类别。

本文将展示如何在Cora数据集中(即以论文为节点,以论文引用为边的引文网络)使用少量标签构建半监督节点分类任务的GNN模型。其具体任务为预测给定论文的类别。每个论文节点均包含一个单词计数向量(word count vector)作为它的特征,这些特征进行了归一化(使其总和为1),参考论文第5.2节。

使用DGL完成节点分类

导入相对应的包

import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.data
Using backend: pytorch

加载数据集

dataset = dgl.data.CoraGraphDataset()
  NumNodes: 2708NumEdges: 10556NumFeats: 1433NumClasses: 7NumTrainingSamples: 140NumValidationSamples: 500NumTestSamples: 1000
Done loading data from cached files.

DGL数据集对象可以包含一个或多个图。一般情况下,整图分类任务数据集包含多个图,边预测节点分类数据集只包含一个图,如节点分类任务中的Cora数据集只包含一个图。

g = dataset[0]

DGL图将节点特征和边特征分别存储在两个类似字典的属性ndataedata中,在Cora数据集中,图包含以下节点特征(其他数据集也类似):

  • train_mask:布尔张量,表示节点是否在训练集中。
  • val_mask:布尔张量,表示节点是否在验证集中。
  • test_mask:布尔张量,表示节点是否在测试集中。
  • label:节点类别。
  • feat:节点特征。
print("Node feature")
print(g.ndata)print("Edge feature")
print(g.edata)
Node feature
{'train_mask': tensor([ True,  True,  True,  ..., False, False, False]), 'label': tensor([3, 4, 4,  ..., 3, 3, 3]), 'val_mask': tensor([False, False, False,  ..., False, False, False]), 'test_mask': tensor([False, False, False,  ...,  True,  True,  True]), 'feat': tensor([[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],...,[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.]])}
Edge feature
{}

定义图卷积网络(GCN)

本文将构建一个两层图卷积网络(GCN)。其中每一层都通过聚合邻居信息来计算新的节点表示,若需要构建多层GCN网络,我们可简单地堆叠dgl.nn.GraphConv模块,这些都模块继承于torch.nn.Module。(假设DGL使用的后端框架为PyTorch)

from dgl.nn import GraphConvclass GCN(nn.Module):def __init__(self, in_feats, h_feats, num_class):super(GCN, self).__init__()self.conv1 = GraphConv(in_feats, h_feats)self.conv2 = GraphConv(h_feats, num_class)def forward(self, g, in_feat):h = self.conv1(g, in_feat)h = F.relu(h)h = self.conv2(g, h)return h#设置参数
in_feats = g.ndata["feat"].shape[1]
h_feats = 16
num_class = (torch.max(g.ndata["label"]) + 1).item() #或者 num_class = dataset.num_classes
# 创建模型
model = GCN(in_feats, h_feats, num_class)

DGL提供了许多流行的邻居聚合模块的实现,我们可以使用一行代码即可轻松调用它们。

训练GCN模型

GCN模型训练过程类似其他PyTorch神经网络训练过程。

def train(g, model, learning_rate=0.01, num_epoch=100):optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)best_val_acc = 0best_test_acc = 0features = g.ndata["feat"]labels = g.ndata["label"]train_mask = g.ndata["train_mask"]test_mask = g.ndata["test_mask"]val_mask = g.ndata["val_mask"]for epoch in range(num_epoch):result = model(g, features)pred = result.argmax(1)loss = F.cross_entropy(result[train_mask], labels[train_mask])train_acc = (pred[train_mask]==labels[train_mask]).float().mean()val_acc  = (pred[val_mask]==labels[val_mask]).float().mean()        test_acc  = (pred[test_mask]==labels[test_mask]).float().mean()if best_val_acc < val_acc:best_val_acc, best_test_acc = val_acc, test_accoptimizer.zero_grad()loss.backward()optimizer.step()if epoch % 5 == 0:print('In epoch {}, loss: {}, val acc: {:.3f} (best {:.3f}), test acc: {:.3f} (best {:.3f})'.format(epoch, loss, val_acc, best_val_acc, test_acc, best_test_acc))if __name__ == "__main__":train(g, model, num_epoch=200, learning_rate=0.002)
In epoch 0, loss: 1.0601081612549024e-06, val acc: 0.764 (best 0.764), test acc: 0.764 (best 0.764)
In epoch 5, loss: 9.979492006095825e-07, val acc: 0.760 (best 0.764), test acc: 0.764 (best 0.764)
In epoch 10, loss: 9.494142432231456e-07, val acc: 0.762 (best 0.764), test acc: 0.764 (best 0.764)
In epoch 15, loss: 9.017308570946625e-07, val acc: 0.764 (best 0.764), test acc: 0.764 (best 0.764)
In epoch 20, loss: 8.557504429518303e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 25, loss: 8.157304023370671e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 30, loss: 7.71452903336467e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 35, loss: 7.322842634494009e-07, val acc: 0.764 (best 0.764), test acc: 0.764 (best 0.764)
In epoch 40, loss: 6.948185955479858e-07, val acc: 0.764 (best 0.764), test acc: 0.764 (best 0.764)
In epoch 45, loss: 6.624618436035234e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 50, loss: 6.292536340879451e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 55, loss: 6.028573125149705e-07, val acc: 0.764 (best 0.764), test acc: 0.764 (best 0.764)
In epoch 60, loss: 5.807185630146705e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 65, loss: 5.534708407139988e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 70, loss: 5.381440359997214e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 75, loss: 5.117477144267468e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 80, loss: 4.913119937555166e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 85, loss: 4.759851037761109e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 90, loss: 4.5640075541086844e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 95, loss: 4.368164354673354e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 100, loss: 4.2319251747358066e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 105, loss: 4.07865627494175e-07, val acc: 0.764 (best 0.764), test acc: 0.766 (best 0.764)
In epoch 110, loss: 3.993507107225014e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 115, loss: 3.840238207430957e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 120, loss: 3.755089039714221e-07, val acc: 0.762 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 125, loss: 3.6358795796331833e-07, val acc: 0.762 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 130, loss: 3.5081561122751737e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 135, loss: 3.414492084630183e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 140, loss: 3.363402356626466e-07, val acc: 0.762 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 145, loss: 3.218648316760664e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 150, loss: 3.159043444611598e-07, val acc: 0.762 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 155, loss: 3.0568645570383524e-07, val acc: 0.762 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 160, loss: 2.988745109178126e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 165, loss: 2.895080797316041e-07, val acc: 0.766 (best 0.766), test acc: 0.765 (best 0.765)
In epoch 170, loss: 2.792901625525701e-07, val acc: 0.766 (best 0.766), test acc: 0.765 (best 0.765)
In epoch 175, loss: 2.733296753376635e-07, val acc: 0.766 (best 0.766), test acc: 0.765 (best 0.765)
In epoch 180, loss: 2.673692165444663e-07, val acc: 0.764 (best 0.766), test acc: 0.765 (best 0.765)
In epoch 185, loss: 2.614087861729786e-07, val acc: 0.762 (best 0.766), test acc: 0.765 (best 0.765)
In epoch 190, loss: 2.53745326972421e-07, val acc: 0.762 (best 0.766), test acc: 0.765 (best 0.765)
In epoch 195, loss: 2.486363541720493e-07, val acc: 0.762 (best 0.766), test acc: 0.765 (best 0.765)

完整代码为

import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.data import CoraGraphDataset
from dgl.nn import GraphConvclass GCN(nn.Module):"""GCN network"""def __init__(self, in_feats, h_feats, num_classes):super(GCN, self).__init__()self.conv1 = GraphConv(in_feats, h_feats)self.conv2 = GraphConv(h_feats, num_classes)def forward(self, g, in_feat):h = self.conv1(g, in_feat)h = F.relu(h)h = self.conv2(g, h)return hdef train(g, model, num_epoch = 100, learning_rate =  0.001):"""train function"""optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)best_val_accurate = 0best_test_accurate = 0features = g.ndata["feat"]labels = g.ndata["label"]train_mask = g.ndata["train_mask"]test_mask = g.ndata["test_mask"]val_mask = g.ndata["val_mask"]for e in range(num_epoch):#forwardresult = model(g, features)#predictionpred = result.argmax(dim=1)#Lossloss = F.cross_entropy(result[train_mask], labels[train_mask])#compute accuratetrain_accurate = (pred[train_mask]==labels[train_mask]).float().mean()test_accurate = (pred[test_mask]==labels[test_mask]).float().mean()val_accurate = (pred[val_mask]==labels[val_mask]).float().mean()if best_val_accurate < val_accurate:best_val_accurate, best_test_accurate = val_accurate, test_accurate#backwardoptimizer.zero_grad()loss.backward()optimizer.step()if e % 5 == 0:print('In epoch {}, loss: {:.3f}, val acc: {:.3f} (best {:.3f}), test acc: {:.3f} (best {:.3f})'.format(e, loss, val_accurate, best_val_accurate, test_accurate, best_test_accurate))def main():dataset = CoraGraphDataset()g = dataset[0]in_feats = g.ndata["feat"].shape[1]h_feats = 16num_classes = dataset.num_classesmodel = GCN(in_feats, h_feats, num_classes)train(g, model)if __name__ == "__main__":main()
  NumNodes: 2708NumEdges: 10556NumFeats: 1433NumClasses: 7NumTrainingSamples: 140NumValidationSamples: 500NumTestSamples: 1000
Done loading data from cached files.
In epoch 0, loss: 1.946, val acc: 0.104 (best 0.104), test acc: 0.114 (best 0.114)
In epoch 5, loss: 1.942, val acc: 0.276 (best 0.276), test acc: 0.314 (best 0.314)
In epoch 10, loss: 1.936, val acc: 0.452 (best 0.452), test acc: 0.452 (best 0.452)
In epoch 15, loss: 1.929, val acc: 0.546 (best 0.546), test acc: 0.549 (best 0.549)
In epoch 20, loss: 1.921, val acc: 0.612 (best 0.612), test acc: 0.631 (best 0.631)
In epoch 25, loss: 1.913, val acc: 0.640 (best 0.640), test acc: 0.647 (best 0.647)
In epoch 30, loss: 1.904, val acc: 0.654 (best 0.654), test acc: 0.670 (best 0.670)
In epoch 35, loss: 1.895, val acc: 0.684 (best 0.684), test acc: 0.692 (best 0.692)
In epoch 40, loss: 1.886, val acc: 0.690 (best 0.692), test acc: 0.695 (best 0.693)
In epoch 45, loss: 1.876, val acc: 0.700 (best 0.700), test acc: 0.694 (best 0.694)
In epoch 50, loss: 1.866, val acc: 0.706 (best 0.708), test acc: 0.701 (best 0.699)
In epoch 55, loss: 1.855, val acc: 0.710 (best 0.710), test acc: 0.698 (best 0.698)
In epoch 60, loss: 1.844, val acc: 0.708 (best 0.712), test acc: 0.702 (best 0.699)
In epoch 65, loss: 1.833, val acc: 0.704 (best 0.712), test acc: 0.702 (best 0.699)
In epoch 70, loss: 1.821, val acc: 0.702 (best 0.712), test acc: 0.704 (best 0.699)
In epoch 75, loss: 1.809, val acc: 0.704 (best 0.712), test acc: 0.705 (best 0.699)
In epoch 80, loss: 1.796, val acc: 0.706 (best 0.712), test acc: 0.704 (best 0.699)
In epoch 85, loss: 1.783, val acc: 0.702 (best 0.712), test acc: 0.706 (best 0.699)
In epoch 90, loss: 1.769, val acc: 0.694 (best 0.712), test acc: 0.703 (best 0.699)
In epoch 95, loss: 1.755, val acc: 0.692 (best 0.712), test acc: 0.706 (best 0.699)

使用GPU进行训练

在GPU上进行训练需要使用to方法将模型和图都放到GPU上,PyTorch训练其他神经网络模型类似。

g = g.to('cuda')
model = GCN(g.ndata['feat'].shape[1], 16, dataset.num_classes).to('cuda')
train(g, model)

参考

翻译整理自Node Classification with DGL

这篇关于使用DGL完成节点分类任务的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!


原文地址:https://blog.csdn.net/huanghelouzi/article/details/116430387
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.chinasem.cn/article/727761

相关文章

Python常用命令提示符使用方法详解

《Python常用命令提示符使用方法详解》在学习python的过程中,我们需要用到命令提示符(CMD)进行环境的配置,:本文主要介绍Python常用命令提示符使用方法的相关资料,文中通过代码介绍的... 目录一、python环境基础命令【Windows】1、检查Python是否安装2、 查看Python的安

Python并行处理实战之如何使用ProcessPoolExecutor加速计算

《Python并行处理实战之如何使用ProcessPoolExecutor加速计算》Python提供了多种并行处理的方式,其中concurrent.futures模块的ProcessPoolExecu... 目录简介完整代码示例代码解释1. 导入必要的模块2. 定义处理函数3. 主函数4. 生成数字列表5.

Python中help()和dir()函数的使用

《Python中help()和dir()函数的使用》我们经常需要查看某个对象(如模块、类、函数等)的属性和方法,Python提供了两个内置函数help()和dir(),它们可以帮助我们快速了解代... 目录1. 引言2. help() 函数2.1 作用2.2 使用方法2.3 示例(1) 查看内置函数的帮助(

Linux脚本(shell)的使用方式

《Linux脚本(shell)的使用方式》:本文主要介绍Linux脚本(shell)的使用方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录概述语法详解数学运算表达式Shell变量变量分类环境变量Shell内部变量自定义变量:定义、赋值自定义变量:引用、修改、删

Java使用HttpClient实现图片下载与本地保存功能

《Java使用HttpClient实现图片下载与本地保存功能》在当今数字化时代,网络资源的获取与处理已成为软件开发中的常见需求,其中,图片作为网络上最常见的资源之一,其下载与保存功能在许多应用场景中都... 目录引言一、Apache HttpClient简介二、技术栈与环境准备三、实现图片下载与保存功能1.

Python中使用uv创建环境及原理举例详解

《Python中使用uv创建环境及原理举例详解》uv是Astral团队开发的高性能Python工具,整合包管理、虚拟环境、Python版本控制等功能,:本文主要介绍Python中使用uv创建环境及... 目录一、uv工具简介核心特点:二、安装uv1. 通过pip安装2. 通过脚本安装验证安装:配置镜像源(可

LiteFlow轻量级工作流引擎使用示例详解

《LiteFlow轻量级工作流引擎使用示例详解》:本文主要介绍LiteFlow是一个灵活、简洁且轻量的工作流引擎,适合用于中小型项目和微服务架构中的流程编排,本文给大家介绍LiteFlow轻量级工... 目录1. LiteFlow 主要特点2. 工作流定义方式3. LiteFlow 流程示例4. LiteF

使用Python开发一个现代化屏幕取色器

《使用Python开发一个现代化屏幕取色器》在UI设计、网页开发等场景中,颜色拾取是高频需求,:本文主要介绍如何使用Python开发一个现代化屏幕取色器,有需要的小伙伴可以参考一下... 目录一、项目概述二、核心功能解析2.1 实时颜色追踪2.2 智能颜色显示三、效果展示四、实现步骤详解4.1 环境配置4.

使用jenv工具管理多个JDK版本的方法步骤

《使用jenv工具管理多个JDK版本的方法步骤》jenv是一个开源的Java环境管理工具,旨在帮助开发者在同一台机器上轻松管理和切换多个Java版本,:本文主要介绍使用jenv工具管理多个JD... 目录一、jenv到底是干啥的?二、jenv的核心功能(一)管理多个Java版本(二)支持插件扩展(三)环境隔

SQL中JOIN操作的条件使用总结与实践

《SQL中JOIN操作的条件使用总结与实践》在SQL查询中,JOIN操作是多表关联的核心工具,本文将从原理,场景和最佳实践三个方面总结JOIN条件的使用规则,希望可以帮助开发者精准控制查询逻辑... 目录一、ON与WHERE的本质区别二、场景化条件使用规则三、最佳实践建议1.优先使用ON条件2.WHERE用