File size: 4,611 Bytes
7143243
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os
import argparse

import torch
from torchvision import utils

from model.sg2_model import Generator
from tqdm import tqdm
from pathlib import Path

import numpy as np

import subprocess
import shutil
import copy

from styleclip.styleclip_global import style_tensor_to_style_dict, style_dict_to_style_tensor

VALID_EDITS = ["pose", "age", "smile", "gender", "hair_length", "beard"]

SUGGESTED_DISTANCES = {
                       "pose": 3.0,
                       "smile": 2.0,
                       "age": 4.0,
                       "gender": 3.0,
                       "hair_length": -4.0,
                       "beard": 2.0
                      }
                    
def project_code(latent_code, boundary, distance=3.0):

    if len(boundary) == 2:
        boundary = boundary.reshape(1, 1, -1)

    return latent_code + distance * boundary

def project_code_by_edit_name(latent_code, name, strength):
    boundary_dir = Path(os.path.abspath(__file__)).parents[0].joinpath("editing", "interfacegan_boundaries")

    distance = SUGGESTED_DISTANCES[name] * strength
    boundary = torch.load(os.path.join(boundary_dir, f'{name}.pt'), map_location="cpu").numpy()

    return project_code(latent_code, boundary, distance)

def generate_frames(source_latent, target_latents, g_ema_list, output_dir):

    device = "cuda" if torch.cuda.is_available() else "cpu"

    code_is_s = target_latents[0].size()[1] == 9088

    if code_is_s:
        source_s_dict = g_ema_list[0].get_s_code(source_latent, input_is_latent=True)[0]
        np_latent = style_dict_to_style_tensor(source_s_dict, g_ema_list[0]).cpu().detach().numpy()
    else:
        np_latent = source_latent.squeeze(0).cpu().detach().numpy()

    np_target_latents = [target_latent.cpu().detach().numpy() for target_latent in target_latents]

    num_alphas = 20 if code_is_s else min(10, 30 // len(target_latents))

    alphas = np.linspace(0, 1, num=num_alphas)
    
    latents = interpolate_with_target_latents(np_latent, np_target_latents, alphas)

    segments = len(g_ema_list) - 1

    if segments:
        segment_length = len(latents) / segments

        g_ema = copy.deepcopy(g_ema_list[0])

        src_pars = dict(g_ema.named_parameters())
        mix_pars = [dict(model.named_parameters()) for model in g_ema_list]
    else:
        g_ema = g_ema_list[0]

    print("Generating frames for video...")
    for idx, latent in tqdm(enumerate(latents), total=len(latents)):

        if segments:
            mix_alpha = (idx % segment_length) * 1.0 / segment_length
            segment_id = int(idx // segment_length)

            for k in src_pars.keys():
                src_pars[k].data.copy_(mix_pars[segment_id][k] * (1 - mix_alpha) + mix_pars[segment_id + 1][k] * mix_alpha)

        if idx == 0 or segments or latent is not latents[idx - 1]:
            latent_tensor = torch.from_numpy(latent).float().to(device)

            with torch.no_grad():
                if code_is_s:
                    latent_for_gen = style_tensor_to_style_dict(latent_tensor, g_ema)
                    img, _ = g_ema(latent_for_gen, input_is_s_code=True, input_is_latent=True, truncation=1, randomize_noise=False)
                else:
                    img, _ = g_ema([latent_tensor], input_is_latent=True, truncation=1, randomize_noise=False)

        utils.save_image(img, f"{output_dir}/{str(idx).zfill(3)}.jpg", nrow=1, normalize=True, scale_each=True, range=(-1, 1))

def interpolate_forward_backward(source_latent, target_latent, alphas):
    latents_forward  = [a * target_latent + (1-a) * source_latent for a in alphas] # interpolate from source to target
    latents_backward = latents_forward[::-1]                                       # interpolate from target to source
    return latents_forward + [target_latent] * len(alphas) + latents_backward      # forward + short delay at target + return

def interpolate_with_target_latents(source_latent, target_latents, alphas):    
    # interpolate latent codes with all targets

    print("Interpolating latent codes...")
    
    latents = []
    for target_latent in target_latents:
        latents.extend(interpolate_forward_backward(source_latent, target_latent, alphas))

    return latents

def video_from_interpolations(fps, output_dir):

    # combine frames to a video
    command = ["ffmpeg", 
               "-r", f"{fps}", 
               "-i", f"{output_dir}/%03d.jpg", 
               "-c:v", "libx264", 
               "-vf", f"fps={fps}", 
               "-pix_fmt", "yuv420p", 
               f"{output_dir}/out.mp4"]
    
    subprocess.call(command)