nxphi47 commited on
Commit
41936ab
·
verified ·
1 Parent(s): 6ba692d

Update multipurpose_chatbot/engines/transformers_engine.py

Browse files
multipurpose_chatbot/engines/transformers_engine.py CHANGED
@@ -429,7 +429,7 @@ class TransformersEngine(BaseEngine):
429
 
430
  # ! MUST PUT INSIDE torch.no_grad() otherwise it will overflow OOM
431
  import sys
432
- # self._model._sample = types.MethodType(NewGenerationMixin.sample_stream, self._model)
433
  with torch.no_grad():
434
  inputs = self.tokenizer(prompt, return_tensors='pt')
435
  num_tokens = inputs.input_ids.size(1)
@@ -450,7 +450,7 @@ class TransformersEngine(BaseEngine):
450
  out_tokens.extend(token.tolist())
451
  response = self.tokenizer.decode(out_tokens)
452
  if "<|im_start|>assistant\n" in response:
453
- response = response.split("<|im_start|>assistant\n")
454
  num_tokens += 1
455
  print(f"{response}", end='\r')
456
  sys.stdout.flush()
@@ -458,7 +458,7 @@ class TransformersEngine(BaseEngine):
458
 
459
  if response is not None:
460
  if "<|im_start|>assistant\n" in response:
461
- response = response.split("<|im_start|>assistant\n")
462
  full_text = prompt + response
463
  num_tokens = len(self.tokenizer.encode(full_text))
464
  yield response, num_tokens
 
429
 
430
  # ! MUST PUT INSIDE torch.no_grad() otherwise it will overflow OOM
431
  import sys
432
+ self._model._sample = types.MethodType(NewGenerationMixin.sample_stream, self._model)
433
  with torch.no_grad():
434
  inputs = self.tokenizer(prompt, return_tensors='pt')
435
  num_tokens = inputs.input_ids.size(1)
 
450
  out_tokens.extend(token.tolist())
451
  response = self.tokenizer.decode(out_tokens)
452
  if "<|im_start|>assistant\n" in response:
453
+ response = response.split("<|im_start|>assistant\n")[-1]
454
  num_tokens += 1
455
  print(f"{response}", end='\r')
456
  sys.stdout.flush()
 
458
 
459
  if response is not None:
460
  if "<|im_start|>assistant\n" in response:
461
+ response = response.split("<|im_start|>assistant\n")[-1]
462
  full_text = prompt + response
463
  num_tokens = len(self.tokenizer.encode(full_text))
464
  yield response, num_tokens