File size: 4,619 Bytes
3c8d75e
b20af9f
 
 
 
334681f
 
b20af9f
35843d7
b20af9f
3c8d75e
00a76b6
833fa47
b20af9f
 
 
 
 
 
 
 
 
 
 
334681f
 
 
 
 
 
 
b623c93
 
 
 
 
334681f
 
 
b623c93
334681f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f752a6
b20af9f
 
 
334681f
 
b20af9f
 
334681f
ddf85f7
334681f
 
5592209
f8ab558
334681f
f8ab558
334681f
 
b20af9f
ddf85f7
f455d39
 
9f952e8
38c6aa1
 
fd01b14
 
9f952e8
a9320d1
a063f1d
 
5c4394c
6c8dc4b
fd01b14
5c4394c
 
 
 
 
 
 
 
59910f7
 
5c4394c
85b3d75
a9320d1
 
 
 
6f752a6
a9320d1
85b3d75
 
a9320d1
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import gradio as gr
import torch
import numpy as np
from modules.models import *
from util import get_prompt_template
from torchvision import transforms as vt
import torchaudio
from PIL import Image
import cv2


def greet(image, audio):
    device = torch.device('cpu')

    # Get model
    model_conf_file = f'./config/model/ACL_ViT16.yaml'
    model = ACL(model_conf_file, device)
    model.train(False)
    model.load('./pretrain/Param_best.pth')

    # Get placeholder text
    prompt_template, text_pos_at_prompt, prompt_length = get_prompt_template()

    # Input pre processing
    sample_rate, audio = audio
    audio = audio.astype(np.float32, order='C') / 32768.0
    desired_sample_rate = 16000
    set_length = 10

    audio_file = torch.from_numpy(audio)

    if len(audio_file.shape) == 2:
        audio_file = torch.concat([audio_file[:, 0:1], audio_file[:, 1:2]], dim=0).T  # Stereo -> mono (x2 duration)
    else:
        audio_file = audio_file.unsqueeze(0)

    if desired_sample_rate != sample_rate:
        audio_file = torchaudio.functional.resample(audio_file, sample_rate, desired_sample_rate)

    audio_file = audio_file.squeeze(0)

    if audio_file.shape[0] > (desired_sample_rate * set_length):
        audio_file = audio_file[:desired_sample_rate * set_length]

    # zero padding
    if audio_file.shape[0] < (desired_sample_rate * set_length):
        pad_len = (desired_sample_rate * set_length) - audio_file.shape[0]
        pad_val = torch.zeros(pad_len)
        audio_file = torch.cat((audio_file, pad_val), dim=0)

    audio_file = audio_file.unsqueeze(0)

    image_transform = vt.Compose([
        vt.Resize((352, 352), vt.InterpolationMode.BICUBIC),
        vt.ToTensor(),
        vt.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),  # CLIP
    ])

    image_file = image_transform(image).unsqueeze(0)

    # Inference
    placeholder_tokens = model.get_placeholder_token(prompt_template.replace('{}', ''))
    audio_driven_embedding = model.encode_audio(audio_file.to(model.device), placeholder_tokens, text_pos_at_prompt,
                                                prompt_length)

    # Localization result
    out_dict = model(image_file.to(model.device), audio_driven_embedding, 352)
    seg = out_dict['heatmap'][0:1]
    seg_image = ((1 - seg.squeeze().detach().cpu().numpy()) * 255).astype(np.uint8)
    seg_image = Image.fromarray(seg_image)

    seg_image = seg_image.resize(image.size, Image.BICUBIC)
    heatmap_image = cv2.applyColorMap(np.array(seg_image), cv2.COLORMAP_JET)
    overlaid_image = cv2.addWeighted(np.array(image), 0.5, heatmap_image, 0.5, 0)

    return overlaid_image


title = "Audio-Grounded Contrastive Learning"

description = """<p>
This is a simple demo of our WACV'24 paper 'Can CLIP Help Sound Source Localization?', zero-shot visual sound localization.<br><br>
To use it simply upload an image and corresponding audio to mask (identify in the image), or use one of the examples below and click ‘submit’.<br><br>
Results will show up in a few seconds. <br><br>
It is recommended to use audio sources with a sample rate of 16 kHz or higher, and the model does not utilize audio beyond the initial 10 seconds.
</p>"""

article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2311.04066'>Can CLIP Help Sound Source Localization?</a> | <a href='https://github.com/swimmiing/ACL-SSL'>Offical Github repo</a></p>"

examples = [['./asset/web_image1.jpeg', './asset/web_dog_barking.wav'],
            ['./asset/web_image1.jpeg', './asset/web_child_laugh.wav'],
            ['./asset/web_image1.jpeg', './asset/web_car_horns.wav'],
            ['./asset/web_image1.jpeg', './asset/web_motorcycle_pass_by.wav'],
            ['./asset/web_image2.jpeg', './asset/web_dog_barking.wav'],
            ['./asset/web_image2.jpeg', './asset/web_female_speech.wav'],
            ['./asset/web_image2.jpeg', './asset/web_car_horns.wav'],
            ['./asset/web_image3.jpeg', './asset/web_motorcycle_pass_by.wav'],
            ['./asset/web_image3.jpeg', './asset/web_car_horns.wav'],
            ['./asset/web_image3.jpeg', './asset/web_wave.wav'],
            ['./asset/web_image4.jpeg', './asset/web_car_horns.wav'],
            ['./asset/web_image4.jpeg', './asset/web_wave.wav'],
            ['./asset/web_image4.jpeg', './asset/web_horse.wav'],
            ]

demo = gr.Interface(
    fn=greet,
    inputs=[gr.Image(type='pil'), gr.Audio()],
    outputs=gr.Image(type="pil"),
    title=title,
    description=description,
    article=article,
    examples=examples
)

demo.launch(debug=True)