fix: data augmentation
Browse files- detector/data.py +27 -18
detector/data.py
CHANGED
@@ -17,13 +17,19 @@ from PIL import Image
|
|
17 |
|
18 |
|
19 |
class RandomColorJitter(object):
|
20 |
-
def __init__(
|
|
|
|
|
21 |
self.brightness = brightness
|
22 |
self.contrast = contrast
|
23 |
self.saturation = saturation
|
24 |
self.hue = hue
|
|
|
25 |
|
26 |
def __call__(self, batch):
|
|
|
|
|
|
|
27 |
image, label = batch
|
28 |
text_color = label[2:5].clone().view(3, 1, 1)
|
29 |
stroke_color = label[7:10].clone().view(3, 1, 1)
|
@@ -54,10 +60,14 @@ class RandomColorJitter(object):
|
|
54 |
|
55 |
|
56 |
class RandomCrop(object):
|
57 |
-
def __init__(self, crop_factor: float = 0.1):
|
58 |
self.crop_factor = crop_factor
|
|
|
59 |
|
60 |
def __call__(self, batch):
|
|
|
|
|
|
|
61 |
image, label = batch
|
62 |
width, height = image.size
|
63 |
|
@@ -80,10 +90,14 @@ class RandomCrop(object):
|
|
80 |
|
81 |
|
82 |
class RandomRotate(object):
|
83 |
-
def __init__(self, max_angle: int = 15):
|
84 |
self.max_angle = max_angle
|
|
|
85 |
|
86 |
def __call__(self, batch):
|
|
|
|
|
|
|
87 |
image, label = batch
|
88 |
|
89 |
angle = random.uniform(-self.max_angle, self.max_angle)
|
@@ -177,8 +191,8 @@ class FontDataset(Dataset):
|
|
177 |
if self.transforms is not None:
|
178 |
transform = transforms.Compose(
|
179 |
[
|
180 |
-
|
181 |
-
|
182 |
]
|
183 |
)
|
184 |
image, label = transform((image, label))
|
@@ -210,20 +224,15 @@ class FontDataset(Dataset):
|
|
210 |
|
211 |
transform = transforms.Compose(
|
212 |
[
|
213 |
-
|
214 |
-
RandomCrop(crop_factor=0.54),
|
215 |
-
|
216 |
]
|
217 |
)
|
218 |
image, label = transform((image, label))
|
219 |
|
220 |
-
transform = transforms.
|
221 |
-
|
222 |
-
transforms.RandomApply(
|
223 |
-
transforms.GaussianBlur(random.randint(2, 5), sigma=(0.1, 5.0)),
|
224 |
-
p=0.8,
|
225 |
-
),
|
226 |
-
]
|
227 |
)
|
228 |
|
229 |
image = transform(image)
|
@@ -259,9 +268,9 @@ class FontDataModule(LightningDataModule):
|
|
259 |
train_shuffle: bool = True,
|
260 |
val_shuffle: bool = False,
|
261 |
test_shuffle: bool = False,
|
262 |
-
train_transforms: bool =
|
263 |
-
val_transforms: bool =
|
264 |
-
test_transforms: bool =
|
265 |
crop_roi_bbox: bool = False,
|
266 |
regression_use_tanh: bool = False,
|
267 |
**kwargs,
|
|
|
17 |
|
18 |
|
19 |
class RandomColorJitter(object):
|
20 |
+
def __init__(
|
21 |
+
self, brightness=0.5, contrast=0.5, saturation=0.5, hue=0.05, preserve=0.2
|
22 |
+
):
|
23 |
self.brightness = brightness
|
24 |
self.contrast = contrast
|
25 |
self.saturation = saturation
|
26 |
self.hue = hue
|
27 |
+
self.preserve = preserve
|
28 |
|
29 |
def __call__(self, batch):
|
30 |
+
if random.random() < self.preserve:
|
31 |
+
return batch
|
32 |
+
|
33 |
image, label = batch
|
34 |
text_color = label[2:5].clone().view(3, 1, 1)
|
35 |
stroke_color = label[7:10].clone().view(3, 1, 1)
|
|
|
60 |
|
61 |
|
62 |
class RandomCrop(object):
|
63 |
+
def __init__(self, crop_factor: float = 0.1, preserve: float = 0.2):
|
64 |
self.crop_factor = crop_factor
|
65 |
+
self.preserve = preserve
|
66 |
|
67 |
def __call__(self, batch):
|
68 |
+
if random.random() < self.preserve:
|
69 |
+
return batch
|
70 |
+
|
71 |
image, label = batch
|
72 |
width, height = image.size
|
73 |
|
|
|
90 |
|
91 |
|
92 |
class RandomRotate(object):
|
93 |
+
def __init__(self, max_angle: int = 15, preserve: float = 0.2):
|
94 |
self.max_angle = max_angle
|
95 |
+
self.preserve = preserve
|
96 |
|
97 |
def __call__(self, batch):
|
98 |
+
if random.random() < self.preserve:
|
99 |
+
return batch
|
100 |
+
|
101 |
image, label = batch
|
102 |
|
103 |
angle = random.uniform(-self.max_angle, self.max_angle)
|
|
|
191 |
if self.transforms is not None:
|
192 |
transform = transforms.Compose(
|
193 |
[
|
194 |
+
RandomColorJitter(preserve=0.2),
|
195 |
+
RandomCrop(preserve=0.2),
|
196 |
]
|
197 |
)
|
198 |
image, label = transform((image, label))
|
|
|
224 |
|
225 |
transform = transforms.Compose(
|
226 |
[
|
227 |
+
RandomColorJitter(preserve=0.2),
|
228 |
+
RandomCrop(crop_factor=0.54, preserve=0),
|
229 |
+
RandomRotate(preserve=0.2),
|
230 |
]
|
231 |
)
|
232 |
image, label = transform((image, label))
|
233 |
|
234 |
+
transform = transforms.GaussianBlur(
|
235 |
+
random.randint(1, 3) * 2 - 1, sigma=(0.1, 5.0)
|
|
|
|
|
|
|
|
|
|
|
236 |
)
|
237 |
|
238 |
image = transform(image)
|
|
|
268 |
train_shuffle: bool = True,
|
269 |
val_shuffle: bool = False,
|
270 |
test_shuffle: bool = False,
|
271 |
+
train_transforms: bool = None,
|
272 |
+
val_transforms: bool = None,
|
273 |
+
test_transforms: bool = None,
|
274 |
crop_roi_bbox: bool = False,
|
275 |
regression_use_tanh: bool = False,
|
276 |
**kwargs,
|