fffiloni's picture
Upload 33 files
fcb4edd verified
raw
history blame
No virus
4.66 kB
import os
import gradio as gr
import torch
# import argparse
checkpoint_dir = "checkpoints/svd_reverse_motion_with_attnflip"
from diffusers.utils import load_image, export_to_video
from diffusers import UNetSpatioTemporalConditionModel
from custom_diffusers.pipelines.pipeline_frame_interpolation_with_noise_injection import FrameInterpolationWithNoiseInjectionPipeline
from custom_diffusers.schedulers.scheduling_euler_discrete import EulerDiscreteScheduler
from attn_ctrl.attention_control import (AttentionStore,
register_temporal_self_attention_control,
register_temporal_self_attention_flip_control,
)
pretrained_model_name_or_path = "stabilityai/stable-video-diffusion-img2vid-xt"
noise_scheduler = EulerDiscreteScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
pipe = FrameInterpolationWithNoiseInjectionPipeline.from_pretrained(
pretrained_model_name_or_path,
scheduler=noise_scheduler,
variant="fp16",
torch_dtype=torch.float16,
)
ref_unet = pipe.ori_unet
state_dict = pipe.unet.state_dict()
# computing delta w
finetuned_unet = UNetSpatioTemporalConditionModel.from_pretrained(
checkpoint_dir,
subfolder="unet",
torch_dtype=torch.float16,
)
assert finetuned_unet.config.num_frames==14
ori_unet = UNetSpatioTemporalConditionModel.from_pretrained(
"stabilityai/stable-video-diffusion-img2vid",
subfolder="unet",
variant='fp16',
torch_dtype=torch.float16,
)
finetuned_state_dict = finetuned_unet.state_dict()
ori_state_dict = ori_unet.state_dict()
for name, param in finetuned_state_dict.items():
if 'temporal_transformer_blocks.0.attn1.to_v' in name or "temporal_transformer_blocks.0.attn1.to_out.0" in name:
delta_w = param - ori_state_dict[name]
state_dict[name] = state_dict[name] + delta_w
pipe.unet.load_state_dict(state_dict)
controller_ref= AttentionStore()
register_temporal_self_attention_control(ref_unet, controller_ref)
controller = AttentionStore()
register_temporal_self_attention_flip_control(pipe.unet, controller, controller_ref)
device = "cuda"
pipe = pipe.to(device)
def check_outputs_folder(folder_path):
# Check if the folder exists
if os.path.exists(folder_path) and os.path.isdir(folder_path):
# Delete all contents inside the folder
for filename in os.listdir(folder_path):
file_path = os.path.join(folder_path, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path) # Remove file or link
elif os.path.isdir(file_path):
shutil.rmtree(file_path) # Remove directory
except Exception as e:
print(f'Failed to delete {file_path}. Reason: {e}')
else:
print(f'The folder {folder_path} does not exist.')
def infer(frame1_path, frame2_path):
seed = 42
num_inference_steps = 25
noise_injection_steps = 0
noise_injection_ratio = 0.5
weighted_average = True
generator = torch.Generator(device)
if seed is not None:
generator = generator.manual_seed(seed)
frame1 = load_image(frame1_path)
frame1 = frame1.resize((1024, 576))
frame2 = load_image(frame2_path)
frame2 = frame2.resize((1024, 576))
frames = pipe(image1=frame1, image2=frame2,
num_inference_steps=num_inference_steps, # 50
generator=generator,
weighted_average=weighted_average, # True
noise_injection_steps=noise_injection_steps, # 0
noise_injection_ratio= noise_injection_ratio, # 0.5
).frames[0]
out_dir = "result"
check_outputs_folder(out_dir)
os.makedirs(out_dir, exist_ok=True)
out_path = "result/video_result.mp4"
if out_path.endswith('.gif'):
frames[0].save(out_path, save_all=True, append_images=frames[1:], duration=142, loop=0)
else:
export_to_video(frames, out_path, fps=7)
return out_path
with gr.Blocks() as demo:
with gr.Column():
gr.Markdown("# Keyframe Interpolation with Stable Video Diffusion")
with gr.Row():
with gr.Column():
image_input1 = gr.Image(type="filepath")
image_input2 = gr.Image(type="filepath")
submit_btn = gr.Button("Submit")
with gr.Column():
output = gr.Video()
submit_btn.click(
fn = infer,
inputs = [image_input1, image_input2],
outputs = [output],
show_api = False
)
demo.queue().launch(show_api=False, show_error=True)