MMOCR / mmocr /datasets /kie_dataset.py
tomofi's picture
Add application file
2366e36
raw
history blame
No virus
8.71 kB
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import warnings
from os import path as osp
import numpy as np
import torch
from mmdet.datasets.builder import DATASETS
from mmocr.core import compute_f1_score
from mmocr.datasets.base_dataset import BaseDataset
from mmocr.datasets.pipelines import sort_vertex8
from mmocr.utils import is_type_list, list_from_file
@DATASETS.register_module()
class KIEDataset(BaseDataset):
"""
Args:
ann_file (str): Annotation file path.
pipeline (list[dict]): Processing pipeline.
loader (dict): Dictionary to construct loader
to load annotation infos.
img_prefix (str, optional): Image prefix to generate full
image path.
test_mode (bool, optional): If True, try...except will
be turned off in __getitem__.
dict_file (str): Character dict file path.
norm (float): Norm to map value from one range to another.
"""
def __init__(self,
ann_file=None,
loader=None,
dict_file=None,
img_prefix='',
pipeline=None,
norm=10.,
directed=False,
test_mode=True,
**kwargs):
if ann_file is None and loader is None:
warnings.warn(
'KIEDataset is only initialized as a downstream demo task '
'of text detection and recognition '
'without an annotation file.', UserWarning)
else:
super().__init__(
ann_file,
loader,
pipeline,
img_prefix=img_prefix,
test_mode=test_mode)
assert osp.exists(dict_file)
self.norm = norm
self.directed = directed
self.dict = {
'': 0,
**{
line.rstrip('\r\n'): ind
for ind, line in enumerate(list_from_file(dict_file), 1)
}
}
def pre_pipeline(self, results):
results['img_prefix'] = self.img_prefix
results['bbox_fields'] = []
results['ori_texts'] = results['ann_info']['ori_texts']
results['filename'] = osp.join(self.img_prefix,
results['img_info']['filename'])
results['ori_filename'] = results['img_info']['filename']
# a dummy img data
results['img'] = np.zeros((0, 0, 0), dtype=np.uint8)
def _parse_anno_info(self, annotations):
"""Parse annotations of boxes, texts and labels for one image.
Args:
annotations (list[dict]): Annotations of one image, where
each dict is for one character.
Returns:
dict: A dict containing the following keys:
- bboxes (np.ndarray): Bbox in one image with shape:
box_num * 4. They are sorted clockwise when loading.
- relations (np.ndarray): Relations between bbox with shape:
box_num * box_num * D.
- texts (np.ndarray): Text index with shape:
box_num * text_max_len.
- labels (np.ndarray): Box Labels with shape:
box_num * (box_num + 1).
"""
assert is_type_list(annotations, dict)
assert len(annotations) > 0, 'Please remove data with empty annotation'
assert 'box' in annotations[0]
assert 'text' in annotations[0]
boxes, texts, text_inds, labels, edges = [], [], [], [], []
for ann in annotations:
box = ann['box']
sorted_box = sort_vertex8(box[:8])
boxes.append(sorted_box)
text = ann['text']
texts.append(ann['text'])
text_ind = [self.dict[c] for c in text if c in self.dict]
text_inds.append(text_ind)
labels.append(ann.get('label', 0))
edges.append(ann.get('edge', 0))
ann_infos = dict(
boxes=boxes,
texts=texts,
text_inds=text_inds,
edges=edges,
labels=labels)
return self.list_to_numpy(ann_infos)
def prepare_train_img(self, index):
"""Get training data and annotations from pipeline.
Args:
index (int): Index of data.
Returns:
dict: Training data and annotation after pipeline with new keys
introduced by pipeline.
"""
img_ann_info = self.data_infos[index]
img_info = {
'filename': img_ann_info['file_name'],
'height': img_ann_info['height'],
'width': img_ann_info['width']
}
ann_info = self._parse_anno_info(img_ann_info['annotations'])
results = dict(img_info=img_info, ann_info=ann_info)
self.pre_pipeline(results)
return self.pipeline(results)
def evaluate(self,
results,
metric='macro_f1',
metric_options=dict(macro_f1=dict(ignores=[])),
**kwargs):
# allow some kwargs to pass through
assert set(kwargs).issubset(['logger'])
# Protect ``metric_options`` since it uses mutable value as default
metric_options = copy.deepcopy(metric_options)
metrics = metric if isinstance(metric, list) else [metric]
allowed_metrics = ['macro_f1']
for m in metrics:
if m not in allowed_metrics:
raise KeyError(f'metric {m} is not supported')
return self.compute_macro_f1(results, **metric_options['macro_f1'])
def compute_macro_f1(self, results, ignores=[]):
node_preds = []
node_gts = []
for idx, result in enumerate(results):
node_preds.append(result['nodes'].cpu())
box_ann_infos = self.data_infos[idx]['annotations']
node_gt = [box_ann_info['label'] for box_ann_info in box_ann_infos]
node_gts.append(torch.Tensor(node_gt))
node_preds = torch.cat(node_preds)
node_gts = torch.cat(node_gts).int()
node_f1s = compute_f1_score(node_preds, node_gts, ignores)
return {
'macro_f1': node_f1s.mean(),
}
def list_to_numpy(self, ann_infos):
"""Convert bboxes, relations, texts and labels to ndarray."""
boxes, text_inds = ann_infos['boxes'], ann_infos['text_inds']
texts = ann_infos['texts']
boxes = np.array(boxes, np.int32)
relations, bboxes = self.compute_relation(boxes)
labels = ann_infos.get('labels', None)
if labels is not None:
labels = np.array(labels, np.int32)
edges = ann_infos.get('edges', None)
if edges is not None:
labels = labels[:, None]
edges = np.array(edges)
edges = (edges[:, None] == edges[None, :]).astype(np.int32)
if self.directed:
edges = (edges & labels == 1).astype(np.int32)
np.fill_diagonal(edges, -1)
labels = np.concatenate([labels, edges], -1)
padded_text_inds = self.pad_text_indices(text_inds)
return dict(
bboxes=bboxes,
relations=relations,
texts=padded_text_inds,
ori_texts=texts,
labels=labels)
def pad_text_indices(self, text_inds):
"""Pad text index to same length."""
max_len = max([len(text_ind) for text_ind in text_inds])
padded_text_inds = -np.ones((len(text_inds), max_len), np.int32)
for idx, text_ind in enumerate(text_inds):
padded_text_inds[idx, :len(text_ind)] = np.array(text_ind)
return padded_text_inds
def compute_relation(self, boxes):
"""Compute relation between every two boxes."""
# Get minimal axis-aligned bounding boxes for each of the boxes
# yapf: disable
bboxes = np.concatenate(
[boxes[:, 0::2].min(axis=1, keepdims=True),
boxes[:, 1::2].min(axis=1, keepdims=True),
boxes[:, 0::2].max(axis=1, keepdims=True),
boxes[:, 1::2].max(axis=1, keepdims=True)],
axis=1).astype(np.float32)
# yapf: enable
x1, y1 = bboxes[:, 0:1], bboxes[:, 1:2]
x2, y2 = bboxes[:, 2:3], bboxes[:, 3:4]
w, h = np.maximum(x2 - x1 + 1, 1), np.maximum(y2 - y1 + 1, 1)
dx = (x1.T - x1) / self.norm
dy = (y1.T - y1) / self.norm
xhh, xwh = h.T / h, w.T / h
whs = w / h + np.zeros_like(xhh)
relation = np.stack([dx, dy, whs, xhh, xwh], -1).astype(np.float32)
return relation, bboxes