H2-FDetector模型解析

2024-05-15 10:36
文章标签 模型 解析 h2 fdetector

本文主要是介绍H2-FDetector模型解析,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

  • 1. H2FDetector_layer 类
  • 2. RelationAware 类
  • 3. MultiRelationH2FDetectorLayer 类
  • 4. H2FDetector 类

这个实现包括三个主要部分:H2FDetector_layer、MultiRelationH2FDetectorLayer 和 H2FDetector。每个部分都有其独特的功能和职责。下面是这些组件的详细实现和解释。

1. H2FDetector_layer 类

这是一个基本的 GNN 层,处理图卷积和注意力机制。

  • 这是基本的图卷积层,包含注意力机制和关系感知的边签名计算。
class H2FDetector_layer(nn.Module):def __init__(self, input_dim, output_dim, head, relation_aware, etype, dropout, if_sum=False):super().__init__()self.etype = etypeself.head = headself.hd = output_dimself.if_sum = if_sumself.relation_aware = relation_awareself.w_liner = nn.Linear(input_dim, output_dim * head)self.atten = nn.Linear(2 * self.hd, 1)self.relu = nn.ReLU()self.leakyrelu = nn.LeakyReLU()self.softmax = nn.Softmax(dim=1)def forward(self, g, h):with g.local_scope():g.ndata['feat'] = hg.apply_edges(self.sign_edges, etype=self.etype)h = self.w_liner(h)g.ndata['h'] = hg.update_all(message_func=self.message, reduce_func=self.reduce, etype=self.etype)out = g.ndata['out']return outdef message(self, edges):src = edges.srcsrc_features = edges.data['sign'].view(-1, 1) * src['h']src_features = src_features.view(-1, self.head, self.hd)z = torch.cat([src_features, edges.dst['h'].view(-1, self.head, self.hd)], dim=-1)alpha = self.atten(z)alpha = self.leakyrelu(alpha)return {'atten': alpha, 'sf': src_features}def reduce(self, nodes):alpha = nodes.mailbox['atten']sf = nodes.mailbox['sf']alpha = self.softmax(alpha)out = torch.sum(alpha * sf, dim=1)if not self.if_sum:out = out.view(-1, self.head * self.hd)else:out = out.sum(dim=-2)return {'out': out}def sign_edges(self, edges):src = edges.src['feat']dst = edges.dst['feat']score = self.relation_aware(src, dst)return {'sign': torch.sign(score)}

这里是对 H2FDetector_layer 类的详细解释。这个类定义了一个图神经网络(GNN)层,它使用注意力机制来对图中的节点进行特征提取和更新。下面是对每一部分代码的详细解释。

class H2FDetector_layer(nn.Module):def __init__(self, input_dim, output_dim, head, relation_aware, etype, dropout, if_sum=False):super().__init__()self.etype = etypeself.head = headself.hd = output_dimself.if_sum = if_sumself.relation_aware = relation_awareself.w_liner = nn.Linear(input_dim, output_dim * head)self.atten = nn.Linear(2 * self.hd, 1)self.relu = nn.ReLU()self.leakyrelu = nn.LeakyReLU()self.softmax = nn.Softmax(dim=1)

在这里插入图片描述
2.

def forward(self, g, h):with g.local_scope():g.ndata['feat'] = hg.apply_edges(self.sign_edges, etype=self.etype)h = self.w_liner(h)g.ndata['h'] = hg.update_all(message_func=self.message, reduce_func=self.reduce, etype=self.etype)out = g.ndata['out']return out

在这里插入图片描述
3.

def message(self, edges):src = edges.srcsrc_features = edges.data['sign'].view(-1, 1) * src['h']src_features = src_features.view(-1, self.head, self.hd)z = torch.cat([src_features, edges.dst['h'].view(-1, self.head, self.hd)], dim=-1)alpha = self.atten(z)alpha = self.leakyrelu(alpha)return {'atten': alpha, 'sf': src_features}

