dh-mc commited on
Commit
7f9d16c
1 Parent(s): ab49330

refactor code

Browse files
Files changed (4) hide show
  1. .env.example +5 -5
  2. Makefile +1 -1
  3. app_modules/llm_loader.py +553 -0
  4. test.py +31 -127
.env.example CHANGED
@@ -8,7 +8,7 @@ LLM_MODEL_TYPE=huggingface
8
 
9
  OPENAI_API_KEY=
10
 
11
- # if unset, default to "gpt-4"
12
  OPENAI_MODEL_NAME=
13
 
14
  # cpu, mps or cuda:0 - if unset, use whatever detected
@@ -54,14 +54,14 @@ MOSAICML_MODEL_NAME_OR_PATH="mosaicml/mpt-7b-instruct"
54
 
55
  FALCON_MODEL_NAME_OR_PATH="tiiuae/falcon-7b-instruct"
56
 
57
- GPT4ALL_J_MODEL_PATH="./models/ggml-gpt4all-j-v1.3-groovy.bin"
58
- GPT4ALL_J_DOWNLOAD_LINK=https://gpt4all.io/models/ggml-gpt4all-j-v1.3-groovy.bin
59
 
60
  GPT4ALL_MODEL_PATH="./models/ggml-nous-gpt4-vicuna-13b.bin"
61
  GPT4ALL_DOWNLOAD_LINK=https://gpt4all.io/models/ggml-nous-gpt4-vicuna-13b.bin
62
 
63
- LLAMACPP_MODEL_PATH="./models/wizardLM-7B.ggmlv3.q4_1.bin"
64
- LLAMACPP_DOWNLOAD_LINK=https://huggingface.co/TheBloke/wizardLM-7B-GGML/resolve/main/wizardLM-7B.ggmlv3.q4_1.bin
65
 
66
  # Index for AI Books PDF files - chunk_size=1024 chunk_overlap=512
67
  # CHROMADB_INDEX_PATH="./data/chromadb_1024_512/"
 
8
 
9
  OPENAI_API_KEY=
10
 
11
+ # if unset, default to "gpt-3.5-turbo"
12
  OPENAI_MODEL_NAME=
13
 
14
  # cpu, mps or cuda:0 - if unset, use whatever detected
 
54
 
55
  FALCON_MODEL_NAME_OR_PATH="tiiuae/falcon-7b-instruct"
56
 
57
+ GPT4ALL_J_MODEL_PATH="../models/llama-2-7b-chat.ggmlv3.q4_0.bin"
58
+ GPT4ALL_J_DOWNLOAD_LINK=https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/resolve/main/llama-2-7b-chat.ggmlv3.q4_0.bin
59
 
60
  GPT4ALL_MODEL_PATH="./models/ggml-nous-gpt4-vicuna-13b.bin"
61
  GPT4ALL_DOWNLOAD_LINK=https://gpt4all.io/models/ggml-nous-gpt4-vicuna-13b.bin
62
 
63
+ LLAMACPP_MODEL_PATH="./models/llama-2-7b-chat.ggmlv3.q4_K_M.bin"
64
+ LLAMACPP_DOWNLOAD_LINK=https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/resolve/main/llama-2-7b-chat.ggmlv3.q4_K_M.bin
65
 
66
  # Index for AI Books PDF files - chunk_size=1024 chunk_overlap=512
67
  # CHROMADB_INDEX_PATH="./data/chromadb_1024_512/"
Makefile CHANGED
@@ -10,7 +10,7 @@ else
10
  endif
11
 
12
  test:
13
- PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 TRANSFORMERS_OFFLINE=1 python test.py
14
 
15
  chat:
16
  python test.py chat
 
10
  endif
11
 
12
  test:
13
+ PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 python test.py
14
 
15
  chat:
16
  python test.py chat
