kazemnejad commited on
Commit
cb16e0e
1 Parent(s): 4d9169a

Upload CustomDecoderOnlyT5

Browse files
config.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "CustomDecoderOnlyT5"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_custom_t5.CustomT5Config",
7
+ "AutoModelForCausalLM": "modeling_custom_t5.CustomDecoderOnlyT5"
8
+ },
9
+ "classifier_dropout": 0.0,
10
+ "d_ff": 16384,
11
+ "d_kv": 128,
12
+ "d_model": 1024,
13
+ "decoder_start_token_id": 0,
14
+ "dense_act_fn": "relu",
15
+ "dropout_rate": 0.1,
16
+ "eos_token_id": 1,
17
+ "feed_forward_proj": "relu",
18
+ "initializer_factor": 1.0,
19
+ "is_decoder": true,
20
+ "is_encoder_decoder": false,
21
+ "is_gated_act": false,
22
+ "layer_norm_epsilon": 1e-06,
23
+ "model_type": "custom_decoder_only_t5",
24
+ "n_positions": 1024,
25
+ "num_decoder_layers": 24,
26
+ "num_heads": 32,
27
+ "num_layers": 24,
28
+ "output_past": true,
29
+ "pad_token_id": 0,
30
+ "position_encoding_type": "alibi",
31
+ "relative_attention_max_distance": 128,
32
+ "relative_attention_num_buckets": 32,
33
+ "torch_dtype": "float32",
34
+ "transformers_version": "4.31.0",
35
+ "use_cache": true,
36
+ "vocab_size": 49152
37
+ }
configuration_custom_t5.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import T5Config
2
+
3
+ POSITION_ENCODING_REL_T5_BIAS = "t5_relative_bias"
4
+ POSITION_ENCODING_REL_TRANSFORMER_XL = "transformer_xl_relative_encoding"
5
+ POSITION_ENCODING_ROTARY = "rotary"
6
+ POSITION_ENCODING_ROTARY_RERUN = "rotary_rerun"
7
+ POSITION_ENCODING_ROTARY_NEW = "new_rotary"
8
+ POSITION_ENCODING_ABS_LEARNED = "abs_learned"
9
+ POSITION_ENCODING_ABS_SINUSOID = "abs_sinusoid"
10
+ POSITION_ENCODING_ALiBi = "alibi"
11
+ POSITION_ENCODING_ALiBi_LEARNED = "alibi_learned"
12
+ POSITION_ENCODING_NONE = "none"
13
+ POSITION_ENCODING_NONE_WINDOW = "none_window"
14
+
15
+
16
+ class CustomT5Config(T5Config):
17
+ model_type = "custom_decoder_only_t5"
18
+
19
+ def __init__(
20
+ self,
21
+ position_encoding_type=POSITION_ENCODING_REL_T5_BIAS,
22
+ **kwargs,
23
+ ):
24
+ if position_encoding_type not in [
25
+ POSITION_ENCODING_ALiBi,
26
+ POSITION_ENCODING_ALiBi_LEARNED,
27
+ POSITION_ENCODING_ABS_LEARNED,
28
+ POSITION_ENCODING_ABS_SINUSOID,
29
+ POSITION_ENCODING_REL_T5_BIAS,
30
+ POSITION_ENCODING_REL_TRANSFORMER_XL,
31
+ POSITION_ENCODING_ROTARY,
32
+ POSITION_ENCODING_ROTARY_NEW,
33
+ POSITION_ENCODING_NONE,
34
+ POSITION_ENCODING_NONE_WINDOW,
35
+ ]:
36
+ raise ValueError(
37
+ f"Invalid position_encoding_type: {position_encoding_type}"
38
+ )
39
+ self.position_encoding_type = position_encoding_type
40
+ super().__init__(**kwargs)
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "decoder_start_token_id": 0,
4
+ "eos_token_id": 1,
5
+ "pad_token_id": 0,
6
+ "transformers_version": "4.31.0"
7
+ }
modeling_custom_t5.py ADDED
@@ -0,0 +1,1416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ import os
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from typing import Optional, Tuple
7
+
8
+ import torch
9
+ from torch import nn
10
+ from torch.nn import CrossEntropyLoss
11
+ from torch.utils.checkpoint import checkpoint
12
+ from transformers import T5Config
13
+ from transformers.modeling_outputs import (
14
+ BaseModelOutputWithPastAndCrossAttentions,
15
+ )
16
+ from transformers.utils import ModelOutput
17
+ from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
18
+
19
+ from .configuration_custom_t5 import (
20
+ POSITION_ENCODING_REL_T5_BIAS,
21
+ POSITION_ENCODING_REL_TRANSFORMER_XL,
22
+ POSITION_ENCODING_ROTARY,
23
+ POSITION_ENCODING_ROTARY_NEW,
24
+ POSITION_ENCODING_ABS_LEARNED,
25
+ POSITION_ENCODING_ABS_SINUSOID,
26
+ POSITION_ENCODING_ALiBi,
27
+ POSITION_ENCODING_ALiBi_LEARNED,
28
+ POSITION_ENCODING_NONE,
29
+ POSITION_ENCODING_NONE_WINDOW,
30
+ CustomT5Config,
31
+ )
32
+ from .modeling_t5 import (
33
+ T5Stack,
34
+ T5PreTrainedModel,
35
+ T5Block,
36
+ T5LayerNorm,
37
+ T5LayerFF,
38
+ T5LayerSelfAttention,
39
+ T5Attention,
40
+ T5LayerCrossAttention,
41
+ )
42
+
43
+ logger = logging.getLogger(__name__)
44
+
45
+
46
+ @dataclass
47
+ class CausalLMOutputWithPastAndLoss(ModelOutput):
48
+ """
49
+ Base class for causal language model (or autoregressive) outputs.
50
+
51
+ Args:
52
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
53
+ Language modeling loss (for next-token prediction).
54
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
55
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
56
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
57
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
58
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
59
+
60
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
61
+ `past_key_values` input) to speed up sequential decoding.
62
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
63
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
64
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
65
+
66
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
67
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
68
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
69
+ sequence_length)`.
70
+
71
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
72
+ heads.
73
+ """
74
+
75
+ loss: Optional[torch.FloatTensor] = None
76
+ logits: torch.FloatTensor = None
77
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
78
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
79
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
80
+ non_reduced_loss: Optional[torch.FloatTensor] = None
81
+
82
+
83
+ def fixed_pos_embedding(x, seq_dim=1, seq_len=None):
84
+ dim = x.shape[-1]
85
+ if seq_len is None:
86
+ seq_len = x.shape[seq_dim]
87
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim))
88
+ sinusoid_inp = (
89
+ torch.einsum("i , j -> i j", torch.arange(seq_len), inv_freq)
90
+ .to(x.device)
91
+ .float()
92
+ )
93
+ return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)
94
+
95
+
96
+ def rotate_every_two(x):
97
+ """
98
+ Example: [a, b, c, d] -> [-b, a, -d, c]
99
+ """
100
+ x1 = x[:, :, :, ::2]
101
+ x2 = x[:, :, :, 1::2]
102
+ x = torch.stack((-x2, x1), axis=-1)
103
+ return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')
104
+
105
+
106
+ def apply_rotary_pos_emb(x, sincos, offset=0):
107
+ sin, cos = map(
108
+ lambda t: t[None, offset : x.shape[1] + offset, None, :].repeat_interleave(
109
+ 2, 3
110
+ ),
111
+ sincos,
112
+ )
113
+ # einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2)
114
+ return (x * cos) + (rotate_every_two(x) * sin)
115
+
116
+
117
+ def apply_rotary_pos_emb_new(x, sincos, offset=0):
118
+ sin, cos = map(
119
+ lambda t: t[:, :, None, :].repeat_interleave(2, 3),
120
+ sincos,
121
+ )
122
+ # einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2)
123
+ return (x * cos) + (rotate_every_two(x) * sin)
124
+
125
+
126
+ class PositionalEmbedding(nn.Module):
127
+ def __init__(self, demb):
128
+ super().__init__()
129
+
130
+ self.demb = demb
131
+
132
+ inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
133
+ self.register_buffer("inv_freq", inv_freq)
134
+
135
+ def forward(self, pos_seq, bsz=None):
136
+ sinusoid_inp = torch.ger(pos_seq, self.inv_freq)
137
+ pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)
138
+
139
+ if bsz is not None:
140
+ return pos_emb[None, :, :].expand(bsz, -1, -1)
141
+ else:
142
+ return pos_emb[None, :, :]
143
+
144
+
145
+ class FixedAbsolutePositionalEmbedding(nn.Module):
146
+ def __init__(self, dim):
147
+ super().__init__()
148
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
149
+ t = torch.arange(16384).type_as(inv_freq)
150
+ sinusoid_inp = torch.einsum("i , j -> i j", t, inv_freq)
151
+ emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
152
+ self.embed = nn.Embedding.from_pretrained(emb, freeze=True)
153
+
154
+ def forward(self, position_ids: torch.Tensor):
155
+ return self.embed(position_ids.long())
156
+
157
+
158
+ class FixedRotaryPositionalEmbedding(nn.Module):
159
+ def __init__(
160
+ self, rotary_dim: int, rotary_base: int = 10000, max_position: int = 16384
161
+ ):
162
+ super().__init__()
163
+ # This is an inverse frequency tensor
164
+ # Each dimension has a higher denominator than the previous one
165
+ # So, the frequency will be lower for higher dimensions
166
+ inv_freq = 1.0 / (
167
+ rotary_base ** (torch.arange(0, rotary_dim, 2).float() / rotary_dim)
168
+ ) # [rotary_dim/2]
169
+
170
+ # Now, we create frequencies for each position
171
+ t = torch.arange(max_position, device=inv_freq.device, dtype=inv_freq.dtype)
172
+ freqs = torch.einsum("i,j->ij", t, inv_freq) # [max_position, rotary_dim/2]
173
+
174
+ sins = torch.sin(freqs)
175
+ coss = torch.cos(freqs)
176
+
177
+ emb = torch.cat([sins, coss], dim=-1) # [max_position, rotary_dim]
178
+ self.embed = nn.Embedding.from_pretrained(emb, freeze=True)
179
+
180
+ def forward(self, position_ids: torch.Tensor):
181
+ return self.embed(position_ids.long())
182
+
183
+
184
+ class CustomT5Attention(T5Attention):
185
+ def __init__(self, config: T5Config, has_relative_attention_bias=False):
186
+ super(T5Attention, self).__init__()
187
+ self.is_decoder = config.is_decoder
188
+ self.has_relative_attention_bias = has_relative_attention_bias
189
+
190
+ self.relative_attention_num_buckets = config.relative_attention_num_buckets
191
+ self.d_model = config.d_model
192
+ self.key_value_proj_dim = config.d_kv
193
+ self.d_head = config.d_kv
194
+ self.n_heads = config.num_heads
195
+ self.dropout = config.dropout_rate
196
+ self.inner_dim = self.n_heads * self.key_value_proj_dim
197
+
198
+ # Mesh TensorFlow initialization to avoid scaling before softmax
199
+ self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
200
+ self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
201
+ self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
202
+ self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
203
+
204
+ self.position_encoding_type = getattr(
205
+ config, "position_encoding_type", POSITION_ENCODING_REL_T5_BIAS
206
+ )
207
+
208
+ if self.has_relative_attention_bias:
209
+ self.relative_attention_bias = nn.Embedding(
210
+ self.relative_attention_num_buckets, self.n_heads
211
+ )
212
+
213
+ if self.position_encoding_type == POSITION_ENCODING_REL_TRANSFORMER_XL:
214
+ self.r_r_bias = nn.Parameter(torch.FloatTensor(self.n_heads, self.d_head))
215
+ nn.init.normal_(
216
+ self.r_r_bias, mean=0.0, std=config.initializer_factor * 0.2
217
+ )
218
+ self.r_w_bias = nn.Parameter(torch.FloatTensor(self.n_heads, self.d_head))
219
+ nn.init.normal_(
220
+ self.r_w_bias, mean=0.0, std=config.initializer_factor * 0.2
221
+ )
222
+ self.r = nn.Linear(self.d_model, self.n_heads * self.d_head, bias=False)
223
+ self.r.weight.data.normal_(
224
+ mean=0.0, std=config.initializer_factor * (self.d_model**-0.5)
225
+ )
226
+ self.pos_emb = PositionalEmbedding(self.d_model)
227
+ self.clamp_length = 1000
228
+
229
+ if self.position_encoding_type == POSITION_ENCODING_ROTARY:
230
+ self.rotary_dim = None
231
+ if getattr(config, "rotary_dim", None) is not None:
232
+ self.rotary_dim = config.rotary_dim
233
+ self.rotary_dim = int(0.25 * self.d_head)
234
+
235
+ if self.position_encoding_type == POSITION_ENCODING_ROTARY_NEW:
236
+ # We hardcode the rotary dim to 25 percent of the head dim
237
+ self.rotary_dim = self.d_head // 4
238
+
239
+ self.pruned_heads = set()
240
+ self.gradient_checkpointing = False
241
+
242
+ def _rel_shift(self, x):
243
+ zero_pad_shape = x.size()[:2] + (x.size(2), 1)
244
+ zero_pad = torch.zeros(zero_pad_shape, device=x.device, dtype=x.dtype)
245
+ x_padded = torch.cat([zero_pad, x], dim=3)
246
+ x_padded_shape = x.size()[:2] + (x.size(3) + 1, x.size(2))
247
+ x_padded = x_padded.view(*x_padded_shape)
248
+ x = x_padded[:, :, 1:, :].view_as(x)
249
+ return x
250
+
251
+ def forward(
252
+ self,
253
+ hidden_states,
254
+ mask=None,
255
+ position_bias=None,
256
+ key_value_states=None,
257
+ past_key_value=None,
258
+ layer_head_mask=None,
259
+ query_length=None,
260
+ use_cache=False,
261
+ output_attentions=False,
262
+ ):
263
+ """
264
+ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
265
+ """
266
+ # Input is (batch_size, seq_length, dim)
267
+ # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
268
+ # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
269
+ batch_size, seq_length = hidden_states.shape[:2]
270
+
271
+ real_seq_length = seq_length
272
+
273
+ if past_key_value is not None:
274
+ assert (
275
+ len(past_key_value) == 2
276
+ ), f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states"
277
+ real_seq_length += (
278
+ past_key_value[0].shape[2] if query_length is None else query_length
279
+ )
280
+
281
+ key_length = (
282
+ real_seq_length if key_value_states is None else key_value_states.shape[1]
283
+ )
284
+
285
+ def shape(states):
286
+ """projection"""
287
+ return states.view(
288
+ batch_size, -1, self.n_heads, self.key_value_proj_dim
289
+ ).transpose(1, 2)
290
+
291
+ def unshape(states):
292
+ """reshape"""
293
+ return (
294
+ states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
295
+ )
296
+
297
+ def project(hidden_states, proj_layer, key_value_states, past_key_value):
298
+ """projects hidden states correctly to key/query states"""
299
+ if key_value_states is None:
300
+ # self-attn
301
+ # (batch_size, n_heads, seq_length, dim_per_head)
302
+ hidden_states = shape(proj_layer(hidden_states))
303
+ elif past_key_value is None:
304
+ # cross-attn
305
+ # (batch_size, n_heads, seq_length, dim_per_head)
306
+ hidden_states = shape(proj_layer(key_value_states))
307
+
308
+ if past_key_value is not None:
309
+ if key_value_states is None:
310
+ # self-attn
311
+ # (batch_size, n_heads, key_length, dim_per_head)
312
+ hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
313
+ else:
314
+ # cross-attn
315
+ hidden_states = past_key_value
316
+ return hidden_states
317
+
318
+ # get query states
319
+ query_states = shape(
320
+ self.q(hidden_states)
321
+ ) # (batch_size, n_heads, seq_length, dim_per_head)
322
+
323
+ if self.position_encoding_type in [
324
+ POSITION_ENCODING_ROTARY,
325
+ POSITION_ENCODING_ROTARY_NEW,
326
+ ]:
327
+ key_states = shape(self.k(hidden_states))
328
+ else:
329
+ # get key/value states
330
+ key_states = project(
331
+ hidden_states,
332
+ self.k,
333
+ key_value_states,
334
+ past_key_value[0] if past_key_value is not None else None,
335
+ )
336
+
337
+ value_states = project(
338
+ hidden_states,
339
+ self.v,
340
+ key_value_states,
341
+ past_key_value[1] if past_key_value is not None else None,
342
+ )
343
+
344
+ attention_output_dict = {}
345
+
346
+ if self.position_encoding_type == POSITION_ENCODING_REL_T5_BIAS:
347
+ scores = torch.matmul(query_states, key_states.transpose(3, 2))
348
+ attention_output_dict["scores_before"] = scores
349
+ if position_bias is None:
350
+ if not self.has_relative_attention_bias:
351
+ position_bias = torch.zeros(
352
+ (1, self.n_heads, real_seq_length, key_length),
353
+ device=scores.device,
354
+ dtype=scores.dtype,
355
+ )
356
+ if self.gradient_checkpointing and self.training:
357
+ position_bias.requires_grad = True
358
+ else:
359
+ position_bias = self.compute_bias(real_seq_length, key_length)
360
+
361
+ # if key and values are already calculated
362
+ # we want only the last query position bias
363
+ if past_key_value is not None:
364
+ position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
365
+
366
+ if mask is not None:
367
+ position_bias = (
368
+ position_bias + mask
369
+ ) # (batch_size, n_heads, seq_length, key_length)
370
+
371
+ scores += position_bias
372
+ elif self.position_encoding_type == POSITION_ENCODING_REL_TRANSFORMER_XL:
373
+ if position_bias is None:
374
+ pos_seq = torch.arange(
375
+ real_seq_length - 1,
376
+ -1,
377
+ -1.0,
378
+ device=hidden_states.device,
379
+ dtype=hidden_states.dtype,
380
+ )
381
+ if self.clamp_length > 0:
382
+ pos_seq = pos_seq.clamp_(max=self.clamp_length)
383
+ position_bias = self.pos_emb(pos_seq)
384
+ position_bias = nn.functional.dropout(
385
+ position_bias, p=self.dropout, training=self.training
386
+ )
387
+
388
+ position_embeds = position_bias # position embeds: [1, seq_len, d_model]
389
+
390
+ r_head_k = self.r(position_embeds) # [1, seq_len, n_head*d_head]
391
+ r_head_k = r_head_k.view(
392
+ position_embeds.shape[1], self.n_heads, self.d_head
393
+ ) # [seq_len, n_head, d_head]
394
+
395
+ rw_head_q = query_states + self.r_w_bias[None, :, None, :]
396
+ AC = torch.einsum("bnqd,bnkd->bnqk", (rw_head_q, key_states))
397
+
398
+ rr_head_q = query_states + self.r_r_bias[None, :, None, :]
399
+ BD = torch.einsum("bnid,jnd->bnij", (rr_head_q, r_head_k))
400
+ BD = self._rel_shift(BD)
401
+
402
+ scores = AC + BD
403
+
404
+ if mask is not None:
405
+ scores += mask
406
+ elif self.position_encoding_type == POSITION_ENCODING_ROTARY:
407
+ r_seq_len = hidden_states.shape[1]
408
+ r_offset = 0
409
+
410
+ if past_key_value is not None:
411
+ r_offset = past_key_value[0].shape[2]
412
+ r_seq_len += r_offset
413
+
414
+ query_states = query_states.permute(0, 2, 1, 3)
415
+ key_states = key_states.permute(0, 2, 1, 3)
416
+
417
+ if self.rotary_dim is not None:
418
+ k_rot = key_states[:, :, :, : self.rotary_dim]
419
+ k_pass = key_states[:, :, :, self.rotary_dim :]
420
+
421
+ q_rot = query_states[:, :, :, : self.rotary_dim]
422
+ q_pass = query_states[:, :, :, self.rotary_dim :]
423
+
424
+ sincos = fixed_pos_embedding(k_rot, 1, seq_len=r_seq_len)
425
+ k_rot = apply_rotary_pos_emb(k_rot, sincos, offset=r_offset)
426
+ q_rot = apply_rotary_pos_emb(q_rot, sincos, offset=r_offset)
427
+
428
+ if output_attentions:
429
+ scores_pass = torch.matmul(
430
+ q_pass.permute(0, 2, 1, 3),
431
+ k_pass.permute(0, 2, 1, 3).transpose(3, 2),
432
+ )
433
+ attention_output_dict["scores_pass"] = scores_pass
434
+
435
+ scores_rot = torch.matmul(
436
+ q_rot.permute(0, 2, 1, 3),
437
+ k_rot.permute(0, 2, 1, 3).transpose(3, 2),
438
+ )
439
+ attention_output_dict["scores_rot"] = scores_rot
440
+
441
+ key_states = torch.cat([k_rot, k_pass], dim=-1)
442
+ query_states = torch.cat([q_rot, q_pass], dim=-1)
443
+ else:
444
+ sincos = fixed_pos_embedding(key_states, 1, seq_len=r_seq_len)
445
+ key_states = apply_rotary_pos_emb(key_states, sincos, offset=r_offset)
446
+ query_states = apply_rotary_pos_emb(
447
+ query_states, sincos, offset=r_offset
448
+ )
449
+
450
+ query_states = query_states.permute(0, 2, 1, 3)
451
+ key_states = key_states.permute(0, 2, 1, 3)
452
+
453
+ if past_key_value is not None:
454
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
455
+
456
+ scores = torch.matmul(
457
+ query_states, key_states.transpose(3, 2)
458
+ ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
459
+ if mask is not None:
460
+ scores += mask # (batch_size, n_heads, seq_length, key_length)
461
+
462
+ elif self.position_encoding_type == POSITION_ENCODING_ROTARY_NEW:
463
+ r_seq_len = hidden_states.shape[1]
464
+ r_offset = 0
465
+
466
+ if past_key_value is not None:
467
+ r_offset = past_key_value[0].shape[2]
468
+ r_seq_len += r_offset
469
+
470
+ query_states = query_states.permute(0, 2, 1, 3)
471
+ key_states = key_states.permute(0, 2, 1, 3)
472
+
473
+ if self.rotary_dim is not None:
474
+ k_rot = key_states[:, :, :, : self.rotary_dim]
475
+ k_pass = key_states[:, :, :, self.rotary_dim :]
476
+
477
+ q_rot = query_states[:, :, :, : self.rotary_dim]
478
+ q_pass = query_states[:, :, :, self.rotary_dim :]
479
+
480
+ sincos = position_bias
481
+ # sincos is just vector created by torch.cat([sin, cos], dim=-1)
482
+ # so we can just split it in half
483
+ sin = sincos[:, :, : self.rotary_dim // 2]
484
+ cos = sincos[:, :, self.rotary_dim // 2 :]
485
+
486
+ # We don't need to pass offset here, because we already used
487
+ # position_ids to retrieve correct sin and cos vectors
488
+ k_rot = apply_rotary_pos_emb_new(k_rot, (sin, cos))
489
+ q_rot = apply_rotary_pos_emb_new(q_rot, (sin, cos))
490
+
491
+ key_states = torch.cat([k_rot, k_pass], dim=-1)
492
+ query_states = torch.cat([q_rot, q_pass], dim=-1)
493
+ else:
494
+ raise ValueError("rotary_dim is None")
495
+
496
+ query_states = query_states.permute(0, 2, 1, 3)
497
+ key_states = key_states.permute(0, 2, 1, 3)
498
+
499
+ if past_key_value is not None:
500
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
501
+
502
+ scores = torch.matmul(
503
+ query_states, key_states.transpose(3, 2)
504
+ ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
505
+ if mask is not None:
506
+ scores += mask # (batch_size, n_heads, seq_length, key_length)
507
+ elif self.position_encoding_type == POSITION_ENCODING_ALiBi:
508
+ scores = torch.matmul(query_states, key_states.transpose(3, 2))
509
+ attention_output_dict["scores_before"] = scores
510
+
511
+ alibi = position_bias
512
+ alibi = alibi.view(batch_size, self.n_heads, 1, key_length)
513
+
514
+ # if key and values are already calculated
515
+ # we want only the last query position bias
516
+ if past_key_value is not None:
517
+ alibi = alibi[:, :, -hidden_states.size(1) :, :]
518
+
519
+ if mask is not None:
520
+ alibi = alibi + mask # (batch_size, n_heads, seq_length, key_length)
521
+
522
+ scores += alibi
523
+ else:
524
+ assert (
525
+ self.position_encoding_type == POSITION_ENCODING_NONE
526
+ ), f"Unknown position encoding type: {self.position_encoding_type}"
527
+ scores = torch.matmul(
528
+ query_states, key_states.transpose(3, 2)
529
+ ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
530
+ if mask is not None:
531
+ scores += mask # (batch_size, n_heads, seq_length, key_length)
532
+
533
+ attention_output_dict["scores"] = scores
534
+
535
+ attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
536
+ scores
537
+ ) # (batch_size, n_heads, seq_length, key_length)
538
+ attn_weights = nn.functional.dropout(
539
+ attn_weights, p=self.dropout, training=self.training
540
+ ) # (batch_size, n_heads, seq_length, key_length)
541
+
542
+ # Mask heads if we want to
543
+ if layer_head_mask is not None:
544
+ attn_weights = attn_weights * layer_head_mask
545
+
546
+ attention_output_dict["probs"] = attn_weights
547
+
548
+ attn_output = unshape(
549
+ torch.matmul(attn_weights, value_states)
550
+ ) # (batch_size, seq_length, dim)
551
+ attn_output = self.o(attn_output)
552
+
553
+ present_key_value_state = (
554
+ (key_states, value_states) if (self.is_decoder and use_cache) else None
555
+ )
556
+ outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
557
+
558
+ if output_attentions:
559
+ outputs = outputs + (attention_output_dict,)
560
+ return outputs
561
+
562
+
563
+ class CustomT5LayerSelfAttention(T5LayerSelfAttention):
564
+ def __init__(self, config, has_relative_attention_bias=False):
565
+ super(T5LayerSelfAttention, self).__init__()
566
+ self.SelfAttention = CustomT5Attention(
567
+ config, has_relative_attention_bias=has_relative_attention_bias
568
+ )
569
+ self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
570
+ self.dropout = nn.Dropout(config.dropout_rate)
571
+
572
+
573
+ class CustomT5Block(T5Block):
574
+ def __init__(self, config, has_relative_attention_bias=False):
575
+ super(T5Block, self).__init__()
576
+ self.is_decoder = config.is_decoder
577
+ assert self.is_decoder
578
+ self.layer = nn.ModuleList()
579
+ self.layer.append(
580
+ CustomT5LayerSelfAttention(
581
+ config, has_relative_attention_bias=has_relative_attention_bias
582
+ )
583
+ )
584
+ if self.is_decoder:
585
+ self.layer.append(T5LayerCrossAttention(config))
586
+
587
+ self.layer.append(T5LayerFF(config))
588
+
589
+
590
+ def _make_causal_mask(
591
+ input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
592
+ ) -> torch.BoolTensor:
593
+ """
594
+ Make causal mask used for self-attention.
595
+ """
596
+ batch_size, target_length = input_ids_shape
597
+ mask = torch.empty(
598
+ (target_length, target_length + past_key_values_length),
599
+ dtype=torch.bool,
600
+ device=device,
601
+ )
602
+ # ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
603
+ seq_ids = torch.arange(target_length, device=device)
604
+ mask[:, past_key_values_length:] = seq_ids[:, None] < seq_ids[None, :]
605
+
606
+ if past_key_values_length > 0:
607
+ mask[:, :past_key_values_length] = False
608
+
609
+ expanded_mask = mask[None, None, :, :].expand(
610
+ batch_size, 1, target_length, target_length + past_key_values_length
611
+ )
612
+ return expanded_mask
613
+
614
+
615
+ def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
616
+ """
617
+ Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`.
618
+ """
619
+ batch_size, src_length = mask.shape
620
+ tgt_length = tgt_length if tgt_length is not None else src_length
621
+
622
+ expanded_mask = ~(mask[:, None, None, :].to(torch.bool))
623
+ return expanded_mask.expand(batch_size, 1, tgt_length, src_length)
624
+
625
+
626
+ def build_alibi_tensor(
627
+ attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype
628
+ ) -> torch.Tensor:
629
+ """
630
+ Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
631
+ relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
632
+ `softmax(l+a) = softmax(l)`. Based on
633
+ https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
634
+ TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly.
635
+ Args:
636
+ Returns tensor shaped (batch_size * num_heads, 1, max_seq_len)
637
+ attention_mask (`torch.Tensor`):
638
+ Token-wise attention mask, this should be of shape (batch_size, max_seq_len).
639
+ num_heads (`int`, *required*):
640
+ number of heads
641
+ dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`):
642
+ dtype of the output tensor
643
+ """
644
+ if len(attention_mask.shape) == 2:
645
+ batch_size, seq_length = attention_mask.shape
646
+ elif len(attention_mask.shape) == 3:
647
+ batch_size, _, seq_length = attention_mask.shape
648
+ closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
649
+ base = torch.tensor(
650
+ 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))),
651
+ device=attention_mask.device,
652
+ dtype=torch.float32,
653
+ )
654
+ powers = torch.arange(
655
+ 1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32
656
+ )
657
+ slopes = torch.pow(base, powers)
658
+
659
+ if closest_power_of_2 != num_heads:
660
+ extra_base = torch.tensor(
661
+ 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))),
662
+ device=attention_mask.device,
663
+ dtype=torch.float32,
664
+ )
665
+ num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
666
+ extra_powers = torch.arange(
667
+ 1,
668
+ 1 + 2 * num_remaining_heads,
669
+ 2,
670
+ device=attention_mask.device,
671
+ dtype=torch.int32,
672
+ )
673
+ slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
674
+
675
+ # Note: alibi will added to the attention bias that will be applied to the query, key product of attention
676
+ # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
677
+ # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
678
+ # => the query_length dimension will then be broadcasted correctly
679
+ # This is more or less identical to T5's relative position bias:
680
+ # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
681
+ arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
682
+ alibi = slopes[..., None] * arange_tensor
683
+ return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
684
+
685
+
686
+ class CustomT5Stack(T5Stack):
687
+ def __init__(self, config, embed_tokens=None):
688
+ super(T5Stack, self).__init__(config)
689
+
690
+ self.embed_tokens = embed_tokens
691
+ self.is_decoder = config.is_decoder
692
+ self.position_encoding_type = getattr(
693
+ config, "position_encoding_type", POSITION_ENCODING_REL_T5_BIAS
694
+ )
695
+
696
+ logger.info(f"position_encoding_type: {self.position_encoding_type}")
697
+
698
+ self.block = nn.ModuleList(
699
+ [
700
+ CustomT5Block(config, has_relative_attention_bias=bool(i == 0))
701
+ for i in range(config.num_layers)
702
+ ]
703
+ )
704
+ self.final_layer_norm = T5LayerNorm(
705
+ config.d_model, eps=config.layer_norm_epsilon
706
+ )
707
+ self.dropout = nn.Dropout(config.dropout_rate)
708
+
709
+ if self.position_encoding_type == POSITION_ENCODING_ABS_LEARNED:
710
+ self.wpe = nn.Embedding(2048, config.d_model)
711
+ parent_dir = Path(os.path.dirname(os.path.abspath(__file__)))
712
+ learned_embed_file = parent_dir / "gpt_neo_125m_pos_embed.npy"
713
+ if learned_embed_file.exists():
714
+ logger.info(
715
+ "Loading position embedding from {}".format(learned_embed_file)
716
+ )
717
+ import numpy as np
718
+
719
+ weight = np.load(str(learned_embed_file))
720
+ self.wpe.weight.data.copy_(torch.from_numpy(weight))
721
+ self.wpe.weight.requires_grad = False
722
+ else:
723
+ self.wpe.weight.data.normal_(
724
+ mean=0.0, std=config.initializer_factor * 1.0
725
+ )
726
+
727
+ if self.position_encoding_type == POSITION_ENCODING_ABS_SINUSOID:
728
+ self.wpe = FixedAbsolutePositionalEmbedding(config.d_model)
729
+
730
+ if self.position_encoding_type == POSITION_ENCODING_ROTARY_NEW:
731
+ # Rotary dim is X percentage of d_head
732
+ # Right now, we just hardcode X here following:
733
+ # https://github.com/huggingface/transformers/blob/v4.26.0/src/transformers/models/gpt_neox/configuration_gpt_neox.py
734
+ rotary_dim = int(config.d_kv * 0.25)
735
+ self.fixed_rotary_embedding = FixedRotaryPositionalEmbedding(
736
+ rotary_dim, max_position=4096
737
+ )
738
+
739
+ if self.position_encoding_type in [
740
+ POSITION_ENCODING_ALiBi,
741
+ POSITION_ENCODING_ALiBi_LEARNED,
742
+ ]:
743
+ maxpos = 2048
744
+ attn_heads = config.num_heads
745
+ if self.position_encoding_type == POSITION_ENCODING_ALiBi_LEARNED:
746
+ self.learned_logslopes = nn.Parameter(
747
+ torch.log(torch.Tensor(self.get_slopes(attn_heads)))
748
+ )
749
+ else:
750
+ slopes = torch.Tensor(self.get_slopes(attn_heads))
751
+ alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(
752
+ maxpos
753
+ ).unsqueeze(0).unsqueeze(0).expand(attn_heads, -1, -1)
754
+ alibi = alibi.view(attn_heads, 1, maxpos)
755
+ self.register_buffer("alibi", alibi)
756
+
757
+ # Initialize weights and apply final processing
758
+ self.post_init()
759
+ # Model parallel
760
+ self.model_parallel = False
761
+ self.device_map = None
762
+ self.gradient_checkpointing = False
763
+
764
+ self.window_size = 80 # only used for none_windowed
765
+
766
+ def _alibi_prepare_attn_mask(
767
+ self,
768
+ attention_mask: torch.Tensor,
769
+ input_shape: Tuple[int, int],
770
+ past_key_values_length: int,
771
+ ) -> torch.BoolTensor:
772
+ # create causal mask
773
+ # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
774
+ combined_attention_mask = None
775
+ device = attention_mask.device
776
+ _, src_length = input_shape
777
+
778
+ if src_length > 1:
779
+ combined_attention_mask = _make_causal_mask(
780
+ input_shape,
781
+ device=device,
782
+ past_key_values_length=past_key_values_length,
783
+ )
784
+
785
+ # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
786
+ expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
787
+ combined_attention_mask = (
788
+ expanded_attn_mask
789
+ if combined_attention_mask is None
790
+ else expanded_attn_mask | combined_attention_mask
791
+ )
792
+
793
+ return combined_attention_mask
794
+
795
+ def get_slopes(self, n):
796
+ def get_slopes_power_of_2(n):
797
+ start = 2 ** (-(2 ** -(math.log2(n) - 3)))
798
+ ratio = start
799
+ return [start * ratio**i for i in range(n)]
800
+
801
+ if math.log2(n).is_integer():
802
+ return get_slopes_power_of_2(
803
+ n
804
+ ) # In the paper, we only train models that have 2^a heads for some a. This function has
805
+ else: # some good properties that only occur when the input is a power of 2. To maintain that even
806
+ closest_power_of_2 = 2 ** math.floor(
807
+ math.log2(n)
808
+ ) # when the number of heads is not a power of 2, we use this workaround.
809
+ return (
810
+ get_slopes_power_of_2(closest_power_of_2)
811
+ + self.get_slopes(2 * closest_power_of_2)[0::2][
812
+ : n - closest_power_of_2
813
+ ]
814
+ )
815
+
816
+ def forward(
817
+ self,
818
+ input_ids=None,
819
+ attention_mask=None,
820
+ encoder_hidden_states=None,
821
+ encoder_attention_mask=None,
822
+ inputs_embeds=None,
823
+ head_mask=None,
824
+ cross_attn_head_mask=None,
825
+ past_key_values=None,
826
+ use_cache=None,
827
+ output_attentions=None,
828
+ output_hidden_states=None,
829
+ position_ids=None,
830
+ return_dict=None,
831
+ ):
832
+ # Model parallel
833
+ if self.model_parallel:
834
+ torch.cuda.set_device(self.first_device)
835
+ self.embed_tokens = self.embed_tokens.to(self.first_device)
836
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
837
+ output_attentions = (
838
+ output_attentions
839
+ if output_attentions is not None
840
+ else self.config.output_attentions
841
+ )
842
+ output_hidden_states = (
843
+ output_hidden_states
844
+ if output_hidden_states is not None
845
+ else self.config.output_hidden_states
846
+ )
847
+ return_dict = (
848
+ return_dict if return_dict is not None else self.config.use_return_dict
849
+ )
850
+
851
+ if input_ids is not None and inputs_embeds is not None:
852
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
853
+ raise ValueError(
854
+ f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
855
+ )
856
+ elif input_ids is not None:
857
+ input_shape = input_ids.size()
858
+ input_ids = input_ids.view(-1, input_shape[-1])
859
+ elif inputs_embeds is not None:
860
+ input_shape = inputs_embeds.size()[:-1]
861
+ else:
862
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
863
+ raise ValueError(
864
+ f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds"
865
+ )
866
+
867
+ if inputs_embeds is None:
868
+ assert (
869
+ self.embed_tokens is not None
870
+ ), "You have to initialize the model with valid token embeddings"
871
+ inputs_embeds = self.embed_tokens(input_ids)
872
+
873
+ if self.position_encoding_type in [
874
+ POSITION_ENCODING_ABS_LEARNED,
875
+ POSITION_ENCODING_ABS_SINUSOID,
876
+ ]:
877
+ if position_ids is not None:
878
+ position_ids = position_ids.view(-1, input_shape[-1])
879
+
880
+ if past_key_values is None:
881
+ past_length = 0
882
+ else:
883
+ past_length = past_key_values[0][0].size(-2)
884
+
885
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
886
+ if position_ids is None:
887
+ position_ids = torch.arange(
888
+ past_length,
889
+ input_shape[-1] + past_length,
890
+ dtype=torch.long,
891
+ device=device,
892
+ )
893
+ position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
894
+
895
+ position_embeds = self.wpe(position_ids)
896
+ inputs_embeds += position_embeds
897
+
898
+ batch_size, seq_length = input_shape
899
+
900
+ # `position_bias` is a just tensor that is passed to all attention layers
901
+ position_bias = None
902
+
903
+ # required mask seq length can be calculated via length of past
904
+ mask_seq_length = (
905
+ past_key_values[0][0].shape[2] + seq_length
906
+ if past_key_values is not None
907
+ else seq_length
908
+ )
909
+
910
+ if use_cache is True:
911
+ assert (
912
+ self.is_decoder
913
+ ), f"`use_cache` can only be set to `True` if {self} is used as a decoder"
914
+
915
+ if attention_mask is None:
916
+ attention_mask = torch.ones(batch_size, mask_seq_length).to(
917
+ inputs_embeds.device
918
+ )
919
+ if (
920
+ self.is_decoder
921
+ and encoder_attention_mask is None
922
+ and encoder_hidden_states is not None
923
+ ):
924
+ encoder_seq_length = encoder_hidden_states.shape[1]
925
+ encoder_attention_mask = torch.ones(
926
+ batch_size,
927
+ encoder_seq_length,
928
+ device=inputs_embeds.device,
929
+ dtype=torch.long,
930
+ )
931
+
932
+ if self.position_encoding_type == POSITION_ENCODING_ROTARY_NEW:
933
+ if position_ids is not None:
934
+ position_ids = position_ids.view(-1, input_shape[-1])
935
+
936
+ if past_key_values is None:
937
+ past_length = 0
938
+ else:
939
+ past_length = past_key_values[0][0].size(-2)
940
+
941
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
942
+ if position_ids is None:
943
+ position_ids = torch.arange(
944
+ past_length,
945
+ input_shape[-1] + past_length,
946
+ dtype=torch.long,
947
+ device=device,
948
+ )
949
+ position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
950
+
951
+ sinusoidal_pos = self.fixed_rotary_embedding(position_ids)
952
+ position_bias = sinusoidal_pos
953
+
954
+ # initialize past_key_values with `None` if past does not exist
955
+ if past_key_values is None:
956
+ past_key_values = [None] * len(self.block)
957
+
958
+ if self.position_encoding_type == POSITION_ENCODING_NONE_WINDOW:
959
+ indices = torch.arange(seq_length, device=inputs_embeds.device)
960
+ causal_mask = indices[:, None] >= indices
961
+ window_mask = (
962
+ (indices.unsqueeze(0) - indices.unsqueeze(0).T)
963
+ .abs()
964
+ .less(self.window_size)
965
+ )
966
+ causal_mask = causal_mask & window_mask
967
+ attention_mask = causal_mask.int()
968
+
969
+ # Repeat the mask for each sample in the batch
970
+ attention_mask = attention_mask[None, :, :].expand(
971
+ batch_size, seq_length, seq_length
972
+ )
973
+
974
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
975
+ # ourselves in which case we just need to make it broadcastable to all heads.
976
+ extended_attention_mask = self.get_extended_attention_mask(
977
+ attention_mask, input_shape, inputs_embeds.device
978
+ )
979
+
980
+ if self.position_encoding_type == POSITION_ENCODING_ALiBi:
981
+ num_heads = self.config.num_heads
982
+ if len(attention_mask.shape) == 3:
983
+ # We need to make a default attention mask
984
+ alibi_attention_mask = torch.ones(batch_size, mask_seq_length).to(
985
+ inputs_embeds.device
986
+ )
987
+ else:
988
+ alibi_attention_mask = attention_mask
989
+
990
+ alibi = build_alibi_tensor(
991
+ alibi_attention_mask, num_heads, dtype=inputs_embeds.dtype
992
+ )
993
+ position_bias = alibi
994
+ del alibi_attention_mask
995
+
996
+ if self.position_encoding_type in [POSITION_ENCODING_ALiBi_LEARNED]:
997
+ if not hasattr(self, "alibi"):
998
+ maxpos = 2048
999
+ attn_heads = self.config.num_heads
1000
+ slopes = self.learned_logslopes.exp()
1001
+ alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(
1002
+ maxpos, device=slopes.device
1003
+ ).unsqueeze(0).unsqueeze(0).expand(attn_heads, -1, -1)
1004
+ alibi = alibi.view(attn_heads, 1, maxpos)
1005
+ else:
1006
+ alibi = self.alibi
1007
+
1008
+ alibi = alibi.unsqueeze(0).repeat(batch_size, 1, 1, 1)
1009
+ alibi = alibi[:, :, :, : attention_mask.shape[-1]]
1010
+ alibi = alibi.repeat(1, 1, extended_attention_mask.shape[2], 1)
1011
+ extended_attention_mask = torch.where(
1012
+ extended_attention_mask == 0,
1013
+ alibi,
1014
+ extended_attention_mask.repeat(1, self.config.num_heads, 1, 1),
1015
+ )
1016
+
1017
+ # If a 2D or 3D attention mask is provided for the cross-attention
1018
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
1019
+ if self.is_decoder and encoder_hidden_states is not None:
1020
+ (
1021
+ encoder_batch_size,
1022
+ encoder_sequence_length,
1023
+ _,
1024
+ ) = encoder_hidden_states.size()
1025
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
1026
+ if encoder_attention_mask is None:
1027
+ encoder_attention_mask = torch.ones(
1028
+ encoder_hidden_shape, device=inputs_embeds.device
1029
+ )
1030
+ encoder_extended_attention_mask = self.invert_attention_mask(
1031
+ encoder_attention_mask
1032
+ )
1033
+ else:
1034
+ encoder_extended_attention_mask = None
1035
+
1036
+ # Prepare head mask if needed
1037
+ head_mask = self.get_head_mask(head_mask, self.config.num_layers)
1038
+ cross_attn_head_mask = self.get_head_mask(
1039
+ cross_attn_head_mask, self.config.num_layers
1040
+ )
1041
+ present_key_value_states = () if use_cache else None
1042
+ all_hidden_states = () if output_hidden_states else None
1043
+ all_attentions = () if output_attentions else None
1044
+ all_cross_attentions = () if (output_attentions and self.is_decoder) else None
1045
+ # position_bias = None
1046
+ encoder_decoder_position_bias = None
1047
+
1048
+ hidden_states = self.dropout(inputs_embeds)
1049
+
1050
+ for i, (layer_module, past_key_value) in enumerate(
1051
+ zip(self.block, past_key_values)
1052
+ ):
1053
+ layer_head_mask = head_mask[i]
1054
+ cross_attn_layer_head_mask = cross_attn_head_mask[i]
1055
+ # Model parallel
1056
+ if self.model_parallel:
1057
+ torch.cuda.set_device(hidden_states.device)
1058
+ # Ensure that attention_mask is always on the same device as hidden_states
1059
+ if attention_mask is not None:
1060
+ attention_mask = attention_mask.to(hidden_states.device)
1061
+ if position_bias is not None:
1062
+ position_bias = position_bias.to(hidden_states.device)
1063
+ if encoder_hidden_states is not None:
1064
+ encoder_hidden_states = encoder_hidden_states.to(
1065
+ hidden_states.device
1066
+ )
1067
+ if encoder_extended_attention_mask is not None:
1068
+ encoder_extended_attention_mask = (
1069
+ encoder_extended_attention_mask.to(hidden_states.device)
1070
+ )
1071
+ if encoder_decoder_position_bias is not None:
1072
+ encoder_decoder_position_bias = encoder_decoder_position_bias.to(
1073
+ hidden_states.device
1074
+ )
1075
+ if layer_head_mask is not None:
1076
+ layer_head_mask = layer_head_mask.to(hidden_states.device)
1077
+ if cross_attn_layer_head_mask is not None:
1078
+ cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(
1079
+ hidden_states.device
1080
+ )
1081
+ if output_hidden_states:
1082
+ all_hidden_states = all_hidden_states + (hidden_states,)
1083
+
1084
+ if self.gradient_checkpointing and self.training:
1085
+ if use_cache:
1086
+ logger.warn(
1087
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1088
+ )
1089
+ use_cache = False
1090
+
1091
+ def create_custom_forward(module):
1092
+ def custom_forward(*inputs):
1093
+ return tuple(module(*inputs, use_cache, output_attentions))
1094
+
1095
+ return custom_forward
1096
+
1097
+ layer_outputs = checkpoint(
1098
+ create_custom_forward(layer_module),
1099
+ hidden_states,
1100
+ extended_attention_mask,
1101
+ position_bias,
1102
+ encoder_hidden_states,
1103
+ encoder_extended_attention_mask,
1104
+ encoder_decoder_position_bias,
1105
+ layer_head_mask,
1106
+ cross_attn_layer_head_mask,
1107
+ None, # past_key_value is always None with gradient checkpointing
1108
+ )
1109
+ else:
1110
+ layer_outputs = layer_module(
1111
+ hidden_states,
1112
+ attention_mask=extended_attention_mask,
1113
+ position_bias=position_bias,
1114
+ encoder_hidden_states=encoder_hidden_states,
1115
+ encoder_attention_mask=encoder_extended_attention_mask,
1116
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
1117
+ layer_head_mask=layer_head_mask,
1118
+ cross_attn_layer_head_mask=cross_attn_layer_head_mask,
1119
+ past_key_value=past_key_value,
1120
+ use_cache=use_cache,
1121
+ output_attentions=output_attentions,
1122
+ )
1123
+
1124
+ # layer_outputs is a tuple with:
1125
+ # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
1126
+ if use_cache is False:
1127
+ layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
1128
+
1129
+ hidden_states, present_key_value_state = layer_outputs[:2]
1130
+
1131
+ # We share the position biases between the layers - the first layer store them
1132
+ # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
1133
+ # (cross-attention position bias), (cross-attention weights)
1134
+ position_bias = layer_outputs[2]
1135
+ if self.is_decoder and encoder_hidden_states is not None:
1136
+ encoder_decoder_position_bias = layer_outputs[
1137
+ 4 if output_attentions else 3
1138
+ ]
1139
+ # append next layer key value states
1140
+ if use_cache:
1141
+ present_key_value_states = present_key_value_states + (
1142
+ present_key_value_state,
1143
+ )
1144
+
1145
+ if output_attentions:
1146
+ all_attentions = all_attentions + (layer_outputs[3],)
1147
+ if self.is_decoder:
1148
+ all_cross_attentions = all_cross_attentions + (None,)
1149
+
1150
+ # Model Parallel: If it's the last layer for that device, put things on the next device
1151
+ if self.model_parallel:
1152
+ for k, v in self.device_map.items():
1153
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
1154
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
1155
+
1156
+ hidden_states = self.final_layer_norm(hidden_states)
1157
+ hidden_states = self.dropout(hidden_states)
1158
+
1159
+ # Add last layer
1160
+ if output_hidden_states:
1161
+ all_hidden_states = all_hidden_states + (hidden_states,)
1162
+
1163
+ if not return_dict:
1164
+ return tuple(
1165
+ v
1166
+ for v in [
1167
+ hidden_states,
1168
+ present_key_value_states,
1169
+ all_hidden_states,
1170
+ all_attentions,
1171
+ all_cross_attentions,
1172
+ ]
1173
+ if v is not None
1174
+ )
1175
+ return BaseModelOutputWithPastAndCrossAttentions(
1176
+ last_hidden_state=hidden_states,
1177
+ past_key_values=present_key_value_states,
1178
+ hidden_states=all_hidden_states,
1179
+ attentions=all_attentions,
1180
+ cross_attentions=all_cross_attentions,
1181
+ )
1182
+
1183
+
1184
+ class CustomDecoderOnlyT5(T5PreTrainedModel):
1185
+ config_class = CustomT5Config
1186
+ _keys_to_ignore_on_load_missing = [
1187
+ r"decoder\.embed_tokens\.weight",
1188
+ r"encoder",
1189
+ r"lm_head\.weight",
1190
+ ]
1191
+ _keys_to_ignore_on_load_unexpected = [
1192
+ r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
1193
+ ]
1194
+
1195
+ def __init__(
1196
+ self,
1197
+ config=None,
1198
+ output_non_reduced_loss: bool = False,
1199
+ **kwargs,
1200
+ ):
1201
+ assert config is not None
1202
+ config.is_decoder = True
1203
+ config.is_encoder_decoder = False
1204
+
1205
+ assert (
1206
+ config.position_encoding_type is not None
1207
+ ), "Position encoding type must be set"
1208
+
1209
+ self.output_non_reduced_loss = output_non_reduced_loss
1210
+ self.main_input_name = "input_ids"
1211
+
1212
+ super().__init__(config)
1213
+
1214
+ self.model_dim = config.d_model
1215
+
1216
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
1217
+ self.decoder = CustomT5Stack(config, self.shared)
1218
+
1219
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
1220
+
1221
+ # Initialize weights and apply final processing
1222
+ self.post_init()
1223
+
1224
+ # Model parallel
1225
+ self.model_parallel = False
1226
+ self.device_map = None
1227
+ #
1228
+ cross_attention_params = [
1229
+ p
1230
+ for n, p in self.decoder.named_parameters()
1231
+ if n.startswith("block.") and ".layer.1." in n
1232
+ ]
1233
+ for param in cross_attention_params:
1234
+ param.requires_grad = False
1235
+
1236
+ # self.handle_tokenizer(tokenizer)
1237
+
1238
+ def get_decoder(self):
1239
+ return self.decoder
1240
+
1241
+ def parallelize(self, device_map=None):
1242
+ self.device_map = (
1243
+ get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
1244
+ if device_map is None
1245
+ else device_map
1246
+ )
1247
+ assert_device_map(self.device_map, len(self.encoder.block))
1248
+ self.encoder.parallelize(self.device_map)
1249
+ self.decoder.parallelize(self.device_map)
1250
+ self.lm_head = self.lm_head.to(self.decoder.first_device)
1251
+ self.model_parallel = True
1252
+
1253
+ def deparallelize(self):
1254
+ self.encoder.deparallelize()
1255
+ self.decoder.deparallelize()
1256
+ self.encoder = self.encoder.to("cpu")
1257
+ self.decoder = self.decoder.to("cpu")
1258
+ self.lm_head = self.lm_head.to("cpu")
1259
+ self.model_parallel = False
1260
+ self.device_map = None
1261
+ torch.cuda.empty_cache()
1262
+
1263
+ def get_input_embeddings(self):
1264
+ return self.shared
1265
+
1266
+ def set_input_embeddings(self, new_embeddings):
1267
+ self.shared = new_embeddings
1268
+ self.decoder.set_input_embeddings(new_embeddings)
1269
+
1270
+ def set_output_embeddings(self, new_embeddings):
1271
+ self.lm_head = new_embeddings
1272
+
1273
+ def get_output_embeddings(self):
1274
+ return self.lm_head
1275
+
1276
+ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
1277
+ token_type_ids = kwargs.get("token_type_ids", None)
1278
+ # only last token for inputs_ids if past is defined in kwargs
1279
+ if past:
1280
+ input_ids = input_ids[:, -1].unsqueeze(-1)
1281
+ if token_type_ids is not None:
1282
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
1283
+
1284
+ attention_mask = kwargs.get("attention_mask", None)
1285
+ position_ids = kwargs.get("position_ids", None)
1286
+
1287
+ if attention_mask is not None and position_ids is None:
1288
+ # create position_ids on the fly for batch generation
1289
+ position_ids = attention_mask.long().cumsum(-1) - 1
1290
+ position_ids.masked_fill_(attention_mask == 0, 1)
1291
+ if past:
1292
+ position_ids = position_ids[:, -1].unsqueeze(-1)
1293
+ else:
1294
+ position_ids = None
1295
+
1296
+ return {
1297
+ "input_ids": input_ids,
1298
+ "past_key_values": past,
1299
+ "use_cache": kwargs.get("use_cache"),
1300
+ "attention_mask": attention_mask,
1301
+ "token_type_ids": token_type_ids,
1302
+ "position_ids": position_ids,
1303
+ }
1304
+
1305
+ def forward(
1306
+ self,
1307
+ input_ids=None,
1308
+ past_key_values=None,
1309
+ attention_mask=None,
1310
+ token_type_ids=None,
1311
+ position_ids=None,
1312
+ head_mask=None,
1313
+ inputs_embeds=None,
1314
+ labels=None,
1315
+ use_cache=None,
1316
+ output_attentions=None,
1317
+ output_hidden_states=None,
1318
+ return_dict=None,
1319
+ ):
1320
+ return_dict = (
1321
+ return_dict if return_dict is not None else self.config.use_return_dict
1322
+ )
1323
+
1324
+ if self.model_parallel:
1325
+ torch.cuda.set_device(self.decoder.first_device)
1326
+
1327
+ if self.model_parallel:
1328
+ torch.cuda.set_device(self.decoder.first_device)
1329
+ if input_ids is not None:
1330
+ input_ids = input_ids.to(self.decoder.first_device)
1331
+ if attention_mask is not None:
1332
+ attention_mask = attention_mask.to(self.decoder.first_device)
1333
+
1334
+ transformer_outputs = self.decoder(
1335
+ input_ids=input_ids,
1336
+ attention_mask=attention_mask,
1337
+ inputs_embeds=inputs_embeds,
1338
+ past_key_values=past_key_values,
1339
+ position_ids=position_ids,
1340
+ encoder_hidden_states=None,
1341
+ encoder_attention_mask=None,
1342
+ head_mask=head_mask,
1343
+ cross_attn_head_mask=None,
1344
+ use_cache=use_cache,
1345
+ output_attentions=output_attentions,
1346
+ output_hidden_states=output_hidden_states,
1347
+ return_dict=return_dict,
1348
+ )
1349
+ hidden_states = transformer_outputs[0]
1350
+
1351
+ if self.config.tie_word_embeddings:
1352
+ # Rescale output before projecting on vocab
1353
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
1354
+ hidden_states = hidden_states * (self.model_dim**-0.5)
1355
+
1356
+ lm_logits = self.lm_head(hidden_states)
1357
+
1358
+ loss = None
1359
+ non_reduced_loss = None
1360
+ if labels is not None:
1361
+ # Compute loss in fp32 to match with mesh-tf version
1362
+ # https://github.com/EleutherAI/gpt-neo/blob/89ce74164da2fb16179106f54e2269b5da8db333/models/gpt2/gpt2.py#L179
1363
+ lm_logits = lm_logits.to(torch.float32)
1364
+
1365
+ # Shift so that tokens < n predict n
1366
+ shift_logits = lm_logits[..., :-1, :].contiguous()
1367
+ shift_labels = labels[..., 1:].contiguous()
1368
+ # Flatten the tokens
1369
+ loss_fct = CrossEntropyLoss()
1370
+ loss = loss_fct(
1371
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
1372
+ )
1373
+
1374
+ lm_logits = lm_logits.to(hidden_states.dtype)
1375
+ loss = loss.to(hidden_states.dtype)
1376
+
1377
+ if self.output_non_reduced_loss:
1378
+ loss_fct = CrossEntropyLoss(reduction="none")
1379
+ non_reduced_loss = loss_fct(
1380
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
1381
+ )
1382
+
1383
+ # Reshape to [batch_size, seq_length - 1]
1384
+ non_reduced_loss = non_reduced_loss.view(
1385
+ shift_labels.shape[0], shift_labels.shape[1]
1386
+ )[:, -1].view(-1, 1)
1387
+
1388
+ if not return_dict:
1389
+ output = (lm_logits,) + transformer_outputs[1:]
1390
+ return ((loss,) + output) if loss is not None else output
1391
+
1392
+ return CausalLMOutputWithPastAndLoss(
1393
+ loss=loss,
1394
+ logits=lm_logits,
1395
+ past_key_values=transformer_outputs.past_key_values,
1396
+ hidden_states=transformer_outputs.hidden_states,
1397
+ attentions=transformer_outputs.attentions,
1398
+ non_reduced_loss=non_reduced_loss,
1399
+ )
1400
+
1401
+ @staticmethod
1402
+ def _reorder_cache(
1403
+ past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
1404
+ ) -> Tuple[Tuple[torch.Tensor]]:
1405
+ """
1406
+ This function is used to re-order the `past_key_values` cache if [`~PretrainedModel.beam_search`] or
1407
+ [`~PretrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1408
+ beam_idx at every generation step.
1409
+ """
1410
+ return tuple(
1411
+ tuple(
1412
+ past_state.index_select(0, beam_idx.to(past_state.device))
1413
+ for past_state in layer_past
1414
+ )
1415
+ for layer_past in past
1416
+ )
modeling_t5.py ADDED
@@ -0,0 +1,1821 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch T5 model."""
16
+
17
+
18
+ import copy
19
+ import math
20
+ import os
21
+
22
+ import torch
23
+ from torch import nn
24
+ from torch.nn import CrossEntropyLoss
25
+ from torch.utils.checkpoint import checkpoint
26
+ from transformers import T5Config
27
+ from transformers.activations import ACT2FN
28
+ from transformers.file_utils import (
29
+ DUMMY_INPUTS,
30
+ DUMMY_MASK,
31
+ add_start_docstrings,
32
+ add_start_docstrings_to_model_forward,
33
+ is_torch_fx_proxy,
34
+ replace_return_docstrings,
35
+ )
36
+ from transformers.modeling_outputs import (
37
+ BaseModelOutput,
38
+ BaseModelOutputWithPastAndCrossAttentions,
39
+ Seq2SeqLMOutput,
40
+ Seq2SeqModelOutput,
41
+ )
42
+ from transformers.modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
43
+ from transformers.utils import logging
44
+ from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
45
+
46
+ logger = logging.get_logger(__name__)
47
+
48
+ _CONFIG_FOR_DOC = "T5Config"
49
+ _TOKENIZER_FOR_DOC = "T5Tokenizer"
50
+ _CHECKPOINT_FOR_DOC = "t5-small"
51
+
52
+ ####################################################
53
+ # This dict contains ids and associated url
54
+ # for the pretrained weights provided with the models
55
+ ####################################################
56
+ T5_PRETRAINED_MODEL_ARCHIVE_LIST = [
57
+ "t5-small",
58
+ "t5-base",
59
+ "t5-large",
60
+ "t5-3b",
61
+ "t5-11b",
62
+ # See all T5 models at https://huggingface.co/models?filter=t5
63
+ ]
64
+
65
+
66
+ ####################################################
67
+ # This is a conversion method from TF 1.0 to PyTorch
68
+ # More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28
69
+ ####################################################
70
+ def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
71
+ """Load tf checkpoints in a pytorch model."""
72
+ try:
73
+ import re
74
+
75
+ import numpy as np
76
+ import tensorflow as tf
77
+ except ImportError:
78
+ logger.error(
79
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
80
+ "https://www.tensorflow.org/install/ for installation instructions."
81
+ )
82
+ raise
83
+ tf_path = os.path.abspath(tf_checkpoint_path)
84
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
85
+ # Load weights from TF model
86
+ init_vars = tf.train.list_variables(tf_path)
87
+ names = []
88
+ tf_weights = {}
89
+ for name, shape in init_vars:
90
+ logger.info(f"Loading TF weight {name} with shape {shape}")
91
+ array = tf.train.load_variable(tf_path, name)
92
+ names.append(name)
93
+ tf_weights[name] = array
94
+
95
+ for txt_name in names:
96
+ name = txt_name.split("/")
97
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
98
+ # which are not required for using pretrained model
99
+ if any(
100
+ n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
101
+ for n in name
102
+ ):
103
+ logger.info(f"Skipping {'/'.join(name)}")
104
+ tf_weights.pop(txt_name, None)
105
+ continue
106
+ if "_slot_" in name[-1]:
107
+ logger.info(f"Skipping {'/'.join(name)}")
108
+ tf_weights.pop(txt_name, None)
109
+ continue
110
+ pointer = model
111
+ array = tf_weights[txt_name]
112
+
113
+ for m_name in name:
114
+ if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
115
+ scope_names = re.split(r"_(\d+)", m_name)
116
+ else:
117
+ scope_names = [m_name]
118
+ if scope_names[0] in ["kernel", "scale", "embedding"]:
119
+ pointer = getattr(pointer, "weight")
120
+ elif scope_names[0] == "self_attention":
121
+ pointer = getattr(pointer, "layer")
122
+ pointer = pointer[0]
123
+ elif scope_names[0] == "enc_dec_attention":
124
+ pointer = getattr(pointer, "layer")
125
+ pointer = pointer[1]
126
+ elif scope_names[0] == "dense_relu_dense":
127
+ pointer = getattr(pointer, "layer")
128
+ pointer = pointer[2]
129
+ elif scope_names[0] == "rms_norm":
130
+ if hasattr(pointer, "layer_norm"):
131
+ pointer = getattr(pointer, "layer_norm")
132
+ elif hasattr(pointer, "final_layer_norm"):
133
+ pointer = getattr(pointer, "final_layer_norm")
134
+ elif scope_names[0] == "scale":
135
+ pointer = getattr(pointer, "weight")
136
+ elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
137
+ pointer = getattr(pointer, "bias")
138
+ elif scope_names[0] == "squad":
139
+ pointer = getattr(pointer, "classifier")
140
+ elif scope_names[0] == "decoder" and name[1] == "logits":
141
+ continue
142
+ elif scope_names[0] == "logits":
143
+ pointer = getattr(pointer, "lm_head")
144
+ elif scope_names[0] == "wi" and len(scope_names) > 1 and scope_names[1].isdigit():
145
+ pointer = getattr(pointer, f"wi_{scope_names[1]}")
146
+ continue
147
+ else:
148
+ try:
149
+ pointer = getattr(pointer, scope_names[0])
150
+ except AttributeError:
151
+ logger.info(f"Skipping {'/'.join(name)}")
152
+ continue
153
+ if len(scope_names) >= 2:
154
+ num = int(scope_names[1])
155
+ pointer = pointer[num]
156
+ if scope_names[0] not in ["kernel", "scale", "embedding"]:
157
+ pointer = getattr(pointer, "weight")
158
+ if scope_names[0] != "embedding":
159
+ logger.info(f"Transposing numpy weight of shape {array.shape} for {name}")
160
+ array = np.transpose(array)
161
+ try:
162
+ assert (
163
+ pointer.shape == array.shape
164
+ ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
165
+ except AssertionError as e:
166
+ e.args += (pointer.shape, array.shape)
167
+ raise
168
+ logger.info(f"Initialize PyTorch weight {name}")
169
+ pointer.data = torch.from_numpy(array.astype(np.float32))
170
+ tf_weights.pop(txt_name, None)
171
+
172
+ logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}.")
173
+ return model
174
+
175
+
176
+ ####################################################
177
+ # PyTorch Models are constructed by sub-classing
178
+ # - torch.nn.Module for the layers and
179
+ # - PreTrainedModel for the models (it-self a sub-class of nn.Module)
180
+ ####################################################
181
+ PARALLELIZE_DOCSTRING = r"""
182
+ This is an experimental feature and is a subject to change at a moment's notice.
183
+
184
+ Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
185
+ it will evenly distribute blocks across all devices.
186
+
187
+ Args:
188
+ device_map (`Dict[int, list]`, optional, defaults to None):
189
+ A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
190
+ automatically mapped to the first device (for esoteric reasons). That means that the first device should
191
+ have fewer attention modules mapped to it than other devices. For reference, the t5 models have the
192
+ following number of attention modules:
193
+
194
+ - t5-small: 6
195
+ - t5-base: 12
196
+ - t5-large: 24
197
+ - t5-3b: 24
198
+ - t5-11b: 24
199
+
200
+ Example:
201
+
202
+ ```python
203
+ # Here is an example of a device map on a machine with 4 GPUs using t5-3b, which has a total of 24 attention modules:
204
+ model = T5ForConditionalGeneration.from_pretrained("t5-3b")
205
+ device_map = {
206
+ 0: [0, 1, 2],
207
+ 1: [3, 4, 5, 6, 7, 8, 9],
208
+ 2: [10, 11, 12, 13, 14, 15, 16],
209
+ 3: [17, 18, 19, 20, 21, 22, 23],
210
+ }
211
+ model.parallelize(device_map)
212
+ ```
213
+ """
214
+ DEPARALLELIZE_DOCSTRING = r"""
215
+ Moves the model to cpu from a model parallel state.
216
+
217
+ Example:
218
+
219
+ ```python
220
+ # On a 4 GPU machine with t5-3b:
221
+ model = T5ForConditionalGeneration.from_pretrained("t5-3b")
222
+ device_map = {
223
+ 0: [0, 1, 2],
224
+ 1: [3, 4, 5, 6, 7, 8, 9],
225
+ 2: [10, 11, 12, 13, 14, 15, 16],
226
+ 3: [17, 18, 19, 20, 21, 22, 23],
227
+ }
228
+ model.parallelize(device_map) # Splits the model across several devices
229
+ model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
230
+ ```
231
+ """
232
+
233
+
234
+ class T5LayerNorm(nn.Module):
235
+ def __init__(self, hidden_size, eps=1e-6):
236
+ """
237
+ Construct a layernorm module in the T5 style No bias and no subtraction of mean.
238
+ """
239
+ super().__init__()
240
+ self.weight = nn.Parameter(torch.ones(hidden_size))
241
+ self.variance_epsilon = eps
242
+
243
+ def forward(self, hidden_states):
244
+ # layer norm should always be calculated in float32
245
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
246
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
247
+
248
+ # convert into half-precision if necessary
249
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
250
+ hidden_states = hidden_states.to(self.weight.dtype)
251
+
252
+ return self.weight * hidden_states
253
+
254
+
255
+ class T5DenseReluDense(nn.Module):
256
+ def __init__(self, config):
257
+ super().__init__()
258
+ self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
259
+ self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
260
+ self.dropout = nn.Dropout(config.dropout_rate)
261
+
262
+ def forward(self, hidden_states):
263
+ hidden_states = self.wi(hidden_states)
264
+ hidden_states = nn.functional.relu(hidden_states)
265
+ hidden_states = self.dropout(hidden_states)
266
+ hidden_states = self.wo(hidden_states)
267
+ return hidden_states
268
+
269
+
270
+ class T5DenseGatedGeluDense(nn.Module):
271
+ def __init__(self, config):
272
+ super().__init__()
273
+ self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
274
+ self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
275
+ self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
276
+ self.dropout = nn.Dropout(config.dropout_rate)
277
+ self.gelu_act = ACT2FN["gelu_new"]
278
+
279
+ def forward(self, hidden_states):
280
+ hidden_gelu = self.gelu_act(self.wi_0(hidden_states))
281
+ hidden_linear = self.wi_1(hidden_states)
282
+ hidden_states = hidden_gelu * hidden_linear
283
+ hidden_states = self.dropout(hidden_states)
284
+ hidden_states = self.wo(hidden_states)
285
+ return hidden_states
286
+
287
+
288
+ class T5LayerFF(nn.Module):
289
+ def __init__(self, config):
290
+ super().__init__()
291
+ if config.feed_forward_proj == "relu":
292
+ self.DenseReluDense = T5DenseReluDense(config)
293
+ elif config.feed_forward_proj == "gated-gelu":
294
+ self.DenseReluDense = T5DenseGatedGeluDense(config)
295
+ else:
296
+ raise ValueError(
297
+ f"{self.config.feed_forward_proj} is not supported. Choose between `relu` and `gated-gelu`"
298
+ )
299
+
300
+ self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
301
+ self.dropout = nn.Dropout(config.dropout_rate)
302
+
303
+ def forward(self, hidden_states):
304
+ forwarded_states = self.layer_norm(hidden_states)
305
+ forwarded_states = self.DenseReluDense(forwarded_states)
306
+ hidden_states = hidden_states + self.dropout(forwarded_states)
307
+ return hidden_states
308
+
309
+
310
+ class T5Attention(nn.Module):
311
+ def __init__(self, config: T5Config, has_relative_attention_bias=False):
312
+ super().__init__()
313
+ self.is_decoder = config.is_decoder
314
+ self.has_relative_attention_bias = has_relative_attention_bias
315
+
316
+ self.relative_attention_num_buckets = config.relative_attention_num_buckets
317
+ self.d_model = config.d_model
318
+ self.key_value_proj_dim = config.d_kv
319
+ self.n_heads = config.num_heads
320
+ self.dropout = config.dropout_rate
321
+ self.inner_dim = self.n_heads * self.key_value_proj_dim
322
+
323
+ # Mesh TensorFlow initialization to avoid scaling before softmax
324
+ self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
325
+ self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
326
+ self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
327
+ self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
328
+
329
+ if self.has_relative_attention_bias:
330
+ self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
331
+ self.pruned_heads = set()
332
+ self.gradient_checkpointing = False
333
+
334
+ def prune_heads(self, heads):
335
+ if len(heads) == 0:
336
+ return
337
+ heads, index = find_pruneable_heads_and_indices(
338
+ heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
339
+ )
340
+ # Prune linear layers
341
+ self.q = prune_linear_layer(self.q, index)
342
+ self.k = prune_linear_layer(self.k, index)
343
+ self.v = prune_linear_layer(self.v, index)
344
+ self.o = prune_linear_layer(self.o, index, dim=1)
345
+ # Update hyper params
346
+ self.n_heads = self.n_heads - len(heads)
347
+ self.inner_dim = self.key_value_proj_dim * self.n_heads
348
+ self.pruned_heads = self.pruned_heads.union(heads)
349
+
350
+ @staticmethod
351
+ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
352
+ """
353
+ Adapted from Mesh Tensorflow:
354
+ https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
355
+
356
+ Translate relative position to a bucket number for relative attention. The relative position is defined as
357
+ memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
358
+ position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
359
+ small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
360
+ positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
361
+ This should allow for more graceful generalization to longer sequences than the model has been trained on
362
+
363
+ Args:
364
+ relative_position: an int32 Tensor
365
+ bidirectional: a boolean - whether the attention is bidirectional
366
+ num_buckets: an integer
367
+ max_distance: an integer
368
+
369
+ Returns:
370
+ a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
371
+ """
372
+ relative_buckets = 0
373
+ if bidirectional:
374
+ num_buckets //= 2
375
+ relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
376
+ relative_position = torch.abs(relative_position)
377
+ else:
378
+ relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
379
+ # now relative_position is in the range [0, inf)
380
+
381
+ # half of the buckets are for exact increments in positions
382
+ max_exact = num_buckets // 2
383
+ is_small = relative_position < max_exact
384
+
385
+ # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
386
+ relative_postion_if_large = max_exact + (
387
+ torch.log(relative_position.float() / max_exact)
388
+ / math.log(max_distance / max_exact)
389
+ * (num_buckets - max_exact)
390
+ ).to(torch.long)
391
+ relative_postion_if_large = torch.min(
392
+ relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
393
+ )
394
+
395
+ relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large)
396
+ return relative_buckets
397
+
398
+ def compute_bias(self, query_length, key_length):
399
+ """Compute binned relative position bias"""
400
+ context_position = torch.arange(
401
+ query_length, dtype=torch.long, device=self.relative_attention_bias.weight.device
402
+ )[:, None]
403
+ memory_position = torch.arange(
404
+ key_length, dtype=torch.long, device=self.relative_attention_bias.weight.device
405
+ )[None, :]
406
+ relative_position = memory_position - context_position # shape (query_length, key_length)
407
+ relative_position_bucket = self._relative_position_bucket(
408
+ relative_position, # shape (query_length, key_length)
409
+ bidirectional=(not self.is_decoder),
410
+ num_buckets=self.relative_attention_num_buckets,
411
+ )
412
+ values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
413
+ values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
414
+ return values
415
+
416
+ def forward(
417
+ self,
418
+ hidden_states,
419
+ mask=None,
420
+ key_value_states=None,
421
+ position_bias=None,
422
+ past_key_value=None,
423
+ layer_head_mask=None,
424
+ query_length=None,
425
+ use_cache=False,
426
+ output_attentions=False,
427
+ ):
428
+ """
429
+ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
430
+ """
431
+ # Input is (batch_size, seq_length, dim)
432
+ # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
433
+ # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
434
+ batch_size, seq_length = hidden_states.shape[:2]
435
+
436
+ real_seq_length = seq_length
437
+
438
+ if past_key_value is not None:
439
+ assert (
440
+ len(past_key_value) == 2
441
+ ), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
442
+ real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
443
+
444
+ key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
445
+
446
+ def shape(states):
447
+ """projection"""
448
+ return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
449
+
450
+ def unshape(states):
451
+ """reshape"""
452
+ return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
453
+
454
+ def project(hidden_states, proj_layer, key_value_states, past_key_value):
455
+ """projects hidden states correctly to key/query states"""
456
+ if key_value_states is None:
457
+ # self-attn
458
+ # (batch_size, n_heads, seq_length, dim_per_head)
459
+ hidden_states = shape(proj_layer(hidden_states))
460
+ elif past_key_value is None:
461
+ # cross-attn
462
+ # (batch_size, n_heads, seq_length, dim_per_head)
463
+ hidden_states = shape(proj_layer(key_value_states))
464
+
465
+ if past_key_value is not None:
466
+ if key_value_states is None:
467
+ # self-attn
468
+ # (batch_size, n_heads, key_length, dim_per_head)
469
+ hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
470
+ else:
471
+ # cross-attn
472
+ hidden_states = past_key_value
473
+ return hidden_states
474
+
475
+ # get query states
476
+ query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head)
477
+
478
+ # get key/value states
479
+ key_states = project(
480
+ hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
481
+ )
482
+ value_states = project(
483
+ hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
484
+ )
485
+
486
+ # compute scores
487
+ scores = torch.matmul(
488
+ query_states, key_states.transpose(3, 2)
489
+ ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
490
+
491
+ if position_bias is None:
492
+ if not self.has_relative_attention_bias:
493
+ position_bias = torch.zeros(
494
+ (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
495
+ )
496
+ if self.gradient_checkpointing and self.training:
497
+ position_bias.requires_grad = True
498
+ else:
499
+ position_bias = self.compute_bias(real_seq_length, key_length)
500
+
501
+ # if key and values are already calculated
502
+ # we want only the last query position bias
503
+ if past_key_value is not None:
504
+ position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
505
+
506
+ if mask is not None:
507
+ position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
508
+
509
+ scores += position_bias
510
+ attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
511
+ scores
512
+ ) # (batch_size, n_heads, seq_length, key_length)
513
+ attn_weights = nn.functional.dropout(
514
+ attn_weights, p=self.dropout, training=self.training
515
+ ) # (batch_size, n_heads, seq_length, key_length)
516
+
517
+ # Mask heads if we want to
518
+ if layer_head_mask is not None:
519
+ attn_weights = attn_weights * layer_head_mask
520
+
521
+ attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim)
522
+ attn_output = self.o(attn_output)
523
+
524
+ present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
525
+ outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
526
+
527
+ if output_attentions:
528
+ outputs = outputs + (attn_weights,)
529
+ return outputs
530
+
531
+
532
+ class T5LayerSelfAttention(nn.Module):
533
+ def __init__(self, config, has_relative_attention_bias=False):
534
+ super().__init__()
535
+ self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias)
536
+ self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
537
+ self.dropout = nn.Dropout(config.dropout_rate)
538
+
539
+ def forward(
540
+ self,
541
+ hidden_states,
542
+ attention_mask=None,
543
+ position_bias=None,
544
+ layer_head_mask=None,
545
+ past_key_value=None,
546
+ use_cache=False,
547
+ output_attentions=False,
548
+ ):
549
+ normed_hidden_states = self.layer_norm(hidden_states)
550
+ attention_output = self.SelfAttention(
551
+ normed_hidden_states,
552
+ mask=attention_mask,
553
+ position_bias=position_bias,
554
+ layer_head_mask=layer_head_mask,
555
+ past_key_value=past_key_value,
556
+ use_cache=use_cache,
557
+ output_attentions=output_attentions,
558
+ )
559
+ hidden_states = hidden_states + self.dropout(attention_output[0])
560
+ outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
561
+ return outputs
562
+
563
+
564
+ class T5LayerCrossAttention(nn.Module):
565
+ def __init__(self, config):
566
+ super().__init__()
567
+ self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False)
568
+ self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
569
+ self.dropout = nn.Dropout(config.dropout_rate)
570
+
571
+ def forward(
572
+ self,
573
+ hidden_states,
574
+ key_value_states,
575
+ attention_mask=None,
576
+ position_bias=None,
577
+ layer_head_mask=None,
578
+ past_key_value=None,
579
+ use_cache=False,
580
+ query_length=None,
581
+ output_attentions=False,
582
+ ):
583
+ normed_hidden_states = self.layer_norm(hidden_states)
584
+ attention_output = self.EncDecAttention(
585
+ normed_hidden_states,
586
+ mask=attention_mask,
587
+ key_value_states=key_value_states,
588
+ position_bias=position_bias,
589
+ layer_head_mask=layer_head_mask,
590
+ past_key_value=past_key_value,
591
+ use_cache=use_cache,
592
+ query_length=query_length,
593
+ output_attentions=output_attentions,
594
+ )
595
+ layer_output = hidden_states + self.dropout(attention_output[0])
596
+ outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
597
+ return outputs
598
+
599
+
600
+ class T5Block(nn.Module):
601
+ def __init__(self, config, has_relative_attention_bias=False):
602
+ super().__init__()
603
+ self.is_decoder = config.is_decoder
604
+ self.layer = nn.ModuleList()
605
+ self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias))
606
+ if self.is_decoder:
607
+ self.layer.append(T5LayerCrossAttention(config))
608
+
609
+ self.layer.append(T5LayerFF(config))
610
+
611
+ def forward(
612
+ self,
613
+ hidden_states,
614
+ attention_mask=None,
615
+ position_bias=None,
616
+ encoder_hidden_states=None,
617
+ encoder_attention_mask=None,
618
+ encoder_decoder_position_bias=None,
619
+ layer_head_mask=None,
620
+ cross_attn_layer_head_mask=None,
621
+ past_key_value=None,
622
+ use_cache=False,
623
+ output_attentions=False,
624
+ return_dict=True,
625
+ ):
626
+
627
+ if past_key_value is not None:
628
+ assert self.is_decoder, "Only decoder can use `past_key_values`"
629
+ expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
630
+
631
+ if len(past_key_value) != expected_num_past_key_values:
632
+ raise ValueError(
633
+ f"There should be {expected_num_past_key_values} past states. "
634
+ f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
635
+ f"Got {len(past_key_value)} past key / value states"
636
+ )
637
+
638
+ self_attn_past_key_value = past_key_value[:2]
639
+ cross_attn_past_key_value = past_key_value[2:]
640
+ else:
641
+ self_attn_past_key_value, cross_attn_past_key_value = None, None
642
+
643
+ self_attention_outputs = self.layer[0](
644
+ hidden_states,
645
+ attention_mask=attention_mask,
646
+ position_bias=position_bias,
647
+ layer_head_mask=layer_head_mask,
648
+ past_key_value=self_attn_past_key_value,
649
+ use_cache=use_cache,
650
+ output_attentions=output_attentions,
651
+ )
652
+ hidden_states, present_key_value_state = self_attention_outputs[:2]
653
+ attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
654
+
655
+ # clamp inf values to enable fp16 training
656
+ if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
657
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
658
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
659
+
660
+ do_cross_attention = self.is_decoder and encoder_hidden_states is not None
661
+ if do_cross_attention:
662
+ # the actual query length is unknown for cross attention
663
+ # if using past key value states. Need to inject it here
664
+ if present_key_value_state is not None:
665
+ query_length = present_key_value_state[0].shape[2]
666
+ else:
667
+ query_length = None
668
+
669
+ cross_attention_outputs = self.layer[1](
670
+ hidden_states,
671
+ key_value_states=encoder_hidden_states,
672
+ attention_mask=encoder_attention_mask,
673
+ position_bias=encoder_decoder_position_bias,
674
+ layer_head_mask=cross_attn_layer_head_mask,
675
+ past_key_value=cross_attn_past_key_value,
676
+ query_length=query_length,
677
+ use_cache=use_cache,
678
+ output_attentions=output_attentions,
679
+ )
680
+ hidden_states = cross_attention_outputs[0]
681
+
682
+ # clamp inf values to enable fp16 training
683
+ if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
684
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
685
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
686
+
687
+ # Combine self attn and cross attn key value states
688
+ if present_key_value_state is not None:
689
+ present_key_value_state = present_key_value_state + cross_attention_outputs[1]
690
+
691
+ # Keep cross-attention outputs and relative position weights
692
+ attention_outputs = attention_outputs + cross_attention_outputs[2:]
693
+
694
+ # Apply Feed Forward layer
695
+ hidden_states = self.layer[-1](hidden_states)
696
+
697
+ # clamp inf values to enable fp16 training
698
+ if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
699
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
700
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
701
+
702
+ outputs = (hidden_states,)
703
+
704
+ if use_cache:
705
+ outputs = outputs + (present_key_value_state,) + attention_outputs
706
+ else:
707
+ outputs = outputs + attention_outputs
708
+
709
+ return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
710
+
711
+
712
+ class T5PreTrainedModel(PreTrainedModel):
713
+ """
714
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
715
+ models.
716
+ """
717
+
718
+ config_class = T5Config
719
+ load_tf_weights = load_tf_weights_in_t5
720
+ base_model_prefix = "transformer"
721
+ is_parallelizable = True
722
+ supports_gradient_checkpointing = True
723
+
724
+ @property
725
+ def dummy_inputs(self):
726
+ input_ids = torch.tensor(DUMMY_INPUTS)
727
+ input_mask = torch.tensor(DUMMY_MASK)
728
+ dummy_inputs = {
729
+ "decoder_input_ids": input_ids,
730
+ "input_ids": input_ids,
731
+ "decoder_attention_mask": input_mask,
732
+ }
733
+ return dummy_inputs
734
+
735
+ def _init_weights(self, module):
736
+ """Initialize the weights"""
737
+ factor = self.config.initializer_factor # Used for testing weights initialization
738
+ if isinstance(module, T5LayerNorm):
739
+ module.weight.data.fill_(factor * 1.0)
740
+ elif isinstance(module, (T5Model, T5ForConditionalGeneration, T5EncoderModel)):
741
+ # Mesh TensorFlow embeddings initialization
742
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
743
+ module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
744
+ elif isinstance(module, T5DenseReluDense):
745
+ # Mesh TensorFlow FF initialization
746
+ # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
747
+ # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
748
+ module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
749
+ if hasattr(module.wi, "bias") and module.wi.bias is not None:
750
+ module.wi.bias.data.zero_()
751
+ module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
752
+ if hasattr(module.wo, "bias") and module.wo.bias is not None:
753
+ module.wo.bias.data.zero_()
754
+ elif isinstance(module, T5DenseGatedGeluDense):
755
+ module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
756
+ if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
757
+ module.wi_0.bias.data.zero_()
758
+ module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
759
+ if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
760
+ module.wi_1.bias.data.zero_()
761
+ module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
762
+ if hasattr(module.wo, "bias") and module.wo.bias is not None:
763
+ module.wo.bias.data.zero_()
764
+ elif isinstance(module, T5Attention):
765
+ # Mesh TensorFlow attention initialization to avoid scaling before softmax
766
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
767
+ d_model = self.config.d_model
768
+ key_value_proj_dim = self.config.d_kv
769
+ n_heads = self.config.num_heads
770
+ module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
771
+ module.k.weight.data.normal_(mean=0.0, std=factor * (d_model ** -0.5))
772
+ module.v.weight.data.normal_(mean=0.0, std=factor * (d_model ** -0.5))
773
+ module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
774
+ if module.has_relative_attention_bias:
775
+ module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
776
+
777
+ def _set_gradient_checkpointing(self, module, value=False):
778
+ if isinstance(module, (T5Attention, T5Stack)):
779
+ module.gradient_checkpointing = value
780
+
781
+ def _shift_right(self, input_ids):
782
+ decoder_start_token_id = self.config.decoder_start_token_id
783
+ pad_token_id = self.config.pad_token_id
784
+
785
+ assert (
786
+ decoder_start_token_id is not None
787
+ ), "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. See T5 docs for more information"
788
+
789
+ # shift inputs to the right
790
+ if is_torch_fx_proxy(input_ids):
791
+ # Item assignment is not supported natively for proxies.
792
+ shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
793
+ shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
794
+ else:
795
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
796
+ shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
797
+ shifted_input_ids[..., 0] = decoder_start_token_id
798
+
799
+ assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
800
+ # replace possible -100 values in labels by `pad_token_id`
801
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
802
+
803
+ assert torch.all(shifted_input_ids >= 0).item(), "Verify that `shifted_input_ids` has only positive values"
804
+
805
+ return shifted_input_ids
806
+
807
+
808
+ class T5Stack(T5PreTrainedModel):
809
+ def __init__(self, config, embed_tokens=None):
810
+ super().__init__(config)
811
+
812
+ self.embed_tokens = embed_tokens
813
+ self.is_decoder = config.is_decoder
814
+
815
+ self.block = nn.ModuleList(
816
+ [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]
817
+ )
818
+ self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
819
+ self.dropout = nn.Dropout(config.dropout_rate)
820
+
821
+ # Initialize weights and apply final processing
822
+ self.post_init()
823
+ # Model parallel
824
+ self.model_parallel = False
825
+ self.device_map = None
826
+ self.gradient_checkpointing = False
827
+
828
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
829
+ def parallelize(self, device_map=None):
830
+ # Check validity of device_map
831
+ self.device_map = (
832
+ get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map
833
+ )
834
+ assert_device_map(self.device_map, len(self.block))
835
+ self.model_parallel = True
836
+ self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
837
+ self.last_device = "cuda:" + str(max(self.device_map.keys()))
838
+ # Load onto devices
839
+ for k, v in self.device_map.items():
840
+ for layer in v:
841
+ cuda_device = "cuda:" + str(k)
842
+ self.block[layer] = self.block[layer].to(cuda_device)
843
+
844
+ # Set embed_tokens to first layer
845
+ self.embed_tokens = self.embed_tokens.to(self.first_device)
846
+ # Set final layer norm to last device
847
+ self.final_layer_norm = self.final_layer_norm.to(self.last_device)
848
+
849
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
850
+ def deparallelize(self):
851
+ self.model_parallel = False
852
+ self.device_map = None
853
+ self.first_device = "cpu"
854
+ self.last_device = "cpu"
855
+ for i in range(len(self.block)):
856
+ self.block[i] = self.block[i].to("cpu")
857
+ self.embed_tokens = self.embed_tokens.to("cpu")
858
+ self.final_layer_norm = self.final_layer_norm.to("cpu")
859
+ torch.cuda.empty_cache()
860
+
861
+ def get_input_embeddings(self):
862
+ return self.embed_tokens
863
+
864
+ def set_input_embeddings(self, new_embeddings):
865
+ self.embed_tokens = new_embeddings
866
+
867
+ def forward(
868
+ self,
869
+ input_ids=None,
870
+ attention_mask=None,
871
+ encoder_hidden_states=None,
872
+ encoder_attention_mask=None,
873
+ inputs_embeds=None,
874
+ head_mask=None,
875
+ cross_attn_head_mask=None,
876
+ past_key_values=None,
877
+ use_cache=None,
878
+ output_attentions=None,
879
+ output_hidden_states=None,
880
+ return_dict=None,
881
+ ):
882
+ # Model parallel
883
+ if self.model_parallel:
884
+ torch.cuda.set_device(self.first_device)
885
+ self.embed_tokens = self.embed_tokens.to(self.first_device)
886
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
887
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
888
+ output_hidden_states = (
889
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
890
+ )
891
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
892
+
893
+ if input_ids is not None and inputs_embeds is not None:
894
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
895
+ raise ValueError(
896
+ f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
897
+ )
898
+ elif input_ids is not None:
899
+ input_shape = input_ids.size()
900
+ input_ids = input_ids.view(-1, input_shape[-1])
901
+ elif inputs_embeds is not None:
902
+ input_shape = inputs_embeds.size()[:-1]
903
+ else:
904
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
905
+ raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
906
+
907
+ if inputs_embeds is None:
908
+ assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings"
909
+ inputs_embeds = self.embed_tokens(input_ids)
910
+
911
+ batch_size, seq_length = input_shape
912
+
913
+ # required mask seq length can be calculated via length of past
914
+ mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
915
+
916
+ if use_cache is True:
917
+ assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder"
918
+
919
+ if attention_mask is None:
920
+ attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device)
921
+ if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
922
+ encoder_seq_length = encoder_hidden_states.shape[1]
923
+ encoder_attention_mask = torch.ones(
924
+ batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long
925
+ )
926
+
927
+ # initialize past_key_values with `None` if past does not exist
928
+ if past_key_values is None:
929
+ past_key_values = [None] * len(self.block)
930
+
931
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
932
+ # ourselves in which case we just need to make it broadcastable to all heads.
933
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, inputs_embeds.device)
934
+
935
+ # If a 2D or 3D attention mask is provided for the cross-attention
936
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
937
+ if self.is_decoder and encoder_hidden_states is not None:
938
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
939
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
940
+ if encoder_attention_mask is None:
941
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
942
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
943
+ else:
944
+ encoder_extended_attention_mask = None
945
+
946
+ # Prepare head mask if needed
947
+ head_mask = self.get_head_mask(head_mask, self.config.num_layers)
948
+ cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
949
+ present_key_value_states = () if use_cache else None
950
+ all_hidden_states = () if output_hidden_states else None
951
+ all_attentions = () if output_attentions else None
952
+ all_cross_attentions = () if (output_attentions and self.is_decoder) else None
953
+ position_bias = None
954
+ encoder_decoder_position_bias = None
955
+
956
+ hidden_states = self.dropout(inputs_embeds)
957
+
958
+ for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
959
+ layer_head_mask = head_mask[i]
960
+ cross_attn_layer_head_mask = cross_attn_head_mask[i]
961
+ # Model parallel
962
+ if self.model_parallel:
963
+ torch.cuda.set_device(hidden_states.device)
964
+ # Ensure that attention_mask is always on the same device as hidden_states
965
+ if attention_mask is not None:
966
+ attention_mask = attention_mask.to(hidden_states.device)
967
+ if position_bias is not None:
968
+ position_bias = position_bias.to(hidden_states.device)
969
+ if encoder_hidden_states is not None:
970
+ encoder_hidden_states = encoder_hidden_states.to(hidden_states.device)
971
+ if encoder_extended_attention_mask is not None:
972
+ encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device)
973
+ if encoder_decoder_position_bias is not None:
974
+ encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
975
+ if layer_head_mask is not None:
976
+ layer_head_mask = layer_head_mask.to(hidden_states.device)
977
+ if cross_attn_layer_head_mask is not None:
978
+ cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device)
979
+ if output_hidden_states:
980
+ all_hidden_states = all_hidden_states + (hidden_states,)
981
+
982
+ if self.gradient_checkpointing and self.training:
983
+ if use_cache:
984
+ logger.warn(
985
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
986
+ )
987
+ use_cache = False
988
+
989
+ def create_custom_forward(module):
990
+ def custom_forward(*inputs):
991
+ return tuple(module(*inputs, use_cache, output_attentions))
992
+
993
+ return custom_forward
994
+
995
+ layer_outputs = checkpoint(
996
+ create_custom_forward(layer_module),
997
+ hidden_states,
998
+ extended_attention_mask,
999
+ position_bias,
1000
+ encoder_hidden_states,
1001
+ encoder_extended_attention_mask,
1002
+ encoder_decoder_position_bias,
1003
+ layer_head_mask,
1004
+ cross_attn_layer_head_mask,
1005
+ None, # past_key_value is always None with gradient checkpointing
1006
+ )
1007
+ else:
1008
+ layer_outputs = layer_module(
1009
+ hidden_states,
1010
+ attention_mask=extended_attention_mask,
1011
+ position_bias=position_bias,
1012
+ encoder_hidden_states=encoder_hidden_states,
1013
+ encoder_attention_mask=encoder_extended_attention_mask,
1014
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
1015
+ layer_head_mask=layer_head_mask,
1016
+ cross_attn_layer_head_mask=cross_attn_layer_head_mask,
1017
+ past_key_value=past_key_value,
1018
+ use_cache=use_cache,
1019
+ output_attentions=output_attentions,
1020
+ )
1021
+
1022
+ # layer_outputs is a tuple with:
1023
+ # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
1024
+ if use_cache is False:
1025
+ layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
1026
+
1027
+ hidden_states, present_key_value_state = layer_outputs[:2]
1028
+
1029
+ # We share the position biases between the layers - the first layer store them
1030
+ # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
1031
+ # (cross-attention position bias), (cross-attention weights)
1032
+ position_bias = layer_outputs[2]
1033
+ if self.is_decoder and encoder_hidden_states is not None:
1034
+ encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
1035
+ # append next layer key value states
1036
+ if use_cache:
1037
+ present_key_value_states = present_key_value_states + (present_key_value_state,)
1038
+
1039
+ if output_attentions:
1040
+ all_attentions = all_attentions + (layer_outputs[3],)
1041
+ if self.is_decoder:
1042
+ all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
1043
+
1044
+ # Model Parallel: If it's the last layer for that device, put things on the next device
1045
+ if self.model_parallel:
1046
+ for k, v in self.device_map.items():
1047
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
1048
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
1049
+
1050
+ hidden_states = self.final_layer_norm(hidden_states)
1051
+ hidden_states = self.dropout(hidden_states)
1052
+
1053
+ # Add last layer
1054
+ if output_hidden_states:
1055
+ all_hidden_states = all_hidden_states + (hidden_states,)
1056
+
1057
+ if not return_dict:
1058
+ return tuple(
1059
+ v
1060
+ for v in [
1061
+ hidden_states,
1062
+ present_key_value_states,
1063
+ all_hidden_states,
1064
+ all_attentions,
1065
+ all_cross_attentions,
1066
+ ]
1067
+ if v is not None
1068
+ )
1069
+ return BaseModelOutputWithPastAndCrossAttentions(
1070
+ last_hidden_state=hidden_states,
1071
+ past_key_values=present_key_value_states,
1072
+ hidden_states=all_hidden_states,
1073
+ attentions=all_attentions,
1074
+ cross_attentions=all_cross_attentions,
1075
+ )
1076
+
1077
+
1078
+ T5_START_DOCSTRING = r"""
1079
+
1080
+ The T5 model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text
1081
+ Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan
1082
+ Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a
1083
+ text-to-text denoising generative setting.
1084
+
1085
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1086
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1087
+ etc.)
1088
+
1089
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1090
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1091
+ and behavior.
1092
+
1093
+ Parameters:
1094
+ config ([`T5Config`]): Model configuration class with all the parameters of the model.
1095
+ Initializing with a config file does not load the weights associated with the model, only the
1096
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1097
+ """
1098
+
1099
+ T5_INPUTS_DOCSTRING = r"""
1100
+ Args:
1101
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1102
+ Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
1103
+ should be able to pad the inputs on both the right and the left.
1104
+
1105
+ Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
1106
+ [`PreTrainedTokenizer.__call__`] for detail.
1107
+
1108
+ [What are input IDs?](../glossary#input-ids)
1109
+
1110
+ To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).
1111
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
1112
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1113
+
1114
+ - 1 for tokens that are **not masked**,
1115
+ - 0 for tokens that are **masked**.
1116
+
1117
+ [What are attention masks?](../glossary#attention-mask)
1118
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
1119
+ Indices of decoder input sequence tokens in the vocabulary.
1120
+
1121
+ Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
1122
+ [`PreTrainedTokenizer.__call__`] for details.
1123
+
1124
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
1125
+
1126
+ T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
1127
+ is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
1128
+
1129
+ To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5
1130
+ Training](./t5#training).
1131
+ decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
1132
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
1133
+ be used by default.
1134
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
1135
+ Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in `[0,
1136
+ 1]`:
1137
+
1138
+ - 1 indicates the head is **not masked**,
1139
+ - 0 indicates the head is **masked**.
1140
+
1141
+ decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
1142
+ Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0,
1143
+ 1]`:
1144
+
1145
+ - 1 indicates the head is **not masked**,
1146
+ - 0 indicates the head is **masked**.
1147
+
1148
+ cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
1149
+ Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
1150
+ `[0, 1]`:
1151
+
1152
+ - 1 indicates the head is **not masked**,
1153
+ - 0 indicates the head is **masked**.
1154
+
1155
+ encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
1156
+ Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*)
1157
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at
1158
+ the output of the last layer of the encoder. Used in the cross-attention of the decoder.
1159
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1160
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1161
+
1162
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
1163
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
1164
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1165
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1166
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1167
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1168
+ model's internal embedding lookup matrix.
1169
+ decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
1170
+ Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
1171
+ representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
1172
+ input (see `past_key_values`). This is useful if you want more control over how to convert
1173
+ `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
1174
+
1175
+ If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
1176
+ of `inputs_embeds`.
1177
+
1178
+ use_cache (`bool`, *optional*):
1179
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1180
+ `past_key_values`).
1181
+
1182
+ output_attentions (`bool`, *optional*):
1183
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1184
+ tensors for more detail.
1185
+ output_hidden_states (`bool`, *optional*):
1186
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1187
+ more detail.
1188
+ return_dict (`bool`, *optional*):
1189
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
1190
+ """
1191
+
1192
+ T5_ENCODER_INPUTS_DOCSTRING = r"""
1193
+ Args:
1194
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1195
+ Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
1196
+ should be able to pad the inputs on both the right and the left.
1197
+
1198
+ Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
1199
+ [`PreTrainedTokenizer.__call__`] for detail.
1200
+
1201
+ To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).
1202
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
1203
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1204
+
1205
+ - 1 for tokens that are **not masked**,
1206
+ - 0 for tokens that are **masked**.
1207
+
1208
+ [What are attention masks?](../glossary#attention-mask)
1209
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
1210
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
1211
+
1212
+ - 1 indicates the head is **not masked**,
1213
+ - 0 indicates the head is **masked**.
1214
+
1215
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1216
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1217
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1218
+ model's internal embedding lookup matrix.
1219
+ output_attentions (`bool`, *optional*):
1220
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1221
+ tensors for more detail.
1222
+ output_hidden_states (`bool`, *optional*):
1223
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1224
+ more detail.
1225
+ return_dict (`bool`, *optional*):
1226
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
1227
+ """
1228
+
1229
+ # Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
1230
+ __HEAD_MASK_WARNING_MSG = """
1231
+ The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
1232
+ `decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
1233
+ If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,
1234
+ num_heads)`.
1235
+ """
1236
+
1237
+
1238
+ @add_start_docstrings(
1239
+ "The bare T5 Model transformer outputting raw hidden-states without any specific head on top.",
1240
+ T5_START_DOCSTRING,
1241
+ )
1242
+ class T5Model(T5PreTrainedModel):
1243
+ _keys_to_ignore_on_load_missing = [
1244
+ r"encoder\.embed_tokens\.weight",
1245
+ r"decoder\.embed_tokens\.weight",
1246
+ ]
1247
+ _keys_to_ignore_on_load_unexpected = [
1248
+ r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
1249
+ ]
1250
+
1251
+ def __init__(self, config: T5Config):
1252
+ super().__init__(config)
1253
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
1254
+
1255
+ encoder_config = copy.deepcopy(config)
1256
+ encoder_config.is_decoder = False
1257
+ encoder_config.use_cache = False
1258
+ encoder_config.is_encoder_decoder = False
1259
+ self.encoder = T5Stack(encoder_config, self.shared)
1260
+
1261
+ decoder_config = copy.deepcopy(config)
1262
+ decoder_config.is_decoder = True
1263
+ decoder_config.is_encoder_decoder = False
1264
+ decoder_config.num_layers = config.num_decoder_layers
1265
+ self.decoder = T5Stack(decoder_config, self.shared)
1266
+
1267
+ # Initialize weights and apply final processing
1268
+ self.post_init()
1269
+
1270
+ # Model parallel
1271
+ self.model_parallel = False
1272
+ self.device_map = None
1273
+
1274
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
1275
+ def parallelize(self, device_map=None):
1276
+ self.device_map = (
1277
+ get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
1278
+ if device_map is None
1279
+ else device_map
1280
+ )
1281
+ assert_device_map(self.device_map, len(self.encoder.block))
1282
+ self.encoder.parallelize(self.device_map)
1283
+ self.decoder.parallelize(self.device_map)
1284
+ self.model_parallel = True
1285
+
1286
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
1287
+ def deparallelize(self):
1288
+ self.encoder.deparallelize()
1289
+ self.decoder.deparallelize()
1290
+ self.encoder = self.encoder.to("cpu")
1291
+ self.decoder = self.decoder.to("cpu")
1292
+ self.model_parallel = False
1293
+ self.device_map = None
1294
+ torch.cuda.empty_cache()
1295
+
1296
+ def get_input_embeddings(self):
1297
+ return self.shared
1298
+
1299
+ def set_input_embeddings(self, new_embeddings):
1300
+ self.shared = new_embeddings
1301
+ self.encoder.set_input_embeddings(new_embeddings)
1302
+ self.decoder.set_input_embeddings(new_embeddings)
1303
+
1304
+ def get_encoder(self):
1305
+ return self.encoder
1306
+
1307
+ def get_decoder(self):
1308
+ return self.decoder
1309
+
1310
+ def _prune_heads(self, heads_to_prune):
1311
+ """
1312
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
1313
+ class PreTrainedModel
1314
+ """
1315
+ for layer, heads in heads_to_prune.items():
1316
+ self.encoder.layer[layer].attention.prune_heads(heads)
1317
+
1318
+ @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
1319
+ @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
1320
+ def forward(
1321
+ self,
1322
+ input_ids=None,
1323
+ attention_mask=None,
1324
+ decoder_input_ids=None,
1325
+ decoder_attention_mask=None,
1326
+ head_mask=None,
1327
+ decoder_head_mask=None,
1328
+ cross_attn_head_mask=None,
1329
+ encoder_outputs=None,
1330
+ past_key_values=None,
1331
+ inputs_embeds=None,
1332
+ decoder_inputs_embeds=None,
1333
+ use_cache=None,
1334
+ output_attentions=None,
1335
+ output_hidden_states=None,
1336
+ return_dict=None,
1337
+ ):
1338
+ r"""
1339
+ Returns:
1340
+
1341
+ Example:
1342
+
1343
+ ```python
1344
+ >>> from transformers import T5Tokenizer, T5Model
1345
+
1346
+ >>> tokenizer = T5Tokenizer.from_pretrained("t5-small")
1347
+ >>> model = T5Model.from_pretrained("t5-small")
1348
+
1349
+ >>> input_ids = tokenizer(
1350
+ ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
1351
+ >>> ).input_ids # Batch size 1
1352
+ >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
1353
+
1354
+ >>> # forward pass
1355
+ >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
1356
+ >>> last_hidden_states = outputs.last_hidden_state
1357
+ ```"""
1358
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1359
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1360
+
1361
+ # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
1362
+ if head_mask is not None and decoder_head_mask is None:
1363
+ if self.config.num_layers == self.config.num_decoder_layers:
1364
+ decoder_head_mask = head_mask
1365
+
1366
+ # Encode if needed (training, first prediction pass)
1367
+ if encoder_outputs is None:
1368
+ encoder_outputs = self.encoder(
1369
+ input_ids=input_ids,
1370
+ attention_mask=attention_mask,
1371
+ inputs_embeds=inputs_embeds,
1372
+ head_mask=head_mask,
1373
+ output_attentions=output_attentions,
1374
+ output_hidden_states=output_hidden_states,
1375
+ return_dict=return_dict,
1376
+ )
1377
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1378
+ encoder_outputs = BaseModelOutput(
1379
+ last_hidden_state=encoder_outputs[0],
1380
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1381
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1382
+ )
1383
+
1384
+ hidden_states = encoder_outputs[0]
1385
+ if self.model_parallel:
1386
+ torch.cuda.set_device(self.decoder.first_device)
1387
+ # Set device for model parallelism
1388
+ if self.model_parallel:
1389
+ torch.cuda.set_device(self.decoder.first_device)
1390
+ hidden_states = hidden_states.to(self.decoder.first_device)
1391
+ if decoder_input_ids is not None:
1392
+ decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
1393
+ if attention_mask is not None:
1394
+ attention_mask = attention_mask.to(self.decoder.first_device)
1395
+ if decoder_attention_mask is not None:
1396
+ decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
1397
+
1398
+ # Decode
1399
+ decoder_outputs = self.decoder(
1400
+ input_ids=decoder_input_ids,
1401
+ attention_mask=decoder_attention_mask,
1402
+ inputs_embeds=decoder_inputs_embeds,
1403
+ past_key_values=past_key_values,
1404
+ encoder_hidden_states=hidden_states,
1405
+ encoder_attention_mask=attention_mask,
1406
+ head_mask=decoder_head_mask,
1407
+ cross_attn_head_mask=cross_attn_head_mask,
1408
+ use_cache=use_cache,
1409
+ output_attentions=output_attentions,
1410
+ output_hidden_states=output_hidden_states,
1411
+ return_dict=return_dict,
1412
+ )
1413
+
1414
+ if not return_dict:
1415
+ return decoder_outputs + encoder_outputs
1416
+
1417
+ return Seq2SeqModelOutput(
1418
+ last_hidden_state=decoder_outputs.last_hidden_state,
1419
+ past_key_values=decoder_outputs.past_key_values,
1420
+ decoder_hidden_states=decoder_outputs.hidden_states,
1421
+ decoder_attentions=decoder_outputs.attentions,
1422
+ cross_attentions=decoder_outputs.cross_attentions,
1423
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1424
+ encoder_hidden_states=encoder_outputs.hidden_states,
1425
+ encoder_attentions=encoder_outputs.attentions,
1426
+ )
1427
+
1428
+
1429
+ @add_start_docstrings("""T5 Model with a `language modeling` head on top.""", T5_START_DOCSTRING)
1430
+ class T5ForConditionalGeneration(T5PreTrainedModel):
1431
+ _keys_to_ignore_on_load_missing = [
1432
+ r"encoder\.embed_tokens\.weight",
1433
+ r"decoder\.embed_tokens\.weight",
1434
+ r"lm_head\.weight",
1435
+ ]
1436
+ _keys_to_ignore_on_load_unexpected = [
1437
+ r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
1438
+ ]
1439
+
1440
+ def __init__(self, config):
1441
+ super().__init__(config)
1442
+ self.model_dim = config.d_model
1443
+
1444
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
1445
+
1446
+ encoder_config = copy.deepcopy(config)
1447
+ encoder_config.is_decoder = False
1448
+ encoder_config.use_cache = False
1449
+ encoder_config.is_encoder_decoder = False
1450
+ self.encoder = T5Stack(encoder_config, self.shared)
1451
+
1452
+ decoder_config = copy.deepcopy(config)
1453
+ decoder_config.is_decoder = True
1454
+ decoder_config.is_encoder_decoder = False
1455
+ decoder_config.num_layers = config.num_decoder_layers
1456
+ self.decoder = T5Stack(decoder_config, self.shared)
1457
+
1458
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
1459
+
1460
+ # Initialize weights and apply final processing
1461
+ self.post_init()
1462
+
1463
+ # Model parallel
1464
+ self.model_parallel = False
1465
+ self.device_map = None
1466
+
1467
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
1468
+ def parallelize(self, device_map=None):
1469
+ self.device_map = (
1470
+ get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
1471
+ if device_map is None
1472
+ else device_map
1473
+ )
1474
+ assert_device_map(self.device_map, len(self.encoder.block))
1475
+ self.encoder.parallelize(self.device_map)
1476
+ self.decoder.parallelize(self.device_map)
1477
+ self.lm_head = self.lm_head.to(self.decoder.first_device)
1478
+ self.model_parallel = True
1479
+
1480
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
1481
+ def deparallelize(self):
1482
+ self.encoder.deparallelize()
1483
+ self.decoder.deparallelize()
1484
+ self.encoder = self.encoder.to("cpu")
1485
+ self.decoder = self.decoder.to("cpu")
1486
+ self.lm_head = self.lm_head.to("cpu")
1487
+ self.model_parallel = False
1488
+ self.device_map = None
1489
+ torch.cuda.empty_cache()
1490
+
1491
+ def get_input_embeddings(self):
1492
+ return self.shared
1493
+
1494
+ def set_input_embeddings(self, new_embeddings):
1495
+ self.shared = new_embeddings
1496
+ self.encoder.set_input_embeddings(new_embeddings)
1497
+ self.decoder.set_input_embeddings(new_embeddings)
1498
+
1499
+ def set_output_embeddings(self, new_embeddings):
1500
+ self.lm_head = new_embeddings
1501
+
1502
+ def get_output_embeddings(self):
1503
+ return self.lm_head
1504
+
1505
+ def get_encoder(self):
1506
+ return self.encoder
1507
+
1508
+ def get_decoder(self):
1509
+ return self.decoder
1510
+
1511
+ @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
1512
+ @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
1513
+ def forward(
1514
+ self,
1515
+ input_ids=None,
1516
+ attention_mask=None,
1517
+ decoder_input_ids=None,
1518
+ decoder_attention_mask=None,
1519
+ head_mask=None,
1520
+ decoder_head_mask=None,
1521
+ cross_attn_head_mask=None,
1522
+ encoder_outputs=None,
1523
+ past_key_values=None,
1524
+ inputs_embeds=None,
1525
+ decoder_inputs_embeds=None,
1526
+ labels=None,
1527
+ use_cache=None,
1528
+ output_attentions=None,
1529
+ output_hidden_states=None,
1530
+ return_dict=None,
1531
+ ):
1532
+ r"""
1533
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1534
+ Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
1535
+ config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
1536
+ labels in `[0, ..., config.vocab_size]`
1537
+
1538
+ Returns:
1539
+
1540
+ Examples:
1541
+
1542
+ ```python
1543
+ >>> from transformers import T5Tokenizer, T5ForConditionalGeneration
1544
+
1545
+ >>> tokenizer = T5Tokenizer.from_pretrained("t5-small")
1546
+ >>> model = T5ForConditionalGeneration.from_pretrained("t5-small")
1547
+
1548
+ >>> # training
1549
+ >>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
1550
+ >>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
1551
+ >>> outputs = model(input_ids=input_ids, labels=labels)
1552
+ >>> loss = outputs.loss
1553
+ >>> logits = outputs.logits
1554
+
1555
+ >>> # inference
1556
+ >>> input_ids = tokenizer(
1557
+ ... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt"
1558
+ >>> ).input_ids # Batch size 1
1559
+ >>> outputs = model.generate(input_ids)
1560
+ >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
1561
+ >>> # studies have shown that owning a dog is good for you.
1562
+ ```"""
1563
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1564
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1565
+
1566
+ # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
1567
+ if head_mask is not None and decoder_head_mask is None:
1568
+ if self.config.num_layers == self.config.num_decoder_layers:
1569
+ decoder_head_mask = head_mask
1570
+
1571
+ # Encode if needed (training, first prediction pass)
1572
+ if encoder_outputs is None:
1573
+ # Convert encoder inputs in embeddings if needed
1574
+ encoder_outputs = self.encoder(
1575
+ input_ids=input_ids,
1576
+ attention_mask=attention_mask,
1577
+ inputs_embeds=inputs_embeds,
1578
+ head_mask=head_mask,
1579
+ output_attentions=output_attentions,
1580
+ output_hidden_states=output_hidden_states,
1581
+ return_dict=return_dict,
1582
+ )
1583
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1584
+ encoder_outputs = BaseModelOutput(
1585
+ last_hidden_state=encoder_outputs[0],
1586
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1587
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1588
+ )
1589
+
1590
+ hidden_states = encoder_outputs[0]
1591
+
1592
+ if self.model_parallel:
1593
+ torch.cuda.set_device(self.decoder.first_device)
1594
+
1595
+ if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
1596
+ # get decoder inputs from shifting lm labels to the right
1597
+ decoder_input_ids = self._shift_right(labels)
1598
+
1599
+ # Set device for model parallelism
1600
+ if self.model_parallel:
1601
+ torch.cuda.set_device(self.decoder.first_device)
1602
+ hidden_states = hidden_states.to(self.decoder.first_device)
1603
+ if decoder_input_ids is not None:
1604
+ decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
1605
+ if attention_mask is not None:
1606
+ attention_mask = attention_mask.to(self.decoder.first_device)
1607
+ if decoder_attention_mask is not None:
1608
+ decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
1609
+
1610
+ # Decode
1611
+ decoder_outputs = self.decoder(
1612
+ input_ids=decoder_input_ids,
1613
+ attention_mask=decoder_attention_mask,
1614
+ inputs_embeds=decoder_inputs_embeds,
1615
+ past_key_values=past_key_values,
1616
+ encoder_hidden_states=hidden_states,
1617
+ encoder_attention_mask=attention_mask,
1618
+ head_mask=decoder_head_mask,
1619
+ cross_attn_head_mask=cross_attn_head_mask,
1620
+ use_cache=use_cache,
1621
+ output_attentions=output_attentions,
1622
+ output_hidden_states=output_hidden_states,
1623
+ return_dict=return_dict,
1624
+ )
1625
+
1626
+ sequence_output = decoder_outputs[0]
1627
+
1628
+ # Set device for model parallelism
1629
+ if self.model_parallel:
1630
+ torch.cuda.set_device(self.encoder.first_device)
1631
+ self.lm_head = self.lm_head.to(self.encoder.first_device)
1632
+ sequence_output = sequence_output.to(self.lm_head.weight.device)
1633
+
1634
+ if self.config.tie_word_embeddings:
1635
+ # Rescale output before projecting on vocab
1636
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
1637
+ sequence_output = sequence_output * (self.model_dim ** -0.5)
1638
+
1639
+ lm_logits = self.lm_head(sequence_output)
1640
+
1641
+ loss = None
1642
+ if labels is not None:
1643
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
1644
+ loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
1645
+ # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
1646
+
1647
+ if not return_dict:
1648
+ output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
1649
+ return ((loss,) + output) if loss is not None else output
1650
+
1651
+ return Seq2SeqLMOutput(
1652
+ loss=loss,
1653
+ logits=lm_logits,
1654
+ past_key_values=decoder_outputs.past_key_values,
1655
+ decoder_hidden_states=decoder_outputs.hidden_states,
1656
+ decoder_attentions=decoder_outputs.attentions,
1657
+ cross_attentions=decoder_outputs.cross_attentions,
1658
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1659
+ encoder_hidden_states=encoder_outputs.hidden_states,
1660
+ encoder_attentions=encoder_outputs.attentions,
1661
+ )
1662
+
1663
+ def prepare_inputs_for_generation(
1664
+ self,
1665
+ input_ids,
1666
+ past=None,
1667
+ attention_mask=None,
1668
+ head_mask=None,
1669
+ decoder_head_mask=None,
1670
+ cross_attn_head_mask=None,
1671
+ use_cache=None,
1672
+ encoder_outputs=None,
1673
+ **kwargs
1674
+ ):
1675
+
1676
+ # cut decoder_input_ids if past is used
1677
+ if past is not None:
1678
+ input_ids = input_ids[:, -1:]
1679
+
1680
+ return {
1681
+ "decoder_input_ids": input_ids,
1682
+ "past_key_values": past,
1683
+ "encoder_outputs": encoder_outputs,
1684
+ "attention_mask": attention_mask,
1685
+ "head_mask": head_mask,
1686
+ "decoder_head_mask": decoder_head_mask,
1687
+ "cross_attn_head_mask": cross_attn_head_mask,
1688
+ "use_cache": use_cache,
1689
+ }
1690
+
1691
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
1692
+ return self._shift_right(labels)
1693
+
1694
+ def _reorder_cache(self, past, beam_idx):
1695
+ # if decoder past is not included in output
1696
+ # speedy decoding is disabled and no need to reorder
1697
+ if past is None:
1698
+ logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
1699
+ return past
1700
+
1701
+ reordered_decoder_past = ()
1702
+ for layer_past_states in past:
1703
+ # get the correct batch idx from layer past batch dim
1704
+ # batch dim of `past` is at 2nd position
1705
+ reordered_layer_past_states = ()
1706
+ for layer_past_state in layer_past_states:
1707
+ # need to set correct `past` for each of the four key / value states
1708
+ reordered_layer_past_states = reordered_layer_past_states + (
1709
+ layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
1710
+ )
1711
+
1712
+ assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
1713
+ assert len(reordered_layer_past_states) == len(layer_past_states)
1714
+
1715
+ reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
1716
+ return reordered_decoder_past
1717
+
1718
+
1719
+ @add_start_docstrings(
1720
+ "The bare T5 Model transformer outputting encoder's raw hidden-states without any specific head on top.",
1721
+ T5_START_DOCSTRING,
1722
+ )
1723
+ class T5EncoderModel(T5PreTrainedModel):
1724
+ authorized_missing_keys = [
1725
+ r"encoder\.embed_tokens\.weight",
1726
+ ]
1727
+
1728
+ def __init__(self, config: T5Config):
1729
+ super().__init__(config)
1730
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
1731
+
1732
+ encoder_config = copy.deepcopy(config)
1733
+ encoder_config.use_cache = False
1734
+ encoder_config.is_encoder_decoder = False
1735
+ self.encoder = T5Stack(encoder_config, self.shared)
1736
+
1737
+ # Initialize weights and apply final processing
1738
+ self.post_init()
1739
+
1740
+ # Model parallel
1741
+ self.model_parallel = False
1742
+ self.device_map = None
1743
+
1744
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
1745
+ def parallelize(self, device_map=None):
1746
+ self.device_map = (
1747
+ get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
1748
+ if device_map is None
1749
+ else device_map
1750
+ )
1751
+ assert_device_map(self.device_map, len(self.encoder.block))
1752
+ self.encoder.parallelize(self.device_map)
1753
+ self.model_parallel = True
1754
+
1755
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
1756
+ def deparallelize(self):
1757
+ self.encoder.deparallelize()
1758
+ self.encoder = self.encoder.to("cpu")
1759
+ self.model_parallel = False
1760
+ self.device_map = None
1761
+ torch.cuda.empty_cache()
1762
+
1763
+ def get_input_embeddings(self):
1764
+ return self.shared
1765
+
1766
+ def set_input_embeddings(self, new_embeddings):
1767
+ self.shared = new_embeddings
1768
+ self.encoder.set_input_embeddings(new_embeddings)
1769
+
1770
+ def get_encoder(self):
1771
+ return self.encoder
1772
+
1773
+ def _prune_heads(self, heads_to_prune):
1774
+ """
1775
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
1776
+ class PreTrainedModel
1777
+ """
1778
+ for layer, heads in heads_to_prune.items():
1779
+ self.encoder.layer[layer].attention.prune_heads(heads)
1780
+
1781
+ @add_start_docstrings_to_model_forward(T5_ENCODER_INPUTS_DOCSTRING)
1782
+ @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
1783
+ def forward(
1784
+ self,
1785
+ input_ids=None,
1786
+ attention_mask=None,
1787
+ head_mask=None,
1788
+ inputs_embeds=None,
1789
+ output_attentions=None,
1790
+ output_hidden_states=None,
1791
+ return_dict=None,
1792
+ ):
1793
+ r"""
1794
+ Returns:
1795
+
1796
+ Example:
1797
+
1798
+ ```python
1799
+ >>> from transformers import T5Tokenizer, T5EncoderModel
1800
+
1801
+ >>> tokenizer = T5Tokenizer.from_pretrained("t5-small")
1802
+ >>> model = T5EncoderModel.from_pretrained("t5-small")
1803
+ >>> input_ids = tokenizer(
1804
+ ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
1805
+ >>> ).input_ids # Batch size 1
1806
+ >>> outputs = model(input_ids=input_ids)
1807
+ >>> last_hidden_states = outputs.last_hidden_state
1808
+ ```"""
1809
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1810
+
1811
+ encoder_outputs = self.encoder(
1812
+ input_ids=input_ids,
1813
+ attention_mask=attention_mask,
1814
+ inputs_embeds=inputs_embeds,
1815
+ head_mask=head_mask,
1816
+ output_attentions=output_attentions,
1817
+ output_hidden_states=output_hidden_states,
1818
+ return_dict=return_dict,
1819
+ )
1820
+
1821
+ return encoder_outputs
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e04b0874e38072331eb99efe6cdd4759268b8e516a23fce6bac21aa7687b1887
3
+ size 6845775809