|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from typing import Any, Optional |
|
|
|
|
|
class LabelSmoothCELoss(nn.Module): |
|
""" |
|
Overview: |
|
Label smooth cross entropy loss. |
|
Interfaces: |
|
``__init__``, ``forward``. |
|
""" |
|
|
|
def __init__(self, ratio: float) -> None: |
|
""" |
|
Overview: |
|
Initialize the LabelSmoothCELoss object using the given arguments. |
|
Arguments: |
|
- ratio (:obj:`float`): The ratio of label-smoothing (the value is in 0-1). If the ratio is larger, the \ |
|
extent of label smoothing is larger. |
|
""" |
|
super().__init__() |
|
self.ratio = ratio |
|
|
|
def forward(self, logits: torch.Tensor, labels: torch.LongTensor) -> torch.Tensor: |
|
""" |
|
Overview: |
|
Calculate label smooth cross entropy loss. |
|
Arguments: |
|
- logits (:obj:`torch.Tensor`): Predicted logits. |
|
- labels (:obj:`torch.LongTensor`): Ground truth. |
|
Returns: |
|
- loss (:obj:`torch.Tensor`): Calculated loss. |
|
""" |
|
B, N = logits.shape |
|
val = float(self.ratio) / (N - 1) |
|
one_hot = torch.full_like(logits, val) |
|
one_hot.scatter_(1, labels.unsqueeze(1), 1 - val) |
|
logits = F.log_softmax(logits, dim=1) |
|
return -torch.sum(logits * (one_hot.detach())) / B |
|
|
|
|
|
class SoftFocalLoss(nn.Module): |
|
""" |
|
Overview: |
|
Soft focal loss. |
|
Interfaces: |
|
``__init__``, ``forward``. |
|
""" |
|
|
|
def __init__( |
|
self, gamma: int = 2, weight: Any = None, size_average: bool = True, reduce: Optional[bool] = None |
|
) -> None: |
|
""" |
|
Overview: |
|
Initialize the SoftFocalLoss object using the given arguments. |
|
Arguments: |
|
- gamma (:obj:`int`): The extent of focus on hard samples. A smaller ``gamma`` will lead to more focus on \ |
|
easy samples, while a larger ``gamma`` will lead to more focus on hard samples. |
|
- weight (:obj:`Any`): The weight for loss of each class. |
|
- size_average (:obj:`bool`): By default, the losses are averaged over each loss element in the batch. \ |
|
Note that for some losses, there are multiple elements per sample. If the field ``size_average`` is \ |
|
set to ``False``, the losses are instead summed for each minibatch. Ignored when ``reduce`` is \ |
|
``False``. |
|
- reduce (:obj:`Optional[bool]`): By default, the losses are averaged or summed over observations for \ |
|
each minibatch depending on size_average. When ``reduce`` is ``False``, returns a loss for each batch \ |
|
element instead and ignores ``size_average``. |
|
""" |
|
super().__init__() |
|
self.gamma = gamma |
|
self.nll_loss = torch.nn.NLLLoss2d(weight, size_average, reduce=reduce) |
|
|
|
def forward(self, inputs: torch.Tensor, targets: torch.LongTensor) -> torch.Tensor: |
|
""" |
|
Overview: |
|
Calculate soft focal loss. |
|
Arguments: |
|
- logits (:obj:`torch.Tensor`): Predicted logits. |
|
- labels (:obj:`torch.LongTensor`): Ground truth. |
|
Returns: |
|
- loss (:obj:`torch.Tensor`): Calculated loss. |
|
""" |
|
return self.nll_loss((1 - F.softmax(inputs, 1)) ** self.gamma * F.log_softmax(inputs, 1), targets) |
|
|
|
|
|
def build_ce_criterion(cfg: dict) -> nn.Module: |
|
""" |
|
Overview: |
|
Get a cross entropy loss instance according to given config. |
|
Arguments: |
|
- cfg (:obj:`dict`) : Config dict. It contains: |
|
- type (:obj:`str`): Type of loss function, now supports ['cross_entropy', 'label_smooth_ce', \ |
|
'soft_focal_loss']. |
|
- kwargs (:obj:`dict`): Arguments for the corresponding loss function. |
|
Returns: |
|
- loss (:obj:`nn.Module`): loss function instance |
|
""" |
|
if cfg.type == 'cross_entropy': |
|
return nn.CrossEntropyLoss() |
|
elif cfg.type == 'label_smooth_ce': |
|
return LabelSmoothCELoss(cfg.kwargs.smooth_ratio) |
|
elif cfg.type == 'soft_focal_loss': |
|
return SoftFocalLoss() |
|
else: |
|
raise ValueError("invalid criterion type:{}".format(cfg.type)) |
|
|