Donghao Huang commited on
Commit
6b469d2
1 Parent(s): 571afe2

fixed bug on llama-2

Browse files
app_modules/llm_inference.py CHANGED
@@ -35,7 +35,12 @@ class LLMInference(metaclass=abc.ABCMeta):
35
  return self.chain
36
 
37
  def call_chain(
38
- self, inputs, streaming_handler, q: Queue = None, tracing: bool = False
 
 
 
 
 
39
  ):
40
  print(inputs)
41
  if self.llm_loader.streamer.for_huggingface:
@@ -46,11 +51,7 @@ class LLMInference(metaclass=abc.ABCMeta):
46
 
47
  chain = self.get_chain(tracing)
48
  result = (
49
- self._run_chain(
50
- chain,
51
- inputs,
52
- streaming_handler,
53
- )
54
  if streaming_handler is not None
55
  else chain(inputs)
56
  )
@@ -74,7 +75,7 @@ class LLMInference(metaclass=abc.ABCMeta):
74
  def _execute_chain(self, chain, inputs, q, sh):
75
  q.put(chain(inputs, callbacks=[sh]))
76
 
77
- def _run_chain(self, chain, inputs, streaming_handler):
78
  que = Queue()
79
 
80
  t = Thread(
@@ -83,7 +84,7 @@ class LLMInference(metaclass=abc.ABCMeta):
83
  )
84
  t.start()
85
 
86
- if self.llm_loader.streamer.for_huggingface:
87
  count = (
88
  2
89
  if "chat_history" in inputs and len(inputs.get("chat_history")) > 0
 
35
  return self.chain
36
 
37
  def call_chain(
38
+ self,
39
+ inputs,
40
+ streaming_handler,
41
+ q: Queue = None,
42
+ tracing: bool = False,
43
+ testing: bool = False,
44
  ):
45
  print(inputs)
46
  if self.llm_loader.streamer.for_huggingface:
 
51
 
52
  chain = self.get_chain(tracing)
53
  result = (
54
+ self._run_chain(chain, inputs, streaming_handler, testing)
 
 
 
 
55
  if streaming_handler is not None
56
  else chain(inputs)
57
  )
 
75
  def _execute_chain(self, chain, inputs, q, sh):
76
  q.put(chain(inputs, callbacks=[sh]))
77
 
78
+ def _run_chain(self, chain, inputs, streaming_handler, testing):
79
  que = Queue()
80
 
81
  t = Thread(
 
84
  )
85
  t.start()
86
 
87
+ if self.llm_loader.streamer.for_huggingface and not testing:
88
  count = (
89
  2
90
  if "chat_history" in inputs and len(inputs.get("chat_history")) > 0
app_modules/llm_loader.py CHANGED
@@ -227,6 +227,7 @@ class LLMLoader:
227
  if "gpt4all-j" in MODEL_NAME_OR_PATH
228
  or "dolly" in MODEL_NAME_OR_PATH
229
  or "Qwen" in MODEL_NAME_OR_PATH
 
230
  else 0
231
  )
232
  use_fast = (
@@ -452,7 +453,6 @@ class LLMLoader:
452
  top_p=0.95,
453
  top_k=0, # select from top 0 tokens (because zero, relies on top_p)
454
  repetition_penalty=1.115,
455
- use_auth_token=token,
456
  token=token,
457
  )
458
  )
 
227
  if "gpt4all-j" in MODEL_NAME_OR_PATH
228
  or "dolly" in MODEL_NAME_OR_PATH
229
  or "Qwen" in MODEL_NAME_OR_PATH
230
+ or "Llama-2" in MODEL_NAME_OR_PATH
231
  else 0
232
  )
233
  use_fast = (
 
453
  top_p=0.95,
454
  top_k=0, # select from top 0 tokens (because zero, relies on top_p)
455
  repetition_penalty=1.115,
 
456
  token=token,
457
  )
458
  )
test.py CHANGED
@@ -69,7 +69,11 @@ while True:
69
 
70
  start = timer()
71
  result = qa_chain.call_chain(
72
- {"question": query, "chat_history": chat_history}, custom_handler
 
 
 
 
73
  )
74
  end = timer()
75
  print(f"Completed in {end - start:.3f}s")
 
69
 
70
  start = timer()
71
  result = qa_chain.call_chain(
72
+ {"question": query, "chat_history": chat_history},
73
+ custom_handler,
74
+ None,
75
+ False,
76
+ True,
77
  )
78
  end = timer()
79
  print(f"Completed in {end - start:.3f}s")