“Weisfeiler-Lehman Neural Machine for Link Prediction“文章复现工作

本文主要是介绍“Weisfeiler-Lehman Neural Machine for Link Prediction“文章复现工作,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1 复现文章:

  • 《Weisfeiler-Lehman Neural Machine for Link Prediction》

2 文章提出的方法思路:

  • 笔者希望能够通过提取目标连边的周围连边所构成的子图,并通过一种编码方法,保留住每个节点在子图中扮演的不同角色,即在不同子图当中扮演相同角色的节点能够有相近的编号。如下面两张图所示,在两个不同的子网络当中,扮演角色相同的节点会得到相同的编号。

Fig.1
在这里插入图片描述

  • 然后笔者要对于网络中的每一条连边生成这样的一张子图,然后将子图都转化为对应的邻接矩阵,输入到一个机器学习分类模型(文章中用到的是单个隐藏层的神经网络)当中进行训练。

3 难点:

  • 如何对于不同的子图当中的节点进行编码,并保证其能够保存节点的相对角色?

4 解决:

  • 笔者的思路就是,首先将一个连边所对应的两个节点的一阶、二阶…邻居节点提取出来,然后根据一个笔者提出的哈希函数对节点进行排序,大致流程如下:其做法为:首先根据子图中的节点与目标连边所连接的两个节点之间的距离初始化节点的标签,然后开始迭代,每次迭代,计算每个节点的哈希值,然后根据哈希值来更新该节点的标签,其中哈希值最小的编码为1,第二小的为2,若有相同取值的就分配到同样的数字。
    在这里插入图片描述
    哈希函数为:
    在这里插入图片描述
    其中 P ( n ) P(n) P(n)为第n个素数。

5 文章提出方法的流程图:

  • 整个算法的基本流程:首先对于每一条连边,提取K个以上邻居节点构成的子图。提取顺序是:先一阶邻居,再二阶邻居…;接着对提取的子图进行图编码,然后选择前K个进行提取。提取完子图之后为每个节点建立一个上三角邻接矩阵,将邻接矩阵输入到神经网络中进行学习。
    在这里插入图片描述

6 代码:

6.1 首先导入所需库
import networkx as nx
import numpy as np
import pandas as pd
import math
import random
import matplotlib.pyplot as plt
import scipy
from scipy.io import loadmat
from functools import partial
from sklearn import svm
from sklearn.linear_model import LogisticRegression
from sklearn.neural_network import MLPClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.utils import shuffle
from sklearn import metrics
6.2 载入数据,构建网络及可视化
data = loadmat('./USAir.mat')
G = nx.from_numpy_matrix(data['net'].todense(),create_using=nx.Graph()) #构建网络
print(nx.info(G))
pos=nx.spring_layout(G)
plt.figure(figsize=(10,8))
nx.draw(G,pos=pos,node_size=50,alpha=0.5,with_labels=False)
plt.title('USAir Network',fontsize=16,fontweight='bold')
plt.show()
  • 结果:
    在这里插入图片描述
    在这里插入图片描述
6.3 正采样
G_train = G.copy() 
G_test = nx.empty_graph(G_train.number_of_nodes())
n_links = int(G_train.number_of_edges()) # 获得网络中的连边总数ratio = 0.1 #取10%的连边作为测试集连边
n_test_link = int(np.ceil(n_links*ratio))
print(n_test_link)
selected_link_id = np.random.choice(np.arange(n_links),size=n_test_link,replace=False)
adj_matrix = nx.adj_matrix(G)
adj_matrix = scipy.sparse.triu(adj_matrix,k=1)
rol,col = adj_matrix.nonzero() #取非零元素
links = [(i,j) for i,j in zip(rol,col)]
selected_links = []
for link_id in selected_link_id:selected_links.append(links[link_id]) #根据连边的索引获得连边
G_train.remove_edges_from(selected_links)
G_test.add_edges_from(selected_links)
print(G_train.number_of_edges(),G_test.number_of_edges())
  • 结果:213;1913,213
6.4 负采样
  • 负采样需要首先建立一个与原图节点数相同的图,然后把原图当中不存在的连边全部加入到该图当中,然后随机抽取。
