详细讲一下PYG 里面的torch_geometric.nn.conv.transformer_conv函数

2024-05-09 04:44

本文主要是介绍详细讲一下PYG 里面的torch_geometric.nn.conv.transformer_conv函数,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1.首先先讲一下代码

这是官方给的代码:torch_geometric.nn.conv.transformer_conv — pytorch_geometric documentation

import math
import typing
from typing import Optional, Tuple, Unionimport torch
import torch.nn.functional as F
from torch import Tensorfrom torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.typing import (Adj,NoneType,OptTensor,PairTensor,SparseTensor,
)
from torch_geometric.utils import softmaxif typing.TYPE_CHECKING:from typing import overload
else:from torch.jit import _overload_method as overload[docs]class TransformerConv(MessagePassing):r"""The graph transformer operator from the `"Masked Label Prediction:Unified Message Passing Model for Semi-Supervised Classification"<https://arxiv.org/abs/2009.03509>`_ paper... math::\mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i +\sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \mathbf{W}_2 \mathbf{x}_{j},where the attention coefficients :math:`\alpha_{i,j}` are computed viamulti-head dot product attention:.. math::\alpha_{i,j} = \textrm{softmax} \left(\frac{(\mathbf{W}_3\mathbf{x}_i)^{\top} (\mathbf{W}_4\mathbf{x}_j)}{\sqrt{d}} \right)Args:in_channels (int or tuple): Size of each input sample, or :obj:`-1` toderive the size from the first input(s) to the forward method.A tuple corresponds to the sizes of source and targetdimensionalities.out_channels (int): Size of each output sample.heads (int, optional): Number of multi-head-attentions.(default: :obj:`1`)concat (bool, optional): If set to :obj:`False`, the multi-headattentions are averaged instead of concatenated.(default: :obj:`True`)beta (bool, optional): If set, will combine aggregation andskip information via.. math::\mathbf{x}^{\prime}_i = \beta_i \mathbf{W}_1 \mathbf{x}_i +(1 - \beta_i) \underbrace{\left(\sum_{j \in \mathcal{N}(i)}\alpha_{i,j} \mathbf{W}_2 \vec{x}_j \right)}_{=\mathbf{m}_i}with :math:`\beta_i = \textrm{sigmoid}(\mathbf{w}_5^{\top}[ \mathbf{W}_1 \mathbf{x}_i, \mathbf{m}_i, \mathbf{W}_1\mathbf{x}_i - \mathbf{m}_i ])` (default: :obj:`False`)dropout (float, optional): Dropout probability of the normalizedattention coefficients which exposes each node to a stochasticallysampled neighborhood during training. (default: :obj:`0`)edge_dim (int, optional): Edge feature dimensionality (in casethere are any). Edge features are added to the keys afterlinear transformation, that is, prior to computing theattention dot product. They are also added to final valuesafter the same linear transformation. The model is:.. math::\mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i +\sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \left(\mathbf{W}_2 \mathbf{x}_{j} + \mathbf{W}_6 \mathbf{e}_{ij}\right),where the attention coefficients :math:`\alpha_{i,j}` are nowcomputed via:.. math::\alpha_{i,j} = \textrm{softmax} \left(\frac{(\mathbf{W}_3\mathbf{x}_i)^{\top}(\mathbf{W}_4\mathbf{x}_j + \mathbf{W}_6 \mathbf{e}_{ij})}{\sqrt{d}} \right)(default :obj:`None`)bias (bool, optional): If set to :obj:`False`, the layer will not learnan additive bias. (default: :obj:`True`)root_weight (bool, optional): If set to :obj:`False`, the layer willnot add the transformed root node features to the output and theoption  :attr:`beta` is set to :obj:`False`. (default: :obj:`True`)**kwargs (optional): Additional arguments of:class:`torch_geometric.nn.conv.MessagePassing`."""_alpha: OptTensordef __init__(self,in_channels: Union[int, Tuple[int, int]],out_channels: int,heads: int = 1,concat: bool = True,beta: bool = False,dropout: float = 0.,edge_dim: Optional[int] = None,bias: bool = True,root_weight: bool = True,**kwargs,):kwargs.setdefault('aggr', 'add')super().__init__(node_dim=0, **kwargs)self.in_channels = in_channelsself.out_channels = out_channelsself.heads = headsself.beta = beta and root_weightself.root_weight = root_weightself.concat = concatself.dropout = dropoutself.edge_dim = edge_dimself._alpha = Noneif isinstance(in_channels, int):in_channels = (in_channels, in_channels)self.lin_key = Linear(in_channels[0], heads * out_channels)self.lin_query = Linear(in_channels[1], heads * out_channels)self.lin_value = Linear(in_channels[0], heads * out_channels)if edge_dim is not None:self.lin_edge = Linear(edge_dim, heads * out_channels, bias=False)else:self.lin_edge = self.register_parameter('lin_edge', None)if concat:self.lin_skip = Linear(in_channels[1], heads * out_channels,bias=bias)if self.beta:self.lin_beta = Linear(3 * heads * out_channels, 1, bias=False)else:self.lin_beta = self.register_parameter('lin_beta', None)else:self.lin_skip = Linear(in_channels[1], out_channels, bias=bias)if self.beta:self.lin_beta = Linear(3 * out_channels, 1, bias=False)else:self.lin_beta = self.register_parameter('lin_beta', None)self.reset_parameters()[docs]    def reset_parameters(self):super().reset_parameters()self.lin_key.reset_parameters()self.lin_query.reset_parameters()self.lin_value.reset_parameters()if self.edge_dim:self.lin_edge.reset_parameters()self.lin_skip.reset_parameters()if self.beta:self.lin_beta.reset_parameters()@overloaddef forward(self,x: Union[Tensor, PairTensor],edge_index: Adj,edge_attr: OptTensor = None,return_attention_weights: NoneType = None,) -> Tensor:pass@overloaddef forward(  # noqa: F811self,x: Union[Tensor, PairTensor],edge_index: Tensor,edge_attr: OptTensor = None,return_attention_weights: bool = None,) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:pass@overloaddef forward(  # noqa: F811self,x: Union[Tensor, PairTensor],edge_index: SparseTensor,edge_attr: OptTensor = None,return_attention_weights: bool = None,) -> Tuple[Tensor, SparseTensor]:pass[docs]    def forward(  # noqa: F811self,x: Union[Tensor, PairTensor],edge_index: Adj,edge_attr: OptTensor = None,return_attention_weights: Optional[bool] = None,) -> Union[Tensor,Tuple[Tensor, Tuple[Tensor, Tensor]],Tuple[Tensor, SparseTensor],]:r"""Runs the forward pass of the module.Args:x (torch.Tensor or (torch.Tensor, torch.Tensor)): The input nodefeatures.edge_index (torch.Tensor or SparseTensor): The edge indices.edge_attr (torch.Tensor, optional): The edge features.(default: :obj:`None`)return_attention_weights (bool, optional): If set to :obj:`True`,will additionally return the tuple:obj:`(edge_index, attention_weights)`, holding the computedattention weights for each edge. (default: :obj:`None`)"""H, C = self.heads, self.out_channelsif isinstance(x, Tensor):x = (x, x)query = self.lin_query(x[1]).view(-1, H, C)key = self.lin_key(x[0]).view(-1, H, C)value = self.lin_value(x[0]).view(-1, H, C)# propagate_type: (query: Tensor, key:Tensor, value: Tensor,#                  edge_attr: OptTensor)out = self.propagate(edge_index, query=query, key=key, value=value,edge_attr=edge_attr)alpha = self._alphaself._alpha = Noneif self.concat:out = out.view(-1, self.heads * self.out_channels)else:out = out.mean(dim=1)if self.root_weight:x_r = self.lin_skip(x[1])if self.lin_beta is not None:beta = self.lin_beta(torch.cat([out, x_r, out - x_r], dim=-1))beta = beta.sigmoid()out = beta * x_r + (1 - beta) * outelse:out = out + x_rif isinstance(return_attention_weights, bool):assert alpha is not Noneif isinstance(edge_index, Tensor):return out, (edge_index, alpha)elif isinstance(edge_index, SparseTensor):return out, edge_index.set_value(alpha, layout='coo')else:return outdef message(self, query_i: Tensor, key_j: Tensor, value_j: Tensor,edge_attr: OptTensor, index: Tensor, ptr: OptTensor,size_i: Optional[int]) -> Tensor:if self.lin_edge is not None:assert edge_attr is not Noneedge_attr = self.lin_edge(edge_attr).view(-1, self.heads,self.out_channels)key_j = key_j + edge_attralpha = (query_i * key_j).sum(dim=-1) / math.sqrt(self.out_channels)alpha = softmax(alpha, index, ptr, size_i)self._alpha = alphaalpha = F.dropout(alpha, p=self.dropout, training=self.training)out = value_jif edge_attr is not None:out = out + edge_attrout = out * alpha.view(-1, self.heads, 1)return outdef __repr__(self) -> str:return (f'{self.__class__.__name__}({self.in_channels}, 'f'{self.out_channels}, heads={self.heads})')

2.详细解释一下

几个重要的参数

in_channels (int or tuple): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities.

out_channels (int): Size of each output sample.

heads (int, optional): Number of multi-head-attentions. (default: :obj:`1`)

怎么理解这几个参数?

 

  • in_channels 表示每个输入样本的大小。如果设置为整数,则表示所有输入样本的大小相同;如果设置为 -1,则表示输入样本的大小将从 forward 方法的第一个输入中推导出来;如果设置为元组,则表示输入样本的大小对应于源维度和目标维度的大小。

  • out_channels 表示每个输出样本的大小,即经过卷积操作后产生的特征向量的维度大小。

 

当使用 tg.nn.TransformerConv 时,可以通过以下方式理解 in_channelsout_channels

假设我们有一个图数据集,每个节点都有一个 10 维的特征向量表示。那么在这种情况下:

  • 如果我们想将每个节点的特征向量作为输入,然后使用 tg.nn.TransformerConv 进行卷积操作,那么 in_channels 应该设置为 10,表示每个输入样本的大小为 10。

  • 假设我们想将节点的特征向量转换为一个 16 维的特征向量,那么 out_channels 应该设置为 16,表示每个输出样本的大小为 16,即经过卷积操作后每个节点的特征向量将变为 16 维。

  • tg.nn.TransformerConv 中,heads 参数表示多头注意力的数量。举个例子,如果 heads 参数设置为 4,那么模型将学习 4 组注意力权重,每组权重都用于计算输入的不同子空间的注意力,然后将这些头的输出进行合并以产生最终的输出。

 举个整体的例子

我们有一个输入张量 x,它的形状是 (batch_size, seq_length, input_dim),其中:

  • batch_size 表示批量大小;
  • seq_length 表示序列长度;
  • input_dim 表示输入特征的维度。

现在假设我们使用了 tg.nn.TransformerConv,并设置 heads=2,那么模型将学习两组注意力权重,每组用于计算不同的注意力。输出张量的形状将取决于 out_channels 参数,我们假设 out_channels=64

import torch
import torch_geometric.nn as tg# 假设输入张量的形状是 (batch_size, seq_length, input_dim)
x = torch.randn(32, 10, 128)  # 32 个样本,每个样本有 10 个时间步,每个时间步有 128 个特征# 创建 TransformerConv 模型,设置 heads=2,out_channels=64
conv_layer = tg.nn.TransformerConv(in_channels=128, out_channels=64, heads=2)# 使用模型进行前向传播
output = conv_layer(x)print("输出张量的形状:", output.shape)

 2.1将特征映射到键值对中

在这里,通过线性变换层 Linear,输入特征被转换成了键(key)、查询(query)和数值(value)的表示形式,以便用于多头自注意力机制。

具体来说:

  • self.lin_key 用于将输入特征(in_channels[0])映射到键的表示形式。
  • self.lin_query 用于将输入特征(in_channels[1])映射到查询的表示形式。
  • self.lin_value 用于将输入特征(in_channels[0])映射到数值的表示形式。

 具体地,假设输入特征的维度是 (batch_size, num_nodes, in_channels),其中 batch_size 是批量大小,num_nodes 是节点数,in_channels 是输入特征的通道数。在映射到键的过程中,线性变换层的权重矩阵将是一个维度为 (in_channels, heads * out_channels) 的矩阵,其中 heads 是注意力头的数量,out_channels 是输出特征的通道数。因此,通过矩阵乘法运算,输入特征将被映射到一个新的特征空间,其维度为 (batch_size, num_nodes, heads, out_channels)。在这个新的特征空间中,每个节点的每个头都有一个键表示。

这篇关于详细讲一下PYG 里面的torch_geometric.nn.conv.transformer_conv函数的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/972422

相关文章

Python函数作用域与闭包举例深度解析

《Python函数作用域与闭包举例深度解析》Python函数的作用域规则和闭包是编程中的关键概念,它们决定了变量的访问和生命周期,:本文主要介绍Python函数作用域与闭包的相关资料,文中通过代码... 目录1. 基础作用域访问示例1:访问全局变量示例2:访问外层函数变量2. 闭包基础示例3:简单闭包示例4

Python中isinstance()函数原理解释及详细用法示例

《Python中isinstance()函数原理解释及详细用法示例》isinstance()是Python内置的一个非常有用的函数,用于检查一个对象是否属于指定的类型或类型元组中的某一个类型,它是Py... 目录python中isinstance()函数原理解释及详细用法指南一、isinstance()函数

python中的高阶函数示例详解

《python中的高阶函数示例详解》在Python中,高阶函数是指接受函数作为参数或返回函数作为结果的函数,下面:本文主要介绍python中高阶函数的相关资料,文中通过代码介绍的非常详细,需要的朋... 目录1.定义2.map函数3.filter函数4.reduce函数5.sorted函数6.自定义高阶函数

Python的pandas库基础知识超详细教程

《Python的pandas库基础知识超详细教程》Pandas是Python数据处理核心库,提供Series和DataFrame结构,支持CSV/Excel/SQL等数据源导入及清洗、合并、统计等功能... 目录一、配置环境二、序列和数据表2.1 初始化2.2  获取数值2.3 获取索引2.4 索引取内容2

Python中的sort方法、sorted函数与lambda表达式及用法详解

《Python中的sort方法、sorted函数与lambda表达式及用法详解》文章对比了Python中list.sort()与sorted()函数的区别,指出sort()原地排序返回None,sor... 目录1. sort()方法1.1 sort()方法1.2 基本语法和参数A. reverse参数B.

uni-app小程序项目中实现前端图片压缩实现方式(附详细代码)

《uni-app小程序项目中实现前端图片压缩实现方式(附详细代码)》在uni-app开发中,文件上传和图片处理是很常见的需求,但也经常会遇到各种问题,下面:本文主要介绍uni-app小程序项目中实... 目录方式一:使用<canvas>实现图片压缩(推荐,兼容性好)示例代码(小程序平台):方式二:使用uni

Python屏幕抓取和录制的详细代码示例

《Python屏幕抓取和录制的详细代码示例》随着现代计算机性能的提高和网络速度的加快,越来越多的用户需要对他们的屏幕进行录制,:本文主要介绍Python屏幕抓取和录制的相关资料,需要的朋友可以参考... 目录一、常用 python 屏幕抓取库二、pyautogui 截屏示例三、mss 高性能截图四、Pill

java时区时间转为UTC的代码示例和详细解释

《java时区时间转为UTC的代码示例和详细解释》作为一名经验丰富的开发者,我经常被问到如何将Java中的时间转换为UTC时间,:本文主要介绍java时区时间转为UTC的代码示例和详细解释,文中通... 目录前言步骤一:导入必要的Java包步骤二:获取指定时区的时间步骤三:将指定时区的时间转换为UTC时间步

MySQL批量替换数据库字符集的实用方法(附详细代码)

《MySQL批量替换数据库字符集的实用方法(附详细代码)》当需要修改数据库编码和字符集时,通常需要对其下属的所有表及表中所有字段进行修改,下面:本文主要介绍MySQL批量替换数据库字符集的实用方法... 目录前言为什么要批量修改字符集?整体脚本脚本逻辑解析1. 设置目标参数2. 生成修改表默认字符集的语句3

Python函数的基本用法、返回值特性、全局变量修改及异常处理技巧

《Python函数的基本用法、返回值特性、全局变量修改及异常处理技巧》本文将通过实际代码示例,深入讲解Python函数的基本用法、返回值特性、全局变量修改以及异常处理技巧,感兴趣的朋友跟随小编一起看看... 目录一、python函数定义与调用1.1 基本函数定义1.2 函数调用二、函数返回值详解2.1 有返