import abc import os import time import urllib from queue import Queue from threading import Thread from langchain.callbacks.tracers import LangChainTracer from langchain.chains.base import Chain from app_modules.llm_loader import LLMLoader, TextIteratorStreamer from app_modules.utils import remove_extra_spaces class LLMInference(metaclass=abc.ABCMeta): llm_loader: LLMLoader chain: Chain def __init__(self, llm_loader): self.llm_loader = llm_loader self.chain = None @abc.abstractmethod def create_chain(self) -> Chain: pass def get_chain(self, tracing: bool = False) -> Chain: if self.chain is None: if tracing: tracer = LangChainTracer() tracer.load_default_session() self.chain = self.create_chain() return self.chain def call_chain( self, inputs, streaming_handler, q: Queue = None, tracing: bool = False ): print(inputs) if self.llm_loader.streamer is not None and isinstance( self.llm_loader.streamer, TextIteratorStreamer ): self.llm_loader.streamer.reset(q) chain = self.get_chain(tracing) result = ( self._run_qa_chain( chain, inputs, streaming_handler, ) if streaming_handler is not None else chain(inputs) ) if "answer" in result: result["answer"] = remove_extra_spaces(result["answer"]) base_url = os.environ.get("PDF_FILE_BASE_URL") if base_url is not None and len(base_url) > 0: documents = result["source_documents"] for doc in documents: source = doc.metadata["source"] title = source.split("/")[-1] doc.metadata["url"] = f"{base_url}{urllib.parse.quote(title)}" return result def _run_qa_chain(self, qa, inputs, streaming_handler): que = Queue() t = Thread( target=lambda qa, inputs, q, sh: q.put(qa(inputs, callbacks=[sh])), args=(qa, inputs, que, streaming_handler), ) t.start() if self.llm_loader.streamer is not None and isinstance( self.llm_loader.streamer, TextIteratorStreamer ): count = 2 if len(inputs.get("chat_history")) > 0 else 1 while count > 0: try: for token in self.llm_loader.streamer: streaming_handler.on_llm_new_token(token) self.llm_loader.streamer.reset() count -= 1 except Exception: print("nothing generated yet - retry in 0.5s") time.sleep(0.5) t.join() return que.get()