Arnaudding001 commited on
Commit
f13deb4
1 Parent(s): f32a850

Create train_vtoonify_t.py

Browse files
Files changed (1) hide show
  1. train_vtoonify_t.py +432 -0
train_vtoonify_t.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ #os.environ['CUDA_VISIBLE_DEVICES'] = "0"
3
+ import argparse
4
+ import math
5
+ import random
6
+
7
+ import numpy as np
8
+ import torch
9
+ from torch import nn, optim
10
+ from torch.nn import functional as F
11
+ from torch.utils import data
12
+ import torch.distributed as dist
13
+ from torchvision import transforms, utils
14
+ from tqdm import tqdm
15
+ from PIL import Image
16
+ from util import *
17
+ from model.stylegan import lpips
18
+ from model.stylegan.model import Generator, Downsample
19
+ from model.vtoonify import VToonify, ConditionalDiscriminator
20
+ from model.bisenet.model import BiSeNet
21
+ from model.simple_augment import random_apply_affine
22
+ from model.stylegan.distributed import (
23
+ get_rank,
24
+ synchronize,
25
+ reduce_loss_dict,
26
+ reduce_sum,
27
+ get_world_size,
28
+ )
29
+
30
+ # In the paper, --weight for each style is set as follows,
31
+ # cartoon: default
32
+ # caricature: default
33
+ # pixar: 1 1 1 1 1 1 1 1 1 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5
34
+ # comic: 0.5 0.5 0.5 0.5 0.5 0.5 0.5 1 1 1 1 1 1 1 1 1 1 1
35
+ # arcane: 0.5 0.5 0.5 0.5 0.5 0.5 0.5 1 1 1 1 1 1 1 1 1 1 1
36
+
37
+ class TrainOptions():
38
+ def __init__(self):
39
+
40
+ self.parser = argparse.ArgumentParser(description="Train VToonify-T")
41
+ self.parser.add_argument("--iter", type=int, default=2000, help="total training iterations")
42
+ self.parser.add_argument("--batch", type=int, default=8, help="batch sizes for each gpus")
43
+ self.parser.add_argument("--lr", type=float, default=0.0001, help="learning rate")
44
+ self.parser.add_argument("--local_rank", type=int, default=0, help="local rank for distributed training")
45
+ self.parser.add_argument("--start_iter", type=int, default=0, help="start iteration")
46
+ self.parser.add_argument("--save_every", type=int, default=30000, help="interval of saving a checkpoint")
47
+ self.parser.add_argument("--save_begin", type=int, default=30000, help="when to start saving a checkpoint")
48
+ self.parser.add_argument("--log_every", type=int, default=200, help="interval of saving an intermediate image result")
49
+
50
+ self.parser.add_argument("--adv_loss", type=float, default=0.01, help="the weight of adv loss")
51
+ self.parser.add_argument("--grec_loss", type=float, default=0.1, help="the weight of mse recontruction loss")
52
+ self.parser.add_argument("--perc_loss", type=float, default=0.01, help="the weight of perceptual loss")
53
+ self.parser.add_argument("--tmp_loss", type=float, default=1.0, help="the weight of temporal consistency loss")
54
+
55
+ self.parser.add_argument("--encoder_path", type=str, default=None, help="path to the pretrained encoder model")
56
+ self.parser.add_argument("--direction_path", type=str, default='./checkpoint/directions.npy', help="path to the editing direction latents")
57
+ self.parser.add_argument("--stylegan_path", type=str, default='./checkpoint/stylegan2-ffhq-config-f.pt', help="path to the stylegan model")
58
+ self.parser.add_argument("--finetunegan_path", type=str, default='./checkpoint/cartoon/finetune-000600.pt', help="path to the finetuned stylegan model")
59
+ self.parser.add_argument("--weight", type=float, nargs=18, default=[1]*9+[0]*9, help="the weight for blending two models")
60
+ self.parser.add_argument("--faceparsing_path", type=str, default='./checkpoint/faceparsing.pth', help="path of the face parsing model")
61
+ self.parser.add_argument("--style_encoder_path", type=str, default='./checkpoint/encoder.pt', help="path of the style encoder")
62
+
63
+ self.parser.add_argument("--name", type=str, default='vtoonify_t_cartoon', help="saved model name")
64
+ self.parser.add_argument("--pretrain", action="store_true", help="if true, only pretrain the encoder")
65
+
66
+ def parse(self):
67
+ self.opt = self.parser.parse_args()
68
+ if self.opt.encoder_path is None:
69
+ self.opt.encoder_path = os.path.join('./checkpoint/', self.opt.name, 'pretrain.pt')
70
+ args = vars(self.opt)
71
+ if self.opt.local_rank == 0:
72
+ print('Load options')
73
+ for name, value in sorted(args.items()):
74
+ print('%s: %s' % (str(name), str(value)))
75
+ return self.opt
76
+
77
+
78
+ # pretrain E of vtoonify.
79
+ # We train E so that its the last-layer feature matches the original 8-th-layer input feature of G1
80
+ # See Model initialization in Sec. 4.1.2 for the detail
81
+ def pretrain(args, generator, g_optim, g_ema, parsingpredictor, down, directions, basemodel, device):
82
+ pbar = range(args.iter)
83
+
84
+ if get_rank() == 0:
85
+ pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01)
86
+
87
+ recon_loss = torch.tensor(0.0, device=device)
88
+ loss_dict = {}
89
+
90
+ if args.distributed:
91
+ g_module = generator.module
92
+ else:
93
+ g_module = generator
94
+
95
+ accum = 0.5 ** (32 / (10 * 1000))
96
+
97
+ requires_grad(g_module.encoder, True)
98
+
99
+ for idx in pbar:
100
+ i = idx + args.start_iter
101
+
102
+ if i > args.iter:
103
+ print("Done!")
104
+ break
105
+
106
+ with torch.no_grad():
107
+ # during pretraining, no geometric transformations are applied.
108
+ noise_sample = torch.randn(args.batch, 512).cuda()
109
+ ws_ = basemodel.style(noise_sample).unsqueeze(1).repeat(1,18,1) # random w
110
+ ws_[:, 3:7] += directions[torch.randint(0, directions.shape[0], (args.batch,)), 3:7] # w''=w'=w+n
111
+ img_gen, _ = basemodel([ws_], input_is_latent=True, truncation=0.5, truncation_latent=0) # image part of x'
112
+ img_gen = torch.clamp(img_gen, -1, 1).detach()
113
+ img_gen512 = down(img_gen.detach())
114
+ img_gen256 = down(img_gen512.detach()) # image part of x'_down
115
+ mask512 = parsingpredictor(2*torch.clamp(img_gen512, -1, 1))[0]
116
+ real_input = torch.cat((img_gen256, down(mask512)/16.0), dim=1).detach() # x'_down
117
+ # f_G1^(8)(w'')
118
+ real_feat, real_skip = g_ema.generator([ws_], input_is_latent=True, return_feature_ind = 6, truncation=0.5, truncation_latent=0)
119
+ real_feat = real_feat.detach()
120
+ real_skip = real_skip.detach()
121
+
122
+ # f_E^(last)(x'_down)
123
+ fake_feat, fake_skip = generator(real_input, style=None, return_feat=True)
124
+
125
+ # L_E in Eq.(1)
126
+ recon_loss = F.mse_loss(fake_feat, real_feat) + F.mse_loss(fake_skip, real_skip)
127
+
128
+ loss_dict["emse"] = recon_loss
129
+
130
+ generator.zero_grad()
131
+ recon_loss.backward()
132
+ g_optim.step()
133
+
134
+ accumulate(g_ema.encoder, g_module.encoder, accum)
135
+
136
+ loss_reduced = reduce_loss_dict(loss_dict)
137
+
138
+ emse_loss_val = loss_reduced["emse"].mean().item()
139
+
140
+ if get_rank() == 0:
141
+ pbar.set_description(
142
+ (
143
+ f"iter: {i:d}; emse: {emse_loss_val:.3f}"
144
+ )
145
+ )
146
+
147
+ if ((i+1) >= args.save_begin and (i+1) % args.save_every == 0) or (i+1) == args.iter:
148
+ if (i+1) == args.iter:
149
+ savename = f"checkpoint/%s/pretrain.pt"%(args.name)
150
+ else:
151
+ savename = f"checkpoint/%s/pretrain-%05d.pt"%(args.name, i+1)
152
+ torch.save(
153
+ {
154
+ #"g": g_module.encoder.state_dict(),
155
+ "g_ema": g_ema.encoder.state_dict(),
156
+ },
157
+ savename,
158
+ )
159
+
160
+
161
+ # generate paired data and train vtoonify, see Sec. 4.1.2 for the detail
162
+ def train(args, generator, discriminator, g_optim, d_optim, g_ema, percept, parsingpredictor, down, pspencoder, directions, basemodel, device):
163
+ pbar = range(args.iter)
164
+
165
+ if get_rank() == 0:
166
+ pbar = tqdm(pbar, initial=args.start_iter, smoothing=0.01, ncols=120, dynamic_ncols=False)
167
+
168
+ d_loss = torch.tensor(0.0, device=device)
169
+ g_loss = torch.tensor(0.0, device=device)
170
+ grec_loss = torch.tensor(0.0, device=device)
171
+ gfeat_loss = torch.tensor(0.0, device=device)
172
+ temporal_loss = torch.tensor(0.0, device=device)
173
+ loss_dict = {}
174
+
175
+ if args.distributed:
176
+ g_module = generator.module
177
+ d_module = discriminator.module
178
+
179
+ else:
180
+ g_module = generator
181
+ d_module = discriminator
182
+
183
+ accum = 0.5 ** (32 / (10 * 1000))
184
+
185
+ for idx in pbar:
186
+ i = idx + args.start_iter
187
+
188
+ if i > args.iter:
189
+ print("Done!")
190
+ break
191
+
192
+ ###### This part is for data generation. Generate pair (x, y, w'') as in Fig. 5 of the paper
193
+ with torch.no_grad():
194
+ noise_sample = torch.randn(args.batch, 512).cuda()
195
+ wc = basemodel.style(noise_sample).unsqueeze(1).repeat(1,18,1) # random w
196
+ wc[:, 3:7] += directions[torch.randint(0, directions.shape[0], (args.batch,)), 3:7] # w'=w+n
197
+ wc = wc.detach()
198
+ xc, _ = basemodel([wc], input_is_latent=True, truncation=0.5, truncation_latent=0)
199
+ xc = torch.clamp(xc, -1, 1).detach() # x'
200
+ xl = pspencoder(F.adaptive_avg_pool2d(xc, 256))
201
+ xl = basemodel.style(xl.reshape(xl.shape[0]*xl.shape[1], xl.shape[2])).reshape(xl.shape) # E_s(x'_down)
202
+ xl = torch.cat((wc[:,0:7]*0.5, xl[:,7:18]), dim=1).detach() # w'' = concatenate w' and E_s(x'_down)
203
+ xs, _ = g_ema.generator([xl], input_is_latent=True)
204
+ xs = torch.clamp(xs, -1, 1).detach() # y'
205
+ # during training, random geometric transformations are applied.
206
+ imgs, _ = random_apply_affine(torch.cat((xc.detach(),xs), dim=1), 0.2, None)
207
+ real_input1024 = imgs[:,0:3].detach() # image part of x
208
+ real_input512 = down(real_input1024).detach()
209
+ real_input256 = down(real_input512).detach()
210
+ mask512 = parsingpredictor(2*real_input512)[0]
211
+ mask256 = down(mask512).detach()
212
+ mask = F.adaptive_avg_pool2d(mask512, 1024).detach() # parsing part of x
213
+ real_output = imgs[:,3:].detach() # y
214
+ real_input = torch.cat((real_input256, mask256/16.0), dim=1) # x_down
215
+ # for log, sample a fixed input-output pair (x_down, y, w'')
216
+ if idx == 0 or i == 0:
217
+ samplein = real_input.clone().detach()
218
+ sampleout = real_output.clone().detach()
219
+ samplexl = xl.clone().detach()
220
+
221
+ ###### This part is for training discriminator
222
+
223
+ requires_grad(g_module.encoder, False)
224
+ requires_grad(g_module.fusion_out, False)
225
+ requires_grad(g_module.fusion_skip, False)
226
+ requires_grad(discriminator, True)
227
+
228
+ fake_output = generator(real_input, xl)
229
+ fake_pred = discriminator(F.adaptive_avg_pool2d(fake_output, 256))
230
+ real_pred = discriminator(F.adaptive_avg_pool2d(real_output, 256))
231
+
232
+ # L_adv in Eq.(3)
233
+ d_loss = d_logistic_loss(real_pred, fake_pred) * args.adv_loss
234
+ loss_dict["d"] = d_loss
235
+
236
+ discriminator.zero_grad()
237
+ d_loss.backward()
238
+ d_optim.step()
239
+
240
+ ###### This part is for training generator (encoder and fusion modules)
241
+
242
+ requires_grad(g_module.encoder, True)
243
+ requires_grad(g_module.fusion_out, True)
244
+ requires_grad(g_module.fusion_skip, True)
245
+ requires_grad(discriminator, False)
246
+
247
+ fake_output = generator(real_input, xl)
248
+ fake_pred = discriminator(F.adaptive_avg_pool2d(fake_output, 256))
249
+ # L_adv in Eq.(3)
250
+ g_loss = g_nonsaturating_loss(fake_pred) * args.adv_loss
251
+ # L_rec in Eq.(2)
252
+ grec_loss = F.mse_loss(fake_output, real_output) * args.grec_loss
253
+ gfeat_loss = percept(F.adaptive_avg_pool2d(fake_output, 512), # 1024 will out of memory
254
+ F.adaptive_avg_pool2d(real_output, 512)).sum() * args.perc_loss # 256 will get blurry output
255
+
256
+ loss_dict["g"] = g_loss
257
+ loss_dict["gr"] = grec_loss
258
+ loss_dict["gf"] = gfeat_loss
259
+
260
+ w = random.randint(0,1024-896)
261
+ h = random.randint(0,1024-896)
262
+ crop_input = torch.cat((real_input1024[:,:,w:w+896,h:h+896], mask[:,:,w:w+896,h:h+896]/16.0), dim=1).detach()
263
+ crop_input = down(down(crop_input))
264
+ crop_fake_output = fake_output[:,:,w:w+896,h:h+896]
265
+ fake_crop_output = generator(crop_input, xl)
266
+ # L_tmp in Eq.(4), gradually increase the weight of L_tmp
267
+ temporal_loss = ((fake_crop_output-crop_fake_output)**2).mean() * max(idx/(args.iter/2.0)-1, 0) * args.tmp_loss
268
+ loss_dict["tp"] = temporal_loss
269
+
270
+ generator.zero_grad()
271
+ (g_loss + grec_loss + gfeat_loss + temporal_loss).backward()
272
+ g_optim.step()
273
+
274
+ accumulate(g_ema.encoder, g_module.encoder, accum)
275
+ accumulate(g_ema.fusion_out, g_module.fusion_out, accum)
276
+ accumulate(g_ema.fusion_skip, g_module.fusion_skip, accum)
277
+
278
+ loss_reduced = reduce_loss_dict(loss_dict)
279
+
280
+ d_loss_val = loss_reduced["d"].mean().item()
281
+ g_loss_val = loss_reduced["g"].mean().item()
282
+ gr_loss_val = loss_reduced["gr"].mean().item()
283
+ gf_loss_val = loss_reduced["gf"].mean().item()
284
+ tmp_loss_val = loss_reduced["tp"].mean().item()
285
+
286
+ if get_rank() == 0:
287
+ pbar.set_description(
288
+ (
289
+ f"iter: {i:d}; advd: {d_loss_val:.3f}; advg: {g_loss_val:.3f}; mse: {gr_loss_val:.3f}; "
290
+ f"perc: {gf_loss_val:.3f}; tmp: {tmp_loss_val:.3f}"
291
+ )
292
+ )
293
+
294
+ if i % args.log_every == 0 or (i+1) == args.iter:
295
+ with torch.no_grad():
296
+ g_ema.eval()
297
+ sample = g_ema(samplein, samplexl)
298
+ sample = F.interpolate(torch.cat((sampleout, sample), dim=0), 256)
299
+ utils.save_image(
300
+ sample,
301
+ f"log/%s/%05d.jpg"%(args.name, i),
302
+ nrow=int(args.batch),
303
+ normalize=True,
304
+ range=(-1, 1),
305
+ )
306
+
307
+ if ((i+1) >= args.save_begin and (i+1) % args.save_every == 0) or (i+1) == args.iter:
308
+ if (i+1) == args.iter:
309
+ savename = f"checkpoint/%s/vtoonify.pt"%(args.name)
310
+ else:
311
+ savename = f"checkpoint/%s/vtoonify_%05d.pt"%(args.name, i+1)
312
+ torch.save(
313
+ {
314
+ #"g": g_module.state_dict(),
315
+ #"d": d_module.state_dict(),
316
+ "g_ema": g_ema.state_dict(),
317
+ },
318
+ savename,
319
+ )
320
+
321
+
322
+
323
+ if __name__ == "__main__":
324
+
325
+ device = "cuda"
326
+ parser = TrainOptions()
327
+ args = parser.parse()
328
+ if args.local_rank == 0:
329
+ print('*'*98)
330
+ if not os.path.exists("log/%s/"%(args.name)):
331
+ os.makedirs("log/%s/"%(args.name))
332
+ if not os.path.exists("checkpoint/%s/"%(args.name)):
333
+ os.makedirs("checkpoint/%s/"%(args.name))
334
+
335
+ n_gpu = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
336
+ args.distributed = n_gpu > 1
337
+
338
+ if args.distributed:
339
+ torch.cuda.set_device(args.local_rank)
340
+ torch.distributed.init_process_group(backend="nccl", init_method="env://")
341
+ synchronize()
342
+
343
+ generator = VToonify(backbone = 'toonify').to(device)
344
+ generator.apply(weights_init)
345
+ g_ema = VToonify(backbone = 'toonify').to(device)
346
+ g_ema.eval()
347
+
348
+ basemodel = Generator(1024, 512, 8, 2).to(device) # G0
349
+ finetunemodel = Generator(1024, 512, 8, 2).to(device)
350
+ basemodel.load_state_dict(torch.load(args.stylegan_path, map_location=lambda storage, loc: storage)['g_ema'])
351
+ finetunemodel.load_state_dict(torch.load(args.finetunegan_path, map_location=lambda storage, loc: storage)['g_ema'])
352
+ fused_state_dict = blend_models(finetunemodel, basemodel, args.weight) # G1
353
+ generator.generator.load_state_dict(fused_state_dict) # load G1
354
+ g_ema.generator.load_state_dict(fused_state_dict)
355
+ requires_grad(basemodel, False)
356
+ requires_grad(generator.generator, False)
357
+ requires_grad(g_ema.generator, False)
358
+
359
+ if not args.pretrain:
360
+ generator.encoder.load_state_dict(torch.load(args.encoder_path, map_location=lambda storage, loc: storage)["g_ema"])
361
+ # we initialize the fusion modules to map f_G \otimes f_E to f_G.
362
+ for k in generator.fusion_out:
363
+ k.weight.data *= 0.01
364
+ k.weight[:,0:k.weight.shape[0],1,1].data += torch.eye(k.weight.shape[0]).cuda()
365
+ for k in generator.fusion_skip:
366
+ k.weight.data *= 0.01
367
+ k.weight[:,0:k.weight.shape[0],1,1].data += torch.eye(k.weight.shape[0]).cuda()
368
+
369
+ accumulate(g_ema.encoder, generator.encoder, 0)
370
+ accumulate(g_ema.fusion_out, generator.fusion_out, 0)
371
+ accumulate(g_ema.fusion_skip, generator.fusion_skip, 0)
372
+
373
+ g_parameters = list(generator.encoder.parameters())
374
+ if not args.pretrain:
375
+ g_parameters = g_parameters + list(generator.fusion_out.parameters()) + list(generator.fusion_skip.parameters())
376
+
377
+ g_optim = optim.Adam(
378
+ g_parameters,
379
+ lr=args.lr,
380
+ betas=(0.9, 0.99),
381
+ )
382
+
383
+ if args.distributed:
384
+ generator = nn.parallel.DistributedDataParallel(
385
+ generator,
386
+ device_ids=[args.local_rank],
387
+ output_device=args.local_rank,
388
+ broadcast_buffers=False,
389
+ find_unused_parameters=True,
390
+ )
391
+
392
+ parsingpredictor = BiSeNet(n_classes=19)
393
+ parsingpredictor.load_state_dict(torch.load(args.faceparsing_path, map_location=lambda storage, loc: storage))
394
+ parsingpredictor.to(device).eval()
395
+ requires_grad(parsingpredictor, False)
396
+
397
+ # we apply gaussian blur to the images to avoid flickers caused during downsampling
398
+ down = Downsample(kernel=[1, 3, 3, 1], factor=2).to(device)
399
+ requires_grad(down, False)
400
+
401
+ directions = torch.tensor(np.load(args.direction_path)).to(device)
402
+
403
+ if not args.pretrain:
404
+ discriminator = ConditionalDiscriminator(256).to(device)
405
+
406
+ d_optim = optim.Adam(
407
+ discriminator.parameters(),
408
+ lr=args.lr,
409
+ betas=(0.9, 0.99),
410
+ )
411
+
412
+ if args.distributed:
413
+ discriminator = nn.parallel.DistributedDataParallel(
414
+ discriminator,
415
+ device_ids=[args.local_rank],
416
+ output_device=args.local_rank,
417
+ broadcast_buffers=False,
418
+ find_unused_parameters=True,
419
+ )
420
+
421
+ percept = lpips.PerceptualLoss(model="net-lin", net="vgg", use_gpu=device.startswith("cuda"), gpu_ids=[args.local_rank])
422
+ requires_grad(percept.model.net, False)
423
+
424
+ pspencoder = load_psp_standalone(args.style_encoder_path, device)
425
+
426
+ if args.local_rank == 0:
427
+ print('Load models and data successfully loaded!')
428
+
429
+ if args.pretrain:
430
+ pretrain(args, generator, g_optim, g_ema, parsingpredictor, down, directions, basemodel, device)
431
+ else:
432
+ train(args, generator, discriminator, g_optim, d_optim, g_ema, percept, parsingpredictor, down, pspencoder, directions, basemodel, device)