Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import numpy as np | |
from mmcv.parallel import DataContainer as DC | |
from mmdet.datasets.builder import PIPELINES | |
from mmdet.datasets.pipelines.formating import DefaultFormatBundle | |
from mmocr.core.visualize import overlay_mask_img, show_feature | |
class CustomFormatBundle(DefaultFormatBundle): | |
"""Custom formatting bundle. | |
It formats common fields such as 'img' and 'proposals' as done in | |
DefaultFormatBundle, while other fields such as 'gt_kernels' and | |
'gt_effective_region_mask' will be formatted to DC as follows: | |
- gt_kernels: to DataContainer (cpu_only=True) | |
- gt_effective_mask: to DataContainer (cpu_only=True) | |
Args: | |
keys (list[str]): Fields to be formatted to DC only. | |
call_super (bool): If True, format common fields | |
by DefaultFormatBundle, else format fields in keys above only. | |
visualize (dict): If flag=True, visualize gt mask for debugging. | |
""" | |
def __init__(self, | |
keys=[], | |
call_super=True, | |
visualize=dict(flag=False, boundary_key=None)): | |
super().__init__() | |
self.visualize = visualize | |
self.keys = keys | |
self.call_super = call_super | |
def __call__(self, results): | |
if self.visualize['flag']: | |
img = results['img'].astype(np.uint8) | |
boundary_key = self.visualize['boundary_key'] | |
if boundary_key is not None: | |
img = overlay_mask_img(img, results[boundary_key].masks[0]) | |
features = [img] | |
names = ['img'] | |
to_uint8 = [1] | |
for k in results['mask_fields']: | |
for iter in range(len(results[k].masks)): | |
features.append(results[k].masks[iter]) | |
names.append(k + str(iter)) | |
to_uint8.append(0) | |
show_feature(features, names, to_uint8) | |
if self.call_super: | |
results = super().__call__(results) | |
for k in self.keys: | |
results[k] = DC(results[k], cpu_only=True) | |
return results | |
def __repr__(self): | |
return self.__class__.__name__ | |