rinong commited on
Commit
7143243
1 Parent(s): 919dd48

Cloned official repo for the party

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +3 -3
  2. app.py +387 -0
  3. e4e/LICENSE +21 -0
  4. e4e/README.md +142 -0
  5. e4e/configs/__init__.py +0 -0
  6. e4e/configs/data_configs.py +41 -0
  7. e4e/configs/paths_config.py +28 -0
  8. e4e/configs/transforms_config.py +62 -0
  9. e4e/models/__init__.py +0 -0
  10. e4e/models/discriminator.py +20 -0
  11. e4e/models/encoders/__init__.py +0 -0
  12. e4e/models/encoders/helpers.py +140 -0
  13. e4e/models/encoders/model_irse.py +84 -0
  14. e4e/models/encoders/psp_encoders.py +200 -0
  15. e4e/models/latent_codes_pool.py +55 -0
  16. e4e/models/psp.py +99 -0
  17. e4e/models/stylegan2/__init__.py +0 -0
  18. e4e/models/stylegan2/model.py +678 -0
  19. e4e/models/stylegan2/op/__init__.py +0 -0
  20. e4e/models/stylegan2/op/fused_act.py +85 -0
  21. e4e/models/stylegan2/op/fused_bias_act.cpp +21 -0
  22. e4e/models/stylegan2/op/fused_bias_act_kernel.cu +99 -0
  23. e4e/models/stylegan2/op/upfirdn2d.cpp +23 -0
  24. e4e/models/stylegan2/op/upfirdn2d.py +184 -0
  25. e4e/models/stylegan2/op/upfirdn2d_kernel.cu +272 -0
  26. e4e/options/__init__.py +0 -0
  27. e4e/options/train_options.py +84 -0
  28. e4e/scripts/calc_losses_on_images.py +87 -0
  29. e4e/scripts/inference.py +133 -0
  30. e4e/scripts/train.py +88 -0
  31. e4e/utils/__init__.py +0 -0
  32. e4e/utils/alignment.py +115 -0
  33. e4e/utils/common.py +55 -0
  34. e4e/utils/data_utils.py +25 -0
  35. e4e/utils/model_utils.py +35 -0
  36. e4e/utils/train_utils.py +13 -0
  37. editing/interfacegan_boundaries/age.pt +0 -0
  38. editing/interfacegan_boundaries/beard.pt +0 -0
  39. editing/interfacegan_boundaries/gender.pt +0 -0
  40. editing/interfacegan_boundaries/hair_length.pt +0 -0
  41. editing/interfacegan_boundaries/pose.pt +0 -0
  42. editing/interfacegan_boundaries/smile.pt +0 -0
  43. generate_videos.py +129 -0
  44. model/sg2_model.py +817 -0
  45. op/__init__.py +0 -0
  46. op/__pycache__/__init__.cpython-37.pyc +0 -0
  47. op/__pycache__/__init__.cpython-38.pyc +0 -0
  48. op/__pycache__/conv2d_gradfix.cpython-37.pyc +0 -0
  49. op/__pycache__/conv2d_gradfix.cpython-38.pyc +0 -0
  50. op/__pycache__/fused_act.cpython-37.pyc +0 -0
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: StyleGAN NADA
3
- emoji: 💩
4
- colorFrom: red
5
- colorTo: pink
6
  sdk: gradio
7
  sdk_version: 3.0.2
8
  app_file: app.py
 
1
  ---
2
  title: StyleGAN NADA
3
+ emoji: 🌖
4
+ colorFrom: blue
5
+ colorTo: gray
6
  sdk: gradio
7
  sdk_version: 3.0.2
8
  app_file: app.py
