Text Generation
Transformers
PyTorch
Safetensors
English
gpt_refact
code
custom_code
Eval Results
svakhreev commited on
Commit
a16281b
1 Parent(s): fb08260

Upload GPTRefactForCausalLM

Browse files
config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "GPTRefactForCausalLM"
4
+ ],
5
+ "attention_softmax_in_fp32": false,
6
+ "attn_pdrop": 0.1,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_gpt_refact.GPTRefactConfig",
9
+ "AutoModelForCausalLM": "modeling_gpt_refact.GPTRefactForCausalLM"
10
+ },
11
+ "bos_token_id": -1,
12
+ "do_sample": true,
13
+ "embd_pdrop": 0.1,
14
+ "eos_token_id": 0,
15
+ "initializer_range": 0.02,
16
+ "layer_norm_epsilon": 1e-05,
17
+ "model_type": "gpt_refact",
18
+ "multi_query": true,
19
+ "n_embd": 2048,
20
+ "n_head": 32,
21
+ "n_inner": null,
22
+ "n_layer": 32,
23
+ "n_positions": 4096,
24
+ "resid_pdrop": 0.1,
25
+ "scale_attention_softmax_in_fp32": false,
26
+ "scale_attn_weights": true,
27
+ "torch_dtype": "float32",
28
+ "transformers_version": "4.31.0",
29
+ "use_cache": true,
30
+ "vocab_size": 49216
31
+ }
configuration_gpt_refact.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+ from transformers.utils import logging
3
+
4
+
5
+ logger = logging.get_logger(__name__)
6
+
7
+
8
+ class GPTRefactConfig(PretrainedConfig):
9
+ model_type = "gpt_refact"
10
+ keys_to_ignore_at_inference = ["past_key_values"]
11
+ attribute_map = {
12
+ "hidden_size": "n_embd",
13
+ "max_position_embeddings": "n_positions",
14
+ "num_attention_heads": "n_head",
15
+ "num_hidden_layers": "n_layer",
16
+ }
17
+
18
+ def __init__(
19
+ self,
20
+ vocab_size: int = 49216,
21
+ n_positions: int = 4096,
22
+ n_embd: int = 1024,
23
+ n_layer: int = 32,
24
+ n_head: int = 64,
25
+ max_position_embeddings: int = 4096,
26
+ multi_query: bool = True,
27
+ layer_norm_epsilon=1e-5,
28
+ initializer_range=0.02,
29
+ scale_attn_weights=True,
30
+ use_cache=True,
31
+ bos_token_id=-1,
32
+ eos_token_id=0,
33
+ attention_softmax_in_fp32=False,
34
+ scale_attention_softmax_in_fp32=False,
35
+ resid_pdrop=0.1,
36
+ embd_pdrop=0.1,
37
+ attn_pdrop=0.1,
38
+ **kwargs,
39
+ ):
40
+ self.vocab_size = vocab_size
41
+ self.n_positions = n_positions
42
+ self.n_embd = n_embd
43
+ self.n_layer = n_layer
44
+ self.n_head = n_head
45
+ self.n_inner = None
46
+ self.resid_pdrop = resid_pdrop
47
+ self.embd_pdrop = embd_pdrop
48
+ self.attn_pdrop = attn_pdrop
49
+ self.layer_norm_epsilon = layer_norm_epsilon
50
+ self.initializer_range = initializer_range
51
+ self.scale_attn_weights = scale_attn_weights
52
+ self.use_cache = use_cache
53
+ self.attention_softmax_in_fp32 = attention_softmax_in_fp32
54
+ self.scale_attention_softmax_in_fp32 = scale_attention_softmax_in_fp32
55
+
56
+ self.bos_token_id = bos_token_id
57
+ self.eos_token_id = eos_token_id
58
+
59
+ self.multi_query = multi_query
60
+ self.max_position_embeddings = max_position_embeddings
61
+ super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": -1,
4
+ "do_sample": true,
5
+ "eos_token_id": 0,
6
+ "transformers_version": "4.31.0"
7
+ }
modeling_gpt_refact.py ADDED
@@ -0,0 +1,586 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import torch.utils.checkpoint
5
+ from torch import nn
6
+ from torch.nn import CrossEntropyLoss
7
+ from transformers.modeling_outputs import (
8
+ BaseModelOutputWithPastAndCrossAttentions,
9
+ CausalLMOutputWithCrossAttentions,
10
+ )
11
+ from transformers.modeling_utils import PreTrainedModel
12
+ from transformers.utils import (
13
+ logging,
14
+ )
15
+ from typing import List, Optional, Tuple, Union
16
+
17
+ from .configuration_gpt_refact import GPTRefactConfig
18
+
19
+ logger = logging.get_logger(__name__)
20
+
21
+
22
+ @torch.jit.script
23
+ def upcast_masked_softmax(
24
+ x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float, softmax_dtype: torch.dtype
25
+ ):
26
+ input_dtype = x.dtype
27
+ x = x.to(softmax_dtype) * scale
28
+ x = torch.where(mask, x, mask_value)
29
+ x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype)
30
+ return x
31
+
32
+
33
+ @torch.jit.script
34
+ def upcast_softmax(x: torch.Tensor, scale: float, softmax_dtype: torch.dtype):
35
+ input_dtype = x.dtype
36
+ x = x.to(softmax_dtype) * scale
37
+ x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype)
38
+ return x
39
+
40
+
41
+ @torch.jit.script
42
+ def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor):
43
+ x = torch.where(mask, x, mask_value)
44
+ x = torch.nn.functional.softmax(x, dim=-1)
45
+ return x
46
+
47
+ @torch.jit.script
48
+ def _get_slopes(attn_heads: int, dev: torch.device) -> torch.Tensor:
49
+ """
50
+ ## Get head-specific slope $m$ for each head
51
+ * `n_heads` is the number of heads in the attention layer $n$
52
+ The slope for first head is
53
+ $$\frac{1}{2^{\frac{8}{n}}} = 2^{-\frac{8}{n}}$$
54
+ The slopes for the rest of the heads are in a geometric series with a ratio same as above.
55
+ For instance when the number of heads is $8$ the slopes are
56
+ $$\frac{1}{2^1}, \frac{1}{2^2}, \dots, \frac{1}{2^8}$$
57
+ """
58
+
59
+ # Get the closest power of 2 to `n_heads`.
60
+ # If `n_heads` is not a power of 2, then we first calculate slopes to the closest (smaller) power of 2,
61
+ # and then add the remaining slopes.
62
+ n = 2 ** math.floor(math.log(attn_heads, 2))
63
+ # $2^{-\frac{8}{n}}$
64
+ m_0 = 2.0 ** (-8.0 / n)
65
+ # $2^{-1\frac{8}{n}}, 2^{-2 \frac{8}{n}}, 2^{-3 \frac{8}{n}}, \dots$
66
+ m = torch.pow(m_0, torch.arange(1, 1 + n, device=dev))
67
+
68
+ # If `n_heads` is not a power of 2, then we add the remaining slopes.
69
+ # We calculate the remaining slopes for $n * 2$ (avoiding slopes added previously).
70
+ # And pick the slopes upto `n_heads`.
71
+ if n < attn_heads:
72
+ # $2^{-\frac{8}{2n}}$
73
+ m_hat_0 = 2.0 ** (-4.0 / n)
74
+ # $2^{-1\frac{8}{2n}}, 2^{-3 \frac{8}{2n}}, 2^{-5 \frac{8}{2n}}, \dots$
75
+ # Note that we take steps by $2$ to avoid slopes added previously.
76
+ m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (attn_heads - n), 2, device=dev))
77
+ # Concatenate the slopes with the remaining slopes.
78
+ m = torch.cat([m, m_hat])
79
+
80
+ return m
81
+
82
+ @torch.jit.script
83
+ def get_alibi_biases(
84
+ B: int,
85
+ T: int,
86
+ attn_heads: int,
87
+ dev: torch.device,
88
+ dtype: torch.dtype,
89
+ causal: bool = True) -> torch.Tensor:
90
+ """
91
+ ## Calculate the attention biases matrix
92
+ * `n_heads` is the number of heads in the attention layer
93
+ * `mask` is the attention mask of shape `[seq_len_q, seq_len_k]`
94
+ This returns a matrix of shape `[seq_len_q, seq_len_k, n_heads, ]` with ALiBi attention biases.
95
+ """
96
+
97
+ # Get slopes $m$ for each head
98
+ if causal:
99
+ mask = (torch.triu(torch.ones((T, T), device=dev)) == 1).transpose(0, 1)
100
+ else:
101
+ mask = torch.ones((T, T), device=dev, dtype=torch.bool)
102
+
103
+ m = _get_slopes(attn_heads, dev)
104
+
105
+ # Calculate distances $[0, 1, \dots, N]$
106
+ # Here we calculate the distances using the mask.
107
+ #
108
+ # Since it's causal mask we can just use $[0, 1, \dots, N]$ too.
109
+ # `distance = torch.arange(mask.shape[1], dtype=torch.long, device=mask.device)[None, :]`
110
+ distance = mask.cumsum(dim=-1)
111
+
112
+ # Multiply them pair-wise to get the AliBi bias matrix
113
+ biases = distance[:, :, None] * m[None, None, :]
114
+ biases = biases.permute(2, 0, 1)[None, :, :T, :T]
115
+ biases = biases.repeat(B, 1, 1, 1)
116
+ return biases.to(dtype).contiguous()
117
+
118
+
119
+ class Attention(nn.Module):
120
+ def __init__(self, config, layer_idx=None):
121
+ super().__init__()
122
+ self.mask_value = None
123
+
124
+ self.embed_dim = config.hidden_size
125
+ self.num_heads = config.num_attention_heads
126
+ self.head_dim = self.embed_dim // self.num_heads
127
+ self.kv_attn_heads = 1
128
+
129
+ self.scale = self.head_dim ** -0.5
130
+
131
+ if self.head_dim * self.num_heads != self.embed_dim:
132
+ raise ValueError(
133
+ f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
134
+ f" {self.num_heads})."
135
+ )
136
+
137
+ self.layer_idx = layer_idx
138
+ self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
139
+ self.scale_attention_softmax_in_fp32 = (
140
+ config.scale_attention_softmax_in_fp32 and config.attention_softmax_in_fp32
141
+ )
142
+
143
+ self.q = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
144
+ self.k = nn.Linear(self.embed_dim, self.head_dim, bias=False)
145
+ self.v = nn.Linear(self.embed_dim, self.head_dim, bias=False)
146
+ self.c_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
147
+
148
+ def _attn(self, query, key, value, attention_mask=None, alibi=None):
149
+ dtype = query.dtype
150
+ softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else dtype
151
+ upcast = dtype != softmax_dtype
152
+ unscale = self.layer_idx + 1 if self.scale_attention_softmax_in_fp32 and upcast else 1
153
+
154
+ attn_weights = alibi + torch.matmul(query * self.scale, key)
155
+
156
+ if upcast:
157
+ if attention_mask is None:
158
+ attn_weights = upcast_softmax(attn_weights, unscale, softmax_dtype)
159
+ else:
160
+ mask_value = self._get_mask_value(attn_weights.device, softmax_dtype)
161
+ attn_weights = upcast_masked_softmax(attn_weights, attention_mask, mask_value, unscale, softmax_dtype)
162
+ else:
163
+ if attention_mask is not None:
164
+ attn_weights = torch.masked_fill(attn_weights, attention_mask, -10000)
165
+
166
+ attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
167
+
168
+ attn_output = torch.matmul(attn_weights, value)
169
+
170
+ return attn_output, attn_weights
171
+
172
+ def _split_heads(self, tensor):
173
+ new_shape = tensor.shape[:-1] + (self.num_heads, self.head_dim)
174
+ tensor = tensor.view(new_shape)
175
+ return tensor.permute(0, 2, 1, 3)
176
+
177
+ def forward(
178
+ self,
179
+ hidden_states: torch.Tensor,
180
+ layer_past: Optional[torch.Tensor] = None,
181
+ attention_mask: Optional[torch.Tensor] = None,
182
+ alibi: Optional[torch.Tensor] = None,
183
+ use_cache: Optional[bool] = False,
184
+ output_attentions: Optional[bool] = False,
185
+ ) -> Union[
186
+ Tuple[torch.Tensor, Optional[torch.Tensor]],
187
+ Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]],
188
+ ]:
189
+ b, t, _ = hidden_states.shape
190
+ query = self.q(hidden_states)
191
+ key = self.k(hidden_states)
192
+ value = self.v(hidden_states)
193
+ query = self._split_heads(query)
194
+ key = key.view(b, t, self.kv_attn_heads, self.head_dim).permute(0, 2, 1, 3)
195
+ value = value.view(b, t, self.kv_attn_heads, self.head_dim).permute(0, 2, 1, 3)
196
+
197
+ if layer_past is not None:
198
+ past_key, past_value = layer_past
199
+ key = torch.cat((past_key, key), dim=-2)
200
+ value = torch.cat((past_value, value), dim=-2)
201
+
202
+ if use_cache is True:
203
+ present = (key, value)
204
+ else:
205
+ present = None
206
+
207
+ attn_output, attn_weights = self._attn(query, key.transpose(-1, -2), value, attention_mask, alibi)
208
+
209
+ attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape)
210
+ attn_output = self.c_proj(attn_output)
211
+
212
+ outputs = (attn_output, present)
213
+ if output_attentions:
214
+ outputs += (attn_weights,)
215
+
216
+ return outputs # a, present, (attentions)
217
+
218
+
219
+ class MLP(nn.Module):
220
+ def __init__(self, intermediate_size, config, multiple_of: int = 256):
221
+ super().__init__()
222
+ embed_dim = config.hidden_size
223
+ hidden_dim = intermediate_size
224
+ hidden_dim = int(2 * hidden_dim / 3)
225
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
226
+ self.linear_1 = nn.Linear(embed_dim, hidden_dim, bias=False)
227
+ self.linear_3 = nn.Linear(embed_dim, hidden_dim, bias=False)
228
+ self.c_proj = nn.Linear(hidden_dim, embed_dim, bias=False)
229
+
230
+ def forward(self, x: Optional[Tuple[torch.Tensor]]) -> torch.Tensor:
231
+ x1 = F.silu(self.linear_1(x))
232
+ x2 = self.linear_3(x)
233
+ x = self.c_proj(x1 * x2)
234
+ return x
235
+
236
+
237
+ class LayerNormNoBias(nn.Module):
238
+
239
+ def __init__(self, shape: int, eps: float = 1e-5):
240
+ super().__init__()
241
+ self.shape = (shape,)
242
+ self.eps = eps
243
+ self.weight = nn.Parameter(torch.empty(self.shape))
244
+
245
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
246
+ return F.layer_norm(x, self.shape, self.weight, None, self.eps)
247
+
248
+
249
+ class GPTRefactBlock(nn.Module):
250
+ def __init__(self, config, layer_idx=None):
251
+ super().__init__()
252
+ hidden_size = config.hidden_size
253
+ self.inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
254
+
255
+ self.ln_1 = LayerNormNoBias(hidden_size, eps=config.layer_norm_epsilon)
256
+ self.attn = Attention(config, layer_idx=layer_idx)
257
+ self.ln_2 = LayerNormNoBias(hidden_size, eps=config.layer_norm_epsilon)
258
+
259
+ self.mlp = MLP(self.inner_dim, config)
260
+
261
+ def forward(
262
+ self,
263
+ hidden_states: Optional[Tuple[torch.Tensor]],
264
+ layer_past: Optional[torch.Tensor] = None,
265
+ attention_mask: Optional[torch.Tensor] = None,
266
+ alibi: Optional[torch.Tensor] = None,
267
+ use_cache: Optional[bool] = False,
268
+ output_attentions: Optional[bool] = False,
269
+ ) -> Union[
270
+ Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
271
+ ]:
272
+ hidden_states_norm = self.ln_1(hidden_states)
273
+ attn_outputs = self.attn(
274
+ hidden_states_norm,
275
+ layer_past=layer_past,
276
+ attention_mask=attention_mask,
277
+ alibi=alibi,
278
+ use_cache=use_cache,
279
+ output_attentions=output_attentions,
280
+ )
281
+ attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
282
+ outputs = attn_outputs[1:]
283
+ # residual connection
284
+ mix = attn_output + hidden_states
285
+
286
+ norm_mix = self.ln_2(mix)
287
+ feed_forward_hidden_states = self.mlp(norm_mix)
288
+ # residual connection
289
+ hidden_states = mix + feed_forward_hidden_states
290
+
291
+ if use_cache:
292
+ outputs = (hidden_states,) + outputs
293
+ else:
294
+ outputs = (hidden_states,) + outputs[1:]
295
+
296
+ return outputs # hidden_states, present, (attentions, cross_attentions)
297
+
298
+
299
+ class GPTRefactPreTrainedModel(PreTrainedModel):
300
+ config_class = GPTRefactConfig
301
+ base_model_prefix = "transformer"
302
+ supports_gradient_checkpointing = True
303
+ _no_split_modules = ["GPTRefactBlock"]
304
+ _skip_keys_device_placement = "past_key_values"
305
+
306
+ def __init__(self, *inputs, **kwargs):
307
+ super().__init__(*inputs, **kwargs)
308
+
309
+ def _init_weights(self, module):
310
+ if isinstance(module, (MLP, Attention)):
311
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
312
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
313
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
314
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
315
+ #
316
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
317
+ module.c_proj.weight.data.normal_(
318
+ mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))
319
+ )
320
+ module.c_proj._is_hf_initialized = True
321
+ elif isinstance(module, nn.Linear):
322
+ # Slightly different from the TF version which uses truncated_normal for initialization
323
+ # cf https://github.com/pytorch/pytorch/pull/5617
324
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
325
+ if module.bias is not None:
326
+ module.bias.data.zero_()
327
+ elif isinstance(module, nn.Embedding):
328
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
329
+ if module.padding_idx is not None:
330
+ module.weight.data[module.padding_idx].zero_()
331
+ elif isinstance(module, LayerNormNoBias):
332
+ module.weight.data.fill_(1.0)
333
+
334
+ def _set_gradient_checkpointing(self, module, value=False):
335
+ if isinstance(module, GPTRefactModel):
336
+ module.gradient_checkpointing = value
337
+
338
+
339
+ class GPTRefactModel(GPTRefactPreTrainedModel):
340
+ def __init__(self, config):
341
+ super().__init__(config)
342
+ self.embed_dim = config.hidden_size
343
+ self.num_heads = config.num_attention_heads
344
+ self.multi_query = config.multi_query
345
+ self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
346
+
347
+ self.h = nn.ModuleList([GPTRefactBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)])
348
+
349
+ self.max_positions = config.max_position_embeddings
350
+ self.register_buffer(
351
+ "bias", torch.tril(torch.ones((self.max_positions, self.max_positions), dtype=torch.bool)),
352
+ persistent=False
353
+ )
354
+
355
+ self.gradient_checkpointing = False
356
+
357
+ # Initialize weights and apply final processing
358
+ self.post_init()
359
+
360
+ @staticmethod
361
+ def _make_mask(seq_len: int, past_key_values_length: int):
362
+ # prompt
363
+ if past_key_values_length == 0:
364
+ mask = torch.ones((seq_len, seq_len + past_key_values_length), dtype=torch.bool)
365
+ mask = torch.triu(mask, 1)
366
+ else:
367
+ mask = torch.zeros((seq_len, seq_len + past_key_values_length), dtype=torch.bool)
368
+ return mask
369
+
370
+ def forward(
371
+ self,
372
+ input_ids: Optional[torch.Tensor] = None,
373
+ past_key_values: Optional[List[torch.Tensor]] = None,
374
+ attention_mask: Optional[torch.Tensor] = None,
375
+ inputs_embeds: Optional[torch.Tensor] = None,
376
+ use_cache: Optional[bool] = None,
377
+ output_attentions: Optional[bool] = None,
378
+ output_hidden_states: Optional[bool] = None,
379
+ return_dict: Optional[bool] = None,
380
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
381
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
382
+ output_hidden_states = (
383
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
384
+ )
385
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
386
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
387
+
388
+ if input_ids is not None and inputs_embeds is not None:
389
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
390
+ elif input_ids is not None:
391
+ input_shape = input_ids.size()
392
+ input_ids = input_ids.view(-1, input_shape[-1])
393
+ batch_size = input_ids.shape[0]
394
+ elif inputs_embeds is not None:
395
+ input_shape = inputs_embeds.size()[:-1]
396
+ batch_size = inputs_embeds.shape[0]
397
+ else:
398
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
399
+
400
+ if batch_size <= 0:
401
+ raise ValueError("batch_size has to be defined and > 0")
402
+
403
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
404
+
405
+ if past_key_values is None:
406
+ past_length = 0
407
+ past_key_values = tuple([None] * len(self.h))
408
+ else:
409
+ past_length = past_key_values[0][0].size(-2)
410
+
411
+ # Self-attention mask.
412
+ query_length = input_shape[-1]
413
+
414
+ seq_length_with_past = past_length + query_length
415
+ if attention_mask is None:
416
+ attention_mask = self._make_mask(query_length, past_length).to(device)
417
+ else:
418
+ attention_mask = attention_mask.to(device)
419
+
420
+ hidden_states = self.wte(input_ids) if inputs_embeds is None else inputs_embeds
421
+
422
+ alibi = get_alibi_biases(hidden_states.shape[0], seq_length_with_past,
423
+ self.num_heads, device, self.wte.weight.dtype)[:, :, -query_length:, :]
424
+
425
+ output_shape = input_shape + (hidden_states.size(-1),)
426
+
427
+ presents = [] if use_cache else None
428
+ all_self_attentions = () if output_attentions else None
429
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
430
+ all_hidden_states = () if output_hidden_states else None
431
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
432
+ if output_hidden_states:
433
+ all_hidden_states = all_hidden_states + (hidden_states,)
434
+
435
+ if self.gradient_checkpointing and self.training:
436
+
437
+ def create_custom_forward(module):
438
+ def custom_forward(*inputs):
439
+ # None for past_key_value
440
+ return module(*inputs, use_cache, output_attentions)
441
+
442
+ return custom_forward
443
+
444
+ outputs = torch.utils.checkpoint.checkpoint(
445
+ create_custom_forward(block),
446
+ hidden_states,
447
+ None,
448
+ attention_mask,
449
+ alibi
450
+ )
451
+ else:
452
+ outputs = block(
453
+ hidden_states,
454
+ layer_past=layer_past,
455
+ attention_mask=attention_mask,
456
+ alibi=alibi,
457
+ use_cache=use_cache,
458
+ output_attentions=output_attentions,
459
+ )
460
+
461
+ hidden_states = outputs[0]
462
+ if use_cache:
463
+ presents.append(outputs[1])
464
+
465
+ if output_attentions:
466
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
467
+ if self.config.add_cross_attention:
468
+ all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
469
+
470
+ hidden_states = hidden_states.view(output_shape)
471
+ # Add last hidden state
472
+ if output_hidden_states:
473
+ all_hidden_states = all_hidden_states + (hidden_states,)
474
+
475
+ if not return_dict:
476
+ return tuple(
477
+ v
478
+ for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
479
+ if v is not None
480
+ )
481
+
482
+ return BaseModelOutputWithPastAndCrossAttentions(
483
+ last_hidden_state=hidden_states,
484
+ past_key_values=presents,
485
+ hidden_states=all_hidden_states,
486
+ attentions=all_self_attentions,
487
+ cross_attentions=all_cross_attentions,
488
+ )
489
+
490
+
491
+ class GPTRefactForCausalLM(GPTRefactPreTrainedModel):
492
+ _tied_weights_keys = ["lm_head.weight", "ln_f.weight"]
493
+
494
+ def __init__(self, config):
495
+ super().__init__(config)
496
+ self.transformer = GPTRefactModel(config)
497
+ self.ln_f = LayerNormNoBias(self.transformer.embed_dim, eps=config.layer_norm_epsilon)
498
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
499
+
500
+ # Initialize weights and apply final processing
501
+ self.post_init()
502
+
503
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
504
+ if inputs_embeds is not None and past_key_values is None:
505
+ model_inputs = {"inputs_embeds": inputs_embeds}
506
+ else:
507
+ if past_key_values is not None:
508
+ model_inputs = {"input_ids": input_ids[..., -1:]}
509
+ else:
510
+ model_inputs = {"input_ids": input_ids}
511
+
512
+ model_inputs.update(
513
+ {
514
+ "past_key_values": past_key_values,
515
+ "use_cache": kwargs.get("use_cache"),
516
+ }
517
+ )
518
+ return model_inputs
519
+
520
+ def forward(
521
+ self,
522
+ input_ids: Optional[torch.Tensor] = None,
523
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
524
+ attention_mask: Optional[torch.Tensor] = None,
525
+ inputs_embeds: Optional[torch.Tensor] = None,
526
+ labels: Optional[torch.Tensor] = None,
527
+ use_cache: Optional[bool] = None,
528
+ output_attentions: Optional[bool] = None,
529
+ output_hidden_states: Optional[bool] = None,
530
+ return_dict: Optional[bool] = None,
531
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
532
+ r"""
533
+ labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
534
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
535
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
536
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
537
+ """
538
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
539
+
540
+ transformer_outputs = self.transformer(
541
+ input_ids,
542
+ past_key_values=past_key_values,
543
+ attention_mask=attention_mask,
544
+ inputs_embeds=inputs_embeds,
545
+ use_cache=use_cache,
546
+ output_attentions=output_attentions,
547
+ output_hidden_states=output_hidden_states,
548
+ return_dict=return_dict,
549
+ )
550
+ hidden_states = transformer_outputs[0]
551
+
552
+ x = self.ln_f(hidden_states)
553
+ lm_logits = self.lm_head(x)
554
+
555
+ loss = None
556
+ if labels is not None:
557
+ # Shift so that tokens < n predict n
558
+ shift_logits = lm_logits[..., :-1, :].contiguous()
559
+ shift_labels = labels[..., 1:].contiguous().to(shift_logits.device)
560
+ # Flatten the tokens
561
+ loss_fct = CrossEntropyLoss()
562
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
563
+
564
+ if not return_dict:
565
+ output = (lm_logits,) + transformer_outputs[1:]
566
+ return ((loss,) + output) if loss is not None else output
567
+
568
+ return CausalLMOutputWithCrossAttentions(
569
+ loss=loss,
570
+ logits=lm_logits,
571
+ past_key_values=transformer_outputs.past_key_values,
572
+ hidden_states=transformer_outputs.hidden_states,
573
+ attentions=transformer_outputs.attentions,
574
+ cross_attentions=transformer_outputs.cross_attentions,
575
+ )
576
+
577
+ @staticmethod
578
+ def _reorder_cache(
579
+ past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
580
+ ) -> Tuple[Tuple[torch.Tensor]]:
581
+ """
582
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
583
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
584
+ beam_idx at every generation step.
585
+ """
586
+ return tuple(layer_past.index_select(0, beam_idx.to(layer_past.device)) for layer_past in past_key_values)
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8c9761aabc16466fdf738d4fe42f12ee6844a360db07bde307ca808d0bfb6b8a
3
+ size 6343461637