Spaces:
Running
on
Zero
Running
on
Zero
from typing import Tuple, Union | |
import gradio as gr | |
import numpy as np | |
import see2sound | |
import spaces | |
import torch | |
import yaml | |
import os | |
from huggingface_hub import snapshot_download | |
model_id = "rishitdagli/see-2-sound" | |
base_path = snapshot_download(repo_id=model_id) | |
with open("config.yaml", "r") as file: | |
data = yaml.safe_load(file) | |
data_str = yaml.dump(data) | |
updated_data_str = data_str.replace("checkpoints", base_path) | |
updated_data = yaml.safe_load(updated_data_str) | |
with open("config.yaml", "w") as file: | |
yaml.safe_dump(updated_data, file) | |
model = see2sound.See2Sound(config_path="config.yaml") | |
model.setup() | |
CACHE_DIR = "gradio_cached_examples" | |
#for local cache | |
def load_cached_example_outputs(example_index: int) -> Tuple[str, str]: | |
cached_dir = os.path.join(CACHE_DIR, str(example_index)) # Use the example index to find the directory | |
cached_image_path = os.path.join(cached_dir, "processed_image.png") | |
cached_audio_path = os.path.join(cached_dir, "audio.wav") | |
# Ensure cached files exist | |
if os.path.exists(cached_image_path) and os.path.exists(cached_audio_path): | |
return cached_image_path, cached_audio_path | |
else: | |
raise FileNotFoundError(f"Cached outputs not found for example {example_index}") | |
# Function to handle the example click, based on index | |
def on_example_click(index: int, *args, **kwargs): | |
return load_cached_example_outputs(index) | |
# # to handle the example click, it now accepts arbitrary arguments | |
# def on_example_click(*args, **kwargs): | |
# return load_cached_example_outputs(1) # Always load example 1 for now | |
def process_image( | |
image: str, num_audios: int, prompt: Union[str, None], steps: Union[int, None] | |
) -> Tuple[str, str]: | |
model.run( | |
path=image, | |
output_path="audio.wav", | |
num_audios=num_audios, | |
prompt=prompt, | |
steps=steps, | |
) | |
return image, "audio.wav" | |
description_text = """# SEE-2-SOUND π Demo | |
Official demo for *SEE-2-SOUND π: Zero-Shot Spatial Environment-to-Spatial Sound*. | |
Please refer to our [paper](https://arxiv.org/abs/2406.06612), [project page](https://see2sound.github.io/), or [github](https://github.com/see2sound/see2sound) for more details. | |
> Note: You should make sure that your hardware supports spatial audio. | |
This demo allows you to generate spatial audio given an image. Upload an image (with an optional text prompt in the advanced settings) to geenrate spatial audio to accompany the image. | |
This forked space contains cached examples. | |
Author: [@Rishit-dagli](https://github.com/Rishit-dagli) (University of Toronto) *et al* | |
cc [@jadechoghari](https://github.com/jadechoghari) for HF/Gradio Issues. | |
""" | |
css = """ | |
h1 { | |
text-align: center; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
gr.Markdown(description_text) | |
with gr.Row(): | |
with gr.Column(): | |
image = gr.Image( | |
label="Select an image", sources=["upload", "webcam"], type="filepath" | |
) | |
with gr.Accordion("Advanced Settings", open=False): | |
steps = gr.Slider( | |
label="Diffusion Steps", minimum=1, maximum=1000, step=1, value=500 | |
) | |
prompt = gr.Text( | |
label="Prompt", | |
show_label=True, | |
max_lines=1, | |
placeholder="Enter your prompt", | |
container=True, | |
) | |
num_audios = gr.Slider( | |
label="Number of Audios", minimum=1, maximum=10, step=1, value=3 | |
) | |
submit_button = gr.Button("Submit") | |
with gr.Column(): | |
processed_image = gr.Image(label="Processed Image") | |
processed_video = gr.Video(label="Processed Video", visible=False) # Initially hidden | |
generated_audio = gr.Audio( | |
label="Generated Audio", | |
show_download_button=True, | |
show_share_button=True, | |
waveform_options=gr.WaveformOptions( | |
waveform_color="#01C6FF", | |
waveform_progress_color="#0066B4", | |
show_controls=True, | |
), | |
) | |
# Example inputs, the last two are videos | |
example = [ | |
["examples/1.png"], | |
["examples/2.png"], | |
["examples/3.png"], | |
["examples/4.png"], | |
["examples/5.png"], | |
["examples/6.png"], | |
["examples/7.png"], | |
["examples/8.png"], | |
["examples/9.png"] | |
] | |
def update_examples(index): | |
example_index = int(index) # Convert index to integer for use | |
return load_cached_example_outputs(example_index) | |
gr.Examples( | |
examples=example, # Example inputs | |
inputs=[image, num_audios, prompt, steps], | |
outputs=[processed_image, generated_audio], | |
cache_examples=True, # Cache examples to avoid running the model | |
fn=lambda *args: on_example_click(int(args[0].split('/')[-1][0])) # Extract example index from image path | |
) | |
gr.on( | |
triggers=[submit_button.click], | |
fn=process_image, | |
inputs=[image, num_audios, prompt, steps], | |
outputs=[processed_image, generated_audio], | |
) | |
if __name__ == "__main__": | |
demo.launch(debug=True) | |