Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import imgaug | |
import imgaug.augmenters as iaa | |
import mmcv | |
import numpy as np | |
from mmdet.core.mask import PolygonMasks | |
from mmdet.datasets.builder import PIPELINES | |
class AugmenterBuilder: | |
"""Build imgaug object according ImgAug argmentations.""" | |
def __init__(self): | |
pass | |
def build(self, args, root=True): | |
if args is None: | |
return None | |
if isinstance(args, (int, float, str)): | |
return args | |
if isinstance(args, list): | |
if root: | |
sequence = [self.build(value, root=False) for value in args] | |
return iaa.Sequential(sequence) | |
arg_list = [self.to_tuple_if_list(a) for a in args[1:]] | |
return getattr(iaa, args[0])(*arg_list) | |
if isinstance(args, dict): | |
if 'cls' in args: | |
cls = getattr(iaa, args['cls']) | |
return cls( | |
**{ | |
k: self.to_tuple_if_list(v) | |
for k, v in args.items() if not k == 'cls' | |
}) | |
else: | |
return { | |
key: self.build(value, root=False) | |
for key, value in args.items() | |
} | |
raise RuntimeError('unknown augmenter arg: ' + str(args)) | |
def to_tuple_if_list(self, obj): | |
if isinstance(obj, list): | |
return tuple(obj) | |
return obj | |
class ImgAug: | |
"""A wrapper to use imgaug https://github.com/aleju/imgaug. | |
Args: | |
args ([list[list|dict]]): The argumentation list. For details, please | |
refer to imgaug document. Take args=[['Fliplr', 0.5], | |
dict(cls='Affine', rotate=[-10, 10]), ['Resize', [0.5, 3.0]]] as an | |
example. The args horizontally flip images with probability 0.5, | |
followed by random rotation with angles in range [-10, 10], and | |
resize with an independent scale in range [0.5, 3.0] for each | |
side of images. | |
""" | |
def __init__(self, args=None): | |
self.augmenter_args = args | |
self.augmenter = AugmenterBuilder().build(self.augmenter_args) | |
def __call__(self, results): | |
# img is bgr | |
image = results['img'] | |
aug = None | |
shape = image.shape | |
if self.augmenter: | |
aug = self.augmenter.to_deterministic() | |
results['img'] = aug.augment_image(image) | |
results['img_shape'] = results['img'].shape | |
results['flip'] = 'unknown' # it's unknown | |
results['flip_direction'] = 'unknown' # it's unknown | |
target_shape = results['img_shape'] | |
self.may_augment_annotation(aug, shape, target_shape, results) | |
return results | |
def may_augment_annotation(self, aug, shape, target_shape, results): | |
if aug is None: | |
return results | |
# augment polygon mask | |
for key in results['mask_fields']: | |
masks = self.may_augment_poly(aug, shape, results[key]) | |
if len(masks) > 0: | |
results[key] = PolygonMasks(masks, *target_shape[:2]) | |
# augment bbox | |
for key in results['bbox_fields']: | |
bboxes = self.may_augment_poly( | |
aug, shape, results[key], mask_flag=False) | |
results[key] = np.zeros(0) | |
if len(bboxes) > 0: | |
results[key] = np.stack(bboxes) | |
return results | |
def may_augment_poly(self, aug, img_shape, polys, mask_flag=True): | |
key_points, poly_point_nums = [], [] | |
for poly in polys: | |
if mask_flag: | |
poly = poly[0] | |
poly = poly.reshape(-1, 2) | |
key_points.extend([imgaug.Keypoint(p[0], p[1]) for p in poly]) | |
poly_point_nums.append(poly.shape[0]) | |
key_points = aug.augment_keypoints( | |
[imgaug.KeypointsOnImage(keypoints=key_points, | |
shape=img_shape)])[0].keypoints | |
new_polys = [] | |
start_idx = 0 | |
for poly_point_num in poly_point_nums: | |
new_poly = [] | |
for key_point in key_points[start_idx:(start_idx + | |
poly_point_num)]: | |
new_poly.append([key_point.x, key_point.y]) | |
start_idx += poly_point_num | |
new_poly = np.array(new_poly).flatten() | |
new_polys.append([new_poly] if mask_flag else new_poly) | |
return new_polys | |
def __repr__(self): | |
repr_str = self.__class__.__name__ | |
return repr_str | |
class EastRandomCrop: | |
def __init__(self, | |
target_size=(640, 640), | |
max_tries=10, | |
min_crop_side_ratio=0.1): | |
self.target_size = target_size | |
self.max_tries = max_tries | |
self.min_crop_side_ratio = min_crop_side_ratio | |
def __call__(self, results): | |
# sampling crop | |
# crop image, boxes, masks | |
img = results['img'] | |
crop_x, crop_y, crop_w, crop_h = self.crop_area( | |
img, results['gt_masks']) | |
scale_w = self.target_size[0] / crop_w | |
scale_h = self.target_size[1] / crop_h | |
scale = min(scale_w, scale_h) | |
h = int(crop_h * scale) | |
w = int(crop_w * scale) | |
padded_img = np.zeros( | |
(self.target_size[1], self.target_size[0], img.shape[2]), | |
img.dtype) | |
padded_img[:h, :w] = mmcv.imresize( | |
img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h)) | |
# for bboxes | |
for key in results['bbox_fields']: | |
lines = [] | |
for box in results[key]: | |
box = box.reshape(2, 2) | |
poly = ((box - (crop_x, crop_y)) * scale) | |
if not self.is_poly_outside_rect(poly, 0, 0, w, h): | |
lines.append(poly.flatten()) | |
results[key] = np.array(lines) | |
# for masks | |
for key in results['mask_fields']: | |
polys = [] | |
polys_label = [] | |
for poly in results[key]: | |
poly = np.array(poly).reshape(-1, 2) | |
poly = ((poly - (crop_x, crop_y)) * scale) | |
if not self.is_poly_outside_rect(poly, 0, 0, w, h): | |
polys.append([poly]) | |
polys_label.append(0) | |
results[key] = PolygonMasks(polys, *self.target_size) | |
if key == 'gt_masks': | |
results['gt_labels'] = polys_label | |
results['img'] = padded_img | |
results['img_shape'] = padded_img.shape | |
return results | |
def is_poly_in_rect(self, poly, x, y, w, h): | |
poly = np.array(poly) | |
if poly[:, 0].min() < x or poly[:, 0].max() > x + w: | |
return False | |
if poly[:, 1].min() < y or poly[:, 1].max() > y + h: | |
return False | |
return True | |
def is_poly_outside_rect(self, poly, x, y, w, h): | |
poly = np.array(poly).reshape(-1, 2) | |
if poly[:, 0].max() < x or poly[:, 0].min() > x + w: | |
return True | |
if poly[:, 1].max() < y or poly[:, 1].min() > y + h: | |
return True | |
return False | |
def split_regions(self, axis): | |
regions = [] | |
min_axis = 0 | |
for i in range(1, axis.shape[0]): | |
if axis[i] != axis[i - 1] + 1: | |
region = axis[min_axis:i] | |
min_axis = i | |
regions.append(region) | |
return regions | |
def random_select(self, axis, max_size): | |
xx = np.random.choice(axis, size=2) | |
xmin = np.min(xx) | |
xmax = np.max(xx) | |
xmin = np.clip(xmin, 0, max_size - 1) | |
xmax = np.clip(xmax, 0, max_size - 1) | |
return xmin, xmax | |
def region_wise_random_select(self, regions): | |
selected_index = list(np.random.choice(len(regions), 2)) | |
selected_values = [] | |
for index in selected_index: | |
axis = regions[index] | |
xx = int(np.random.choice(axis, size=1)) | |
selected_values.append(xx) | |
xmin = min(selected_values) | |
xmax = max(selected_values) | |
return xmin, xmax | |
def crop_area(self, img, polys): | |
h, w, _ = img.shape | |
h_array = np.zeros(h, dtype=np.int32) | |
w_array = np.zeros(w, dtype=np.int32) | |
for points in polys: | |
points = np.round( | |
points, decimals=0).astype(np.int32).reshape(-1, 2) | |
min_x = np.min(points[:, 0]) | |
max_x = np.max(points[:, 0]) | |
w_array[min_x:max_x] = 1 | |
min_y = np.min(points[:, 1]) | |
max_y = np.max(points[:, 1]) | |
h_array[min_y:max_y] = 1 | |
# ensure the cropped area not across a text | |
h_axis = np.where(h_array == 0)[0] | |
w_axis = np.where(w_array == 0)[0] | |
if len(h_axis) == 0 or len(w_axis) == 0: | |
return 0, 0, w, h | |
h_regions = self.split_regions(h_axis) | |
w_regions = self.split_regions(w_axis) | |
for i in range(self.max_tries): | |
if len(w_regions) > 1: | |
xmin, xmax = self.region_wise_random_select(w_regions) | |
else: | |
xmin, xmax = self.random_select(w_axis, w) | |
if len(h_regions) > 1: | |
ymin, ymax = self.region_wise_random_select(h_regions) | |
else: | |
ymin, ymax = self.random_select(h_axis, h) | |
if (xmax - xmin < self.min_crop_side_ratio * w | |
or ymax - ymin < self.min_crop_side_ratio * h): | |
# area too small | |
continue | |
num_poly_in_rect = 0 | |
for poly in polys: | |
if not self.is_poly_outside_rect(poly, xmin, ymin, xmax - xmin, | |
ymax - ymin): | |
num_poly_in_rect += 1 | |
break | |
if num_poly_in_rect > 0: | |
return xmin, ymin, xmax - xmin, ymax - ymin | |
return 0, 0, w, h | |