Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import warnings | |
import torch | |
import torch.nn as nn | |
from mmcv.runner import BaseModule, Sequential | |
from mmocr.models.builder import HEADS | |
from .head_mixin import HeadMixin | |
class DBHead(HeadMixin, BaseModule): | |
"""The class for DBNet head. | |
This was partially adapted from https://github.com/MhLiao/DB | |
Args: | |
in_channels (int): The number of input channels of the db head. | |
with_bias (bool): Whether add bias in Conv2d layer. | |
downsample_ratio (float): The downsample ratio of ground truths. | |
loss (dict): Config of loss for dbnet. | |
postprocessor (dict): Config of postprocessor for dbnet. | |
""" | |
def __init__( | |
self, | |
in_channels, | |
with_bias=False, | |
downsample_ratio=1.0, | |
loss=dict(type='DBLoss'), | |
postprocessor=dict(type='DBPostprocessor', text_repr_type='quad'), | |
init_cfg=[ | |
dict(type='Kaiming', layer='Conv'), | |
dict(type='Constant', layer='BatchNorm', val=1., bias=1e-4) | |
], | |
train_cfg=None, | |
test_cfg=None, | |
**kwargs): | |
old_keys = ['text_repr_type', 'decoding_type'] | |
for key in old_keys: | |
if kwargs.get(key, None): | |
postprocessor[key] = kwargs.get(key) | |
warnings.warn( | |
f'{key} is deprecated, please specify ' | |
'it in postprocessor config dict. See ' | |
'https://github.com/open-mmlab/mmocr/pull/640' | |
' for details.', UserWarning) | |
BaseModule.__init__(self, init_cfg=init_cfg) | |
HeadMixin.__init__(self, loss, postprocessor) | |
assert isinstance(in_channels, int) | |
self.in_channels = in_channels | |
self.train_cfg = train_cfg | |
self.test_cfg = test_cfg | |
self.downsample_ratio = downsample_ratio | |
self.binarize = Sequential( | |
nn.Conv2d( | |
in_channels, in_channels // 4, 3, bias=with_bias, padding=1), | |
nn.BatchNorm2d(in_channels // 4), nn.ReLU(inplace=True), | |
nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 2, 2), | |
nn.BatchNorm2d(in_channels // 4), nn.ReLU(inplace=True), | |
nn.ConvTranspose2d(in_channels // 4, 1, 2, 2), nn.Sigmoid()) | |
self.threshold = self._init_thr(in_channels) | |
def diff_binarize(self, prob_map, thr_map, k): | |
return torch.reciprocal(1.0 + torch.exp(-k * (prob_map - thr_map))) | |
def forward(self, inputs): | |
""" | |
Args: | |
inputs (Tensor): Shape (batch_size, hidden_size, h, w). | |
Returns: | |
Tensor: A tensor of the same shape as input. | |
""" | |
prob_map = self.binarize(inputs) | |
thr_map = self.threshold(inputs) | |
binary_map = self.diff_binarize(prob_map, thr_map, k=50) | |
outputs = torch.cat((prob_map, thr_map, binary_map), dim=1) | |
return outputs | |
def _init_thr(self, inner_channels, bias=False): | |
in_channels = inner_channels | |
seq = Sequential( | |
nn.Conv2d( | |
in_channels, inner_channels // 4, 3, padding=1, bias=bias), | |
nn.BatchNorm2d(inner_channels // 4), nn.ReLU(inplace=True), | |
nn.ConvTranspose2d(inner_channels // 4, inner_channels // 4, 2, 2), | |
nn.BatchNorm2d(inner_channels // 4), nn.ReLU(inplace=True), | |
nn.ConvTranspose2d(inner_channels // 4, 1, 2, 2), nn.Sigmoid()) | |
return seq | |