Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import warnings | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from ..builder import LOSSES | |
from .utils import weight_reduce_loss | |
def cross_entropy(pred, | |
label, | |
weight=None, | |
reduction='mean', | |
avg_factor=None, | |
class_weight=None, | |
ignore_index=-100, | |
avg_non_ignore=False): | |
"""Calculate the CrossEntropy loss. | |
Args: | |
pred (torch.Tensor): The prediction with shape (N, C), C is the number | |
of classes. | |
label (torch.Tensor): The learning label of the prediction. | |
weight (torch.Tensor, optional): Sample-wise loss weight. | |
reduction (str, optional): The method used to reduce the loss. | |
avg_factor (int, optional): Average factor that is used to average | |
the loss. Defaults to None. | |
class_weight (list[float], optional): The weight for each class. | |
ignore_index (int | None): The label index to be ignored. | |
If None, it will be set to default value. Default: -100. | |
avg_non_ignore (bool): The flag decides to whether the loss is | |
only averaged over non-ignored targets. Default: False. | |
Returns: | |
torch.Tensor: The calculated loss | |
""" | |
# The default value of ignore_index is the same as F.cross_entropy | |
ignore_index = -100 if ignore_index is None else ignore_index | |
# element-wise losses | |
loss = F.cross_entropy( | |
pred, | |
label, | |
weight=class_weight, | |
reduction='none', | |
ignore_index=ignore_index) | |
# average loss over non-ignored elements | |
# pytorch's official cross_entropy average loss over non-ignored elements | |
# refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa | |
if (avg_factor is None) and avg_non_ignore and reduction == 'mean': | |
avg_factor = label.numel() - (label == ignore_index).sum().item() | |
# apply weights and do the reduction | |
if weight is not None: | |
weight = weight.float() | |
loss = weight_reduce_loss( | |
loss, weight=weight, reduction=reduction, avg_factor=avg_factor) | |
return loss | |
def _expand_onehot_labels(labels, label_weights, label_channels, ignore_index): | |
"""Expand onehot labels to match the size of prediction.""" | |
bin_labels = labels.new_full((labels.size(0), label_channels), 0) | |
valid_mask = (labels >= 0) & (labels != ignore_index) | |
inds = torch.nonzero( | |
valid_mask & (labels < label_channels), as_tuple=False) | |
if inds.numel() > 0: | |
bin_labels[inds, labels[inds]] = 1 | |
valid_mask = valid_mask.view(-1, 1).expand(labels.size(0), | |
label_channels).float() | |
if label_weights is None: | |
bin_label_weights = valid_mask | |
else: | |
bin_label_weights = label_weights.view(-1, 1).repeat(1, label_channels) | |
bin_label_weights *= valid_mask | |
return bin_labels, bin_label_weights, valid_mask | |
def binary_cross_entropy(pred, | |
label, | |
weight=None, | |
reduction='mean', | |
avg_factor=None, | |
class_weight=None, | |
ignore_index=-100, | |
avg_non_ignore=False): | |
"""Calculate the binary CrossEntropy loss. | |
Args: | |
pred (torch.Tensor): The prediction with shape (N, 1) or (N, ). | |
When the shape of pred is (N, 1), label will be expanded to | |
one-hot format, and when the shape of pred is (N, ), label | |
will not be expanded to one-hot format. | |
label (torch.Tensor): The learning label of the prediction, | |
with shape (N, ). | |
weight (torch.Tensor, optional): Sample-wise loss weight. | |
reduction (str, optional): The method used to reduce the loss. | |
Options are "none", "mean" and "sum". | |
avg_factor (int, optional): Average factor that is used to average | |
the loss. Defaults to None. | |
class_weight (list[float], optional): The weight for each class. | |
ignore_index (int | None): The label index to be ignored. | |
If None, it will be set to default value. Default: -100. | |
avg_non_ignore (bool): The flag decides to whether the loss is | |
only averaged over non-ignored targets. Default: False. | |
Returns: | |
torch.Tensor: The calculated loss. | |
""" | |
# The default value of ignore_index is the same as F.cross_entropy | |
ignore_index = -100 if ignore_index is None else ignore_index | |
if pred.dim() != label.dim(): | |
label, weight, valid_mask = _expand_onehot_labels( | |
label, weight, pred.size(-1), ignore_index) | |
else: | |
# should mask out the ignored elements | |
valid_mask = ((label >= 0) & (label != ignore_index)).float() | |
if weight is not None: | |
# The inplace writing method will have a mismatched broadcast | |
# shape error if the weight and valid_mask dimensions | |
# are inconsistent such as (B,N,1) and (B,N,C). | |
weight = weight * valid_mask | |
else: | |
weight = valid_mask | |
# average loss over non-ignored elements | |
if (avg_factor is None) and avg_non_ignore and reduction == 'mean': | |
avg_factor = valid_mask.sum().item() | |
# weighted element-wise losses | |
weight = weight.float() | |
loss = F.binary_cross_entropy_with_logits( | |
pred, label.float(), pos_weight=class_weight, reduction='none') | |
# do the reduction for the weighted loss | |
loss = weight_reduce_loss( | |
loss, weight, reduction=reduction, avg_factor=avg_factor) | |
return loss | |
def mask_cross_entropy(pred, | |
target, | |
label, | |
reduction='mean', | |
avg_factor=None, | |
class_weight=None, | |
ignore_index=None, | |
**kwargs): | |
"""Calculate the CrossEntropy loss for masks. | |
Args: | |
pred (torch.Tensor): The prediction with shape (N, C, *), C is the | |
number of classes. The trailing * indicates arbitrary shape. | |
target (torch.Tensor): The learning label of the prediction. | |
label (torch.Tensor): ``label`` indicates the class label of the mask | |
corresponding object. This will be used to select the mask in the | |
of the class which the object belongs to when the mask prediction | |
if not class-agnostic. | |
reduction (str, optional): The method used to reduce the loss. | |
Options are "none", "mean" and "sum". | |
avg_factor (int, optional): Average factor that is used to average | |
the loss. Defaults to None. | |
class_weight (list[float], optional): The weight for each class. | |
ignore_index (None): Placeholder, to be consistent with other loss. | |
Default: None. | |
Returns: | |
torch.Tensor: The calculated loss | |
Example: | |
>>> N, C = 3, 11 | |
>>> H, W = 2, 2 | |
>>> pred = torch.randn(N, C, H, W) * 1000 | |
>>> target = torch.rand(N, H, W) | |
>>> label = torch.randint(0, C, size=(N,)) | |
>>> reduction = 'mean' | |
>>> avg_factor = None | |
>>> class_weights = None | |
>>> loss = mask_cross_entropy(pred, target, label, reduction, | |
>>> avg_factor, class_weights) | |
>>> assert loss.shape == (1,) | |
""" | |
assert ignore_index is None, 'BCE loss does not support ignore_index' | |
# TODO: handle these two reserved arguments | |
assert reduction == 'mean' and avg_factor is None | |
num_rois = pred.size()[0] | |
inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device) | |
pred_slice = pred[inds, label].squeeze(1) | |
return F.binary_cross_entropy_with_logits( | |
pred_slice, target, weight=class_weight, reduction='mean')[None] | |
class CrossEntropyLoss(nn.Module): | |
def __init__(self, | |
use_sigmoid=False, | |
use_mask=False, | |
reduction='mean', | |
class_weight=None, | |
ignore_index=None, | |
loss_weight=1.0, | |
avg_non_ignore=False): | |
"""CrossEntropyLoss. | |
Args: | |
use_sigmoid (bool, optional): Whether the prediction uses sigmoid | |
of softmax. Defaults to False. | |
use_mask (bool, optional): Whether to use mask cross entropy loss. | |
Defaults to False. | |
reduction (str, optional): . Defaults to 'mean'. | |
Options are "none", "mean" and "sum". | |
class_weight (list[float], optional): Weight of each class. | |
Defaults to None. | |
ignore_index (int | None): The label index to be ignored. | |
Defaults to None. | |
loss_weight (float, optional): Weight of the loss. Defaults to 1.0. | |
avg_non_ignore (bool): The flag decides to whether the loss is | |
only averaged over non-ignored targets. Default: False. | |
""" | |
super(CrossEntropyLoss, self).__init__() | |
assert (use_sigmoid is False) or (use_mask is False) | |
self.use_sigmoid = use_sigmoid | |
self.use_mask = use_mask | |
self.reduction = reduction | |
self.loss_weight = loss_weight | |
self.class_weight = class_weight | |
self.ignore_index = ignore_index | |
self.avg_non_ignore = avg_non_ignore | |
if ((ignore_index is not None) and not self.avg_non_ignore | |
and self.reduction == 'mean'): | |
warnings.warn( | |
'Default ``avg_non_ignore`` is False, if you would like to ' | |
'ignore the certain label and average loss over non-ignore ' | |
'labels, which is the same with PyTorch official ' | |
'cross_entropy, set ``avg_non_ignore=True``.') | |
if self.use_sigmoid: | |
self.cls_criterion = binary_cross_entropy | |
elif self.use_mask: | |
self.cls_criterion = mask_cross_entropy | |
else: | |
self.cls_criterion = cross_entropy | |
def extra_repr(self): | |
"""Extra repr.""" | |
s = f'avg_non_ignore={self.avg_non_ignore}' | |
return s | |
def forward(self, | |
cls_score, | |
label, | |
weight=None, | |
avg_factor=None, | |
reduction_override=None, | |
ignore_index=None, | |
**kwargs): | |
"""Forward function. | |
Args: | |
cls_score (torch.Tensor): The prediction. | |
label (torch.Tensor): The learning label of the prediction. | |
weight (torch.Tensor, optional): Sample-wise loss weight. | |
avg_factor (int, optional): Average factor that is used to average | |
the loss. Defaults to None. | |
reduction_override (str, optional): The method used to reduce the | |
loss. Options are "none", "mean" and "sum". | |
ignore_index (int | None): The label index to be ignored. | |
If not None, it will override the default value. Default: None. | |
Returns: | |
torch.Tensor: The calculated loss. | |
""" | |
assert reduction_override in (None, 'none', 'mean', 'sum') | |
reduction = ( | |
reduction_override if reduction_override else self.reduction) | |
if ignore_index is None: | |
ignore_index = self.ignore_index | |
if self.class_weight is not None: | |
class_weight = cls_score.new_tensor( | |
self.class_weight, device=cls_score.device) | |
else: | |
class_weight = None | |
loss_cls = self.loss_weight * self.cls_criterion( | |
cls_score, | |
label, | |
weight, | |
class_weight=class_weight, | |
reduction=reduction, | |
avg_factor=avg_factor, | |
ignore_index=ignore_index, | |
avg_non_ignore=self.avg_non_ignore, | |
**kwargs) | |
return loss_cls | |