from typing import Optional
from jittor_geometric.typing import Adj, OptVar
import jittor as jt
from jittor import Var
from jittor.nn import Linear
from jittor_geometric.nn.conv import MessagePassing
from ..inits import glorot, zeros
from jittor_geometric.data import CSC, CSR
from jittor_geometric.ops import SpmmCsr, aggregateWithWeight
[docs]
class SGConv(MessagePassing):
r"""The simple graph convolutional operator from the `"Simplifying Graph
Convolutional Networks" <https://arxiv.org/abs/1902.07153>`_ paper.
This class implements the Simplified Graph Convolution (SGC) layer, which removes nonlinearities and collapses weight
matrices across layers to achieve a simplified and computationally efficient graph convolution.
Mathematical Formulation:
.. math::
\mathbf{X}^{\prime} = {\left(\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}}
\mathbf{\hat{D}}^{-1/2} \right)}^K \mathbf{X} \mathbf{\Theta},
where:
- :math:`\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}` denotes the adjacency matrix with added self-loops.
- :math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}` is its diagonal degree matrix.
- :math:`K` controls the number of propagation steps.
- The adjacency matrix can include other values than :obj:`1`, representing edge weights via the optional `edge_weight` variable.
Args:
in_channels (int): Number of input features per node.
out_channels (int): Number of output features per node.
K (int, optional): Number of propagation steps. Default is `1`.
bias (bool, optional): Whether to include a learnable bias term. Default is `True`.
**kwargs (optional): Additional arguments for the `MessagePassing` class.
"""
_cached_x: Optional[Var]
def __init__(self, in_channels: int, out_channels: int, K: int = 1, bias: bool = True, spmm:bool=True, **kwargs):
kwargs.setdefault('aggr', 'add')
super(SGConv, self).__init__(**kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.K = K
self.lin = Linear(in_channels, out_channels, bias=bias)
self.spmm=spmm
self.reset_parameters()
[docs]
def reset_parameters(self):
glorot(self.lin.parameters()[0])
zeros(self.lin.parameters()[1])
self._cached_adj_t = None
self._cached_csc=None
[docs]
def execute(self, x: Var, csc: OptVar, csr: OptVar) -> Var:
"""Perform forward propagation."""
for k in range(self.K):
if self.spmm and jt.flags.use_cuda == 1:
x = self.propagate_spmm(x=x, csr=csr)
else:
x = self.propagate_msg(x=x, csc=csc, csr=csr)
return self.lin(x)
# propagate by message passing
[docs]
def propagate_msg(self,x, csc: CSC, csr:CSR):
out = aggregateWithWeight(x,csc,csr)
return out
# propagate by spmm
[docs]
def propagate_spmm(self, x, csr:CSR):
out = SpmmCsr(x,csr)
return out
def __repr__(self):
return '{}({}, {}, K={})'.format(self.__class__.__name__,
self.in_channels, self.out_channels,
self.K)