import jittor as jt
import jittor.nn as nn
import numpy as np
[docs]
class TimeEncoder(nn.Module):
def __init__(self, time_dim: int, parameter_requires_grad: bool = True):
"""
Time encoder.
:param time_dim: int, dimension of time encodings
:param parameter_requires_grad: boolean, whether the parameter in TimeEncoder needs gradient
"""
super(TimeEncoder, self).__init__()
self.time_dim = time_dim
# trainable parameters for time encoding
self.w = nn.Linear(1, time_dim)
self.w.weight = nn.Parameter((jt.Var(1 / 10 ** np.linspace(0, 9, time_dim, dtype=np.float32))).reshape(time_dim, -1))
self.w.bias = nn.Parameter(jt.zeros(time_dim))
if not parameter_requires_grad:
self.w.weight.requires_grad = False
self.w.bias.requires_grad = False
[docs]
def execute(self, timestamps: jt.Var):
"""
compute time encodings of time in timestamps
:param timestamps: Var, shape (batch_size, seq_len)
:return:
"""
# Var, shape (batch_size, seq_len, 1)
timestamps = timestamps.unsqueeze(dim=2)
# Var, shape (batch_size, seq_len, time_dim)
output = jt.cos(self.w(timestamps))
return output