Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from mmdet.core import multi_apply | |
from torch import nn | |
from mmocr.models.builder import LOSSES | |
class FCELoss(nn.Module): | |
"""The class for implementing FCENet loss. | |
FCENet(CVPR2021): `Fourier Contour Embedding for Arbitrary-shaped Text | |
Detection <https://arxiv.org/abs/2104.10442>`_ | |
Args: | |
fourier_degree (int) : The maximum Fourier transform degree k. | |
num_sample (int) : The sampling points number of regression | |
loss. If it is too small, fcenet tends to be overfitting. | |
ohem_ratio (float): the negative/positive ratio in OHEM. | |
""" | |
def __init__(self, fourier_degree, num_sample, ohem_ratio=3.): | |
super().__init__() | |
self.fourier_degree = fourier_degree | |
self.num_sample = num_sample | |
self.ohem_ratio = ohem_ratio | |
def forward(self, preds, _, p3_maps, p4_maps, p5_maps): | |
"""Compute FCENet loss. | |
Args: | |
preds (list[list[Tensor]]): The outer list indicates images | |
in a batch, and the inner list indicates the classification | |
prediction map (with shape :math:`(N, C, H, W)`) and | |
regression map (with shape :math:`(N, C, H, W)`). | |
p3_maps (list[ndarray]): List of leval 3 ground truth target map | |
with shape :math:`(C, H, W)`. | |
p4_maps (list[ndarray]): List of leval 4 ground truth target map | |
with shape :math:`(C, H, W)`. | |
p5_maps (list[ndarray]): List of leval 5 ground truth target map | |
with shape :math:`(C, H, W)`. | |
Returns: | |
dict: A loss dict with ``loss_text``, ``loss_center``, | |
``loss_reg_x`` and ``loss_reg_y``. | |
""" | |
assert isinstance(preds, list) | |
assert p3_maps[0].shape[0] == 4 * self.fourier_degree + 5,\ | |
'fourier degree not equal in FCEhead and FCEtarget' | |
device = preds[0][0].device | |
# to tensor | |
gts = [p3_maps, p4_maps, p5_maps] | |
for idx, maps in enumerate(gts): | |
gts[idx] = torch.from_numpy(np.stack(maps)).float().to(device) | |
losses = multi_apply(self.forward_single, preds, gts) | |
loss_tr = torch.tensor(0., device=device).float() | |
loss_tcl = torch.tensor(0., device=device).float() | |
loss_reg_x = torch.tensor(0., device=device).float() | |
loss_reg_y = torch.tensor(0., device=device).float() | |
for idx, loss in enumerate(losses): | |
if idx == 0: | |
loss_tr += sum(loss) | |
elif idx == 1: | |
loss_tcl += sum(loss) | |
elif idx == 2: | |
loss_reg_x += sum(loss) | |
else: | |
loss_reg_y += sum(loss) | |
results = dict( | |
loss_text=loss_tr, | |
loss_center=loss_tcl, | |
loss_reg_x=loss_reg_x, | |
loss_reg_y=loss_reg_y, | |
) | |
return results | |
def forward_single(self, pred, gt): | |
cls_pred = pred[0].permute(0, 2, 3, 1).contiguous() | |
reg_pred = pred[1].permute(0, 2, 3, 1).contiguous() | |
gt = gt.permute(0, 2, 3, 1).contiguous() | |
k = 2 * self.fourier_degree + 1 | |
tr_pred = cls_pred[:, :, :, :2].view(-1, 2) | |
tcl_pred = cls_pred[:, :, :, 2:].view(-1, 2) | |
x_pred = reg_pred[:, :, :, 0:k].view(-1, k) | |
y_pred = reg_pred[:, :, :, k:2 * k].view(-1, k) | |
tr_mask = gt[:, :, :, :1].view(-1) | |
tcl_mask = gt[:, :, :, 1:2].view(-1) | |
train_mask = gt[:, :, :, 2:3].view(-1) | |
x_map = gt[:, :, :, 3:3 + k].view(-1, k) | |
y_map = gt[:, :, :, 3 + k:].view(-1, k) | |
tr_train_mask = train_mask * tr_mask | |
device = x_map.device | |
# tr loss | |
loss_tr = self.ohem(tr_pred, tr_mask.long(), train_mask.long()) | |
# tcl loss | |
loss_tcl = torch.tensor(0.).float().to(device) | |
tr_neg_mask = 1 - tr_train_mask | |
if tr_train_mask.sum().item() > 0: | |
loss_tcl_pos = F.cross_entropy( | |
tcl_pred[tr_train_mask.bool()], | |
tcl_mask[tr_train_mask.bool()].long()) | |
loss_tcl_neg = F.cross_entropy(tcl_pred[tr_neg_mask.bool()], | |
tcl_mask[tr_neg_mask.bool()].long()) | |
loss_tcl = loss_tcl_pos + 0.5 * loss_tcl_neg | |
# regression loss | |
loss_reg_x = torch.tensor(0.).float().to(device) | |
loss_reg_y = torch.tensor(0.).float().to(device) | |
if tr_train_mask.sum().item() > 0: | |
weight = (tr_mask[tr_train_mask.bool()].float() + | |
tcl_mask[tr_train_mask.bool()].float()) / 2 | |
weight = weight.contiguous().view(-1, 1) | |
ft_x, ft_y = self.fourier2poly(x_map, y_map) | |
ft_x_pre, ft_y_pre = self.fourier2poly(x_pred, y_pred) | |
loss_reg_x = torch.mean(weight * F.smooth_l1_loss( | |
ft_x_pre[tr_train_mask.bool()], | |
ft_x[tr_train_mask.bool()], | |
reduction='none')) | |
loss_reg_y = torch.mean(weight * F.smooth_l1_loss( | |
ft_y_pre[tr_train_mask.bool()], | |
ft_y[tr_train_mask.bool()], | |
reduction='none')) | |
return loss_tr, loss_tcl, loss_reg_x, loss_reg_y | |
def ohem(self, predict, target, train_mask): | |
device = train_mask.device | |
pos = (target * train_mask).bool() | |
neg = ((1 - target) * train_mask).bool() | |
n_pos = pos.float().sum() | |
if n_pos.item() > 0: | |
loss_pos = F.cross_entropy( | |
predict[pos], target[pos], reduction='sum') | |
loss_neg = F.cross_entropy( | |
predict[neg], target[neg], reduction='none') | |
n_neg = min( | |
int(neg.float().sum().item()), | |
int(self.ohem_ratio * n_pos.float())) | |
else: | |
loss_pos = torch.tensor(0.).to(device) | |
loss_neg = F.cross_entropy( | |
predict[neg], target[neg], reduction='none') | |
n_neg = 100 | |
if len(loss_neg) > n_neg: | |
loss_neg, _ = torch.topk(loss_neg, n_neg) | |
return (loss_pos + loss_neg.sum()) / (n_pos + n_neg).float() | |
def fourier2poly(self, real_maps, imag_maps): | |
"""Transform Fourier coefficient maps to polygon maps. | |
Args: | |
real_maps (tensor): A map composed of the real parts of the | |
Fourier coefficients, whose shape is (-1, 2k+1) | |
imag_maps (tensor):A map composed of the imag parts of the | |
Fourier coefficients, whose shape is (-1, 2k+1) | |
Returns | |
x_maps (tensor): A map composed of the x value of the polygon | |
represented by n sample points (xn, yn), whose shape is (-1, n) | |
y_maps (tensor): A map composed of the y value of the polygon | |
represented by n sample points (xn, yn), whose shape is (-1, n) | |
""" | |
device = real_maps.device | |
k_vect = torch.arange( | |
-self.fourier_degree, | |
self.fourier_degree + 1, | |
dtype=torch.float, | |
device=device).view(-1, 1) | |
i_vect = torch.arange( | |
0, self.num_sample, dtype=torch.float, device=device).view(1, -1) | |
transform_matrix = 2 * np.pi / self.num_sample * torch.mm( | |
k_vect, i_vect) | |
x1 = torch.einsum('ak, kn-> an', real_maps, | |
torch.cos(transform_matrix)) | |
x2 = torch.einsum('ak, kn-> an', imag_maps, | |
torch.sin(transform_matrix)) | |
y1 = torch.einsum('ak, kn-> an', real_maps, | |
torch.sin(transform_matrix)) | |
y2 = torch.einsum('ak, kn-> an', imag_maps, | |
torch.cos(transform_matrix)) | |
x_maps = x1 - x2 | |
y_maps = y1 + y2 | |
return x_maps, y_maps | |