from transformers import OwlViTProcessor, OwlViTForObjectDetection import matplotlib.colors as mcolors import matplotlib.pyplot as plt import streamlit as st from PIL import Image import warnings import torch import os import io # setttings os.environ['CUDA_VISIBLE_DEVICES'] = '1' warnings.filterwarnings('ignore') st.set_page_config() class owl_vit: def __init__(self, image_path, text, threshold): self.image_path = image_path self.text = text self.threshold = threshold def process(self, processor, model): image = Image.open(self.image_path) if len(image.split()) == 1: image = image.convert("RGB") inputs = processor(text=[self.text], images=[image], return_tensors="pt") outputs = model(**inputs) target_sizes = torch.tensor([[image.height, image.width] for image in [image]]) self.results = processor.post_process(outputs=outputs, target_sizes=target_sizes) self.image = image return self.result_image() def result_image(self): boxes, scores, labels = self.results[0]["boxes"], self.results[0]["scores"], self.results[0]["labels"] plt.imshow(self.image) ax = plt.gca() for box, score, label in zip(boxes, scores, labels): if score >= self.threshold: box = box.detach().numpy() color = list(mcolors.CSS4_COLORS.keys())[label] ax.add_patch(plt.Rectangle(box[:2], box[2] - box[0], box[3] - box[1], fill=False, color=color, linewidth=3,)) ax.text(box[0], box[1], f"{self.text[label]}: {round(score.item(), 2)}", fontsize=15, color=color) plt.tight_layout() img_buf = io.BytesIO() plt.savefig(img_buf, format='png') image = Image.open(img_buf) return image def load_model(): with st.spinner('Getting Neruons in Order ...'): processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch16") model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch16") return processor, model def show_detects(image): st.title("Results") st.image(image, use_column_width=True, caption="Object Detection Results", clamp=True) def process(upload, text, threshold): # save upload to file filetype = upload.name.split('.')[-1] name = len(os.listdir("images")) + 1 file_path = os.path.join('images', f'{name}.{filetype}') with open(file_path, "wb") as f: f.write(upload.getbuffer()) # predict detections and show results detector = owl_vit(file_path, text, threshold) results = detector.process(processor, model) show_detects(results) # clean up - if over 1000 images in folder, delete oldest 1 if len(os.listdir("images")) > 1000: oldest = min(os.listdir("images"), key=os.path.getctime) os.remove(os.path.join("images", oldest)) def main(processor, model): # splash image st.image(os.path.join('refs', 'baseball_labeled.png'), use_column_width=True) # title project descriptions st.title("OWL-ViT") st.markdown("**OWL-ViT** is a zero-shot text-conditioned object detection model. OWL-ViT uses CLIP as its multi-modal \ backbone, with a ViT-like Transformer to get visual features and a causal language model to get the text features. \ To use CLIP for detection, OWL-ViT removes the final token pooling layer of the vision model and attaches a \ lightweight classification and box head to each transformer output token. Open-vocabulary classification \ is enabled by replacing the fixed classification layer weights with the class-name embeddings obtained \ from the text model. The authors first train CLIP from scratch and fine-tune it end-to-end with the classification \ and box heads on standard detection datasets using a bipartite matching loss. One or multiple text queries per image \ can be used to perform zero-shot text-conditioned object detection.", unsafe_allow_html=True) # example if st.button("Run the Example Image/Text"): with st.spinner('Detecting Objects and Comparing Vocab...'): info = owl_vit(os.path.join('refs', 'baseball.jpg'), ["batter", "umpire", "catcher"], 0.50) results = info.process(processor, model) show_detects(results) if st.button("Clear Example"): st.markdown("") # upload col1, col2 = st.columns(2) threshold = st.slider('Confidence Threshold', min_value=0.0, max_value=1.0, value=0.1) with col1: upload = st.file_uploader('Image:', type=['jpg', 'jpeg', 'png']) with col2: text = st.text_area('Objects to Detect: (comma, seperated)', "batter, umpire, catcher") text = [x.strip() for x in text.split(',')] # process if upload is not None and text is not None: filetype = upload.name.split('.')[-1] if filetype in ['jpg', 'jpeg', 'png']: with st.spinner('Detecting and Counting Single Image...'): process(upload, text, threshold) else: st.warning('Unsupported file type.') if __name__ == '__main__': processor, model = load_model() main(processor, model)