Ermond's picture
Update app.py
35c26d1 verified
# Credits to IDEA Research for the model:
# https://huggingface.co/IDEA-Research/grounding-dino-tiny
from base64 import b64decode
from io import BytesIO
import gradio as gr
import spaces
from PIL import Image
import torch
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
model_id = "IDEA-Research/grounding-dino-tiny"
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
def predict(base64: str, queries: str, box_threshold: float, text_threshold: float):
decoded_img = b64decode(base64)
image_stream = BytesIO(decoded_img)
image = Image.open(image_stream)
inputs = processor(images=image, text=queries, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
results = processor.post_process_grounded_object_detection(
outputs,
inputs.input_ids,
box_threshold=box_threshold,
text_threshold=text_threshold,
target_sizes=[image.size[::-1]]
)
fmt_results = {
"scores": [float(s) for s in results[0]["scores"]],
"labels": results[0]["labels"],
"boxes": [[float(x) for x in box] for box in results[0]["boxes"]]
}
print(fmt_results)
return fmt_results
demo = gr.Interface(
fn=predict,
inputs=[
gr.Text(label="Image (B64)"),
gr.Text(label="Queries, in lowercase, separated by full stop", placeholder="a bird. a blue bird."),
gr.Number(label="box_threshold", value=0.4),
gr.Number(label="text_threshold", value=0.3)
],
outputs=gr.JSON(label="Predictions"),
)
demo.launch()