from qa.chains import conversational_retrieval_qa, retrieval_qa from qa.loader import youtube_doc_loader from qa.model import load_llm from qa.split import split_document from qa.vector_store import create_vector_store class YoutubeQA: def __init__(self): self.CHAT_MODE = 'normal' def change_chat_mode(self, mode: str) -> None: self.CHAT_MODE = mode def load_model(self) -> None: self.llm = load_llm() def load_vector_store(self, url: str) -> None: data = youtube_doc_loader(url=url) docs = split_document(data=data) self.retriver = create_vector_store(docs=docs) def load_retriever(self) -> None: if self.CHAT_MODE == 'normal': self.retrieval_qa = retrieval_qa(llm=self.llm, retriever=self.retriver) elif self.CHAT_MODE == 'conversational': self.retrieval_qa = conversational_retrieval_qa(llm=self.llm, retriever=self.retriver) else: raise ValueError('Chat Mode not implemented') def run(self, question: str) -> str: if self.CHAT_MODE == 'normal': return self.retrieval_qa.run(question) elif self.CHAT_MODE == 'conversational': return self.retrieval_qa({'question': question})['answer']