Source code for jittor_geometric.nn.conv.gcn_conv

'''
Description: 
Author: lusz
Date: 2024-11-29 10:14:19
'''
from typing import Optional, Tuple
from jittor_geometric.typing import Adj, OptVar
import jittor as jt
from jittor import Var,nn,Module
from jittor_geometric.utils import add_remaining_self_loops
from jittor_geometric.utils.num_nodes import maybe_num_nodes

from ..inits import glorot, zeros
from jittor_geometric.data import CSC, CSR
from jittor_geometric.ops import SpmmCsr, aggregateWithWeight


def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False,
             add_self_loops=True, dtype=None):
    fill_value = 2. if improved else 1.
    if isinstance(edge_index, Var):
        num_nodes = maybe_num_nodes(edge_index, num_nodes)
        if edge_weight is None:
            edge_weight = jt.ones((edge_index.size(1), ))
        if add_self_loops:
            edge_index, tmp_edge_weight = add_remaining_self_loops(
                edge_index, edge_weight, fill_value, num_nodes)
            assert tmp_edge_weight is not None
            edge_weight = tmp_edge_weight
        row, col = edge_index[0], edge_index[1]
        shape = list(edge_weight.shape)
        shape[0] = num_nodes
        deg = jt.zeros(shape)
        deg = jt.scatter(deg, 0, col, src=edge_weight, reduce='add')
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt.masked_fill(deg_inv_sqrt == float('inf'), 0)
        return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
    
    
[docs] class GCNConv(Module): r"""The graph convolutional operator from the `"Semi-supervised Classification with Graph Convolutional Networks" <https://arxiv.org/abs/1609.02907>`_ paper. This class implements the Graph Convolutional Network (GCN) layer, which updates node representations by aggregating information from neighboring nodes, taking the graph structure into account. It supports both message-passing and sparse matrix multiplication (SPMM) for propagation. Args: in_channels (int): Size of each input sample (number of input features per node). out_channels (int): Size of each output sample (number of output features per node). bias (bool, optional): If set to `True`, adds a learnable bias to the output. Default is `True`. spmm (bool, optional): If set to `True`, uses sparse matrix multiplication (SPMM) for propagation. Default is `True`. **kwargs (optional): Additional arguments for the base `Module`. Example: >>> dataset = MyGraphDataset(root='/path/to/dataset') >>> data = dataset[0] # Access the first graph object >>> conv = GCNConv(in_channels=16, out_channels=32, bias=True, spmm=True) >>> v_num = data.x.shape[0] # Number of nodes >>> edge_index, edge_weight = data.edge_index, data.edge_attr >>> edge_index, edge_weight = gcn_norm(edge_index, edge_weight, v_num, ... improved=False, add_self_loops=True) >>> with jt.no_grad(): ... csc = cootocsc(edge_index, edge_weight, v_num) ... csr = cootocsr(edge_index, edge_weight, v_num) >>> x = jt.random((v_num, 16)) # Randomly initialize node features >>> out = conv(x, csc, csr) # Apply GCN layer """ _cached_edge_index: Optional[Tuple[Var, Var]] _cached_csc: Optional[CSC] def __init__(self, in_channels: int, out_channels: int, bias: bool = True, spmm:bool=True, **kwargs): kwargs.setdefault('aggr', 'add') super(GCNConv, self).__init__(**kwargs) self.in_channels = in_channels self.out_channels = out_channels self._cached_edge_index = None self._cached_adj_t = None self.weight = jt.random((in_channels, out_channels)) if bias: self.bias = jt.random((out_channels,)) else: self.bias = None self.spmm=spmm self.reset_parameters()
[docs] def reset_parameters(self): glorot(self.weight) zeros(self.bias) self._cached_adj_t = None self._cached_csc=None
[docs] def execute(self, x: Var, csc: OptVar, csr: OptVar) -> Var: x = x @ self.weight if self.spmm and jt.flags.use_cuda==1: out = self.propagate_spmm(x=x, csr=csr) else: out = self.propagate_msg(x=x, csc=csc,csr=csr) if self.bias is not None: out += self.bias return out
# propagate by message passing
[docs] def propagate_msg(self,x, csc: CSC, csr:CSR): out = aggregateWithWeight(x,csc,csr) return out
# propagate by spmm
[docs] def propagate_spmm(self, x, csr:CSR): out = SpmmCsr(x,csr) return out
def __repr__(self): return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, self.out_channels)