jamescalam commited on
Commit
2118bfe
β€’
1 Parent(s): f58d975

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +212 -0
app.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import CLIPTokenizerFast, CLIPProcessor, CLIPModel
3
+ import torch
4
+ import io
5
+ from PIL import Image
6
+ import os
7
+ from cryptography.fernet import Fernet
8
+ from google.cloud import storage
9
+ import pinecone
10
+ import json
11
+
12
+ # decrypt Storage Cloud credentials
13
+ fernet = Fernet(os.environ['DECRYPTION_KEY'])
14
+
15
+ with open('cloud-storage.encrypted', 'rb') as fp:
16
+ encrypted = fp.read()
17
+ creds = json.loads(fernet.decrypt(encrypted).decode())
18
+
19
+ # then save creds to file
20
+ with open('cloud-storage.json', 'w', encoding='utf-8') as fp:
21
+ fp.write(json.dumps(creds, indent=4))
22
+
23
+ # connect to Cloud Storage
24
+ os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'cloud-storage.json'
25
+ storage_client = storage.Client()
26
+ bucket = storage_client.get_bucket('diffusion-search')
27
+
28
+ # get api key for pinecone auth
29
+ PINECONE_KEY = os.environ['PINECONE_KEY']
30
+
31
+ index_id = "diffusion-search"
32
+
33
+ # init connection to pinecone
34
+ pinecone.init(
35
+ api_key=PINECONE_KEY,
36
+ environment="us-west1-gcp"
37
+ )
38
+ if index_id not in pinecone.list_indexes():
39
+ raise ValueError(f"Index '{index_id}' not found")
40
+
41
+ index = pinecone.Index(index_id)
42
+
43
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
44
+ print(f"Using '{device}' device...")
45
+
46
+ # init all of the models and move them to a given GPU
47
+
48
+ # if you have CUDA or MPS, set it to the active device like this
49
+ device = "cuda" if torch.cuda.is_available() else "cpu"
50
+ model_id = "openai/clip-vit-base-patch32"
51
+
52
+ # we initialize a tokenizer, image processor, and the model itself
53
+ tokenizer = CLIPTokenizerFast.from_pretrained(model_id)
54
+ model = CLIPModel.from_pretrained(model_id).to(device)
55
+
56
+ missing_im = Image.open('missing.png')
57
+ threshold = 0.85
58
+
59
+ def encode_text(text: str):
60
+ # create transformer-readable tokens
61
+ inputs = tokenizer(text, return_tensors="pt").to(device)
62
+ text_emb = model.get_text_features(**inputs).cpu().detach().tolist()
63
+ return text_emb
64
+
65
+ def prompt_query(text: str):
66
+ print(f"Running prompt_query('{text}')")
67
+ embeds = encode_text(text)
68
+ try:
69
+ print("Try query pinecone")
70
+ xc = index.query(embeds, top_k=30, include_metadata=True)
71
+ print("query successful")
72
+ except Exception as e:
73
+ print(f"Error during query: {e}")
74
+ # reinitialize connection
75
+ print("Try reinitialize Pinecone connection")
76
+ pinecone.init(api_key=PINECONE_KEY, environment='us-west1-gcp')
77
+ index2 = pinecone.Index(index_id)
78
+ try:
79
+ print("Now try querying pinecone again")
80
+ xc = index2.query(embeds, top_k=30, include_metadata=True)
81
+ print("query successful")
82
+ except Exception as e:
83
+ raise ValueError(e)
84
+ scores = [round(match['score'], 2) for match in xc['matches']]
85
+ ids = [match['id'] for match in xc['matches']]
86
+ return ids
87
+
88
+ def get_image(url: str):
89
+ blob = bucket.blob(url).download_as_string()
90
+ blob_bytes = io.BytesIO(blob)
91
+ im = Image.open(blob_bytes)
92
+ return im
93
+
94
+ def test_image(_id, image):
95
+ try:
96
+ image.save('tmp.png')
97
+ return True
98
+ except OSError:
99
+ # delete corrupted file from pinecone and cloud
100
+ index.delete(ids=[_id])
101
+ bucket.blob(f"images/{_id}.png").delete()
102
+ print(f"DELETED '{_id}'")
103
+ return False
104
+
105
+ def prompt_image(text: str):
106
+ print(f"prompt_image('{text}')")
107
+ embeds = encode_text(text)
108
+ try:
109
+ print("try query pinecone")
110
+ xc = index.query(
111
+ embeds, top_k=9, include_metadata=True,
112
+ filter={"image_nsfw": {"$lt": 0.5}}
113
+ )
114
+ except Exception as e:
115
+ print(f"Error during query: {e}")
116
+ # reinitialize connection
117
+ pinecone.init(api_key=PINECONE_KEY, environment='us-west1-gcp')
118
+ index2 = pinecone.Index(index_id)
119
+ try:
120
+ print("try query pinecone after reinit")
121
+ xc = index2.query(
122
+ embeds, top_k=9, include_metadata=True,
123
+ filter={"image_nsfw": {"$lt": 0.5}}
124
+ )
125
+ except Exception as e:
126
+ raise ValueError(e)
127
+ scores = [match['score'] for match in xc['matches']]
128
+ ids = [match['id'] for match in xc['matches']]
129
+ images = []
130
+ print("Begin looping through (ids, image_urls)")
131
+ for _id in ids:
132
+ try:
133
+ image_url = f"images/{_id}.png"
134
+ print("download_as_string from GCP")
135
+ blob = bucket.blob(image_url).download_as_string()
136
+ print("downloaded successfully")
137
+ blob_bytes = io.BytesIO(blob)
138
+ im = Image.open(blob_bytes)
139
+ print("image opened successfully")
140
+ if test_image(_id, im):
141
+ images.append(im)
142
+ print("image accessible")
143
+ else:
144
+ images.append(missing_im)
145
+ print("image NOT accessible")
146
+ except ValueError:
147
+ print(f"ValueError: '{image_url}'")
148
+ return images, scores
149
+
150
+ # __APP FUNCTIONS__
151
+
152
+ def set_suggestion(text: str):
153
+ return gr.TextArea.update(value=text[0])
154
+
155
+ def set_images(text: str):
156
+ images, scores = prompt_image(text)
157
+ return gr.Gallery.update(value=images)
158
+
159
+ # __CREATE APP__
160
+ demo = gr.Blocks()
161
+
162
+ with demo:
163
+ gr.HTML(
164
+ """
165
+ <img src="https://huggingface.co/spaces/pinecone/diffusion-image-search/resolve/main/pine-trees-collage.png" />
166
+ <style>
167
+ .parallax {
168
+ /* The image used */
169
+ background-image: url("https://huggingface.co/spaces/pinecone/diffusion-image-search/resolve/main/pine-trees-collage.png");
170
+ /* Create the parallax scrolling effect */
171
+ background-attachment: fixed;
172
+ background-position: center;
173
+ background-repeat: no-repeat;
174
+ background-size: cover;
175
+ }
176
+ </style>
177
+
178
+ <!-- Container element -->
179
+ <div class="parallax"></div>
180
+ """
181
+ )
182
+ with gr.Row():
183
+ with gr.Column():
184
+ prompt = gr.TextArea(
185
+ value="space dogs",
186
+ placeholder="Something cool to search for...",
187
+ interactive=True
188
+ )
189
+ search = gr.Button(value="Search!")
190
+ gr.Markdown(
191
+ """
192
+ #### Search through 10K images generated by AI
193
+
194
+ This app demonstrates the idea of text-to-image search. The search process
195
+ uses an AI model that understands the *meaning* of text and images to identify
196
+ images that best align to a search prompt.
197
+
198
+ πŸͺ„ [*Built with the OP Stack*](https://gkogan.notion.site/gkogan/The-OP-Stack-aafcab0005e3445a8ad8491aac80446c)
199
+ """
200
+ )
201
+
202
+ # results column
203
+ with gr.Column():
204
+ pics = gr.Gallery()
205
+ pics.style(grid=3)
206
+ # search event listening
207
+ try:
208
+ search.click(set_images, prompt, pics)
209
+ except OSError:
210
+ print("OSError")
211
+
212
+ demo.launch()