yukamu 发表于 2022-5-8 15:16

通过数据驱动节点采样优化GraphSAGE

今天我分享一篇来自ICLR上面的ADVANCING GRAPHSAGE WITH A DATA-DRIVEN NODE SAMPLING的论文。
      作为一种高效的、可扩展的图神经网络,GraphSAGE通过聚合下采样的局部邻域和以小批量梯度下降方式学习,具有推断不可见节点或图的归纳能力。GraphSAGE是图神经网络的一个典型变体模型。大家对GraphSAGE不了解的,推荐看一下这https://zhuanlan.zhihu.com/p/62750137这篇文章。GraphSAGE中使用的邻域抽样方法在并行推断不同程度的一批目标节点时,可以有效地提高计算和存储效率。尽管有这一优点,但默认均匀抽样在训练和推理中存在高方差,导致精度不佳。本文提出了一种新的数据驱动的采样方法,通过一个非线性回归器来推理邻域的实值重要性,并将其值作为次采样邻域的标准。回归器是通过基于value的强化学习来学习的。从GraphSAGE的负分类损失输出中归纳提取每个顶点和邻域组合的隐含重要性。因此,在使用三个数据集的归纳节点分类基准中,此方法使用均匀抽样增强了基线,在准确性方面优于图神经网络的最新变体。      
       基于图结构网络数据的机器学习在许多重要的应用中得到了广泛应用。例如,它在化学预测问题(Gilmer et al.(2017))、蛋白质功能理解和粒子物理实验(Henrion et al. (2017);Choma等人(2018))。学习关于图的结构信息的表示可以发现将节点(或子图)嵌入到低维向量空间中的点的映射。基于邻域聚合的图神经网络算法,通过利用节点的属性解决了这个问题(Kipf & Welling (2016);Hamilton等人(2017);Pham等人(2017))。GraphSAGE算法(Hamilton et al.(2017))通过多次跳跃对局部邻域的固定数目节点进行均匀抽样递归子样本,并学习一组聚合器模型,该模型通过向原点回溯来聚合下采样节点的隐藏特征。采样方法使并行计算中每个批次的计算足迹保持固定。然而,尽管GraphSAGE的特点比较全面,但均匀分布的无偏随机抽样在训练和测试中方差较大,导致准确率不理想。在目前的工作中,我们提出了一种新的方法来取代子采样算法在GraphSAGE与数据驱动的采样算法,训练与强化学习。      
         首先我们来了解一下什么是强化学习。强化学习主要由智能体(Agent)、环境(Environment)、状态(State)、动作(Action)、奖励(Reward)组成。智能体执行了某个动作后,环境将会转换到一个新的状态,对于该新的状态环境会给出奖励信号(正奖励或者负奖励)。随后,智能体根据新的状态和环境反馈的奖励,按照一定的策略执行新的动作。上述过程为智能体和环境通过状态、动作、奖励进行交互的方式。智能体通过强化学习,可以知道自己在什么状态下,应该采取什么样的动作使得自身获得最大奖励。由于智能体与环境的交互方式与人类与环境的交互方式类似,可以认为强化学习是一套通用的学习框架。


      本文主要分以下两步优化:      1:基于值函数的节点抽样强化学习         为了替换之前的统一采样器,本文考虑了一种强化学习方法,它可以帮助学习如何在新的数据集中快速找到一个好的采样分布。每一步奖励Rkv,u,是在给定k跳一致下采样邻域和直接连接的1跳邻域u1的节点v处计算的交叉熵损失的负值。注意,每一步奖励是一个不应用小批量目标节点求和的批量值,v∈v:



其中Fθ为GraphSAGE的聚合器,u1 ∪ · · · ∪uk表示输入目标节点v和k-hop下采样邻域。Ckv,u记录每一步的访问计数,Ckv,u,记录了(v, u)被索引的次数。


为了产生每一步的奖励,GraphSAGE预测所有中间层的类yk。为此,我们在最后一层旁边的每一中间层上添加辅助分类层。我们考虑由从第一跳传播到最后第k跳的每步奖励的折现和组成的返回G:


其中γ∈(0,1]是一个折现因子,它对来自未来奖励的贡献进行折现。换句话说,当γ较低时,我们认为距离较近的邻域对Gv,u的回归有较大的影响。为了避免计算所有每步奖励的开销,我们探索了一个近似方案,当k < K时,我们将Rk设为0.方程7可以用近似于全跳学习由最后一跳代替;,u = RKv u。访问次数Cv,u表示每一步的访问次数总和


这个返回使用强化学习对策略π进行了优化。策略的输入为目标节点及其邻域的候选节点,输出动作空间为1或0,表示是否被选为子样本。与此策略相关的值函数表示为Vv,u;我们回忆一下,它是在从目标节点v到邻近节点u的策略下,由Gv,u除以Cv,u得到的期望收益。价值函数和相邻节点之间的关系u∈u1连接到目标节点定义如下


2:非线性回归模型的代价函数一个可能的状态(v, u)并不局限于训练中观察到的有限节点集。这是因为它假设图表是不断发展的;也就是说,在测试期间可以观察到不可见的节点。因此,我们考虑使用状态(V, u)属性的非线性组合来逼近值函数V:


式中设xv和xu分别是节点v和邻域每个成员u∈N (v)的m维输入向量(属性)。θ表示可微分非线性回归函数的权值,G. a大小为1 × 2M的权值矩阵W和偏置b为单个待学习感知器层的参数。利用小批量梯度下降优化方法,训练该模型最小化方程9中得到的真值函数Vv,u与输出Vv,u之间的l2范数。在所有深度的采样邻域中共享学习到的权值。
       谢谢你看完了文章,欢迎点亮【赞】+【保存】+【转发】,让更多的人看到,也欢迎大家关注我的公众号GNN图神经网络。
这是原文链接地址:https://www.researchgate.net/publication/332779223_Advancing_GraphSAGE_with_A_Data-Driven_Node_Sampling
页: [1]
查看完整版本: 通过数据驱动节点采样优化GraphSAGE