开发者

Pytorch-Geometric中的Message Passing使用及说明

目录
  • Pytorch-Geometric中Message Passing使用
    • 具体函数说明如下
    • GCN 的计算公式如下
    • 实际计算工程可以分为下面几步
  • 总结

    Pytorch-Geometric中Message Passing使用

    图中的卷积计算通常被称为邻域聚合或者消息传递 (neighborhood aggregation or message passing).

    定义

    Pytorch-Geometric中的Message Passing使用及说明

    为节点i在第(k−1)层的特征,ej,i表示节点j到节点i的边特征,在GNN中消息传递可以表示为

    Pytorch-Geometric中的Message Passing使用及说明

    其中 □ 表示具有置换不变性并且可微的函数,例如 sum, mean, max 等, γ 和 表示可微函数。

    在 PyTorch Gemetric 中,所有卷积算子都是由 MessagePassing 类派生而来,理解 MessagePasing 有助于我们理解 PyG 中消息传递的计算方式和编写自定义的卷积。

    在自定义卷积中,用户只需定义消息传递函数 message(), 节点更新函数 γ update() 以及聚合方式 aggr='add', aggr='mean' 或则 aggr=max.

    具体函数说明如下

    • MessagePassing(aggr='add', flow='source_to_target', node_dim=-2) 定义聚合计算的方式 ('add', 'mean' or max ) 以及消息的传递方向 (source_to_target or target_to_source ). 在 PyG 中,中心节点为目标 target,邻域节点为源 source. node_dim 为消息聚合的维度
    • MessagePassing.propagate(edge_index, size=None, **kwargs): 该函数接受边信息 edge_ind编程ex 和其他额外的数据来执行消息传递并更新节点嵌入
    • MessagePassing.message(...): 该函数的作用是计算节点消息,就是公式中的函数 \phi . 如果 flow='source_to_target' ,那么消息将由邻域节点 j j j 传向中心节点 i i i ;如果 flow='target_to_source',消息则由中心节点 i i i 传向邻域节点 j j j . 传入参数的节点类型可以通过变量名后缀来确定,例如中心节点嵌入变量一般以 _i 为结尾,邻域节点嵌入变量以 x_j 为结尾
    • MessagePassing.update(arr_out, ...): 该函数为节点嵌入的更新函数 γ \gamma γ , 输入参数为聚合函数 MessagePassing.aggregate 计算的结果

    为了更好的理解 PyG 中 MessagePassing 的计算过程,我们来分析一下源代码。

    class MessagePassing(torch.nn.Module):
    
        special_args: Set[str] = {
            'edge_index', 'adj_t', 'edge_index_i', 'edge_index_j', 'size',
            'size_i', 'size_j', 'ptr', 'index', 'dim_size'
        }
    
        def __init__(self, aggr: Optional[str] = "add",
                     flow: str = "source_to_target", node_dim: int = -2):
    
            super(MessagePassing, self).__init__()
    
            self.aggr = aggr
            assert self.aggr in ['add', 'mean', 'max', None]
    
            self.flow = flow
            assert self.flow in ['source_to_target', 'target_to_source']
    
            self.node_dim = node_dim
    
            self.inspector = Inspector(self)
            self.inspector.inspect(self.message)
            self.inspector.inspect(self.aggregate, pop_first=True)
            self.inspector.inspect(self.message_and_aggregate, pop_first=True)
            self.inspector.inspect(self.update, pop_first=True)
    
            self.__user_args__ = self.inspector.keys(
                ['message', 'aggregate', 'update']).difference(self.special_args)
            self.__fused_user_args__ = self.inspector.keys(
                ['message_and_aggregate', 'update']).difference(self.special_args)
    
            # Support for "fused" message passing.
            self.fuse = self.inspector.implements('message_and_aggregate')
    
            # Support for GNNExplainer.
            self.__explain__ = False
            self.__edge_mask__ = None
    

    在初始化函数中,MessagePassing 定义了一个 Inspector . Inspector 的中文意思是检查员的意思,这个类的作用就是检查各个函数的输入参数,并保存到 Inspector的参数列表字典中 Inspector.params中。

    如果 message的输入参数为 x_i, x_j,那么Inspector.params['message']={'x_i': Parameter, 'x_j': Parameter} (注:这里仅作示意,实际 Inspector.params['message'] 类型为 OrderedDict). Inspector.implements 检查函数是否实现.

    MessagePasing 中最核心的是 propgate 函数,假设邻接矩阵 edge_index 的类型为 Torch.LongTensor,消息由 edge_index[0] 传向 edge_index[1] ,代码实现如下

    def propagate(self, edge_index: Adj, size: Size = None, **kwargs):
        # 为了简化问题,这里不讨论 edge_index 为 SparseTensor 的情况,感兴趣的可阅读 PyG 原始代码
    
        size = self.__check_input__(edge_index, size)
        coll_dict = self.__collect__(self.__user_args__, edge_index, size,
                                     kwargs)
    
        msg_kwargs = self.inspector.distribute('message', coll_dict)
        out = self.message(**msg_kwargs)
    
        aggr_kwargs = self.inspector.distribute('aggregate', coll_dict)
        out = self.aggregate(out, **aggr_kwargs)
    
        update_kwargs = self.inspector.distribute('update', coll_dict)
        return self.update(out, **update_kwargs)        
    
    

    在这段代码中,首先是检查节点数量和用户自定义的输入变量,然后依次执行 message, aggregateupdate 函数。

    如果是自定义图卷积,一般会重写 messageupdate,这一点随后再以 GCN 为例解释,这里首先来看一下 aggregate 的实现

    def aggregate(self, inputs: Tensor, index: Tensor,
                  ptr: Optional[Tensor] = None,
                  dim_size: Optional[int] = None) -> Tensor:
        if ptr is not None:
            ptr = expand_left(ptr, dim=self.node_dim, dims=inputs.dim())
            return segment_csr(inputs, ptr, reduce=self.aggr)
        else:
            return scatter(inputs, index, dim=self.node_dim, dim_size=dim_size,
                           reduce=self.aggr)
    

    ptr 变量是针对邻接矩阵 edge_indexSparseTensor的情况,此处暂且不论

    inputsmessage计算得到的消息, index 就是待更新节点的索引,实际上就是 edge_index_i. 聚合计算通过 scatter 函数实现。scatter 具体实现参考链接

    下面以 GCN 为例,我们来看一下 MessagePassing 的计算过程。

    GCN 的计算公式如下

    Pytorch-Geometric中的Message Passing使用及说明

    实际计算工程可以分为下面几步

    • 1.在邻接矩阵中增加自循环,即把邻接矩阵的对角线上的元素设为1
    • 2.对节点特征矩阵做线性变换
    • 3.计算节点的归一化系数,也就是节点度乘积的开方
    • 4.对节点特征做归一化处理
    • 5.聚合(求和)节点特征得到新的节点嵌入

    代码如下

    import torch
    from torch_geometric.nn import MessagePassing
    from torch_geometric.utils import add_self_loops, degree
    
    class GCNConv(MessagePassing):
        def __init__(self, in_channels, out_channels):
            super(GCNConv, self).__init__(aggr='add')  # "Add" aggregation (Step 5).
            self.lin = torch.nn.Linear(in_channels, out_channels)
    
        def forwwww.devze.comard(self, x, edge_index):
            # x has shape [N, in_channels]
            # edge_index has shape [2, E]
    
            # Step 1: Add self-loops to the adjacency matrix.
            edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
    
            # Step 2: Linearly transform node feature matrix.
            x = self.lphpin(x)
    
            # Step 3: Compute normalization.
            row, col = edge_index
            deg = degree(col, x.size(0), dtype=x.dtype)
            deg_inv_sqrt = deg.pow(-0.5)
            deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
            norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
    
            # Steppython 4-5: Start propagating messages.
            return self.propagate(edge_index, x=x, norm=norm)
    
        def message(self, x_j, norm):
            # x_j has shape [E, out_channels]
    
            # Step 4: Normalize node features.
            return norm.view(-1, 1) * x_j
    

    forward 函数中,首先是给节点边增加自循环。设输入变量如下

    edge_index = torch.tensor([[0, 0, 2], [1, 2, 3]], dtype=torch.long)
    x = torch.rand((4, 3)) 
    conv = GCNConv(3, 8)
    

    注意到默认消息传递方向为 source_to_target,此时edge_index[0]=x_j 为 source, edge_index[1]=x_i 为 target.

    在 GCN 中,第一步是增加节点的自循环,add_self_loops 计算前后变化如下

    # before add_self_loops
    # edge_index=
    tensor([[0, 0, 2],
            [1, 2, 3]])
    # after add_self_loops
    # edge_index=
    tensor([[0, 0, 2, 0, 1, 2, 3],
            [1, 2,编程 3, 0, 1, 2, 3]])
    # norm=
    tensor([0.7071, 0.7071, 0.5000, 1.0000, 0.5000, 0.5000, 0.5000]
    

    此处的 propagate 的输出参数由 edge_index, x, norm , edge_indexpropagete 必须输入的参数,x, no开发者_Python入门rm 为用户自定义参数。

    __collect__ 会根据变量名称来收集 message 需要的输入参数。

    在 GCN 中,norm 保持不变,x 将被映射到 x_j ,并且经过 __lift__ 函数,其值也会发生变化。__lift__ 函数如下

    def __lift__(self, src, edge_index, dim):
        if isinstance(edge_index, Tensor):
            index = edge_index[dim]
            return src.index_select(self.node_dim, index)
    

    在本例中,输入的特征 shape=[4, 8],经过 __lift__ 后,节点特征 shape=[7, 8] . 经过 message 计算后,就可以执行 aggregateupdate 了。

    总结

    以上为个人经验,希望能给大家一个参考,也希望大家多多支持我们。

    0

    上一篇:

    下一篇:

    精彩评论

    暂无评论...
    验证码 换一张
    取 消

    最新开发

    开发排行榜