text2live / models /video_model.py
SupermanxKiaski's picture
Upload 351 files
3b40f46
import random
import torch
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from models.image_model import Model
class VideoModel(Model):
def __init__(self, config):
super().__init__(config)
self.config = config
self.net_preprocess = transforms.Compose([])
@staticmethod
def resize_crops(crops, resize_factor):
return torchvision.transforms.functional.resize(
crops,
[
crops.shape[-2] // resize_factor,
crops.shape[-1] // resize_factor,
],
InterpolationMode.BILINEAR,
antialias=True,
)
def process_crops(self, uv_values, crops, original_crops, alpha=None):
resized_crops = []
cnn_output_crops = []
render_dict = {"edit": [], "alpha": [], "edit_on_greenscreen": [], "composite": []}
atlas_crop = crops[0]
for i in range(3):
grid_sampled_atlas_crop = F.grid_sample(
atlas_crop,
uv_values[i],
mode="bilinear",
align_corners=self.config["align_corners"],
).clamp(min=0.0, max=1.0)
resized_crops.append(grid_sampled_atlas_crop)
cnn_output = self.netG(atlas_crop)
cnn_output_crops.append(cnn_output[:, :3])
rendered_atlas_crops = self.render(cnn_output, bg_image=atlas_crop)
for key, value in rendered_atlas_crops.items():
for i in range(3):
sampled_frame_crop = F.grid_sample(
value,
uv_values[i],
mode="bilinear",
align_corners=self.config["align_corners"],
).clamp(min=0.0, max=1.0)
if alpha is not None:
sampled_frame_crop = sampled_frame_crop * alpha[i]
if key == "edit_on_greenscreen":
greenscreen = torch.zeros_like(sampled_frame_crop).to(sampled_frame_crop.device)
greenscreen[:, 1, :, :] = 177 / 255
greenscreen[:, 2, :, :] = 64 / 255
sampled_frame_crop += (1 - alpha[i]) * greenscreen
render_dict[key].append(sampled_frame_crop.squeeze(0))
# passing a random frame to the network
frame_index = random.randint(0, 2) # randomly sample one of three frames
rec_crop = original_crops[frame_index]
resized_crops.append(rec_crop)
cnn_output = self.netG(rec_crop)
if alpha is not None:
alpha_crop = alpha[frame_index]
cnn_output = cnn_output * alpha_crop
cnn_output_crops.append(cnn_output[:, :3])
rendered_frame_crop = self.render(cnn_output, bg_image=original_crops[frame_index])
for key, value in rendered_frame_crop.items():
render_dict[key].append(value.squeeze(0))
return render_dict, resized_crops, cnn_output_crops
def process_atlas(self, atlas):
atlas_edit = self.netG(atlas)
rendered_atlas = self.render(atlas_edit, bg_image=atlas)
return rendered_atlas
def forward(self, input_dict):
inputs = input_dict["global_crops"]
outputs = {"background": {}, "foreground": {}}
if self.config["finetune_foreground"]:
if self.config["multiply_foreground_alpha"]:
alpha = inputs["foreground_alpha"]
else:
alpha = None
foreground_outputs, resized_crops, cnn_output_crops = self.process_crops(
inputs["foreground_uvs"],
inputs["foreground_atlas_crops"],
inputs["original_foreground_crops"],
alpha=alpha,
)
outputs["foreground"]["output_crop"] = foreground_outputs
outputs["foreground"]["cnn_inputs"] = resized_crops
outputs["foreground"]["cnn_outputs"] = cnn_output_crops
if "input_image" in input_dict.keys():
outputs["foreground"]["output_image"] = self.process_atlas(input_dict["input_image"])
elif self.config["finetune_background"]:
background_outputs, resized_crops, cnn_output_crops = self.process_crops(
inputs["background_uvs"],
inputs["background_atlas_crops"],
inputs["original_background_crops"],
)
outputs["background"]["output_crop"] = background_outputs
outputs["background"]["cnn_inputs"] = resized_crops
outputs["background"]["cnn_outputs"] = cnn_output_crops
if "input_image" in input_dict.keys():
outputs["background"]["output_image"] = self.process_atlas(input_dict["input_image"])
return outputs