Spaces:
Runtime error
Runtime error
File size: 2,507 Bytes
2366e36 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmocr.models.builder import LOSSES
@LOSSES.register_module()
class SegLoss(nn.Module):
"""Implementation of loss module for segmentation based text recognition
method.
Args:
seg_downsample_ratio (float): Downsample ratio of
segmentation map.
seg_with_loss_weight (bool): If True, set weight for
segmentation loss.
ignore_index (int): Specifies a target value that is ignored
and does not contribute to the input gradient.
"""
def __init__(self,
seg_downsample_ratio=0.5,
seg_with_loss_weight=True,
ignore_index=255,
**kwargs):
super().__init__()
assert isinstance(seg_downsample_ratio, (int, float))
assert 0 < seg_downsample_ratio <= 1
assert isinstance(ignore_index, int)
self.seg_downsample_ratio = seg_downsample_ratio
self.seg_with_loss_weight = seg_with_loss_weight
self.ignore_index = ignore_index
def seg_loss(self, out_head, gt_kernels):
seg_map = out_head # bsz * num_classes * H/2 * W/2
seg_target = [
item[1].rescale(self.seg_downsample_ratio).to_tensor(
torch.long, seg_map.device) for item in gt_kernels
]
seg_target = torch.stack(seg_target).squeeze(1)
loss_weight = None
if self.seg_with_loss_weight:
N = torch.sum(seg_target != self.ignore_index)
N_neg = torch.sum(seg_target == 0)
weight_val = 1.0 * N_neg / (N - N_neg)
loss_weight = torch.ones(seg_map.size(1), device=seg_map.device)
loss_weight[1:] = weight_val
loss_seg = F.cross_entropy(
seg_map,
seg_target,
weight=loss_weight,
ignore_index=self.ignore_index)
return loss_seg
def forward(self, out_neck, out_head, gt_kernels):
"""
Args:
out_neck (None): Unused.
out_head (Tensor): The output from head whose shape
is :math:`(N, C, H, W)`.
gt_kernels (BitmapMasks): The ground truth masks.
Returns:
dict: A loss dictionary with the key ``loss_seg``.
"""
losses = {}
loss_seg = self.seg_loss(out_head, gt_kernels)
losses['loss_seg'] = loss_seg
return losses
|