Megatron17 commited on
Commit
7ac78dd
1 Parent(s): 5c630b6
Files changed (1) hide show
  1. app.py +91 -0
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from PIL import Image
4
+ from transformers import CLIPProcessor, CLIPModel, DetrFeatureExtractor, DetrForObjectDetection, AutoFeatureExtractor, AutoModelForObjectDetection
5
+ import torch
6
+
7
+ feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/detr-resnet-50")
8
+ dmodel = AutoModelForObjectDetection.from_pretrained("facebook/detr-resnet-50")
9
+
10
+ model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
11
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
12
+
13
+ def extract_image(image, text, prob, num=1):
14
+
15
+ inputs = feature_extractor(images=image, return_tensors="pt")
16
+ outputs = dmodel(**inputs)
17
+
18
+ # model predicts bounding boxes and corresponding COCO classes
19
+ logits = outputs.logits
20
+ bboxes = outputs.pred_boxes
21
+ probas = outputs.logits.softmax(-1)[0, :, :-1] #removing no class as detr maps
22
+
23
+ keep = probas.max(-1).values > prob
24
+ outs = feature_extractor.post_process(outputs, torch.tensor(image.size[::-1]).unsqueeze(0))
25
+ bboxes_scaled = outs[0]['boxes'][keep].detach().numpy()
26
+ labels = outs[0]['labels'][keep].detach().numpy()
27
+ scores = outs[0]['scores'][keep].detach().numpy()
28
+
29
+ images_list = []
30
+ for i,j in enumerate(bboxes_scaled):
31
+
32
+ xmin = int(j[0])
33
+ ymin = int(j[1])
34
+ xmax = int(j[2])
35
+ ymax = int(j[3])
36
+
37
+ im_arr = np.array(image)
38
+ roi = im_arr[ymin:ymax, xmin:xmax]
39
+ roi_im = Image.fromarray(roi)
40
+
41
+ images_list.append(roi_im)
42
+
43
+ inpu = processor(text = [text], images=images_list , return_tensors="pt", padding=True)
44
+ output = model(**inpu)
45
+ logits_per_image = output.logits_per_text
46
+ # print("Logits:",logits_per_image)
47
+ probs = logits_per_image.softmax(-1)
48
+ # print("Probability:",probs)
49
+ l_idx = np.argsort(probs[-1].detach().numpy())[::-1][0:num]
50
+ # print("Index:",l_idx)
51
+
52
+ final_ims = []
53
+ for i,j in enumerate(images_list):
54
+ json_dict = {}
55
+ if i in l_idx:
56
+ json_dict['image'] = images_list[i]
57
+ json_dict['score'] = probs[-1].detach().numpy()[i]
58
+
59
+ final_ims.append(json_dict)
60
+
61
+ fi = sorted(final_ims, key=lambda item: item.get("score"), reverse=True)
62
+ return fi[0]['image'], fi[0]['score']
63
+ def zero_shot_classification(image, labels):
64
+ labels = labels.split(',')
65
+ text = [f"a photo of a {c}" for c in labels]
66
+ inpu = processor(text = text, images=image , return_tensors="pt", padding=True)
67
+ output = model(**inpu)
68
+ logits_per_image = output.logits_per_image
69
+ probs = logits_per_image.softmax(dim=1)
70
+ return {k: float(v) for k, v in zip(labels, probs[0])}
71
+
72
+ with gr.Blocks() as demo:
73
+ with gr.Tab("Clip and Crop"):
74
+ i1 = gr.Image(type="pil", label="Input image")
75
+ i2 = gr.Textbox(label="Input text")
76
+ i3 = gr.Number(default=0.96, label="Threshold percentage score")
77
+ o1 = gr.Image(type="pil", label="Cropped part")
78
+ o2 = gr.Textbox(label="Similarity score")
79
+ title = "Cliping and Cropping"
80
+ description = "<p style= 'color:white'>Extract sections of images from your image by using OpenAI's CLIP and Facebooks Detr implemented on HuggingFace Transformers, if the similarity score is not so much, then please consider the prediction to be void.</p>"
81
+ examples=[['ex3.jpg', 'black bag', 0.96],['ex2.jpg', 'man in red dress', 0.85]]
82
+ article = "<p style= 'color:white; text-align:center;'><a href='https://github.com/Vishnunkumar/clipcrop' target='_blank'>clipcrop</a></p>"
83
+ gr.Interface(fn=extract_image, inputs=[i1, i2, i3], outputs=[o1, o2], title=title, description=description, article=article, examples=examples, enable_queue=True)
84
+ with gr.Tab("Zero Shot Image Classification"):
85
+ i1 = gr.Image(label="Image to classify.", type="pil")
86
+ i2 = gr.Textbox(lines=1, label="Comma separated classes", placeholder="Enter your classes separated by ','",)
87
+ title = "Zero Shot Image Classification"
88
+ description = "<p style= 'color:white'>Use clip models embedding to identify the closest class it belongs form its pretrianed data from the given list</p>"
89
+ examples=[['ex3.jpg', 'black bag', 0.96],['ex2.jpg', 'man in red dress', 0.85]]
90
+ gr.Interface(fn=zero_shot_classification,inputs=[i1,i2],outputs="label",title=title,description="Zero Shot Image classification..")
91
+ demo.launch(debug = False)