Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import math | |
import torch | |
import torch.nn as nn | |
from mmocr.models.builder import LOSSES | |
class CTCLoss(nn.Module): | |
"""Implementation of loss module for CTC-loss based text recognition. | |
Args: | |
flatten (bool): If True, use flattened targets, else padded targets. | |
blank (int): Blank label. Default 0. | |
reduction (str): Specifies the reduction to apply to the output, | |
should be one of the following: ('none', 'mean', 'sum'). | |
zero_infinity (bool): Whether to zero infinite losses and | |
the associated gradients. Default: False. | |
Infinite losses mainly occur when the inputs | |
are too short to be aligned to the targets. | |
""" | |
def __init__(self, | |
flatten=True, | |
blank=0, | |
reduction='mean', | |
zero_infinity=False, | |
**kwargs): | |
super().__init__() | |
assert isinstance(flatten, bool) | |
assert isinstance(blank, int) | |
assert isinstance(reduction, str) | |
assert isinstance(zero_infinity, bool) | |
self.flatten = flatten | |
self.blank = blank | |
self.ctc_loss = nn.CTCLoss( | |
blank=blank, reduction=reduction, zero_infinity=zero_infinity) | |
def forward(self, outputs, targets_dict, img_metas=None): | |
""" | |
Args: | |
outputs (Tensor): A raw logit tensor of shape :math:`(N, T, C)`. | |
targets_dict (dict): A dict with 3 keys ``target_lengths``, | |
``flatten_targets`` and ``targets``. | |
- | ``target_lengths`` (Tensor): A tensor of shape :math:`(N)`. | |
Each item is the length of a word. | |
- | ``flatten_targets`` (Tensor): Used if ``self.flatten=True`` | |
(default). A tensor of shape | |
(sum(targets_dict['target_lengths'])). Each item is the | |
index of a character. | |
- | ``targets`` (Tensor): Used if ``self.flatten=False``. A | |
tensor of :math:`(N, T)`. Empty slots are padded with | |
``self.blank``. | |
img_metas (dict): A dict that contains meta information of input | |
images. Preferably with the key ``valid_ratio``. | |
Returns: | |
dict: The loss dict with key ``loss_ctc``. | |
""" | |
valid_ratios = None | |
if img_metas is not None: | |
valid_ratios = [ | |
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas | |
] | |
outputs = torch.log_softmax(outputs, dim=2) | |
bsz, seq_len = outputs.size(0), outputs.size(1) | |
outputs_for_loss = outputs.permute(1, 0, 2).contiguous() # T * N * C | |
if self.flatten: | |
targets = targets_dict['flatten_targets'] | |
else: | |
targets = torch.full( | |
size=(bsz, seq_len), fill_value=self.blank, dtype=torch.long) | |
for idx, tensor in enumerate(targets_dict['targets']): | |
valid_len = min(tensor.size(0), seq_len) | |
targets[idx, :valid_len] = tensor[:valid_len] | |
target_lengths = targets_dict['target_lengths'] | |
target_lengths = torch.clamp(target_lengths, min=1, max=seq_len).long() | |
input_lengths = torch.full( | |
size=(bsz, ), fill_value=seq_len, dtype=torch.long) | |
if not self.flatten and valid_ratios is not None: | |
input_lengths = [ | |
math.ceil(valid_ratio * seq_len) | |
for valid_ratio in valid_ratios | |
] | |
input_lengths = torch.Tensor(input_lengths).long() | |
loss_ctc = self.ctc_loss(outputs_for_loss, targets, input_lengths, | |
target_lengths) | |
losses = dict(loss_ctc=loss_ctc) | |
return losses | |