import abc import os import time import urllib from queue import Queue from threading import Thread from langchain.callbacks.tracers import LangChainTracer from langchain.chains.base import Chain from app_modules.llm_loader import LLMLoader, TextIteratorStreamer from app_modules.utils import remove_extra_spaces class LLMInference(metaclass=abc.ABCMeta): llm_loader: LLMLoader chain: Chain def __init__(self, llm_loader): self.llm_loader = llm_loader self.chain = None @abc.abstractmethod def create_chain(self) -> Chain: pass def get_chain(self, tracing: bool = False) -> Chain: if self.chain is None: if tracing: tracer = LangChainTracer() tracer.load_default_session() self.chain = self.create_chain() return self.chain def call_chain( self, inputs, streaming_handler, q: Queue = None, tracing: bool = False ): print(inputs) self.llm_loader.lock.acquire() try: self.llm_loader.streamer.reset(q) chain = self.get_chain(tracing) result = ( self._run_chain( chain, inputs, streaming_handler, ) if streaming_handler is not None and self.llm_loader.streamer.for_huggingface else chain(inputs) ) if "answer" in result: result["answer"] = remove_extra_spaces(result["answer"]) base_url = os.environ.get("PDF_FILE_BASE_URL") if base_url is not None and len(base_url) > 0: documents = result["source_documents"] for doc in documents: source = doc.metadata["source"] title = source.split("/")[-1] doc.metadata["url"] = f"{base_url}{urllib.parse.quote(title)}" return result finally: self.llm_loader.lock.release() def _execute_chain(self, chain, inputs, q, sh): q.put(chain(inputs, callbacks=[sh])) def _run_chain(self, chain, inputs, streaming_handler): que = Queue() t = Thread( target=self._execute_chain, args=(chain, inputs, que, streaming_handler), ) t.start() count = ( 2 if "chat_history" in inputs and len(inputs.get("chat_history")) > 0 else 1 ) while count > 0: try: for token in self.llm_loader.streamer: streaming_handler.on_llm_new_token(token) self.llm_loader.streamer.reset() count -= 1 except Exception: print("nothing generated yet - retry in 0.5s") time.sleep(0.5) t.join() return que.get()