Source code for jittor_geometric.utils.induced_graph

import jittor
from typing import Tuple

[docs] def induced_graph( edge_index: jittor.Var, node_selected: jittor.Var, max_nodes: int ) -> Tuple[jittor.Var, jittor.Var]: r"""Generate the node-induced graph of the original graph described by 'edge_index'. It is expected that max_nodes is larger than, or equal to the real max node index in edge_index, or may cause errors. Args: edge_index (jittor.Var): The edge list describing the whole graph. node_selected (jittor.Var): The node list of the expected induced graph. max_nodes (int): The maximum node index in edge_index. """ node_mask = jittor.zeros(max_nodes, dtype = "bool") node_mask[node_selected] = True node_map = jittor.zeros(max_nodes, dtype = "int") node_map[node_selected] = jittor.arange(0, node_selected.size(0)) edge_mask = node_mask[edge_index[0]] & node_mask[edge_index[1]] edge_selected = jittor.nonzero(edge_mask).view(-1) return node_map, edge_selected