GCN学习:Pytorch-Geometric教程(二)

2024-02-01 08:18

本文主要是介绍GCN学习:Pytorch-Geometric教程(二),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

PyG教程二

  • 数据转换
  • GCN网络

数据转换

PyTorch Geometric带有自己的变换,该变换期望将Data对象作为输入并返回一个新的变换后的Data对象。 可以使用torch_geometric.transforms.Compose将变换链接在一起,并在将处理后的数据集保存到磁盘之前(pre_transform)或访问数据集中的图形之前(transform)应用变换。
让我们看一个示例,其中我们对ShapeNet数据集(包含17,000个3D形状点clouds和来自16个形状类别的每个点标签)应用变换。

from torch_geometric.datasets import ShapeNet
dataset = ShapeNet(root='/tmp/ShapeNet', categories=['Airplane'])
print(dataset[0])
>>> Data(pos=[2518, 3], y=[2518])

通过transform将其从点云转化成图。

import torch_geometric.transforms as T
from torch_geometric.datasets import ShapeNet
dataset = ShapeNet(root='ShapeNet', categories=['Airplane'],pre_transform=T.KNNGraph(k=6))
print(dataset[0])
>>> Data(edge_index=[2, 15108], pos=[2518, 3], y=[2518])

在将数据保存到磁盘之前,我们使用pre_transform进行了转换(从而缩短了加载时间)。 请注意,下一次初始化数据集时,即使我们不传递任何变换,也将已经包含图的边。

GCN网络

我们建立如下一个GCN网络
在这里插入图片描述先获取数据集:

from torch_geometric.datasets import Planetoid
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
dataset = Planetoid(root='Cora', name='Cora')
print(dataset)
print(dataset.num_node_features)
print(dataset.num_classes)
>>>Cora()
>>>1433
>>>7

Cora是一个机器学习论文数据集,其中共有7个类别(num_classes:基于案例、遗传算法、 神经网络、概率方法、强化学习 、规则学习、理论。整个数据集中共有2708篇论文,在词干堵塞和去除词尾后,只剩下1433个独特的单词(num_node_features),文档频率小于10的所有单词都被删除。
这里我们并不需要使用dataloader和transform。

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
class Net(torch.nn.module):def __init__(self):super(Net, self).__init__()#GCNConv的两个参数为input channel size和Output channel size#conv1将每个顶点的1433个特征压缩到16个特征值#conv2根据之前得到的16个特征值将其再压缩为7self.conv1=GCNConv(dataset.num_node_features, 16)self.conv2=GCNConv(16, dataset.num_classes)def forward(self, data):x, edge_index=data.x, data.edge_indexx=self.conv1(x, edge_index)x=F.relu(x)#dropout用于降低过拟合情况x=F.dropout((x, training = self.training))x=self.conv2(x, edge_index)#dim=0对一列所有元素的进行softmax运算#dim=1对一行所有元素的进行softmax运算return F.log_softmax(x,dim=1)
class GCNConv(in_channels: int, out_channels: int, improved: bool = False, cached: bool = False, add_self_loops: bool = True, normalize: bool = True, bias: bool = True, **kwargs)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
data = dataset[0].to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)model.train()
for epoch in range(200):optimizer.zero_grad()out = model(data)#在训练集上计算loss,out为图在gcn网络中的计算结果,data.y即7类的概率大小loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])loss.backward()optimizer.step()
#计算准确率
model.eval()
#选取7种类别中概率最大的类别为预测的节点类别
_, pred = model(data).max(dim=1)
correct = int(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())
acc = correct/int(data.test_mask.sum())
print('Accuracy:{:.4f}'.format(acc))
>>> Accuracy: 0.8150

这篇关于GCN学习:Pytorch-Geometric教程(二)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/666646

相关文章

SpringBoot集成redisson实现延时队列教程

