Source code for jittor_geometric.data.temporal

import copy
import jittor
import numpy as np


[docs] class TemporalData(object): def __init__(self, src=None, dst=None, t=None, msg=None, y=None, edge_ids=None, **kwargs): self.src = src self.dst = dst self.t = t self.msg = msg self.y = y self.edge_ids = edge_ids for key, item in kwargs.items(): self[key] = item def __getitem__(self, idx): if isinstance(idx, str): return getattr(self, idx, None) if isinstance(idx, int): idx = jittor.Var([idx]) if isinstance(idx, (list, tuple)): idx = jittor.Var(idx) elif isinstance(idx, slice): pass elif isinstance(idx, jittor.Var) and (idx.dtype == jittor.int32 or idx.dtype == jittor.bool): pass elif isinstance(idx, np.ndarray): pass else: raise IndexError( f'Only strings, integers, slices (`:`), list, tuples, and ' f'long or bool Vars are valid indices (got ' f'{type(idx).__name__}).') data = copy.copy(self) for key, item in data: if item.shape[0] == self.num_events: data[key] = item[idx] return data def __setitem__(self, key, value): """Sets the attribute :obj:`key` to :obj:`value`.""" setattr(self, key, value) @property def keys(self): return [key for key in self.__dict__.keys() if self[key] is not None] def __len__(self): return len(self.keys) def __contains__(self, key): return key in self.keys def __iter__(self): for key in sorted(self.keys): yield key, self[key] def __call__(self, *keys): for key in sorted(self.keys) if not keys else keys: if key in self: yield key, self[key] @property def num_nodes(self): return max(int(self.src.max()), int(self.dst.max())) + 1 @property def num_events(self): if isinstance(self.src,np.ndarray): return self.src.shape[0] return self.src.size(0) @property def num_edges(self): if isinstance(self.src,np.ndarray): return self.src.shape[0] return self.src.size(0) def __apply__(self, item, func): if jittor.is_Var(item): return func(item) elif isinstance(item, (tuple, list)): return [self.__apply__(v, func) for v in item] elif isinstance(item, dict): return {k: self.__apply__(v, func) for k, v in item.items()} else: return item
[docs] def apply(self, func, *keys): r"""Applies the function :obj:`func` to all Var attributes :obj:`*keys`. If :obj:`*keys` is not given, :obj:`func` is applied to all present attributes. """ for key, item in self(*keys): self[key] = self.__apply__(item, func) return self
[docs] def to(self, device, *keys, **kwargs): return self.apply(lambda x: x.to(device, **kwargs), *keys)
[docs] def train_val_test_split(self, val_ratio=0.15, test_ratio=0.15): val_time, test_time = np.quantile( self.t.cpu().numpy(), [1. - val_ratio - test_ratio, 1. - test_ratio]) val_idx = int((self.t <= val_time).sum()) + 1 test_idx = int((self.t <= test_time).sum()) + 1 return self[:val_idx], self[val_idx:test_idx], self[test_idx:]
[docs] def train_val_test_split_w_mask(self): return self[(self.train_mask)], self[(self.val_mask)], self[(self.test_mask)]
[docs] def seq_batches(self, batch_size): for start in range(0, self.num_events, batch_size): yield self[start:start + batch_size]
def __cat_dim__(self, key, value): return 0 def __inc__(self, key, value): return 0 def __repr__(self): cls = str(self.__class__.__name__) shapes = ', '.join([f'{k}={list(v.shape)}' for k, v in self]) return f'{cls}({shapes})'