app_modules/llm_loader.py ADDED
@@ -0,0 +1,553 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import time
4
+ import urllib
5
+ from queue import Queue
6
+ from threading import Thread
7
+ from typing import Any, Optional
8
+
9
+ import torch
10
+ from langchain.callbacks.base import BaseCallbackHandler
11
+ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
12
+ from langchain.callbacks.tracers import LangChainTracer
13
+ from langchain.chains import ConversationalRetrievalChain
14
+ from langchain.chat_models import ChatOpenAI
15
+ from langchain.llms import GPT4All, HuggingFacePipeline, LlamaCpp
16
+ from langchain.schema import LLMResult
17
+ from langchain.vectorstores import VectorStore
18
+ from langchain.vectorstores.base import VectorStore
19
+ from transformers import (
20
+ AutoConfig,
21
+ AutoModelForCausalLM,
22
+ AutoModelForSeq2SeqLM,
23
+ AutoTokenizer,
24
+ BitsAndBytesConfig,
25
+ StoppingCriteria,
26
+ StoppingCriteriaList,
27
+ T5Tokenizer,
28
+ TextStreamer,
29
+ pipeline,
30
+ )
31
+
32
+ from app_modules.instruct_pipeline import InstructionTextGenerationPipeline
33
+ from app_modules.utils import ensure_model_is_downloaded, remove_extra_spaces
34
+
35
+
36
+ class TextIteratorStreamer(TextStreamer, StreamingStdOutCallbackHandler):
37
+ def __init__(
38
+ self,
39
+ tokenizer: "AutoTokenizer",
40
+ skip_prompt: bool = False,
41
+ timeout: Optional[float] = None,
42
+ **decode_kwargs,
43
+ ):
44
+ super().__init__(tokenizer, skip_prompt, **decode_kwargs)
45
+ self.text_queue = Queue()
46
+ self.stop_signal = None
47
+ self.timeout = timeout
48
+
49
+ def on_finalized_text(self, text: str, stream_end: bool = False):
50
+ super().on_finalized_text(text, stream_end=stream_end)
51
+
52
+ """Put the new text in the queue. If the stream is ending, also put a stop signal in the queue."""
53
+ self.text_queue.put(text, timeout=self.timeout)
54
+ if stream_end:
55
+ print("\n")
56
+ self.text_queue.put("\n", timeout=self.timeout)
57
+ self.text_queue.put(self.stop_signal, timeout=self.timeout)
58
+
59
+ def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
60
+ sys.stdout.write(token)
61
+ sys.stdout.flush()
62
+ self.text_queue.put(token, timeout=self.timeout)
63
+
64
+ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
65
+ print("\n")
66
+ self.text_queue.put("\n", timeout=self.timeout)
67
+ self.text_queue.put(self.stop_signal, timeout=self.timeout)
68
+
69
+ def __iter__(self):
70
+ return self
71
+
72
+ def __next__(self):
73
+ value = self.text_queue.get(timeout=self.timeout)
74
+ if value == self.stop_signal:
75
+ raise StopIteration()
76
+ else:
77
+ return value
78
+
79
+ def reset(self, q: Queue = None):
80
+ # print("resetting TextIteratorStreamer")
81
+ self.text_queue = q if q is not None else Queue()
82
+
83
+ def empty(self):
84
+ return self.text_queue.empty()
85
+
86
+
87
+ class LLMLoader:
88
+ llm_model_type: str
89
+ llm: any
90
+ streamer: any
91
+
92
+ def __init__(self, llm_model_type):
93
+ self.llm_model_type = llm_model_type
94
+ self.llm = None
95
+ self.streamer = TextIteratorStreamer("")
96
+ self.max_tokens_limit = 2048
97
+ self.search_kwargs = {"k": 4}
98
+
99
+ def _init_streamer(self, tokenizer, custom_handler):
100
+ self.streamer = (
101
+ TextIteratorStreamer(
102
+ tokenizer,
103
+ timeout=10.0,
104
+ skip_prompt=True,
105
+ skip_special_tokens=True,
106
+ )
107
+ if custom_handler is None
108
+ else TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
109
+ )
110
+
111
+ def init(
112
+ self,
113
+ custom_handler: Optional[BaseCallbackHandler] = None,
114
+ n_threds: int = 4,
115
+ hf_pipeline_device_type: str = None,
116
+ ):
117
+ print("initializing LLM: " + self.llm_model_type)
118
+
119
+ if hf_pipeline_device_type is None:
120
+ hf_pipeline_device_type = "cpu"
121
+
122
+ using_cuda = hf_pipeline_device_type.startswith("cuda")
123
+ torch_dtype = torch.float16 if using_cuda else torch.float32
124
+ if os.environ.get("USING_TORCH_BFLOAT16") == "true":
125
+ torch_dtype = torch.bfloat16
126
+ load_quantized_model = os.environ.get("LOAD_QUANTIZED_MODEL")
127
+
128
+ print(f" hf_pipeline_device_type: {hf_pipeline_device_type}")
129
+ print(f" load_quantized_model: {load_quantized_model}")
130
+ print(f" torch_dtype: {torch_dtype}")
131
+ print(f" n_threds: {n_threds}")
132
+
133
+ double_quant_config = BitsAndBytesConfig(
134
+ load_in_4bit=load_quantized_model == "4bit",
135
+ bnb_4bit_use_double_quant=load_quantized_model == "4bit",
136
+ load_in_8bit=load_quantized_model == "8bit",
137
+ bnb_8bit_use_double_quant=load_quantized_model == "8bit",
138
+ )
139
+
140
+ callbacks = [self.streamer]
141
+ if custom_handler is not None:
142
+ callbacks.append(custom_handler)
143
+
144
+ if self.llm is None:
145
+ if self.llm_model_type == "openai":
146
+ MODEL_NAME = os.environ.get("OPENAI_MODEL_NAME") or "gpt-3.5-turbo"
147
+ print(f" using model: {MODEL_NAME}")
148
+ self.llm = ChatOpenAI(
149
+ model_name=MODEL_NAME,
150
+ streaming=True,
151
+ callbacks=callbacks,
152
+ verbose=True,
153
+ temperature=0,
154
+ )
155
+ elif self.llm_model_type.startswith("gpt4all"):
156
+ MODEL_PATH = ensure_model_is_downloaded(self.llm_model_type)
157
+ self.llm = GPT4All(
158
+ model=MODEL_PATH,
159
+ max_tokens=2048,
160
+ n_threads=n_threds,
161
+ backend="gptj" if self.llm_model_type == "gpt4all-j" else "llama",
162
+ callbacks=callbacks,
163
+ verbose=True,
164
+ use_mlock=True,
165
+ )
166
+ elif self.llm_model_type == "llamacpp":
167
+ MODEL_PATH = ensure_model_is_downloaded(self.llm_model_type)
168
+ self.llm = LlamaCpp(
169
+ model_path=MODEL_PATH,
170
+ n_ctx=8192,
171
+ n_threads=n_threds,
172
+ seed=0,
173
+ temperature=0,
174
+ max_tokens=2048,
175
+ callbacks=callbacks,
176
+ verbose=True,
177
+ use_mlock=True,
178
+ )
179
+ elif self.llm_model_type.startswith("huggingface"):
180
+ MODEL_NAME_OR_PATH = os.environ.get("HUGGINGFACE_MODEL_NAME_OR_PATH")
181
+ print(f" loading model: {MODEL_NAME_OR_PATH}")
182
+
183
+ hf_auth_token = os.environ.get("HUGGINGFACE_AUTH_TOKEN")
184
+ transformers_offline = os.environ.get("TRANSFORMERS_OFFLINE") == "1"
185
+ token = (
186
+ hf_auth_token
187
+ if hf_auth_token is not None
188
+ and len(hf_auth_token) > 0
189
+ and not transformers_offline
190
+ else None
191
+ )
192
+ print(f" HF auth token: {str(token)[-5:]}")
193
+
194
+ is_t5 = "t5" in MODEL_NAME_OR_PATH
195
+ temperature = (
196
+ 0.01
197
+ if "gpt4all-j" in MODEL_NAME_OR_PATH
198
+ or "dolly" in MODEL_NAME_OR_PATH
199
+ else 0
200
+ )
201
+ use_fast = (
202
+ "stable" in MODEL_NAME_OR_PATH
203
+ or "RedPajama" in MODEL_NAME_OR_PATH
204
+ or "dolly" in MODEL_NAME_OR_PATH
205
+ )
206
+ padding_side = "left" # if "dolly" in MODEL_NAME_OR_PATH else None
207
+
208
+ config = AutoConfig.from_pretrained(
209
+ MODEL_NAME_OR_PATH,
210
+ trust_remote_code=True,
211
+ token=token,
212
+ )
213
+ # config.attn_config["attn_impl"] = "triton"
214
+ # config.max_seq_len = 4096
215
+ config.init_device = hf_pipeline_device_type
216
+
217
+ tokenizer = (
218
+ T5Tokenizer.from_pretrained(
219
+ MODEL_NAME_OR_PATH,
220
+ token=token,
221
+ )
222
+ if is_t5
223
+ else AutoTokenizer.from_pretrained(
224
+ MODEL_NAME_OR_PATH,
225
+ use_fast=use_fast,
226
+ trust_remote_code=True,
227
+ padding_side=padding_side,
228
+ token=token,
229
+ )
230
+ )
231
+
232
+ self._init_streamer(tokenizer, custom_handler)
233
+
234
+ task = "text2text-generation" if is_t5 else "text-generation"
235
+
236
+ return_full_text = True if "dolly" in MODEL_NAME_OR_PATH else None
237
+
238
+ repetition_penalty = (
239
+ 1.15
240
+ if "falcon" in MODEL_NAME_OR_PATH
241
+ else (1.25 if "dolly" in MODEL_NAME_OR_PATH else 1.1)
242
+ )
243
+
244
+ if load_quantized_model is not None:
245
+ model = (
246
+ AutoModelForSeq2SeqLM.from_pretrained(
247
+ MODEL_NAME_OR_PATH,
248
+ config=config,
249
+ quantization_config=double_quant_config,
250
+ trust_remote_code=True,
251
+ token=token,
252
+ )
253
+ if is_t5
254
+ else AutoModelForCausalLM.from_pretrained(
255
+ MODEL_NAME_OR_PATH,
256
+ config=config,
257
+ quantization_config=double_quant_config,
258
+ trust_remote_code=True,
259
+ token=token,
260
+ )
261
+ )
262
+
263
+ print(f"Model memory footprint: {model.get_memory_footprint()}")
264
+
265
+ eos_token_id = -1
266
+ # starchat-beta uses a special <|end|> token with ID 49155 to denote ends of a turn
267
+ if "starchat" in MODEL_NAME_OR_PATH:
268
+ eos_token_id = 49155
269
+ pad_token_id = eos_token_id
270
+
271
+ pipe = (
272
+ InstructionTextGenerationPipeline(
273
+ task=task,
274
+ model=model,
275
+ tokenizer=tokenizer,
276
+ streamer=self.streamer,
277
+ max_new_tokens=2048,
278
+ temperature=temperature,
279
+ return_full_text=return_full_text, # langchain expects the full text
280
+ repetition_penalty=repetition_penalty,
281
+ )
282
+ if "dolly" in MODEL_NAME_OR_PATH
283
+ else (
284
+ pipeline(
285
+ task,
286
+ model=model,
287
+ tokenizer=tokenizer,
288
+ eos_token_id=eos_token_id,
289
+ pad_token_id=pad_token_id,
290
+ streamer=self.streamer,
291
+ return_full_text=return_full_text, # langchain expects the full text
292
+ device_map="auto",
293
+ trust_remote_code=True,
294
+ max_new_tokens=2048,
295
+ do_sample=True,
296
+ temperature=0.01,
297
+ top_p=0.95,
298
+ top_k=50,
299
+ repetition_penalty=repetition_penalty,
300
+ )
301
+ if eos_token_id != -1
302
+ else pipeline(
303
+ task,
304
+ model=model,
305
+ tokenizer=tokenizer,
306
+ streamer=self.streamer,
307
+ return_full_text=return_full_text, # langchain expects the full text
308
+ device_map="auto",
309
+ trust_remote_code=True,
310
+ max_new_tokens=2048,
311
+ # verbose=True,
312
+ temperature=temperature,
313
+ top_p=0.95,
314
+ top_k=0, # select from top 0 tokens (because zero, relies on top_p)
315
+ repetition_penalty=repetition_penalty,
316
+ )
317
+ )
318
+ )
319
+ elif "dolly" in MODEL_NAME_OR_PATH:
320
+ model = AutoModelForCausalLM.from_pretrained(
321
+ MODEL_NAME_OR_PATH,
322
+ device_map=hf_pipeline_device_type,
323
+ torch_dtype=torch_dtype,
324
+ )
325
+
326
+ pipe = InstructionTextGenerationPipeline(
327
+ task=task,
328
+ model=model,
329
+ tokenizer=tokenizer,
330
+ streamer=self.streamer,
331
+ max_new_tokens=2048,
332
+ temperature=temperature,
333
+ return_full_text=True,
334
+ repetition_penalty=repetition_penalty,
335
+ token=token,
336
+ )
337
+ else:
338
+ if os.environ.get("DISABLE_MODEL_PRELOADING") != "true":
339
+ use_auth_token = None
340
+ model = (
341
+ AutoModelForSeq2SeqLM.from_pretrained(
342
+ MODEL_NAME_OR_PATH,
343
+ config=config,
344
+ trust_remote_code=True,
345
+ token=token,
346
+ )
347
+ if is_t5
348
+ else AutoModelForCausalLM.from_pretrained(
349
+ MODEL_NAME_OR_PATH,
350
+ config=config,
351
+ trust_remote_code=True,
352
+ token=token,
353
+ )
354
+ )
355
+ print(f"Model memory footprint: {model.get_memory_footprint()}")
356
+ else:
357
+ use_auth_token = token
358
+ model = MODEL_NAME_OR_PATH
359
+
360
+ pipe = pipeline(
361
+ task,
362
+ model=model,
363
+ tokenizer=tokenizer,
364
+ streamer=self.streamer,
365
+ return_full_text=return_full_text, # langchain expects the full text
366
+ device=hf_pipeline_device_type,
367
+ torch_dtype=torch_dtype,
368
+ max_new_tokens=2048,
369
+ trust_remote_code=True,
370
+ temperature=temperature,
371
+ top_p=0.95,
372
+ top_k=0, # select from top 0 tokens (because zero, relies on top_p)
373
+ repetition_penalty=1.115,
374
+ token=use_auth_token,
375
+ )
376
+
377
+ self.llm = HuggingFacePipeline(pipeline=pipe, callbacks=callbacks)
378
+ elif self.llm_model_type == "mosaicml":
379
+ MODEL_NAME_OR_PATH = os.environ.get("MOSAICML_MODEL_NAME_OR_PATH")
380
+ print(f" loading model: {MODEL_NAME_OR_PATH}")
381
+
382
+ config = AutoConfig.from_pretrained(
383
+ MODEL_NAME_OR_PATH, trust_remote_code=True
384
+ )
385
+ # config.attn_config["attn_impl"] = "triton"
386
+ config.max_seq_len = 16384 if "30b" in MODEL_NAME_OR_PATH else 4096
387
+ config.init_device = hf_pipeline_device_type
388
+
389
+ model = (
390
+ AutoModelForCausalLM.from_pretrained(
391
+ MODEL_NAME_OR_PATH,
392
+ config=config,
393
+ quantization_config=double_quant_config,
394
+ trust_remote_code=True,
395
+ )
396
+ if load_quantized_model is not None
397
+ else AutoModelForCausalLM.from_pretrained(
398
+ MODEL_NAME_OR_PATH,
399
+ config=config,
400
+ torch_dtype=torch_dtype,
401
+ trust_remote_code=True,
402
+ )
403
+ )
404
+
405
+ print(f"Model loaded on {config.init_device}")
406
+ print(f"Model memory footprint: {model.get_memory_footprint()}")
407
+
408
+ tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
409
+ self._init_streamer(tokenizer, custom_handler)
410
+
411
+ # mtp-7b is trained to add "<|endoftext|>" at the end of generations
412
+ stop_token_ids = tokenizer.convert_tokens_to_ids(["<|endoftext|>"])
413
+
414
+ # define custom stopping criteria object
415
+ class StopOnTokens(StoppingCriteria):
416
+ def __call__(
417
+ self,
418
+ input_ids: torch.LongTensor,
419
+ scores: torch.FloatTensor,
420
+ **kwargs,
421
+ ) -> bool:
422
+ for stop_id in stop_token_ids:
423
+ if input_ids[0][-1] == stop_id:
424
+ return True
425
+ return False
426
+
427
+ stopping_criteria = StoppingCriteriaList([StopOnTokens()])
428
+
429
+ max_new_tokens = 8192 if "30b" in MODEL_NAME_OR_PATH else 2048
430
+ self.max_tokens_limit = max_new_tokens
431
+ self.search_kwargs = (
432
+ {"k": 8} if "30b" in MODEL_NAME_OR_PATH else self.search_kwargs
433
+ )
434
+ repetition_penalty = 1.05 if "30b" in MODEL_NAME_OR_PATH else 1.02
435
+
436
+ pipe = (
437
+ pipeline(
438
+ model=model,
439
+ tokenizer=tokenizer,
440
+ streamer=self.streamer,
441
+ return_full_text=True, # langchain expects the full text
442
+ task="text-generation",
443
+ device_map="auto",
444
+ # we pass model parameters here too
445
+ stopping_criteria=stopping_criteria, # without this model will ramble
446
+ temperature=0, # 'randomness' of outputs, 0.0 is the min and 1.0 the max
447
+ top_p=0.95, # select from top tokens whose probability add up to 15%
448
+ top_k=0, # select from top 0 tokens (because zero, relies on top_p)
449
+ max_new_tokens=max_new_tokens, # mex number of tokens to generate in the output
450
+ repetition_penalty=repetition_penalty, # without this output begins repeating
451
+ )
452
+ if load_quantized_model is not None
453
+ else pipeline(
454
+ model=model,
455
+ tokenizer=tokenizer,
456
+ streamer=self.streamer,
457
+ return_full_text=True, # langchain expects the full text
458
+ task="text-generation",
459
+ device=config.init_device,
460
+ # we pass model parameters here too
461
+ stopping_criteria=stopping_criteria, # without this model will ramble
462
+ temperature=0, # 'randomness' of outputs, 0.0 is the min and 1.0 the max
463
+ top_p=0.95, # select from top tokens whose probability add up to 15%
464
+ top_k=0, # select from top 0 tokens (because zero, relies on top_p)
465
+ max_new_tokens=max_new_tokens, # mex number of tokens to generate in the output
466
+ repetition_penalty=repetition_penalty, # without this output begins repeating
467
+ )
468
+ )
469
+ self.llm = HuggingFacePipeline(pipeline=pipe, callbacks=callbacks)
470
+ elif self.llm_model_type == "stablelm":
471
+ MODEL_NAME_OR_PATH = os.environ.get("STABLELM_MODEL_NAME_OR_PATH")
472
+ print(f" loading model: {MODEL_NAME_OR_PATH}")
473
+
474
+ config = AutoConfig.from_pretrained(
475
+ MODEL_NAME_OR_PATH, trust_remote_code=True
476
+ )
477
+ # config.attn_config["attn_impl"] = "triton"
478
+ # config.max_seq_len = 4096
479
+ config.init_device = hf_pipeline_device_type
480
+
481
+ model = (
482
+ AutoModelForCausalLM.from_pretrained(
483
+ MODEL_NAME_OR_PATH,
484
+ config=config,
485
+ quantization_config=double_quant_config,
486
+ trust_remote_code=True,
487
+ )
488
+ if load_quantized_model is not None
489
+ else AutoModelForCausalLM.from_pretrained(
490
+ MODEL_NAME_OR_PATH,
491
+ config=config,
492
+ torch_dtype=torch_dtype,
493
+ trust_remote_code=True,
494
+ )
495
+ )
496
+
497
+ print(f"Model loaded on {config.init_device}")
498
+ print(f"Model memory footprint: {model.get_memory_footprint()}")
499
+
500
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH)
501
+ self._init_streamer(tokenizer, custom_handler)
502
+
503
+ class StopOnTokens(StoppingCriteria):
504
+ def __call__(
505
+ self,
506
+ input_ids: torch.LongTensor,
507
+ scores: torch.FloatTensor,
508
+ **kwargs,
509
+ ) -> bool:
510
+ stop_ids = [50278, 50279, 50277, 1, 0]
511
+ for stop_id in stop_ids:
512
+ if input_ids[0][-1] == stop_id:
513
+ return True
514
+ return False
515
+
516
+ stopping_criteria = StoppingCriteriaList([StopOnTokens()])
517
+
518
+ pipe = (
519
+ pipeline(
520
+ model=model,
521
+ tokenizer=tokenizer,
522
+ streamer=self.streamer,
523
+ return_full_text=True, # langchain expects the full text
524
+ task="text-generation",
525
+ device_map="auto",
526
+ # we pass model parameters here too
527
+ stopping_criteria=stopping_criteria, # without this model will ramble
528
+ temperature=0, # 'randomness' of outputs, 0.0 is the min and 1.0 the max
529
+ top_p=0.95, # select from top tokens whose probability add up to 15%
530
+ top_k=0, # select from top 0 tokens (because zero, relies on top_p)
531
+ max_new_tokens=2048, # mex number of tokens to generate in the output
532
+ repetition_penalty=1.25, # without this output begins repeating
533
+ )
534
+ if load_quantized_model is not None
535
+ else pipeline(
536
+ model=model,
537
+ tokenizer=tokenizer,
538
+ streamer=self.streamer,
539
+ return_full_text=True, # langchain expects the full text
540
+ task="text-generation",
541
+ device=config.init_device,
542
+ # we pass model parameters here too
543
+ stopping_criteria=stopping_criteria, # without this model will ramble
544
+ temperature=0, # 'randomness' of outputs, 0.0 is the min and 1.0 the max
545
+ top_p=0.95, # select from top tokens whose probability add up to 15%
546
+ top_k=0, # select from top 0 tokens (because zero, relies on top_p)
547
+ max_new_tokens=2048, # mex number of tokens to generate in the output
548
+ repetition_penalty=1.05, # without this output begins repeating
549
+ )
550
+ )
551
+ self.llm = HuggingFacePipeline(pipeline=pipe, callbacks=callbacks)
552
+
553
+ print("initialization complete")
test.py CHANGED
@@ -1,45 +1,14 @@
1
- import os
2
- import sys
3
- from timeit import default_timer as timer
4
- from typing import List
5
-
6
- from langchain.callbacks.base import BaseCallbackHandler
7
- from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
8
- from langchain.chains import ConversationalRetrievalChain
9
- from langchain.embeddings import HuggingFaceInstructEmbeddings
10
- from langchain.llms import GPT4All
11
- from langchain.schema import LLMResult
12
- from langchain.vectorstores.chroma import Chroma
13
- from langchain.vectorstores.faiss import FAISS
14
-
15
- from app_modules.qa_chain import *
16
- from app_modules.utils import *
17
 
