omg2vid / app.py
salomonsky's picture
Update app.py
9c9643c verified
import gradio as gr
import torch
import os
from glob import glob
from pathlib import Path
from typing import Optional
from diffusers import StableVideoDiffusionPipeline
from diffusers.utils import load_image, export_to_video
from PIL import Image
import uuid
import random
from huggingface_hub import hf_hub_download
import spaces
MAX_64_BIT_INT = 2**63 - 1
DEFAULT_SEED = 42
DEFAULT_OUTPUT_FOLDER = "outputs"
pipe = StableVideoDiffusionPipeline.from_pretrained("vdo/stable-video-diffusion-img2vid-xt-1-1", torch_dtype=torch.float16, variant="fp16")
pipe.to("cpu")
def resize_image(image, output_size=(1024, 576)):
target_aspect = output_size[0] / output_size[1]
image_aspect = image.width / image.height
if image_aspect > target_aspect:
new_height = output_size[1]
new_width = int(new_height * image_aspect)
resized_image = image.resize((new_width, new_height), Image.LANCZOS)
left = (new_width - output_size[0]) / 2
top = 0
right = (new_width + output_size[0]) / 2
bottom = output_size[1]
else:
new_width = output_size[0]
new_height = int(new_width / image_aspect)
resized_image = image.resize((new_width, new_height), Image.LANCZOS)
left = 0
top = (new_height - output_size[1]) / 2
right = output_size[0]
bottom = (new_height + output_size[1]) / 2
cropped_image = resized_image.crop((left, top, right, bottom))
return cropped_image
def generate_video(image, seed, motion_bucket_id, fps_id):
if image.mode == "RGBA":
image = image.convert("RGB")
generator = torch.manual_seed(seed)
frames = pipe(image, decode_chunk_size=3, generator=generator, motion_bucket_id=motion_bucket_id, noise_aug_strength=0.1, num_frames=25).frames[0]
return frames
def export_video(frames, video_path, fps_id):
export_to_video(frames, video_path, fps=fps_id)
@spaces.GPU(duration=120)
def sample(
image,
seed=DEFAULT_SEED,
randomize_seed=True,
motion_bucket_id=127,
fps_id=6,
version="svd_xt",
cond_aug=0.02,
decoding_t=3,
device="cuda",
output_folder=DEFAULT_OUTPUT_FOLDER,
):
if randomize_seed:
seed = random.randint(0, MAX_64_BIT_INT)
generator = torch.manual_seed(seed)
os.makedirs(output_folder, exist_ok=True)
base_count = len(glob(os.path.join(output_folder, "*.mp4")))
video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
frames = generate_video(image, seed, motion_bucket_id, fps_id)
export_video(frames, video_path, fps_id)
torch.manual_seed(seed)
return video_path, frames, seed
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
image = gr.Image(label="Upload your image", type="pil")
with gr.Accordion("Advanced options", open=False):
seed = gr.Slider(label="Seed", value=DEFAULT_SEED, randomize=True, minimum=0, maximum=MAX_64_BIT_INT, step=1)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
motion_bucket_id = gr.Slider(label="Motion bucket id", info="Controls how much motion to add/remove from the image", value=127, minimum=1, maximum=255)
fps_id = gr.Slider(label="Frames per second", info="The length of your video in seconds will be 25/fps", value=6, minimum=5, maximum=30)
generate_btn = gr.Button(value="Animate", variant="primary")
with gr.Column():
video = gr.Video(label="Generated video")
gallery = gr.Gallery(label="Generated frames")
image.upload(fn=resize_image, inputs=image, outputs=image, queue=False)
generate_btn.click(fn=sample, inputs=[image, seed, randomize_seed, motion_bucket_id, fps_id], outputs=[video, gallery, seed], api_name="video")
if __name__ == "__main__":
demo.launch(share=True, show_api=False)