Spaces:
Running
on
Zero
Running
on
Zero
from modules.loader.module_loader import GenericModuleLoader | |
from modules.params.diffusion_trainer.params_streaming_diff_trainer import DiffusionTrainerParams | |
import torch | |
from modules.params.diffusion.inference_params import InferenceParams | |
from utils import result_processor | |
from modules.loader.module_loader import GenericModuleLoader | |
from tqdm import tqdm | |
from PIL import Image, ImageFilter | |
from utils.inference_utils import resize_and_crop,get_padding_for_aspect_ratio | |
import numpy as np | |
from safetensors.torch import load_file as load_safetensors | |
import math | |
from einops import repeat, rearrange | |
from torchvision.transforms import ToTensor | |
from models.svd.sgm.modules.autoencoding.temporal_ae import VideoDecoder | |
import PIL | |
from modules.params.vfi import VFIParams | |
from modules.params.i2v_enhance import I2VEnhanceParams | |
from typing import List,Union | |
from models.diffusion.wrappers import StreamingWrapper | |
from diffusion_trainer.abstract_trainer import AbstractTrainer | |
from utils.loader import download_ckpt | |
import torchvision.transforms.functional as TF | |
from diffusers import AutoPipelineForInpainting, DEISMultistepScheduler | |
from transformers import BlipProcessor, BlipForConditionalGeneration | |
class StreamingSVD(AbstractTrainer): | |
def __init__(self, | |
module_loader: GenericModuleLoader, | |
diff_trainer_params: DiffusionTrainerParams, | |
inference_params: InferenceParams, | |
vfi: VFIParams, | |
i2v_enhance: I2VEnhanceParams, | |
): | |
super().__init__(inference_params=inference_params, | |
diff_trainer_params=diff_trainer_params, | |
module_loader=module_loader, | |
) | |
# network config is wrapped by OpenAIWrapper, so we dont need a direct reference anymore | |
# this corresponds to the config yaml defined at model.module_loader.module_config.model.dependent_modules | |
del self.network_config | |
self.diff_trainer_params: DiffusionTrainerParams | |
self.vfi = vfi | |
self.i2v_enhance = i2v_enhance | |
def on_inference_epoch_start(self): | |
super().on_inference_epoch_start() | |
# for StreamingSVD we use a model wrapper that combines the base SVD model and the control model. | |
self.inference_model = StreamingWrapper( | |
diffusion_model=self.model.diffusion_model, | |
controlnet=self.controlnet, | |
num_frame_conditioning=self.inference_params.num_conditional_frames | |
) | |
def post_init(self): | |
self.svd_pipeline.set_progress_bar_config(disable=True) | |
if self.device.type != "cpu": | |
self.svd_pipeline.enable_model_cpu_offload(gpu_id = self.device.index) | |
# re-use the open clip already loaded for image conditioner for image_encoder_apm | |
embedders = self.conditioner.embedders | |
for embedder in embedders: | |
if hasattr(embedder,"input_key") and embedder.input_key == "cond_frames_without_noise": | |
self.image_encoder_apm = embedder.open_clip | |
self.first_stage_model.to("cpu") | |
self.conditioner.embedders[3].encoder.to("cpu") | |
self.conditioner.embedders[0].open_clip.to("cpu") | |
pipe = AutoPipelineForInpainting.from_pretrained( | |
'Lykon/dreamshaper-8-inpainting', torch_dtype=torch.float16, variant="fp16", safety_checker=None, requires_safety_checker=False) | |
pipe.scheduler = DEISMultistepScheduler.from_config(pipe.scheduler.config) | |
pipe = pipe.to(self.device) | |
pipe.enable_model_cpu_offload(gpu_id = self.device.index) | |
self.inpaint_pipe = pipe | |
processor = BlipProcessor.from_pretrained( | |
"Salesforce/blip-image-captioning-large") | |
model = BlipForConditionalGeneration.from_pretrained( | |
"Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to(self.device) | |
def blip(x): return processor.decode(model.generate(** processor(x, | |
return_tensors='pt').to("cuda", torch.float16))[0], skip_special_tokens=True) | |
self.blip = blip | |
# Adapted from https://github.com/Stability-AI/generative-models/blob/main/scripts/sampling/simple_video_sample.py | |
def get_unique_embedder_keys_from_conditioner(self, conditioner): | |
return list(set([x.input_key for x in conditioner.embedders])) | |
# Adapted from https://github.com/Stability-AI/generative-models/blob/main/scripts/sampling/simple_video_sample.py | |
def get_batch_sgm(self, keys, value_dict, N, T, device): | |
batch = {} | |
batch_uc = {} | |
for key in keys: | |
if key == "fps_id": | |
batch[key] = ( | |
torch.tensor([value_dict["fps_id"]]) | |
.to(device) | |
.repeat(int(math.prod(N))) | |
) | |
elif key == "motion_bucket_id": | |
batch[key] = ( | |
torch.tensor([value_dict["motion_bucket_id"]]) | |
.to(device) | |
.repeat(int(math.prod(N))) | |
) | |
elif key == "cond_aug": | |
batch[key] = repeat( | |
torch.tensor([value_dict["cond_aug"]]).to(device), | |
"1 -> b", | |
b=math.prod(N), | |
) | |
elif key == "cond_frames": | |
batch[key] = repeat(value_dict["cond_frames"], | |
"1 ... -> b ...", b=N[0]) | |
elif key == "cond_frames_without_noise": | |
batch[key] = repeat( | |
value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0] | |
) | |
else: | |
batch[key] = value_dict[key] | |
if T is not None: | |
batch["num_video_frames"] = T | |
for key in batch.keys(): | |
if key not in batch_uc and isinstance(batch[key], torch.Tensor): | |
batch_uc[key] = torch.clone(batch[key]) | |
return batch, batch_uc | |
# Adapted from https://github.com/Stability-AI/generative-models/blob/main/sgm/models/diffusion.py | |
def decode_first_stage(self, z): | |
self.first_stage_model.to(self.device) | |
z = 1.0 / self.diff_trainer_params.scale_factor * z | |
#n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0]) | |
n_samples = min(z.shape[0],8) | |
#print("SVD decoder started") | |
import time | |
start = time.time() | |
n_rounds = math.ceil(z.shape[0] / n_samples) | |
all_out = [] | |
with torch.autocast("cuda", enabled=not self.diff_trainer_params.disable_first_stage_autocast): | |
for n in range(n_rounds): | |
if isinstance(self.first_stage_model.decoder, VideoDecoder): | |
kwargs = {"timesteps": len( | |
z[n * n_samples: (n + 1) * n_samples])} | |
else: | |
kwargs = {} | |
out = self.first_stage_model.decode( | |
z[n * n_samples: (n + 1) * n_samples], **kwargs | |
) | |
all_out.append(out) | |
out = torch.cat(all_out, dim=0) | |
# print(f"SVD decoder finished after {time.time()-start} seconds.") | |
self.first_stage_model.to("cpu") | |
return out | |
# Adapted from https://github.com/Stability-AI/generative-models/blob/main/scripts/sampling/simple_video_sample.py | |
def _generate_conditional_output(self, svd_input_frame, inference_params: InferenceParams, **params): | |
C = 4 | |
F = 8 # spatial compression TODO read from model | |
H = svd_input_frame.shape[-2] | |
W = svd_input_frame.shape[-1] | |
num_frames = self.sampler.guider.num_frames | |
shape = (num_frames, C, H // F, W // F) | |
batch_size = 1 | |
image = svd_input_frame[None,:] | |
cond_aug = 0.02 | |
value_dict = {} | |
value_dict["motion_bucket_id"] = 127 | |
value_dict["fps_id"] = 6 | |
value_dict["cond_aug"] = cond_aug | |
value_dict["cond_frames_without_noise"] = image | |
value_dict["cond_frames"] =image + cond_aug * torch.rand_like(image) | |
batch, batch_uc = self.get_batch_sgm( | |
self.get_unique_embedder_keys_from_conditioner( | |
self.conditioner), | |
value_dict, | |
[1, num_frames], | |
T=num_frames, | |
device=self.device, | |
) | |
self.conditioner.embedders[3].encoder.to(self.device) | |
self.conditioner.embedders[0].open_clip.to(self.device) | |
c, uc = self.conditioner.get_unconditional_conditioning( | |
batch, | |
batch_uc=batch_uc, | |
force_uc_zero_embeddings=[ | |
"cond_frames", | |
"cond_frames_without_noise", | |
], | |
) | |
self.conditioner.embedders[3].encoder.to("cpu") | |
self.conditioner.embedders[0].open_clip.to("cpu") | |
for k in ["crossattn", "concat"]: | |
uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames) | |
uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames) | |
c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames) | |
c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames) | |
randn = torch.randn(shape, device=self.device) | |
additional_model_inputs = {} | |
additional_model_inputs["image_only_indicator"] = torch.zeros(2*batch_size,num_frames).to(self.device) | |
additional_model_inputs["num_video_frames"] = batch["num_video_frames"] | |
# StreamingSVD inputs | |
additional_model_inputs["batch_size"] = 2*batch_size | |
additional_model_inputs["num_conditional_frames"] = self.inference_params.num_conditional_frames | |
additional_model_inputs["ctrl_frames"] = params["ctrl_frames"] | |
self.inference_model.diffusion_model = self.inference_model.diffusion_model.to( | |
self.device) | |
self.inference_model.controlnet = self.inference_model.controlnet.to( | |
self.device) | |
c["vector"] = c["vector"].to(randn.dtype) | |
uc["vector"] = uc["vector"].to(randn.dtype) | |
def denoiser(input, sigma, c): | |
return self.denoiser(self.inference_model,input,sigma,c, **additional_model_inputs) | |
samples_z = self.sampler(denoiser,randn,cond=c,uc=uc) | |
self.inference_model.diffusion_model = self.inference_model.diffusion_model.to( "cpu") | |
self.inference_model.controlnet = self.inference_model.controlnet.to("cpu") | |
samples_x = self.decode_first_stage(samples_z) | |
samples = torch.clamp(samples_x,min=-1.0,max=1.0) | |
return samples | |
def extract_anchor_frames(self, video, input_range,inference_params: InferenceParams): | |
""" | |
Extracts anchor frames from the input video based on the provided inference parameters. | |
Parameters: | |
- video: torch.Tensor | |
The input video tensor. | |
- input_range: list | |
The pixel value range of input video. | |
- inference_params: InferenceParams | |
An object containing inference parameters. | |
- anchor_frames: str | |
Specifies how the anchor frames are encoded. It can be either a single number specifying which frame is used as the anchor frame, | |
or a range in the format "a:b" indicating that frames from index a up to index b (inclusive) are used as anchor frames. | |
Returns: | |
- torch.Tensor | |
The extracted anchor frames from the input video. | |
""" | |
video = result_processor.convert_range(video=video.clone(),input_range=input_range,output_range=[-1,1]) | |
if video.shape[1] == 3 and video.shape[0]>3: | |
video = rearrange(video,"F C W H -> 1 F C W H") | |
elif video.shape[0]>3 and video.shape[-1] == 3: | |
video = rearrange(video,"F W H C -> 1 F C W H") | |
else: | |
raise NotImplementedError(f"Unexpected video input format: {video.shape}") | |
if ":" in inference_params.anchor_frames: | |
anchor_frames = inference_params.anchor_frames.split(":") | |
anchor_frames = [int(anchor_frame) for anchor_frame in anchor_frames] | |
assert len(anchor_frames) == 2,"Anchor frames encoding wrong." | |
anchor = video[:,anchor_frames[0]:anchor_frames[1]] | |
else: | |
anchor_frame = int(inference_params.anchor_frames) | |
anchor = video[:, anchor_frame].unsqueeze(0) | |
return anchor | |
def extract_ctrl_frames(self,video: torch.FloatType, input_range: List[int], inference_params: InferenceParams): | |
""" | |
Extracts control frames from the input video. | |
Parameters: | |
- video: torch.Tensor | |
The input video tensor. | |
- input_range: list | |
The pixel value range of input video. | |
- inference_params: InferenceParams | |
An object containing inference parameters. | |
Returns: | |
- torch.Tensor | |
The extracted control image encoding frames from the input video. | |
""" | |
video = result_processor.convert_range(video=video.clone(), input_range=input_range, output_range=[-1, 1]) | |
if video.shape[1] == 3 and video.shape[0] > 3: | |
video = rearrange(video, "F C W H -> 1 F C W H") | |
elif video.shape[0] > 3 and video.shape[-1] == 3: | |
video = rearrange(video, "F W H C -> 1 F C W H") | |
else: | |
raise NotImplementedError( | |
f"Unexpected video input format: {video.shape}") | |
# return the last num_conditional_frames frames | |
video = video[:, -inference_params.num_conditional_frames:] | |
return video | |
def _autoregressive_generation(self,initial_generation: Union[torch.FloatType,List[torch.FloatType]], inference_params:InferenceParams): | |
""" | |
Perform autoregressive generation of video chunks based on the initial generation and inference parameters. | |
Parameters: | |
- initial_generation: torch.Tensor or list of torch.Tensor | |
The initial generation or list of initial generation video chunks. | |
- inference_params: InferenceParams | |
An object containing inference parameters. | |
Returns: | |
- torch.Tensor | |
The generated video resulting from autoregressive generation. | |
""" | |
# input is [-1,1] float | |
result_chunks = initial_generation | |
if not isinstance(result_chunks,list): | |
result_chunks = [result_chunks] | |
# make sure | |
if (result_chunks[0].shape[1] >3) and (result_chunks[0].shape[-1] == 3): | |
result_chunks = [rearrange(result_chunks[0],"F W H C -> F C W H")] | |
# generating chunk by conditioning on the previous chunks | |
for _ in tqdm(list(range(inference_params.n_autoregressive_generations)),desc="StreamingSVD"): | |
# extract anchor frames based on the entire, so far generated, video | |
# note that we do note use anchor frame in StreamingSVD (apart from the anchor frame already used by SVD). | |
anchor_frames = self.extract_anchor_frames( | |
video = torch.cat(result_chunks), | |
inference_params=inference_params, | |
input_range=[-1, 1], | |
) | |
# extract control frames based on the last generated chunk | |
ctrl_frames = self.extract_ctrl_frames( | |
video = result_chunks[-1], | |
input_range=[-1, 1], | |
inference_params=inference_params, | |
) | |
# select the anchor frame for svd | |
svd_input_frame = result_chunks[0][int(inference_params.anchor_frames)] | |
# generate the next chunk | |
# result is [F, C, H, W], range is [-1,1] float. | |
result = self._generate_conditional_output( | |
svd_input_frame = svd_input_frame, | |
inference_params=inference_params, | |
anchor_frames=anchor_frames, | |
ctrl_frames=ctrl_frames, | |
) | |
# from each generation, we keep all frames except for the first <num_conditional_frames> frames | |
result = result[inference_params.num_conditional_frames:] | |
result_chunks.append(result) | |
torch.cuda.empty_cache() | |
# concat all chunks to one long video | |
result_chunks = [result_processor.convert_range(chunk,output_range=[0,255],input_range=[-1,1]) for chunk in result_chunks] | |
result = result_processor.concat_chunks(result_chunks) | |
torch.cuda.empty_cache() | |
return result | |
def ensure_image_ratio(self,source_image: PIL,target_aspect_ratio = 16/9): | |
if source_image.width / source_image.height == target_aspect_ratio: | |
return source_image, None | |
image = source_image.copy().convert("RGBA") | |
mask = image.split()[-1] | |
image = image.convert("RGB") | |
padding = get_padding_for_aspect_ratio(image) | |
mask_padded = TF.pad(mask, padding) | |
mask_padded_size = mask_padded.size | |
mask_padded_resized = TF.resize(mask_padded, (512, 512), | |
interpolation=TF.InterpolationMode.NEAREST) | |
mask_padded_resized = TF.invert(mask_padded_resized) | |
# image | |
padded_input_image = TF.pad(image, padding, padding_mode="reflect") | |
resized_image = TF.resize(padded_input_image, (512, 512)) | |
image_tensor = (self.inpaint_pipe.image_processor.preprocess( | |
resized_image).cuda().half()) | |
latent_tensor = self.inpaint_pipe._encode_vae_image(image_tensor, None) | |
self.inpaint_pipe.scheduler.set_timesteps(999) | |
noisy_latent_tensor = self.inpaint_pipe.scheduler.add_noise( | |
latent_tensor, | |
torch.randn_like(latent_tensor), | |
self.inpaint_pipe.scheduler.timesteps[:1], | |
) | |
prompt = self.blip(source_image) | |
if prompt.startswith("there is "): | |
prompt = prompt[len("there is "):] | |
output_image_normalized_size = self.inpaint_pipe( | |
prompt=prompt, | |
image=resized_image, | |
mask_image=mask_padded_resized, | |
latents=noisy_latent_tensor, | |
).images[0] | |
output_image_extended_size = TF.resize( | |
output_image_normalized_size, mask_padded_size[::-1]) | |
blured_outpainting_mask = TF.invert(mask_padded).filter( | |
ImageFilter.GaussianBlur(radius=5)) | |
final_image = Image.composite( | |
output_image_extended_size, padded_input_image, blured_outpainting_mask) | |
return final_image, TF.invert(mask_padded) | |
def image_to_video(self, batch, inference_params: InferenceParams, batch_idx): | |
""" | |
Performs image to video based on the input batch and inference parameters. | |
It runs SVD-XT one to generate the first chunk, then auto-regressively applies StreamingSVD. | |
Parameters: | |
- batch: dict | |
The input batch containing the start image for generating the video. | |
- inference_params: InferenceParams | |
An object containing inference parameters. | |
- batch_idx: int | |
The index of the batch. | |
Returns: | |
- torch.Tensor | |
The generated video based on the image image. | |
""" | |
batch_key = "image" | |
assert batch_key == "image", f"Generating video from {batch_key} not implemented." | |
input_image = PIL.Image.fromarray(batch[batch_key][0].cpu().numpy()) | |
# TODO remove conversion forth and back | |
outpainted_image, _ = self.ensure_image_ratio(input_image) | |
#image = Image.fromarray(np.uint8(image)) | |
''' | |
if image.width/image.height != 16/9: | |
print(f"Warning! For best results, we assume the aspect ratio of the input image to be 16:9. Found ratio {image.width}:{image.height}.") | |
''' | |
scaled_outpainted_image, expanded_size = resize_and_crop(outpainted_image) | |
assert scaled_outpainted_image.width == 1024 and scaled_outpainted_image.height == 576, f"Wrong shape for file {batch[batch_key]} with shape {scaled_outpainted_image.width}:{scaled_outpainted_image.height}." | |
# Generating first chunk | |
with torch.autocast(device_type="cuda",enabled=False): | |
video_chunks = self.svd_pipeline( | |
scaled_outpainted_image, decode_chunk_size=8).frames[0] | |
video_chunks = torch.stack([ToTensor()(frame) for frame in video_chunks]) | |
video_chunks = video_chunks * 2.0 - 1 # [-1,1], float | |
video_chunks = video_chunks.to(self.device) | |
video = self._autoregressive_generation( | |
initial_generation=video_chunks, | |
inference_params=inference_params) | |
return video, scaled_outpainted_image, expanded_size | |
def generate_output(self, batch, batch_idx,inference_params: InferenceParams): | |
""" | |
Generate output video based on the input batch and inference parameters. | |
Parameters: | |
- batch: dict | |
The input batch containing data for generating the output video. | |
- batch_idx: int | |
The index of the batch. | |
- inference_params: InferenceParams | |
An object containing inference parameters. | |
Returns: | |
- torch.Tensor | |
The generated video. Note the result is also accessible via self.trainer.generated_video | |
""" | |
sample_id = batch["sample_id"].item() | |
video, scaled_outpainted_image, expanded_size = self.image_to_video( | |
batch, inference_params=inference_params, batch_idx=sample_id) | |
self.trainer.generated_video = video.numpy() | |
self.trainer.expanded_size = expanded_size | |
self.trainer.scaled_outpainted_image = scaled_outpainted_image | |
return video | |