18
- # Constants
19
- init_settings()
20
 
21
- # https://github.com/huggingface/transformers/issues/17611
22
- os.environ["CURL_CA_BUNDLE"] = ""
23
-
24
- hf_embeddings_device_type, hf_pipeline_device_type = get_device_types()
25
- print(f"hf_embeddings_device_type: {hf_embeddings_device_type}")
26
- print(f"hf_pipeline_device_type: {hf_pipeline_device_type}")
27
-
28
- hf_embeddings_model_name = (
29
- os.environ.get("HF_EMBEDDINGS_MODEL_NAME") or "hkunlp/instructor-xl"
30
- )
31
- n_threds = int(os.environ.get("NUMBER_OF_CPU_CORES") or "4")
32
- faiss_index_path = os.environ.get("FAISS_INDEX_PATH") or ""
33
- using_faiss = len(faiss_index_path) > 0
34
- index_path = faiss_index_path if using_faiss else os.environ.get("CHROMADB_INDEX_PATH")
35
- llm_model_type = os.environ.get("LLM_MODEL_TYPE")
36
- chatting = len(sys.argv) > 1 and sys.argv[1] == "chat"
37
- questions_file_path = os.environ.get("QUESTIONS_FILE_PATH")
38
- chat_history_enabled = os.environ.get("CHAT_HISTORY_ENABLED") or "true"
39
 
