File size: 4,207 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Any, Optional


class LabelSmoothCELoss(nn.Module):
    """
    Overview:
        Label smooth cross entropy loss.
    Interfaces:
        ``__init__``, ``forward``.
    """

    def __init__(self, ratio: float) -> None:
        """
        Overview:
            Initialize the LabelSmoothCELoss object using the given arguments.
        Arguments:
            - ratio (:obj:`float`): The ratio of label-smoothing (the value is in 0-1). If the ratio is larger, the \
                extent of label smoothing is larger.
        """
        super().__init__()
        self.ratio = ratio

    def forward(self, logits: torch.Tensor, labels: torch.LongTensor) -> torch.Tensor:
        """
        Overview:
            Calculate label smooth cross entropy loss.
        Arguments:
            - logits (:obj:`torch.Tensor`): Predicted logits.
            - labels (:obj:`torch.LongTensor`): Ground truth.
        Returns:
            - loss (:obj:`torch.Tensor`): Calculated loss.
        """
        B, N = logits.shape
        val = float(self.ratio) / (N - 1)
        one_hot = torch.full_like(logits, val)
        one_hot.scatter_(1, labels.unsqueeze(1), 1 - val)
        logits = F.log_softmax(logits, dim=1)
        return -torch.sum(logits * (one_hot.detach())) / B


class SoftFocalLoss(nn.Module):
    """
    Overview:
        Soft focal loss.
    Interfaces:
        ``__init__``, ``forward``.
    """

    def __init__(
            self, gamma: int = 2, weight: Any = None, size_average: bool = True, reduce: Optional[bool] = None
    ) -> None:
        """
        Overview:
            Initialize the SoftFocalLoss object using the given arguments.
        Arguments:
            - gamma (:obj:`int`): The extent of focus on hard samples. A smaller ``gamma`` will lead to more focus on \
                easy samples, while a larger ``gamma`` will lead to more focus on hard samples.
            - weight (:obj:`Any`): The weight for loss of each class.
            - size_average (:obj:`bool`): By default, the losses are averaged over each loss element in the batch. \
                Note that for some losses, there are multiple elements per sample. If the field ``size_average`` is \
                set to ``False``, the losses are instead summed for each minibatch. Ignored when ``reduce`` is \
                ``False``.
            - reduce (:obj:`Optional[bool]`): By default, the losses are averaged or summed over observations for \
                each minibatch depending on size_average. When ``reduce`` is ``False``, returns a loss for each batch \
                element instead and ignores ``size_average``.
        """
        super().__init__()
        self.gamma = gamma
        self.nll_loss = torch.nn.NLLLoss2d(weight, size_average, reduce=reduce)

    def forward(self, inputs: torch.Tensor, targets: torch.LongTensor) -> torch.Tensor:
        """
        Overview:
            Calculate soft focal loss.
        Arguments:
            - logits (:obj:`torch.Tensor`): Predicted logits.
            - labels (:obj:`torch.LongTensor`): Ground truth.
        Returns:
            - loss (:obj:`torch.Tensor`): Calculated loss.
        """
        return self.nll_loss((1 - F.softmax(inputs, 1)) ** self.gamma * F.log_softmax(inputs, 1), targets)


def build_ce_criterion(cfg: dict) -> nn.Module:
    """
    Overview:
        Get a cross entropy loss instance according to given config.
    Arguments:
        - cfg (:obj:`dict`) : Config dict. It contains:
            - type (:obj:`str`): Type of loss function, now supports ['cross_entropy', 'label_smooth_ce', \
                'soft_focal_loss'].
            - kwargs (:obj:`dict`): Arguments for the corresponding loss function.
    Returns:
        - loss (:obj:`nn.Module`): loss function instance
    """
    if cfg.type == 'cross_entropy':
        return nn.CrossEntropyLoss()
    elif cfg.type == 'label_smooth_ce':
        return LabelSmoothCELoss(cfg.kwargs.smooth_ratio)
    elif cfg.type == 'soft_focal_loss':
        return SoftFocalLoss()
    else:
        raise ValueError("invalid criterion type:{}".format(cfg.type))