app.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+
4
+ import torch
5
+ import gradio as gr
6
+
7
+ from e4e.models.psp import pSp
8
+ from util import *
9
+ from huggingface_hub import hf_hub_download
10
+
11
+ import tempfile
12
+ from argparse import Namespace
13
+ import shutil
14
+
15
+ import dlib
16
+ import numpy as np
17
+ import torchvision.transforms as transforms
18
+ from torchvision import utils
19
+
20
+ from model.sg2_model import Generator
21
+ from generate_videos import generate_frames, video_from_interpolations, project_code_by_edit_name
22
+ from styleclip.styleclip_global import project_code_with_styleclip, style_tensor_to_style_dict
23
+
24
+ import clip
25
+
26
+ model_dir = "models"
27
+ os.makedirs(model_dir, exist_ok=True)
28
+
29
+ model_repos = {"e4e": ("akhaliq/JoJoGAN_e4e_ffhq_encode", "e4e_ffhq_encode.pt"),
30
+ "dlib": ("akhaliq/jojogan_dlib", "shape_predictor_68_face_landmarks.dat"),
31
+ "sc_fs3": ("rinong/stylegan-nada-models", "fs3.npy"),
32
+ "base": ("akhaliq/jojogan-stylegan2-ffhq-config-f", "stylegan2-ffhq-config-f.pt"),
33
+ "sketch": ("rinong/stylegan-nada-models", "sketch.pt"),
34
+ "joker": ("rinong/stylegan-nada-models", "joker.pt"),
35
+ "pixar": ("rinong/stylegan-nada-models", "pixar.pt"),
36
+ "botero": ("rinong/stylegan-nada-models", "botero.pt"),
37
+ "white_walker": ("rinong/stylegan-nada-models", "white_walker.pt"),
38
+ "zuckerberg": ("rinong/stylegan-nada-models", "zuckerberg.pt"),
39
+ "simpson": ("rinong/stylegan-nada-models", "simpson.pt"),
40
+ "ssj": ("rinong/stylegan-nada-models", "ssj.pt"),
41
+ "cubism": ("rinong/stylegan-nada-models", "cubism.pt"),
42
+ "disney_princess": ("rinong/stylegan-nada-models", "disney_princess.pt"),
43
+ "edvard_munch": ("rinong/stylegan-nada-models", "edvard_munch.pt"),
44
+ "van_gogh": ("rinong/stylegan-nada-models", "van_gogh.pt"),
45
+ "oil": ("rinong/stylegan-nada-models", "oil.pt"),
46
+ "rick_morty": ("rinong/stylegan-nada-models", "rick_morty.pt"),
47
+ "anime": ("rinong/stylegan-nada-models", "anime.pt"),
48
+ "shrek": ("rinong/stylegan-nada-models", "shrek.pt"),
49
+ "thanos": ("rinong/stylegan-nada-models", "thanos.pt"),
50
+ "ukiyoe": ("rinong/stylegan-nada-models", "ukiyoe.pt"),
51
+ "groot": ("rinong/stylegan-nada-models", "groot.pt"),
52
+ "witcher": ("rinong/stylegan-nada-models", "witcher.pt"),
53
+ "grafitti_on_wall": ("rinong/stylegan-nada-models", "grafitti_on_wall.pt"),
54
+ "modernism": ("rinong/stylegan-nada-models", "modernism.pt"),
55
+ "marble": ("rinong/stylegan-nada-models", "marble.pt"),
56
+ "vintage_comics": ("rinong/stylegan-nada-models", "vintage_comics.pt"),
57
+ "crochet": ("rinong/stylegan-nada-models", "crochet.pt"),
58
+ "modigliani": ("rinong/stylegan-nada-models", "modigliani.pt"),
59
+ "ghibli": ("rinong/stylegan-nada-models", "ghibli.pt"),
60
+ "elf": ("rinong/stylegan-nada-models", "elf.pt"),
61
+ "zombie": ("rinong/stylegan-nada-models", "zombie.pt"),
62
+ "werewolf": ("rinong/stylegan-nada-models", "werewolf.pt"),
63
+ "plastic_puppet": ("rinong/stylegan-nada-models", "plastic_puppet.pt"),
64
+ "mona_lisa": ("rinong/stylegan-nada-models", "mona_lisa.pt"),
65
+ }
66
+
67
+ def get_models():
68
+ os.makedirs(model_dir, exist_ok=True)
69
+
70
+ model_paths = {}
71
+
72
+ for model_name, repo_details in model_repos.items():
73
+ download_path = hf_hub_download(repo_id=repo_details[0], filename=repo_details[1])
74
+ model_paths[model_name] = download_path
75
+
76
+ return model_paths
77
+
78
+ model_paths = get_models()
79
+
80
+ class ImageEditor(object):
81
+ def __init__(self):
82
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
83
+
84
+ latent_size = 512
85
+ n_mlp = 8
86
+ channel_mult = 2
87
+ model_size = 1024
88
+
89
+ self.generators = {}
90
+
91
+ self.model_list = [name for name in model_paths.keys() if name not in ["e4e", "dlib", "sc_fs3"]]
92
+
93
+ for model in self.model_list:
94
+ g_ema = Generator(
95
+ model_size, latent_size, n_mlp, channel_multiplier=channel_mult
96
+ ).to(self.device)
97
+
98
+ checkpoint = torch.load(model_paths[model], map_location=self.device)
99
+
100
+ g_ema.load_state_dict(checkpoint['g_ema'])
101
+
102
+ self.generators[model] = g_ema
103
+
104
+ self.experiment_args = {"model_path": model_paths["e4e"]}
105
+ self.experiment_args["transform"] = transforms.Compose(
106
+ [
107
+ transforms.Resize((256, 256)),
108
+ transforms.ToTensor(),
109
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
110
+ ]
111
+ )
112
+ self.resize_dims = (256, 256)
113
+
114
+ model_path = self.experiment_args["model_path"]
115
+
116
+ ckpt = torch.load(model_path, map_location="cpu")
117
+ opts = ckpt["opts"]
118
+
119
+ opts["checkpoint_path"] = model_path
120
+ opts = Namespace(**opts)
121
+
122
+ self.e4e_net = pSp(opts, self.device)
123
+ self.e4e_net.eval()
124
+
125
+ self.shape_predictor = dlib.shape_predictor(
126
+ model_paths["dlib"]
127
+ )
128
+
129
+ self.styleclip_fs3 = torch.from_numpy(np.load(model_paths["sc_fs3"])).to(self.device)
130
+
131
+ self.clip_model, _ = clip.load("ViT-B/32", device=self.device)
132
+
133
+ print("setup complete")
134
+
135
+ def get_style_list(self):
136
+ style_list = []
137
+
138
+ for key in self.generators:
139
+ style_list.append(key)
140
+
141
+ return style_list
142
+
143
+ def invert_image(self, input_image):
144
+ input_image = self.run_alignment(str(input_image))
145
+
146
+ input_image = input_image.resize(self.resize_dims)
147
+
148
+ img_transforms = self.experiment_args["transform"]
149
+ transformed_image = img_transforms(input_image)
150
+
151
+ with torch.no_grad():
152
+ images, latents = self.run_on_batch(transformed_image.unsqueeze(0))
153
+ result_image, latent = images[0], latents[0]
154
+
155
+ inverted_latent = latent.unsqueeze(0).unsqueeze(1)
156
+
157
+ return inverted_latent
158
+
159
+ def get_generators_for_styles(self, output_styles, loop_styles=False):
160
+
161
+ if "base" in output_styles: # always start with base if chosen
162
+ output_styles.insert(0, output_styles.pop(output_styles.index("base")))
163
+ if loop_styles:
164
+ output_styles.append(output_styles[0])
165
+
166
+ return [self.generators[style] for style in output_styles]
167
+
168
+ def _pack_edits(func):
169
+ def inner(self,
170
+ edit_type_choice,
171
+ pose_slider,
172
+ smile_slider,
173
+ gender_slider,
174
+ age_slider,
175
+ hair_slider,
176
+ src_text_styleclip,
177
+ tar_text_styleclip,
178
+ alpha_styleclip,
179
+ beta_styleclip,
180
+ *args):
181
+
182
+ edit_choices = {"edit_type": edit_type_choice,
183
+ "pose": pose_slider,
184
+ "smile": smile_slider,
185
+ "gender": gender_slider,
186
+ "age": age_slider,
187
+ "hair_length": hair_slider,
188
+ "src_text": src_text_styleclip,
189
+ "tar_text": tar_text_styleclip,
190
+ "alpha": alpha_styleclip,
191
+ "beta": beta_styleclip}
192
+
193
+
194
+ return func(self, *args, edit_choices)
195
+
196
+ return inner
197
+
198
+ def get_target_latents(self, source_latent, edit_choices, generators):
199
+
200
+ target_latents = []
201
+
202
+ if edit_choices["edit_type"] == "InterFaceGAN":
203
+ np_source_latent = source_latent.squeeze(0).cpu().detach().numpy()
204
+
205
+ for attribute_name in ["pose", "smile", "gender", "age", "hair_length"]:
206
+ strength = edit_choices[attribute_name]
207
+ if strength != 0.0:
208
+ projected_code_np = project_code_by_edit_name(np_source_latent, attribute_name, strength)
209
+ target_latents.append(torch.from_numpy(projected_code_np).float().to(self.device))
210
+
211
+ elif edit_choices["edit_type"] == "StyleCLIP":
212
+ if edit_choices["alpha"] != 0.0:
213
+ source_s_dict = generators[0].get_s_code(source_latent, input_is_latent=True)[0]
214
+ target_latents.append(project_code_with_styleclip(source_s_dict,
215
+ edit_choices["src_text"],
216
+ edit_choices["tar_text"],
217
+ edit_choices["alpha"],
218
+ edit_choices["beta"],
219
+ generators[0],
220
+ self.styleclip_fs3,
221
+ self.clip_model))
222
+
223
+ # if edit type is none or if all sliders were set to 0
224
+ if not target_latents:
225
+ target_latents = [source_latent.squeeze(0), ] * max((len(generators) - 1), 1)
226
+
227
+ return target_latents
228
+
229
+ @_pack_edits
230
+ def edit_image(self, input, output_styles, edit_choices):
231
+ return self.predict(input, output_styles, edit_choices=edit_choices)
232
+
233
+ @_pack_edits
234
+ def edit_video(self, input, output_styles, loop_styles, edit_choices):
235
+ return self.predict(input, output_styles, generate_video=True, loop_styles=loop_styles, edit_choices=edit_choices)
236
+
237
+ def predict(
238
+ self,
239
+ input, # Input image path
240
+ output_styles, # Style checkbox options.
241
+ generate_video = False, # Generate a video instead of an output image
242
+ loop_styles = False, # Loop back to the initial style
243
+ edit_choices = None, # Optional dictionary with edit choice arguments
244
+ ):
245
+
246
+ if edit_choices is None:
247
+ edit_choices = {"edit_type": "None"}
248
+
249
+ # @title Align image
250
+ out_dir = tempfile.mkdtemp()
251
+
252
+ inverted_latent = self.invert_image(input)
253
+ generators = self.get_generators_for_styles(output_styles, loop_styles)
254
+
255
+ target_latents = self.get_target_latents(inverted_latent, edit_choices, generators)
256
+
257
+ if not generate_video:
258
+ output_paths = []
259
+
260
+ with torch.no_grad():
261
+ for g_ema in generators:
262
+ latent_for_gen = random.choice(target_latents)
263
+
264
+ if edit_choices["edit_type"] == "StyleCLIP":
265
+ latent_for_gen = style_tensor_to_style_dict(latent_for_gen, g_ema)
266
+ img, _ = g_ema(latent_for_gen, input_is_s_code=True, input_is_latent=True, truncation=1, randomize_noise=False)
267
+ else:
268
+ img, _ = g_ema([latent_for_gen], input_is_latent=True, truncation=1, randomize_noise=False)
269
+
270
+ output_path = os.path.join(out_dir, f"out_{len(output_paths)}.jpg")
271
+ utils.save_image(img, output_path, nrow=1, normalize=True, range=(-1, 1))
272
+
273
+ output_paths.append(output_path)
274
+
275
+ return output_paths
276
+
277
+ return self.generate_vid(generators, inverted_latent, target_latents, out_dir)
278
+
279
+ def generate_vid(self, generators, source_latent, target_latents, out_dir):
280
+
281
+ fps = 24
282
+
283
+ with tempfile.TemporaryDirectory() as dirpath:
284
+ generate_frames(source_latent, target_latents, generators, dirpath)
285
+ video_from_interpolations(fps, dirpath)
286
+
287
+ gen_path = os.path.join(dirpath, "out.mp4")
288
+ out_path = os.path.join(out_dir, "out.mp4")
289
+
290
+ shutil.copy2(gen_path, out_path)
291
+
292
+ return out_path
293
+
294
+ def run_alignment(self, image_path):
295
+ aligned_image = align_face(filepath=image_path, predictor=self.shape_predictor)
296
+ print("Aligned image has shape: {}".format(aligned_image.size))
297
+ return aligned_image
298
+
299
+ def run_on_batch(self, inputs):
300
+ images, latents = self.e4e_net(
301
+ inputs.to(self.device).float(), randomize_noise=False, return_latents=True
302
+ )
303
+ return images, latents
304
+
305
+ editor = ImageEditor()
306
+
307
+ blocks = gr.Blocks()
308
+
309
+ with blocks:
310
+ gr.Markdown("<h1><center>StyleGAN-NADA</center></h1>")
311
+ gr.Markdown(
312
+ "<h4 style='font-size: 110%;margin-top:.5em'>Inference demo for StyleGAN-NADA: CLIP-Guided Domain Adaptation of Image Generators (SIGGRAPH 2022).</h4>"
313
+ )
314
+ gr.Markdown(
315
+ "<h4 style='font-size: 110%;margin-top:.5em'>Usage</h4><div>Upload an image of your face, pick your desired output styles, and apply StyleGAN-based editing.</div>"
316
+ "<div>Choose the edit image tab to create static images in all chosen styles. Choose the video tab in order to interpolate between all chosen styles</div><div>(To make it easier on the servers, we've limited video length. If you add too many styles (we recommend no more than 3!), they'll pass in the blink of an eye! 🤗)</div>"
317
+ )
318
+ gr.Markdown(
319
+ "For more information about the paper and code for training your own models (with text or images), please visit our <a href='https://stylegan-nada.github.io/' target='_blank'>project page</a> or the <a href='https://github.com/rinongal/StyleGAN-nada' target='_blank'>official repository</a>."
320
+ )
321
+
322
+ gr.Markdown("<h4 style='font-size: 110%;margin-top:.5em'>A note on social impact</h4><div>This model relies on StyleGAN and CLIP, both of which are prone to biases inherited from their training data and their architecture. These may include (but are not limited to) poor representation of minorities or the perpetution of societal biases, such as gender norms. In particular, StyleGAN editing may induce undesired changes in skin tones. Moreover, generative models can, and have been used to create deep fake imagery which may assist in the spread of propaganda. However, <a href='https://github.com/NVlabs/stylegan3-detector' target='_blank'>tools are available</a> for identifying StyleGAN generated imagery, and any 'realistic' results produced by this model should be easily identifiable through such tools.</div>")
323
+
324
+ with gr.Row():
325
+ with gr.Column():
326
+ input_img = gr.inputs.Image(type="filepath", label="Input image")
327
+
328
+ with gr.Column():
329
+ style_choice = gr.inputs.CheckboxGroup(choices=editor.get_style_list(), type="value", label="Choose your styles!")
330
+
331
+ editing_type_choice = gr.Radio(choices=["None", "InterFaceGAN", "StyleCLIP"], label="Choose latent space editing option. For InterFaceGAN and StyleCLIP, set the options below:")
332
+
333
+ with gr.Row():
334
+ with gr.Column():
335
+ with gr.Tabs():
336
+ with gr.TabItem("Edit Images"):
337
+ img_button = gr.Button("Edit Image")
338
+ img_output = gr.Gallery(label="Output Images")
339
+
340
+ with gr.TabItem("Create Video"):
341
+ with gr.Row():
342
+ vid_button = gr.Button("Generate Video")
343
+ loop_styles = gr.inputs.Checkbox(default=True, label="Loop video back to the initial style?")
344
+ with gr.Row():
345
+ with gr.Column():
346
+ gr.Markdown("Warning: Videos generation requires the synthesis of hundreds of frames and is expected to take several minutes.")
347
+ gr.Markdown("To reduce queue times, we significantly reduced the number of video frames. Using more than 3 styles will further reduce the frames per style, leading to quicker transitions. For better control, we recommend cloning the gradio app, adjusting <b>num_alphas</b> in <b>generate_videos.py</b>, and running the code locally.")
348
+ vid_output = gr.outputs.Video(label="Output Video")
349
+
350
+ with gr.Column():
351
+ with gr.Tabs():
352
+ with gr.TabItem("InterFaceGAN Editing Options"):
353
+ gr.Markdown("Move the sliders to make the chosen attribute stronger (e.g. the person older) or leave at 0 to disable editing.")
354
+ gr.Markdown("If multiple options are provided, they will be used randomly between images (or sequentially for a video), <u>not</u> together.")
355
+ gr.Markdown("Please note that some directions may be entangled. For example, hair length adjustments are likely to also modify the perceived gender.")
356
+
357
+ gr.Markdown("For more information about InterFaceGAN, please visit <a href='https://github.com/genforce/interfacegan' target='_blank'>the official repository</a>")
358
+
359
+ pose_slider = gr.Slider(label="Pose", minimum=-1, maximum=1, value=0, step=0.05)
360
+ smile_slider = gr.Slider(label="Smile", minimum=-1, maximum=1, value=0, step=0.05)
361
+ gender_slider = gr.Slider(label="Perceived Gender", minimum=-1, maximum=1, value=0, step=0.05)
362
+ age_slider = gr.Slider(label="Age", minimum=-1, maximum=1, value=0, step=0.05)
363
+ hair_slider = gr.Slider(label="Hair Length", minimum=-1, maximum=1, value=0, step=0.05)
364
+
365
+ ig_edit_choices = [pose_slider, smile_slider, gender_slider, age_slider, hair_slider]
366
+
367
+ with gr.TabItem("StyleCLIP Editing Options"):
368
+ gr.Markdown("Choose source and target descriptors, such as 'face with hair' to 'face with curly hair'")
369
+ gr.Markdown("Editing strength controls the magnitude of change. Disentanglement thresholds limits the number of channels the network can modify, reducing possible leak into other attributes. Setting the threshold too high may lead to no available channels. If you see an error, lower the threshold and try again.")
370
+ gr.Markdown("For more information about StyleCLIP, please visit <a href='https://github.com/orpatashnik/StyleCLIP' target='_blank'>the official repository</a>")
371
+
372
+ src_text_styleclip = gr.Textbox(label="Source text")
373
+ tar_text_styleclip = gr.Textbox(label="Target text")
374
+
375
+ alpha_styleclip = gr.Slider(label="Edit strength", minimum=-10, maximum=10, value=0, step=0.1)
376
+ beta_styleclip = gr.Slider(label="Disentanglement Threshold", minimum=0.08, maximum=0.3, value=0.14, step=0.01)
377
+
378
+ sc_edit_choices = [src_text_styleclip, tar_text_styleclip, alpha_styleclip, beta_styleclip]
379
+
380
+ edit_inputs = [editing_type_choice] + ig_edit_choices + sc_edit_choices
381
+ img_button.click(fn=editor.edit_image, inputs=edit_inputs + [input_img, style_choice], outputs=img_output)
382
+ vid_button.click(fn=editor.edit_video, inputs=edit_inputs + [input_img, style_choice, loop_styles], outputs=vid_output)
383
+
384
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2108.00946' target='_blank'>StyleGAN-NADA: CLIP-Guided Domain Adaptation of Image Generators</a> | <a href='https://stylegan-nada.github.io/' target='_blank'>Project Page</a> | <a href='https://github.com/rinongal/StyleGAN-nada' target='_blank'>Code</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=rinong_sgnada' alt='visitor badge'></center>"
385
+ gr.Markdown(article)
386
+
387
+ blocks.launch(enable_queue=True)
e4e/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2021 omertov
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
e4e/README.md ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Designing an Encoder for StyleGAN Image Manipulation
2
+ <a href="https://arxiv.org/abs/2102.02766"><img src="https://img.shields.io/badge/arXiv-2008.00951-b31b1b.svg"></a>
3
+ <a href="https://opensource.org/licenses/MIT"><img src="https://img.shields.io/badge/License-MIT-yellow.svg"></a>
4
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](http://colab.research.google.com/github/omertov/encoder4editing/blob/main/notebooks/inference_playground.ipynb)
5
+
6
+ > Recently, there has been a surge of diverse methods for performing image editing by employing pre-trained unconditional generators. Applying these methods on real images, however, remains a challenge, as it necessarily requires the inversion of the images into their latent space. To successfully invert a real image, one needs to find a latent code that reconstructs the input image accurately, and more importantly, allows for its meaningful manipulation. In this paper, we carefully study the latent space of StyleGAN, the state-of-the-art unconditional generator. We identify and analyze the existence of a distortion-editability tradeoff and a distortion-perception tradeoff within the StyleGAN latent space. We then suggest two principles for designing encoders in a manner that allows one to control the proximity of the inversions to regions that StyleGAN was originally trained on. We present an encoder based on our two principles that is specifically designed for facilitating editing on real images by balancing these tradeoffs. By evaluating its performance qualitatively and quantitatively on numerous challenging domains, including cars and horses, we show that our inversion method, followed by common editing techniques, achieves superior real-image editing quality, with only a small reconstruction accuracy drop.
7
+
8
+ <p align="center">
9
+ <img src="docs/teaser.jpg" width="800px"/>
10
+ </p>
11
+
12
+ ## Description
13
+ Official Implementation of "<a href="https://arxiv.org/abs/2102.02766">Designing an Encoder for StyleGAN Image Manipulation</a>" paper for both training and evaluation.
14
+ The e4e encoder is specifically designed to complement existing image manipulation techniques performed over StyleGAN's latent space.
15
+
16
+ ## Recent Updates
17
+ `2021.03.25`: Add pose editing direction.
18
+
19
+ ## Getting Started
20
+ ### Prerequisites
21
+ - Linux or macOS
22
+ - NVIDIA GPU + CUDA CuDNN (CPU may be possible with some modifications, but is not inherently supported)
23
+ - Python 3
24
+
25
+ ### Installation
26
+ - Clone the repository:
27
+ ```
28
+ git clone https://github.com/omertov/encoder4editing.git
29
+ cd encoder4editing
30
+ ```
31
+ - Dependencies:
32
+ We recommend running this repository using [Anaconda](https://docs.anaconda.com/anaconda/install/).
33
+ All dependencies for defining the environment are provided in `environment/e4e_env.yaml`.
34
+
35
+ ### Inference Notebook
36
+ We provide a Jupyter notebook found in `notebooks/inference_playground.ipynb` that allows one to encode and perform several editings on real images using StyleGAN.
37
+
38
+ ### Pretrained Models
39
+ Please download the pre-trained models from the following links. Each e4e model contains the entire pSp framework architecture, including the encoder and decoder weights.
40
+ | Path | Description
41
+ | :--- | :----------
42
+ |[FFHQ Inversion](https://drive.google.com/file/d/1cUv_reLE6k3604or78EranS7XzuVMWeO/view?usp=sharing) | FFHQ e4e encoder.
43
+ |[Cars Inversion](https://drive.google.com/file/d/17faPqBce2m1AQeLCLHUVXaDfxMRU2QcV/view?usp=sharing) | Cars e4e encoder.
44
+ |[Horse Inversion](https://drive.google.com/file/d/1TkLLnuX86B_BMo2ocYD0kX9kWh53rUVX/view?usp=sharing) | Horse e4e encoder.
45
+ |[Church Inversion](https://drive.google.com/file/d/1-L0ZdnQLwtdy6-A_Ccgq5uNJGTqE7qBa/view?usp=sharing) | Church e4e encoder.
46
+
47
+ If you wish to use one of the pretrained models for training or inference, you may do so using the flag `--checkpoint_path`.
48
+
49
+ In addition, we provide various auxiliary models needed for training your own e4e model from scratch.
50
+ | Path | Description
51
+ | :--- | :----------
52
+ |[FFHQ StyleGAN](https://drive.google.com/file/d/1EM87UquaoQmk17Q8d5kYIAHqu0dkYqdT/view?usp=sharing) | StyleGAN model pretrained on FFHQ taken from [rosinality](https://github.com/rosinality/stylegan2-pytorch) with 1024x1024 output resolution.
53
+ |[IR-SE50 Model](https://drive.google.com/file/d/1KW7bjndL3QG3sxBbZxreGHigcCCpsDgn/view?usp=sharing) | Pretrained IR-SE50 model taken from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) for use in our ID loss during training.
54
+ |[MOCOv2 Model](https://drive.google.com/file/d/18rLcNGdteX5LwT7sv_F7HWr12HpVEzVe/view?usp=sharing) | Pretrained ResNet-50 model trained using MOCOv2 for use in our simmilarity loss for domains other then human faces during training.
55
+
56
+ By default, we assume that all auxiliary models are downloaded and saved to the directory `pretrained_models`. However, you may use your own paths by changing the necessary values in `configs/path_configs.py`.
57
+
58
+ ## Training
59
+ To train the e4e encoder, make sure the paths to the required models, as well as training and testing data is configured in `configs/path_configs.py` and `configs/data_configs.py`.
60
+ #### **Training the e4e Encoder**
61
+ ```
62
+ python scripts/train.py \
63
+ --dataset_type cars_encode \
64
+ --exp_dir new/experiment/directory \
65
+ --start_from_latent_avg \
66
+ --use_w_pool \
67
+ --w_discriminator_lambda 0.1 \
68
+ --progressive_start 20000 \
69
+ --id_lambda 0.5 \
70
+ --val_interval 10000 \
71
+ --max_steps 200000 \
72
+ --stylegan_size 512 \
73
+ --stylegan_weights path/to/pretrained/stylegan.pt \
74
+ --workers 8 \
75
+ --batch_size 8 \
76
+ --test_batch_size 4 \
77
+ --test_workers 4
78
+ ```
79
+
80
+ #### Training on your own dataset
81
+ In order to train the e4e encoder on a custom dataset, perform the following adjustments:
82
+ 1. Insert the paths to your train and test data into the `dataset_paths` variable defined in `configs/paths_config.py`:
83
+ ```
84
+ dataset_paths = {
85
+ 'my_train_data': '/path/to/train/images/directory',
86
+ 'my_test_data': '/path/to/test/images/directory'
87
+ }
88
+ ```
89
+ 2. Configure a new dataset under the DATASETS variable defined in `configs/data_configs.py`:
90
+ ```
91
+ DATASETS = {
92
+ 'my_data_encode': {
93
+ 'transforms': transforms_config.EncodeTransforms,
94
+ 'train_source_root': dataset_paths['my_train_data'],
95
+ 'train_target_root': dataset_paths['my_train_data'],
96
+ 'test_source_root': dataset_paths['my_test_data'],
97
+ 'test_target_root': dataset_paths['my_test_data']
98
+ }
99
+ }
100
+ ```
101
+ Refer to `configs/transforms_config.py` for the transformations applied to the train and test images during training.
102
+
103
+ 3. Finally, run a training session with `--dataset_type my_data_encode`.
104
+
105
+ ## Inference
106
+ Having trained your model, you can use `scripts/inference.py` to apply the model on a set of images.
107
+ For example,
108
+ ```
109
+ python scripts/inference.py \
110
+ --images_dir=/path/to/images/directory \
111
+ --save_dir=/path/to/saving/directory \
112
+ path/to/checkpoint.pt
113
+ ```
114
+
115
+ ## Latent Editing Consistency (LEC)
116
+ As described in the paper, we suggest a new metric, Latent Editing Consistency (LEC), for evaluating the encoder's
117
+ performance.
118
+ We provide an example for calculating the metric over the FFHQ StyleGAN using the aging editing direction in
119
+ `metrics/LEC.py`.
120
+
121
+ To run the example:
122
+ ```
123
+ cd metrics
124
+ python LEC.py \
125
+ --images_dir=/path/to/images/directory \
126
+ path/to/checkpoint.pt
127
+ ```
128
+
129
+ ## Acknowledgments
130
+ This code borrows heavily from [pixel2style2pixel](https://github.com/eladrich/pixel2style2pixel)
131
+
132
+ ## Citation
133
+ If you use this code for your research, please cite our paper <a href="https://arxiv.org/abs/2102.02766">Designing an Encoder for StyleGAN Image Manipulation</a>:
134
+
135
+ ```
136
+ @article{tov2021designing,
137
+ title={Designing an Encoder for StyleGAN Image Manipulation},
138
+ author={Tov, Omer and Alaluf, Yuval and Nitzan, Yotam and Patashnik, Or and Cohen-Or, Daniel},
139
+ journal={arXiv preprint arXiv:2102.02766},
140
+ year={2021}
141
+ }
142
+ ```
e4e/configs/__init__.py ADDED
File without changes
e4e/configs/data_configs.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from configs import transforms_config
2
+ from configs.paths_config import dataset_paths
3
+
4
+
5
+ DATASETS = {
6
+ 'ffhq_encode': {
7
+ 'transforms': transforms_config.EncodeTransforms,
8
+ 'train_source_root': dataset_paths['ffhq'],
9
+ 'train_target_root': dataset_paths['ffhq'],
10
+ 'test_source_root': dataset_paths['celeba_test'],
11
+ 'test_target_root': dataset_paths['celeba_test'],
12
+ },
13
+ 'cars_encode': {
14
+ 'transforms': transforms_config.CarsEncodeTransforms,
15
+ 'train_source_root': dataset_paths['cars_train'],
16
+ 'train_target_root': dataset_paths['cars_train'],
17
+ 'test_source_root': dataset_paths['cars_test'],
18
+ 'test_target_root': dataset_paths['cars_test'],
19
+ },
20
+ 'horse_encode': {
21
+ 'transforms': transforms_config.EncodeTransforms,
22
+ 'train_source_root': dataset_paths['horse_train'],
23
+ 'train_target_root': dataset_paths['horse_train'],
24
+ 'test_source_root': dataset_paths['horse_test'],
25
+ 'test_target_root': dataset_paths['horse_test'],
26
+ },
27
+ 'church_encode': {
28
+ 'transforms': transforms_config.EncodeTransforms,
29
+ 'train_source_root': dataset_paths['church_train'],
30
+ 'train_target_root': dataset_paths['church_train'],
31
+ 'test_source_root': dataset_paths['church_test'],
32
+ 'test_target_root': dataset_paths['church_test'],
33
+ },
34
+ 'cats_encode': {
35
+ 'transforms': transforms_config.EncodeTransforms,
36
+ 'train_source_root': dataset_paths['cats_train'],
37
+ 'train_target_root': dataset_paths['cats_train'],
38
+ 'test_source_root': dataset_paths['cats_test'],
39
+ 'test_target_root': dataset_paths['cats_test'],
40
+ }
41
+ }
e4e/configs/paths_config.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_paths = {
2
+ # Face Datasets (In the paper: FFHQ - train, CelebAHQ - test)
3
+ 'ffhq': '',
4
+ 'celeba_test': '',
5
+
6
+ # Cars Dataset (In the paper: Stanford cars)
7
+ 'cars_train': '',
8
+ 'cars_test': '',
9
+
10
+ # Horse Dataset (In the paper: LSUN Horse)
11
+ 'horse_train': '',
12
+ 'horse_test': '',
13
+
14
+ # Church Dataset (In the paper: LSUN Church)
15
+ 'church_train': '',
16
+ 'church_test': '',
17
+
18
+ # Cats Dataset (In the paper: LSUN Cat)
19
+ 'cats_train': '',
20
+ 'cats_test': ''
21
+ }
22
+
23
+ model_paths = {
24
+ 'stylegan_ffhq': 'pretrained_models/stylegan2-ffhq-config-f.pt',
25
+ 'ir_se50': 'pretrained_models/model_ir_se50.pth',
26
+ 'shape_predictor': 'pretrained_models/shape_predictor_68_face_landmarks.dat',
27
+ 'moco': 'pretrained_models/moco_v2_800ep_pretrain.pth'
28
+ }
e4e/configs/transforms_config.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ import torchvision.transforms as transforms
3
+
4
+
5
+ class TransformsConfig(object):
6
+
7
+ def __init__(self, opts):
8
+ self.opts = opts
9
+
10
+ @abstractmethod
11
+ def get_transforms(self):
12
+ pass
13
+
14
+
15
+ class EncodeTransforms(TransformsConfig):
16
+
17
+ def __init__(self, opts):
18
+ super(EncodeTransforms, self).__init__(opts)
19
+
20
+ def get_transforms(self):
21
+ transforms_dict = {
22
+ 'transform_gt_train': transforms.Compose([
23
+ transforms.Resize((256, 256)),
24
+ transforms.RandomHorizontalFlip(0.5),
25
+ transforms.ToTensor(),
26
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
27
+ 'transform_source': None,
28
+ 'transform_test': transforms.Compose([
29
+ transforms.Resize((256, 256)),
30
+ transforms.ToTensor(),
31
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
32
+ 'transform_inference': transforms.Compose([
33
+ transforms.Resize((256, 256)),
34
+ transforms.ToTensor(),
35
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
36
+ }
37
+ return transforms_dict
38
+
39
+
40
+ class CarsEncodeTransforms(TransformsConfig):
41
+
42
+ def __init__(self, opts):
43
+ super(CarsEncodeTransforms, self).__init__(opts)
44
+
45
+ def get_transforms(self):
46
+ transforms_dict = {
47
+ 'transform_gt_train': transforms.Compose([
48
+ transforms.Resize((192, 256)),
49
+ transforms.RandomHorizontalFlip(0.5),
50
+ transforms.ToTensor(),
51
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
52
+ 'transform_source': None,
53
+ 'transform_test': transforms.Compose([
54
+ transforms.Resize((192, 256)),
55
+ transforms.ToTensor(),
56
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
57
+ 'transform_inference': transforms.Compose([
58
+ transforms.Resize((192, 256)),
59
+ transforms.ToTensor(),
60
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
61
+ }
62
+ return transforms_dict
e4e/models/__init__.py ADDED
File without changes
e4e/models/discriminator.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+
4
+ class LatentCodesDiscriminator(nn.Module):
5
+ def __init__(self, style_dim, n_mlp):
6
+ super().__init__()
7
+
8
+ self.style_dim = style_dim
9
+
10
+ layers = []
11
+ for i in range(n_mlp-1):
12
+ layers.append(
13
+ nn.Linear(style_dim, style_dim)
14
+ )
15
+ layers.append(nn.LeakyReLU(0.2))
16
+ layers.append(nn.Linear(512, 1))
17
+ self.mlp = nn.Sequential(*layers)
18
+
19
+ def forward(self, w):
20
+ return self.mlp(w)
e4e/models/encoders/__init__.py ADDED
File without changes
e4e/models/encoders/helpers.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module
5
+
6
+ """
7
+ ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
8
+ """
9
+
10
+
11
+ class Flatten(Module):
12
+ def forward(self, input):
13
+ return input.view(input.size(0), -1)
14
+
15
+
16
+ def l2_norm(input, axis=1):
17
+ norm = torch.norm(input, 2, axis, True)
18
+ output = torch.div(input, norm)
19
+ return output
20
+
21
+
22
+ class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
23
+ """ A named tuple describing a ResNet block. """
24
+
25
+
26
+ def get_block(in_channel, depth, num_units, stride=2):
27
+ return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
28
+
29
+
30
+ def get_blocks(num_layers):
31
+ if num_layers == 50:
32
+ blocks = [
33
+ get_block(in_channel=64, depth=64, num_units=3),
34
+ get_block(in_channel=64, depth=128, num_units=4),
35
+ get_block(in_channel=128, depth=256, num_units=14),
36
+ get_block(in_channel=256, depth=512, num_units=3)
37
+ ]
38
+ elif num_layers == 100:
39
+ blocks = [
40
+ get_block(in_channel=64, depth=64, num_units=3),
41
+ get_block(in_channel=64, depth=128, num_units=13),
42
+ get_block(in_channel=128, depth=256, num_units=30),
43
+ get_block(in_channel=256, depth=512, num_units=3)
44
+ ]
45
+ elif num_layers == 152:
46
+ blocks = [
47
+ get_block(in_channel=64, depth=64, num_units=3),
48
+ get_block(in_channel=64, depth=128, num_units=8),
49
+ get_block(in_channel=128, depth=256, num_units=36),
50
+ get_block(in_channel=256, depth=512, num_units=3)
51
+ ]
52
+ else:
53
+ raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers))
54
+ return blocks
55
+
56
+
57
+ class SEModule(Module):
58
+ def __init__(self, channels, reduction):
59
+ super(SEModule, self).__init__()
60
+ self.avg_pool = AdaptiveAvgPool2d(1)
61
+ self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
62
+ self.relu = ReLU(inplace=True)
63
+ self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
64
+ self.sigmoid = Sigmoid()
65
+
66
+ def forward(self, x):
67
+ module_input = x
68
+ x = self.avg_pool(x)
69
+ x = self.fc1(x)
70
+ x = self.relu(x)
71
+ x = self.fc2(x)
72
+ x = self.sigmoid(x)
73
+ return module_input * x
74
+
75
+
76
+ class bottleneck_IR(Module):
77
+ def __init__(self, in_channel, depth, stride):
78
+ super(bottleneck_IR, self).__init__()
79
+ if in_channel == depth:
80
+ self.shortcut_layer = MaxPool2d(1, stride)
81
+ else:
82
+ self.shortcut_layer = Sequential(
83
+ Conv2d(in_channel, depth, (1, 1), stride, bias=False),
84
+ BatchNorm2d(depth)
85
+ )
86
+ self.res_layer = Sequential(
87
+ BatchNorm2d(in_channel),
88
+ Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
89
+ Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)
90
+ )
91
+
92
+ def forward(self, x):
93
+ shortcut = self.shortcut_layer(x)
94
+ res = self.res_layer(x)
95
+ return res + shortcut
96
+
97
+
98
+ class bottleneck_IR_SE(Module):
99
+ def __init__(self, in_channel, depth, stride):
100
+ super(bottleneck_IR_SE, self).__init__()
101
+ if in_channel == depth:
102
+ self.shortcut_layer = MaxPool2d(1, stride)
103
+ else:
104
+ self.shortcut_layer = Sequential(
105
+ Conv2d(in_channel, depth, (1, 1), stride, bias=False),
106
+ BatchNorm2d(depth)
107
+ )
108
+ self.res_layer = Sequential(
109
+ BatchNorm2d(in_channel),
110
+ Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
111
+ PReLU(depth),
112
+ Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
113
+ BatchNorm2d(depth),
114
+ SEModule(depth, 16)
115
+ )
116
+
117
+ def forward(self, x):
118
+ shortcut = self.shortcut_layer(x)
119
+ res = self.res_layer(x)
120
+ return res + shortcut
121
+
122
+
123
+ def _upsample_add(x, y):
124
+ """Upsample and add two feature maps.
125
+ Args:
126
+ x: (Variable) top feature map to be upsampled.
127
+ y: (Variable) lateral feature map.
128
+ Returns:
129
+ (Variable) added feature map.
130
+ Note in PyTorch, when input size is odd, the upsampled feature map
131
+ with `F.upsample(..., scale_factor=2, mode='nearest')`
132
+ maybe not equal to the lateral feature map size.
133
+ e.g.
134
+ original input size: [N,_,15,15] ->
135
+ conv2d feature map size: [N,_,8,8] ->
136
+ upsampled feature map size: [N,_,16,16]
137
+ So we choose bilinear upsample which supports arbitrary output sizes.
138
+ """
139
+ _, _, H, W = y.size()
140
+ return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y
e4e/models/encoders/model_irse.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module
2
+ from e4e.models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm
3
+
4
+ """
5
+ Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
6
+ """
7
+
8
+
9
+ class Backbone(Module):
10
+ def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True):
11
+ super(Backbone, self).__init__()
12
+ assert input_size in [112, 224], "input_size should be 112 or 224"
13
+ assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
14
+ assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
15
+ blocks = get_blocks(num_layers)
16
+ if mode == 'ir':
17
+ unit_module = bottleneck_IR
18
+ elif mode == 'ir_se':
19
+ unit_module = bottleneck_IR_SE
20
+ self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
21
+ BatchNorm2d(64),
22
+ PReLU(64))
23
+ if input_size == 112:
24
+ self.output_layer = Sequential(BatchNorm2d(512),
25
+ Dropout(drop_ratio),
26
+ Flatten(),
27
+ Linear(512 * 7 * 7, 512),
28
+ BatchNorm1d(512, affine=affine))
29
+ else:
30
+ self.output_layer = Sequential(BatchNorm2d(512),
31
+ Dropout(drop_ratio),
32
+ Flatten(),
33
+ Linear(512 * 14 * 14, 512),
34
+ BatchNorm1d(512, affine=affine))
35
+
36
+ modules = []
37
+ for block in blocks:
38
+ for bottleneck in block:
39
+ modules.append(unit_module(bottleneck.in_channel,
40
+ bottleneck.depth,
41
+ bottleneck.stride))
42
+ self.body = Sequential(*modules)
43
+
44
+ def forward(self, x):
45
+ x = self.input_layer(x)
46
+ x = self.body(x)
47
+ x = self.output_layer(x)
48
+ return l2_norm(x)
49
+
50
+
51
+ def IR_50(input_size):
52
+ """Constructs a ir-50 model."""
53
+ model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False)
54
+ return model
55
+
56
+
57
+ def IR_101(input_size):
58
+ """Constructs a ir-101 model."""
59
+ model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False)
60
+ return model
61
+
62
+
63
+ def IR_152(input_size):
64
+ """Constructs a ir-152 model."""
65
+ model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False)
66
+ return model
67
+
68
+
69
+ def IR_SE_50(input_size):
70
+ """Constructs a ir_se-50 model."""
71
+ model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False)
72
+ return model
73
+
74
+
75
+ def IR_SE_101(input_size):
76
+ """Constructs a ir_se-101 model."""
77
+ model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False)
78
+ return model
79
+
80
+
81
+ def IR_SE_152(input_size):
82
+ """Constructs a ir_se-152 model."""
83
+ model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False)
84
+ return model
e4e/models/encoders/psp_encoders.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ import math
3
+ import numpy as np
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import Conv2d, BatchNorm2d, PReLU, Sequential, Module
7
+
8
+ from e4e.models.encoders.helpers import get_blocks, bottleneck_IR, bottleneck_IR_SE, _upsample_add
9
+ from e4e.models.stylegan2.model import EqualLinear
10
+
11
+
12
+ class ProgressiveStage(Enum):
13
+ WTraining = 0
14
+ Delta1Training = 1
15
+ Delta2Training = 2
16
+ Delta3Training = 3
17
+ Delta4Training = 4
18
+ Delta5Training = 5
19
+ Delta6Training = 6
20
+ Delta7Training = 7
21
+ Delta8Training = 8
22
+ Delta9Training = 9
23
+ Delta10Training = 10
24
+ Delta11Training = 11
25
+ Delta12Training = 12
26
+ Delta13Training = 13
27
+ Delta14Training = 14
28
+ Delta15Training = 15
29
+ Delta16Training = 16
30
+ Delta17Training = 17
31
+ Inference = 18
32
+
33
+
34
+ class GradualStyleBlock(Module):
35
+ def __init__(self, in_c, out_c, spatial):
36
+ super(GradualStyleBlock, self).__init__()
37
+ self.out_c = out_c
38
+ self.spatial = spatial
39
+ num_pools = int(np.log2(spatial))
40
+ modules = []
41
+ modules += [Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1),
42
+ nn.LeakyReLU()]
43
+ for i in range(num_pools - 1):
44
+ modules += [
45
+ Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1),
46
+ nn.LeakyReLU()
47
+ ]
48
+ self.convs = nn.Sequential(*modules)
49
+ self.linear = EqualLinear(out_c, out_c, lr_mul=1)
50
+
51
+ def forward(self, x):
52
+ x = self.convs(x)
53
+ x = x.view(-1, self.out_c)
54
+ x = self.linear(x)
55
+ return x
56
+
57
+
58
+ class GradualStyleEncoder(Module):
59
+ def __init__(self, num_layers, mode='ir', opts=None):
60
+ super(GradualStyleEncoder, self).__init__()
61
+ assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
62
+ assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
63
+ blocks = get_blocks(num_layers)
64
+ if mode == 'ir':
65
+ unit_module = bottleneck_IR
66
+ elif mode == 'ir_se':
67
+ unit_module = bottleneck_IR_SE
68
+ self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
69
+ BatchNorm2d(64),
70
+ PReLU(64))
71
+ modules = []
72
+ for block in blocks:
73
+ for bottleneck in block:
74
+ modules.append(unit_module(bottleneck.in_channel,
75
+ bottleneck.depth,
76
+ bottleneck.stride))
77
+ self.body = Sequential(*modules)
78
+
79
+ self.styles = nn.ModuleList()
80
+ log_size = int(math.log(opts.stylegan_size, 2))
81
+ self.style_count = 2 * log_size - 2
82
+ self.coarse_ind = 3
83
+ self.middle_ind = 7
84
+ for i in range(self.style_count):
85
+ if i < self.coarse_ind:
86
+ style = GradualStyleBlock(512, 512, 16)
87
+ elif i < self.middle_ind:
88
+ style = GradualStyleBlock(512, 512, 32)
89
+ else:
90
+ style = GradualStyleBlock(512, 512, 64)
91
+ self.styles.append(style)
92
+ self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0)
93
+ self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0)
94
+
95
+ def forward(self, x):
96
+ x = self.input_layer(x)
97
+
98
+ latents = []
99
+ modulelist = list(self.body._modules.values())
100
+ for i, l in enumerate(modulelist):
101
+ x = l(x)
102
+ if i == 6:
103
+ c1 = x
104
+ elif i == 20:
105
+ c2 = x
106
+ elif i == 23:
107
+ c3 = x
108
+
109
+ for j in range(self.coarse_ind):
110
+ latents.append(self.styles[j](c3))
111
+
112
+ p2 = _upsample_add(c3, self.latlayer1(c2))
113
+ for j in range(self.coarse_ind, self.middle_ind):
114
+ latents.append(self.styles[j](p2))
115
+
116
+ p1 = _upsample_add(p2, self.latlayer2(c1))
117
+ for j in range(self.middle_ind, self.style_count):
118
+ latents.append(self.styles[j](p1))
119
+
120
+ out = torch.stack(latents, dim=1)
121
+ return out
122
+
123
+
124
+ class Encoder4Editing(Module):
125
+ def __init__(self, num_layers, mode='ir', opts=None):
126
+ super(Encoder4Editing, self).__init__()
127
+ assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
128
+ assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
129
+ blocks = get_blocks(num_layers)
130
+ if mode == 'ir':
131
+ unit_module = bottleneck_IR
132
+ elif mode == 'ir_se':
133
+ unit_module = bottleneck_IR_SE
134
+ self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
135
+ BatchNorm2d(64),
136
+ PReLU(64))
137
+ modules = []
138
+ for block in blocks:
139
+ for bottleneck in block:
140
+ modules.append(unit_module(bottleneck.in_channel,
141
+ bottleneck.depth,
142
+ bottleneck.stride))
143
+ self.body = Sequential(*modules)
144
+
145
+ self.styles = nn.ModuleList()
146
+ log_size = int(math.log(opts.stylegan_size, 2))
147
+ self.style_count = 2 * log_size - 2
148
+ self.coarse_ind = 3
149
+ self.middle_ind = 7
150
+
151
+ for i in range(self.style_count):
152
+ if i < self.coarse_ind:
153
+ style = GradualStyleBlock(512, 512, 16)
154
+ elif i < self.middle_ind:
155
+ style = GradualStyleBlock(512, 512, 32)
156
+ else:
157
+ style = GradualStyleBlock(512, 512, 64)
158
+ self.styles.append(style)
159
+
160
+ self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0)
161
+ self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0)
162
+
163
+ self.progressive_stage = ProgressiveStage.Inference
164
+
165
+ def get_deltas_starting_dimensions(self):
166
+ ''' Get a list of the initial dimension of every delta from which it is applied '''
167
+ return list(range(self.style_count)) # Each dimension has a delta applied to it
168
+
169
+ def set_progressive_stage(self, new_stage: ProgressiveStage):
170
+ self.progressive_stage = new_stage
171
+ print('Changed progressive stage to: ', new_stage)
172
+
173
+ def forward(self, x):
174
+ x = self.input_layer(x)
175
+
176
+ modulelist = list(self.body._modules.values())
177
+ for i, l in enumerate(modulelist):
178
+ x = l(x)
179
+ if i == 6:
180
+ c1 = x
181
+ elif i == 20:
182
+ c2 = x
183
+ elif i == 23:
184
+ c3 = x
185
+
186
+ # Infer main W and duplicate it
187
+ w0 = self.styles[0](c3)
188
+ w = w0.repeat(self.style_count, 1, 1).permute(1, 0, 2)
189
+ stage = self.progressive_stage.value
190
+ features = c3
191
+ for i in range(1, min(stage + 1, self.style_count)): # Infer additional deltas
192
+ if i == self.coarse_ind:
193
+ p2 = _upsample_add(c3, self.latlayer1(c2)) # FPN's middle features
194
+ features = p2
195
+ elif i == self.middle_ind:
196
+ p1 = _upsample_add(p2, self.latlayer2(c1)) # FPN's fine features
197
+ features = p1
198
+ delta_i = self.styles[i](features)
199
+ w[:, i] += delta_i
200
+ return w
e4e/models/latent_codes_pool.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+
4
+
5
+ class LatentCodesPool:
6
+ """This class implements latent codes buffer that stores previously generated w latent codes.
7
+ This buffer enables us to update discriminators using a history of generated w's
8
+ rather than the ones produced by the latest encoder.
9
+ """
10
+
11
+ def __init__(self, pool_size):
12
+ """Initialize the ImagePool class
13
+ Parameters:
14
+ pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created
15
+ """
16
+ self.pool_size = pool_size
17
+ if self.pool_size > 0: # create an empty pool
18
+ self.num_ws = 0
19
+ self.ws = []
20
+
21
+ def query(self, ws):
22
+ """Return w's from the pool.
23
+ Parameters:
24
+ ws: the latest generated w's from the generator
25
+ Returns w's from the buffer.
26
+ By 50/100, the buffer will return input w's.
27
+ By 50/100, the buffer will return w's previously stored in the buffer,
28
+ and insert the current w's to the buffer.
29
+ """
30
+ if self.pool_size == 0: # if the buffer size is 0, do nothing
31
+ return ws
32
+ return_ws = []
33
+ for w in ws: # ws.shape: (batch, 512) or (batch, n_latent, 512)
34
+ # w = torch.unsqueeze(image.data, 0)
35
+ if w.ndim == 2:
36
+ i = random.randint(0, len(w) - 1) # apply a random latent index as a candidate
37
+ w = w[i]
38
+ self.handle_w(w, return_ws)
39
+ return_ws = torch.stack(return_ws, 0) # collect all the images and return
40
+ return return_ws
41
+
42
+ def handle_w(self, w, return_ws):
43
+ if self.num_ws < self.pool_size: # if the buffer is not full; keep inserting current codes to the buffer
44
+ self.num_ws = self.num_ws + 1
45
+ self.ws.append(w)
46
+ return_ws.append(w)
47
+ else:
48
+ p = random.uniform(0, 1)
49
+ if p > 0.5: # by 50% chance, the buffer will return a previously stored latent code, and insert the current code into the buffer
50
+ random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
51
+ tmp = self.ws[random_id].clone()
52
+ self.ws[random_id] = w
53
+ return_ws.append(tmp)
54
+ else: # by another 50% chance, the buffer will return the current image
55
+ return_ws.append(w)
e4e/models/psp.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib
2
+
3
+ matplotlib.use('Agg')
4
+ import torch
5
+ from torch import nn
6
+ from e4e.models.encoders import psp_encoders
7
+ from e4e.models.stylegan2.model import Generator
8
+ from e4e.configs.paths_config import model_paths
9
+
10
+
11
+ def get_keys(d, name):
12
+ if 'state_dict' in d:
13
+ d = d['state_dict']
14
+ d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name}
15
+ return d_filt
16
+
17
+
18
+ class pSp(nn.Module):
19
+
20
+ def __init__(self, opts, device):
21
+ super(pSp, self).__init__()
22
+ self.opts = opts
23
+ self.device = device
24
+ # Define architecture
25
+ self.encoder = self.set_encoder()
26
+ self.decoder = Generator(opts.stylegan_size, 512, 8, channel_multiplier=2)
27
+ self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
28
+ # Load weights if needed
29
+ self.load_weights()
30
+
31
+ def set_encoder(self):
32
+ if self.opts.encoder_type == 'GradualStyleEncoder':
33
+ encoder = psp_encoders.GradualStyleEncoder(50, 'ir_se', self.opts)
34
+ elif self.opts.encoder_type == 'Encoder4Editing':
35
+ encoder = psp_encoders.Encoder4Editing(50, 'ir_se', self.opts)
36
+ else:
37
+ raise Exception('{} is not a valid encoders'.format(self.opts.encoder_type))
38
+ return encoder
39
+
40
+ def load_weights(self):
41
+ if self.opts.checkpoint_path is not None:
42
+ print('Loading e4e over the pSp framework from checkpoint: {}'.format(self.opts.checkpoint_path))
43
+ ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu')
44
+ self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=True)
45
+ self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=True)
46
+ self.__load_latent_avg(ckpt)
47
+ else:
48
+ print('Loading encoders weights from irse50!')
49
+ encoder_ckpt = torch.load(model_paths['ir_se50'])
50
+ self.encoder.load_state_dict(encoder_ckpt, strict=False)
51
+ print('Loading decoder weights from pretrained!')
52
+ ckpt = torch.load(self.opts.stylegan_weights)
53
+ self.decoder.load_state_dict(ckpt['g_ema'], strict=False)
54
+ self.__load_latent_avg(ckpt, repeat=self.encoder.style_count)
55
+
56
+ def forward(self, x, resize=True, latent_mask=None, input_code=False, randomize_noise=True,
57
+ inject_latent=None, return_latents=False, alpha=None):
58
+ if input_code:
59
+ codes = x
60
+ else:
61
+ codes = self.encoder(x)
62
+ # normalize with respect to the center of an average face
63
+ if self.opts.start_from_latent_avg:
64
+ if codes.ndim == 2:
65
+ codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)[:, 0, :]
66
+ else:
67
+ codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)
68
+
69
+ if latent_mask is not None:
70
+ for i in latent_mask:
71
+ if inject_latent is not None:
72
+ if alpha is not None:
73
+ codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i]
74
+ else:
75
+ codes[:, i] = inject_latent[:, i]
76
+ else:
77
+ codes[:, i] = 0
78
+
79
+ input_is_latent = not input_code
80
+ images, result_latent = self.decoder([codes],
81
+ input_is_latent=input_is_latent,
82
+ randomize_noise=randomize_noise,
83
+ return_latents=return_latents)
84
+
85
+ if resize:
86
+ images = self.face_pool(images)
87
+
88
+ if return_latents:
89
+ return images, result_latent
90
+ else:
91
+ return images
92
+
93
+ def __load_latent_avg(self, ckpt, repeat=None):
94
+ if 'latent_avg' in ckpt:
95
+ self.latent_avg = ckpt['latent_avg'].to(self.device)
96
+ if repeat is not None:
97
+ self.latent_avg = self.latent_avg.repeat(repeat, 1)
98
+ else:
99
+ self.latent_avg = None
e4e/models/stylegan2/__init__.py ADDED
File without changes
e4e/models/stylegan2/model.py ADDED
@@ -0,0 +1,678 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ if torch.cuda.is_available():
8
+ from op.fused_act import FusedLeakyReLU, fused_leaky_relu
9
+ from op.upfirdn2d import upfirdn2d
10
+ else:
11
+ from op.fused_act_cpu import FusedLeakyReLU, fused_leaky_relu
12
+ from op.upfirdn2d_cpu import upfirdn2d
13
+
14
+
15
+ class PixelNorm(nn.Module):
16
+ def __init__(self):
17
+ super().__init__()
18
+
19
+ def forward(self, input):
20
+ return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
21
+
22
+
23
+ def make_kernel(k):
24
+ k = torch.tensor(k, dtype=torch.float32)
25
+
26
+ if k.ndim == 1:
27
+ k = k[None, :] * k[:, None]
28
+
29
+ k /= k.sum()
30
+
31
+ return k
32
+
33
+
34
+ class Upsample(nn.Module):
35
+ def __init__(self, kernel, factor=2):
36
+ super().__init__()
37
+
38
+ self.factor = factor
39
+ kernel = make_kernel(kernel) * (factor ** 2)
40
+ self.register_buffer('kernel', kernel)
41
+
42
+ p = kernel.shape[0] - factor
43
+
44
+ pad0 = (p + 1) // 2 + factor - 1
45
+ pad1 = p // 2
46
+
47
+ self.pad = (pad0, pad1)
48
+
49
+ def forward(self, input):
50
+ out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
51
+
52
+ return out
53
+
54
+
55
+ class Downsample(nn.Module):
56
+ def __init__(self, kernel, factor=2):
57
+ super().__init__()
58
+
59
+ self.factor = factor
60
+ kernel = make_kernel(kernel)
61
+ self.register_buffer('kernel', kernel)
62
+
63
+ p = kernel.shape[0] - factor
64
+
65
+ pad0 = (p + 1) // 2
66
+ pad1 = p // 2
67
+
68
+ self.pad = (pad0, pad1)
69
+
70
+ def forward(self, input):
71
+ out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
72
+
73
+ return out
74
+
75
+
76
+ class Blur(nn.Module):
77
+ def __init__(self, kernel, pad, upsample_factor=1):
78
+ super().__init__()
79
+
80
+ kernel = make_kernel(kernel)
81
+
82
+ if upsample_factor > 1:
83
+ kernel = kernel * (upsample_factor ** 2)
84
+
85
+ self.register_buffer('kernel', kernel)
86
+
87
+ self.pad = pad
88
+
89
+ def forward(self, input):
90
+ out = upfirdn2d(input, self.kernel, pad=self.pad)
91
+
92
+ return out
93
+
94
+
95
+ class EqualConv2d(nn.Module):
96
+ def __init__(
97
+ self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
98
+ ):
99
+ super().__init__()
100
+
101
+ self.weight = nn.Parameter(
102
+ torch.randn(out_channel, in_channel, kernel_size, kernel_size)
103
+ )
104
+ self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
105
+
106
+ self.stride = stride
107
+ self.padding = padding
108
+
109
+ if bias:
110
+ self.bias = nn.Parameter(torch.zeros(out_channel))
111
+
112
+ else:
113
+ self.bias = None
114
+
115
+ def forward(self, input):
116
+ out = F.conv2d(
117
+ input,
118
+ self.weight * self.scale,
119
+ bias=self.bias,
120
+ stride=self.stride,
121
+ padding=self.padding,
122
+ )
123
+
124
+ return out
125
+
126
+ def __repr__(self):
127
+ return (
128
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
129
+ f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
130
+ )
131
+
132
+
133
+ class EqualLinear(nn.Module):
134
+ def __init__(
135
+ self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
136
+ ):
137
+ super().__init__()
138
+
139
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
140
+
141
+ if bias:
142
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
143
+
144
+ else:
145
+ self.bias = None
146
+
147
+ self.activation = activation
148
+
149
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
150
+ self.lr_mul = lr_mul
151
+
152
+ def forward(self, input):
153
+ if self.activation:
154
+ out = F.linear(input, self.weight * self.scale)
155
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
156
+
157
+ else:
158
+ out = F.linear(
159
+ input, self.weight * self.scale, bias=self.bias * self.lr_mul
160
+ )
161
+
162
+ return out
163
+
164
+ def __repr__(self):
165
+ return (
166
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
167
+ )
168
+
169
+
170
+ class ScaledLeakyReLU(nn.Module):
171
+ def __init__(self, negative_slope=0.2):
172
+ super().__init__()
173
+
174
+ self.negative_slope = negative_slope
175
+
176
+ def forward(self, input):
177
+ out = F.leaky_relu(input, negative_slope=self.negative_slope)
178
+
179
+ return out * math.sqrt(2)
180
+
181
+
182
+ class ModulatedConv2d(nn.Module):
183
+ def __init__(
184
+ self,
185
+ in_channel,
186
+ out_channel,
187
+ kernel_size,
188
+ style_dim,
189
+ demodulate=True,
190
+ upsample=False,
191
+ downsample=False,
192
+ blur_kernel=[1, 3, 3, 1],
193
+ ):
194
+ super().__init__()
195
+
196
+ self.eps = 1e-8
197
+ self.kernel_size = kernel_size
198
+ self.in_channel = in_channel
199
+ self.out_channel = out_channel
200
+ self.upsample = upsample
201
+ self.downsample = downsample
202
+
203
+ if upsample:
204
+ factor = 2
205
+ p = (len(blur_kernel) - factor) - (kernel_size - 1)
206
+ pad0 = (p + 1) // 2 + factor - 1
207
+ pad1 = p // 2 + 1
208
+
209
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
210
+
211
+ if downsample:
212
+ factor = 2
213
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
214
+ pad0 = (p + 1) // 2
215
+ pad1 = p // 2
216
+
217
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1))
218
+
219
+ fan_in = in_channel * kernel_size ** 2
220
+ self.scale = 1 / math.sqrt(fan_in)
221
+ self.padding = kernel_size // 2
222
+
223
+ self.weight = nn.Parameter(
224
+ torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
225
+ )
226
+
227
+ self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
228
+
229
+ self.demodulate = demodulate
230
+
231
+ def __repr__(self):
232
+ return (
233
+ f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
234
+ f'upsample={self.upsample}, downsample={self.downsample})'
235
+ )
236
+
237
+ def forward(self, input, style):
238
+ batch, in_channel, height, width = input.shape
239
+
240
+ style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
241
+ weight = self.scale * self.weight * style
242
+
243
+ if self.demodulate:
244
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
245
+ weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
246
+
247
+ weight = weight.view(
248
+ batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
249
+ )
250
+
251
+ if self.upsample:
252
+ input = input.view(1, batch * in_channel, height, width)
253
+ weight = weight.view(
254
+ batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
255
+ )
256
+ weight = weight.transpose(1, 2).reshape(
257
+ batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
258
+ )
259
+ out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
260
+ _, _, height, width = out.shape
261
+ out = out.view(batch, self.out_channel, height, width)
262
+ out = self.blur(out)
263
+
264
+ elif self.downsample:
265
+ input = self.blur(input)
266
+ _, _, height, width = input.shape
267
+ input = input.view(1, batch * in_channel, height, width)
268
+ out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
269
+ _, _, height, width = out.shape
270
+ out = out.view(batch, self.out_channel, height, width)
271
+
272
+ else:
273
+ input = input.view(1, batch * in_channel, height, width)
274
+ out = F.conv2d(input, weight, padding=self.padding, groups=batch)
275
+ _, _, height, width = out.shape
276
+ out = out.view(batch, self.out_channel, height, width)
277
+
278
+ return out
279
+
280
+
281
+ class NoiseInjection(nn.Module):
282
+ def __init__(self):
283
+ super().__init__()
284
+
285
+ self.weight = nn.Parameter(torch.zeros(1))
286
+
287
+ def forward(self, image, noise=None):
288
+ if noise is None:
289
+ batch, _, height, width = image.shape
290
+ noise = image.new_empty(batch, 1, height, width).normal_()
291
+
292
+ return image + self.weight * noise
293
+
294
+
295
+ class ConstantInput(nn.Module):
296
+ def __init__(self, channel, size=4):
297
+ super().__init__()
298
+
299
+ self.input = nn.Parameter(torch.randn(1, channel, size, size))
300
+
301
+ def forward(self, input):
302
+ batch = input.shape[0]
303
+ out = self.input.repeat(batch, 1, 1, 1)
304
+
305
+ return out
306
+
307
+
308
+ class StyledConv(nn.Module):
309
+ def __init__(
310
+ self,
311
+ in_channel,
312
+ out_channel,
313
+ kernel_size,
314
+ style_dim,
315
+ upsample=False,
316
+ blur_kernel=[1, 3, 3, 1],
317
+ demodulate=True,
318
+ ):
319
+ super().__init__()
320
+
321
+ self.conv = ModulatedConv2d(
322
+ in_channel,
323
+ out_channel,
324
+ kernel_size,
325
+ style_dim,
326
+ upsample=upsample,
327
+ blur_kernel=blur_kernel,
328
+ demodulate=demodulate,
329
+ )
330
+
331
+ self.noise = NoiseInjection()
332
+ # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
333
+ # self.activate = ScaledLeakyReLU(0.2)
334
+ self.activate = FusedLeakyReLU(out_channel)
335
+
336
+ def forward(self, input, style, noise=None):
337
+ out = self.conv(input, style)
338
+ out = self.noise(out, noise=noise)
339
+ # out = out + self.bias
340
+ out = self.activate(out)
341
+
342
+ return out
343
+
344
+
345
+ class ToRGB(nn.Module):
346
+ def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
347
+ super().__init__()
348
+
349
+ if upsample:
350
+ self.upsample = Upsample(blur_kernel)
351
+
352
+ self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
353
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
354
+
355
+ def forward(self, input, style, skip=None):
356
+ out = self.conv(input, style)
357
+ out = out + self.bias
358
+
359
+ if skip is not None:
360
+ skip = self.upsample(skip)
361
+
362
+ out = out + skip
363
+
364
+ return out
365
+
366
+
367
+ class Generator(nn.Module):
368
+ def __init__(
369
+ self,
370
+ size,
371
+ style_dim,
372
+ n_mlp,
373
+ channel_multiplier=2,
374
+ blur_kernel=[1, 3, 3, 1],
375
+ lr_mlp=0.01,
376
+ ):
377
+ super().__init__()
378
+
379
+ self.size = size
380
+
381
+ self.style_dim = style_dim
382
+
383
+ layers = [PixelNorm()]
384
+
385
+ for i in range(n_mlp):
386
+ layers.append(
387
+ EqualLinear(
388
+ style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu'
389
+ )
390
+ )
391
+
392
+ self.style = nn.Sequential(*layers)
393
+
394
+ self.channels = {
395
+ 4: 512,
396
+ 8: 512,
397
+ 16: 512,
398
+ 32: 512,
399
+ 64: 256 * channel_multiplier,
400
+ 128: 128 * channel_multiplier,
401
+ 256: 64 * channel_multiplier,
402
+ 512: 32 * channel_multiplier,
403
+ 1024: 16 * channel_multiplier,
404
+ }
405
+
406
+ self.input = ConstantInput(self.channels[4])
407
+ self.conv1 = StyledConv(
408
+ self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
409
+ )
410
+ self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
411
+
412
+ self.log_size = int(math.log(size, 2))
413
+ self.num_layers = (self.log_size - 2) * 2 + 1
414
+
415
+ self.convs = nn.ModuleList()
416
+ self.upsamples = nn.ModuleList()
417
+ self.to_rgbs = nn.ModuleList()
418
+ self.noises = nn.Module()
419
+
420
+ in_channel = self.channels[4]
421
+
422
+ for layer_idx in range(self.num_layers):
423
+ res = (layer_idx + 5) // 2
424
+ shape = [1, 1, 2 ** res, 2 ** res]
425
+ self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape))
426
+
427
+ for i in range(3, self.log_size + 1):
428
+ out_channel = self.channels[2 ** i]
429
+
430
+ self.convs.append(
431
+ StyledConv(
432
+ in_channel,
433
+ out_channel,
434
+ 3,
435
+ style_dim,
436
+ upsample=True,
437
+ blur_kernel=blur_kernel,
438
+ )
439
+ )
440
+
441
+ self.convs.append(
442
+ StyledConv(
443
+ out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
444
+ )
445
+ )
446
+
447
+ self.to_rgbs.append(ToRGB(out_channel, style_dim))
448
+
449
+ in_channel = out_channel
450
+
451
+ self.n_latent = self.log_size * 2 - 2
452
+
453
+ def make_noise(self):
454
+ device = self.input.input.device
455
+
456
+ noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
457
+
458
+ for i in range(3, self.log_size + 1):
459
+ for _ in range(2):
460
+ noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
461
+
462
+ return noises
463
+
464
+ def mean_latent(self, n_latent):
465
+ latent_in = torch.randn(
466
+ n_latent, self.style_dim, device=self.input.input.device
467
+ )
468
+ latent = self.style(latent_in).mean(0, keepdim=True)
469
+
470
+ return latent
471
+
472
+ def get_latent(self, input):
473
+ return self.style(input)
474
+
475
+ def forward(
476
+ self,
477
+ styles,
478
+ return_latents=False,
479
+ return_features=False,
480
+ inject_index=None,
481
+ truncation=1,
482
+ truncation_latent=None,
483
+ input_is_latent=False,
484
+ noise=None,
485
+ randomize_noise=True,
486
+ ):
487
+ if not input_is_latent:
488
+ styles = [self.style(s) for s in styles]
489
+
490
+ if noise is None:
491
+ if randomize_noise:
492
+ noise = [None] * self.num_layers
493
+ else:
494
+ noise = [
495
+ getattr(self.noises, f'noise_{i}') for i in range(self.num_layers)
496
+ ]
497
+
498
+ if truncation < 1:
499
+ style_t = []
500
+
501
+ for style in styles:
502
+ style_t.append(
503
+ truncation_latent + truncation * (style - truncation_latent)
504
+ )
505
+
506
+ styles = style_t
507
+
508
+ if len(styles) < 2:
509
+ inject_index = self.n_latent
510
+
511
+ if styles[0].ndim < 3:
512
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
513
+ else:
514
+ latent = styles[0]
515
+
516
+ else:
517
+ if inject_index is None:
518
+ inject_index = random.randint(1, self.n_latent - 1)
519
+
520
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
521
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
522
+
523
+ latent = torch.cat([latent, latent2], 1)
524
+
525
+ out = self.input(latent)
526
+ out = self.conv1(out, latent[:, 0], noise=noise[0])
527
+
528
+ skip = self.to_rgb1(out, latent[:, 1])
529
+
530
+ i = 1
531
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
532
+ self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
533
+ ):
534
+ out = conv1(out, latent[:, i], noise=noise1)
535
+ out = conv2(out, latent[:, i + 1], noise=noise2)
536
+ skip = to_rgb(out, latent[:, i + 2], skip)
537
+
538
+ i += 2
539
+
540
+ image = skip
541
+
542
+ if return_latents:
543
+ return image, latent
544
+ elif return_features:
545
+ return image, out
546
+ else:
547
+ return image, None
548
+
549
+
550
+ class ConvLayer(nn.Sequential):
551
+ def __init__(
552
+ self,
553
+ in_channel,
554
+ out_channel,
555
+ kernel_size,
556
+ downsample=False,
557
+ blur_kernel=[1, 3, 3, 1],
558
+ bias=True,
559
+ activate=True,
560
+ ):
561
+ layers = []
562
+
563
+ if downsample:
564
+ factor = 2
565
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
566
+ pad0 = (p + 1) // 2
567
+ pad1 = p // 2
568
+
569
+ layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
570
+
571
+ stride = 2
572
+ self.padding = 0
573
+
574
+ else:
575
+ stride = 1
576
+ self.padding = kernel_size // 2
577
+
578
+ layers.append(
579
+ EqualConv2d(
580
+ in_channel,
581
+ out_channel,
582
+ kernel_size,
583
+ padding=self.padding,
584
+ stride=stride,
585
+ bias=bias and not activate,
586
+ )
587
+ )
588
+
589
+ if activate:
590
+ if bias:
591
+ layers.append(FusedLeakyReLU(out_channel))
592
+
593
+ else:
594
+ layers.append(ScaledLeakyReLU(0.2))
595
+
596
+ super().__init__(*layers)
597
+
598
+
599
+ class ResBlock(nn.Module):
600
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
601
+ super().__init__()
602
+
603
+ self.conv1 = ConvLayer(in_channel, in_channel, 3)
604
+ self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
605
+
606
+ self.skip = ConvLayer(
607
+ in_channel, out_channel, 1, downsample=True, activate=False, bias=False
608
+ )
609
+
610
+ def forward(self, input):
611
+ out = self.conv1(input)
612
+ out = self.conv2(out)
613
+
614
+ skip = self.skip(input)
615
+ out = (out + skip) / math.sqrt(2)
616
+
617
+ return out
618
+
619
+
620
+ class Discriminator(nn.Module):
621
+ def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
622
+ super().__init__()
623
+
624
+ channels = {
625
+ 4: 512,
626
+ 8: 512,
627
+ 16: 512,
628
+ 32: 512,
629
+ 64: 256 * channel_multiplier,
630
+ 128: 128 * channel_multiplier,
631
+ 256: 64 * channel_multiplier,
632
+ 512: 32 * channel_multiplier,
633
+ 1024: 16 * channel_multiplier,
634
+ }
635
+
636
+ convs = [ConvLayer(3, channels[size], 1)]
637
+
638
+ log_size = int(math.log(size, 2))
639
+
640
+ in_channel = channels[size]
641
+
642
+ for i in range(log_size, 2, -1):
643
+ out_channel = channels[2 ** (i - 1)]
644
+
645
+ convs.append(ResBlock(in_channel, out_channel, blur_kernel))
646
+
647
+ in_channel = out_channel
648
+
649
+ self.convs = nn.Sequential(*convs)
650
+
651
+ self.stddev_group = 4
652
+ self.stddev_feat = 1
653
+
654
+ self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
655
+ self.final_linear = nn.Sequential(
656
+ EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),
657
+ EqualLinear(channels[4], 1),
658
+ )
659
+
660
+ def forward(self, input):
661
+ out = self.convs(input)
662
+
663
+ batch, channel, height, width = out.shape
664
+ group = min(batch, self.stddev_group)
665
+ stddev = out.view(
666
+ group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
667
+ )
668
+ stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
669
+ stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
670
+ stddev = stddev.repeat(group, 1, height, width)
671
+ out = torch.cat([out, stddev], 1)
672
+
673
+ out = self.final_conv(out)
674
+
675
+ out = out.view(batch, -1)
676
+ out = self.final_linear(out)
677
+
678
+ return out
e4e/models/stylegan2/op/__init__.py ADDED
File without changes
e4e/models/stylegan2/op/fused_act.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.autograd import Function
6
+ from torch.utils.cpp_extension import load
7
+
8
+ module_path = os.path.dirname(__file__)
9
+ fused = load(
10
+ 'fused',
11
+ sources=[
12
+ os.path.join(module_path, 'fused_bias_act.cpp'),
13
+ os.path.join(module_path, 'fused_bias_act_kernel.cu'),
14
+ ],
15
+ )
16
+
17
+
18
+ class FusedLeakyReLUFunctionBackward(Function):
19
+ @staticmethod
20
+ def forward(ctx, grad_output, out, negative_slope, scale):
21
+ ctx.save_for_backward(out)
22
+ ctx.negative_slope = negative_slope
23
+ ctx.scale = scale
24
+
25
+ empty = grad_output.new_empty(0)
26
+
27
+ grad_input = fused.fused_bias_act(
28
+ grad_output, empty, out, 3, 1, negative_slope, scale
29
+ )
30
+
31
+ dim = [0]
32
+
33
+ if grad_input.ndim > 2:
34
+ dim += list(range(2, grad_input.ndim))
35
+
36
+ grad_bias = grad_input.sum(dim).detach()
37
+
38
+ return grad_input, grad_bias
39
+
40
+ @staticmethod
41
+ def backward(ctx, gradgrad_input, gradgrad_bias):
42
+ out, = ctx.saved_tensors
43
+ gradgrad_out = fused.fused_bias_act(
44
+ gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
45
+ )
46
+
47
+ return gradgrad_out, None, None, None
48
+
49
+
50
+ class FusedLeakyReLUFunction(Function):
51
+ @staticmethod
52
+ def forward(ctx, input, bias, negative_slope, scale):
53
+ empty = input.new_empty(0)
54
+ out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
55
+ ctx.save_for_backward(out)
56
+ ctx.negative_slope = negative_slope
57
+ ctx.scale = scale
58
+
59
+ return out
60
+
61
+ @staticmethod
62
+ def backward(ctx, grad_output):
63
+ out, = ctx.saved_tensors
64
+
65
+ grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
66
+ grad_output, out, ctx.negative_slope, ctx.scale
67
+ )
68
+
69
+ return grad_input, grad_bias, None, None
70
+
71
+
72
+ class FusedLeakyReLU(nn.Module):
73
+ def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
74
+ super().__init__()
75
+
76
+ self.bias = nn.Parameter(torch.zeros(channel))
77
+ self.negative_slope = negative_slope
78
+ self.scale = scale
79
+
80
+ def forward(self, input):
81
+ return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
82
+
83
+
84
+ def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
85
+ return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
e4e/models/stylegan2/op/fused_bias_act.cpp ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+
3
+
4
+ torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
5
+ int act, int grad, float alpha, float scale);
6
+
7
+ #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
8
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
9
+ #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
10
+
11
+ torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
12
+ int act, int grad, float alpha, float scale) {
13
+ CHECK_CUDA(input);
14
+ CHECK_CUDA(bias);
15
+
16
+ return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
17
+ }
18
+
19
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
20
+ m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
21
+ }
e4e/models/stylegan2/op/fused_bias_act_kernel.cu ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2
+ //
3
+ // This work is made available under the Nvidia Source Code License-NC.
4
+ // To view a copy of this license, visit
5
+ // https://nvlabs.github.io/stylegan2/license.html
6
+
7
+ #include <torch/types.h>
8
+
9
+ #include <ATen/ATen.h>
10
+ #include <ATen/AccumulateType.h>
11
+ #include <ATen/cuda/CUDAContext.h>
12
+ #include <ATen/cuda/CUDAApplyUtils.cuh>
13
+
14
+ #include <cuda.h>
15
+ #include <cuda_runtime.h>
16
+
17
+
18
+ template <typename scalar_t>
19
+ static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
20
+ int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
21
+ int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
22
+
23
+ scalar_t zero = 0.0;
24
+
25
+ for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
26
+ scalar_t x = p_x[xi];
27
+
28
+ if (use_bias) {
29
+ x += p_b[(xi / step_b) % size_b];
30
+ }
31
+
32
+ scalar_t ref = use_ref ? p_ref[xi] : zero;
33
+
34
+ scalar_t y;
35
+
36
+ switch (act * 10 + grad) {
37
+ default:
38
+ case 10: y = x; break;
39
+ case 11: y = x; break;
40
+ case 12: y = 0.0; break;
41
+
42
+ case 30: y = (x > 0.0) ? x : x * alpha; break;
43
+ case 31: y = (ref > 0.0) ? x : x * alpha; break;
44
+ case 32: y = 0.0; break;
45
+ }
46
+
47
+ out[xi] = y * scale;
48
+ }
49
+ }
50
+
51
+
52
+ torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
53
+ int act, int grad, float alpha, float scale) {
54
+ int curDevice = -1;
55
+ cudaGetDevice(&curDevice);
56
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
57
+
58
+ auto x = input.contiguous();
59
+ auto b = bias.contiguous();
60
+ auto ref = refer.contiguous();
61
+
62
+ int use_bias = b.numel() ? 1 : 0;
63
+ int use_ref = ref.numel() ? 1 : 0;
64
+
65
+ int size_x = x.numel();
66
+ int size_b = b.numel();
67
+ int step_b = 1;
68
+
69
+ for (int i = 1 + 1; i < x.dim(); i++) {
70
+ step_b *= x.size(i);
71
+ }
72
+
73
+ int loop_x = 4;
74
+ int block_size = 4 * 32;
75
+ int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
76
+
77
+ auto y = torch::empty_like(x);
78
+
79
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
80
+ fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
81
+ y.data_ptr<scalar_t>(),
82
+ x.data_ptr<scalar_t>(),
83
+ b.data_ptr<scalar_t>(),
84
+ ref.data_ptr<scalar_t>(),
85
+ act,
86
+ grad,
87
+ alpha,
88
+ scale,
89
+ loop_x,
90
+ size_x,
91
+ step_b,
92
+ size_b,
93
+ use_bias,
94
+ use_ref
95
+ );
96
+ });
97
+
98
+ return y;
99
+ }
e4e/models/stylegan2/op/upfirdn2d.cpp ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+
3
+
4
+ torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
5
+ int up_x, int up_y, int down_x, int down_y,
6
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1);
7
+
8
+ #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
9
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
10
+ #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
11
+
12
+ torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
13
+ int up_x, int up_y, int down_x, int down_y,
14
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
15
+ CHECK_CUDA(input);
16
+ CHECK_CUDA(kernel);
17
+
18
+ return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
19
+ }
20
+
21
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
22
+ m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
23
+ }
e4e/models/stylegan2/op/upfirdn2d.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from torch.autograd import Function
5
+ from torch.utils.cpp_extension import load
6
+
7
+ module_path = os.path.dirname(__file__)
8
+ upfirdn2d_op = load(
9
+ 'upfirdn2d',
10
+ sources=[
11
+ os.path.join(module_path, 'upfirdn2d.cpp'),
12
+ os.path.join(module_path, 'upfirdn2d_kernel.cu'),
13
+ ],
14
+ )
15
+
16
+
17
+ class UpFirDn2dBackward(Function):
18
+ @staticmethod
19
+ def forward(
20
+ ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
21
+ ):
22
+ up_x, up_y = up
23
+ down_x, down_y = down
24
+ g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
25
+
26
+ grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
27
+
28
+ grad_input = upfirdn2d_op.upfirdn2d(
29
+ grad_output,
30
+ grad_kernel,
31
+ down_x,
32
+ down_y,
33
+ up_x,
34
+ up_y,
35
+ g_pad_x0,
36
+ g_pad_x1,
37
+ g_pad_y0,
38
+ g_pad_y1,
39
+ )
40
+ grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
41
+
42
+ ctx.save_for_backward(kernel)
43
+
44
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
45
+
46
+ ctx.up_x = up_x
47
+ ctx.up_y = up_y
48
+ ctx.down_x = down_x
49
+ ctx.down_y = down_y
50
+ ctx.pad_x0 = pad_x0
51
+ ctx.pad_x1 = pad_x1
52
+ ctx.pad_y0 = pad_y0
53
+ ctx.pad_y1 = pad_y1
54
+ ctx.in_size = in_size
55
+ ctx.out_size = out_size
56
+
57
+ return grad_input
58
+
59
+ @staticmethod
60
+ def backward(ctx, gradgrad_input):
61
+ kernel, = ctx.saved_tensors
62
+
63
+ gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
64
+
65
+ gradgrad_out = upfirdn2d_op.upfirdn2d(
66
+ gradgrad_input,
67
+ kernel,
68
+ ctx.up_x,
69
+ ctx.up_y,
70
+ ctx.down_x,
71
+ ctx.down_y,
72
+ ctx.pad_x0,
73
+ ctx.pad_x1,
74
+ ctx.pad_y0,
75
+ ctx.pad_y1,
76
+ )
77
+ # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
78
+ gradgrad_out = gradgrad_out.view(
79
+ ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
80
+ )
81
+
82
+ return gradgrad_out, None, None, None, None, None, None, None, None
83
+
84
+
85
+ class UpFirDn2d(Function):
86
+ @staticmethod
87
+ def forward(ctx, input, kernel, up, down, pad):
88
+ up_x, up_y = up
89
+ down_x, down_y = down
90
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
91
+
92
+ kernel_h, kernel_w = kernel.shape
93
+ batch, channel, in_h, in_w = input.shape
94
+ ctx.in_size = input.shape
95
+
96
+ input = input.reshape(-1, in_h, in_w, 1)
97
+
98
+ ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
99
+
100
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
101
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
102
+ ctx.out_size = (out_h, out_w)
103
+
104
+ ctx.up = (up_x, up_y)
105
+ ctx.down = (down_x, down_y)
106
+ ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
107
+
108
+ g_pad_x0 = kernel_w - pad_x0 - 1
109
+ g_pad_y0 = kernel_h - pad_y0 - 1
110
+ g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
111
+ g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
112
+
113
+ ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
114
+
115
+ out = upfirdn2d_op.upfirdn2d(
116
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
117
+ )
118
+ # out = out.view(major, out_h, out_w, minor)
119
+ out = out.view(-1, channel, out_h, out_w)
120
+
121
+ return out
122
+
123
+ @staticmethod
124
+ def backward(ctx, grad_output):
125
+ kernel, grad_kernel = ctx.saved_tensors
126
+
127
+ grad_input = UpFirDn2dBackward.apply(
128
+ grad_output,
129
+ kernel,
130
+ grad_kernel,
131
+ ctx.up,
132
+ ctx.down,
133
+ ctx.pad,
134
+ ctx.g_pad,
135
+ ctx.in_size,
136
+ ctx.out_size,
137
+ )
138
+
139
+ return grad_input, None, None, None, None
140
+
141
+
142
+ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
143
+ out = UpFirDn2d.apply(
144
+ input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
145
+ )
146
+
147
+ return out
148
+
149
+
150
+ def upfirdn2d_native(
151
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
152
+ ):
153
+ _, in_h, in_w, minor = input.shape
154
+ kernel_h, kernel_w = kernel.shape
155
+
156
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
157
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
158
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
159
+
160
+ out = F.pad(
161
+ out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
162
+ )
163
+ out = out[
164
+ :,
165
+ max(-pad_y0, 0): out.shape[1] - max(-pad_y1, 0),
166
+ max(-pad_x0, 0): out.shape[2] - max(-pad_x1, 0),
167
+ :,
168
+ ]
169
+
170
+ out = out.permute(0, 3, 1, 2)
171
+ out = out.reshape(
172
+ [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
173
+ )
174
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
175
+ out = F.conv2d(out, w)
176
+ out = out.reshape(
177
+ -1,
178
+ minor,
179
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
180
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
181
+ )
182
+ out = out.permute(0, 2, 3, 1)
183
+
184
+ return out[:, ::down_y, ::down_x, :]
e4e/models/stylegan2/op/upfirdn2d_kernel.cu ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2
+ //
3
+ // This work is made available under the Nvidia Source Code License-NC.
4
+ // To view a copy of this license, visit
5
+ // https://nvlabs.github.io/stylegan2/license.html
6
+
7
+ #include <torch/types.h>
8
+
9
+ #include <ATen/ATen.h>
10
+ #include <ATen/AccumulateType.h>
11
+ #include <ATen/cuda/CUDAContext.h>
12
+ #include <ATen/cuda/CUDAApplyUtils.cuh>
13
+
14
+ #include <cuda.h>
15
+ #include <cuda_runtime.h>
16
+
17
+
18
+ static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
19
+ int c = a / b;
20
+
21
+ if (c * b > a) {
22
+ c--;
23
+ }
24
+
25
+ return c;
26
+ }
27
+
28
+
29
+ struct UpFirDn2DKernelParams {
30
+ int up_x;
31
+ int up_y;
32
+ int down_x;
33
+ int down_y;
34
+ int pad_x0;
35
+ int pad_x1;
36
+ int pad_y0;
37
+ int pad_y1;
38
+
39
+ int major_dim;
40
+ int in_h;
41
+ int in_w;
42
+ int minor_dim;
43
+ int kernel_h;
44
+ int kernel_w;
45
+ int out_h;
46
+ int out_w;
47
+ int loop_major;
48
+ int loop_x;
49
+ };
50
+
51
+
52
+ template <typename scalar_t, int up_x, int up_y, int down_x, int down_y, int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
53
+ __global__ void upfirdn2d_kernel(scalar_t* out, const scalar_t* input, const scalar_t* kernel, const UpFirDn2DKernelParams p) {
54
+ const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
55
+ const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
56
+
57
+ __shared__ volatile float sk[kernel_h][kernel_w];
58
+ __shared__ volatile float sx[tile_in_h][tile_in_w];
59
+
60
+ int minor_idx = blockIdx.x;
61
+ int tile_out_y = minor_idx / p.minor_dim;
62
+ minor_idx -= tile_out_y * p.minor_dim;
63
+ tile_out_y *= tile_out_h;
64
+ int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
65
+ int major_idx_base = blockIdx.z * p.loop_major;
66
+
67
+ if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | major_idx_base >= p.major_dim) {
68
+ return;
69
+ }
70
+
71
+ for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; tap_idx += blockDim.x) {
72
+ int ky = tap_idx / kernel_w;
73
+ int kx = tap_idx - ky * kernel_w;
74
+ scalar_t v = 0.0;
75
+
76
+ if (kx < p.kernel_w & ky < p.kernel_h) {
77
+ v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
78
+ }
79
+
80
+ sk[ky][kx] = v;
81
+ }
82
+
83
+ for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major & major_idx < p.major_dim; loop_major++, major_idx++) {
84
+ for (int loop_x = 0, tile_out_x = tile_out_x_base; loop_x < p.loop_x & tile_out_x < p.out_w; loop_x++, tile_out_x += tile_out_w) {
85
+ int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
86
+ int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
87
+ int tile_in_x = floor_div(tile_mid_x, up_x);
88
+ int tile_in_y = floor_div(tile_mid_y, up_y);
89
+
90
+ __syncthreads();
91
+
92
+ for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; in_idx += blockDim.x) {
93
+ int rel_in_y = in_idx / tile_in_w;
94
+ int rel_in_x = in_idx - rel_in_y * tile_in_w;
95
+ int in_x = rel_in_x + tile_in_x;
96
+ int in_y = rel_in_y + tile_in_y;
97
+
98
+ scalar_t v = 0.0;
99
+
100
+ if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
101
+ v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx];
102
+ }
103
+
104
+ sx[rel_in_y][rel_in_x] = v;
105
+ }
106
+
107
+ __syncthreads();
108
+ for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; out_idx += blockDim.x) {
109
+ int rel_out_y = out_idx / tile_out_w;
110
+ int rel_out_x = out_idx - rel_out_y * tile_out_w;
111
+ int out_x = rel_out_x + tile_out_x;
112
+ int out_y = rel_out_y + tile_out_y;
113
+
114
+ int mid_x = tile_mid_x + rel_out_x * down_x;
115
+ int mid_y = tile_mid_y + rel_out_y * down_y;
116
+ int in_x = floor_div(mid_x, up_x);
117
+ int in_y = floor_div(mid_y, up_y);
118
+ int rel_in_x = in_x - tile_in_x;
119
+ int rel_in_y = in_y - tile_in_y;
120
+ int kernel_x = (in_x + 1) * up_x - mid_x - 1;
121
+ int kernel_y = (in_y + 1) * up_y - mid_y - 1;
122
+
123
+ scalar_t v = 0.0;
124
+
125
+ #pragma unroll
126
+ for (int y = 0; y < kernel_h / up_y; y++)
127
+ #pragma unroll
128
+ for (int x = 0; x < kernel_w / up_x; x++)
129
+ v += sx[rel_in_y + y][rel_in_x + x] * sk[kernel_y + y * up_y][kernel_x + x * up_x];
130
+
131
+ if (out_x < p.out_w & out_y < p.out_h) {
132
+ out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v;
133
+ }
134
+ }
135
+ }
136
+ }
137
+ }
138
+
139
+
140
+ torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
141
+ int up_x, int up_y, int down_x, int down_y,
142
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
143
+ int curDevice = -1;
144
+ cudaGetDevice(&curDevice);
145
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
146
+
147
+ UpFirDn2DKernelParams p;
148
+
149
+ auto x = input.contiguous();
150
+ auto k = kernel.contiguous();
151
+
152
+ p.major_dim = x.size(0);
153
+ p.in_h = x.size(1);
154
+ p.in_w = x.size(2);
155
+ p.minor_dim = x.size(3);
156
+ p.kernel_h = k.size(0);
157
+ p.kernel_w = k.size(1);
158
+ p.up_x = up_x;
159
+ p.up_y = up_y;
160
+ p.down_x = down_x;
161
+ p.down_y = down_y;
162
+ p.pad_x0 = pad_x0;
163
+ p.pad_x1 = pad_x1;
164
+ p.pad_y0 = pad_y0;
165
+ p.pad_y1 = pad_y1;
166
+
167
+ p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / p.down_y;
168
+ p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / p.down_x;
169
+
170
+ auto out = at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
171
+
172
+ int mode = -1;
173
+
174
+ int tile_out_h;
175
+ int tile_out_w;
176
+
177
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) {
178
+ mode = 1;
179
+ tile_out_h = 16;
180
+ tile_out_w = 64;
181
+ }
182
+
183
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 3 && p.kernel_w <= 3) {
184
+ mode = 2;
185
+ tile_out_h = 16;
186
+ tile_out_w = 64;
187
+ }
188
+
189
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) {
190
+ mode = 3;
191
+ tile_out_h = 16;
192
+ tile_out_w = 64;
193
+ }
194
+
195
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 2 && p.kernel_w <= 2) {
196
+ mode = 4;
197
+ tile_out_h = 16;
198
+ tile_out_w = 64;
199
+ }
200
+
201
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 4 && p.kernel_w <= 4) {
202
+ mode = 5;
203
+ tile_out_h = 8;
204
+ tile_out_w = 32;
205
+ }
206
+
207
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 2 && p.kernel_w <= 2) {
208
+ mode = 6;
209
+ tile_out_h = 8;
210
+ tile_out_w = 32;
211
+ }
212
+
213
+ dim3 block_size;
214
+ dim3 grid_size;
215
+
216
+ if (tile_out_h > 0 && tile_out_w) {
217
+ p.loop_major = (p.major_dim - 1) / 16384 + 1;
218
+ p.loop_x = 1;
219
+ block_size = dim3(32 * 8, 1, 1);
220
+ grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
221
+ (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
222
+ (p.major_dim - 1) / p.loop_major + 1);
223
+ }
224
+
225
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
226
+ switch (mode) {
227
+ case 1:
228
+ upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64><<<grid_size, block_size, 0, stream>>>(
229
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
230
+ );
231
+
232
+ break;
233
+
234
+ case 2:
235
+ upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64><<<grid_size, block_size, 0, stream>>>(
236
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
237
+ );
238
+
239
+ break;
240
+
241
+ case 3:
242
+ upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64><<<grid_size, block_size, 0, stream>>>(
243
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
244
+ );
245
+
246
+ break;
247
+
248
+ case 4:
249
+ upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64><<<grid_size, block_size, 0, stream>>>(
250
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
251
+ );
252
+
253
+ break;
254
+
255
+ case 5:
256
+ upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32><<<grid_size, block_size, 0, stream>>>(
257
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
258
+ );
259
+
260
+ break;
261
+
262
+ case 6:
263
+ upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32><<<grid_size, block_size, 0, stream>>>(
264
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
265
+ );
266
+
267
+ break;
268
+ }
269
+ });
270
+
271
+ return out;
272
+ }
e4e/options/__init__.py ADDED
File without changes
e4e/options/train_options.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import ArgumentParser
2
+ from configs.paths_config import model_paths
3
+
4
+
5
+ class TrainOptions:
6
+
7
+ def __init__(self):
8
+ self.parser = ArgumentParser()
9
+ self.initialize()
10
+
11
+ def initialize(self):
12
+ self.parser.add_argument('--exp_dir', type=str, help='Path to experiment output directory')
13
+ self.parser.add_argument('--dataset_type', default='ffhq_encode', type=str,
14
+ help='Type of dataset/experiment to run')
15
+ self.parser.add_argument('--encoder_type', default='Encoder4Editing', type=str, help='Which encoder to use')
16
+
17
+ self.parser.add_argument('--batch_size', default=4, type=int, help='Batch size for training')
18
+ self.parser.add_argument('--test_batch_size', default=2, type=int, help='Batch size for testing and inference')
19
+ self.parser.add_argument('--workers', default=4, type=int, help='Number of train dataloader workers')
20
+ self.parser.add_argument('--test_workers', default=2, type=int,
21
+ help='Number of test/inference dataloader workers')
22
+
23
+ self.parser.add_argument('--learning_rate', default=0.0001, type=float, help='Optimizer learning rate')
24
+ self.parser.add_argument('--optim_name', default='ranger', type=str, help='Which optimizer to use')
25
+ self.parser.add_argument('--train_decoder', default=False, type=bool, help='Whether to train the decoder model')
26
+ self.parser.add_argument('--start_from_latent_avg', action='store_true',
27
+ help='Whether to add average latent vector to generate codes from encoder.')
28
+ self.parser.add_argument('--lpips_type', default='alex', type=str, help='LPIPS backbone')
29
+
30
+ self.parser.add_argument('--lpips_lambda', default=0.8, type=float, help='LPIPS loss multiplier factor')
31
+ self.parser.add_argument('--id_lambda', default=0.1, type=float, help='ID loss multiplier factor')
32
+ self.parser.add_argument('--l2_lambda', default=1.0, type=float, help='L2 loss multiplier factor')
33
+
34
+ self.parser.add_argument('--stylegan_weights', default=model_paths['stylegan_ffhq'], type=str,
35
+ help='Path to StyleGAN model weights')
36
+ self.parser.add_argument('--stylegan_size', default=1024, type=int,
37
+ help='size of pretrained StyleGAN Generator')
38
+ self.parser.add_argument('--checkpoint_path', default=None, type=str, help='Path to pSp model checkpoint')
39
+
40
+ self.parser.add_argument('--max_steps', default=500000, type=int, help='Maximum number of training steps')
41
+ self.parser.add_argument('--image_interval', default=100, type=int,
42
+ help='Interval for logging train images during training')
43
+ self.parser.add_argument('--board_interval', default=50, type=int,
44
+ help='Interval for logging metrics to tensorboard')
45
+ self.parser.add_argument('--val_interval', default=1000, type=int, help='Validation interval')
46
+ self.parser.add_argument('--save_interval', default=None, type=int, help='Model checkpoint interval')
47
+
48
+ # Discriminator flags
49
+ self.parser.add_argument('--w_discriminator_lambda', default=0, type=float, help='Dw loss multiplier')
50
+ self.parser.add_argument('--w_discriminator_lr', default=2e-5, type=float, help='Dw learning rate')
51
+ self.parser.add_argument("--r1", type=float, default=10, help="weight of the r1 regularization")
52
+ self.parser.add_argument("--d_reg_every", type=int, default=16,
53
+ help="interval for applying r1 regularization")
54
+ self.parser.add_argument('--use_w_pool', action='store_true',
55
+ help='Whether to store a latnet codes pool for the discriminator\'s training')
56
+ self.parser.add_argument("--w_pool_size", type=int, default=50,
57
+ help="W\'s pool size, depends on --use_w_pool")
58
+
59
+ # e4e specific
60
+ self.parser.add_argument('--delta_norm', type=int, default=2, help="norm type of the deltas")
61
+ self.parser.add_argument('--delta_norm_lambda', type=float, default=2e-4, help="lambda for delta norm loss")
62
+
63
+ # Progressive training
64
+ self.parser.add_argument('--progressive_steps', nargs='+', type=int, default=None,
65
+ help="The training steps of training new deltas. steps[i] starts the delta_i training")
66
+ self.parser.add_argument('--progressive_start', type=int, default=None,
67
+ help="The training step to start training the deltas, overrides progressive_steps")
68
+ self.parser.add_argument('--progressive_step_every', type=int, default=2_000,
69
+ help="Amount of training steps for each progressive step")
70
+
71
+ # Save additional training info to enable future training continuation from produced checkpoints
72
+ self.parser.add_argument('--save_training_data', action='store_true',
73
+ help='Save intermediate training data to resume training from the checkpoint')
74
+ self.parser.add_argument('--sub_exp_dir', default=None, type=str, help='Name of sub experiment directory')
75
+ self.parser.add_argument('--keep_optimizer', action='store_true',
76
+ help='Whether to continue from the checkpoint\'s optimizer')
77
+ self.parser.add_argument('--resume_training_from_ckpt', default=None, type=str,
78
+ help='Path to training checkpoint, works when --save_training_data was set to True')
79
+ self.parser.add_argument('--update_param_list', nargs='+', type=str, default=None,
80
+ help="Name of training parameters to update the loaded training checkpoint")
81
+
82
+ def parse(self):
83
+ opts = self.parser.parse_args()
84
+ return opts
e4e/scripts/calc_losses_on_images.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import ArgumentParser
2
+ import os
3
+ import json
4
+ import sys
5
+ from tqdm import tqdm
6
+ import numpy as np
7
+ import torch
8
+ from torch.utils.data import DataLoader
9
+ import torchvision.transforms as transforms
10
+
11
+ sys.path.append(".")
12
+ sys.path.append("..")
13
+
14
+ from criteria.lpips.lpips import LPIPS
15
+ from datasets.gt_res_dataset import GTResDataset
16
+
17
+
18
+ def parse_args():
19
+ parser = ArgumentParser(add_help=False)
20
+ parser.add_argument('--mode', type=str, default='lpips', choices=['lpips', 'l2'])
21
+ parser.add_argument('--data_path', type=str, default='results')
22
+ parser.add_argument('--gt_path', type=str, default='gt_images')
23
+ parser.add_argument('--workers', type=int, default=4)
24
+ parser.add_argument('--batch_size', type=int, default=4)
25
+ parser.add_argument('--is_cars', action='store_true')
26
+ args = parser.parse_args()
27
+ return args
28
+
29
+
30
+ def run(args):
31
+ resize_dims = (256, 256)
32
+ if args.is_cars:
33
+ resize_dims = (192, 256)
34
+ transform = transforms.Compose([transforms.Resize(resize_dims),
35
+ transforms.ToTensor(),
36
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
37
+
38
+ print('Loading dataset')
39
+ dataset = GTResDataset(root_path=args.data_path,
40
+ gt_dir=args.gt_path,
41
+ transform=transform)
42
+
43
+ dataloader = DataLoader(dataset,
44
+ batch_size=args.batch_size,
45
+ shuffle=False,
46
+ num_workers=int(args.workers),
47
+ drop_last=True)
48
+
49
+ if args.mode == 'lpips':
50
+ loss_func = LPIPS(net_type='alex')
51
+ elif args.mode == 'l2':
52
+ loss_func = torch.nn.MSELoss()
53
+ else:
54
+ raise Exception('Not a valid mode!')
55
+ loss_func.cuda()
56
+
57
+ global_i = 0
58
+ scores_dict = {}
59
+ all_scores = []
60
+ for result_batch, gt_batch in tqdm(dataloader):
61
+ for i in range(args.batch_size):
62
+ loss = float(loss_func(result_batch[i:i + 1].cuda(), gt_batch[i:i + 1].cuda()))
63
+ all_scores.append(loss)
64
+ im_path = dataset.pairs[global_i][0]
65
+ scores_dict[os.path.basename(im_path)] = loss
66
+ global_i += 1
67
+
68
+ all_scores = list(scores_dict.values())
69
+ mean = np.mean(all_scores)
70
+ std = np.std(all_scores)
71
+ result_str = 'Average loss is {:.2f}+-{:.2f}'.format(mean, std)
72
+ print('Finished with ', args.data_path)
73
+ print(result_str)
74
+
75
+ out_path = os.path.join(os.path.dirname(args.data_path), 'inference_metrics')
76
+ if not os.path.exists(out_path):
77
+ os.makedirs(out_path)
78
+
79
+ with open(os.path.join(out_path, 'stat_{}.txt'.format(args.mode)), 'w') as f:
80
+ f.write(result_str)
81
+ with open(os.path.join(out_path, 'scores_{}.json'.format(args.mode)), 'w') as f:
82
+ json.dump(scores_dict, f)
83
+
84
+
85
+ if __name__ == '__main__':
86
+ args = parse_args()
87
+ run(args)
e4e/scripts/inference.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import torch
4
+ import numpy as np
5
+ import sys
6
+ import os
7
+ import dlib
8
+
9
+ sys.path.append(".")
10
+ sys.path.append("..")
11
+
12
+ from configs import data_configs, paths_config
13
+ from datasets.inference_dataset import InferenceDataset
14
+ from torch.utils.data import DataLoader
15
+ from utils.model_utils import setup_model
16
+ from utils.common import tensor2im
17
+ from utils.alignment import align_face
18
+ from PIL import Image
19
+
20
+
21
+ def main(args):
22
+ net, opts = setup_model(args.ckpt, device)
23
+ is_cars = 'cars_' in opts.dataset_type
24
+ generator = net.decoder
25
+ generator.eval()
26
+ args, data_loader = setup_data_loader(args, opts)
27
+
28
+ # Check if latents exist
29
+ latents_file_path = os.path.join(args.save_dir, 'latents.pt')
30
+ if os.path.exists(latents_file_path):
31
+ latent_codes = torch.load(latents_file_path).to(device)
32
+ else:
33
+ latent_codes = get_all_latents(net, data_loader, args.n_sample, is_cars=is_cars)
34
+ torch.save(latent_codes, latents_file_path)
35
+
36
+ if not args.latents_only:
37
+ generate_inversions(args, generator, latent_codes, is_cars=is_cars)
38
+
39
+
40
+ def setup_data_loader(args, opts):
41
+ dataset_args = data_configs.DATASETS[opts.dataset_type]
42
+ transforms_dict = dataset_args['transforms'](opts).get_transforms()
43
+ images_path = args.images_dir if args.images_dir is not None else dataset_args['test_source_root']
44
+ print(f"images path: {images_path}")
45
+ align_function = None
46
+ if args.align:
47
+ align_function = run_alignment
48
+ test_dataset = InferenceDataset(root=images_path,
49
+ transform=transforms_dict['transform_test'],
50
+ preprocess=align_function,
51
+ opts=opts)
52
+
53
+ data_loader = DataLoader(test_dataset,
54
+ batch_size=args.batch,
55
+ shuffle=False,
56
+ num_workers=2,
57
+ drop_last=True)
58
+
59
+ print(f'dataset length: {len(test_dataset)}')
60
+
61
+ if args.n_sample is None:
62
+ args.n_sample = len(test_dataset)
63
+ return args, data_loader
64
+
65
+
66
+ def get_latents(net, x, is_cars=False):
67
+ codes = net.encoder(x)
68
+ if net.opts.start_from_latent_avg:
69
+ if codes.ndim == 2:
70
+ codes = codes + net.latent_avg.repeat(codes.shape[0], 1, 1)[:, 0, :]
71
+ else:
72
+ codes = codes + net.latent_avg.repeat(codes.shape[0], 1, 1)
73
+ if codes.shape[1] == 18 and is_cars:
74
+ codes = codes[:, :16, :]
75
+ return codes
76
+
77
+
78
+ def get_all_latents(net, data_loader, n_images=None, is_cars=False):
79
+ all_latents = []
80
+ i = 0
81
+ with torch.no_grad():
82
+ for batch in data_loader:
83
+ if n_images is not None and i > n_images:
84
+ break
85
+ x = batch
86
+ inputs = x.to(device).float()
87
+ latents = get_latents(net, inputs, is_cars)
88
+ all_latents.append(latents)
89
+ i += len(latents)
90
+ return torch.cat(all_latents)
91
+
92
+
93
+ def save_image(img, save_dir, idx):
94
+ result = tensor2im(img)
95
+ im_save_path = os.path.join(save_dir, f"{idx:05d}.jpg")
96
+ Image.fromarray(np.array(result)).save(im_save_path)
97
+
98
+
99
+ @torch.no_grad()
100
+ def generate_inversions(args, g, latent_codes, is_cars):
101
+ print('Saving inversion images')
102
+ inversions_directory_path = os.path.join(args.save_dir, 'inversions')
103
+ os.makedirs(inversions_directory_path, exist_ok=True)
104
+ for i in range(args.n_sample):
105
+ imgs, _ = g([latent_codes[i].unsqueeze(0)], input_is_latent=True, randomize_noise=False, return_latents=True)
106
+ if is_cars:
107
+ imgs = imgs[:, :, 64:448, :]
108
+ save_image(imgs[0], inversions_directory_path, i + 1)
109
+
110
+
111
+ def run_alignment(image_path):
112
+ predictor = dlib.shape_predictor(paths_config.model_paths['shape_predictor'])
113
+ aligned_image = align_face(filepath=image_path, predictor=predictor)
114
+ print("Aligned image has shape: {}".format(aligned_image.size))
115
+ return aligned_image
116
+
117
+
118
+ if __name__ == "__main__":
119
+ device = "cuda"
120
+
121
+ parser = argparse.ArgumentParser(description="Inference")
122
+ parser.add_argument("--images_dir", type=str, default=None,
123
+ help="The directory of the images to be inverted")
124
+ parser.add_argument("--save_dir", type=str, default=None,
125
+ help="The directory to save the latent codes and inversion images. (default: images_dir")
126
+ parser.add_argument("--batch", type=int, default=1, help="batch size for the generator")
127
+ parser.add_argument("--n_sample", type=int, default=None, help="number of the samples to infer.")
128
+ parser.add_argument("--latents_only", action="store_true", help="infer only the latent codes of the directory")
129
+ parser.add_argument("--align", action="store_true", help="align face images before inference")
130
+ parser.add_argument("ckpt", metavar="CHECKPOINT", help="path to generator checkpoint")
131
+
132
+ args = parser.parse_args()
133
+ main(args)
e4e/scripts/train.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file runs the main training/val loop
3
+ """
4
+ import os
5
+ import json
6
+ import math
7
+ import sys
8
+ import pprint
9
+ import torch
10
+ from argparse import Namespace
11
+
12
+ sys.path.append(".")
13
+ sys.path.append("..")
14
+
15
+ from options.train_options import TrainOptions
16
+ from training.coach import Coach
17
+
18
+
19
+ def main():
20
+ opts = TrainOptions().parse()
21
+ previous_train_ckpt = None
22
+ if opts.resume_training_from_ckpt:
23
+ opts, previous_train_ckpt = load_train_checkpoint(opts)
24
+ else:
25
+ setup_progressive_steps(opts)
26
+ create_initial_experiment_dir(opts)
27
+
28
+ coach = Coach(opts, previous_train_ckpt)
29
+ coach.train()
30
+
31
+
32
+ def load_train_checkpoint(opts):
33
+ train_ckpt_path = opts.resume_training_from_ckpt
34
+ previous_train_ckpt = torch.load(opts.resume_training_from_ckpt, map_location='cpu')
35
+ new_opts_dict = vars(opts)
36
+ opts = previous_train_ckpt['opts']
37
+ opts['resume_training_from_ckpt'] = train_ckpt_path
38
+ update_new_configs(opts, new_opts_dict)
39
+ pprint.pprint(opts)
40
+ opts = Namespace(**opts)
41
+ if opts.sub_exp_dir is not None:
42
+ sub_exp_dir = opts.sub_exp_dir
43
+ opts.exp_dir = os.path.join(opts.exp_dir, sub_exp_dir)
44
+ create_initial_experiment_dir(opts)
45
+ return opts, previous_train_ckpt
46
+
47
+
48
+ def setup_progressive_steps(opts):
49
+ log_size = int(math.log(opts.stylegan_size, 2))
50
+ num_style_layers = 2*log_size - 2
51
+ num_deltas = num_style_layers - 1
52
+ if opts.progressive_start is not None: # If progressive delta training
53
+ opts.progressive_steps = [0]
54
+ next_progressive_step = opts.progressive_start
55
+ for i in range(num_deltas):
56
+ opts.progressive_steps.append(next_progressive_step)
57
+ next_progressive_step += opts.progressive_step_every
58
+
59
+ assert opts.progressive_steps is None or is_valid_progressive_steps(opts, num_style_layers), \
60
+ "Invalid progressive training input"
61
+
62
+
63
+ def is_valid_progressive_steps(opts, num_style_layers):
64
+ return len(opts.progressive_steps) == num_style_layers and opts.progressive_steps[0] == 0
65
+
66
+
67
+ def create_initial_experiment_dir(opts):
68
+ if os.path.exists(opts.exp_dir):
69
+ raise Exception('Oops... {} already exists'.format(opts.exp_dir))
70
+ os.makedirs(opts.exp_dir)
71
+
72
+ opts_dict = vars(opts)
73
+ pprint.pprint(opts_dict)
74
+ with open(os.path.join(opts.exp_dir, 'opt.json'), 'w') as f:
75
+ json.dump(opts_dict, f, indent=4, sort_keys=True)
76
+
77
+
78
+ def update_new_configs(ckpt_opts, new_opts):
79
+ for k, v in new_opts.items():
80
+ if k not in ckpt_opts:
81
+ ckpt_opts[k] = v
82
+ if new_opts['update_param_list']:
83
+ for param in new_opts['update_param_list']:
84
+ ckpt_opts[param] = new_opts[param]
85
+
86
+
87
+ if __name__ == '__main__':
88
+ main()
e4e/utils/__init__.py ADDED
File without changes
e4e/utils/alignment.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import PIL
3
+ import PIL.Image
4
+ import scipy
5
+ import scipy.ndimage
6
+ import dlib
7
+
8
+
9
+ def get_landmark(filepath, predictor):
10
+ """get landmark with dlib
11
+ :return: np.array shape=(68, 2)
12
+ """
13
+ detector = dlib.get_frontal_face_detector()
14
+
15
+ img = dlib.load_rgb_image(filepath)
16
+ dets = detector(img, 1)
17
+
18
+ for k, d in enumerate(dets):
19
+ shape = predictor(img, d)
20
+
21
+ t = list(shape.parts())
22
+ a = []
23
+ for tt in t:
24
+ a.append([tt.x, tt.y])
25
+ lm = np.array(a)
26
+ return lm
27
+
28
+
29
+ def align_face(filepath, predictor):
30
+ """
31
+ :param filepath: str
32
+ :return: PIL Image
33
+ """
34
+
35
+ lm = get_landmark(filepath, predictor)
36
+
37
+ lm_chin = lm[0: 17] # left-right
38
+ lm_eyebrow_left = lm[17: 22] # left-right
39
+ lm_eyebrow_right = lm[22: 27] # left-right
40
+ lm_nose = lm[27: 31] # top-down
41
+ lm_nostrils = lm[31: 36] # top-down
42
+ lm_eye_left = lm[36: 42] # left-clockwise
43
+ lm_eye_right = lm[42: 48] # left-clockwise
44
+ lm_mouth_outer = lm[48: 60] # left-clockwise
45
+ lm_mouth_inner = lm[60: 68] # left-clockwise
46
+
47
+ # Calculate auxiliary vectors.
48
+ eye_left = np.mean(lm_eye_left, axis=0)
49
+ eye_right = np.mean(lm_eye_right, axis=0)
50
+ eye_avg = (eye_left + eye_right) * 0.5
51
+ eye_to_eye = eye_right - eye_left
52
+ mouth_left = lm_mouth_outer[0]
53
+ mouth_right = lm_mouth_outer[6]
54
+ mouth_avg = (mouth_left + mouth_right) * 0.5
55
+ eye_to_mouth = mouth_avg - eye_avg
56
+
57
+ # Choose oriented crop rectangle.
58
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
59
+ x /= np.hypot(*x)
60
+ x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
61
+ y = np.flipud(x) * [-1, 1]
62
+ c = eye_avg + eye_to_mouth * 0.1
63
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
64
+ qsize = np.hypot(*x) * 2
65
+
66
+ # read image
67
+ img = PIL.Image.open(filepath)
68
+
69
+ output_size = 256
70
+ transform_size = 256
71
+ enable_padding = True
72
+
73
+ # Shrink.
74
+ shrink = int(np.floor(qsize / output_size * 0.5))
75
+ if shrink > 1:
76
+ rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
77
+ img = img.resize(rsize, PIL.Image.ANTIALIAS)
78
+ quad /= shrink
79
+ qsize /= shrink
80
+
81
+ # Crop.
82
+ border = max(int(np.rint(qsize * 0.1)), 3)
83
+ crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
84
+ int(np.ceil(max(quad[:, 1]))))
85
+ crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]),
86
+ min(crop[3] + border, img.size[1]))
87
+ if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
88
+ img = img.crop(crop)
89
+ quad -= crop[0:2]
90
+
91
+ # Pad.
92
+ pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
93
+ int(np.ceil(max(quad[:, 1]))))
94
+ pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0),
95
+ max(pad[3] - img.size[1] + border, 0))
96
+ if enable_padding and max(pad) > border - 4:
97
+ pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
98
+ img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
99
+ h, w, _ = img.shape
100
+ y, x, _ = np.ogrid[:h, :w, :1]
101
+ mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]),
102
+ 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3]))
103
+ blur = qsize * 0.02
104
+ img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
105
+ img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
106
+ img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
107
+ quad += pad[:2]
108
+
109
+ # Transform.
110
+ img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR)
111
+ if output_size < transform_size:
112
+ img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)
113
+
114
+ # Return aligned image.
115
+ return img
e4e/utils/common.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import matplotlib.pyplot as plt
3
+
4
+
5
+ # Log images
6
+ def log_input_image(x, opts):
7
+ return tensor2im(x)
8
+
9
+
10
+ def tensor2im(var):
11
+ # var shape: (3, H, W)
12
+ var = var.cpu().detach().transpose(0, 2).transpose(0, 1).numpy()
13
+ var = ((var + 1) / 2)
14
+ var[var < 0] = 0
15
+ var[var > 1] = 1
16
+ var = var * 255
17
+ return Image.fromarray(var.astype('uint8'))
18
+
19
+
20
+ def vis_faces(log_hooks):
21
+ display_count = len(log_hooks)
22
+ fig = plt.figure(figsize=(8, 4 * display_count))
23
+ gs = fig.add_gridspec(display_count, 3)
24
+ for i in range(display_count):
25
+ hooks_dict = log_hooks[i]
26
+ fig.add_subplot(gs[i, 0])
27
+ if 'diff_input' in hooks_dict:
28
+ vis_faces_with_id(hooks_dict, fig, gs, i)
29
+ else:
30
+ vis_faces_no_id(hooks_dict, fig, gs, i)
31
+ plt.tight_layout()
32
+ return fig
33
+
34
+
35
+ def vis_faces_with_id(hooks_dict, fig, gs, i):
36
+ plt.imshow(hooks_dict['input_face'])
37
+ plt.title('Input\nOut Sim={:.2f}'.format(float(hooks_dict['diff_input'])))
38
+ fig.add_subplot(gs[i, 1])
39
+ plt.imshow(hooks_dict['target_face'])
40
+ plt.title('Target\nIn={:.2f}, Out={:.2f}'.format(float(hooks_dict['diff_views']),
41
+ float(hooks_dict['diff_target'])))
42
+ fig.add_subplot(gs[i, 2])
43
+ plt.imshow(hooks_dict['output_face'])
44
+ plt.title('Output\n Target Sim={:.2f}'.format(float(hooks_dict['diff_target'])))
45
+
46
+
47
+ def vis_faces_no_id(hooks_dict, fig, gs, i):
48
+ plt.imshow(hooks_dict['input_face'], cmap="gray")
49
+ plt.title('Input')
50
+ fig.add_subplot(gs[i, 1])
51
+ plt.imshow(hooks_dict['target_face'])
52
+ plt.title('Target')
53
+ fig.add_subplot(gs[i, 2])
54
+ plt.imshow(hooks_dict['output_face'])
55
+ plt.title('Output')
e4e/utils/data_utils.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code adopted from pix2pixHD:
3
+ https://github.com/NVIDIA/pix2pixHD/blob/master/data/image_folder.py
4
+ """
5
+ import os
6
+
7
+ IMG_EXTENSIONS = [
8
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
9
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff'
10
+ ]
11
+
12
+
13
+ def is_image_file(filename):
14
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
15
+
16
+
17
+ def make_dataset(dir):
18
+ images = []
19
+ assert os.path.isdir(dir), '%s is not a valid directory' % dir
20
+ for root, _, fnames in sorted(os.walk(dir)):
21
+ for fname in fnames:
22
+ if is_image_file(fname):
23
+ path = os.path.join(root, fname)
24
+ images.append(path)
25
+ return images
e4e/utils/model_utils.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import argparse
3
+ from models.psp import pSp
4
+ from models.encoders.psp_encoders import Encoder4Editing
5
+
6
+
7
+ def setup_model(checkpoint_path, device='cuda'):
8
+ ckpt = torch.load(checkpoint_path, map_location='cpu')
9
+ opts = ckpt['opts']
10
+
11
+ opts['checkpoint_path'] = checkpoint_path
12
+ opts['device'] = device
13
+ opts = argparse.Namespace(**opts)
14
+
15
+ net = pSp(opts)
16
+ net.eval()
17
+ net = net.to(device)
18
+ return net, opts
19
+
20
+
21
+ def load_e4e_standalone(checkpoint_path, device='cuda'):
22
+ ckpt = torch.load(checkpoint_path, map_location='cpu')
23
+ opts = argparse.Namespace(**ckpt['opts'])
24
+ e4e = Encoder4Editing(50, 'ir_se', opts)
25
+ e4e_dict = {k.replace('encoder.', ''): v for k, v in ckpt['state_dict'].items() if k.startswith('encoder.')}
26
+ e4e.load_state_dict(e4e_dict)
27
+ e4e.eval()
28
+ e4e = e4e.to(device)
29
+ latent_avg = ckpt['latent_avg'].to(device)
30
+
31
+ def add_latent_avg(model, inputs, outputs):
32
+ return outputs + latent_avg.repeat(outputs.shape[0], 1, 1)
33
+
34
+ e4e.register_forward_hook(add_latent_avg)
35
+ return e4e
e4e/utils/train_utils.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ def aggregate_loss_dict(agg_loss_dict):
3
+ mean_vals = {}
4
+ for output in agg_loss_dict:
5
+ for key in output:
6
+ mean_vals[key] = mean_vals.setdefault(key, []) + [output[key]]
7
+ for key in mean_vals:
8
+ if len(mean_vals[key]) > 0:
9
+ mean_vals[key] = sum(mean_vals[key]) / len(mean_vals[key])
10
+ else:
11
+ print('{} has no value'.format(key))
12
+ mean_vals[key] = 0
13
+ return mean_vals
editing/interfacegan_boundaries/age.pt ADDED
Binary file (2.8 kB). View file
 
editing/interfacegan_boundaries/beard.pt ADDED
Binary file (2.8 kB). View file
 
editing/interfacegan_boundaries/gender.pt ADDED
Binary file (2.8 kB). View file
 
editing/interfacegan_boundaries/hair_length.pt ADDED
Binary file (2.8 kB). View file
 
editing/interfacegan_boundaries/pose.pt ADDED
Binary file (37.6 kB). View file
 
editing/interfacegan_boundaries/smile.pt ADDED
Binary file (2.8 kB). View file
 
generate_videos.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+
4
+ import torch
5
+ from torchvision import utils
6
+
7
+ from model.sg2_model import Generator
8
+ from tqdm import tqdm
9
+ from pathlib import Path
10
+
11
+ import numpy as np
12
+
13
+ import subprocess
14
+ import shutil
15
+ import copy
16
+
17
+ from styleclip.styleclip_global import style_tensor_to_style_dict, style_dict_to_style_tensor
18
+
19
+ VALID_EDITS = ["pose", "age", "smile", "gender", "hair_length", "beard"]
20
+
21
+ SUGGESTED_DISTANCES = {
22
+ "pose": 3.0,
23
+ "smile": 2.0,
24
+ "age": 4.0,
25
+ "gender": 3.0,
26
+ "hair_length": -4.0,
27
+ "beard": 2.0
28
+ }
29
+
30
+ def project_code(latent_code, boundary, distance=3.0):
31
+
32
+ if len(boundary) == 2:
33
+ boundary = boundary.reshape(1, 1, -1)
34
+
35
+ return latent_code + distance * boundary
36
+
37
+ def project_code_by_edit_name(latent_code, name, strength):
38
+ boundary_dir = Path(os.path.abspath(__file__)).parents[0].joinpath("editing", "interfacegan_boundaries")
39
+
40
+ distance = SUGGESTED_DISTANCES[name] * strength
41
+ boundary = torch.load(os.path.join(boundary_dir, f'{name}.pt'), map_location="cpu").numpy()
42
+
43
+ return project_code(latent_code, boundary, distance)
44
+
45
+ def generate_frames(source_latent, target_latents, g_ema_list, output_dir):
46
+
47
+ device = "cuda" if torch.cuda.is_available() else "cpu"
48
+
49
+ code_is_s = target_latents[0].size()[1] == 9088
50
+
51
+ if code_is_s:
52
+ source_s_dict = g_ema_list[0].get_s_code(source_latent, input_is_latent=True)[0]
53
+ np_latent = style_dict_to_style_tensor(source_s_dict, g_ema_list[0]).cpu().detach().numpy()
54
+ else:
55
+ np_latent = source_latent.squeeze(0).cpu().detach().numpy()
56
+
57
+ np_target_latents = [target_latent.cpu().detach().numpy() for target_latent in target_latents]
58
+
59
+ num_alphas = 20 if code_is_s else min(10, 30 // len(target_latents))
60
+
61
+ alphas = np.linspace(0, 1, num=num_alphas)
62
+
63
+ latents = interpolate_with_target_latents(np_latent, np_target_latents, alphas)
64
+
65
+ segments = len(g_ema_list) - 1
66
+
67
+ if segments:
68
+ segment_length = len(latents) / segments
69
+
70
+ g_ema = copy.deepcopy(g_ema_list[0])
71
+
72
+ src_pars = dict(g_ema.named_parameters())
73
+ mix_pars = [dict(model.named_parameters()) for model in g_ema_list]
74
+ else:
75
+ g_ema = g_ema_list[0]
76
+
77
+ print("Generating frames for video...")
78
+ for idx, latent in tqdm(enumerate(latents), total=len(latents)):
79
+
80
+ if segments:
81
+ mix_alpha = (idx % segment_length) * 1.0 / segment_length
82
+ segment_id = int(idx // segment_length)
83
+
84
+ for k in src_pars.keys():
85
+ src_pars[k].data.copy_(mix_pars[segment_id][k] * (1 - mix_alpha) + mix_pars[segment_id + 1][k] * mix_alpha)
86
+
87
+ if idx == 0 or segments or latent is not latents[idx - 1]:
88
+ latent_tensor = torch.from_numpy(latent).float().to(device)
89
+
90
+ with torch.no_grad():
91
+ if code_is_s:
92
+ latent_for_gen = style_tensor_to_style_dict(latent_tensor, g_ema)
93
+ img, _ = g_ema(latent_for_gen, input_is_s_code=True, input_is_latent=True, truncation=1, randomize_noise=False)
94
+ else:
95
+ img, _ = g_ema([latent_tensor], input_is_latent=True, truncation=1, randomize_noise=False)
96
+
97
+ utils.save_image(img, f"{output_dir}/{str(idx).zfill(3)}.jpg", nrow=1, normalize=True, scale_each=True, range=(-1, 1))
98
+
99
+ def interpolate_forward_backward(source_latent, target_latent, alphas):
100
+ latents_forward = [a * target_latent + (1-a) * source_latent for a in alphas] # interpolate from source to target
101
+ latents_backward = latents_forward[::-1] # interpolate from target to source
102
+ return latents_forward + [target_latent] * len(alphas) + latents_backward # forward + short delay at target + return
103
+
104
+ def interpolate_with_target_latents(source_latent, target_latents, alphas):
105
+ # interpolate latent codes with all targets
106
+
107
+ print("Interpolating latent codes...")
108
+
109
+ latents = []
110
+ for target_latent in target_latents:
111
+ latents.extend(interpolate_forward_backward(source_latent, target_latent, alphas))
112
+
113
+ return latents
114
+
115
+ def video_from_interpolations(fps, output_dir):
116
+
117
+ # combine frames to a video
118
+ command = ["ffmpeg",
119
+ "-r", f"{fps}",
120
+ "-i", f"{output_dir}/%03d.jpg",
121
+ "-c:v", "libx264",
122
+ "-vf", f"fps={fps}",
123
+ "-pix_fmt", "yuv420p",
124
+ f"{output_dir}/out.mp4"]
125
+
126
+ subprocess.call(command)
127
+
128
+
129
+
model/sg2_model.py ADDED
@@ -0,0 +1,817 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import functools
4
+ import operator
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+ from torch.autograd import Function
10
+
11
+ from op import conv2d_gradfix
12
+
13
+ if torch.cuda.is_available():
14
+ from op.fused_act import FusedLeakyReLU, fused_leaky_relu
15
+ from op.upfirdn2d import upfirdn2d
16
+ else:
17
+ from op.fused_act_cpu import FusedLeakyReLU, fused_leaky_relu
18
+ from op.upfirdn2d_cpu import upfirdn2d
19
+
20
+
21
+ class PixelNorm(nn.Module):
22
+ def __init__(self):
23
+ super().__init__()
24
+
25
+ def forward(self, input):
26
+ return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
27
+
28
+
29
+ def make_kernel(k):
30
+ k = torch.tensor(k, dtype=torch.float32)
31
+
32
+ if k.ndim == 1:
33
+ k = k[None, :] * k[:, None]
34
+
35
+ k /= k.sum()
36
+
37
+ return k
38
+
39
+
40
+ class Upsample(nn.Module):
41
+ def __init__(self, kernel, factor=2):
42
+ super().__init__()
43
+
44
+ self.factor = factor
45
+ kernel = make_kernel(kernel) * (factor ** 2)
46
+ self.register_buffer("kernel", kernel)
47
+
48
+ p = kernel.shape[0] - factor
49
+
50
+ pad0 = (p + 1) // 2 + factor - 1
51
+ pad1 = p // 2
52
+
53
+ self.pad = (pad0, pad1)
54
+
55
+ def forward(self, input):
56
+ out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
57
+
58
+ return out
59
+
60
+
61
+ class Downsample(nn.Module):
62
+ def __init__(self, kernel, factor=2):
63
+ super().__init__()
64
+
65
+ self.factor = factor
66
+ kernel = make_kernel(kernel)
67
+ self.register_buffer("kernel", kernel)
68
+
69
+ p = kernel.shape[0] - factor
70
+
71
+ pad0 = (p + 1) // 2
72
+ pad1 = p // 2
73
+
74
+ self.pad = (pad0, pad1)
75
+
76
+ def forward(self, input):
77
+ out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
78
+
79
+ return out
80
+
81
+
82
+ class Blur(nn.Module):
83
+ def __init__(self, kernel, pad, upsample_factor=1):
84
+ super().__init__()
85
+
86
+ kernel = make_kernel(kernel)
87
+
88
+ if upsample_factor > 1:
89
+ kernel = kernel * (upsample_factor ** 2)
90
+
91
+ self.register_buffer("kernel", kernel)
92
+
93
+ self.pad = pad
94
+
95
+ def forward(self, input):
96
+ out = upfirdn2d(input, self.kernel, pad=self.pad)
97
+
98
+ return out
99
+
100
+
101
+ class EqualConv2d(nn.Module):
102
+ def __init__(
103
+ self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
104
+ ):
105
+ super().__init__()
106
+
107
+ self.weight = nn.Parameter(
108
+ torch.randn(out_channel, in_channel, kernel_size, kernel_size)
109
+ )
110
+ self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
111
+
112
+ self.stride = stride
113
+ self.padding = padding
114
+
115
+ if bias:
116
+ self.bias = nn.Parameter(torch.zeros(out_channel))
117
+
118
+ else:
119
+ self.bias = None
120
+
121
+ def forward(self, input):
122
+ out = conv2d_gradfix.conv2d(
123
+ input,
124
+ self.weight * self.scale,
125
+ bias=self.bias,
126
+ stride=self.stride,
127
+ padding=self.padding,
128
+ )
129
+
130
+ return out
131
+
132
+ def __repr__(self):
133
+ return (
134
+ f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
135
+ f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
136
+ )
137
+
138
+
139
+ class EqualLinear(nn.Module):
140
+ def __init__(
141
+ self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
142
+ ):
143
+ super().__init__()
144
+
145
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
146
+
147
+ if bias:
148
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
149
+
150
+ else:
151
+ self.bias = None
152
+
153
+ self.activation = activation
154
+
155
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
156
+ self.lr_mul = lr_mul
157
+
158
+ def forward(self, input):
159
+ if self.activation:
160
+ out = F.linear(input, self.weight * self.scale)
161
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
162
+
163
+ else:
164
+ out = F.linear(
165
+ input, self.weight * self.scale, bias=self.bias * self.lr_mul
166
+ )
167
+
168
+ return out
169
+
170
+ def __repr__(self):
171
+ return (
172
+ f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
173
+ )
174
+
175
+
176
+ class ModulatedConv2d(nn.Module):
177
+ def __init__(
178
+ self,
179
+ in_channel,
180
+ out_channel,
181
+ kernel_size,
182
+ style_dim,
183
+ demodulate=True,
184
+ upsample=False,
185
+ downsample=False,
186
+ blur_kernel=[1, 3, 3, 1],
187
+ fused=True,
188
+ ):
189
+ super().__init__()
190
+
191
+ self.eps = 1e-8
192
+ self.kernel_size = kernel_size
193
+ self.in_channel = in_channel
194
+ self.out_channel = out_channel
195
+ self.upsample = upsample
196
+ self.downsample = downsample
197
+
198
+ if upsample:
199
+ factor = 2
200
+ p = (len(blur_kernel) - factor) - (kernel_size - 1)
201
+ pad0 = (p + 1) // 2 + factor - 1
202
+ pad1 = p // 2 + 1
203
+
204
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
205
+
206
+ if downsample:
207
+ factor = 2
208
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
209
+ pad0 = (p + 1) // 2
210
+ pad1 = p // 2
211
+
212
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1))
213
+
214
+ fan_in = in_channel * kernel_size ** 2
215
+ self.scale = 1 / math.sqrt(fan_in)
216
+ self.padding = kernel_size // 2
217
+
218
+ self.weight = nn.Parameter(
219
+ torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
220
+ )
221
+
222
+ self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
223
+
224
+ self.demodulate = demodulate
225
+ self.fused = fused
226
+
227
+ def __repr__(self):
228
+ return (
229
+ f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
230
+ f"upsample={self.upsample}, downsample={self.downsample})"
231
+ )
232
+
233
+ def forward(self, input, style, is_s_code=False):
234
+ batch, in_channel, height, width = input.shape
235
+
236
+ if not self.fused:
237
+ weight = self.scale * self.weight.squeeze(0)
238
+
239
+ if is_s_code:
240
+ style = style[self.modulation]
241
+ else:
242
+ style = self.modulation(style)
243
+
244
+ if self.demodulate:
245
+ w = weight.unsqueeze(0) * style.view(batch, 1, in_channel, 1, 1)
246
+ dcoefs = (w.square().sum((2, 3, 4)) + 1e-8).rsqrt()
247
+
248
+ input = input * style.reshape(batch, in_channel, 1, 1)
249
+
250
+ if self.upsample:
251
+ weight = weight.transpose(0, 1)
252
+ out = conv2d_gradfix.conv_transpose2d(
253
+ input, weight, padding=0, stride=2
254
+ )
255
+ out = self.blur(out)
256
+
257
+ elif self.downsample:
258
+ input = self.blur(input)
259
+ out = conv2d_gradfix.conv2d(input, weight, padding=0, stride=2)
260
+
261
+ else:
262
+ out = conv2d_gradfix.conv2d(input, weight, padding=self.padding)
263
+
264
+ if self.demodulate:
265
+ out = out * dcoefs.view(batch, -1, 1, 1)
266
+
267
+ return out
268
+
269
+ if is_s_code:
270
+ style = style[self.modulation]
271
+ else:
272
+ style = self.modulation(style)
273
+
274
+ style = style.view(batch, 1, in_channel, 1, 1)
275
+ weight = self.scale * self.weight * style
276
+
277
+ if self.demodulate:
278
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
279
+ weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
280
+
281
+ weight = weight.view(
282
+ batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
283
+ )
284
+
285
+ if self.upsample:
286
+ input = input.view(1, batch * in_channel, height, width)
287
+ weight = weight.view(
288
+ batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
289
+ )
290
+ weight = weight.transpose(1, 2).reshape(
291
+ batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
292
+ )
293
+ out = conv2d_gradfix.conv_transpose2d(
294
+ input, weight, padding=0, stride=2, groups=batch
295
+ )
296
+ _, _, height, width = out.shape
297
+ out = out.view(batch, self.out_channel, height, width)
298
+ out = self.blur(out)
299
+
300
+ elif self.downsample:
301
+ input = self.blur(input)
302
+ _, _, height, width = input.shape
303
+ input = input.view(1, batch * in_channel, height, width)
304
+ out = conv2d_gradfix.conv2d(
305
+ input, weight, padding=0, stride=2, groups=batch
306
+ )
307
+ _, _, height, width = out.shape
308
+ out = out.view(batch, self.out_channel, height, width)
309
+
310
+ else:
311
+ input = input.view(1, batch * in_channel, height, width)
312
+ out = conv2d_gradfix.conv2d(
313
+ input, weight, padding=self.padding, groups=batch
314
+ )
315
+ _, _, height, width = out.shape
316
+ out = out.view(batch, self.out_channel, height, width)
317
+
318
+ return out
319
+
320
+
321
+ class NoiseInjection(nn.Module):
322
+ def __init__(self):
323
+ super().__init__()
324
+
325
+ self.weight = nn.Parameter(torch.zeros(1))
326
+
327
+ def forward(self, image, noise=None):
328
+ if noise is None:
329
+ batch, _, height, width = image.shape
330
+ noise = image.new_empty(batch, 1, height, width).normal_()
331
+
332
+ return image + self.weight * noise
333
+
334
+
335
+ class ConstantInput(nn.Module):
336
+ def __init__(self, channel, size=4):
337
+ super().__init__()
338
+
339
+ self.input = nn.Parameter(torch.randn(1, channel, size, size))
340
+
341
+ def forward(self, input, is_s_code=False):
342
+ if not is_s_code:
343
+ batch = input.shape[0]
344
+ else:
345
+ batch = next(iter(input.values())).shape[0]
346
+
347
+ out = self.input.repeat(batch, 1, 1, 1)
348
+
349
+ return out
350
+
351
+
352
+ class StyledConv(nn.Module):
353
+ def __init__(
354
+ self,
355
+ in_channel,
356
+ out_channel,
357
+ kernel_size,
358
+ style_dim,
359
+ upsample=False,
360
+ blur_kernel=[1, 3, 3, 1],
361
+ demodulate=True,
362
+ ):
363
+ super().__init__()
364
+
365
+ self.conv = ModulatedConv2d(
366
+ in_channel,
367
+ out_channel,
368
+ kernel_size,
369
+ style_dim,
370
+ upsample=upsample,
371
+ blur_kernel=blur_kernel,
372
+ demodulate=demodulate,
373
+ )
374
+
375
+ self.noise = NoiseInjection()
376
+ # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
377
+ # self.activate = ScaledLeakyReLU(0.2)
378
+ self.activate = FusedLeakyReLU(out_channel)
379
+
380
+ def forward(self, input, style, noise=None, is_s_code=False):
381
+ out = self.conv(input, style, is_s_code=is_s_code)
382
+ out = self.noise(out, noise=noise)
383
+ # out = out + self.bias
384
+ out = self.activate(out)
385
+
386
+ return out
387
+
388
+
389
+ class ToRGB(nn.Module):
390
+ def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
391
+ super().__init__()
392
+
393
+ if upsample:
394
+ self.upsample = Upsample(blur_kernel)
395
+
396
+ self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
397
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
398
+
399
+ def forward(self, input, style, skip=None, is_s_code=False):
400
+ out = self.conv(input, style, is_s_code=is_s_code)
401
+ out = out + self.bias
402
+
403
+ if skip is not None:
404
+ skip = self.upsample(skip)
405
+
406
+ out = out + skip
407
+
408
+ return out
409
+
410
+
411
+ class Generator(nn.Module):
412
+ def __init__(
413
+ self,
414
+ size,
415
+ style_dim,
416
+ n_mlp,
417
+ channel_multiplier=2,
418
+ blur_kernel=[1, 3, 3, 1],
419
+ lr_mlp=0.01,
420
+ ):
421
+ super().__init__()
422
+
423
+ self.size = size
424
+
425
+ self.style_dim = style_dim
426
+
427
+ layers = [PixelNorm()]
428
+
429
+ for i in range(n_mlp):
430
+ layers.append(
431
+ EqualLinear(
432
+ style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu"
433
+ )
434
+ )
435
+
436
+ self.style = nn.Sequential(*layers)
437
+
438
+ self.channels = {
439
+ 4: 512,
440
+ 8: 512,
441
+ 16: 512,
442
+ 32: 512,
443
+ 64: 256 * channel_multiplier,
444
+ 128: 128 * channel_multiplier,
445
+ 256: 64 * channel_multiplier,
446
+ 512: 32 * channel_multiplier,
447
+ 1024: 16 * channel_multiplier,
448
+ }
449
+
450
+ self.input = ConstantInput(self.channels[4])
451
+ self.conv1 = StyledConv(
452
+ self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
453
+ )
454
+ self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
455
+
456
+ self.log_size = int(math.log(size, 2))
457
+ self.num_layers = (self.log_size - 2) * 2 + 1
458
+
459
+ self.convs = nn.ModuleList()
460
+ self.upsamples = nn.ModuleList()
461
+ self.to_rgbs = nn.ModuleList()
462
+ self.noises = nn.Module()
463
+
464
+ in_channel = self.channels[4]
465
+
466
+ for layer_idx in range(self.num_layers):
467
+ res = (layer_idx + 5) // 2
468
+ shape = [1, 1, 2 ** res, 2 ** res]
469
+ self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape))
470
+
471
+ for i in range(3, self.log_size + 1):
472
+ out_channel = self.channels[2 ** i]
473
+
474
+ self.convs.append(
475
+ StyledConv(
476
+ in_channel,
477
+ out_channel,
478
+ 3,
479
+ style_dim,
480
+ upsample=True,
481
+ blur_kernel=blur_kernel,
482
+ )
483
+ )
484
+
485
+ self.convs.append(
486
+ StyledConv(
487
+ out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
488
+ )
489
+ )
490
+
491
+ self.to_rgbs.append(ToRGB(out_channel, style_dim))
492
+
493
+ in_channel = out_channel
494
+
495
+ self.n_latent = self.log_size * 2 - 2
496
+
497
+
498
+ self.modulation_layers = [self.conv1.conv.modulation, self.to_rgb1.conv.modulation] + \
499
+ [layer.conv.modulation for layer in self.convs] + \
500
+ [layer.conv.modulation for layer in self.to_rgbs]
501
+
502
+ def make_noise(self):
503
+ device = self.input.input.device
504
+
505
+ noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
506
+
507
+ for i in range(3, self.log_size + 1):
508
+ for _ in range(2):
509
+ noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
510
+
511
+ return noises
512
+
513
+ def mean_latent(self, n_latent):
514
+ latent_in = torch.randn(
515
+ n_latent, self.style_dim, device=self.input.input.device
516
+ )
517
+ latent = self.style(latent_in).mean(0, keepdim=True)
518
+
519
+ return latent
520
+
521
+ def get_latent(self, input):
522
+ return self.style(input)
523
+
524
+ def get_s_code(self, styles, input_is_latent):
525
+
526
+ if not input_is_latent:
527
+ styles = [self.style(s) for s in styles]
528
+
529
+ s_codes = [{# const block
530
+ self.modulation_layers[0]: self.modulation_layers[0](style[:, 0]), #s0
531
+ self.modulation_layers[1]: self.modulation_layers[1](style[:, 1]), #s1
532
+ # conv layers
533
+ self.modulation_layers[2]: self.modulation_layers[2](style[:, 1]), #s2
534
+ self.modulation_layers[3]: self.modulation_layers[3](style[:, 2]), #s3
535
+ self.modulation_layers[4]: self.modulation_layers[4](style[:, 3]), #s5
536
+ self.modulation_layers[5]: self.modulation_layers[5](style[:, 4]), #s6
537
+ self.modulation_layers[6]: self.modulation_layers[6](style[:, 5]), #s8
538
+ self.modulation_layers[7]: self.modulation_layers[7](style[:, 6]), #s9
539
+ self.modulation_layers[8]: self.modulation_layers[8](style[:, 7]), #s11
540
+ self.modulation_layers[9]: self.modulation_layers[9](style[:, 8]), #s12
541
+ self.modulation_layers[10]: self.modulation_layers[10](style[:, 9]), #s14
542
+ self.modulation_layers[11]: self.modulation_layers[11](style[:, 10]), #s15
543
+ self.modulation_layers[12]: self.modulation_layers[12](style[:, 11]), #s17
544
+ self.modulation_layers[13]: self.modulation_layers[13](style[:, 12]), #s18
545
+ self.modulation_layers[14]: self.modulation_layers[14](style[:, 13]), #s20
546
+ self.modulation_layers[15]: self.modulation_layers[15](style[:, 14]), #s21
547
+ self.modulation_layers[16]: self.modulation_layers[16](style[:, 15]), #s23
548
+ self.modulation_layers[17]: self.modulation_layers[17](style[:, 16]), #s24
549
+ # toRGB layers
550
+ self.modulation_layers[18]: self.modulation_layers[18](style[:, 3]), #s4
551
+ self.modulation_layers[19]: self.modulation_layers[19](style[:, 5]), #s7
552
+ self.modulation_layers[20]: self.modulation_layers[20](style[:, 7]), #s10
553
+ self.modulation_layers[21]: self.modulation_layers[21](style[:, 9]), #s13
554
+ self.modulation_layers[22]: self.modulation_layers[22](style[:, 11]), #s16
555
+ self.modulation_layers[23]: self.modulation_layers[23](style[:, 13]), #s19
556
+ self.modulation_layers[24]: self.modulation_layers[24](style[:, 15]), #s22
557
+ self.modulation_layers[25]: self.modulation_layers[25](style[:, 17]), #s25
558
+ } for style in styles]
559
+
560
+ return s_codes
561
+
562
+
563
+ def forward(
564
+ self,
565
+ styles,
566
+ return_latents=False,
567
+ inject_index=None,
568
+ truncation=1,
569
+ truncation_latent=None,
570
+ input_is_latent=False,
571
+ input_is_s_code=False,
572
+ noise=None,
573
+ randomize_noise=True,
574
+ ):
575
+ if not input_is_s_code:
576
+ return self.forward_with_w(styles, return_latents, inject_index, truncation, truncation_latent, input_is_latent, noise, randomize_noise)
577
+
578
+ return self.forward_with_s(styles, return_latents, noise, randomize_noise)
579
+
580
+ def forward_with_w(
581
+ self,
582
+ styles,
583
+ return_latents=False,
584
+ inject_index=None,
585
+ truncation=1,
586
+ truncation_latent=None,
587
+ input_is_latent=False,
588
+ noise=None,
589
+ randomize_noise=True,
590
+ ):
591
+ if not input_is_latent:
592
+ styles = [self.style(s) for s in styles]
593
+
594
+ if noise is None:
595
+ if randomize_noise:
596
+ noise = [None] * self.num_layers
597
+ else:
598
+ noise = [
599
+ getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)
600
+ ]
601
+
602
+ if truncation < 1:
603
+ style_t = []
604
+
605
+ for style in styles:
606
+ style_t.append(
607
+ truncation_latent + truncation * (style - truncation_latent)
608
+ )
609
+
610
+ styles = style_t
611
+
612
+ if len(styles) < 2:
613
+ inject_index = self.n_latent
614
+
615
+ if styles[0].ndim < 3:
616
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
617
+
618
+ else:
619
+ latent = styles[0]
620
+
621
+ else:
622
+ if inject_index is None:
623
+ inject_index = random.randint(1, self.n_latent - 1)
624
+
625
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
626
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
627
+
628
+ latent = torch.cat([latent, latent2], 1)
629
+
630
+ out = self.input(latent)
631
+ out = self.conv1(out, latent[:, 0], noise=noise[0])
632
+
633
+ skip = self.to_rgb1(out, latent[:, 1])
634
+
635
+ i = 1
636
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
637
+ self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
638
+ ):
639
+ out = conv1(out, latent[:, i], noise=noise1)
640
+ out = conv2(out, latent[:, i + 1], noise=noise2)
641
+ skip = to_rgb(out, latent[:, i + 2], skip)
642
+
643
+ i += 2
644
+
645
+ image = skip
646
+
647
+ if return_latents:
648
+ return image, latent
649
+
650
+ else:
651
+ return image, None
652
+
653
+ def forward_with_s(
654
+ self,
655
+ styles,
656
+ return_latents=False,
657
+ noise=None,
658
+ randomize_noise=True,
659
+ ):
660
+
661
+ if noise is None:
662
+ if randomize_noise:
663
+ noise = [None] * self.num_layers
664
+ else:
665
+ noise = [
666
+ getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)
667
+ ]
668
+
669
+ out = self.input(styles, is_s_code=True)
670
+ out = self.conv1(out, styles, is_s_code=True, noise=noise[0])
671
+
672
+ skip = self.to_rgb1(out, styles, is_s_code=True)
673
+
674
+ i = 1
675
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
676
+ self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
677
+ ):
678
+ out = conv1(out, styles, is_s_code=True, noise=noise1)
679
+ out = conv2(out, styles, is_s_code=True, noise=noise2)
680
+ skip = to_rgb(out, styles, skip, is_s_code=True)
681
+
682
+ i += 2
683
+
684
+ image = skip
685
+
686
+ if return_latents:
687
+ return image, styles
688
+
689
+ else:
690
+ return image, None
691
+
692
+ class ConvLayer(nn.Sequential):
693
+ def __init__(
694
+ self,
695
+ in_channel,
696
+ out_channel,
697
+ kernel_size,
698
+ downsample=False,
699
+ blur_kernel=[1, 3, 3, 1],
700
+ bias=True,
701
+ activate=True,
702
+ ):
703
+ layers = []
704
+
705
+ if downsample:
706
+ factor = 2
707
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
708
+ pad0 = (p + 1) // 2
709
+ pad1 = p // 2
710
+
711
+ layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
712
+
713
+ stride = 2
714
+ self.padding = 0
715
+
716
+ else:
717
+ stride = 1
718
+ self.padding = kernel_size // 2
719
+
720
+ layers.append(
721
+ EqualConv2d(
722
+ in_channel,
723
+ out_channel,
724
+ kernel_size,
725
+ padding=self.padding,
726
+ stride=stride,
727
+ bias=bias and not activate,
728
+ )
729
+ )
730
+
731
+ if activate:
732
+ layers.append(FusedLeakyReLU(out_channel, bias=bias))
733
+
734
+ super().__init__(*layers)
735
+
736
+
737
+ class ResBlock(nn.Module):
738
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
739
+ super().__init__()
740
+
741
+ self.conv1 = ConvLayer(in_channel, in_channel, 3)
742
+ self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
743
+
744
+ self.skip = ConvLayer(
745
+ in_channel, out_channel, 1, downsample=True, activate=False, bias=False
746
+ )
747
+
748
+ def forward(self, input):
749
+ out = self.conv1(input)
750
+ out = self.conv2(out)
751
+
752
+ skip = self.skip(input)
753
+ out = (out + skip) / math.sqrt(2)
754
+
755
+ return out
756
+
757
+
758
+ class Discriminator(nn.Module):
759
+ def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
760
+ super().__init__()
761
+
762
+ channels = {
763
+ 4: 512,
764
+ 8: 512,
765
+ 16: 512,
766
+ 32: 512,
767
+ 64: 256 * channel_multiplier,
768
+ 128: 128 * channel_multiplier,
769
+ 256: 64 * channel_multiplier,
770
+ 512: 32 * channel_multiplier,
771
+ 1024: 16 * channel_multiplier,
772
+ }
773
+
774
+ convs = [ConvLayer(3, channels[size], 1)]
775
+
776
+ log_size = int(math.log(size, 2))
777
+
778
+ in_channel = channels[size]
779
+
780
+ for i in range(log_size, 2, -1):
781
+ out_channel = channels[2 ** (i - 1)]
782
+
783
+ convs.append(ResBlock(in_channel, out_channel, blur_kernel))
784
+
785
+ in_channel = out_channel
786
+
787
+ self.convs = nn.Sequential(*convs)
788
+
789
+ self.stddev_group = 4
790
+ self.stddev_feat = 1
791
+
792
+ self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
793
+ self.final_linear = nn.Sequential(
794
+ EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"),
795
+ EqualLinear(channels[4], 1),
796
+ )
797
+
798
+ def forward(self, input):
799
+ out = self.convs(input)
800
+
801
+ batch, channel, height, width = out.shape
802
+ group = min(batch, self.stddev_group)
803
+ stddev = out.view(
804
+ group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
805
+ )
806
+ stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
807
+ stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
808
+ stddev = stddev.repeat(group, 1, height, width)
809
+ out = torch.cat([out, stddev], 1)
810
+
811
+ out = self.final_conv(out)
812
+
813
+ out = out.view(batch, -1)
814
+ out = self.final_linear(out)
815
+
816
+ return out
817
+
op/__init__.py ADDED
File without changes
op/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (227 Bytes). View file
 
op/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (231 Bytes). View file
 
op/__pycache__/conv2d_gradfix.cpython-37.pyc ADDED
Binary file (5.23 kB). View file
 
op/__pycache__/conv2d_gradfix.cpython-38.pyc ADDED
Binary file (5.3 kB). View file
 
op/__pycache__/fused_act.cpython-37.pyc ADDED
Binary file (2.78 kB). View file