import os
import os.path as osp
import sys
from typing import List, Optional, Set
from inspect import Parameter
import jittor as jt
from jittor import nn, Module
from jittor import Var
from jittor_geometric.typing import Adj, Size
from .utils.inspector import Inspector
[docs]
class MessagePassing(Module):
special_args: Set[str] = {
'edge_index', 'adj_t', 'edge_index_i', 'edge_index_j', 'size',
'size_i', 'size_j', 'ptr', 'index', 'dim_size'
}
def __init__(self, aggr: Optional[str] = "add",
flow: str = "source_to_target", node_dim: int = -2):
super(MessagePassing, self).__init__()
self.aggr = aggr
assert self.aggr in ['add', 'mean', 'max', None] #聚合方法 ('add', 'mean', 'max')。默认为 'add'
self.flow = flow
assert self.flow in ['source_to_target', 'target_to_source'] # 信息传递的方向 ('source_to_target' 或 'target_to_source')
self.node_dim = node_dim
self.inspector = Inspector(self)
self.inspector.inspect(self.message)
self.inspector.inspect(self.aggregate, pop_first=True)
self.inspector.inspect(self.update, pop_first=True)
self.__user_args__ = self.inspector.keys(
['message', 'aggregate', 'update']).difference(self.special_args)
def __check_input__(self, edge_index, size):
the_size: List[Optional[int]] = [None, None]
if isinstance(edge_index, Var):
assert edge_index.dtype == Var.int32
assert edge_index.ndim == 2
assert edge_index.size(0) == 2
if size is not None:
the_size[0] = size[0]
the_size[1] = size[1]
return the_size
raise ValueError(
('`MessagePassing.propagate` only supports `jittor Var int32` of '
'shape `[2, num_messages]`'))
def __set_size__(self, size: List[Optional[int]], dim: int, src: Var):
the_size = size[dim]
if the_size is None:
size[dim] = src.size(self.node_dim)
elif the_size != src.size(self.node_dim):
raise ValueError(
(f'Encountered Var with size {src.size(self.node_dim)} in '
f'dimension {self.node_dim}, but expected size {the_size}.'))
def __lift__(self, src, edge_index, dim):
if isinstance(edge_index, Var):
index = edge_index[dim]
return src[(slice(None),)*self.node_dim+(index,)]
raise ValueError
def __collect__(self, args, edge_index, size, kwargs):
i, j = (1, 0) if self.flow == 'source_to_target' else (0, 1)
out = {}
for arg in args:
if arg[-2:] not in ['_i', '_j']:
out[arg] = kwargs.get(arg, Parameter.empty)
else:
dim = 0 if arg[-2:] == '_j' else 1
data = kwargs.get(arg[:-2], Parameter.empty)
if isinstance(data, (tuple, list)):
# 如果数据是元组或列表
assert len(data) == 2
if isinstance(data[1 - dim], Var):
self.__set_size__(size, 1 - dim, data[1 - dim])
data = data[dim]
if isinstance(data, Var):
self.__set_size__(size, dim, data)
data = self.__lift__(data, edge_index,
j if arg[-2:] == '_j' else i)
out[arg] = data
if isinstance(edge_index, Var):
out['adj_t'] = None
out['edge_index'] = edge_index
out['edge_index_i'] = edge_index[i]
out['edge_index_j'] = edge_index[j]
out['ptr'] = None
out['index'] = out['edge_index_i']
out['size'] = size
out['size_i'] = size[1] or size[0]
out['size_j'] = size[0] or size[1]
out['dim_size'] = out['size_i']
return out
[docs]
def propagate(self, edge_index: Adj, size: Size = None, **kwargs):
size = self.__check_input__(edge_index, size)
if isinstance(edge_index, Var):
coll_dict = self.__collect__(self.__user_args__, edge_index, size,
kwargs)
msg_kwargs = self.inspector.distribute('message', coll_dict)
out = self.message(**msg_kwargs)
aggr_kwargs = self.inspector.distribute('aggregate', coll_dict)
out = self.aggregate(out, **aggr_kwargs)
update_kwargs = self.inspector.distribute('update', coll_dict)
return self.update(out, **update_kwargs)
[docs]
def message(self, x_j: Var) -> Var:
r"""Constructs messages from node :math:`j` to node :math:`i`
in analogy to :math:`\phi_{\mathbf{\Theta}}` for each edge in
:obj:`edge_index`.
This function can take any argument as input which was initially
passed to :meth:`propagate`.
Furthermore, Var passed to :meth:`propagate` can be mapped to the
respective nodes :math:`i` and :math:`j` by appending :obj:`_i` or
:obj:`_j` to the variable name, *.e.g.* :obj:`x_i` and :obj:`x_j`.
"""
return x_j
[docs]
def aggregate(self, inputs: Var, index: Var,
ptr: Optional[Var] = None,
dim_size: Optional[int] = None) -> Var:
shape = list(inputs.shape)
shape[self.node_dim] = dim_size
out = jt.zeros(shape)
return jt.scatter(out, 0, index, src=inputs, reduce=self.aggr)
[docs]
def update(self, inputs: Var) -> Var:
r"""Updates node embeddings in analogy to
:math:`\gamma_{\mathbf{\Theta}}` for each node
:math:`i \in \mathcal{V}`.
Takes in the output of aggregation as first argument and any argument
which was initially passed to :meth:`propagate`.
"""
return inputs