k = 2
neg_train_link = k*G_train.number_of_edges() # 取正样本数的两倍作为负样本数
neg_test_link = k*G_test.number_of_edges()
G_neg = nx.empty_graph(G.number_of_edges())
neg_links = list(nx.non_edges(G)) # 返回原始网络中不存在的连边
G_neg.add_edges_from(neg_links)
print(G_neg.number_of_edges())selected_link_idd = np.random.choice(np.arange(G_neg.number_of_edges()),size=neg_train_link+neg_test_link,replace=False)
G_train_neg = nx.empty_graph(G.number_of_nodes())
G_test_neg = nx.empty_graph(G.number_of_nodes())selected_links = []
for i in range(neg_train_link):inx = selected_link_idd[i]selected_links.append(neg_links[inx])
G_train_neg.add_edges_from(selected_links)selected_links = []
for i in range(neg_test_link):inx = selected_link_idd[i]selected_links.append(neg_links[inx])
G_test_neg.add_edges_from(selected_links)
  • 结果:52820
6.5 将正样本和负样本组合起来
all_train_links = list(G_train.edges) + list(G_train_neg.edges)
label_train = [1]*G_train.number_of_edges()+[0]*neg_train_linkall_test_links = list(G_test.edges) + list(G_test_neg.edges)
label_test = [1]*G_test.number_of_edges()+[0]*neg_test_linky_train,y_test = np.array(label_train),np.array(label_test)
6.6 提取子图
def enclosing_subgraph(fringe,G,subgraph,distance):"""构建enclosing subgraphInput:fringe:用来寻找下一阶邻居的列表G:原网络subgraph:要提取的子图distance:描述连边的距离return:fringe,subgraph:更新后的内容"""neighbor_link = []temp_subgraph = subgraph.copy()for link in fringe:u = link[0]v = link[1]neighbor_link += list(G.edges(u))neighbor_link += list(G.edges(v))temp_subgraph.add_edges_from(neighbor_link)# 除去已有的连边neighbor_link = [l for l in temp_subgraph.edges() if l not in subgraph.edges()] temp_subgraph.add_edges_from(neighbor_link,distance=distance,inverse_distance=1/distance)return neighbor_link,temp_subgraphdef subgraph_extractor(G,link,K):"""为每一条连边提取一个子图Input:G:网络link:目标连边K:子图大小return:subgraph:连边的子图 """distance = 0subgraph = nx.Graph()fringe = [link] #用来存放邻居节点的列表subgraph.add_edge(link[0],link[1],distance=distance)while subgraph.number_of_nodes()<K and len(fringe)>0:distance += 1 fringe,subgraph = enclosing_subgraph(fringe,G,subgraph,distance)temp_subgraph = G.subgraph(subgraph.nodes)additional_edges = [l for l in temp_subgraph.edges() if l not in subgraph.edges]subgraph.add_edges_from(additional_edges,distance=distance+1,inverse_distance=1/(distance+1))return subgraphif __name__ == '__main__':G_512 = subgraph_extractor(G_train,(5,12),10)nx.draw(G_512,with_labels=True)plt.show()
  • 结果:
    在这里插入图片描述
