Source code for jittor_geometric.nn.dense.time_encoder

import jittor as jt
import jittor.nn as nn
import numpy as np

[docs] class TimeEncoder(nn.Module): def __init__(self, time_dim: int, parameter_requires_grad: bool = True): """ Time encoder. :param time_dim: int, dimension of time encodings :param parameter_requires_grad: boolean, whether the parameter in TimeEncoder needs gradient """ super(TimeEncoder, self).__init__() self.time_dim = time_dim # trainable parameters for time encoding self.w = nn.Linear(1, time_dim) self.w.weight = nn.Parameter((jt.Var(1 / 10 ** np.linspace(0, 9, time_dim, dtype=np.float32))).reshape(time_dim, -1)) self.w.bias = nn.Parameter(jt.zeros(time_dim)) if not parameter_requires_grad: self.w.weight.requires_grad = False self.w.bias.requires_grad = False
[docs] def execute(self, timestamps: jt.Var): """ compute time encodings of time in timestamps :param timestamps: Var, shape (batch_size, seq_len) :return: """ # Var, shape (batch_size, seq_len, 1) timestamps = timestamps.unsqueeze(dim=2) # Var, shape (batch_size, seq_len, time_dim) output = jt.cos(self.w(timestamps)) return output