tomofi's picture
Add application file
2366e36
raw
history blame
No virus
4.66 kB
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmocr.models.builder import LOSSES
@LOSSES.register_module()
class CELoss(nn.Module):
"""Implementation of loss module for encoder-decoder based text recognition
method with CrossEntropy loss.
Args:
ignore_index (int): Specifies a target value that is
ignored and does not contribute to the input gradient.
reduction (str): Specifies the reduction to apply to the output,
should be one of the following: ('none', 'mean', 'sum').
ignore_first_char (bool): Whether to ignore the first token in target (
usually the start token). If ``True``, the last token of the output
sequence will also be removed to be aligned with the target length.
"""
def __init__(self,
ignore_index=-1,
reduction='none',
ignore_first_char=False):
super().__init__()
assert isinstance(ignore_index, int)
assert isinstance(reduction, str)
assert reduction in ['none', 'mean', 'sum']
assert isinstance(ignore_first_char, bool)
self.loss_ce = nn.CrossEntropyLoss(
ignore_index=ignore_index, reduction=reduction)
self.ignore_first_char = ignore_first_char
def format(self, outputs, targets_dict):
targets = targets_dict['padded_targets']
if self.ignore_first_char:
targets = targets[:, 1:].contiguous()
outputs = outputs[:, :-1, :]
outputs = outputs.permute(0, 2, 1).contiguous()
return outputs, targets
def forward(self, outputs, targets_dict, img_metas=None):
"""
Args:
outputs (Tensor): A raw logit tensor of shape :math:`(N, T, C)`.
targets_dict (dict): A dict with a key ``padded_targets``, which is
a tensor of shape :math:`(N, T)`. Each element is the index of
a character.
img_metas (None): Unused.
Returns:
dict: A loss dict with the key ``loss_ce``.
"""
outputs, targets = self.format(outputs, targets_dict)
loss_ce = self.loss_ce(outputs, targets.to(outputs.device))
losses = dict(loss_ce=loss_ce)
return losses
@LOSSES.register_module()
class SARLoss(CELoss):
"""Implementation of loss module in `SAR.
<https://arxiv.org/abs/1811.00751>`_.
Args:
ignore_index (int): Specifies a target value that is
ignored and does not contribute to the input gradient.
reduction (str): Specifies the reduction to apply to the output,
should be one of the following: ("none", "mean", "sum").
Warning:
SARLoss assumes that the first input token is always `<SOS>`.
"""
def __init__(self, ignore_index=0, reduction='mean', **kwargs):
super().__init__(ignore_index, reduction)
def format(self, outputs, targets_dict):
targets = targets_dict['padded_targets']
# targets[0, :], [start_idx, idx1, idx2, ..., end_idx, pad_idx...]
# outputs[0, :, 0], [idx1, idx2, ..., end_idx, ...]
# ignore first index of target in loss calculation
targets = targets[:, 1:].contiguous()
# ignore last index of outputs to be in same seq_len with targets
outputs = outputs[:, :-1, :].permute(0, 2, 1).contiguous()
return outputs, targets
@LOSSES.register_module()
class TFLoss(CELoss):
"""Implementation of loss module for transformer.
Args:
ignore_index (int, optional): The character index to be ignored in
loss computation.
reduction (str): Type of reduction to apply to the output,
should be one of the following: ("none", "mean", "sum").
flatten (bool): Whether to flatten the vectors for loss computation.
Warning:
TFLoss assumes that the first input token is always `<SOS>`.
"""
def __init__(self,
ignore_index=-1,
reduction='none',
flatten=True,
**kwargs):
super().__init__(ignore_index, reduction)
assert isinstance(flatten, bool)
self.flatten = flatten
def format(self, outputs, targets_dict):
outputs = outputs[:, :-1, :].contiguous()
targets = targets_dict['padded_targets']
targets = targets[:, 1:].contiguous()
if self.flatten:
outputs = outputs.view(-1, outputs.size(-1))
targets = targets.view(-1)
else:
outputs = outputs.permute(0, 2, 1).contiguous()
return outputs, targets