NetsPresso_QA / tests /test_search_lucene8.py
geonmin-kim's picture
Upload folder using huggingface_hub
d6585f5
#
# Pyserini: Reproducible IR research with sparse and dense representations
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import shutil
import tarfile
import unittest
from random import randint
from typing import List, Dict
from urllib.request import urlretrieve
from pyserini.search.lucene import LuceneSearcher, JLuceneSearcherResult
class TestSearch(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Download pre-built CACM index; append a random value to avoid filename clashes.
r = randint(0, 10000000)
cls.collection_url = 'https://github.com/castorini/anserini-data/raw/master/CACM/lucene-index.cacm.tar.gz'
cls.tarball_name = 'lucene-index.cacm-{}.tar.gz'.format(r)
cls.index_dir = 'index{}/'.format(r)
urlretrieve(cls.collection_url, cls.tarball_name)
tarball = tarfile.open(cls.tarball_name)
tarball.extractall(cls.index_dir)
tarball.close()
cls.searcher = LuceneSearcher(f'{cls.index_dir}lucene-index.cacm')
def test_basic(self):
self.assertTrue(self.searcher.get_similarity().toString().startswith('BM25'))
hits = self.searcher.search('information retrieval')
self.assertEqual(3204, self.searcher.num_docs)
self.assertTrue(isinstance(hits, List))
self.assertTrue(isinstance(hits[0], JLuceneSearcherResult))
self.assertEqual(hits[0].docid, 'CACM-3134')
self.assertEqual(hits[0].lucene_docid, 3133)
self.assertEqual(len(hits[0].contents), 1500)
self.assertEqual(len(hits[0].raw), 1532)
self.assertAlmostEqual(hits[0].score, 4.7655, places=4)
# Test accessing the raw Lucene document and fetching fields from it:
self.assertEqual(hits[0].lucene_document.getField('id').stringValue(), 'CACM-3134')
self.assertEqual(hits[0].lucene_document.get('id'), 'CACM-3134') # simpler call, same result as above
self.assertEqual(len(hits[0].lucene_document.getField('raw').stringValue()), 1532)
self.assertEqual(len(hits[0].lucene_document.get('raw')), 1532) # simpler call, same result as above
self.assertTrue(isinstance(hits[9], JLuceneSearcherResult))
self.assertEqual(hits[9].docid, 'CACM-2516')
self.assertAlmostEqual(hits[9].score, 4.2174, places=4)
hits = self.searcher.search('search')
self.assertTrue(isinstance(hits[0], JLuceneSearcherResult))
self.assertEqual(hits[0].docid, 'CACM-3058')
self.assertAlmostEqual(hits[0].score, 2.8576, places=4)
self.assertTrue(isinstance(hits[9], JLuceneSearcherResult))
self.assertEqual(hits[9].docid, 'CACM-3040')
self.assertAlmostEqual(hits[9].score, 2.6878, places=4)
def test_batch(self):
results = self.searcher.batch_search(['information retrieval', 'search'], ['q1', 'q2'], threads=2)
self.assertEqual(3204, self.searcher.num_docs)
self.assertTrue(isinstance(results, Dict))
self.assertTrue(isinstance(results['q1'], List))
self.assertTrue(isinstance(results['q1'][0], JLuceneSearcherResult))
self.assertEqual(results['q1'][0].docid, 'CACM-3134')
self.assertAlmostEqual(results['q1'][0].score, 4.7655, places=4)
self.assertTrue(isinstance(results['q1'][9], JLuceneSearcherResult))
self.assertEqual(results['q1'][9].docid, 'CACM-2516')
self.assertAlmostEqual(results['q1'][9].score, 4.2174, places=4)
self.assertTrue(isinstance(results['q2'], List))
self.assertTrue(isinstance(results['q2'][0], JLuceneSearcherResult))
self.assertEqual(results['q2'][0].docid, 'CACM-3058')
self.assertAlmostEqual(results['q2'][0].score, 2.8576, places=4)
self.assertTrue(isinstance(results['q2'][9], JLuceneSearcherResult))
self.assertEqual(results['q2'][9].docid, 'CACM-3040')
self.assertAlmostEqual(results['q2'][9].score, 2.6878, places=4)
def test_basic_k(self):
hits = self.searcher.search('information retrieval', k=100)
self.assertEqual(3204, self.searcher.num_docs)
self.assertTrue(isinstance(hits, List))
self.assertTrue(isinstance(hits[0], JLuceneSearcherResult))
self.assertEqual(len(hits), 100)
def test_batch_k(self):
results = self.searcher.batch_search(['information retrieval', 'search'], ['q1', 'q2'], k=100, threads=2)
self.assertEqual(3204, self.searcher.num_docs)
self.assertTrue(isinstance(results, Dict))
self.assertTrue(isinstance(results['q1'], List))
self.assertTrue(isinstance(results['q1'][0], JLuceneSearcherResult))
self.assertEqual(len(results['q1']), 100)
self.assertTrue(isinstance(results['q2'], List))
self.assertTrue(isinstance(results['q2'][0], JLuceneSearcherResult))
self.assertEqual(len(results['q2']), 100)
def test_basic_fields(self):
# This test just provides a sanity check, it's not that interesting as it only searches one field.
hits = self.searcher.search('information retrieval', k=42, fields={'contents': 2.0},)
self.assertEqual(3204, self.searcher.num_docs)
self.assertTrue(isinstance(hits, List))
self.assertTrue(isinstance(hits[0], JLuceneSearcherResult))
self.assertEqual(len(hits), 42)
def test_batch_fields(self):
# This test just provides a sanity check, it's not that interesting as it only searches one field.
results = self.searcher.batch_search(['information retrieval', 'search'], ['q1', 'q2'], k=42,
threads=2, fields={'contents': 2.0})
self.assertEqual(3204, self.searcher.num_docs)
self.assertTrue(isinstance(results, Dict))
self.assertTrue(isinstance(results['q1'], List))
self.assertTrue(isinstance(results['q1'][0], JLuceneSearcherResult))
self.assertEqual(len(results['q1']), 42)
self.assertTrue(isinstance(results['q2'], List))
self.assertTrue(isinstance(results['q2'][0], JLuceneSearcherResult))
self.assertEqual(len(results['q2']), 42)
def test_different_similarity(self):
# qld, default mu
self.searcher.set_qld()
self.assertTrue(self.searcher.get_similarity().toString().startswith('LM Dirichlet'))
hits = self.searcher.search('information retrieval')
self.assertEqual(hits[0].docid, 'CACM-3134')
self.assertAlmostEqual(hits[0].score, 3.6803, places=4)
self.assertEqual(hits[9].docid, 'CACM-1927')
self.assertAlmostEqual(hits[9].score, 2.5324, places=4)
# bm25, default parameters
self.searcher.set_bm25()
self.assertTrue(self.searcher.get_similarity().toString().startswith('BM25'))
hits = self.searcher.search('information retrieval')
self.assertEqual(hits[0].docid, 'CACM-3134')
self.assertAlmostEqual(hits[0].score, 4.7655, places=4)
self.assertEqual(hits[9].docid, 'CACM-2516')
self.assertAlmostEqual(hits[9].score, 4.2174, places=4)
# qld, custom mu
self.searcher.set_qld(100)
self.assertTrue(self.searcher.get_similarity().toString().startswith('LM Dirichlet'))
hits = self.searcher.search('information retrieval')
self.assertEqual(hits[0].docid, 'CACM-3134')
self.assertAlmostEqual(hits[0].score, 6.3558, places=4)
self.assertEqual(hits[9].docid, 'CACM-2631')
self.assertAlmostEqual(hits[9].score, 5.1896, places=4)
# bm25, custom parameters
self.searcher.set_bm25(0.8, 0.3)
self.assertTrue(self.searcher.get_similarity().toString().startswith('BM25'))
hits = self.searcher.search('information retrieval')
self.assertEqual(hits[0].docid, 'CACM-3134')
self.assertAlmostEqual(hits[0].score, 4.8688, places=4)
self.assertEqual(hits[9].docid, 'CACM-2516')
self.assertAlmostEqual(hits[9].score, 4.3332, places=4)
def test_rm3(self):
self.searcher = LuceneSearcher(f'{self.index_dir}lucene-index.cacm')
self.searcher.set_rm3()
self.assertTrue(self.searcher.is_using_rm3())
hits = self.searcher.search('information retrieval')
self.assertEqual(hits[0].docid, 'CACM-3134')
self.assertAlmostEqual(hits[0].score, 2.1735, places=4)
self.assertEqual(hits[9].docid, 'CACM-2516')
self.assertAlmostEqual(hits[9].score, 1.7018, places=4)
self.searcher.unset_rm3()
self.assertFalse(self.searcher.is_using_rm3())
hits = self.searcher.search('information retrieval')
self.assertEqual(hits[0].docid, 'CACM-3134')
self.assertAlmostEqual(hits[0].score, 4.7655, places=4)
self.assertEqual(hits[9].docid, 'CACM-2516')
self.assertAlmostEqual(hits[9].score, 4.2174, places=4)
self.searcher.set_rm3(fb_docs=4, fb_terms=6, original_query_weight=0.3)
self.assertTrue(self.searcher.is_using_rm3())
hits = self.searcher.search('information retrieval')
self.assertEqual(hits[0].docid, 'CACM-3134')
self.assertAlmostEqual(hits[0].score, 2.1715, places=4)
self.assertEqual(hits[9].docid, 'CACM-1457')
self.assertAlmostEqual(hits[9].score, 1.4556, places=4)
def test_rocchio(self):
self.searcher = LuceneSearcher(f'{self.index_dir}lucene-index.cacm')
self.searcher.set_rocchio()
self.assertTrue(self.searcher.is_using_rocchio())
hits = self.searcher.search('information retrieval')
self.assertEqual(hits[0].docid, 'CACM-3134')
self.assertAlmostEqual(hits[0].score, 7.1883, places=4)
self.assertEqual(hits[9].docid, 'CACM-2140')
self.assertAlmostEqual(hits[9].score, 5.5797, places=4)
self.searcher.unset_rocchio()
self.assertFalse(self.searcher.is_using_rocchio())
hits = self.searcher.search('information retrieval')
self.assertEqual(hits[0].docid, 'CACM-3134')
self.assertAlmostEqual(hits[0].score, 4.7655, places=4)
self.assertEqual(hits[9].docid, 'CACM-2516')
self.assertAlmostEqual(hits[9].score, 4.2174, places=4)
self.searcher.set_rocchio(top_fb_terms=10, top_fb_docs=8, bottom_fb_terms=10,
bottom_fb_docs=8, alpha=0.4, beta=0.5, gamma=0.1,
debug=False, use_negative=True)
self.assertTrue(self.searcher.is_using_rocchio())
hits = self.searcher.search('information retrieval')
self.assertEqual(hits[0].docid, 'CACM-3134')
self.assertAlmostEqual(hits[0].score, 3.6489, places=4)
self.assertEqual(hits[9].docid, 'CACM-1032')
self.assertAlmostEqual(hits[9].score, 2.5751, places=4)
self.searcher.set_rocchio(top_fb_terms=10, top_fb_docs=8, bottom_fb_terms=10,
bottom_fb_docs=8, alpha=0.4, beta=0.5, gamma=0.1,
debug=False, use_negative=False)
self.assertTrue(self.searcher.is_using_rocchio())
hits = self.searcher.search('information retrieval')
self.assertEqual(hits[0].docid, 'CACM-3134')
self.assertAlmostEqual(hits[0].score, 4.0390, places=4)
self.assertEqual(hits[9].docid, 'CACM-1032')
self.assertAlmostEqual(hits[9].score, 2.9155, places=4)
@classmethod
def tearDownClass(cls):
cls.searcher.close()
os.remove(cls.tarball_name)
shutil.rmtree(cls.index_dir)
if __name__ == '__main__':
unittest.main()