GraphSAGE 到底在训练什么? 图上的Mini-Batch 是怎么训练的 ?

2023-12-13 00:52

本文主要是介绍GraphSAGE 到底在训练什么? 图上的Mini-Batch 是怎么训练的 ?,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1. 一个端到端的 同构图(Cora数据集)节点分类代码:

import argparseimport dgl
import dgl.nn as dglnnimport torch
import torch.nn as nn
import torch.nn.functional as F
from dgl import AddSelfLoop
from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDatasetclass SAGE(nn.Module):def __init__(self, in_size, hid_size, out_size):super().__init__()self.layers = nn.ModuleList()# two-layer GraphSAGE-meanself.layers.append(dglnn.SAGEConv(in_size, hid_size, "gcn"))self.layers.append(dglnn.SAGEConv(hid_size, out_size, "gcn"))self.dropout = nn.Dropout(0.5)def forward(self, graph, x):h = self.dropout(x)for l, layer in enumerate(self.layers):h = layer(graph, h)if l != len(self.layers) - 1:h = F.relu(h)h = self.dropout(h)return hdef evaluate(g, features, labels, mask, model):model.eval()with torch.no_grad():logits = model(g, features)logits = logits[mask]labels = labels[mask]_, indices = torch.max(logits, dim=1)correct = torch.sum(indices == labels)return correct.item() * 1.0 / len(labels)def train(g, features, labels, masks, model):# define train/val samples, loss function and optimizertrain_mask, val_mask = masksloss_fcn = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)# training loopfor epoch in range(200):model.train()logits = model(g, features)loss = loss_fcn(logits[train_mask], labels[train_mask])optimizer.zero_grad()loss.backward()optimizer.step()acc = evaluate(g, features, labels, val_mask, model)print("Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} ".format(epoch, loss.item(), acc))if __name__ == "__main__":parser = argparse.ArgumentParser(description="GraphSAGE")parser.add_argument("--dataset",type=str,default="cora",help="Dataset name ('cora', 'citeseer', 'pubmed')",)parser.add_argument("--dt",type=str,default="float",help="data type(float, bfloat16)",)args = parser.parse_args()print(f"Training with DGL built-in GraphSage module")# load and preprocess datasettransform = (AddSelfLoop())  # by default, it will first remove self-loops to prevent duplicationif args.dataset == "cora":data = CoraGraphDataset(transform=transform)elif args.dataset == "citeseer":data = CiteseerGraphDataset(transform=transform)elif args.dataset == "pubmed":data = PubmedGraphDataset(transform=transform)else:raise ValueError("Unknown dataset: {}".format(args.dataset))g = data[0]device = torch.device("cuda" if torch.cuda.is_available() else "cpu")g = g.int().to(device)features = g.ndata["feat"]labels = g.ndata["label"]masks = g.ndata["train_mask"], g.ndata["val_mask"]# create GraphSAGE modelin_size = features.shape[1]out_size = data.num_classesmodel = SAGE(in_size, 16, out_size).to(device)# convert model and graph to bfloat16 if neededif args.dt == "bfloat16":g = dgl.to_bfloat16(g)features = features.to(dtype=torch.bfloat16)model = model.to(dtype=torch.bfloat16)# model trainingprint("Training...")train(g, features, labels, masks, model)# test the modelprint("Testing...")acc = evaluate(g, features, labels, g.ndata["test_mask"], model)print("Test accuracy {:.4f}".format(acc))

2. GraphSAGE的实现 : SAGEConv 类:

我们先来介绍一下DGL对GraphSAGE这个模型的实现:SAGEConv() 在三方库的下述位置:

这篇关于GraphSAGE 到底在训练什么? 图上的Mini-Batch 是怎么训练的 ?的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/486576

相关文章

怎么用idea创建一个SpringBoot项目

《怎么用idea创建一个SpringBoot项目》本文介绍了在IDEA中创建SpringBoot项目的步骤,包括环境准备(JDK1.8+、Maven3.2.5+)、使用SpringInitializr... 目录如何在idea中创建一个SpringBoot项目环境准备1.1打开IDEA,点击New新建一个项

