图神经网络(GNN)最简单全面原理与代码实现!
深度学习新手入门福利,本文将给你带来最简单全面的图神经网络理解与代码实现!GNN,即图神经网络,是一种用于处理图形数据的深度学习技术。
1. 什么是图数据?在图神经网络中,图数据是以什么形式表示的?
图数据是由节点(Node)和边(Edge)组成的数据,最简单的方式是使用邻接矩阵来表示图形结构,从而捕捉图形中的节点和边的相关性。假设图中的节点数为n,那么邻接矩阵就是一个n*n的矩阵,如果节点之间有关联,则在邻接矩阵中表示为1,无关联则为0。在图中,鲁班与其他英雄都没有关联,表现在邻接矩阵当中就是它所在的行与列为全零。
王者荣耀当中的图和邻接矩阵
图数据的信息包含3个层面,分别是节点信息(V)、边信息(E)、图整体(U)信息,它们通常是用向量来表示。而图神经网络就是通过学习数据从而得到3个层面向量的最优表示。
2. 对于图数据而言有怎样的任务?
●图层面的任务(分类/回归)
例:分子是天然的图,原子是节点,化学键是边。现在要做一个分类,有一个苯环的分子分一类,两个苯环的分子分一类。这是图分类任务。
●边层面的任务(分类/回归)
例:UFO拳击赛上,首先通过语义分割把台上的人和环境分离开来。赛场上的人都是节点,现在要做一个预测,预测的是这些人之间的关系,是对抗关系?还是观众watch的关系?还是裁判watch的关系?这是边分类任务。
●节点层面的任务(分类/回归)
例:假设一个跆拳道俱乐部里有A、B两个教练,所有的会员都是节点。有一天A、B两个跆拳道教练决裂,那么各个学员是愿意和A在一个阵营还是愿意和B在一个阵营?这是节点分类任务。
3. 图神经网络是如何工作的?
GNN工作流程图
GNN是对图上的所有属性进行的一个可以优化的变换,它的输入是一个图,输出也是个图。它只对属性向量(即上文所述的V、E、U)进行变换,但它不会改变图的连接性(即哪些点互相连接经过GNN后是不会变的)。在获取优化后的属性向量之后,再根据实际的任务,后接全连接神经网络,进行分类和回归。大家可以把图神经网络看做是一个图数据的在三个维度的特征提取器。
GNN对属性向量优化的方法叫做消息传递机制。比如最原始的GNN是SUM求和传递机制;到后面发展成图卷积网络(GCN)就考虑到了节点的度,度越大,权重越小,使用了加权的SUM;再到后面发展为图注意力网络GAT,在消息传递过程中引入了注意力机制;目前的SOTA模型研究也都专注在了消息传递机制的研究。见下图所示。
三种不同的图神经网络模型的消息传递机制差异
但是!即使消息传递机制你不完全明白也没有关系,你只要记住:不同GNN的本质差别就在于它们如何进行节点之间的信息传递和计算,也就是它们的消息传递机制不同。就可以了!
4. 图神经网络代码实现
我常用的包是PyG(PyTorch Geometric),它是一个为图形数据的处理和学习提供支持的PyTorch扩展库,提供了一系列工具来帮助开发者轻松地实现基于图形的机器学习任务,例如图分类、图回归、图生成等。下面我将使用PyG的内置数据进行三个任务的代码实现:
4.1 节点分类任务代码实现
Cora数据集是PyG内置的节点分类数据集,代表着学术论文的相关性分类问题(即把每一篇学术论文都看成是节点),Cora数据集有2708个节点,1433维特征,边数为5429,7类节点分类。
下面是代码【需要使用美国的IP,否则好像不能下载Cora数据,大伙可以试试】:
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv
#载入数据
dataset = Planetoid(root='~/tmp/Cora', name='Cora')
data = dataset
#定义网络架构
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = GCNConv(dataset.num_features, 16)#输入=节点特征维度,16是中间隐藏神经元个数
self.conv2 = GCNConv(16, dataset.num_classes)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = F.relu(x)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
data = data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
#模型训练
model.train()
for epoch in range(200):
optimizer.zero_grad()
out = model(data.x, data.edge_index) #模型的输入有节点特征还有边特征,使用的是全部数据
loss = F.nll_loss(out, data.y) #损失仅仅计算的是训练集的损失
loss.backward()
optimizer.step()
#测试:
model.eval()
test_predict = model(data.x, data.edge_index)
max_index = torch.argmax(test_predict, dim=1)
test_true = data.y
correct = 0
for i in range(len(max_index)):
if max_index == test_true:
correct += 1
print('测试集准确率为:{}%'.format(correct*100/len(test_true)))对于这个节点7分类的问题,最终在测试集(1000个样本)上的分类准确率为79.9%(见下图)。因为我们只是使用了一个很简单的模型架构,所以这个结果还说得过去。
测试结果
—————————————————————————————————————
未完待续。明天继续更。
页:
[1]