Parsa Kzr commited on
Commit
8533d6f
1 Parent(s): 2c1dc89

feat: application file

Browse files
Files changed (2) hide show
  1. app.py +627 -0
  2. requirements.txt +1 -0
app.py ADDED
@@ -0,0 +1,627 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from dataclasses import dataclass
3
+ import pickle
4
+ import os
5
+ from typing import Iterable, Callable, List, Dict, Optional, Type, TypeVar
6
+ from collections import Counter
7
+ import tqdm
8
+ import re
9
+ import nltk
10
+ from __future__ import annotations
11
+ from dataclasses import asdict, dataclass
12
+ import math
13
+ from typing import Iterable, List, Optional, Type
14
+ import tqdm
15
+ from typing import Type
16
+ from abc import abstractmethod
17
+ import pytrec_eval
18
+ import gradio as gr
19
+ from typing import TypedDict
20
+ from nlp4web_codebase.ir.data_loaders.dm import Document
21
+ from nlp4web_codebase.ir.data_loaders.sciq import load_sciq
22
+ from nlp4web_codebase.ir.data_loaders.dm import Document
23
+ from nlp4web_codebase.ir.models import BaseRetriever
24
+ from nlp4web_codebase.ir.data_loaders import Split
25
+ from scipy.sparse._csc import csc_matrix
26
+
27
+
28
+ # ----------------- PRE SETUP ----------------- #
29
+ nltk.download("stopwords", quiet=True)
30
+ from nltk.corpus import stopwords as nltk_stopwords
31
+
32
+ LANGUAGE = "english"
33
+ word_splitter = re.compile(r"(?u)\b\w\w+\b").findall
34
+ stopwords = set(nltk_stopwords.words(LANGUAGE))
35
+
36
+
37
+ best_k1 = 0.8
38
+ best_b = 0.6
39
+
40
+ index_dir = "output/csc_bm25_index"
41
+
42
+
43
+ # ----------------- SETUP CLASSES AND FUCNTIONS ----------------- #
44
+ def word_splitting(text: str) -> List[str]:
45
+ return word_splitter(text.lower())
46
+
47
+
48
+ def lemmatization(words: List[str]) -> List[str]:
49
+ return words # We ignore lemmatization here for simplicity
50
+
51
+
52
+ def simple_tokenize(text: str) -> List[str]:
53
+ words = word_splitting(text)
54
+ tokenized = list(filter(lambda w: w not in stopwords, words))
55
+ tokenized = lemmatization(tokenized)
56
+ return tokenized
57
+
58
+
59
+ T = TypeVar("T", bound="InvertedIndex")
60
+
61
+
62
+ @dataclass
63
+ class PostingList:
64
+ term: str # The term
65
+ docid_postings: List[
66
+ int
67
+ ] # docid_postings[i] means the docid (int) of the i-th associated posting
68
+ tweight_postings: List[
69
+ float
70
+ ] # tweight_postings[i] means the term weight (float) of the i-th associated posting
71
+
72
+
73
+ @dataclass
74
+ class InvertedIndex:
75
+ posting_lists: List[PostingList] # docid -> posting_list
76
+ vocab: Dict[str, int]
77
+ cid2docid: Dict[str, int] # collection_id -> docid
78
+ collection_ids: List[str] # docid -> collection_id
79
+ doc_texts: Optional[List[str]] = None # docid -> document text
80
+
81
+ def save(self, output_dir: str) -> None:
82
+ os.makedirs(output_dir, exist_ok=True)
83
+ with open(os.path.join(output_dir, "index.pkl"), "wb") as f:
84
+ pickle.dump(self, f)
85
+
86
+ @classmethod
87
+ def from_saved(cls: Type[T], saved_dir: str) -> T:
88
+ index = cls(
89
+ posting_lists=[], vocab={}, cid2docid={}, collection_ids=[], doc_texts=None
90
+ )
91
+ with open(os.path.join(saved_dir, "index.pkl"), "rb") as f:
92
+ index = pickle.load(f)
93
+ return index
94
+
95
+
96
+ # The output of the counting function:
97
+ @dataclass
98
+ class Counting:
99
+ posting_lists: List[PostingList]
100
+ vocab: Dict[str, int]
101
+ cid2docid: Dict[str, int]
102
+ collection_ids: List[str]
103
+ dfs: List[int] # tid -> df
104
+ dls: List[int] # docid -> doc length
105
+ avgdl: float
106
+ nterms: int
107
+ doc_texts: Optional[List[str]] = None
108
+
109
+
110
+ def run_counting(
111
+ documents: Iterable[Document],
112
+ tokenize_fn: Callable[[str], List[str]] = simple_tokenize,
113
+ store_raw: bool = True, # store the document text in doc_texts
114
+ ndocs: Optional[int] = None,
115
+ show_progress_bar: bool = True,
116
+ ) -> Counting:
117
+ """Counting TFs, DFs, doc_lengths, etc."""
118
+ posting_lists: List[PostingList] = []
119
+ vocab: Dict[str, int] = {}
120
+ cid2docid: Dict[str, int] = {}
121
+ collection_ids: List[str] = []
122
+ dfs: List[int] = [] # tid -> df
123
+ dls: List[int] = [] # docid -> doc length
124
+ nterms: int = 0
125
+ doc_texts: Optional[List[str]] = []
126
+ for doc in tqdm.tqdm(
127
+ documents,
128
+ desc="Counting",
129
+ total=ndocs,
130
+ disable=not show_progress_bar,
131
+ ):
132
+ if doc.collection_id in cid2docid:
133
+ continue
134
+ collection_ids.append(doc.collection_id)
135
+ docid = cid2docid.setdefault(doc.collection_id, len(cid2docid))
136
+ toks = tokenize_fn(doc.text)
137
+ tok2tf = Counter(toks)
138
+ dls.append(sum(tok2tf.values()))
139
+ for tok, tf in tok2tf.items():
140
+ nterms += tf
141
+ tid = vocab.get(tok, None)
142
+ if tid is None:
143
+ posting_lists.append(
144
+ PostingList(term=tok, docid_postings=[], tweight_postings=[])
145
+ )
146
+ tid = vocab.setdefault(tok, len(vocab))
147
+ posting_lists[tid].docid_postings.append(docid)
148
+ posting_lists[tid].tweight_postings.append(tf)
149
+ if tid < len(dfs):
150
+ dfs[tid] += 1
151
+ else:
152
+ dfs.append(0)
153
+ if store_raw:
154
+ doc_texts.append(doc.text)
155
+ else:
156
+ doc_texts = None
157
+ return Counting(
158
+ posting_lists=posting_lists,
159
+ vocab=vocab,
160
+ cid2docid=cid2docid,
161
+ collection_ids=collection_ids,
162
+ dfs=dfs,
163
+ dls=dls,
164
+ avgdl=sum(dls) / len(dls),
165
+ nterms=nterms,
166
+ doc_texts=doc_texts,
167
+ )
168
+
169
+
170
+ # sciq = load_sciq()
171
+ # counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus))
172
+
173
+
174
+ @dataclass
175
+ class BM25Index(InvertedIndex):
176
+
177
+ @staticmethod
178
+ def tokenize(text: str) -> List[str]:
179
+ return simple_tokenize(text)
180
+
181
+ @staticmethod
182
+ def cache_term_weights(
183
+ posting_lists: List[PostingList],
184
+ total_docs: int,
185
+ avgdl: float,
186
+ dfs: List[int],
187
+ dls: List[int],
188
+ k1: float,
189
+ b: float,
190
+ ) -> None:
191
+ """Compute term weights and caching"""
192
+
193
+ N = total_docs
194
+ for tid, posting_list in enumerate(
195
+ tqdm.tqdm(posting_lists, desc="Regularizing TFs")
196
+ ):
197
+ idf = BM25Index.calc_idf(df=dfs[tid], N=N)
198
+ for i in range(len(posting_list.docid_postings)):
199
+ docid = posting_list.docid_postings[i]
200
+ tf = posting_list.tweight_postings[i]
201
+ dl = dls[docid]
202
+ regularized_tf = BM25Index.calc_regularized_tf(
203
+ tf=tf, dl=dl, avgdl=avgdl, k1=k1, b=b
204
+ )
205
+ posting_list.tweight_postings[i] = regularized_tf * idf
206
+
207
+ @staticmethod
208
+ def calc_regularized_tf(
209
+ tf: int, dl: float, avgdl: float, k1: float, b: float
210
+ ) -> float:
211
+ return tf / (tf + k1 * (1 - b + b * dl / avgdl))
212
+
213
+ @staticmethod
214
+ def calc_idf(df: int, N: int):
215
+ return math.log(1 + (N - df + 0.5) / (df + 0.5))
216
+
217
+ @classmethod
218
+ def build_from_documents(
219
+ cls: Type[BM25Index],
220
+ documents: Iterable[Document],
221
+ store_raw: bool = True,
222
+ output_dir: Optional[str] = None,
223
+ ndocs: Optional[int] = None,
224
+ show_progress_bar: bool = True,
225
+ k1: float = 0.9,
226
+ b: float = 0.4,
227
+ ) -> BM25Index:
228
+ # Counting TFs, DFs, doc_lengths, etc.:
229
+ counting = run_counting(
230
+ documents=documents,
231
+ tokenize_fn=BM25Index.tokenize,
232
+ store_raw=store_raw,
233
+ ndocs=ndocs,
234
+ show_progress_bar=show_progress_bar,
235
+ )
236
+
237
+ # Compute term weights and caching:
238
+ posting_lists = counting.posting_lists
239
+ total_docs = len(counting.cid2docid)
240
+ BM25Index.cache_term_weights(
241
+ posting_lists=posting_lists,
242
+ total_docs=total_docs,
243
+ avgdl=counting.avgdl,
244
+ dfs=counting.dfs,
245
+ dls=counting.dls,
246
+ k1=k1,
247
+ b=b,
248
+ )
249
+
250
+ # Assembly and save:
251
+ index = BM25Index(
252
+ posting_lists=posting_lists,
253
+ vocab=counting.vocab,
254
+ cid2docid=counting.cid2docid,
255
+ collection_ids=counting.collection_ids,
256
+ doc_texts=counting.doc_texts,
257
+ )
258
+ return index
259
+
260
+
261
+ class BaseInvertedIndexRetriever(BaseRetriever):
262
+
263
+ @property
264
+ @abstractmethod
265
+ def index_class(self) -> Type[InvertedIndex]:
266
+ pass
267
+
268
+ def __init__(self, index_dir: str) -> None:
269
+ self.index = self.index_class.from_saved(index_dir)
270
+
271
+ def get_term_weights(self, query: str, cid: str) -> Dict[str, float]:
272
+ toks = self.index.tokenize(query)
273
+ target_docid = self.index.cid2docid[cid]
274
+ term_weights = {}
275
+ for tok in toks:
276
+ if tok not in self.index.vocab:
277
+ continue
278
+ tid = self.index.vocab[tok]
279
+ posting_list = self.index.posting_lists[tid]
280
+ for docid, tweight in zip(
281
+ posting_list.docid_postings, posting_list.tweight_postings
282
+ ):
283
+ if docid == target_docid:
284
+ term_weights[tok] = tweight
285
+ break
286
+ return term_weights
287
+
288
+ def score(self, query: str, cid: str) -> float:
289
+ return sum(self.get_term_weights(query=query, cid=cid).values())
290
+
291
+ def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]:
292
+ toks = self.index.tokenize(query)
293
+ docid2score: Dict[int, float] = {}
294
+ for tok in toks:
295
+ if tok not in self.index.vocab:
296
+ continue
297
+ tid = self.index.vocab[tok]
298
+ posting_list = self.index.posting_lists[tid]
299
+ for docid, tweight in zip(
300
+ posting_list.docid_postings, posting_list.tweight_postings
301
+ ):
302
+ docid2score.setdefault(docid, 0)
303
+ docid2score[docid] += tweight
304
+ docid2score = dict(
305
+ sorted(docid2score.items(), key=lambda pair: pair[1], reverse=True)[:topk]
306
+ )
307
+ return {
308
+ self.index.collection_ids[docid]: score
309
+ for docid, score in docid2score.items()
310
+ }
311
+
312
+
313
+ class BM25Retriever(BaseInvertedIndexRetriever):
314
+
315
+ @property
316
+ def index_class(self) -> Type[BM25Index]:
317
+ return BM25Index
318
+
319
+
320
+ @dataclass
321
+ class CSCInvertedIndex:
322
+ posting_lists_matrix: csc_matrix # docid -> posting_list
323
+ vocab: Dict[str, int]
324
+ cid2docid: Dict[str, int] # collection_id -> docid
325
+ collection_ids: List[str] # docid -> collection_id
326
+ doc_texts: Optional[List[str]] = None # docid -> document text
327
+
328
+ def save(self, output_dir: str) -> None:
329
+ os.makedirs(output_dir, exist_ok=True)
330
+ with open(os.path.join(output_dir, "index.pkl"), "wb") as f:
331
+ pickle.dump(self, f)
332
+
333
+ @classmethod
334
+ def from_saved(cls: Type[T], saved_dir: str) -> T:
335
+ index = cls(
336
+ posting_lists_matrix=None,
337
+ vocab={},
338
+ cid2docid={},
339
+ collection_ids=[],
340
+ doc_texts=None,
341
+ )
342
+ with open(os.path.join(saved_dir, "index.pkl"), "rb") as f:
343
+ index = pickle.load(f)
344
+ return index
345
+
346
+
347
+ @dataclass
348
+ class CSCBM25Index(CSCInvertedIndex):
349
+
350
+ @staticmethod
351
+ def tokenize(text: str) -> List[str]:
352
+ return simple_tokenize(text)
353
+
354
+ @staticmethod
355
+ def cache_term_weights(
356
+ posting_lists: List[PostingList],
357
+ total_docs: int,
358
+ avgdl: float,
359
+ dfs: List[int],
360
+ dls: List[int],
361
+ k1: float,
362
+ b: float,
363
+ ) -> csc_matrix:
364
+ ## YOUR_CODE_STARTS_HERE
365
+ data: List[np.float32] = []
366
+ row_indices = []
367
+ col_indices = []
368
+
369
+ N = total_docs
370
+ for tid, posting_list in enumerate(
371
+ tqdm.tqdm(posting_lists, desc="Regularizing TFs")
372
+ ):
373
+ idf = CSCBM25Index.calc_idf(df=dfs[tid], N=N)
374
+ for i in range(len(posting_list.docid_postings)):
375
+ docid = posting_list.docid_postings[i]
376
+ tf = posting_list.tweight_postings[i]
377
+ dl = dls[docid]
378
+ regularized_tf = CSCBM25Index.calc_regularized_tf(
379
+ tf=tf, dl=dl, avgdl=avgdl, k1=k1, b=b
380
+ )
381
+ weight = regularized_tf * idf
382
+
383
+ # Store values for sparse matrix construction
384
+ row_indices.append(docid)
385
+ col_indices.append(tid)
386
+ data.append(np.float32(weight))
387
+
388
+ # Create a CSC matrix from the collected data
389
+ term_weights_matrix = csc_matrix(
390
+ (data, (row_indices, col_indices)), shape=(N, len(posting_lists))
391
+ )
392
+
393
+ return term_weights_matrix
394
+
395
+ ## YOUR_CODE_ENDS_HERE
396
+
397
+ @staticmethod
398
+ def calc_regularized_tf(
399
+ tf: int, dl: float, avgdl: float, k1: float, b: float
400
+ ) -> float:
401
+ return tf / (tf + k1 * (1 - b + b * dl / avgdl))
402
+
403
+ @staticmethod
404
+ def calc_idf(df: int, N: int):
405
+ return math.log(1 + (N - df + 0.5) / (df + 0.5))
406
+
407
+ @classmethod
408
+ def build_from_documents(
409
+ cls: Type[CSCBM25Index],
410
+ documents: Iterable[Document],
411
+ store_raw: bool = True,
412
+ output_dir: Optional[str] = None,
413
+ ndocs: Optional[int] = None,
414
+ show_progress_bar: bool = True,
415
+ k1: float = 0.9,
416
+ b: float = 0.4,
417
+ ) -> CSCBM25Index:
418
+ # Counting TFs, DFs, doc_lengths, etc.:
419
+ counting = run_counting(
420
+ documents=documents,
421
+ tokenize_fn=CSCBM25Index.tokenize,
422
+ store_raw=store_raw,
423
+ ndocs=ndocs,
424
+ show_progress_bar=show_progress_bar,
425
+ )
426
+
427
+ # Compute term weights and caching:
428
+ posting_lists = counting.posting_lists
429
+ total_docs = len(counting.cid2docid)
430
+ posting_lists_matrix = CSCBM25Index.cache_term_weights(
431
+ posting_lists=posting_lists,
432
+ total_docs=total_docs,
433
+ avgdl=counting.avgdl,
434
+ dfs=counting.dfs,
435
+ dls=counting.dls,
436
+ k1=k1,
437
+ b=b,
438
+ )
439
+
440
+ # Assembly and save:
441
+ index = CSCBM25Index(
442
+ posting_lists_matrix=posting_lists_matrix,
443
+ vocab=counting.vocab,
444
+ cid2docid=counting.cid2docid,
445
+ collection_ids=counting.collection_ids,
446
+ doc_texts=counting.doc_texts,
447
+ )
448
+ return index
449
+
450
+
451
+ class BaseCSCInvertedIndexRetriever(BaseRetriever):
452
+
453
+ @property
454
+ @abstractmethod
455
+ def index_class(self) -> Type[CSCInvertedIndex]:
456
+ pass
457
+
458
+ def __init__(self, index_dir: str) -> None:
459
+ self.index = self.index_class.from_saved(index_dir)
460
+
461
+ def get_term_weights(self, query: str, cid: str) -> Dict[str, float]:
462
+ """Retrieve term weights for a specific query and document."""
463
+ toks = self.index.tokenize(query)
464
+ target_docid = self.index.cid2docid[cid]
465
+ term_weights = {}
466
+
467
+ for tok in toks:
468
+ if tok not in self.index.vocab:
469
+ continue
470
+ tid = self.index.vocab[tok]
471
+ # Access the term weights for the target docid and token tid
472
+ weight = self.index.posting_lists_matrix[target_docid, tid]
473
+ if weight != 0:
474
+ term_weights[tok] = weight
475
+
476
+ return term_weights
477
+
478
+ def score(self, query: str, cid: str) -> float:
479
+ return sum(self.get_term_weights(query=query, cid=cid).values())
480
+
481
+ def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]:
482
+ toks = self.index.tokenize(query)
483
+ docid2score: Dict[int, float] = {}
484
+
485
+ for tok in toks:
486
+ if tok not in self.index.vocab:
487
+ continue
488
+ tid = self.index.vocab[tok]
489
+ # Get the column of the matrix corresponding to the tid
490
+ term_weights = self.index.posting_lists_matrix[
491
+ :, tid
492
+ ].tocoo() # To COOrdinate Matrix for easier access to rows
493
+ for docid, tweight in zip(term_weights.row, term_weights.data):
494
+ docid2score.setdefault(docid, 0)
495
+ docid2score[docid] += tweight
496
+
497
+ # Sort and retrieve the top-k results
498
+ docid2score = dict(
499
+ sorted(docid2score.items(), key=lambda pair: pair[1], reverse=True)[:topk]
500
+ )
501
+
502
+ return {
503
+ self.index.collection_ids[docid]: score
504
+ for docid, score in docid2score.items()
505
+ }
506
+ return docid2score
507
+
508
+
509
+ class CSCBM25Retriever(BaseCSCInvertedIndexRetriever):
510
+
511
+ @property
512
+ def index_class(self) -> Type[CSCBM25Index]:
513
+ return CSCBM25Index
514
+
515
+
516
+ # ----------------- SETUP MAIN ----------------- #
517
+
518
+ sciq = load_sciq()
519
+ counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus))
520
+
521
+ bm25_index = BM25Index.build_from_documents(
522
+ documents=iter(sciq.corpus),
523
+ ndocs=12160,
524
+ show_progress_bar=True,
525
+ )
526
+ bm25_index.save("output/bm25_index")
527
+
528
+ csc_bm25_index = CSCBM25Index.build_from_documents(
529
+ documents=iter(sciq.corpus),
530
+ ndocs=12160,
531
+ show_progress_bar=True,
532
+ k1=best_k1,
533
+ b=best_b,
534
+ )
535
+ csc_bm25_index.save("output/csc_bm25_index")
536
+
537
+
538
+ class Hit(TypedDict):
539
+ cid: str
540
+ score: float
541
+ text: str
542
+
543
+
544
+ demo: Optional[gr.Interface] = None # Assign your gradio demo to this variable
545
+ return_type = List[Hit]
546
+
547
+
548
+ # index_dir = "output/csc_bm25_index"
549
+
550
+
551
+ def search(query: str, index_dir: str = index_dir) -> List[Hit]: # , topk:int = 10
552
+ """base search functionality for the retrieval"""
553
+ retriever: BaseRetriever = None
554
+ if "csc" in index_dir.lower():
555
+ retriever = CSCBM25Retriever(index_dir)
556
+ else:
557
+ retriever = BM25Retriever(index_dir)
558
+
559
+ # Retrieve the documents
560
+ ranking = retriever.retrieve(query) # , topk
561
+ # used for retrieving the doc texts
562
+ text = lambda docid: (
563
+ retriever.index.doc_texts[retriever.index.cid2docid[docid]]
564
+ if retriever.index.doc_texts
565
+ else None
566
+ )
567
+
568
+ hits = [Hit(cid=cid, score=score, text=text(cid)) for cid, score in ranking.items()]
569
+
570
+ return hits
571
+
572
+
573
+ '''
574
+ # Function for formatted display of results
575
+ def format_hits_md(hits: List[Hit]) -> str:
576
+ if not hits:
577
+ return "No results found."
578
+ formatted = []
579
+ for idx, hit in enumerate(hits, start=1):
580
+ formatted.append(
581
+ f"## Result {idx}:\n"
582
+ f"* CID: {hit['cid']}\n"
583
+ f"* Score: {hit['score']:.2f}\n"
584
+ f"* Text: {hit['text'] or 'No text available.'}\n"
585
+ )
586
+ return "\n".join(formatted)
587
+ # to return pure json data as a list of json objects
588
+ def format_hits_json(hits: List[Hit]):
589
+ if not hits:
590
+ return
591
+ formatted = []
592
+ # json format
593
+ for hit in hits:
594
+ formatted.append(
595
+ {
596
+ "cid": hit['cid'],
597
+ "score": hit['score'],
598
+ "text": hit['text'] or ''
599
+ }
600
+ )
601
+ return formatted
602
+
603
+ # Gradio wrapper
604
+ def interface_search(query: str) -> str: # , topk: int = 10
605
+ """Wrapper for Gradio interface to call search function and format results."""
606
+ try:
607
+ hits = search(query) # , topk=topk
608
+ return format_hits_json(hits) # [json, md]
609
+ except Exception as e:
610
+ return f"Error: {str(e)}"
611
+ '''
612
+
613
+ # app interface
614
+ demo = gr.Interface(
615
+ fn=search, # interface_search to format Markdown or JSON
616
+ inputs=[
617
+ gr.Textbox(label="Search Query", placeholder="Type your search query"),
618
+ # gr.Number(label="Number of Results (Top-k)", value=10),
619
+ ],
620
+ outputs=gr.Textbox(
621
+ label="Search Results"
622
+ ), # gr.Markdown() or gr.JSON() for better formatting (Next API Testing block should be changed to work)
623
+ title="BM25 Retrieval on allenai/sciq",
624
+ description="Search through the allenai/sciq corpus using a BM25-based retrieval system.",
625
+ )
626
+
627
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ nlp4web-codebase @ git+https://github.com/kwang2049/nlp4web-codebase.git