Spaces:
Sleeping
Sleeping
sayanbanerjee32
commited on
Commit
•
3b2349b
1
Parent(s):
8337d2a
Upload folder using huggingface_hub
Browse files- app.py +83 -0
- requirements.txt +4 -0
app.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
|
5 |
+
import skimage
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
import open_clip
|
9 |
+
|
10 |
+
import gradio as gr
|
11 |
+
|
12 |
+
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
|
13 |
+
model.eval() # model in train mode by default, impacts some models with BatchNorm or stochastic depth active
|
14 |
+
tokenizer = open_clip.get_tokenizer('ViT-B-32')
|
15 |
+
|
16 |
+
target_labels = ["page","chelsea","astronaut","rocket",
|
17 |
+
"motorcycle_right","camera","horse","coffee",
|
18 |
+
'logo']
|
19 |
+
|
20 |
+
original_images = []
|
21 |
+
images = []
|
22 |
+
file_names = []
|
23 |
+
|
24 |
+
for filename in [filename for filename in os.listdir(skimage.data_dir) if filename.endswith(".png") or filename.endswith(".jpg")]:
|
25 |
+
name = os.path.splitext(filename)[0]
|
26 |
+
if name not in target_labels:
|
27 |
+
continue
|
28 |
+
|
29 |
+
image = Image.open(os.path.join(skimage.data_dir, filename)).convert("RGB")
|
30 |
+
|
31 |
+
original_images.append(image)
|
32 |
+
images.append(preprocess(image))
|
33 |
+
file_names.append(filename)
|
34 |
+
|
35 |
+
image_input = torch.tensor(np.stack(images))
|
36 |
+
with torch.no_grad(), torch.cuda.amp.autocast():
|
37 |
+
image_features = model.encode_image(image_input).float()
|
38 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
39 |
+
|
40 |
+
|
41 |
+
def identify_image(input_description):
|
42 |
+
if input_description is None: return None
|
43 |
+
text_tokens = tokenizer([input_description])
|
44 |
+
with torch.no_grad(), torch.cuda.amp.autocast():
|
45 |
+
text_features = model.encode_text(text_tokens).float()
|
46 |
+
text_features /= text_features.norm(dim=-1, keepdim=True)
|
47 |
+
text_probs = (100.0 * image_features @ text_features.T)
|
48 |
+
top_probs, _ = text_probs.cpu().topk(1, dim=-1)
|
49 |
+
return original_images[top_probs.argmax().item()]
|
50 |
+
|
51 |
+
with gr.Blocks() as demo:
|
52 |
+
gr.HTML("<h1 align = 'center'> Image Search </h1>")
|
53 |
+
gr.HTML("<h4 align = 'center'> Identify the most suitable image for description provided.</h4>")
|
54 |
+
|
55 |
+
gr.Gallery(value = original_images,
|
56 |
+
label="Images to search from", show_label=True, elem_id="gallery"
|
57 |
+
, columns=[3], rows=[3], object_fit="contain", height="auto")
|
58 |
+
|
59 |
+
content = gr.Textbox(label = "Enter search text here")
|
60 |
+
inputs = [
|
61 |
+
content,
|
62 |
+
]
|
63 |
+
gr.Examples(["Page of text about segmentation",
|
64 |
+
"Facial photo of a tabby cat",
|
65 |
+
"Portrait of an astronaut with the American flag",
|
66 |
+
"Rocket standing on a launchpad",
|
67 |
+
"Red motorcycle standing in a garage",
|
68 |
+
"Person looking at a camera on a tripod",
|
69 |
+
"Black-and-white silhouette of a horse",
|
70 |
+
"Cup of coffee on a saucer",
|
71 |
+
"A snake in the background"],
|
72 |
+
inputs = inputs)
|
73 |
+
|
74 |
+
generate_btn = gr.Button(value = 'Identify')
|
75 |
+
outputs = [gr.Image(label = "Is this the image you are referring to?",
|
76 |
+
height = 512, width = 512)]
|
77 |
+
generate_btn.click(fn = identify_image, inputs= inputs, outputs = outputs)
|
78 |
+
|
79 |
+
## for collab
|
80 |
+
# demo.launch(debug=True)
|
81 |
+
|
82 |
+
if __name__ == '__main__':
|
83 |
+
demo.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
scikit-image
|
2 |
+
open_clip_torch
|
3 |
+
pillow
|
4 |
+
torch
|