swimmiing's picture
Upload model files
b20af9f
raw
history blame
1.44 kB
import gradio as gr
import torch
import numpy as np
from modules.models import *
from util import get_prompt_template
from PIL import Image
def greet(name):
return "Hello " + name + "!!"
def main():
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# Get model
model_conf_file = f'./config/model/ACL_ViT16.yaml'
model = ACL(model_conf_file, device)
model.train(False)
model.load('./pretrain/Param_best.pth')
# Get placeholder text
prompt_template, text_pos_at_prompt, prompt_length = get_prompt_template()
# Input pre processing
# Inference
placeholder_tokens = model.get_placeholder_token(prompt_template.replace('{}', ''))
# audio_driven_embedding = model.encode_audio(audios.to(model.device), placeholder_tokens, text_pos_at_prompt,
# prompt_length)
# Localization result
# out_dict = model(images.to(model.device), audio_driven_embedding, 352)
# seg = out_dict['heatmap'][j:j + 1]
# seg_image = ((1 - seg.squeeze().detach().cpu().numpy()) * 255).astype(np.uint8)
# seg_image = Image.fromarray(seg_image)
heatmap_image = cv2.applyColorMap(np.array(seg_image), cv2.COLORMAP_JET)
# overlaid_image = cv2.addWeighted(np.array(original_image), 0.5, heatmap_image, 0.5, 0)
if __name__ == "__main__":
iface = gr.Interface(fn=greet, inputs="text", outputs="text")
iface.launch()
main()