Spaces:
Running
on
Zero
Running
on
Zero
import argparse | |
import logging | |
import math | |
import os | |
import gc | |
import copy | |
from omegaconf import OmegaConf | |
import torch | |
import torch.utils.checkpoint | |
import diffusers | |
import transformers | |
from tqdm.auto import tqdm | |
from accelerate import Accelerator | |
from accelerate.logging import get_logger | |
from models.unet.unet_3d_condition import UNet3DConditionModel | |
from diffusers.models import AutoencoderKL | |
from diffusers import DDIMScheduler, TextToVideoSDPipeline | |
from transformers import CLIPTextModel, CLIPTokenizer | |
from utils.ddim_utils import inverse_video | |
from utils.gpu_utils import handle_memory_attention, unet_and_text_g_c | |
from utils.func_utils import * | |
import imageio | |
import numpy as np | |
from dataset import * | |
from loss import * | |
from noise_init import * | |
from attn_ctrl import register_attention_control | |
import shutil | |
logger = get_logger(__name__, log_level="INFO") | |
def log_validation(accelerator, config, batch, global_step, text_prompt, unet, text_encoder, vae, output_dir): | |
with accelerator.autocast(): | |
unet.eval() | |
text_encoder.eval() | |
unet_and_text_g_c(unet, text_encoder, False, False) | |
# handle spatial lora | |
if config.loss.type =='DebiasedHybrid': | |
loras = extract_lora_child_module(unet, target_replace_module=["Transformer2DModel"]) | |
for lora_i in loras: | |
lora_i.scale = 0 | |
pipeline = TextToVideoSDPipeline.from_pretrained( | |
config.model.pretrained_model_path, | |
text_encoder=text_encoder, | |
vae=vae, | |
unet=unet | |
) | |
prompt_list = text_prompt if len(config.val.prompt) <= 0 else config.val.prompt | |
for seed in config.val.seeds: | |
noisy_latent = batch['inversion_noise'] | |
shape = noisy_latent.shape | |
noise = torch.randn( | |
shape, | |
device=noisy_latent.device, | |
generator=torch.Generator(noisy_latent.device).manual_seed(seed) | |
).to(noisy_latent.dtype) | |
# handle different noise initialization strategy | |
init_func_name = f'{config.noise_init.type}' | |
# Assuming config.dataset is a DictConfig object | |
init_params_dict = OmegaConf.to_container(config.noise_init, resolve=True) | |
# Remove the 'type' key | |
init_params_dict.pop('type', None) # 'None' ensures no error if 'type' key doesn't exist | |
init_func_to_call = globals().get(init_func_name) | |
init_noise = init_func_to_call(noisy_latent, noise, **init_params_dict) | |
for prompt in prompt_list: | |
file_name = f"{prompt.replace(' ', '_')}_seed_{seed}.mp4" | |
file_path = f"{output_dir}/samples_{global_step}/" | |
if not os.path.exists(file_path): | |
os.makedirs(file_path) | |
with torch.no_grad(): | |
video_frames = pipeline( | |
prompt=prompt, | |
negative_prompt=config.val.negative_prompt, | |
width=config.val.width, | |
height=config.val.height, | |
num_frames=config.val.num_frames, | |
num_inference_steps=config.val.num_inference_steps, | |
guidance_scale=config.val.guidance_scale, | |
latents=init_noise, | |
).frames[0] | |
export_to_video(video_frames, os.path.join(file_path, file_name), config.dataset.fps) | |
logger.info(f"Saved a new sample to {os.path.join(file_path, file_name)}") | |
del pipeline | |
torch.cuda.empty_cache() | |
def create_logging(logging, logger, accelerator): | |
logging.basicConfig( | |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
datefmt="%m/%d/%Y %H:%M:%S", | |
level=logging.INFO, | |
) | |
logger.info(accelerator.state, main_process_only=False) | |
def accelerate_set_verbose(accelerator): | |
if accelerator.is_local_main_process: | |
transformers.utils.logging.set_verbosity_warning() | |
diffusers.utils.logging.set_verbosity_info() | |
else: | |
transformers.utils.logging.set_verbosity_error() | |
diffusers.utils.logging.set_verbosity_error() | |
def export_to_video(video_frames, output_video_path, fps): | |
video_writer = imageio.get_writer(output_video_path, fps=fps) | |
for img in video_frames: | |
video_writer.append_data(np.array(img)) | |
video_writer.close() | |
return output_video_path | |
def create_output_folders(output_dir, config): | |
out_dir = os.path.join(output_dir) | |
os.makedirs(out_dir, exist_ok=True) | |
OmegaConf.save(config, os.path.join(out_dir, 'config.yaml')) | |
shutil.copyfile(config.dataset.single_video_path, os.path.join(out_dir,'source.mp4')) | |
return out_dir | |
def load_primary_models(pretrained_model_path): | |
noise_scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler") | |
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") | |
text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder") | |
vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae") | |
unet = UNet3DConditionModel.from_pretrained(pretrained_model_path, subfolder="unet") | |
return noise_scheduler, tokenizer, text_encoder, vae, unet | |
def freeze_models(models_to_freeze): | |
for model in models_to_freeze: | |
if model is not None: model.requires_grad_(False) | |
def is_mixed_precision(accelerator): | |
weight_dtype = torch.float32 | |
if accelerator.mixed_precision == "fp16": | |
weight_dtype = torch.float16 | |
elif accelerator.mixed_precision == "bf16": | |
weight_dtype = torch.bfloat16 | |
return weight_dtype | |
def cast_to_gpu_and_type(model_list, accelerator, weight_dtype): | |
for model in model_list: | |
if model is not None: model.to(accelerator.device, dtype=weight_dtype) | |
def handle_cache_latents( | |
should_cache, | |
output_dir, | |
train_dataloader, | |
train_batch_size, | |
vae, | |
unet, | |
pretrained_model_path, | |
cached_latent_dir=None, | |
): | |
# Cache latents by storing them in VRAM. | |
# Speeds up training and saves memory by not encoding during the train loop. | |
if not should_cache: return None | |
vae.to('cuda', dtype=torch.float16) | |
vae.enable_slicing() | |
pipe = TextToVideoSDPipeline.from_pretrained( | |
pretrained_model_path, | |
vae=vae, | |
unet=copy.deepcopy(unet).to('cuda', dtype=torch.float16) | |
) | |
pipe.text_encoder.to('cuda', dtype=torch.float16) | |
cached_latent_dir = ( | |
os.path.abspath(cached_latent_dir) if cached_latent_dir is not None else None | |
) | |
if cached_latent_dir is None: | |
cache_save_dir = f"{output_dir}/cached_latents" | |
os.makedirs(cache_save_dir, exist_ok=True) | |
for i, batch in enumerate(tqdm(train_dataloader, desc="Caching Latents.")): | |
save_name = f"cached_{i}" | |
full_out_path = f"{cache_save_dir}/{save_name}.pt" | |
pixel_values = batch['pixel_values'].to('cuda', dtype=torch.float16) | |
batch['latents'] = tensor_to_vae_latent(pixel_values, vae) | |
batch['inversion_noise'] = inverse_video(pipe, batch['latents'], 50) | |
for k, v in batch.items(): batch[k] = v[0] | |
torch.save(batch, full_out_path) | |
del pixel_values | |
del batch | |
# We do this to avoid fragmentation from casting latents between devices. | |
torch.cuda.empty_cache() | |
else: | |
cache_save_dir = cached_latent_dir | |
return torch.utils.data.DataLoader( | |
CachedDataset(cache_dir=cache_save_dir), | |
batch_size=train_batch_size, | |
shuffle=True, | |
num_workers=0 | |
) | |
def should_sample(global_step, validation_steps, validation_data): | |
return (global_step == 1 or global_step % validation_steps == 0) and validation_data.sample_preview | |
def save_pipe( | |
path, | |
global_step, | |
accelerator, | |
unet, | |
text_encoder, | |
vae, | |
output_dir, | |
is_checkpoint=False, | |
save_pretrained_model=False, | |
**extra_params | |
): | |
if is_checkpoint: | |
save_path = os.path.join(output_dir, f"checkpoint-{global_step}") | |
os.makedirs(save_path, exist_ok=True) | |
else: | |
save_path = output_dir | |
# Save the dtypes so we can continue training at the same precision. | |
u_dtype, t_dtype, v_dtype = unet.dtype, text_encoder.dtype, vae.dtype | |
# Copy the model without creating a reference to it. This allows keeping the state of our lora training if enabled. | |
unet_out = copy.deepcopy(accelerator.unwrap_model(unet.cpu(), keep_fp32_wrapper=False)) | |
text_encoder_out = copy.deepcopy(accelerator.unwrap_model(text_encoder.cpu(), keep_fp32_wrapper=False)) | |
pipeline = TextToVideoSDPipeline.from_pretrained( | |
path, | |
unet=unet_out, | |
text_encoder=text_encoder_out, | |
vae=vae, | |
).to(torch_dtype=torch.float32) | |
lora_managers_spatial = extra_params.get('lora_managers_spatial', [None]) | |
lora_manager_spatial = lora_managers_spatial[-1] | |
if lora_manager_spatial is not None: | |
lora_manager_spatial.save_lora_weights(model=copy.deepcopy(pipeline), save_path=save_path+'/spatial', step=global_step) | |
save_motion_embeddings(unet_out, os.path.join(save_path, 'motion_embed.pt')) | |
if save_pretrained_model: | |
pipeline.save_pretrained(save_path) | |
if is_checkpoint: | |
unet, text_encoder = accelerator.prepare(unet, text_encoder) | |
models_to_cast_back = [(unet, u_dtype), (text_encoder, t_dtype), (vae, v_dtype)] | |
[x[0].to(accelerator.device, dtype=x[1]) for x in models_to_cast_back] | |
logger.info(f"Saved model at {save_path} on step {global_step}") | |
del pipeline | |
del unet_out | |
del text_encoder_out | |
torch.cuda.empty_cache() | |
gc.collect() | |
def main(config): | |
# Initialize the Accelerator | |
accelerator = Accelerator( | |
gradient_accumulation_steps=config.train.gradient_accumulation_steps, | |
mixed_precision=config.train.mixed_precision, | |
log_with=config.train.logger_type, | |
project_dir=config.train.output_dir | |
) | |
video_path = config.dataset.single_video_path | |
cap = cv2.VideoCapture(video_path) | |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
fps = 8 | |
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
config.dataset.width = width | |
config.dataset.height = height | |
config.dataset.fps = fps | |
config.dataset.n_sample_frames = frame_count | |
config.dataset.single_video_path = video_path | |
config.val.width = width | |
config.val.height = height | |
config.val.num_frames = frame_count | |
# Create output directories and set up logging | |
if accelerator.is_main_process: | |
output_dir = create_output_folders(config.train.output_dir, config) | |
create_logging(logging, logger, accelerator) | |
accelerate_set_verbose(accelerator) | |
# Load primary models | |
noise_scheduler, tokenizer, text_encoder, vae, unet = load_primary_models(config.model.pretrained_model_path) | |
# Load videoCrafter2 unet for better video quality, if needed | |
if config.model.unet == 'videoCrafter2': | |
unet = UNet3DConditionModel.from_pretrained("/hpc2hdd/home/lwang592/ziyang/cache/videocrafterv2",subfolder='unet') | |
elif config.model.unet == 'zeroscope_v2_576w': | |
# by default, we use zeroscope_v2_576w, thus this unet is already loaded | |
pass | |
else: | |
raise ValueError("Invalid UNet model") | |
freeze_models([vae, text_encoder]) | |
handle_memory_attention(unet) | |
train_dataloader, train_dataset = prepare_data(config, tokenizer) | |
# Handle latents caching | |
cached_data_loader = handle_cache_latents( | |
config.train.cache_latents, | |
output_dir, | |
train_dataloader, | |
config.train.train_batch_size, | |
vae, | |
unet, | |
config.model.pretrained_model_path, | |
config.train.cached_latent_dir, | |
) | |
if cached_data_loader is not None: | |
train_dataloader = cached_data_loader | |
# Prepare parameters and optimization | |
params, extra_params = prepare_params(unet, config, train_dataset) | |
optimizers, lr_schedulers = prepare_optimizers(params, config, **extra_params) | |
# Prepare models and data for training | |
unet, optimizers, train_dataloader, lr_schedulers, text_encoder = accelerator.prepare( | |
unet, optimizers, train_dataloader, lr_schedulers, text_encoder | |
) | |
# Additional model setups | |
unet_and_text_g_c(unet, text_encoder) | |
vae.enable_slicing() | |
# Setup for mixed precision training | |
weight_dtype = is_mixed_precision(accelerator) | |
cast_to_gpu_and_type([text_encoder, vae], accelerator, weight_dtype) | |
# Recalculate training steps and epochs | |
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / config.train.gradient_accumulation_steps) | |
num_train_epochs = math.ceil(config.train.max_train_steps / num_update_steps_per_epoch) | |
# Initialize trackers and store configuration | |
if accelerator.is_main_process: | |
accelerator.init_trackers("motion-inversion") | |
# Train! | |
total_batch_size = config.train.train_batch_size * accelerator.num_processes * config.train.gradient_accumulation_steps | |
logger.info("***** Running training *****") | |
logger.info(f" Num examples = {len(train_dataset)}") | |
logger.info(f" Num Epochs = {num_train_epochs}") | |
logger.info(f" Instantaneous batch size per device = {config.train.train_batch_size}") | |
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") | |
logger.info(f" Gradient Accumulation steps = {config.train.gradient_accumulation_steps}") | |
logger.info(f" Total optimization steps = {config.train.max_train_steps}") | |
global_step = 0 | |
first_epoch = 0 | |
# Only show the progress bar once on each machine. | |
progress_bar = tqdm(range(global_step, config.train.max_train_steps), disable=not accelerator.is_local_main_process) | |
progress_bar.set_description("Steps") | |
# Register the attention control, for Motion Value Embedding(s) | |
register_attention_control(unet, config=config) | |
for epoch in range(first_epoch, num_train_epochs): | |
train_loss_temporal = 0.0 | |
for step, batch in enumerate(train_dataloader): | |
# Skip steps until we reach the resumed step | |
if config.train.resume_from_checkpoint and epoch == first_epoch and step < config.train.resume_step: | |
if step % config.train.gradient_accumulation_steps == 0: | |
progress_bar.update(1) | |
continue | |
with accelerator.accumulate(unet), accelerator.accumulate(text_encoder): | |
for optimizer in optimizers: | |
optimizer.zero_grad(set_to_none=True) | |
with accelerator.autocast(): | |
if global_step == 0: | |
unet.train() | |
loss_func_to_call = globals().get(f'{config.loss.type}') | |
loss_temporal, train_loss_temporal = loss_func_to_call( | |
train_loss_temporal, | |
accelerator, | |
optimizers, | |
lr_schedulers, | |
unet, | |
vae, | |
text_encoder, | |
noise_scheduler, | |
batch, | |
step, | |
config | |
) | |
# Checks if the accelerator has performed an optimization step behind the scenes | |
if accelerator.sync_gradients: | |
progress_bar.update(1) | |
global_step += 1 | |
accelerator.log({"train_loss": train_loss_temporal}, step=global_step) | |
train_loss_temporal = 0.0 | |
if global_step % config.train.checkpointing_steps == 0 and global_step > 0: | |
save_pipe( | |
config.model.pretrained_model_path, | |
global_step, | |
accelerator, | |
unet, | |
text_encoder, | |
vae, | |
output_dir, | |
is_checkpoint=True, | |
**extra_params | |
) | |
if loss_temporal is not None: | |
accelerator.log({"loss_temporal": loss_temporal.detach().item()}, step=step) | |
if global_step >= config.train.max_train_steps: | |
break | |
# Create the pipeline using the trained modules and save it. | |
accelerator.wait_for_everyone() | |
accelerator.end_training() | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--config", type=str, default='configs/config.yaml') | |
parser.add_argument("--single_video_path", type=str) | |
parser.add_argument("--prompts", type=str, help="JSON string of prompts") | |
args = parser.parse_args() | |
# Load and merge configurations | |
config = OmegaConf.load(args.config) | |
# Update the config with the command-line arguments | |
if args.single_video_path: | |
config.dataset.single_video_path = args.single_video_path | |
# Set the output dir | |
config.train.output_dir = os.path.join(config.train.output_dir, os.path.basename(args.single_video_path).split('.')[0]) | |
if args.prompts: | |
config.val.prompt = json.loads(args.prompts) | |
main(config) | |