Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
from mmocr.models.builder import LOSSES | |
from mmocr.models.common.losses.dice_loss import DiceLoss | |
class DBLoss(nn.Module): | |
"""The class for implementing DBNet loss. | |
This is partially adapted from https://github.com/MhLiao/DB. | |
Args: | |
alpha (float): The binary loss coef. | |
beta (float): The threshold loss coef. | |
reduction (str): The way to reduce the loss. | |
negative_ratio (float): The ratio of positives to negatives. | |
eps (float): Epsilon in the threshold loss function. | |
bbce_loss (bool): Whether to use balanced bce for probability loss. | |
If False, dice loss will be used instead. | |
""" | |
def __init__(self, | |
alpha=1, | |
beta=1, | |
reduction='mean', | |
negative_ratio=3.0, | |
eps=1e-6, | |
bbce_loss=False): | |
super().__init__() | |
assert reduction in ['mean', | |
'sum'], " reduction must in ['mean','sum']" | |
self.alpha = alpha | |
self.beta = beta | |
self.reduction = reduction | |
self.negative_ratio = negative_ratio | |
self.eps = eps | |
self.bbce_loss = bbce_loss | |
self.dice_loss = DiceLoss(eps=eps) | |
def bitmasks2tensor(self, bitmasks, target_sz): | |
"""Convert Bitmasks to tensor. | |
Args: | |
bitmasks (list[BitmapMasks]): The BitmapMasks list. Each item is | |
for one img. | |
target_sz (tuple(int, int)): The target tensor of size | |
:math:`(H, W)`. | |
Returns: | |
list[Tensor]: The list of kernel tensors. Each element stands for | |
one kernel level. | |
""" | |
assert isinstance(bitmasks, list) | |
assert isinstance(target_sz, tuple) | |
batch_size = len(bitmasks) | |
num_levels = len(bitmasks[0]) | |
result_tensors = [] | |
for level_inx in range(num_levels): | |
kernel = [] | |
for batch_inx in range(batch_size): | |
mask = torch.from_numpy(bitmasks[batch_inx].masks[level_inx]) | |
mask_sz = mask.shape | |
pad = [ | |
0, target_sz[1] - mask_sz[1], 0, target_sz[0] - mask_sz[0] | |
] | |
mask = F.pad(mask, pad, mode='constant', value=0) | |
kernel.append(mask) | |
kernel = torch.stack(kernel) | |
result_tensors.append(kernel) | |
return result_tensors | |
def balance_bce_loss(self, pred, gt, mask): | |
positive = (gt * mask) | |
negative = ((1 - gt) * mask) | |
positive_count = int(positive.float().sum()) | |
negative_count = min( | |
int(negative.float().sum()), | |
int(positive_count * self.negative_ratio)) | |
assert gt.max() <= 1 and gt.min() >= 0 | |
assert pred.max() <= 1 and pred.min() >= 0 | |
loss = F.binary_cross_entropy(pred, gt, reduction='none') | |
positive_loss = loss * positive.float() | |
negative_loss = loss * negative.float() | |
negative_loss, _ = torch.topk(negative_loss.view(-1), negative_count) | |
balance_loss = (positive_loss.sum() + negative_loss.sum()) / ( | |
positive_count + negative_count + self.eps) | |
return balance_loss | |
def l1_thr_loss(self, pred, gt, mask): | |
thr_loss = torch.abs((pred - gt) * mask).sum() / ( | |
mask.sum() + self.eps) | |
return thr_loss | |
def forward(self, preds, downsample_ratio, gt_shrink, gt_shrink_mask, | |
gt_thr, gt_thr_mask): | |
"""Compute DBNet loss. | |
Args: | |
preds (Tensor): The output tensor with size :math:`(N, 3, H, W)`. | |
downsample_ratio (float): The downsample ratio for the | |
ground truths. | |
gt_shrink (list[BitmapMasks]): The mask list with each element | |
being the shrunk text mask for one img. | |
gt_shrink_mask (list[BitmapMasks]): The effective mask list with | |
each element being the shrunk effective mask for one img. | |
gt_thr (list[BitmapMasks]): The mask list with each element | |
being the threshold text mask for one img. | |
gt_thr_mask (list[BitmapMasks]): The effective mask list with | |
each element being the threshold effective mask for one img. | |
Returns: | |
dict: The dict for dbnet losses with "loss_prob", "loss_db" and | |
"loss_thresh". | |
""" | |
assert isinstance(downsample_ratio, float) | |
assert isinstance(gt_shrink, list) | |
assert isinstance(gt_shrink_mask, list) | |
assert isinstance(gt_thr, list) | |
assert isinstance(gt_thr_mask, list) | |
pred_prob = preds[:, 0, :, :] | |
pred_thr = preds[:, 1, :, :] | |
pred_db = preds[:, 2, :, :] | |
feature_sz = preds.size() | |
keys = ['gt_shrink', 'gt_shrink_mask', 'gt_thr', 'gt_thr_mask'] | |
gt = {} | |
for k in keys: | |
gt[k] = eval(k) | |
gt[k] = [item.rescale(downsample_ratio) for item in gt[k]] | |
gt[k] = self.bitmasks2tensor(gt[k], feature_sz[2:]) | |
gt[k] = [item.to(preds.device) for item in gt[k]] | |
gt['gt_shrink'][0] = (gt['gt_shrink'][0] > 0).float() | |
if self.bbce_loss: | |
loss_prob = self.balance_bce_loss(pred_prob, gt['gt_shrink'][0], | |
gt['gt_shrink_mask'][0]) | |
else: | |
loss_prob = self.dice_loss(pred_prob, gt['gt_shrink'][0], | |
gt['gt_shrink_mask'][0]) | |
loss_db = self.dice_loss(pred_db, gt['gt_shrink'][0], | |
gt['gt_shrink_mask'][0]) | |
loss_thr = self.l1_thr_loss(pred_thr, gt['gt_thr'][0], | |
gt['gt_thr_mask'][0]) | |
results = dict( | |
loss_prob=self.alpha * loss_prob, | |
loss_db=loss_db, | |
loss_thr=self.beta * loss_thr) | |
return results | |