Source code for jittor_geometric.dataloader.cluster_loader

import jittor as jt
from jittor import sparse
import copy
from .general_loader import GeneralLoader
from ..utils import induced_graph
from jittor_geometric.partition.chunk_manager import ChunkManager
import numpy as np


[docs] class ClusterLoader(GeneralLoader): r''' The graph dataset loader, using metis to partition graph into 'mini_splits' groups and randomly merge them into 'num_parts' mini-batches for training. The dataset loader yields the induced graph of the selected nodes iteratively. Args: dataset (InMemoryDataset): The original graph dataset. num_parts (int): Number of expected mini-batches. mini_splits (int): Number of parts that metis partitions. fixed (bool, optional): If set to 'True', the dataset loader will yield identical mini-batches every round. ''' def __init__(self, dataset, num_parts: int, mini_splits: int, fixed: bool = False): self.data = copy.copy(dataset[0]) self.N = self.data.num_nodes self.E = self.data.num_edges self.edge_index = copy.copy(self.data.edge_index) self.data.edge_index = None self.num_parts = num_parts self.mini_splits = mini_splits self.fixed = fixed self.itermax = num_parts self.parts = self.partition(self.edge_index, self.N, self.mini_splits) self.itercnt = 0 self.n_splits = self.get_node_indices() for i in self.n_splits: print(i)
[docs] def partition(self, edge_index, num_nodes, mini_splits): chunk_manager = ChunkManager(output_dir=None) partition = chunk_manager.metis_partition(edge_index, num_nodes, mini_splits) partition = jt.Var(partition) partition = partition.sort() parts = [] part_begin = 0 part_end = 1 for i in range(self.mini_splits): part_end = part_begin + 1 while part_end <= num_nodes - 1 and partition[0][part_end] == partition[0][part_end - 1]: part_end += 1 if part_begin >= part_end: parts.append(jt.zeros(0, dtype='int')) elif part_end >= num_nodes: parts.append(partition[1][part_begin:]) else: parts.append(partition[1][part_begin: part_end]) part_begin = part_end return parts
[docs] def get_node_indices(self): n_id = np.random.permutation(self.mini_splits) % self.num_parts n_ids = [jt.nonzero((n_id == i)).view(-1) for i in range(self.num_parts)] return n_ids
def __reset__(self): self.itercnt = 0 if self.fixed == False: self.n_ids = self.get_node_indices() def __iter__(self): return self def __next__(self): if self.itercnt < self.itermax: node_id = jt.zeros(0, dtype='int') for i in self.n_splits[self.itercnt]: node_id = jt.concat([node_id, self.parts[i]]) node_map, edge_id = induced_graph(self.edge_index, node_id, self.N) # node_mask = jt.zeros(self.N, dtype='bool') # node_mask[node_id] = True # edge_mask = node_mask[self.edge_index[0]] & node_mask[self.edge_index[1]] # edge_id = jt.nonzero(edge_mask).view(-1) data = self.data.__class__() # node_map = jt.zeros(self.N, dtype='int32') # node_map[node_id] = jt.arange(0, node_id.size(0)) data.edge_index = node_map[self.edge_index[:, edge_id]] for key, item in self.data: if key in ['num_nodes']: data[key] = node_id.size(0) elif isinstance(item, jt.Var) and item.size(0) == self.N: data[key] = item[node_id] elif isinstance(item, jt.Var) and item.size(0) == self.E: data[key] = item[edge_id] else: data[key] = item self.itercnt += 1 return data else: self.__reset__() raise StopIteration