import torch from tqdm import tqdm from typing import List, Optional, Tuple from models import PipelineWrapper import gradio as gr def inversion_forward_process(model: PipelineWrapper, x0: torch.Tensor, etas: Optional[float] = None, prompts: List[str] = [""], cfg_scales: List[float] = [3.5], num_inference_steps: int = 50, numerical_fix: bool = False, duration: Optional[float] = None, first_order: bool = False, save_compute: bool = True, progress=gr.Progress()) -> Tuple: if len(prompts) > 1 or prompts[0] != "": text_embeddings_hidden_states, text_embeddings_class_labels, \ text_embeddings_boolean_prompt_mask = model.encode_text(prompts) # In the forward negative prompts are not supported currently (TODO) uncond_embeddings_hidden_states, uncond_embeddings_class_lables, uncond_boolean_prompt_mask = model.encode_text( [""], negative=True, save_compute=save_compute, cond_length=text_embeddings_class_labels.shape[1] if text_embeddings_class_labels is not None else None) else: uncond_embeddings_hidden_states, uncond_embeddings_class_lables, uncond_boolean_prompt_mask = model.encode_text( [""], negative=True, save_compute=False) timesteps = model.model.scheduler.timesteps.to(model.device) variance_noise_shape = model.get_noise_shape(x0, num_inference_steps) if type(etas) in [int, float]: etas = [etas]*model.model.scheduler.num_inference_steps xts = model.sample_xts_from_x0(x0, num_inference_steps=num_inference_steps) zs = torch.zeros(size=variance_noise_shape, device=model.device) extra_info = [None] * len(zs) if timesteps[0].dtype == torch.int64: t_to_idx = {int(v): k for k, v in enumerate(timesteps)} elif timesteps[0].dtype == torch.float32: t_to_idx = {float(v): k for k, v in enumerate(timesteps)} xt = x0 op = tqdm(timesteps, desc="Inverting") model.setup_extra_inputs(xt, init_timestep=timesteps[0], audio_end_in_s=duration, save_compute=save_compute and prompts[0] != "") app_op = progress.tqdm(timesteps, desc="Inverting") for t, _ in zip(op, app_op): idx = num_inference_steps - t_to_idx[int(t) if timesteps[0].dtype == torch.int64 else float(t)] - 1 # 1. predict noise residual xt = xts[idx+1][None] xt_inp = model.model.scheduler.scale_model_input(xt, t) with torch.no_grad(): if save_compute and prompts[0] != "": comb_out, _, _ = model.unet_forward( xt_inp.expand(2, -1, -1, -1) if hasattr(model.model, 'unet') else xt_inp.expand(2, -1, -1), timestep=t, encoder_hidden_states=torch.cat([uncond_embeddings_hidden_states, text_embeddings_hidden_states ], dim=0) if uncond_embeddings_hidden_states is not None else None, class_labels=torch.cat([uncond_embeddings_class_lables, text_embeddings_class_labels], dim=0) if uncond_embeddings_class_lables is not None else None, encoder_attention_mask=torch.cat([uncond_boolean_prompt_mask, text_embeddings_boolean_prompt_mask ], dim=0) if uncond_boolean_prompt_mask is not None else None, ) out, cond_out = comb_out.sample.chunk(2, dim=0) else: out = model.unet_forward(xt_inp, timestep=t, encoder_hidden_states=uncond_embeddings_hidden_states, class_labels=uncond_embeddings_class_lables, encoder_attention_mask=uncond_boolean_prompt_mask)[0].sample if len(prompts) > 1 or prompts[0] != "": cond_out = model.unet_forward( xt_inp, timestep=t, encoder_hidden_states=text_embeddings_hidden_states, class_labels=text_embeddings_class_labels, encoder_attention_mask=text_embeddings_boolean_prompt_mask)[0].sample if len(prompts) > 1 or prompts[0] != "": # # classifier free guidance noise_pred = out + (cfg_scales[0] * (cond_out - out)).sum(axis=0).unsqueeze(0) else: noise_pred = out # xtm1 = xts[idx+1][None] xtm1 = xts[idx][None] z, xtm1, extra = model.get_zs_from_xts(xt, xtm1, noise_pred, t, eta=etas[idx], numerical_fix=numerical_fix, first_order=first_order) zs[idx] = z # print(f"Fix Xt-1 distance - NORM:{torch.norm(xts[idx] - xtm1):.4g}, MSE:{((xts[idx] - xtm1)**2).mean():.4g}") xts[idx] = xtm1 extra_info[idx] = extra if zs is not None: # zs[-1] = torch.zeros_like(zs[-1]) zs[0] = torch.zeros_like(zs[0]) # zs_cycle[0] = torch.zeros_like(zs[0]) del app_op.iterables[0] return xt, zs, xts, extra_info def inversion_reverse_process(model: PipelineWrapper, xT: torch.Tensor, tstart: torch.Tensor, etas: float = 0, prompts: List[str] = [""], neg_prompts: List[str] = [""], cfg_scales: Optional[List[float]] = None, zs: Optional[List[torch.Tensor]] = None, duration: Optional[float] = None, first_order: bool = False, extra_info: Optional[List] = None, save_compute: bool = True, progress=gr.Progress()) -> Tuple[torch.Tensor, torch.Tensor]: text_embeddings_hidden_states, text_embeddings_class_labels, \ text_embeddings_boolean_prompt_mask = model.encode_text(prompts) uncond_embeddings_hidden_states, uncond_embeddings_class_lables, \ uncond_boolean_prompt_mask = model.encode_text(neg_prompts, negative=True, save_compute=save_compute, cond_length=text_embeddings_class_labels.shape[1] if text_embeddings_class_labels is not None else None) xt = xT[tstart.max()].unsqueeze(0) if etas is None: etas = 0 if type(etas) in [int, float]: etas = [etas]*model.model.scheduler.num_inference_steps assert len(etas) == model.model.scheduler.num_inference_steps timesteps = model.model.scheduler.timesteps.to(model.device) op = tqdm(timesteps[-zs.shape[0]:], desc="Editing") if timesteps[0].dtype == torch.int64: t_to_idx = {int(v): k for k, v in enumerate(timesteps[-zs.shape[0]:])} elif timesteps[0].dtype == torch.float32: t_to_idx = {float(v): k for k, v in enumerate(timesteps[-zs.shape[0]:])} model.setup_extra_inputs(xt, extra_info=extra_info, init_timestep=timesteps[-zs.shape[0]], audio_end_in_s=duration, save_compute=save_compute) app_op = progress.tqdm(timesteps[-zs.shape[0]:], desc="Editing") for it, (t, _) in enumerate(zip(op, app_op)): idx = model.model.scheduler.num_inference_steps - t_to_idx[ int(t) if timesteps[0].dtype == torch.int64 else float(t)] - \ (model.model.scheduler.num_inference_steps - zs.shape[0] + 1) xt_inp = model.model.scheduler.scale_model_input(xt, t) # # Unconditional embedding with torch.no_grad(): # print(f'xt_inp.shape: {xt_inp.shape}') # print(f't.shape: {t.shape}') # print(f'uncond_embeddings_hidden_states.shape: {uncond_embeddings_hidden_states.shape}') # print(f'uncond_embeddings_class_lables.shape: {uncond_embeddings_class_lables.shape}') # print(f'uncond_boolean_prompt_mask.shape: {uncond_boolean_prompt_mask.shape}') # print(f'text_embeddings_hidden_states.shape: {text_embeddings_hidden_states.shape}') # print(f'text_embeddings_class_labels.shape: {text_embeddings_class_labels.shape}') # print(f'text_embeddings_boolean_prompt_mask.shape: {text_embeddings_boolean_prompt_mask.shape}') if save_compute: comb_out, _, _ = model.unet_forward( xt_inp.expand(2, -1, -1, -1) if hasattr(model.model, 'unet') else xt_inp.expand(2, -1, -1), timestep=t, encoder_hidden_states=torch.cat([uncond_embeddings_hidden_states, text_embeddings_hidden_states ], dim=0) if uncond_embeddings_hidden_states is not None else None, class_labels=torch.cat([uncond_embeddings_class_lables, text_embeddings_class_labels], dim=0) if uncond_embeddings_class_lables is not None else None, encoder_attention_mask=torch.cat([uncond_boolean_prompt_mask, text_embeddings_boolean_prompt_mask ], dim=0) if uncond_boolean_prompt_mask is not None else None, ) uncond_out, cond_out = comb_out.sample.chunk(2, dim=0) else: uncond_out = model.unet_forward( xt_inp, timestep=t, encoder_hidden_states=uncond_embeddings_hidden_states, class_labels=uncond_embeddings_class_lables, encoder_attention_mask=uncond_boolean_prompt_mask, )[0].sample # Conditional embedding cond_out = model.unet_forward( xt_inp, timestep=t, encoder_hidden_states=text_embeddings_hidden_states, class_labels=text_embeddings_class_labels, encoder_attention_mask=text_embeddings_boolean_prompt_mask, )[0].sample z = zs[idx] if zs is not None else None z = z.unsqueeze(0) # classifier free guidance noise_pred = uncond_out + (cfg_scales[0] * (cond_out - uncond_out)).sum(axis=0).unsqueeze(0) # 2. compute less noisy image and set x_t -> x_t-1 xt = model.reverse_step_with_custom_noise(noise_pred, t, xt, variance_noise=z, eta=etas[idx], first_order=first_order) del app_op.iterables[0] return xt, zs