File size: 4,398 Bytes
46ff99b
7e40a31
46ff99b
7e40a31
 
2879448
7e40a31
 
cfb23a7
7e40a31
 
 
 
46ff99b
7e40a31
 
 
 
 
 
 
7c3177c
7e40a31
 
9a68e0a
8575490
 
7e40a31
9a68e0a
 
7c3177c
 
 
7e40a31
7c3177c
 
9a68e0a
 
46ff99b
7e40a31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46ff99b
7e40a31
 
5e5405f
7e40a31
 
 
46ff99b
 
 
7e40a31
 
 
46ff99b
 
 
 
 
 
 
7e40a31
 
 
46ff99b
 
7e40a31
 
 
 
 
 
 
 
 
 
 
 
 
46ff99b
 
 
 
 
7e40a31
 
 
 
 
 
 
 
 
 
9a68e0a
46ff99b
9a68e0a
7c3177c
46ff99b
9a68e0a
 
46ff99b
 
7e40a31
 
 
9a68e0a
7e40a31
46ff99b
 
 
7e40a31
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from typing import Tuple, Union

import gradio as gr
import numpy as np
import see2sound
import spaces
import torch
import yaml
import os
from huggingface_hub import snapshot_download

model_id = "rishitdagli/see-2-sound"
base_path = snapshot_download(repo_id=model_id)

with open("config.yaml", "r") as file:
    data = yaml.safe_load(file)
data_str = yaml.dump(data)
updated_data_str = data_str.replace("checkpoints", base_path)
updated_data = yaml.safe_load(updated_data_str)
with open("config.yaml", "w") as file:
    yaml.safe_dump(updated_data, file)

model = see2sound.See2Sound(config_path="config.yaml")
model.setup()

CACHE_DIR = "gradio_cached_examples"

#for local cache
def load_cached_example_outputs(example_index: int) -> Tuple[str, str]:
    cached_dir = os.path.join(CACHE_DIR, str(example_index))  # Use the example index to find the directory
    cached_image_path = os.path.join(cached_dir, "processed_image.png")
    cached_audio_path = os.path.join(cached_dir, "audio.wav")

    # Ensure cached files exist
    if os.path.exists(cached_image_path) and os.path.exists(cached_audio_path):
        return cached_image_path, cached_audio_path
    else:
        raise FileNotFoundError(f"Cached outputs not found for example {example_index}")

# Function to handle the example click, it now accepts arbitrary arguments
def on_example_click(*args, **kwargs):
    return load_cached_example_outputs(1)  # Always load example 1 for now


@spaces.GPU(duration=280)
@torch.no_grad()
def process_image(
    image: str, num_audios: int, prompt: Union[str, None], steps: Union[int, None]
) -> Tuple[str, str]:
    model.run(
        path=image,
        output_path="audio.wav",
        num_audios=num_audios,
        prompt=prompt,
        steps=steps,
    )
    return image, "audio.wav"


description_text = """# SEE-2-SOUND 🔊 Demo
Official demo for *SEE-2-SOUND 🔊: Zero-Shot Spatial Environment-to-Spatial Sound*.
Please refer to our [paper](https://arxiv.org/abs/2406.06612), [project page](https://see2sound.github.io/), or [github](https://github.com/see2sound/see2sound) for more details.
> Note: You should make sure that your hardware supports spatial audio.
This demo allows you to generate spatial audio given an image. Upload an image (with an optional text prompt in the advanced settings) to geenrate spatial audio to accompany the image.
"""

css = """
h1 {
    text-align: center;
}
"""

with gr.Blocks(css=css) as demo:
    gr.Markdown(description_text)

    with gr.Row():
        with gr.Column():
            image = gr.Image(
                label="Select an image", sources=["upload", "webcam"], type="filepath"
            )

            with gr.Accordion("Advanced Settings", open=False):
                steps = gr.Slider(
                    label="Diffusion Steps", minimum=1, maximum=1000, step=1, value=500
                )
                prompt = gr.Text(
                    label="Prompt",
                    show_label=True,
                    max_lines=1,
                    placeholder="Enter your prompt",
                    container=True,
                )
                num_audios = gr.Slider(
                    label="Number of Audios", minimum=1, maximum=10, step=1, value=3
                )

            submit_button = gr.Button("Submit")

        with gr.Column():
            processed_image = gr.Image(label="Processed Image")
            generated_audio = gr.Audio(
                label="Generated Audio",
                show_download_button=True,
                show_share_button=True,
                waveform_options=gr.WaveformOptions(
                    waveform_color="#01C6FF",
                    waveform_progress_color="#0066B4",
                    show_controls=True,
                ),
            )

    gr.Examples(
        examples=[["examples/1.png", 3, "A scenic mountain view", 500]],  # Example input
        inputs=[image, num_audios, prompt, steps],
        outputs=[processed_image, generated_audio],
        cache_examples=True,  # Cache examples to avoid running the model
        fn=on_example_click  # Load the cached output when the example is clicked
    )

    gr.on(
        triggers=[submit_button.click],
        fn=process_image,
        inputs=[image, num_audios, prompt, steps],
        outputs=[processed_image, generated_audio],
    )

if __name__ == "__main__":
    demo.launch()