qt5cored.dll报错怎么解决? 电脑qt5cored.dll文件丢失修复技巧

《qt5cored.dll报错怎么解决?电脑qt5cored.dll文件丢失修复技巧》在进行软件安装或运行程序时,有时会遇到由于找不到qt5core.dll,无法继续执行代码,这个问题可能是由于该文... 遇到qt5cored.dll文件错误时,可能会导致基于 Qt 开发的应用程序无法正常运行或启动。这种错

电脑提示xlstat4.dll丢失怎么修复? xlstat4.dll文件丢失处理办法

《电脑提示xlstat4.dll丢失怎么修复?xlstat4.dll文件丢失处理办法》长时间使用电脑,大家多少都会遇到类似dll文件丢失的情况,不过,解决这一问题其实并不复杂,下面我们就来看看xls... 在Windows操作系统中,xlstat4.dll是一个重要的动态链接库文件,通常用于支持各种应用程序

Mac备忘录怎么导出/备份和云同步? Mac备忘录使用技巧

《Mac备忘录怎么导出/备份和云同步?Mac备忘录使用技巧》备忘录作为iOS里简单而又不可或缺的一个系统应用,上手容易,可以满足我们日常生活中各种记录的需求,今天我们就来看看Mac备忘录的导出、... 「备忘录」是 MAC 上的一款常用应用,它可以帮助我们捕捉灵感、记录待办事项或保存重要信息。为了便于在不同

springboot+vue项目怎么解决跨域问题详解

《springboot+vue项目怎么解决跨域问题详解》:本文主要介绍springboot+vue项目怎么解决跨域问题的相关资料,包括前端代理、后端全局配置CORS、注解配置和Nginx反向代理,... 目录1. 前端代理(开发环境推荐)2. 后端全局配置 CORS(生产环境推荐)3. 后端注解配置(按接口

一文带你搞懂Python中__init__.py到底是什么

《一文带你搞懂Python中__init__.py到底是什么》朋友们,今天我们来聊聊Python里一个低调却至关重要的文件——__init__.py,有些人可能听说过它是“包的标志”,也有人觉得它“没... 目录先搞懂 python 模块(module)Python 包(package)是啥?那么 __in

电脑死机无反应怎么强制重启? 一文读懂方法及注意事项

《电脑死机无反应怎么强制重启?一文读懂方法及注意事项》在日常使用电脑的过程中,我们难免会遇到电脑无法正常启动的情况,本文将详细介绍几种常见的电脑强制开机方法,并探讨在强制开机后应注意的事项,以及如何... 在日常生活和工作中,我们经常会遇到电脑突然无反应的情况,这时候强制重启就成了解决问题的“救命稻草”。那

电脑开机提示krpt.dll丢失怎么解决? krpt.dll文件缺失的多种解决办法

《电脑开机提示krpt.dll丢失怎么解决?krpt.dll文件缺失的多种解决办法》krpt.dll是Windows操作系统中的一个动态链接库文件,它对于系统的正常运行起着重要的作用,本文将详细介绍... 在使用 Windows 操作系统的过程中,用户有时会遇到各种错误提示,其中“找不到 krpt.dll”

MySql死锁怎么排查的方法实现

《MySql死锁怎么排查的方法实现》本文主要介绍了MySql死锁怎么排查的方法实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧... 目录前言一、死锁排查方法1. 查看死锁日志方法 1:启用死锁日志输出方法 2:检查 mysql 错误

Rsnapshot怎么用? 基于Rsync的强大Linux备份工具使用指南

《Rsnapshot怎么用?基于Rsync的强大Linux备份工具使用指南》Rsnapshot不仅可以备份本地文件,还能通过SSH备份远程文件,接下来详细介绍如何安装、配置和使用Rsnaps... Rsnapshot 是一款开源的文件系统快照工具。它结合了 Rsync 和 SSH 的能力,可以帮助你在 li