maduvantha commited on
Commit
fc34ca6
·
1 Parent(s): 2917b9c

Upload 42 files

Browse files
.dockerignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ /venv
2
+ .git
3
+ __pycache__
.gitattributes CHANGED
@@ -32,3 +32,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ sup-mat/absolute-demo.gif filter=lfs diff=lfs merge=lfs -text
36
+ sup-mat/face-swap.gif filter=lfs diff=lfs merge=lfs -text
37
+ sup-mat/fashion-teaser.gif filter=lfs diff=lfs merge=lfs -text
38
+ sup-mat/mgif-teaser.gif filter=lfs diff=lfs merge=lfs -text
39
+ sup-mat/relative-demo.gif filter=lfs diff=lfs merge=lfs -text
40
+ sup-mat/vox-teaser.gif filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ /.vscode
2
+ __pycache__
3
+ /venv
Dockerfile ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvcr.io/nvidia/cuda:10.0-cudnn7-runtime-ubuntu18.04
2
+
3
+ RUN DEBIAN_FRONTEND=noninteractive apt-get -qq update \
4
+ && DEBIAN_FRONTEND=noninteractive apt-get -qqy install python3-pip ffmpeg git less nano libsm6 libxext6 libxrender-dev \
5
+ && rm -rf /var/lib/apt/lists/*
6
+
7
+ COPY . /app/
8
+ WORKDIR /app
9
+
10
+ RUN pip3 install --upgrade pip
11
+ RUN pip3 install \
12
+ https://download.pytorch.org/whl/cu100/torch-1.0.0-cp36-cp36m-linux_x86_64.whl \
13
+ git+https://github.com/1adrianb/face-alignment \
14
+ -r requirements.txt
animate.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from tqdm import tqdm
3
+
4
+ import torch
5
+ from torch.utils.data import DataLoader
6
+
7
+ from frames_dataset import PairedDataset
8
+ from logger import Logger, Visualizer
9
+ import imageio
10
+ from scipy.spatial import ConvexHull
11
+ import numpy as np
12
+
13
+ from sync_batchnorm import DataParallelWithCallback
14
+
15
+
16
+ def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_movement_scale=False,
17
+ use_relative_movement=False, use_relative_jacobian=False):
18
+ if adapt_movement_scale:
19
+ source_area = ConvexHull(kp_source['value'][0].data.cpu().numpy()).volume
20
+ driving_area = ConvexHull(kp_driving_initial['value'][0].data.cpu().numpy()).volume
21
+ adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area)
22
+ else:
23
+ adapt_movement_scale = 1
24
+
25
+ kp_new = {k: v for k, v in kp_driving.items()}
26
+
27
+ if use_relative_movement:
28
+ kp_value_diff = (kp_driving['value'] - kp_driving_initial['value'])
29
+ kp_value_diff *= adapt_movement_scale
30
+ kp_new['value'] = kp_value_diff + kp_source['value']
31
+
32
+ if use_relative_jacobian:
33
+ jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian']))
34
+ kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian'])
35
+
36
+ return kp_new
37
+
38
+
39
+ def animate(config, generator, kp_detector, checkpoint, log_dir, dataset):
40
+ log_dir = os.path.join(log_dir, 'animation')
41
+ png_dir = os.path.join(log_dir, 'png')
42
+ animate_params = config['animate_params']
43
+
44
+ dataset = PairedDataset(initial_dataset=dataset, number_of_pairs=animate_params['num_pairs'])
45
+ dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
46
+
47
+ if checkpoint is not None:
48
+ Logger.load_cpk(checkpoint, generator=generator, kp_detector=kp_detector)
49
+ else:
50
+ raise AttributeError("Checkpoint should be specified for mode='animate'.")
51
+
52
+ if not os.path.exists(log_dir):
53
+ os.makedirs(log_dir)
54
+
55
+ if not os.path.exists(png_dir):
56
+ os.makedirs(png_dir)
57
+
58
+ if torch.cuda.is_available():
59
+ generator = DataParallelWithCallback(generator)
60
+ kp_detector = DataParallelWithCallback(kp_detector)
61
+
62
+ generator.eval()
63
+ kp_detector.eval()
64
+
65
+ for it, x in tqdm(enumerate(dataloader)):
66
+ with torch.no_grad():
67
+ predictions = []
68
+ visualizations = []
69
+
70
+ driving_video = x['driving_video']
71
+ source_frame = x['source_video'][:, :, 0, :, :]
72
+
73
+ kp_source = kp_detector(source_frame)
74
+ kp_driving_initial = kp_detector(driving_video[:, :, 0])
75
+
76
+ for frame_idx in range(driving_video.shape[2]):
77
+ driving_frame = driving_video[:, :, frame_idx]
78
+ kp_driving = kp_detector(driving_frame)
79
+ kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
80
+ kp_driving_initial=kp_driving_initial, **animate_params['normalization_params'])
81
+ out = generator(source_frame, kp_source=kp_source, kp_driving=kp_norm)
82
+
83
+ out['kp_driving'] = kp_driving
84
+ out['kp_source'] = kp_source
85
+ out['kp_norm'] = kp_norm
86
+
87
+ del out['sparse_deformed']
88
+
89
+ predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
90
+
91
+ visualization = Visualizer(**config['visualizer_params']).visualize(source=source_frame,
92
+ driving=driving_frame, out=out)
93
+ visualization = visualization
94
+ visualizations.append(visualization)
95
+
96
+ predictions = np.concatenate(predictions, axis=1)
97
+ result_name = "-".join([x['driving_name'][0], x['source_name'][0]])
98
+ imageio.imsave(os.path.join(png_dir, result_name + '.png'), (255 * predictions).astype(np.uint8))
99
+
100
+ image_name = result_name + animate_params['format']
101
+ imageio.mimsave(os.path.join(log_dir, image_name), visualizations)
augmentation.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code from https://github.com/hassony2/torch_videovision
3
+ """
4
+
5
+ import numbers
6
+
7
+ import random
8
+ import numpy as np
9
+ import PIL
10
+
11
+ from skimage.transform import resize, rotate
12
+ from skimage.util import pad
13
+ import torchvision
14
+
15
+ import warnings
16
+
17
+ from skimage import img_as_ubyte, img_as_float
18
+
19
+
20
+ def crop_clip(clip, min_h, min_w, h, w):
21
+ if isinstance(clip[0], np.ndarray):
22
+ cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip]
23
+
24
+ elif isinstance(clip[0], PIL.Image.Image):
25
+ cropped = [
26
+ img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip
27
+ ]
28
+ else:
29
+ raise TypeError('Expected numpy.ndarray or PIL.Image' +
30
+ 'but got list of {0}'.format(type(clip[0])))
31
+ return cropped
32
+
33
+
34
+ def pad_clip(clip, h, w):
35
+ im_h, im_w = clip[0].shape[:2]
36
+ pad_h = (0, 0) if h < im_h else ((h - im_h) // 2, (h - im_h + 1) // 2)
37
+ pad_w = (0, 0) if w < im_w else ((w - im_w) // 2, (w - im_w + 1) // 2)
38
+
39
+ return pad(clip, ((0, 0), pad_h, pad_w, (0, 0)), mode='edge')
40
+
41
+
42
+ def resize_clip(clip, size, interpolation='bilinear'):
43
+ if isinstance(clip[0], np.ndarray):
44
+ if isinstance(size, numbers.Number):
45
+ im_h, im_w, im_c = clip[0].shape
46
+ # Min spatial dim already matches minimal size
47
+ if (im_w <= im_h and im_w == size) or (im_h <= im_w
48
+ and im_h == size):
49
+ return clip
50
+ new_h, new_w = get_resize_sizes(im_h, im_w, size)
51
+ size = (new_w, new_h)
52
+ else:
53
+ size = size[1], size[0]
54
+
55
+ scaled = [
56
+ resize(img, size, order=1 if interpolation == 'bilinear' else 0, preserve_range=True,
57
+ mode='constant', anti_aliasing=True) for img in clip
58
+ ]
59
+ elif isinstance(clip[0], PIL.Image.Image):
60
+ if isinstance(size, numbers.Number):
61
+ im_w, im_h = clip[0].size
62
+ # Min spatial dim already matches minimal size
63
+ if (im_w <= im_h and im_w == size) or (im_h <= im_w
64
+ and im_h == size):
65
+ return clip
66
+ new_h, new_w = get_resize_sizes(im_h, im_w, size)
67
+ size = (new_w, new_h)
68
+ else:
69
+ size = size[1], size[0]
70
+ if interpolation == 'bilinear':
71
+ pil_inter = PIL.Image.NEAREST
72
+ else:
73
+ pil_inter = PIL.Image.BILINEAR
74
+ scaled = [img.resize(size, pil_inter) for img in clip]
75
+ else:
76
+ raise TypeError('Expected numpy.ndarray or PIL.Image' +
77
+ 'but got list of {0}'.format(type(clip[0])))
78
+ return scaled
79
+
80
+
81
+ def get_resize_sizes(im_h, im_w, size):
82
+ if im_w < im_h:
83
+ ow = size
84
+ oh = int(size * im_h / im_w)
85
+ else:
86
+ oh = size
87
+ ow = int(size * im_w / im_h)
88
+ return oh, ow
89
+
90
+
91
+ class RandomFlip(object):
92
+ def __init__(self, time_flip=False, horizontal_flip=False):
93
+ self.time_flip = time_flip
94
+ self.horizontal_flip = horizontal_flip
95
+
96
+ def __call__(self, clip):
97
+ if random.random() < 0.5 and self.time_flip:
98
+ return clip[::-1]
99
+ if random.random() < 0.5 and self.horizontal_flip:
100
+ return [np.fliplr(img) for img in clip]
101
+
102
+ return clip
103
+
104
+
105
+ class RandomResize(object):
106
+ """Resizes a list of (H x W x C) numpy.ndarray to the final size
107
+ The larger the original image is, the more times it takes to
108
+ interpolate
109
+ Args:
110
+ interpolation (str): Can be one of 'nearest', 'bilinear'
111
+ defaults to nearest
112
+ size (tuple): (widht, height)
113
+ """
114
+
115
+ def __init__(self, ratio=(3. / 4., 4. / 3.), interpolation='nearest'):
116
+ self.ratio = ratio
117
+ self.interpolation = interpolation
118
+
119
+ def __call__(self, clip):
120
+ scaling_factor = random.uniform(self.ratio[0], self.ratio[1])
121
+
122
+ if isinstance(clip[0], np.ndarray):
123
+ im_h, im_w, im_c = clip[0].shape
124
+ elif isinstance(clip[0], PIL.Image.Image):
125
+ im_w, im_h = clip[0].size
126
+
127
+ new_w = int(im_w * scaling_factor)
128
+ new_h = int(im_h * scaling_factor)
129
+ new_size = (new_w, new_h)
130
+ resized = resize_clip(
131
+ clip, new_size, interpolation=self.interpolation)
132
+
133
+ return resized
134
+
135
+
136
+ class RandomCrop(object):
137
+ """Extract random crop at the same location for a list of videos
138
+ Args:
139
+ size (sequence or int): Desired output size for the
140
+ crop in format (h, w)
141
+ """
142
+
143
+ def __init__(self, size):
144
+ if isinstance(size, numbers.Number):
145
+ size = (size, size)
146
+
147
+ self.size = size
148
+
149
+ def __call__(self, clip):
150
+ """
151
+ Args:
152
+ img (PIL.Image or numpy.ndarray): List of videos to be cropped
153
+ in format (h, w, c) in numpy.ndarray
154
+ Returns:
155
+ PIL.Image or numpy.ndarray: Cropped list of videos
156
+ """
157
+ h, w = self.size
158
+ if isinstance(clip[0], np.ndarray):
159
+ im_h, im_w, im_c = clip[0].shape
160
+ elif isinstance(clip[0], PIL.Image.Image):
161
+ im_w, im_h = clip[0].size
162
+ else:
163
+ raise TypeError('Expected numpy.ndarray or PIL.Image' +
164
+ 'but got list of {0}'.format(type(clip[0])))
165
+
166
+ clip = pad_clip(clip, h, w)
167
+ im_h, im_w = clip.shape[1:3]
168
+ x1 = 0 if h == im_h else random.randint(0, im_w - w)
169
+ y1 = 0 if w == im_w else random.randint(0, im_h - h)
170
+ cropped = crop_clip(clip, y1, x1, h, w)
171
+
172
+ return cropped
173
+
174
+
175
+ class RandomRotation(object):
176
+ """Rotate entire clip randomly by a random angle within
177
+ given bounds
178
+ Args:
179
+ degrees (sequence or int): Range of degrees to select from
180
+ If degrees is a number instead of sequence like (min, max),
181
+ the range of degrees, will be (-degrees, +degrees).
182
+ """
183
+
184
+ def __init__(self, degrees):
185
+ if isinstance(degrees, numbers.Number):
186
+ if degrees < 0:
187
+ raise ValueError('If degrees is a single number,'
188
+ 'must be positive')
189
+ degrees = (-degrees, degrees)
190
+ else:
191
+ if len(degrees) != 2:
192
+ raise ValueError('If degrees is a sequence,'
193
+ 'it must be of len 2.')
194
+
195
+ self.degrees = degrees
196
+
197
+ def __call__(self, clip):
198
+ """
199
+ Args:
200
+ img (PIL.Image or numpy.ndarray): List of videos to be cropped
201
+ in format (h, w, c) in numpy.ndarray
202
+ Returns:
203
+ PIL.Image or numpy.ndarray: Cropped list of videos
204
+ """
205
+ angle = random.uniform(self.degrees[0], self.degrees[1])
206
+ if isinstance(clip[0], np.ndarray):
207
+ rotated = [rotate(image=img, angle=angle, preserve_range=True) for img in clip]
208
+ elif isinstance(clip[0], PIL.Image.Image):
209
+ rotated = [img.rotate(angle) for img in clip]
210
+ else:
211
+ raise TypeError('Expected numpy.ndarray or PIL.Image' +
212
+ 'but got list of {0}'.format(type(clip[0])))
213
+
214
+ return rotated
215
+
216
+
217
+ class ColorJitter(object):
218
+ """Randomly change the brightness, contrast and saturation and hue of the clip
219
+ Args:
220
+ brightness (float): How much to jitter brightness. brightness_factor
221
+ is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
222
+ contrast (float): How much to jitter contrast. contrast_factor
223
+ is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
224
+ saturation (float): How much to jitter saturation. saturation_factor
225
+ is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
226
+ hue(float): How much to jitter hue. hue_factor is chosen uniformly from
227
+ [-hue, hue]. Should be >=0 and <= 0.5.
228
+ """
229
+
230
+ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
231
+ self.brightness = brightness
232
+ self.contrast = contrast
233
+ self.saturation = saturation
234
+ self.hue = hue
235
+
236
+ def get_params(self, brightness, contrast, saturation, hue):
237
+ if brightness > 0:
238
+ brightness_factor = random.uniform(
239
+ max(0, 1 - brightness), 1 + brightness)
240
+ else:
241
+ brightness_factor = None
242
+
243
+ if contrast > 0:
244
+ contrast_factor = random.uniform(
245
+ max(0, 1 - contrast), 1 + contrast)
246
+ else:
247
+ contrast_factor = None
248
+
249
+ if saturation > 0:
250
+ saturation_factor = random.uniform(
251
+ max(0, 1 - saturation), 1 + saturation)
252
+ else:
253
+ saturation_factor = None
254
+
255
+ if hue > 0:
256
+ hue_factor = random.uniform(-hue, hue)
257
+ else:
258
+ hue_factor = None
259
+ return brightness_factor, contrast_factor, saturation_factor, hue_factor
260
+
261
+ def __call__(self, clip):
262
+ """
263
+ Args:
264
+ clip (list): list of PIL.Image
265
+ Returns:
266
+ list PIL.Image : list of transformed PIL.Image
267
+ """
268
+ if isinstance(clip[0], np.ndarray):
269
+ brightness, contrast, saturation, hue = self.get_params(
270
+ self.brightness, self.contrast, self.saturation, self.hue)
271
+
272
+ # Create img transform function sequence
273
+ img_transforms = []
274
+ if brightness is not None:
275
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness))
276
+ if saturation is not None:
277
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation))
278
+ if hue is not None:
279
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue))
280
+ if contrast is not None:
281
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast))
282
+ random.shuffle(img_transforms)
283
+ img_transforms = [img_as_ubyte, torchvision.transforms.ToPILImage()] + img_transforms + [np.array,
284
+ img_as_float]
285
+
286
+ with warnings.catch_warnings():
287
+ warnings.simplefilter("ignore")
288
+ jittered_clip = []
289
+ for img in clip:
290
+ jittered_img = img
291
+ for func in img_transforms:
292
+ jittered_img = func(jittered_img)
293
+ jittered_clip.append(jittered_img.astype('float32'))
294
+ elif isinstance(clip[0], PIL.Image.Image):
295
+ brightness, contrast, saturation, hue = self.get_params(
296
+ self.brightness, self.contrast, self.saturation, self.hue)
297
+
298
+ # Create img transform function sequence
299
+ img_transforms = []
300
+ if brightness is not None:
301
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness))
302
+ if saturation is not None:
303
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation))
304
+ if hue is not None:
305
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue))
306
+ if contrast is not None:
307
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast))
308
+ random.shuffle(img_transforms)
309
+
310
+ # Apply to all videos
311
+ jittered_clip = []
312
+ for img in clip:
313
+ for func in img_transforms:
314
+ jittered_img = func(img)
315
+ jittered_clip.append(jittered_img)
316
+
317
+ else:
318
+ raise TypeError('Expected numpy.ndarray or PIL.Image' +
319
+ 'but got list of {0}'.format(type(clip[0])))
320
+ return jittered_clip
321
+
322
+
323
+ class AllAugmentationTransform:
324
+ def __init__(self, resize_param=None, rotation_param=None, flip_param=None, crop_param=None, jitter_param=None):
325
+ self.transforms = []
326
+
327
+ if flip_param is not None:
328
+ self.transforms.append(RandomFlip(**flip_param))
329
+
330
+ if rotation_param is not None:
331
+ self.transforms.append(RandomRotation(**rotation_param))
332
+
333
+ if resize_param is not None:
334
+ self.transforms.append(RandomResize(**resize_param))
335
+
336
+ if crop_param is not None:
337
+ self.transforms.append(RandomCrop(**crop_param))
338
+
339
+ if jitter_param is not None:
340
+ self.transforms.append(ColorJitter(**jitter_param))
341
+
342
+ def __call__(self, clip):
343
+ for t in self.transforms:
344
+ clip = t(clip)
345
+ return clip
config/bair-256.yaml ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_params:
2
+ root_dir: data/bair
3
+ frame_shape: [256, 256, 3]
4
+ id_sampling: False
5
+ augmentation_params:
6
+ flip_param:
7
+ horizontal_flip: True
8
+ time_flip: True
9
+ jitter_param:
10
+ brightness: 0.1
11
+ contrast: 0.1
12
+ saturation: 0.1
13
+ hue: 0.1
14
+
15
+
16
+ model_params:
17
+ common_params:
18
+ num_kp: 10
19
+ num_channels: 3
20
+ estimate_jacobian: True
21
+ kp_detector_params:
22
+ temperature: 0.1
23
+ block_expansion: 32
24
+ max_features: 1024
25
+ scale_factor: 0.25
26
+ num_blocks: 5
27
+ generator_params:
28
+ block_expansion: 64
29
+ max_features: 512
30
+ num_down_blocks: 2
31
+ num_bottleneck_blocks: 6
32
+ estimate_occlusion_map: True
33
+ dense_motion_params:
34
+ block_expansion: 64
35
+ max_features: 1024
36
+ num_blocks: 5
37
+ scale_factor: 0.25
38
+ discriminator_params:
39
+ scales: [1]
40
+ block_expansion: 32
41
+ max_features: 512
42
+ num_blocks: 4
43
+ sn: True
44
+
45
+ train_params:
46
+ num_epochs: 20
47
+ num_repeats: 1
48
+ epoch_milestones: [12, 18]
49
+ lr_generator: 2.0e-4
50
+ lr_discriminator: 2.0e-4
51
+ lr_kp_detector: 2.0e-4
52
+ batch_size: 36
53
+ scales: [1, 0.5, 0.25, 0.125]
54
+ checkpoint_freq: 10
55
+ transform_params:
56
+ sigma_affine: 0.05
57
+ sigma_tps: 0.005
58
+ points_tps: 5
59
+ loss_weights:
60
+ generator_gan: 1
61
+ discriminator_gan: 1
62
+ feature_matching: [10, 10, 10, 10]
63
+ perceptual: [10, 10, 10, 10, 10]
64
+ equivariance_value: 10
65
+ equivariance_jacobian: 10
66
+
67
+ reconstruction_params:
68
+ num_videos: 1000
69
+ format: '.mp4'
70
+
71
+ animate_params:
72
+ num_pairs: 50
73
+ format: '.mp4'
74
+ normalization_params:
75
+ adapt_movement_scale: False
76
+ use_relative_movement: True
77
+ use_relative_jacobian: True
78
+
79
+ visualizer_params:
80
+ kp_size: 5
81
+ draw_border: True
82
+ colormap: 'gist_rainbow'
config/fashion-256.yaml ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_params:
2
+ root_dir: data/fashion-png
3
+ frame_shape: [256, 256, 3]
4
+ id_sampling: False
5
+ augmentation_params:
6
+ flip_param:
7
+ horizontal_flip: True
8
+ time_flip: True
9
+ jitter_param:
10
+ hue: 0.1
11
+
12
+ model_params:
13
+ common_params:
14
+ num_kp: 10
15
+ num_channels: 3
16
+ estimate_jacobian: True
17
+ kp_detector_params:
18
+ temperature: 0.1
19
+ block_expansion: 32
20
+ max_features: 1024
21
+ scale_factor: 0.25
22
+ num_blocks: 5
23
+ generator_params:
24
+ block_expansion: 64
25
+ max_features: 512
26
+ num_down_blocks: 2
27
+ num_bottleneck_blocks: 6
28
+ estimate_occlusion_map: True
29
+ dense_motion_params:
30
+ block_expansion: 64
31
+ max_features: 1024
32
+ num_blocks: 5
33
+ scale_factor: 0.25
34
+ discriminator_params:
35
+ scales: [1]
36
+ block_expansion: 32
37
+ max_features: 512
38
+ num_blocks: 4
39
+
40
+ train_params:
41
+ num_epochs: 100
42
+ num_repeats: 50
43
+ epoch_milestones: [60, 90]
44
+ lr_generator: 2.0e-4
45
+ lr_discriminator: 2.0e-4
46
+ lr_kp_detector: 2.0e-4
47
+ batch_size: 27
48
+ scales: [1, 0.5, 0.25, 0.125]
49
+ checkpoint_freq: 50
50
+ transform_params:
51
+ sigma_affine: 0.05
52
+ sigma_tps: 0.005
53
+ points_tps: 5
54
+ loss_weights:
55
+ generator_gan: 1
56
+ discriminator_gan: 1
57
+ feature_matching: [10, 10, 10, 10]
58
+ perceptual: [10, 10, 10, 10, 10]
59
+ equivariance_value: 10
60
+ equivariance_jacobian: 10
61
+
62
+ reconstruction_params:
63
+ num_videos: 1000
64
+ format: '.mp4'
65
+
66
+ animate_params:
67
+ num_pairs: 50
68
+ format: '.mp4'
69
+ normalization_params:
70
+ adapt_movement_scale: False
71
+ use_relative_movement: True
72
+ use_relative_jacobian: True
73
+
74
+ visualizer_params:
75
+ kp_size: 5
76
+ draw_border: True
77
+ colormap: 'gist_rainbow'
config/mgif-256.yaml ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_params:
2
+ root_dir: data/moving-gif
3
+ frame_shape: [256, 256, 3]
4
+ id_sampling: False
5
+ augmentation_params:
6
+ flip_param:
7
+ horizontal_flip: True
8
+ time_flip: True
9
+ crop_param:
10
+ size: [256, 256]
11
+ resize_param:
12
+ ratio: [0.9, 1.1]
13
+ jitter_param:
14
+ hue: 0.5
15
+
16
+ model_params:
17
+ common_params:
18
+ num_kp: 10
19
+ num_channels: 3
20
+ estimate_jacobian: True
21
+ kp_detector_params:
22
+ temperature: 0.1
23
+ block_expansion: 32
24
+ max_features: 1024
25
+ scale_factor: 0.25
26
+ num_blocks: 5
27
+ single_jacobian_map: True
28
+ generator_params:
29
+ block_expansion: 64
30
+ max_features: 512
31
+ num_down_blocks: 2
32
+ num_bottleneck_blocks: 6
33
+ estimate_occlusion_map: True
34
+ dense_motion_params:
35
+ block_expansion: 64
36
+ max_features: 1024
37
+ num_blocks: 5
38
+ scale_factor: 0.25
39
+ discriminator_params:
40
+ scales: [1]
41
+ block_expansion: 32
42
+ max_features: 512
43
+ num_blocks: 4
44
+ sn: True
45
+
46
+ train_params:
47
+ num_epochs: 100
48
+ num_repeats: 25
49
+ epoch_milestones: [60, 90]
50
+ lr_generator: 2.0e-4
51
+ lr_discriminator: 2.0e-4
52
+ lr_kp_detector: 2.0e-4
53
+
54
+ batch_size: 36
55
+ scales: [1, 0.5, 0.25, 0.125]
56
+ checkpoint_freq: 100
57
+ transform_params:
58
+ sigma_affine: 0.05
59
+ sigma_tps: 0.005
60
+ points_tps: 5
61
+ loss_weights:
62
+ generator_gan: 1
63
+ discriminator_gan: 1
64
+ feature_matching: [10, 10, 10, 10]
65
+ perceptual: [10, 10, 10, 10, 10]
66
+ equivariance_value: 10
67
+ equivariance_jacobian: 10
68
+
69
+ reconstruction_params:
70
+ num_videos: 1000
71
+ format: '.mp4'
72
+
73
+ animate_params:
74
+ num_pairs: 50
75
+ format: '.mp4'
76
+ normalization_params:
77
+ adapt_movement_scale: False
78
+ use_relative_movement: True
79
+ use_relative_jacobian: True
80
+
81
+ visualizer_params:
82
+ kp_size: 5
83
+ draw_border: True
84
+ colormap: 'gist_rainbow'
config/nemo-256.yaml ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_params:
2
+ root_dir: data/nemo-png
3
+ frame_shape: [256, 256, 3]
4
+ id_sampling: False
5
+ augmentation_params:
6
+ flip_param:
7
+ horizontal_flip: True
8
+ time_flip: True
9
+
10
+ model_params:
11
+ common_params:
12
+ num_kp: 10
13
+ num_channels: 3
14
+ estimate_jacobian: True
15
+ kp_detector_params:
16
+ temperature: 0.1
17
+ block_expansion: 32
18
+ max_features: 1024
19
+ scale_factor: 0.25
20
+ num_blocks: 5
21
+ generator_params:
22
+ block_expansion: 64
23
+ max_features: 512
24
+ num_down_blocks: 2
25
+ num_bottleneck_blocks: 6
26
+ estimate_occlusion_map: True
27
+ dense_motion_params:
28
+ block_expansion: 64
29
+ max_features: 1024
30
+ num_blocks: 5
31
+ scale_factor: 0.25
32
+ discriminator_params:
33
+ scales: [1]
34
+ block_expansion: 32
35
+ max_features: 512
36
+ num_blocks: 4
37
+ sn: True
38
+
39
+ train_params:
40
+ num_epochs: 100
41
+ num_repeats: 8
42
+ epoch_milestones: [60, 90]
43
+ lr_generator: 2.0e-4
44
+ lr_discriminator: 2.0e-4
45
+ lr_kp_detector: 2.0e-4
46
+ batch_size: 36
47
+ scales: [1, 0.5, 0.25, 0.125]
48
+ checkpoint_freq: 50
49
+ transform_params:
50
+ sigma_affine: 0.05
51
+ sigma_tps: 0.005
52
+ points_tps: 5
53
+ loss_weights:
54
+ generator_gan: 1
55
+ discriminator_gan: 1
56
+ feature_matching: [10, 10, 10, 10]
57
+ perceptual: [10, 10, 10, 10, 10]
58
+ equivariance_value: 10
59
+ equivariance_jacobian: 10
60
+
61
+ reconstruction_params:
62
+ num_videos: 1000
63
+ format: '.mp4'
64
+
65
+ animate_params:
66
+ num_pairs: 50
67
+ format: '.mp4'
68
+ normalization_params:
69
+ adapt_movement_scale: False
70
+ use_relative_movement: True
71
+ use_relative_jacobian: True
72
+
73
+ visualizer_params:
74
+ kp_size: 5
75
+ draw_border: True
76
+ colormap: 'gist_rainbow'
config/taichi-256.yaml ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dataset parameters
2
+ # Each dataset should contain 2 folders train and test
3
+ # Each video can be represented as:
4
+ # - an image of concatenated frames
5
+ # - '.mp4' or '.gif'
6
+ # - folder with all frames from a specific video
7
+ # In case of Taichi. Same (youtube) video can be splitted in many parts (chunks). Each part has a following
8
+ # format (id)#other#info.mp4. For example '12335#adsbf.mp4' has an id 12335. In case of TaiChi id stands for youtube
9
+ # video id.
10
+ dataset_params:
11
+ # Path to data, data can be stored in several formats: .mp4 or .gif videos, stacked .png images or folders with frames.
12
+ root_dir: data/taichi-png
13
+ # Image shape, needed for staked .png format.
14
+ frame_shape: [256, 256, 3]
15
+ # In case of TaiChi single video can be splitted in many chunks, or the maybe several videos for single person.
16
+ # In this case epoch can be a pass over different videos (if id_sampling=True) or over different chunks (if id_sampling=False)
17
+ # If the name of the video '12335#adsbf.mp4' the id is assumed to be 12335
18
+ id_sampling: True
19
+ # List with pairs for animation, None for random pairs
20
+ pairs_list: data/taichi256.csv
21
+ # Augmentation parameters see augmentation.py for all posible augmentations
22
+ augmentation_params:
23
+ flip_param:
24
+ horizontal_flip: True
25
+ time_flip: True
26
+ jitter_param:
27
+ brightness: 0.1
28
+ contrast: 0.1
29
+ saturation: 0.1
30
+ hue: 0.1
31
+
32
+ # Defines model architecture
33
+ model_params:
34
+ common_params:
35
+ # Number of keypoint
36
+ num_kp: 10
37
+ # Number of channels per image
38
+ num_channels: 3
39
+ # Using first or zero order model
40
+ estimate_jacobian: True
41
+ kp_detector_params:
42
+ # Softmax temperature for keypoint heatmaps
43
+ temperature: 0.1
44
+ # Number of features mutliplier
45
+ block_expansion: 32
46
+ # Maximum allowed number of features
47
+ max_features: 1024
48
+ # Number of block in Unet. Can be increased or decreased depending or resolution.
49
+ num_blocks: 5
50
+ # Keypioint is predicted on smaller images for better performance,
51
+ # scale_factor=0.25 means that 256x256 image will be resized to 64x64
52
+ scale_factor: 0.25
53
+ generator_params:
54
+ # Number of features mutliplier
55
+ block_expansion: 64
56
+ # Maximum allowed number of features
57
+ max_features: 512
58
+ # Number of downsampling blocks in Jonson architecture.
59
+ # Can be increased or decreased depending or resolution.
60
+ num_down_blocks: 2
61
+ # Number of ResBlocks in Jonson architecture.
62
+ num_bottleneck_blocks: 6
63
+ # Use occlusion map or not
64
+ estimate_occlusion_map: True
65
+
66
+ dense_motion_params:
67
+ # Number of features mutliplier
68
+ block_expansion: 64
69
+ # Maximum allowed number of features
70
+ max_features: 1024
71
+ # Number of block in Unet. Can be increased or decreased depending or resolution.
72
+ num_blocks: 5
73
+ # Dense motion is predicted on smaller images for better performance,
74
+ # scale_factor=0.25 means that 256x256 image will be resized to 64x64
75
+ scale_factor: 0.25
76
+ discriminator_params:
77
+ # Discriminator can be multiscale, if you want 2 discriminator on original
78
+ # resolution and half of the original, specify scales: [1, 0.5]
79
+ scales: [1]
80
+ # Number of features mutliplier
81
+ block_expansion: 32
82
+ # Maximum allowed number of features
83
+ max_features: 512
84
+ # Number of blocks. Can be increased or decreased depending or resolution.
85
+ num_blocks: 4
86
+
87
+ # Parameters of training
88
+ train_params:
89
+ # Number of training epochs
90
+ num_epochs: 100
91
+ # For better i/o performance when number of videos is small number of epochs can be multiplied by this number.
92
+ # Thus effectivlly with num_repeats=100 each epoch is 100 times larger.
93
+ num_repeats: 150
94
+ # Drop learning rate by 10 times after this epochs
95
+ epoch_milestones: [60, 90]
96
+ # Initial learing rate for all modules
97
+ lr_generator: 2.0e-4
98
+ lr_discriminator: 2.0e-4
99
+ lr_kp_detector: 2.0e-4
100
+ batch_size: 30
101
+ # Scales for perceptual pyramide loss. If scales = [1, 0.5, 0.25, 0.125] and image resolution is 256x256,
102
+ # than the loss will be computer on resolutions 256x256, 128x128, 64x64, 32x32.
103
+ scales: [1, 0.5, 0.25, 0.125]
104
+ # Save checkpoint this frequently. If checkpoint_freq=50, checkpoint will be saved every 50 epochs.
105
+ checkpoint_freq: 50
106
+ # Parameters of transform for equivariance loss
107
+ transform_params:
108
+ # Sigma for affine part
109
+ sigma_affine: 0.05
110
+ # Sigma for deformation part
111
+ sigma_tps: 0.005
112
+ # Number of point in the deformation grid
113
+ points_tps: 5
114
+ loss_weights:
115
+ # Weight for LSGAN loss in generator, 0 for no adversarial loss.
116
+ generator_gan: 0
117
+ # Weight for LSGAN loss in discriminator
118
+ discriminator_gan: 1
119
+ # Weights for feature matching loss, the number should be the same as number of blocks in discriminator.
120
+ feature_matching: [10, 10, 10, 10]
121
+ # Weights for perceptual loss.
122
+ perceptual: [10, 10, 10, 10, 10]
123
+ # Weights for value equivariance.
124
+ equivariance_value: 10
125
+ # Weights for jacobian equivariance.
126
+ equivariance_jacobian: 10
127
+
128
+ # Parameters of reconstruction
129
+ reconstruction_params:
130
+ # Maximum number of videos for reconstruction
131
+ num_videos: 1000
132
+ # Format for visualization, note that results will be also stored in staked .png.
133
+ format: '.mp4'
134
+
135
+ # Parameters of animation
136
+ animate_params:
137
+ # Maximum number of pairs for animation, the pairs will be either taken from pairs_list or random.
138
+ num_pairs: 50
139
+ # Format for visualization, note that results will be also stored in staked .png.
140
+ format: '.mp4'
141
+ # Normalization of diriving keypoints
142
+ normalization_params:
143
+ # Increase or decrease relative movement scale depending on the size of the object
144
+ adapt_movement_scale: False
145
+ # Apply only relative displacement of the keypoint
146
+ use_relative_movement: True
147
+ # Apply only relative change in jacobian
148
+ use_relative_jacobian: True
149
+
150
+ # Visualization parameters
151
+ visualizer_params:
152
+ # Draw keypoints of this size, increase or decrease depending on resolution
153
+ kp_size: 5
154
+ # Draw white border around images
155
+ draw_border: True
156
+ # Color map for keypoints
157
+ colormap: 'gist_rainbow'
config/taichi-adv-256.yaml ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dataset parameters
2
+ dataset_params:
3
+ # Path to data, data can be stored in several formats: .mp4 or .gif videos, stacked .png images or folders with frames.
4
+ root_dir: data/taichi-png
5
+ # Image shape, needed for staked .png format.
6
+ frame_shape: [256, 256, 3]
7
+ # In case of TaiChi single video can be splitted in many chunks, or the maybe several videos for single person.
8
+ # In this case epoch can be a pass over different videos (if id_sampling=True) or over different chunks (if id_sampling=False)
9
+ # If the name of the video '12335#adsbf.mp4' the id is assumed to be 12335
10
+ id_sampling: True
11
+ # List with pairs for animation, None for random pairs
12
+ pairs_list: data/taichi256.csv
13
+ # Augmentation parameters see augmentation.py for all posible augmentations
14
+ augmentation_params:
15
+ flip_param:
16
+ horizontal_flip: True
17
+ time_flip: True
18
+ jitter_param:
19
+ brightness: 0.1
20
+ contrast: 0.1
21
+ saturation: 0.1
22
+ hue: 0.1
23
+
24
+ # Defines model architecture
25
+ model_params:
26
+ common_params:
27
+ # Number of keypoint
28
+ num_kp: 10
29
+ # Number of channels per image
30
+ num_channels: 3
31
+ # Using first or zero order model
32
+ estimate_jacobian: True
33
+ kp_detector_params:
34
+ # Softmax temperature for keypoint heatmaps
35
+ temperature: 0.1
36
+ # Number of features mutliplier
37
+ block_expansion: 32
38
+ # Maximum allowed number of features
39
+ max_features: 1024
40
+ # Number of block in Unet. Can be increased or decreased depending or resolution.
41
+ num_blocks: 5
42
+ # Keypioint is predicted on smaller images for better performance,
43
+ # scale_factor=0.25 means that 256x256 image will be resized to 64x64
44
+ scale_factor: 0.25
45
+ generator_params:
46
+ # Number of features mutliplier
47
+ block_expansion: 64
48
+ # Maximum allowed number of features
49
+ max_features: 512
50
+ # Number of downsampling blocks in Jonson architecture.
51
+ # Can be increased or decreased depending or resolution.
52
+ num_down_blocks: 2
53
+ # Number of ResBlocks in Jonson architecture.
54
+ num_bottleneck_blocks: 6
55
+ # Use occlusion map or not
56
+ estimate_occlusion_map: True
57
+
58
+ dense_motion_params:
59
+ # Number of features mutliplier
60
+ block_expansion: 64
61
+ # Maximum allowed number of features
62
+ max_features: 1024
63
+ # Number of block in Unet. Can be increased or decreased depending or resolution.
64
+ num_blocks: 5
65
+ # Dense motion is predicted on smaller images for better performance,
66
+ # scale_factor=0.25 means that 256x256 image will be resized to 64x64
67
+ scale_factor: 0.25
68
+ discriminator_params:
69
+ # Discriminator can be multiscale, if you want 2 discriminator on original
70
+ # resolution and half of the original, specify scales: [1, 0.5]
71
+ scales: [1]
72
+ # Number of features mutliplier
73
+ block_expansion: 32
74
+ # Maximum allowed number of features
75
+ max_features: 512
76
+ # Number of blocks. Can be increased or decreased depending or resolution.
77
+ num_blocks: 4
78
+ use_kp: True
79
+
80
+ # Parameters of training
81
+ train_params:
82
+ # Number of training epochs
83
+ num_epochs: 150
84
+ # For better i/o performance when number of videos is small number of epochs can be multiplied by this number.
85
+ # Thus effectivlly with num_repeats=100 each epoch is 100 times larger.
86
+ num_repeats: 150
87
+ # Drop learning rate by 10 times after this epochs
88
+ epoch_milestones: []
89
+ # Initial learing rate for all modules
90
+ lr_generator: 2.0e-4
91
+ lr_discriminator: 2.0e-4
92
+ lr_kp_detector: 0
93
+ batch_size: 27
94
+ # Scales for perceptual pyramide loss. If scales = [1, 0.5, 0.25, 0.125] and image resolution is 256x256,
95
+ # than the loss will be computer on resolutions 256x256, 128x128, 64x64, 32x32.
96
+ scales: [1, 0.5, 0.25, 0.125]
97
+ # Save checkpoint this frequently. If checkpoint_freq=50, checkpoint will be saved every 50 epochs.
98
+ checkpoint_freq: 50
99
+ # Parameters of transform for equivariance loss
100
+ transform_params:
101
+ # Sigma for affine part
102
+ sigma_affine: 0.05
103
+ # Sigma for deformation part
104
+ sigma_tps: 0.005
105
+ # Number of point in the deformation grid
106
+ points_tps: 5
107
+ loss_weights:
108
+ # Weight for LSGAN loss in generator
109
+ generator_gan: 1
110
+ # Weight for LSGAN loss in discriminator
111
+ discriminator_gan: 1
112
+ # Weights for feature matching loss, the number should be the same as number of blocks in discriminator.
113
+ feature_matching: [10, 10, 10, 10]
114
+ # Weights for perceptual loss.
115
+ perceptual: [10, 10, 10, 10, 10]
116
+ # Weights for value equivariance.
117
+ equivariance_value: 10
118
+ # Weights for jacobian equivariance.
119
+ equivariance_jacobian: 10
120
+
121
+ # Parameters of reconstruction
122
+ reconstruction_params:
123
+ # Maximum number of videos for reconstruction
124
+ num_videos: 1000
125
+ # Format for visualization, note that results will be also stored in staked .png.
126
+ format: '.mp4'
127
+
128
+ # Parameters of animation
129
+ animate_params:
130
+ # Maximum number of pairs for animation, the pairs will be either taken from pairs_list or random.
131
+ num_pairs: 50
132
+ # Format for visualization, note that results will be also stored in staked .png.
133
+ format: '.mp4'
134
+ # Normalization of diriving keypoints
135
+ normalization_params:
136
+ # Increase or decrease relative movement scale depending on the size of the object
137
+ adapt_movement_scale: False
138
+ # Apply only relative displacement of the keypoint
139
+ use_relative_movement: True
140
+ # Apply only relative change in jacobian
141
+ use_relative_jacobian: True
142
+
143
+ # Visualization parameters
144
+ visualizer_params:
145
+ # Draw keypoints of this size, increase or decrease depending on resolution
146
+ kp_size: 5
147
+ # Draw white border around images
148
+ draw_border: True
149
+ # Color map for keypoints
150
+ colormap: 'gist_rainbow'
config/vox-256.yaml ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_params:
2
+ root_dir: data/vox-png
3
+ frame_shape: [256, 256, 3]
4
+ id_sampling: True
5
+ pairs_list: data/vox256.csv
6
+ augmentation_params:
7
+ flip_param:
8
+ horizontal_flip: True
9
+ time_flip: True
10
+ jitter_param:
11
+ brightness: 0.1
12
+ contrast: 0.1
13
+ saturation: 0.1
14
+ hue: 0.1
15
+
16
+
17
+ model_params:
18
+ common_params:
19
+ num_kp: 10
20
+ num_channels: 3
21
+ estimate_jacobian: True
22
+ kp_detector_params:
23
+ temperature: 0.1
24
+ block_expansion: 32
25
+ max_features: 1024
26
+ scale_factor: 0.25
27
+ num_blocks: 5
28
+ generator_params:
29
+ block_expansion: 64
30
+ max_features: 512
31
+ num_down_blocks: 2
32
+ num_bottleneck_blocks: 6
33
+ estimate_occlusion_map: True
34
+ dense_motion_params:
35
+ block_expansion: 64
36
+ max_features: 1024
37
+ num_blocks: 5
38
+ scale_factor: 0.25
39
+ discriminator_params:
40
+ scales: [1]
41
+ block_expansion: 32
42
+ max_features: 512
43
+ num_blocks: 4
44
+ sn: True
45
+
46
+ train_params:
47
+ num_epochs: 100
48
+ num_repeats: 75
49
+ epoch_milestones: [60, 90]
50
+ lr_generator: 2.0e-4
51
+ lr_discriminator: 2.0e-4
52
+ lr_kp_detector: 2.0e-4
53
+ batch_size: 40
54
+ scales: [1, 0.5, 0.25, 0.125]
55
+ checkpoint_freq: 50
56
+ transform_params:
57
+ sigma_affine: 0.05
58
+ sigma_tps: 0.005
59
+ points_tps: 5
60
+ loss_weights:
61
+ generator_gan: 0
62
+ discriminator_gan: 1
63
+ feature_matching: [10, 10, 10, 10]
64
+ perceptual: [10, 10, 10, 10, 10]
65
+ equivariance_value: 10
66
+ equivariance_jacobian: 10
67
+
68
+ reconstruction_params:
69
+ num_videos: 1000
70
+ format: '.mp4'
71
+
72
+ animate_params:
73
+ num_pairs: 50
74
+ format: '.mp4'
75
+ normalization_params:
76
+ adapt_movement_scale: False
77
+ use_relative_movement: True
78
+ use_relative_jacobian: True
79
+
80
+ visualizer_params:
81
+ kp_size: 5
82
+ draw_border: True
83
+ colormap: 'gist_rainbow'
config/vox-adv-256.yaml ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_params:
2
+ root_dir: data/vox-png
3
+ frame_shape: [256, 256, 3]
4
+ id_sampling: True
5
+ pairs_list: data/vox256.csv
6
+ augmentation_params:
7
+ flip_param:
8
+ horizontal_flip: True
9
+ time_flip: True
10
+ jitter_param:
11
+ brightness: 0.1
12
+ contrast: 0.1
13
+ saturation: 0.1
14
+ hue: 0.1
15
+
16
+
17
+ model_params:
18
+ common_params:
19
+ num_kp: 10
20
+ num_channels: 3
21
+ estimate_jacobian: True
22
+ kp_detector_params:
23
+ temperature: 0.1
24
+ block_expansion: 32
25
+ max_features: 1024
26
+ scale_factor: 0.25
27
+ num_blocks: 5
28
+ generator_params:
29
+ block_expansion: 64
30
+ max_features: 512
31
+ num_down_blocks: 2
32
+ num_bottleneck_blocks: 6
33
+ estimate_occlusion_map: True
34
+ dense_motion_params:
35
+ block_expansion: 64
36
+ max_features: 1024
37
+ num_blocks: 5
38
+ scale_factor: 0.25
39
+ discriminator_params:
40
+ scales: [1]
41
+ block_expansion: 32
42
+ max_features: 512
43
+ num_blocks: 4
44
+ use_kp: True
45
+
46
+
47
+ train_params:
48
+ num_epochs: 150
49
+ num_repeats: 75
50
+ epoch_milestones: []
51
+ lr_generator: 2.0e-4
52
+ lr_discriminator: 2.0e-4
53
+ lr_kp_detector: 2.0e-4
54
+ batch_size: 36
55
+ scales: [1, 0.5, 0.25, 0.125]
56
+ checkpoint_freq: 50
57
+ transform_params:
58
+ sigma_affine: 0.05
59
+ sigma_tps: 0.005
60
+ points_tps: 5
61
+ loss_weights:
62
+ generator_gan: 1
63
+ discriminator_gan: 1
64
+ feature_matching: [10, 10, 10, 10]
65
+ perceptual: [10, 10, 10, 10, 10]
66
+ equivariance_value: 10
67
+ equivariance_jacobian: 10
68
+
69
+ reconstruction_params:
70
+ num_videos: 1000
71
+ format: '.mp4'
72
+
73
+ animate_params:
74
+ num_pairs: 50
75
+ format: '.mp4'
76
+ normalization_params:
77
+ adapt_movement_scale: False
78
+ use_relative_movement: True
79
+ use_relative_jacobian: True
80
+
81
+ visualizer_params:
82
+ kp_size: 5
83
+ draw_border: True
84
+ colormap: 'gist_rainbow'
crop-video.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import face_alignment
2
+ import skimage.io
3
+ import numpy
4
+ from argparse import ArgumentParser
5
+ from skimage import img_as_ubyte
6
+ from skimage.transform import resize
7
+ from tqdm import tqdm
8
+ import os
9
+ import imageio
10
+ import numpy as np
11
+ import warnings
12
+ warnings.filterwarnings("ignore")
13
+
14
+ def extract_bbox(frame, fa):
15
+ if max(frame.shape[0], frame.shape[1]) > 640:
16
+ scale_factor = max(frame.shape[0], frame.shape[1]) / 640.0
17
+ frame = resize(frame, (int(frame.shape[0] / scale_factor), int(frame.shape[1] / scale_factor)))
18
+ frame = img_as_ubyte(frame)
19
+ else:
20
+ scale_factor = 1
21
+ frame = frame[..., :3]
22
+ bboxes = fa.face_detector.detect_from_image(frame[..., ::-1])
23
+ if len(bboxes) == 0:
24
+ return []
25
+ return np.array(bboxes)[:, :-1] * scale_factor
26
+
27
+
28
+
29
+ def bb_intersection_over_union(boxA, boxB):
30
+ xA = max(boxA[0], boxB[0])
31
+ yA = max(boxA[1], boxB[1])
32
+ xB = min(boxA[2], boxB[2])
33
+ yB = min(boxA[3], boxB[3])
34
+ interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)
35
+ boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
36
+ boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)
37
+ iou = interArea / float(boxAArea + boxBArea - interArea)
38
+ return iou
39
+
40
+
41
+ def join(tube_bbox, bbox):
42
+ xA = min(tube_bbox[0], bbox[0])
43
+ yA = min(tube_bbox[1], bbox[1])
44
+ xB = max(tube_bbox[2], bbox[2])
45
+ yB = max(tube_bbox[3], bbox[3])
46
+ return (xA, yA, xB, yB)
47
+
48
+
49
+ def compute_bbox(start, end, fps, tube_bbox, frame_shape, inp, image_shape, increase_area=0.1):
50
+ left, top, right, bot = tube_bbox
51
+ width = right - left
52
+ height = bot - top
53
+
54
+ #Computing aspect preserving bbox
55
+ width_increase = max(increase_area, ((1 + 2 * increase_area) * height - width) / (2 * width))
56
+ height_increase = max(increase_area, ((1 + 2 * increase_area) * width - height) / (2 * height))
57
+
58
+ left = int(left - width_increase * width)
59
+ top = int(top - height_increase * height)
60
+ right = int(right + width_increase * width)
61
+ bot = int(bot + height_increase * height)
62
+
63
+ top, bot, left, right = max(0, top), min(bot, frame_shape[0]), max(0, left), min(right, frame_shape[1])
64
+ h, w = bot - top, right - left
65
+
66
+ start = start / fps
67
+ end = end / fps
68
+ time = end - start
69
+
70
+ scale = f'{image_shape[0]}:{image_shape[1]}'
71
+
72
+ return f'ffmpeg -i {inp} -ss {start} -t {time} -filter:v "crop={w}:{h}:{left}:{top}, scale={scale}" crop.mp4'
73
+
74
+
75
+ def compute_bbox_trajectories(trajectories, fps, frame_shape, args):
76
+ commands = []
77
+ for i, (bbox, tube_bbox, start, end) in enumerate(trajectories):
78
+ if (end - start) > args.min_frames:
79
+ command = compute_bbox(start, end, fps, tube_bbox, frame_shape, inp=args.inp, image_shape=args.image_shape, increase_area=args.increase)
80
+ commands.append(command)
81
+ return commands
82
+
83
+
84
+ def process_video(args):
85
+ device = 'cpu' if args.cpu else 'cuda'
86
+ fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False, device=device)
87
+ video = imageio.get_reader(args.inp)
88
+
89
+ trajectories = []
90
+ previous_frame = None
91
+ fps = video.get_meta_data()['fps']
92
+ commands = []
93
+ try:
94
+ for i, frame in tqdm(enumerate(video)):
95
+ frame_shape = frame.shape
96
+ bboxes = extract_bbox(frame, fa)
97
+ ## For each trajectory check the criterion
98
+ not_valid_trajectories = []
99
+ valid_trajectories = []
100
+
101
+ for trajectory in trajectories:
102
+ tube_bbox = trajectory[0]
103
+ intersection = 0
104
+ for bbox in bboxes:
105
+ intersection = max(intersection, bb_intersection_over_union(tube_bbox, bbox))
106
+ if intersection > args.iou_with_initial:
107
+ valid_trajectories.append(trajectory)
108
+ else:
109
+ not_valid_trajectories.append(trajectory)
110
+
111
+ commands += compute_bbox_trajectories(not_valid_trajectories, fps, frame_shape, args)
112
+ trajectories = valid_trajectories
113
+
114
+ ## Assign bbox to trajectories, create new trajectories
115
+ for bbox in bboxes:
116
+ intersection = 0
117
+ current_trajectory = None
118
+ for trajectory in trajectories:
119
+ tube_bbox = trajectory[0]
120
+ current_intersection = bb_intersection_over_union(tube_bbox, bbox)
121
+ if intersection < current_intersection and current_intersection > args.iou_with_initial:
122
+ intersection = bb_intersection_over_union(tube_bbox, bbox)
123
+ current_trajectory = trajectory
124
+
125
+ ## Create new trajectory
126
+ if current_trajectory is None:
127
+ trajectories.append([bbox, bbox, i, i])
128
+ else:
129
+ current_trajectory[3] = i
130
+ current_trajectory[1] = join(current_trajectory[1], bbox)
131
+
132
+
133
+ except IndexError as e:
134
+ raise (e)
135
+
136
+ commands += compute_bbox_trajectories(trajectories, fps, frame_shape, args)
137
+ return commands
138
+
139
+
140
+ if __name__ == "__main__":
141
+ parser = ArgumentParser()
142
+
143
+ parser.add_argument("--image_shape", default=(256, 256), type=lambda x: tuple(map(int, x.split(','))),
144
+ help="Image shape")
145
+ parser.add_argument("--increase", default=0.1, type=float, help='Increase bbox by this amount')
146
+ parser.add_argument("--iou_with_initial", type=float, default=0.25, help="The minimal allowed iou with inital bbox")
147
+ parser.add_argument("--inp", required=True, help='Input image or video')
148
+ parser.add_argument("--min_frames", type=int, default=150, help='Minimum number of frames')
149
+ parser.add_argument("--cpu", dest="cpu", action="store_true", help="cpu mode.")
150
+
151
+
152
+ args = parser.parse_args()
153
+
154
+ commands = process_video(args)
155
+ for command in commands:
156
+ print (command)
157
+
158
+
data/bair256.csv ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ distance,source,driving,frame
2
+ 0,000054.mp4,000048.mp4,0
3
+ 0,000050.mp4,000063.mp4,0
4
+ 0,000073.mp4,000007.mp4,0
5
+ 0,000021.mp4,000010.mp4,0
6
+ 0,000084.mp4,000046.mp4,0
7
+ 0,000031.mp4,000102.mp4,0
8
+ 0,000029.mp4,000111.mp4,0
9
+ 0,000090.mp4,000112.mp4,0
10
+ 0,000039.mp4,000010.mp4,0
11
+ 0,000008.mp4,000069.mp4,0
12
+ 0,000068.mp4,000076.mp4,0
13
+ 0,000051.mp4,000052.mp4,0
14
+ 0,000022.mp4,000098.mp4,0
15
+ 0,000096.mp4,000032.mp4,0
16
+ 0,000032.mp4,000099.mp4,0
17
+ 0,000006.mp4,000053.mp4,0
18
+ 0,000098.mp4,000020.mp4,0
19
+ 0,000029.mp4,000066.mp4,0
20
+ 0,000022.mp4,000007.mp4,0
21
+ 0,000027.mp4,000065.mp4,0
22
+ 0,000026.mp4,000059.mp4,0
23
+ 0,000015.mp4,000112.mp4,0
24
+ 0,000086.mp4,000123.mp4,0
25
+ 0,000103.mp4,000052.mp4,0
26
+ 0,000123.mp4,000103.mp4,0
27
+ 0,000051.mp4,000005.mp4,0
28
+ 0,000062.mp4,000125.mp4,0
29
+ 0,000126.mp4,000111.mp4,0
30
+ 0,000066.mp4,000090.mp4,0
31
+ 0,000075.mp4,000106.mp4,0
32
+ 0,000020.mp4,000010.mp4,0
33
+ 0,000076.mp4,000028.mp4,0
34
+ 0,000062.mp4,000002.mp4,0
35
+ 0,000095.mp4,000127.mp4,0
36
+ 0,000113.mp4,000072.mp4,0
37
+ 0,000027.mp4,000104.mp4,0
38
+ 0,000054.mp4,000124.mp4,0
39
+ 0,000019.mp4,000089.mp4,0
40
+ 0,000052.mp4,000072.mp4,0
41
+ 0,000108.mp4,000033.mp4,0
42
+ 0,000044.mp4,000118.mp4,0
43
+ 0,000029.mp4,000086.mp4,0
44
+ 0,000068.mp4,000066.mp4,0
45
+ 0,000014.mp4,000036.mp4,0
46
+ 0,000053.mp4,000071.mp4,0
47
+ 0,000022.mp4,000094.mp4,0
48
+ 0,000000.mp4,000121.mp4,0
49
+ 0,000071.mp4,000079.mp4,0
50
+ 0,000127.mp4,000005.mp4,0
51
+ 0,000085.mp4,000023.mp4,0
data/taichi-loading/README.md ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TaiChi dataset
2
+
3
+ The scripst for loading the TaiChi dataset.
4
+
5
+ We provide only the id of the corresponding video and the bounding box. Following script will download videos from youtube and crop them according to the provided bounding boxes.
6
+
7
+ 1) Load youtube-dl:
8
+ ```
9
+ wget https://yt-dl.org/downloads/latest/youtube-dl -O youtube-dl
10
+ chmod a+rx youtube-dl
11
+ ```
12
+
13
+ 2) Run script to download videos, there are 2 formats that can be used for storing videos one is .mp4 and another is folder with .png images. While .png images occupy significantly more space, the format is loss-less and have better i/o performance when training.
14
+
15
+ ```
16
+ python load_videos.py --metadata taichi-metadata.csv --format .mp4 --out_folder taichi --workers 8
17
+ ```
18
+ select number of workers based on number of cpu avaliable. Note .png format take aproximatly 80GB.
data/taichi-loading/load_videos.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import imageio
4
+ import os
5
+ import subprocess
6
+ from multiprocessing import Pool
7
+ from itertools import cycle
8
+ import warnings
9
+ import glob
10
+ import time
11
+ from tqdm import tqdm
12
+ from argparse import ArgumentParser
13
+ from skimage import img_as_ubyte
14
+ from skimage.transform import resize
15
+ warnings.filterwarnings("ignore")
16
+
17
+ DEVNULL = open(os.devnull, 'wb')
18
+
19
+
20
+ def save(path, frames, format):
21
+ if format == '.mp4':
22
+ imageio.mimsave(path, frames)
23
+ elif format == '.png':
24
+ if os.path.exists(path):
25
+ print ("Warning: skiping video %s" % os.path.basename(path))
26
+ return
27
+ else:
28
+ os.makedirs(path)
29
+ for j, frame in enumerate(frames):
30
+ imageio.imsave(os.path.join(path, str(j).zfill(7) + '.png'), frames[j])
31
+ else:
32
+ print ("Unknown format %s" % format)
33
+ exit()
34
+
35
+
36
+ def download(video_id, args):
37
+ video_path = os.path.join(args.video_folder, video_id + ".mp4")
38
+ subprocess.call([args.youtube, '-f', "''best/mp4''", '--write-auto-sub', '--write-sub',
39
+ '--sub-lang', 'en', '--skip-unavailable-fragments',
40
+ "https://www.youtube.com/watch?v=" + video_id, "--output",
41
+ video_path], stdout=DEVNULL, stderr=DEVNULL)
42
+ return video_path
43
+
44
+
45
+ def run(data):
46
+ video_id, args = data
47
+ if not os.path.exists(os.path.join(args.video_folder, video_id.split('#')[0] + '.mp4')):
48
+ download(video_id.split('#')[0], args)
49
+
50
+ if not os.path.exists(os.path.join(args.video_folder, video_id.split('#')[0] + '.mp4')):
51
+ print ('Can not load video %s, broken link' % video_id.split('#')[0])
52
+ return
53
+ reader = imageio.get_reader(os.path.join(args.video_folder, video_id.split('#')[0] + '.mp4'))
54
+ fps = reader.get_meta_data()['fps']
55
+
56
+ df = pd.read_csv(args.metadata)
57
+ df = df[df['video_id'] == video_id]
58
+
59
+ all_chunks_dict = [{'start': df['start'].iloc[j], 'end': df['end'].iloc[j],
60
+ 'bbox': list(map(int, df['bbox'].iloc[j].split('-'))), 'frames':[]} for j in range(df.shape[0])]
61
+ ref_fps = df['fps'].iloc[0]
62
+ ref_height = df['height'].iloc[0]
63
+ ref_width = df['width'].iloc[0]
64
+ partition = df['partition'].iloc[0]
65
+ try:
66
+ for i, frame in enumerate(reader):
67
+ for entry in all_chunks_dict:
68
+ if (i * ref_fps >= entry['start'] * fps) and (i * ref_fps < entry['end'] * fps):
69
+ left, top, right, bot = entry['bbox']
70
+ left = int(left / (ref_width / frame.shape[1]))
71
+ top = int(top / (ref_height / frame.shape[0]))
72
+ right = int(right / (ref_width / frame.shape[1]))
73
+ bot = int(bot / (ref_height / frame.shape[0]))
74
+ crop = frame[top:bot, left:right]
75
+ if args.image_shape is not None:
76
+ crop = img_as_ubyte(resize(crop, args.image_shape, anti_aliasing=True))
77
+ entry['frames'].append(crop)
78
+ except imageio.core.format.CannotReadFrameError:
79
+ None
80
+
81
+ for entry in all_chunks_dict:
82
+ first_part = '#'.join(video_id.split('#')[::-1])
83
+ path = first_part + '#' + str(entry['start']).zfill(6) + '#' + str(entry['end']).zfill(6) + '.mp4'
84
+ save(os.path.join(args.out_folder, partition, path), entry['frames'], args.format)
85
+
86
+
87
+ if __name__ == "__main__":
88
+ parser = ArgumentParser()
89
+ parser.add_argument("--video_folder", default='youtube-taichi', help='Path to youtube videos')
90
+ parser.add_argument("--metadata", default='taichi-metadata-new.csv', help='Path to metadata')
91
+ parser.add_argument("--out_folder", default='taichi-png', help='Path to output')
92
+ parser.add_argument("--format", default='.png', help='Storing format')
93
+ parser.add_argument("--workers", default=1, type=int, help='Number of workers')
94
+ parser.add_argument("--youtube", default='./youtube-dl', help='Path to youtube-dl')
95
+
96
+ parser.add_argument("--image_shape", default=(256, 256), type=lambda x: tuple(map(int, x.split(','))),
97
+ help="Image shape, None for no resize")
98
+
99
+ args = parser.parse_args()
100
+ if not os.path.exists(args.video_folder):
101
+ os.makedirs(args.video_folder)
102
+ if not os.path.exists(args.out_folder):
103
+ os.makedirs(args.out_folder)
104
+ for partition in ['test', 'train']:
105
+ if not os.path.exists(os.path.join(args.out_folder, partition)):
106
+ os.makedirs(os.path.join(args.out_folder, partition))
107
+
108
+ df = pd.read_csv(args.metadata)
109
+ video_ids = set(df['video_id'])
110
+ pool = Pool(processes=args.workers)
111
+ args_list = cycle([args])
112
+ for chunks_data in tqdm(pool.imap_unordered(run, zip(video_ids, args_list))):
113
+ None
data/taichi-loading/taichi-metadata.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/taichi256.csv ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ distance,source,driving,frame
2
+ 3.54437869822485,ab28GAufK8o#000261#000596.mp4,aDyyTMUBoLE#000164#000351.mp4,0
3
+ 2.8639053254437887,DMEaUoA8EPE#000028#000354.mp4,0Q914by5A98#010440#010764.mp4,0
4
+ 2.153846153846153,L82WHgYRq6I#000021#000479.mp4,0Q914by5A98#010440#010764.mp4,0
5
+ 2.8994082840236666,oNkBx4CZuEg#000000#001024.mp4,DMEaUoA8EPE#000028#000354.mp4,0
6
+ 3.3905325443786998,ab28GAufK8o#000261#000596.mp4,uEqWZ9S_-Lw#000089#000581.mp4,0
7
+ 3.266272189349112,0Q914by5A98#010440#010764.mp4,ab28GAufK8o#000261#000596.mp4,0
8
+ 2.7514792899408294,WlDYrq8K6nk#008186#008512.mp4,OiblkvkAHWM#014331#014459.mp4,0
9
+ 3.0177514792899407,oNkBx4CZuEg#001024#002048.mp4,aDyyTMUBoLE#000375#000518.mp4,0
10
+ 3.4792899408284064,aDyyTMUBoLE#000164#000351.mp4,w2awOCDRtrc#001729#002009.mp4,0
11
+ 2.769230769230769,oNkBx4CZuEg#000000#001024.mp4,L82WHgYRq6I#000021#000479.mp4,0
12
+ 3.8047337278106514,ab28GAufK8o#000261#000596.mp4,w2awOCDRtrc#001729#002009.mp4,0
13
+ 3.4260355029585763,w2awOCDRtrc#001729#002009.mp4,oNkBx4CZuEg#000000#001024.mp4,0
14
+ 3.313609467455621,DMEaUoA8EPE#000028#000354.mp4,WlDYrq8K6nk#005943#006135.mp4,0
15
+ 3.8402366863905333,oNkBx4CZuEg#001024#002048.mp4,ab28GAufK8o#000261#000596.mp4,0
16
+ 3.3254437869822504,aDyyTMUBoLE#000164#000351.mp4,oNkBx4CZuEg#000000#001024.mp4,0
17
+ 1.2485207100591724,0Q914by5A98#010440#010764.mp4,aDyyTMUBoLE#000164#000351.mp4,0
18
+ 3.804733727810652,OiblkvkAHWM#006251#006533.mp4,aDyyTMUBoLE#000375#000518.mp4,0
19
+ 3.662721893491124,uEqWZ9S_-Lw#000089#000581.mp4,DMEaUoA8EPE#000028#000354.mp4,0
20
+ 3.230769230769233,A3ZmT97hAWU#000095#000678.mp4,ab28GAufK8o#000261#000596.mp4,0
21
+ 3.3668639053254434,w81Tr0Dp1K8#015329#015485.mp4,WlDYrq8K6nk#008186#008512.mp4,0
22
+ 3.313609467455621,WlDYrq8K6nk#005943#006135.mp4,DMEaUoA8EPE#000028#000354.mp4,0
23
+ 2.7514792899408294,OiblkvkAHWM#014331#014459.mp4,WlDYrq8K6nk#008186#008512.mp4,0
24
+ 1.964497041420118,L82WHgYRq6I#000021#000479.mp4,DMEaUoA8EPE#000028#000354.mp4,0
25
+ 3.78698224852071,FBuF0xOal9M#046824#047542.mp4,lCb5w6n8kPs#011879#012014.mp4,0
26
+ 3.92307692307692,ab28GAufK8o#000261#000596.mp4,L82WHgYRq6I#000021#000479.mp4,0
27
+ 3.8402366863905333,ab28GAufK8o#000261#000596.mp4,oNkBx4CZuEg#001024#002048.mp4,0
28
+ 3.828402366863905,ab28GAufK8o#000261#000596.mp4,OiblkvkAHWM#006251#006533.mp4,0
29
+ 2.041420118343196,L82WHgYRq6I#000021#000479.mp4,aDyyTMUBoLE#000164#000351.mp4,0
30
+ 3.2485207100591724,0Q914by5A98#010440#010764.mp4,w2awOCDRtrc#001729#002009.mp4,0
31
+ 3.2485207100591746,oNkBx4CZuEg#000000#001024.mp4,0Q914by5A98#010440#010764.mp4,0
32
+ 1.964497041420118,DMEaUoA8EPE#000028#000354.mp4,L82WHgYRq6I#000021#000479.mp4,0
33
+ 3.5266272189349115,kgvcI9oe3NI#001578#001763.mp4,lCb5w6n8kPs#004451#004631.mp4,0
34
+ 3.005917159763317,A3ZmT97hAWU#000095#000678.mp4,0Q914by5A98#010440#010764.mp4,0
35
+ 3.230769230769233,ab28GAufK8o#000261#000596.mp4,A3ZmT97hAWU#000095#000678.mp4,0
36
+ 3.5266272189349115,lCb5w6n8kPs#004451#004631.mp4,kgvcI9oe3NI#001578#001763.mp4,0
37
+ 2.769230769230769,L82WHgYRq6I#000021#000479.mp4,oNkBx4CZuEg#000000#001024.mp4,0
38
+ 3.165680473372782,WlDYrq8K6nk#005943#006135.mp4,w81Tr0Dp1K8#001375#001516.mp4,0
39
+ 2.8994082840236666,DMEaUoA8EPE#000028#000354.mp4,oNkBx4CZuEg#000000#001024.mp4,0
40
+ 2.4556213017751523,0Q914by5A98#010440#010764.mp4,mndSqTrxpts#000000#000175.mp4,0
41
+ 2.201183431952659,A3ZmT97hAWU#000095#000678.mp4,VMSqvTE90hk#007168#007312.mp4,0
42
+ 3.8047337278106514,w2awOCDRtrc#001729#002009.mp4,ab28GAufK8o#000261#000596.mp4,0
43
+ 3.769230769230769,uEqWZ9S_-Lw#000089#000581.mp4,0Q914by5A98#010440#010764.mp4,0
44
+ 3.6568047337278102,A3ZmT97hAWU#000095#000678.mp4,aDyyTMUBoLE#000164#000351.mp4,0
45
+ 3.7869822485207107,uEqWZ9S_-Lw#000089#000581.mp4,L82WHgYRq6I#000021#000479.mp4,0
46
+ 3.78698224852071,lCb5w6n8kPs#011879#012014.mp4,FBuF0xOal9M#046824#047542.mp4,0
47
+ 3.591715976331361,nAQEOC1Z10M#020177#020600.mp4,w81Tr0Dp1K8#004036#004218.mp4,0
48
+ 3.8757396449704156,uEqWZ9S_-Lw#000089#000581.mp4,aDyyTMUBoLE#000164#000351.mp4,0
49
+ 2.45562130177515,aDyyTMUBoLE#000164#000351.mp4,DMEaUoA8EPE#000028#000354.mp4,0
50
+ 3.5502958579881647,uEqWZ9S_-Lw#000089#000581.mp4,OiblkvkAHWM#006251#006533.mp4,0
51
+ 3.7928994082840224,aDyyTMUBoLE#000375#000518.mp4,ab28GAufK8o#000261#000596.mp4,0
demo.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import yaml
3
+ from argparse import ArgumentParser
4
+ from tqdm.auto import tqdm
5
+
6
+ import imageio
7
+ import numpy as np
8
+ from skimage.transform import resize
9
+ from skimage import img_as_ubyte
10
+ import torch
11
+ from sync_batchnorm import DataParallelWithCallback
12
+
13
+ from modules.generator import OcclusionAwareGenerator
14
+ from modules.keypoint_detector import KPDetector
15
+ from animate import normalize_kp
16
+
17
+ import ffmpeg
18
+ from os.path import splitext
19
+ from shutil import copyfileobj
20
+ from tempfile import NamedTemporaryFile
21
+
22
+ if sys.version_info[0] < 3:
23
+ raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7")
24
+
25
+ def load_checkpoints(config_path, checkpoint_path, cpu=False):
26
+
27
+ with open(config_path) as f:
28
+ config = yaml.full_load(f)
29
+
30
+ generator = OcclusionAwareGenerator(**config['model_params']['generator_params'],
31
+ **config['model_params']['common_params'])
32
+ if not cpu:
33
+ generator.cuda()
34
+
35
+ kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
36
+ **config['model_params']['common_params'])
37
+ if not cpu:
38
+ kp_detector.cuda()
39
+
40
+ if cpu:
41
+ checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
42
+ else:
43
+ checkpoint = torch.load(checkpoint_path)
44
+
45
+ generator.load_state_dict(checkpoint['generator'])
46
+ kp_detector.load_state_dict(checkpoint['kp_detector'])
47
+
48
+ if not cpu:
49
+ generator = DataParallelWithCallback(generator)
50
+ kp_detector = DataParallelWithCallback(kp_detector)
51
+
52
+ generator.eval()
53
+ kp_detector.eval()
54
+
55
+ return generator, kp_detector
56
+
57
+
58
+ def make_animation(source_image, driving_video, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=False):
59
+ with torch.no_grad():
60
+ predictions = []
61
+ source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)
62
+ if not cpu:
63
+ source = source.cuda()
64
+ driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3)
65
+ kp_source = kp_detector(source)
66
+ kp_driving_initial = kp_detector(driving[:, :, 0])
67
+
68
+ for frame_idx in tqdm(range(driving.shape[2])):
69
+ driving_frame = driving[:, :, frame_idx]
70
+ if not cpu:
71
+ driving_frame = driving_frame.cuda()
72
+ kp_driving = kp_detector(driving_frame)
73
+ kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
74
+ kp_driving_initial=kp_driving_initial, use_relative_movement=relative,
75
+ use_relative_jacobian=relative, adapt_movement_scale=adapt_movement_scale)
76
+ out = generator(source, kp_source=kp_source, kp_driving=kp_norm)
77
+
78
+ predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
79
+ return predictions
80
+
81
+ def find_best_frame(source, driving, cpu=False):
82
+ import face_alignment # type: ignore (local file)
83
+ from scipy.spatial import ConvexHull
84
+
85
+ def normalize_kp(kp):
86
+ kp = kp - kp.mean(axis=0, keepdims=True)
87
+ area = ConvexHull(kp[:, :2]).volume
88
+ area = np.sqrt(area)
89
+ kp[:, :2] = kp[:, :2] / area
90
+ return kp
91
+
92
+ fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=True,
93
+ device='cpu' if cpu else 'cuda')
94
+ kp_source = fa.get_landmarks(255 * source)[0]
95
+ kp_source = normalize_kp(kp_source)
96
+ norm = float('inf')
97
+ frame_num = 0
98
+ for i, image in tqdm(enumerate(driving)):
99
+ kp_driving = fa.get_landmarks(255 * image)[0]
100
+ kp_driving = normalize_kp(kp_driving)
101
+ new_norm = (np.abs(kp_source - kp_driving) ** 2).sum()
102
+ if new_norm < norm:
103
+ norm = new_norm
104
+ frame_num = i
105
+ return frame_num
106
+
107
+ if __name__ == "__main__":
108
+ parser = ArgumentParser()
109
+ parser.add_argument("--config", required=True, help="path to config")
110
+ parser.add_argument("--checkpoint", default='vox-cpk.pth.tar', help="path to checkpoint to restore")
111
+
112
+ parser.add_argument("--source_image", default='sup-mat/source.png', help="path to source image")
113
+ parser.add_argument("--driving_video", default='driving.mp4', help="path to driving video")
114
+ parser.add_argument("--result_video", default='result.mp4', help="path to output")
115
+
116
+ parser.add_argument("--relative", dest="relative", action="store_true", help="use relative or absolute keypoint coordinates")
117
+ parser.add_argument("--adapt_scale", dest="adapt_scale", action="store_true", help="adapt movement scale based on convex hull of keypoints")
118
+
119
+ parser.add_argument("--find_best_frame", dest="find_best_frame", action="store_true",
120
+ help="Generate from the frame that is the most alligned with source. (Only for faces, requires face_aligment lib)")
121
+
122
+ parser.add_argument("--best_frame", dest="best_frame", type=int, default=None, help="Set frame to start from.")
123
+
124
+ parser.add_argument("--cpu", dest="cpu", action="store_true", help="cpu mode.")
125
+
126
+ parser.add_argument("--audio", dest="audio", action="store_true", help="copy audio to output from the driving video" )
127
+
128
+ parser.set_defaults(relative=False)
129
+ parser.set_defaults(adapt_scale=False)
130
+ parser.set_defaults(audio_on=False)
131
+
132
+ opt = parser.parse_args()
133
+
134
+ source_image = imageio.imread(opt.source_image)
135
+ reader = imageio.get_reader(opt.driving_video)
136
+ fps = reader.get_meta_data()['fps']
137
+ driving_video = []
138
+ try:
139
+ for im in reader:
140
+ driving_video.append(im)
141
+ except RuntimeError:
142
+ pass
143
+ reader.close()
144
+
145
+ source_image = resize(source_image, (256, 256))[..., :3]
146
+ driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video]
147
+ generator, kp_detector = load_checkpoints(config_path=opt.config, checkpoint_path=opt.checkpoint, cpu=opt.cpu)
148
+
149
+ if opt.find_best_frame or opt.best_frame is not None:
150
+ i = opt.best_frame if opt.best_frame is not None else find_best_frame(source_image, driving_video, cpu=opt.cpu)
151
+ print ("Best frame: " + str(i))
152
+ driving_forward = driving_video[i:]
153
+ driving_backward = driving_video[:(i+1)][::-1]
154
+ predictions_forward = make_animation(source_image, driving_forward, generator, kp_detector, relative=opt.relative, adapt_movement_scale=opt.adapt_scale, cpu=opt.cpu)
155
+ predictions_backward = make_animation(source_image, driving_backward, generator, kp_detector, relative=opt.relative, adapt_movement_scale=opt.adapt_scale, cpu=opt.cpu)
156
+ predictions = predictions_backward[::-1] + predictions_forward[1:]
157
+ else:
158
+ predictions = make_animation(source_image, driving_video, generator, kp_detector, relative=opt.relative, adapt_movement_scale=opt.adapt_scale, cpu=opt.cpu)
159
+ imageio.mimsave(opt.result_video, [img_as_ubyte(frame) for frame in predictions], fps=fps)
160
+
161
+ if opt.audio:
162
+ with NamedTemporaryFile(suffix='.' + splitext(opt.result_video)[1]) as output:
163
+ ffmpeg.output(ffmpeg.input(opt.result_video).video, ffmpeg.input(opt.driving_video).audio, output.name, c='copy').run()
164
+ with open(opt.result_video, 'wb') as result:
165
+ copyfileobj(output, result)
frames_dataset.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from skimage import io, img_as_float32
3
+ from skimage.color import gray2rgb
4
+ from sklearn.model_selection import train_test_split
5
+ from imageio import mimread
6
+
7
+ import numpy as np
8
+ from torch.utils.data import Dataset
9
+ import pandas as pd
10
+ from augmentation import AllAugmentationTransform
11
+ import glob
12
+
13
+
14
+ def read_video(name, frame_shape):
15
+ """
16
+ Read video which can be:
17
+ - an image of concatenated frames
18
+ - '.mp4' and'.gif'
19
+ - folder with videos
20
+ """
21
+
22
+ if os.path.isdir(name):
23
+ frames = sorted(os.listdir(name))
24
+ num_frames = len(frames)
25
+ video_array = np.array(
26
+ [img_as_float32(io.imread(os.path.join(name, frames[idx]))) for idx in range(num_frames)])
27
+ elif name.lower().endswith('.png') or name.lower().endswith('.jpg'):
28
+ image = io.imread(name)
29
+
30
+ if len(image.shape) == 2 or image.shape[2] == 1:
31
+ image = gray2rgb(image)
32
+
33
+ if image.shape[2] == 4:
34
+ image = image[..., :3]
35
+
36
+ image = img_as_float32(image)
37
+
38
+ video_array = np.moveaxis(image, 1, 0)
39
+
40
+ video_array = video_array.reshape((-1,) + frame_shape)
41
+ video_array = np.moveaxis(video_array, 1, 2)
42
+ elif name.lower().endswith('.gif') or name.lower().endswith('.mp4') or name.lower().endswith('.mov'):
43
+ video = np.array(mimread(name))
44
+ if len(video.shape) == 3:
45
+ video = np.array([gray2rgb(frame) for frame in video])
46
+ if video.shape[-1] == 4:
47
+ video = video[..., :3]
48
+ video_array = img_as_float32(video)
49
+ else:
50
+ raise Exception("Unknown file extensions %s" % name)
51
+
52
+ return video_array
53
+
54
+
55
+ class FramesDataset(Dataset):
56
+ """
57
+ Dataset of videos, each video can be represented as:
58
+ - an image of concatenated frames
59
+ - '.mp4' or '.gif'
60
+ - folder with all frames
61
+ """
62
+
63
+ def __init__(self, root_dir, frame_shape=(256, 256, 3), id_sampling=False, is_train=True,
64
+ random_seed=0, pairs_list=None, augmentation_params=None):
65
+ self.root_dir = root_dir
66
+ self.videos = os.listdir(root_dir)
67
+ self.frame_shape = tuple(frame_shape)
68
+ self.pairs_list = pairs_list
69
+ self.id_sampling = id_sampling
70
+ if os.path.exists(os.path.join(root_dir, 'train')):
71
+ assert os.path.exists(os.path.join(root_dir, 'test'))
72
+ print("Use predefined train-test split.")
73
+ if id_sampling:
74
+ train_videos = {os.path.basename(video).split('#')[0] for video in
75
+ os.listdir(os.path.join(root_dir, 'train'))}
76
+ train_videos = list(train_videos)
77
+ else:
78
+ train_videos = os.listdir(os.path.join(root_dir, 'train'))
79
+ test_videos = os.listdir(os.path.join(root_dir, 'test'))
80
+ self.root_dir = os.path.join(self.root_dir, 'train' if is_train else 'test')
81
+ else:
82
+ print("Use random train-test split.")
83
+ train_videos, test_videos = train_test_split(self.videos, random_state=random_seed, test_size=0.2)
84
+
85
+ if is_train:
86
+ self.videos = train_videos
87
+ else:
88
+ self.videos = test_videos
89
+
90
+ self.is_train = is_train
91
+
92
+ if self.is_train:
93
+ self.transform = AllAugmentationTransform(**augmentation_params)
94
+ else:
95
+ self.transform = None
96
+
97
+ def __len__(self):
98
+ return len(self.videos)
99
+
100
+ def __getitem__(self, idx):
101
+ if self.is_train and self.id_sampling:
102
+ name = self.videos[idx]
103
+ path = np.random.choice(glob.glob(os.path.join(self.root_dir, name + '*.mp4')))
104
+ else:
105
+ name = self.videos[idx]
106
+ path = os.path.join(self.root_dir, name)
107
+
108
+ video_name = os.path.basename(path)
109
+
110
+ if self.is_train and os.path.isdir(path):
111
+ frames = os.listdir(path)
112
+ num_frames = len(frames)
113
+ frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2))
114
+ video_array = [img_as_float32(io.imread(os.path.join(path, frames[idx]))) for idx in frame_idx]
115
+ else:
116
+ video_array = read_video(path, frame_shape=self.frame_shape)
117
+ num_frames = len(video_array)
118
+ frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2)) if self.is_train else range(
119
+ num_frames)
120
+ video_array = video_array[frame_idx]
121
+
122
+ if self.transform is not None:
123
+ video_array = self.transform(video_array)
124
+
125
+ out = {}
126
+ if self.is_train:
127
+ source = np.array(video_array[0], dtype='float32')
128
+ driving = np.array(video_array[1], dtype='float32')
129
+
130
+ out['driving'] = driving.transpose((2, 0, 1))
131
+ out['source'] = source.transpose((2, 0, 1))
132
+ else:
133
+ video = np.array(video_array, dtype='float32')
134
+ out['video'] = video.transpose((3, 0, 1, 2))
135
+
136
+ out['name'] = video_name
137
+
138
+ return out
139
+
140
+
141
+ class DatasetRepeater(Dataset):
142
+ """
143
+ Pass several times over the same dataset for better i/o performance
144
+ """
145
+
146
+ def __init__(self, dataset, num_repeats=100):
147
+ self.dataset = dataset
148
+ self.num_repeats = num_repeats
149
+
150
+ def __len__(self):
151
+ return self.num_repeats * self.dataset.__len__()
152
+
153
+ def __getitem__(self, idx):
154
+ return self.dataset[idx % self.dataset.__len__()]
155
+
156
+
157
+ class PairedDataset(Dataset):
158
+ """
159
+ Dataset of pairs for animation.
160
+ """
161
+
162
+ def __init__(self, initial_dataset, number_of_pairs, seed=0):
163
+ self.initial_dataset = initial_dataset
164
+ pairs_list = self.initial_dataset.pairs_list
165
+
166
+ np.random.seed(seed)
167
+
168
+ if pairs_list is None:
169
+ max_idx = min(number_of_pairs, len(initial_dataset))
170
+ nx, ny = max_idx, max_idx
171
+ xy = np.mgrid[:nx, :ny].reshape(2, -1).T
172
+ number_of_pairs = min(xy.shape[0], number_of_pairs)
173
+ self.pairs = xy.take(np.random.choice(xy.shape[0], number_of_pairs, replace=False), axis=0)
174
+ else:
175
+ videos = self.initial_dataset.videos
176
+ name_to_index = {name: index for index, name in enumerate(videos)}
177
+ pairs = pd.read_csv(pairs_list)
178
+ pairs = pairs[np.logical_and(pairs['source'].isin(videos), pairs['driving'].isin(videos))]
179
+
180
+ number_of_pairs = min(pairs.shape[0], number_of_pairs)
181
+ self.pairs = []
182
+ self.start_frames = []
183
+ for ind in range(number_of_pairs):
184
+ self.pairs.append(
185
+ (name_to_index[pairs['driving'].iloc[ind]], name_to_index[pairs['source'].iloc[ind]]))
186
+
187
+ def __len__(self):
188
+ return len(self.pairs)
189
+
190
+ def __getitem__(self, idx):
191
+ pair = self.pairs[idx]
192
+ first = self.initial_dataset[pair[0]]
193
+ second = self.initial_dataset[pair[1]]
194
+ first = {'driving_' + key: value for key, value in first.items()}
195
+ second = {'source_' + key: value for key, value in second.items()}
196
+
197
+ return {**first, **second}
logger.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import imageio
5
+
6
+ import os
7
+ from skimage.draw import circle
8
+
9
+ import matplotlib.pyplot as plt
10
+ import collections
11
+
12
+
13
+ class Logger:
14
+ def __init__(self, log_dir, checkpoint_freq=100, visualizer_params=None, zfill_num=8, log_file_name='log.txt'):
15
+
16
+ self.loss_list = []
17
+ self.cpk_dir = log_dir
18
+ self.visualizations_dir = os.path.join(log_dir, 'train-vis')
19
+ if not os.path.exists(self.visualizations_dir):
20
+ os.makedirs(self.visualizations_dir)
21
+ self.log_file = open(os.path.join(log_dir, log_file_name), 'a')
22
+ self.zfill_num = zfill_num
23
+ self.visualizer = Visualizer(**visualizer_params)
24
+ self.checkpoint_freq = checkpoint_freq
25
+ self.epoch = 0
26
+ self.best_loss = float('inf')
27
+ self.names = None
28
+
29
+ def log_scores(self, loss_names):
30
+ loss_mean = np.array(self.loss_list).mean(axis=0)
31
+
32
+ loss_string = "; ".join(["%s - %.5f" % (name, value) for name, value in zip(loss_names, loss_mean)])
33
+ loss_string = str(self.epoch).zfill(self.zfill_num) + ") " + loss_string
34
+
35
+ print(loss_string, file=self.log_file)
36
+ self.loss_list = []
37
+ self.log_file.flush()
38
+
39
+ def visualize_rec(self, inp, out):
40
+ image = self.visualizer.visualize(inp['driving'], inp['source'], out)
41
+ imageio.imsave(os.path.join(self.visualizations_dir, "%s-rec.png" % str(self.epoch).zfill(self.zfill_num)), image)
42
+
43
+ def save_cpk(self, emergent=False):
44
+ cpk = {k: v.state_dict() for k, v in self.models.items()}
45
+ cpk['epoch'] = self.epoch
46
+ cpk_path = os.path.join(self.cpk_dir, '%s-checkpoint.pth.tar' % str(self.epoch).zfill(self.zfill_num))
47
+ if not (os.path.exists(cpk_path) and emergent):
48
+ torch.save(cpk, cpk_path)
49
+
50
+ @staticmethod
51
+ def load_cpk(checkpoint_path, generator=None, discriminator=None, kp_detector=None,
52
+ optimizer_generator=None, optimizer_discriminator=None, optimizer_kp_detector=None):
53
+ if torch.cuda.is_available():
54
+ map_location = None
55
+ else:
56
+ map_location = 'cpu'
57
+ checkpoint = torch.load(checkpoint_path, map_location)
58
+ if generator is not None:
59
+ generator.load_state_dict(checkpoint['generator'])
60
+ if kp_detector is not None:
61
+ kp_detector.load_state_dict(checkpoint['kp_detector'])
62
+ if discriminator is not None:
63
+ try:
64
+ discriminator.load_state_dict(checkpoint['discriminator'])
65
+ except:
66
+ print ('No discriminator in the state-dict. Dicriminator will be randomly initialized')
67
+ if optimizer_generator is not None:
68
+ optimizer_generator.load_state_dict(checkpoint['optimizer_generator'])
69
+ if optimizer_discriminator is not None:
70
+ try:
71
+ optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
72
+ except RuntimeError as e:
73
+ print ('No discriminator optimizer in the state-dict. Optimizer will be not initialized')
74
+ if optimizer_kp_detector is not None:
75
+ optimizer_kp_detector.load_state_dict(checkpoint['optimizer_kp_detector'])
76
+
77
+ return checkpoint['epoch']
78
+
79
+ def __enter__(self):
80
+ return self
81
+
82
+ def __exit__(self, exc_type, exc_val, exc_tb):
83
+ if 'models' in self.__dict__:
84
+ self.save_cpk()
85
+ self.log_file.close()
86
+
87
+ def log_iter(self, losses):
88
+ losses = collections.OrderedDict(losses.items())
89
+ if self.names is None:
90
+ self.names = list(losses.keys())
91
+ self.loss_list.append(list(losses.values()))
92
+
93
+ def log_epoch(self, epoch, models, inp, out):
94
+ self.epoch = epoch
95
+ self.models = models
96
+ if (self.epoch + 1) % self.checkpoint_freq == 0:
97
+ self.save_cpk()
98
+ self.log_scores(self.names)
99
+ self.visualize_rec(inp, out)
100
+
101
+
102
+ class Visualizer:
103
+ def __init__(self, kp_size=5, draw_border=False, colormap='gist_rainbow'):
104
+ self.kp_size = kp_size
105
+ self.draw_border = draw_border
106
+ self.colormap = plt.get_cmap(colormap)
107
+
108
+ def draw_image_with_kp(self, image, kp_array):
109
+ image = np.copy(image)
110
+ spatial_size = np.array(image.shape[:2][::-1])[np.newaxis]
111
+ kp_array = spatial_size * (kp_array + 1) / 2
112
+ num_kp = kp_array.shape[0]
113
+ for kp_ind, kp in enumerate(kp_array):
114
+ rr, cc = circle(kp[1], kp[0], self.kp_size, shape=image.shape[:2])
115
+ image[rr, cc] = np.array(self.colormap(kp_ind / num_kp))[:3]
116
+ return image
117
+
118
+ def create_image_column_with_kp(self, images, kp):
119
+ image_array = np.array([self.draw_image_with_kp(v, k) for v, k in zip(images, kp)])
120
+ return self.create_image_column(image_array)
121
+
122
+ def create_image_column(self, images):
123
+ if self.draw_border:
124
+ images = np.copy(images)
125
+ images[:, :, [0, -1]] = (1, 1, 1)
126
+ return np.concatenate(list(images), axis=0)
127
+
128
+ def create_image_grid(self, *args):
129
+ out = []
130
+ for arg in args:
131
+ if type(arg) == tuple:
132
+ out.append(self.create_image_column_with_kp(arg[0], arg[1]))
133
+ else:
134
+ out.append(self.create_image_column(arg))
135
+ return np.concatenate(out, axis=1)
136
+
137
+ def visualize(self, driving, source, out):
138
+ images = []
139
+
140
+ # Source image with keypoints
141
+ source = source.data.cpu()
142
+ kp_source = out['kp_source']['value'].data.cpu().numpy()
143
+ source = np.transpose(source, [0, 2, 3, 1])
144
+ images.append((source, kp_source))
145
+
146
+ # Equivariance visualization
147
+ if 'transformed_frame' in out:
148
+ transformed = out['transformed_frame'].data.cpu().numpy()
149
+ transformed = np.transpose(transformed, [0, 2, 3, 1])
150
+ transformed_kp = out['transformed_kp']['value'].data.cpu().numpy()
151
+ images.append((transformed, transformed_kp))
152
+
153
+ # Driving image with keypoints
154
+ kp_driving = out['kp_driving']['value'].data.cpu().numpy()
155
+ driving = driving.data.cpu().numpy()
156
+ driving = np.transpose(driving, [0, 2, 3, 1])
157
+ images.append((driving, kp_driving))
158
+
159
+ # Deformed image
160
+ if 'deformed' in out:
161
+ deformed = out['deformed'].data.cpu().numpy()
162
+ deformed = np.transpose(deformed, [0, 2, 3, 1])
163
+ images.append(deformed)
164
+
165
+ # Result with and without keypoints
166
+ prediction = out['prediction'].data.cpu().numpy()
167
+ prediction = np.transpose(prediction, [0, 2, 3, 1])
168
+ if 'kp_norm' in out:
169
+ kp_norm = out['kp_norm']['value'].data.cpu().numpy()
170
+ images.append((prediction, kp_norm))
171
+ images.append(prediction)
172
+
173
+
174
+ ## Occlusion map
175
+ if 'occlusion_map' in out:
176
+ occlusion_map = out['occlusion_map'].data.cpu().repeat(1, 3, 1, 1)
177
+ occlusion_map = F.interpolate(occlusion_map, size=source.shape[1:3]).numpy()
178
+ occlusion_map = np.transpose(occlusion_map, [0, 2, 3, 1])
179
+ images.append(occlusion_map)
180
+
181
+ # Deformed images according to each individual transform
182
+ if 'sparse_deformed' in out:
183
+ full_mask = []
184
+ for i in range(out['sparse_deformed'].shape[1]):
185
+ image = out['sparse_deformed'][:, i].data.cpu()
186
+ image = F.interpolate(image, size=source.shape[1:3])
187
+ mask = out['mask'][:, i:(i+1)].data.cpu().repeat(1, 3, 1, 1)
188
+ mask = F.interpolate(mask, size=source.shape[1:3])
189
+ image = np.transpose(image.numpy(), (0, 2, 3, 1))
190
+ mask = np.transpose(mask.numpy(), (0, 2, 3, 1))
191
+
192
+ if i != 0:
193
+ color = np.array(self.colormap((i - 1) / (out['sparse_deformed'].shape[1] - 1)))[:3]
194
+ else:
195
+ color = np.array((0, 0, 0))
196
+
197
+ color = color.reshape((1, 1, 1, 3))
198
+
199
+ images.append(image)
200
+ if i != 0:
201
+ images.append(mask * color)
202
+ else:
203
+ images.append(mask)
204
+
205
+ full_mask.append(mask * color)
206
+
207
+ images.append(sum(full_mask))
208
+
209
+ image = self.create_image_grid(*images)
210
+ image = (255 * image).astype(np.uint8)
211
+ return image
modules/dense_motion.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch.nn.functional as F
3
+ import torch
4
+ from modules.util import Hourglass, AntiAliasInterpolation2d, make_coordinate_grid, kp2gaussian
5
+
6
+
7
+ class DenseMotionNetwork(nn.Module):
8
+ """
9
+ Module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving
10
+ """
11
+
12
+ def __init__(self, block_expansion, num_blocks, max_features, num_kp, num_channels, estimate_occlusion_map=False,
13
+ scale_factor=1, kp_variance=0.01):
14
+ super(DenseMotionNetwork, self).__init__()
15
+ self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp + 1) * (num_channels + 1),
16
+ max_features=max_features, num_blocks=num_blocks)
17
+
18
+ self.mask = nn.Conv2d(self.hourglass.out_filters, num_kp + 1, kernel_size=(7, 7), padding=(3, 3))
19
+
20
+ if estimate_occlusion_map:
21
+ self.occlusion = nn.Conv2d(self.hourglass.out_filters, 1, kernel_size=(7, 7), padding=(3, 3))
22
+ else:
23
+ self.occlusion = None
24
+
25
+ self.num_kp = num_kp
26
+ self.scale_factor = scale_factor
27
+ self.kp_variance = kp_variance
28
+
29
+ if self.scale_factor != 1:
30
+ self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor)
31
+
32
+ def create_heatmap_representations(self, source_image, kp_driving, kp_source):
33
+ """
34
+ Eq 6. in the paper H_k(z)
35
+ """
36
+ spatial_size = source_image.shape[2:]
37
+ gaussian_driving = kp2gaussian(kp_driving, spatial_size=spatial_size, kp_variance=self.kp_variance)
38
+ gaussian_source = kp2gaussian(kp_source, spatial_size=spatial_size, kp_variance=self.kp_variance)
39
+ heatmap = gaussian_driving - gaussian_source
40
+
41
+ #adding background feature
42
+ zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1]).type(heatmap.type())
43
+ heatmap = torch.cat([zeros, heatmap], dim=1)
44
+ heatmap = heatmap.unsqueeze(2)
45
+ return heatmap
46
+
47
+ def create_sparse_motions(self, source_image, kp_driving, kp_source):
48
+ """
49
+ Eq 4. in the paper T_{s<-d}(z)
50
+ """
51
+ bs, _, h, w = source_image.shape
52
+ identity_grid = make_coordinate_grid((h, w), type=kp_source['value'].type())
53
+ identity_grid = identity_grid.view(1, 1, h, w, 2)
54
+ coordinate_grid = identity_grid - kp_driving['value'].view(bs, self.num_kp, 1, 1, 2)
55
+ if 'jacobian' in kp_driving:
56
+ jacobian = torch.matmul(kp_source['jacobian'], torch.inverse(kp_driving['jacobian']))
57
+ jacobian = jacobian.unsqueeze(-3).unsqueeze(-3)
58
+ jacobian = jacobian.repeat(1, 1, h, w, 1, 1)
59
+ coordinate_grid = torch.matmul(jacobian, coordinate_grid.unsqueeze(-1))
60
+ coordinate_grid = coordinate_grid.squeeze(-1)
61
+
62
+ driving_to_source = coordinate_grid + kp_source['value'].view(bs, self.num_kp, 1, 1, 2)
63
+
64
+ #adding background feature
65
+ identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1)
66
+ sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1)
67
+ return sparse_motions
68
+
69
+ def create_deformed_source_image(self, source_image, sparse_motions):
70
+ """
71
+ Eq 7. in the paper \hat{T}_{s<-d}(z)
72
+ """
73
+ bs, _, h, w = source_image.shape
74
+ source_repeat = source_image.unsqueeze(1).unsqueeze(1).repeat(1, self.num_kp + 1, 1, 1, 1, 1)
75
+ source_repeat = source_repeat.view(bs * (self.num_kp + 1), -1, h, w)
76
+ sparse_motions = sparse_motions.view((bs * (self.num_kp + 1), h, w, -1))
77
+ sparse_deformed = F.grid_sample(source_repeat, sparse_motions)
78
+ sparse_deformed = sparse_deformed.view((bs, self.num_kp + 1, -1, h, w))
79
+ return sparse_deformed
80
+
81
+ def forward(self, source_image, kp_driving, kp_source):
82
+ if self.scale_factor != 1:
83
+ source_image = self.down(source_image)
84
+
85
+ bs, _, h, w = source_image.shape
86
+
87
+ out_dict = dict()
88
+ heatmap_representation = self.create_heatmap_representations(source_image, kp_driving, kp_source)
89
+ sparse_motion = self.create_sparse_motions(source_image, kp_driving, kp_source)
90
+ deformed_source = self.create_deformed_source_image(source_image, sparse_motion)
91
+ out_dict['sparse_deformed'] = deformed_source
92
+
93
+ input = torch.cat([heatmap_representation, deformed_source], dim=2)
94
+ input = input.view(bs, -1, h, w)
95
+
96
+ prediction = self.hourglass(input)
97
+
98
+ mask = self.mask(prediction)
99
+ mask = F.softmax(mask, dim=1)
100
+ out_dict['mask'] = mask
101
+ mask = mask.unsqueeze(2)
102
+ sparse_motion = sparse_motion.permute(0, 1, 4, 2, 3)
103
+ deformation = (sparse_motion * mask).sum(dim=1)
104
+ deformation = deformation.permute(0, 2, 3, 1)
105
+
106
+ out_dict['deformation'] = deformation
107
+
108
+ # Sec. 3.2 in the paper
109
+ if self.occlusion:
110
+ occlusion_map = torch.sigmoid(self.occlusion(prediction))
111
+ out_dict['occlusion_map'] = occlusion_map
112
+
113
+ return out_dict
modules/discriminator.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch.nn.functional as F
3
+ from modules.util import kp2gaussian
4
+ import torch
5
+
6
+
7
+ class DownBlock2d(nn.Module):
8
+ """
9
+ Simple block for processing video (encoder).
10
+ """
11
+
12
+ def __init__(self, in_features, out_features, norm=False, kernel_size=4, pool=False, sn=False):
13
+ super(DownBlock2d, self).__init__()
14
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size)
15
+
16
+ if sn:
17
+ self.conv = nn.utils.spectral_norm(self.conv)
18
+
19
+ if norm:
20
+ self.norm = nn.InstanceNorm2d(out_features, affine=True)
21
+ else:
22
+ self.norm = None
23
+ self.pool = pool
24
+
25
+ def forward(self, x):
26
+ out = x
27
+ out = self.conv(out)
28
+ if self.norm:
29
+ out = self.norm(out)
30
+ out = F.leaky_relu(out, 0.2)
31
+ if self.pool:
32
+ out = F.avg_pool2d(out, (2, 2))
33
+ return out
34
+
35
+
36
+ class Discriminator(nn.Module):
37
+ """
38
+ Discriminator similar to Pix2Pix
39
+ """
40
+
41
+ def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, max_features=512,
42
+ sn=False, use_kp=False, num_kp=10, kp_variance=0.01, **kwargs):
43
+ super(Discriminator, self).__init__()
44
+
45
+ down_blocks = []
46
+ for i in range(num_blocks):
47
+ down_blocks.append(
48
+ DownBlock2d(num_channels + num_kp * use_kp if i == 0 else min(max_features, block_expansion * (2 ** i)),
49
+ min(max_features, block_expansion * (2 ** (i + 1))),
50
+ norm=(i != 0), kernel_size=4, pool=(i != num_blocks - 1), sn=sn))
51
+
52
+ self.down_blocks = nn.ModuleList(down_blocks)
53
+ self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1)
54
+ if sn:
55
+ self.conv = nn.utils.spectral_norm(self.conv)
56
+ self.use_kp = use_kp
57
+ self.kp_variance = kp_variance
58
+
59
+ def forward(self, x, kp=None):
60
+ feature_maps = []
61
+ out = x
62
+ if self.use_kp:
63
+ heatmap = kp2gaussian(kp, x.shape[2:], self.kp_variance)
64
+ out = torch.cat([out, heatmap], dim=1)
65
+
66
+ for down_block in self.down_blocks:
67
+ feature_maps.append(down_block(out))
68
+ out = feature_maps[-1]
69
+ prediction_map = self.conv(out)
70
+
71
+ return feature_maps, prediction_map
72
+
73
+
74
+ class MultiScaleDiscriminator(nn.Module):
75
+ """
76
+ Multi-scale (scale) discriminator
77
+ """
78
+
79
+ def __init__(self, scales=(), **kwargs):
80
+ super(MultiScaleDiscriminator, self).__init__()
81
+ self.scales = scales
82
+ discs = {}
83
+ for scale in scales:
84
+ discs[str(scale).replace('.', '-')] = Discriminator(**kwargs)
85
+ self.discs = nn.ModuleDict(discs)
86
+
87
+ def forward(self, x, kp=None):
88
+ out_dict = {}
89
+ for scale, disc in self.discs.items():
90
+ scale = str(scale).replace('-', '.')
91
+ key = 'prediction_' + scale
92
+ feature_maps, prediction_map = disc(x[key], kp)
93
+ out_dict['feature_maps_' + scale] = feature_maps
94
+ out_dict['prediction_map_' + scale] = prediction_map
95
+ return out_dict
modules/generator.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ from modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d
5
+ from modules.dense_motion import DenseMotionNetwork
6
+
7
+
8
+ class OcclusionAwareGenerator(nn.Module):
9
+ """
10
+ Generator that given source image and and keypoints try to transform image according to movement trajectories
11
+ induced by keypoints. Generator follows Johnson architecture.
12
+ """
13
+
14
+ def __init__(self, num_channels, num_kp, block_expansion, max_features, num_down_blocks,
15
+ num_bottleneck_blocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False):
16
+ super(OcclusionAwareGenerator, self).__init__()
17
+
18
+ if dense_motion_params is not None:
19
+ self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, num_channels=num_channels,
20
+ estimate_occlusion_map=estimate_occlusion_map,
21
+ **dense_motion_params)
22
+ else:
23
+ self.dense_motion_network = None
24
+
25
+ self.first = SameBlock2d(num_channels, block_expansion, kernel_size=(7, 7), padding=(3, 3))
26
+
27
+ down_blocks = []
28
+ for i in range(num_down_blocks):
29
+ in_features = min(max_features, block_expansion * (2 ** i))
30
+ out_features = min(max_features, block_expansion * (2 ** (i + 1)))
31
+ down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
32
+ self.down_blocks = nn.ModuleList(down_blocks)
33
+
34
+ up_blocks = []
35
+ for i in range(num_down_blocks):
36
+ in_features = min(max_features, block_expansion * (2 ** (num_down_blocks - i)))
37
+ out_features = min(max_features, block_expansion * (2 ** (num_down_blocks - i - 1)))
38
+ up_blocks.append(UpBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
39
+ self.up_blocks = nn.ModuleList(up_blocks)
40
+
41
+ self.bottleneck = torch.nn.Sequential()
42
+ in_features = min(max_features, block_expansion * (2 ** num_down_blocks))
43
+ for i in range(num_bottleneck_blocks):
44
+ self.bottleneck.add_module('r' + str(i), ResBlock2d(in_features, kernel_size=(3, 3), padding=(1, 1)))
45
+
46
+ self.final = nn.Conv2d(block_expansion, num_channels, kernel_size=(7, 7), padding=(3, 3))
47
+ self.estimate_occlusion_map = estimate_occlusion_map
48
+ self.num_channels = num_channels
49
+
50
+ def deform_input(self, inp, deformation):
51
+ _, h_old, w_old, _ = deformation.shape
52
+ _, _, h, w = inp.shape
53
+ if h_old != h or w_old != w:
54
+ deformation = deformation.permute(0, 3, 1, 2)
55
+ deformation = F.interpolate(deformation, size=(h, w), mode='bilinear')
56
+ deformation = deformation.permute(0, 2, 3, 1)
57
+ return F.grid_sample(inp, deformation)
58
+
59
+ def forward(self, source_image, kp_driving, kp_source):
60
+ # Encoding (downsampling) part
61
+ out = self.first(source_image)
62
+ for i in range(len(self.down_blocks)):
63
+ out = self.down_blocks[i](out)
64
+
65
+ # Transforming feature representation according to deformation and occlusion
66
+ output_dict = {}
67
+ if self.dense_motion_network is not None:
68
+ dense_motion = self.dense_motion_network(source_image=source_image, kp_driving=kp_driving,
69
+ kp_source=kp_source)
70
+ output_dict['mask'] = dense_motion['mask']
71
+ output_dict['sparse_deformed'] = dense_motion['sparse_deformed']
72
+
73
+ if 'occlusion_map' in dense_motion:
74
+ occlusion_map = dense_motion['occlusion_map']
75
+ output_dict['occlusion_map'] = occlusion_map
76
+ else:
77
+ occlusion_map = None
78
+ deformation = dense_motion['deformation']
79
+ out = self.deform_input(out, deformation)
80
+
81
+ if occlusion_map is not None:
82
+ if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]:
83
+ occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear')
84
+ out = out * occlusion_map
85
+
86
+ output_dict["deformed"] = self.deform_input(source_image, deformation)
87
+
88
+ # Decoding part
89
+ out = self.bottleneck(out)
90
+ for i in range(len(self.up_blocks)):
91
+ out = self.up_blocks[i](out)
92
+ out = self.final(out)
93
+ out = F.sigmoid(out)
94
+
95
+ output_dict["prediction"] = out
96
+
97
+ return output_dict
modules/keypoint_detector.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from modules.util import Hourglass, make_coordinate_grid, AntiAliasInterpolation2d
5
+
6
+
7
+ class KPDetector(nn.Module):
8
+ """
9
+ Detecting a keypoints. Return keypoint position and jacobian near each keypoint.
10
+ """
11
+
12
+ def __init__(self, block_expansion, num_kp, num_channels, max_features,
13
+ num_blocks, temperature, estimate_jacobian=False, scale_factor=1,
14
+ single_jacobian_map=False, pad=0):
15
+ super(KPDetector, self).__init__()
16
+
17
+ self.predictor = Hourglass(block_expansion, in_features=num_channels,
18
+ max_features=max_features, num_blocks=num_blocks)
19
+
20
+ self.kp = nn.Conv2d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=(7, 7),
21
+ padding=pad)
22
+
23
+ if estimate_jacobian:
24
+ self.num_jacobian_maps = 1 if single_jacobian_map else num_kp
25
+ self.jacobian = nn.Conv2d(in_channels=self.predictor.out_filters,
26
+ out_channels=4 * self.num_jacobian_maps, kernel_size=(7, 7), padding=pad)
27
+ self.jacobian.weight.data.zero_()
28
+ self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 1] * self.num_jacobian_maps, dtype=torch.float))
29
+ else:
30
+ self.jacobian = None
31
+
32
+ self.temperature = temperature
33
+ self.scale_factor = scale_factor
34
+ if self.scale_factor != 1:
35
+ self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor)
36
+
37
+ def gaussian2kp(self, heatmap):
38
+ """
39
+ Extract the mean and from a heatmap
40
+ """
41
+ shape = heatmap.shape
42
+ heatmap = heatmap.unsqueeze(-1)
43
+ grid = make_coordinate_grid(shape[2:], heatmap.type()).unsqueeze_(0).unsqueeze_(0)
44
+ value = (heatmap * grid).sum(dim=(2, 3))
45
+ kp = {'value': value}
46
+
47
+ return kp
48
+
49
+ def forward(self, x):
50
+ if self.scale_factor != 1:
51
+ x = self.down(x)
52
+
53
+ feature_map = self.predictor(x)
54
+ prediction = self.kp(feature_map)
55
+
56
+ final_shape = prediction.shape
57
+ heatmap = prediction.view(final_shape[0], final_shape[1], -1)
58
+ heatmap = F.softmax(heatmap / self.temperature, dim=2)
59
+ heatmap = heatmap.view(*final_shape)
60
+
61
+ out = self.gaussian2kp(heatmap)
62
+
63
+ if self.jacobian is not None:
64
+ jacobian_map = self.jacobian(feature_map)
65
+ jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 4, final_shape[2],
66
+ final_shape[3])
67
+ heatmap = heatmap.unsqueeze(2)
68
+
69
+ jacobian = heatmap * jacobian_map
70
+ jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1)
71
+ jacobian = jacobian.sum(dim=-1)
72
+ jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 2, 2)
73
+ out['jacobian'] = jacobian
74
+
75
+ return out
modules/model.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from modules.util import AntiAliasInterpolation2d, make_coordinate_grid
5
+ from torchvision import models
6
+ import numpy as np
7
+ from torch.autograd import grad
8
+
9
+
10
+ class Vgg19(torch.nn.Module):
11
+ """
12
+ Vgg19 network for perceptual loss. See Sec 3.3.
13
+ """
14
+ def __init__(self, requires_grad=False):
15
+ super(Vgg19, self).__init__()
16
+ vgg_pretrained_features = models.vgg19(pretrained=True).features
17
+ self.slice1 = torch.nn.Sequential()
18
+ self.slice2 = torch.nn.Sequential()
19
+ self.slice3 = torch.nn.Sequential()
20
+ self.slice4 = torch.nn.Sequential()
21
+ self.slice5 = torch.nn.Sequential()
22
+ for x in range(2):
23
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
24
+ for x in range(2, 7):
25
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
26
+ for x in range(7, 12):
27
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
28
+ for x in range(12, 21):
29
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
30
+ for x in range(21, 30):
31
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
32
+
33
+ self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))),
34
+ requires_grad=False)
35
+ self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))),
36
+ requires_grad=False)
37
+
38
+ if not requires_grad:
39
+ for param in self.parameters():
40
+ param.requires_grad = False
41
+
42
+ def forward(self, X):
43
+ X = (X - self.mean) / self.std
44
+ h_relu1 = self.slice1(X)
45
+ h_relu2 = self.slice2(h_relu1)
46
+ h_relu3 = self.slice3(h_relu2)
47
+ h_relu4 = self.slice4(h_relu3)
48
+ h_relu5 = self.slice5(h_relu4)
49
+ out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
50
+ return out
51
+
52
+
53
+ class ImagePyramide(torch.nn.Module):
54
+ """
55
+ Create image pyramide for computing pyramide perceptual loss. See Sec 3.3
56
+ """
57
+ def __init__(self, scales, num_channels):
58
+ super(ImagePyramide, self).__init__()
59
+ downs = {}
60
+ for scale in scales:
61
+ downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale)
62
+ self.downs = nn.ModuleDict(downs)
63
+
64
+ def forward(self, x):
65
+ out_dict = {}
66
+ for scale, down_module in self.downs.items():
67
+ out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x)
68
+ return out_dict
69
+
70
+
71
+ class Transform:
72
+ """
73
+ Random tps transformation for equivariance constraints. See Sec 3.3
74
+ """
75
+ def __init__(self, bs, **kwargs):
76
+ noise = torch.normal(mean=0, std=kwargs['sigma_affine'] * torch.ones([bs, 2, 3]))
77
+ self.theta = noise + torch.eye(2, 3).view(1, 2, 3)
78
+ self.bs = bs
79
+
80
+ if ('sigma_tps' in kwargs) and ('points_tps' in kwargs):
81
+ self.tps = True
82
+ self.control_points = make_coordinate_grid((kwargs['points_tps'], kwargs['points_tps']), type=noise.type())
83
+ self.control_points = self.control_points.unsqueeze(0)
84
+ self.control_params = torch.normal(mean=0,
85
+ std=kwargs['sigma_tps'] * torch.ones([bs, 1, kwargs['points_tps'] ** 2]))
86
+ else:
87
+ self.tps = False
88
+
89
+ def transform_frame(self, frame):
90
+ grid = make_coordinate_grid(frame.shape[2:], type=frame.type()).unsqueeze(0)
91
+ grid = grid.view(1, frame.shape[2] * frame.shape[3], 2)
92
+ grid = self.warp_coordinates(grid).view(self.bs, frame.shape[2], frame.shape[3], 2)
93
+ return F.grid_sample(frame, grid, padding_mode="reflection")
94
+
95
+ def warp_coordinates(self, coordinates):
96
+ theta = self.theta.type(coordinates.type())
97
+ theta = theta.unsqueeze(1)
98
+ transformed = torch.matmul(theta[:, :, :, :2], coordinates.unsqueeze(-1)) + theta[:, :, :, 2:]
99
+ transformed = transformed.squeeze(-1)
100
+
101
+ if self.tps:
102
+ control_points = self.control_points.type(coordinates.type())
103
+ control_params = self.control_params.type(coordinates.type())
104
+ distances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2)
105
+ distances = torch.abs(distances).sum(-1)
106
+
107
+ result = distances ** 2
108
+ result = result * torch.log(distances + 1e-6)
109
+ result = result * control_params
110
+ result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1)
111
+ transformed = transformed + result
112
+
113
+ return transformed
114
+
115
+ def jacobian(self, coordinates):
116
+ new_coordinates = self.warp_coordinates(coordinates)
117
+ grad_x = grad(new_coordinates[..., 0].sum(), coordinates, create_graph=True)
118
+ grad_y = grad(new_coordinates[..., 1].sum(), coordinates, create_graph=True)
119
+ jacobian = torch.cat([grad_x[0].unsqueeze(-2), grad_y[0].unsqueeze(-2)], dim=-2)
120
+ return jacobian
121
+
122
+
123
+ def detach_kp(kp):
124
+ return {key: value.detach() for key, value in kp.items()}
125
+
126
+
127
+ class GeneratorFullModel(torch.nn.Module):
128
+ """
129
+ Merge all generator related updates into single model for better multi-gpu usage
130
+ """
131
+
132
+ def __init__(self, kp_extractor, generator, discriminator, train_params):
133
+ super(GeneratorFullModel, self).__init__()
134
+ self.kp_extractor = kp_extractor
135
+ self.generator = generator
136
+ self.discriminator = discriminator
137
+ self.train_params = train_params
138
+ self.scales = train_params['scales']
139
+ self.disc_scales = self.discriminator.scales
140
+ self.pyramid = ImagePyramide(self.scales, generator.num_channels)
141
+ if torch.cuda.is_available():
142
+ self.pyramid = self.pyramid.cuda()
143
+
144
+ self.loss_weights = train_params['loss_weights']
145
+
146
+ if sum(self.loss_weights['perceptual']) != 0:
147
+ self.vgg = Vgg19()
148
+ if torch.cuda.is_available():
149
+ self.vgg = self.vgg.cuda()
150
+
151
+ def forward(self, x):
152
+ kp_source = self.kp_extractor(x['source'])
153
+ kp_driving = self.kp_extractor(x['driving'])
154
+
155
+ generated = self.generator(x['source'], kp_source=kp_source, kp_driving=kp_driving)
156
+ generated.update({'kp_source': kp_source, 'kp_driving': kp_driving})
157
+
158
+ loss_values = {}
159
+
160
+ pyramide_real = self.pyramid(x['driving'])
161
+ pyramide_generated = self.pyramid(generated['prediction'])
162
+
163
+ if sum(self.loss_weights['perceptual']) != 0:
164
+ value_total = 0
165
+ for scale in self.scales:
166
+ x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])
167
+ y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])
168
+
169
+ for i, weight in enumerate(self.loss_weights['perceptual']):
170
+ value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
171
+ value_total += self.loss_weights['perceptual'][i] * value
172
+ loss_values['perceptual'] = value_total
173
+
174
+ if self.loss_weights['generator_gan'] != 0:
175
+ discriminator_maps_generated = self.discriminator(pyramide_generated, kp=detach_kp(kp_driving))
176
+ discriminator_maps_real = self.discriminator(pyramide_real, kp=detach_kp(kp_driving))
177
+ value_total = 0
178
+ for scale in self.disc_scales:
179
+ key = 'prediction_map_%s' % scale
180
+ value = ((1 - discriminator_maps_generated[key]) ** 2).mean()
181
+ value_total += self.loss_weights['generator_gan'] * value
182
+ loss_values['gen_gan'] = value_total
183
+
184
+ if sum(self.loss_weights['feature_matching']) != 0:
185
+ value_total = 0
186
+ for scale in self.disc_scales:
187
+ key = 'feature_maps_%s' % scale
188
+ for i, (a, b) in enumerate(zip(discriminator_maps_real[key], discriminator_maps_generated[key])):
189
+ if self.loss_weights['feature_matching'][i] == 0:
190
+ continue
191
+ value = torch.abs(a - b).mean()
192
+ value_total += self.loss_weights['feature_matching'][i] * value
193
+ loss_values['feature_matching'] = value_total
194
+
195
+ if (self.loss_weights['equivariance_value'] + self.loss_weights['equivariance_jacobian']) != 0:
196
+ transform = Transform(x['driving'].shape[0], **self.train_params['transform_params'])
197
+ transformed_frame = transform.transform_frame(x['driving'])
198
+ transformed_kp = self.kp_extractor(transformed_frame)
199
+
200
+ generated['transformed_frame'] = transformed_frame
201
+ generated['transformed_kp'] = transformed_kp
202
+
203
+ ## Value loss part
204
+ if self.loss_weights['equivariance_value'] != 0:
205
+ value = torch.abs(kp_driving['value'] - transform.warp_coordinates(transformed_kp['value'])).mean()
206
+ loss_values['equivariance_value'] = self.loss_weights['equivariance_value'] * value
207
+
208
+ ## jacobian loss part
209
+ if self.loss_weights['equivariance_jacobian'] != 0:
210
+ jacobian_transformed = torch.matmul(transform.jacobian(transformed_kp['value']),
211
+ transformed_kp['jacobian'])
212
+
213
+ normed_driving = torch.inverse(kp_driving['jacobian'])
214
+ normed_transformed = jacobian_transformed
215
+ value = torch.matmul(normed_driving, normed_transformed)
216
+
217
+ eye = torch.eye(2).view(1, 1, 2, 2).type(value.type())
218
+
219
+ value = torch.abs(eye - value).mean()
220
+ loss_values['equivariance_jacobian'] = self.loss_weights['equivariance_jacobian'] * value
221
+
222
+ return loss_values, generated
223
+
224
+
225
+ class DiscriminatorFullModel(torch.nn.Module):
226
+ """
227
+ Merge all discriminator related updates into single model for better multi-gpu usage
228
+ """
229
+
230
+ def __init__(self, kp_extractor, generator, discriminator, train_params):
231
+ super(DiscriminatorFullModel, self).__init__()
232
+ self.kp_extractor = kp_extractor
233
+ self.generator = generator
234
+ self.discriminator = discriminator
235
+ self.train_params = train_params
236
+ self.scales = self.discriminator.scales
237
+ self.pyramid = ImagePyramide(self.scales, generator.num_channels)
238
+ if torch.cuda.is_available():
239
+ self.pyramid = self.pyramid.cuda()
240
+
241
+ self.loss_weights = train_params['loss_weights']
242
+
243
+ def forward(self, x, generated):
244
+ pyramide_real = self.pyramid(x['driving'])
245
+ pyramide_generated = self.pyramid(generated['prediction'].detach())
246
+
247
+ kp_driving = generated['kp_driving']
248
+ discriminator_maps_generated = self.discriminator(pyramide_generated, kp=detach_kp(kp_driving))
249
+ discriminator_maps_real = self.discriminator(pyramide_real, kp=detach_kp(kp_driving))
250
+
251
+ loss_values = {}
252
+ value_total = 0
253
+ for scale in self.scales:
254
+ key = 'prediction_map_%s' % scale
255
+ value = (1 - discriminator_maps_real[key]) ** 2 + discriminator_maps_generated[key] ** 2
256
+ value_total += self.loss_weights['discriminator_gan'] * value.mean()
257
+ loss_values['disc_gan'] = value_total
258
+
259
+ return loss_values
modules/util.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+ import torch.nn.functional as F
4
+ import torch
5
+
6
+ from sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d
7
+
8
+
9
+ def kp2gaussian(kp, spatial_size, kp_variance):
10
+ """
11
+ Transform a keypoint into gaussian like representation
12
+ """
13
+ mean = kp['value']
14
+
15
+ coordinate_grid = make_coordinate_grid(spatial_size, mean.type())
16
+ number_of_leading_dimensions = len(mean.shape) - 1
17
+ shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape
18
+ coordinate_grid = coordinate_grid.view(*shape)
19
+ repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1)
20
+ coordinate_grid = coordinate_grid.repeat(*repeats)
21
+
22
+ # Preprocess kp shape
23
+ shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 2)
24
+ mean = mean.view(*shape)
25
+
26
+ mean_sub = (coordinate_grid - mean)
27
+
28
+ out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance)
29
+
30
+ return out
31
+
32
+
33
+ def make_coordinate_grid(spatial_size, type):
34
+ """
35
+ Create a meshgrid [-1,1] x [-1,1] of given spatial_size.
36
+ """
37
+ h, w = spatial_size
38
+ x = torch.arange(w).type(type)
39
+ y = torch.arange(h).type(type)
40
+
41
+ x = (2 * (x / (w - 1)) - 1)
42
+ y = (2 * (y / (h - 1)) - 1)
43
+
44
+ yy = y.view(-1, 1).repeat(1, w)
45
+ xx = x.view(1, -1).repeat(h, 1)
46
+
47
+ meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2)
48
+
49
+ return meshed
50
+
51
+
52
+ class ResBlock2d(nn.Module):
53
+ """
54
+ Res block, preserve spatial resolution.
55
+ """
56
+
57
+ def __init__(self, in_features, kernel_size, padding):
58
+ super(ResBlock2d, self).__init__()
59
+ self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
60
+ padding=padding)
61
+ self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
62
+ padding=padding)
63
+ self.norm1 = BatchNorm2d(in_features, affine=True)
64
+ self.norm2 = BatchNorm2d(in_features, affine=True)
65
+
66
+ def forward(self, x):
67
+ out = self.norm1(x)
68
+ out = F.relu(out)
69
+ out = self.conv1(out)
70
+ out = self.norm2(out)
71
+ out = F.relu(out)
72
+ out = self.conv2(out)
73
+ out += x
74
+ return out
75
+
76
+
77
+ class UpBlock2d(nn.Module):
78
+ """
79
+ Upsampling block for use in decoder.
80
+ """
81
+
82
+ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
83
+ super(UpBlock2d, self).__init__()
84
+
85
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
86
+ padding=padding, groups=groups)
87
+ self.norm = BatchNorm2d(out_features, affine=True)
88
+
89
+ def forward(self, x):
90
+ out = F.interpolate(x, scale_factor=2)
91
+ out = self.conv(out)
92
+ out = self.norm(out)
93
+ out = F.relu(out)
94
+ return out
95
+
96
+
97
+ class DownBlock2d(nn.Module):
98
+ """
99
+ Downsampling block for use in encoder.
100
+ """
101
+
102
+ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
103
+ super(DownBlock2d, self).__init__()
104
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
105
+ padding=padding, groups=groups)
106
+ self.norm = BatchNorm2d(out_features, affine=True)
107
+ self.pool = nn.AvgPool2d(kernel_size=(2, 2))
108
+
109
+ def forward(self, x):
110
+ out = self.conv(x)
111
+ out = self.norm(out)
112
+ out = F.relu(out)
113
+ out = self.pool(out)
114
+ return out
115
+
116
+
117
+ class SameBlock2d(nn.Module):
118
+ """
119
+ Simple block, preserve spatial resolution.
120
+ """
121
+
122
+ def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1):
123
+ super(SameBlock2d, self).__init__()
124
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features,
125
+ kernel_size=kernel_size, padding=padding, groups=groups)
126
+ self.norm = BatchNorm2d(out_features, affine=True)
127
+
128
+ def forward(self, x):
129
+ out = self.conv(x)
130
+ out = self.norm(out)
131
+ out = F.relu(out)
132
+ return out
133
+
134
+
135
+ class Encoder(nn.Module):
136
+ """
137
+ Hourglass Encoder
138
+ """
139
+
140
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
141
+ super(Encoder, self).__init__()
142
+
143
+ down_blocks = []
144
+ for i in range(num_blocks):
145
+ down_blocks.append(DownBlock2d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)),
146
+ min(max_features, block_expansion * (2 ** (i + 1))),
147
+ kernel_size=3, padding=1))
148
+ self.down_blocks = nn.ModuleList(down_blocks)
149
+
150
+ def forward(self, x):
151
+ outs = [x]
152
+ for down_block in self.down_blocks:
153
+ outs.append(down_block(outs[-1]))
154
+ return outs
155
+
156
+
157
+ class Decoder(nn.Module):
158
+ """
159
+ Hourglass Decoder
160
+ """
161
+
162
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
163
+ super(Decoder, self).__init__()
164
+
165
+ up_blocks = []
166
+
167
+ for i in range(num_blocks)[::-1]:
168
+ in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1)))
169
+ out_filters = min(max_features, block_expansion * (2 ** i))
170
+ up_blocks.append(UpBlock2d(in_filters, out_filters, kernel_size=3, padding=1))
171
+
172
+ self.up_blocks = nn.ModuleList(up_blocks)
173
+ self.out_filters = block_expansion + in_features
174
+
175
+ def forward(self, x):
176
+ out = x.pop()
177
+ for up_block in self.up_blocks:
178
+ out = up_block(out)
179
+ skip = x.pop()
180
+ out = torch.cat([out, skip], dim=1)
181
+ return out
182
+
183
+
184
+ class Hourglass(nn.Module):
185
+ """
186
+ Hourglass architecture.
187
+ """
188
+
189
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
190
+ super(Hourglass, self).__init__()
191
+ self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features)
192
+ self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features)
193
+ self.out_filters = self.decoder.out_filters
194
+
195
+ def forward(self, x):
196
+ return self.decoder(self.encoder(x))
197
+
198
+
199
+ class AntiAliasInterpolation2d(nn.Module):
200
+ """
201
+ Band-limited downsampling, for better preservation of the input signal.
202
+ """
203
+ def __init__(self, channels, scale):
204
+ super(AntiAliasInterpolation2d, self).__init__()
205
+ sigma = (1 / scale - 1) / 2
206
+ kernel_size = 2 * round(sigma * 4) + 1
207
+ self.ka = kernel_size // 2
208
+ self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka
209
+
210
+ kernel_size = [kernel_size, kernel_size]
211
+ sigma = [sigma, sigma]
212
+ # The gaussian kernel is the product of the
213
+ # gaussian function of each dimension.
214
+ kernel = 1
215
+ meshgrids = torch.meshgrid(
216
+ [
217
+ torch.arange(size, dtype=torch.float32)
218
+ for size in kernel_size
219
+ ]
220
+ )
221
+ for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
222
+ mean = (size - 1) / 2
223
+ kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2))
224
+
225
+ # Make sure sum of values in gaussian kernel equals 1.
226
+ kernel = kernel / torch.sum(kernel)
227
+ # Reshape to depthwise convolutional weight
228
+ kernel = kernel.view(1, 1, *kernel.size())
229
+ kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
230
+
231
+ self.register_buffer('weight', kernel)
232
+ self.groups = channels
233
+ self.scale = scale
234
+ inv_scale = 1 / scale
235
+ self.int_inv_scale = int(inv_scale)
236
+
237
+ def forward(self, input):
238
+ if self.scale == 1.0:
239
+ return input
240
+
241
+ out = F.pad(input, (self.ka, self.kb, self.ka, self.kb))
242
+ out = F.conv2d(out, weight=self.weight, groups=self.groups)
243
+ out = out[:, :, ::self.int_inv_scale, ::self.int_inv_scale]
244
+
245
+ return out
reconstruction.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from tqdm import tqdm
3
+ import torch
4
+ from torch.utils.data import DataLoader
5
+ from logger import Logger, Visualizer
6
+ import numpy as np
7
+ import imageio
8
+ from sync_batchnorm import DataParallelWithCallback
9
+
10
+
11
+ def reconstruction(config, generator, kp_detector, checkpoint, log_dir, dataset):
12
+ png_dir = os.path.join(log_dir, 'reconstruction/png')
13
+ log_dir = os.path.join(log_dir, 'reconstruction')
14
+
15
+ if checkpoint is not None:
16
+ Logger.load_cpk(checkpoint, generator=generator, kp_detector=kp_detector)
17
+ else:
18
+ raise AttributeError("Checkpoint should be specified for mode='reconstruction'.")
19
+ dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
20
+
21
+ if not os.path.exists(log_dir):
22
+ os.makedirs(log_dir)
23
+
24
+ if not os.path.exists(png_dir):
25
+ os.makedirs(png_dir)
26
+
27
+ loss_list = []
28
+ if torch.cuda.is_available():
29
+ generator = DataParallelWithCallback(generator)
30
+ kp_detector = DataParallelWithCallback(kp_detector)
31
+
32
+ generator.eval()
33
+ kp_detector.eval()
34
+
35
+ for it, x in tqdm(enumerate(dataloader)):
36
+ if config['reconstruction_params']['num_videos'] is not None:
37
+ if it > config['reconstruction_params']['num_videos']:
38
+ break
39
+ with torch.no_grad():
40
+ predictions = []
41
+ visualizations = []
42
+ if torch.cuda.is_available():
43
+ x['video'] = x['video'].cuda()
44
+ kp_source = kp_detector(x['video'][:, :, 0])
45
+ for frame_idx in range(x['video'].shape[2]):
46
+ source = x['video'][:, :, 0]
47
+ driving = x['video'][:, :, frame_idx]
48
+ kp_driving = kp_detector(driving)
49
+ out = generator(source, kp_source=kp_source, kp_driving=kp_driving)
50
+ out['kp_source'] = kp_source
51
+ out['kp_driving'] = kp_driving
52
+ del out['sparse_deformed']
53
+ predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
54
+
55
+ visualization = Visualizer(**config['visualizer_params']).visualize(source=source,
56
+ driving=driving, out=out)
57
+ visualizations.append(visualization)
58
+
59
+ loss_list.append(torch.abs(out['prediction'] - driving).mean().cpu().numpy())
60
+
61
+ predictions = np.concatenate(predictions, axis=1)
62
+ imageio.imsave(os.path.join(png_dir, x['name'][0] + '.png'), (255 * predictions).astype(np.uint8))
63
+
64
+ image_name = x['name'][0] + config['reconstruction_params']['format']
65
+ imageio.mimsave(os.path.join(log_dir, image_name), visualizations)
66
+
67
+ print("Reconstruction loss: %s" % np.mean(loss_list))
run.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib
2
+
3
+ matplotlib.use('Agg')
4
+
5
+ import os, sys
6
+ import yaml
7
+ from argparse import ArgumentParser
8
+ from time import gmtime, strftime
9
+ from shutil import copy
10
+
11
+ from frames_dataset import FramesDataset
12
+
13
+ from modules.generator import OcclusionAwareGenerator
14
+ from modules.discriminator import MultiScaleDiscriminator
15
+ from modules.keypoint_detector import KPDetector
16
+
17
+ import torch
18
+
19
+ from train import train
20
+ from reconstruction import reconstruction
21
+ from animate import animate
22
+
23
+ if __name__ == "__main__":
24
+
25
+ if sys.version_info[0] < 3:
26
+ raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7")
27
+
28
+ parser = ArgumentParser()
29
+ parser.add_argument("--config", required=True, help="path to config")
30
+ parser.add_argument("--mode", default="train", choices=["train", "reconstruction", "animate"])
31
+ parser.add_argument("--log_dir", default='log', help="path to log into")
32
+ parser.add_argument("--checkpoint", default=None, help="path to checkpoint to restore")
33
+ parser.add_argument("--device_ids", default="0", type=lambda x: list(map(int, x.split(','))),
34
+ help="Names of the devices comma separated.")
35
+ parser.add_argument("--verbose", dest="verbose", action="store_true", help="Print model architecture")
36
+ parser.set_defaults(verbose=False)
37
+
38
+ opt = parser.parse_args()
39
+ with open(opt.config) as f:
40
+ config = yaml.load(f)
41
+
42
+ if opt.checkpoint is not None:
43
+ log_dir = os.path.join(*os.path.split(opt.checkpoint)[:-1])
44
+ else:
45
+ log_dir = os.path.join(opt.log_dir, os.path.basename(opt.config).split('.')[0])
46
+ log_dir += ' ' + strftime("%d_%m_%y_%H.%M.%S", gmtime())
47
+
48
+ generator = OcclusionAwareGenerator(**config['model_params']['generator_params'],
49
+ **config['model_params']['common_params'])
50
+
51
+ if torch.cuda.is_available():
52
+ generator.to(opt.device_ids[0])
53
+ if opt.verbose:
54
+ print(generator)
55
+
56
+ discriminator = MultiScaleDiscriminator(**config['model_params']['discriminator_params'],
57
+ **config['model_params']['common_params'])
58
+ if torch.cuda.is_available():
59
+ discriminator.to(opt.device_ids[0])
60
+ if opt.verbose:
61
+ print(discriminator)
62
+
63
+ kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
64
+ **config['model_params']['common_params'])
65
+
66
+ if torch.cuda.is_available():
67
+ kp_detector.to(opt.device_ids[0])
68
+
69
+ if opt.verbose:
70
+ print(kp_detector)
71
+
72
+ dataset = FramesDataset(is_train=(opt.mode == 'train'), **config['dataset_params'])
73
+
74
+ if not os.path.exists(log_dir):
75
+ os.makedirs(log_dir)
76
+ if not os.path.exists(os.path.join(log_dir, os.path.basename(opt.config))):
77
+ copy(opt.config, log_dir)
78
+
79
+ if opt.mode == 'train':
80
+ print("Training...")
81
+ train(config, generator, discriminator, kp_detector, opt.checkpoint, log_dir, dataset, opt.device_ids)
82
+ elif opt.mode == 'reconstruction':
83
+ print("Reconstruction...")
84
+ reconstruction(config, generator, kp_detector, opt.checkpoint, log_dir, dataset)
85
+ elif opt.mode == 'animate':
86
+ print("Animate...")
87
+ animate(config, generator, kp_detector, opt.checkpoint, log_dir, dataset)
sup-mat/absolute-demo.gif ADDED

Git LFS Details

  • SHA256: 02e5b75bccd0766244ea83f3a427e725055a7dc118c3f8555367e432be868dc0
  • Pointer size: 132 Bytes
  • Size of remote file: 5.6 MB
sup-mat/face-swap.gif ADDED

Git LFS Details

  • SHA256: 36994f1fa70750d095dddca623fda3c81bf67939017f134563f2d0ca564242b6
  • Pointer size: 133 Bytes
  • Size of remote file: 15.2 MB
sup-mat/fashion-teaser.gif ADDED

Git LFS Details

  • SHA256: b391e0e46ab39e9eae48b5987a319a14dfa7b558eba4de2fc060b2d5d1050e83
  • Pointer size: 133 Bytes
  • Size of remote file: 11.3 MB
sup-mat/mgif-teaser.gif ADDED

Git LFS Details

  • SHA256: 53d509c00b63bbffbbfc88b65ce1f20e77acc59c128ec9da3d3e44325b202744
  • Pointer size: 132 Bytes
  • Size of remote file: 3.2 MB
sup-mat/relative-demo.gif ADDED

Git LFS Details

  • SHA256: 598c35e158d9c81305254ce82a2e6d2547e7b0e5c9cc710ab70e9953ff954db2
  • Pointer size: 132 Bytes
  • Size of remote file: 5.54 MB
sup-mat/vox-teaser.gif ADDED

Git LFS Details

  • SHA256: 1e70bf39d5e3299a6688a7aa2a654e3fcd0bf9d8c72cb883eac3a206040acc78
  • Pointer size: 133 Bytes
  • Size of remote file: 39.4 MB
sync_batchnorm/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : __init__.py
3
+ # Author : Jiayuan Mao
4
+ # Email : maojiayuan@gmail.com
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
12
+ from .replicate import DataParallelWithCallback, patch_replication_callback
sync_batchnorm/batchnorm.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : batchnorm.py
3
+ # Author : Jiayuan Mao
4
+ # Email : maojiayuan@gmail.com
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import collections
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+
16
+ from torch.nn.modules.batchnorm import _BatchNorm
17
+ from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
18
+
19
+ from .comm import SyncMaster
20
+
21
+ __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
22
+
23
+
24
+ def _sum_ft(tensor):
25
+ """sum over the first and last dimention"""
26
+ return tensor.sum(dim=0).sum(dim=-1)
27
+
28
+
29
+ def _unsqueeze_ft(tensor):
30
+ """add new dementions at the front and the tail"""
31
+ return tensor.unsqueeze(0).unsqueeze(-1)
32
+
33
+
34
+ _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
35
+ _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
36
+
37
+
38
+ class _SynchronizedBatchNorm(_BatchNorm):
39
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
40
+ super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
41
+
42
+ self._sync_master = SyncMaster(self._data_parallel_master)
43
+
44
+ self._is_parallel = False
45
+ self._parallel_id = None
46
+ self._slave_pipe = None
47
+
48
+ def forward(self, input):
49
+ # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
50
+ if not (self._is_parallel and self.training):
51
+ return F.batch_norm(
52
+ input, self.running_mean, self.running_var, self.weight, self.bias,
53
+ self.training, self.momentum, self.eps)
54
+
55
+ # Resize the input to (B, C, -1).
56
+ input_shape = input.size()
57
+ input = input.view(input.size(0), self.num_features, -1)
58
+
59
+ # Compute the sum and square-sum.
60
+ sum_size = input.size(0) * input.size(2)
61
+ input_sum = _sum_ft(input)
62
+ input_ssum = _sum_ft(input ** 2)
63
+
64
+ # Reduce-and-broadcast the statistics.
65
+ if self._parallel_id == 0:
66
+ mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
67
+ else:
68
+ mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
69
+
70
+ # Compute the output.
71
+ if self.affine:
72
+ # MJY:: Fuse the multiplication for speed.
73
+ output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
74
+ else:
75
+ output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
76
+
77
+ # Reshape it.
78
+ return output.view(input_shape)
79
+
80
+ def __data_parallel_replicate__(self, ctx, copy_id):
81
+ self._is_parallel = True
82
+ self._parallel_id = copy_id
83
+
84
+ # parallel_id == 0 means master device.
85
+ if self._parallel_id == 0:
86
+ ctx.sync_master = self._sync_master
87
+ else:
88
+ self._slave_pipe = ctx.sync_master.register_slave(copy_id)
89
+
90
+ def _data_parallel_master(self, intermediates):
91
+ """Reduce the sum and square-sum, compute the statistics, and broadcast it."""
92
+
93
+ # Always using same "device order" makes the ReduceAdd operation faster.
94
+ # Thanks to:: Tete Xiao (http://tetexiao.com/)
95
+ intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
96
+
97
+ to_reduce = [i[1][:2] for i in intermediates]
98
+ to_reduce = [j for i in to_reduce for j in i] # flatten
99
+ target_gpus = [i[1].sum.get_device() for i in intermediates]
100
+
101
+ sum_size = sum([i[1].sum_size for i in intermediates])
102
+ sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
103
+ mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
104
+
105
+ broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
106
+
107
+ outputs = []
108
+ for i, rec in enumerate(intermediates):
109
+ outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
110
+
111
+ return outputs
112
+
113
+ def _compute_mean_std(self, sum_, ssum, size):
114
+ """Compute the mean and standard-deviation with sum and square-sum. This method
115
+ also maintains the moving average on the master device."""
116
+ assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
117
+ mean = sum_ / size
118
+ sumvar = ssum - sum_ * mean
119
+ unbias_var = sumvar / (size - 1)
120
+ bias_var = sumvar / size
121
+
122
+ self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
123
+ self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
124
+
125
+ return mean, bias_var.clamp(self.eps) ** -0.5
126
+
127
+
128
+ class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
129
+ r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
130
+ mini-batch.
131
+
132
+ .. math::
133
+
134
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
135
+
136
+ This module differs from the built-in PyTorch BatchNorm1d as the mean and
137
+ standard-deviation are reduced across all devices during training.
138
+
139
+ For example, when one uses `nn.DataParallel` to wrap the network during
140
+ training, PyTorch's implementation normalize the tensor on each device using
141
+ the statistics only on that device, which accelerated the computation and
142
+ is also easy to implement, but the statistics might be inaccurate.
143
+ Instead, in this synchronized version, the statistics will be computed
144
+ over all training samples distributed on multiple devices.
145
+
146
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
147
+ as the built-in PyTorch implementation.
148
+
149
+ The mean and standard-deviation are calculated per-dimension over
150
+ the mini-batches and gamma and beta are learnable parameter vectors
151
+ of size C (where C is the input size).
152
+
153
+ During training, this layer keeps a running estimate of its computed mean
154
+ and variance. The running sum is kept with a default momentum of 0.1.
155
+
156
+ During evaluation, this running mean/variance is used for normalization.
157
+
158
+ Because the BatchNorm is done over the `C` dimension, computing statistics
159
+ on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
160
+
161
+ Args:
162
+ num_features: num_features from an expected input of size
163
+ `batch_size x num_features [x width]`
164
+ eps: a value added to the denominator for numerical stability.
165
+ Default: 1e-5
166
+ momentum: the value used for the running_mean and running_var
167
+ computation. Default: 0.1
168
+ affine: a boolean value that when set to ``True``, gives the layer learnable
169
+ affine parameters. Default: ``True``
170
+
171
+ Shape:
172
+ - Input: :math:`(N, C)` or :math:`(N, C, L)`
173
+ - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
174
+
175
+ Examples:
176
+ >>> # With Learnable Parameters
177
+ >>> m = SynchronizedBatchNorm1d(100)
178
+ >>> # Without Learnable Parameters
179
+ >>> m = SynchronizedBatchNorm1d(100, affine=False)
180
+ >>> input = torch.autograd.Variable(torch.randn(20, 100))
181
+ >>> output = m(input)
182
+ """
183
+
184
+ def _check_input_dim(self, input):
185
+ if input.dim() != 2 and input.dim() != 3:
186
+ raise ValueError('expected 2D or 3D input (got {}D input)'
187
+ .format(input.dim()))
188
+ super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
189
+
190
+
191
+ class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
192
+ r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
193
+ of 3d inputs
194
+
195
+ .. math::
196
+
197
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
198
+
199
+ This module differs from the built-in PyTorch BatchNorm2d as the mean and
200
+ standard-deviation are reduced across all devices during training.
201
+
202
+ For example, when one uses `nn.DataParallel` to wrap the network during
203
+ training, PyTorch's implementation normalize the tensor on each device using
204
+ the statistics only on that device, which accelerated the computation and
205
+ is also easy to implement, but the statistics might be inaccurate.
206
+ Instead, in this synchronized version, the statistics will be computed
207
+ over all training samples distributed on multiple devices.
208
+
209
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
210
+ as the built-in PyTorch implementation.
211
+
212
+ The mean and standard-deviation are calculated per-dimension over
213
+ the mini-batches and gamma and beta are learnable parameter vectors
214
+ of size C (where C is the input size).
215
+
216
+ During training, this layer keeps a running estimate of its computed mean
217
+ and variance. The running sum is kept with a default momentum of 0.1.
218
+
219
+ During evaluation, this running mean/variance is used for normalization.
220
+
221
+ Because the BatchNorm is done over the `C` dimension, computing statistics
222
+ on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
223
+
224
+ Args:
225
+ num_features: num_features from an expected input of
226
+ size batch_size x num_features x height x width
227
+ eps: a value added to the denominator for numerical stability.
228
+ Default: 1e-5
229
+ momentum: the value used for the running_mean and running_var
230
+ computation. Default: 0.1
231
+ affine: a boolean value that when set to ``True``, gives the layer learnable
232
+ affine parameters. Default: ``True``
233
+
234
+ Shape:
235
+ - Input: :math:`(N, C, H, W)`
236
+ - Output: :math:`(N, C, H, W)` (same shape as input)
237
+
238
+ Examples:
239
+ >>> # With Learnable Parameters
240
+ >>> m = SynchronizedBatchNorm2d(100)
241
+ >>> # Without Learnable Parameters
242
+ >>> m = SynchronizedBatchNorm2d(100, affine=False)
243
+ >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
244
+ >>> output = m(input)
245
+ """
246
+
247
+ def _check_input_dim(self, input):
248
+ if input.dim() != 4:
249
+ raise ValueError('expected 4D input (got {}D input)'
250
+ .format(input.dim()))
251
+ super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
252
+
253
+
254
+ class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
255
+ r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
256
+ of 4d inputs
257
+
258
+ .. math::
259
+
260
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
261
+
262
+ This module differs from the built-in PyTorch BatchNorm3d as the mean and
263
+ standard-deviation are reduced across all devices during training.
264
+
265
+ For example, when one uses `nn.DataParallel` to wrap the network during
266
+ training, PyTorch's implementation normalize the tensor on each device using
267
+ the statistics only on that device, which accelerated the computation and
268
+ is also easy to implement, but the statistics might be inaccurate.
269
+ Instead, in this synchronized version, the statistics will be computed
270
+ over all training samples distributed on multiple devices.
271
+
272
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
273
+ as the built-in PyTorch implementation.
274
+
275
+ The mean and standard-deviation are calculated per-dimension over
276
+ the mini-batches and gamma and beta are learnable parameter vectors
277
+ of size C (where C is the input size).
278
+
279
+ During training, this layer keeps a running estimate of its computed mean
280
+ and variance. The running sum is kept with a default momentum of 0.1.
281
+
282
+ During evaluation, this running mean/variance is used for normalization.
283
+
284
+ Because the BatchNorm is done over the `C` dimension, computing statistics
285
+ on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
286
+ or Spatio-temporal BatchNorm
287
+
288
+ Args:
289
+ num_features: num_features from an expected input of
290
+ size batch_size x num_features x depth x height x width
291
+ eps: a value added to the denominator for numerical stability.
292
+ Default: 1e-5
293
+ momentum: the value used for the running_mean and running_var
294
+ computation. Default: 0.1
295
+ affine: a boolean value that when set to ``True``, gives the layer learnable
296
+ affine parameters. Default: ``True``
297
+
298
+ Shape:
299
+ - Input: :math:`(N, C, D, H, W)`
300
+ - Output: :math:`(N, C, D, H, W)` (same shape as input)
301
+
302
+ Examples:
303
+ >>> # With Learnable Parameters
304
+ >>> m = SynchronizedBatchNorm3d(100)
305
+ >>> # Without Learnable Parameters
306
+ >>> m = SynchronizedBatchNorm3d(100, affine=False)
307
+ >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
308
+ >>> output = m(input)
309
+ """
310
+
311
+ def _check_input_dim(self, input):
312
+ if input.dim() != 5:
313
+ raise ValueError('expected 5D input (got {}D input)'
314
+ .format(input.dim()))
315
+ super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
sync_batchnorm/comm.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : comm.py
3
+ # Author : Jiayuan Mao
4
+ # Email : maojiayuan@gmail.com
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import queue
12
+ import collections
13
+ import threading
14
+
15
+ __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
16
+
17
+
18
+ class FutureResult(object):
19
+ """A thread-safe future implementation. Used only as one-to-one pipe."""
20
+
21
+ def __init__(self):
22
+ self._result = None
23
+ self._lock = threading.Lock()
24
+ self._cond = threading.Condition(self._lock)
25
+
26
+ def put(self, result):
27
+ with self._lock:
28
+ assert self._result is None, 'Previous result has\'t been fetched.'
29
+ self._result = result
30
+ self._cond.notify()
31
+
32
+ def get(self):
33
+ with self._lock:
34
+ if self._result is None:
35
+ self._cond.wait()
36
+
37
+ res = self._result
38
+ self._result = None
39
+ return res
40
+
41
+
42
+ _MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
43
+ _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
44
+
45
+
46
+ class SlavePipe(_SlavePipeBase):
47
+ """Pipe for master-slave communication."""
48
+
49
+ def run_slave(self, msg):
50
+ self.queue.put((self.identifier, msg))
51
+ ret = self.result.get()
52
+ self.queue.put(True)
53
+ return ret
54
+
55
+
56
+ class SyncMaster(object):
57
+ """An abstract `SyncMaster` object.
58
+
59
+ - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
60
+ call `register(id)` and obtain an `SlavePipe` to communicate with the master.
61
+ - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
62
+ and passed to a registered callback.
63
+ - After receiving the messages, the master device should gather the information and determine to message passed
64
+ back to each slave devices.
65
+ """
66
+
67
+ def __init__(self, master_callback):
68
+ """
69
+
70
+ Args:
71
+ master_callback: a callback to be invoked after having collected messages from slave devices.
72
+ """
73
+ self._master_callback = master_callback
74
+ self._queue = queue.Queue()
75
+ self._registry = collections.OrderedDict()
76
+ self._activated = False
77
+
78
+ def __getstate__(self):
79
+ return {'master_callback': self._master_callback}
80
+
81
+ def __setstate__(self, state):
82
+ self.__init__(state['master_callback'])
83
+
84
+ def register_slave(self, identifier):
85
+ """
86
+ Register an slave device.
87
+
88
+ Args:
89
+ identifier: an identifier, usually is the device id.
90
+
91
+ Returns: a `SlavePipe` object which can be used to communicate with the master device.
92
+
93
+ """
94
+ if self._activated:
95
+ assert self._queue.empty(), 'Queue is not clean before next initialization.'
96
+ self._activated = False
97
+ self._registry.clear()
98
+ future = FutureResult()
99
+ self._registry[identifier] = _MasterRegistry(future)
100
+ return SlavePipe(identifier, self._queue, future)
101
+
102
+ def run_master(self, master_msg):
103
+ """
104
+ Main entry for the master device in each forward pass.
105
+ The messages were first collected from each devices (including the master device), and then
106
+ an callback will be invoked to compute the message to be sent back to each devices
107
+ (including the master device).
108
+
109
+ Args:
110
+ master_msg: the message that the master want to send to itself. This will be placed as the first
111
+ message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
112
+
113
+ Returns: the message to be sent back to the master device.
114
+
115
+ """
116
+ self._activated = True
117
+
118
+ intermediates = [(0, master_msg)]
119
+ for i in range(self.nr_slaves):
120
+ intermediates.append(self._queue.get())
121
+
122
+ results = self._master_callback(intermediates)
123
+ assert results[0][0] == 0, 'The first result should belongs to the master.'
124
+
125
+ for i, res in results:
126
+ if i == 0:
127
+ continue
128
+ self._registry[i].result.put(res)
129
+
130
+ for i in range(self.nr_slaves):
131
+ assert self._queue.get() is True
132
+
133
+ return results[0][1]
134
+
135
+ @property
136
+ def nr_slaves(self):
137
+ return len(self._registry)
sync_batchnorm/replicate.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : replicate.py
3
+ # Author : Jiayuan Mao
4
+ # Email : maojiayuan@gmail.com
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import functools
12
+
13
+ from torch.nn.parallel.data_parallel import DataParallel
14
+
15
+ __all__ = [
16
+ 'CallbackContext',
17
+ 'execute_replication_callbacks',
18
+ 'DataParallelWithCallback',
19
+ 'patch_replication_callback'
20
+ ]
21
+
22
+
23
+ class CallbackContext(object):
24
+ pass
25
+
26
+
27
+ def execute_replication_callbacks(modules):
28
+ """
29
+ Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
30
+
31
+ The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
32
+
33
+ Note that, as all modules are isomorphism, we assign each sub-module with a context
34
+ (shared among multiple copies of this module on different devices).
35
+ Through this context, different copies can share some information.
36
+
37
+ We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
38
+ of any slave copies.
39
+ """
40
+ master_copy = modules[0]
41
+ nr_modules = len(list(master_copy.modules()))
42
+ ctxs = [CallbackContext() for _ in range(nr_modules)]
43
+
44
+ for i, module in enumerate(modules):
45
+ for j, m in enumerate(module.modules()):
46
+ if hasattr(m, '__data_parallel_replicate__'):
47
+ m.__data_parallel_replicate__(ctxs[j], i)
48
+
49
+
50
+ class DataParallelWithCallback(DataParallel):
51
+ """
52
+ Data Parallel with a replication callback.
53
+
54
+ An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
55
+ original `replicate` function.
56
+ The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
57
+
58
+ Examples:
59
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
60
+ > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
61
+ # sync_bn.__data_parallel_replicate__ will be invoked.
62
+ """
63
+
64
+ def replicate(self, module, device_ids):
65
+ modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
66
+ execute_replication_callbacks(modules)
67
+ return modules
68
+
69
+
70
+ def patch_replication_callback(data_parallel):
71
+ """
72
+ Monkey-patch an existing `DataParallel` object. Add the replication callback.
73
+ Useful when you have customized `DataParallel` implementation.
74
+
75
+ Examples:
76
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
77
+ > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
78
+ > patch_replication_callback(sync_bn)
79
+ # this is equivalent to
80
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
81
+ > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
82
+ """
83
+
84
+ assert isinstance(data_parallel, DataParallel)
85
+
86
+ old_replicate = data_parallel.replicate
87
+
88
+ @functools.wraps(old_replicate)
89
+ def new_replicate(module, device_ids):
90
+ modules = old_replicate(module, device_ids)
91
+ execute_replication_callbacks(modules)
92
+ return modules
93
+
94
+ data_parallel.replicate = new_replicate
sync_batchnorm/unittest.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : unittest.py
3
+ # Author : Jiayuan Mao
4
+ # Email : maojiayuan@gmail.com
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import unittest
12
+
13
+ import numpy as np
14
+ from torch.autograd import Variable
15
+
16
+
17
+ def as_numpy(v):
18
+ if isinstance(v, Variable):
19
+ v = v.data
20
+ return v.cpu().numpy()
21
+
22
+
23
+ class TorchTestCase(unittest.TestCase):
24
+ def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
25
+ npa, npb = as_numpy(a), as_numpy(b)
26
+ self.assertTrue(
27
+ np.allclose(npa, npb, atol=atol),
28
+ 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())
29
+ )
train.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import trange
2
+ import torch
3
+
4
+ from torch.utils.data import DataLoader
5
+
6
+ from logger import Logger
7
+ from modules.model import GeneratorFullModel, DiscriminatorFullModel
8
+
9
+ from torch.optim.lr_scheduler import MultiStepLR
10
+
11
+ from sync_batchnorm import DataParallelWithCallback
12
+
13
+ from frames_dataset import DatasetRepeater
14
+
15
+
16
+ def train(config, generator, discriminator, kp_detector, checkpoint, log_dir, dataset, device_ids):
17
+ train_params = config['train_params']
18
+
19
+ optimizer_generator = torch.optim.Adam(generator.parameters(), lr=train_params['lr_generator'], betas=(0.5, 0.999))
20
+ optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=train_params['lr_discriminator'], betas=(0.5, 0.999))
21
+ optimizer_kp_detector = torch.optim.Adam(kp_detector.parameters(), lr=train_params['lr_kp_detector'], betas=(0.5, 0.999))
22
+
23
+ if checkpoint is not None:
24
+ start_epoch = Logger.load_cpk(checkpoint, generator, discriminator, kp_detector,
25
+ optimizer_generator, optimizer_discriminator,
26
+ None if train_params['lr_kp_detector'] == 0 else optimizer_kp_detector)
27
+ else:
28
+ start_epoch = 0
29
+
30
+ scheduler_generator = MultiStepLR(optimizer_generator, train_params['epoch_milestones'], gamma=0.1,
31
+ last_epoch=start_epoch - 1)
32
+ scheduler_discriminator = MultiStepLR(optimizer_discriminator, train_params['epoch_milestones'], gamma=0.1,
33
+ last_epoch=start_epoch - 1)
34
+ scheduler_kp_detector = MultiStepLR(optimizer_kp_detector, train_params['epoch_milestones'], gamma=0.1,
35
+ last_epoch=-1 + start_epoch * (train_params['lr_kp_detector'] != 0))
36
+
37
+ if 'num_repeats' in train_params or train_params['num_repeats'] != 1:
38
+ dataset = DatasetRepeater(dataset, train_params['num_repeats'])
39
+ dataloader = DataLoader(dataset, batch_size=train_params['batch_size'], shuffle=True, num_workers=6, drop_last=True)
40
+
41
+ generator_full = GeneratorFullModel(kp_detector, generator, discriminator, train_params)
42
+ discriminator_full = DiscriminatorFullModel(kp_detector, generator, discriminator, train_params)
43
+
44
+ if torch.cuda.is_available():
45
+ generator_full = DataParallelWithCallback(generator_full, device_ids=device_ids)
46
+ discriminator_full = DataParallelWithCallback(discriminator_full, device_ids=device_ids)
47
+
48
+ with Logger(log_dir=log_dir, visualizer_params=config['visualizer_params'], checkpoint_freq=train_params['checkpoint_freq']) as logger:
49
+ for epoch in trange(start_epoch, train_params['num_epochs']):
50
+ for x in dataloader:
51
+ losses_generator, generated = generator_full(x)
52
+
53
+ loss_values = [val.mean() for val in losses_generator.values()]
54
+ loss = sum(loss_values)
55
+
56
+ loss.backward()
57
+ optimizer_generator.step()
58
+ optimizer_generator.zero_grad()
59
+ optimizer_kp_detector.step()
60
+ optimizer_kp_detector.zero_grad()
61
+
62
+ if train_params['loss_weights']['generator_gan'] != 0:
63
+ optimizer_discriminator.zero_grad()
64
+ losses_discriminator = discriminator_full(x, generated)
65
+ loss_values = [val.mean() for val in losses_discriminator.values()]
66
+ loss = sum(loss_values)
67
+
68
+ loss.backward()
69
+ optimizer_discriminator.step()
70
+ optimizer_discriminator.zero_grad()
71
+ else:
72
+ losses_discriminator = {}
73
+
74
+ losses_generator.update(losses_discriminator)
75
+ losses = {key: value.mean().detach().data.cpu().numpy() for key, value in losses_generator.items()}
76
+ logger.log_iter(losses=losses)
77
+
78
+ scheduler_generator.step()
79
+ scheduler_discriminator.step()
80
+ scheduler_kp_detector.step()
81
+
82
+ logger.log_epoch(epoch, {'generator': generator,
83
+ 'discriminator': discriminator,
84
+ 'kp_detector': kp_detector,
85
+ 'optimizer_generator': optimizer_generator,
86
+ 'optimizer_discriminator': optimizer_discriminator,
87
+ 'optimizer_kp_detector': optimizer_kp_detector}, inp=x, out=generated)