File size: 1,992 Bytes
e7febdd c61d5a8 e7febdd c61d5a8 9aafc47 8c03de9 113c1cc c61d5a8 4d677d2 c61d5a8 5e43e87 4e9e3d3 c61d5a8 4e9e3d3 9aafc47 4d677d2 c61d5a8 dd8705d 8c03de9 c61d5a8 2dfbbfa 8c03de9 c61d5a8 71820b7 4d677d2 c61d5a8 8c03de9 dd8705d 8c03de9 e7febdd c61d5a8 e7febdd c61d5a8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
#!/usr/bin/env python
from __future__ import annotations
import gradio as gr
import PIL.Image
import spaces
import torch
from transformers import AutoProcessor, BlipForConditionalGeneration
from typing import Union
import os
DESCRIPTION = "# Image Captioning with LongCap"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device: ", device)
model_id = "unography/blip-long-cap"
processor = AutoProcessor.from_pretrained(model_id)
model = BlipForConditionalGeneration.from_pretrained(model_id).to(device)
torch.hub.download_url_to_file("https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg", "demo.jpg")
torch.hub.download_url_to_file(
"https://huggingface.co/datasets/nielsr/textcaps-sample/resolve/main/stop_sign.png", "stop_sign.png"
)
torch.hub.download_url_to_file(
"https://cdn.openai.com/dall-e-2/demos/text2im/astronaut/horse/photo/0.jpg", "astronaut.jpg"
)
@spaces.GPU()
def run(image: Union[str, PIL.Image.Image]) -> str:
if isinstance(image, str):
image = Image.open(image)
inputs = processor(images=image, return_tensors="pt").to(device)
out = model.generate(pixel_values=inputs.pixel_values, num_beams=3, repetition_penalty=2.5, max_length=300)
generated_caption = processor.decode(out[0], skip_special_tokens=True)
return generated_caption
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
input_image = gr.Image(type="pil")
run_button = gr.Button("Caption")
output = gr.Textbox(label="Result")
gr.Examples(
examples=[
"demo.jpg",
"stop_sign.png",
"astronaut.jpg",
],
inputs=input_image,
outputs=output,
fn=run,
cache_examples=os.getenv("CACHE_EXAMPLES") == "1",
)
run_button.click(
fn=run,
inputs=input_image,
outputs=output,
api_name="caption",
)
if __name__ == "__main__":
demo.queue(max_size=20).launch()
|