File size: 3,714 Bytes
b9e3c29
9f90955
b9e3c29
 
9f90955
b9e3c29
b7f6a3a
39a3f86
 
 
9f90955
b9e3c29
 
 
9f90955
 
b9e3c29
 
 
 
 
 
 
 
 
9f90955
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b9e3c29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f90955
b9e3c29
 
 
 
 
 
 
9f90955
 
 
b9e3c29
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import inspect
import json
from typing import Any, Dict, Optional

from keyword_extraction import KeywordExtractor
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.conversational_retrieval.base import (
    ConversationalRetrievalChain,
    _get_chat_history,
)
from langchain.schema import Document


class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
    keyword_extractor: KeywordExtractor = KeywordExtractor()

    def _handle_docs(self, docs):
        if len(docs) == 0:
            return False, "No documents found. Can you rephrase ?"
        elif len(docs) == 1:
            return False, "Only one document found. Can you rephrase ?"
        elif len(docs) > 10:
            return False, "Too many documents found. Can you specify your request ?"
        return True, ""

    def rerank_documents(self, question: str, docs: list[Document]) -> list[Document]:
        """Rerank documents based on the number of similar keywords

        Args:
            question (str): Orinal question
            docs (list[Document]): List of documents

        Returns:
            list[Document]: List of documents sorted by the number of similar keywords
        """
        keywords = self.keyword_extractor(question)

        for doc in docs:
            doc.metadata["similar_keyword"] = 0
            doc_keywords = json.loads(doc.page_content)["keywords"]
            if doc_keywords is None:
                continue
            doc_keywords = doc_keywords.lower().split(",")

            for kw in keywords:
                if kw.lower() in doc_keywords:
                    doc.metadata["similar_keyword"] += 1
                    print("similar keyword : ", kw)

        docs = sorted(docs, key=lambda x: x.metadata["similar_keyword"])
        return docs

    def _call(
        self,
        inputs: Dict[str, Any],
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> Dict[str, Any]:
        _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
        question = inputs["question"]
        get_chat_history = self.get_chat_history or _get_chat_history
        chat_history_str = get_chat_history(inputs["chat_history"])

        if chat_history_str:
            callbacks = _run_manager.get_child()
            new_question = self.question_generator.run(
                question=question, chat_history=chat_history_str, callbacks=callbacks
            )
        else:
            new_question = question
        accepts_run_manager = (
            "run_manager" in inspect.signature(self._get_docs).parameters
        )
        if accepts_run_manager:
            docs = self._get_docs(new_question, inputs, run_manager=_run_manager)
        else:
            docs = self._get_docs(new_question, inputs)  # type: ignore[call-arg]

        valid_docs, message = self._handle_docs(docs)
        if not valid_docs:
            return {
                self.output_key: message,
                "source_documents": docs,
            }

        # Add reranking
        docs = self.rerank_documents(new_question, docs)

        new_inputs = inputs.copy()
        if self.rephrase_question:
            new_inputs["question"] = new_question
        new_inputs["chat_history"] = chat_history_str
        answer = self.combine_docs_chain.run(
            input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs
        )
        output: Dict[str, Any] = {self.output_key: answer}
        if self.return_source_documents:
            output["source_documents"] = docs
        if self.return_generated_question:
            output["generated_question"] = new_question
        return output