ChatData / backend /chains /stuff_documents.py
lqhl's picture
Synced repo using 'sync_with_huggingface' Github Action
e931b70 verified
from typing import Any, List, Tuple
from langchain.callbacks.manager import Callbacks
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.docstore.document import Document
from langchain.schema.prompt_template import format_document
class CustomStuffDocumentChain(StuffDocumentsChain):
"""Combine arxiv documents with PDF reference number"""
def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict:
"""Construct inputs from kwargs and docs.
Format and the join all the documents together into one input with name
`self.document_variable_name`. The pluck any additional variables
from **kwargs.
Args:
docs: List of documents to format and then join into single input
**kwargs: additional inputs to chain, will pluck any other required
arguments from here.
Returns:
dictionary of inputs to LLMChain
"""
# Format each document according to the prompt
doc_strings = []
for doc_id, doc in enumerate(docs):
# add temp reference number in metadata
doc.metadata.update({'ref_id': doc_id})
doc.page_content = doc.page_content.replace('\n', ' ')
doc_strings.append(format_document(doc, self.document_prompt))
# Join the documents together to put them in the prompt.
inputs = {
k: v
for k, v in kwargs.items()
if k in self.llm_chain.prompt.input_variables
}
inputs[self.document_variable_name] = self.document_separator.join(
doc_strings)
return inputs
def combine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]:
"""Stuff all documents into one prompt and pass to LLM.
Args:
docs: List of documents to join together into one variable
callbacks: Optional callbacks to pass along
**kwargs: additional parameters to use to get inputs to LLMChain.
Returns:
The first element returned is the single string output. The second
element returned is a dictionary of other keys to return.
"""
inputs = self._get_inputs(docs, **kwargs)
# Call predict on the LLM.
output = self.llm_chain.predict(callbacks=callbacks, **inputs)
return output, {}
@property
def _chain_type(self) -> str:
return "custom_stuff_document_chain"