Spaces:
Sleeping
Sleeping
Megatron17
commited on
Commit
•
7ac78dd
1
Parent(s):
5c630b6
app.py
Browse files
app.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
from transformers import CLIPProcessor, CLIPModel, DetrFeatureExtractor, DetrForObjectDetection, AutoFeatureExtractor, AutoModelForObjectDetection
|
5 |
+
import torch
|
6 |
+
|
7 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/detr-resnet-50")
|
8 |
+
dmodel = AutoModelForObjectDetection.from_pretrained("facebook/detr-resnet-50")
|
9 |
+
|
10 |
+
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
11 |
+
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
12 |
+
|
13 |
+
def extract_image(image, text, prob, num=1):
|
14 |
+
|
15 |
+
inputs = feature_extractor(images=image, return_tensors="pt")
|
16 |
+
outputs = dmodel(**inputs)
|
17 |
+
|
18 |
+
# model predicts bounding boxes and corresponding COCO classes
|
19 |
+
logits = outputs.logits
|
20 |
+
bboxes = outputs.pred_boxes
|
21 |
+
probas = outputs.logits.softmax(-1)[0, :, :-1] #removing no class as detr maps
|
22 |
+
|
23 |
+
keep = probas.max(-1).values > prob
|
24 |
+
outs = feature_extractor.post_process(outputs, torch.tensor(image.size[::-1]).unsqueeze(0))
|
25 |
+
bboxes_scaled = outs[0]['boxes'][keep].detach().numpy()
|
26 |
+
labels = outs[0]['labels'][keep].detach().numpy()
|
27 |
+
scores = outs[0]['scores'][keep].detach().numpy()
|
28 |
+
|
29 |
+
images_list = []
|
30 |
+
for i,j in enumerate(bboxes_scaled):
|
31 |
+
|
32 |
+
xmin = int(j[0])
|
33 |
+
ymin = int(j[1])
|
34 |
+
xmax = int(j[2])
|
35 |
+
ymax = int(j[3])
|
36 |
+
|
37 |
+
im_arr = np.array(image)
|
38 |
+
roi = im_arr[ymin:ymax, xmin:xmax]
|
39 |
+
roi_im = Image.fromarray(roi)
|
40 |
+
|
41 |
+
images_list.append(roi_im)
|
42 |
+
|
43 |
+
inpu = processor(text = [text], images=images_list , return_tensors="pt", padding=True)
|
44 |
+
output = model(**inpu)
|
45 |
+
logits_per_image = output.logits_per_text
|
46 |
+
# print("Logits:",logits_per_image)
|
47 |
+
probs = logits_per_image.softmax(-1)
|
48 |
+
# print("Probability:",probs)
|
49 |
+
l_idx = np.argsort(probs[-1].detach().numpy())[::-1][0:num]
|
50 |
+
# print("Index:",l_idx)
|
51 |
+
|
52 |
+
final_ims = []
|
53 |
+
for i,j in enumerate(images_list):
|
54 |
+
json_dict = {}
|
55 |
+
if i in l_idx:
|
56 |
+
json_dict['image'] = images_list[i]
|
57 |
+
json_dict['score'] = probs[-1].detach().numpy()[i]
|
58 |
+
|
59 |
+
final_ims.append(json_dict)
|
60 |
+
|
61 |
+
fi = sorted(final_ims, key=lambda item: item.get("score"), reverse=True)
|
62 |
+
return fi[0]['image'], fi[0]['score']
|
63 |
+
def zero_shot_classification(image, labels):
|
64 |
+
labels = labels.split(',')
|
65 |
+
text = [f"a photo of a {c}" for c in labels]
|
66 |
+
inpu = processor(text = text, images=image , return_tensors="pt", padding=True)
|
67 |
+
output = model(**inpu)
|
68 |
+
logits_per_image = output.logits_per_image
|
69 |
+
probs = logits_per_image.softmax(dim=1)
|
70 |
+
return {k: float(v) for k, v in zip(labels, probs[0])}
|
71 |
+
|
72 |
+
with gr.Blocks() as demo:
|
73 |
+
with gr.Tab("Clip and Crop"):
|
74 |
+
i1 = gr.Image(type="pil", label="Input image")
|
75 |
+
i2 = gr.Textbox(label="Input text")
|
76 |
+
i3 = gr.Number(default=0.96, label="Threshold percentage score")
|
77 |
+
o1 = gr.Image(type="pil", label="Cropped part")
|
78 |
+
o2 = gr.Textbox(label="Similarity score")
|
79 |
+
title = "Cliping and Cropping"
|
80 |
+
description = "<p style= 'color:white'>Extract sections of images from your image by using OpenAI's CLIP and Facebooks Detr implemented on HuggingFace Transformers, if the similarity score is not so much, then please consider the prediction to be void.</p>"
|
81 |
+
examples=[['ex3.jpg', 'black bag', 0.96],['ex2.jpg', 'man in red dress', 0.85]]
|
82 |
+
article = "<p style= 'color:white; text-align:center;'><a href='https://github.com/Vishnunkumar/clipcrop' target='_blank'>clipcrop</a></p>"
|
83 |
+
gr.Interface(fn=extract_image, inputs=[i1, i2, i3], outputs=[o1, o2], title=title, description=description, article=article, examples=examples, enable_queue=True)
|
84 |
+
with gr.Tab("Zero Shot Image Classification"):
|
85 |
+
i1 = gr.Image(label="Image to classify.", type="pil")
|
86 |
+
i2 = gr.Textbox(lines=1, label="Comma separated classes", placeholder="Enter your classes separated by ','",)
|
87 |
+
title = "Zero Shot Image Classification"
|
88 |
+
description = "<p style= 'color:white'>Use clip models embedding to identify the closest class it belongs form its pretrianed data from the given list</p>"
|
89 |
+
examples=[['ex3.jpg', 'black bag', 0.96],['ex2.jpg', 'man in red dress', 0.85]]
|
90 |
+
gr.Interface(fn=zero_shot_classification,inputs=[i1,i2],outputs="label",title=title,description="Zero Shot Image classification..")
|
91 |
+
demo.launch(debug = False)
|