Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from torch import nn | |
from mmocr.models.builder import LOSSES | |
from mmocr.models.common.losses.focal_loss import FocalLoss | |
class MaskedFocalLoss(nn.Module): | |
"""The implementation of masked focal loss. | |
The mask has 1 for real tokens and 0 for padding tokens, | |
which only keep active parts of the focal loss | |
Args: | |
num_labels (int): Number of classes in labels. | |
ignore_index (int): Specifies a target value that is ignored | |
and does not contribute to the input gradient. | |
""" | |
def __init__(self, num_labels=None, ignore_index=0): | |
super().__init__() | |
self.num_labels = num_labels | |
self.criterion = FocalLoss(ignore_index=ignore_index) | |
def forward(self, logits, img_metas): | |
'''Loss forword. | |
Args: | |
logits: Model output with shape [N, C]. | |
img_metas (dict): A dict containing the following keys: | |
- img (list]): This parameter is reserved. | |
- labels (list[int]): The labels for each word | |
of the sequence. | |
- texts (list): The words of the sequence. | |
- input_ids (list): The ids for each word of | |
the sequence. | |
- attention_mask (list): The mask for each word | |
of the sequence. The mask has 1 for real tokens | |
and 0 for padding tokens. Only real tokens are | |
attended to. | |
- token_type_ids (list): The tokens for each word | |
of the sequence. | |
''' | |
labels = img_metas['labels'] | |
attention_masks = img_metas['attention_masks'] | |
# Only keep active parts of the loss | |
if attention_masks is not None: | |
active_loss = attention_masks.view(-1) == 1 | |
active_logits = logits.view(-1, self.num_labels)[active_loss] | |
active_labels = labels.view(-1)[active_loss] | |
loss = self.criterion(active_logits, active_labels) | |
else: | |
loss = self.criterion( | |
logits.view(-1, self.num_labels), labels.view(-1)) | |
return {'loss_cls': loss} | |