MMOCR / mmocr /datasets /pipelines /custom_format_bundle.py
tomofi's picture
Add application file
2366e36
raw
history blame
No virus
2.21 kB
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
from mmcv.parallel import DataContainer as DC
from mmdet.datasets.builder import PIPELINES
from mmdet.datasets.pipelines.formating import DefaultFormatBundle
from mmocr.core.visualize import overlay_mask_img, show_feature
@PIPELINES.register_module()
class CustomFormatBundle(DefaultFormatBundle):
"""Custom formatting bundle.
It formats common fields such as 'img' and 'proposals' as done in
DefaultFormatBundle, while other fields such as 'gt_kernels' and
'gt_effective_region_mask' will be formatted to DC as follows:
- gt_kernels: to DataContainer (cpu_only=True)
- gt_effective_mask: to DataContainer (cpu_only=True)
Args:
keys (list[str]): Fields to be formatted to DC only.
call_super (bool): If True, format common fields
by DefaultFormatBundle, else format fields in keys above only.
visualize (dict): If flag=True, visualize gt mask for debugging.
"""
def __init__(self,
keys=[],
call_super=True,
visualize=dict(flag=False, boundary_key=None)):
super().__init__()
self.visualize = visualize
self.keys = keys
self.call_super = call_super
def __call__(self, results):
if self.visualize['flag']:
img = results['img'].astype(np.uint8)
boundary_key = self.visualize['boundary_key']
if boundary_key is not None:
img = overlay_mask_img(img, results[boundary_key].masks[0])
features = [img]
names = ['img']
to_uint8 = [1]
for k in results['mask_fields']:
for iter in range(len(results[k].masks)):
features.append(results[k].masks[iter])
names.append(k + str(iter))
to_uint8.append(0)
show_feature(features, names, to_uint8)
if self.call_super:
results = super().__call__(results)
for k in self.keys:
results[k] = DC(results[k], cpu_only=True)
return results
def __repr__(self):
return self.__class__.__name__