在这里插入图片描述
4.

def reduce(self, nodes):alpha = nodes.mailbox['atten']sf = nodes.mailbox['sf']alpha = self.softmax(alpha)out = torch.sum(alpha * sf, dim=1)if not self.if_sum:out = out.view(-1, self.head * self.hd)else:out = out.sum(dim=-2)return {'out': out}

在这里插入图片描述
5.

def sign_edges(self, edges):src = edges.src['feat']dst = edges.dst['feat']score = self.relation_aware(src, dst)return {'sign': torch.sign(score)}

在这里插入图片描述
6.
在这里插入图片描述

2. RelationAware 类

这是一个关系感知的模块,用于计算边的关系权重。

  • 关系感知模块,用于计算边的关系权重。
class RelationAware(nn.Module):def __init__(self, input_dim, output_dim, dropout):super().__init__()self.d_liner = nn.Linear(input_dim, output_dim)self.f_liner = nn.Linear(3 * output_dim, 1)self.tanh = nn.Tanh()self.dropout = nn.Dropout(dropout)def forward(self, src, dst):src = self.d_liner(src)dst = self.d_liner(dst)diff = src - dste_feats = torch.cat([src, dst, diff], dim=1)e_feats = self.dropout(e_feats)score = self.f_liner(e_feats).squeeze()score = self.tanh(score)return score

3. MultiRelationH2FDetectorLayer 类

这是一个处理多种关系的 GNN 层。

  • 处理多种关系的图卷积层,包含对不同关系类型的处理逻辑。
class MultiRelationH2FDetectorLayer(nn.Module):def __init__(self, input_dim, output_dim, head, dataset, dropout, if_sum=False):super().__init__()self.relation = copy.deepcopy(dataset.etypes)self.relation.remove('homo')self.n_relation = len(self.relation)if not if_sum:self.liner = nn.Linear(self.n_relation * output_dim * head, output_dim * head)else:self.liner = nn.Linear(self.n_relation * output_dim, output_dim)self.relation_aware = RelationAware(input_dim, output_dim * head, dropout)self.minelayers = nn.ModuleDict()self.dropout = nn.Dropout(dropout)for e in self.relation:self.minelayers[e] = H2FDetector_layer(input_dim, output_dim, head, self.relation_aware, e, dropout, if_sum)def forward(self, g, h):hs = []for e in self.relation:he = self.minelayers[e](g, h)hs.append(he)h = torch.cat(hs, dim=1)h = self.dropout(h)h = self.liner(h)return hdef loss(self, g, h):with g.local_scope():g.ndata['feat'] = hagg_h = self.forward(g, h)g.apply_edges(self.score_edges, etype='homo')edges_score = g.edges['homo'].data['score']edge_train_mask = g.edges['homo'].data['train_mask'].bool()edge_train_label = g.edges['homo'].data['label'][edge_train_mask]edge_train_pos = edge_train_label == 1edge_train_neg = edge_train_label == -1edge_train_pos_index = edge_train_pos.nonzero().flatten().detach().cpu().numpy()edge_train_neg_index = edge_train_neg.nonzero().flatten().detach().cpu().numpy()edge_train_pos_index = np.random.choice(edge_train_pos_index, size=len(edge_train_neg_index))index = np.concatenate([edge_train_pos_index, edge_train_neg_index])index.sort()edge_train_score = edges_score[edge_train_mask]# hinge lossedge_diff_loss = hinge_loss(edge_train_label[index], edge_train_score[index])train_mask = g.ndata['train_mask'].bool()train_h = agg_h[train_mask]train_label = g.ndata['label'][train_mask]train_pos = train_label == 1train_neg = train_label == 0train_pos_index = train_pos.nonzero().flatten().detach().cpu().numpy()train_neg_index = train_neg.nonzero().flatten().detach().cpu().numpy()train_neg_index = np.random.choice(train_neg_index, size=len(train_pos_index))node_index = np.concatenate([train_neg_index, train_pos_index])node_index.sort()pos_prototype = torch.mean(train_h[train_pos], dim=0).view(1, -1)neg_prototype = torch.mean(train_h[train_neg], dim=0).view(1, -1)train_h_loss = train_h[node_index]pos_prototypes = pos_prototype.expand(train_h_loss.shape)neg_prototypes = neg_prototype.expand(train_h_loss.shape)diff_pos = -F.pairwise_distance(train_h_loss, pos_prototypes)diff_neg = -F.pairwise_distance(train_h_loss, neg_prototypes)diff_pos = diff_pos.view(-1, 1)diff_neg = diff_neg.view(-1, 1)diff = torch.cat([diff_neg, diff_pos], dim=1)diff_loss = F.cross_entropy(diff, train_label[node_index])return agg_h, edge_diff_loss, diff_lossdef score_edges(self, edges):src = edges.src['feat']dst = edges.dst['feat']score = self.relation_aware(src, dst)return {'score': score}

