jeffreygo commited on
Commit
9d1f20b
1 Parent(s): 7ef4e43

Update modeling_cpmbee.py

Browse files
Files changed (1) hide show
  1. modeling_cpmbee.py +2 -1
modeling_cpmbee.py CHANGED
@@ -1472,6 +1472,7 @@ class CpmBeeForCausalLM(CpmBeePreTrainedModel):
1472
  pad_token_id: Optional[int] = None,
1473
  eos_token_id: Optional[Union[int, List[int]]] = None,
1474
  bos_token_id: Optional[Union[int, List[int]]] = None,
 
1475
  output_attentions: Optional[bool] = None,
1476
  output_hidden_states: Optional[bool] = None,
1477
  output_scores: Optional[bool] = None,
@@ -1487,6 +1488,7 @@ class CpmBeeForCausalLM(CpmBeePreTrainedModel):
1487
  pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
1488
  eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
1489
  bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id
 
1490
  max_length = max_length if max_length is not None else self.generation_config.max_length
1491
  output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
1492
  output_attentions = (
@@ -1589,7 +1591,6 @@ class CpmBeeForCausalLM(CpmBeePreTrainedModel):
1589
  break
1590
  # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
1591
  # cannot be generated both before and after the `nn.functional.log_softmax` operation.
1592
- vocab_size = next_token_logits.shape[-1]
1593
  next_token_logits = self.adjust_logits_during_generation(
1594
  next_token_logits, batch_size, num_beams, vocab_size, ext_table_ids_cpu, **model_kwargs
1595
  )
 
1472
  pad_token_id: Optional[int] = None,
1473
  eos_token_id: Optional[Union[int, List[int]]] = None,
1474
  bos_token_id: Optional[Union[int, List[int]]] = None,
1475
+ vocab_size: Optional[int] = None,
1476
  output_attentions: Optional[bool] = None,
1477
  output_hidden_states: Optional[bool] = None,
1478
  output_scores: Optional[bool] = None,
 
1488
  pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
1489
  eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
1490
  bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id
1491
+ vocab_size = vocab_size if vocab_size is not None else self.generation_config.vocab_size
1492
  max_length = max_length if max_length is not None else self.generation_config.max_length
1493
  output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
1494
  output_attentions = (
 
1591
  break
1592
  # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
1593
  # cannot be generated both before and after the `nn.functional.log_softmax` operation.
 
1594
  next_token_logits = self.adjust_logits_during_generation(
1595
  next_token_logits, batch_size, num_beams, vocab_size, ext_table_ids_cpu, **model_kwargs
1596
  )