ZwwWayne commited on
Commit
744e137
1 Parent(s): 6ee6bdb

fix: add eoa into eos_token_id in chat to accelerate chat interface

Browse files
Files changed (1) hide show
  1. modeling_internlm2.py +3 -0
modeling_internlm2.py CHANGED
@@ -1049,6 +1049,8 @@ class InternLM2ForCausalLM(InternLM2PreTrainedModel):
1049
  ):
1050
  inputs = self.build_inputs(tokenizer, query, history, meta_instruction)
1051
  inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}
 
 
1052
  outputs = self.generate(
1053
  **inputs,
1054
  streamer=streamer,
@@ -1056,6 +1058,7 @@ class InternLM2ForCausalLM(InternLM2PreTrainedModel):
1056
  do_sample=do_sample,
1057
  temperature=temperature,
1058
  top_p=top_p,
 
1059
  **kwargs,
1060
  )
1061
  outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]) :]
 
1049
  ):
1050
  inputs = self.build_inputs(tokenizer, query, history, meta_instruction)
1051
  inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}
1052
+ # also add end-of-assistant token in eos token id to avoid unnecessary generation
1053
+ eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(["[UNUSED_TOKEN_145]"])[0]]
1054
  outputs = self.generate(
1055
  **inputs,
1056
  streamer=streamer,
 
1058
  do_sample=do_sample,
1059
  temperature=temperature,
1060
  top_p=top_p,
1061
+ eos_token_id=eos_token_id,
1062
  **kwargs,
1063
  )
1064
  outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]) :]