File size: 2,224 Bytes
5110eb7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
import urllib.request
import io
from pathlib import Path
from blip_vqa import blip_vqa
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
image_size = 384
class App():
def __init__(self):
self.selected_model=0
# Load blip for question answer
print("Loading Blip for question answering")
model_url = str(Path(__file__).parent/'blip_vqa.pth')
self.qa_model = blip_vqa(pretrained=model_url, image_size=image_size, vit='base')
self.qa_model.eval()
self.qa_model = self.qa_model.to(device)
with gr.Blocks() as demo:
with gr.Row():
self.image_source = gr.inputs.Image(shape=(224, 224))
with gr.Tabs():
with gr.Tab("Question/Answer"):
self.question = gr.inputs.Textbox(label="Custom question (if applicable)", default="where is the right hand?")
self.answer = gr.Button("Ask")
self.lbl_caption = gr.outputs.Label(label="Caption")
self.answer.click(self.answer_question_image, [self.image_source, self.question], self.lbl_caption)
# Launch the interface
demo.launch()
def answer_question_image(self, img, custom_question="Describe this image"):
# Load the selected PyTorch model
# Preprocess the image
preprocess = transforms.Compose([
transforms.Resize((image_size,image_size),interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])
img = preprocess(Image.fromarray(img.astype('uint8'), 'RGB'))
# Make a prediction with the model
with torch.no_grad():
output = self.qa_model(img.unsqueeze(0).to(device), custom_question, train=False, inference='generate')
answer = output
# Return the predicted label as a string
return answer[0]
app = App()
|