Source code for jittor_geometric.ops.aggregateWithWeight

'''
Description: 
Author: lusz
Date: 2024-06-21 14:50:39
'''
import jittor as jt
import os
import sys
from jittor import nn,Var
from jittor import Function
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from jittor_geometric.data import CSC, CSR
module_path = os.path.dirname(__file__)
src = os.path.join(module_path, "cpp/aggregate_op.cc")
header = os.path.join(module_path, "cpp/aggregate_op.h")

aggregate_op = jt.compile_custom_ops((src, header))
# Run the test
class AggregateFunc(Function):
    def execute(self,x,csc,csr,edge_weight):
        self.csc=csc
        self.csr=csr
        self.weight=edge_weight
        if isinstance(edge_weight, Var)==False:
            edge_weight=csc.edge_weight
        indices=csc.row_indices
        offset=csc.column_offset
        output=jt.zeros_like(x)
        aggregate_op.aggregate(output,x,indices,offset,edge_weight,True).fetch_sync()
        return output

    def grad(self, grad_output):
        if isinstance(self.weight, Var)==False:
            edge_weight=self.csr.edge_weight
        else:
            edge_weight=self.weight
        indices=self.csr.column_indices
        offset=self.csr.row_offset
        output_grad=jt.zeros_like(grad_output)
        aggregate_op.aggregate(output_grad,grad_output,indices,offset,edge_weight,False).fetch_sync()
        return output_grad,None,None
    
'''
description:  This function performs aggregation on the vertex embedding matrix using CSC (Compressed Sparse Column) 
and CSR (Compressed Sparse Row) representations of the graph
param {*} x The vertex embedding matrix of shape (v_num, dim), where `v_num` is the number of vertices and `dim` is the dimension of the embeddings.
param {*} csc
param {*} csr
return {*}
author: xuchaoxin
'''
[docs] def aggregateWithWeight(x,csc,csr,weight=None): out = AggregateFunc.apply(x,csc,csr,weight) return out