from typing import Dict, Optional import jax import jax.numpy as jnp import jaxlib.xla_extension as jax_xla from transformers.generation_flax_utils import FlaxGenerationMixin from transformers.utils import logging logger = logging.get_logger(__name__) class VaeFlaxGenerationMixin(FlaxGenerationMixin): def generate( self, latent_codes: jax_xla.DeviceArray, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, bos_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, decoder_start_token_id: Optional[int] = None, do_sample: Optional[bool] = None, prng_key: Optional[jax_xla.DeviceArray] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, temperature: Optional[float] = None, num_beams: Optional[int] = None, no_repeat_ngram_size: Optional[int] = None, min_length: Optional[int] = None, forced_bos_token_id: Optional[int] = None, forced_eos_token_id: Optional[int] = None, length_penalty: Optional[float] = None, early_stopping: Optional[bool] = None, trace: bool = True, params: Optional[Dict[str, jax_xla.DeviceArray]] = None, **model_kwargs, ): r""" Generates sequences for models with a language modeling head. The method currently supports greedy decoding, and, multinomial sampling. Apart from :obj:`latent_codes`, all the arguments below will default to the value of the attribute of the same name inside the :class:`~transformers.PretrainedConfig` of the model. The default values indicated are the default values of those config. Most of these parameters are explained in more detail in `this blog post `__. Parameters: latent_codes (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, n_latent_tokens, latent_token_dim)`, `optional`): The sequence used as a prompt for the generation. max_length (:obj:`int`, `optional`, defaults to 20): The maximum length of the sequence to be generated. do_sample (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to use sampling ; use greedy decoding otherwise. temperature (:obj:`float`, `optional`, defaults to 1.0): The value used to module the next token probabilities. top_k (:obj:`int`, `optional`, defaults to 50): The number of highest probability vocabulary tokens to keep for top-k-filtering. top_p (:obj:`float`, `optional`, defaults to 1.0): If set to float < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or higher are kept for generation. pad_token_id (:obj:`int`, `optional`): The id of the `padding` token. bos_token_id (:obj:`int`, `optional`): The id of the `beginning-of-sequence` token. eos_token_id (:obj:`int`, `optional`): The id of the `end-of-sequence` token. num_beams (:obj:`int`, `optional`, defaults to 1): Number of beams for beam search. 1 means no beam search. decoder_start_token_id (:obj:`int`, `optional`): If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token. trace (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether to trace generation. Setting ``trace=False`` should only be used for debugging and will lead to a considerably slower runtime. params (:obj:`Dict[str, jax_xla.DeviceArray]`, `optional`): Optionally the model parameters can be passed. Can be useful for parallelized generation. model_kwargs: Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. Return: :class:`~transformers.file_utils.ModelOutput`. Examples:: >>> from transformers import AutoTokenizer, FlaxAutoModelForCausalLM >>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2") >>> model = FlaxAutoModelForCausalLM.from_pretrained("distilgpt2") >>> input_context = "The dog" >>> # encode input context >>> input_ids = tokenizer(input_context, return_tensors="jax").input_ids >>> # generate candidates using sampling >>> outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True) >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) """ # set init values max_length = max_length if max_length is not None else self.config.max_length bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id decoder_start_token_id = ( decoder_start_token_id if decoder_start_token_id else self.config.decoder_start_token_id ) prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0) if decoder_start_token_id is None and self.config.is_encoder_decoder: raise ValueError("`decoder_start_token_id` has to be defined for encoder-decoder generation.") model_kwargs['latent_codes'] = latent_codes if self.config.is_encoder_decoder: # add encoder_outputs to model_kwargs # NOTE: Don't prepare encoder outputs, instead rely on latent_codes. # model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs) # prepare decoder_input_ids for generation input_ids = jnp.ones((latent_codes.shape[0], 1), dtype="i4") * decoder_start_token_id do_sample = do_sample if do_sample is not None else self.config.do_sample num_beams = num_beams if num_beams is not None else self.config.num_beams if not do_sample and num_beams == 1: logits_processor = self._get_logits_processor( no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id ) return self._greedy_search( input_ids, max_length, pad_token_id, eos_token_id, logits_processor=logits_processor, trace=trace, params=params, model_kwargs=model_kwargs, ) elif do_sample and num_beams == 1: logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, temperature=temperature) logits_processor = self._get_logits_processor( no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id ) return self._sample( input_ids, max_length, pad_token_id, eos_token_id, prng_key, logits_warper=logits_warper, logits_processor=logits_processor, trace=trace, params=params, model_kwargs=model_kwargs, ) elif not do_sample and num_beams > 1: # broadcast input_ids & encoder_outputs input_ids = self._expand_to_num_beams(input_ids, num_beams=num_beams) if "encoder_outputs" in model_kwargs: model_kwargs["encoder_outputs"]["last_hidden_state"] = self._expand_to_num_beams( model_kwargs["encoder_outputs"]["last_hidden_state"], num_beams=num_beams ) if "attention_mask" in model_kwargs: model_kwargs["attention_mask"] = self._expand_to_num_beams( model_kwargs["attention_mask"], num_beams=num_beams ) logits_processor = self._get_logits_processor( no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id ) return self._beam_search( input_ids, max_length, pad_token_id, eos_token_id, length_penalty=length_penalty, early_stopping=early_stopping, logits_processor=logits_processor, trace=trace, params=params, model_kwargs=model_kwargs, ) else: raise NotImplementedError("`Beam sampling is currently not implemented.")