Spaces:
Sleeping
Sleeping
import inspect | |
import json | |
from typing import Any, Dict, Optional | |
from keyword_extraction import KeywordExtractor | |
from langchain.callbacks.manager import CallbackManagerForChainRun | |
from langchain.chains.conversational_retrieval.base import ( | |
ConversationalRetrievalChain, | |
_get_chat_history, | |
) | |
from langchain.schema import Document | |
class CustomConversationalRetrievalChain(ConversationalRetrievalChain): | |
keyword_extractor: KeywordExtractor = KeywordExtractor() | |
def _handle_docs(self, docs): | |
if len(docs) == 0: | |
return False, "No documents found. Can you rephrase ?" | |
elif len(docs) == 1: | |
return False, "Only one document found. Can you rephrase ?" | |
elif len(docs) > 10: | |
return False, "Too many documents found. Can you specify your request ?" | |
return True, "" | |
def rerank_documents(self, question: str, docs: list[Document]) -> list[Document]: | |
"""Rerank documents based on the number of similar keywords | |
Args: | |
question (str): Orinal question | |
docs (list[Document]): List of documents | |
Returns: | |
list[Document]: List of documents sorted by the number of similar keywords | |
""" | |
keywords = self.keyword_extractor(question) | |
for doc in docs: | |
doc.metadata["similar_keyword"] = 0 | |
doc_keywords = json.loads(doc.page_content)["keywords"] | |
if doc_keywords is None: | |
continue | |
doc_keywords = doc_keywords.lower().split(",") | |
for kw in keywords: | |
if kw.lower() in doc_keywords: | |
doc.metadata["similar_keyword"] += 1 | |
print("similar keyword : ", kw) | |
docs = sorted(docs, key=lambda x: x.metadata["similar_keyword"]) | |
return docs | |
def _call( | |
self, | |
inputs: Dict[str, Any], | |
run_manager: Optional[CallbackManagerForChainRun] = None, | |
) -> Dict[str, Any]: | |
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() | |
question = inputs["question"] | |
get_chat_history = self.get_chat_history or _get_chat_history | |
chat_history_str = get_chat_history(inputs["chat_history"]) | |
if chat_history_str: | |
callbacks = _run_manager.get_child() | |
new_question = self.question_generator.run( | |
question=question, chat_history=chat_history_str, callbacks=callbacks | |
) | |
else: | |
new_question = question | |
accepts_run_manager = ( | |
"run_manager" in inspect.signature(self._get_docs).parameters | |
) | |
if accepts_run_manager: | |
docs = self._get_docs(new_question, inputs, run_manager=_run_manager) | |
else: | |
docs = self._get_docs(new_question, inputs) # type: ignore[call-arg] | |
valid_docs, message = self._handle_docs(docs) | |
if not valid_docs: | |
return { | |
self.output_key: message, | |
"source_documents": docs, | |
} | |
# Add reranking | |
docs = self.rerank_documents(new_question, docs) | |
new_inputs = inputs.copy() | |
if self.rephrase_question: | |
new_inputs["question"] = new_question | |
new_inputs["chat_history"] = chat_history_str | |
answer = self.combine_docs_chain.run( | |
input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs | |
) | |
output: Dict[str, Any] = {self.output_key: answer} | |
if self.return_source_documents: | |
output["source_documents"] = docs | |
if self.return_generated_question: | |
output["generated_question"] = new_question | |
return output | |