File size: 4,184 Bytes
2366e36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# Copyright (c) OpenMMLab. All rights reserved.
from mmdet.core import BitmapMasks

from mmocr.models.builder import LOSSES
from mmocr.utils import check_argument
from . import PANLoss


@LOSSES.register_module()
class PSELoss(PANLoss):
    r"""The class for implementing PSENet loss. This is partially adapted from
    https://github.com/whai362/PSENet.

    PSENet: `Shape Robust Text Detection with
    Progressive Scale Expansion Network <https://arxiv.org/abs/1806.02559>`_.

    Args:
        alpha (float): Text loss coefficient, and :math:`1-\alpha` is the
            kernel loss coefficient.
        ohem_ratio (float): The negative/positive ratio in ohem.
        reduction (str): The way to reduce the loss. Available options are
            "mean" and "sum".
    """

    def __init__(self,
                 alpha=0.7,
                 ohem_ratio=3,
                 reduction='mean',
                 kernel_sample_type='adaptive'):
        super().__init__()
        assert reduction in ['mean', 'sum'
                             ], "reduction must be either of ['mean','sum']"
        self.alpha = alpha
        self.ohem_ratio = ohem_ratio
        self.reduction = reduction
        self.kernel_sample_type = kernel_sample_type

    def forward(self, score_maps, downsample_ratio, gt_kernels, gt_mask):
        """Compute PSENet loss.

        Args:
            score_maps (tensor): The output tensor with size of Nx6xHxW.
            downsample_ratio (float): The downsample ratio between score_maps
                and the input img.
            gt_kernels (list[BitmapMasks]): The kernel list with each element
                being the text kernel mask for one img.
            gt_mask (list[BitmapMasks]): The effective mask list
                with each element being the effective mask for one img.

        Returns:
            dict:  A loss dict with ``loss_text`` and ``loss_kernel``.
        """

        assert check_argument.is_type_list(gt_kernels, BitmapMasks)
        assert check_argument.is_type_list(gt_mask, BitmapMasks)
        assert isinstance(downsample_ratio, float)
        losses = []

        pred_texts = score_maps[:, 0, :, :]
        pred_kernels = score_maps[:, 1:, :, :]
        feature_sz = score_maps.size()

        gt_kernels = [item.rescale(downsample_ratio) for item in gt_kernels]
        gt_kernels = self.bitmasks2tensor(gt_kernels, feature_sz[2:])
        gt_kernels = [item.to(score_maps.device) for item in gt_kernels]

        gt_mask = [item.rescale(downsample_ratio) for item in gt_mask]
        gt_mask = self.bitmasks2tensor(gt_mask, feature_sz[2:])
        gt_mask = [item.to(score_maps.device) for item in gt_mask]

        # compute text loss
        sampled_masks_text = self.ohem_batch(pred_texts.detach(),
                                             gt_kernels[0], gt_mask[0])
        loss_texts = self.dice_loss_with_logits(pred_texts, gt_kernels[0],
                                                sampled_masks_text)
        losses.append(self.alpha * loss_texts)

        # compute kernel loss
        if self.kernel_sample_type == 'hard':
            sampled_masks_kernel = (gt_kernels[0] > 0.5).float() * (
                gt_mask[0].float())
        elif self.kernel_sample_type == 'adaptive':
            sampled_masks_kernel = (pred_texts > 0).float() * (
                gt_mask[0].float())
        else:
            raise NotImplementedError

        num_kernel = pred_kernels.shape[1]
        assert num_kernel == len(gt_kernels) - 1
        loss_list = []
        for idx in range(num_kernel):
            loss_kernels = self.dice_loss_with_logits(
                pred_kernels[:, idx, :, :], gt_kernels[1 + idx],
                sampled_masks_kernel)
            loss_list.append(loss_kernels)

        losses.append((1 - self.alpha) * sum(loss_list) / len(loss_list))

        if self.reduction == 'mean':
            losses = [item.mean() for item in losses]
        elif self.reduction == 'sum':
            losses = [item.sum() for item in losses]
        else:
            raise NotImplementedError
        results = dict(loss_text=losses[0], loss_kernel=losses[1])
        return results