Deep Graph Library(DGL),一款面向图神经网络以及图机器学习的全新框架。DGL基于主流框架进行开发。用户可以使用他们偏爱的框架编写常见的CNN和注意力层,而当遇到图相关的计算时可以切换到DGL。用户和DGL的交互主要通过自定义函数UDF(user-defined function)。目前DGL支持Pytorch和MXNet/Gluon作为系统后端。
1. DGL
DGL是基于消息传递message passing的编程模型。原因在于图上的计算往往可以表示为2步:
- 发送节点:根据自身的特征计算需要向外分发的消息。
- 接收节点:对收到的消息进行累加并更新自身的特征。
用户需要自定义消息分发函数和消息聚合函数,来构造新的模型。
- 消息分发函数(message function):将结点自身的消息传递传递给其邻居。因为对每条边来说,每个源节点将会将自身的Embedding(e.src.data)和边的Embedding(edge.data)传递给目的节点。对于每个目的节点来说,它可能会收到多个源节点传过来的消息,它会将这些消息存储在mailbox中。
- 消息聚合函数(reduce function):聚合函数的目的是根据邻居传过来的消息更新自身节点Embedding,对每个节点来说,它先从邮箱(v.mailbox[‘m’])中汇聚消息函数所传递过来的消息(message),并清空邮箱(v.mailbox[‘m’])内的消息;然后该节点结合汇聚后的结果和该节点原Embedding,更新节点Embedding。
GCN的公式如下所示:
上面的数学描述可以利用消息传递的机制实现:
(1)在GCN中,每个节点都有属于自己的表示$h_i$
(2)根据消息传递(message passing),每个节点将会收到邻居节点发来的Embedding
(3)每个节点将聚合邻居节点的Embedding,得到中间表示$\hat{h_i}$
(4)对中间节点表示$\hat{h_i}$进行线性变换,然后利用非线性函数$f$进行计算:$h^{new}_u = f(W_u\hat{h}_u)$
(5)利用新的节点表示$h^{new}_u$对该节点的表示$h_u$进行更新。
2. 函数介绍
2.1. 创建图
1 | import dgl |
2.2. 获取节点和边的个数
1 | #查看节点和边个数 |
2.3. 分配节点和边的特征
1 | import torch |
2.4. 删除节点和边特征
1 | g.ndata.pop('feature') |
2.5. 自定义message函数
1 | def message_func(edges): |
2.6. 自定义reduce函数
1 | def reduce_func(nodes): |
2.7. 注册message和reduce函数
1 | #自定了message和reduce函数,在graph中注册,以便后续使用 |
2.8. update_all更新节点特征
该方法是上面方法的高级版本DGLGraph.update_all(message_func='default', reduce_func='default', apply_node_func='default')
传入的参数是message函数名,reduce函数名,UDF函数名,如果不传入,使用默认值。
1 | def pagerank_level2(g): |
2.9. 高级用法
dgl.function.copy_src(src, out)
:需要指定源节点的名称,和message的key值1
2
3
4
5import dgl
message_func = dgl.function.copy_src('feature', 'state')
等价于
def message_func(edges):
return {'state': edges.src['feature']}dgl.function.sum(msg, out)
:用于对目的节点的mailbox进行求和,需要指定message的key值和输出的名称。1
2
3
4
5import dgl
reduce_func = dgl.function.sum('state', 'new_state')
等价于
def reduce_func(nodes):
return {'new_state': torch.sum(nodes.mailbox['state'], dim=1)}
1 | import dgl.function as fn |