Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,664 Bytes
fcb4edd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import os
import gradio as gr
import torch
# import argparse
checkpoint_dir = "checkpoints/svd_reverse_motion_with_attnflip"
from diffusers.utils import load_image, export_to_video
from diffusers import UNetSpatioTemporalConditionModel
from custom_diffusers.pipelines.pipeline_frame_interpolation_with_noise_injection import FrameInterpolationWithNoiseInjectionPipeline
from custom_diffusers.schedulers.scheduling_euler_discrete import EulerDiscreteScheduler
from attn_ctrl.attention_control import (AttentionStore,
register_temporal_self_attention_control,
register_temporal_self_attention_flip_control,
)
pretrained_model_name_or_path = "stabilityai/stable-video-diffusion-img2vid-xt"
noise_scheduler = EulerDiscreteScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
pipe = FrameInterpolationWithNoiseInjectionPipeline.from_pretrained(
pretrained_model_name_or_path,
scheduler=noise_scheduler,
variant="fp16",
torch_dtype=torch.float16,
)
ref_unet = pipe.ori_unet
state_dict = pipe.unet.state_dict()
# computing delta w
finetuned_unet = UNetSpatioTemporalConditionModel.from_pretrained(
checkpoint_dir,
subfolder="unet",
torch_dtype=torch.float16,
)
assert finetuned_unet.config.num_frames==14
ori_unet = UNetSpatioTemporalConditionModel.from_pretrained(
"stabilityai/stable-video-diffusion-img2vid",
subfolder="unet",
variant='fp16',
torch_dtype=torch.float16,
)
finetuned_state_dict = finetuned_unet.state_dict()
ori_state_dict = ori_unet.state_dict()
for name, param in finetuned_state_dict.items():
if 'temporal_transformer_blocks.0.attn1.to_v' in name or "temporal_transformer_blocks.0.attn1.to_out.0" in name:
delta_w = param - ori_state_dict[name]
state_dict[name] = state_dict[name] + delta_w
pipe.unet.load_state_dict(state_dict)
controller_ref= AttentionStore()
register_temporal_self_attention_control(ref_unet, controller_ref)
controller = AttentionStore()
register_temporal_self_attention_flip_control(pipe.unet, controller, controller_ref)
device = "cuda"
pipe = pipe.to(device)
def check_outputs_folder(folder_path):
# Check if the folder exists
if os.path.exists(folder_path) and os.path.isdir(folder_path):
# Delete all contents inside the folder
for filename in os.listdir(folder_path):
file_path = os.path.join(folder_path, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path) # Remove file or link
elif os.path.isdir(file_path):
shutil.rmtree(file_path) # Remove directory
except Exception as e:
print(f'Failed to delete {file_path}. Reason: {e}')
else:
print(f'The folder {folder_path} does not exist.')
def infer(frame1_path, frame2_path):
seed = 42
num_inference_steps = 25
noise_injection_steps = 0
noise_injection_ratio = 0.5
weighted_average = True
generator = torch.Generator(device)
if seed is not None:
generator = generator.manual_seed(seed)
frame1 = load_image(frame1_path)
frame1 = frame1.resize((1024, 576))
frame2 = load_image(frame2_path)
frame2 = frame2.resize((1024, 576))
frames = pipe(image1=frame1, image2=frame2,
num_inference_steps=num_inference_steps, # 50
generator=generator,
weighted_average=weighted_average, # True
noise_injection_steps=noise_injection_steps, # 0
noise_injection_ratio= noise_injection_ratio, # 0.5
).frames[0]
out_dir = "result"
check_outputs_folder(out_dir)
os.makedirs(out_dir, exist_ok=True)
out_path = "result/video_result.mp4"
if out_path.endswith('.gif'):
frames[0].save(out_path, save_all=True, append_images=frames[1:], duration=142, loop=0)
else:
export_to_video(frames, out_path, fps=7)
return out_path
with gr.Blocks() as demo:
with gr.Column():
gr.Markdown("# Keyframe Interpolation with Stable Video Diffusion")
with gr.Row():
with gr.Column():
image_input1 = gr.Image(type="filepath")
image_input2 = gr.Image(type="filepath")
submit_btn = gr.Button("Submit")
with gr.Column():
output = gr.Video()
submit_btn.click(
fn = infer,
inputs = [image_input1, image_input2],
outputs = [output],
show_api = False
)
demo.queue().launch(show_api=False, show_error=True) |