Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import copy | |
import unittest.mock as mock | |
import numpy as np | |
import pytest | |
from mmocr.datasets.pipelines import (OneOfWrapper, RandomWrapper, | |
TorchVisionWrapper) | |
from mmocr.datasets.pipelines.transforms import ColorJitter | |
def test_torchvision_wrapper(): | |
x = {'img': np.ones((128, 100, 3), dtype=np.uint8)} | |
# object not found error | |
with pytest.raises(Exception): | |
TorchVisionWrapper(op='NonExist') | |
with pytest.raises(TypeError): | |
TorchVisionWrapper() | |
f = TorchVisionWrapper('Grayscale') | |
with pytest.raises(AssertionError): | |
f({}) | |
results = f(x) | |
assert results['img'].shape == (128, 100) | |
assert results['img_shape'] == (128, 100) | |
def test_oneof(rand_choice): | |
color_jitter = dict(type='TorchVisionWrapper', op='ColorJitter') | |
gray_scale = dict(type='TorchVisionWrapper', op='Grayscale') | |
x = {'img': np.random.randint(0, 256, size=(128, 100, 3), dtype=np.uint8)} | |
f = OneOfWrapper([color_jitter, gray_scale]) | |
# Use color_jitter at the first call | |
rand_choice.side_effect = lambda x: x[0] | |
results = f(x) | |
assert results['img'].shape == (128, 100, 3) | |
# Use gray_scale at the second call | |
rand_choice.side_effect = lambda x: x[1] | |
results = f(x) | |
assert results['img'].shape == (128, 100) | |
# Passing object | |
f = OneOfWrapper([ColorJitter(), gray_scale]) | |
# Use color_jitter at the first call | |
results = f(x) | |
assert results['img'].shape == (128, 100) | |
# Test invalid inputs | |
with pytest.raises(AssertionError): | |
f = OneOfWrapper(None) | |
with pytest.raises(AssertionError): | |
f = OneOfWrapper([]) | |
with pytest.raises(AssertionError): | |
f = OneOfWrapper({}) | |
def test_runwithprob(np_random_uniform): | |
np_random_uniform.side_effect = [0.1, 0.9] | |
f = RandomWrapper([dict(type='TorchVisionWrapper', op='Grayscale')], 0.5) | |
img = np.random.randint(0, 256, size=(128, 100, 3), dtype=np.uint8) | |
results = f({'img': copy.deepcopy(img)}) | |
assert results['img'].shape == (128, 100) | |
results = f({'img': copy.deepcopy(img)}) | |
assert results['img'].shape == (128, 100, 3) | |