Spaces:
Build error
Build error
File size: 4,611 Bytes
7143243 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import os
import argparse
import torch
from torchvision import utils
from model.sg2_model import Generator
from tqdm import tqdm
from pathlib import Path
import numpy as np
import subprocess
import shutil
import copy
from styleclip.styleclip_global import style_tensor_to_style_dict, style_dict_to_style_tensor
VALID_EDITS = ["pose", "age", "smile", "gender", "hair_length", "beard"]
SUGGESTED_DISTANCES = {
"pose": 3.0,
"smile": 2.0,
"age": 4.0,
"gender": 3.0,
"hair_length": -4.0,
"beard": 2.0
}
def project_code(latent_code, boundary, distance=3.0):
if len(boundary) == 2:
boundary = boundary.reshape(1, 1, -1)
return latent_code + distance * boundary
def project_code_by_edit_name(latent_code, name, strength):
boundary_dir = Path(os.path.abspath(__file__)).parents[0].joinpath("editing", "interfacegan_boundaries")
distance = SUGGESTED_DISTANCES[name] * strength
boundary = torch.load(os.path.join(boundary_dir, f'{name}.pt'), map_location="cpu").numpy()
return project_code(latent_code, boundary, distance)
def generate_frames(source_latent, target_latents, g_ema_list, output_dir):
device = "cuda" if torch.cuda.is_available() else "cpu"
code_is_s = target_latents[0].size()[1] == 9088
if code_is_s:
source_s_dict = g_ema_list[0].get_s_code(source_latent, input_is_latent=True)[0]
np_latent = style_dict_to_style_tensor(source_s_dict, g_ema_list[0]).cpu().detach().numpy()
else:
np_latent = source_latent.squeeze(0).cpu().detach().numpy()
np_target_latents = [target_latent.cpu().detach().numpy() for target_latent in target_latents]
num_alphas = 20 if code_is_s else min(10, 30 // len(target_latents))
alphas = np.linspace(0, 1, num=num_alphas)
latents = interpolate_with_target_latents(np_latent, np_target_latents, alphas)
segments = len(g_ema_list) - 1
if segments:
segment_length = len(latents) / segments
g_ema = copy.deepcopy(g_ema_list[0])
src_pars = dict(g_ema.named_parameters())
mix_pars = [dict(model.named_parameters()) for model in g_ema_list]
else:
g_ema = g_ema_list[0]
print("Generating frames for video...")
for idx, latent in tqdm(enumerate(latents), total=len(latents)):
if segments:
mix_alpha = (idx % segment_length) * 1.0 / segment_length
segment_id = int(idx // segment_length)
for k in src_pars.keys():
src_pars[k].data.copy_(mix_pars[segment_id][k] * (1 - mix_alpha) + mix_pars[segment_id + 1][k] * mix_alpha)
if idx == 0 or segments or latent is not latents[idx - 1]:
latent_tensor = torch.from_numpy(latent).float().to(device)
with torch.no_grad():
if code_is_s:
latent_for_gen = style_tensor_to_style_dict(latent_tensor, g_ema)
img, _ = g_ema(latent_for_gen, input_is_s_code=True, input_is_latent=True, truncation=1, randomize_noise=False)
else:
img, _ = g_ema([latent_tensor], input_is_latent=True, truncation=1, randomize_noise=False)
utils.save_image(img, f"{output_dir}/{str(idx).zfill(3)}.jpg", nrow=1, normalize=True, scale_each=True, range=(-1, 1))
def interpolate_forward_backward(source_latent, target_latent, alphas):
latents_forward = [a * target_latent + (1-a) * source_latent for a in alphas] # interpolate from source to target
latents_backward = latents_forward[::-1] # interpolate from target to source
return latents_forward + [target_latent] * len(alphas) + latents_backward # forward + short delay at target + return
def interpolate_with_target_latents(source_latent, target_latents, alphas):
# interpolate latent codes with all targets
print("Interpolating latent codes...")
latents = []
for target_latent in target_latents:
latents.extend(interpolate_forward_backward(source_latent, target_latent, alphas))
return latents
def video_from_interpolations(fps, output_dir):
# combine frames to a video
command = ["ffmpeg",
"-r", f"{fps}",
"-i", f"{output_dir}/%03d.jpg",
"-c:v", "libx264",
"-vf", f"fps={fps}",
"-pix_fmt", "yuv420p",
f"{output_dir}/out.mp4"]
subprocess.call(command)
|