tomofi's picture
Add application file
2366e36
raw
history blame
No virus
3.55 kB
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import torch
import torch.nn as nn
from mmcv.runner import BaseModule, Sequential
from mmocr.models.builder import HEADS
from .head_mixin import HeadMixin
@HEADS.register_module()
class DBHead(HeadMixin, BaseModule):
"""The class for DBNet head.
This was partially adapted from https://github.com/MhLiao/DB
Args:
in_channels (int): The number of input channels of the db head.
with_bias (bool): Whether add bias in Conv2d layer.
downsample_ratio (float): The downsample ratio of ground truths.
loss (dict): Config of loss for dbnet.
postprocessor (dict): Config of postprocessor for dbnet.
"""
def __init__(
self,
in_channels,
with_bias=False,
downsample_ratio=1.0,
loss=dict(type='DBLoss'),
postprocessor=dict(type='DBPostprocessor', text_repr_type='quad'),
init_cfg=[
dict(type='Kaiming', layer='Conv'),
dict(type='Constant', layer='BatchNorm', val=1., bias=1e-4)
],
train_cfg=None,
test_cfg=None,
**kwargs):
old_keys = ['text_repr_type', 'decoding_type']
for key in old_keys:
if kwargs.get(key, None):
postprocessor[key] = kwargs.get(key)
warnings.warn(
f'{key} is deprecated, please specify '
'it in postprocessor config dict. See '
'https://github.com/open-mmlab/mmocr/pull/640'
' for details.', UserWarning)
BaseModule.__init__(self, init_cfg=init_cfg)
HeadMixin.__init__(self, loss, postprocessor)
assert isinstance(in_channels, int)
self.in_channels = in_channels
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.downsample_ratio = downsample_ratio
self.binarize = Sequential(
nn.Conv2d(
in_channels, in_channels // 4, 3, bias=with_bias, padding=1),
nn.BatchNorm2d(in_channels // 4), nn.ReLU(inplace=True),
nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 2, 2),
nn.BatchNorm2d(in_channels // 4), nn.ReLU(inplace=True),
nn.ConvTranspose2d(in_channels // 4, 1, 2, 2), nn.Sigmoid())
self.threshold = self._init_thr(in_channels)
def diff_binarize(self, prob_map, thr_map, k):
return torch.reciprocal(1.0 + torch.exp(-k * (prob_map - thr_map)))
def forward(self, inputs):
"""
Args:
inputs (Tensor): Shape (batch_size, hidden_size, h, w).
Returns:
Tensor: A tensor of the same shape as input.
"""
prob_map = self.binarize(inputs)
thr_map = self.threshold(inputs)
binary_map = self.diff_binarize(prob_map, thr_map, k=50)
outputs = torch.cat((prob_map, thr_map, binary_map), dim=1)
return outputs
def _init_thr(self, inner_channels, bias=False):
in_channels = inner_channels
seq = Sequential(
nn.Conv2d(
in_channels, inner_channels // 4, 3, padding=1, bias=bias),
nn.BatchNorm2d(inner_channels // 4), nn.ReLU(inplace=True),
nn.ConvTranspose2d(inner_channels // 4, inner_channels // 4, 2, 2),
nn.BatchNorm2d(inner_channels // 4), nn.ReLU(inplace=True),
nn.ConvTranspose2d(inner_channels // 4, 1, 2, 2), nn.Sigmoid())
return seq