Spaces:
Runtime error
Runtime error
from PIL import Image | |
import requests | |
import torch | |
from torchvision import transforms | |
from torchvision.transforms.functional import InterpolationMode | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
import numpy as np | |
from transformers import pipeline | |
import gradio as gr | |
from models.blip import blip_decoder | |
image_size = 384 | |
transform = transforms.Compose([ | |
transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC), | |
transforms.ToTensor(), | |
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) | |
]) | |
model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth' | |
model = blip_decoder(pretrained=model_url, image_size=384, vit='large') | |
model.eval() | |
model = model.to(device) | |
def getModelPath(language): | |
if language == 'English': | |
path = None | |
elif language == 'German': | |
path = "Helsinki-NLP/opus-mt-en-de" | |
elif language == 'French': | |
path = "Helsinki-NLP/opus-mt-en-fr" | |
elif language == 'Spanish': | |
path = "Helsinki-NLP/opus-mt-en-es" | |
elif language == 'Chinese': | |
path = "Helsinki-NLP/opus-mt-en-zh" | |
elif language == 'Ukranian': | |
path = "Helsinki-NLP/opus-mt-en-uk" | |
elif language == 'Swedish': | |
path = "Helsinki-NLP/opus-mt-en-sv" | |
elif language == 'Arabic': | |
path = "Helsinki-NLP/opus-mt-en-ar" | |
elif language == 'Italian': | |
path = "Helsinki-NLP/opus-mt-en-it" | |
elif language == 'Hindi': | |
path = "Helsinki-NLP/opus-mt-en-hi" | |
return(path) | |
def inference(input_img,strategy,language): | |
image = transform(input_img).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
if strategy == "Beam search": | |
cap = model.generate(image, sample=False, num_beams=3, max_length=20, min_length=5) | |
else: | |
cap = model.generate(image, sample=True, top_p=0.9, max_length=20, min_length=5) | |
modelpath = getModelPath(language) | |
if modelpath: | |
translator = pipeline("translation", model=modelpath) | |
trans_cap = translator(cap[0]) | |
tc = trans_cap[0]['translation_text'] | |
return str(tc) | |
else: | |
return str(cap[0]) | |
print("HI") | |
description = "A pipeline of BLIP image captioning and Helsinki translation in order to generate image captions in a language of your choice either with beam search (deterministic) or nucleus sampling (stochastic). Enjoy! Is the language you want to use missing? Let me know and I'll integrate it." | |
inputs_ = [gr.inputs.Image(type='pil', label="Input Image"),gr.inputs.Radio(choices=['Beam search','Nucleus sampling'], type="value", default="Nucleus sampling", label="Mode"), gr.inputs.Radio(choices=['English','German', 'French', 'Spanish', 'Chinese', 'Ukranian', 'Swedish', 'Arabic', 'Italian', 'Hindi'],type="value", default = 'German',label="Language")] | |
outputs_ = gr.outputs.Textbox(label="Output") | |
iface = gr.Interface(inference, inputs_, outputs_, description=description) | |
iface.launch(debug=True,show_error=True) |