File size: 2,032 Bytes
7f9d16c
815128e
c2cb992
7f9d16c
c2cb992
815128e
7f9d16c
 
815128e
7f9d16c
c2cb992
815128e
c2cb992
 
 
 
 
 
815128e
 
 
 
 
 
 
 
 
 
 
 
7f9d16c
815128e
 
 
 
 
 
7f9d16c
 
 
 
c2cb992
 
 
7f9d16c
 
815128e
7f9d16c
 
 
 
 
 
815128e
c2cb992
 
815128e
c2cb992
 
815128e
c2cb992
 
815128e
7f9d16c
c2cb992
815128e
 
7f9d16c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# project/test.py

import os
import unittest
from timeit import default_timer as timer

from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import HumanMessage

from app_modules.llm_loader import LLMLoader
from app_modules.utils import *

user_question = "What's the capital city of Malaysia?"
n_threds = int(os.environ.get("NUMBER_OF_CPU_CORES") or "4")

hf_embeddings_device_type, hf_pipeline_device_type = get_device_types()
print(f"hf_embeddings_device_type: {hf_embeddings_device_type}")
print(f"hf_pipeline_device_type: {hf_pipeline_device_type}")


class MyCustomHandler(BaseCallbackHandler):
    def __init__(self):
        self.reset()

    def reset(self):
        self.texts = []

    def get_standalone_question(self) -> str:
        return self.texts[0].strip() if len(self.texts) > 0 else None

    def on_llm_end(self, response, **kwargs) -> None:
        """Run when chain ends running."""
        print("\non_llm_end - response:")
        print(response)
        self.texts.append(response.generations[0][0].text)


class TestLLMLoader(unittest.TestCase):
    def run_test_case(self, llm_model_type, query):
        llm_loader = LLMLoader(llm_model_type)
        start = timer()
        llm_loader.init(
            n_threds=n_threds, hf_pipeline_device_type=hf_pipeline_device_type
        )
        end = timer()
        print(f"Model loaded in {end - start:.3f}s")

        result = llm_loader.llm(
            [HumanMessage(content=query)] if llm_model_type == "openai" else query
        )
        end2 = timer()
        print(f"Inference completed in {end2 - end:.3f}s")
        print(result)

    def test_openai(self):
        self.run_test_case("openai", user_question)

    def test_llamacpp(self):
        self.run_test_case("llamacpp", user_question)

    def test_gpt4all_j(self):
        self.run_test_case("gpt4all-j", user_question)

    def test_huggingface(self):
        self.run_test_case("huggingface", user_question)


if __name__ == "__main__":
    unittest.main()