Spaces:
Runtime error
Runtime error
update app.py
Browse files- .gitignore +4 -0
- app.py +15 -16
- requirements.txt +1 -0
.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
*.swp
|
3 |
+
hf_models/
|
4 |
+
pretrained_models/
|
app.py
CHANGED
@@ -6,11 +6,12 @@ CC BY-NC-4.0 (https://creativecommons.org/licenses/by-nc/4.0/)
|
|
6 |
import os
|
7 |
import time
|
8 |
from argparse import ArgumentParser
|
|
|
9 |
|
10 |
import numpy as np
|
11 |
import torch
|
12 |
import gradio as gr
|
13 |
-
|
14 |
|
15 |
from encode_with_pseudo_tokens import encode_with_pseudo_tokens_HF
|
16 |
from models import build_text_encoder, Phi, PIC2WORD
|
@@ -19,6 +20,7 @@ import transformers
|
|
19 |
from huggingface_hub import hf_hub_url, cached_download
|
20 |
|
21 |
|
|
|
22 |
def parse_args():
|
23 |
parser = ArgumentParser()
|
24 |
parser.add_argument("--lincir_ckpt_path", default=None, type=str,
|
@@ -100,6 +102,7 @@ def load_models(args):
|
|
100 |
}
|
101 |
|
102 |
|
|
|
103 |
def predict(images, input_text, model_name):
|
104 |
start_time = time.time()
|
105 |
input_images = model_dict['clip_preprocess'](images, return_tensors='pt')['pixel_values'].to(model_dict['device'])
|
@@ -125,18 +128,15 @@ def predict(images, input_text, model_name):
|
|
125 |
clip_text_time = time.time() - start_time
|
126 |
|
127 |
start_time = time.time()
|
128 |
-
try:
|
129 |
-
results = client.query(embedding_input=text_embeddings[0].tolist())
|
130 |
-
output = ''
|
131 |
-
except:
|
132 |
-
results = []
|
133 |
-
output = 'The server for image retrieval is not working. Please try again later.'
|
134 |
-
retrieval_time = time.time() - start_time
|
135 |
|
|
|
|
|
|
|
136 |
|
|
|
137 |
|
138 |
-
for idx,
|
139 |
-
image_url =
|
140 |
output += f'![image]({image_url})\n'
|
141 |
|
142 |
time_output = {'CLIP visual extractor': clip_image_time,
|
@@ -180,7 +180,7 @@ def test_fps(batch_size=1):
|
|
180 |
if __name__ == '__main__':
|
181 |
args = parse_args()
|
182 |
|
183 |
-
global model_dict,
|
184 |
|
185 |
model_dict = load_models(args)
|
186 |
|
@@ -189,19 +189,18 @@ if __name__ == '__main__':
|
|
189 |
test_fps(1)
|
190 |
exit()
|
191 |
|
|
|
192 |
|
193 |
-
|
194 |
-
indice_name="laion5B-H-14" if args.clip_model_name == "huge" else "laion5B-L-14",
|
195 |
-
)
|
196 |
|
197 |
-
title = 'Zeroshot CIR demo'
|
198 |
|
199 |
md_title = f'''# {title}
|
200 |
[LinCIR](https://arxiv.org/abs/2312.01998): Language-only Training of Zero-shot Composed Image Retrieval
|
201 |
[SEARLE](https://arxiv.org/abs/2303.15247): Zero-shot Composed Image Retrieval with Textual Inversion
|
202 |
[Pic2Word](https://arxiv.org/abs/2302.03084): Mapping Pictures to Words for Zero-shot Composed Image Retrieval
|
203 |
|
204 |
-
K-NN index for the retrieval results are entirely trained using the
|
205 |
'''
|
206 |
|
207 |
with gr.Blocks(title=title) as demo:
|
|
|
6 |
import os
|
7 |
import time
|
8 |
from argparse import ArgumentParser
|
9 |
+
import json
|
10 |
|
11 |
import numpy as np
|
12 |
import torch
|
13 |
import gradio as gr
|
14 |
+
import faiss
|
15 |
|
16 |
from encode_with_pseudo_tokens import encode_with_pseudo_tokens_HF
|
17 |
from models import build_text_encoder, Phi, PIC2WORD
|
|
|
20 |
from huggingface_hub import hf_hub_url, cached_download
|
21 |
|
22 |
|
23 |
+
|
24 |
def parse_args():
|
25 |
parser = ArgumentParser()
|
26 |
parser.add_argument("--lincir_ckpt_path", default=None, type=str,
|
|
|
102 |
}
|
103 |
|
104 |
|
105 |
+
@torch.no_grad()
|
106 |
def predict(images, input_text, model_name):
|
107 |
start_time = time.time()
|
108 |
input_images = model_dict['clip_preprocess'](images, return_tensors='pt')['pixel_values'].to(model_dict['device'])
|
|
|
128 |
clip_text_time = time.time() - start_time
|
129 |
|
130 |
start_time = time.time()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
|
132 |
+
_, results = faiss_index.search(text_embeddings.cpu().numpy(), k=10)
|
133 |
+
|
134 |
+
retrieval_time = time.time() - start_time
|
135 |
|
136 |
+
output = ''
|
137 |
|
138 |
+
for idx, retrieved_idx in enumerate(results[0]):
|
139 |
+
image_url = image_urls[retrieved_idx]
|
140 |
output += f'![image]({image_url})\n'
|
141 |
|
142 |
time_output = {'CLIP visual extractor': clip_image_time,
|
|
|
180 |
if __name__ == '__main__':
|
181 |
args = parse_args()
|
182 |
|
183 |
+
global model_dict, faiss_index, image_urls
|
184 |
|
185 |
model_dict = load_models(args)
|
186 |
|
|
|
189 |
test_fps(1)
|
190 |
exit()
|
191 |
|
192 |
+
faiss_index = faiss.read_index('./clip_large.index', faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY)
|
193 |
|
194 |
+
image_urls = json.load(open('./image_urls.json'))
|
|
|
|
|
195 |
|
196 |
+
title = 'Zeroshot CIR demo to search high-quality AI images'
|
197 |
|
198 |
md_title = f'''# {title}
|
199 |
[LinCIR](https://arxiv.org/abs/2312.01998): Language-only Training of Zero-shot Composed Image Retrieval
|
200 |
[SEARLE](https://arxiv.org/abs/2303.15247): Zero-shot Composed Image Retrieval with Textual Inversion
|
201 |
[Pic2Word](https://arxiv.org/abs/2302.03084): Mapping Pictures to Words for Zero-shot Composed Image Retrieval
|
202 |
|
203 |
+
K-NN index for the retrieval results are entirely trained using [the upscaled midjourney v5 images (444,901)](https://huggingface.co/datasets/wanng/midjourney-v5-202304-clean).
|
204 |
'''
|
205 |
|
206 |
with gr.Blocks(title=title) as demo:
|
requirements.txt
CHANGED
@@ -6,3 +6,4 @@ accelerate
|
|
6 |
datasets
|
7 |
spacy
|
8 |
git+https://github.com/rom1504/clip-retrieval
|
|
|
|
6 |
datasets
|
7 |
spacy
|
8 |
git+https://github.com/rom1504/clip-retrieval
|
9 |
+
faiss
|