Fraser's picture
cope that submodules not allowed
1d30073
raw
history blame
No virus
8.94 kB
from typing import Dict, Optional
import jax
import jax.numpy as jnp
import jaxlib.xla_extension as jax_xla
from transformers.generation_flax_utils import FlaxGenerationMixin
from transformers.utils import logging
logger = logging.get_logger(__name__)
class VaeFlaxGenerationMixin(FlaxGenerationMixin):
def generate(
self,
latent_codes: jax_xla.DeviceArray,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
bos_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
decoder_start_token_id: Optional[int] = None,
do_sample: Optional[bool] = None,
prng_key: Optional[jax_xla.DeviceArray] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
num_beams: Optional[int] = None,
no_repeat_ngram_size: Optional[int] = None,
min_length: Optional[int] = None,
forced_bos_token_id: Optional[int] = None,
forced_eos_token_id: Optional[int] = None,
length_penalty: Optional[float] = None,
early_stopping: Optional[bool] = None,
trace: bool = True,
params: Optional[Dict[str, jax_xla.DeviceArray]] = None,
**model_kwargs,
):
r"""
Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
and, multinomial sampling.
Apart from :obj:`latent_codes`, all the arguments below will default to the value of the attribute of the same
name inside the :class:`~transformers.PretrainedConfig` of the model. The default values indicated are the
default values of those config.
Most of these parameters are explained in more detail in `this blog post
<https://huggingface.co/blog/how-to-generate>`__.
Parameters:
latent_codes (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, n_latent_tokens, latent_token_dim)`, `optional`):
The sequence used as a prompt for the generation.
max_length (:obj:`int`, `optional`, defaults to 20):
The maximum length of the sequence to be generated.
do_sample (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to use sampling ; use greedy decoding otherwise.
temperature (:obj:`float`, `optional`, defaults to 1.0):
The value used to module the next token probabilities.
top_k (:obj:`int`, `optional`, defaults to 50):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p (:obj:`float`, `optional`, defaults to 1.0):
If set to float < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or
higher are kept for generation.
pad_token_id (:obj:`int`, `optional`):
The id of the `padding` token.
bos_token_id (:obj:`int`, `optional`):
The id of the `beginning-of-sequence` token.
eos_token_id (:obj:`int`, `optional`):
The id of the `end-of-sequence` token.
num_beams (:obj:`int`, `optional`, defaults to 1):
Number of beams for beam search. 1 means no beam search.
decoder_start_token_id (:obj:`int`, `optional`):
If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token.
trace (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether to trace generation. Setting ``trace=False`` should only be used for debugging and will lead to
a considerably slower runtime.
params (:obj:`Dict[str, jax_xla.DeviceArray]`, `optional`):
Optionally the model parameters can be passed. Can be useful for parallelized generation.
model_kwargs:
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model.
Return:
:class:`~transformers.file_utils.ModelOutput`.
Examples::
>>> from transformers import AutoTokenizer, FlaxAutoModelForCausalLM
>>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
>>> model = FlaxAutoModelForCausalLM.from_pretrained("distilgpt2")
>>> input_context = "The dog"
>>> # encode input context
>>> input_ids = tokenizer(input_context, return_tensors="jax").input_ids
>>> # generate candidates using sampling
>>> outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True)
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
"""
# set init values
max_length = max_length if max_length is not None else self.config.max_length
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
decoder_start_token_id = (
decoder_start_token_id if decoder_start_token_id else self.config.decoder_start_token_id
)
prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
if decoder_start_token_id is None and self.config.is_encoder_decoder:
raise ValueError("`decoder_start_token_id` has to be defined for encoder-decoder generation.")
model_kwargs['latent_codes'] = latent_codes
if self.config.is_encoder_decoder:
# add encoder_outputs to model_kwargs
# NOTE: Don't prepare encoder outputs, instead rely on latent_codes.
# model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs)
# prepare decoder_input_ids for generation
input_ids = jnp.ones((latent_codes.shape[0], 1), dtype="i4") * decoder_start_token_id
do_sample = do_sample if do_sample is not None else self.config.do_sample
num_beams = num_beams if num_beams is not None else self.config.num_beams
if not do_sample and num_beams == 1:
logits_processor = self._get_logits_processor(
no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id
)
return self._greedy_search(
input_ids,
max_length,
pad_token_id,
eos_token_id,
logits_processor=logits_processor,
trace=trace,
params=params,
model_kwargs=model_kwargs,
)
elif do_sample and num_beams == 1:
logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, temperature=temperature)
logits_processor = self._get_logits_processor(
no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id
)
return self._sample(
input_ids,
max_length,
pad_token_id,
eos_token_id,
prng_key,
logits_warper=logits_warper,
logits_processor=logits_processor,
trace=trace,
params=params,
model_kwargs=model_kwargs,
)
elif not do_sample and num_beams > 1:
# broadcast input_ids & encoder_outputs
input_ids = self._expand_to_num_beams(input_ids, num_beams=num_beams)
if "encoder_outputs" in model_kwargs:
model_kwargs["encoder_outputs"]["last_hidden_state"] = self._expand_to_num_beams(
model_kwargs["encoder_outputs"]["last_hidden_state"], num_beams=num_beams
)
if "attention_mask" in model_kwargs:
model_kwargs["attention_mask"] = self._expand_to_num_beams(
model_kwargs["attention_mask"], num_beams=num_beams
)
logits_processor = self._get_logits_processor(
no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id
)
return self._beam_search(
input_ids,
max_length,
pad_token_id,
eos_token_id,
length_penalty=length_penalty,
early_stopping=early_stopping,
logits_processor=logits_processor,
trace=trace,
params=params,
model_kwargs=model_kwargs,
)
else:
raise NotImplementedError("`Beam sampling is currently not implemented.")