import copy
import jittor
import numpy as np
[docs]
class TemporalData(object):
def __init__(self, src=None, dst=None, t=None, msg=None, y=None, edge_ids=None, **kwargs):
self.src = src
self.dst = dst
self.t = t
self.msg = msg
self.y = y
self.edge_ids = edge_ids
for key, item in kwargs.items():
self[key] = item
def __getitem__(self, idx):
if isinstance(idx, str):
return getattr(self, idx, None)
if isinstance(idx, int):
idx = jittor.Var([idx])
if isinstance(idx, (list, tuple)):
idx = jittor.Var(idx)
elif isinstance(idx, slice):
pass
elif isinstance(idx, jittor.Var) and (idx.dtype == jittor.int32
or idx.dtype == jittor.bool):
pass
elif isinstance(idx, np.ndarray):
pass
else:
raise IndexError(
f'Only strings, integers, slices (`:`), list, tuples, and '
f'long or bool Vars are valid indices (got '
f'{type(idx).__name__}).')
data = copy.copy(self)
for key, item in data:
if item.shape[0] == self.num_events:
data[key] = item[idx]
return data
def __setitem__(self, key, value):
"""Sets the attribute :obj:`key` to :obj:`value`."""
setattr(self, key, value)
@property
def keys(self):
return [key for key in self.__dict__.keys() if self[key] is not None]
def __len__(self):
return len(self.keys)
def __contains__(self, key):
return key in self.keys
def __iter__(self):
for key in sorted(self.keys):
yield key, self[key]
def __call__(self, *keys):
for key in sorted(self.keys) if not keys else keys:
if key in self:
yield key, self[key]
@property
def num_nodes(self):
return max(int(self.src.max()), int(self.dst.max())) + 1
@property
def num_events(self):
if isinstance(self.src,np.ndarray):
return self.src.shape[0]
return self.src.size(0)
@property
def num_edges(self):
if isinstance(self.src,np.ndarray):
return self.src.shape[0]
return self.src.size(0)
def __apply__(self, item, func):
if jittor.is_Var(item):
return func(item)
elif isinstance(item, (tuple, list)):
return [self.__apply__(v, func) for v in item]
elif isinstance(item, dict):
return {k: self.__apply__(v, func) for k, v in item.items()}
else:
return item
[docs]
def apply(self, func, *keys):
r"""Applies the function :obj:`func` to all Var attributes
:obj:`*keys`. If :obj:`*keys` is not given, :obj:`func` is applied to
all present attributes.
"""
for key, item in self(*keys):
self[key] = self.__apply__(item, func)
return self
[docs]
def to(self, device, *keys, **kwargs):
return self.apply(lambda x: x.to(device, **kwargs), *keys)
[docs]
def train_val_test_split(self, val_ratio=0.15, test_ratio=0.15):
val_time, test_time = np.quantile(
self.t.cpu().numpy(),
[1. - val_ratio - test_ratio, 1. - test_ratio])
val_idx = int((self.t <= val_time).sum()) + 1
test_idx = int((self.t <= test_time).sum()) + 1
return self[:val_idx], self[val_idx:test_idx], self[test_idx:]
[docs]
def train_val_test_split_w_mask(self):
return self[(self.train_mask)], self[(self.val_mask)], self[(self.test_mask)]
[docs]
def seq_batches(self, batch_size):
for start in range(0, self.num_events, batch_size):
yield self[start:start + batch_size]
def __cat_dim__(self, key, value):
return 0
def __inc__(self, key, value):
return 0
def __repr__(self):
cls = str(self.__class__.__name__)
shapes = ', '.join([f'{k}={list(v.shape)}' for k, v in self])
return f'{cls}({shapes})'