"""Main entrypoint for the app.""" import json import os from timeit import default_timer as timer from typing import List, Optional from lcserve import serving from pydantic import BaseModel from app_modules.init import app_init from app_modules.llm_chat_chain import ChatChain from app_modules.utils import print_llm_response llm_loader, qa_chain = app_init(True) chat_history_enabled = os.environ.get("CHAT_HISTORY_ENABLED") == "true" uuid_to_chat_chain_mapping = dict() class ChatResponse(BaseModel): """Chat response schema.""" token: Optional[str] = None error: Optional[str] = None sourceDocs: Optional[List] = None @serving(websocket=True) def chat( question: str, history: Optional[List] = [], uuid: Optional[str] = None, **kwargs ) -> str: print(f"uuid: {uuid}") # Get the `streaming_handler` from `kwargs`. This is used to stream data to the client. streaming_handler = kwargs.get("streaming_handler") if uuid is None: chat_history = [] if chat_history_enabled: for element in history: item = (element[0] or "", element[1] or "") chat_history.append(item) start = timer() result = qa_chain.call_chain( {"question": question, "chat_history": chat_history}, streaming_handler ) end = timer() print(f"Completed in {end - start:.3f}s") print(f"qa_chain result: {result}") resp = ChatResponse(sourceDocs=result["source_documents"]) return json.dumps(resp.dict()) else: if uuid in uuid_to_chat_chain_mapping: chat = uuid_to_chat_chain_mapping[uuid] else: chat = ChatChain(llm_loader) uuid_to_chat_chain_mapping[uuid] = chat result = chat.call_chain({"question": question}, streaming_handler) print(f"chat result: {result}") resp = ChatResponse(sourceDocs=[]) return json.dumps(resp.dict()) if __name__ == "__main__": print_llm_response(json.loads(chat("What's deep learning?", [])))