from typing import Optional import torch from torch import Tensor def one_hot( index: Tensor, num_classes: Optional[int] = None, dtype: Optional[torch.dtype] = None, ) -> Tensor: r"""Taskes a one-dimensional :obj:`index` tensor and returns a one-hot encoded representation of it with shape :obj:`[*, num_classes]` that has zeros everywhere except where the index of last dimension matches the corresponding value of the input tensor, in which case it will be :obj:`1`. .. note:: This is a more memory-efficient version of :meth:`torch.nn.functional.one_hot` as you can customize the output :obj:`dtype`. Args: index (torch.Tensor): The one-dimensional input tensor. num_classes (int, optional): The total number of classes. If set to :obj:`None`, the number of classes will be inferred as one greater than the largest class value in the input tensor. (default: :obj:`None`) dtype (torch.dtype, optional): The :obj:`dtype` of the output tensor. """ if index.dim() != 1: raise ValueError("'index' tensor needs to be one-dimensional") if num_classes is None: num_classes = int(index.max()) + 1 out = torch.zeros((index.size(0), num_classes), dtype=dtype, device=index.device) return out.scatter_(1, index.unsqueeze(1), 1)