swimmiing commited on
Commit
334681f
1 Parent(s): a5ed3da

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -16
app.py CHANGED
@@ -3,14 +3,12 @@ import torch
3
  import numpy as np
4
  from modules.models import *
5
  from util import get_prompt_template
 
 
6
  from PIL import Image
7
 
8
 
9
- def greet(name):
10
- return "Hello " + name + "!!"
11
-
12
-
13
- def main():
14
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
15
 
16
  # Get model
@@ -23,23 +21,65 @@ def main():
23
  prompt_template, text_pos_at_prompt, prompt_length = get_prompt_template()
24
 
25
  # Input pre processing
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  # Inference
28
  placeholder_tokens = model.get_placeholder_token(prompt_template.replace('{}', ''))
29
- # audio_driven_embedding = model.encode_audio(audios.to(model.device), placeholder_tokens, text_pos_at_prompt,
30
- # prompt_length)
31
 
32
  # Localization result
33
- # out_dict = model(images.to(model.device), audio_driven_embedding, 352)
34
- # seg = out_dict['heatmap'][j:j + 1]
35
- # seg_image = ((1 - seg.squeeze().detach().cpu().numpy()) * 255).astype(np.uint8)
36
- # seg_image = Image.fromarray(seg_image)
37
- # heatmap_image = cv2.applyColorMap(np.array(seg_image), cv2.COLORMAP_JET)
38
- # overlaid_image = cv2.addWeighted(np.array(original_image), 0.5, heatmap_image, 0.5, 0)
 
 
39
 
40
 
41
  if __name__ == "__main__":
42
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
43
- iface.launch()
 
 
 
 
 
 
 
44
 
45
- main()
 
3
  import numpy as np
4
  from modules.models import *
5
  from util import get_prompt_template
6
+ from torchvision import transforms as vt
7
+ import torchaudio
8
  from PIL import Image
9
 
10
 
11
+ def greet(audio, image):
 
 
 
 
12
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
13
 
14
  # Get model
 
21
  prompt_template, text_pos_at_prompt, prompt_length = get_prompt_template()
22
 
23
  # Input pre processing
24
+ sample_rate, audio = audio
25
+ audio = audio.astype(np.float32, order='C') / 32768.0
26
+ desired_sample_rate = 16000
27
+ set_length = 10
28
+
29
+ audio_file = torch.from_numpy(audio)
30
+
31
+ if desired_sample_rate != sample_rate:
32
+ audio_file = torchaudio.functional.resample(audio_file, sample_rate, desired_sample_rate)
33
+
34
+ if audio_file.shape[0] == 2:
35
+ audio_file = torch.concat([audio_file[0], audio_file[1]], dim=0) # Stereo -> mono (x2 duration)
36
+
37
+ audio_file.squeeze(0)
38
+
39
+ if audio_file.shape[0] > (desired_sample_rate * set_length):
40
+ audio_file = audio_file[:desired_sample_rate * set_length]
41
+
42
+ # zero padding
43
+ if audio_file.shape[0] < (desired_sample_rate * set_length):
44
+ pad_len = (desired_sample_rate * set_length) - audio_file.shape[0]
45
+ pad_val = torch.zeros(pad_len)
46
+ audio_file = torch.cat((audio_file, pad_val), dim=0)
47
+
48
+ audio_file = audio_file.unsqueeze(0)
49
+
50
+ image_transform = vt.Compose([
51
+ vt.Resize((352, 352), vt.InterpolationMode.BICUBIC),
52
+ vt.ToTensor(),
53
+ vt.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), # CLIP
54
+ ])
55
+
56
+ image_file = image_transform(image)
57
 
58
  # Inference
59
  placeholder_tokens = model.get_placeholder_token(prompt_template.replace('{}', ''))
60
+ audio_driven_embedding = model.encode_audio(audio_file.to(model.device), placeholder_tokens, text_pos_at_prompt,
61
+ prompt_length)
62
 
63
  # Localization result
64
+ out_dict = model(image_file.to(model.device), audio_driven_embedding, 352)
65
+ seg = out_dict['heatmap'][j:j + 1]
66
+ seg_image = ((1 - seg.squeeze().detach().cpu().numpy()) * 255).astype(np.uint8)
67
+ seg_image = Image.fromarray(seg_image)
68
+ heatmap_image = cv2.applyColorMap(np.array(seg_image), cv2.COLORMAP_JET)
69
+ overlaid_image = cv2.addWeighted(np.array(image), 0.5, heatmap_image, 0.5, 0)
70
+
71
+ return overlaid_image
72
 
73
 
74
  if __name__ == "__main__":
75
+ description = 'hello world'
76
+
77
+ demo = gr.Interface(
78
+ fn=greet,
79
+ inputs=[gr.Image(type='pil'), gr.Audio()],
80
+ outputs=gr.Image(type="pil"),
81
+ title='AudioToken',
82
+ description=description,
83
+ )
84
 
85
+ demo.launch()