File size: 2,507 Bytes
2366e36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F

from mmocr.models.builder import LOSSES


@LOSSES.register_module()
class SegLoss(nn.Module):
    """Implementation of loss module for segmentation based text recognition
    method.

    Args:
        seg_downsample_ratio (float): Downsample ratio of
            segmentation map.
        seg_with_loss_weight (bool): If True, set weight for
            segmentation loss.
        ignore_index (int): Specifies a target value that is ignored
            and does not contribute to the input gradient.
    """

    def __init__(self,
                 seg_downsample_ratio=0.5,
                 seg_with_loss_weight=True,
                 ignore_index=255,
                 **kwargs):
        super().__init__()

        assert isinstance(seg_downsample_ratio, (int, float))
        assert 0 < seg_downsample_ratio <= 1
        assert isinstance(ignore_index, int)

        self.seg_downsample_ratio = seg_downsample_ratio
        self.seg_with_loss_weight = seg_with_loss_weight
        self.ignore_index = ignore_index

    def seg_loss(self, out_head, gt_kernels):
        seg_map = out_head  # bsz * num_classes * H/2 * W/2
        seg_target = [
            item[1].rescale(self.seg_downsample_ratio).to_tensor(
                torch.long, seg_map.device) for item in gt_kernels
        ]
        seg_target = torch.stack(seg_target).squeeze(1)

        loss_weight = None
        if self.seg_with_loss_weight:
            N = torch.sum(seg_target != self.ignore_index)
            N_neg = torch.sum(seg_target == 0)
            weight_val = 1.0 * N_neg / (N - N_neg)
            loss_weight = torch.ones(seg_map.size(1), device=seg_map.device)
            loss_weight[1:] = weight_val

        loss_seg = F.cross_entropy(
            seg_map,
            seg_target,
            weight=loss_weight,
            ignore_index=self.ignore_index)

        return loss_seg

    def forward(self, out_neck, out_head, gt_kernels):
        """
        Args:
            out_neck (None): Unused.
            out_head (Tensor): The output from head whose shape
                is :math:`(N, C, H, W)`.
            gt_kernels (BitmapMasks): The ground truth masks.

        Returns:
            dict: A loss dictionary with the key ``loss_seg``.
        """

        losses = {}

        loss_seg = self.seg_loss(out_head, gt_kernels)

        losses['loss_seg'] = loss_seg

        return losses