Source code for jittor_geometric.ops.saparse_ops

'''
Description: 
Author: lusz
Date: 2025-01-11 13:40:29
'''
import jittor as jt
from jittor_geometric.data import CSC, CSR

[docs] def from_nodes(csc, nodes): """ Given a CSC structure and a set of input nodes, find all the neighbor nodes. Parameters: csc (CSC): Compressed Sparse Column structure. nodes (Var): Input node IDs (Var type). Returns: Var: A Var containing all neighbor nodes corresponding to the input nodes. """ if csc is not None: # Use CSC (column-based) to find neighbors column_offset = csc.column_offset row_indices = csc.row_indices # Extract neighbors for each input node neighbors = [] for i in range(nodes.shape[0]): node = nodes[i] start = column_offset[node] end = column_offset[node + 1] neighbors.append(row_indices[start:end]) else: raise ValueError("CSC structure must be provided.") # Flatten the list of neighbors and remove duplicates neighbors = jt.Var(jt.contrib.concat([n for n in neighbors])) # unique_neighbors = jt.unique(neighbors) return neighbors
[docs] def to_nodes(csr, nodes): """ Given a CSR structure and a set of input nodes, find all the neighbor nodes. Parameters: csr (CSR): Compressed Sparse Row structure. nodes (Var): Input node IDs (Var type). Returns: Var: A Var containing all neighbor nodes corresponding to the input nodes. """ if csr is not None: # Use CSR to find neighbors row_offset = csr.row_offset column_indices = csr.column_indices neighbors = [] for i in range(nodes.shape[0]): node = nodes[i] start = row_offset[node] end = row_offset[node + 1] neighbors.append(column_indices[start:end]) else: raise ValueError("CSR structure must be provided.") # Flatten the list of neighbors and remove duplicates neighbors = jt.Var(jt.contrib.concat([n for n in neighbors])) # unique_neighbors = jt.unique(neighbors) return neighbors