import gradio as gr import numpy as np import torch from PIL import Image from infer_model import CLIPpyModel from utils import get_similarity, get_transform, ade_palette, get_cmap_image pretrained_ckpt = "https://github.com/kahnchana/clippy/releases/download/v1.0/clippy_5k.pt" ckpt = torch.utils.model_zoo.load_url(pretrained_ckpt) clippy = CLIPpyModel() transform = get_transform((224, 224)) msg = clippy.load_state_dict(ckpt, strict=False) palette = ade_palette() def process_image(img, captions): sample_text = [x.strip() for x in captions.split(",")] sample_prompts = [f"a photo of a {x}" for x in sample_text] image = Image.fromarray(img) image_vector = clippy.encode_image(transform(image).unsqueeze(0), get_pos_tokens=True) text_vector = clippy.text.encode(sample_prompts, convert_to_tensor=True) similarity = get_similarity(image_vector, text_vector, (224, 224), do_argmax=True)[0, 0].numpy() rgb_seg = np.zeros((similarity.shape[0], similarity.shape[1], 3), dtype=np.uint8) for idx, _ in enumerate(sample_text): rgb_seg[similarity == idx] = palette[idx] joint = Image.blend(image, Image.fromarray(rgb_seg), 0.5) cmap = get_cmap_image({label: tuple(palette[idx]) for idx, label in enumerate(sample_text)}) return cmap, rgb_seg, joint title = 'CLIPpy' description = """ Gradio Demo for CLIPpy: Perceptual Grouping in Contrastive Vision Language Models. \n \n Upload an image and type in a set of comma separated labels (e.g.: "man, woman, background"). CLIPPy will segment the image, according to the set of class label you provide. """ article = """

Perceptual Grouping in Contrastive Vision Language Models | Github Repository

""" demo = gr.Interface( fn=process_image, inputs=[gr.Image(shape=(224, 224)), "text"], outputs=[gr.Image(shape=(224, 224)).style(height=150), gr.Image(shape=(224, 224)).style(height=260), gr.Image(shape=(224, 224)).style(height=260)], title=title, description=description, article=article, ) demo.launch()