numb3r3 commited on
Commit
67b96c4
1 Parent(s): e72879e

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
README.md ADDED
@@ -0,0 +1 @@
 
 
1
+
block.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/block.py
2
+ # Commit id: abbc1311731867310635f9edc2a9ec18317c8c48
3
+
4
+ # Copyright (c) 2024, Tri Dao.
5
+
6
+ from functools import partial
7
+ from typing import Optional
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from torch import Tensor
13
+
14
+ from .stochastic_depth import StochasticDepth
15
+ from .mha import MHA
16
+ from .mlp import Mlp
17
+
18
+ try:
19
+ from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm
20
+ except ImportError:
21
+ layer_norm_fn, RMSNorm = None, None
22
+
23
+
24
+ class Block(nn.Module):
25
+ def __init__(
26
+ self,
27
+ dim,
28
+ mixer_cls=None,
29
+ mlp_cls=None,
30
+ norm_cls=nn.LayerNorm,
31
+ dropout_cls=nn.Dropout,
32
+ prenorm=True,
33
+ resid_dropout1=0.0,
34
+ resid_dropout2=0.0,
35
+ drop_path1=0.0,
36
+ drop_path2=0.0,
37
+ fused_dropout_add_ln=False,
38
+ return_residual=False,
39
+ residual_in_fp32=False,
40
+ sequence_parallel=False,
41
+ mark_shared_params=False,
42
+ ):
43
+ """
44
+ For prenorm=True, this Block has a slightly different structure compared to a regular
45
+ prenorm Transformer block.
46
+ The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add.
47
+ [Ref: https://arxiv.org/abs/2002.04745]
48
+ Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both
49
+ the hidden_states (output of the MLP) and the residual.
50
+ This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
51
+ The residual needs to be provided (except for the very first block).
52
+
53
+ For prenorm=False, this Block has the same structure as a regular postnorm Transformer
54
+ block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN.
55
+
56
+ return_residual: whether each of the sub-layers (mixer and mlp) will return the residual.
57
+ This is for performance reason: for post-norm architecture, returning the input allows us
58
+ to fuse the backward of nn.Linear with the residual connection.
59
+ """
60
+ super().__init__()
61
+ self.prenorm = prenorm
62
+ self.fused_dropout_add_ln = fused_dropout_add_ln
63
+ self.return_residual = return_residual
64
+ self.residual_in_fp32 = residual_in_fp32
65
+ if self.residual_in_fp32:
66
+ assert self.prenorm, "residual_in_fp32 is only compatible with prenorm=True"
67
+ if mixer_cls is None:
68
+ mixer_cls = partial(MHA, num_heads=dim // 64)
69
+ if mlp_cls is None:
70
+ mlp_cls = partial(Mlp, hidden_features=4 * dim)
71
+ self.mixer = mixer_cls(dim)
72
+ self.dropout1 = dropout_cls(resid_dropout1)
73
+ self.drop_path1 = StochasticDepth(drop_path1, mode="row")
74
+ self.norm1 = norm_cls(dim)
75
+ self.mlp = mlp_cls(dim)
76
+ if not isinstance(self.mlp, nn.Identity):
77
+ self.dropout2 = dropout_cls(resid_dropout2)
78
+ self.drop_path2 = StochasticDepth(drop_path2, mode="row")
79
+ self.norm2 = norm_cls(dim)
80
+
81
+ if self.fused_dropout_add_ln:
82
+ assert layer_norm_fn is not None, "Triton is not installed"
83
+ assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
84
+ self.dropout1, nn.Dropout
85
+ )
86
+
87
+ # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
88
+ # then the input to each worker in the tensor parallel group will be different.
89
+ # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
90
+ # For now this is not an issue because we always use sequence_parallel=True during training
91
+ # and only use sequence_parallel=False during inference.
92
+
93
+ # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
94
+ if sequence_parallel:
95
+ for p in self.norm1.parameters():
96
+ p._sequence_parallel = True
97
+ if hasattr(self, "norm2"):
98
+ for p in self.norm2.parameters():
99
+ p._sequence_parallel = True
100
+ # Mark the norm parameters as "shared_params" so that we sync their values at init.
101
+ if mark_shared_params:
102
+ for p in self.norm1.parameters():
103
+ p._shared_params = True
104
+ if hasattr(self, "norm2"):
105
+ for p in self.norm2.parameters():
106
+ p._shared_params = True
107
+
108
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
109
+ return self.mixer.allocate_inference_cache(
110
+ batch_size, max_seqlen, dtype=dtype, **kwargs
111
+ )
112
+
113
+ def forward(
114
+ self,
115
+ hidden_states: Tensor,
116
+ residual: Optional[Tensor] = None,
117
+ mixer_subset=None,
118
+ mixer_kwargs=None,
119
+ ):
120
+ r"""Pass the input through the encoder layer.
121
+
122
+ Args:
123
+ hidden_states: the sequence to the encoder layer (required).
124
+ residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
125
+ mixer_subset: for cross-attention only. If not None, will take a subset of x
126
+ before applying the query projection. Useful for e.g., ViT where we only care
127
+ about the CLS token in the last layer.
128
+ """
129
+ if self.prenorm:
130
+ if not self.fused_dropout_add_ln:
131
+ dropped = self.drop_path1(self.dropout1(hidden_states))
132
+ residual = (dropped + residual) if residual is not None else dropped
133
+ hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
134
+ if self.residual_in_fp32:
135
+ residual = residual.to(torch.float32)
136
+ else:
137
+ if self.drop_path1.p == 0 or not self.training:
138
+ rowscale1 = None
139
+ else:
140
+ rowscale1 = self.drop_path1(
141
+ torch.ones(
142
+ hidden_states.shape[:-1],
143
+ device=hidden_states.device,
144
+ dtype=hidden_states.dtype,
145
+ )
146
+ )
147
+ hidden_states, residual = layer_norm_fn(
148
+ hidden_states,
149
+ self.norm1.weight,
150
+ self.norm1.bias,
151
+ residual=residual,
152
+ eps=self.norm1.eps,
153
+ dropout_p=self.dropout1.p if self.training else 0.0,
154
+ rowscale=rowscale1,
155
+ prenorm=True,
156
+ residual_in_fp32=self.residual_in_fp32,
157
+ is_rms_norm=isinstance(self.norm1, RMSNorm),
158
+ )
159
+ if mixer_kwargs is None:
160
+ mixer_kwargs = {}
161
+ if mixer_subset is not None:
162
+ mixer_kwargs["mixer_subset"] = mixer_subset
163
+ hidden_states = self.mixer(hidden_states, **mixer_kwargs)
164
+ if mixer_subset is not None:
165
+ residual = residual[:, mixer_subset]
166
+ if not isinstance(self.mlp, nn.Identity):
167
+ if not self.fused_dropout_add_ln:
168
+ dropped = self.drop_path2(self.dropout2(hidden_states))
169
+ residual = (dropped + residual) if residual is not None else dropped
170
+ hidden_states = self.norm2(
171
+ residual.to(dtype=self.norm2.weight.dtype)
172
+ )
173
+ if self.residual_in_fp32:
174
+ residual = residual.to(torch.float32)
175
+ else:
176
+ if self.drop_path2.p == 0 or not self.training:
177
+ rowscale2 = None
178
+ else:
179
+ rowscale2 = self.drop_path2(
180
+ torch.ones(
181
+ hidden_states.shape[:-1],
182
+ device=hidden_states.device,
183
+ dtype=hidden_states.dtype,
184
+ )
185
+ )
186
+ hidden_states, residual = layer_norm_fn(
187
+ hidden_states,
188
+ self.norm2.weight,
189
+ self.norm2.bias,
190
+ residual=residual,
191
+ eps=self.norm2.eps,
192
+ dropout_p=self.dropout2.p if self.training else 0.0,
193
+ rowscale=rowscale2,
194
+ prenorm=True,
195
+ residual_in_fp32=self.residual_in_fp32,
196
+ is_rms_norm=isinstance(self.norm2, RMSNorm),
197
+ )
198
+ hidden_states = self.mlp(hidden_states)
199
+ return hidden_states, residual
200
+ else:
201
+ assert residual is None
202
+ mixer_out = self.mixer(
203
+ hidden_states, **(mixer_kwargs if mixer_kwargs is not None else {})
204
+ )
205
+ if self.return_residual: # mixer out is actually a pair here
206
+ mixer_out, hidden_states = mixer_out
207
+ if not self.fused_dropout_add_ln:
208
+ hidden_states = self.norm1(
209
+ (self.drop_path1(self.dropout1(mixer_out)) + hidden_states).to(
210
+ dtype=self.norm1.weight.dtype
211
+ )
212
+ )
213
+ else:
214
+ if self.drop_path1.p == 0 or not self.training:
215
+ rowscale1 = None
216
+ else:
217
+ rowscale1 = self.drop_path1(
218
+ torch.ones(
219
+ mixer_out.shape[:-1],
220
+ device=mixer_out.device,
221
+ dtype=mixer_out.dtype,
222
+ )
223
+ )
224
+ hidden_states = layer_norm_fn(
225
+ mixer_out,
226
+ self.norm1.weight,
227
+ self.norm1.bias,
228
+ residual=hidden_states,
229
+ eps=self.norm1.eps,
230
+ dropout_p=self.dropout1.p if self.training else 0.0,
231
+ rowscale=rowscale1,
232
+ prenorm=False,
233
+ is_rms_norm=isinstance(self.norm1, RMSNorm),
234
+ )
235
+ if not isinstance(self.mlp, nn.Identity):
236
+ mlp_out = self.mlp(hidden_states)
237
+ if self.return_residual: # mlp out is actually a pair here
238
+ mlp_out, hidden_states = mlp_out
239
+ if not self.fused_dropout_add_ln:
240
+ hidden_states = self.norm2(
241
+ (self.drop_path2(self.dropout2(mlp_out)) + hidden_states).to(
242
+ dtype=self.norm2.weight.dtype
243
+ )
244
+ )
245
+ else:
246
+ if self.drop_path2.p == 0 or not self.training:
247
+ rowscale2 = None
248
+ else:
249
+ rowscale2 = self.drop_path2(
250
+ torch.ones(
251
+ mlp_out.shape[:-1],
252
+ device=mlp_out.device,
253
+ dtype=mlp_out.dtype,
254
+ )
255
+ )
256
+ hidden_states = layer_norm_fn(
257
+ mlp_out,
258
+ self.norm2.weight,
259
+ self.norm2.bias,
260
+ residual=hidden_states,
261
+ eps=self.norm2.eps,
262
+ dropout_p=self.dropout2.p if self.training else 0.0,
263
+ rowscale=rowscale2,
264
+ prenorm=False,
265
+ is_rms_norm=isinstance(self.norm2, RMSNorm),
266
+ )
267
+ return hidden_states
268
+
269
+
270
+ class ParallelBlock(nn.Module):
271
+ """The attention (mixer) and MLP blocks are done in parallel, similar to GPT-J, GPT-NeoX,
272
+ and PaLM.
273
+ """
274
+
275
+ def __init__(
276
+ self,
277
+ dim,
278
+ mixer_cls=None,
279
+ mlp_cls=None,
280
+ norm_cls=nn.LayerNorm,
281
+ dropout_cls=nn.Dropout,
282
+ resid_dropout1=0.0,
283
+ resid_dropout2=0.0,
284
+ tied_norm=False,
285
+ fused_dropout_add_ln=False,
286
+ residual_in_fp32=False,
287
+ sequence_parallel=False,
288
+ mark_shared_params=False,
289
+ ):
290
+ """
291
+ This Block has a slightly different structure compared to a regular
292
+ prenorm Transformer block.
293
+ The standard block is: LN -> MHA / MLP -> Dropout -> Add.
294
+ [Ref: https://arxiv.org/abs/2002.04745]
295
+ Here we have: Dropout -> Add -> LN -> MHA / MLP, returning both
296
+ the hidden_states (output1 of the MHA / MLP) and the residual.
297
+ This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
298
+ The residual needs to be provided (except for the very first block).
299
+ """
300
+ super().__init__()
301
+ self.tied_norm = tied_norm
302
+ self.fused_dropout_add_ln = fused_dropout_add_ln
303
+ self.residual_in_fp32 = residual_in_fp32
304
+ if mixer_cls is None:
305
+ mixer_cls = partial(MHA, num_heads=dim // 64)
306
+ if mlp_cls is None:
307
+ mlp_cls = partial(Mlp, hidden_features=4 * dim)
308
+ self.mixer = mixer_cls(dim)
309
+ self.dropout1 = dropout_cls(resid_dropout1)
310
+ self.norm1 = norm_cls(dim)
311
+ self.mlp = mlp_cls(dim)
312
+ self.dropout2 = dropout_cls(resid_dropout2)
313
+ if not self.tied_norm:
314
+ self.norm2 = norm_cls(dim)
315
+
316
+ if self.fused_dropout_add_ln:
317
+ assert layer_norm_fn is not None, "Triton is not installed"
318
+ assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
319
+ self.dropout1, nn.Dropout
320
+ )
321
+
322
+ # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
323
+ # then the input to each worker in the tensor parallel group will be different.
324
+ # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
325
+ # For now this is not an issue because we always use sequence_parallel=True during training
326
+ # and only use sequence_parallel=False during inference.
327
+
328
+ # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
329
+ if sequence_parallel:
330
+ for p in self.norm1.parameters():
331
+ p._sequence_parallel = True
332
+ if hasattr(self, "norm2"):
333
+ for p in self.norm2.parameters():
334
+ p._sequence_parallel = True
335
+ # Mark the norm parameters as "shared_params" so that we sync their values at init.
336
+ if mark_shared_params:
337
+ for p in self.norm1.parameters():
338
+ p._shared_params = True
339
+ if hasattr(self, "norm2"):
340
+ for p in self.norm2.parameters():
341
+ p._shared_params = True
342
+
343
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
344
+ return self.mixer.allocate_inference_cache(
345
+ batch_size, max_seqlen, dtype=dtype, **kwargs
346
+ )
347
+
348
+ def forward(
349
+ self,
350
+ hidden_states1: Tensor,
351
+ hidden_states2: Optional[Tensor] = None,
352
+ residual: Optional[Tensor] = None,
353
+ mixer_kwargs=None,
354
+ ):
355
+ r"""Pass the input through the encoder layer.
356
+
357
+ Args:
358
+ hidden_states1: the output of the previous attention (mixer) or embedding layer.
359
+ hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1).
360
+ residual.
361
+ """
362
+ # TODO: Ideally we should only do the allgather / allreduce once for
363
+ # the Linear to MLP & Attention
364
+ if not self.fused_dropout_add_ln:
365
+ dropped1 = self.dropout1(hidden_states1)
366
+ # For the very 1st block, we only want 1 dropout, not two different dropouts
367
+ if hidden_states2 is not None:
368
+ dropped2 = self.dropout2(hidden_states2)
369
+ residual = (
370
+ (residual + dropped1 + dropped2)
371
+ if residual is not None
372
+ else dropped1 + dropped2
373
+ )
374
+ else:
375
+ residual = (residual + dropped1) if residual is not None else dropped1
376
+ hidden_states1 = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
377
+ hidden_states2 = (
378
+ self.norm2(residual.to(dtype=self.norm2.weight.dtype))
379
+ if not self.tied_norm
380
+ else hidden_states1
381
+ )
382
+ if self.residual_in_fp32:
383
+ residual = residual.to(torch.float32)
384
+ else:
385
+ weight2, bias2 = (
386
+ (self.norm2.weight, self.norm2.bias)
387
+ if not self.tied_norm
388
+ else (None, None)
389
+ )
390
+ hidden_states1, *rest, residual = layer_norm_fn(
391
+ hidden_states1,
392
+ self.norm1.weight,
393
+ self.norm1.bias,
394
+ residual=residual,
395
+ x1=hidden_states2,
396
+ weight1=weight2,
397
+ bias1=bias2,
398
+ eps=self.norm1.eps,
399
+ dropout_p=self.dropout1.p if self.training else 0.0,
400
+ prenorm=True,
401
+ residual_in_fp32=self.residual_in_fp32,
402
+ is_rms_norm=isinstance(self.norm1, RMSNorm),
403
+ )
404
+ if self.tied_norm:
405
+ hidden_states2 = hidden_states1
406
+ else:
407
+ (hidden_states2,) = rest
408
+ if mixer_kwargs is None:
409
+ mixer_kwargs = {}
410
+ hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs)
411
+ hidden_states2 = self.mlp(hidden_states2)
412
+ return hidden_states1, hidden_states2, residual
config.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "jinaai/jina-reranker-v2-base-multilingual",
3
+ "architectures": ["XLMRobertaForSequenceClassification"],
4
+ "attention_probs_dropout_prob": 0.1,
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_xlm_roberta.XLMRobertaFlashConfig",
7
+ "AutoModel": "modeling_xlm_roberta.XLMRobertaModel",
8
+ "AutoModelForSequenceClassification": "modeling_xlm_roberta.XLMRobertaForSequenceClassification"
9
+ },
10
+ "bos_token_id": 0,
11
+ "classifier_dropout": null,
12
+ "emb_pooler": null,
13
+ "eos_token_id": 2,
14
+ "hidden_act": "gelu",
15
+ "hidden_dropout_prob": 0.1,
16
+ "hidden_size": 768,
17
+ "num_labels": 1,
18
+ "id2label": {
19
+ "0": "LABEL_0"
20
+ },
21
+ "initializer_range": 0.02,
22
+ "intermediate_size": 3072,
23
+ "label2id": {
24
+ "LABEL_0": 0
25
+ },
26
+ "layer_norm_eps": 1e-5,
27
+ "max_position_embeddings": 1026,
28
+ "num_attention_heads": 12,
29
+ "num_hidden_layers": 12,
30
+ "output_past": true,
31
+ "pad_token_id": 1,
32
+ "position_embedding_type": "absolute",
33
+ "torch_dtype": "bfloat16",
34
+ "transformers_version": "4.40.0",
35
+ "type_vocab_size": 1,
36
+ "use_cache": false,
37
+ "use_flash_attn": true,
38
+ "vocab_size": 250002
39
+ }
configuration_xlm_roberta.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ import torch
3
+
4
+ class XLMRobertaFlashConfig(PretrainedConfig):
5
+ def __init__(
6
+ self,
7
+ vocab_size=30522,
8
+ hidden_size=768,
9
+ num_hidden_layers=12,
10
+ num_attention_heads=12,
11
+ intermediate_size=3072,
12
+ hidden_act="gelu",
13
+ hidden_dropout_prob=0.1,
14
+ attention_probs_dropout_prob=0.1,
15
+ max_position_embeddings=512,
16
+ type_vocab_size=2,
17
+ initializer_range=0.02,
18
+ layer_norm_eps=1e-12,
19
+ pad_token_id=1,
20
+ bos_token_id=0,
21
+ eos_token_id=2,
22
+ position_embedding_type="absolute",
23
+ use_cache=True,
24
+ classifier_dropout=None,
25
+ lora_adaptations=None,
26
+ lora_rank=4,
27
+ lora_dropout_p=0.0,
28
+ lora_alpha=1,
29
+ lora_main_params_trainable=False,
30
+ load_trained_adapters=False,
31
+ use_flash_attn=True,
32
+ torch_dtype=None,
33
+ emb_pooler=None,
34
+ matryoshka_dimensions=None,
35
+ truncate_dim=None,
36
+ **kwargs,
37
+ ):
38
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
39
+
40
+
41
+ self.vocab_size = vocab_size
42
+ self.hidden_size = hidden_size
43
+ self.num_hidden_layers = num_hidden_layers
44
+ self.num_attention_heads = num_attention_heads
45
+ self.hidden_act = hidden_act
46
+ self.intermediate_size = intermediate_size
47
+ self.hidden_dropout_prob = hidden_dropout_prob
48
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
49
+ self.max_position_embeddings = max_position_embeddings
50
+ self.type_vocab_size = type_vocab_size
51
+ self.initializer_range = initializer_range
52
+ self.layer_norm_eps = layer_norm_eps
53
+ self.position_embedding_type = position_embedding_type
54
+ self.use_cache = use_cache
55
+ self.classifier_dropout = classifier_dropout
56
+ self.load_trained_adapters = load_trained_adapters
57
+ self.lora_adaptations = lora_adaptations
58
+ self.lora_rank = lora_rank
59
+ self.lora_dropout_p = lora_dropout_p
60
+ self.lora_alpha = lora_alpha
61
+ self.lora_main_params_trainable = lora_main_params_trainable
62
+ self.use_flash_attn = use_flash_attn
63
+ self.emb_pooler = emb_pooler
64
+ self.matryoshka_dimensions = matryoshka_dimensions
65
+ self.truncate_dim = truncate_dim
66
+ if torch_dtype and hasattr(torch, torch_dtype) and type(getattr(torch, torch_dtype)) is torch.dtype:
67
+ self.torch_dtype = getattr(torch, torch_dtype)
68
+ else:
69
+ self.torch_dtype = torch_dtype
embedding.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/embedding.py
2
+ # Commit id: f1a73d074002226c42ce65a1df170ecff9f022c0
3
+
4
+ # Copyright (c) 2022, Tri Dao.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from einops import rearrange
9
+ from torch import Tensor
10
+
11
+ from transformers.models.xlm_roberta.modeling_xlm_roberta import create_position_ids_from_input_ids
12
+
13
+
14
+ class XLMRobertaEmbeddings(nn.Module):
15
+ def __init__(
16
+ self,
17
+ embed_dim,
18
+ vocab_size,
19
+ max_position_embeddings,
20
+ type_vocab_size,
21
+ padding_idx=None,
22
+ device=None,
23
+ dtype=None,
24
+ ):
25
+ """
26
+ If max_position_embeddings <= 0, there's no position embeddings
27
+ If type_vocab_size <= 0, there's no token type embeddings
28
+ """
29
+ factory_kwargs = {"device": device, "dtype": dtype}
30
+ super().__init__()
31
+ self.word_embeddings = nn.Embedding(
32
+ vocab_size, embed_dim, padding_idx=padding_idx, **factory_kwargs
33
+ )
34
+ self.max_position_embeddings = max_position_embeddings
35
+ self.type_vocab_size = type_vocab_size
36
+ if self.max_position_embeddings > 0:
37
+ self.position_embeddings = nn.Embedding(
38
+ max_position_embeddings, embed_dim, **factory_kwargs
39
+ )
40
+ if self.type_vocab_size > 0:
41
+ self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, **factory_kwargs)
42
+
43
+ def forward(self, input_ids, position_ids=None, token_type_ids=None):
44
+ """
45
+ input_ids: (batch, seqlen)
46
+ position_ids: (batch, seqlen)
47
+ token_type_ids: (batch, seqlen)
48
+ """
49
+ batch_size, seqlen = input_ids.shape
50
+ embeddings = self.word_embeddings(input_ids)
51
+ if self.max_position_embeddings > 0:
52
+ if position_ids is None:
53
+ position_ids = create_position_ids_from_input_ids(input_ids, padding_idx=self.word_embeddings.padding_idx).to(input_ids.device)
54
+ # position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
55
+ position_embeddings = self.position_embeddings(position_ids)
56
+ embeddings = embeddings + position_embeddings
57
+ if self.type_vocab_size > 0:
58
+ if token_type_ids is None:
59
+ token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
60
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
61
+ embeddings = embeddings + token_type_embeddings
62
+ return embeddings
mha.py ADDED
@@ -0,0 +1,662 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+ # Adapted from https://github.com/Dao-AILab/flash-attention/pull/556
3
+
4
+ import math
5
+ from functools import partial
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from einops import rearrange, repeat
10
+
11
+ try:
12
+ from flash_attn import (
13
+ flash_attn_kvpacked_func,
14
+ flash_attn_qkvpacked_func,
15
+ flash_attn_varlen_kvpacked_func,
16
+ flash_attn_varlen_qkvpacked_func,
17
+ flash_attn_with_kvcache,
18
+ )
19
+ except ImportError:
20
+ flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
21
+ flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
22
+ flash_attn_with_kvcache = None
23
+
24
+ try:
25
+ from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, RowParallelLinear
26
+ except ImportError:
27
+ FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
28
+
29
+
30
+ class FlashSelfAttention(nn.Module):
31
+ """Implement the scaled dot product attention with softmax.
32
+ Arguments
33
+ ---------
34
+ softmax_scale: The temperature to use for the softmax attention.
35
+ (default: 1/sqrt(d_keys) where d_keys is computed at
36
+ runtime)
37
+ attention_dropout: The dropout rate to apply to the attention
38
+ (default: 0.0)
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ causal=False,
44
+ softmax_scale=None,
45
+ attention_dropout=0.0,
46
+ window_size=(-1, -1),
47
+ deterministic=False,
48
+ ):
49
+ super().__init__()
50
+ assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed"
51
+ assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
52
+ self.causal = causal
53
+ self.softmax_scale = softmax_scale
54
+ self.drop = nn.Dropout(attention_dropout)
55
+ self.window_size = window_size
56
+ self.deterministic = deterministic
57
+
58
+ def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
59
+ """Implements the multihead softmax attention.
60
+ Arguments
61
+ ---------
62
+ qkv: The tensor containing the query, key, and value.
63
+ If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).
64
+ If cu_seqlens is not None and max_seqlen is not None, then qkv has shape
65
+ (total, 3, H, D), where total is the sum of the sequence lengths in the batch.
66
+ causal: if passed, will override self.causal
67
+ cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
68
+ of the sequences in the batch, used to index into qkv.
69
+ max_seqlen: int. Maximum sequence length in the batch.
70
+ Returns:
71
+ --------
72
+ out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None,
73
+ else (B, S, H, D).
74
+ """
75
+ assert qkv.dtype in [torch.float16, torch.bfloat16]
76
+ assert qkv.is_cuda
77
+ causal = self.causal if causal is None else causal
78
+ unpadded = cu_seqlens is not None
79
+
80
+ if unpadded:
81
+ assert cu_seqlens.dtype == torch.int32
82
+ assert max_seqlen is not None
83
+ assert isinstance(max_seqlen, int)
84
+ return flash_attn_varlen_qkvpacked_func(
85
+ qkv,
86
+ cu_seqlens,
87
+ max_seqlen,
88
+ self.drop.p if self.training else 0.0,
89
+ softmax_scale=self.softmax_scale,
90
+ causal=causal,
91
+ alibi_slopes=None,
92
+ window_size=self.window_size,
93
+ deterministic=self.deterministic,
94
+ )
95
+ else:
96
+ return flash_attn_qkvpacked_func(
97
+ qkv,
98
+ self.drop.p if self.training else 0.0,
99
+ softmax_scale=self.softmax_scale,
100
+ causal=causal,
101
+ alibi_slopes=None,
102
+ window_size=self.window_size,
103
+ deterministic=self.deterministic,
104
+ )
105
+
106
+
107
+ class FlashCrossAttention(nn.Module):
108
+ """Implement the scaled dot product attention with softmax.
109
+ Arguments
110
+ ---------
111
+ softmax_scale: The temperature to use for the softmax attention.
112
+ (default: 1/sqrt(d_keys) where d_keys is computed at
113
+ runtime)
114
+ attention_dropout: The dropout rate to apply to the attention
115
+ (default: 0.0)
116
+ """
117
+
118
+ def __init__(
119
+ self,
120
+ causal=False,
121
+ softmax_scale=None,
122
+ attention_dropout=0.0,
123
+ window_size=(-1, -1),
124
+ deterministic=False,
125
+ ):
126
+ super().__init__()
127
+ assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed"
128
+ assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
129
+ self.causal = causal
130
+ self.softmax_scale = softmax_scale
131
+ self.drop = nn.Dropout(attention_dropout)
132
+ self.window_size = window_size
133
+ self.deterministic = deterministic
134
+
135
+ def forward(
136
+ self,
137
+ q,
138
+ kv,
139
+ causal=None,
140
+ cu_seqlens=None,
141
+ max_seqlen=None,
142
+ cu_seqlens_k=None,
143
+ max_seqlen_k=None,
144
+ ):
145
+ """Implements the multihead softmax attention.
146
+ Arguments
147
+ ---------
148
+ q: The tensor containing the query. (B, Sq, H, D)
149
+ kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
150
+ causal: if passed, will override self.causal
151
+ cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
152
+ of the sequences in the batch, used to index into q.
153
+ max_seqlen: int. Maximum sequence length in the batch of q.
154
+ cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
155
+ of the sequences in the batch, used to index into kv.
156
+ max_seqlen_k: int. Maximum sequence length in the batch of k and v.
157
+ """
158
+ assert q.dtype in [torch.float16, torch.bfloat16]
159
+ assert q.is_cuda and kv.is_cuda
160
+ causal = self.causal if causal is None else causal
161
+ unpadded = cu_seqlens is not None
162
+
163
+ if unpadded:
164
+ assert cu_seqlens.dtype == torch.int32
165
+ assert max_seqlen is not None
166
+ assert isinstance(max_seqlen, int)
167
+ assert cu_seqlens_k is not None
168
+ assert cu_seqlens_k.dtype == torch.int32
169
+ assert max_seqlen_k is not None
170
+ assert isinstance(max_seqlen, int)
171
+ return flash_attn_varlen_kvpacked_func(
172
+ q,
173
+ kv,
174
+ cu_seqlens,
175
+ cu_seqlens_k,
176
+ max_seqlen,
177
+ max_seqlen_k,
178
+ self.drop.p if self.training else 0.0,
179
+ softmax_scale=self.softmax_scale,
180
+ causal=causal,
181
+ alibi_slopes=None,
182
+ window_size=self.window_size,
183
+ deterministic=self.deterministic,
184
+ )
185
+ else:
186
+ batch_size, seqlen_q = q.shape[0], q.shape[1]
187
+ seqlen_k = kv.shape[1]
188
+ assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
189
+ return flash_attn_kvpacked_func(
190
+ q,
191
+ kv,
192
+ self.drop.p if self.training else 0.0,
193
+ causal=causal,
194
+ softmax_scale=self.softmax_scale,
195
+ alibi_slopes=None,
196
+ window_size=self.window_size,
197
+ deterministic=self.deterministic,
198
+ )
199
+
200
+
201
+ class SelfAttention(nn.Module):
202
+ """Implement the scaled dot product attention with softmax.
203
+ Arguments
204
+ ---------
205
+ softmax_scale: The temperature to use for the softmax attention.
206
+ (default: 1/sqrt(d_keys) where d_keys is computed at
207
+ runtime)
208
+ attention_dropout: The dropout rate to apply to the attention
209
+ (default: 0.0)
210
+ """
211
+
212
+ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
213
+ super().__init__()
214
+ self.causal = causal
215
+ self.softmax_scale = softmax_scale
216
+ self.drop = nn.Dropout(attention_dropout)
217
+
218
+ def forward(self, qkv, causal=None, key_padding_mask=None):
219
+ """Implements the multihead softmax attention.
220
+ Arguments
221
+ ---------
222
+ qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
223
+ causal: if passed, will override self.causal
224
+ key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
225
+ False means to mask out. (B, S)
226
+ """
227
+ batch_size, seqlen = qkv.shape[0], qkv.shape[1]
228
+ causal = self.causal if causal is None else causal
229
+ q, k, v = qkv.unbind(dim=2)
230
+ softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
231
+ scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
232
+ if key_padding_mask is not None:
233
+ padding_mask = torch.full(
234
+ (batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device
235
+ )
236
+ padding_mask.masked_fill_(key_padding_mask, 0.0)
237
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
238
+ scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
239
+ if causal:
240
+ # "triu_tril_cuda_template" not implemented for 'BFloat16'
241
+ # So we have to construct the mask in float
242
+ causal_mask = torch.triu(
243
+ torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1
244
+ )
245
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
246
+ scores = scores + causal_mask.to(dtype=scores.dtype)
247
+ attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
248
+ attention_drop = self.drop(attention)
249
+ output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
250
+ return output
251
+
252
+
253
+ class CrossAttention(nn.Module):
254
+ """Implement the scaled dot product attention with softmax.
255
+ Arguments
256
+ ---------
257
+ softmax_scale: The temperature to use for the softmax attention.
258
+ (default: 1/sqrt(d_keys) where d_keys is computed at
259
+ runtime)
260
+ attention_dropout: The dropout rate to apply to the attention
261
+ (default: 0.0)
262
+ """
263
+
264
+ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
265
+ super().__init__()
266
+ self.causal = causal
267
+ self.softmax_scale = softmax_scale
268
+ self.drop = nn.Dropout(attention_dropout)
269
+
270
+ def forward(self, q, kv, causal=None, key_padding_mask=None):
271
+ """Implements the multihead softmax attention.
272
+ Arguments
273
+ ---------
274
+ q: The tensor containing the query. (B, Sq, H, D)
275
+ kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
276
+ causal: if passed, will override self.causal
277
+ key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
278
+ False means to mask out. (B, Sk)
279
+ """
280
+ batch_size, seqlen_q = q.shape[0], q.shape[1]
281
+ causal = self.causal if causal is None else causal
282
+ seqlen_k = kv.shape[1]
283
+ assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
284
+ if kv.shape[3] != q.shape[2]: # MQA/GQA
285
+ kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
286
+ k, v = kv.unbind(dim=2)
287
+ softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
288
+ scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
289
+ if key_padding_mask is not None:
290
+ padding_mask = torch.full(
291
+ (batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device
292
+ )
293
+ padding_mask.masked_fill_(key_padding_mask, 0.0)
294
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
295
+ scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
296
+ if causal:
297
+ # causal mask needs to take into account the difference between seqlen_q and seqlen_k
298
+ row_idx = rearrange(
299
+ torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1"
300
+ )
301
+ col_idx = torch.arange(seqlen_k, device=kv.device, dtype=torch.long)
302
+ sk = (
303
+ seqlen_k
304
+ if key_padding_mask is None
305
+ else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
306
+ )
307
+ causal_mask = col_idx > row_idx + sk - seqlen_q
308
+ scores = scores.masked_fill(causal_mask, -10000.0)
309
+ attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
310
+ attention_drop = self.drop(attention)
311
+ output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
312
+ return output
313
+
314
+
315
+ class LinearResidual(nn.Linear):
316
+ """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense."""
317
+
318
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
319
+ return super().forward(input), input
320
+
321
+
322
+ def _update_kv_cache(kv, inference_params, layer_idx):
323
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
324
+ # Pre-allocate memory for key-values for inference.
325
+ num_heads, head_dim = kv.shape[-2:]
326
+ if layer_idx not in inference_params.key_value_memory_dict:
327
+ kv_cache = torch.empty(
328
+ inference_params.max_batch_size,
329
+ inference_params.max_seqlen,
330
+ 2,
331
+ num_heads,
332
+ head_dim,
333
+ dtype=kv.dtype,
334
+ device=kv.device,
335
+ )
336
+ inference_params.key_value_memory_dict[layer_idx] = kv_cache
337
+ else:
338
+ kv_cache = inference_params.key_value_memory_dict[layer_idx]
339
+ # Adjust key and value for inference
340
+ batch_start = inference_params.batch_size_offset
341
+ batch_end = batch_start + kv.shape[0]
342
+ sequence_start = inference_params.seqlen_offset
343
+ sequence_end = sequence_start + kv.shape[1]
344
+ assert batch_end <= kv_cache.shape[0]
345
+ assert sequence_end <= kv_cache.shape[1]
346
+ assert kv_cache is not None
347
+ kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
348
+ return kv_cache[batch_start:batch_end, :sequence_end, ...]
349
+
350
+
351
+ class MHA(nn.Module):
352
+ """Multi-head self-attention and cross-attention"""
353
+
354
+ def __init__(
355
+ self,
356
+ embed_dim,
357
+ num_heads,
358
+ num_heads_kv=None,
359
+ cross_attn=False,
360
+ qkv_proj_bias=True,
361
+ out_proj_bias=True,
362
+ dropout=0.0,
363
+ softmax_scale=None,
364
+ causal=False,
365
+ layer_idx=None,
366
+ dwconv=False,
367
+ window_size=(-1, -1),
368
+ fused_bias_fc=False,
369
+ use_flash_attn=False,
370
+ return_residual=False,
371
+ checkpointing=False,
372
+ device=None,
373
+ dtype=None,
374
+ ) -> None:
375
+ """
376
+ num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
377
+ return_residual: whether to return the input x along with the output. This is for
378
+ performance reason: for post-norm architecture, returning the input allows us
379
+ to fuse the backward of nn.Linear with the residual connection.
380
+ """
381
+ factory_kwargs = {"device": device, "dtype": dtype}
382
+ super().__init__()
383
+ self.embed_dim = embed_dim
384
+ self.cross_attn = cross_attn
385
+ self.causal = causal
386
+ self.layer_idx = layer_idx
387
+ self.dwconv = dwconv
388
+ self.use_flash_attn = use_flash_attn
389
+ self.return_residual = return_residual
390
+ self.checkpointing = checkpointing
391
+
392
+ if window_size != (-1, -1):
393
+ assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
394
+
395
+ self.num_heads = num_heads
396
+ self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
397
+ assert (
398
+ self.num_heads % self.num_heads_kv == 0
399
+ ), "num_heads must be divisible by num_heads_kv"
400
+ assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
401
+ self.head_dim = self.embed_dim // num_heads
402
+ qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
403
+ kv_dim = 2 * self.head_dim * self.num_heads_kv
404
+
405
+ if fused_bias_fc and FusedDense is None:
406
+ raise ImportError("fused_dense is not installed")
407
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
408
+ linear_resid_cls = (
409
+ LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
410
+ )
411
+ wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
412
+ inner_attn_cls = (
413
+ partial(FlashSelfAttention, window_size=window_size)
414
+ if use_flash_attn
415
+ else SelfAttention
416
+ )
417
+ inner_cross_attn_cls = (
418
+ partial(FlashCrossAttention, window_size=window_size)
419
+ if use_flash_attn
420
+ else CrossAttention
421
+ )
422
+ if not self.cross_attn:
423
+ self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs)
424
+ else:
425
+ self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs)
426
+ self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
427
+ if self.dwconv:
428
+ if self.num_heads_kv == self.num_heads:
429
+ self.dwconv_qkv = nn.Conv1d(
430
+ qkv_dim, qkv_dim, kernel_size=3, padding=2, groups=qkv_dim
431
+ )
432
+ else:
433
+ self.dwconv_q = nn.Conv1d(
434
+ embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim
435
+ )
436
+ self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim)
437
+ self.inner_attn = inner_attn_cls(
438
+ causal=causal,
439
+ softmax_scale=softmax_scale,
440
+ attention_dropout=dropout,
441
+ )
442
+ self.inner_cross_attn = inner_cross_attn_cls(
443
+ causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
444
+ )
445
+ self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs)
446
+
447
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
448
+ dtype = self.out_proj.weight.dtype if dtype is None else dtype
449
+ device = self.out_proj.weight.device
450
+ return torch.empty(
451
+ batch_size,
452
+ max_seqlen,
453
+ 2,
454
+ self.num_heads_kv,
455
+ self.head_dim,
456
+ dtype=dtype,
457
+ device=device,
458
+ )
459
+
460
+ def _update_kv_cache(self, kv, inference_params):
461
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
462
+ assert not self.dwconv, "Generation does not support dwconv yet"
463
+ assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
464
+ return _update_kv_cache(kv, inference_params, self.layer_idx)
465
+
466
+ def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
467
+ """
468
+ Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
469
+ q: (batch_size, seqlen_q, nheads, head_dim)
470
+ kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
471
+ """
472
+ assert inference_params is not None and inference_params.seqlen_offset > 0
473
+ assert self.use_flash_attn
474
+ batch = q.shape[0]
475
+ kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
476
+ cache_seqlens = (
477
+ inference_params.lengths_per_sample[:batch]
478
+ if inference_params.lengths_per_sample is not None
479
+ else inference_params.seqlen_offset
480
+ )
481
+ context = flash_attn_with_kvcache(
482
+ q,
483
+ kv_cache[:, :, 0],
484
+ kv_cache[:, :, 1],
485
+ kv[:, :, 0],
486
+ kv[:, :, 1],
487
+ cache_seqlens=cache_seqlens,
488
+ softmax_scale=self.inner_cross_attn.softmax_scale,
489
+ causal=self.inner_cross_attn.causal,
490
+ rotary_interleaved=False,
491
+ alibi_slopes=None,
492
+ )
493
+ return context
494
+
495
+ def _update_kvcache_attention(self, q, kv, inference_params):
496
+ """Write kv to inference_params, then do attention"""
497
+ if (
498
+ inference_params.seqlen_offset == 0
499
+ or flash_attn_with_kvcache is None
500
+ or not self.use_flash_attn
501
+ ):
502
+ # TODO: this only uses seqlen_offset and not lengths_per_sample.
503
+ kv = self._update_kv_cache(kv, inference_params)
504
+ return self.inner_cross_attn(q, kv)
505
+ else:
506
+ batch = q.shape[0]
507
+ kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
508
+ cache_seqlens = (
509
+ inference_params.lengths_per_sample[:batch]
510
+ if inference_params.lengths_per_sample is not None
511
+ else inference_params.seqlen_offset
512
+ )
513
+ return flash_attn_with_kvcache(
514
+ q,
515
+ kv_cache[:, :, 0],
516
+ kv_cache[:, :, 1],
517
+ kv[:, :, 0],
518
+ kv[:, :, 1],
519
+ cache_seqlens=cache_seqlens,
520
+ softmax_scale=self.inner_cross_attn.softmax_scale,
521
+ causal=self.inner_cross_attn.causal,
522
+ alibi_slopes=None,
523
+ )
524
+
525
+ def forward(
526
+ self,
527
+ x,
528
+ x_kv=None,
529
+ key_padding_mask=None,
530
+ cu_seqlens=None,
531
+ max_seqlen=None,
532
+ mixer_subset=None,
533
+ inference_params=None,
534
+ **kwargs,
535
+ ):
536
+ """
537
+ Arguments:
538
+ x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
539
+ cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
540
+ is the is the sum of the sequence lengths in the batch.
541
+ x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
542
+ cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
543
+ of the sequences in the batch, used to index into x. Only applicable when using
544
+ FlashAttention.
545
+ max_seqlen: int. Maximum sequence length in the batch.
546
+ key_padding_mask: boolean mask, True means to keep, False means to mask out.
547
+ (batch, seqlen). Only applicable when not using FlashAttention.
548
+ mixer_subset: for cross-attention only. If not None, will take a subset of x
549
+ before applying the query projection. Useful for e.g., ViT where we only care
550
+ about the CLS token in the last layer.
551
+ inference_params: for generation. Adapted from Megatron-LM (and Apex)
552
+ https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
553
+ """
554
+ if cu_seqlens is not None:
555
+ assert max_seqlen is not None
556
+ assert key_padding_mask is None
557
+ assert self.use_flash_attn
558
+ assert not self.dwconv
559
+ if key_padding_mask is not None:
560
+ assert cu_seqlens is None
561
+ assert max_seqlen is None
562
+ assert not self.use_flash_attn
563
+ if inference_params is not None:
564
+ assert key_padding_mask is None
565
+ assert cu_seqlens is None and max_seqlen is None
566
+ assert not self.dwconv
567
+
568
+ kwargs = (
569
+ {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, **kwargs}
570
+ if self.use_flash_attn
571
+ else {"key_padding_mask": key_padding_mask, **kwargs}
572
+ )
573
+ seqlen_offset = (
574
+ 0
575
+ if inference_params is None
576
+ else (
577
+ inference_params.lengths_per_sample
578
+ if inference_params.lengths_per_sample is not None
579
+ else inference_params.seqlen_offset
580
+ )
581
+ )
582
+ rotary_max_seqlen = (
583
+ inference_params.max_sequence_len if inference_params is not None else max_seqlen
584
+ )
585
+ batch, seqlen = x.shape[:2]
586
+ if not self.cross_attn and self.num_heads_kv == self.num_heads:
587
+ assert x_kv is None and mixer_subset is None
588
+ if not self.return_residual:
589
+ qkv = self.Wqkv(x)
590
+ else:
591
+ qkv, x = self.Wqkv(x)
592
+ if self.dwconv:
593
+ qkv = rearrange(
594
+ self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
595
+ ).contiguous()
596
+ qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
597
+ if (
598
+ inference_params is None
599
+ or inference_params.seqlen_offset == 0
600
+ or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
601
+ or not self.use_flash_attn
602
+ ):
603
+ if inference_params is None:
604
+ if not self.checkpointing:
605
+ context = self.inner_attn(qkv, **kwargs)
606
+ else:
607
+ context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
608
+ else:
609
+ context = self._update_kvcache_attention(
610
+ qkv[:, :, 0], qkv[:, :, 1:], inference_params
611
+ )
612
+ else:
613
+ context = self._apply_rotary_update_kvcache_attention(
614
+ qkv[:, :, 0], qkv[:, :, 1:], inference_params
615
+ )
616
+ else:
617
+ if self.cross_attn:
618
+ if not self.return_residual:
619
+ q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
620
+ kv = self.Wkv(x_kv if x_kv is not None else x)
621
+ else:
622
+ if x_kv is not None:
623
+ kv, x_kv = self.Wkv(x_kv)
624
+ else:
625
+ kv, x = self.Wkv(x)
626
+ q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
627
+ else:
628
+ assert self.num_heads_kv != self.num_heads
629
+ if not self.return_residual:
630
+ qkv = self.Wqkv(x)
631
+ else:
632
+ qkv, x = self.Wqkv(x)
633
+ q = qkv[..., : self.num_heads * self.head_dim]
634
+ kv = qkv[..., self.num_heads * self.head_dim :]
635
+ q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
636
+ kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
637
+ if self.dwconv:
638
+ q = rearrange(
639
+ self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
640
+ ).contiguous()
641
+ kv = rearrange(
642
+ self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
643
+ ).contiguous()
644
+ if (
645
+ inference_params is None
646
+ or inference_params.seqlen_offset == 0
647
+ or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
648
+ or not self.use_flash_attn
649
+ ):
650
+ if inference_params is None:
651
+ if not self.checkpointing:
652
+ context = self.inner_cross_attn(q, kv, **kwargs)
653
+ else:
654
+ context = torch.utils.checkpoint.checkpoint(
655
+ self.inner_cross_attn, q, kv, **kwargs
656
+ )
657
+ else:
658
+ context = self._update_kvcache_attention(q, kv, inference_params)
659
+ else:
660
+ context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
661
+ out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
662
+ return out if not self.return_residual else (out, x)
mlp.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mlp.py
2
+ # Commit id: c3b219665292c61a51153d0ded4473c494296382
3
+
4
+ # Copyright (c) 2023, Tri Dao.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch.distributed import ProcessGroup
10
+
11
+
12
+ try:
13
+ from flash_attn.ops.activations import swiglu
14
+ except ImportError:
15
+ swiglu = None
16
+
17
+ try:
18
+ from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
19
+ except ImportError:
20
+ ColumnParallelLinear, RowParallelLinear = None, None
21
+
22
+ try:
23
+ from flash_attn.ops.fused_dense import FusedMLP, ParallelFusedMLP
24
+ except ImportError:
25
+ FusedMLP, ParallelFusedMLP = None, None
26
+
27
+
28
+ class Mlp(nn.Module):
29
+ def __init__(
30
+ self,
31
+ in_features,
32
+ hidden_features=None,
33
+ out_features=None,
34
+ activation=F.gelu,
35
+ bias1=True,
36
+ bias2=True,
37
+ return_residual=False,
38
+ device=None,
39
+ dtype=None,
40
+ ):
41
+ factory_kwargs = {"device": device, "dtype": dtype}
42
+ super().__init__()
43
+ out_features = out_features if out_features is not None else in_features
44
+ hidden_features = hidden_features if hidden_features is not None else in_features * 4
45
+ self.return_residual = return_residual
46
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
47
+ self.activation = activation
48
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
49
+
50
+ def forward(self, x):
51
+ y = self.fc1(x)
52
+ y = self.activation(y)
53
+ y = self.fc2(y)
54
+ return y if not self.return_residual else (y, x)
55
+
56
+
57
+ class ParallelMLP(nn.Module):
58
+ def __init__(
59
+ self,
60
+ in_features,
61
+ hidden_features=None,
62
+ out_features=None,
63
+ activation=F.gelu,
64
+ process_group: ProcessGroup = None,
65
+ sequence_parallel=True,
66
+ bias1=True,
67
+ bias2=True,
68
+ device=None,
69
+ dtype=None,
70
+ ):
71
+ factory_kwargs = {"device": device, "dtype": dtype}
72
+ super().__init__()
73
+ assert ColumnParallelLinear is not None, "Need to install fused_dense"
74
+ assert RowParallelLinear is not None, "Need to install fused_dense"
75
+ out_features = out_features if out_features is not None else in_features
76
+ hidden_features = hidden_features if hidden_features is not None else in_features * 4
77
+ self.fc1 = ColumnParallelLinear(
78
+ in_features,
79
+ hidden_features,
80
+ process_group,
81
+ bias=bias1,
82
+ sequence_parallel=sequence_parallel,
83
+ **factory_kwargs,
84
+ )
85
+ self.activation = activation
86
+ self.fc2 = RowParallelLinear(
87
+ hidden_features,
88
+ out_features,
89
+ process_group,
90
+ bias=bias2,
91
+ sequence_parallel=sequence_parallel,
92
+ **factory_kwargs,
93
+ )
94
+
95
+ def forward(self, x):
96
+ y = self.fc1(x)
97
+ y = self.activation(y)
98
+ y = self.fc2(y)
99
+ return y
100
+
101
+
102
+ class GatedMlp(nn.Module):
103
+ def __init__(
104
+ self,
105
+ in_features,
106
+ hidden_features=None,
107
+ out_features=None,
108
+ activation=F.sigmoid,
109
+ bias1=True,
110
+ bias2=True,
111
+ multiple_of=128,
112
+ return_residual=False,
113
+ device=None,
114
+ dtype=None,
115
+ ):
116
+ factory_kwargs = {"device": device, "dtype": dtype}
117
+ super().__init__()
118
+ out_features = out_features if out_features is not None else in_features
119
+ hidden_features = (
120
+ hidden_features if hidden_features is not None else int(8 * in_features / 3)
121
+ )
122
+ hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
123
+ self.return_residual = return_residual
124
+ self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias1, **factory_kwargs)
125
+ self.activation = activation
126
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
127
+
128
+ def forward(self, x):
129
+ y = self.fc1(x)
130
+ if self.activation == F.sigmoid: # Special case for GLU
131
+ y = F.glu(y, dim=-1)
132
+ elif self.activation == F.silu and swiglu is not None: # Special case for SwiGLU
133
+ y, gate = y.chunk(2, dim=-1)
134
+ y = swiglu(gate, y)
135
+ else:
136
+ y, gate = y.chunk(2, dim=-1)
137
+ y = y * self.activation(gate)
138
+ y = self.fc2(y)
139
+ return y if not self.return_residual else (y, x)
140
+
141
+
142
+ class ParallelGatedMlp(nn.Module):
143
+ """Parallel GatedMlp"""
144
+
145
+ def __init__(
146
+ self,
147
+ in_features,
148
+ process_group,
149
+ hidden_features=None,
150
+ out_features=None,
151
+ activation=F.sigmoid,
152
+ bias1=True,
153
+ bias2=True,
154
+ multiple_of=128,
155
+ sequence_parallel=True,
156
+ device=None,
157
+ dtype=None,
158
+ ):
159
+ factory_kwargs = {"device": device, "dtype": dtype}
160
+ super().__init__()
161
+ out_features = out_features if out_features is not None else in_features
162
+ hidden_features = (
163
+ hidden_features if hidden_features is not None else int(8 * in_features / 3)
164
+ )
165
+ hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
166
+ if ColumnParallelLinear is None or RowParallelLinear is None:
167
+ raise ImportError("fused_dense is not installed")
168
+ self.fc1 = ColumnParallelLinear(
169
+ in_features,
170
+ 2 * hidden_features,
171
+ process_group,
172
+ bias=bias1,
173
+ sequence_parallel=sequence_parallel,
174
+ **factory_kwargs,
175
+ )
176
+ self.activation = activation
177
+ self.fc2 = RowParallelLinear(
178
+ hidden_features,
179
+ out_features,
180
+ process_group,
181
+ bias=bias2,
182
+ sequence_parallel=sequence_parallel,
183
+ **factory_kwargs,
184
+ )
185
+
186
+ def forward(self, x):
187
+ y = self.fc1(x)
188
+ if self.activation == F.sigmoid: # Special case for GLU
189
+ y = F.glu(y, dim=-1)
190
+ else:
191
+ y, gate = y.chunk(2, dim=-1)
192
+ y = y * self.activation(gate)
193
+ y = self.fc2(y)
194
+ return y
modeling_xlm_roberta.py ADDED
@@ -0,0 +1,904 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This implementation was adopted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/models/bert.py
2
+ # Commit id: abbc1311731867310635f9edc2a9ec18317c8c48
3
+ # Copyright (c) 2022, Tri Dao.
4
+ # This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.
5
+ # https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
6
+ # https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py
7
+
8
+ # Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
9
+
10
+ import importlib.util
11
+ import logging
12
+ import re
13
+ from collections import OrderedDict
14
+ from collections.abc import Sequence
15
+ from functools import partial
16
+ import numpy as np
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ import torch.utils.checkpoint
22
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
23
+ from einops import rearrange
24
+ from transformers import PretrainedConfig
25
+ from transformers.modeling_utils import PreTrainedModel
26
+ from transformers.modeling_outputs import MaskedLMOutput,SequenceClassifierOutput
27
+ from transformers.models.xlm_roberta.modeling_xlm_roberta import XLMRobertaLMHead
28
+
29
+ from transformers.models.bert.modeling_bert import (
30
+ BaseModelOutputWithPoolingAndCrossAttentions,
31
+ BertForPreTrainingOutput,
32
+ )
33
+
34
+ from typing import List, Optional, Tuple, Union
35
+
36
+ from .xlm_padding import (
37
+ index_first_axis,
38
+ index_first_axis_residual,
39
+ pad_input,
40
+ unpad_input,
41
+ )
42
+ from .configuration_xlm_roberta import XLMRobertaFlashConfig
43
+ from .block import Block
44
+ from .embedding import XLMRobertaEmbeddings
45
+ from .mha import MHA
46
+ from .mlp import FusedMLP, Mlp
47
+
48
+ try:
49
+ from flash_attn.ops.fused_dense import FusedDense
50
+ except ImportError:
51
+ FusedDense = None
52
+
53
+ try:
54
+ from flash_attn.ops.triton.layer_norm import layer_norm_fn
55
+ except ImportError:
56
+ layer_norm_fn = None
57
+
58
+
59
+ try:
60
+ from flash_attn.losses.cross_entropy import CrossEntropyLoss
61
+ except ImportError:
62
+ CrossEntropyLoss = torch.nn.CrossEntropyLoss
63
+
64
+ try:
65
+ from tqdm.autonotebook import trange
66
+ except ImportError:
67
+ trange = None
68
+
69
+
70
+ logger = logging.getLogger(__name__)
71
+
72
+
73
+ def get_use_flash_attn(config: XLMRobertaFlashConfig):
74
+ if not getattr(config, "use_flash_attn", False):
75
+ return False
76
+ if not torch.cuda.is_available():
77
+ return False
78
+ if importlib.util.find_spec("flash_attn") is None:
79
+ logger.warning(
80
+ 'flash_attn is not installed. Using PyTorch native attention implementation.'
81
+ )
82
+ return False
83
+ return True
84
+
85
+
86
+ def create_mixer_cls(config, cross_attn=False, return_residual=False):
87
+ use_flash_attn = get_use_flash_attn(config)
88
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
89
+
90
+ mixer_cls = partial(
91
+ MHA,
92
+ num_heads=config.num_attention_heads,
93
+ cross_attn=cross_attn,
94
+ dropout=config.attention_probs_dropout_prob,
95
+ causal=False,
96
+ fused_bias_fc=fused_bias_fc,
97
+ use_flash_attn=use_flash_attn,
98
+ return_residual=return_residual,
99
+ )
100
+ return mixer_cls
101
+
102
+
103
+ def create_mlp_cls(config, layer_idx=None, return_residual=False):
104
+ inner_dim = config.intermediate_size
105
+ fused_mlp = getattr(config, "fused_mlp", False)
106
+ if fused_mlp:
107
+ assert config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"], (
108
+ "fused_mlp only " "supports approximate gelu"
109
+ )
110
+ if not fused_mlp:
111
+ approximate = (
112
+ "tanh"
113
+ if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
114
+ else "none"
115
+ )
116
+ mlp_cls = partial(
117
+ Mlp,
118
+ hidden_features=inner_dim,
119
+ activation=partial(F.gelu, approximate=approximate),
120
+ return_residual=return_residual,
121
+ )
122
+ else:
123
+ if FusedMLP is None:
124
+ raise ImportError("fused_dense is not installed")
125
+ mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0)
126
+ # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
127
+ if isinstance(mlp_checkpoint_lvl, Sequence):
128
+ assert layer_idx is not None
129
+ mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
130
+ mlp_cls = partial(
131
+ FusedMLP,
132
+ hidden_features=inner_dim,
133
+ checkpoint_lvl=mlp_checkpoint_lvl,
134
+ return_residual=return_residual,
135
+ )
136
+ return mlp_cls
137
+
138
+
139
+ def create_block(config, layer_idx=None):
140
+ last_layer_subset = getattr(config, "last_layer_subset", False)
141
+ cross_attn = last_layer_subset and layer_idx == config.num_hidden_layers - 1
142
+ # TD [2022-12-19]: For cross attention (last layer), we actually want to return the
143
+ # residual x_kv, not residual x. But it's annoying to change the API (and it only affects
144
+ # one layer) so we just choose not to return residual in this case.
145
+ return_residual = not cross_attn
146
+ mixer_cls = create_mixer_cls(config, cross_attn, return_residual=return_residual)
147
+ mlp_cls = create_mlp_cls(config, layer_idx, return_residual=return_residual)
148
+ norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_eps)
149
+ block = Block(
150
+ config.hidden_size,
151
+ mixer_cls,
152
+ mlp_cls,
153
+ norm_cls=norm_cls,
154
+ prenorm=False,
155
+ resid_dropout1=config.hidden_dropout_prob,
156
+ resid_dropout2=config.hidden_dropout_prob,
157
+ fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False),
158
+ return_residual=return_residual,
159
+ )
160
+ return block
161
+
162
+
163
+ # https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/src/transformers/models/bert/modeling_bert.py#L748
164
+ def _init_weights(module, initializer_range=0.02):
165
+ if isinstance(module, nn.Linear):
166
+ nn.init.normal_(module.weight, std=initializer_range)
167
+ if module.bias is not None:
168
+ nn.init.zeros_(module.bias)
169
+ elif isinstance(module, nn.Embedding):
170
+ nn.init.normal_(module.weight, std=initializer_range)
171
+ if module.padding_idx is not None:
172
+ nn.init.zeros_(module.weight[module.padding_idx])
173
+
174
+
175
+ class XLMRobertaEncoder(nn.Module):
176
+ def __init__(self, config: XLMRobertaFlashConfig):
177
+ super().__init__()
178
+ self.use_flash_attn = get_use_flash_attn(config)
179
+ self.layers = nn.ModuleList(
180
+ [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
181
+ )
182
+ self._grad_checkpointing = False
183
+
184
+ @property
185
+ def gradient_checkpointing(self):
186
+ return self._grad_checkpointing
187
+
188
+ @gradient_checkpointing.setter
189
+ def gradient_checkpointing(self, value):
190
+ self._grad_checkpointing = value
191
+
192
+ def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
193
+ """If subset_mask is not None, we only want output for the subset of the sequence.
194
+ This means that we only compute the last layer output for these tokens.
195
+ subset_mask: (batch, seqlen), dtype=torch.bool
196
+ """
197
+ if key_padding_mask is None or not self.use_flash_attn:
198
+ mixer_kwargs = (
199
+ {"key_padding_mask": key_padding_mask.bool()}
200
+ if key_padding_mask is not None
201
+ else None
202
+ )
203
+ for layer in self.layers:
204
+ if self._grad_checkpointing:
205
+ hidden_states = torch.utils.checkpoint.checkpoint(
206
+ layer,
207
+ hidden_states,
208
+ use_reentrant=False,
209
+ mixer_kwargs=mixer_kwargs,
210
+ )
211
+ else:
212
+ hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
213
+ if subset_mask is not None:
214
+ hidden_states = hidden_states[subset_mask]
215
+ else:
216
+ batch, seqlen = hidden_states.shape[:2]
217
+ hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
218
+ hidden_states, key_padding_mask
219
+ )
220
+ mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
221
+ if subset_mask is None:
222
+ for layer in self.layers:
223
+ if self._grad_checkpointing:
224
+ hidden_states = torch.utils.checkpoint.checkpoint(
225
+ layer,
226
+ hidden_states,
227
+ use_reentrant=False,
228
+ mixer_kwargs=mixer_kwargs,
229
+ )
230
+ else:
231
+ hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
232
+ hidden_states = pad_input(hidden_states, indices, batch, seqlen)
233
+ else:
234
+ for layer in self.layers[:-1]:
235
+ if self._grad_checkpointing:
236
+ hidden_states = torch.utils.checkpoint.checkpoint(
237
+ layer,
238
+ hidden_states,
239
+ use_reentrant=False,
240
+ mixer_kwargs=mixer_kwargs,
241
+ )
242
+ else:
243
+ hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
244
+ if key_padding_mask is not None:
245
+ subset_idx = torch.nonzero(
246
+ subset_mask[key_padding_mask], as_tuple=False
247
+ ).flatten()
248
+ subset_seqlens = (subset_mask & key_padding_mask).sum(
249
+ dim=-1, dtype=torch.int32
250
+ )
251
+ subset_cu_seqlens = F.pad(
252
+ torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32),
253
+ (1, 0),
254
+ )
255
+ else:
256
+ subset_idx = torch.nonzero(subset_mask, as_tuple=False).flatten()
257
+ subset_seqlens = subset_mask.sum(dim=-1, dtype=torch.int32)
258
+ subset_cu_seqlens = F.pad(
259
+ torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32),
260
+ (1, 0),
261
+ )
262
+ hidden_states_subset, hidden_states = index_first_axis_residual(
263
+ hidden_states, subset_idx
264
+ )
265
+ # It's ok to set max_seqlen_q to be much larger
266
+ mixer_kwargs = {
267
+ "x_kv": hidden_states,
268
+ "cu_seqlens": subset_cu_seqlens,
269
+ "max_seqlen": max_seqlen_in_batch,
270
+ "cu_seqlens_k": cu_seqlens,
271
+ "max_seqlen_k": max_seqlen_in_batch,
272
+ }
273
+ if self._grad_checkpointing:
274
+ torch.utils.checkpoint.checkpoint(
275
+ self.layers[-1],
276
+ hidden_states_subset,
277
+ use_reentrant=False,
278
+ mixer_kwargs=mixer_kwargs,
279
+ )
280
+ else:
281
+ hidden_states = self.layers[-1](
282
+ hidden_states_subset, mixer_kwargs=mixer_kwargs
283
+ )
284
+ return hidden_states
285
+
286
+
287
+ class XLMRobertaPooler(nn.Module):
288
+ def __init__(self, config):
289
+ super().__init__()
290
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
291
+ if fused_bias_fc and FusedDense is None:
292
+ raise ImportError("fused_dense is not installed")
293
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
294
+ self.dense = linear_cls(config.hidden_size, config.hidden_size)
295
+ self.activation = nn.Tanh()
296
+
297
+ def forward(self, hidden_states, pool=True):
298
+ # We "pool" the model by simply taking the hidden state corresponding
299
+ # to the first token.
300
+ first_token_tensor = hidden_states[:, 0] if pool else hidden_states
301
+ pooled_output = self.dense(first_token_tensor)
302
+ pooled_output = self.activation(pooled_output)
303
+ return pooled_output
304
+
305
+
306
+ class XLMRobertaPredictionHeadTransform(nn.Module):
307
+ def __init__(self, config):
308
+ super().__init__()
309
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
310
+ if fused_bias_fc and FusedDense is None:
311
+ raise ImportError("fused_dense is not installed")
312
+ self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
313
+ if self.fused_dropout_add_ln and layer_norm_fn is None:
314
+ raise ImportError("Triton is not installed")
315
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
316
+ self.dense = linear_cls(config.hidden_size, config.hidden_size)
317
+ approximate = (
318
+ "tanh"
319
+ if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
320
+ else "none"
321
+ )
322
+ self.transform_act_fn = nn.GELU(approximate=approximate)
323
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
324
+
325
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
326
+ hidden_states = self.dense(hidden_states)
327
+ hidden_states = self.transform_act_fn(hidden_states)
328
+ if not self.fused_dropout_add_ln:
329
+ hidden_states = self.layer_norm(hidden_states)
330
+ else:
331
+ hidden_states = layer_norm_fn(
332
+ hidden_states,
333
+ self.layer_norm.weight,
334
+ self.layer_norm.bias,
335
+ eps=self.layer_norm.eps,
336
+ )
337
+ return hidden_states
338
+
339
+
340
+ class XLMRobertaLMPredictionHead(nn.Module):
341
+ def __init__(self, config):
342
+ super().__init__()
343
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
344
+ if fused_bias_fc and FusedDense is None:
345
+ raise ImportError("fused_dense is not installed")
346
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
347
+
348
+ self.transform = XLMRobertaPredictionHeadTransform(config)
349
+
350
+ # The output weights are the same as the input embeddings, but there is
351
+ # an output-only bias for each token.
352
+ self.decoder = linear_cls(config.hidden_size, config.vocab_size, bias=True)
353
+
354
+ def forward(self, hidden_states):
355
+ hidden_states = self.transform(hidden_states)
356
+ hidden_states = self.decoder(hidden_states)
357
+ return hidden_states
358
+
359
+
360
+ class XLMRobertaPreTrainingHeads(nn.Module):
361
+ def __init__(self, config):
362
+ super().__init__()
363
+ self.predictions = XLMRobertaLMPredictionHead(config)
364
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
365
+
366
+ def forward(self, sequence_output, pooled_output):
367
+ prediction_scores = self.predictions(sequence_output)
368
+ seq_relationship_score = self.seq_relationship(pooled_output)
369
+ return prediction_scores, seq_relationship_score
370
+
371
+
372
+ class XLMRobertaPreTrainedModel(PreTrainedModel):
373
+ """An abstract class to handle weights initialization and
374
+ a simple interface for dowloading and loading pretrained models.
375
+ """
376
+
377
+ config_class = XLMRobertaFlashConfig
378
+ base_model_prefix = "roberta"
379
+ supports_gradient_checkpointing = True
380
+
381
+ def _set_gradient_checkpointing(self, module, value=False):
382
+ if isinstance(module, XLMRobertaEncoder):
383
+ module.gradient_checkpointing = value
384
+
385
+ @classmethod
386
+ def from_pretrained(
387
+ cls,
388
+ *args,
389
+ **kwargs,
390
+ ):
391
+ if not 'torch_dtype' in kwargs:
392
+ kwargs['torch_dtype'] = 'auto'
393
+ return super().from_pretrained(*args, **kwargs)
394
+
395
+
396
+
397
+ class XLMRobertaModel(XLMRobertaPreTrainedModel):
398
+ def __init__(self, config: XLMRobertaFlashConfig, add_pooling_layer=True):
399
+ super().__init__(config)
400
+ self.pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
401
+ if config.vocab_size % self.pad_vocab_size_multiple != 0:
402
+ config.vocab_size += self.pad_vocab_size_multiple - (
403
+ config.vocab_size % self.pad_vocab_size_multiple
404
+ )
405
+ self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
406
+ if self.fused_dropout_add_ln and layer_norm_fn is None:
407
+ raise ImportError("Triton is not installed")
408
+ assert config.hidden_act in [
409
+ "gelu",
410
+ "gelu_new",
411
+ "gelu_fast",
412
+ "gelu_pytorch_tanh",
413
+ ]
414
+
415
+ self.embeddings = XLMRobertaEmbeddings(
416
+ config.hidden_size,
417
+ config.vocab_size,
418
+ config.max_position_embeddings if config.position_embedding_type == 'absolute' else -1,
419
+ config.type_vocab_size,
420
+ padding_idx=config.pad_token_id,
421
+ )
422
+ self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
423
+ self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
424
+ self.encoder = XLMRobertaEncoder(config)
425
+ self.pooler = XLMRobertaPooler(config) if add_pooling_layer else None
426
+
427
+ self.apply(partial(_init_weights, initializer_range=config.initializer_range))
428
+
429
+
430
+ @torch.inference_mode()
431
+ def encode(
432
+ self: 'XLMRobertaModel',
433
+ sentences: Union[str, List[str]],
434
+ batch_size: int = 32,
435
+ show_progress_bar: Optional[bool] = None,
436
+ output_value: str = 'sentence_embedding',
437
+ convert_to_numpy: bool = True,
438
+ convert_to_tensor: bool = False,
439
+ device: Optional[torch.device] = None,
440
+ normalize_embeddings: bool = False,
441
+ truncate_dim: Optional[int] = None,
442
+ **tokenizer_kwargs,
443
+ ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
444
+ """
445
+ Computes sentence embeddings
446
+ Args:
447
+ sentences(`str` or `List[str]`):
448
+ Sentence or sentences to be encoded
449
+ batch_size(`int`, *optional*, defaults to 32):
450
+ Batch size for the computation
451
+ show_progress_bar(`bool`, *optional*, defaults to None):
452
+ Show a progress bar when encoding sentences.
453
+ If set to None, progress bar is only shown when
454
+ `logger.level == logging.INFO` or `logger.level == logging.DEBUG`.
455
+ output_value(`str`, *optional*, defaults to 'sentence_embedding'):
456
+ Default sentence_embedding, to get sentence embeddings.
457
+ Can be set to token_embeddings to get wordpiece token embeddings.
458
+ Set to None, to get all output values
459
+ convert_to_numpy(`bool`, *optional*, defaults to True):
460
+ If true, the output is a list of numpy vectors.
461
+ Else, it is a list of pytorch tensors.
462
+ convert_to_tensor(`bool`, *optional*, defaults to False):
463
+ If true, you get one large tensor as return.
464
+ Overwrites any setting from convert_to_numpy
465
+ device(`torch.device`, *optional*, defaults to None):
466
+ Which torch.device to use for the computation
467
+ normalize_embeddings(`bool`, *optional*, defaults to False):
468
+ If set to true, returned vectors will have length 1. In that case, the
469
+ faster dot-product (util.dot_score) instead of cosine similarity can
470
+ be used.
471
+ truncate_dim(`int`, *optional*, defaults to None):
472
+ The dimension to truncate sentence embeddings to. `None` does no truncation.
473
+ tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
474
+ Keyword arguments for the tokenizer
475
+ Returns:
476
+ By default, a list of tensors is returned.
477
+ If convert_to_tensor, a stacked tensor is returned.
478
+ If convert_to_numpy, a numpy matrix is returned.
479
+ """
480
+ from transformers import AutoTokenizer
481
+
482
+ self.tokenizer = AutoTokenizer.from_pretrained(
483
+ self.name_or_path, trust_remote_code=True
484
+ )
485
+
486
+ is_training = self.training
487
+ self.eval()
488
+
489
+ if show_progress_bar is None:
490
+ show_progress_bar = (
491
+ logger.getEffectiveLevel() == logging.INFO
492
+ or logger.getEffectiveLevel() == logging.DEBUG
493
+ )
494
+
495
+ if convert_to_tensor:
496
+ convert_to_numpy = False
497
+
498
+ if output_value != 'sentence_embedding':
499
+ convert_to_tensor = False
500
+ convert_to_numpy = False
501
+
502
+ input_was_string = False
503
+ if isinstance(sentences, str) or not hasattr(sentences, '__len__'):
504
+ sentences = [sentences]
505
+ input_was_string = True
506
+
507
+ if device is not None:
508
+ self.to(device)
509
+
510
+ permutation = np.argsort([-len(i) for i in sentences])
511
+ inverse_permutation = np.argsort(permutation)
512
+ sentences = [sentences[idx] for idx in permutation]
513
+
514
+ tokenizer_kwargs['padding'] = tokenizer_kwargs.get('padding', True)
515
+ tokenizer_kwargs['max_length'] = tokenizer_kwargs.get(
516
+ 'max_length', self.tokenizer.init_kwargs.get('model_max_length', 8192)
517
+ )
518
+ tokenizer_kwargs['truncation'] = tokenizer_kwargs.get('truncation', True)
519
+
520
+ all_embeddings = []
521
+
522
+ if trange is not None:
523
+ range_iter = trange(
524
+ 0,
525
+ len(sentences),
526
+ batch_size,
527
+ desc="Encoding",
528
+ disable=not show_progress_bar,
529
+ )
530
+ else:
531
+ range_iter = range(0, len(sentences), batch_size)
532
+
533
+ for i in range_iter:
534
+ encoded_input = self.tokenizer(
535
+ sentences[i : i + batch_size],
536
+ return_tensors='pt',
537
+ **tokenizer_kwargs,
538
+ ).to(self.device)
539
+ token_embs = self.forward(**encoded_input)[0]
540
+
541
+ # Accumulate in fp32 to avoid overflow
542
+ token_embs = token_embs.float()
543
+
544
+ if output_value == 'token_embeddings':
545
+ raise NotImplementedError
546
+ elif output_value is None:
547
+ raise NotImplementedError
548
+ else:
549
+ if self.config.emb_pooler == 'cls':
550
+ embeddings = self.cls_pooling(
551
+ token_embs, encoded_input['attention_mask']
552
+ )
553
+ else:
554
+ embeddings = self.mean_pooling(
555
+ token_embs, encoded_input['attention_mask']
556
+ )
557
+
558
+ if normalize_embeddings:
559
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
560
+
561
+ if convert_to_numpy:
562
+ embeddings = embeddings.cpu()
563
+ all_embeddings.extend(embeddings)
564
+
565
+ all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
566
+
567
+ truncate_dim = truncate_dim or self.config.truncate_dim
568
+ if truncate_dim:
569
+ all_embeddings = self.truncate_embeddings(all_embeddings, truncate_dim)
570
+
571
+ if convert_to_tensor:
572
+ all_embeddings = torch.stack(all_embeddings)
573
+ elif convert_to_numpy:
574
+ all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
575
+
576
+ if input_was_string:
577
+ all_embeddings = all_embeddings[0]
578
+
579
+ self.train(is_training)
580
+ return all_embeddings
581
+
582
+
583
+ def truncate_embeddings(self, embeddings, truncate_dim):
584
+ if not self.config.matryoshka_dimensions:
585
+ logger.warning(
586
+ 'Matryoshka embeddings are not supported, so dimension truncation will not be performed.'
587
+ )
588
+ return embeddings
589
+ elif truncate_dim in self.config.matryoshka_dimensions:
590
+ return [tensor[:truncate_dim] for tensor in embeddings]
591
+ else:
592
+ raise ValueError(f'The provided `truncate_dim` value of {truncate_dim} is not supported. '
593
+ f'Supported dimensions are {self.config.matryoshka_dimensions}.')
594
+
595
+ def mean_pooling(
596
+ self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
597
+ ):
598
+ input_mask_expanded = (
599
+ attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
600
+ )
601
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
602
+ input_mask_expanded.sum(1), min=1e-9
603
+ )
604
+
605
+
606
+ def cls_pooling(
607
+ self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
608
+ ):
609
+ return token_embeddings[:,0]
610
+
611
+
612
+ def forward(
613
+ self,
614
+ input_ids,
615
+ position_ids=None,
616
+ token_type_ids=None,
617
+ attention_mask=None,
618
+ masked_tokens_mask=None,
619
+ return_dict=None,
620
+ **kwargs,
621
+ ):
622
+ """If masked_tokens_mask is not None (i.e. last_layer_subset == True in XLMForPreTraining),
623
+ we only want the output for the masked tokens. This means that we only compute the last
624
+ layer output for these tokens.
625
+ masked_tokens_mask: (batch, seqlen), dtype=torch.bool
626
+ """
627
+
628
+ if kwargs:
629
+ for key, value in kwargs.items():
630
+ if value is not None:
631
+ logger.warning(
632
+ 'Flash attention implementation does not support kwargs: %s',
633
+ key,
634
+ )
635
+
636
+ return_dict = (
637
+ return_dict if return_dict is not None else self.config.use_return_dict
638
+ )
639
+
640
+ hidden_states = self.embeddings(
641
+ input_ids, position_ids=position_ids, token_type_ids=token_type_ids
642
+ )
643
+ # TD [2022-12:18]: Don't need to force residual in fp32
644
+ # BERT puts embedding LayerNorm before embedding dropout.
645
+ if not self.fused_dropout_add_ln:
646
+ hidden_states = self.emb_ln(hidden_states)
647
+ else:
648
+ hidden_states = layer_norm_fn(
649
+ hidden_states, self.emb_ln.weight, self.emb_ln.bias, eps=self.emb_ln.eps
650
+ )
651
+ hidden_states = self.emb_drop(hidden_states)
652
+
653
+ if masked_tokens_mask is not None:
654
+ batch_size, seqlen = input_ids.shape[:2]
655
+ # We also need the first column for the CLS token
656
+ first_col_mask = torch.zeros(
657
+ batch_size, seqlen, dtype=torch.bool, device=input_ids.device
658
+ )
659
+ first_col_mask[:, 0] = True
660
+ subset_mask = masked_tokens_mask | first_col_mask
661
+ else:
662
+ subset_mask = None
663
+
664
+ sequence_output = self.encoder(
665
+ hidden_states, key_padding_mask=attention_mask, subset_mask=subset_mask
666
+ )
667
+
668
+ if masked_tokens_mask is None:
669
+ pooled_output = (
670
+ self.pooler(sequence_output) if self.pooler is not None else None
671
+ )
672
+ else:
673
+ # TD [2022-03-01]: the indexing here is very tricky.
674
+ if attention_mask is not None:
675
+ subset_idx = subset_mask[attention_mask]
676
+ pool_input = sequence_output[first_col_mask[attention_mask][subset_idx]]
677
+ sequence_output = sequence_output[
678
+ masked_tokens_mask[attention_mask][subset_idx]
679
+ ]
680
+ else:
681
+ pool_input = sequence_output[first_col_mask[subset_mask]]
682
+ sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
683
+ pooled_output = (
684
+ self.pooler(pool_input, pool=False) if self.pooler is not None else None
685
+ )
686
+
687
+ if not return_dict:
688
+ return sequence_output, pooled_output
689
+
690
+ return BaseModelOutputWithPoolingAndCrossAttentions(
691
+ last_hidden_state=sequence_output,
692
+ pooler_output=pooled_output,
693
+ )
694
+
695
+
696
+ class XLMRobertaForMaskedLM(XLMRobertaPreTrainedModel):
697
+ _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
698
+
699
+ def __init__(self, config):
700
+ super().__init__(config)
701
+
702
+ if config.is_decoder:
703
+ logger.warning(
704
+ "If you want to use `XLMRobertaForMaskedLM` make sure `config.is_decoder=False` for "
705
+ "bi-directional self-attention."
706
+ )
707
+
708
+ self.roberta = XLMRobertaModel(config, add_pooling_layer=False)
709
+ self.lm_head = XLMRobertaLMHead(config)
710
+
711
+ # Initialize weights and apply final processing
712
+ self.post_init()
713
+
714
+ def get_input_embeddings(self):
715
+ return self.roberta.embeddings.word_embeddings
716
+
717
+ def get_output_embeddings(self):
718
+ return self.lm_head.decoder
719
+
720
+ def set_output_embeddings(self, new_embeddings):
721
+ self.lm_head.decoder = new_embeddings
722
+
723
+ def forward(
724
+ self,
725
+ input_ids: Optional[torch.LongTensor] = None,
726
+ attention_mask: Optional[torch.FloatTensor] = None,
727
+ token_type_ids: Optional[torch.LongTensor] = None,
728
+ position_ids: Optional[torch.LongTensor] = None,
729
+ head_mask: Optional[torch.FloatTensor] = None,
730
+ inputs_embeds: Optional[torch.FloatTensor] = None,
731
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
732
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
733
+ labels: Optional[torch.LongTensor] = None,
734
+ output_attentions: Optional[bool] = None,
735
+ output_hidden_states: Optional[bool] = None,
736
+ return_dict: Optional[bool] = None,
737
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
738
+ r"""
739
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
740
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
741
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
742
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
743
+ kwargs (`Dict[str, any]`, optional, defaults to *{}*):
744
+ Used to hide legacy arguments that have been deprecated.
745
+ """
746
+ return_dict = (
747
+ return_dict if return_dict is not None else self.config.use_return_dict
748
+ )
749
+
750
+ outputs = self.roberta(
751
+ input_ids,
752
+ attention_mask=attention_mask,
753
+ token_type_ids=token_type_ids,
754
+ position_ids=position_ids,
755
+ head_mask=head_mask,
756
+ inputs_embeds=inputs_embeds,
757
+ encoder_hidden_states=encoder_hidden_states,
758
+ encoder_attention_mask=encoder_attention_mask,
759
+ output_attentions=output_attentions,
760
+ output_hidden_states=output_hidden_states,
761
+ return_dict=return_dict,
762
+ )
763
+ sequence_output = outputs[0]
764
+ prediction_scores = self.lm_head(sequence_output)
765
+
766
+ masked_lm_loss = None
767
+ if labels is not None:
768
+ # move labels to correct device to enable model parallelism
769
+ labels = labels.to(prediction_scores.device)
770
+ loss_fct = CrossEntropyLoss()
771
+ masked_lm_loss = loss_fct(
772
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
773
+ )
774
+
775
+ if not return_dict:
776
+ output = (prediction_scores,) + outputs[2:]
777
+ return (
778
+ ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
779
+ )
780
+
781
+ return MaskedLMOutput(
782
+ loss=masked_lm_loss,
783
+ logits=prediction_scores,
784
+ hidden_states=outputs.hidden_states,
785
+ attentions=outputs.attentions,
786
+ )
787
+
788
+
789
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaClassificationHead with Roberta->XLMRoberta
790
+ class XLMRobertaClassificationHead(nn.Module):
791
+ """Head for sentence-level classification tasks."""
792
+
793
+ def __init__(self, config):
794
+ super().__init__()
795
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
796
+ if fused_bias_fc and FusedDense is None:
797
+ raise ImportError("fused_dense is not installed")
798
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
799
+ self.dense = linear_cls(config.hidden_size, config.hidden_size)
800
+ classifier_dropout = (
801
+ config.classifier_dropout
802
+ if config.classifier_dropout is not None
803
+ else config.hidden_dropout_prob
804
+ )
805
+ self.dropout = nn.Dropout(classifier_dropout)
806
+ self.out_proj = linear_cls(config.hidden_size, config.num_labels)
807
+
808
+ def forward(self, features, **kwargs):
809
+ x = features[:, 0, :] # take <s> token (equiv. to [CLS])
810
+ x = self.dropout(x)
811
+ x = self.dense(x)
812
+ x = torch.tanh(x)
813
+ x = self.dropout(x)
814
+ x = self.out_proj(x)
815
+ return x
816
+
817
+
818
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaForSequenceClassification with Roberta->XLMRoberta, ROBERTA->XLM_ROBERTA
819
+ class XLMRobertaForSequenceClassification(XLMRobertaPreTrainedModel):
820
+ def __init__(self, config):
821
+ super().__init__(config)
822
+ self.num_labels = config.num_labels
823
+ self.config = config
824
+
825
+ self.roberta = XLMRobertaModel(config, add_pooling_layer=False)
826
+ self.classifier = XLMRobertaClassificationHead(config)
827
+
828
+ # Initialize weights and apply final processing
829
+ self.post_init()
830
+
831
+ def forward(
832
+ self,
833
+ input_ids: Optional[torch.LongTensor] = None,
834
+ attention_mask: Optional[torch.FloatTensor] = None,
835
+ token_type_ids: Optional[torch.LongTensor] = None,
836
+ position_ids: Optional[torch.LongTensor] = None,
837
+ head_mask: Optional[torch.FloatTensor] = None,
838
+ inputs_embeds: Optional[torch.FloatTensor] = None,
839
+ labels: Optional[torch.LongTensor] = None,
840
+ output_attentions: Optional[bool] = None,
841
+ output_hidden_states: Optional[bool] = None,
842
+ return_dict: Optional[bool] = None,
843
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
844
+ r"""
845
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
846
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
847
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
848
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
849
+ """
850
+ return_dict = (
851
+ return_dict if return_dict is not None else self.config.use_return_dict
852
+ )
853
+
854
+ outputs = self.roberta(
855
+ input_ids,
856
+ attention_mask=attention_mask,
857
+ token_type_ids=token_type_ids,
858
+ position_ids=position_ids,
859
+ head_mask=head_mask,
860
+ inputs_embeds=inputs_embeds,
861
+ output_attentions=output_attentions,
862
+ output_hidden_states=output_hidden_states,
863
+ return_dict=return_dict,
864
+ )
865
+ sequence_output = outputs[0]
866
+ logits = self.classifier(sequence_output)
867
+
868
+ loss = None
869
+ if labels is not None:
870
+ # move labels to correct device to enable model parallelism
871
+ labels = labels.to(logits.device)
872
+ if self.config.problem_type is None:
873
+ if self.num_labels == 1:
874
+ self.config.problem_type = "regression"
875
+ elif self.num_labels > 1 and (
876
+ labels.dtype == torch.long or labels.dtype == torch.int
877
+ ):
878
+ self.config.problem_type = "single_label_classification"
879
+ else:
880
+ self.config.problem_type = "multi_label_classification"
881
+
882
+ if self.config.problem_type == "regression":
883
+ loss_fct = MSELoss()
884
+ if self.num_labels == 1:
885
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
886
+ else:
887
+ loss = loss_fct(logits, labels)
888
+ elif self.config.problem_type == "single_label_classification":
889
+ loss_fct = CrossEntropyLoss()
890
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
891
+ elif self.config.problem_type == "multi_label_classification":
892
+ loss_fct = BCEWithLogitsLoss()
893
+ loss = loss_fct(logits, labels)
894
+
895
+ if not return_dict:
896
+ output = (logits,) + outputs[2:]
897
+ return ((loss,) + output) if loss is not None else output
898
+
899
+ return SequenceClassifierOutput(
900
+ loss=loss,
901
+ logits=logits,
902
+ hidden_states=outputs.hidden_states,
903
+ attentions=outputs.attentions,
904
+ )
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:318b11c3ce6d8d34e5034d001166a857934c0811c4fc5fb4a40328477ccaaaf9
3
+ size 561622266
special_tokens_map.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "cls_token": {
10
+ "content": "<s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "eos_token": {
17
+ "content": "</s>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "mask_token": {
24
+ "content": "<mask>",
25
+ "lstrip": true,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "pad_token": {
31
+ "content": "<pad>",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ },
37
+ "sep_token": {
38
+ "content": "</s>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false
43
+ },
44
+ "unk_token": {
45
+ "content": "<unk>",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false
50
+ }
51
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3a56def25aa40facc030ea8b0b87f3688e4b3c39eb8b45d5702b3a1300fe2a20
3
+ size 17082734
tokenizer_config.json ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "<s>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "<pad>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "</s>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "<unk>",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "250001": {
36
+ "content": "<mask>",
37
+ "lstrip": true,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "bos_token": "<s>",
45
+ "clean_up_tokenization_spaces": true,
46
+ "cls_token": "<s>",
47
+ "eos_token": "</s>",
48
+ "mask_token": "<mask>",
49
+ "model_max_length": 1026,
50
+ "pad_token": "<pad>",
51
+ "sep_token": "</s>",
52
+ "tokenizer_class": "XLMRobertaTokenizer",
53
+ "unk_token": "<unk>"
54
+ }
xlm_padding.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/block.py
2
+ # Commit id: c94cd09744d20f0ac587a351ff6ff2e8ad11ae1b
3
+
4
+ # Previously adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from einops import rearrange, repeat
9
+
10
+
11
+ class IndexFirstAxis(torch.autograd.Function):
12
+ @staticmethod
13
+ def forward(ctx, input, indices):
14
+ ctx.save_for_backward(indices)
15
+ assert input.ndim >= 2
16
+ ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
17
+ second_dim = other_shape.numel()
18
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
19
+ # return input[indices]
20
+ return torch.gather(
21
+ rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)
22
+ ).reshape(-1, *other_shape)
23
+
24
+ @staticmethod
25
+ def backward(ctx, grad_output):
26
+ (indices,) = ctx.saved_tensors
27
+ assert grad_output.ndim >= 2
28
+ other_shape = grad_output.shape[1:]
29
+ grad_output = rearrange(grad_output, "b ... -> b (...)")
30
+ grad_input = torch.zeros(
31
+ [ctx.first_axis_dim, grad_output.shape[1]],
32
+ device=grad_output.device,
33
+ dtype=grad_output.dtype,
34
+ )
35
+ # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
36
+ # grad_input[indices] = grad_output
37
+ grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
38
+ return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
39
+
40
+
41
+ index_first_axis = IndexFirstAxis.apply
42
+
43
+
44
+ class IndexPutFirstAxis(torch.autograd.Function):
45
+ @staticmethod
46
+ def forward(ctx, values, indices, first_axis_dim):
47
+ ctx.save_for_backward(indices)
48
+ assert indices.ndim == 1
49
+ assert values.ndim >= 2
50
+ output = torch.zeros(
51
+ first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype
52
+ )
53
+ # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
54
+ output[indices] = values
55
+ # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
56
+ return output
57
+
58
+ @staticmethod
59
+ def backward(ctx, grad_output):
60
+ (indices,) = ctx.saved_tensors
61
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
62
+ grad_values = grad_output[indices]
63
+ # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
64
+ return grad_values, None, None
65
+
66
+
67
+ index_put_first_axis = IndexPutFirstAxis.apply
68
+
69
+
70
+ class IndexFirstAxisResidual(torch.autograd.Function):
71
+ @staticmethod
72
+ def forward(ctx, input, indices):
73
+ ctx.save_for_backward(indices)
74
+ assert input.ndim >= 2
75
+ ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
76
+ second_dim = other_shape.numel()
77
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
78
+ output = input[indices]
79
+ # We don't want to reshape input (b ... -> b (...)) since it could change the channel_last
80
+ # memory format to channel_first. In other words, input might not be contiguous.
81
+ # If we don't detach, Pytorch complains about output being a view and is being modified inplace
82
+ return output, input.detach()
83
+
84
+ @staticmethod
85
+ def backward(ctx, grad_output, grad_residual):
86
+ (indices,) = ctx.saved_tensors
87
+ assert grad_output.ndim >= 2
88
+ other_shape = grad_output.shape[1:]
89
+ assert grad_residual.shape[1:] == other_shape
90
+ grad_input = grad_residual
91
+ # grad_input[indices] += grad_output
92
+ indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1)))
93
+ indices = indices.expand_as(grad_output)
94
+ grad_input.scatter_add_(0, indices, grad_output)
95
+ return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
96
+
97
+
98
+ index_first_axis_residual = IndexFirstAxisResidual.apply
99
+
100
+
101
+ def unpad_input(hidden_states, attention_mask):
102
+ """
103
+ Arguments:
104
+ hidden_states: (batch, seqlen, ...)
105
+ attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
106
+ Return:
107
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
108
+ indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
109
+ cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
110
+ max_seqlen_in_batch: int
111
+ """
112
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
113
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
114
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
115
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
116
+ # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
117
+ # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
118
+ # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
119
+ # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
120
+ # so we write custom forward and backward to make it a bit faster.
121
+ return (
122
+ index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
123
+ indices,
124
+ cu_seqlens,
125
+ max_seqlen_in_batch,
126
+ )
127
+
128
+
129
+ def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length):
130
+ """
131
+ Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model).
132
+ The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286).
133
+
134
+ For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is:
135
+ ```
136
+ [
137
+ [2, 3, 0, 0, 0, 0],
138
+ [3, 2, 0, 0, 0, 0],
139
+ [6, 0, 0, 0, 0, 0]
140
+ ]
141
+ ```
142
+ , which refers to the 3D-attention mask:
143
+ ```
144
+ [
145
+ [
146
+ [1, 0, 0, 0, 0, 0],
147
+ [1, 1, 0, 0, 0, 0],
148
+ [0, 0, 1, 0, 0, 0],
149
+ [0, 0, 1, 1, 0, 0],
150
+ [0, 0, 1, 1, 1, 0],
151
+ [0, 0, 0, 0, 0, 1]
152
+ ],
153
+ [
154
+ [1, 0, 0, 0, 0, 0],
155
+ [1, 1, 0, 0, 0, 0],
156
+ [1, 1, 1, 0, 0, 0],
157
+ [0, 0, 0, 1, 0, 0],
158
+ [0, 0, 0, 1, 1, 0],
159
+ [0, 0, 0, 0, 0, 1]
160
+ ],
161
+ [
162
+ [1, 0, 0, 0, 0, 0],
163
+ [1, 1, 0, 0, 0, 0],
164
+ [1, 1, 1, 0, 0, 0],
165
+ [1, 1, 1, 1, 0, 0],
166
+ [1, 1, 1, 1, 1, 0],
167
+ [1, 1, 1, 1, 1, 1]
168
+ ]
169
+ ]
170
+ ```.
171
+
172
+ Arguments:
173
+ hidden_states: (batch, seqlen, ...)
174
+ attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none.
175
+ Return:
176
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
177
+ indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
178
+ cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
179
+ max_seqlen_in_batch: int
180
+ """
181
+ length = attention_mask_in_length.sum(dim=-1)
182
+ seqlen = attention_mask_in_length.size(-1)
183
+ attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(len(length),
184
+ seqlen) < length.unsqueeze(
185
+ 1)
186
+ real_indices_idx = torch.nonzero(attention_mask_in_length.flatten(), as_tuple=False).flatten()
187
+ seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx]
188
+ indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten()
189
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
190
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
191
+ # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
192
+ # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
193
+ # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
194
+ # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
195
+ # so we write custom forward and backward to make it a bit faster.
196
+ return (
197
+ index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
198
+ indices,
199
+ cu_seqlens,
200
+ max_seqlen_in_batch,
201
+ )
202
+
203
+
204
+ def pad_input(hidden_states, indices, batch, seqlen):
205
+ """
206
+ Arguments:
207
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
208
+ indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
209
+ batch: int, batch size for the padded sequence.
210
+ seqlen: int, maximum sequence length for the padded sequence.
211
+ Return:
212
+ hidden_states: (batch, seqlen, ...)
213
+ """
214
+ dim = hidden_states.shape[-1]
215
+ # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
216
+ # output[indices] = hidden_states
217
+ output = index_put_first_axis(hidden_states, indices, batch * seqlen)
218
+ return rearrange(output, "(b s) ... -> b s ...", b=batch)