4. H2FDetector 类

这是一个多层的 GNN 模型,用于构建一个关系感知的图神经网络模型。

  • 多层的关系感知图神经网络模型,包含前向传播和损失计算方法。
class H2FDetector(nn.Module):def __init__(self, args, g):super().__init__()self.n_layer = args.n_layerself.input_dim = g.nodes['r'].data['feature'].shape[1]self.intra_dim = args.intra_dimself.n_class = args.n_classself.gamma1 = args.gamma1self.gamma2 = args.gamma2self.n_layer = args.n_layerself.mine_layers = nn.ModuleList()if args.n_layer == 1:self.mine_layers.append(MultiRelationH2FDetectorLayer(self.input_dim, self.n_class, args.head, g, args.dropout, if_sum=True))else:self.mine_layers.append(MultiRelationH2FDetectorLayer(self.input_dim, self.intra_dim, args.head, g, args.dropout))for _ in range(1, self.n_layer - 1):self.mine_layers.append(MultiRelationH2FDetectorLayer(self.intra_dim * args.head, self.intra_dim, args.head, g, args.dropout))self.mine_layers.append(MultiRelationH2FDetectorLayer(self.intra_dim * args.head, self.n_class, args.head, g, args.dropout, if_sum=True))self.dropout = nn.Dropout(args.dropout)self.relu = nn.ReLU()def forward(self, g):feats = g.ndata['feature'].float()h = self.mine_layers[0](g, feats)if self.n_layer > 1:h = self.relu(h)h = self.dropout(h)for i in range(1, len(self.mine_layers) - 1):h = self.mine_layers[i](g, h)h = self.relu(h)h = self.dropout(h)h = self.mine_layers[-1](g, h)return hdef loss(self, g):feats = g.ndata['feature'].float()train_mask = g.ndata['train_mask'].bool()train_label = g.ndata['label'][train_mask]train_pos = train_label == 1train_neg = train_label == 0pos_index = train_pos.nonzero().flatten().detach().cpu().numpy()neg_index = train_neg.nonzero().flatten().detach().cpu().numpy()neg_index = np.random.choice(neg_index, size=len(pos_index), replace=False)index = np.concatenate([pos_index, neg_index])index.sort()h, edge_loss, prototype_loss = self.mine_layers[0].loss(g, feats)if self.n_layer > 1:h = self.relu(h)h = self.dropout(h)for i in range(1, len(self.mine_layers) - 1):h, e_loss, p_loss = self.mine_layers[i].loss(g, h)h = self.relu(h)h = self.dropout(h)edge_loss += e_lossprototype_loss += p_lossh, e_loss, p_loss = self.mine_layers[-1].loss(g, h)edge_loss += e_lossprototype_loss += p_lossmodel_loss = F.cross_entropy(h[train_mask][index], train_label[index])loss = model_loss + self.gamma1 * edge_loss + self.gamma2 * prototype_lossreturn loss

这篇关于H2-FDetector模型解析的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/991597

相关文章

深度解析Python中递归下降解析器的原理与实现

