Arnaudding001 commited on
Commit
c642d93
1 Parent(s): 5914e7c

Create stylegan_prepare_data.py

Browse files
Files changed (1) hide show
  1. stylegan_prepare_data.py +105 -0
stylegan_prepare_data.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from io import BytesIO
3
+ import multiprocessing
4
+ from functools import partial
5
+
6
+ import os
7
+ from PIL import Image
8
+ import lmdb
9
+ from tqdm import tqdm
10
+ from torchvision import datasets
11
+ from torchvision.transforms import functional as trans_fn
12
+
13
+
14
+ def resize_and_convert(img, size, resample, quality=100):
15
+ img = trans_fn.resize(img, size, resample)
16
+ img = trans_fn.center_crop(img, size)
17
+ buffer = BytesIO()
18
+ img.save(buffer, format="jpeg", quality=quality)
19
+ val = buffer.getvalue()
20
+
21
+ return val
22
+
23
+
24
+ def resize_multiple(
25
+ img, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS, quality=100
26
+ ):
27
+ imgs = []
28
+
29
+ for size in sizes:
30
+ imgs.append(resize_and_convert(img, size, resample, quality))
31
+
32
+ return imgs
33
+
34
+
35
+ def resize_worker(img_file, sizes, resample):
36
+ i, file = img_file
37
+ img = Image.open(file)
38
+ img = img.convert("RGB")
39
+ out = resize_multiple(img, sizes=sizes, resample=resample)
40
+
41
+ return i, out
42
+
43
+
44
+ def prepare(
45
+ env, dataset, n_worker, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS
46
+ ):
47
+ resize_fn = partial(resize_worker, sizes=sizes, resample=resample)
48
+
49
+ files = sorted(dataset.imgs, key=lambda x: x[0])
50
+ files = [(i, file) for i, (file, label) in enumerate(files)]
51
+ total = 0
52
+
53
+ with multiprocessing.Pool(n_worker) as pool:
54
+ for i, imgs in tqdm(pool.imap_unordered(resize_fn, files)):
55
+ for size, img in zip(sizes, imgs):
56
+ key = f"{size}-{str(i).zfill(5)}".encode("utf-8")
57
+
58
+ with env.begin(write=True) as txn:
59
+ txn.put(key, img)
60
+
61
+ total += 1
62
+
63
+ with env.begin(write=True) as txn:
64
+ txn.put("length".encode("utf-8"), str(total).encode("utf-8"))
65
+
66
+
67
+ if __name__ == "__main__":
68
+ parser = argparse.ArgumentParser(description="Preprocess images for model training")
69
+ parser.add_argument("--out", type=str, help="filename of the result lmdb dataset")
70
+ parser.add_argument(
71
+ "--size",
72
+ type=str,
73
+ default="128,256,512,1024",
74
+ help="resolutions of images for the dataset",
75
+ )
76
+ parser.add_argument(
77
+ "--n_worker",
78
+ type=int,
79
+ default=8,
80
+ help="number of workers for preparing dataset",
81
+ )
82
+ parser.add_argument(
83
+ "--resample",
84
+ type=str,
85
+ default="lanczos",
86
+ help="resampling methods for resizing images",
87
+ )
88
+ parser.add_argument("path", type=str, help="path to the image dataset")
89
+
90
+ args = parser.parse_args()
91
+
92
+ if not os.path.exists(args.out):
93
+ os.makedirs(args.out)
94
+
95
+ resample_map = {"lanczos": Image.LANCZOS, "bilinear": Image.BILINEAR}
96
+ resample = resample_map[args.resample]
97
+
98
+ sizes = [int(s.strip()) for s in args.size.split(",")]
99
+
100
+ print(f"Make dataset of image sizes:", ", ".join(str(s) for s in sizes))
101
+
102
+ imgset = datasets.ImageFolder(args.path)
103
+
104
+ with lmdb.open(args.out, map_size=1024 ** 4, readahead=False) as env:
105
+ prepare(env, imgset, args.n_worker, sizes=sizes, resample=resample)