tomofi's picture
Add application file
2366e36
raw
history blame
10.7 kB
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.runner import BaseModule
from mmocr.models.builder import HEADS, build_loss
from mmocr.models.textdet.modules import GCN, LocalGraphs, ProposalLocalGraphs
from mmocr.utils import check_argument
from .head_mixin import HeadMixin
@HEADS.register_module()
class DRRGHead(HeadMixin, BaseModule):
"""The class for DRRG head: `Deep Relational Reasoning Graph Network for
Arbitrary Shape Text Detection <https://arxiv.org/abs/2003.07493>`_.
Args:
k_at_hops (tuple(int)): The number of i-hop neighbors, i = 1, 2.
num_adjacent_linkages (int): The number of linkages when constructing
adjacent matrix.
node_geo_feat_len (int): The length of embedded geometric feature
vector of a component.
pooling_scale (float): The spatial scale of rotated RoI-Align.
pooling_output_size (tuple(int)): The output size of RRoI-Aligning.
nms_thr (float): The locality-aware NMS threshold of text components.
min_width (float): The minimum width of text components.
max_width (float): The maximum width of text components.
comp_shrink_ratio (float): The shrink ratio of text components.
comp_ratio (float): The reciprocal of aspect ratio of text components.
comp_score_thr (float): The score threshold of text components.
text_region_thr (float): The threshold for text region probability map.
center_region_thr (float): The threshold for text center region
probability map.
center_region_area_thr (int): The threshold for filtering small-sized
text center region.
local_graph_thr (float): The threshold to filter identical local
graphs.
loss (dict): The config of loss that DRRGHead uses..
postprocessor (dict): Config of postprocessor for Drrg.
init_cfg (dict or list[dict], optional): Initialization configs.
"""
def __init__(self,
in_channels,
k_at_hops=(8, 4),
num_adjacent_linkages=3,
node_geo_feat_len=120,
pooling_scale=1.0,
pooling_output_size=(4, 3),
nms_thr=0.3,
min_width=8.0,
max_width=24.0,
comp_shrink_ratio=1.03,
comp_ratio=0.4,
comp_score_thr=0.3,
text_region_thr=0.2,
center_region_thr=0.2,
center_region_area_thr=50,
local_graph_thr=0.7,
loss=dict(type='DRRGLoss'),
postprocessor=dict(type='DRRGPostprocessor', link_thr=0.85),
train_cfg=None,
test_cfg=None,
init_cfg=dict(
type='Normal',
override=dict(name='out_conv'),
mean=0,
std=0.01),
**kwargs):
old_keys = ['text_repr_type', 'decoding_type', 'link_thr']
for key in old_keys:
if kwargs.get(key, None):
postprocessor[key] = kwargs.get(key)
warnings.warn(
f'{key} is deprecated, please specify '
'it in postprocessor config dict. See '
'https://github.com/open-mmlab/mmocr/pull/640'
' for details.', UserWarning)
BaseModule.__init__(self, init_cfg=init_cfg)
HeadMixin.__init__(self, loss, postprocessor)
assert isinstance(in_channels, int)
assert isinstance(k_at_hops, tuple)
assert isinstance(num_adjacent_linkages, int)
assert isinstance(node_geo_feat_len, int)
assert isinstance(pooling_scale, float)
assert isinstance(pooling_output_size, tuple)
assert isinstance(comp_shrink_ratio, float)
assert isinstance(nms_thr, float)
assert isinstance(min_width, float)
assert isinstance(max_width, float)
assert isinstance(comp_ratio, float)
assert isinstance(comp_score_thr, float)
assert isinstance(text_region_thr, float)
assert isinstance(center_region_thr, float)
assert isinstance(center_region_area_thr, int)
assert isinstance(local_graph_thr, float)
self.in_channels = in_channels
self.out_channels = 6
self.downsample_ratio = 1.0
self.k_at_hops = k_at_hops
self.num_adjacent_linkages = num_adjacent_linkages
self.node_geo_feat_len = node_geo_feat_len
self.pooling_scale = pooling_scale
self.pooling_output_size = pooling_output_size
self.comp_shrink_ratio = comp_shrink_ratio
self.nms_thr = nms_thr
self.min_width = min_width
self.max_width = max_width
self.comp_ratio = comp_ratio
self.comp_score_thr = comp_score_thr
self.text_region_thr = text_region_thr
self.center_region_thr = center_region_thr
self.center_region_area_thr = center_region_area_thr
self.local_graph_thr = local_graph_thr
self.loss_module = build_loss(loss)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.out_conv = nn.Conv2d(
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=1,
stride=1,
padding=0)
self.graph_train = LocalGraphs(self.k_at_hops,
self.num_adjacent_linkages,
self.node_geo_feat_len,
self.pooling_scale,
self.pooling_output_size,
self.local_graph_thr)
self.graph_test = ProposalLocalGraphs(
self.k_at_hops, self.num_adjacent_linkages, self.node_geo_feat_len,
self.pooling_scale, self.pooling_output_size, self.nms_thr,
self.min_width, self.max_width, self.comp_shrink_ratio,
self.comp_ratio, self.comp_score_thr, self.text_region_thr,
self.center_region_thr, self.center_region_area_thr)
pool_w, pool_h = self.pooling_output_size
node_feat_len = (pool_w * pool_h) * (
self.in_channels + self.out_channels) + self.node_geo_feat_len
self.gcn = GCN(node_feat_len)
def forward(self, inputs, gt_comp_attribs):
"""
Args:
inputs (Tensor): Shape of :math:`(N, C, H, W)`.
gt_comp_attribs (list[ndarray]): The padded text component
attributes. Shape: (num_component, 8).
Returns:
tuple: Returns (pred_maps, (gcn_pred, gt_labels)).
- | pred_maps (Tensor): Prediction map with shape
:math:`(N, C_{out}, H, W)`.
- | gcn_pred (Tensor): Prediction from GCN module, with
shape :math:`(N, 2)`.
- | gt_labels (Tensor): Ground-truth label with shape
:math:`(N, 8)`.
"""
pred_maps = self.out_conv(inputs)
feat_maps = torch.cat([inputs, pred_maps], dim=1)
node_feats, adjacent_matrices, knn_inds, gt_labels = self.graph_train(
feat_maps, np.stack(gt_comp_attribs))
gcn_pred = self.gcn(node_feats, adjacent_matrices, knn_inds)
return pred_maps, (gcn_pred, gt_labels)
def single_test(self, feat_maps):
r"""
Args:
feat_maps (Tensor): Shape of :math:`(N, C, H, W)`.
Returns:
tuple: Returns (edge, score, text_comps).
- | edge (ndarray): The edge array of shape :math:`(N, 2)`
where each row is a pair of text component indices
that makes up an edge in graph.
- | score (ndarray): The score array of shape :math:`(N,)`,
corresponding to the edge above.
- | text_comps (ndarray): The text components of shape
:math:`(N, 9)` where each row corresponds to one box and
its score: (x1, y1, x2, y2, x3, y3, x4, y4, score).
"""
pred_maps = self.out_conv(feat_maps)
feat_maps = torch.cat([feat_maps, pred_maps], dim=1)
none_flag, graph_data = self.graph_test(pred_maps, feat_maps)
(local_graphs_node_feat, adjacent_matrices, pivots_knn_inds,
pivot_local_graphs, text_comps) = graph_data
if none_flag:
return None, None, None
gcn_pred = self.gcn(local_graphs_node_feat, adjacent_matrices,
pivots_knn_inds)
pred_labels = F.softmax(gcn_pred, dim=1)
edges = []
scores = []
pivot_local_graphs = pivot_local_graphs.long().squeeze().cpu().numpy()
for pivot_ind, pivot_local_graph in enumerate(pivot_local_graphs):
pivot = pivot_local_graph[0]
for k_ind, neighbor_ind in enumerate(pivots_knn_inds[pivot_ind]):
neighbor = pivot_local_graph[neighbor_ind.item()]
edges.append([pivot, neighbor])
scores.append(
pred_labels[pivot_ind * pivots_knn_inds.shape[1] + k_ind,
1].item())
edges = np.asarray(edges)
scores = np.asarray(scores)
return edges, scores, text_comps
def get_boundary(self, edges, scores, text_comps, img_metas, rescale):
"""Compute text boundaries via post processing.
Args:
edges (ndarray): The edge array of shape N * 2, each row is a pair
of text component indices that makes up an edge in graph.
scores (ndarray): The edge score array.
text_comps (ndarray): The text components.
img_metas (list[dict]): The image meta infos.
rescale (bool): Rescale boundaries to the original image
resolution.
Returns:
dict: The result dict containing key `boundary_result`.
"""
assert check_argument.is_type_list(img_metas, dict)
assert isinstance(rescale, bool)
boundaries = []
if edges is not None:
boundaries = self.postprocessor(edges, scores, text_comps)
if rescale:
boundaries = self.resize_boundary(
boundaries,
1.0 / self.downsample_ratio / img_metas[0]['scale_factor'])
results = dict(boundary_result=boundaries)
return results