Spaces:
Running
Running
import gradio as gr | |
import torch | |
import numpy as np | |
from modules.models import * | |
from util import get_prompt_template | |
from PIL import Image | |
def greet(name): | |
return "Hello " + name + "!!" | |
def main(): | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
# Get model | |
model_conf_file = f'./config/model/ACL_ViT16.yaml' | |
model = ACL(model_conf_file, device) | |
model.train(False) | |
model.load('./pretrain/Param_best.pth') | |
# Get placeholder text | |
prompt_template, text_pos_at_prompt, prompt_length = get_prompt_template() | |
# Input pre processing | |
# Inference | |
placeholder_tokens = model.get_placeholder_token(prompt_template.replace('{}', '')) | |
# audio_driven_embedding = model.encode_audio(audios.to(model.device), placeholder_tokens, text_pos_at_prompt, | |
# prompt_length) | |
# Localization result | |
# out_dict = model(images.to(model.device), audio_driven_embedding, 352) | |
# seg = out_dict['heatmap'][j:j + 1] | |
# seg_image = ((1 - seg.squeeze().detach().cpu().numpy()) * 255).astype(np.uint8) | |
# seg_image = Image.fromarray(seg_image) | |
heatmap_image = cv2.applyColorMap(np.array(seg_image), cv2.COLORMAP_JET) | |
# overlaid_image = cv2.addWeighted(np.array(original_image), 0.5, heatmap_image, 0.5, 0) | |
if __name__ == "__main__": | |
iface = gr.Interface(fn=greet, inputs="text", outputs="text") | |
iface.launch() | |
main() | |