jadechoghari commited on
Commit
7c3177c
β€’
1 Parent(s): e090f2e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -13
app.py CHANGED
@@ -1,16 +1,18 @@
1
  from typing import Tuple, Union
2
-
3
  import gradio as gr
4
  import numpy as np
5
  import see2sound
6
  import spaces
7
  import torch
8
  import yaml
 
9
  from huggingface_hub import snapshot_download
 
10
 
11
  model_id = "rishitdagli/see-2-sound"
12
  base_path = snapshot_download(repo_id=model_id)
13
 
 
14
  with open("config.yaml", "r") as file:
15
  data = yaml.safe_load(file)
16
  data_str = yaml.dump(data)
@@ -22,20 +24,43 @@ with open("config.yaml", "w") as file:
22
  model = see2sound.See2Sound(config_path="config.yaml")
23
  model.setup()
24
 
 
 
 
 
 
 
 
 
25
 
 
26
  @spaces.GPU(duration=280)
27
  @torch.no_grad()
28
  def process_image(
29
  image: str, num_audios: int, prompt: Union[str, None], steps: Union[int, None]
30
  ) -> Tuple[str, str]:
 
 
 
 
 
 
 
 
 
31
  model.run(
32
  path=image,
33
- output_path="audio.wav",
34
  num_audios=num_audios,
35
  prompt=prompt,
36
  steps=steps,
37
  )
38
- return image, "audio.wav"
 
 
 
 
 
39
 
40
 
41
  description_text = """# SEE-2-SOUND πŸ”Š Demo
@@ -43,8 +68,6 @@ description_text = """# SEE-2-SOUND πŸ”Š Demo
43
  Official demo for *SEE-2-SOUND πŸ”Š: Zero-Shot Spatial Environment-to-Spatial Sound*.
44
  Please refer to our [paper](https://arxiv.org/abs/2406.06612), [project page](https://see2sound.github.io/), or [github](https://github.com/see2sound/see2sound) for more details.
45
  > Note: You should make sure that your hardware supports spatial audio.
46
-
47
- This demo allows you to generate spatial audio given an image. Upload an image (with an optional text prompt in the advanced settings) to geenrate spatial audio to accompany the image.
48
  """
49
 
50
  css = """
@@ -92,18 +115,23 @@ with gr.Blocks(css=css) as demo:
92
  ),
93
  )
94
 
 
95
  gr.Examples(
96
- examples=[[f"examples/{i}.png"] for i in range(1, 10)],
97
- inputs=[image],
 
 
 
 
98
  outputs=[processed_image, generated_audio],
99
- cache_examples="lazy"
 
100
  )
101
 
102
- gr.on(
103
- triggers=[submit_button.click],
104
- fn=process_image,
105
- inputs=[image, num_audios, prompt, steps],
106
- outputs=[processed_image, generated_audio],
107
  )
108
 
109
  if __name__ == "__main__":
 
1
  from typing import Tuple, Union
 
2
  import gradio as gr
3
  import numpy as np
4
  import see2sound
5
  import spaces
6
  import torch
7
  import yaml
8
+ import os
9
  from huggingface_hub import snapshot_download
10
+ from PIL import Image
11
 
12
  model_id = "rishitdagli/see-2-sound"
13
  base_path = snapshot_download(repo_id=model_id)
14
 
15
+ # load and update the configuration
16
  with open("config.yaml", "r") as file:
17
  data = yaml.safe_load(file)
18
  data_str = yaml.dump(data)
 
24
  model = see2sound.See2Sound(config_path="config.yaml")
25
  model.setup()
26
 
27
+ CACHE_DIR = "gradio_cached_examples"
28
+
29
+ # function to create cached output directory
30
+ def create_cache_dir(image_path):
31
+ image_name = os.path.basename(image_path).split('.')[0]
32
+ cached_dir = os.path.join(CACHE_DIR, image_name)
33
+ os.makedirs(cached_dir, exist_ok=True)
34
+ return cached_dir
35
 
36
+ # fn to process image and cache outputs
37
  @spaces.GPU(duration=280)
38
  @torch.no_grad()
39
  def process_image(
40
  image: str, num_audios: int, prompt: Union[str, None], steps: Union[int, None]
41
  ) -> Tuple[str, str]:
42
+ cached_dir = create_cache_dir(image)
43
+ cached_image_path = os.path.join(cached_dir, "processed_image.png")
44
+ cached_audio_path = os.path.join(cached_dir, "audio.wav")
45
+
46
+ # check if cached outputs exist, if yes, return them
47
+ if os.path.exists(cached_image_path) and os.path.exists(cached_audio_path):
48
+ return cached_image_path, cached_audio_path
49
+
50
+ # run the model if outputs are not cached
51
  model.run(
52
  path=image,
53
+ output_path=cached_audio_path, # Save audio in cache directory
54
  num_audios=num_audios,
55
  prompt=prompt,
56
  steps=steps,
57
  )
58
+
59
+ # save the processed image to the cache directory (use original image or any transformations)
60
+ processed_image = Image.open(image) # Assuming image is a file path
61
+ processed_image.save(cached_image_path)
62
+
63
+ return cached_image_path, cached_audio_path
64
 
65
 
66
  description_text = """# SEE-2-SOUND πŸ”Š Demo
 
68
  Official demo for *SEE-2-SOUND πŸ”Š: Zero-Shot Spatial Environment-to-Spatial Sound*.
69
  Please refer to our [paper](https://arxiv.org/abs/2406.06612), [project page](https://see2sound.github.io/), or [github](https://github.com/see2sound/see2sound) for more details.
70
  > Note: You should make sure that your hardware supports spatial audio.
 
 
71
  """
72
 
73
  css = """
 
115
  ),
116
  )
117
 
118
+ # load examples with manually cached outputs
119
  gr.Examples(
120
+ examples=[
121
+ ["examples/1.png", 3, "A scenic mountain view", 500],
122
+ ["examples/2.png", 2, "A forest with birds", 500],
123
+ ["examples/3.png", 1, "A crowded city", 500]
124
+ ],
125
+ inputs=[image, num_audios, prompt, steps],
126
  outputs=[processed_image, generated_audio],
127
+ cache_examples="lazy", # Cache outputs as users interact
128
+ fn=process_image
129
  )
130
 
131
+ submit_button.click(
132
+ process_image,
133
+ inputs=[image, num_audios, prompt, steps],
134
+ outputs=[processed_image, generated_audio]
 
135
  )
136
 
137
  if __name__ == "__main__":