from langchain import PromptTemplate from langchain.chains import RetrievalQA from langchain.llms import HuggingFacePipeline class HuggingFaceQuestionAnswering: def __init__(self, retriever) -> None: self.retriever = retriever self.llm = HuggingFacePipeline.from_model_id( # model_id="bigscience/bloom-1b7", model_id="bigscience/bloomz-1b1", task="text-generation", device=1, # model_kwargs={"do_sample": True, "temperature": 0.7, "num_beams": 4, "top_p": 0.95, "repetition_penalty": 1.25, "length_penalty": 1.2}, model_kwargs={"do_sample": True, "temperature": 0.7, "num_beams": 2}, # pipeline_kwargs={"max_new_tokens": 256, "min_new_tokens": 30}, pipeline_kwargs={"max_new_tokens": 256, "min_new_tokens": 30}, ) self.chain = None def initialize(self): template = """Use the information contained in the following text: {context}. Complete the phrase: {question} """ prompt_template = PromptTemplate( template=template, input_variables=["context", "question"], ) # self.chain = RetrievalQA.from_chain_type(self.llm, retriever=self.retriever.retriever, chain_type_kwargs={"prompt": prompt_template}) def answer_question(self, question: str, filter_dict): retriever = self.retriever.vector_store.db.as_retriever(search_kwargs={"filter": filter_dict, "fetch_k": 150}) try: self.chain = RetrievalQA.from_chain_type(self.llm, retriever=retriever, return_source_documents=True) result = self.chain({"query": question}) docs = '\n'.join([x.metadata["paper_title"][:40] + " - " + x.page_content[:40].replace("\n", " ") + "..." for x in result["source_documents"]]) print(f""" Retrieved Documents: {docs if docs != "" else "No documents found."}""") return result except: return {"result": "Error generating answer."}