Ziya-v1 / utils.py
Zimix's picture
init
a5fe767
raw
history blame
No virus
28.5 kB
import torch
from typing import Optional, Tuple, Union, List, Callable
from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.beam_search import BeamSearchScorer
from transformers.deepspeed import is_deepspeed_zero3_enabled
from transformers.generation.utils import (
LogitsProcessorList,
StoppingCriteriaList,
GenerationConfig,
GenerationMixin,
)
from transformers import LlamaForCausalLM
import warnings
import torch.distributed as dist
from torch import nn
import copy
class SteamGenerationMixin(LlamaForCausalLM):
# support for streamly generation
# TODO: group_beam_search
@torch.no_grad()
def stream_generate(
self,
input_ids: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[
Callable[[int, torch.Tensor], List[int]]
] = None,
**kwargs,
):
self._reorder_cache = self.base_model._reorder_cache
if is_deepspeed_zero3_enabled() and dist.world_size() > 1:
synced_gpus = True
else:
synced_gpus = False
if kwargs.get("attention_mask", None) is not None:
# concat prompt attention mask
prefix_attention_mask = torch.ones(
kwargs["input_ids"].shape[0], self.peft_config.num_virtual_tokens
).to(kwargs["input_ids"].device)
kwargs["attention_mask"] = torch.cat(
(prefix_attention_mask, kwargs["attention_mask"]), dim=1
)
if kwargs.get("position_ids", None) is not None:
warnings.warn(
"Position ids are not supported for parameter efficient tuning. Ignoring position ids."
)
kwargs["position_ids"] = None
if kwargs.get("token_type_ids", None) is not None:
warnings.warn(
"Token type ids are not supported for parameter efficient tuning. Ignoring token type ids"
)
kwargs["token_type_ids"] = None
batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
if generation_config is None:
generation_config = self.generation_config
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs)
bos_token_id, eos_token_id, pad_token_id = (
generation_config.bos_token_id,
generation_config.eos_token_id,
generation_config.pad_token_id,
)
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
has_default_max_length = (
kwargs.get("max_length") is None
and generation_config.max_length is not None
)
if has_default_max_length and generation_config.max_new_tokens is None:
warnings.warn(
f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
" recommend using `max_new_tokens` to control the maximum length of the generation.",
UserWarning,
)
elif generation_config.max_new_tokens is not None:
generation_config.max_length = (
generation_config.max_new_tokens + input_ids_seq_length
)
if generation_config.min_new_tokens is not None:
generation_config.min_length = (
generation_config.min_new_tokens + input_ids_seq_length
)
if input_ids_seq_length >= generation_config.max_length:
input_ids_string = (
"decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
)
# 2. Set generation parameters if not already defined
logits_processor = (
logits_processor if logits_processor is not None else LogitsProcessorList()
)
stopping_criteria = (
stopping_criteria
if stopping_criteria is not None
else StoppingCriteriaList()
)
# 7. determine generation mode
is_constraint_gen_mode = (
generation_config.constraints is not None or generation_config.force_words_ids is not None
)
is_contrastive_search_gen_mode = (
generation_config.top_k is not None
and generation_config.top_k > 1
and generation_config.do_sample is False
and generation_config.penalty_alpha is not None
and generation_config.penalty_alpha > 0
)
is_greedy_gen_mode = (
(generation_config.num_beams == 1)
and (generation_config.num_beam_groups == 1)
and generation_config.do_sample is False
and not is_constraint_gen_mode
and not is_contrastive_search_gen_mode
)
# beam=1 and do_sample=True
is_sample_gen_mode = (
(generation_config.num_beams == 1)
and (generation_config.num_beam_groups == 1)
and generation_config.do_sample is True
and not is_constraint_gen_mode
and not is_contrastive_search_gen_mode
)
is_beam_gen_mode = (
(generation_config.num_beams > 1)
and (generation_config.num_beam_groups == 1)
and generation_config.do_sample is False
and not is_constraint_gen_mode
and not is_contrastive_search_gen_mode
)
is_beam_sample_gen_mode = (
(generation_config.num_beams > 1)
and (generation_config.num_beam_groups == 1)
and generation_config.do_sample is True
and not is_constraint_gen_mode
and not is_contrastive_search_gen_mode
)
is_group_beam_gen_mode = (
(generation_config.num_beams > 1)
and (generation_config.num_beam_groups > 1)
and not is_constraint_gen_mode
and not is_contrastive_search_gen_mode
)
# 8. prepare distribution pre_processing samplers
logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=input_ids,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
logits_processor=logits_processor,
)
# 9. prepare stopping criteria
stopping_criteria = self._get_stopping_criteria(
generation_config=generation_config, stopping_criteria=stopping_criteria
)
logits_warper = self._get_logits_warper(generation_config)
if is_greedy_gen_mode:
# 11. run greedy search
return self.greedy_search(
input_ids,
logits_processor,
stopping_criteria,
generation_config,
synced_gpus,
**model_kwargs,
)
elif is_sample_gen_mode:
# 12. expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=generation_config.num_return_sequences,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
return self.stream_sample(
generation_config,
input_ids,
logits_processor,
logits_warper,
stopping_criteria,
synced_gpus,
**model_kwargs,
)
elif is_beam_gen_mode:
return self.beam_search(
generation_config,
input_ids,
logits_processor,
stopping_criteria,
synced_gpus,
**model_kwargs,
)
elif is_beam_sample_gen_mode:
# interleave input_ids with `num_beams` additional sequences per batch
return self.beam_sample(
input_ids,
logits_processor,
logits_warper,
stopping_criteria,
generation_config,
synced_gpus,
**model_kwargs,
)
else:
raise Exception('not implement')
def stream_sample(
self,
generation_config,
input_ids,
logits_processor,
logits_warper,
stopping_criteria,
synced_gpus,
**model_kwargs,
):
bos_token_id, eos_token_id, pad_token_id = (
generation_config.bos_token_id,
generation_config.eos_token_id,
generation_config.pad_token_id,
)
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
# keep track of which sequences are already finished
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
this_peer_finished = False # used by synced_gpus only
scores=()
# auto-regressive generation
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# forward pass to get next token
outputs = self(
**model_inputs,
return_dict=True,
)
if synced_gpus and this_peer_finished:
continue # don't waste resources running the code we don't need
next_token_logits = outputs.logits[:, -1, :]
# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits)
next_token_scores = logits_warper(input_ids, next_token_scores)
# sample
probs = nn.functional.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
# finished sentences should have their next token be a padding token
if eos_token_id is not None:
if pad_token_id is None:
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
yield input_ids
# torch.cuda.empty_cache()
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id_tensor is not None:
unfinished_sequences = unfinished_sequences.mul(
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
)
# stop when each sentence is finished, or if we exceed the maximum length
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
if not synced_gpus:
break
else:
this_peer_finished = True
return input_ids
def empty_cache(self):
torch.cuda.empty_cache()
def beam_sample(
self,
input_ids,
logits_processor,
logits_warper,
stopping_criteria,
generation_config,
synced_gpus,
**model_kwargs,
):
bos_token_id, eos_token_id, pad_token_id = (
generation_config.bos_token_id,
generation_config.eos_token_id,
generation_config.pad_token_id,
)
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
num_beams = generation_config.num_beams
batch_size, cur_len = input_ids.shape[0], input_ids.shape[-1]
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=generation_config.num_beams,
device=input_ids.device,
length_penalty=generation_config.length_penalty,
do_early_stopping=generation_config.early_stopping,
num_beam_hyps_to_keep=generation_config.num_return_sequences,
max_length=generation_config.max_length,
)
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=generation_config.num_beams * generation_config.num_return_sequences,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
scores = ()
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
beam_scores = beam_scores.view((batch_size * num_beams,))
this_peer_finished = False # used by synced_gpus only
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
outputs = self(
**model_inputs,
return_dict=True,
)
if synced_gpus and this_peer_finished:
cur_len = cur_len + 1
continue # don't waste resources running the code we don't need
next_token_logits = outputs.logits[:, -1, :]
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `nn.functional.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1
) # (batch_size * num_beams, vocab_size)
next_token_scores_processed = logits_processor(input_ids, next_token_scores)
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)
# Note: logits warpers are intentionally applied after adding running beam scores. On some logits warpers
# (like top_p) this is indiferent, but on others (like temperature) it is not. For reference, see
# https://github.com/huggingface/transformers/pull/5420#discussion_r449779867
next_token_scores = logits_warper(input_ids, next_token_scores)
# reshape for beam search
vocab_size = next_token_scores.shape[-1]
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
probs = nn.functional.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=2 * num_beams)
next_token_scores = torch.gather(next_token_scores, -1, next_tokens)
next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
next_tokens = torch.gather(next_tokens, -1, _indices)
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
next_tokens = next_tokens % vocab_size
# stateless
beam_outputs = beam_scorer.process(
input_ids,
next_token_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
beam_indices=None,
)
beam_scores = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"]
beam_idx = beam_outputs["next_beam_indices"]
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
yield input_ids
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
if model_kwargs["past_key_values"] is not None:
model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)
# increase cur_len
cur_len = cur_len + 1
if beam_scorer.is_done or stopping_criteria(input_ids, scores):
if not synced_gpus:
break
else:
this_peer_finished = True
sequence_outputs = beam_scorer.finalize(
input_ids,
beam_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
beam_indices=None,
)
yield sequence_outputs["sequences"]
def greedy_search(
self,
input_ids,
logits_processor,
stopping_criteria,
generation_config,
synced_gpus,
**model_kwargs,
):
# init values
bos_token_id, eos_token_id, pad_token_id = (
generation_config.bos_token_id,
generation_config.eos_token_id,
generation_config.pad_token_id,
)
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
# init attention / hidden states / scores tuples
scores = ()
# keep track of which sequences are already finished
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
this_peer_finished = False # used by synced_gpus only
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# forward pass to get next token
outputs = self(
**model_inputs,
return_dict=True,
)
if synced_gpus and this_peer_finished:
continue # don't waste resources running the code we don't need
next_token_logits = outputs.logits[:, -1, :]
# pre-process distribution
next_tokens_scores = logits_processor(input_ids, next_token_logits)
# argmax
next_tokens = torch.argmax(next_tokens_scores, dim=-1)
# finished sentences should have their next token be a padding token
if eos_token_id is not None:
if pad_token_id is None:
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
yield input_ids
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id_tensor is not None:
unfinished_sequences = unfinished_sequences.mul(
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
)
# stop when each sentence is finished, or if we exceed the maximum length
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
if not synced_gpus:
break
else:
this_peer_finished = True
yield input_ids
def beam_search(
self,
generation_config,
input_ids,
logits_processor,
stopping_criteria,
synced_gpus,
**model_kwargs,
):
# 10. go into beam search generation modes
# 11. prepare beam search scorer
bos_token_id, eos_token_id, pad_token_id = (
generation_config.bos_token_id,
generation_config.eos_token_id,
generation_config.pad_token_id,
)
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
num_beams = generation_config.num_beams
batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=generation_config.num_beams,
device=input_ids.device,
length_penalty=generation_config.length_penalty,
do_early_stopping=generation_config.early_stopping,
num_beam_hyps_to_keep=generation_config.num_return_sequences,
max_length=generation_config.max_length,
)
# 12. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=generation_config.num_beams,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
# beam_search logits
batch_beam_size, cur_len = input_ids.shape
if num_beams * batch_size != batch_beam_size:
raise ValueError(
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
)
beam_scores = torch.zeros(
(batch_size, num_beams), dtype=torch.float, device=input_ids.device
)
beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view((batch_size * num_beams,))
this_peer_finished = False # used by synced_gpus only
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(
0.0 if this_peer_finished else 1.0
).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
if synced_gpus and this_peer_finished:
cur_len = cur_len + 1
continue # don't waste resources running the code we don't need
next_token_logits = outputs.logits[:, -1, :]
# next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) hack: adjust tokens for Marian.
next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1
) # (batch_size * num_beams, vocab_size)
next_token_scores_processed = logits_processor(input_ids, next_token_scores)
next_token_scores = next_token_scores_processed + beam_scores[
:, None
].expand_as(next_token_scores)
# reshape for beam search
vocab_size = next_token_scores.shape[-1]
next_token_scores = next_token_scores.view(
batch_size, num_beams * vocab_size
)
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search)
next_token_scores, next_tokens = torch.topk(
next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True
)
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
next_tokens = next_tokens % vocab_size
# stateless
beam_outputs = beam_scorer.process(
input_ids,
next_token_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
beam_indices=None,
)
beam_scores = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"]
beam_idx = beam_outputs["next_beam_indices"]
input_ids = torch.cat(
[input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1
)
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
if model_kwargs["past_key_values"] is not None:
model_kwargs["past_key_values"] = self._reorder_cache(
model_kwargs["past_key_values"], beam_idx
)
# increase cur_len
cur_len = cur_len + 1
yield input_ids
if beam_scorer.is_done or stopping_criteria(input_ids, None):
if not synced_gpus:
break
else:
this_peer_finished = True
final_result = beam_scorer.finalize(
input_ids,
beam_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
beam_indices=None,
)
yield final_result["sequences"]