Source code for jittor_geometric.utils.one_hot

from typing import Optional

import jittor as jt
from jittor import Var


[docs] def one_hot( index: Var, num_classes: Optional[int] = None, ) -> Var: r"""Taskes a one-dimensional :obj:`index` var and returns a one-hot encoded representation of it with shape :obj:`[*, num_classes]` that has zeros everywhere except where the index of last dimension matches the corresponding value of the input var, in which case it will be :obj:`1`. Args: index (jittor.Var): The one-dimensional input var. num_classes (int, optional): The total number of classes. If set to :obj:`None`, the number of classes will be inferred as one greater than the largest class value in the input var. (default: :obj:`None`) dtype (jittor.dtype, optional): The :obj:`dtype` of the output var. """ if index.dim() != 1: raise ValueError("'index' var needs to be one-dimensional") if num_classes is None: num_classes = int(index.max()) + 1 out = jt.zeros((index.size(0), num_classes)) return out.scatter_(1, index.unsqueeze(1), jt.Var([1]))