Arnaudding001 commited on
Commit
f32a850
1 Parent(s): 3315f0a

Create train_vtoonify_d.py

Browse files
Files changed (1) hide show
  1. train_vtoonify_d.py +515 -0
train_vtoonify_d.py ADDED
@@ -0,0 +1,515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ #os.environ['CUDA_VISIBLE_DEVICES'] = "0"
3
+ import argparse
4
+ import math
5
+ import random
6
+
7
+ import numpy as np
8
+ import torch
9
+ from torch import nn, optim
10
+ from torch.nn import functional as F
11
+ from torch.utils import data
12
+ import torch.distributed as dist
13
+ from torchvision import transforms, utils
14
+ from tqdm import tqdm
15
+ from PIL import Image
16
+ from util import *
17
+
18
+ from model.stylegan import lpips
19
+ from model.stylegan.model import Generator, Downsample
20
+ from model.vtoonify import VToonify, ConditionalDiscriminator
21
+ from model.bisenet.model import BiSeNet
22
+ from model.simple_augment import random_apply_affine
23
+ from model.stylegan.distributed import (
24
+ get_rank,
25
+ synchronize,
26
+ reduce_loss_dict,
27
+ reduce_sum,
28
+ get_world_size,
29
+ )
30
+
31
+ class TrainOptions():
32
+ def __init__(self):
33
+
34
+ self.parser = argparse.ArgumentParser(description="Train VToonify-D")
35
+ self.parser.add_argument("--iter", type=int, default=2000, help="total training iterations")
36
+ self.parser.add_argument("--batch", type=int, default=8, help="batch sizes for each gpus")
37
+ self.parser.add_argument("--lr", type=float, default=0.0001, help="learning rate")
38
+ self.parser.add_argument("--local_rank", type=int, default=0, help="local rank for distributed training")
39
+ self.parser.add_argument("--start_iter", type=int, default=0, help="start iteration")
40
+ self.parser.add_argument("--save_every", type=int, default=30000, help="interval of saving a checkpoint")
41
+ self.parser.add_argument("--save_begin", type=int, default=30000, help="when to start saving a checkpoint")
42
+ self.parser.add_argument("--log_every", type=int, default=200, help="interval of saving a checkpoint")
43
+
44
+ self.parser.add_argument("--adv_loss", type=float, default=0.01, help="the weight of adv loss")
45
+ self.parser.add_argument("--grec_loss", type=float, default=0.1, help="the weight of mse recontruction loss")
46
+ self.parser.add_argument("--perc_loss", type=float, default=0.01, help="the weight of perceptual loss")
47
+ self.parser.add_argument("--tmp_loss", type=float, default=1.0, help="the weight of temporal consistency loss")
48
+ self.parser.add_argument("--msk_loss", type=float, default=0.0005, help="the weight of attention mask loss")
49
+
50
+ self.parser.add_argument("--fix_degree", action="store_true", help="use a fixed style degree")
51
+ self.parser.add_argument("--fix_style", action="store_true", help="use a fixed style image")
52
+ self.parser.add_argument("--fix_color", action="store_true", help="use the original color (no color transfer)")
53
+ self.parser.add_argument("--exstyle_path", type=str, default='./checkpoint/cartoon/refined_exstyle_code.npy', help="path of the extrinsic style code")
54
+ self.parser.add_argument("--style_id", type=int, default=26, help="the id of the style image")
55
+ self.parser.add_argument("--style_degree", type=float, default=0.5, help="style degree for VToonify-D")
56
+
57
+ self.parser.add_argument("--encoder_path", type=str, default=None, help="path to the pretrained encoder model")
58
+ self.parser.add_argument("--direction_path", type=str, default='./checkpoint/directions.npy', help="path to the editing direction latents")
59
+ self.parser.add_argument("--stylegan_path", type=str, default='./checkpoint/cartoon/generator.pt', help="path to the stylegan model")
60
+ self.parser.add_argument("--faceparsing_path", type=str, default='./checkpoint/faceparsing.pth', help="path of the face parsing model")
61
+ self.parser.add_argument("--style_encoder_path", type=str, default='./checkpoint/encoder.pt', help="path of the style encoder")
62
+
63
+ self.parser.add_argument("--name", type=str, default='vtoonify_d_cartoon', help="saved model name")
64
+ self.parser.add_argument("--pretrain", action="store_true", help="if true, only pretrain the encoder")
65
+
66
+ def parse(self):
67
+ self.opt = self.parser.parse_args()
68
+ if self.opt.encoder_path is None:
69
+ self.opt.encoder_path = os.path.join('./checkpoint/', self.opt.name, 'pretrain.pt')
70
+ args = vars(self.opt)
71
+ if self.opt.local_rank == 0:
72
+ print('Load options')
73
+ for name, value in sorted(args.items()):
74
+ print('%s: %s' % (str(name), str(value)))
75
+ return self.opt
76
+
77
+
78
+ # pretrain E of vtoonify.
79
+ # We train E so that its the last-layer feature matches the original 8-th-layer input feature of G1
80
+ # See Model initialization in Sec. 4.2.2 for the detail
81
+ def pretrain(args, generator, g_optim, g_ema, parsingpredictor, down, directions, styles, device):
82
+ pbar = range(args.iter)
83
+
84
+ if get_rank() == 0:
85
+ pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01)
86
+
87
+ recon_loss = torch.tensor(0.0, device=device)
88
+ loss_dict = {}
89
+
90
+ if args.distributed:
91
+ g_module = generator.module
92
+ else:
93
+ g_module = generator
94
+
95
+ accum = 0.5 ** (32 / (10 * 1000))
96
+
97
+ requires_grad(g_module.encoder, True)
98
+
99
+ for idx in pbar:
100
+ i = idx + args.start_iter
101
+
102
+ if i > args.iter:
103
+ print("Done!")
104
+ break
105
+
106
+ # during pretraining, the last 11 layers of DualStyleGAN (for color transfer) is not used.
107
+ # so args.fix_color is not used. the last 11 elements in weight are not used.
108
+ if args.fix_degree:
109
+ d_s = args.style_degree
110
+ else:
111
+ d_s = 0 if i <= args.iter / 4.0 else np.random.rand(1)[0]
112
+ weight = [d_s] * 18
113
+
114
+ # sample pre-saved w''=E_s(s)
115
+ if args.fix_style:
116
+ style = styles[args.style_id:args.style_id+1].repeat(args.batch,1,1)
117
+ else:
118
+ style = styles[torch.randint(0, styles.size(0), (args.batch,))]
119
+
120
+ with torch.no_grad():
121
+ # during pretraining, no geometric transformations are applied.
122
+ noise_sample = torch.randn(args.batch, 512).cuda()
123
+ ws_ = g_ema.stylegan().style(noise_sample).unsqueeze(1).repeat(1,18,1) # random w
124
+ ws_[:, 3:7] += directions[torch.randint(0, directions.shape[0], (args.batch,)), 3:7] # w'=w+n
125
+ img_gen, _ = g_ema.stylegan()([ws_], input_is_latent=True, truncation=0.5, truncation_latent=0)
126
+ img_gen = torch.clamp(img_gen, -1, 1).detach() # x''
127
+ img_gen512 = down(img_gen.detach())
128
+ img_gen256 = down(img_gen512.detach()) # image part of x''_down
129
+ mask512 = parsingpredictor(2*torch.clamp(img_gen512, -1, 1))[0]
130
+ real_input = torch.cat((img_gen256, down(mask512)/16.0), dim=1) # x''_down
131
+ # f_G1^(8)(w', w'', d_s)
132
+ real_feat, real_skip = g_ema.generator([ws_], style, input_is_latent=True, return_feat=True,
133
+ truncation=0.5, truncation_latent=0, use_res=True, interp_weights=weight)
134
+
135
+ real_input = real_input.detach()
136
+ real_feat = real_feat.detach()
137
+ real_skip = real_skip.detach()
138
+
139
+ # f_E^(last)(x''_down, w'', d_s)
140
+ fake_feat, fake_skip = generator(real_input, style, d_s, return_feat=True)
141
+
142
+ # L_E in Eq.(8)
143
+ recon_loss = F.mse_loss(fake_feat, real_feat) + F.mse_loss(fake_skip, real_skip)
144
+
145
+ loss_dict["emse"] = recon_loss
146
+
147
+ generator.zero_grad()
148
+ recon_loss.backward()
149
+ g_optim.step()
150
+
151
+ accumulate(g_ema.encoder, g_module.encoder, accum)
152
+
153
+ loss_reduced = reduce_loss_dict(loss_dict)
154
+
155
+ emse_loss_val = loss_reduced["emse"].mean().item()
156
+
157
+ if get_rank() == 0:
158
+ pbar.set_description(
159
+ (
160
+ f"iter: {i:d}; emse: {emse_loss_val:.3f}"
161
+ )
162
+ )
163
+
164
+ if ((i+1) >= args.save_begin and (i+1) % args.save_every == 0) or (i+1) == args.iter:
165
+ if (i+1) == args.iter:
166
+ savename = f"checkpoint/%s/pretrain.pt"%(args.name)
167
+ else:
168
+ savename = f"checkpoint/%s/pretrain-%05d.pt"%(args.name, i+1)
169
+ torch.save(
170
+ {
171
+ #"g": g_module.encoder.state_dict(),
172
+ "g_ema": g_ema.encoder.state_dict(),
173
+ },
174
+ savename,
175
+ )
176
+
177
+
178
+ # generate paired data and train vtoonify, see Sec. 4.2.2 for the detail
179
+ def train(args, generator, discriminator, g_optim, d_optim, g_ema, percept, parsingpredictor, down, pspencoder, directions, styles, device):
180
+ pbar = range(args.iter)
181
+
182
+ if get_rank() == 0:
183
+ pbar = tqdm(pbar, initial=args.start_iter, smoothing=0.01, ncols=130, dynamic_ncols=False)
184
+
185
+ d_loss = torch.tensor(0.0, device=device)
186
+ g_loss = torch.tensor(0.0, device=device)
187
+ grec_loss = torch.tensor(0.0, device=device)
188
+ gfeat_loss = torch.tensor(0.0, device=device)
189
+ temporal_loss = torch.tensor(0.0, device=device)
190
+ gmask_loss = torch.tensor(0.0, device=device)
191
+ loss_dict = {}
192
+
193
+ surffix = '_s'
194
+ if args.fix_style:
195
+ surffix += '%03d'%(args.style_id)
196
+ surffix += '_d'
197
+ if args.fix_degree:
198
+ surffix += '%1.1f'%(args.style_degree)
199
+ if not args.fix_color:
200
+ surffix += '_c'
201
+
202
+ if args.distributed:
203
+ g_module = generator.module
204
+ d_module = discriminator.module
205
+
206
+ else:
207
+ g_module = generator
208
+ d_module = discriminator
209
+
210
+ accum = 0.5 ** (32 / (10 * 1000))
211
+
212
+ for idx in pbar:
213
+ i = idx + args.start_iter
214
+
215
+ if i > args.iter:
216
+ print("Done!")
217
+ break
218
+
219
+ # sample style degree
220
+ if args.fix_degree or idx == 0 or i == 0:
221
+ d_s = args.style_degree
222
+ else:
223
+ d_s = np.random.randint(0,6) / 5.0
224
+ if args.fix_color:
225
+ weight = [d_s] * 7 + [0] * 11
226
+ else:
227
+ weight = [d_s] * 7 + [1] * 11
228
+ # style degree condition for discriminator
229
+ degree_label = torch.zeros(args.batch, 1).to(device) + d_s
230
+
231
+ # style index condition for discriminator
232
+ style_ind = torch.randint(0, styles.size(0), (args.batch,))
233
+ if args.fix_style or idx == 0 or i == 0:
234
+ style_ind = style_ind * 0 + args.style_id
235
+ # sample pre-saved E_s(s)
236
+ style = styles[style_ind]
237
+
238
+ with torch.no_grad():
239
+ noise_sample = torch.randn(args.batch, 512).cuda()
240
+ wc = g_ema.stylegan().style(noise_sample).unsqueeze(1).repeat(1,18,1) # random w
241
+ wc[:, 3:7] += directions[torch.randint(0, directions.shape[0], (args.batch,)), 3:7] # w'=w+n
242
+ wc = wc.detach()
243
+ xc, _ = g_ema.stylegan()([wc], input_is_latent=True, truncation=0.5, truncation_latent=0)
244
+ xc = torch.clamp(xc, -1, 1).detach() # x''
245
+ if not args.fix_color and args.fix_style: # only transfer this fixed style's color
246
+ xl = style.clone()
247
+ else:
248
+ xl = pspencoder(F.adaptive_avg_pool2d(xc, 256))
249
+ xl = g_ema.zplus2wplus(xl) # E_s(x''_down)
250
+ xl = torch.cat((style[:,0:7], xl[:,7:18]), dim=1).detach() # w'' = concatenate E_s(s) and E_s(x''_down)
251
+ xs, _ = g_ema.generator([wc], xl, input_is_latent=True,
252
+ truncation=0.5, truncation_latent=0, use_res=True, interp_weights=weight)
253
+ xs = torch.clamp(xs, -1, 1).detach() # y'=G1(w', w'', d_s, d_c)
254
+ # apply color jitter to w'. we fuse w' of the current iteration with w' of the last iteration
255
+ if idx > 0 and i >= (args.iter/2.0) and (not args.fix_color and not args.fix_style):
256
+ wcfuse = wc.clone()
257
+ wcfuse[:,7:] = wc_[:,7:] * (i/(args.iter/2.0)-1) + wcfuse[:,7:] * (2-i/(args.iter/2.0))
258
+ xc, _ = g_ema.stylegan()([wcfuse], input_is_latent=True, truncation=0.5, truncation_latent=0)
259
+ xc = torch.clamp(xc, -1, 1).detach() # x'
260
+ wc_ = wc.clone() # wc_ is the w' in the last iteration
261
+ # during training, random geometric transformations are applied.
262
+ imgs, _ = random_apply_affine(torch.cat((xc.detach(),xs), dim=1), 0.2, None)
263
+ real_input1024 = imgs[:,0:3].detach() # image part of x
264
+ real_input512 = down(real_input1024).detach()
265
+ real_input256 = down(real_input512).detach()
266
+ mask512 = parsingpredictor(2*real_input512)[0]
267
+ mask256 = down(mask512).detach()
268
+ mask = F.adaptive_avg_pool2d(mask512, 1024).detach() # parsing part of x
269
+ real_output = imgs[:,3:].detach() # y
270
+ real_input = torch.cat((real_input256, mask256/16.0), dim=1) # x_down
271
+ # for log, sample a fixed input-output pair (x_down, y, w'', d_s)
272
+ if idx == 0 or i == 0:
273
+ samplein = real_input.clone().detach()
274
+ sampleout = real_output.clone().detach()
275
+ samplexl = xl.clone().detach()
276
+ sampleds = d_s
277
+
278
+ ###### This part is for training discriminator
279
+
280
+ requires_grad(g_module.encoder, False)
281
+ requires_grad(g_module.fusion_out, False)
282
+ requires_grad(g_module.fusion_skip, False)
283
+ requires_grad(discriminator, True)
284
+
285
+ fake_output = generator(real_input, xl, d_s)
286
+ fake_pred = discriminator(F.adaptive_avg_pool2d(fake_output, 256), degree_label, style_ind)
287
+ real_pred = discriminator(F.adaptive_avg_pool2d(real_output, 256), degree_label, style_ind)
288
+
289
+ # L_adv in Eq.(3)
290
+ d_loss = d_logistic_loss(real_pred, fake_pred) * args.adv_loss
291
+ loss_dict["d"] = d_loss
292
+
293
+ discriminator.zero_grad()
294
+ d_loss.backward()
295
+ d_optim.step()
296
+
297
+ ###### This part is for training generator (encoder and fusion modules)
298
+
299
+ requires_grad(g_module.encoder, True)
300
+ requires_grad(g_module.fusion_out, True)
301
+ requires_grad(g_module.fusion_skip, True)
302
+ requires_grad(discriminator, False)
303
+
304
+ fake_output, m_Es = generator(real_input, xl, d_s, return_mask=True)
305
+ fake_pred = discriminator(F.adaptive_avg_pool2d(fake_output, 256), degree_label, style_ind)
306
+
307
+ # L_adv in Eq.(3)
308
+ g_loss = g_nonsaturating_loss(fake_pred) * args.adv_loss
309
+ # L_rec in Eq.(2)
310
+ grec_loss = F.mse_loss(fake_output, real_output) * args.grec_loss
311
+ gfeat_loss = percept(F.adaptive_avg_pool2d(fake_output, 512), # 1024 will out of memory
312
+ F.adaptive_avg_pool2d(real_output, 512)).sum() * args.perc_loss # 256 will get blurry output
313
+
314
+ # L_msk in Eq.(9)
315
+ gmask_loss = torch.tensor(0.0, device=device)
316
+ if not args.fix_degree or args.msk_loss > 0:
317
+ for jj, m_E in enumerate(m_Es):
318
+ gd_s = (1 - d_s) ** 2 * 0.9 + 0.1
319
+ gmask_loss += F.relu(torch.mean(m_E)-gd_s) * args.msk_loss
320
+
321
+ loss_dict["g"] = g_loss
322
+ loss_dict["gr"] = grec_loss
323
+ loss_dict["gf"] = gfeat_loss
324
+ loss_dict["msk"] = gmask_loss
325
+
326
+ w = random.randint(0,1024-896)
327
+ h = random.randint(0,1024-896)
328
+ crop_input = torch.cat((real_input1024[:,:,w:w+896,h:h+896], mask[:,:,w:w+896,h:h+896]/16.0), dim=1).detach()
329
+ crop_input = down(down(crop_input))
330
+ crop_fake_output = fake_output[:,:,w:w+896,h:h+896]
331
+ fake_crop_output = generator(crop_input, xl, d_s)
332
+ # L_tmp in Eq.(4), gradually increase the weight of L_tmp
333
+ temporal_loss = ((fake_crop_output-crop_fake_output)**2).mean() * max(idx/(args.iter/2.0)-1, 0) * args.tmp_loss
334
+ loss_dict["tp"] = temporal_loss
335
+
336
+ generator.zero_grad()
337
+ (g_loss + grec_loss + gfeat_loss + temporal_loss + gmask_loss).backward()
338
+ g_optim.step()
339
+
340
+ accumulate(g_ema.encoder, g_module.encoder, accum)
341
+ accumulate(g_ema.fusion_out, g_module.fusion_out, accum)
342
+ accumulate(g_ema.fusion_skip, g_module.fusion_skip, accum)
343
+
344
+ loss_reduced = reduce_loss_dict(loss_dict)
345
+
346
+ d_loss_val = loss_reduced["d"].mean().item()
347
+ g_loss_val = loss_reduced["g"].mean().item()
348
+ gr_loss_val = loss_reduced["gr"].mean().item()
349
+ gf_loss_val = loss_reduced["gf"].mean().item()
350
+ tmp_loss_val = loss_reduced["tp"].mean().item()
351
+ msk_loss_val = loss_reduced["msk"].mean().item()
352
+
353
+ if get_rank() == 0:
354
+ pbar.set_description(
355
+ (
356
+ f"iter: {i:d}; advd: {d_loss_val:.3f}; advg: {g_loss_val:.3f}; mse: {gr_loss_val:.3f}; "
357
+ f"perc: {gf_loss_val:.3f}; tmp: {tmp_loss_val:.3f}; msk: {msk_loss_val:.3f}"
358
+ )
359
+ )
360
+
361
+ if i == 0 or (i+1) % args.log_every == 0 or (i+1) == args.iter:
362
+ with torch.no_grad():
363
+ g_ema.eval()
364
+ sample1 = g_ema(samplein, samplexl, sampleds)
365
+ if args.fix_degree:
366
+ sample = F.interpolate(torch.cat((sampleout, sample1), dim=0), 256)
367
+ else:
368
+ sample2 = g_ema(samplein, samplexl, d_s)
369
+ sample = F.interpolate(torch.cat((sampleout, sample1, sample2), dim=0), 256)
370
+ utils.save_image(
371
+ sample,
372
+ f"log/%s/%05d.jpg"%(args.name, (i+1)),
373
+ nrow=int(args.batch),
374
+ normalize=True,
375
+ range=(-1, 1),
376
+ )
377
+
378
+ if ((i+1) >= args.save_begin and (i+1) % args.save_every == 0) or (i+1) == args.iter:
379
+ if (i+1) == args.iter:
380
+ savename = f"checkpoint/%s/vtoonify%s.pt"%(args.name, surffix)
381
+ else:
382
+ savename = f"checkpoint/%s/vtoonify%s_%05d.pt"%(args.name, surffix, i+1)
383
+ torch.save(
384
+ {
385
+ #"g": g_module.state_dict(),
386
+ #"d": d_module.state_dict(),
387
+ "g_ema": g_ema.state_dict(),
388
+ },
389
+ savename,
390
+ )
391
+
392
+
393
+
394
+ if __name__ == "__main__":
395
+
396
+ device = "cuda"
397
+ parser = TrainOptions()
398
+ args = parser.parse()
399
+ if args.local_rank == 0:
400
+ print('*'*98)
401
+ if not os.path.exists("log/%s/"%(args.name)):
402
+ os.makedirs("log/%s/"%(args.name))
403
+ if not os.path.exists("checkpoint/%s/"%(args.name)):
404
+ os.makedirs("checkpoint/%s/"%(args.name))
405
+
406
+ n_gpu = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
407
+ args.distributed = n_gpu > 1
408
+
409
+ if args.distributed:
410
+ torch.cuda.set_device(args.local_rank)
411
+ torch.distributed.init_process_group(backend="nccl", init_method="env://")
412
+ synchronize()
413
+
414
+ generator = VToonify(backbone = 'dualstylegan').to(device)
415
+ generator.apply(weights_init)
416
+ g_ema = VToonify(backbone = 'dualstylegan').to(device)
417
+ g_ema.eval()
418
+
419
+ ckpt = torch.load(args.stylegan_path, map_location=lambda storage, loc: storage)
420
+ generator.generator.load_state_dict(ckpt["g_ema"], strict=False)
421
+ # load ModRes blocks of DualStyleGAN into the modified ModRes blocks (with dilation)
422
+ generator.res.load_state_dict(generator.generator.res.state_dict(), strict=False)
423
+ g_ema.generator.load_state_dict(ckpt["g_ema"], strict=False)
424
+ g_ema.res.load_state_dict(g_ema.generator.res.state_dict(), strict=False)
425
+ requires_grad(generator.generator, False)
426
+ requires_grad(generator.res, False)
427
+ requires_grad(g_ema.generator, False)
428
+ requires_grad(g_ema.res, False)
429
+
430
+ if not args.pretrain:
431
+ generator.encoder.load_state_dict(torch.load(args.encoder_path, map_location=lambda storage, loc: storage)["g_ema"])
432
+ # we initialize the fusion modules to map f_G \otimes f_E to f_G.
433
+ for k in generator.fusion_out:
434
+ k.conv.weight.data *= 0.01
435
+ k.conv.weight[:,0:k.conv.weight.shape[0],1,1].data += torch.eye(k.conv.weight.shape[0]).cuda()
436
+ for k in generator.fusion_skip:
437
+ k.weight.data *= 0.01
438
+ k.weight[:,0:k.weight.shape[0],1,1].data += torch.eye(k.weight.shape[0]).cuda()
439
+
440
+ accumulate(g_ema.encoder, generator.encoder, 0)
441
+ accumulate(g_ema.fusion_out, generator.fusion_out, 0)
442
+ accumulate(g_ema.fusion_skip, generator.fusion_skip, 0)
443
+
444
+ g_parameters = list(generator.encoder.parameters())
445
+ if not args.pretrain:
446
+ g_parameters = g_parameters + list(generator.fusion_out.parameters()) + list(generator.fusion_skip.parameters())
447
+
448
+ g_optim = optim.Adam(
449
+ g_parameters,
450
+ lr=args.lr,
451
+ betas=(0.9, 0.99),
452
+ )
453
+
454
+ if args.distributed:
455
+ generator = nn.parallel.DistributedDataParallel(
456
+ generator,
457
+ device_ids=[args.local_rank],
458
+ output_device=args.local_rank,
459
+ broadcast_buffers=False,
460
+ find_unused_parameters=True,
461
+ )
462
+
463
+ parsingpredictor = BiSeNet(n_classes=19)
464
+ parsingpredictor.load_state_dict(torch.load(args.faceparsing_path, map_location=lambda storage, loc: storage))
465
+ parsingpredictor.to(device).eval()
466
+ requires_grad(parsingpredictor, False)
467
+
468
+ # we apply gaussian blur to the images to avoid flickers caused during downsampling
469
+ down = Downsample(kernel=[1, 3, 3, 1], factor=2).to(device)
470
+ requires_grad(down, False)
471
+
472
+ directions = torch.tensor(np.load(args.direction_path)).to(device)
473
+
474
+ # load style codes of DualStyleGAN
475
+ exstyles = np.load(args.exstyle_path, allow_pickle='TRUE').item()
476
+ if args.local_rank == 0 and not os.path.exists('checkpoint/%s/exstyle_code.npy'%(args.name)):
477
+ np.save('checkpoint/%s/exstyle_code.npy'%(args.name), exstyles, allow_pickle=True)
478
+ styles = []
479
+ with torch.no_grad():
480
+ for stylename in exstyles.keys():
481
+ exstyle = torch.tensor(exstyles[stylename]).to(device)
482
+ exstyle = g_ema.zplus2wplus(exstyle)
483
+ styles += [exstyle]
484
+ styles = torch.cat(styles, dim=0)
485
+
486
+ if not args.pretrain:
487
+ discriminator = ConditionalDiscriminator(256, use_condition=True, style_num = styles.size(0)).to(device)
488
+
489
+ d_optim = optim.Adam(
490
+ discriminator.parameters(),
491
+ lr=args.lr,
492
+ betas=(0.9, 0.99),
493
+ )
494
+
495
+ if args.distributed:
496
+ discriminator = nn.parallel.DistributedDataParallel(
497
+ discriminator,
498
+ device_ids=[args.local_rank],
499
+ output_device=args.local_rank,
500
+ broadcast_buffers=False,
501
+ find_unused_parameters=True,
502
+ )
503
+
504
+ percept = lpips.PerceptualLoss(model="net-lin", net="vgg", use_gpu=device.startswith("cuda"), gpu_ids=[args.local_rank])
505
+ requires_grad(percept.model.net, False)
506
+
507
+ pspencoder = load_psp_standalone(args.style_encoder_path, device)
508
+
509
+ if args.local_rank == 0:
510
+ print('Load models and data successfully loaded!')
511
+
512
+ if args.pretrain:
513
+ pretrain(args, generator, g_optim, g_ema, parsingpredictor, down, directions, styles, device)
514
+ else:
515
+ train(args, generator, discriminator, g_optim, d_optim, g_ema, percept, parsingpredictor, down, pspencoder, directions, styles, device)