MMOCR / mmocr /models /textdet /modules /proposal_local_graph.py
tomofi's picture
Add application file
2366e36
raw
history blame
No virus
18.6 kB
# Copyright (c) OpenMMLab. All rights reserved.
import cv2
import numpy as np
import torch
from lanms import merge_quadrangle_n9 as la_nms
from mmcv.ops import RoIAlignRotated
from mmocr.models.textdet.postprocess.utils import fill_hole
from .utils import (euclidean_distance_matrix, feature_embedding,
normalize_adjacent_matrix)
class ProposalLocalGraphs:
"""Propose text components and generate local graphs for GCN to classify
the k-nearest neighbors of a pivot in DRRG: Deep Relational Reasoning Graph
Network for Arbitrary Shape Text Detection.
[https://arxiv.org/abs/2003.07493]. This code was partially adapted from
https://github.com/GXYM/DRRG licensed under the MIT license.
Args:
k_at_hops (tuple(int)): The number of i-hop neighbors, i = 1, 2.
num_adjacent_linkages (int): The number of linkages when constructing
adjacent matrix.
node_geo_feat_len (int): The length of embedded geometric feature
vector of a text component.
pooling_scale (float): The spatial scale of rotated RoI-Align.
pooling_output_size (tuple(int)): The output size of rotated RoI-Align.
nms_thr (float): The locality-aware NMS threshold for text components.
min_width (float): The minimum width of text components.
max_width (float): The maximum width of text components.
comp_shrink_ratio (float): The shrink ratio of text components.
comp_w_h_ratio (float): The width to height ratio of text components.
comp_score_thr (float): The score threshold of text component.
text_region_thr (float): The threshold for text region probability map.
center_region_thr (float): The threshold for text center region
probability map.
center_region_area_thr (int): The threshold for filtering small-sized
text center region.
"""
def __init__(self, k_at_hops, num_adjacent_linkages, node_geo_feat_len,
pooling_scale, pooling_output_size, nms_thr, min_width,
max_width, comp_shrink_ratio, comp_w_h_ratio, comp_score_thr,
text_region_thr, center_region_thr, center_region_area_thr):
assert len(k_at_hops) == 2
assert isinstance(k_at_hops, tuple)
assert isinstance(num_adjacent_linkages, int)
assert isinstance(node_geo_feat_len, int)
assert isinstance(pooling_scale, float)
assert isinstance(pooling_output_size, tuple)
assert isinstance(nms_thr, float)
assert isinstance(min_width, float)
assert isinstance(max_width, float)
assert isinstance(comp_shrink_ratio, float)
assert isinstance(comp_w_h_ratio, float)
assert isinstance(comp_score_thr, float)
assert isinstance(text_region_thr, float)
assert isinstance(center_region_thr, float)
assert isinstance(center_region_area_thr, int)
self.k_at_hops = k_at_hops
self.active_connection = num_adjacent_linkages
self.local_graph_depth = len(self.k_at_hops)
self.node_geo_feat_dim = node_geo_feat_len
self.pooling = RoIAlignRotated(pooling_output_size, pooling_scale)
self.nms_thr = nms_thr
self.min_width = min_width
self.max_width = max_width
self.comp_shrink_ratio = comp_shrink_ratio
self.comp_w_h_ratio = comp_w_h_ratio
self.comp_score_thr = comp_score_thr
self.text_region_thr = text_region_thr
self.center_region_thr = center_region_thr
self.center_region_area_thr = center_region_area_thr
def propose_comps(self, score_map, top_height_map, bot_height_map, sin_map,
cos_map, comp_score_thr, min_width, max_width,
comp_shrink_ratio, comp_w_h_ratio):
"""Propose text components.
Args:
score_map (ndarray): The score map for NMS.
top_height_map (ndarray): The predicted text height map from each
pixel in text center region to top sideline.
bot_height_map (ndarray): The predicted text height map from each
pixel in text center region to bottom sideline.
sin_map (ndarray): The predicted sin(theta) map.
cos_map (ndarray): The predicted cos(theta) map.
comp_score_thr (float): The score threshold of text component.
min_width (float): The minimum width of text components.
max_width (float): The maximum width of text components.
comp_shrink_ratio (float): The shrink ratio of text components.
comp_w_h_ratio (float): The width to height ratio of text
components.
Returns:
text_comps (ndarray): The text components.
"""
comp_centers = np.argwhere(score_map > comp_score_thr)
comp_centers = comp_centers[np.argsort(comp_centers[:, 0])]
y = comp_centers[:, 0]
x = comp_centers[:, 1]
top_height = top_height_map[y, x].reshape((-1, 1)) * comp_shrink_ratio
bot_height = bot_height_map[y, x].reshape((-1, 1)) * comp_shrink_ratio
sin = sin_map[y, x].reshape((-1, 1))
cos = cos_map[y, x].reshape((-1, 1))
top_mid_pts = comp_centers + np.hstack(
[top_height * sin, top_height * cos])
bot_mid_pts = comp_centers - np.hstack(
[bot_height * sin, bot_height * cos])
width = (top_height + bot_height) * comp_w_h_ratio
width = np.clip(width, min_width, max_width)
r = width / 2
tl = top_mid_pts[:, ::-1] - np.hstack([-r * sin, r * cos])
tr = top_mid_pts[:, ::-1] + np.hstack([-r * sin, r * cos])
br = bot_mid_pts[:, ::-1] + np.hstack([-r * sin, r * cos])
bl = bot_mid_pts[:, ::-1] - np.hstack([-r * sin, r * cos])
text_comps = np.hstack([tl, tr, br, bl]).astype(np.float32)
score = score_map[y, x].reshape((-1, 1))
text_comps = np.hstack([text_comps, score])
return text_comps
def propose_comps_and_attribs(self, text_region_map, center_region_map,
top_height_map, bot_height_map, sin_map,
cos_map):
"""Generate text components and attributes.
Args:
text_region_map (ndarray): The predicted text region probability
map.
center_region_map (ndarray): The predicted text center region
probability map.
top_height_map (ndarray): The predicted text height map from each
pixel in text center region to top sideline.
bot_height_map (ndarray): The predicted text height map from each
pixel in text center region to bottom sideline.
sin_map (ndarray): The predicted sin(theta) map.
cos_map (ndarray): The predicted cos(theta) map.
Returns:
comp_attribs (ndarray): The text component attributes.
text_comps (ndarray): The text components.
"""
assert (text_region_map.shape == center_region_map.shape ==
top_height_map.shape == bot_height_map.shape == sin_map.shape
== cos_map.shape)
text_mask = text_region_map > self.text_region_thr
center_region_mask = (center_region_map >
self.center_region_thr) * text_mask
scale = np.sqrt(1.0 / (sin_map**2 + cos_map**2 + 1e-8))
sin_map, cos_map = sin_map * scale, cos_map * scale
center_region_mask = fill_hole(center_region_mask)
center_region_contours, _ = cv2.findContours(
center_region_mask.astype(np.uint8), cv2.RETR_TREE,
cv2.CHAIN_APPROX_SIMPLE)
mask_sz = center_region_map.shape
comp_list = []
for contour in center_region_contours:
current_center_mask = np.zeros(mask_sz)
cv2.drawContours(current_center_mask, [contour], -1, 1, -1)
if current_center_mask.sum() <= self.center_region_area_thr:
continue
score_map = text_region_map * current_center_mask
text_comps = self.propose_comps(score_map, top_height_map,
bot_height_map, sin_map, cos_map,
self.comp_score_thr,
self.min_width, self.max_width,
self.comp_shrink_ratio,
self.comp_w_h_ratio)
text_comps = la_nms(text_comps, self.nms_thr)
text_comp_mask = np.zeros(mask_sz)
text_comp_boxes = text_comps[:, :8].reshape(
(-1, 4, 2)).astype(np.int32)
cv2.drawContours(text_comp_mask, text_comp_boxes, -1, 1, -1)
if (text_comp_mask * text_mask).sum() < text_comp_mask.sum() * 0.5:
continue
if text_comps.shape[-1] > 0:
comp_list.append(text_comps)
if len(comp_list) <= 0:
return None, None
text_comps = np.vstack(comp_list)
text_comp_boxes = text_comps[:, :8].reshape((-1, 4, 2))
centers = np.mean(text_comp_boxes, axis=1).astype(np.int32)
x = centers[:, 0]
y = centers[:, 1]
scores = []
for text_comp_box in text_comp_boxes:
text_comp_box[:, 0] = np.clip(text_comp_box[:, 0], 0,
mask_sz[1] - 1)
text_comp_box[:, 1] = np.clip(text_comp_box[:, 1], 0,
mask_sz[0] - 1)
min_coord = np.min(text_comp_box, axis=0).astype(np.int32)
max_coord = np.max(text_comp_box, axis=0).astype(np.int32)
text_comp_box = text_comp_box - min_coord
box_sz = (max_coord - min_coord + 1)
temp_comp_mask = np.zeros((box_sz[1], box_sz[0]), dtype=np.uint8)
cv2.fillPoly(temp_comp_mask, [text_comp_box.astype(np.int32)], 1)
temp_region_patch = text_region_map[min_coord[1]:(max_coord[1] +
1),
min_coord[0]:(max_coord[0] +
1)]
score = cv2.mean(temp_region_patch, temp_comp_mask)[0]
scores.append(score)
scores = np.array(scores).reshape((-1, 1))
text_comps = np.hstack([text_comps[:, :-1], scores])
h = top_height_map[y, x].reshape(
(-1, 1)) + bot_height_map[y, x].reshape((-1, 1))
w = np.clip(h * self.comp_w_h_ratio, self.min_width, self.max_width)
sin = sin_map[y, x].reshape((-1, 1))
cos = cos_map[y, x].reshape((-1, 1))
x = x.reshape((-1, 1))
y = y.reshape((-1, 1))
comp_attribs = np.hstack([x, y, h, w, cos, sin])
return comp_attribs, text_comps
def generate_local_graphs(self, sorted_dist_inds, node_feats):
"""Generate local graphs and graph convolution network input data.
Args:
sorted_dist_inds (ndarray): The node indices sorted according to
the Euclidean distance.
node_feats (tensor): The features of nodes in graph.
Returns:
local_graphs_node_feats (tensor): The features of nodes in local
graphs.
adjacent_matrices (tensor): The adjacent matrices.
pivots_knn_inds (tensor): The k-nearest neighbor indices in
local graphs.
pivots_local_graphs (tensor): The indices of nodes in local
graphs.
"""
assert sorted_dist_inds.ndim == 2
assert (sorted_dist_inds.shape[0] == sorted_dist_inds.shape[1] ==
node_feats.shape[0])
knn_graph = sorted_dist_inds[:, 1:self.k_at_hops[0] + 1]
pivot_local_graphs = []
pivot_knns = []
device = node_feats.device
for pivot_ind, knn in enumerate(knn_graph):
local_graph_neighbors = set(knn)
for neighbor_ind in knn:
local_graph_neighbors.update(
set(sorted_dist_inds[neighbor_ind,
1:self.k_at_hops[1] + 1]))
local_graph_neighbors.discard(pivot_ind)
pivot_local_graph = list(local_graph_neighbors)
pivot_local_graph.insert(0, pivot_ind)
pivot_knn = [pivot_ind] + list(knn)
pivot_local_graphs.append(pivot_local_graph)
pivot_knns.append(pivot_knn)
num_max_nodes = max([
len(pivot_local_graph) for pivot_local_graph in pivot_local_graphs
])
local_graphs_node_feat = []
adjacent_matrices = []
pivots_knn_inds = []
pivots_local_graphs = []
for graph_ind, pivot_knn in enumerate(pivot_knns):
pivot_local_graph = pivot_local_graphs[graph_ind]
num_nodes = len(pivot_local_graph)
pivot_ind = pivot_local_graph[0]
node2ind_map = {j: i for i, j in enumerate(pivot_local_graph)}
knn_inds = torch.tensor([node2ind_map[i]
for i in pivot_knn[1:]]).long().to(device)
pivot_feats = node_feats[pivot_ind]
normalized_feats = node_feats[pivot_local_graph] - pivot_feats
adjacent_matrix = np.zeros((num_nodes, num_nodes))
for node in pivot_local_graph:
neighbors = sorted_dist_inds[node,
1:self.active_connection + 1]
for neighbor in neighbors:
if neighbor in pivot_local_graph:
adjacent_matrix[node2ind_map[node],
node2ind_map[neighbor]] = 1
adjacent_matrix[node2ind_map[neighbor],
node2ind_map[node]] = 1
adjacent_matrix = normalize_adjacent_matrix(adjacent_matrix)
pad_adjacent_matrix = torch.zeros((num_max_nodes, num_max_nodes),
dtype=torch.float,
device=device)
pad_adjacent_matrix[:num_nodes, :num_nodes] = torch.from_numpy(
adjacent_matrix)
pad_normalized_feats = torch.cat([
normalized_feats,
torch.zeros(
(num_max_nodes - num_nodes, normalized_feats.shape[1]),
dtype=torch.float,
device=device)
],
dim=0)
local_graph_nodes = torch.tensor(pivot_local_graph)
local_graph_nodes = torch.cat([
local_graph_nodes,
torch.zeros(num_max_nodes - num_nodes, dtype=torch.long)
],
dim=-1)
local_graphs_node_feat.append(pad_normalized_feats)
adjacent_matrices.append(pad_adjacent_matrix)
pivots_knn_inds.append(knn_inds)
pivots_local_graphs.append(local_graph_nodes)
local_graphs_node_feat = torch.stack(local_graphs_node_feat, 0)
adjacent_matrices = torch.stack(adjacent_matrices, 0)
pivots_knn_inds = torch.stack(pivots_knn_inds, 0)
pivots_local_graphs = torch.stack(pivots_local_graphs, 0)
return (local_graphs_node_feat, adjacent_matrices, pivots_knn_inds,
pivots_local_graphs)
def __call__(self, preds, feat_maps):
"""Generate local graphs and graph convolutional network input data.
Args:
preds (tensor): The predicted maps.
feat_maps (tensor): The feature maps to extract content feature of
text components.
Returns:
none_flag (bool): The flag showing whether the number of proposed
text components is 0.
local_graphs_node_feats (tensor): The features of nodes in local
graphs.
adjacent_matrices (tensor): The adjacent matrices.
pivots_knn_inds (tensor): The k-nearest neighbor indices in
local graphs.
pivots_local_graphs (tensor): The indices of nodes in local
graphs.
text_comps (ndarray): The predicted text components.
"""
if preds.ndim == 4:
assert preds.shape[0] == 1
preds = torch.squeeze(preds)
pred_text_region = torch.sigmoid(preds[0]).data.cpu().numpy()
pred_center_region = torch.sigmoid(preds[1]).data.cpu().numpy()
pred_sin_map = preds[2].data.cpu().numpy()
pred_cos_map = preds[3].data.cpu().numpy()
pred_top_height_map = preds[4].data.cpu().numpy()
pred_bot_height_map = preds[5].data.cpu().numpy()
device = preds.device
comp_attribs, text_comps = self.propose_comps_and_attribs(
pred_text_region, pred_center_region, pred_top_height_map,
pred_bot_height_map, pred_sin_map, pred_cos_map)
if comp_attribs is None or len(comp_attribs) < 2:
none_flag = True
return none_flag, (0, 0, 0, 0, 0)
comp_centers = comp_attribs[:, 0:2]
distance_matrix = euclidean_distance_matrix(comp_centers, comp_centers)
geo_feats = feature_embedding(comp_attribs, self.node_geo_feat_dim)
geo_feats = torch.from_numpy(geo_feats).to(preds.device)
batch_id = np.zeros((comp_attribs.shape[0], 1), dtype=np.float32)
comp_attribs = comp_attribs.astype(np.float32)
angle = np.arccos(comp_attribs[:, -2]) * np.sign(comp_attribs[:, -1])
angle = angle.reshape((-1, 1))
rotated_rois = np.hstack([batch_id, comp_attribs[:, :-2], angle])
rois = torch.from_numpy(rotated_rois).to(device)
content_feats = self.pooling(feat_maps, rois)
content_feats = content_feats.view(content_feats.shape[0],
-1).to(device)
node_feats = torch.cat([content_feats, geo_feats], dim=-1)
sorted_dist_inds = np.argsort(distance_matrix, axis=1)
(local_graphs_node_feat, adjacent_matrices, pivots_knn_inds,
pivots_local_graphs) = self.generate_local_graphs(
sorted_dist_inds, node_feats)
none_flag = False
return none_flag, (local_graphs_node_feat, adjacent_matrices,
pivots_knn_inds, pivots_local_graphs, text_comps)