Spaces:
Running
Running
import gradio as gr | |
import torch | |
import numpy as np | |
from modules.models import * | |
from util import get_prompt_template | |
from torchvision import transforms as vt | |
import torchaudio | |
from PIL import Image | |
import cv2 | |
def greet(image, audio): | |
device = torch.device('cpu') | |
# Get model | |
model_conf_file = f'./config/model/ACL_ViT16.yaml' | |
model = ACL(model_conf_file, device) | |
model.train(False) | |
model.load('./pretrain/Param_best.pth') | |
# Get placeholder text | |
prompt_template, text_pos_at_prompt, prompt_length = get_prompt_template() | |
# Input pre processing | |
sample_rate, audio = audio | |
audio = audio.astype(np.float32, order='C') / 32768.0 | |
desired_sample_rate = 16000 | |
set_length = 10 | |
audio_file = torch.from_numpy(audio) | |
if len(audio_file.shape) == 2: | |
audio_file = torch.concat([audio_file[:, 0:1], audio_file[:, 1:2]], dim=0).T # Stereo -> mono (x2 duration) | |
else: | |
audio_file = audio_file.unsqueeze(0) | |
if desired_sample_rate != sample_rate: | |
audio_file = torchaudio.functional.resample(audio_file, sample_rate, desired_sample_rate) | |
audio_file = audio_file.squeeze(0) | |
if audio_file.shape[0] > (desired_sample_rate * set_length): | |
audio_file = audio_file[:desired_sample_rate * set_length] | |
# zero padding | |
if audio_file.shape[0] < (desired_sample_rate * set_length): | |
pad_len = (desired_sample_rate * set_length) - audio_file.shape[0] | |
pad_val = torch.zeros(pad_len) | |
audio_file = torch.cat((audio_file, pad_val), dim=0) | |
audio_file = audio_file.unsqueeze(0) | |
image_transform = vt.Compose([ | |
vt.Resize((352, 352), vt.InterpolationMode.BICUBIC), | |
vt.ToTensor(), | |
vt.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), # CLIP | |
]) | |
image_file = image_transform(image).unsqueeze(0) | |
# Inference | |
placeholder_tokens = model.get_placeholder_token(prompt_template.replace('{}', '')) | |
audio_driven_embedding = model.encode_audio(audio_file.to(model.device), placeholder_tokens, text_pos_at_prompt, | |
prompt_length) | |
# Localization result | |
out_dict = model(image_file.to(model.device), audio_driven_embedding, 352) | |
seg = out_dict['heatmap'][0:1] | |
seg_image = ((1 - seg.squeeze().detach().cpu().numpy()) * 255).astype(np.uint8) | |
seg_image = Image.fromarray(seg_image) | |
seg_image = seg_image.resize(image.size, Image.BICUBIC) | |
heatmap_image = cv2.applyColorMap(np.array(seg_image), cv2.COLORMAP_JET) | |
overlaid_image = cv2.addWeighted(np.array(image), 0.5, heatmap_image, 0.5, 0) | |
return overlaid_image | |
title = "Audio-Grounded Contrastive Learning" | |
description = """<p> | |
This is a simple demo of our WACV'24 paper 'Can CLIP Help Sound Source Localization?', zero-shot visual sound localization.<br><br> | |
To use it simply upload an image and corresponding audio to mask (identify in the image), or use one of the examples below and click ‘submit’.<br><br> | |
Results will show up in a few seconds. <br><br> | |
It is recommended to use audio sources with a sample rate of 16 kHz or higher, and the model does not utilize audio beyond the initial 10 seconds. | |
</p>""" | |
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2311.04066'>Can CLIP Help Sound Source Localization?</a> | <a href='https://github.com/swimmiing/ACL-SSL'>Offical Github repo</a></p>" | |
examples = [['./asset/web_image1.jpeg', './asset/web_dog_barking.wav'], | |
['./asset/web_image1.jpeg', './asset/web_child_laugh.wav'], | |
['./asset/web_image1.jpeg', './asset/web_car_horns.wav'], | |
['./asset/web_image1.jpeg', './asset/web_motorcycle_pass_by.wav'], | |
['./asset/web_image2.jpeg', './asset/web_dog_barking.wav'], | |
['./asset/web_image2.jpeg', './asset/web_female_speech.wav'], | |
['./asset/web_image2.jpeg', './asset/web_car_horns.wav'], | |
['./asset/web_image3.jpeg', './asset/web_motorcycle_pass_by.wav'], | |
['./asset/web_image3.jpeg', './asset/web_car_horns.wav'], | |
['./asset/web_image3.jpeg', './asset/web_wave.wav'], | |
['./asset/web_image4.jpeg', './asset/web_car_horns.wav'], | |
['./asset/web_image4.jpeg', './asset/web_wave.wav'], | |
['./asset/web_image4.jpeg', './asset/web_horse.wav'], | |
] | |
demo = gr.Interface( | |
fn=greet, | |
inputs=[gr.Image(type='pil'), gr.Audio()], | |
outputs=gr.Image(type="pil"), | |
title=title, | |
description=description, | |
article=article, | |
examples=examples | |
) | |
demo.launch(debug=True) | |