zero / text_to_video /text_to_video_pipeline.py
lev1's picture
T2V Tab improvements
2d7762b
raw
history blame
21.1 kB
from diffusers import StableDiffusionPipeline
import torch
from dataclasses import dataclass
from typing import Callable, List, Optional, Union
import numpy as np
from diffusers.utils import deprecate, logging, BaseOutput
from einops import rearrange, repeat
from torch.nn.functional import grid_sample
import torchvision.transforms as T
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
@dataclass
class TextToVideoPipelineOutput(BaseOutput):
videos: Union[torch.Tensor, np.ndarray]
code: Union[torch.Tensor, np.ndarray]
def coords_grid(batch, ht, wd, device):
# Adapted from https://github.com/princeton-vl/RAFT/blob/master/core/utils/utils.py
coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device))
coords = torch.stack(coords[::-1], dim=0).float()
return coords[None].repeat(batch, 1, 1, 1)
class TextToVideoPipeline(StableDiffusionPipeline):
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
):
#super().__init__(*args,**kwargs)
super().__init__(vae,text_encoder,tokenizer,unet,scheduler,safety_checker,feature_extractor,requires_safety_checker)
def DDPM_forward(self, x0, t0, tMax, generator, device, shape, text_embeddings):
rand_device = "cpu" if device.type == "mps" else device
if x0 is None:
return torch.randn(shape, generator=generator, device=rand_device, dtype=text_embeddings.dtype).to(device)
else:
eps = torch.randn_like(x0, dtype=text_embeddings.dtype).to(device)
alpha_vec = torch.prod(self.scheduler.alphas[t0:tMax])
xt = torch.sqrt(alpha_vec) * x0 + \
torch.sqrt(1-alpha_vec) * eps
return xt
def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, video_length, height //
self.vae_scale_factor, width // self.vae_scale_factor)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
rand_device = "cpu" if device.type == "mps" else device
if isinstance(generator, list):
shape = (1,) + shape[1:]
latents = [
torch.randn(
shape, generator=generator[i], device=rand_device, dtype=dtype)
for i in range(batch_size)
]
latents = torch.cat(latents, dim=0).to(device)
else:
latents = torch.randn(
shape, generator=generator, device=rand_device, dtype=dtype).to(device)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
def warp_latents(self, latents, reference_flow):
_, _, H, W = reference_flow.size()
b, c, f, h, w = latents.size()
coords0 = coords_grid(f, H, W, device=latents.device).to(latents.dtype)
coords_t0 = coords0 + reference_flow
coords_t0[:, 0] /= W
coords_t0[:, 1] /= H
coords_t0 = coords_t0 * 2.0 - 1.0
coords_t0 = T.Resize((h, w))(coords_t0)
coords_t0 = rearrange(coords_t0, 'f c h w -> f h w c')
latents_0 = latents[:, :, 0]
latents_0 = latents_0.repeat(f, 1, 1, 1)
warped = grid_sample(latents_0, coords_t0,
mode='nearest', padding_mode='reflection')
warped = rearrange(warped, '(b f) c h w -> b c f h w', f=f)
return warped
def warp_latents_independently(self, latents, reference_flow):
_, _, H, W = reference_flow.size()
b, c, f, h, w = latents.size()
assert b == 1
coords0 = coords_grid(f, H, W, device=latents.device).to(latents.dtype)
coords_t0 = coords0 + reference_flow
coords_t0[:, 0] /= W
coords_t0[:, 1] /= H
coords_t0 = coords_t0 * 2.0 - 1.0
coords_t0 = T.Resize((h, w))(coords_t0)
coords_t0 = rearrange(coords_t0, 'f c h w -> f h w c')
latents_0 = rearrange(latents[0], 'c f h w -> f c h w')
warped = grid_sample(latents_0, coords_t0,
mode='nearest', padding_mode='reflection')
warped = rearrange(warped, '(b f) c h w -> b c f h w', f=f)
return warped
def DDIM_backward(self, num_inference_steps, timesteps, skip_t, t0, t1, do_classifier_free_guidance, null_embs, text_embeddings, latents_local, latents_dtype, guidance_scale, guidance_stop_step, callback, callback_steps, extra_step_kwargs, num_warmup_steps):
entered = False
f = latents_local.shape[2]
latents_local = rearrange(latents_local,"b c f w h -> (b f) c w h")
latents = latents_local.detach().clone()
x_t0_1 = None
x_t1_1 = None
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if t > skip_t:
continue
else:
if not entered:
print(
f"Continue DDIM with i = {i}, t = {t}, latent = {latents.shape}, device = {latents.device}, type = {latents.dtype}")
entered = True
latents = latents.detach()
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat(
[latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t)
# predict the noise residual
with torch.no_grad():
if null_embs is not None:
text_embeddings[0] = null_embs[i][0]
te = torch.cat([repeat(text_embeddings[0,:,:], "c k -> f c k",f=f),repeat(text_embeddings[1,:,:], "c k -> f c k",f=f)])
noise_pred = self.unet(
latent_model_input, t, encoder_hidden_states=te).sample.to(dtype=latents_dtype)
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(
2)
noise_pred = noise_pred_uncond + guidance_scale * \
(noise_pred_text - noise_pred_uncond)
if i >= guidance_stop_step * len(timesteps):
alpha = 0
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(
noise_pred, t, latents, **extra_step_kwargs).prev_sample
# latents = latents - alpha * grads / (torch.norm(grads) + 1e-10)
# call the callback, if provided
if i < len(timesteps)-1 and timesteps[i+1] == t0:
x_t0_1 = latents.detach().clone()
print(f"latent t0 found at i = {i}, t = {t}")
elif i < len(timesteps)-1 and timesteps[i+1] == t1:
x_t1_1 = latents.detach().clone()
print(f"latent t1 found at i={i}, t = {t}")
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
latents = rearrange(latents,"(b f) c w h -> b c f w h",f = f)
res = {"x0": latents.detach().clone()}
if x_t0_1 is not None:
x_t0_1 = rearrange(x_t0_1,"(b f) c w h -> b c f w h",f = f)
res["x_t0_1"] = x_t0_1.detach().clone()
if x_t1_1 is not None:
x_t1_1 = rearrange(x_t1_1,"(b f) c w h -> b c f w h",f = f)
res["x_t1_1"] = x_t1_1.detach().clone()
return res
def decode_latents(self, latents):
video_length = latents.shape[2]
latents = 1 / 0.18215 * latents
latents = rearrange(latents, "b c f h w -> (b f) c h w")
video = self.vae.decode(latents).sample
video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
video = (video / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
return video
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
video_length: Optional[int],
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
guidance_stop_step: float = 0.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_videos_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator,
List[torch.Generator]]] = None,
xT: Optional[torch.FloatTensor] = None,
null_embs: Optional[torch.FloatTensor] = None,
#motion_field_strength: float = 12,
motion_field_strength_x: float = 12,
motion_field_strength_y: float = 12,
output_type: Optional[str] = "tensor",
return_dict: bool = True,
callback: Optional[Callable[[
int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
use_motion_field: bool = True,
smooth_bg: bool = True,
smooth_bg_strength: float = 0.4,
**kwargs,
):
print(motion_field_strength_x,motion_field_strength_y)
print(f" Use: Motion field = {use_motion_field}")
print(f" Use: Background smoothing = {smooth_bg}")
# Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
# Check inputs. Raise error if not correct
self.check_inputs(prompt, height, width, callback_steps)
# Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# Encode input prompt
text_embeddings = self._encode_prompt(
prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
)
# Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# print(f" Latent shape = {latents.shape}")
# Prepare latent variables
num_channels_latents = self.unet.in_channels
xT = self.prepare_latents(
batch_size * num_videos_per_prompt,
num_channels_latents,
video_length,
height,
width,
text_embeddings.dtype,
device,
generator,
xT,
)
dtype = xT.dtype
# when motion field is not used, augment with random latent codes
if use_motion_field:
xT = xT[:, :, :1]
else:
if xT.shape[2] < video_length:
xT_missing = self.prepare_latents(
batch_size * num_videos_per_prompt,
num_channels_latents,
video_length-xT.shape[2],
height,
width,
text_embeddings.dtype,
device,
generator,
None,
)
xT = torch.cat([xT, xT_missing], dim=2)
xInit = xT.clone()
t0 = kwargs["t0"]
t1 = kwargs["t1"]
x_t1_1 = None
# Prepare extra step kwargs.
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# Denoising loop
num_warmup_steps = len(timesteps) - \
num_inference_steps * self.scheduler.order
ddim_res = self.DDIM_backward(num_inference_steps=num_inference_steps, timesteps=timesteps, skip_t=1000, t0=t0, t1=t1, do_classifier_free_guidance=do_classifier_free_guidance,
null_embs=null_embs, text_embeddings=text_embeddings, latents_local=xT, latents_dtype=dtype, guidance_scale=guidance_scale, guidance_stop_step=guidance_stop_step, callback=callback, callback_steps=callback_steps, extra_step_kwargs=extra_step_kwargs, num_warmup_steps=num_warmup_steps)
x0 = ddim_res["x0"].detach()
if "x_t0_1" in ddim_res:
x_t0_1 = ddim_res["x_t0_1"].detach()
if "x_t1_1" in ddim_res:
x_t1_1 = ddim_res["x_t1_1"].detach()
del ddim_res
del xT
if use_motion_field:
del x0
shape = (batch_size, num_channels_latents, 1, height //
self.vae_scale_factor, width // self.vae_scale_factor)
x_t0_k = x_t0_1[:, :, :1, :, :].repeat(1, 1, video_length-1, 1, 1)
reference_flow = torch.zeros(
(video_length-1, 2, 512, 512), device=x_t0_1.device, dtype=x_t0_1.dtype)
for fr_idx in range(video_length-1):
#reference_flow[fr_idx, :, :, :] = motion_field_strength*(fr_idx+1)
reference_flow[fr_idx, 0, :, :] = motion_field_strength_x*(fr_idx+1)
reference_flow[fr_idx, 1, :, :] = motion_field_strength_y*(fr_idx+1)
for idx, latent in enumerate(x_t0_k):
x_t0_k[idx] = self.warp_latents_independently(
latent[None], reference_flow)
# assuming t0=t1=1000, if t0 = 1000
if t1 > t0:
x_t1_k = self.DDPM_forward(
x0=x_t0_k, t0=t0, tMax=t1, device=device, shape=shape, text_embeddings=text_embeddings, generator=generator)
else:
x_t1_k = x_t0_k
if x_t1_1 is None:
raise Exception
x_t1 = torch.cat([x_t1_1, x_t1_k], dim=2).clone().detach()
ddim_res = self.DDIM_backward(num_inference_steps=num_inference_steps, timesteps=timesteps, skip_t=t1, t0=-1, t1=-1, do_classifier_free_guidance=do_classifier_free_guidance,
null_embs=null_embs, text_embeddings=text_embeddings, latents_local=x_t1, latents_dtype=dtype, guidance_scale=guidance_scale, guidance_stop_step=guidance_stop_step, callback=callback, callback_steps=callback_steps, extra_step_kwargs=extra_step_kwargs, num_warmup_steps=num_warmup_steps)
x0 = ddim_res["x0"].detach()
del ddim_res
else:
x_t1 = x_t1_1.clone()
x_t1_1 = x_t1_1[:,:,:1,:,:].clone()
x_t1_k = x_t1_1[:,:,1:,:,:].clone()
x_t0_k = x_t0_1[:, :, 1:, :, :].clone()
x_t0_1 = x_t0_1[:,:,:1,:,:].clone()
# smooth background
if smooth_bg:
h, w = x0.shape[3], x0.shape[4]
M_FG = torch.zeros((batch_size, video_length, h, w),
device=x0.device).to(x0.dtype)
for batch_idx, x0_b in enumerate(x0):
z0_b = self.decode_latents(x0_b[None]).detach()
z0_b = rearrange(z0_b[0], "c f h w -> f h w c")
for frame_idx, z0_f in enumerate(z0_b):
z0_f = torch.round(
z0_f * 255).cpu().numpy().astype(np.uint8)
# apply SOD detection
m_f = torch.tensor(self.sod_model.process_data(
z0_f), device=x0.device).to(x0.dtype)
mask = T.Resize(
size=(h, w), interpolation=T.InterpolationMode.NEAREST)(m_f[None])
kernel = torch.ones(5, 5, device=x0.device, dtype=x0.dtype)
mask = dilation(mask[None].to(x0.device), kernel)[0]
M_FG[batch_idx, frame_idx, :, :] = mask
x_t1_1_fg_masked = x_t1_1 * \
(1 - repeat(M_FG[:, 0, :, :],
"b w h -> b c 1 w h", c=x_t1_1.shape[1]))
x_t1_1_fg_masked_moved = []
for batch_idx, x_t1_1_fg_masked_b in enumerate(x_t1_1_fg_masked):
x_t1_fg_masked_b = x_t1_1_fg_masked_b.clone()
x_t1_fg_masked_b = x_t1_fg_masked_b.repeat(
1, video_length-1, 1, 1)
if use_motion_field:
x_t1_fg_masked_b = x_t1_fg_masked_b[None]
x_t1_fg_masked_b = self.warp_latents_independently(
x_t1_fg_masked_b, reference_flow)
else:
x_t1_fg_masked_b = x_t1_fg_masked_b[None]
x_t1_fg_masked_b = torch.cat(
[x_t1_1_fg_masked_b[None], x_t1_fg_masked_b], dim=2)
x_t1_1_fg_masked_moved.append(x_t1_fg_masked_b)
x_t1_1_fg_masked_moved = torch.cat(x_t1_1_fg_masked_moved, dim=0)
M_FG_1 = M_FG[:, :1, :, :]
M_FG_warped = []
for batch_idx, m_fg_1_b in enumerate(M_FG_1):
m_fg_1_b = m_fg_1_b[None, None]
m_fg_b = m_fg_1_b.repeat(1, 1, video_length-1, 1, 1)
if use_motion_field:
m_fg_b = self.warp_latents_independently(
m_fg_b.clone(), reference_flow)
M_FG_warped.append(
torch.cat([m_fg_1_b[:1, 0], m_fg_b[:1, 0]], dim=1))
M_FG_warped = torch.cat(M_FG_warped, dim=0)
channels = x0.shape[1]
M_BG = (1-M_FG) * (1 - M_FG_warped)
M_BG = repeat(M_BG, "b f h w -> b c f h w", c=channels)
a_convex = smooth_bg_strength
x_t1_blending = (1-M_BG) * x_t1 + M_BG * (a_convex *
x_t1 + (1-a_convex) * x_t1_1_fg_masked_moved)
'''
x_t1_blending = self.DDPM_forward(
x0=x_t1_blending, t0=t1, tMax=961, device=device, shape=shape, text_embeddings=text_embeddings, generator=generator)
t1 = 961
'''
latents = x_t1_blending
ddim_res = self.DDIM_backward(num_inference_steps=num_inference_steps, timesteps=timesteps, skip_t=t1, t0=-1, t1=-1, do_classifier_free_guidance=do_classifier_free_guidance,
null_embs=null_embs, text_embeddings=text_embeddings, latents_local=latents, latents_dtype=dtype, guidance_scale=guidance_scale, guidance_stop_step=guidance_stop_step, callback=callback, callback_steps=callback_steps, extra_step_kwargs=extra_step_kwargs, num_warmup_steps=num_warmup_steps)
x0 = ddim_res["x0"].detach()
del ddim_res
# Post-processing
video_list = []
for latent in x0:
tmp = latent[None]
print("Frame spit shape", tmp.shape)
frames = []
for fr_split in range(tmp.shape[2]):
print("frame decoding")
frames.append(self.decode_latents(
tmp[:, :, fr_split, None]).detach())
video_list.append(torch.cat(frames, dim=2).cpu().float().numpy())
# Convert to tensor
videos = []
if output_type == "tensor":
for video in video_list:
videos.append(torch.from_numpy(video))
if output_type == 'numpy':
for video in video_list:
videos.append(rearrange(video, 'b c f h w -> (b f) h w c'))
if not return_dict:
return video
return TextToVideoPipelineOutput(videos=videos, code=torch.split(xInit.detach().cpu(), 1, dim=0))