onuralpszr's picture
fix: πŸ› adjust intro markdown text
47768b2 verified
raw
history blame
8.54 kB
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
import torch
import supervision as sv
import cv2
import numpy as np
from PIL import Image
import gradio as gr
import spaces
from helpers.utils import create_directory, delete_directory, generate_unique_name
import os
BOX_ANNOTATOR = sv.BoxAnnotator()
LABEL_ANNOTATOR = sv.LabelAnnotator()
MASK_ANNOTATOR = sv.MaskAnnotator()
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
VIDEO_TARGET_DIRECTORY = "tmp"
INTRO_TEXT = """
## PaliGemma 2 Detection with Supervision - Demo
<div style="display: flex; gap: 10px;">
<a href="https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md">
<img src="https://img.shields.io/badge/Github-100000?style=flat&logo=github&logoColor=white" alt="Github">
</a>
<a href="https://huggingface.co/blog/paligemma">
<img src="https://img.shields.io/badge/Huggingface-FFD21E?style=flat&logo=Huggingface&logoColor=black" alt="Huggingface">
</a>
<a href="https://github.com/merveenoyan/smol-vision/blob/main/Fine_tune_PaliGemma.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Colab">
</a>
<a href="https://arxiv.org/abs/2412.03555">
<img src="https://img.shields.io/badge/Arvix-B31B1B?style=flat&logo=arXiv&logoColor=white" alt="Paper">
</a>
<a href="https://supervision.roboflow.com/">
<img src="https://img.shields.io/badge/Supervision-6706CE?style=flat&logo=Roboflow&logoColor=white" alt="Supervision">
</a>
</div>
PaliGemma 2 is an open vision-language model by Google, inspired by [PaLI-3](https://arxiv.org/abs/2310.09199) and
built with open components such as the [SigLIP](https://arxiv.org/abs/2303.15343)
vision model and the [Gemma 2](https://arxiv.org/abs/2408.00118) language model. PaliGemma 2 is designed as a versatile
model for transfer to a wide range of vision-language tasks such as image and short video caption, visual question
answering, text reading, object detection and object segmentation.
This space show how to use PaliGemma 2 for object detection with supervision.
You can input an image and a text prompt
"""
create_directory(directory_path=VIDEO_TARGET_DIRECTORY)
model_id = "google/paligemma2-3b-pt-448"
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval().to(DEVICE)
processor = PaliGemmaProcessor.from_pretrained(model_id)
@spaces.GPU
def paligemma_detection(input_image, input_text, max_new_tokens):
model_inputs = processor(text=input_text,
images=input_image,
return_tensors="pt"
).to(torch.bfloat16).to(model.device)
input_len = model_inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(**model_inputs, max_new_tokens=max_new_tokens, do_sample=False)
generation = generation[0][input_len:]
result = processor.decode(generation, skip_special_tokens=True)
return result
def annotate_image(result, resolution_wh, class_names, cv_image):
detections = sv.Detections.from_lmm(
sv.LMM.PALIGEMMA,
result,
resolution_wh=resolution_wh,
classes=class_names.split(',')
)
annotated_image = BOX_ANNOTATOR.annotate(
scene=cv_image.copy(),
detections=detections
)
annotated_image = LABEL_ANNOTATOR.annotate(
scene=annotated_image,
detections=detections
)
annotated_image = MASK_ANNOTATOR.annotate(
scene=annotated_image,
detections=detections
)
annotated_image = cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB)
annotated_image = Image.fromarray(annotated_image)
return annotated_image
def process_image(input_image, input_text, class_names, max_new_tokens):
cv_image = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR)
result = paligemma_detection(input_image, input_text, max_new_tokens)
annotated_image = annotate_image(result,
(input_image.width, input_image.height),
class_names, cv_image)
return annotated_image, result
@spaces.GPU
def process_video(input_video, input_text, class_names, max_new_tokens, progress=gr.Progress(track_tqdm=True)):
if not input_video:
gr.Info("Please upload a video.")
return None
if not input_text:
gr.Info("Please enter a text prompt.")
return None
name = generate_unique_name()
frame_directory_path = os.path.join(VIDEO_TARGET_DIRECTORY, name)
create_directory(frame_directory_path)
video_info = sv.VideoInfo.from_video_path(input_video)
frame_generator = sv.get_video_frames_generator(input_video)
video_path = os.path.join(VIDEO_TARGET_DIRECTORY, f"{name}.mp4")
results = []
with sv.VideoSink(video_path, video_info=video_info) as sink:
for frame in progress.tqdm(frame_generator, desc="Processing video"):
pil_frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
model_inputs = processor(
text=input_text,
images=pil_frame,
return_tensors="pt"
).to(torch.bfloat16).to(model.device)
input_len = model_inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(**model_inputs, max_new_tokens=max_new_tokens, do_sample=False)
generation = generation[0][input_len:]
result = processor.decode(generation, skip_special_tokens=True)
detections = sv.Detections.from_lmm(
sv.LMM.PALIGEMMA,
result,
resolution_wh=(video_info.width, video_info.height),
classes=class_names.split(',')
)
annotated_frame = BOX_ANNOTATOR.annotate(
scene=frame.copy(),
detections=detections
)
annotated_frame = LABEL_ANNOTATOR.annotate(
scene=annotated_frame,
detections=detections
)
annotated_frame = MASK_ANNOTATOR.annotate(
scene=annotated_frame,
detections=detections
)
results.append(result)
sink.write_frame(annotated_frame)
delete_directory(frame_directory_path)
return video_path, results
with gr.Blocks() as app:
gr.Markdown(INTRO_TEXT)
with gr.Tab("Image Detection"):
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Input Image")
input_text = gr.Textbox(lines=2, placeholder="Enter text here...", label="Enter prompt for example 'detect person;dog")
class_names = gr.Textbox(lines=1, placeholder="Enter class names separated by commas...", label="Class Names")
max_new_tokens = gr.Slider(minimum=20, maximum=200, value=100, step=10, label="Max New Tokens", info="Set to larger for longer generation.")
with gr.Column():
annotated_image = gr.Image(type="pil", label="Annotated Image")
detection_result = gr.Textbox(label="Detection Result")
gr.Button("Submit").click(
fn=process_image,
inputs=[input_image, input_text, class_names, max_new_tokens],
outputs=[annotated_image, detection_result]
)
with gr.Tab("Video Detection"):
with gr.Row():
with gr.Column():
input_video = gr.Video(label="Input Video")
input_text = gr.Textbox(lines=2, placeholder="Enter text here...", label="Enter prompt for example 'detect person;dog")
class_names = gr.Textbox(lines=1, placeholder="Enter class names separated by commas...", label="Class Names")
max_new_tokens = gr.Slider(minimum=20, maximum=200, value=100, step=1, label="Max New Tokens", info="Set to larger for longer generation.")
with gr.Column():
output_video = gr.Video(label="Annotated Video")
detection_result = gr.Textbox(label="Detection Result")
gr.Button("Process Video").click(
fn=process_video,
inputs=[input_video, input_text, class_names, max_new_tokens],
outputs=[output_video, detection_result]
)
if __name__ == "__main__":
app.launch(ssr_mode=False)