File size: 5,060 Bytes
1d30073
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from typing import Optional, Tuple

import flax
import jaxlib.xla_extension as jax_xla

from transformers.file_utils import ModelOutput


@flax.struct.dataclass
class TransformerVaeOutput(ModelOutput):
    """
    Base class for a Transformer-VAE's outputs.

    Args:
        latent_codes (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, n_latent_tokens, latent_token_size)`):
            Latent codes representing encoded sequences.
        remade_encoder_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, n_tokens, model_dim)`):
            Reconstructed encoder hidden states representing sequences.

    (std Seq2Seq) Args:
        logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        past_key_values (:obj:`tuple(tuple(jax_xla.DeviceArray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
            Tuple of :obj:`tuple(jax_xla.DeviceArray)` of length :obj:`config.n_layers`, with each tuple having 2
            tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional
            tensors of shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.

            Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
            blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential decoding.
        last_hidden_state (:obj:`tuple(jax_xla.DeviceArray)`:
            Last model hidden state.
        decoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
            Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
            layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
        decoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
            Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
            sequence_length, sequence_length)`.

            Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
            self-attention heads.
        cross_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
            Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
            sequence_length, sequence_length)`.

            Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
            weighted average in the cross-attention heads.
        encoder_last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
            Sequence of hidden-states at the output of the last layer of the encoder of the model.
        encoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
            Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
            layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
        encoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
            Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
            sequence_length, sequence_length)`.

            Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
            self-attention heads.
    """
    logits: jax_xla.DeviceArray = None
    latent_codes: jax_xla.DeviceArray = None
    remade_encoder_hidden_state: jax_xla.DeviceArray = None
    # seq2seq
    past_key_values: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None
    decoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
    decoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
    cross_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
    last_hidden_state: Optional[jax_xla.DeviceArray] = None
    encoder_last_hidden_state: Optional[jax_xla.DeviceArray] = None
    encoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
    encoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None