tommymarto's picture
first attempt to hf spaces
7f7b773
raw
history blame
2.01 kB
from langchain import PromptTemplate
from langchain.chains import RetrievalQA
from langchain.llms import HuggingFacePipeline
class HuggingFaceQuestionAnswering:
def __init__(self, retriever) -> None:
self.retriever = retriever
self.llm = HuggingFacePipeline.from_model_id(
# model_id="bigscience/bloom-1b7",
model_id="bigscience/bloomz-1b1",
task="text-generation",
device=1,
# model_kwargs={"do_sample": True, "temperature": 0.7, "num_beams": 4, "top_p": 0.95, "repetition_penalty": 1.25, "length_penalty": 1.2},
model_kwargs={"do_sample": True, "temperature": 0.7, "num_beams": 2},
# pipeline_kwargs={"max_new_tokens": 256, "min_new_tokens": 30},
pipeline_kwargs={"max_new_tokens": 256, "min_new_tokens": 30},
)
self.chain = None
def initialize(self):
template = """Use the information contained in the following text: {context}. Complete the phrase: {question} """
prompt_template = PromptTemplate(
template=template,
input_variables=["context", "question"],
)
# self.chain = RetrievalQA.from_chain_type(self.llm, retriever=self.retriever.retriever, chain_type_kwargs={"prompt": prompt_template})
def answer_question(self, question: str, filter_dict):
retriever = self.retriever.vector_store.db.as_retriever(search_kwargs={"filter": filter_dict, "fetch_k": 150})
try:
self.chain = RetrievalQA.from_chain_type(self.llm, retriever=retriever, return_source_documents=True)
result = self.chain({"query": question})
docs = '\n'.join([x.metadata["paper_title"][:40] + " - " + x.page_content[:40].replace("\n", " ") + "..." for x in result["source_documents"]])
print(f"""
Retrieved Documents:
{docs if docs != "" else "No documents found."}""")
return result
except:
return {"result": "Error generating answer."}