File size: 2,224 Bytes
5110eb7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
import urllib.request
import io
from pathlib import Path

from blip_vqa import blip_vqa

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
image_size = 384

class App():
    def __init__(self):
        self.selected_model=0
        
        # Load blip for question answer
        print("Loading Blip for question answering")
        model_url = str(Path(__file__).parent/'blip_vqa.pth')
        self.qa_model = blip_vqa(pretrained=model_url, image_size=image_size, vit='base')
        self.qa_model.eval()
        self.qa_model = self.qa_model.to(device)

        
        
        with gr.Blocks() as demo:
            with gr.Row():
                self.image_source = gr.inputs.Image(shape=(224, 224))
                with gr.Tabs():
                    with gr.Tab("Question/Answer"):
                        self.question = gr.inputs.Textbox(label="Custom question (if applicable)", default="where is the right hand?")
                        self.answer = gr.Button("Ask")
                        self.lbl_caption = gr.outputs.Label(label="Caption")
                        self.answer.click(self.answer_question_image, [self.image_source, self.question], self.lbl_caption)
        # Launch the interface
        demo.launch()
        
        

    def answer_question_image(self, img, custom_question="Describe this image"):
        # Load the selected PyTorch model
        
        # Preprocess the image
        preprocess = transforms.Compose([
            transforms.Resize((image_size,image_size),interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
        ])
        img = preprocess(Image.fromarray(img.astype('uint8'), 'RGB'))
        
        # Make a prediction with the model
        with torch.no_grad():
            output = self.qa_model(img.unsqueeze(0).to(device), custom_question, train=False, inference='generate') 
            answer = output
        
        # Return the predicted label as a string
        return answer[0]

app = App()