File size: 1,442 Bytes
6b33608
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from typing import Optional

import torch
from torch import Tensor


def one_hot(

    index: Tensor,

    num_classes: Optional[int] = None,

    dtype: Optional[torch.dtype] = None,

) -> Tensor:
    r"""Taskes a one-dimensional :obj:`index` tensor and returns a one-hot

    encoded representation of it with shape :obj:`[*, num_classes]` that has

    zeros everywhere except where the index of last dimension matches the

    corresponding value of the input tensor, in which case it will be :obj:`1`.



    .. note::

        This is a more memory-efficient version of

        :meth:`torch.nn.functional.one_hot` as you can customize the output

        :obj:`dtype`.



    Args:

        index (torch.Tensor): The one-dimensional input tensor.

        num_classes (int, optional): The total number of classes. If set to

            :obj:`None`, the number of classes will be inferred as one greater

            than the largest class value in the input tensor.

            (default: :obj:`None`)

        dtype (torch.dtype, optional): The :obj:`dtype` of the output tensor.

    """
    if index.dim() != 1:
        raise ValueError("'index' tensor needs to be one-dimensional")

    if num_classes is None:
        num_classes = int(index.max()) + 1

    out = torch.zeros((index.size(0), num_classes), dtype=dtype,
                      device=index.device)
    return out.scatter_(1, index.unsqueeze(1), 1)