Post

【DGL教程】第3章 构建GNN模块

官方文档:https://docs.dgl.ai/en/latest/guide/nn.html

DGL的NN模块用于构建GNN模型,相当于PyTorch的NN模块或TensorFlow的层

以PyTorch后端为例,DGL NN模块的使用方法和PyTorch相同——在构造函数中注册参数,在forward方法中进行张量运算,因此可以和其他PyTorch的NN模块无缝集成,主要的不同是消息传递

dgl.nn.pytorch.conv包实现了一些常用的图卷积模块

1.DGL NN模块的构造函数

构造函数的工作通常包括:

  • 设置选项(例如输出、输出和隐藏层的维数)
  • 注册可学习的参数或子模块
  • 重置(初始化)参数

2.DGL NN模块的forward方法

在NN模块中,forward方法执行实际的消息传递和计算

与PyTorch相比,DGL NN模块的forward方法接受一个额外的参数dgl.DGLGraph

forward方法的工作通常包括三部分:

  • 对图进行检查
  • 消息传递和归约(使用update_all()
  • 更新特征

官方文档以SAGEConv模块(GraphSAGE模型的实现)为例进行解析: https://docs.dgl.ai/en/latest/guide/nn-forward.html

3.异构图卷积模块

dgl.nn.pytorch.HeteroGraphConv

官方文档:https://docs.dgl.ai/en/latest/guide/nn-heterograph.html

该模块是在异构图上运行DGL的NN模块的模块级别的封装,实现逻辑与消息传递API multi_update_all()相同(相当于nn.LinearF.linear()的区别),包括:

  • 在每个关系上运行NN模块
  • 合并同一种顶点类型收到的来自多个关系的结果

用公式表示为

\[h_{dst}={AGG}_{r \in R(dst)}f_r(g_r,h_{src},h_{dst})\]

构造函数参数

1
2
3
4
5
6
7
8
9
def __init__(self, mods, aggregate='sum'):
    super(HeteroGraphConv, self).__init__()
    self.mods = nn.ModuleDict(mods)
    # ...
    if isinstance(aggregate, str):
        self.agg_fn = get_aggregate_fn(aggregate)
    else:
        self.agg_fn = aggregate
    self.agg_fn = aggregate

边类型到NN模块的映射mods、用于聚集不同关系生成的顶点特征的函数agg,其中每个NN模块forward()函数的前两个参数应当是g和feat,g是关系二分图,feat是顶点输入特征或源顶点和目的顶点输入特征的二元组

forward()方法

简化的forward()方法如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def forward(self, g, inputs):
    outputs = {nty : [] for nty in g.dsttypes}
    for stype, etype, dtype in g.canonical_etypes:
        rel_graph = g[stype, etype, dtype]
        if rel_graph.number_of_edges() == 0:
            continue
        dstdata = self.mods[etype](
            rel_graph,
            (inputs[stype], inputs[dtype]))
        outputs[dtype].append(dstdata)
    rsts = {}
    for nty, alist in outputs.items():
        if len(alist) != 0:
            rsts[nty] = self.agg_fn(alist, nty)
    return rsts

输入异构图g和顶点类型到输入特征的映射inputs,对于每种边类型(stype, etype, dtype),使用mods[etype](g[stype, etype, dtype], inputs[stype])计算dtype类型顶点的输出特征并添加到一个列表(边数为0的关系子图将被忽略),最后将每个顶点类型收到的所有结果使用agg聚集起来得到最终输出特征,返回顶点类型到输出特征的映射

聚集函数

接收两个参数:输出特征列表List[tensor(N, d)]和该顶点类型名称

内置聚集函数:

名称实现逻辑返回值
sumtorch.sum(torch.stack(outputs, dim=0), dim=0)tensor(N, d)
meantorch.mean(torch.stack(outputs, dim=0), dim=0)tensor(N, d)
maxtorch.max(torch.stack(outputs, dim=0), dim=0)[0]tensor(N, d)
mintorch.min(torch.stack(outputs, dim=0), dim=0)[0]tensor(N, d)
stacktorch.stack(outputs, dim=1)tensor(N, R, d)

自定义聚集函数示例:

1
2
3
4
def agg(outputs, dtype):
    outputs = torch.stack(outputs, dim=1)  # (N, R, d)
    outputs = ... # aggregation (N, R, d) -> (N, d)
    return outputs

例如:异构图由(A, ab, B), (A, ac, C), (B, bc, C)三种关系构成,则输出为

1
2
3
4
output = {
    'B': agg([mods['ab'](g['A', 'ab', 'B'], inputs['A'])]),
    'C': agg([mods['ac'](g['A', 'ac', 'C'], inputs['A']), mods['bc'](g['B', 'bc', 'C'], inputs['B']))
}
This post is licensed under CC BY 4.0 by the author.