Gong Baitao commited on
Commit
1ddbefc
1 Parent(s): 9575b6b

Update modeling_cpmbee.py for max_length

Browse files
Files changed (1) hide show
  1. modeling_cpmbee.py +3 -4
modeling_cpmbee.py CHANGED
@@ -1729,7 +1729,7 @@ class CpmBeeForCausalLM(CpmBeePreTrainedModel):
1729
  eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
1730
  bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id
1731
  vocab_size = vocab_size if vocab_size is not None else self.generation_config.vocab_size
1732
- max_length = max_length if max_length is not None else self.generation_config.max_length
1733
  output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
1734
  output_attentions = (
1735
  output_attentions if output_attentions is not None else self.generation_config.output_attentions
@@ -2093,7 +2093,7 @@ class CpmBeeForCausalLM(CpmBeePreTrainedModel):
2093
  length_penalty=generation_config.length_penalty,
2094
  do_early_stopping=generation_config.early_stopping,
2095
  num_beam_hyps_to_keep=generation_config.num_return_sequences,
2096
- max_length=generation_config.max_length,
2097
  **kwargs,
2098
  )
2099
  # 9. interleave input_ids with `num_beams` additional sequences per batch
@@ -2109,6 +2109,7 @@ class CpmBeeForCausalLM(CpmBeePreTrainedModel):
2109
  beam_scorer,
2110
  repetition_penalty=repetition_penalty,
2111
  logits_processor=logits_processor,
 
2112
  pad_token_id=generation_config.pad_token_id,
2113
  eos_token_id=generation_config.eos_token_id,
2114
  vocab_size=kwargs.get("vocab_size", None),
@@ -2123,7 +2124,6 @@ class CpmBeeForCausalLM(CpmBeePreTrainedModel):
2123
  self,
2124
  data_list: Union[Dict, List[Dict]],
2125
  tokenizer: CpmBeeTokenizer,
2126
- generation_config=None,
2127
  **kwargs,
2128
  ):
2129
  """
@@ -2148,7 +2148,6 @@ class CpmBeeForCausalLM(CpmBeePreTrainedModel):
2148
  data_list = [data_list]
2149
  input_encoded = tokenizer(data_list, return_tensors="pt", padding=True, device=self.device)
2150
  input_encoded.update(kwargs)
2151
- input_encoded["generation_config"] = generation_config
2152
  input_encoded["vocab_size"] = tokenizer.vocab_size
2153
 
2154
  decode_res = self._generate(**input_encoded)
 
1729
  eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
1730
  bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id
1731
  vocab_size = vocab_size if vocab_size is not None else self.generation_config.vocab_size
1732
+ max_length = max_length if max_length is not None else self.generation_config.max_new_tokens
1733
  output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
1734
  output_attentions = (
1735
  output_attentions if output_attentions is not None else self.generation_config.output_attentions
 
2093
  length_penalty=generation_config.length_penalty,
2094
  do_early_stopping=generation_config.early_stopping,
2095
  num_beam_hyps_to_keep=generation_config.num_return_sequences,
2096
+ max_length=generation_config.max_new_tokens,
2097
  **kwargs,
2098
  )
2099
  # 9. interleave input_ids with `num_beams` additional sequences per batch
 
2109
  beam_scorer,
2110
  repetition_penalty=repetition_penalty,
2111
  logits_processor=logits_processor,
2112
+ max_length=generation_config.max_new_tokens,
2113
  pad_token_id=generation_config.pad_token_id,
2114
  eos_token_id=generation_config.eos_token_id,
2115
  vocab_size=kwargs.get("vocab_size", None),
 
2124
  self,
2125
  data_list: Union[Dict, List[Dict]],
2126
  tokenizer: CpmBeeTokenizer,
 
2127
  **kwargs,
2128
  ):
2129
  """
 
2148
  data_list = [data_list]
2149
  input_encoded = tokenizer(data_list, return_tensors="pt", padding=True, device=self.device)
2150
  input_encoded.update(kwargs)
 
2151
  input_encoded["vocab_size"] = tokenizer.vocab_size
2152
 
2153
  decode_res = self._generate(**input_encoded)