LLM4SciLit / src /demo.py
tommymarto's picture
first attempt to hf spaces
5d872c9
raw
history blame
5.28 kB
import logging
from pathlib import Path
import cmd
import shlex
import hydra
from omegaconf import DictConfig, OmegaConf
from art import tprint
import utils
log = logging.getLogger(__name__)
class CLIApp(cmd.Cmd):
class CleanExit:
def __init__(self, cliapp):
self.cliapp = cliapp
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, exc_tb):
if exc_type is KeyboardInterrupt:
print("\n", end="")
self.cliapp.do_exit(None)
return True
return exc_type is None
prompt = '> '
intro = """Running in interactive mode:
Welcome to the LLM4SciLit shell. Type help or ? to list commands.\n"""
def __init__(self, app, cfg : DictConfig) -> None:
super().__init__()
self.app = app
self.cfg = cfg
def do_exit(self, _):
"""Exit the shell."""
# self.app.vector_store.save(self.cfg.storage_path.vector_store)
print("\nLLM4SciLit: Bye!\n")
self.app.exit()
return True
do_EOF = do_exit
def do_ask_paper(self, line):
"""Ask a question about a paper."""
paper, line = shlex.split(line)
filter_dict = {"paper_title": paper}
print(f"\nLLM4SciLit: {self.app.qa_model.answer_question(line, filter_dict)['result']}\n")
def default(self, line):
# print(f"\nLLM4SciLit: a bunch of nonsense\n")
print(f"\nLLM4SciLit: {self.app.qa_model.answer_question(line, {})['result']}\n")
class App:
def __init__(self, cfg : DictConfig) -> None:
self.cfg = cfg
log.info("Loading: Document Loader")
self.loader = hydra.utils.instantiate(cfg.document_loader)
log.info("Loading: Text Splitter")
self.splitter = hydra.utils.instantiate(cfg.text_splitter)
log.info("Loading: Text Embedding Model")
self.text_embedding_model = hydra.utils.instantiate(cfg.text_embedding)
log.info("Loading: Vector Store")
self.vector_store = hydra.utils.instantiate(cfg.vector_store, self.text_embedding_model)
log.info("Loading: Document Retriever")
self.retriever = hydra.utils.instantiate(cfg.document_retriever, self.vector_store)
log.info("Loading: Question Answering Model")
self.qa_model = hydra.utils.instantiate(cfg.question_answering, self.retriever)
def _bootstrap(self) -> None:
# if vector store does not exist, create it
# if vector store exists, load it
if not Path(self.cfg.storage_path.vector_store).exists() or self.cfg.debug.force_rebuild_storage:
message = (
"Vector store not found at %s. Building storage from scratch"
if not self.cfg.debug.force_rebuild_storage
else "Forced to rebuild storage. Building storage from scratch"
)
log.info(message, self.cfg.storage_path.vector_store)
docs = self.loader.load_documents(self.cfg.storage_path.documents)
docs = self.splitter.split_documents(docs)
utils.save_docs_to_jsonl(docs, self.cfg.storage_path.documents_processed)
self.vector_store.initialize_from_documents(docs)
self.vector_store.save(self.cfg.storage_path.vector_store)
else:
log.info("Vector store found at %s. Loading existing storage", self.cfg.storage_path.vector_store)
self.vector_store.initialize_from_file(self.cfg.storage_path.vector_store)
self.retriever.initialize()
self.qa_model.initialize()
print("Ready to answer your questions πŸ”₯πŸ”₯\n")
##################################################################################################
# App functionalities
def ask_paper(self, line):
"""Ask a question about a paper."""
paper, line = shlex.split(line)
filter_dict = {"paper_title": paper}
print(f"\nLLM4SciLit: {self.qa_model.answer_question(line, filter_dict)['result']}\n")
def ask(self, line):
# print(f"\nLLM4SciLit: a bunch of nonsense\n")
print(f"\nLLM4SciLit: {self.qa_model.answer_question(line, {})['result']}\n")
def ask_chat(self, line, history):
# print(f"\nLLM4SciLit: a bunch of nonsense\n")
return self.qa_model.answer_question(line, {})['result']
##################################################################################################
# App modes
def run_interactive(self) -> None:
self._bootstrap()
cli = CLIApp(self, self.cfg)
with CLIApp.CleanExit(cli):
cli.cmdloop()
def exit(self):
"""
Do any cleanup here
"""
@hydra.main(version_base=None, config_path="../config", config_name="config")
def main(cfg : DictConfig) -> None:
tprint("LLM4SciLit")
if cfg.debug.is_debug:
print("Running with config:")
print(OmegaConf.to_yaml(cfg))
app = App(cfg)
match cfg.mode:
case "interactive":
app.run_interactive()
case _:
raise ValueError(f"Unknown mode: {cfg.mode}")
if __name__ == "__main__":
main() # pylint: disable=E1120:no-value-for-parameter