Blip_QA / app.py
ParisNeo
first working
5110eb7
raw
history blame
2.22 kB
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
import urllib.request
import io
from pathlib import Path
from blip_vqa import blip_vqa
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
image_size = 384
class App():
def __init__(self):
self.selected_model=0
# Load blip for question answer
print("Loading Blip for question answering")
model_url = str(Path(__file__).parent/'blip_vqa.pth')
self.qa_model = blip_vqa(pretrained=model_url, image_size=image_size, vit='base')
self.qa_model.eval()
self.qa_model = self.qa_model.to(device)
with gr.Blocks() as demo:
with gr.Row():
self.image_source = gr.inputs.Image(shape=(224, 224))
with gr.Tabs():
with gr.Tab("Question/Answer"):
self.question = gr.inputs.Textbox(label="Custom question (if applicable)", default="where is the right hand?")
self.answer = gr.Button("Ask")
self.lbl_caption = gr.outputs.Label(label="Caption")
self.answer.click(self.answer_question_image, [self.image_source, self.question], self.lbl_caption)
# Launch the interface
demo.launch()
def answer_question_image(self, img, custom_question="Describe this image"):
# Load the selected PyTorch model
# Preprocess the image
preprocess = transforms.Compose([
transforms.Resize((image_size,image_size),interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])
img = preprocess(Image.fromarray(img.astype('uint8'), 'RGB'))
# Make a prediction with the model
with torch.no_grad():
output = self.qa_model(img.unsqueeze(0).to(device), custom_question, train=False, inference='generate')
answer = output
# Return the predicted label as a string
return answer[0]
app = App()