Source code for jittor_geometric.dataloader.neighbor_loader


from typing import Callable, Dict, List, Optional, Tuple, Union
import numpy as np

import jittor as jt
from jittor import sparse
import copy
from .general_loader import GeneralLoader
from ..utils import neighbor_sampler

[docs] class NeighborLoader(GeneralLoader): r''' The graph dataset loader, samples the neighbors of source nodes for graph convolution iteratively and randomly. The dataset loader yields all the tuples of (source_node, sampled neighbor) as the graph edge. The yielded data block will contain the node map and node index. Args: dataset (InMemoryDataset): The original graph dataset. source_node (Var): The list of source_node. num_neighbors (List[int]): Number of sampled neighbors per layer. For example, [2, 3] represents sample 2 neighbors, and 3 neighbors of each sampled ones. batch_size (int): Size of each mini-batch fixed (bool, optional): If set to 'True', the dataset loader will yield identical mini-batches every round. ''' def __init__( self, dataset, source_node : jt.Var = None, num_neighbors : List[int] = None, batch_size : int = None, fixed : bool = False ): self.data = copy.copy(dataset[0]) self.N = self.data.num_nodes self.E = self.data.num_edges self.edge_index = copy.copy(self.data.edge_index) self.source_node = source_node self.data.edge_index = None self.num_neighbors = num_neighbors self.batch_size = batch_size self.fixed = fixed self.itercnt = 0 self.itermax = int((source_node.size(0) - 1) / batch_size) + 1 self.n_ids = self.get_node_indices() # print(self.n_ids) # print(self.edge_index[:,0:15]) self.row_offsets = jt.zeros((self.N+1, ),dtype='int') for i in range(self.E - 1, -1, -1): self.row_offsets[self.edge_index[0,i]] = i self.row_offsets[-1] = self.E for i in range(self.N - 1, -1, -1): if self.row_offsets[i] == 0: self.row_offsets[i] = self.row_offsets[i+1] # print(self.row_offsets[0:5], self.row_offsets[-1]) self.row_range = self.row_offsets[1:] - self.row_offsets[:-1] for i in range(self.N): if self.row_range[i] == 0: self.row_range[i] = self.E # testsubgraph = self.neighbor_sampling(self.n_ids[0]) # print(testsubgraph)
[docs] def get_node_indices(self): n_id = np.random.permutation(self.source_node.size(0)) # print(n_id) n_ids = [] for i in range(self.itermax): n_ids.append(self.source_node[n_id[i * self.batch_size : min((i+1) * self.batch_size, self.source_node.size(0))]]) return n_ids
[docs] def neighbor_sampling(self, source_nodes): results = jt.empty((2, 0), dtype=int) for i in self.num_neighbors: source_nodes = jt.repeat_interleave(source_nodes, i) target_nodes = neighbor_sampler(self.edge_index, self.row_offsets, self.row_range, source_nodes, self.N, self.E) # idx = jt.randint_like(source_nodes, 0, self.E) # idx = (idx % self.row_range[source_nodes] + self.row_offsets[source_nodes]) % self.E # target_nodes = self.edge_index[1][idx] edges = jt.stack([target_nodes, source_nodes]) new_results = jt.empty((2, results.shape[1] + edges.shape[1]), dtype=int) if results.shape[1] != 0: new_results[:, :results.shape[1]] = results new_results[:, results.shape[1]:] = edges results = new_results source_nodes = target_nodes # print(results) return results
def __reset__(self): self.itercnt = 0 if self.fixed == False: self.n_ids = self.get_node_indices() def __iter__(self): return self
[docs] def unique(self, x): res = [x[0].item()] for i in range(1, x.shape[0]): if x[i].item() != x[i-1].item(): res.append(x[i].item()) res = jt.Var(res) return res
def __next__(self): if self.itercnt < self.itermax: central_nodes = self.n_ids[self.itercnt] edges = self.neighbor_sampling(central_nodes) node_id = self.unique(jt.sort(jt.flatten(edges))[0]) node_mask = jt.zeros(self.N, dtype='bool') node_mask[node_id] = True # edge_mask = node_mask[self.edge_index[0]] & node_mask[self.edge_index[1]] # edge_id = jt.nonzero(edge_mask).view(-1) data = self.data.__class__() node_map = jt.zeros(self.N, dtype='int') node_map[node_id] = jt.arange(0, node_id.size(0)) data.node_map = node_id data.edge_index = node_map[edges] data.central_nodes = node_map[central_nodes] for key, item in self.data: if key in ['num_nodes']: data[key] = node_id.size(0) elif isinstance(item, jt.Var) and item.size(0) == self.N: data[key] = item[node_id] elif isinstance(item, jt.Var) and item.size(0) == self.E: data[key] = None else: data[key] = item self.itercnt += 1 return data else: self.__reset__() raise StopIteration