GED / one_hot.py
ubuntu
Initial Commit
6b33608
from typing import Optional
import torch
from torch import Tensor
def one_hot(
index: Tensor,
num_classes: Optional[int] = None,
dtype: Optional[torch.dtype] = None,
) -> Tensor:
r"""Taskes a one-dimensional :obj:`index` tensor and returns a one-hot
encoded representation of it with shape :obj:`[*, num_classes]` that has
zeros everywhere except where the index of last dimension matches the
corresponding value of the input tensor, in which case it will be :obj:`1`.
.. note::
This is a more memory-efficient version of
:meth:`torch.nn.functional.one_hot` as you can customize the output
:obj:`dtype`.
Args:
index (torch.Tensor): The one-dimensional input tensor.
num_classes (int, optional): The total number of classes. If set to
:obj:`None`, the number of classes will be inferred as one greater
than the largest class value in the input tensor.
(default: :obj:`None`)
dtype (torch.dtype, optional): The :obj:`dtype` of the output tensor.
"""
if index.dim() != 1:
raise ValueError("'index' tensor needs to be one-dimensional")
if num_classes is None:
num_classes = int(index.max()) + 1
out = torch.zeros((index.size(0), num_classes), dtype=dtype,
device=index.device)
return out.scatter_(1, index.unsqueeze(1), 1)