6.7 对节点进行编码排序
def primes(x):"""判断是否为素数"""if x <2:return Falseif x ==2 or x == 3:return Truefor i in range(2,x):if x% i ==0:return Falsereturn Truedef cal_mean_geo_distance(G_subgraph,link):"""计算目标节点到图中其他节点的距离"""u = link[0]v = link[1]G_subgraph.remove_edge(u,v) # 因为不需要计算节点u和节点v之间的距离,所以先去除n = G_subgraph.number_of_nodes()u_reachable = nx.descendants(G_subgraph,source=u)# 找到节点u可以到达的节点v_reachable = nx.descendants(G_subgraph,source=v)for each in G_subgraph.nodes():distance_to_u = 0distance_to_v = 0if each != u: #计算节点到u的距离,如果节点无法到u就设其为2**ndistance_to_u = nx.shortest_path_length(G_subgraph,source=each,target=u) if each in u_reachable else 2**nif each != v:distance_to_v = nx.shortest_path_length(G_subgraph,source=each,target=v) if each in v_reachable else 2**n# 将信息存到节点属性当中G_subgraph.nodes[each]['ave_d'] = math.sqrt(distance_to_u*distance_to_v) G_subgraph.add_edge(u,v,distance=0)return G_subgraphdef PWL(G_sub,link,prime_list):"""对子图中的节点上色Input:G_sub:目标子图link:确定该子图的连边prime_list:素数列表return:nodelist :编码后的节点列表"""tem_subgraph = G_sub.copy()if tem_subgraph.has_edge(link[0],link[1]):tem_subgraph.remove_edge(link[0],link[1])ave_d = nx.get_node_attributes(tem_subgraph,'ave_d') # 获取初始化所需要的数据df = pd.DataFrame.from_dict(ave_d,orient='index',columns=['hash_value']) #转化为一个DataFramedf = df.sort_index() #按照index进行排序df['order'] = df['hash_value'].rank(axis=0,method='min').astype(np.int) # 按照hash_value来获得每个节点的排序编号df['previous_order'] = np.zeros(len(ave_d)) #用来存放前一轮的编号adj_matrix = nx.adj_matrix(tem_subgraph,nodelist=sorted(tem_subgraph.nodes)).todense()while any(df['order']!= df['previous_order']): #只要排序还在变,就执行下面的代码df['log_priem'] = np.log(prime_list[df['order'].values])total_log = np.ceil(np.sum(df['log_priem'].values))df['hash_value'] = adj_matrix*df['log_priem'].values.reshape(-1,1)/total_log + df['order'].values.reshape(-1,1)df['previous_order'] = df['order']df['order'] = df['hash_value'].rank(axis=0,method='min').astype(np.int)nodelist = df['order'].sort_values().index.values #根据排序输出子图中的节点集return nodelistif __name__ == '__main__':e_subgraph = cal_mean_geo_distance(G_01,(5,12))prime_list = np.array([i for i in range(10000) if primes(i)],dtype=np.int)nodelist = PWL(e_subgraph,(5,12),prime_list)print(nodelist)
  • 结果:
    在这里插入图片描述
6.8 用邻居矩阵来embedding每个子图
def sample(subgraph,nodelist,weight='weight',size=10):adj_mat = nx.adj_matrix(subgraph,weight=weight,nodelist=nodelist).todense()vector = np.asarray(adj_mat)[np.triu_indices(len(adj_mat),k=1)] # np.triu_indices:获取矩阵上三角元素,k=1就是不需要对角线元素d = size*(size-1)//2if len(vector) <d:vector = np.append(vector,np.zeros(d-len(vector)))return vector[1:]if __name__ == '__main__':s = sample(e_subgraph_, nodelist, size=10)print(s)
  • 结果:
    在这里插入图片描述
6.9 对每条连边进行embedding
def encode_link(link,G,prime_list,weight='weight',K=10):e_subgraph = subgraph_extractor(G_train,link,K) #首先提取子图e_subgraph = cal_mean_geo_distance(e_subgraph,link) # 然后获得距离属性nodelist = PWL(e_subgraph,link,prime_list) # 排序if len(nodelist)>K: # 返回固定大小的节点集nodelist = nodelist[:K]e_subgraph = e_subgraph.subgraph(nodelist)embedd = sample(e_subgraph,nodelist,weight='weight',size=K)#进行embeddingreturn embeddif __name__ == '__main__':X_train = np.array(list(map(partial(encode_link,G=G_train,prime_list=prime_list,weight='weight',K=10),all_train_links)))X_test = np.array(list(map(partial(encode_link,G=G_test,prime_list=prime_list,weight='weight',K=10),all_test_links)))
6.10 训练模型
X_train_shuffle, y_train_shuffle = shuffle(X_train, y_train)
model1 = MLPClassifier(hidden_layer_sizes=(32, 32, 16),alpha=1e-3,batch_size=128,learning_rate_init=0.001,max_iter=100,verbose=True,early_stopping=False,tol=-10000)
model1.fit(X_train_shuffle,y_train_shuffle)
predictions = model1.predict(X_test)
fpr, tpr, thresholds = metrics.roc_curve(label_test,predictions,pos_label=1)
auc = metrics.auc(fpr, tpr)
print(auc)
  • 结果:
    在这里插入图片描述