40
- ## utility functions
 
41
 
42
- import os
43
 
44
 
45
  class MyCustomHandler(BaseCallbackHandler):
@@ -52,105 +21,40 @@ class MyCustomHandler(BaseCallbackHandler):
52
  def get_standalone_question(self) -> str:
53
  return self.texts[0].strip() if len(self.texts) > 0 else None
54
 
55
- def on_llm_end(self, response: LLMResult, **kwargs) -> None:
56
  """Run when chain ends running."""
57
  print("\non_llm_end - response:")
58
  print(response)
59
  self.texts.append(response.generations[0][0].text)
60
 
61
 
62
- start = timer()
63
- embeddings = HuggingFaceInstructEmbeddings(
64
- model_name=hf_embeddings_model_name,
65
- model_kwargs={"device": hf_embeddings_device_type},
66
- )
67
- end = timer()
68
-
69
- print(f"Completed in {end - start:.3f}s")
70
-
71
- start = timer()
72
-
73
- print(f"Load index from {index_path} with {'FAISS' if using_faiss else 'Chroma'}")
74
-
75
- if not os.path.isdir(index_path):
76
- raise ValueError(f"{index_path} does not exist!")
77
- elif using_faiss:
78
- vectorstore = FAISS.load_local(index_path, embeddings)
79
- else:
80
- vectorstore = Chroma(embedding_function=embeddings, persist_directory=index_path)
81
-
82
- end = timer()
83
-
84
- print(f"Completed in {end - start:.3f}s")
85
-
86
- start = timer()
87
- qa_chain = QAChain(vectorstore, llm_model_type)
88
- custom_handler = MyCustomHandler()
89
- qa_chain.init(
90
- custom_handler, n_threds=n_threds, hf_pipeline_device_type=hf_pipeline_device_type
91
- )
92
- end = timer()
93
- print(f"Completed in {end - start:.3f}s")
94
-
95
- # input("Press Enter to continue...")
96
- # exit()
97
-
98
- # Chatbot loop
99
- chat_history = []
100
- print("Welcome to the ChatPDF! Type 'exit' to stop.")
101
-
102
- # Open the file for reading
103
- file = open(questions_file_path, "r")
104
-
105
- # Read the contents of the file into a list of strings
106
- queue = file.readlines()
107
- for i in range(len(queue)):
108
- queue[i] = queue[i].strip()
109
-
110
- # Close the file
111
- file.close()
112
-
113
- queue.append("exit")
114
-
115
- chat_start = timer()
116
-
117
- while True:
118
- if chatting:
119
- query = input("Please enter your question: ")
120
- else:
121
- query = queue.pop(0)
122
-
123
- query = query.strip()
124
- if query.lower() == "exit":
125
- break
126
-
127
- print("\nQuestion: " + query)
128
- custom_handler.reset()
129
 
