|
import argparse |
|
import io |
|
import logging |
|
import os |
|
|
|
import gradio as gr |
|
import requests |
|
from PIL import Image |
|
from pillow_heif import register_heif_opener |
|
from transformers import pipeline |
|
|
|
os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "False") |
|
LOG_LEVEL = os.getenv("LOG_LEVEL", "DEBUG") |
|
MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", 200)) |
|
|
|
MODEL = os.getenv("MODEL", "Salesforce/blip-image-captioning-large") |
|
|
|
register_heif_opener() |
|
|
|
logging.basicConfig(level=LOG_LEVEL) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def setup_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--share", action="store_true", default=False) |
|
return parser.parse_args() |
|
|
|
|
|
def load_image_from_url(url): |
|
try: |
|
response = requests.get(url) |
|
if not response.ok: |
|
raise Exception("Error downloading image") |
|
image = Image.open(io.BytesIO(response.content)) |
|
return image |
|
except Exception as e: |
|
logger.error("Error loading image from URL: %s", e) |
|
raise |
|
|
|
|
|
def graptioner(image, url): |
|
if url and url.strip(): |
|
image = load_image_from_url(url) |
|
width, height = image.size |
|
if width < 1 or height < 1: |
|
raise Exception("Invalid image") |
|
logger.debug("Loaded image size: %sx%s", width, height) |
|
|
|
result = captioner(image) |
|
return result[0]["generated_text"] |
|
|
|
|
|
if __name__ == "__main__": |
|
args = setup_args() |
|
logger.info("Loading model...") |
|
|
|
captioner = pipeline( |
|
"image-to-text", |
|
model=MODEL, |
|
max_new_tokens=MAX_NEW_TOKENS, |
|
) |
|
logger.info("Done loading model.") |
|
iface = gr.Interface( |
|
fn=graptioner, |
|
inputs=[ |
|
gr.Image(type="pil", label="Upload Image"), |
|
gr.Textbox(lines=1, placeholder="Image URL", label="Image URL"), |
|
], |
|
outputs=["text"], |
|
allow_flagging="never", |
|
) |
|
iface.launch(share=args.share) |
|
|