Spaces:
Runtime error
Runtime error
File size: 2,772 Bytes
54bf1bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import os
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
from logger import Logger, Visualizer
import numpy as np
import imageio
from sync_batchnorm import DataParallelWithCallback
def reconstruction(config, generator, kp_detector, checkpoint, log_dir, dataset):
png_dir = os.path.join(log_dir, 'reconstruction/png')
log_dir = os.path.join(log_dir, 'reconstruction')
if checkpoint is not None:
Logger.load_cpk(checkpoint, generator=generator, kp_detector=kp_detector)
else:
raise AttributeError("Checkpoint should be specified for mode='reconstruction'.")
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
if not os.path.exists(log_dir):
os.makedirs(log_dir)
if not os.path.exists(png_dir):
os.makedirs(png_dir)
loss_list = []
if torch.cuda.is_available():
generator = DataParallelWithCallback(generator)
kp_detector = DataParallelWithCallback(kp_detector)
generator.eval()
kp_detector.eval()
for it, x in tqdm(enumerate(dataloader)):
if config['reconstruction_params']['num_videos'] is not None:
if it > config['reconstruction_params']['num_videos']:
break
with torch.no_grad():
predictions = []
visualizations = []
if torch.cuda.is_available():
x['video'] = x['video'].cuda()
kp_source = kp_detector(x['video'][:, :, 0])
for frame_idx in range(x['video'].shape[2]):
source = x['video'][:, :, 0]
driving = x['video'][:, :, frame_idx]
kp_driving = kp_detector(driving)
out = generator(source, kp_source=kp_source, kp_driving=kp_driving)
out['kp_source'] = kp_source
out['kp_driving'] = kp_driving
del out['sparse_deformed']
predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
visualization = Visualizer(**config['visualizer_params']).visualize(source=source,
driving=driving, out=out)
visualizations.append(visualization)
loss_list.append(torch.abs(out['prediction'] - driving).mean().cpu().numpy())
predictions = np.concatenate(predictions, axis=1)
imageio.imsave(os.path.join(png_dir, x['name'][0] + '.png'), (255 * predictions).astype(np.uint8))
image_name = x['name'][0] + config['reconstruction_params']['format']
imageio.mimsave(os.path.join(log_dir, image_name), visualizations)
print("Reconstruction loss: %s" % np.mean(loss_list))
|