import inspect import json from typing import Any, Dict, Optional from keyword_extraction import KeywordExtractor from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.conversational_retrieval.base import ( ConversationalRetrievalChain, _get_chat_history, ) from langchain.schema import Document class CustomConversationalRetrievalChain(ConversationalRetrievalChain): keyword_extractor: KeywordExtractor = KeywordExtractor() def _handle_docs(self, docs): if len(docs) == 0: return False, "No documents found. Can you rephrase ?" elif len(docs) == 1: return False, "Only one document found. Can you rephrase ?" elif len(docs) > 10: return False, "Too many documents found. Can you specify your request ?" return True, "" def rerank_documents(self, question: str, docs: list[Document]) -> list[Document]: """Rerank documents based on the number of similar keywords Args: question (str): Orinal question docs (list[Document]): List of documents Returns: list[Document]: List of documents sorted by the number of similar keywords """ keywords = self.keyword_extractor(question) for doc in docs: doc.metadata["similar_keyword"] = 0 doc_keywords = json.loads(doc.page_content)["keywords"] if doc_keywords is None: continue doc_keywords = doc_keywords.lower().split(",") for kw in keywords: if kw.lower() in doc_keywords: doc.metadata["similar_keyword"] += 1 print("similar keyword : ", kw) docs = sorted(docs, key=lambda x: x.metadata["similar_keyword"]) return docs def _call( self, inputs: Dict[str, Any], run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Dict[str, Any]: _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() question = inputs["question"] get_chat_history = self.get_chat_history or _get_chat_history chat_history_str = get_chat_history(inputs["chat_history"]) if chat_history_str: callbacks = _run_manager.get_child() new_question = self.question_generator.run( question=question, chat_history=chat_history_str, callbacks=callbacks ) else: new_question = question accepts_run_manager = ( "run_manager" in inspect.signature(self._get_docs).parameters ) if accepts_run_manager: docs = self._get_docs(new_question, inputs, run_manager=_run_manager) else: docs = self._get_docs(new_question, inputs) # type: ignore[call-arg] valid_docs, message = self._handle_docs(docs) if not valid_docs: return { self.output_key: message, "source_documents": docs, } # Add reranking docs = self.rerank_documents(new_question, docs) new_inputs = inputs.copy() if self.rephrase_question: new_inputs["question"] = new_question new_inputs["chat_history"] = chat_history_str answer = self.combine_docs_chain.run( input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs ) output: Dict[str, Any] = {self.output_key: answer} if self.return_source_documents: output["source_documents"] = docs if self.return_generated_question: output["generated_question"] = new_question return output