Spaces:
Running
on
L40S
Running
on
L40S
File size: 4,533 Bytes
3aca8ee 6bfca8e 3aca8ee 4489869 3aca8ee 70b4176 3aca8ee 6bfca8e 3aca8ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import os
import sys
import shutil
import uuid
import subprocess
import gradio as gr
import shutil
from glob import glob
from huggingface_hub import snapshot_download, hf_hub_download
# Download models
os.makedirs("pretrained_weights", exist_ok=True)
# List of subdirectories to create inside "checkpoints"
subfolders = [
"stable-video-diffusion-img2vid-xt"
]
# Create each subdirectory
for subfolder in subfolders:
os.makedirs(os.path.join("pretrained_weights", subfolder), exist_ok=True)
snapshot_download(
repo_id = "stabilityai/stable-video-diffusion-img2vid",
local_dir = "./pretrained_weights/stable-video-diffusion-img2vid-xt"
)
snapshot_download(
repo_id = "Yhmeng1106/anidoc",
local_dir = "./pretrained_weights"
)
hf_hub_download(
repo_id = "facebook/cotracker",
filename = "cotracker2.pth",
local_dir = "./pretrained_weights"
)
def generate(control_sequence, ref_image):
control_image = control_sequence # "data_test/sample4.mp4"
ref_image = ref_image # "data_test/sample4.png"
unique_id = str(uuid.uuid4())
output_dir = f"results_{unique_id}"
try:
# Run the inference command
subprocess.run(
[
"python", "scripts_infer/anidoc_inference.py",
"--all_sketch",
"--matching",
"--tracking",
"--control_image", f"{control_image}",
"--ref_image", f"{ref_image}",
"--output_dir", f"{output_dir}",
"--max_point", "10",
],
check=True
)
# Search for the mp4 file in a subfolder of output_dir
output_video = glob(os.path.join(output_dir,"*.mp4"))
print(output_video)
if output_video:
output_video_path = output_video[0] # Get the first match
else:
output_video_path = None
print(output_video_path)
return output_video_path
except subprocess.CalledProcessError as e:
raise gr.Error(f"Error during inference: {str(e)}")
css="""
div#col-container{
margin: 0 auto;
max-width: 982px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("# AniDoc: Animation Creation Made Easier")
gr.Markdown("AniDoc colorizes a sequence of sketches based on a character design reference with high fidelity, even when the sketches significantly differ in pose and scale.")
gr.HTML("""
<div style="display:flex;column-gap:4px;">
<a href="https://github.com/yihao-meng/AniDoc">
<img src='https://img.shields.io/badge/GitHub-Repo-blue'>
</a>
<a href="https://yihao-meng.github.io/AniDoc_demo/">
<img src='https://img.shields.io/badge/Project-Page-green'>
</a>
<a href="https://arxiv.org/pdf/2412.14173">
<img src='https://img.shields.io/badge/ArXiv-Paper-red'>
</a>
<a href="https://huggingface.co/spaces/fffiloni/AniDoc?duplicate=true">
<img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space">
</a>
<a href="https://huggingface.co/fffiloni">
<img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/follow-me-on-HF-sm-dark.svg" alt="Follow me on HF">
</a>
</div>
""")
with gr.Row():
with gr.Column():
control_sequence = gr.Video(label="Control Sequence", format="mp4")
ref_image = gr.Image(label="Reference Image", type="filepath")
submit_btn = gr.Button("Submit")
with gr.Column():
video_result = gr.Video(label="Result")
gr.Examples(
examples = [
["data_test/sample1.mp4", "data_test/sample1.png"],
["data_test/sample2.mp4", "data_test/sample2.png"],
["data_test/sample3.mp4", "data_test/sample3.png"],
["data_test/sample4.mp4", "data_test/sample4.png"]
],
inputs = [control_sequence, ref_image]
)
submit_btn.click(
fn = generate,
inputs = [control_sequence, ref_image],
outputs = [video_result]
)
demo.queue().launch(show_api=False, show_error=True)
|