AniDoc / gradio_app.py
fffiloni's picture
Update gradio_app.py
04b5ae3 verified
raw
history blame
4.53 kB
import os
import sys
import shutil
import uuid
import subprocess
import gradio as gr
import shutil
from glob import glob
from huggingface_hub import snapshot_download, hf_hub_download
# Download models
os.makedirs("pretrained_weights", exist_ok=True)
# List of subdirectories to create inside "checkpoints"
subfolders = [
"stable-video-diffusion-img2vid-xt"
]
# Create each subdirectory
for subfolder in subfolders:
os.makedirs(os.path.join("pretrained_weights", subfolder), exist_ok=True)
snapshot_download(
repo_id = "stabilityai/stable-video-diffusion-img2vid",
local_dir = "./pretrained_weights/stable-video-diffusion-img2vid-xt"
)
snapshot_download(
repo_id = "Yhmeng1106/anidoc",
local_dir = "./pretrained_weights"
)
hf_hub_download(
repo_id = "facebook/cotracker",
filename = "cotracker2.pth",
local_dir = "./pretrained_weights"
)
def generate(control_sequence, ref_image):
control_image = control_sequence # "data_test/sample4.mp4"
ref_image = ref_image # "data_test/sample4.png"
unique_id = str(uuid.uuid4())
output_dir = f"results_{unique_id}"
try:
# Run the inference command
subprocess.run(
[
"python", "scripts_infer/anidoc_inference.py",
"--all_sketch",
"--matching",
"--tracking",
"--control_image", f"{control_image}",
"--ref_image", f"{ref_image}",
"--output_dir", f"{output_dir}",
"--max_point", "10",
],
check=True
)
# Search for the mp4 file in a subfolder of output_dir
output_video = glob(os.path.join(output_dir, "*", "*.mp4"))
print(output_video)
if output_video:
output_video_path = output_video[0] # Get the first match
else:
output_video_path = None
print(output_video_path)
return output_video_path
except subprocess.CalledProcessError as e:
raise gr.Error(f"Error during inference: {str(e)}")
css="""
div#col-container{
margin: 0 auto;
max-width: 982px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("# AniDoc: Animation Creation Made Easier")
gr.Markdown("AniDoc colorizes a sequence of sketches based on a character design reference with high fidelity, even when the sketches significantly differ in pose and scale.")
gr.HTML("""
<div style="display:flex;column-gap:4px;">
<a href="https://github.com/yihao-meng/AniDoc">
<img src='https://img.shields.io/badge/GitHub-Repo-blue'>
</a>
<a href="https://yihao-meng.github.io/AniDoc_demo/">
<img src='https://img.shields.io/badge/Project-Page-green'>
</a>
<a href="https://arxiv.org/pdf/2412.14173">
<img src='https://img.shields.io/badge/ArXiv-Paper-red'>
</a>
<a href="https://huggingface.co/spaces/fffiloni/AniDoc?duplicate=true">
<img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space">
</a>
<a href="https://huggingface.co/fffiloni">
<img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/follow-me-on-HF-sm-dark.svg" alt="Follow me on HF">
</a>
</div>
""")
with gr.Row():
with gr.Column():
control_sequence = gr.Video(label="Control Sequence")
ref_image = gr.Image(label="Reference Image", type="filepath")
submit_btn = gr.Button("Submit")
with gr.Column():
video_result = gr.Video(label="Result")
gr.Examples(
examples = [
["data_test/sample1.mp4", "data_test/sample1.png"],
["data_test/sample2.mp4", "data_test/sample2.png"],
["data_test/sample3.mp4", "data_test/sample3.png"],
["data_test/sample4.mp4", "data_test/sample4.png"]
],
inputs = [control_sequence, ref_image]
)
submit_btn.click(
fn = generate,
inputs = [control_sequence, ref_image],
outputs = [video_result]
)
demo.queue().launch(show_api=False, show_error=True)