Fix demo rescoring; add approximate index support; add sliders
Browse files
app.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
|
2 |
import time
|
3 |
import gradio as gr
|
4 |
from datasets import load_dataset
|
@@ -14,6 +13,7 @@ title_text_dataset = load_dataset("mixedbread-ai/wikipedia-data-en-2023-11", spl
|
|
14 |
# Load the int8 and binary indices. Int8 is loaded as a view to save memory, as we never actually perform search with it.
|
15 |
int8_view = Index.restore("wikipedia_int8_usearch_50m.index", view=True)
|
16 |
binary_index: faiss.IndexBinaryFlat = faiss.read_index_binary("wikipedia_ubinary_faiss_50m.index")
|
|
|
17 |
|
18 |
# Load the SentenceTransformer model for embedding the queries
|
19 |
model = SentenceTransformer(
|
@@ -25,7 +25,7 @@ model = SentenceTransformer(
|
|
25 |
)
|
26 |
|
27 |
|
28 |
-
def search(query, top_k: int =
|
29 |
# 1. Embed the query as float32
|
30 |
start_time = time.time()
|
31 |
query_embedding = model.encode(query)
|
@@ -36,9 +36,10 @@ def search(query, top_k: int = 10, rescore_multiplier: int = 4):
|
|
36 |
query_embedding_ubinary = quantize_embeddings(query_embedding.reshape(1, -1), "ubinary")
|
37 |
quantize_time = time.time() - start_time
|
38 |
|
39 |
-
# 3. Search the binary index
|
|
|
40 |
start_time = time.time()
|
41 |
-
_scores, binary_ids =
|
42 |
binary_ids = binary_ids[0]
|
43 |
search_time = time.time() - start_time
|
44 |
|
@@ -54,11 +55,15 @@ def search(query, top_k: int = 10, rescore_multiplier: int = 4):
|
|
54 |
|
55 |
# 6. Sort the scores and return the top_k
|
56 |
start_time = time.time()
|
57 |
-
indices = scores.argsort()[:top_k]
|
58 |
top_k_indices = binary_ids[indices]
|
59 |
top_k_scores = scores[indices]
|
60 |
-
top_k_titles, top_k_texts = zip(
|
61 |
-
|
|
|
|
|
|
|
|
|
62 |
sort_time = time.time() - start_time
|
63 |
|
64 |
return df, {
|
@@ -68,14 +73,15 @@ def search(query, top_k: int = 10, rescore_multiplier: int = 4):
|
|
68 |
"Load Time": f"{load_time:.4f} s",
|
69 |
"Rescore Time": f"{rescore_time:.4f} s",
|
70 |
"Sort Time": f"{sort_time:.4f} s",
|
71 |
-
"Total Retrieval Time": f"{quantize_time + search_time + load_time + rescore_time + sort_time:.4f} s"
|
72 |
}
|
73 |
|
|
|
74 |
with gr.Blocks(title="Quantized Retrieval") as demo:
|
75 |
gr.Markdown(
|
76 |
-
"""
|
77 |
## Quantized Retrieval - Binary Search with Scalar (int8) Rescoring
|
78 |
-
This demo showcases
|
79 |
|
80 |
<details><summary>Click to learn about the retrieval process</summary>
|
81 |
|
@@ -94,11 +100,47 @@ we need `1024 / 8 * num_docs` bytes for the binary index and `1024 * num_docs` b
|
|
94 |
This is notably cheaper than doing the same process with float32 embeddings, which would require `4 * 1024 * num_docs` bytes of memory/disk space for the float32 index, i.e. 32x as much memory and 4x as much disk space.
|
95 |
Additionally, the binary index is much faster (up to 32x) to search than the float32 index, while the rescoring is also extremely efficient. In conclusion, this process allows for fast, scalable, cheap, and memory-efficient retrieval.
|
96 |
|
97 |
-
Feel free to check out the [code for this demo](https://huggingface.co/spaces/
|
|
|
|
|
|
|
98 |
|
99 |
</details>
|
100 |
-
"""
|
101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
search_button = gr.Button(value="Search")
|
103 |
|
104 |
with gr.Row():
|
@@ -107,8 +149,8 @@ Feel free to check out the [code for this demo](https://huggingface.co/spaces/to
|
|
107 |
with gr.Column(scale=1):
|
108 |
json = gr.JSON()
|
109 |
|
110 |
-
query.submit(search, inputs=[query], outputs=[output, json])
|
111 |
-
search_button.click(search, inputs=[query], outputs=[output, json])
|
112 |
|
113 |
demo.queue()
|
114 |
-
demo.launch(
|
|
|
|
|
1 |
import time
|
2 |
import gradio as gr
|
3 |
from datasets import load_dataset
|
|
|
13 |
# Load the int8 and binary indices. Int8 is loaded as a view to save memory, as we never actually perform search with it.
|
14 |
int8_view = Index.restore("wikipedia_int8_usearch_50m.index", view=True)
|
15 |
binary_index: faiss.IndexBinaryFlat = faiss.read_index_binary("wikipedia_ubinary_faiss_50m.index")
|
16 |
+
binary_ivf: faiss.IndexBinaryIVF = faiss.read_index_binary("wikipedia_ubinary_ivf_faiss_50m.index")
|
17 |
|
18 |
# Load the SentenceTransformer model for embedding the queries
|
19 |
model = SentenceTransformer(
|
|
|
25 |
)
|
26 |
|
27 |
|
28 |
+
def search(query, top_k: int = 100, rescore_multiplier: int = 1, use_approx: bool = False):
|
29 |
# 1. Embed the query as float32
|
30 |
start_time = time.time()
|
31 |
query_embedding = model.encode(query)
|
|
|
36 |
query_embedding_ubinary = quantize_embeddings(query_embedding.reshape(1, -1), "ubinary")
|
37 |
quantize_time = time.time() - start_time
|
38 |
|
39 |
+
# 3. Search the binary index (either exact or approximate)
|
40 |
+
index = binary_ivf if use_approx else binary_index
|
41 |
start_time = time.time()
|
42 |
+
_scores, binary_ids = index.search(query_embedding_ubinary, top_k * rescore_multiplier)
|
43 |
binary_ids = binary_ids[0]
|
44 |
search_time = time.time() - start_time
|
45 |
|
|
|
55 |
|
56 |
# 6. Sort the scores and return the top_k
|
57 |
start_time = time.time()
|
58 |
+
indices = scores.argsort()[::-1][:top_k]
|
59 |
top_k_indices = binary_ids[indices]
|
60 |
top_k_scores = scores[indices]
|
61 |
+
top_k_titles, top_k_texts = zip(
|
62 |
+
*[(title_text_dataset[idx]["title"], title_text_dataset[idx]["text"]) for idx in top_k_indices.tolist()]
|
63 |
+
)
|
64 |
+
df = pd.DataFrame(
|
65 |
+
{"Score": [round(value, 2) for value in top_k_scores], "Title": top_k_titles, "Text": top_k_texts}
|
66 |
+
)
|
67 |
sort_time = time.time() - start_time
|
68 |
|
69 |
return df, {
|
|
|
73 |
"Load Time": f"{load_time:.4f} s",
|
74 |
"Rescore Time": f"{rescore_time:.4f} s",
|
75 |
"Sort Time": f"{sort_time:.4f} s",
|
76 |
+
"Total Retrieval Time": f"{quantize_time + search_time + load_time + rescore_time + sort_time:.4f} s",
|
77 |
}
|
78 |
|
79 |
+
|
80 |
with gr.Blocks(title="Quantized Retrieval") as demo:
|
81 |
gr.Markdown(
|
82 |
+
"""
|
83 |
## Quantized Retrieval - Binary Search with Scalar (int8) Rescoring
|
84 |
+
This demo showcases retrieval using [quantized embeddings](https://huggingface.co/blog/embedding-quantization). The corpus consists of 41 million texts from Wikipedia articles.
|
85 |
|
86 |
<details><summary>Click to learn about the retrieval process</summary>
|
87 |
|
|
|
100 |
This is notably cheaper than doing the same process with float32 embeddings, which would require `4 * 1024 * num_docs` bytes of memory/disk space for the float32 index, i.e. 32x as much memory and 4x as much disk space.
|
101 |
Additionally, the binary index is much faster (up to 32x) to search than the float32 index, while the rescoring is also extremely efficient. In conclusion, this process allows for fast, scalable, cheap, and memory-efficient retrieval.
|
102 |
|
103 |
+
Feel free to check out the [code for this demo](https://huggingface.co/spaces/sentence-transformers/quantized-retrieval/blob/main/app.py) to learn more about how to apply this in practice.
|
104 |
+
|
105 |
+
Notes:
|
106 |
+
- The approximate search index (a binary Inverted File Index (IVF)) is in beta and has not been trained with a lot of data. A better IVF index will be released soon.
|
107 |
|
108 |
</details>
|
109 |
+
"""
|
110 |
+
)
|
111 |
+
with gr.Row():
|
112 |
+
with gr.Column(scale=75):
|
113 |
+
query = gr.Textbox(
|
114 |
+
label="Query for Wikipedia articles",
|
115 |
+
placeholder="Enter a query to search for relevant texts from Wikipedia.",
|
116 |
+
)
|
117 |
+
with gr.Column(scale=25):
|
118 |
+
use_approx = gr.Radio(
|
119 |
+
choices=[("Exact Search", False), ("Approximate Search", True)],
|
120 |
+
value=True,
|
121 |
+
label="Search Index",
|
122 |
+
)
|
123 |
+
|
124 |
+
with gr.Row():
|
125 |
+
with gr.Column(scale=2):
|
126 |
+
top_k = gr.Slider(
|
127 |
+
minimum=10,
|
128 |
+
maximum=1000,
|
129 |
+
step=5,
|
130 |
+
value=100,
|
131 |
+
label="Number of documents to retrieve",
|
132 |
+
info="Number of documents to retrieve from the binary search",
|
133 |
+
)
|
134 |
+
with gr.Column(scale=2):
|
135 |
+
rescore_multiplier = gr.Slider(
|
136 |
+
minimum=1,
|
137 |
+
maximum=10,
|
138 |
+
step=1,
|
139 |
+
value=1,
|
140 |
+
label="Rescore multiplier",
|
141 |
+
info="Search for `rescore_multiplier` as many documents to rescore",
|
142 |
+
)
|
143 |
+
|
144 |
search_button = gr.Button(value="Search")
|
145 |
|
146 |
with gr.Row():
|
|
|
149 |
with gr.Column(scale=1):
|
150 |
json = gr.JSON()
|
151 |
|
152 |
+
query.submit(search, inputs=[query, top_k, rescore_multiplier, use_approx], outputs=[output, json])
|
153 |
+
search_button.click(search, inputs=[query, top_k, rescore_multiplier, use_approx], outputs=[output, json])
|
154 |
|
155 |
demo.queue()
|
156 |
+
demo.launch()
|