Spaces:
Running
Running
File size: 4,619 Bytes
3c8d75e b20af9f 334681f b20af9f 35843d7 b20af9f 3c8d75e 00a76b6 833fa47 b20af9f 334681f b623c93 334681f b623c93 334681f 6f752a6 b20af9f 334681f b20af9f 334681f ddf85f7 334681f 5592209 f8ab558 334681f f8ab558 334681f b20af9f ddf85f7 f455d39 9f952e8 38c6aa1 fd01b14 9f952e8 a9320d1 a063f1d 5c4394c 6c8dc4b fd01b14 5c4394c 59910f7 5c4394c 85b3d75 a9320d1 6f752a6 a9320d1 85b3d75 a9320d1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import gradio as gr
import torch
import numpy as np
from modules.models import *
from util import get_prompt_template
from torchvision import transforms as vt
import torchaudio
from PIL import Image
import cv2
def greet(image, audio):
device = torch.device('cpu')
# Get model
model_conf_file = f'./config/model/ACL_ViT16.yaml'
model = ACL(model_conf_file, device)
model.train(False)
model.load('./pretrain/Param_best.pth')
# Get placeholder text
prompt_template, text_pos_at_prompt, prompt_length = get_prompt_template()
# Input pre processing
sample_rate, audio = audio
audio = audio.astype(np.float32, order='C') / 32768.0
desired_sample_rate = 16000
set_length = 10
audio_file = torch.from_numpy(audio)
if len(audio_file.shape) == 2:
audio_file = torch.concat([audio_file[:, 0:1], audio_file[:, 1:2]], dim=0).T # Stereo -> mono (x2 duration)
else:
audio_file = audio_file.unsqueeze(0)
if desired_sample_rate != sample_rate:
audio_file = torchaudio.functional.resample(audio_file, sample_rate, desired_sample_rate)
audio_file = audio_file.squeeze(0)
if audio_file.shape[0] > (desired_sample_rate * set_length):
audio_file = audio_file[:desired_sample_rate * set_length]
# zero padding
if audio_file.shape[0] < (desired_sample_rate * set_length):
pad_len = (desired_sample_rate * set_length) - audio_file.shape[0]
pad_val = torch.zeros(pad_len)
audio_file = torch.cat((audio_file, pad_val), dim=0)
audio_file = audio_file.unsqueeze(0)
image_transform = vt.Compose([
vt.Resize((352, 352), vt.InterpolationMode.BICUBIC),
vt.ToTensor(),
vt.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), # CLIP
])
image_file = image_transform(image).unsqueeze(0)
# Inference
placeholder_tokens = model.get_placeholder_token(prompt_template.replace('{}', ''))
audio_driven_embedding = model.encode_audio(audio_file.to(model.device), placeholder_tokens, text_pos_at_prompt,
prompt_length)
# Localization result
out_dict = model(image_file.to(model.device), audio_driven_embedding, 352)
seg = out_dict['heatmap'][0:1]
seg_image = ((1 - seg.squeeze().detach().cpu().numpy()) * 255).astype(np.uint8)
seg_image = Image.fromarray(seg_image)
seg_image = seg_image.resize(image.size, Image.BICUBIC)
heatmap_image = cv2.applyColorMap(np.array(seg_image), cv2.COLORMAP_JET)
overlaid_image = cv2.addWeighted(np.array(image), 0.5, heatmap_image, 0.5, 0)
return overlaid_image
title = "Audio-Grounded Contrastive Learning"
description = """<p>
This is a simple demo of our WACV'24 paper 'Can CLIP Help Sound Source Localization?', zero-shot visual sound localization.<br><br>
To use it simply upload an image and corresponding audio to mask (identify in the image), or use one of the examples below and click ‘submit’.<br><br>
Results will show up in a few seconds. <br><br>
It is recommended to use audio sources with a sample rate of 16 kHz or higher, and the model does not utilize audio beyond the initial 10 seconds.
</p>"""
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2311.04066'>Can CLIP Help Sound Source Localization?</a> | <a href='https://github.com/swimmiing/ACL-SSL'>Offical Github repo</a></p>"
examples = [['./asset/web_image1.jpeg', './asset/web_dog_barking.wav'],
['./asset/web_image1.jpeg', './asset/web_child_laugh.wav'],
['./asset/web_image1.jpeg', './asset/web_car_horns.wav'],
['./asset/web_image1.jpeg', './asset/web_motorcycle_pass_by.wav'],
['./asset/web_image2.jpeg', './asset/web_dog_barking.wav'],
['./asset/web_image2.jpeg', './asset/web_female_speech.wav'],
['./asset/web_image2.jpeg', './asset/web_car_horns.wav'],
['./asset/web_image3.jpeg', './asset/web_motorcycle_pass_by.wav'],
['./asset/web_image3.jpeg', './asset/web_car_horns.wav'],
['./asset/web_image3.jpeg', './asset/web_wave.wav'],
['./asset/web_image4.jpeg', './asset/web_car_horns.wav'],
['./asset/web_image4.jpeg', './asset/web_wave.wav'],
['./asset/web_image4.jpeg', './asset/web_horse.wav'],
]
demo = gr.Interface(
fn=greet,
inputs=[gr.Image(type='pil'), gr.Audio()],
outputs=gr.Image(type="pil"),
title=title,
description=description,
article=article,
examples=examples
)
demo.launch(debug=True)
|