import gradio as gr import torch import numpy as np from modules.models import * from util import get_prompt_template from PIL import Image def greet(name): return "Hello " + name + "!!" def main(): device = "cuda:0" if torch.cuda.is_available() else "cpu" # Get model model_conf_file = f'./config/model/ACL_ViT16.yaml' model = ACL(model_conf_file, device) model.train(False) model.load('./pretrain/Param_best.pth') # Get placeholder text prompt_template, text_pos_at_prompt, prompt_length = get_prompt_template() # Input pre processing # Inference placeholder_tokens = model.get_placeholder_token(prompt_template.replace('{}', '')) # audio_driven_embedding = model.encode_audio(audios.to(model.device), placeholder_tokens, text_pos_at_prompt, # prompt_length) # Localization result # out_dict = model(images.to(model.device), audio_driven_embedding, 352) # seg = out_dict['heatmap'][j:j + 1] # seg_image = ((1 - seg.squeeze().detach().cpu().numpy()) * 255).astype(np.uint8) # seg_image = Image.fromarray(seg_image) heatmap_image = cv2.applyColorMap(np.array(seg_image), cv2.COLORMAP_JET) # overlaid_image = cv2.addWeighted(np.array(original_image), 0.5, heatmap_image, 0.5, 0) if __name__ == "__main__": iface = gr.Interface(fn=greet, inputs="text", outputs="text") iface.launch() main()