gyrojeff commited on
Commit
a976004
1 Parent(s): 00a4b21

feat: add data augmentation

Browse files
Files changed (2) hide show
  1. detector/data.py +117 -14
  2. train.py +2 -0
detector/data.py CHANGED
@@ -5,20 +5,102 @@ from . import config
5
 
6
  import math
7
  import os
 
8
  import pickle
9
  import torch
10
  import torchvision.transforms as transforms
 
11
  from typing import List, Dict, Tuple
12
  from torch.utils.data import Dataset, DataLoader
13
  from pytorch_lightning import LightningDataModule
14
  from PIL import Image
15
 
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  class FontDataset(Dataset):
18
- def __init__(self, path: str, config_path: str = "configs/font.yml", regression_use_tanh: bool=False):
 
 
 
 
 
 
19
  self.path = path
20
  self.fonts = load_font_with_exclusion(config_path)
21
  self.regression_use_tanh = regression_use_tanh
 
22
 
23
  self.images = [
24
  os.path.join(path, f) for f in os.listdir(path) if f.endswith(".jpg")
@@ -51,9 +133,6 @@ class FontDataset(Dataset):
51
  out[7:10] = out[2:5]
52
  out[10] = label.line_spacing / label.image_width
53
  out[11] = label.angle / 180.0 + 0.5
54
-
55
- if self.regression_use_tanh:
56
- out[2:12] = out[2:12] * 2 - 1
57
 
58
  return out
59
 
@@ -62,6 +141,25 @@ class FontDataset(Dataset):
62
  image_path = self.images[index]
63
  image = Image.open(image_path).convert("RGB")
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  transform = transforms.Compose(
66
  [
67
  transforms.Resize((config.INPUT_SIZE, config.INPUT_SIZE)),
@@ -70,13 +168,9 @@ class FontDataset(Dataset):
70
  )
71
  image = transform(image)
72
 
73
- # Load label
74
- label_path = image_path.replace(".jpg", ".bin")
75
- with open(label_path, "rb") as f:
76
- label: FontLabel = pickle.load(f)
77
-
78
- # encode label
79
- label = self.fontlabel2tensor(label, label_path)
80
 
81
  return image, label
82
 
@@ -91,6 +185,9 @@ class FontDataModule(LightningDataModule):
91
  train_shuffle: bool = True,
92
  val_shuffle: bool = False,
93
  test_shuffle: bool = False,
 
 
 
94
  regression_use_tanh: bool = False,
95
  **kwargs,
96
  ):
@@ -99,9 +196,15 @@ class FontDataModule(LightningDataModule):
99
  self.train_shuffle = train_shuffle
100
  self.val_shuffle = val_shuffle
101
  self.test_shuffle = test_shuffle
102
- self.train_dataset = FontDataset(train_path, config_path, regression_use_tanh)
103
- self.val_dataset = FontDataset(val_path, config_path, regression_use_tanh)
104
- self.test_dataset = FontDataset(test_path, config_path, regression_use_tanh)
 
 
 
 
 
 
105
 
106
  def get_train_num_iter(self, num_device: int) -> int:
107
  return math.ceil(
 
5
 
6
  import math
7
  import os
8
+ import random
9
  import pickle
10
  import torch
11
  import torchvision.transforms as transforms
12
+ import torchvision.transforms.functional as TF
13
  from typing import List, Dict, Tuple
14
  from torch.utils.data import Dataset, DataLoader
15
  from pytorch_lightning import LightningDataModule
16
  from PIL import Image
17
 
18
 
19
+ class RandomColorJitter(object):
20
+ def __init__(
21
+ self, brightness=0.5, contrast=0.5, saturation=0.5, hue=0.05, preserve=0.2
22
+ ):
23
+ self.brightness = brightness
24
+ self.contrast = contrast
25
+ self.saturation = saturation
26
+ self.hue = hue
27
+ self.preserve = preserve
28
+
29
+ def __call__(self, batch):
30
+ if random.random() < self.preserve:
31
+ return batch
32
+
33
+ image, label = batch
34
+ text_color = label[2:5].clone().view(3, 1, 1)
35
+ stroke_color = label[7:10].clone().view(3, 1, 1)
36
+
37
+ brightness = random.uniform(1 - self.brightness, 1 + self.brightness)
38
+ image = TF.adjust_brightness(image, brightness)
39
+ text_color = TF.adjust_brightness(text_color, brightness)
40
+ stroke_color = TF.adjust_brightness(stroke_color, brightness)
41
+
42
+ contrast = random.uniform(1 - self.contrast, 1 + self.contrast)
43
+ image = TF.adjust_contrast(image, contrast)
44
+ text_color = TF.adjust_contrast(text_color, contrast)
45
+ stroke_color = TF.adjust_contrast(stroke_color, contrast)
46
+
47
+ saturation = random.uniform(1 - self.saturation, 1 + self.saturation)
48
+ image = TF.adjust_saturation(image, saturation)
49
+ text_color = TF.adjust_saturation(text_color, saturation)
50
+ stroke_color = TF.adjust_saturation(stroke_color, saturation)
51
+
52
+ hue = random.uniform(-self.hue, self.hue)
53
+ image = TF.adjust_hue(image, hue)
54
+ text_color = TF.adjust_hue(text_color, hue)
55
+ stroke_color = TF.adjust_hue(stroke_color, hue)
56
+
57
+ label[2:5] = text_color.view(3)
58
+ label[7:10] = stroke_color.view(3)
59
+ return image, label
60
+
61
+
62
+ class RandomCrop(object):
63
+ def __init__(self, crop_factor: float = 0.1, preserve: float = 0.2):
64
+ self.crop_factor = crop_factor
65
+ self.preserve = preserve
66
+
67
+ def __call__(self, batch):
68
+ if random.random() < self.preserve:
69
+ return batch
70
+
71
+ image, label = batch
72
+ width, height = image.size
73
+
74
+ # use random value to decide scaling factor on x and y axis
75
+ random_height = random.random() * self.crop_factor
76
+ random_width = random.random() * self.crop_factor
77
+ # use random value again to decide scaling factor for 4 borders
78
+ random_top = random.random() * random_height
79
+ random_left = random.random() * random_width
80
+ # calculate new width and height and position
81
+ top = int(random_top * height)
82
+ left = int(random_left * width)
83
+ height = int(height - random_height * height)
84
+ width = int(width - random_width * width)
85
+ # crop image
86
+ image = TF.crop(image, top, left, height, width)
87
+
88
+ label[[5, 6, 10]] = label[[5, 6, 10]] * (1 - random_height)
89
+ return image, label
90
+
91
+
92
  class FontDataset(Dataset):
93
+ def __init__(
94
+ self,
95
+ path: str,
96
+ config_path: str = "configs/font.yml",
97
+ regression_use_tanh: bool = False,
98
+ transforms: bool = False,
99
+ ):
100
  self.path = path
101
  self.fonts = load_font_with_exclusion(config_path)
102
  self.regression_use_tanh = regression_use_tanh
103
+ self.transforms = transforms
104
 
105
  self.images = [
106
  os.path.join(path, f) for f in os.listdir(path) if f.endswith(".jpg")
 
133
  out[7:10] = out[2:5]
134
  out[10] = label.line_spacing / label.image_width
135
  out[11] = label.angle / 180.0 + 0.5
 
 
 
136
 
137
  return out
138
 
 
141
  image_path = self.images[index]
142
  image = Image.open(image_path).convert("RGB")
143
 
144
+ # Load label
145
+ label_path = image_path.replace(".jpg", ".bin")
146
+ with open(label_path, "rb") as f:
147
+ label: FontLabel = pickle.load(f)
148
+
149
+ # encode label
150
+ label = self.fontlabel2tensor(label, label_path)
151
+
152
+ # data augmentation
153
+ if self.transforms:
154
+ transform = transforms.Compose(
155
+ [
156
+ RandomColorJitter(),
157
+ RandomCrop(),
158
+ ]
159
+ )
160
+ image, label = transform((image, label))
161
+
162
+ # resize and to tensor
163
  transform = transforms.Compose(
164
  [
165
  transforms.Resize((config.INPUT_SIZE, config.INPUT_SIZE)),
 
168
  )
169
  image = transform(image)
170
 
171
+ # normalize label
172
+ if self.regression_use_tanh:
173
+ label[2:12] = label[2:12] * 2 - 1
 
 
 
 
174
 
175
  return image, label
176
 
 
185
  train_shuffle: bool = True,
186
  val_shuffle: bool = False,
187
  test_shuffle: bool = False,
188
+ train_transforms: bool = False,
189
+ val_transforms: bool = False,
190
+ test_transforms: bool = False,
191
  regression_use_tanh: bool = False,
192
  **kwargs,
193
  ):
 
196
  self.train_shuffle = train_shuffle
197
  self.val_shuffle = val_shuffle
198
  self.test_shuffle = test_shuffle
199
+ self.train_dataset = FontDataset(
200
+ train_path, config_path, regression_use_tanh, train_transforms
201
+ )
202
+ self.val_dataset = FontDataset(
203
+ val_path, config_path, regression_use_tanh, val_transforms
204
+ )
205
+ self.test_dataset = FontDataset(
206
+ test_path, config_path, regression_use_tanh, test_transforms
207
+ )
208
 
209
  def get_train_num_iter(self, num_device: int) -> int:
210
  return math.ceil(
train.py CHANGED
@@ -31,6 +31,7 @@ lambda_direction = 0.5
31
  lambda_regression = 1.0
32
 
33
  regression_use_tanh = True
 
34
 
35
  num_warmup_epochs = 1
36
  num_epochs = 100
@@ -47,6 +48,7 @@ data_module = FontDataModule(
47
  val_shuffle=False,
48
  test_shuffle=False,
49
  regression_use_tanh=regression_use_tanh,
 
50
  )
51
 
52
  num_iters = data_module.get_train_num_iter(num_device) * num_epochs
 
31
  lambda_regression = 1.0
32
 
33
  regression_use_tanh = True
34
+ augmentation = True
35
 
36
  num_warmup_epochs = 1
37
  num_epochs = 100
 
48
  val_shuffle=False,
49
  test_shuffle=False,
50
  regression_use_tanh=regression_use_tanh,
51
+ train_transforms=augmentation,
52
  )
53
 
54
  num_iters = data_module.get_train_num_iter(num_device) * num_epochs