《深度解析Python中递归下降解析器的原理与实现》在编译器设计、配置文件处理和数据转换领域,递归下降解析器是最常用且最直观的解析技术,本文将详细介绍递归下降解析器的原理与实现,感兴趣的小伙伴可以跟随... 目录引言:解析器的核心价值一、递归下降解析器基础1.1 核心概念解析1.2 基本架构二、简单算术表达

深度解析Java @Serial 注解及常见错误案例

《深度解析Java@Serial注解及常见错误案例》Java14引入@Serial注解,用于编译时校验序列化成员,替代传统方式解决运行时错误,适用于Serializable类的方法/字段,需注意签... 目录Java @Serial 注解深度解析1. 注解本质2. 核心作用(1) 主要用途(2) 适用位置3

Java MCP 的鉴权深度解析

《JavaMCP的鉴权深度解析》文章介绍JavaMCP鉴权的实现方式,指出客户端可通过queryString、header或env传递鉴权信息,服务器端支持工具单独鉴权、过滤器集中鉴权及启动时鉴权... 目录一、MCP Client 侧(负责传递,比较简单)(1)常见的 mcpServers json 配置

从原理到实战解析Java Stream 的并行流性能优化

《从原理到实战解析JavaStream的并行流性能优化》本文给大家介绍JavaStream的并行流性能优化:从原理到实战的全攻略,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的... 目录一、并行流的核心原理与适用场景二、性能优化的核心策略1. 合理设置并行度:打破默认阈值2. 避免装箱

Maven中生命周期深度解析与实战指南

《Maven中生命周期深度解析与实战指南》这篇文章主要为大家详细介绍了Maven生命周期实战指南,包含核心概念、阶段详解、SpringBoot特化场景及企业级实践建议,希望对大家有一定的帮助... 目录一、Maven 生命周期哲学二、default生命周期核心阶段详解(高频使用)三、clean生命周期核心阶

深入解析C++ 中std::map内存管理

《深入解析C++中std::map内存管理》文章详解C++std::map内存管理,指出clear()仅删除元素可能不释放底层内存,建议用swap()与空map交换以彻底释放,针对指针类型需手动de... 目录1️、基本清空std::map2️、使用 swap 彻底释放内存3️、map 中存储指针类型的对象

Java Scanner类解析与实战教程

《JavaScanner类解析与实战教程》JavaScanner类(java.util包)是文本输入解析工具,支持基本类型和字符串读取,基于Readable接口与正则分隔符实现,适用于控制台、文件输... 目录一、核心设计与工作原理1.底层依赖2.解析机制A.核心逻辑基于分隔符(delimiter)和模式匹

Java+AI驱动实现PDF文件数据提取与解析

《Java+AI驱动实现PDF文件数据提取与解析》本文将和大家分享一套基于AI的体检报告智能评估方案,详细介绍从PDF上传、内容提取到AI分析、数据存储的全流程自动化实现方法,感兴趣的可以了解下... 目录一、核心流程:从上传到评估的完整链路二、第一步:解析 PDF,提取体检报告内容1. 引入依赖2. 封装

深度解析Python yfinance的核心功能和高级用法

《深度解析Pythonyfinance的核心功能和高级用法》yfinance是一个功能强大且易于使用的Python库,用于从YahooFinance获取金融数据,本教程将深入探讨yfinance的核... 目录yfinance 深度解析教程 (python)1. 简介与安装1.1 什么是 yfinance?

99%的人都选错了! 路由器WiFi双频合一还是分开好的专业解析与适用场景探讨

《99%的人都选错了!路由器WiFi双频合一还是分开好的专业解析与适用场景探讨》关于双频路由器的“双频合一”与“分开使用”两种模式,用户往往存在诸多疑问,本文将从多个维度深入探讨这两种模式的优缺点,... 在如今“没有WiFi就等于与世隔绝”的时代,越来越多家庭、办公室都开始配置双频无线路由器。但你有没有注