Spaces:
Running
on
Zero
Running
on
Zero
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor | |
import torch | |
import supervision as sv | |
import cv2 | |
import numpy as np | |
from PIL import Image | |
import gradio as gr | |
import spaces | |
from helpers.file_utils import create_directory, delete_directory, generate_unique_name | |
from helpers.segment_utils import parse_segmentation, extract_objs | |
import os | |
BOX_ANNOTATOR = sv.BoxAnnotator() | |
LABEL_ANNOTATOR = sv.LabelAnnotator() | |
MASK_ANNOTATOR = sv.MaskAnnotator() | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
VIDEO_TARGET_DIRECTORY = "tmp" | |
VAE_MODEL = "vae-oid.npz" | |
COLORS = ['#4285f4', '#db4437', '#f4b400', '#0f9d58', '#e48ef1'] | |
INTRO_TEXT = """ | |
## PaliGemma 2 Detection/Segmentation with Supervision - Demo | |
<div style="display: flex; gap: 10px;"> | |
<a href="https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md"> | |
<img src="https://img.shields.io/badge/Github-100000?style=flat&logo=github&logoColor=white" alt="Github"> | |
</a> | |
<a href="https://huggingface.co/blog/paligemma"> | |
<img src="https://img.shields.io/badge/Huggingface-FFD21E?style=flat&logo=Huggingface&logoColor=black" alt="Huggingface"> | |
</a> | |
<a href="https://github.com/merveenoyan/smol-vision/blob/main/Fine_tune_PaliGemma.ipynb"> | |
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Colab"> | |
</a> | |
<a href="https://arxiv.org/abs/2412.03555"> | |
<img src="https://img.shields.io/badge/Arvix-B31B1B?style=flat&logo=arXiv&logoColor=white" alt="Paper"> | |
</a> | |
<a href="https://supervision.roboflow.com/"> | |
<img src="https://img.shields.io/badge/Supervision-6706CE?style=flat&logo=Roboflow&logoColor=white" alt="Supervision"> | |
</a> | |
</div> | |
PaliGemma 2 is an open vision-language model by Google, inspired by [PaLI-3](https://arxiv.org/abs/2310.09199) and | |
built with open components such as the [SigLIP](https://arxiv.org/abs/2303.15343) | |
vision model and the [Gemma 2](https://arxiv.org/abs/2408.00118) language model. PaliGemma 2 is designed as a versatile | |
model for transfer to a wide range of vision-language tasks such as image and short video caption, visual question | |
answering, text reading, object detection and object segmentation. | |
This space show how to use PaliGemma 2 for object detection with supervision. | |
You can input an image and a text prompt | |
""" | |
create_directory(directory_path=VIDEO_TARGET_DIRECTORY) | |
model_id = "google/paligemma2-3b-pt-448" | |
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval().to(DEVICE) | |
processor = PaliGemmaProcessor.from_pretrained(model_id) | |
def parse_class_names(prompt): | |
if not prompt.lower().startswith('detect '): | |
return [] | |
classes_text = prompt[7:].strip() | |
return [cls.strip() for cls in classes_text.split(';') if cls.strip()] | |
def parse_prompt_type(prompt): | |
"""Determine if the prompt is for detection or segmentation.""" | |
if prompt.lower().startswith('detect '): | |
return 'detection', prompt[7:].strip() | |
elif prompt.lower().startswith('segment '): | |
return 'segmentation', prompt[8:].strip() | |
return None, prompt | |
def paligemma_detection(input_image, input_text, max_new_tokens): | |
model_inputs = processor(text=input_text, | |
images=input_image, | |
return_tensors="pt" | |
).to(torch.bfloat16).to(model.device) | |
input_len = model_inputs["input_ids"].shape[-1] | |
with torch.inference_mode(): | |
generation = model.generate(**model_inputs, max_new_tokens=max_new_tokens, do_sample=False) | |
generation = generation[0][input_len:] | |
result = processor.decode(generation, skip_special_tokens=True) | |
return result | |
def annotate_image(result, resolution_wh, prompt, cv_image): | |
class_names = parse_class_names(prompt) | |
if not class_names: | |
gr.Warning("Invalid prompt format. Please use 'detect class1;class2;class3' format") | |
return cv_image | |
detections = sv.Detections.from_lmm( | |
sv.LMM.PALIGEMMA, | |
result, | |
resolution_wh=resolution_wh, | |
classes=class_names | |
) | |
annotated_image = BOX_ANNOTATOR.annotate( | |
scene=cv_image.copy(), | |
detections=detections | |
) | |
annotated_image = LABEL_ANNOTATOR.annotate( | |
scene=annotated_image, | |
detections=detections | |
) | |
annotated_image = MASK_ANNOTATOR.annotate( | |
scene=annotated_image, | |
detections=detections | |
) | |
annotated_image = cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB) | |
annotated_image = Image.fromarray(annotated_image) | |
return annotated_image | |
def process_image(input_image, input_text, max_new_tokens): | |
cv_image = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR) | |
prompt_type, cleaned_prompt = parse_prompt_type(input_text) | |
if prompt_type == 'detection': | |
# Existing detection logic | |
result = paligemma_detection(input_image, input_text, max_new_tokens) | |
class_names = [cls.strip() for cls in cleaned_prompt.split(';') if cls.strip()] | |
detections = sv.Detections.from_lmm( | |
sv.LMM.PALIGEMMA, | |
result, | |
resolution_wh=(input_image.width, input_image.height), | |
classes=class_names | |
) | |
annotated_image = BOX_ANNOTATOR.annotate(scene=cv_image.copy(), detections=detections) | |
annotated_image = LABEL_ANNOTATOR.annotate(scene=annotated_image, detections=detections) | |
annotated_image = MASK_ANNOTATOR.annotate(scene=annotated_image, detections=detections) | |
elif prompt_type == 'segmentation': | |
# New segmentation logic | |
result = paligemma_detection(input_image, input_text, max_new_tokens) | |
objs = extract_objs(result.lstrip("\n"), input_image.width, input_image.height, unique_labels=True) | |
# Create masks and annotations | |
annotated_image = cv_image.copy() | |
for obj in objs: | |
if 'mask' in obj and obj['mask'] is not None: | |
mask = obj['mask'] | |
# Convert mask to uint8 for visualization | |
mask_vis = (mask * 255).astype(np.uint8) | |
# Create colored mask | |
colored_mask = np.zeros_like(cv_image) | |
color_idx = hash(obj['name']) % len(COLORS) | |
color = tuple(int(COLORS[color_idx].lstrip('#')[i:i+2], 16) for i in (0, 2, 4)) | |
colored_mask[mask > 0] = color | |
# Blend mask with image | |
alpha = 0.5 | |
annotated_image = cv2.addWeighted(annotated_image, 1, colored_mask, alpha, 0) | |
# Add label | |
if 'xyxy' in obj: | |
x1, y1, x2, y2 = obj['xyxy'] | |
cv2.putText(annotated_image, obj['name'], (x1, y1-10), | |
cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2) | |
else: | |
gr.Warning("Invalid prompt format. Please use 'detect' or 'segment' followed by class names") | |
return input_image, "Invalid prompt format" | |
# Convert back to RGB for display | |
annotated_image = cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB) | |
annotated_image = Image.fromarray(annotated_image) | |
return annotated_image, result | |
def process_video(input_video, input_text, max_new_tokens, progress=gr.Progress(track_tqdm=True)): | |
if not input_video: | |
gr.Info("Please upload a video.") | |
return None | |
if not input_text: | |
gr.Info("Please enter a text prompt.") | |
return None | |
class_names = parse_class_names(input_text) | |
if not class_names: | |
gr.Warning("Invalid prompt format. Please use 'detect class1;class2;class3' format") | |
return None, None | |
name = generate_unique_name() | |
frame_directory_path = os.path.join(VIDEO_TARGET_DIRECTORY, name) | |
create_directory(frame_directory_path) | |
video_info = sv.VideoInfo.from_video_path(input_video) | |
frame_generator = sv.get_video_frames_generator(input_video) | |
video_path = os.path.join(VIDEO_TARGET_DIRECTORY, f"{name}.mp4") | |
results = [] | |
with sv.VideoSink(video_path, video_info=video_info) as sink: | |
for frame in progress.tqdm(frame_generator, desc="Processing video"): | |
pil_frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
model_inputs = processor( | |
text=input_text, | |
images=pil_frame, | |
return_tensors="pt" | |
).to(torch.bfloat16).to(model.device) | |
input_len = model_inputs["input_ids"].shape[-1] | |
with torch.inference_mode(): | |
generation = model.generate(**model_inputs, max_new_tokens=max_new_tokens, do_sample=False) | |
generation = generation[0][input_len:] | |
result = processor.decode(generation, skip_special_tokens=True) | |
detections = sv.Detections.from_lmm( | |
sv.LMM.PALIGEMMA, | |
result, | |
resolution_wh=(video_info.width, video_info.height), | |
classes=class_names | |
) | |
annotated_frame = BOX_ANNOTATOR.annotate( | |
scene=frame.copy(), | |
detections=detections | |
) | |
annotated_frame = LABEL_ANNOTATOR.annotate( | |
scene=annotated_frame, | |
detections=detections | |
) | |
annotated_frame = MASK_ANNOTATOR.annotate( | |
scene=annotated_frame, | |
detections=detections | |
) | |
results.append(result) | |
sink.write_frame(annotated_frame) | |
delete_directory(frame_directory_path) | |
return video_path, results | |
with gr.Blocks() as app: | |
gr.Markdown(INTRO_TEXT) | |
with gr.Tab("Image Detection/Segmentation"): | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(type="pil", label="Input Image") | |
input_text = gr.Textbox( | |
lines=2, | |
placeholder="Enter prompt in format like this: detect person;dog;building or segment person;dog;building", | |
label="Enter detection prompt" | |
) | |
max_new_tokens = gr.Slider(minimum=20, maximum=200, value=100, step=10, label="Max New Tokens", info="Set to larger for longer generation.") | |
with gr.Column(): | |
annotated_image = gr.Image(type="pil", label="Annotated Image") | |
detection_result = gr.Textbox(label="Detection Result") | |
gr.Button("Submit").click( | |
fn=process_image, | |
inputs=[input_image, input_text, max_new_tokens], | |
outputs=[annotated_image, detection_result] | |
) | |
with gr.Tab("Video Detection"): | |
with gr.Row(): | |
with gr.Column(): | |
input_video = gr.Video(label="Input Video") | |
input_text = gr.Textbox( | |
lines=2, | |
placeholder="Enter prompt in format like this: detect person;dog;building or segment person;dog;building", | |
label="Enter detection prompt" | |
) | |
max_new_tokens = gr.Slider(minimum=20, maximum=200, value=100, step=1, label="Max New Tokens", info="Set to larger for longer generation.") | |
with gr.Column(): | |
output_video = gr.Video(label="Annotated Video") | |
detection_result = gr.Textbox(label="Detection Result") | |
gr.Button("Process Video").click( | |
fn=process_video, | |
inputs=[input_video, input_text, max_new_tokens], | |
outputs=[output_video, detection_result] | |
) | |
if __name__ == "__main__": | |
app.launch(ssr_mode=False) |