Source code for jittor_geometric.data.batch

import jittor as jt
import numpy as np
from typing import Any, List, Optional, Type, Union
from collections.abc import Sequence
from jittor_geometric.data import Data


[docs] class Batch(Data): r"""A data object describing a batch of graphs as one big (disconnected) graph. Inherits from :class:`jittor_geometric.data.Data`. """
[docs] @classmethod def from_data_list( cls, data_list: List[Data], follow_batch: Optional[List[str]] = None, exclude_keys: Optional[List[str]] = None, ) -> 'Batch': r"""Constructs a :class:`~jittor_geometric.data.Batch` object from a list of :class:`~jittor_geometric.data.Data` objects. The assignment vector :obj:`batch` is created on the fly. In addition, creates assignment vectors for each key in :obj:`follow_batch`. Will exclude any keys given in :obj:`exclude_keys`. """ if exclude_keys is None: exclude_keys = [] # Initialize batch batch = cls() batch._num_graphs = len(data_list) # Merge data slice_dict = {} inc_dict = {} for key in data_list[0].keys: if key in exclude_keys: continue items = [data[key] for data in data_list] if isinstance(items[0], jt.Var): if key == 'edge_index': # Adjust edge_index for each graph cum_nodes = 0 adjusted_edge_indices = [] for i, edge_index in enumerate(items): adjusted_edge_indices.append(edge_index + cum_nodes) cum_nodes += data_list[i].num_nodes batch[key] = jt.concat(adjusted_edge_indices, dim=1) else: batch[key] = jt.concat(items, dim=0) elif isinstance(items[0], (int, float)): batch[key] = jt.array(items) else: batch[key] = items # Create slice_dict for reconstruction if isinstance(items[0], jt.Var): slice_dict[key] = jt.cumsum(jt.array([data[key].shape[0] for data in data_list]), dim=0) else: slice_dict[key] = jt.cumsum(jt.array([1 for _ in data_list]), dim=0) # Create batch vector batch.batch = jt.concat([jt.full((data.num_nodes,), i) for i, data in enumerate(data_list)]) # Handle follow_batch if follow_batch is not None: for key in follow_batch: if key in exclude_keys: continue batch[f'{key}_batch'] = jt.concat([jt.full((data[key].shape[0],), i) for i, data in enumerate(data_list)]) # Save slice_dict and inc_dict for reconstruction batch._slice_dict = slice_dict batch._inc_dict = inc_dict return batch
[docs] def get_example(self, idx: int) -> Data: r"""Gets the :class:`~jittor_geometric.data.Data` object at index :obj:`idx`. The :class:`~jittor_geometric.data.Batch` object must have been created via :meth:`from_data_list` in order to be able to reconstruct the initial object. """ if not hasattr(self, '_slice_dict'): raise RuntimeError( "Cannot reconstruct 'Data' object from 'Batch' because " "'Batch' was not created via 'Batch.from_data_list()'") data = Data() for key in self.keys: if key == 'batch' or key.endswith('_batch'): continue if isinstance(self[key], jt.Var): start = 0 if idx == 0 else self._slice_dict[key][idx - 1].item() end = self._slice_dict[key][idx].item() data[key] = self[key][start:end] else: data[key] = self[key][idx] return data
[docs] def to_data_list(self) -> List[Data]: r"""Reconstructs the list of :class:`~jittor_geometric.data.Data` objects from the :class:`~jittor_geometric.data.Batch` object. """ return [self.get_example(i) for i in range(self.num_graphs)]
@property def num_graphs(self) -> int: """Returns the number of graphs in the batch.""" if hasattr(self, '_num_graphs'): return self._num_graphs elif hasattr(self, 'batch'): return int(self.batch.max().item()) + 1 else: raise ValueError("Can not infer the number of graphs") @property def batch_size(self) -> int: r"""Alias for :obj:`num_graphs`.""" return self.num_graphs def __len__(self) -> int: return self.num_graphs