File size: 3,653 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
from typing import Optional, Callable


def levenshtein_distance(
        pred: torch.LongTensor,
        target: torch.LongTensor,
        pred_extra: Optional[torch.Tensor] = None,
        target_extra: Optional[torch.Tensor] = None,
        extra_fn: Optional[Callable] = None
) -> torch.FloatTensor:
    """
    Overview:
        Levenshtein Distance, i.e. Edit Distance.
    Arguments:
        - pred (:obj:`torch.LongTensor`): The first tensor to calculate the distance, shape: (N1, )  (N1 >= 0).
        - target (:obj:`torch.LongTensor`): The second tensor to calculate the distance, shape: (N2, )  (N2 >= 0).
        - pred_extra (:obj:`Optional[torch.Tensor]`): Extra tensor to calculate the distance, only works when \
            ``extra_fn`` is not ``None``.
        - target_extra (:obj:`Optional[torch.Tensor]`): Extra tensor to calculate the distance, only works when \
            ``extra_fn`` is not ``None``.
        - extra_fn (:obj:`Optional[Callable]`): The distance function for ``pred_extra`` and \
            ``target_extra``. If set to ``None``, this distance will not be considered.
    Returns:
        - distance (:obj:`torch.FloatTensor`): distance(scalar), shape: (1, ).
    """
    assert (isinstance(pred, torch.Tensor) and isinstance(target, torch.Tensor))
    assert (pred.dtype == torch.long and target.dtype == torch.long), '{}\t{}'.format(pred.dtype, target.dtype)
    assert (pred.device == target.device)
    assert (type(pred_extra) == type(target_extra))
    if not extra_fn:
        assert (not pred_extra)
    N1, N2 = pred.shape[0], target.shape[0]
    assert (N1 >= 0 and N2 >= 0)
    if N1 == 0 or N2 == 0:
        distance = max(N1, N2)
    else:
        dp_array = torch.zeros(N1, N2).float()
        if extra_fn:
            if pred[0] == target[0]:
                extra = extra_fn(pred_extra[0], target_extra[0])
            else:
                extra = 1.
            dp_array[0, :] = torch.arange(0, N2) + extra
            dp_array[:, 0] = torch.arange(0, N1) + extra
        else:
            dp_array[0, :] = torch.arange(0, N2)
            dp_array[:, 0] = torch.arange(0, N1)
        for i in range(1, N1):
            for j in range(1, N2):
                if pred[i] == target[j]:
                    if extra_fn:
                        dp_array[i, j] = dp_array[i - 1, j - 1] + extra_fn(pred_extra[i], target_extra[j])
                    else:
                        dp_array[i, j] = dp_array[i - 1, j - 1]
                else:
                    dp_array[i, j] = min(dp_array[i - 1, j] + 1, dp_array[i, j - 1] + 1, dp_array[i - 1, j - 1] + 1)
        distance = dp_array[N1 - 1, N2 - 1]
    return torch.FloatTensor([distance]).to(pred.device)


def hamming_distance(pred: torch.LongTensor, target: torch.LongTensor, weight=1.) -> torch.LongTensor:
    """
    Overview:
        Hamming Distance.
    Arguments:
        - pred (:obj:`torch.LongTensor`): Pred input, boolean vector(0 or 1).
        - target (:obj:`torch.LongTensor`): Target input, boolean vector(0 or 1).
        - weight (:obj:`torch.LongTensor`): Weight to multiply.
    Returns:
        - distance(:obj:`torch.LongTensor`): Distance (scalar), shape (1, ).
    Shapes:
        - pred & target (:obj:`torch.LongTensor`): shape :math:`(B, N)`, \
            while B is the batch size, N is the dimension
    """
    assert (isinstance(pred, torch.Tensor) and isinstance(target, torch.Tensor))
    assert (pred.dtype == torch.long and target.dtype == torch.long)
    assert (pred.device == target.device)
    assert (pred.shape == target.shape)
    return pred.ne(target).sum(dim=1).float().mul_(weight)