tomofi's picture
Add application file
2366e36
raw
history blame
No virus
12.8 kB
# Copyright (c) OpenMMLab. All rights reserved.
import itertools
import warnings
import numpy as np
import torch
import torch.nn.functional as F
from mmdet.core import BitmapMasks
from torch import nn
from mmocr.models.builder import LOSSES
from mmocr.utils import check_argument
@LOSSES.register_module()
class PANLoss(nn.Module):
"""The class for implementing PANet loss. This was partially adapted from
https://github.com/WenmuZhou/PAN.pytorch.
PANet: `Efficient and Accurate Arbitrary-
Shaped Text Detection with Pixel Aggregation Network
<https://arxiv.org/abs/1908.05900>`_.
Args:
alpha (float): The kernel loss coef.
beta (float): The aggregation and discriminative loss coef.
delta_aggregation (float): The constant for aggregation loss.
delta_discrimination (float): The constant for discriminative loss.
ohem_ratio (float): The negative/positive ratio in ohem.
reduction (str): The way to reduce the loss.
speedup_bbox_thr (int): Speed up if speedup_bbox_thr > 0
and < bbox num.
"""
def __init__(self,
alpha=0.5,
beta=0.25,
delta_aggregation=0.5,
delta_discrimination=3,
ohem_ratio=3,
reduction='mean',
speedup_bbox_thr=-1):
super().__init__()
assert reduction in ['mean', 'sum'], "reduction must in ['mean','sum']"
self.alpha = alpha
self.beta = beta
self.delta_aggregation = delta_aggregation
self.delta_discrimination = delta_discrimination
self.ohem_ratio = ohem_ratio
self.reduction = reduction
self.speedup_bbox_thr = speedup_bbox_thr
def bitmasks2tensor(self, bitmasks, target_sz):
"""Convert Bitmasks to tensor.
Args:
bitmasks (list[BitmapMasks]): The BitmapMasks list. Each item is
for one img.
target_sz (tuple(int, int)): The target tensor of size
:math:`(H, W)`.
Returns:
list[Tensor]: The list of kernel tensors. Each element stands for
one kernel level.
"""
assert check_argument.is_type_list(bitmasks, BitmapMasks)
assert isinstance(target_sz, tuple)
batch_size = len(bitmasks)
num_masks = len(bitmasks[0])
results = []
for level_inx in range(num_masks):
kernel = []
for batch_inx in range(batch_size):
mask = torch.from_numpy(bitmasks[batch_inx].masks[level_inx])
# hxw
mask_sz = mask.shape
# left, right, top, bottom
pad = [
0, target_sz[1] - mask_sz[1], 0, target_sz[0] - mask_sz[0]
]
mask = F.pad(mask, pad, mode='constant', value=0)
kernel.append(mask)
kernel = torch.stack(kernel)
results.append(kernel)
return results
def forward(self, preds, downsample_ratio, gt_kernels, gt_mask):
"""Compute PANet loss.
Args:
preds (Tensor): The output tensor of size :math:`(N, 6, H, W)`.
downsample_ratio (float): The downsample ratio between preds
and the input img.
gt_kernels (list[BitmapMasks]): The kernel list with each element
being the text kernel mask for one img.
gt_mask (list[BitmapMasks]): The effective mask list
with each element being the effective mask for one img.
Returns:
dict: A loss dict with ``loss_text``, ``loss_kernel``,
``loss_aggregation`` and ``loss_discrimination``.
"""
assert check_argument.is_type_list(gt_kernels, BitmapMasks)
assert check_argument.is_type_list(gt_mask, BitmapMasks)
assert isinstance(downsample_ratio, float)
pred_texts = preds[:, 0, :, :]
pred_kernels = preds[:, 1, :, :]
inst_embed = preds[:, 2:, :, :]
feature_sz = preds.size()
mapping = {'gt_kernels': gt_kernels, 'gt_mask': gt_mask}
gt = {}
for key, value in mapping.items():
gt[key] = value
gt[key] = [item.rescale(downsample_ratio) for item in gt[key]]
gt[key] = self.bitmasks2tensor(gt[key], feature_sz[2:])
gt[key] = [item.to(preds.device) for item in gt[key]]
loss_aggrs, loss_discrs = self.aggregation_discrimination_loss(
gt['gt_kernels'][0], gt['gt_kernels'][1], inst_embed)
# compute text loss
sampled_mask = self.ohem_batch(pred_texts.detach(),
gt['gt_kernels'][0], gt['gt_mask'][0])
loss_texts = self.dice_loss_with_logits(pred_texts,
gt['gt_kernels'][0],
sampled_mask)
# compute kernel loss
sampled_masks_kernel = (gt['gt_kernels'][0] > 0.5).float() * (
gt['gt_mask'][0].float())
loss_kernels = self.dice_loss_with_logits(pred_kernels,
gt['gt_kernels'][1],
sampled_masks_kernel)
losses = [loss_texts, loss_kernels, loss_aggrs, loss_discrs]
if self.reduction == 'mean':
losses = [item.mean() for item in losses]
elif self.reduction == 'sum':
losses = [item.sum() for item in losses]
else:
raise NotImplementedError
coefs = [1, self.alpha, self.beta, self.beta]
losses = [item * scale for item, scale in zip(losses, coefs)]
results = dict()
results.update(
loss_text=losses[0],
loss_kernel=losses[1],
loss_aggregation=losses[2],
loss_discrimination=losses[3])
return results
def aggregation_discrimination_loss(self, gt_texts, gt_kernels,
inst_embeds):
"""Compute the aggregation and discrimnative losses.
Args:
gt_texts (Tensor): The ground truth text mask of size
:math:`(N, 1, H, W)`.
gt_kernels (Tensor): The ground truth text kernel mask of
size :math:`(N, 1, H, W)`.
inst_embeds(Tensor): The text instance embedding tensor
of size :math:`(N, 1, H, W)`.
Returns:
(Tensor, Tensor): A tuple of aggregation loss and discriminative
loss before reduction.
"""
batch_size = gt_texts.size()[0]
gt_texts = gt_texts.contiguous().reshape(batch_size, -1)
gt_kernels = gt_kernels.contiguous().reshape(batch_size, -1)
assert inst_embeds.shape[1] == 4
inst_embeds = inst_embeds.contiguous().reshape(batch_size, 4, -1)
loss_aggrs = []
loss_discrs = []
for text, kernel, embed in zip(gt_texts, gt_kernels, inst_embeds):
# for each image
text_num = int(text.max().item())
loss_aggr_img = []
kernel_avgs = []
select_num = self.speedup_bbox_thr
if 0 < select_num < text_num:
inds = np.random.choice(
text_num, select_num, replace=False) + 1
else:
inds = range(1, text_num + 1)
for i in inds:
# for each text instance
kernel_i = (kernel == i) # 0.2ms
if kernel_i.sum() == 0 or (text == i).sum() == 0: # 0.2ms
continue
# compute G_Ki in Eq (2)
avg = embed[:, kernel_i].mean(1) # 0.5ms
kernel_avgs.append(avg)
embed_i = embed[:, text == i] # 0.6ms
# ||F(p) - G(K_i)|| - delta_aggregation, shape: nums
distance = (embed_i - avg.reshape(4, 1)).norm( # 0.5ms
2, dim=0) - self.delta_aggregation
# compute D(p,K_i) in Eq (2)
hinge = torch.max(
distance,
torch.tensor(0, device=distance.device,
dtype=torch.float)).pow(2)
aggr = torch.log(hinge + 1).mean()
loss_aggr_img.append(aggr)
num_inst = len(loss_aggr_img)
if num_inst > 0:
loss_aggr_img = torch.stack(loss_aggr_img).mean()
else:
loss_aggr_img = torch.tensor(
0, device=gt_texts.device, dtype=torch.float)
loss_aggrs.append(loss_aggr_img)
loss_discr_img = 0
for avg_i, avg_j in itertools.combinations(kernel_avgs, 2):
# delta_discrimination - ||G(K_i) - G(K_j)||
distance_ij = self.delta_discrimination - (avg_i -
avg_j).norm(2)
# D(K_i,K_j)
D_ij = torch.max(
distance_ij,
torch.tensor(
0, device=distance_ij.device,
dtype=torch.float)).pow(2)
loss_discr_img += torch.log(D_ij + 1)
if num_inst > 1:
loss_discr_img /= (num_inst * (num_inst - 1))
else:
loss_discr_img = torch.tensor(
0, device=gt_texts.device, dtype=torch.float)
if num_inst == 0:
warnings.warn('num of instance is 0')
loss_discrs.append(loss_discr_img)
return torch.stack(loss_aggrs), torch.stack(loss_discrs)
def dice_loss_with_logits(self, pred, target, mask):
smooth = 0.001
pred = torch.sigmoid(pred)
target[target <= 0.5] = 0
target[target > 0.5] = 1
pred = pred.contiguous().view(pred.size()[0], -1)
target = target.contiguous().view(target.size()[0], -1)
mask = mask.contiguous().view(mask.size()[0], -1)
pred = pred * mask
target = target * mask
a = torch.sum(pred * target, 1) + smooth
b = torch.sum(pred * pred, 1) + smooth
c = torch.sum(target * target, 1) + smooth
d = (2 * a) / (b + c)
return 1 - d
def ohem_img(self, text_score, gt_text, gt_mask):
"""Sample the top-k maximal negative samples and all positive samples.
Args:
text_score (Tensor): The text score of size :math:`(H, W)`.
gt_text (Tensor): The ground truth text mask of size
:math:`(H, W)`.
gt_mask (Tensor): The effective region mask of size :math:`(H, W)`.
Returns:
Tensor: The sampled pixel mask of size :math:`(H, W)`.
"""
assert isinstance(text_score, torch.Tensor)
assert isinstance(gt_text, torch.Tensor)
assert isinstance(gt_mask, torch.Tensor)
assert len(text_score.shape) == 2
assert text_score.shape == gt_text.shape
assert gt_text.shape == gt_mask.shape
pos_num = (int)(torch.sum(gt_text > 0.5).item()) - (int)(
torch.sum((gt_text > 0.5) * (gt_mask <= 0.5)).item())
neg_num = (int)(torch.sum(gt_text <= 0.5).item())
neg_num = (int)(min(pos_num * self.ohem_ratio, neg_num))
if pos_num == 0 or neg_num == 0:
warnings.warn('pos_num = 0 or neg_num = 0')
return gt_mask.bool()
neg_score = text_score[gt_text <= 0.5]
neg_score_sorted, _ = torch.sort(neg_score, descending=True)
threshold = neg_score_sorted[neg_num - 1]
sampled_mask = (((text_score >= threshold) + (gt_text > 0.5)) > 0) * (
gt_mask > 0.5)
return sampled_mask
def ohem_batch(self, text_scores, gt_texts, gt_mask):
"""OHEM sampling for a batch of imgs.
Args:
text_scores (Tensor): The text scores of size :math:`(H, W)`.
gt_texts (Tensor): The gt text masks of size :math:`(H, W)`.
gt_mask (Tensor): The gt effective mask of size :math:`(H, W)`.
Returns:
Tensor: The sampled mask of size :math:`(H, W)`.
"""
assert isinstance(text_scores, torch.Tensor)
assert isinstance(gt_texts, torch.Tensor)
assert isinstance(gt_mask, torch.Tensor)
assert len(text_scores.shape) == 3
assert text_scores.shape == gt_texts.shape
assert gt_texts.shape == gt_mask.shape
sampled_masks = []
for i in range(text_scores.shape[0]):
sampled_masks.append(
self.ohem_img(text_scores[i], gt_texts[i], gt_mask[i]))
sampled_masks = torch.stack(sampled_masks)
return sampled_masks