DGL

    Deep Graph Library(DGL),一款面向图神经网络以及图机器学习的全新框架。DGL基于主流框架进行开发。用户可以使用他们偏爱的框架编写常见的CNN和注意力层,而当遇到图相关的计算时可以切换到DGL。用户和DGL的交互主要通过自定义函数UDF(user-defined function)。目前DGL支持Pytorch和MXNet/Gluon作为系统后端。

GNN 教程:DGL框架-消息和GCN的实现

1. DGL

DGL是基于消息传递message passing的编程模型。原因在于图上的计算往往可以表示为2步:

  1. 发送节点:根据自身的特征计算需要向外分发的消息。
  2. 接收节点:对收到的消息进行累加并更新自身的特征。
    用户需要自定义消息分发函数消息聚合函数,来构造新的模型。
  • 消息分发函数(message function):将结点自身的消息传递传递给其邻居。因为对每条边来说,每个源节点将会将自身的Embedding(e.src.data)和边的Embedding(edge.data)传递给目的节点。对于每个目的节点来说,它可能会收到多个源节点传过来的消息,它会将这些消息存储在mailbox中。
  • 消息聚合函数(reduce function):聚合函数的目的是根据邻居传过来的消息更新自身节点Embedding,对每个节点来说,它先从邮箱(v.mailbox[‘m’])中汇聚消息函数所传递过来的消息(message),并清空邮箱(v.mailbox[‘m’])内的消息;然后该节点结合汇聚后的结果和该节点原Embedding,更新节点Embedding。

GCN的公式如下所示:

上面的数学描述可以利用消息传递的机制实现:
(1)在GCN中,每个节点都有属于自己的表示$h_i$
(2)根据消息传递(message passing),每个节点将会收到邻居节点发来的Embedding
(3)每个节点将聚合邻居节点的Embedding,得到中间表示$\hat{h_i}$
(4)对中间节点表示$\hat{h_i}$进行线性变换,然后利用非线性函数$f$进行计算:$h^{new}_u = f(W_u\hat{h}_u)$
(5)利用新的节点表示$h^{new}_u$对该节点的表示$h_u$进行更新。

2. 函数介绍

官方文档
实现GCN例子

2.1. 创建图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import dgl
g = dgl.DGLGraph()
#为图添加节点和边
g.add_node(34)#添加34个节点
#一共有78条边
edge_list = [(1, 0), (2, 0), (2, 1), (3, 0), (3, 1), (3, 2),
(4, 0), (5, 0), (6, 0), (6, 4), (6, 5), (7, 0), (7, 1),
(7, 2), (7, 3), (8, 0), (8, 2), (9, 2), (10, 0), (10, 4),
(10, 5), (11, 0), (12, 0), (12, 3), (13, 0), (13, 1), (13, 2),
(13, 3), (16, 5), (16, 6), (17, 0), (17, 1), (19, 0), (19, 1),
(21, 0), (21, 1), (25, 23), (25, 24), (27, 2), (27, 23),
(27, 24), (28, 2), (29, 23), (29, 26), (30, 1), (30, 8),
(31, 0), (31, 24), (31, 25), (31, 28), (32, 2), (32, 8),
(32, 14), (32, 15), (32, 18), (32, 20), (32, 22), (32, 23),
(32, 29), (32, 30), (32, 31), (33, 8), (33, 9), (33, 13),
(33, 14), (33, 15), (33, 18), (33, 19), (33, 20), (33, 22),
(33, 23), (33, 26), (33, 27), (33, 28), (33, 29), (33, 30),
(33, 31), (33, 32)]
#添加边的源节点和目的节点
drc,dst = tuple(zip(*edge_list))
g.add_edges(src,dst)
#边是双向的
g.add_edges(dst,src)

2.2. 获取节点和边的个数

1
2
3
4
5
6
#查看节点和边个数
g.number_of_nodes()
g.number_of_edges()

#查看节点和边类型
g.node_attr_schemes()

