gyrojeff commited on
Commit
964201e
1 Parent(s): 49159ce

feat: add detector data pipeline

Browse files
Files changed (1) hide show
  1. detector/data.py +125 -0
detector/data.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from font_dataset.fontlabel import FontLabel
2
+ from font_dataset.font import DSFont, load_font_with_exclusion
3
+ from . import config
4
+
5
+
6
+ import math
7
+ import os
8
+ import pickle
9
+ import torch
10
+ import torchvision.transforms as transforms
11
+ from typing import List, Dict, Tuple
12
+ from torch.utils.data import Dataset, DataLoader
13
+ from pytorch_lightning import LightningDataModule
14
+ from PIL import Image
15
+
16
+
17
+ class FontDataset(Dataset):
18
+ def __init__(self, path: str, config_path: str = "configs/font.yml"):
19
+ self.path = path
20
+ self.fonts = load_font_with_exclusion(config_path)
21
+
22
+ self.images = [
23
+ os.path.join(path, f) for f in os.listdir(path) if f.endswith(".jpg")
24
+ ]
25
+ self.images.sort()
26
+
27
+ def __len__(self):
28
+ return len(self.images)
29
+
30
+ def fontlabel2tensor(self, label: FontLabel, label_path) -> torch.Tensor:
31
+ out = torch.zeros(12, dtype=torch.float)
32
+ try:
33
+ out[0] = self.fonts[label.font.path]
34
+ except KeyError:
35
+ print(f"Unqualified font: {label.font.path}")
36
+ print(f"Label path: {label_path}")
37
+ raise KeyError
38
+ out[1] = 0 if label.text_direction == "ltr" else 1
39
+ # [0, 1]
40
+ out[2] = label.text_color[0] / 255.0
41
+ out[3] = label.text_color[1] / 255.0
42
+ out[4] = label.text_color[2] / 255.0
43
+ out[5] = label.text_size / label.image_width
44
+ out[6] = label.stroke_width / label.image_width
45
+ if label.stroke_color:
46
+ out[7] = label.stroke_color[0] / 255.0
47
+ out[8] = label.stroke_color[1] / 255.0
48
+ out[9] = label.stroke_color[2] / 255.0
49
+ else:
50
+ out[7:10] = 0.5
51
+ out[10] = label.line_spacing / label.image_width
52
+ out[11] = label.angle / 180.0 + 0.5
53
+
54
+ return out
55
+
56
+ def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
57
+ # Load image
58
+ image_path = self.images[index]
59
+ image = Image.open(image_path).convert("RGB")
60
+
61
+ transform = transforms.Compose(
62
+ [
63
+ transforms.Resize((config.INPUT_SIZE, config.INPUT_SIZE)),
64
+ transforms.ToTensor(),
65
+ ]
66
+ )
67
+ image = transform(image)
68
+
69
+ # Load label
70
+ label_path = image_path.replace(".jpg", ".bin")
71
+ with open(label_path, "rb") as f:
72
+ label: FontLabel = pickle.load(f)
73
+
74
+ # encode label
75
+ label = self.fontlabel2tensor(label, label_path)
76
+
77
+ return image, label
78
+
79
+
80
+ class FontDataModule(LightningDataModule):
81
+ def __init__(
82
+ self,
83
+ config_path: str = "configs/font.yml",
84
+ train_path: str = "./dataset/font_img/train",
85
+ val_path: str = "./dataset/font_img/train",
86
+ test_path: str = "./dataset/font_img/train",
87
+ train_shuffle: bool = True,
88
+ val_shuffle: bool = False,
89
+ test_shuffle: bool = False,
90
+ **kwargs,
91
+ ):
92
+ super().__init__()
93
+ self.dataloader_args = kwargs
94
+ self.train_shuffle = train_shuffle
95
+ self.val_shuffle = val_shuffle
96
+ self.test_shuffle = test_shuffle
97
+ self.train_dataset = FontDataset(train_path, config_path)
98
+ self.val_dataset = FontDataset(val_path, config_path)
99
+ self.test_dataset = FontDataset(test_path, config_path)
100
+
101
+ def get_train_num_iter(self, num_device: int) -> int:
102
+ return math.ceil(
103
+ len(self.train_dataset) / (self.dataloader_args["batch_size"] * num_device)
104
+ )
105
+
106
+ def train_dataloader(self):
107
+ return DataLoader(
108
+ self.train_dataset,
109
+ shuffle=self.train_shuffle,
110
+ **self.dataloader_args,
111
+ )
112
+
113
+ def val_dataloader(self):
114
+ return DataLoader(
115
+ self.val_dataset,
116
+ shuffle=self.val_shuffle,
117
+ **self.dataloader_args,
118
+ )
119
+
120
+ def test_dataloader(self):
121
+ return DataLoader(
122
+ self.test_dataset,
123
+ shuffle=self.test_shuffle,
124
+ **self.dataloader_args,
125
+ )