|
import gradio as gr |
|
import torch |
|
from torchvision import transforms |
|
from PIL import Image |
|
import urllib.request |
|
import io |
|
from pathlib import Path |
|
|
|
from blip_vqa import blip_vqa |
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
image_size = 384 |
|
|
|
class App(): |
|
def __init__(self): |
|
self.selected_model=0 |
|
|
|
|
|
print("Loading Blip for question answering") |
|
model_url = str(Path(__file__).parent/'blip_vqa.pth') |
|
self.qa_model = blip_vqa(pretrained=model_url, image_size=image_size, vit='base') |
|
self.qa_model.eval() |
|
self.qa_model = self.qa_model.to(device) |
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
self.image_source = gr.inputs.Image(shape=(224, 224)) |
|
with gr.Tabs(): |
|
with gr.Tab("Question/Answer"): |
|
self.question = gr.inputs.Textbox(label="Custom question (if applicable)", default="where is the right hand?") |
|
self.answer = gr.Button("Ask") |
|
self.lbl_caption = gr.outputs.Label(label="Caption") |
|
self.answer.click(self.answer_question_image, [self.image_source, self.question], self.lbl_caption) |
|
|
|
demo.launch() |
|
|
|
|
|
|
|
def answer_question_image(self, img, custom_question="Describe this image"): |
|
|
|
|
|
|
|
preprocess = transforms.Compose([ |
|
transforms.Resize((image_size,image_size),interpolation=transforms.InterpolationMode.BICUBIC), |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) |
|
]) |
|
img = preprocess(Image.fromarray(img.astype('uint8'), 'RGB')) |
|
|
|
|
|
with torch.no_grad(): |
|
output = self.qa_model(img.unsqueeze(0).to(device), custom_question, train=False, inference='generate') |
|
answer = output |
|
|
|
|
|
return answer[0] |
|
|
|
app = App() |
|
|
|
|
|
|