2.3. 分配节点和边的特征

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch  
#分配节点特征
g.ndata['feature'] = torch.eye(34)

#获取某个节点的特征
G.nodes[2].data['feature']
G.nodes[[10,11]].data['feature']

#分配边特征,9条边,每条边特征有2个
g.edata['edge_feature'] = torch.randn(9,2)

#单独为每条边赋值
g.edata[1].data['edge_feature'] = torch.randn(1,2)
g.edata[[0,1,2]].data['edge_feature'] = torch.randn(3,2)
#同时指定起点和终点
g.edata[[1,2,3],[0,0,0]].data['edge_feature'] = torch.randn(3,2)

#查看图的节点特征和边特征
g.ndata,g.edata

2.4. 删除节点和边特征

1
2
g.ndata.pop('feature')
g.edata.pop('edge_feature')

2.5. 自定义message函数

1
2
3
4
5
6
7
8
9
def message_func(edges):
"""
在该函数中,接收一个参数edges,edges有3个成员变量:
edges.src:获取源节点
edges.dst:获取目的节点
edges.data:获取边
主要是向目的节点传递消息,返回的格式是dict
"""
return { 'alpha': alpha, 'state': edge.src['state'] }

2.6. 自定义reduce函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def reduce_func(nodes):
"""
源节点通过message函数,将消息发送给目的节点,目的节点接收多个邻居发来的消息,并存储在mailbox中,reduce函数聚合多个邻居发来的消息,并以dict的形式返回。
reduce函数,接收一个参数nodes,nodes有2个成员变量
nodes.data:获取节点的特征
nodes.mailbox:获取message函数返回的值

"""
state = nodes.mailbox['state']
alpha = nodes.mailbox['alpha']
alpha = nd.softmax(alpha, axis=1)

new_state = nd.relu(nd.sum(alpha * state, axis=1))
return { 'new_state': new_state }

2.7. 注册message和reduce函数

1
2
3
4
5
6
7
8
9
10
11
#自定了message和reduce函数,在graph中注册,以便后续使用 
g.register_message_func(message_func)
g.register_reduce_func(reduce_func)

def pagerank_batch(g):
g.send(g.edges())
g.recv(g.nodes())

#如果没有将自定义的message和reduce函数注册,使用以下语句
g.send(g.edges(),message_func)
g.recv(g.nodes(),reduce_func)

2.8. update_all更新节点特征

该方法是上面方法的高级版本
DGLGraph.update_all(message_func='default', reduce_func='default', apply_node_func='default')
传入的参数是message函数名,reduce函数名,UDF函数名,如果不传入,使用默认值。

1
2
3
def pagerank_level2(g):
# g.update_all()
g.update_all(self.message_func,self.reduce_func)

2.9. 高级用法

PageRank实现

  • dgl.function.copy_src(src, out):需要指定源节点的名称,和message的key值

    1
    2
    3
    4
    5
    import dgl
    message_func = dgl.function.copy_src('feature', 'state')
    等价于
    def message_func(edges):
    return {'state': edges.src['feature']}
  • dgl.function.sum(msg, out):用于对目的节点的mailbox进行求和,需要指定message的key值和输出的名称。

    1
    2
    3
    4
    5
    import dgl
    reduce_func = dgl.function.sum('state', 'new_state')
    等价于
    def reduce_func(nodes):
    return {'new_state': torch.sum(nodes.mailbox['state'], dim=1)}
1
2
3
4
5
6
7
8
9
import dgl.function as fn

def pagerank_builtin(g):
N = 100 # number of nodes
DAMP = 0.85 # damping factor
g.ndata['pv'] = g.ndata['pv'] / g.ndata['deg']
g.update_all(message_func=fn.copy_src(src='pv', out='m',
reduce_func=fn.sum(msg='m',out='m_sum'))
g.ndata['pv'] = (1 - DAMP) / N + DAMP * g.ndata['m_sum']
打赏
0%