Spaces:
Runtime error
Runtime error
jamescalam
commited on
Commit
β’
2118bfe
1
Parent(s):
f58d975
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from transformers import CLIPTokenizerFast, CLIPProcessor, CLIPModel
|
3 |
+
import torch
|
4 |
+
import io
|
5 |
+
from PIL import Image
|
6 |
+
import os
|
7 |
+
from cryptography.fernet import Fernet
|
8 |
+
from google.cloud import storage
|
9 |
+
import pinecone
|
10 |
+
import json
|
11 |
+
|
12 |
+
# decrypt Storage Cloud credentials
|
13 |
+
fernet = Fernet(os.environ['DECRYPTION_KEY'])
|
14 |
+
|
15 |
+
with open('cloud-storage.encrypted', 'rb') as fp:
|
16 |
+
encrypted = fp.read()
|
17 |
+
creds = json.loads(fernet.decrypt(encrypted).decode())
|
18 |
+
|
19 |
+
# then save creds to file
|
20 |
+
with open('cloud-storage.json', 'w', encoding='utf-8') as fp:
|
21 |
+
fp.write(json.dumps(creds, indent=4))
|
22 |
+
|
23 |
+
# connect to Cloud Storage
|
24 |
+
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'cloud-storage.json'
|
25 |
+
storage_client = storage.Client()
|
26 |
+
bucket = storage_client.get_bucket('diffusion-search')
|
27 |
+
|
28 |
+
# get api key for pinecone auth
|
29 |
+
PINECONE_KEY = os.environ['PINECONE_KEY']
|
30 |
+
|
31 |
+
index_id = "diffusion-search"
|
32 |
+
|
33 |
+
# init connection to pinecone
|
34 |
+
pinecone.init(
|
35 |
+
api_key=PINECONE_KEY,
|
36 |
+
environment="us-west1-gcp"
|
37 |
+
)
|
38 |
+
if index_id not in pinecone.list_indexes():
|
39 |
+
raise ValueError(f"Index '{index_id}' not found")
|
40 |
+
|
41 |
+
index = pinecone.Index(index_id)
|
42 |
+
|
43 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
44 |
+
print(f"Using '{device}' device...")
|
45 |
+
|
46 |
+
# init all of the models and move them to a given GPU
|
47 |
+
|
48 |
+
# if you have CUDA or MPS, set it to the active device like this
|
49 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
50 |
+
model_id = "openai/clip-vit-base-patch32"
|
51 |
+
|
52 |
+
# we initialize a tokenizer, image processor, and the model itself
|
53 |
+
tokenizer = CLIPTokenizerFast.from_pretrained(model_id)
|
54 |
+
model = CLIPModel.from_pretrained(model_id).to(device)
|
55 |
+
|
56 |
+
missing_im = Image.open('missing.png')
|
57 |
+
threshold = 0.85
|
58 |
+
|
59 |
+
def encode_text(text: str):
|
60 |
+
# create transformer-readable tokens
|
61 |
+
inputs = tokenizer(text, return_tensors="pt").to(device)
|
62 |
+
text_emb = model.get_text_features(**inputs).cpu().detach().tolist()
|
63 |
+
return text_emb
|
64 |
+
|
65 |
+
def prompt_query(text: str):
|
66 |
+
print(f"Running prompt_query('{text}')")
|
67 |
+
embeds = encode_text(text)
|
68 |
+
try:
|
69 |
+
print("Try query pinecone")
|
70 |
+
xc = index.query(embeds, top_k=30, include_metadata=True)
|
71 |
+
print("query successful")
|
72 |
+
except Exception as e:
|
73 |
+
print(f"Error during query: {e}")
|
74 |
+
# reinitialize connection
|
75 |
+
print("Try reinitialize Pinecone connection")
|
76 |
+
pinecone.init(api_key=PINECONE_KEY, environment='us-west1-gcp')
|
77 |
+
index2 = pinecone.Index(index_id)
|
78 |
+
try:
|
79 |
+
print("Now try querying pinecone again")
|
80 |
+
xc = index2.query(embeds, top_k=30, include_metadata=True)
|
81 |
+
print("query successful")
|
82 |
+
except Exception as e:
|
83 |
+
raise ValueError(e)
|
84 |
+
scores = [round(match['score'], 2) for match in xc['matches']]
|
85 |
+
ids = [match['id'] for match in xc['matches']]
|
86 |
+
return ids
|
87 |
+
|
88 |
+
def get_image(url: str):
|
89 |
+
blob = bucket.blob(url).download_as_string()
|
90 |
+
blob_bytes = io.BytesIO(blob)
|
91 |
+
im = Image.open(blob_bytes)
|
92 |
+
return im
|
93 |
+
|
94 |
+
def test_image(_id, image):
|
95 |
+
try:
|
96 |
+
image.save('tmp.png')
|
97 |
+
return True
|
98 |
+
except OSError:
|
99 |
+
# delete corrupted file from pinecone and cloud
|
100 |
+
index.delete(ids=[_id])
|
101 |
+
bucket.blob(f"images/{_id}.png").delete()
|
102 |
+
print(f"DELETED '{_id}'")
|
103 |
+
return False
|
104 |
+
|
105 |
+
def prompt_image(text: str):
|
106 |
+
print(f"prompt_image('{text}')")
|
107 |
+
embeds = encode_text(text)
|
108 |
+
try:
|
109 |
+
print("try query pinecone")
|
110 |
+
xc = index.query(
|
111 |
+
embeds, top_k=9, include_metadata=True,
|
112 |
+
filter={"image_nsfw": {"$lt": 0.5}}
|
113 |
+
)
|
114 |
+
except Exception as e:
|
115 |
+
print(f"Error during query: {e}")
|
116 |
+
# reinitialize connection
|
117 |
+
pinecone.init(api_key=PINECONE_KEY, environment='us-west1-gcp')
|
118 |
+
index2 = pinecone.Index(index_id)
|
119 |
+
try:
|
120 |
+
print("try query pinecone after reinit")
|
121 |
+
xc = index2.query(
|
122 |
+
embeds, top_k=9, include_metadata=True,
|
123 |
+
filter={"image_nsfw": {"$lt": 0.5}}
|
124 |
+
)
|
125 |
+
except Exception as e:
|
126 |
+
raise ValueError(e)
|
127 |
+
scores = [match['score'] for match in xc['matches']]
|
128 |
+
ids = [match['id'] for match in xc['matches']]
|
129 |
+
images = []
|
130 |
+
print("Begin looping through (ids, image_urls)")
|
131 |
+
for _id in ids:
|
132 |
+
try:
|
133 |
+
image_url = f"images/{_id}.png"
|
134 |
+
print("download_as_string from GCP")
|
135 |
+
blob = bucket.blob(image_url).download_as_string()
|
136 |
+
print("downloaded successfully")
|
137 |
+
blob_bytes = io.BytesIO(blob)
|
138 |
+
im = Image.open(blob_bytes)
|
139 |
+
print("image opened successfully")
|
140 |
+
if test_image(_id, im):
|
141 |
+
images.append(im)
|
142 |
+
print("image accessible")
|
143 |
+
else:
|
144 |
+
images.append(missing_im)
|
145 |
+
print("image NOT accessible")
|
146 |
+
except ValueError:
|
147 |
+
print(f"ValueError: '{image_url}'")
|
148 |
+
return images, scores
|
149 |
+
|
150 |
+
# __APP FUNCTIONS__
|
151 |
+
|
152 |
+
def set_suggestion(text: str):
|
153 |
+
return gr.TextArea.update(value=text[0])
|
154 |
+
|
155 |
+
def set_images(text: str):
|
156 |
+
images, scores = prompt_image(text)
|
157 |
+
return gr.Gallery.update(value=images)
|
158 |
+
|
159 |
+
# __CREATE APP__
|
160 |
+
demo = gr.Blocks()
|
161 |
+
|
162 |
+
with demo:
|
163 |
+
gr.HTML(
|
164 |
+
"""
|
165 |
+
<img src="https://huggingface.co/spaces/pinecone/diffusion-image-search/resolve/main/pine-trees-collage.png" />
|
166 |
+
<style>
|
167 |
+
.parallax {
|
168 |
+
/* The image used */
|
169 |
+
background-image: url("https://huggingface.co/spaces/pinecone/diffusion-image-search/resolve/main/pine-trees-collage.png");
|
170 |
+
/* Create the parallax scrolling effect */
|
171 |
+
background-attachment: fixed;
|
172 |
+
background-position: center;
|
173 |
+
background-repeat: no-repeat;
|
174 |
+
background-size: cover;
|
175 |
+
}
|
176 |
+
</style>
|
177 |
+
|
178 |
+
<!-- Container element -->
|
179 |
+
<div class="parallax"></div>
|
180 |
+
"""
|
181 |
+
)
|
182 |
+
with gr.Row():
|
183 |
+
with gr.Column():
|
184 |
+
prompt = gr.TextArea(
|
185 |
+
value="space dogs",
|
186 |
+
placeholder="Something cool to search for...",
|
187 |
+
interactive=True
|
188 |
+
)
|
189 |
+
search = gr.Button(value="Search!")
|
190 |
+
gr.Markdown(
|
191 |
+
"""
|
192 |
+
#### Search through 10K images generated by AI
|
193 |
+
|
194 |
+
This app demonstrates the idea of text-to-image search. The search process
|
195 |
+
uses an AI model that understands the *meaning* of text and images to identify
|
196 |
+
images that best align to a search prompt.
|
197 |
+
|
198 |
+
πͺ [*Built with the OP Stack*](https://gkogan.notion.site/gkogan/The-OP-Stack-aafcab0005e3445a8ad8491aac80446c)
|
199 |
+
"""
|
200 |
+
)
|
201 |
+
|
202 |
+
# results column
|
203 |
+
with gr.Column():
|
204 |
+
pics = gr.Gallery()
|
205 |
+
pics.style(grid=3)
|
206 |
+
# search event listening
|
207 |
+
try:
|
208 |
+
search.click(set_images, prompt, pics)
|
209 |
+
except OSError:
|
210 |
+
print("OSError")
|
211 |
+
|
212 |
+
demo.launch()
|