130
- start = timer()
131
- result = qa_chain.call({"question": query, "chat_history": chat_history}, None)
132
- end = timer()
133
- print(f"Completed in {end - start:.3f}s")
 
 
134
 
135
- print_llm_response(result)
 
136
 
137
- if len(chat_history) == 0:
138
- standalone_question = query
139
- else:
140
- standalone_question = custom_handler.get_standalone_question()
141
 
142
- if standalone_question is not None:
143
- print(f"Load relevant documents for standalone question: {standalone_question}")
144
- start = timer()
145
- qa = qa_chain.get_chain()
146
- docs = qa.retriever.get_relevant_documents(standalone_question)
147
- end = timer()
148
 
149
- # print(docs)
150
- print(f"Completed in {end - start:.3f}s")
151
 
152
- if chat_history_enabled == "true":
153
- chat_history.append((query, result["answer"]))
154
 
155
- chat_end = timer()
156
- print(f"Total time used: {chat_end - chat_start:.3f}s")
 
1
+ # project/test.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
+ import unittest
 
4
 
5
+ from langchain.callbacks.base import BaseCallbackHandler
6
+ from langchain.schema import HumanMessage
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ from app_modules.llm_loader import LLMLoader
9
+ from timeit import default_timer as timer
10
 
11
+ USER_QUESTION = "What's the capital city of Malaysia?"
12
 
