Gong Baitao
commited on
Commit
•
1ddbefc
1
Parent(s):
9575b6b
Update modeling_cpmbee.py for max_length
Browse files- modeling_cpmbee.py +3 -4
modeling_cpmbee.py
CHANGED
@@ -1729,7 +1729,7 @@ class CpmBeeForCausalLM(CpmBeePreTrainedModel):
|
|
1729 |
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
1730 |
bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id
|
1731 |
vocab_size = vocab_size if vocab_size is not None else self.generation_config.vocab_size
|
1732 |
-
max_length = max_length if max_length is not None else self.generation_config.
|
1733 |
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
1734 |
output_attentions = (
|
1735 |
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
@@ -2093,7 +2093,7 @@ class CpmBeeForCausalLM(CpmBeePreTrainedModel):
|
|
2093 |
length_penalty=generation_config.length_penalty,
|
2094 |
do_early_stopping=generation_config.early_stopping,
|
2095 |
num_beam_hyps_to_keep=generation_config.num_return_sequences,
|
2096 |
-
max_length=generation_config.
|
2097 |
**kwargs,
|
2098 |
)
|
2099 |
# 9. interleave input_ids with `num_beams` additional sequences per batch
|
@@ -2109,6 +2109,7 @@ class CpmBeeForCausalLM(CpmBeePreTrainedModel):
|
|
2109 |
beam_scorer,
|
2110 |
repetition_penalty=repetition_penalty,
|
2111 |
logits_processor=logits_processor,
|
|
|
2112 |
pad_token_id=generation_config.pad_token_id,
|
2113 |
eos_token_id=generation_config.eos_token_id,
|
2114 |
vocab_size=kwargs.get("vocab_size", None),
|
@@ -2123,7 +2124,6 @@ class CpmBeeForCausalLM(CpmBeePreTrainedModel):
|
|
2123 |
self,
|
2124 |
data_list: Union[Dict, List[Dict]],
|
2125 |
tokenizer: CpmBeeTokenizer,
|
2126 |
-
generation_config=None,
|
2127 |
**kwargs,
|
2128 |
):
|
2129 |
"""
|
@@ -2148,7 +2148,6 @@ class CpmBeeForCausalLM(CpmBeePreTrainedModel):
|
|
2148 |
data_list = [data_list]
|
2149 |
input_encoded = tokenizer(data_list, return_tensors="pt", padding=True, device=self.device)
|
2150 |
input_encoded.update(kwargs)
|
2151 |
-
input_encoded["generation_config"] = generation_config
|
2152 |
input_encoded["vocab_size"] = tokenizer.vocab_size
|
2153 |
|
2154 |
decode_res = self._generate(**input_encoded)
|
|
|
1729 |
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
1730 |
bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id
|
1731 |
vocab_size = vocab_size if vocab_size is not None else self.generation_config.vocab_size
|
1732 |
+
max_length = max_length if max_length is not None else self.generation_config.max_new_tokens
|
1733 |
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
1734 |
output_attentions = (
|
1735 |
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
|
|
2093 |
length_penalty=generation_config.length_penalty,
|
2094 |
do_early_stopping=generation_config.early_stopping,
|
2095 |
num_beam_hyps_to_keep=generation_config.num_return_sequences,
|
2096 |
+
max_length=generation_config.max_new_tokens,
|
2097 |
**kwargs,
|
2098 |
)
|
2099 |
# 9. interleave input_ids with `num_beams` additional sequences per batch
|
|
|
2109 |
beam_scorer,
|
2110 |
repetition_penalty=repetition_penalty,
|
2111 |
logits_processor=logits_processor,
|
2112 |
+
max_length=generation_config.max_new_tokens,
|
2113 |
pad_token_id=generation_config.pad_token_id,
|
2114 |
eos_token_id=generation_config.eos_token_id,
|
2115 |
vocab_size=kwargs.get("vocab_size", None),
|
|
|
2124 |
self,
|
2125 |
data_list: Union[Dict, List[Dict]],
|
2126 |
tokenizer: CpmBeeTokenizer,
|
|
|
2127 |
**kwargs,
|
2128 |
):
|
2129 |
"""
|
|
|
2148 |
data_list = [data_list]
|
2149 |
input_encoded = tokenizer(data_list, return_tensors="pt", padding=True, device=self.device)
|
2150 |
input_encoded.update(kwargs)
|
|
|
2151 |
input_encoded["vocab_size"] = tokenizer.vocab_size
|
2152 |
|
2153 |
decode_res = self._generate(**input_encoded)
|