这篇关于“Weisfeiler-Lehman Neural Machine for Link Prediction“文章复现工作的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/395580

相关文章

Spring @Scheduled注解及工作原理

《Spring@Scheduled注解及工作原理》Spring的@Scheduled注解用于标记定时任务,无需额外库,需配置@EnableScheduling,设置fixedRate、fixedDe... 目录1.@Scheduled注解定义2.配置 @Scheduled2.1 开启定时任务支持2.2 创建

SpringBoot整合Flowable实现工作流的详细流程

《SpringBoot整合Flowable实现工作流的详细流程》Flowable是一个使用Java编写的轻量级业务流程引擎,Flowable流程引擎可用于部署BPMN2.0流程定义,创建这些流程定义的... 目录1、流程引擎介绍2、创建项目3、画流程图4、开发接口4.1 Java 类梳理4.2 查看流程图4

LiteFlow轻量级工作流引擎使用示例详解

《LiteFlow轻量级工作流引擎使用示例详解》:本文主要介绍LiteFlow是一个灵活、简洁且轻量的工作流引擎,适合用于中小型项目和微服务架构中的流程编排,本文给大家介绍LiteFlow轻量级工... 目录1. LiteFlow 主要特点2. 工作流定义方式3. LiteFlow 流程示例4. LiteF

SpringBoot集成LiteFlow实现轻量级工作流引擎的详细过程

《SpringBoot集成LiteFlow实现轻量级工作流引擎的详细过程》LiteFlow是一款专注于逻辑驱动流程编排的轻量级框架,它以组件化方式快速构建和执行业务流程,有效解耦复杂业务逻辑,下面给大... 目录一、基础概念1.1 组件(Component)1.2 规则(Rule)1.3 上下文(Conte

详解如何使用Python构建从数据到文档的自动化工作流

《详解如何使用Python构建从数据到文档的自动化工作流》这篇文章将通过真实工作场景拆解,为大家展示如何用Python构建自动化工作流,让工具代替人力完成这些数字苦力活,感兴趣的小伙伴可以跟随小编一起... 目录一、Excel处理:从数据搬运工到智能分析师二、PDF处理:文档工厂的智能生产线三、邮件自动化:

基于Python开发一个有趣的工作时长计算器

《基于Python开发一个有趣的工作时长计算器》随着远程办公和弹性工作制的兴起,个人及团队对于工作时长的准确统计需求日益增长,本文将使用Python和PyQt5打造一个工作时长计算器,感兴趣的小伙伴可... 目录概述功能介绍界面展示php软件使用步骤说明代码详解1.窗口初始化与布局2.工作时长计算核心逻辑3

RabbitMQ工作模式中的RPC通信模式详解

《RabbitMQ工作模式中的RPC通信模式详解》在RabbitMQ中,RPC模式通过消息队列实现远程调用功能,这篇文章给大家介绍RabbitMQ工作模式之RPC通信模式,感兴趣的朋友一起看看吧... 目录RPC通信模式概述工作流程代码案例引入依赖常量类编写客户端代码编写服务端代码RPC通信模式概述在R

Go 语言中的select语句详解及工作原理

《Go语言中的select语句详解及工作原理》在Go语言中,select语句是用于处理多个通道(channel)操作的一种控制结构,它类似于switch语句,本文给大家介绍Go语言中的select语... 目录Go 语言中的 select 是做什么的基本功能语法工作原理示例示例 1:监听多个通道示例 2:带

微信公众号脚本-获取热搜自动新建草稿并发布文章

《微信公众号脚本-获取热搜自动新建草稿并发布文章》本来想写一个自动化发布微信公众号的小绿书的脚本,但是微信公众号官网没有小绿书的接口,那就写一个获取热搜微信普通文章的脚本吧,:本文主要介绍微信公众... 目录介绍思路前期准备环境要求获取接口token获取热搜获取热搜数据下载热搜图片给图片加上标题文字上传图片

kotlin中的模块化结构组件及工作原理

《kotlin中的模块化结构组件及工作原理》本文介绍了Kotlin中模块化结构组件,包括ViewModel、LiveData、Room和Navigation的工作原理和基础使用,本文通过实例代码给大家... 目录ViewModel 工作原理LiveData 工作原理Room 工作原理Navigation 工