SorboBot / sorbobotapp /conversation_retrieval_chain.py
Léo Bourrel
feat: install pre-commit && clean
39a3f86
raw
history blame
3.71 kB
import inspect
import json
from typing import Any, Dict, Optional
from keyword_extraction import KeywordExtractor
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.conversational_retrieval.base import (
ConversationalRetrievalChain,
_get_chat_history,
)
from langchain.schema import Document
class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
keyword_extractor: KeywordExtractor = KeywordExtractor()
def _handle_docs(self, docs):
if len(docs) == 0:
return False, "No documents found. Can you rephrase ?"
elif len(docs) == 1:
return False, "Only one document found. Can you rephrase ?"
elif len(docs) > 10:
return False, "Too many documents found. Can you specify your request ?"
return True, ""
def rerank_documents(self, question: str, docs: list[Document]) -> list[Document]:
"""Rerank documents based on the number of similar keywords
Args:
question (str): Orinal question
docs (list[Document]): List of documents
Returns:
list[Document]: List of documents sorted by the number of similar keywords
"""
keywords = self.keyword_extractor(question)
for doc in docs:
doc.metadata["similar_keyword"] = 0
doc_keywords = json.loads(doc.page_content)["keywords"]
if doc_keywords is None:
continue
doc_keywords = doc_keywords.lower().split(",")
for kw in keywords:
if kw.lower() in doc_keywords:
doc.metadata["similar_keyword"] += 1
print("similar keyword : ", kw)
docs = sorted(docs, key=lambda x: x.metadata["similar_keyword"])
return docs
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
question = inputs["question"]
get_chat_history = self.get_chat_history or _get_chat_history
chat_history_str = get_chat_history(inputs["chat_history"])
if chat_history_str:
callbacks = _run_manager.get_child()
new_question = self.question_generator.run(
question=question, chat_history=chat_history_str, callbacks=callbacks
)
else:
new_question = question
accepts_run_manager = (
"run_manager" in inspect.signature(self._get_docs).parameters
)
if accepts_run_manager:
docs = self._get_docs(new_question, inputs, run_manager=_run_manager)
else:
docs = self._get_docs(new_question, inputs) # type: ignore[call-arg]
valid_docs, message = self._handle_docs(docs)
if not valid_docs:
return {
self.output_key: message,
"source_documents": docs,
}
# Add reranking
docs = self.rerank_documents(new_question, docs)
new_inputs = inputs.copy()
if self.rephrase_question:
new_inputs["question"] = new_question
new_inputs["chat_history"] = chat_history_str
answer = self.combine_docs_chain.run(
input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs
)
output: Dict[str, Any] = {self.output_key: answer}
if self.return_source_documents:
output["source_documents"] = docs
if self.return_generated_question:
output["generated_question"] = new_question
return output