13
 
14
  class MyCustomHandler(BaseCallbackHandler):
 
21
  def get_standalone_question(self) -> str:
22
  return self.texts[0].strip() if len(self.texts) > 0 else None
23
 
24
+ def on_llm_end(self, response, **kwargs) -> None:
25
  """Run when chain ends running."""
26
  print("\non_llm_end - response:")
27
  print(response)
28
  self.texts.append(response.generations[0][0].text)
29
 
30
 
31
+ class TestLLMLoader(unittest.TestCase):
32
+ def run_test_case(self, llm_model_type, query):
33
+ llm_loader = LLMLoader(llm_model_type)
34
+ start = timer()
35
+ llm_loader.init(n_threds=8, hf_pipeline_device_type="cpu")
36
+ end = timer()
37
+ print(f"Model loaded in {end - start:.3f}s")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ result = llm_loader.llm(
40
+ [HumanMessage(content=query)] if llm_model_type == "openai" else query
41
+ )
42
+ end2 = timer()
43
+ print(f"Inference completed in {end2 - end:.3f}s")
44
+ print(result)
45
 
46
+ def xtest_openai(self):
47
+ self.run_test_case("openai", USER_QUESTION)
48
 
49
+ def xtest_llamacpp(self):
50
+ self.run_test_case("llamacpp", USER_QUESTION)
 
 
51
 
52
+ def xtest_gpt4all_j(self):
53
+ self.run_test_case("gpt4all-j", USER_QUESTION)
 
 
 
 
54
 
55
+ def test_huggingface(self):
56
+ self.run_test_case("huggingface", USER_QUESTION)
57
 
 
 
58
 
59
+ if __name__ == "__main__":
60
+ unittest.main()