《SpringBoot集成redisson实现延时队列教程》文章介绍了使用Redisson实现延迟队列的完整步骤,包括依赖导入、Redis配置、工具类封装、业务枚举定义、执行器实现、Bean创建、消费... 目录1、先给项目导入Redisson依赖2、配置redis3、创建 RedissonConfig 配

基于C#实现PDF转图片的详细教程

《基于C#实现PDF转图片的详细教程》在数字化办公场景中,PDF文件的可视化处理需求日益增长,本文将围绕Spire.PDFfor.NET这一工具,详解如何通过C#将PDF转换为JPG、PNG等主流图片... 目录引言一、组件部署二、快速入门:PDF 转图片的核心 C# 代码三、分辨率设置 - 清晰度的决定因

Java Scanner类解析与实战教程

《JavaScanner类解析与实战教程》JavaScanner类(java.util包)是文本输入解析工具,支持基本类型和字符串读取,基于Readable接口与正则分隔符实现,适用于控制台、文件输... 目录一、核心设计与工作原理1.底层依赖2.解析机制A.核心逻辑基于分隔符(delimiter)和模式匹

Unity新手入门学习殿堂级知识详细讲解(图文)

《Unity新手入门学习殿堂级知识详细讲解(图文)》Unity是一款跨平台游戏引擎,支持2D/3D及VR/AR开发,核心功能模块包括图形、音频、物理等,通过可视化编辑器与脚本扩展实现开发,项目结构含A... 目录入门概述什么是 UnityUnity引擎基础认知编辑器核心操作Unity 编辑器项目模式分类工程

spring AMQP代码生成rabbitmq的exchange and queue教程

《springAMQP代码生成rabbitmq的exchangeandqueue教程》使用SpringAMQP代码直接创建RabbitMQexchange和queue,并确保绑定关系自动成立,简... 目录spring AMQP代码生成rabbitmq的exchange and 编程queue执行结果总结s

Python学习笔记之getattr和hasattr用法示例详解

《Python学习笔记之getattr和hasattr用法示例详解》在Python中,hasattr()、getattr()和setattr()是一组内置函数,用于对对象的属性进行操作和查询,这篇文章... 目录1.getattr用法详解1.1 基本作用1.2 示例1.3 原理2.hasattr用法详解2.

python使用Akshare与Streamlit实现股票估值分析教程(图文代码)

《python使用Akshare与Streamlit实现股票估值分析教程(图文代码)》入职测试中的一道题,要求:从Akshare下载某一个股票近十年的财务报表包括,资产负债表,利润表,现金流量表,保存... 目录一、前言二、核心知识点梳理1、Akshare数据获取2、Pandas数据处理3、Matplotl

Python pandas库自学超详细教程

《Pythonpandas库自学超详细教程》文章介绍了Pandas库的基本功能、安装方法及核心操作,涵盖数据导入(CSV/Excel等)、数据结构(Series、DataFrame)、数据清洗、转换... 目录一、什么是Pandas库(1)、Pandas 应用(2)、Pandas 功能(3)、数据结构二、安

2025版mysql8.0.41 winx64 手动安装详细教程

《2025版mysql8.0.41winx64手动安装详细教程》本文指导Windows系统下MySQL安装配置,包含解压、设置环境变量、my.ini配置、初始化密码获取、服务安装与手动启动等步骤,... 目录一、下载安装包二、配置环境变量三、安装配置四、启动 mysql 服务,修改密码一、下载安装包安装地

电脑提示d3dx11_43.dll缺失怎么办? DLL文件丢失的多种修复教程

《电脑提示d3dx11_43.dll缺失怎么办?DLL文件丢失的多种修复教程》在使用电脑玩游戏或运行某些图形处理软件时,有时会遇到系统提示“d3dx11_43.dll缺失”的错误,下面我们就来分享超... 在计算机使用过程中,我们可能会遇到一些错误提示,其中之一就是缺失某个dll文件。其中,d3dx11_4