|
import gradio as gr |
|
import numpy as np |
|
import torch |
|
from PIL import Image |
|
|
|
from infer_model import CLIPpyModel |
|
from utils import get_similarity, get_transform, ade_palette, get_cmap_image |
|
|
|
pretrained_ckpt = "https://github.com/kahnchana/clippy/releases/download/v1.0/clippy_5k.pt" |
|
ckpt = torch.utils.model_zoo.load_url(pretrained_ckpt) |
|
|
|
clippy = CLIPpyModel() |
|
transform = get_transform((224, 224)) |
|
|
|
msg = clippy.load_state_dict(ckpt, strict=False) |
|
|
|
palette = ade_palette() |
|
|
|
|
|
def process_image(img, captions): |
|
sample_text = [x.strip() for x in captions.split(",")] |
|
sample_prompts = [f"a photo of a {x}" for x in sample_text] |
|
|
|
image = Image.fromarray(img) |
|
image_vector = clippy.encode_image(transform(image).unsqueeze(0), get_pos_tokens=True) |
|
text_vector = clippy.text.encode(sample_prompts, convert_to_tensor=True) |
|
|
|
similarity = get_similarity(image_vector, text_vector, (224, 224), do_argmax=True)[0, 0].numpy() |
|
rgb_seg = np.zeros((similarity.shape[0], similarity.shape[1], 3), dtype=np.uint8) |
|
for idx, _ in enumerate(sample_text): |
|
rgb_seg[similarity == idx] = palette[idx] |
|
|
|
joint = Image.blend(image, Image.fromarray(rgb_seg), 0.5) |
|
cmap = get_cmap_image({label: tuple(palette[idx]) for idx, label in enumerate(sample_text)}) |
|
|
|
return cmap, rgb_seg, joint |
|
|
|
|
|
title = 'CLIPpy' |
|
|
|
description = """ |
|
Gradio Demo for CLIPpy: Perceptual Grouping in Contrastive Vision Language Models. \n \n |
|
Upload an image and type in a set of comma separated labels (e.g.: "man, woman, background"). |
|
CLIPPy will segment the image, according to the set of class label you provide. |
|
""" |
|
|
|
article = """ |
|
<p style='text-align: center'> |
|
<a href='https://arxiv.org/abs/2210.09996' target='_blank'> |
|
Perceptual Grouping in Contrastive Vision Language Models |
|
</a> |
|
| |
|
<a href='https://github.com/kahnchana/clippy' target='_blank'>Github Repository</a></p> |
|
""" |
|
|
|
demo = gr.Interface( |
|
fn=process_image, |
|
inputs=[gr.Image(shape=(224, 224)), "text"], |
|
outputs=[gr.Image(shape=(224, 224)).style(height=150), |
|
gr.Image(shape=(224, 224)).style(height=260), |
|
gr.Image(shape=(224, 224)).style(height=260)], |
|
title=title, |
|
description=description, |
|
article=article, |
|
) |
|
|
|
demo.launch() |
|
|