Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -3,14 +3,12 @@ import torch
|
|
3 |
import numpy as np
|
4 |
from modules.models import *
|
5 |
from util import get_prompt_template
|
|
|
|
|
6 |
from PIL import Image
|
7 |
|
8 |
|
9 |
-
def greet(
|
10 |
-
return "Hello " + name + "!!"
|
11 |
-
|
12 |
-
|
13 |
-
def main():
|
14 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
15 |
|
16 |
# Get model
|
@@ -23,23 +21,65 @@ def main():
|
|
23 |
prompt_template, text_pos_at_prompt, prompt_length = get_prompt_template()
|
24 |
|
25 |
# Input pre processing
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
# Inference
|
28 |
placeholder_tokens = model.get_placeholder_token(prompt_template.replace('{}', ''))
|
29 |
-
|
30 |
-
|
31 |
|
32 |
# Localization result
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
|
|
39 |
|
40 |
|
41 |
if __name__ == "__main__":
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
-
|
|
|
3 |
import numpy as np
|
4 |
from modules.models import *
|
5 |
from util import get_prompt_template
|
6 |
+
from torchvision import transforms as vt
|
7 |
+
import torchaudio
|
8 |
from PIL import Image
|
9 |
|
10 |
|
11 |
+
def greet(audio, image):
|
|
|
|
|
|
|
|
|
12 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
13 |
|
14 |
# Get model
|
|
|
21 |
prompt_template, text_pos_at_prompt, prompt_length = get_prompt_template()
|
22 |
|
23 |
# Input pre processing
|
24 |
+
sample_rate, audio = audio
|
25 |
+
audio = audio.astype(np.float32, order='C') / 32768.0
|
26 |
+
desired_sample_rate = 16000
|
27 |
+
set_length = 10
|
28 |
+
|
29 |
+
audio_file = torch.from_numpy(audio)
|
30 |
+
|
31 |
+
if desired_sample_rate != sample_rate:
|
32 |
+
audio_file = torchaudio.functional.resample(audio_file, sample_rate, desired_sample_rate)
|
33 |
+
|
34 |
+
if audio_file.shape[0] == 2:
|
35 |
+
audio_file = torch.concat([audio_file[0], audio_file[1]], dim=0) # Stereo -> mono (x2 duration)
|
36 |
+
|
37 |
+
audio_file.squeeze(0)
|
38 |
+
|
39 |
+
if audio_file.shape[0] > (desired_sample_rate * set_length):
|
40 |
+
audio_file = audio_file[:desired_sample_rate * set_length]
|
41 |
+
|
42 |
+
# zero padding
|
43 |
+
if audio_file.shape[0] < (desired_sample_rate * set_length):
|
44 |
+
pad_len = (desired_sample_rate * set_length) - audio_file.shape[0]
|
45 |
+
pad_val = torch.zeros(pad_len)
|
46 |
+
audio_file = torch.cat((audio_file, pad_val), dim=0)
|
47 |
+
|
48 |
+
audio_file = audio_file.unsqueeze(0)
|
49 |
+
|
50 |
+
image_transform = vt.Compose([
|
51 |
+
vt.Resize((352, 352), vt.InterpolationMode.BICUBIC),
|
52 |
+
vt.ToTensor(),
|
53 |
+
vt.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), # CLIP
|
54 |
+
])
|
55 |
+
|
56 |
+
image_file = image_transform(image)
|
57 |
|
58 |
# Inference
|
59 |
placeholder_tokens = model.get_placeholder_token(prompt_template.replace('{}', ''))
|
60 |
+
audio_driven_embedding = model.encode_audio(audio_file.to(model.device), placeholder_tokens, text_pos_at_prompt,
|
61 |
+
prompt_length)
|
62 |
|
63 |
# Localization result
|
64 |
+
out_dict = model(image_file.to(model.device), audio_driven_embedding, 352)
|
65 |
+
seg = out_dict['heatmap'][j:j + 1]
|
66 |
+
seg_image = ((1 - seg.squeeze().detach().cpu().numpy()) * 255).astype(np.uint8)
|
67 |
+
seg_image = Image.fromarray(seg_image)
|
68 |
+
heatmap_image = cv2.applyColorMap(np.array(seg_image), cv2.COLORMAP_JET)
|
69 |
+
overlaid_image = cv2.addWeighted(np.array(image), 0.5, heatmap_image, 0.5, 0)
|
70 |
+
|
71 |
+
return overlaid_image
|
72 |
|
73 |
|
74 |
if __name__ == "__main__":
|
75 |
+
description = 'hello world'
|
76 |
+
|
77 |
+
demo = gr.Interface(
|
78 |
+
fn=greet,
|
79 |
+
inputs=[gr.Image(type='pil'), gr.Audio()],
|
80 |
+
outputs=gr.Image(type="pil"),
|
81 |
+
title='AudioToken',
|
82 |
+
description=description,
|
83 |
+
)
|
84 |
|
85 |
+
demo.launch()
|