Source code for jittor_geometric.utils.neighbor_sampler

import numpy as np
from typing import Optional, Tuple
import jittor as jt
import copy

[docs] def neighbor_sampler( neighbor_list: jt.Var, neighbor_offs: jt.Var, neighbor_nums: jt.Var, source_node: jt.Var, max_nodes : Optional[int] = None, max_edges : Optional[int] = None ) -> jt.Var: r"""Samples the neighbor of all the source nodes, in the given graph described by neighbor_list, neighbor_hook, and neighbor_offs. Args: neighbor_list (jittor.Var): An ordered list of neighbors. neighbor_offs (jittor.Var): For each node i, the neighbors of i are neighbor_list[neighbor_offs[i], neighbor_offs[i+1]]. There is neighbor_offs[max_nodes], which should equals to negihbor_list.size(0). neighbor_nums (jittor.Var): For each node i, the number of its neighbors is neighbor_nums[i] == neighbor_offs[i+1] - neighbor_offs[i].For neighbor_nums[i] == 0, should be preprocessed (usually assigned max_edges). source_node (jittor,Var): The source node of sampling. max_nodes (int, optional): The total number of nodes, assert to be (the length of neighbor_offs) - 1 and (the length of neighbor_nums) . max_edges (int, optional): The total number of edges, assert to be the length of neighbor_list. """ if max_nodes is None: max_nodes = neighbor_nums.size(0) if max_edges is None: max_edges = neighbor_list.size(0) idx = jt.randint_like(source_node, 0, max_edges) idx = (idx % neighbor_nums[source_node] + neighbor_offs[source_node]) % max_edges dst = neighbor_list[1,idx] return dst
[docs] def randomwalk_sampler( neighbor_list: jt.Var, neighbor_offs: jt.Var, neighbor_nums: jt.Var, source_node: jt.Var, walk_length: int, max_nodes : Optional[int] = None, max_edges : Optional[int] = None ) -> jt.Var: r"""Samples the random_walk of all the source nodes with length 'walk_length', in the given graph described by neighbor_list, neighbor_hook, and neighbor_offs. Args: neighbor_list (jittor.Var): An ordered list of neighbors. neighbor_offs (jittor.Var): For each node i, the neighbors of i are neighbor_list[neighbor_offs[i], neighbor_offs[i+1]]. There is neighbor_offs[max_nodes], which should equals to negihbor_list.size(0). neighbor_nums (jittor.Var): For each node i, the number of its neighbors is neighbor_nums[i] == neighbor_offs[i+1] - neighbor_offs[i].For neighbor_nums[i] == 0, should be preprocessed (usually assigned max_edges). source_node (jittor,Var): The source node of sampling. walk_length (int): The length of random walk. max_nodes (int, optional): The total number of nodes, assert to be (the length of neighbor_offs) - 1 and (the length of neighbor_nums) . max_edges (int, optional): The total number of edges, assert to be the length of neighbor_list. """ if max_nodes is None: max_nodes = neighbor_nums.size(0) if max_edges is None: max_edges = neighbor_list.size(0) dst = jt.zeros([source_node.size(0), walk_length + 1], dtype = 'int') dst[:, 0] = source_node source = copy.copy(source_node) for i in range(walk_length): target = neighbor_sampler(neighbor_list, neighbor_offs, neighbor_nums, source, max_nodes, max_edges) dst[:, i+1] = target source = copy.copy(target) return dst