jeffreygo commited on
Commit
a2b1e3e
1 Parent(s): 9d1f20b

Upload modeling_cpmbee.py

Browse files
Files changed (1) hide show
  1. modeling_cpmbee.py +4 -2
modeling_cpmbee.py CHANGED
@@ -1634,8 +1634,7 @@ class CpmBeeForCausalLM(CpmBeePreTrainedModel):
1634
  )
1635
 
1636
  # reshape for beam search
1637
- vocab_size = next_token_scores.shape[-1]
1638
- next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
1639
 
1640
  # Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search)
1641
  next_token_scores, next_tokens = torch.topk(
@@ -1872,6 +1871,7 @@ class CpmBeeForCausalLM(CpmBeePreTrainedModel):
1872
  logits_processor=logits_processor,
1873
  pad_token_id=generation_config.pad_token_id,
1874
  eos_token_id=generation_config.eos_token_id,
 
1875
  output_scores=generation_config.output_scores,
1876
  return_dict_in_generate=generation_config.return_dict_in_generate,
1877
  synced_gpus=synced_gpus,
@@ -1909,6 +1909,8 @@ class CpmBeeForCausalLM(CpmBeePreTrainedModel):
1909
  input_encoded = tokenizer(data_list, return_tensors="pt", padding=True, device=self.device)
1910
  input_encoded.update(kwargs)
1911
  input_encoded["generation_config"] = generation_config
 
 
1912
 
1913
  decode_res = self._generate(**input_encoded)
1914
 
 
1634
  )
1635
 
1636
  # reshape for beam search
1637
+ next_token_scores = next_token_scores.view(batch_size, -1)
 
1638
 
1639
  # Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search)
1640
  next_token_scores, next_tokens = torch.topk(
 
1871
  logits_processor=logits_processor,
1872
  pad_token_id=generation_config.pad_token_id,
1873
  eos_token_id=generation_config.eos_token_id,
1874
+ vocab_size=kwargs.get("vocab_size", None),
1875
  output_scores=generation_config.output_scores,
1876
  return_dict_in_generate=generation_config.return_dict_in_generate,
1877
  synced_gpus=synced_gpus,
 
1909
  input_encoded = tokenizer(data_list, return_tensors="pt", padding=True, device=self.device)
1910
  input_encoded.update(kwargs)
1911
  input_encoded["generation_config"] = generation_config
1912
+ input_encoded["vocab_size"] = tokenizer.vocab_size
1913
+ print(tokenizer.vocab_size)
1914
 
1915
  decode_res = self._generate(**input_encoded)
1916