Source code for jittor_geometric.ops.scatterToEdge

'''
Description: COO version of scatter to edge for GAT
Author: lusz
Date: 2024-06-28 17:10:12
'''

import jittor as jt
import os
import sys
from jittor import nn
from jittor import Function
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from jittor_geometric.data import CSC, CSR
module_path = os.path.dirname(__file__)
src = os.path.join(module_path, "cpp/scattertoedge_op.cc")
header = os.path.join(module_path, "cpp/scattertoedge_op.h")
srcb = os.path.join(module_path, "cpp/edgetovertex_op.cc")
headerb = os.path.join(module_path, "cpp/edgetovertex_op.h")
scatter_op = jt.compile_custom_ops((src, header))
scatter_backward_op = jt.compile_custom_ops((srcb, headerb))

class ScatterToEdgeFunc(Function):
    def execute(self,x,csc,flow):
        self.flow=flow
        self.csc=csc
        # output dim
        e_num=jt.size(csc.row_indices,0)
        feature_dim=jt.size(x,1)        
        v_num=jt.size(x,0)
        self.e_num=e_num
        self.v_num=v_num
        self.feature_dim=feature_dim
        output=jt.zeros(e_num,feature_dim)
        flag=1
        if flow=="src":
            flag=0
        self.flag=flag
        scatter_op.scattertoedge(output,x,csc.row_indices,csc.column_offset,False,flag).fetch_sync()
        
        return output

    def grad(self, grad_output):
        output_grad=jt.zeros(self.v_num,self.feature_dim)
        csc=self.csc
        scatter_backward_op.edgetovertex(output_grad,grad_output,csc.row_indices,csc.column_offset,self.flag).fetch_sync()
        return output_grad,None,None
    

[docs] def ScatterToEdge(x,csc,flow): out = ScatterToEdgeFunc.apply(x,csc,flow) return out