Source code for jittor_geometric.nn.conv.message_passing

import os
import os.path as osp
import sys
from typing import List, Optional, Set
from inspect import Parameter

import jittor as jt
from jittor import nn, Module
from jittor import Var
from jittor_geometric.typing import Adj, Size

from .utils.inspector import Inspector


[docs] class MessagePassing(Module): special_args: Set[str] = { 'edge_index', 'adj_t', 'edge_index_i', 'edge_index_j', 'size', 'size_i', 'size_j', 'ptr', 'index', 'dim_size' } def __init__(self, aggr: Optional[str] = "add", flow: str = "source_to_target", node_dim: int = -2): super(MessagePassing, self).__init__() self.aggr = aggr assert self.aggr in ['add', 'mean', 'max', None] #聚合方法 ('add', 'mean', 'max')。默认为 'add' self.flow = flow assert self.flow in ['source_to_target', 'target_to_source'] # 信息传递的方向 ('source_to_target' 或 'target_to_source') self.node_dim = node_dim self.inspector = Inspector(self) self.inspector.inspect(self.message) self.inspector.inspect(self.aggregate, pop_first=True) self.inspector.inspect(self.update, pop_first=True) self.__user_args__ = self.inspector.keys( ['message', 'aggregate', 'update']).difference(self.special_args) def __check_input__(self, edge_index, size): the_size: List[Optional[int]] = [None, None] if isinstance(edge_index, Var): assert edge_index.dtype == Var.int32 assert edge_index.ndim == 2 assert edge_index.size(0) == 2 if size is not None: the_size[0] = size[0] the_size[1] = size[1] return the_size raise ValueError( ('`MessagePassing.propagate` only supports `jittor Var int32` of ' 'shape `[2, num_messages]`')) def __set_size__(self, size: List[Optional[int]], dim: int, src: Var): the_size = size[dim] if the_size is None: size[dim] = src.size(self.node_dim) elif the_size != src.size(self.node_dim): raise ValueError( (f'Encountered Var with size {src.size(self.node_dim)} in ' f'dimension {self.node_dim}, but expected size {the_size}.')) def __lift__(self, src, edge_index, dim): if isinstance(edge_index, Var): index = edge_index[dim] return src[(slice(None),)*self.node_dim+(index,)] raise ValueError def __collect__(self, args, edge_index, size, kwargs): i, j = (1, 0) if self.flow == 'source_to_target' else (0, 1) out = {} for arg in args: if arg[-2:] not in ['_i', '_j']: out[arg] = kwargs.get(arg, Parameter.empty) else: dim = 0 if arg[-2:] == '_j' else 1 data = kwargs.get(arg[:-2], Parameter.empty) if isinstance(data, (tuple, list)): # 如果数据是元组或列表 assert len(data) == 2 if isinstance(data[1 - dim], Var): self.__set_size__(size, 1 - dim, data[1 - dim]) data = data[dim] if isinstance(data, Var): self.__set_size__(size, dim, data) data = self.__lift__(data, edge_index, j if arg[-2:] == '_j' else i) out[arg] = data if isinstance(edge_index, Var): out['adj_t'] = None out['edge_index'] = edge_index out['edge_index_i'] = edge_index[i] out['edge_index_j'] = edge_index[j] out['ptr'] = None out['index'] = out['edge_index_i'] out['size'] = size out['size_i'] = size[1] or size[0] out['size_j'] = size[0] or size[1] out['dim_size'] = out['size_i'] return out
[docs] def propagate(self, edge_index: Adj, size: Size = None, **kwargs): size = self.__check_input__(edge_index, size) if isinstance(edge_index, Var): coll_dict = self.__collect__(self.__user_args__, edge_index, size, kwargs) msg_kwargs = self.inspector.distribute('message', coll_dict) out = self.message(**msg_kwargs) aggr_kwargs = self.inspector.distribute('aggregate', coll_dict) out = self.aggregate(out, **aggr_kwargs) update_kwargs = self.inspector.distribute('update', coll_dict) return self.update(out, **update_kwargs)
[docs] def message(self, x_j: Var) -> Var: r"""Constructs messages from node :math:`j` to node :math:`i` in analogy to :math:`\phi_{\mathbf{\Theta}}` for each edge in :obj:`edge_index`. This function can take any argument as input which was initially passed to :meth:`propagate`. Furthermore, Var passed to :meth:`propagate` can be mapped to the respective nodes :math:`i` and :math:`j` by appending :obj:`_i` or :obj:`_j` to the variable name, *.e.g.* :obj:`x_i` and :obj:`x_j`. """ return x_j
[docs] def aggregate(self, inputs: Var, index: Var, ptr: Optional[Var] = None, dim_size: Optional[int] = None) -> Var: shape = list(inputs.shape) shape[self.node_dim] = dim_size out = jt.zeros(shape) return jt.scatter(out, 0, index, src=inputs, reduce=self.aggr)
[docs] def update(self, inputs: Var) -> Var: r"""Updates node embeddings in analogy to :math:`\gamma_{\mathbf{\Theta}}` for each node :math:`i \in \mathcal{V}`. Takes in the output of aggregation as first argument and any argument which was initially passed to :meth:`propagate`. """ return inputs