diff --git a/models.py b/models.py index 21dfc51..6c3ca66 100644 --- a/models.py +++ b/models.py @@ -223,7 +223,7 @@ def retrieve(self, query, corpus, model_name, topk=1): corpus_format = CORPUS_TO_FORMAT[corpus] if "BM25" in model_name: - index = self.load_bm25_index(model_name, corpus) + index = self.load_bm25_index("BM25", corpus) docs = index.search([query], topk=topk) if corpus == "stackexchange": return [[query, corpus_format.format(text=docs[0][0]["text"])]]