|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" PyTorch ParlerTTS model.""" |
|
import copy |
|
import inspect |
|
import math |
|
import random |
|
from dataclasses import dataclass |
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union, List |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.nn import CrossEntropyLoss |
|
from transformers import AutoConfig, AutoModel, AutoModelForTextEncoding |
|
from transformers.activations import ACT2FN |
|
from transformers.generation.configuration_utils import GenerationConfig |
|
from transformers.generation.logits_process import ClassifierFreeGuidanceLogitsProcessor, LogitsProcessorList |
|
from transformers.generation.stopping_criteria import StoppingCriteriaList |
|
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask |
|
from transformers.modeling_outputs import ( |
|
BaseModelOutput, |
|
BaseModelOutputWithPastAndCrossAttentions, |
|
CausalLMOutputWithCrossAttentions, |
|
ModelOutput, |
|
Seq2SeqLMOutput, |
|
) |
|
from transformers.modeling_utils import PreTrainedModel |
|
from transformers.utils import ( |
|
add_start_docstrings, |
|
add_start_docstrings_to_model_forward, |
|
logging, |
|
replace_return_docstrings, |
|
) |
|
|
|
from .configuration_parler_tts import ParlerTTSConfig, ParlerTTSDecoderConfig |
|
from .dac_wrapper import DACConfig, DACModel |
|
from transformers import AutoConfig, AutoModel |
|
|
|
AutoConfig.register("dac", DACConfig) |
|
AutoModel.register(DACConfig, DACModel) |
|
|
|
if TYPE_CHECKING: |
|
from transformers.generation.streamers import BaseStreamer |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
_CONFIG_FOR_DOC = "ParlerTTSConfig" |
|
_CHECKPOINT_FOR_DOC = "facebook/parler_tts-small" |
|
|
|
MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST = [ |
|
"facebook/parler_tts-small", |
|
|
|
] |
|
|
|
|
|
def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask): |
|
"""Apply a delay pattern mask to the decoder input ids, only preserving predictions where |
|
the mask is set to -1, and otherwise setting to the value detailed in the mask.""" |
|
seq_len = input_ids.shape[-1] |
|
decoder_pad_token_mask = decoder_pad_token_mask[..., :seq_len] |
|
input_ids = torch.where(decoder_pad_token_mask == -1, input_ids, decoder_pad_token_mask) |
|
return input_ids |
|
|
|
|
|
def build_delay_pattern_mask( |
|
input_ids: torch.LongTensor, bos_token_id: int, pad_token_id: int, max_length: int, num_codebooks: int |
|
): |
|
"""Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by |
|
one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there |
|
are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks, |
|
seq_len)`: |
|
- [B, -1, -1, -1, -1, P, P, P] |
|
- [B, B, -1, -1, -1, -1, P, P] |
|
- [B, B, B, -1, -1, -1, -1, P] |
|
- [B, B, B, B, -1, -1, -1, -1] |
|
where P is the special padding token id and -1 indicates that the token is valid for prediction. If we include |
|
a prompt (decoder input ids), the -1 positions indicate where new tokens should be predicted. Otherwise, the |
|
mask is set to the value in the prompt: |
|
- [B, a, b, -1, -1, P, P, P] |
|
- [B, B, c, d, -1, -1, P, P] |
|
- [B, B, B, e, f, -1, -1, P] |
|
- [B, B, B, B, g, h, -1, -1] |
|
where a-h indicate the input prompt (decoder input ids) that are offset by 1. Now, we only override the -1 |
|
tokens in our prediction. |
|
""" |
|
|
|
input_ids = input_ids.reshape(-1, num_codebooks, input_ids.shape[-1]) |
|
bsz, num_codebooks, seq_len = input_ids.shape |
|
|
|
input_ids_shifted = torch.ones((bsz, num_codebooks, max_length), dtype=torch.long, device=input_ids.device) * -1 |
|
|
|
|
|
if max_length < 2 * num_codebooks - 1: |
|
return input_ids.reshape(bsz * num_codebooks, -1), input_ids_shifted.reshape(bsz * num_codebooks, -1) |
|
|
|
|
|
for codebook in range(num_codebooks): |
|
|
|
input_ids_shifted[:, codebook, codebook : seq_len + codebook] = input_ids[:, codebook] |
|
|
|
|
|
|
|
eos_delay_pattern = torch.triu( |
|
torch.ones((num_codebooks, max_length), dtype=torch.bool), diagonal=max_length - num_codebooks + 1 |
|
) |
|
|
|
bos_delay_pattern = torch.tril(torch.ones((num_codebooks, max_length), dtype=torch.bool)) |
|
|
|
bos_mask = ~(bos_delay_pattern).to(input_ids.device) |
|
eos_mask = ~(eos_delay_pattern).to(input_ids.device) |
|
mask = ~(bos_delay_pattern + eos_delay_pattern).to(input_ids.device) |
|
input_ids = mask * input_ids_shifted + ~bos_mask * bos_token_id + ~eos_mask * pad_token_id |
|
|
|
|
|
|
|
first_codebook_ids = input_ids[:, 0, :] |
|
start_ids = (first_codebook_ids == -1).nonzero()[:, 1] |
|
if len(start_ids) > 0: |
|
first_start_id = min(start_ids) |
|
else: |
|
|
|
first_start_id = seq_len |
|
|
|
|
|
pattern_mask = input_ids.reshape(bsz * num_codebooks, -1) |
|
input_ids = input_ids[..., :first_start_id].reshape(bsz * num_codebooks, -1) |
|
return input_ids, pattern_mask |
|
|
|
|
|
@dataclass |
|
class ParlerTTSUnconditionalInput(ModelOutput): |
|
""" |
|
Args: |
|
encoder_outputs (`Tuple[torch.FloatTensor]` of length 1, with tensor shape `(batch_size, sequence_length, hidden_size)`): |
|
Sequence of hidden-states at the output of the last layer of the text encoder model. |
|
attention_mask (`torch.LongTensor`) of shape `(batch_size, sequence_length)`, *optional*): |
|
Encoder attention mask to avoid performing attention on padding token indices. Mask values selected in `[0, |
|
1]`: 1 for tokens that are **not masked**, 0 for tokens that are **masked**. |
|
guidance_scale (`float`, *optional*): |
|
Guidance scale for classifier free guidance, setting the balance between the conditional logits (predicted |
|
from the prompts) and the unconditional logits (predicted without prompts). |
|
""" |
|
|
|
encoder_outputs: Tuple[torch.FloatTensor] = None |
|
attention_mask: torch.LongTensor = None |
|
guidance_scale: float = None |
|
|
|
|
|
|
|
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): |
|
""" |
|
Shift input ids one token to the right. |
|
""" |
|
shifted_input_ids = input_ids.new_zeros(input_ids.shape) |
|
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() |
|
if decoder_start_token_id is None: |
|
raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.") |
|
shifted_input_ids[:, 0] = decoder_start_token_id |
|
|
|
if pad_token_id is None: |
|
raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.") |
|
|
|
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) |
|
|
|
return shifted_input_ids |
|
|
|
|
|
|
|
class ParlerTTSSinusoidalPositionalEmbedding(nn.Module): |
|
"""This module produces sinusoidal positional embeddings of any length.""" |
|
|
|
def __init__(self, num_positions: int, embedding_dim: int): |
|
super().__init__() |
|
self.embedding_dim = embedding_dim |
|
self.make_weights(num_positions, embedding_dim) |
|
|
|
def make_weights(self, num_embeddings: int, embedding_dim: int): |
|
emb_weights = self.get_embedding(num_embeddings, embedding_dim) |
|
if hasattr(self, "weights"): |
|
|
|
emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device) |
|
|
|
self.weights = nn.Parameter(emb_weights) |
|
self.weights.requires_grad = False |
|
self.weights.detach_() |
|
|
|
@staticmethod |
|
def get_embedding(num_embeddings: int, embedding_dim: int): |
|
""" |
|
Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the |
|
description in Section 3.5 of "Attention Is All You Need". |
|
""" |
|
half_dim = embedding_dim // 2 |
|
emb = math.log(10000) / (half_dim - 1) |
|
emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb) |
|
emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0) |
|
emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=1).view(num_embeddings, -1) |
|
if embedding_dim % 2 == 1: |
|
|
|
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) |
|
return emb.to(torch.get_default_dtype()) |
|
|
|
@torch.no_grad() |
|
def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): |
|
bsz, seq_len, _ = input_ids.size() |
|
|
|
position_ids = (torch.arange(seq_len) + past_key_values_length).to(input_ids.device) |
|
|
|
if seq_len > self.weights.size(0): |
|
self.make_weights(seq_len + self.offset, self.embedding_dim) |
|
return self.weights.index_select(0, position_ids.view(-1)).detach() |
|
|
|
|
|
class ParlerTTSRotaryEmbedding(nn.Module): |
|
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): |
|
super().__init__() |
|
self.scaling_factor = scaling_factor |
|
self.dim = dim |
|
self.max_position_embeddings = max_position_embeddings |
|
self.base = base |
|
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) |
|
self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
|
|
self.max_seq_len_cached = max_position_embeddings |
|
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) |
|
t = t / self.scaling_factor |
|
freqs = torch.outer(t, self.inv_freq) |
|
|
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) |
|
self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) |
|
|
|
@torch.no_grad() |
|
def forward(self, x, position_ids): |
|
|
|
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) |
|
position_ids_expanded = position_ids[:, None, :].float() |
|
|
|
|
|
device_type = x.device.type |
|
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" |
|
with torch.autocast(device_type=device_type, enabled=False): |
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) |
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
cos = emb.cos() |
|
sin = emb.sin() |
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) |
|
|
|
def rotate_half(x): |
|
"""Rotates half the hidden dims of the input.""" |
|
x1 = x[..., : x.shape[-1] // 2] |
|
x2 = x[..., x.shape[-1] // 2 :] |
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
|
def apply_rotary_pos_emb(x, cos, sin, unsqueeze_dim=1): |
|
"""Applies Rotary Position Embedding to the query and key tensors. |
|
|
|
Args: |
|
x (`torch.Tensor`): The tensor over which to apply the rope embeddings |
|
cos (`torch.Tensor`): The cosine part of the rotary embedding. |
|
sin (`torch.Tensor`): The sine part of the rotary embedding. |
|
unsqueeze_dim (`int`, *optional*, defaults to 1): |
|
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and |
|
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note |
|
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and |
|
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes |
|
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have |
|
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. |
|
Returns: |
|
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. |
|
""" |
|
cos = cos.unsqueeze(unsqueeze_dim) |
|
sin = sin.unsqueeze(unsqueeze_dim) |
|
x_embed = (x * cos) + (rotate_half(x) * sin) |
|
return x_embed |
|
|
|
class ParlerTTSAttention(nn.Module): |
|
"""Multi-headed attention from 'Attention Is All You Need' paper""" |
|
|
|
def __init__( |
|
self, |
|
embed_dim: int, |
|
num_heads: int, |
|
dropout: float = 0.0, |
|
is_decoder: bool = False, |
|
bias: bool = True, |
|
is_causal: bool = False, |
|
config: Optional[ParlerTTSDecoderConfig] = None, |
|
): |
|
super().__init__() |
|
self.embed_dim = embed_dim |
|
self.num_heads = num_heads |
|
self.dropout = dropout |
|
self.head_dim = embed_dim // num_heads |
|
self.config = config |
|
|
|
if (self.head_dim * num_heads) != self.embed_dim: |
|
raise ValueError( |
|
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" |
|
f" and `num_heads`: {num_heads})." |
|
) |
|
self.scaling = self.head_dim**-0.5 |
|
self.is_decoder = is_decoder |
|
self.is_causal = is_causal |
|
|
|
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
|
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
|
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
|
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
|
|
|
self.rope_embeddings = config.rope_embeddings |
|
if config.rope_embeddings: |
|
self.rotary_emb = ParlerTTSRotaryEmbedding( |
|
self.head_dim, |
|
max_position_embeddings=config.max_position_embeddings, |
|
base=config.rope_theta, |
|
) |
|
|
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): |
|
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
key_value_states: Optional[torch.Tensor] = None, |
|
past_key_value: Optional[Tuple[torch.Tensor]] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
layer_head_mask: Optional[torch.Tensor] = None, |
|
output_attentions: bool = False, |
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
|
"""Input shape: Batch x Time x Channel""" |
|
|
|
|
|
|
|
is_cross_attention = key_value_states is not None |
|
|
|
bsz, tgt_len, _ = hidden_states.size() |
|
|
|
|
|
query_states = self.q_proj(hidden_states) * self.scaling |
|
query_states = self._shape(query_states, tgt_len, bsz) |
|
|
|
if self.rope_embeddings: |
|
cos, sin = self.rotary_emb(query_states, position_ids) |
|
query_states = apply_rotary_pos_emb(query_states, cos, sin) |
|
|
|
|
|
|
|
|
|
|
|
if ( |
|
is_cross_attention |
|
and past_key_value is not None |
|
and past_key_value[0].shape[2] == key_value_states.shape[1] |
|
): |
|
|
|
key_states = past_key_value[0] |
|
value_states = past_key_value[1] |
|
elif is_cross_attention: |
|
|
|
key_states = self._shape(self.k_proj(key_value_states), -1, bsz) |
|
value_states = self._shape(self.v_proj(key_value_states), -1, bsz) |
|
elif past_key_value is not None: |
|
|
|
key_states = self._shape(self.k_proj(hidden_states), -1, bsz) |
|
value_states = self._shape(self.v_proj(hidden_states), -1, bsz) |
|
|
|
key_states = apply_rotary_pos_emb(key_states, cos, sin) if self.rope_embeddings else key_states |
|
key_states = torch.cat([past_key_value[0], key_states], dim=2) |
|
value_states = torch.cat([past_key_value[1], value_states], dim=2) |
|
else: |
|
|
|
key_states = self._shape(self.k_proj(hidden_states), -1, bsz) |
|
value_states = self._shape(self.v_proj(hidden_states), -1, bsz) |
|
key_states = apply_rotary_pos_emb(key_states, cos, sin) if self.rope_embeddings else key_states |
|
if self.is_decoder: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
past_key_value = (key_states, value_states) |
|
|
|
proj_shape = (bsz * self.num_heads, -1, self.head_dim) |
|
query_states = query_states.reshape(*proj_shape) |
|
key_states = key_states.reshape(*proj_shape) |
|
value_states = value_states.reshape(*proj_shape) |
|
|
|
src_len = key_states.size(1) |
|
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) |
|
|
|
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): |
|
raise ValueError( |
|
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" |
|
f" {attn_weights.size()}" |
|
) |
|
|
|
if attention_mask is not None: |
|
if attention_mask.size() != (bsz, 1, tgt_len, src_len): |
|
raise ValueError( |
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" |
|
) |
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask |
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) |
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1) |
|
|
|
if layer_head_mask is not None: |
|
if layer_head_mask.size() != (self.num_heads,): |
|
raise ValueError( |
|
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" |
|
f" {layer_head_mask.size()}" |
|
) |
|
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) |
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) |
|
|
|
if output_attentions: |
|
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) |
|
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) |
|
else: |
|
attn_weights_reshaped = None |
|
|
|
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) |
|
|
|
attn_output = torch.bmm(attn_probs, value_states) |
|
|
|
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): |
|
raise ValueError( |
|
f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" |
|
f" {attn_output.size()}" |
|
) |
|
|
|
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) |
|
attn_output = attn_output.transpose(1, 2) |
|
|
|
|
|
|
|
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) |
|
|
|
attn_output = self.out_proj(attn_output) |
|
|
|
return attn_output, attn_weights_reshaped, past_key_value |
|
|
|
|
|
class ParlerTTSDecoderLayer(nn.Module): |
|
def __init__(self, config: ParlerTTSDecoderConfig): |
|
super().__init__() |
|
self.embed_dim = config.hidden_size |
|
|
|
self.self_attn = ParlerTTSAttention( |
|
embed_dim=self.embed_dim, |
|
num_heads=config.num_attention_heads, |
|
dropout=config.attention_dropout, |
|
is_decoder=True, |
|
bias=False, |
|
config=config, |
|
) |
|
self.dropout = config.dropout |
|
self.activation_fn = ACT2FN[config.activation_function] |
|
self.activation_dropout = config.activation_dropout |
|
|
|
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) |
|
self.encoder_attn = ParlerTTSAttention( |
|
self.embed_dim, |
|
config.num_attention_heads, |
|
dropout=config.attention_dropout, |
|
is_decoder=True, |
|
bias=False, |
|
config=config, |
|
) |
|
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) |
|
self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim, bias=False) |
|
self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim, bias=False) |
|
self.final_layer_norm = nn.LayerNorm(self.embed_dim) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
encoder_hidden_states: Optional[torch.Tensor] = None, |
|
encoder_attention_mask: Optional[torch.Tensor] = None, |
|
layer_head_mask: Optional[torch.Tensor] = None, |
|
cross_attn_layer_head_mask: Optional[torch.Tensor] = None, |
|
past_key_value: Optional[Tuple[torch.Tensor]] = None, |
|
output_attentions: Optional[bool] = False, |
|
use_cache: Optional[bool] = True, |
|
) -> torch.Tensor: |
|
""" |
|
Args: |
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` |
|
attention_mask (`torch.FloatTensor`): attention mask of size |
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. |
|
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, |
|
config.n_positions - 1]`. |
|
encoder_hidden_states (`torch.FloatTensor`): |
|
cross attention input to the layer of shape `(batch, seq_len, embed_dim)` |
|
encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size |
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. |
|
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size |
|
`(encoder_attention_heads,)`. |
|
cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of |
|
size `(decoder_attention_heads,)`. |
|
past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states |
|
output_attentions (`bool`, *optional*): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
|
returned tensors for more detail. |
|
""" |
|
residual = hidden_states |
|
hidden_states = self.self_attn_layer_norm(hidden_states) |
|
|
|
|
|
|
|
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None |
|
|
|
hidden_states, self_attn_weights, present_key_value = self.self_attn( |
|
hidden_states=hidden_states, |
|
past_key_value=self_attn_past_key_value, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
layer_head_mask=layer_head_mask, |
|
output_attentions=output_attentions, |
|
) |
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) |
|
hidden_states = residual + hidden_states |
|
|
|
|
|
cross_attn_present_key_value = None |
|
cross_attn_weights = None |
|
if encoder_hidden_states is not None: |
|
residual = hidden_states |
|
hidden_states = self.encoder_attn_layer_norm(hidden_states) |
|
|
|
|
|
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None |
|
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( |
|
hidden_states=hidden_states, |
|
key_value_states=encoder_hidden_states, |
|
attention_mask=encoder_attention_mask, |
|
position_ids=position_ids, |
|
layer_head_mask=cross_attn_layer_head_mask, |
|
past_key_value=cross_attn_past_key_value, |
|
output_attentions=output_attentions, |
|
) |
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) |
|
hidden_states = residual + hidden_states |
|
|
|
|
|
present_key_value = present_key_value + cross_attn_present_key_value |
|
|
|
|
|
residual = hidden_states |
|
hidden_states = self.final_layer_norm(hidden_states) |
|
hidden_states = self.activation_fn(self.fc1(hidden_states)) |
|
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) |
|
hidden_states = self.fc2(hidden_states) |
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) |
|
hidden_states = residual + hidden_states |
|
|
|
outputs = (hidden_states,) |
|
|
|
if output_attentions: |
|
outputs += (self_attn_weights, cross_attn_weights) |
|
|
|
if use_cache: |
|
outputs += (present_key_value,) |
|
|
|
return outputs |
|
|
|
|
|
|
|
class ParlerTTSPreTrainedModel(PreTrainedModel): |
|
""" |
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
|
models. |
|
""" |
|
|
|
config_class = ParlerTTSDecoderConfig |
|
base_model_prefix = "model" |
|
supports_gradient_checkpointing = True |
|
_no_split_modules = ["ParlerTTSDecoderLayer", "ParlerTTSAttention"] |
|
|
|
def _init_weights(self, module): |
|
std = self.config.initializer_factor |
|
if isinstance(module, (nn.Linear, nn.Conv1d)): |
|
module.weight.data.normal_(mean=0.0, std=std) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, nn.Embedding): |
|
module.weight.data.normal_(mean=0.0, std=std) |
|
if module.padding_idx is not None: |
|
module.weight.data[module.padding_idx].zero_() |
|
|
|
|
|
MUSICGEN_START_DOCSTRING = r""" |
|
|
|
The ParlerTTS model was proposed in [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by |
|
Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi, Alexandre Défossez. It is an |
|
encoder decoder transformer trained on the task of conditional music generation |
|
|
|
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the |
|
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads |
|
etc.) |
|
|
|
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. |
|
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage |
|
and behavior. |
|
|
|
Parameters: |
|
config ([`ParlerTTSConfig`]): Model configuration class with all the parameters of the model. |
|
Initializing with a config file does not load the weights associated with the model, only the |
|
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. |
|
""" |
|
|
|
MUSICGEN_INPUTS_DOCSTRING = r""" |
|
Args: |
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide |
|
it. |
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
|
[`PreTrainedTokenizer.__call__`] for details. |
|
|
|
[What are input IDs?](../glossary#input-ids) |
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
|
|
|
- 1 for tokens that are **not masked**, |
|
- 0 for tokens that are **masked**. |
|
|
|
[What are attention masks?](../glossary#attention-mask) |
|
decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length)`, *optional*): |
|
Indices of decoder input sequence tokens in the vocabulary, corresponding to the sequence of audio codes. |
|
|
|
Indices can be obtained by encoding an audio prompt with an audio encoder model to predict audio codes, |
|
such as with the [`EncodecModel`]. See [`EncodecModel.encode`] for details. |
|
|
|
[What are decoder input IDs?](../glossary#decoder-input-ids) |
|
|
|
<Tip warning={true}> |
|
|
|
The `decoder_input_ids` will automatically be converted from shape `(batch_size * num_codebooks, |
|
target_sequence_length)` to `(batch_size, num_codebooks, target_sequence_length)` in the forward pass. If |
|
you obtain audio codes from an audio encoding model, such as [`EncodecModel`], ensure that the number of |
|
frames is equal to 1, and that you reshape the audio codes from `(frames, batch_size, num_codebooks, |
|
target_sequence_length)` to `(batch_size * num_codebooks, target_sequence_length)` prior to passing them as |
|
`decoder_input_ids`. |
|
|
|
</Tip> |
|
|
|
decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): |
|
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also |
|
be used by default. |
|
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): |
|
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: |
|
|
|
- 1 indicates the head is **not masked**, |
|
- 0 indicates the head is **masked**. |
|
|
|
decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): |
|
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`: |
|
|
|
- 1 indicates the head is **not masked**, |
|
- 0 indicates the head is **masked**. |
|
|
|
cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): |
|
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, |
|
1]`: |
|
|
|
- 1 indicates the head is **not masked**, |
|
- 0 indicates the head is **masked**. |
|
|
|
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): |
|
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) |
|
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of |
|
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. |
|
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): |
|
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape |
|
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape |
|
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. |
|
|
|
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention |
|
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. |
|
|
|
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that |
|
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all |
|
`decoder_input_ids` of shape `(batch_size, sequence_length)`. |
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): |
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. |
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors |
|
than the model's internal embedding lookup matrix. |
|
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): |
|
Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded |
|
representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be |
|
input (see `past_key_values`). This is useful if you want more control over how to convert |
|
`decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. |
|
|
|
If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value |
|
of `inputs_embeds`. |
|
prompt_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
|
Indices of input prompt sequence tokens in the vocabulary. Padding will be ignored by default should you provide |
|
it. |
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
|
[`PreTrainedTokenizer.__call__`] for details. |
|
|
|
[What are input IDs?](../glossary#input-ids) |
|
prompt_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Mask to avoid performing attention on padding prompt token indices. Mask values selected in `[0, 1]`: |
|
|
|
- 1 for tokens that are **not masked**, |
|
- 0 for tokens that are **masked**. |
|
|
|
[What are attention masks?](../glossary#attention-mask) |
|
prompt_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): |
|
Optionally, instead of passing `prompt_input_ids` you can choose to directly pass an embedded representation. |
|
This is useful if you want more control over how to convert `prompt_input_ids` indices into associated vectors |
|
than the model's internal embedding lookup matrix. |
|
use_cache (`bool`, *optional*): |
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see |
|
`past_key_values`). |
|
output_attentions (`bool`, *optional*): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
|
tensors for more detail. |
|
output_hidden_states (`bool`, *optional*): |
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
|
more detail. |
|
return_dict (`bool`, *optional*): |
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
|
""" |
|
|
|
MUSICGEN_DECODER_INPUTS_DOCSTRING = r""" |
|
Args: |
|
input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, sequence_length)`): |
|
Indices of input sequence tokens in the vocabulary, corresponding to the sequence of audio codes. |
|
|
|
Indices can be obtained by encoding an audio prompt with an audio encoder model to predict audio codes, |
|
such as with the [`EncodecModel`]. See [`EncodecModel.encode`] for details. |
|
|
|
[What are input IDs?](../glossary#input-ids) |
|
|
|
<Tip warning={true}> |
|
|
|
The `input_ids` will automatically be converted from shape `(batch_size * num_codebooks, |
|
target_sequence_length)` to `(batch_size, num_codebooks, target_sequence_length)` in the forward pass. If |
|
you obtain audio codes from an audio encoding model, such as [`EncodecModel`], ensure that the number of |
|
frames is equal to 1, and that you reshape the audio codes from `(frames, batch_size, num_codebooks, |
|
target_sequence_length)` to `(batch_size * num_codebooks, target_sequence_length)` prior to passing them as |
|
`input_ids`. |
|
|
|
</Tip> |
|
|
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
|
|
|
- 1 for tokens that are **not masked**, |
|
- 0 for tokens that are **masked**. |
|
|
|
[What are attention masks?](../glossary#attention-mask) |
|
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): |
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of |
|
the decoder. |
|
encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): |
|
Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values |
|
selected in `[0, 1]`: |
|
|
|
- 1 for tokens that are **not masked**, |
|
- 0 for tokens that are **masked**. |
|
|
|
[What are attention masks?](../glossary#attention-mask) |
|
prompt_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): |
|
Sequence of prompt hidden-states at the output of the initial embedding layer. Concatenated to the input embeds. |
|
prompt_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): |
|
Mask to avoid performing cross-attention on padding tokens indices of prompt input_ids. Mask values |
|
selected in `[0, 1]`: |
|
|
|
- 1 for tokens that are **not masked**, |
|
- 0 for tokens that are **masked**. |
|
|
|
[What are attention masks?](../glossary#attention-mask) |
|
head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): |
|
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: |
|
|
|
- 1 indicates the head is **not masked**, |
|
- 0 indicates the head is **masked**. |
|
|
|
cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): |
|
Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing |
|
cross-attention on hidden heads. Mask values selected in `[0, 1]`: |
|
|
|
- 1 indicates the head is **not masked**, |
|
- 0 indicates the head is **masked**. |
|
|
|
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): |
|
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape |
|
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape |
|
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. |
|
|
|
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention |
|
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. |
|
|
|
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that |
|
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all |
|
`decoder_input_ids` of shape `(batch_size, sequence_length)`. |
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): |
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. |
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors |
|
than the model's internal embedding lookup matrix. |
|
output_attentions (`bool`, *optional*): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
|
tensors for more detail. |
|
output_hidden_states (`bool`, *optional*): |
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
|
more detail. |
|
return_dict (`bool`, *optional*): |
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
|
""" |
|
|
|
|
|
class ParlerTTSDecoder(ParlerTTSPreTrainedModel): |
|
""" |
|
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`ParlerTTSDecoderLayer`] |
|
""" |
|
|
|
def __init__(self, config: ParlerTTSDecoderConfig): |
|
super().__init__(config) |
|
self.dropout = config.dropout |
|
self.layerdrop = config.layerdrop |
|
self.max_target_positions = config.max_position_embeddings |
|
self.d_model = config.hidden_size |
|
self.num_codebooks = config.num_codebooks |
|
self.embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0 |
|
|
|
|
|
embed_dim = config.vocab_size + 1 |
|
self.embed_tokens = nn.ModuleList( |
|
[nn.Embedding(embed_dim, config.hidden_size) for _ in range(config.num_codebooks)] |
|
) |
|
|
|
self.rope_embeddings = config.rope_embeddings |
|
if not config.rope_embeddings: |
|
self.embed_positions = ParlerTTSSinusoidalPositionalEmbedding( |
|
config.max_position_embeddings, |
|
config.hidden_size, |
|
) |
|
|
|
self.layers = nn.ModuleList([ParlerTTSDecoderLayer(config) for _ in range(config.num_hidden_layers)]) |
|
self.layer_norm = nn.LayerNorm(config.hidden_size) |
|
|
|
self.gradient_checkpointing = False |
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.embed_tokens |
|
|
|
def set_input_embeddings(self, value): |
|
self.embed_tokens = value |
|
|
|
@add_start_docstrings_to_model_forward(MUSICGEN_DECODER_INPUTS_DOCSTRING) |
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
encoder_hidden_states: Optional[torch.FloatTensor] = None, |
|
encoder_attention_mask: Optional[torch.LongTensor] = None, |
|
prompt_hidden_states: Optional[torch.FloatTensor] = None, |
|
prompt_attention_mask: Optional[torch.LongTensor] = None, |
|
head_mask: Optional[torch.Tensor] = None, |
|
cross_attn_head_mask: Optional[torch.Tensor] = None, |
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
if input_ids is not None and inputs_embeds is not None: |
|
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") |
|
elif input_ids is not None: |
|
|
|
input = input_ids.reshape(-1, self.num_codebooks, input_ids.shape[-1]) |
|
bsz, num_codebooks, seq_len = input.shape |
|
input_shape = (bsz, seq_len) |
|
elif inputs_embeds is not None: |
|
input_shape = inputs_embeds.size()[:-1] |
|
input = inputs_embeds[:, :, -1:] |
|
else: |
|
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") |
|
|
|
|
|
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 |
|
|
|
if inputs_embeds is None: |
|
inputs_embeds = sum([self.embed_tokens[codebook](input[:, codebook]) for codebook in range(num_codebooks)]) |
|
|
|
|
|
if prompt_hidden_states is not None: |
|
inputs_embeds = torch.cat([prompt_hidden_states, inputs_embeds], dim=1) |
|
|
|
|
|
if prompt_attention_mask is not None and attention_mask is not None: |
|
attention_mask = torch.cat([prompt_attention_mask, attention_mask], dim=1) |
|
elif prompt_attention_mask is not None: |
|
logger.warning_once( |
|
"`prompt_attention_mask` is specified but `attention_mask` is not. A full `attention_mask` will be created. Make sure this is the intended behaviour." |
|
) |
|
if past_key_values is None: |
|
attention_mask = torch.cat( |
|
[ |
|
prompt_attention_mask, |
|
torch.ones(input_shape, device=self.device, dtype=prompt_attention_mask.dtype), |
|
], |
|
dim=1, |
|
) |
|
else: |
|
generated_length = past_key_values_length - prompt_attention_mask.shape[1] + 1 |
|
attention_mask = torch.cat( |
|
[ |
|
prompt_attention_mask, |
|
torch.ones( |
|
(input_shape[0], generated_length), device=self.device, dtype=prompt_attention_mask.dtype |
|
), |
|
], |
|
dim=1, |
|
) |
|
|
|
input_shape = inputs_embeds.size()[:-1] |
|
|
|
if not self.rope_embeddings: |
|
|
|
|
|
|
|
positions = self.embed_positions(inputs_embeds, past_key_values_length) |
|
hidden_states = inputs_embeds + positions.to(inputs_embeds.device) |
|
else: |
|
hidden_states = inputs_embeds |
|
|
|
if position_ids is None: |
|
if attention_mask is not None: |
|
|
|
position_ids = attention_mask.long().cumsum(-1) - 1 |
|
position_ids.masked_fill_(attention_mask == 0, 1) |
|
else: |
|
position_ids = torch.arange( |
|
past_key_values_length, input_shape[1] + past_key_values_length, |
|
dtype=torch.long, |
|
device=inputs_embeds.device |
|
) |
|
position_ids = position_ids.unsqueeze(0) |
|
|
|
|
|
if position_ids.shape[1] > input_shape[1]: |
|
position_ids = position_ids[:, -input_shape[1]:] |
|
|
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) |
|
|
|
|
|
attention_mask = _prepare_4d_causal_attention_mask( |
|
attention_mask, input_shape, inputs_embeds, past_key_values_length |
|
) |
|
|
|
|
|
if encoder_hidden_states is not None and encoder_attention_mask is not None: |
|
|
|
encoder_attention_mask = _prepare_4d_attention_mask( |
|
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] |
|
) |
|
|
|
if self.gradient_checkpointing and self.training: |
|
if use_cache: |
|
logger.warning_once( |
|
"`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." |
|
) |
|
use_cache = False |
|
|
|
|
|
all_hidden_states = () if output_hidden_states else None |
|
all_self_attns = () if output_attentions else None |
|
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None |
|
next_decoder_cache = () if use_cache else None |
|
|
|
|
|
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): |
|
if attn_mask is not None: |
|
if attn_mask.size()[0] != len(self.layers): |
|
raise ValueError( |
|
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" |
|
f" {attn_mask.size()[0]}." |
|
) |
|
for idx, decoder_layer in enumerate(self.layers): |
|
|
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
dropout_probability = random.uniform(0, 1) |
|
if self.training and (dropout_probability < self.layerdrop): |
|
continue |
|
|
|
past_key_value = past_key_values[idx] if past_key_values is not None else None |
|
|
|
if self.gradient_checkpointing and self.training: |
|
layer_outputs = self._gradient_checkpointing_func( |
|
decoder_layer.forward, |
|
hidden_states, |
|
attention_mask, |
|
position_ids, |
|
encoder_hidden_states, |
|
encoder_attention_mask, |
|
head_mask[idx] if head_mask is not None else None, |
|
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, |
|
None, |
|
output_attentions, |
|
use_cache, |
|
) |
|
else: |
|
layer_outputs = decoder_layer( |
|
hidden_states, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_attention_mask, |
|
layer_head_mask=(head_mask[idx] if head_mask is not None else None), |
|
cross_attn_layer_head_mask=( |
|
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None |
|
), |
|
past_key_value=past_key_value, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
) |
|
hidden_states = layer_outputs[0] |
|
|
|
if use_cache: |
|
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) |
|
|
|
if output_attentions: |
|
all_self_attns += (layer_outputs[1],) |
|
|
|
if encoder_hidden_states is not None: |
|
all_cross_attentions += (layer_outputs[2],) |
|
|
|
hidden_states = self.layer_norm(hidden_states) |
|
|
|
|
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
next_cache = next_decoder_cache if use_cache else None |
|
if not return_dict: |
|
return tuple( |
|
v |
|
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] |
|
if v is not None |
|
) |
|
return BaseModelOutputWithPastAndCrossAttentions( |
|
last_hidden_state=hidden_states, |
|
past_key_values=next_cache, |
|
hidden_states=all_hidden_states, |
|
attentions=all_self_attns, |
|
cross_attentions=all_cross_attentions, |
|
) |
|
|
|
|
|
@add_start_docstrings( |
|
"The bare ParlerTTS decoder model outputting raw hidden-states without any specific head on top.", |
|
MUSICGEN_START_DOCSTRING, |
|
) |
|
|
|
class ParlerTTSModel(ParlerTTSPreTrainedModel): |
|
def __init__(self, config: ParlerTTSDecoderConfig): |
|
super().__init__(config) |
|
self.decoder = ParlerTTSDecoder(config) |
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.decoder.embed_tokens |
|
|
|
def set_input_embeddings(self, value): |
|
self.decoder.embed_tokens = value |
|
|
|
def get_decoder(self): |
|
return self.decoder |
|
|
|
@add_start_docstrings_to_model_forward(MUSICGEN_DECODER_INPUTS_DOCSTRING) |
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
encoder_hidden_states: Optional[torch.FloatTensor] = None, |
|
encoder_attention_mask: Optional[torch.LongTensor] = None, |
|
prompt_hidden_states: Optional[torch.FloatTensor] = None, |
|
prompt_attention_mask: Optional[torch.LongTensor] = None, |
|
head_mask: Optional[torch.Tensor] = None, |
|
cross_attn_head_mask: Optional[torch.Tensor] = None, |
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
decoder_outputs = self.decoder( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
encoder_attention_mask=encoder_attention_mask, |
|
encoder_hidden_states=encoder_hidden_states, |
|
prompt_hidden_states=prompt_hidden_states, |
|
prompt_attention_mask=prompt_attention_mask, |
|
head_mask=head_mask, |
|
cross_attn_head_mask=cross_attn_head_mask, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
if not return_dict: |
|
return decoder_outputs |
|
|
|
return BaseModelOutputWithPastAndCrossAttentions( |
|
last_hidden_state=decoder_outputs.last_hidden_state, |
|
past_key_values=decoder_outputs.past_key_values, |
|
hidden_states=decoder_outputs.hidden_states, |
|
attentions=decoder_outputs.attentions, |
|
cross_attentions=decoder_outputs.cross_attentions, |
|
) |
|
|
|
|
|
@add_start_docstrings( |
|
"The Parler-TTS decoder model with a language modelling head on top.", |
|
MUSICGEN_START_DOCSTRING, |
|
) |
|
class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel): |
|
def __init__(self, config: ParlerTTSDecoderConfig): |
|
super().__init__(config) |
|
|
|
self.model = ParlerTTSModel(config) |
|
|
|
self.num_codebooks = config.num_codebooks |
|
self.lm_heads = nn.ModuleList( |
|
[nn.Linear(config.hidden_size, config.vocab_size, bias=False) for _ in range(config.num_codebooks)] |
|
) |
|
|
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.model.decoder.embed_tokens |
|
|
|
def set_input_embeddings(self, value): |
|
self.model.decoder.embed_tokens = value |
|
|
|
def get_output_embeddings(self): |
|
return self.lm_heads |
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
self.lm_heads = new_embeddings |
|
|
|
def set_decoder(self, decoder): |
|
self.model.decoder = decoder |
|
|
|
def get_decoder(self): |
|
return self.model.decoder |
|
|
|
@add_start_docstrings_to_model_forward(MUSICGEN_DECODER_INPUTS_DOCSTRING) |
|
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) |
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
encoder_hidden_states: Optional[torch.FloatTensor] = None, |
|
encoder_attention_mask: Optional[torch.LongTensor] = None, |
|
prompt_hidden_states: Optional[torch.FloatTensor] = None, |
|
prompt_attention_mask: Optional[torch.LongTensor] = None, |
|
head_mask: Optional[torch.Tensor] = None, |
|
cross_attn_head_mask: Optional[torch.Tensor] = None, |
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*): |
|
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set |
|
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` |
|
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` |
|
Returns: |
|
""" |
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
outputs = self.model( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_attention_mask, |
|
prompt_hidden_states=prompt_hidden_states, |
|
prompt_attention_mask=prompt_attention_mask, |
|
head_mask=head_mask, |
|
cross_attn_head_mask=cross_attn_head_mask, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
hidden_states = outputs[0] |
|
|
|
lm_logits = torch.stack([head(hidden_states) for head in self.lm_heads], dim=1) |
|
|
|
loss = None |
|
if labels is not None: |
|
|
|
logits = lm_logits[:, :, -labels.shape[1] :] |
|
|
|
loss_fct = CrossEntropyLoss() |
|
loss = torch.zeros([], device=self.device) |
|
|
|
|
|
labels = labels.masked_fill(labels == self.config.bos_token_id, -100) |
|
|
|
|
|
mask = (input_ids.transpose(1, 2) != self.config.eos_token_id) & ((labels != -100)) |
|
|
|
|
|
for codebook in range(self.config.num_codebooks): |
|
codebook_logits = logits[:, codebook].contiguous().view(-1, logits.shape[-1]) |
|
codebook_mask = mask[..., codebook].contiguous().view(-1) |
|
codebook_labels = labels[..., codebook].contiguous().view(-1) |
|
|
|
codebook_loss = loss_fct(codebook_logits[codebook_mask], codebook_labels[codebook_mask]) |
|
loss += codebook_loss |
|
|
|
loss = loss / self.config.num_codebooks |
|
|
|
|
|
lm_logits = lm_logits.reshape(-1, *lm_logits.shape[2:]) |
|
|
|
if not return_dict: |
|
output = (lm_logits,) + outputs[1:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return CausalLMOutputWithCrossAttentions( |
|
loss=loss, |
|
logits=lm_logits, |
|
past_key_values=outputs.past_key_values, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
cross_attentions=outputs.cross_attentions, |
|
) |
|
|
|
def prepare_inputs_for_generation( |
|
self, |
|
input_ids, |
|
attention_mask=None, |
|
encoder_hidden_states=None, |
|
encoder_attention_mask=None, |
|
prompt_hidden_states=None, |
|
prompt_attention_mask=None, |
|
head_mask=None, |
|
cross_attn_head_mask=None, |
|
past_key_values=None, |
|
use_cache=True, |
|
delay_pattern_mask=None, |
|
guidance_scale=None, |
|
**kwargs, |
|
): |
|
if delay_pattern_mask is None: |
|
input_ids, delay_pattern_mask = self.build_delay_pattern_mask( |
|
input_ids, |
|
bos_token_id=self.generation_config.bos_token_id, |
|
pad_token_id=self.generation_config.pad_token_id, |
|
max_length=self.generation_config.max_length, |
|
) |
|
|
|
|
|
input_ids = self.apply_delay_pattern_mask(input_ids, delay_pattern_mask) |
|
|
|
if guidance_scale is not None and guidance_scale > 1: |
|
|
|
|
|
input_ids = input_ids.repeat((2, 1)) |
|
if attention_mask is not None: |
|
attention_mask = attention_mask.repeat((2, 1)) |
|
|
|
if prompt_hidden_states is not None: |
|
prompt_hidden_states = torch.concatenate( |
|
[prompt_hidden_states, torch.zeros_like(prompt_hidden_states)], dim=0 |
|
) |
|
|
|
if prompt_attention_mask is not None: |
|
prompt_attention_mask = torch.concatenate( |
|
[prompt_attention_mask, torch.zeros_like(prompt_attention_mask)], dim=0 |
|
) |
|
|
|
position_ids = kwargs.get("position_ids", None) |
|
if attention_mask is not None and position_ids is None: |
|
|
|
position_ids = attention_mask.long().cumsum(-1) - 1 |
|
position_ids.masked_fill_(attention_mask == 0, 1) |
|
|
|
if past_key_values is not None: |
|
input_ids = input_ids[:, -1:] |
|
if position_ids is not None: |
|
position_ids = position_ids[:, -input_ids.shape[1]:] |
|
|
|
|
|
prompt_hidden_states = None |
|
|
|
return { |
|
"input_ids": input_ids, |
|
"attention_mask": attention_mask, |
|
"position_ids": position_ids, |
|
"encoder_hidden_states": encoder_hidden_states, |
|
"encoder_attention_mask": encoder_attention_mask, |
|
"prompt_hidden_states": prompt_hidden_states, |
|
"prompt_attention_mask": prompt_attention_mask, |
|
"head_mask": head_mask, |
|
"cross_attn_head_mask": cross_attn_head_mask, |
|
"past_key_values": past_key_values, |
|
"use_cache": use_cache, |
|
} |
|
|
|
|
|
def build_delay_pattern_mask( |
|
self, input_ids: torch.LongTensor, bos_token_id: int, pad_token_id: int, max_length: int = None |
|
): |
|
"""Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by |
|
one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there |
|
are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks, |
|
seq_len)`: |
|
- [B, -1, -1, -1, -1, P, P, P] |
|
- [B, B, -1, -1, -1, -1, P, P] |
|
- [B, B, B, -1, -1, -1, -1, P] |
|
- [B, B, B, B, -1, -1, -1, -1] |
|
where P is the special padding token id and -1 indicates that the token is valid for prediction. If we include |
|
a prompt (decoder input ids), the -1 positions indicate where new tokens should be predicted. Otherwise, the |
|
mask is set to the value in the prompt: |
|
- [B, a, b, -1, -1, P, P, P] |
|
- [B, B, c, d, -1, -1, P, P] |
|
- [B, B, B, e, f, -1, -1, P] |
|
- [B, B, B, B, g, h, -1, -1] |
|
where a-h indicate the input prompt (decoder input ids) that are offset by 1. Now, we only override the -1 |
|
tokens in our prediction. |
|
""" |
|
max_length = max_length if max_length is not None else self.generation_config.max_length |
|
return build_delay_pattern_mask(input_ids, bos_token_id, pad_token_id, max_length, self.num_codebooks) |
|
|
|
@staticmethod |
|
def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask): |
|
"""Apply a delay pattern mask to the decoder input ids, only preserving predictions where |
|
the mask is set to -1, and otherwise setting to the value detailed in the mask.""" |
|
return apply_delay_pattern_mask(input_ids, decoder_pad_token_mask) |
|
|
|
@torch.no_grad() |
|
def generate( |
|
self, |
|
inputs: Optional[torch.Tensor] = None, |
|
generation_config: Optional[GenerationConfig] = None, |
|
logits_processor: Optional[LogitsProcessorList] = None, |
|
stopping_criteria: Optional[StoppingCriteriaList] = None, |
|
synced_gpus: Optional[bool] = None, |
|
streamer: Optional["BaseStreamer"] = None, |
|
**kwargs, |
|
): |
|
""" |
|
Generates sequences of token ids for models with a language modeling head. |
|
|
|
<Tip warning={true}> |
|
|
|
Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the |
|
model's default generation configuration. You can override any `generation_config` by passing the corresponding |
|
parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. |
|
|
|
For an overview of generation strategies and code examples, check out the [following |
|
guide](./generation_strategies). |
|
|
|
</Tip> |
|
|
|
Parameters: |
|
inputs (`torch.Tensor` of varying shape depending on the modality, *optional*): |
|
The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the |
|
method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` |
|
should be in the format `input_ids`. For encoder-decoder models *inputs* can represent any of |
|
`input_ids`, `input_values`, `input_features`, or `pixel_values`. |
|
generation_config (`~generation.GenerationConfig`, *optional*): |
|
The generation configuration to be used as base parametrization for the generation call. `**kwargs` |
|
passed to generate matching the attributes of `generation_config` will override them. If |
|
`generation_config` is not provided, the default will be used, which had the following loading |
|
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model |
|
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s |
|
default values, whose documentation should be checked to parameterize generation. |
|
logits_processor (`LogitsProcessorList`, *optional*): |
|
Custom logits processors that complement the default logits processors built from arguments and |
|
generation config. If a logit processor is passed that is already created with the arguments or a |
|
generation config an error is thrown. This feature is intended for advanced users. |
|
stopping_criteria (`StoppingCriteriaList`, *optional*): |
|
Custom stopping criteria that complement the default stopping criteria built from arguments and a |
|
generation config. If a stopping criteria is passed that is already created with the arguments or a |
|
generation config an error is thrown. This feature is intended for advanced users. |
|
synced_gpus (`bool`, *optional*, defaults to `False`): |
|
Whether to continue running the while loop until max_length (needed for ZeRO stage 3) |
|
streamer (`BaseStreamer`, *optional*): |
|
Streamer object that will be used to stream the generated sequences. Generated tokens are passed |
|
through `streamer.put(token_ids)` and the streamer is responsible for any further processing. |
|
kwargs (`Dict[str, Any]`, *optional*): |
|
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be |
|
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder |
|
specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. |
|
|
|
Return: |
|
[`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` |
|
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. |
|
|
|
If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible |
|
[`~utils.ModelOutput`] types are: |
|
|
|
- [`~generation.GenerateDecoderOnlyOutput`], |
|
- [`~generation.GenerateBeamDecoderOnlyOutput`] |
|
|
|
If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible |
|
[`~utils.ModelOutput`] types are: |
|
|
|
- [`~generation.GenerateEncoderDecoderOutput`], |
|
- [`~generation.GenerateBeamEncoderDecoderOutput`] |
|
""" |
|
|
|
if generation_config is None: |
|
generation_config = self.generation_config |
|
|
|
generation_config = copy.deepcopy(generation_config) |
|
model_kwargs = generation_config.update(**kwargs) |
|
generation_config.validate() |
|
self._validate_model_kwargs(model_kwargs.copy()) |
|
|
|
|
|
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() |
|
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() |
|
|
|
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None: |
|
if model_kwargs.get("attention_mask", None) is None: |
|
logger.warning( |
|
"The attention mask and the pad token id were not set. As a consequence, you may observe " |
|
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." |
|
) |
|
eos_token_id = generation_config.eos_token_id |
|
if isinstance(eos_token_id, list): |
|
eos_token_id = eos_token_id[0] |
|
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") |
|
generation_config.pad_token_id = eos_token_id |
|
|
|
|
|
|
|
|
|
|
|
|
|
input_ids, model_input_name, model_kwargs = self._prepare_model_inputs( |
|
inputs, generation_config.bos_token_id, model_kwargs |
|
) |
|
batch_size = input_ids.shape[0] // self.num_codebooks |
|
|
|
|
|
model_kwargs["use_cache"] = generation_config.use_cache |
|
model_kwargs["guidance_scale"] = generation_config.guidance_scale |
|
|
|
requires_attention_mask = "encoder_outputs" not in model_kwargs |
|
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask: |
|
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( |
|
input_ids, generation_config.pad_token_id, generation_config.eos_token_id |
|
) |
|
|
|
|
|
input_ids_seq_length = input_ids.shape[-1] |
|
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None |
|
if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20: |
|
logger.warning( |
|
f"Using the model-agnostic default `max_length` (={generation_config.max_length}) " |
|
"to control the generation length. recommend setting `max_new_tokens` to control the maximum length of the generation." |
|
) |
|
elif generation_config.max_new_tokens is not None: |
|
if not has_default_max_length: |
|
logger.warning( |
|
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" |
|
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " |
|
"Please refer to the documentation for more information. " |
|
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" |
|
) |
|
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length |
|
|
|
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length: |
|
raise ValueError( |
|
f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than" |
|
f" the maximum length ({generation_config.max_length})" |
|
) |
|
if input_ids_seq_length >= generation_config.max_length: |
|
logger.warning( |
|
f"Input length of decoder_input_ids is {input_ids_seq_length}, but `max_length` is set to" |
|
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" |
|
" increasing `max_new_tokens`." |
|
) |
|
|
|
|
|
|
|
input_ids, delay_pattern_mask = self.build_delay_pattern_mask( |
|
input_ids, |
|
bos_token_id=generation_config.bos_token_id, |
|
pad_token_id=generation_config.pad_token_id, |
|
max_length=generation_config.max_length, |
|
) |
|
|
|
if streamer is not None: |
|
streamer.put(input_ids.cpu()) |
|
|
|
|
|
model_kwargs["delay_pattern_mask"] = delay_pattern_mask |
|
|
|
|
|
is_greedy_gen_mode = ( |
|
(generation_config.num_beams == 1) |
|
and (generation_config.num_beam_groups == 1) |
|
and generation_config.do_sample is False |
|
) |
|
is_sample_gen_mode = ( |
|
(generation_config.num_beams == 1) |
|
and (generation_config.num_beam_groups == 1) |
|
and generation_config.do_sample is True |
|
) |
|
|
|
|
|
if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1: |
|
logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale)) |
|
generation_config.guidance_scale = None |
|
|
|
|
|
logits_processor = self._get_logits_processor( |
|
generation_config=generation_config, |
|
input_ids_seq_length=input_ids_seq_length, |
|
encoder_input_ids=input_ids, |
|
prefix_allowed_tokens_fn=None, |
|
logits_processor=logits_processor, |
|
) |
|
|
|
|
|
stopping_criteria = self._get_stopping_criteria( |
|
generation_config=generation_config, stopping_criteria=stopping_criteria |
|
) |
|
|
|
if is_greedy_gen_mode: |
|
if generation_config.num_return_sequences > 1: |
|
raise ValueError( |
|
"num_return_sequences has to be 1 when doing greedy search, " |
|
f"but is {generation_config.num_return_sequences}." |
|
) |
|
|
|
|
|
outputs = self._greedy_search( |
|
input_ids, |
|
logits_processor=logits_processor, |
|
stopping_criteria=stopping_criteria, |
|
generation_config=generation_config, |
|
synced_gpus=synced_gpus, |
|
streamer=streamer, |
|
**model_kwargs, |
|
) |
|
|
|
elif is_sample_gen_mode: |
|
|
|
logits_warper = self._get_logits_warper(generation_config) |
|
|
|
|
|
input_ids, model_kwargs = self._expand_inputs_for_generation( |
|
input_ids=input_ids, |
|
expand_size=generation_config.num_return_sequences, |
|
**model_kwargs, |
|
) |
|
|
|
|
|
outputs = self._sample( |
|
input_ids, |
|
logits_processor=logits_processor, |
|
logits_warper=logits_warper, |
|
stopping_criteria=stopping_criteria, |
|
generation_config=generation_config, |
|
synced_gpus=synced_gpus, |
|
streamer=streamer, |
|
**model_kwargs, |
|
) |
|
|
|
else: |
|
raise ValueError( |
|
"Got incompatible mode for generation, should be one of greedy or sampling. " |
|
"Ensure that beam search is de-activated by setting `num_beams=1` and `num_beam_groups=1`." |
|
) |
|
|
|
if generation_config.return_dict_in_generate: |
|
output_ids = outputs.sequences |
|
else: |
|
output_ids = outputs |
|
|
|
|
|
output_ids = self.apply_delay_pattern_mask(output_ids, model_kwargs["delay_pattern_mask"]) |
|
|
|
|
|
_, mask = self.build_delay_pattern_mask( |
|
input_ids, |
|
bos_token_id=generation_config.bos_token_id, |
|
pad_token_id=generation_config.pad_token_id, |
|
max_length=output_ids.shape[1], |
|
) |
|
|
|
mask = (mask != generation_config.bos_token_id) & (mask != generation_config.pad_token_id) |
|
output_ids = output_ids[mask].reshape(batch_size, self.num_codebooks, -1) |
|
|
|
if generation_config.return_dict_in_generate: |
|
outputs.sequences = output_ids |
|
return outputs |
|
else: |
|
return output_ids |
|
|
|
|
|
@add_start_docstrings( |
|
"The composite Parler-TTS model with a text encoder, audio encoder and ParlerTTS decoder, " |
|
"for music generation tasks with one or both of text and audio prompts.", |
|
MUSICGEN_START_DOCSTRING, |
|
) |
|
class ParlerTTSForConditionalGeneration(PreTrainedModel): |
|
config_class = ParlerTTSConfig |
|
base_model_prefix = "encoder_decoder" |
|
main_input_name = "input_ids" |
|
supports_gradient_checkpointing = True |
|
|
|
def __init__( |
|
self, |
|
config: Optional[ParlerTTSConfig] = None, |
|
text_encoder: Optional[PreTrainedModel] = None, |
|
audio_encoder: Optional[PreTrainedModel] = None, |
|
decoder: Optional[ParlerTTSForCausalLM] = None, |
|
): |
|
if config is None and (text_encoder is None or audio_encoder is None or decoder is None): |
|
raise ValueError( |
|
"Either a configuration has to be provided, or all three of text encoder, audio encoder and Parler-TTS decoder." |
|
) |
|
if config is None: |
|
config = ParlerTTSConfig.from_sub_models_config(text_encoder.config, audio_encoder.config, decoder.config) |
|
else: |
|
if not isinstance(config, self.config_class): |
|
raise ValueError(f"Config: {config} has to be of type {self.config_class}") |
|
|
|
if config.decoder.cross_attention_hidden_size is not None: |
|
if config.decoder.cross_attention_hidden_size != config.text_encoder.hidden_size: |
|
raise ValueError( |
|
"If `cross_attention_hidden_size` is specified in the Parler-TTS decoder's configuration, it has to be equal" |
|
f" to the text encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for" |
|
f" `config.decoder.cross_attention_hidden_size` and {config.text_encoder.hidden_size} for" |
|
" `config.text_encoder.hidden_size`." |
|
) |
|
|
|
|
|
super().__init__(config) |
|
|
|
if text_encoder is None: |
|
from transformers.models.auto.modeling_auto import AutoModelForTextEncoding |
|
|
|
text_encoder = AutoModelForTextEncoding.from_config(config.text_encoder) |
|
|
|
if audio_encoder is None: |
|
from transformers.models.auto.modeling_auto import AutoModel |
|
|
|
audio_encoder = AutoModel.from_config(config.audio_encoder) |
|
|
|
if decoder is None: |
|
decoder = ParlerTTSForCausalLM(config.decoder) |
|
|
|
self.text_encoder = text_encoder |
|
self.audio_encoder = audio_encoder |
|
self.decoder = decoder |
|
|
|
if self.text_encoder.config.to_dict() != self.config.text_encoder.to_dict(): |
|
logger.warning( |
|
f"Config of the text_encoder: {self.text_encoder.__class__} is overwritten by shared text_encoder config:" |
|
f" {self.config.text_encoder}" |
|
) |
|
if self.audio_encoder.config.to_dict() != self.config.audio_encoder.to_dict(): |
|
logger.warning( |
|
f"Config of the audio_encoder: {self.audio_encoder.__class__} is overwritten by shared audio_encoder config:" |
|
f" {self.config.audio_encoder}" |
|
) |
|
if self.decoder.config.to_dict() != self.config.decoder.to_dict(): |
|
logger.warning( |
|
f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:" |
|
f" {self.config.decoder}" |
|
) |
|
|
|
|
|
|
|
self.text_encoder.config = self.config.text_encoder |
|
self.audio_encoder.config = self.config.audio_encoder |
|
self.decoder.config = self.config.decoder |
|
|
|
|
|
if ( |
|
self.text_encoder.config.hidden_size != self.decoder.config.hidden_size |
|
and self.decoder.config.cross_attention_hidden_size is None |
|
): |
|
self.enc_to_dec_proj = nn.Linear(self.text_encoder.config.hidden_size, self.decoder.config.hidden_size) |
|
|
|
|
|
self.embed_prompts = nn.Embedding(config.vocab_size, self.decoder.config.hidden_size) |
|
|
|
self.prompt_cross_attention = config.prompt_cross_attention |
|
if config.prompt_cross_attention: |
|
self.embed_positions = ParlerTTSSinusoidalPositionalEmbedding( |
|
config.decoder.max_position_embeddings, |
|
config.decoder.hidden_size, |
|
) |
|
|
|
if self.text_encoder.get_output_embeddings() is not None: |
|
raise ValueError( |
|
f"The encoder {self.text_encoder} should not have a LM Head. Please use a model without and LM Head" |
|
) |
|
|
|
decoder_signature = set(inspect.signature(self.decoder.forward).parameters.keys()) |
|
if "encoder_hidden_states" not in decoder_signature: |
|
raise ValueError( |
|
"The selected decoder is not prepared for the encoder hidden states to be passed. Please see the " |
|
"following discussion on GitHub: https://github.com/huggingface/transformers/issues/23350" |
|
) |
|
|
|
|
|
self.post_init() |
|
|
|
def _init_weights(self, module): |
|
std = self.decoder.config.initializer_factor |
|
if isinstance(module, (nn.Linear, nn.Conv1d)): |
|
module.weight.data.normal_(mean=0.0, std=std) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, nn.Embedding): |
|
module.weight.data.normal_(mean=0.0, std=std) |
|
if module.padding_idx is not None: |
|
module.weight.data[module.padding_idx].zero_() |
|
|
|
def tie_weights(self): |
|
|
|
if self.config.tie_encoder_decoder: |
|
|
|
decoder_base_model_prefix = self.decoder.base_model_prefix |
|
self._tie_encoder_decoder_weights( |
|
self.text_encoder, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix |
|
) |
|
|
|
def get_audio_encoder(self): |
|
return self.audio_encoder |
|
|
|
def get_text_encoder(self): |
|
return self.text_encoder |
|
|
|
def get_encoder(self): |
|
|
|
return self.get_text_encoder() |
|
|
|
def get_decoder(self): |
|
return self.decoder |
|
|
|
def get_input_embeddings(self): |
|
return self.text_encoder.get_input_embeddings() |
|
|
|
def get_output_embeddings(self): |
|
return self.decoder.get_output_embeddings() |
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
return self.decoder.set_output_embeddings(new_embeddings) |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |
|
r""" |
|
Example: |
|
|
|
```python |
|
>>> from parler_tts import ParlerTTSForConditionalGeneration |
|
|
|
>>> model = ParlerTTSForConditionalGeneration.from_pretrained("facebook/parler_tts-small") |
|
```""" |
|
|
|
|
|
if kwargs.get("_fast_init", False): |
|
logger.warning( |
|
"Fast initialization is currently not supported for ParlerTTSForConditionalGeneration. " |
|
"Falling back to slow initialization..." |
|
) |
|
kwargs["_fast_init"] = False |
|
|
|
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) |
|
|
|
@classmethod |
|
def from_sub_models_pretrained( |
|
cls, |
|
text_encoder_pretrained_model_name_or_path: str = None, |
|
audio_encoder_pretrained_model_name_or_path: str = None, |
|
decoder_pretrained_model_name_or_path: str = None, |
|
*model_args, |
|
**kwargs, |
|
) -> PreTrainedModel: |
|
r""" |
|
Instantiate a text encoder, an audio encoder, and a Parler-TTS decoder from one, two or three base classes of the |
|
library from pretrained model checkpoints. |
|
|
|
|
|
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train |
|
the model, you need to first set it back in training mode with `model.train()`. |
|
|
|
Params: |
|
text_encoder_pretrained_model_name_or_path (`str`, *optional*): |
|
Information necessary to initiate the text encoder. Can be either: |
|
|
|
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. |
|
Valid model ids can be located at the root-level, like `t5-base`, or namespaced under a user or |
|
organization name, like `google/flan-t5-base. |
|
- A path to a *directory* containing model weights saved using |
|
[`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. |
|
|
|
audio_encoder_pretrained_model_name_or_path (`str`, *optional*): |
|
Information necessary to initiate the audio encoder. Can be either: |
|
|
|
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. |
|
Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a |
|
user or organization name, like `facebook/encodec_24khz`. |
|
- A path to a *directory* containing model weights saved using |
|
[`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. |
|
|
|
decoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`): |
|
Information necessary to initiate the decoder. Can be either: |
|
|
|
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. |
|
Valid model ids can be located at the root-level, like `gpt2`, or namespaced under a user or |
|
organization name, like `facebook/parler_tts-small`. |
|
- A path to a *directory* containing model weights saved using |
|
[`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. |
|
|
|
model_args (remaining positional arguments, *optional*): |
|
All remaining positional arguments will be passed to the underlying model's `__init__` method. |
|
|
|
kwargs (remaining dictionary of keyword arguments, *optional*): |
|
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., |
|
`output_attentions=True`). |
|
|
|
- To update the text encoder configuration, use the prefix *text_encoder_* for each configuration |
|
parameter. |
|
- To update the audio encoder configuration, use the prefix *audio_encoder_* for each configuration |
|
parameter. |
|
- To update the decoder configuration, use the prefix *decoder_* for each configuration parameter. |
|
- To update the parent model configuration, do not use a prefix for each configuration parameter. |
|
|
|
Behaves differently depending on whether a `config` is provided or automatically loaded. |
|
|
|
Example: |
|
|
|
```python |
|
>>> from parler_tts import ParlerTTSForConditionalGeneration |
|
|
|
>>> # initialize a parler_tts model from a t5 text encoder, encodec audio encoder, and parler_tts decoder |
|
>>> model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained( |
|
... text_encoder_pretrained_model_name_or_path="t5-base", |
|
... audio_encoder_pretrained_model_name_or_path="facebook/encodec_24khz", |
|
... decoder_pretrained_model_name_or_path="facebook/parler_tts-small", |
|
... ) |
|
>>> # saving model after fine-tuning |
|
>>> model.save_pretrained("./parler_tts-ft") |
|
>>> # load fine-tuned model |
|
>>> model = ParlerTTSForConditionalGeneration.from_pretrained("./parler_tts-ft") |
|
```""" |
|
|
|
kwargs_text_encoder = { |
|
argument[len("text_encoder_") :]: value |
|
for argument, value in kwargs.items() |
|
if argument.startswith("text_encoder_") |
|
} |
|
|
|
kwargs_audio_encoder = { |
|
argument[len("audio_encoder_") :]: value |
|
for argument, value in kwargs.items() |
|
if argument.startswith("audio_encoder_") |
|
} |
|
|
|
kwargs_decoder = { |
|
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") |
|
} |
|
|
|
|
|
for key in kwargs_text_encoder.keys(): |
|
del kwargs["text_encoder_" + key] |
|
for key in kwargs_audio_encoder.keys(): |
|
del kwargs["audio_encoder_" + key] |
|
for key in kwargs_decoder.keys(): |
|
del kwargs["decoder_" + key] |
|
|
|
|
|
|
|
|
|
text_encoder = kwargs_text_encoder.pop("model", None) |
|
if text_encoder is None: |
|
if text_encoder_pretrained_model_name_or_path is None: |
|
raise ValueError( |
|
"If `text_encoder_model` is not defined as an argument, a `text_encoder_pretrained_model_name_or_path` has " |
|
"to be defined." |
|
) |
|
|
|
if "config" not in kwargs_text_encoder: |
|
encoder_config, kwargs_text_encoder = AutoConfig.from_pretrained( |
|
text_encoder_pretrained_model_name_or_path, **kwargs_text_encoder, return_unused_kwargs=True |
|
) |
|
|
|
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: |
|
logger.info( |
|
f"Initializing {text_encoder_pretrained_model_name_or_path} as a text_encoder model " |
|
"from a decoder model. Cross-attention and casual mask are disabled." |
|
) |
|
encoder_config.is_decoder = False |
|
encoder_config.add_cross_attention = False |
|
|
|
kwargs_text_encoder["config"] = encoder_config |
|
|
|
text_encoder = AutoModelForTextEncoding.from_pretrained( |
|
text_encoder_pretrained_model_name_or_path, *model_args, **kwargs_text_encoder |
|
) |
|
|
|
audio_encoder = kwargs_audio_encoder.pop("model", None) |
|
if audio_encoder is None: |
|
if audio_encoder_pretrained_model_name_or_path is None: |
|
raise ValueError( |
|
"If `audio_encoder_model` is not defined as an argument, an `audio_encoder_pretrained_model_name_or_path` has " |
|
"to be defined." |
|
) |
|
|
|
if "config" not in kwargs_audio_encoder: |
|
encoder_config, kwargs_audio_encoder = AutoConfig.from_pretrained( |
|
audio_encoder_pretrained_model_name_or_path, **kwargs_audio_encoder, return_unused_kwargs=True |
|
) |
|
|
|
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: |
|
logger.info( |
|
f"Initializing {audio_encoder_pretrained_model_name_or_path} as an audio_encoder model " |
|
"from a decoder model. Cross-attention and casual mask are disabled." |
|
) |
|
encoder_config.is_decoder = False |
|
encoder_config.add_cross_attention = False |
|
|
|
kwargs_audio_encoder["config"] = encoder_config |
|
|
|
audio_encoder = AutoModel.from_pretrained( |
|
audio_encoder_pretrained_model_name_or_path, *model_args, **kwargs_audio_encoder |
|
) |
|
|
|
decoder = kwargs_decoder.pop("model", None) |
|
if decoder is None: |
|
if decoder_pretrained_model_name_or_path is None: |
|
raise ValueError( |
|
"If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has " |
|
"to be defined." |
|
) |
|
|
|
if "config" not in kwargs_decoder: |
|
decoder_config, kwargs_decoder = ParlerTTSDecoderConfig.from_pretrained( |
|
decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True |
|
) |
|
|
|
if isinstance(decoder_config, ParlerTTSConfig): |
|
decoder_config = decoder_config.decoder |
|
|
|
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False: |
|
logger.info( |
|
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention" |
|
f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if" |
|
f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers." |
|
) |
|
decoder_config.is_decoder = True |
|
decoder_config.add_cross_attention = True |
|
|
|
kwargs_decoder["config"] = decoder_config |
|
|
|
if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False: |
|
logger.warning( |
|
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. " |
|
f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, " |
|
"make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` " |
|
"passed to `.from_sub_models_pretrained(...)` are set to `True` or do not pass a " |
|
"`decoder_config` to `.from_sub_models_pretrained(...)`" |
|
) |
|
|
|
decoder = ParlerTTSForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) |
|
|
|
|
|
config = ParlerTTSConfig.from_sub_models_config( |
|
text_encoder.config, audio_encoder.config, decoder.config, **kwargs |
|
) |
|
return cls(text_encoder=text_encoder, audio_encoder=audio_encoder, decoder=decoder, config=config) |
|
|
|
@add_start_docstrings_to_model_forward(MUSICGEN_INPUTS_DOCSTRING) |
|
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) |
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.BoolTensor] = None, |
|
input_values: Optional[torch.FloatTensor] = None, |
|
padding_mask: Optional[torch.BoolTensor] = None, |
|
decoder_input_ids: Optional[torch.LongTensor] = None, |
|
decoder_attention_mask: Optional[torch.BoolTensor] = None, |
|
encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, |
|
past_key_values: Tuple[Tuple[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
decoder_inputs_embeds: Optional[torch.FloatTensor] = None, |
|
prompt_input_ids: Optional[torch.FloatTensor] = None, |
|
prompt_attention_mask: Optional[torch.LongTensor] = None, |
|
prompt_hidden_states: Optional[torch.FloatTensor] = None, |
|
decoder_position_ids: Optional[torch.LongTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
**kwargs, |
|
) -> Union[Tuple, Seq2SeqLMOutput]: |
|
r""" |
|
Returns: |
|
|
|
Examples: |
|
```python |
|
>>> from transformers import AutoProcessor, ParlerTTSForConditionalGeneration |
|
>>> import torch |
|
|
|
>>> processor = AutoProcessor.from_pretrained("facebook/parler_tts-small") |
|
>>> model = ParlerTTSForConditionalGeneration.from_pretrained("facebook/parler_tts-small") |
|
|
|
>>> inputs = processor( |
|
... text=["80s pop track with bassy drums and synth", "90s rock song with loud guitars and heavy drums"], |
|
... padding=True, |
|
... return_tensors="pt", |
|
... ) |
|
|
|
>>> pad_token_id = model.generation_config.pad_token_id |
|
>>> decoder_input_ids = ( |
|
... torch.ones((inputs.input_ids.shape[0] * model.decoder.num_codebooks, 1), dtype=torch.long) |
|
... * pad_token_id |
|
... ) |
|
|
|
>>> logits = model(**inputs, decoder_input_ids=decoder_input_ids).logits |
|
>>> logits.shape # (bsz * num_codebooks, tgt_len, vocab_size) |
|
torch.Size([8, 1, 2048]) |
|
```""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
kwargs_text_encoder = { |
|
argument[len("text_encoder_")]: value |
|
for argument, value in kwargs.items() |
|
if argument.startswith("text_encoder_") |
|
} |
|
|
|
kwargs_audio_encoder = { |
|
argument[len("audio_encoder_")]: value |
|
for argument, value in kwargs.items() |
|
if argument.startswith("audio_encoder_") |
|
} |
|
|
|
kwargs_decoder = { |
|
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") |
|
} |
|
|
|
if encoder_outputs is None: |
|
encoder_outputs = self.text_encoder( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
inputs_embeds=inputs_embeds, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
**kwargs_text_encoder, |
|
) |
|
elif isinstance(encoder_outputs, tuple): |
|
encoder_outputs = BaseModelOutput(*encoder_outputs) |
|
|
|
encoder_hidden_states = encoder_outputs[0] |
|
|
|
|
|
if ( |
|
self.text_encoder.config.hidden_size != self.decoder.config.hidden_size |
|
and self.decoder.config.cross_attention_hidden_size is None |
|
): |
|
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) |
|
|
|
if attention_mask is not None: |
|
encoder_hidden_states = encoder_hidden_states * attention_mask[..., None] |
|
|
|
if prompt_hidden_states is None: |
|
if prompt_input_ids is not None: |
|
prompt_hidden_states = self.embed_prompts(prompt_input_ids) |
|
|
|
if prompt_hidden_states is not None and self.prompt_cross_attention: |
|
|
|
positions = self.embed_positions(prompt_hidden_states, 0) |
|
prompt_hidden_states = prompt_hidden_states + positions.to(prompt_hidden_states.device) |
|
|
|
|
|
encoder_hidden_states = torch.cat([encoder_hidden_states, prompt_hidden_states], dim=1) |
|
if prompt_attention_mask is not None: |
|
if attention_mask is None: |
|
attention_mask = torch.ones(encoder_hidden_states.shape[:2], device=self.device, dtype=prompt_attention_mask.dtype) |
|
attention_mask = torch.cat([attention_mask, prompt_attention_mask], dim=1) |
|
|
|
prompt_hidden_states = None |
|
prompt_attention_mask = None |
|
|
|
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None): |
|
decoder_input_ids = shift_tokens_right( |
|
labels, self.config.pad_token_id, self.config.decoder_start_token_id |
|
).transpose(1, 2) |
|
|
|
elif decoder_input_ids is None and decoder_inputs_embeds is None: |
|
audio_encoder_outputs = self.audio_encoder( |
|
input_values=input_values, |
|
padding_mask=padding_mask, |
|
**kwargs_audio_encoder, |
|
) |
|
audio_codes = audio_encoder_outputs.audio_codes |
|
frames, bsz, codebooks, seq_len = audio_codes.shape |
|
if frames != 1: |
|
raise ValueError( |
|
f"Expected 1 frame in the audio code outputs, got {frames} frames. Ensure chunking is " |
|
"disabled by setting `chunk_length=None` in the audio encoder." |
|
) |
|
|
|
if self.config.decoder.audio_channels == 2 and audio_codes.shape[2] == self.decoder.num_codebooks // 2: |
|
|
|
audio_codes = audio_codes.repeat_interleave(2, dim=2) |
|
|
|
decoder_input_ids = audio_codes[0, ...].reshape(bsz * self.decoder.num_codebooks, seq_len) |
|
|
|
|
|
decoder_outputs = self.decoder( |
|
input_ids=decoder_input_ids, |
|
attention_mask=decoder_attention_mask, |
|
position_ids=decoder_position_ids, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=attention_mask, |
|
prompt_hidden_states=prompt_hidden_states, |
|
prompt_attention_mask=prompt_attention_mask, |
|
inputs_embeds=decoder_inputs_embeds, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
use_cache=use_cache, |
|
past_key_values=past_key_values, |
|
return_dict=return_dict, |
|
labels=labels, |
|
**kwargs_decoder, |
|
) |
|
|
|
if not return_dict: |
|
return decoder_outputs + (encoder_hidden_states,) |
|
|
|
return Seq2SeqLMOutput( |
|
loss=decoder_outputs.loss, |
|
logits=decoder_outputs.logits, |
|
past_key_values=decoder_outputs.past_key_values, |
|
decoder_hidden_states=decoder_outputs.hidden_states, |
|
decoder_attentions=decoder_outputs.attentions, |
|
cross_attentions=decoder_outputs.cross_attentions, |
|
encoder_last_hidden_state=encoder_outputs.last_hidden_state, |
|
encoder_hidden_states=encoder_outputs.hidden_states, |
|
encoder_attentions=encoder_outputs.attentions, |
|
) |
|
|
|
def prepare_inputs_for_generation( |
|
self, |
|
decoder_input_ids, |
|
past_key_values=None, |
|
attention_mask=None, |
|
head_mask=None, |
|
decoder_attention_mask=None, |
|
decoder_head_mask=None, |
|
prompt_hidden_states=None, |
|
prompt_attention_mask=None, |
|
cross_attn_head_mask=None, |
|
use_cache=None, |
|
encoder_outputs=None, |
|
decoder_delay_pattern_mask=None, |
|
guidance_scale=None, |
|
**kwargs, |
|
): |
|
if decoder_delay_pattern_mask is None: |
|
decoder_input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask( |
|
decoder_input_ids, |
|
bos_token_id=self.generation_config.bos_token_id, |
|
pad_token_id=self.generation_config.pad_token_id, |
|
max_length=self.generation_config.max_length, |
|
) |
|
|
|
|
|
decoder_input_ids = self.decoder.apply_delay_pattern_mask(decoder_input_ids, decoder_delay_pattern_mask) |
|
|
|
if guidance_scale is not None and guidance_scale > 1: |
|
|
|
|
|
decoder_input_ids = decoder_input_ids.repeat((2, 1)) |
|
if decoder_attention_mask is not None: |
|
decoder_attention_mask = decoder_attention_mask.repeat((2, 1)) |
|
if prompt_hidden_states is not None: |
|
prompt_hidden_states = prompt_hidden_states.repeat((2, 1, 1)) |
|
if prompt_attention_mask is not None: |
|
prompt_attention_mask = prompt_attention_mask.repeat((2, 1)) |
|
|
|
if past_key_values is not None: |
|
past_length = past_key_values[0][0].shape[2] |
|
|
|
|
|
if decoder_input_ids.shape[1] > past_length: |
|
remove_prefix_length = past_length |
|
else: |
|
|
|
remove_prefix_length = decoder_input_ids.shape[1] - 1 |
|
|
|
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] |
|
|
|
|
|
prompt_hidden_states = prompt_hidden_states if self.prompt_cross_attention else None |
|
|
|
return { |
|
"input_ids": None, |
|
"encoder_outputs": encoder_outputs, |
|
"past_key_values": past_key_values, |
|
"decoder_input_ids": decoder_input_ids, |
|
"attention_mask": attention_mask, |
|
"decoder_attention_mask": decoder_attention_mask, |
|
"head_mask": head_mask, |
|
"decoder_head_mask": decoder_head_mask, |
|
"cross_attn_head_mask": cross_attn_head_mask, |
|
"prompt_hidden_states": prompt_hidden_states, |
|
"prompt_attention_mask": prompt_attention_mask, |
|
"use_cache": use_cache, |
|
} |
|
|
|
def _prepare_decoder_input_ids_for_generation( |
|
self, |
|
batch_size: int, |
|
model_input_name: str, |
|
model_kwargs: Dict[str, torch.Tensor], |
|
decoder_start_token_id: int = None, |
|
bos_token_id: int = None, |
|
device: torch.device = None, |
|
) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]: |
|
"""Prepares `decoder_input_ids` for generation with encoder-decoder models""" |
|
|
|
|
|
|
|
if model_kwargs is not None and "decoder_input_ids" in model_kwargs: |
|
decoder_input_ids = model_kwargs.pop("decoder_input_ids") |
|
elif "input_ids" in model_kwargs and model_input_name != "input_ids": |
|
decoder_input_ids = model_kwargs.pop("input_ids") |
|
else: |
|
decoder_input_ids = None |
|
|
|
|
|
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) |
|
if device is None: |
|
device = self.device |
|
decoder_input_ids_start = ( |
|
torch.ones((batch_size * self.decoder.num_codebooks, 1), dtype=torch.long, device=device) |
|
* decoder_start_token_id |
|
) |
|
|
|
|
|
if decoder_input_ids is None: |
|
decoder_input_ids = decoder_input_ids_start |
|
|
|
|
|
|
|
elif (decoder_input_ids[..., 0] != decoder_start_token_id).all().item(): |
|
decoder_input_ids = torch.cat([decoder_input_ids_start, decoder_input_ids], dim=-1) |
|
if "decoder_attention_mask" in model_kwargs: |
|
decoder_attention_mask = model_kwargs["decoder_attention_mask"] |
|
decoder_attention_mask = torch.cat( |
|
(torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask), |
|
dim=-1, |
|
) |
|
model_kwargs["decoder_attention_mask"] = decoder_attention_mask |
|
|
|
return decoder_input_ids, model_kwargs |
|
|
|
def _prepare_text_encoder_kwargs_for_generation( |
|
self, |
|
inputs_tensor: torch.Tensor, |
|
model_kwargs, |
|
model_input_name: Optional[str], |
|
generation_config: GenerationConfig, |
|
) -> Dict[str, Any]: |
|
|
|
encoder = self.get_text_encoder() |
|
|
|
|
|
if hasattr(encoder, "_hf_hook"): |
|
encoder._hf_hook.io_same_device = True |
|
|
|
|
|
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"] |
|
encoder_kwargs = { |
|
argument: value |
|
for argument, value in model_kwargs.items() |
|
if not any(argument.startswith(p) for p in irrelevant_prefix) |
|
} |
|
encoder_signature = set(inspect.signature(encoder.forward).parameters) |
|
encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature |
|
if not encoder_accepts_wildcard: |
|
encoder_kwargs = { |
|
argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature |
|
} |
|
encoder_kwargs["output_attentions"] = generation_config.output_attentions |
|
encoder_kwargs["output_hidden_states"] = generation_config.output_hidden_states |
|
guidance_scale = generation_config.guidance_scale |
|
|
|
|
|
model_input_name = model_input_name if model_input_name is not None else self.text_encoder.main_input_name |
|
encoder_kwargs["return_dict"] = True |
|
encoder_kwargs[model_input_name] = inputs_tensor |
|
last_hidden_state = encoder(**encoder_kwargs).last_hidden_state |
|
|
|
|
|
if guidance_scale is not None and guidance_scale > 1: |
|
last_hidden_state = torch.concatenate([last_hidden_state, torch.zeros_like(last_hidden_state)], dim=0) |
|
if "attention_mask" in model_kwargs: |
|
model_kwargs["attention_mask"] = torch.concatenate( |
|
[model_kwargs["attention_mask"], torch.zeros_like(model_kwargs["attention_mask"])], dim=0 |
|
) |
|
|
|
model_kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=last_hidden_state) |
|
|
|
return model_kwargs |
|
|
|
def _prepare_prompt_kwargs_for_generation(self, prompt_input_ids, model_kwargs): |
|
model_kwargs["prompt_hidden_states"] = self.embed_prompts(prompt_input_ids) |
|
return model_kwargs |
|
|
|
def _prepare_audio_encoder_kwargs_for_generation( |
|
self, input_values, model_kwargs, model_input_name: Optional[str] = None |
|
): |
|
|
|
encoder = self.get_audio_encoder() |
|
|
|
|
|
if hasattr(encoder, "_hf_hook"): |
|
encoder._hf_hook.io_same_device = True |
|
|
|
|
|
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"] |
|
encoder_kwargs = { |
|
argument: value |
|
for argument, value in model_kwargs.items() |
|
if not any(argument.startswith(p) for p in irrelevant_prefix) |
|
} |
|
encoder_signature = set(inspect.signature(encoder.forward).parameters) |
|
encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature |
|
if not encoder_accepts_wildcard: |
|
encoder_kwargs = { |
|
argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature |
|
} |
|
|
|
|
|
model_input_name = model_input_name if model_input_name is not None else self.audio_encoder.main_input_name |
|
encoder_kwargs["return_dict"] = True |
|
|
|
encoder_kwargs[model_input_name] = input_values |
|
audio_encoder_outputs = encoder.encode(**encoder_kwargs) |
|
audio_codes = audio_encoder_outputs.audio_codes |
|
audio_scales = audio_encoder_outputs.audio_scales |
|
|
|
frames, bsz, codebooks, seq_len = audio_codes.shape |
|
|
|
if frames != 1: |
|
raise ValueError( |
|
f"Expected 1 frame in the audio code outputs, got {frames} frames. Ensure chunking is " |
|
"disabled by setting `chunk_length=None` in the audio encoder." |
|
) |
|
|
|
decoder_input_ids = audio_codes[0, ...].reshape(bsz * self.decoder.num_codebooks, seq_len) |
|
|
|
model_kwargs["decoder_input_ids"] = decoder_input_ids |
|
model_kwargs["audio_scales"] = audio_scales |
|
return model_kwargs |
|
|
|
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): |
|
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id).transpose(1, 2) |
|
|
|
def resize_token_embeddings(self, *args, **kwargs): |
|
raise NotImplementedError( |
|
"Resizing the embedding layers via the EncoderDecoderModel directly is not supported. Please use the" |
|
" respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or" |
|
" model.decoder.resize_token_embeddings(...))" |
|
) |
|
|
|
def _maybe_initialize_input_ids_for_generation( |
|
self, |
|
inputs: Optional[torch.Tensor] = None, |
|
bos_token_id: Optional[int] = None, |
|
model_kwargs: Optional[Dict[str, torch.Tensor]] = None, |
|
) -> torch.LongTensor: |
|
"""Initializes input ids for generation, if necessary.""" |
|
if inputs is not None: |
|
return inputs |
|
|
|
encoder_outputs = model_kwargs.get("encoder_outputs") |
|
if encoder_outputs is not None: |
|
|
|
shape = encoder_outputs[0].size()[:-1] |
|
return torch.ones(shape, dtype=torch.long, device=self.device) * -100 |
|
|
|
if bos_token_id is None: |
|
raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.") |
|
|
|
|
|
|
|
batch_size = 1 |
|
for value in model_kwargs.values(): |
|
if isinstance(value, torch.Tensor): |
|
batch_size = value.shape[0] |
|
break |
|
return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id |
|
|
|
def freeze_encoders(self, freeze_text_encoder=True): |
|
if freeze_text_encoder: |
|
for param in self.text_encoder.parameters(): |
|
param.requires_grad = False |
|
self.text_encoder._requires_grad = False |
|
|
|
for param in self.audio_encoder.parameters(): |
|
param.requires_grad = False |
|
self.audio_encoder._requires_grad = False |
|
|
|
@torch.no_grad() |
|
def generate( |
|
self, |
|
inputs: Optional[torch.Tensor] = None, |
|
generation_config: Optional[GenerationConfig] = None, |
|
logits_processor: Optional[LogitsProcessorList] = None, |
|
stopping_criteria: Optional[StoppingCriteriaList] = None, |
|
synced_gpus: Optional[bool] = None, |
|
streamer: Optional["BaseStreamer"] = None, |
|
**kwargs, |
|
): |
|
""" |
|
|
|
Generates sequences of token ids for models with a language modeling head. |
|
|
|
<Tip warning={true}> |
|
|
|
Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the |
|
model's default generation configuration. You can override any `generation_config` by passing the corresponding |
|
parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. |
|
|
|
For an overview of generation strategies and code examples, check out the [following |
|
guide](./generation_strategies). |
|
|
|
</Tip> |
|
|
|
Parameters: |
|
inputs (`torch.Tensor` of varying shape depending on the modality, *optional*): |
|
The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the |
|
method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` |
|
should be in the format `input_ids`. For encoder-decoder models *inputs* can represent any of |
|
`input_ids`, `input_values`, `input_features`, or `pixel_values`. |
|
generation_config (`~generation.GenerationConfig`, *optional*): |
|
The generation configuration to be used as base parametrization for the generation call. `**kwargs` |
|
passed to generate matching the attributes of `generation_config` will override them. If |
|
`generation_config` is not provided, the default will be used, which had the following loading |
|
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model |
|
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s |
|
default values, whose documentation should be checked to parameterize generation. |
|
logits_processor (`LogitsProcessorList`, *optional*): |
|
Custom logits processors that complement the default logits processors built from arguments and |
|
generation config. If a logit processor is passed that is already created with the arguments or a |
|
generation config an error is thrown. This feature is intended for advanced users. |
|
stopping_criteria (`StoppingCriteriaList`, *optional*): |
|
Custom stopping criteria that complement the default stopping criteria built from arguments and a |
|
generation config. If a stopping criteria is passed that is already created with the arguments or a |
|
generation config an error is thrown. This feature is intended for advanced users. |
|
synced_gpus (`bool`, *optional*, defaults to `False`): |
|
Whether to continue running the while loop until max_length (needed for ZeRO stage 3) |
|
streamer (`BaseStreamer`, *optional*): |
|
Streamer object that will be used to stream the generated sequences. Generated tokens are passed |
|
through `streamer.put(token_ids)` and the streamer is responsible for any further processing. |
|
kwargs (`Dict[str, Any]`, *optional*): |
|
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be |
|
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder |
|
specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. |
|
|
|
Return: |
|
[`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` |
|
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. |
|
|
|
If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible |
|
[`~utils.ModelOutput`] types are: |
|
|
|
- [`~generation.GenerateDecoderOnlyOutput`], |
|
- [`~generation.GenerateBeamDecoderOnlyOutput`] |
|
|
|
If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible |
|
[`~utils.ModelOutput`] types are: |
|
|
|
- [`~generation.GenerateEncoderDecoderOutput`], |
|
- [`~generation.GenerateBeamEncoderDecoderOutput`] |
|
""" |
|
|
|
if generation_config is None: |
|
generation_config = self.generation_config |
|
|
|
generation_config = copy.deepcopy(generation_config) |
|
model_kwargs = generation_config.update(**kwargs) |
|
generation_config.validate() |
|
self._validate_model_kwargs(model_kwargs.copy()) |
|
|
|
if model_kwargs.get("encoder_outputs") is not None and type(model_kwargs["encoder_outputs"]) == tuple: |
|
|
|
model_kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=model_kwargs["encoder_outputs"][0]) |
|
|
|
|
|
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() |
|
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() |
|
|
|
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None: |
|
if model_kwargs.get("attention_mask", None) is None: |
|
logger.warning( |
|
"The attention mask and the pad token id were not set. As a consequence, you may observe " |
|
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." |
|
) |
|
eos_token_id = generation_config.eos_token_id |
|
if isinstance(eos_token_id, list): |
|
eos_token_id = eos_token_id[0] |
|
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") |
|
generation_config.pad_token_id = eos_token_id |
|
|
|
|
|
|
|
|
|
|
|
|
|
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( |
|
inputs, generation_config.bos_token_id, model_kwargs |
|
) |
|
batch_size = inputs_tensor.shape[0] |
|
|
|
|
|
model_kwargs["use_cache"] = generation_config.use_cache |
|
model_kwargs["guidance_scale"] = generation_config.guidance_scale |
|
|
|
requires_attention_mask = "encoder_outputs" not in model_kwargs |
|
|
|
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask: |
|
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( |
|
inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id |
|
) |
|
|
|
if "encoder_outputs" not in model_kwargs: |
|
|
|
model_kwargs = self._prepare_text_encoder_kwargs_for_generation( |
|
inputs_tensor, |
|
model_kwargs, |
|
model_input_name, |
|
generation_config, |
|
) |
|
|
|
if "prompt_hidden_states" not in model_kwargs and "prompt_input_ids" in model_kwargs: |
|
|
|
model_kwargs = self._prepare_prompt_kwargs_for_generation( |
|
model_kwargs["prompt_input_ids"], |
|
model_kwargs, |
|
) |
|
|
|
if "decoder_input_ids" not in model_kwargs and "input_values" in model_kwargs: |
|
model_kwargs = self._prepare_audio_encoder_kwargs_for_generation( |
|
model_kwargs["input_values"], |
|
model_kwargs, |
|
) |
|
|
|
|
|
input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation( |
|
batch_size=batch_size, |
|
model_input_name=model_input_name, |
|
model_kwargs=model_kwargs, |
|
decoder_start_token_id=generation_config.decoder_start_token_id, |
|
bos_token_id=generation_config.bos_token_id, |
|
device=inputs_tensor.device, |
|
) |
|
|
|
|
|
input_ids_seq_length = input_ids.shape[-1] |
|
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None |
|
if has_default_max_length and generation_config.max_new_tokens is None: |
|
logger.warning( |
|
f"Using the model-agnostic default `max_length` (={generation_config.max_length}) " |
|
"to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation." |
|
) |
|
elif generation_config.max_new_tokens is not None: |
|
if not has_default_max_length: |
|
logger.warning( |
|
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" |
|
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " |
|
"Please refer to the documentation for more information. " |
|
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" |
|
) |
|
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length |
|
|
|
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length: |
|
raise ValueError( |
|
f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than" |
|
f" the maximum length ({generation_config.max_length})" |
|
) |
|
if input_ids_seq_length >= generation_config.max_length: |
|
logger.warning( |
|
f"Input length of decoder_input_ids is {input_ids_seq_length}, but `max_length` is set to" |
|
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" |
|
" increasing `max_new_tokens`." |
|
) |
|
|
|
|
|
input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask( |
|
input_ids, |
|
bos_token_id=generation_config.bos_token_id, |
|
pad_token_id=generation_config.pad_token_id, |
|
max_length=generation_config.max_length, |
|
) |
|
|
|
model_kwargs["decoder_delay_pattern_mask"] = decoder_delay_pattern_mask |
|
|
|
|
|
if streamer is not None: |
|
streamer.put(input_ids.cpu()) |
|
|
|
|
|
is_greedy_gen_mode = ( |
|
(generation_config.num_beams == 1) |
|
and (generation_config.num_beam_groups == 1) |
|
and generation_config.do_sample is False |
|
) |
|
is_sample_gen_mode = ( |
|
(generation_config.num_beams == 1) |
|
and (generation_config.num_beam_groups == 1) |
|
and generation_config.do_sample is True |
|
) |
|
|
|
|
|
if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1: |
|
logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale)) |
|
generation_config.guidance_scale = None |
|
|
|
|
|
logits_processor = self._get_logits_processor( |
|
generation_config=generation_config, |
|
input_ids_seq_length=input_ids_seq_length, |
|
encoder_input_ids=inputs_tensor, |
|
prefix_allowed_tokens_fn=None, |
|
logits_processor=logits_processor, |
|
) |
|
|
|
|
|
stopping_criteria = self._get_stopping_criteria( |
|
generation_config=generation_config, stopping_criteria=stopping_criteria |
|
) |
|
|
|
if is_greedy_gen_mode: |
|
if generation_config.num_return_sequences > 1: |
|
raise ValueError( |
|
"num_return_sequences has to be 1 when doing greedy search, " |
|
f"but is {generation_config.num_return_sequences}." |
|
) |
|
|
|
|
|
outputs = self._greedy_search( |
|
input_ids, |
|
logits_processor=logits_processor, |
|
stopping_criteria=stopping_criteria, |
|
generation_config=generation_config, |
|
synced_gpus=synced_gpus, |
|
streamer=streamer, |
|
**model_kwargs, |
|
) |
|
|
|
elif is_sample_gen_mode: |
|
|
|
logits_warper = self._get_logits_warper(generation_config) |
|
|
|
|
|
input_ids, model_kwargs = self._expand_inputs_for_generation( |
|
input_ids=input_ids, |
|
expand_size=generation_config.num_return_sequences, |
|
is_encoder_decoder=self.config.is_encoder_decoder, |
|
**model_kwargs, |
|
) |
|
|
|
|
|
outputs = self._sample( |
|
input_ids, |
|
logits_processor=logits_processor, |
|
logits_warper=logits_warper, |
|
stopping_criteria=stopping_criteria, |
|
generation_config=generation_config, |
|
synced_gpus=synced_gpus, |
|
streamer=streamer, |
|
**model_kwargs, |
|
) |
|
|
|
else: |
|
raise ValueError( |
|
"Got incompatible mode for generation, should be one of greedy or sampling. " |
|
"Ensure that beam search is de-activated by setting `num_beams=1` and `num_beam_groups=1`." |
|
) |
|
|
|
if generation_config.return_dict_in_generate: |
|
output_ids = outputs.sequences |
|
else: |
|
output_ids = outputs |
|
|
|
|
|
output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"]) |
|
|
|
|
|
_, mask = self.decoder.build_delay_pattern_mask( |
|
input_ids, |
|
bos_token_id=generation_config.bos_token_id, |
|
pad_token_id=generation_config.pad_token_id, |
|
max_length=output_ids.shape[1], |
|
) |
|
|
|
mask = (mask != generation_config.bos_token_id) & (mask != generation_config.pad_token_id) |
|
output_ids = output_ids[mask].reshape(batch_size, self.decoder.num_codebooks, -1) |
|
|
|
|
|
output_ids = output_ids[None, ...] |
|
|
|
audio_scales = model_kwargs.get("audio_scales") |
|
if audio_scales is None: |
|
audio_scales = [None] * batch_size |
|
|
|
decode_sequentially = ( |
|
generation_config.bos_token_id in output_ids |
|
or generation_config.pad_token_id in output_ids |
|
or generation_config.eos_token_id in output_ids |
|
) |
|
if not decode_sequentially: |
|
output_values = self.audio_encoder.decode( |
|
output_ids, |
|
audio_scales=audio_scales, |
|
).audio_values.squeeze(1) |
|
else: |
|
output_values = [] |
|
for sample_id in range(batch_size): |
|
sample = output_ids[:, sample_id] |
|
sample_mask = (sample >= self.audio_encoder.config.codebook_size).sum(dim=(0, 1)) == 0 |
|
if sample_mask.sum() > 0: |
|
sample = sample[:, :, sample_mask] |
|
sample = self.audio_encoder.decode(sample[None, ...], [audio_scales[sample_id]]).audio_values |
|
output_values.append(sample.transpose(0, 2)) |
|
else: |
|
output_values.append(torch.zeros((1, 1, 1)).to(self.device)) |
|
|
|
output_values = ( |
|
torch.nn.utils.rnn.pad_sequence(output_values, batch_first=True, padding_value=0) |
|
.squeeze(-1) |
|
.squeeze(-1) |
|
) |
|
|
|
if generation_config.return_dict_in_generate: |
|
outputs.sequences = output_values |
|
return outputs |
|
else: |
|
return output_values |
|
|