tomaarsen HF staff commited on
Commit
608ef2c
1 Parent(s): 5efeece

Fix demo rescoring; add approximate index support; add sliders

Browse files
Files changed (1) hide show
  1. app.py +58 -16
app.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import time
3
  import gradio as gr
4
  from datasets import load_dataset
@@ -14,6 +13,7 @@ title_text_dataset = load_dataset("mixedbread-ai/wikipedia-data-en-2023-11", spl
14
  # Load the int8 and binary indices. Int8 is loaded as a view to save memory, as we never actually perform search with it.
15
  int8_view = Index.restore("wikipedia_int8_usearch_50m.index", view=True)
16
  binary_index: faiss.IndexBinaryFlat = faiss.read_index_binary("wikipedia_ubinary_faiss_50m.index")
 
17
 
18
  # Load the SentenceTransformer model for embedding the queries
19
  model = SentenceTransformer(
@@ -25,7 +25,7 @@ model = SentenceTransformer(
25
  )
26
 
27
 
28
- def search(query, top_k: int = 10, rescore_multiplier: int = 4):
29
  # 1. Embed the query as float32
30
  start_time = time.time()
31
  query_embedding = model.encode(query)
@@ -36,9 +36,10 @@ def search(query, top_k: int = 10, rescore_multiplier: int = 4):
36
  query_embedding_ubinary = quantize_embeddings(query_embedding.reshape(1, -1), "ubinary")
37
  quantize_time = time.time() - start_time
38
 
39
- # 3. Search the binary index
 
40
  start_time = time.time()
41
- _scores, binary_ids = binary_index.search(query_embedding_ubinary, top_k * rescore_multiplier)
42
  binary_ids = binary_ids[0]
43
  search_time = time.time() - start_time
44
 
@@ -54,11 +55,15 @@ def search(query, top_k: int = 10, rescore_multiplier: int = 4):
54
 
55
  # 6. Sort the scores and return the top_k
56
  start_time = time.time()
57
- indices = scores.argsort()[:top_k]
58
  top_k_indices = binary_ids[indices]
59
  top_k_scores = scores[indices]
60
- top_k_titles, top_k_texts = zip(*[(title_text_dataset[idx]["title"], title_text_dataset[idx]["text"]) for idx in top_k_indices.tolist()])
61
- df = pd.DataFrame({"Score": [round(value, 2) for value in top_k_scores], "Title": top_k_titles, "Text": top_k_texts})
 
 
 
 
62
  sort_time = time.time() - start_time
63
 
64
  return df, {
@@ -68,14 +73,15 @@ def search(query, top_k: int = 10, rescore_multiplier: int = 4):
68
  "Load Time": f"{load_time:.4f} s",
69
  "Rescore Time": f"{rescore_time:.4f} s",
70
  "Sort Time": f"{sort_time:.4f} s",
71
- "Total Retrieval Time": f"{quantize_time + search_time + load_time + rescore_time + sort_time:.4f} s"
72
  }
73
 
 
74
  with gr.Blocks(title="Quantized Retrieval") as demo:
75
  gr.Markdown(
76
- """
77
  ## Quantized Retrieval - Binary Search with Scalar (int8) Rescoring
78
- This demo showcases exact retrieval using [quantized embeddings](https://huggingface.co/blog/embedding-quantization). The corpus consists of 41 million texts from Wikipedia articles.
79
 
80
  <details><summary>Click to learn about the retrieval process</summary>
81
 
@@ -94,11 +100,47 @@ we need `1024 / 8 * num_docs` bytes for the binary index and `1024 * num_docs` b
94
  This is notably cheaper than doing the same process with float32 embeddings, which would require `4 * 1024 * num_docs` bytes of memory/disk space for the float32 index, i.e. 32x as much memory and 4x as much disk space.
95
  Additionally, the binary index is much faster (up to 32x) to search than the float32 index, while the rescoring is also extremely efficient. In conclusion, this process allows for fast, scalable, cheap, and memory-efficient retrieval.
96
 
97
- Feel free to check out the [code for this demo](https://huggingface.co/spaces/tomaarsen/quantized_retrieval/blob/main/app.py) to learn more about how to apply this in practice.
 
 
 
98
 
99
  </details>
100
- """)
101
- query = gr.Textbox(label="Query for Wikipedia articles", placeholder="Enter a query to search for relevant texts from Wikipedia.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  search_button = gr.Button(value="Search")
103
 
104
  with gr.Row():
@@ -107,8 +149,8 @@ Feel free to check out the [code for this demo](https://huggingface.co/spaces/to
107
  with gr.Column(scale=1):
108
  json = gr.JSON()
109
 
110
- query.submit(search, inputs=[query], outputs=[output, json])
111
- search_button.click(search, inputs=[query], outputs=[output, json])
112
 
113
  demo.queue()
114
- demo.launch(debug=True)
 
 
1
  import time
2
  import gradio as gr
3
  from datasets import load_dataset
 
13
  # Load the int8 and binary indices. Int8 is loaded as a view to save memory, as we never actually perform search with it.
14
  int8_view = Index.restore("wikipedia_int8_usearch_50m.index", view=True)
15
  binary_index: faiss.IndexBinaryFlat = faiss.read_index_binary("wikipedia_ubinary_faiss_50m.index")
16
+ binary_ivf: faiss.IndexBinaryIVF = faiss.read_index_binary("wikipedia_ubinary_ivf_faiss_50m.index")
17
 
18
  # Load the SentenceTransformer model for embedding the queries
19
  model = SentenceTransformer(
 
25
  )
26
 
27
 
28
+ def search(query, top_k: int = 100, rescore_multiplier: int = 1, use_approx: bool = False):
29
  # 1. Embed the query as float32
30
  start_time = time.time()
31
  query_embedding = model.encode(query)
 
36
  query_embedding_ubinary = quantize_embeddings(query_embedding.reshape(1, -1), "ubinary")
37
  quantize_time = time.time() - start_time
38
 
39
+ # 3. Search the binary index (either exact or approximate)
40
+ index = binary_ivf if use_approx else binary_index
41
  start_time = time.time()
42
+ _scores, binary_ids = index.search(query_embedding_ubinary, top_k * rescore_multiplier)
43
  binary_ids = binary_ids[0]
44
  search_time = time.time() - start_time
45
 
 
55
 
56
  # 6. Sort the scores and return the top_k
57
  start_time = time.time()
58
+ indices = scores.argsort()[::-1][:top_k]
59
  top_k_indices = binary_ids[indices]
60
  top_k_scores = scores[indices]
61
+ top_k_titles, top_k_texts = zip(
62
+ *[(title_text_dataset[idx]["title"], title_text_dataset[idx]["text"]) for idx in top_k_indices.tolist()]
63
+ )
64
+ df = pd.DataFrame(
65
+ {"Score": [round(value, 2) for value in top_k_scores], "Title": top_k_titles, "Text": top_k_texts}
66
+ )
67
  sort_time = time.time() - start_time
68
 
69
  return df, {
 
73
  "Load Time": f"{load_time:.4f} s",
74
  "Rescore Time": f"{rescore_time:.4f} s",
75
  "Sort Time": f"{sort_time:.4f} s",
76
+ "Total Retrieval Time": f"{quantize_time + search_time + load_time + rescore_time + sort_time:.4f} s",
77
  }
78
 
79
+
80
  with gr.Blocks(title="Quantized Retrieval") as demo:
81
  gr.Markdown(
82
+ """
83
  ## Quantized Retrieval - Binary Search with Scalar (int8) Rescoring
84
+ This demo showcases retrieval using [quantized embeddings](https://huggingface.co/blog/embedding-quantization). The corpus consists of 41 million texts from Wikipedia articles.
85
 
86
  <details><summary>Click to learn about the retrieval process</summary>
87
 
 
100
  This is notably cheaper than doing the same process with float32 embeddings, which would require `4 * 1024 * num_docs` bytes of memory/disk space for the float32 index, i.e. 32x as much memory and 4x as much disk space.
101
  Additionally, the binary index is much faster (up to 32x) to search than the float32 index, while the rescoring is also extremely efficient. In conclusion, this process allows for fast, scalable, cheap, and memory-efficient retrieval.
102
 
103
+ Feel free to check out the [code for this demo](https://huggingface.co/spaces/sentence-transformers/quantized-retrieval/blob/main/app.py) to learn more about how to apply this in practice.
104
+
105
+ Notes:
106
+ - The approximate search index (a binary Inverted File Index (IVF)) is in beta and has not been trained with a lot of data. A better IVF index will be released soon.
107
 
108
  </details>
109
+ """
110
+ )
111
+ with gr.Row():
112
+ with gr.Column(scale=75):
113
+ query = gr.Textbox(
114
+ label="Query for Wikipedia articles",
115
+ placeholder="Enter a query to search for relevant texts from Wikipedia.",
116
+ )
117
+ with gr.Column(scale=25):
118
+ use_approx = gr.Radio(
119
+ choices=[("Exact Search", False), ("Approximate Search", True)],
120
+ value=True,
121
+ label="Search Index",
122
+ )
123
+
124
+ with gr.Row():
125
+ with gr.Column(scale=2):
126
+ top_k = gr.Slider(
127
+ minimum=10,
128
+ maximum=1000,
129
+ step=5,
130
+ value=100,
131
+ label="Number of documents to retrieve",
132
+ info="Number of documents to retrieve from the binary search",
133
+ )
134
+ with gr.Column(scale=2):
135
+ rescore_multiplier = gr.Slider(
136
+ minimum=1,
137
+ maximum=10,
138
+ step=1,
139
+ value=1,
140
+ label="Rescore multiplier",
141
+ info="Search for `rescore_multiplier` as many documents to rescore",
142
+ )
143
+
144
  search_button = gr.Button(value="Search")
145
 
146
  with gr.Row():
 
149
  with gr.Column(scale=1):
150
  json = gr.JSON()
151
 
152
+ query.submit(search, inputs=[query, top_k, rescore_multiplier, use_approx], outputs=[output, json])
153
+ search_button.click(search, inputs=[query, top_k, rescore_multiplier, use_approx], outputs=[output, json])
154
 
155
  demo.queue()
156
+ demo.launch()