learn-ai / test.py
dh-mc's picture
completed gradio app for qa
e182c41
raw
history blame
No virus
2.57 kB
# project/test.py
import os
import unittest
from timeit import default_timer as timer
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import HumanMessage
from app_modules.init import app_init
from app_modules.llm_loader import LLMLoader
from app_modules.utils import get_device_types, print_llm_response
class TestLLMLoader: # (unittest.TestCase):
question = "What's the capital city of Malaysia?"
def run_test_case(self, llm_model_type, query):
n_threds = int(os.environ.get("NUMBER_OF_CPU_CORES") or "4")
hf_embeddings_device_type, hf_pipeline_device_type = get_device_types()
print(f"hf_embeddings_device_type: {hf_embeddings_device_type}")
print(f"hf_pipeline_device_type: {hf_pipeline_device_type}")
llm_loader = LLMLoader(llm_model_type)
start = timer()
llm_loader.init(
n_threds=n_threds, hf_pipeline_device_type=hf_pipeline_device_type
)
end = timer()
print(f"Model loaded in {end - start:.3f}s")
result = llm_loader.llm(
[HumanMessage(content=query)] if llm_model_type == "openai" else query
)
end2 = timer()
print(f"Inference completed in {end2 - end:.3f}s")
print(result)
def test_openai(self):
self.run_test_case("openai", self.question)
def test_llamacpp(self):
self.run_test_case("llamacpp", self.question)
def test_gpt4all_j(self):
self.run_test_case("gpt4all-j", self.question)
def test_huggingface(self):
self.run_test_case("huggingface", self.question)
class TestQAChain(unittest.TestCase):
qa_chain: any
question = "What's deep learning?"
def run_test_case(self, llm_model_type, query):
start = timer()
os.environ["LLM_MODEL_TYPE"] = llm_model_type
qa_chain = app_init()
end = timer()
print(f"App initialized in {end - start:.3f}s")
inputs = {"question": query, "chat_history": []}
result = qa_chain.call_chain(inputs, None)
end2 = timer()
print(f"Inference completed in {end2 - end:.3f}s")
print_llm_response(result)
def test_openai(self):
self.run_test_case("openai", self.question)
def test_llamacpp(self):
self.run_test_case("llamacpp", self.question)
def test_gpt4all_j(self):
self.run_test_case("gpt4all-j", self.question)
def test_huggingface(self):
self.run_test_case("huggingface", self.question)
if __name__ == "__main__":
unittest.main()