michael-guenther commited on
Commit
77af1c7
1 Parent(s): 1c61b96

add stochastic_depth

Browse files
Files changed (3) hide show
  1. block.py +26 -14
  2. modeling_xlm_roberta.py +121 -61
  3. stochastic_depth.py +97 -0
block.py CHANGED
@@ -10,8 +10,8 @@ import torch
10
  import torch.nn as nn
11
  import torch.nn.functional as F
12
  from torch import Tensor
13
- from torchvision.ops import StochasticDepth
14
 
 
15
  from .mha import MHA
16
  from .mlp import Mlp
17
 
@@ -106,7 +106,9 @@ class Block(nn.Module):
106
  p._shared_params = True
107
 
108
  def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
109
- return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
 
 
110
 
111
  def forward(
112
  self,
@@ -152,7 +154,7 @@ class Block(nn.Module):
152
  rowscale=rowscale1,
153
  prenorm=True,
154
  residual_in_fp32=self.residual_in_fp32,
155
- is_rms_norm=isinstance(self.norm1, RMSNorm)
156
  )
157
  if mixer_kwargs is None:
158
  mixer_kwargs = {}
@@ -165,7 +167,9 @@ class Block(nn.Module):
165
  if not self.fused_dropout_add_ln:
166
  dropped = self.drop_path2(self.dropout2(hidden_states))
167
  residual = (dropped + residual) if residual is not None else dropped
168
- hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
 
 
169
  if self.residual_in_fp32:
170
  residual = residual.to(torch.float32)
171
  else:
@@ -189,7 +193,7 @@ class Block(nn.Module):
189
  rowscale=rowscale2,
190
  prenorm=True,
191
  residual_in_fp32=self.residual_in_fp32,
192
- is_rms_norm=isinstance(self.norm2, RMSNorm)
193
  )
194
  hidden_states = self.mlp(hidden_states)
195
  return hidden_states, residual
@@ -212,7 +216,9 @@ class Block(nn.Module):
212
  else:
213
  rowscale1 = self.drop_path1(
214
  torch.ones(
215
- mixer_out.shape[:-1], device=mixer_out.device, dtype=mixer_out.dtype
 
 
216
  )
217
  )
218
  hidden_states = layer_norm_fn(
@@ -224,7 +230,7 @@ class Block(nn.Module):
224
  dropout_p=self.dropout1.p if self.training else 0.0,
225
  rowscale=rowscale1,
226
  prenorm=False,
227
- is_rms_norm=isinstance(self.norm1, RMSNorm)
228
  )
229
  if not isinstance(self.mlp, nn.Identity):
230
  mlp_out = self.mlp(hidden_states)
@@ -242,7 +248,9 @@ class Block(nn.Module):
242
  else:
243
  rowscale2 = self.drop_path2(
244
  torch.ones(
245
- mlp_out.shape[:-1], device=mlp_out.device, dtype=mlp_out.dtype
 
 
246
  )
247
  )
248
  hidden_states = layer_norm_fn(
@@ -254,7 +262,7 @@ class Block(nn.Module):
254
  dropout_p=self.dropout2.p if self.training else 0.0,
255
  rowscale=rowscale2,
256
  prenorm=False,
257
- is_rms_norm=isinstance(self.norm2, RMSNorm)
258
  )
259
  return hidden_states
260
 
@@ -333,7 +341,9 @@ class ParallelBlock(nn.Module):
333
  p._shared_params = True
334
 
335
  def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
336
- return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
 
 
337
 
338
  def forward(
339
  self,
@@ -373,7 +383,9 @@ class ParallelBlock(nn.Module):
373
  residual = residual.to(torch.float32)
374
  else:
375
  weight2, bias2 = (
376
- (self.norm2.weight, self.norm2.bias) if not self.tied_norm else (None, None)
 
 
377
  )
378
  hidden_states1, *rest, residual = layer_norm_fn(
379
  hidden_states1,
@@ -387,14 +399,14 @@ class ParallelBlock(nn.Module):
387
  dropout_p=self.dropout1.p if self.training else 0.0,
388
  prenorm=True,
389
  residual_in_fp32=self.residual_in_fp32,
390
- is_rms_norm=isinstance(self.norm1, RMSNorm)
391
  )
392
  if self.tied_norm:
393
  hidden_states2 = hidden_states1
394
  else:
395
- hidden_states2, = rest
396
  if mixer_kwargs is None:
397
  mixer_kwargs = {}
398
  hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs)
399
  hidden_states2 = self.mlp(hidden_states2)
400
- return hidden_states1, hidden_states2, residual
 
10
  import torch.nn as nn
11
  import torch.nn.functional as F
12
  from torch import Tensor
 
13
 
14
+ from .stochastic_depth import StochasticDepth
15
  from .mha import MHA
16
  from .mlp import Mlp
17
 
 
106
  p._shared_params = True
107
 
108
  def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
109
+ return self.mixer.allocate_inference_cache(
110
+ batch_size, max_seqlen, dtype=dtype, **kwargs
111
+ )
112
 
113
  def forward(
114
  self,
 
154
  rowscale=rowscale1,
155
  prenorm=True,
156
  residual_in_fp32=self.residual_in_fp32,
157
+ is_rms_norm=isinstance(self.norm1, RMSNorm),
158
  )
159
  if mixer_kwargs is None:
160
  mixer_kwargs = {}
 
167
  if not self.fused_dropout_add_ln:
168
  dropped = self.drop_path2(self.dropout2(hidden_states))
169
  residual = (dropped + residual) if residual is not None else dropped
170
+ hidden_states = self.norm2(
171
+ residual.to(dtype=self.norm2.weight.dtype)
172
+ )
173
  if self.residual_in_fp32:
174
  residual = residual.to(torch.float32)
175
  else:
 
193
  rowscale=rowscale2,
194
  prenorm=True,
195
  residual_in_fp32=self.residual_in_fp32,
196
+ is_rms_norm=isinstance(self.norm2, RMSNorm),
197
  )
198
  hidden_states = self.mlp(hidden_states)
199
  return hidden_states, residual
 
216
  else:
217
  rowscale1 = self.drop_path1(
218
  torch.ones(
219
+ mixer_out.shape[:-1],
220
+ device=mixer_out.device,
221
+ dtype=mixer_out.dtype,
222
  )
223
  )
224
  hidden_states = layer_norm_fn(
 
230
  dropout_p=self.dropout1.p if self.training else 0.0,
231
  rowscale=rowscale1,
232
  prenorm=False,
233
+ is_rms_norm=isinstance(self.norm1, RMSNorm),
234
  )
235
  if not isinstance(self.mlp, nn.Identity):
236
  mlp_out = self.mlp(hidden_states)
 
248
  else:
249
  rowscale2 = self.drop_path2(
250
  torch.ones(
251
+ mlp_out.shape[:-1],
252
+ device=mlp_out.device,
253
+ dtype=mlp_out.dtype,
254
  )
255
  )
256
  hidden_states = layer_norm_fn(
 
262
  dropout_p=self.dropout2.p if self.training else 0.0,
263
  rowscale=rowscale2,
264
  prenorm=False,
265
+ is_rms_norm=isinstance(self.norm2, RMSNorm),
266
  )
267
  return hidden_states
268
 
 
341
  p._shared_params = True
342
 
343
  def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
344
+ return self.mixer.allocate_inference_cache(
345
+ batch_size, max_seqlen, dtype=dtype, **kwargs
346
+ )
347
 
348
  def forward(
349
  self,
 
383
  residual = residual.to(torch.float32)
384
  else:
385
  weight2, bias2 = (
386
+ (self.norm2.weight, self.norm2.bias)
387
+ if not self.tied_norm
388
+ else (None, None)
389
  )
390
  hidden_states1, *rest, residual = layer_norm_fn(
391
  hidden_states1,
 
399
  dropout_p=self.dropout1.p if self.training else 0.0,
400
  prenorm=True,
401
  residual_in_fp32=self.residual_in_fp32,
402
+ is_rms_norm=isinstance(self.norm1, RMSNorm),
403
  )
404
  if self.tied_norm:
405
  hidden_states2 = hidden_states1
406
  else:
407
+ (hidden_states2,) = rest
408
  if mixer_kwargs is None:
409
  mixer_kwargs = {}
410
  hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs)
411
  hidden_states2 = self.mlp(hidden_states2)
412
+ return hidden_states1, hidden_states2, residual
modeling_xlm_roberta.py CHANGED
@@ -42,6 +42,7 @@ from .block import Block
42
  from .embedding import XLMRobertaEmbeddings
43
  from .mha import MHA
44
  from .mlp import FusedMLP, Mlp
 
45
 
46
 
47
  try:
@@ -69,10 +70,16 @@ def create_mixer_cls(config, cross_attn=False, return_residual=False):
69
  fused_bias_fc = getattr(config, "fused_bias_fc", False)
70
  rotary_kwargs = {}
71
  if config.position_embedding_type == "rotary":
72
- rotary_kwargs["rotary_emb_dim"] = getattr(config, "rotary_emb_dim", config.hidden_size)
 
 
73
  rotary_kwargs["rotary_emb_base"] = getattr(config, "rotary_emb_base", 10000.0)
74
- rotary_kwargs["rotary_emb_scale_base"] = getattr(config, "rotary_emb_scale_base", None)
75
- rotary_kwargs["rotary_emb_interleaved"] = getattr(config, "rotary_emb_interleaved", False)
 
 
 
 
76
  mixer_cls = partial(
77
  MHA,
78
  num_heads=config.num_attention_heads,
@@ -183,7 +190,9 @@ class XLMRobertaEncoder(nn.Module):
183
  """
184
  if key_padding_mask is None or not self.use_flash_attn:
185
  mixer_kwargs = (
186
- {"key_padding_mask": key_padding_mask.bool()} if key_padding_mask is not None else None
 
 
187
  )
188
  for layer in self.layers:
189
  if self._grad_checkpointing:
@@ -191,7 +200,7 @@ class XLMRobertaEncoder(nn.Module):
191
  layer,
192
  hidden_states,
193
  use_reentrant=False,
194
- mixer_kwargs=mixer_kwargs
195
  )
196
  else:
197
  hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
@@ -210,7 +219,7 @@ class XLMRobertaEncoder(nn.Module):
210
  layer,
211
  hidden_states,
212
  use_reentrant=False,
213
- mixer_kwargs=mixer_kwargs
214
  )
215
  else:
216
  hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
@@ -222,7 +231,7 @@ class XLMRobertaEncoder(nn.Module):
222
  layer,
223
  hidden_states,
224
  use_reentrant=False,
225
- mixer_kwargs=mixer_kwargs
226
  )
227
  else:
228
  hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
@@ -230,15 +239,19 @@ class XLMRobertaEncoder(nn.Module):
230
  subset_idx = torch.nonzero(
231
  subset_mask[key_padding_mask], as_tuple=False
232
  ).flatten()
233
- subset_seqlens = (subset_mask & key_padding_mask).sum(dim=-1, dtype=torch.int32)
 
 
234
  subset_cu_seqlens = F.pad(
235
- torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32), (1, 0)
 
236
  )
237
  else:
238
  subset_idx = torch.nonzero(subset_mask, as_tuple=False).flatten()
239
  subset_seqlens = subset_mask.sum(dim=-1, dtype=torch.int32)
240
  subset_cu_seqlens = F.pad(
241
- torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32), (1, 0)
 
242
  )
243
  hidden_states_subset, hidden_states = index_first_axis_residual(
244
  hidden_states, subset_idx
@@ -256,10 +269,12 @@ class XLMRobertaEncoder(nn.Module):
256
  self.layers[-1],
257
  hidden_states_subset,
258
  use_reentrant=False,
259
- mixer_kwargs=mixer_kwargs
260
  )
261
  else:
262
- hidden_states = self.layers[-1](hidden_states_subset, mixer_kwargs=mixer_kwargs)
 
 
263
  return hidden_states
264
 
265
 
@@ -308,7 +323,10 @@ class XLMRobertaPredictionHeadTransform(nn.Module):
308
  hidden_states = self.layer_norm(hidden_states)
309
  else:
310
  hidden_states = layer_norm_fn(
311
- hidden_states, self.layer_norm.weight, self.layer_norm.bias, eps=self.layer_norm.eps
 
 
 
312
  )
313
  return hidden_states
314
 
@@ -349,6 +367,7 @@ class XLMRobertaPreTrainedModel(PreTrainedModel):
349
  """An abstract class to handle weights initialization and
350
  a simple interface for dowloading and loading pretrained models.
351
  """
 
352
  config_class = XLMRobertaFlashConfig
353
  base_model_prefix = "roberta"
354
  supports_gradient_checkpointing = True
@@ -358,7 +377,6 @@ class XLMRobertaPreTrainedModel(PreTrainedModel):
358
  module.gradient_checkpointing = value
359
 
360
 
361
-
362
  class XLMRobertaModel(XLMRobertaPreTrainedModel):
363
  def __init__(self, config: XLMRobertaFlashConfig, add_pooling_layer=True):
364
  super().__init__(config)
@@ -370,7 +388,12 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
370
  self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
371
  if self.fused_dropout_add_ln and layer_norm_fn is None:
372
  raise ImportError("Triton is not installed")
373
- assert config.hidden_act in ["gelu", "gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
 
 
 
 
 
374
 
375
  self.embeddings = XLMRobertaEmbeddings(
376
  config.hidden_size,
@@ -386,7 +409,6 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
386
 
387
  self.apply(partial(_init_weights, initializer_range=config.initializer_range))
388
 
389
-
390
  def forward(
391
  self,
392
  input_ids,
@@ -406,9 +428,14 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
406
  if kwargs:
407
  for key, value in kwargs.items():
408
  if value is not None:
409
- logger.warning('Flash attention implementation does not support kwargs: %s', key)
 
 
 
410
 
411
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
412
 
413
  hidden_states = self.embeddings(
414
  input_ids, position_ids=position_ids, token_type_ids=token_type_ids
@@ -439,17 +466,23 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
439
  )
440
 
441
  if masked_tokens_mask is None:
442
- pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
 
 
443
  else:
444
  # TD [2022-03-01]: the indexing here is very tricky.
445
  if attention_mask is not None:
446
  subset_idx = subset_mask[attention_mask]
447
  pool_input = sequence_output[first_col_mask[attention_mask][subset_idx]]
448
- sequence_output = sequence_output[masked_tokens_mask[attention_mask][subset_idx]]
 
 
449
  else:
450
  pool_input = sequence_output[first_col_mask[subset_mask]]
451
  sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
452
- pooled_output = self.pooler(pool_input, pool=False) if self.pooler is not None else None
 
 
453
 
454
  if not return_dict:
455
  return sequence_output, pooled_output
@@ -487,7 +520,6 @@ class XLMRobertaForMaskedLM(XLMRobertaPreTrainedModel):
487
  def set_output_embeddings(self, new_embeddings):
488
  self.lm_head.decoder = new_embeddings
489
 
490
-
491
  def forward(
492
  self,
493
  input_ids: Optional[torch.LongTensor] = None,
@@ -511,7 +543,9 @@ class XLMRobertaForMaskedLM(XLMRobertaPreTrainedModel):
511
  kwargs (`Dict[str, any]`, optional, defaults to *{}*):
512
  Used to hide legacy arguments that have been deprecated.
513
  """
514
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
515
 
516
  outputs = self.roberta(
517
  input_ids,
@@ -534,11 +568,15 @@ class XLMRobertaForMaskedLM(XLMRobertaPreTrainedModel):
534
  # move labels to correct device to enable model parallelism
535
  labels = labels.to(prediction_scores.device)
536
  loss_fct = CrossEntropyLoss()
537
- masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
 
 
538
 
539
  if not return_dict:
540
  output = (prediction_scores,) + outputs[2:]
541
- return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
 
 
542
 
543
  return MaskedLMOutput(
544
  loss=masked_lm_loss,
@@ -656,7 +694,9 @@ def remap_state_dict(state_dict, config: PretrainedConfig):
656
  key = re.sub(r"LayerNorm.beta$", "LayerNorm.bias", key)
657
  return key
658
 
659
- state_dict = OrderedDict((key_mapping_ln_gamma_beta(k), v) for k, v in state_dict.items())
 
 
660
 
661
  # Layers
662
  def key_mapping_layers(key):
@@ -715,12 +755,18 @@ def remap_state_dict(state_dict, config: PretrainedConfig):
715
  state_dict[f"bert.encoder.layers.{d}.mixer.Wqkv.weight"] = torch.cat(
716
  [Wq, Wk, Wv], dim=0
717
  )
718
- state_dict[f"bert.encoder.layers.{d}.mixer.Wqkv.bias"] = torch.cat([bq, bk, bv], dim=0)
 
 
719
  else:
720
  state_dict[f"bert.encoder.layers.{d}.mixer.Wq.weight"] = Wq
721
- state_dict[f"bert.encoder.layers.{d}.mixer.Wkv.weight"] = torch.cat([Wk, Wv], dim=0)
 
 
722
  state_dict[f"bert.encoder.layers.{d}.mixer.Wq.bias"] = bq
723
- state_dict[f"bert.encoder.layers.{d}.mixer.Wkv.bias"] = torch.cat([bk, bv], dim=0)
 
 
724
 
725
  def key_mapping_attn(key):
726
  return re.sub(
@@ -734,7 +780,9 @@ def remap_state_dict(state_dict, config: PretrainedConfig):
734
  def key_mapping_decoder_bias(key):
735
  return re.sub(r"^cls.predictions.bias", "cls.predictions.decoder.bias", key)
736
 
737
- state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items())
 
 
738
 
739
  # Word embedding
740
  pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
@@ -774,51 +822,59 @@ def inv_remap_state_dict(state_dict, config: PretrainedConfig):
774
  state_dict["bert.embeddings.word_embeddings.weight"] = word_embeddings[
775
  : config.orig_vocab_size, :
776
  ]
777
- state_dict["cls.predictions.decoder.weight"] = decoder_weight[: config.orig_vocab_size, :]
778
- state_dict["cls.predictions.decoder.bias"] = decoder_bias[: config.orig_vocab_size]
 
 
 
 
779
 
780
  for d in range(config.num_hidden_layers):
781
  last_layer_subset = getattr(config, "last_layer_subset", False)
782
  if not last_layer_subset or d != (config.num_hidden_layers - 1):
783
  Wqkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.weight")
784
  Wqkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.bias")
785
- state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = Wqkv_weights[
786
- : Wqkv_weights.shape[0] // 3, :
787
- ]
788
- state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = Wqkv_weights[
 
 
789
  Wqkv_weights.shape[0] // 3 : 2 * Wqkv_weights.shape[0] // 3, :
790
  ]
791
- state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = Wqkv_weights[
792
- 2 * Wqkv_weights.shape[0] // 3 :, :
793
- ]
794
- state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wqkv_biases[
795
- : Wqkv_biases.shape[0] // 3
796
- ]
797
- state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wqkv_biases[
798
- Wqkv_biases.shape[0] // 3 : 2 * Wqkv_biases.shape[0] // 3
799
- ]
800
- state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = Wqkv_biases[
801
- 2 * Wqkv_biases.shape[0] // 3 :
802
- ]
803
  else:
804
  Wq_weight = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.weight")
805
  Wkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.weight")
806
  Wq_bias = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.bias")
807
  Wkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.bias")
808
- state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = Wq_weight
809
- state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = Wkv_weights[
810
- : Wkv_weights.shape[0] // 2, :
811
- ]
812
- state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = Wkv_weights[
813
- Wkv_weights.shape[0] // 2 :, :
814
- ]
 
 
815
  state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wq_bias
816
  state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wkv_biases[
817
  : Wkv_biases.shape[0] // 2
818
  ]
819
- state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = Wkv_biases[
820
- Wkv_biases.shape[0] // 2 :
821
- ]
822
 
823
  def inv_key_mapping_ln(key):
824
  key = re.sub(r"bert.emb_ln.", "bert.embeddings.LayerNorm.", key)
@@ -870,14 +926,18 @@ def inv_remap_state_dict(state_dict, config: PretrainedConfig):
870
  def inv_key_mapping_decoder_bias(key):
871
  return re.sub(r"cls.predictions.decoder.bias", "cls.predictions.bias", key)
872
 
873
- state_dict = OrderedDict((inv_key_mapping_ln(key), value) for key, value in state_dict.items())
 
 
874
  state_dict = OrderedDict(
875
  (inv_key_mapping_ln_gamma_beta(key), value) for key, value in state_dict.items()
876
  )
877
  state_dict = OrderedDict(
878
  (inv_key_mapping_layers(key), value) for key, value in state_dict.items()
879
  )
880
- state_dict = OrderedDict((inv_key_mapping_mlp(key), value) for key, value in state_dict.items())
 
 
881
  state_dict = OrderedDict(
882
  (inv_key_mapping_attn(key), value) for key, value in state_dict.items()
883
  )
@@ -885,4 +945,4 @@ def inv_remap_state_dict(state_dict, config: PretrainedConfig):
885
  (inv_key_mapping_decoder_bias(key), value) for key, value in state_dict.items()
886
  )
887
 
888
- return state_dict
 
42
  from .embedding import XLMRobertaEmbeddings
43
  from .mha import MHA
44
  from .mlp import FusedMLP, Mlp
45
+ from .stochastic_depth import StochasticDepth
46
 
47
 
48
  try:
 
70
  fused_bias_fc = getattr(config, "fused_bias_fc", False)
71
  rotary_kwargs = {}
72
  if config.position_embedding_type == "rotary":
73
+ rotary_kwargs["rotary_emb_dim"] = getattr(
74
+ config, "rotary_emb_dim", config.hidden_size
75
+ )
76
  rotary_kwargs["rotary_emb_base"] = getattr(config, "rotary_emb_base", 10000.0)
77
+ rotary_kwargs["rotary_emb_scale_base"] = getattr(
78
+ config, "rotary_emb_scale_base", None
79
+ )
80
+ rotary_kwargs["rotary_emb_interleaved"] = getattr(
81
+ config, "rotary_emb_interleaved", False
82
+ )
83
  mixer_cls = partial(
84
  MHA,
85
  num_heads=config.num_attention_heads,
 
190
  """
191
  if key_padding_mask is None or not self.use_flash_attn:
192
  mixer_kwargs = (
193
+ {"key_padding_mask": key_padding_mask.bool()}
194
+ if key_padding_mask is not None
195
+ else None
196
  )
197
  for layer in self.layers:
198
  if self._grad_checkpointing:
 
200
  layer,
201
  hidden_states,
202
  use_reentrant=False,
203
+ mixer_kwargs=mixer_kwargs,
204
  )
205
  else:
206
  hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
 
219
  layer,
220
  hidden_states,
221
  use_reentrant=False,
222
+ mixer_kwargs=mixer_kwargs,
223
  )
224
  else:
225
  hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
 
231
  layer,
232
  hidden_states,
233
  use_reentrant=False,
234
+ mixer_kwargs=mixer_kwargs,
235
  )
236
  else:
237
  hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
 
239
  subset_idx = torch.nonzero(
240
  subset_mask[key_padding_mask], as_tuple=False
241
  ).flatten()
242
+ subset_seqlens = (subset_mask & key_padding_mask).sum(
243
+ dim=-1, dtype=torch.int32
244
+ )
245
  subset_cu_seqlens = F.pad(
246
+ torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32),
247
+ (1, 0),
248
  )
249
  else:
250
  subset_idx = torch.nonzero(subset_mask, as_tuple=False).flatten()
251
  subset_seqlens = subset_mask.sum(dim=-1, dtype=torch.int32)
252
  subset_cu_seqlens = F.pad(
253
+ torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32),
254
+ (1, 0),
255
  )
256
  hidden_states_subset, hidden_states = index_first_axis_residual(
257
  hidden_states, subset_idx
 
269
  self.layers[-1],
270
  hidden_states_subset,
271
  use_reentrant=False,
272
+ mixer_kwargs=mixer_kwargs,
273
  )
274
  else:
275
+ hidden_states = self.layers[-1](
276
+ hidden_states_subset, mixer_kwargs=mixer_kwargs
277
+ )
278
  return hidden_states
279
 
280
 
 
323
  hidden_states = self.layer_norm(hidden_states)
324
  else:
325
  hidden_states = layer_norm_fn(
326
+ hidden_states,
327
+ self.layer_norm.weight,
328
+ self.layer_norm.bias,
329
+ eps=self.layer_norm.eps,
330
  )
331
  return hidden_states
332
 
 
367
  """An abstract class to handle weights initialization and
368
  a simple interface for dowloading and loading pretrained models.
369
  """
370
+
371
  config_class = XLMRobertaFlashConfig
372
  base_model_prefix = "roberta"
373
  supports_gradient_checkpointing = True
 
377
  module.gradient_checkpointing = value
378
 
379
 
 
380
  class XLMRobertaModel(XLMRobertaPreTrainedModel):
381
  def __init__(self, config: XLMRobertaFlashConfig, add_pooling_layer=True):
382
  super().__init__(config)
 
388
  self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
389
  if self.fused_dropout_add_ln and layer_norm_fn is None:
390
  raise ImportError("Triton is not installed")
391
+ assert config.hidden_act in [
392
+ "gelu",
393
+ "gelu_new",
394
+ "gelu_fast",
395
+ "gelu_pytorch_tanh",
396
+ ]
397
 
398
  self.embeddings = XLMRobertaEmbeddings(
399
  config.hidden_size,
 
409
 
410
  self.apply(partial(_init_weights, initializer_range=config.initializer_range))
411
 
 
412
  def forward(
413
  self,
414
  input_ids,
 
428
  if kwargs:
429
  for key, value in kwargs.items():
430
  if value is not None:
431
+ logger.warning(
432
+ 'Flash attention implementation does not support kwargs: %s',
433
+ key,
434
+ )
435
 
436
+ return_dict = (
437
+ return_dict if return_dict is not None else self.config.use_return_dict
438
+ )
439
 
440
  hidden_states = self.embeddings(
441
  input_ids, position_ids=position_ids, token_type_ids=token_type_ids
 
466
  )
467
 
468
  if masked_tokens_mask is None:
469
+ pooled_output = (
470
+ self.pooler(sequence_output) if self.pooler is not None else None
471
+ )
472
  else:
473
  # TD [2022-03-01]: the indexing here is very tricky.
474
  if attention_mask is not None:
475
  subset_idx = subset_mask[attention_mask]
476
  pool_input = sequence_output[first_col_mask[attention_mask][subset_idx]]
477
+ sequence_output = sequence_output[
478
+ masked_tokens_mask[attention_mask][subset_idx]
479
+ ]
480
  else:
481
  pool_input = sequence_output[first_col_mask[subset_mask]]
482
  sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
483
+ pooled_output = (
484
+ self.pooler(pool_input, pool=False) if self.pooler is not None else None
485
+ )
486
 
487
  if not return_dict:
488
  return sequence_output, pooled_output
 
520
  def set_output_embeddings(self, new_embeddings):
521
  self.lm_head.decoder = new_embeddings
522
 
 
523
  def forward(
524
  self,
525
  input_ids: Optional[torch.LongTensor] = None,
 
543
  kwargs (`Dict[str, any]`, optional, defaults to *{}*):
544
  Used to hide legacy arguments that have been deprecated.
545
  """
546
+ return_dict = (
547
+ return_dict if return_dict is not None else self.config.use_return_dict
548
+ )
549
 
550
  outputs = self.roberta(
551
  input_ids,
 
568
  # move labels to correct device to enable model parallelism
569
  labels = labels.to(prediction_scores.device)
570
  loss_fct = CrossEntropyLoss()
571
+ masked_lm_loss = loss_fct(
572
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
573
+ )
574
 
575
  if not return_dict:
576
  output = (prediction_scores,) + outputs[2:]
577
+ return (
578
+ ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
579
+ )
580
 
581
  return MaskedLMOutput(
582
  loss=masked_lm_loss,
 
694
  key = re.sub(r"LayerNorm.beta$", "LayerNorm.bias", key)
695
  return key
696
 
697
+ state_dict = OrderedDict(
698
+ (key_mapping_ln_gamma_beta(k), v) for k, v in state_dict.items()
699
+ )
700
 
701
  # Layers
702
  def key_mapping_layers(key):
 
755
  state_dict[f"bert.encoder.layers.{d}.mixer.Wqkv.weight"] = torch.cat(
756
  [Wq, Wk, Wv], dim=0
757
  )
758
+ state_dict[f"bert.encoder.layers.{d}.mixer.Wqkv.bias"] = torch.cat(
759
+ [bq, bk, bv], dim=0
760
+ )
761
  else:
762
  state_dict[f"bert.encoder.layers.{d}.mixer.Wq.weight"] = Wq
763
+ state_dict[f"bert.encoder.layers.{d}.mixer.Wkv.weight"] = torch.cat(
764
+ [Wk, Wv], dim=0
765
+ )
766
  state_dict[f"bert.encoder.layers.{d}.mixer.Wq.bias"] = bq
767
+ state_dict[f"bert.encoder.layers.{d}.mixer.Wkv.bias"] = torch.cat(
768
+ [bk, bv], dim=0
769
+ )
770
 
771
  def key_mapping_attn(key):
772
  return re.sub(
 
780
  def key_mapping_decoder_bias(key):
781
  return re.sub(r"^cls.predictions.bias", "cls.predictions.decoder.bias", key)
782
 
783
+ state_dict = OrderedDict(
784
+ (key_mapping_decoder_bias(k), v) for k, v in state_dict.items()
785
+ )
786
 
787
  # Word embedding
788
  pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
 
822
  state_dict["bert.embeddings.word_embeddings.weight"] = word_embeddings[
823
  : config.orig_vocab_size, :
824
  ]
825
+ state_dict["cls.predictions.decoder.weight"] = decoder_weight[
826
+ : config.orig_vocab_size, :
827
+ ]
828
+ state_dict["cls.predictions.decoder.bias"] = decoder_bias[
829
+ : config.orig_vocab_size
830
+ ]
831
 
832
  for d in range(config.num_hidden_layers):
833
  last_layer_subset = getattr(config, "last_layer_subset", False)
834
  if not last_layer_subset or d != (config.num_hidden_layers - 1):
835
  Wqkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.weight")
836
  Wqkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.bias")
837
+ state_dict[
838
+ f"bert.encoder.layers.{d}.attention.self.query.weight"
839
+ ] = Wqkv_weights[: Wqkv_weights.shape[0] // 3, :]
840
+ state_dict[
841
+ f"bert.encoder.layers.{d}.attention.self.key.weight"
842
+ ] = Wqkv_weights[
843
  Wqkv_weights.shape[0] // 3 : 2 * Wqkv_weights.shape[0] // 3, :
844
  ]
845
+ state_dict[
846
+ f"bert.encoder.layers.{d}.attention.self.value.weight"
847
+ ] = Wqkv_weights[2 * Wqkv_weights.shape[0] // 3 :, :]
848
+ state_dict[
849
+ f"bert.encoder.layers.{d}.attention.self.query.bias"
850
+ ] = Wqkv_biases[: Wqkv_biases.shape[0] // 3]
851
+ state_dict[
852
+ f"bert.encoder.layers.{d}.attention.self.key.bias"
853
+ ] = Wqkv_biases[Wqkv_biases.shape[0] // 3 : 2 * Wqkv_biases.shape[0] // 3]
854
+ state_dict[
855
+ f"bert.encoder.layers.{d}.attention.self.value.bias"
856
+ ] = Wqkv_biases[2 * Wqkv_biases.shape[0] // 3 :]
857
  else:
858
  Wq_weight = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.weight")
859
  Wkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.weight")
860
  Wq_bias = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.bias")
861
  Wkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.bias")
862
+ state_dict[
863
+ f"bert.encoder.layers.{d}.attention.self.query.weight"
864
+ ] = Wq_weight
865
+ state_dict[
866
+ f"bert.encoder.layers.{d}.attention.self.key.weight"
867
+ ] = Wkv_weights[: Wkv_weights.shape[0] // 2, :]
868
+ state_dict[
869
+ f"bert.encoder.layers.{d}.attention.self.value.weight"
870
+ ] = Wkv_weights[Wkv_weights.shape[0] // 2 :, :]
871
  state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wq_bias
872
  state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wkv_biases[
873
  : Wkv_biases.shape[0] // 2
874
  ]
875
+ state_dict[
876
+ f"bert.encoder.layers.{d}.attention.self.value.bias"
877
+ ] = Wkv_biases[Wkv_biases.shape[0] // 2 :]
878
 
879
  def inv_key_mapping_ln(key):
880
  key = re.sub(r"bert.emb_ln.", "bert.embeddings.LayerNorm.", key)
 
926
  def inv_key_mapping_decoder_bias(key):
927
  return re.sub(r"cls.predictions.decoder.bias", "cls.predictions.bias", key)
928
 
929
+ state_dict = OrderedDict(
930
+ (inv_key_mapping_ln(key), value) for key, value in state_dict.items()
931
+ )
932
  state_dict = OrderedDict(
933
  (inv_key_mapping_ln_gamma_beta(key), value) for key, value in state_dict.items()
934
  )
935
  state_dict = OrderedDict(
936
  (inv_key_mapping_layers(key), value) for key, value in state_dict.items()
937
  )
938
+ state_dict = OrderedDict(
939
+ (inv_key_mapping_mlp(key), value) for key, value in state_dict.items()
940
+ )
941
  state_dict = OrderedDict(
942
  (inv_key_mapping_attn(key), value) for key, value in state_dict.items()
943
  )
 
945
  (inv_key_mapping_decoder_bias(key), value) for key, value in state_dict.items()
946
  )
947
 
948
+ return state_dict
stochastic_depth.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Implementation modified from torchvision:
2
+ # https://github.com/pytorch/vision/blob/main/torchvision/ops/stochastic_depth.py
3
+ #
4
+ # License:
5
+ # BSD 3-Clause License
6
+ #
7
+ # Copyright (c) Soumith Chintala 2016,
8
+ # All rights reserved.
9
+ #
10
+ # Redistribution and use in source and binary forms, with or without
11
+ # modification, are permitted provided that the following conditions are met:
12
+ #
13
+ # * Redistributions of source code must retain the above copyright notice, this
14
+ # list of conditions and the following disclaimer.
15
+ #
16
+ # * Redistributions in binary form must reproduce the above copyright notice,
17
+ # this list of conditions and the following disclaimer in the documentation
18
+ # and/or other materials provided with the distribution.
19
+ #
20
+ # * Neither the name of the copyright holder nor the names of its
21
+ # contributors may be used to endorse or promote products derived from
22
+ # this software without specific prior written permission.
23
+ #
24
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
25
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
26
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
27
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
28
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
29
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
30
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
31
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
32
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
33
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
34
+
35
+ import torch
36
+ import torch.fx
37
+ from torch import nn, Tensor
38
+
39
+
40
+ def stochastic_depth(
41
+ input: Tensor, p: float, mode: str, training: bool = True
42
+ ) -> Tensor:
43
+ """
44
+ Implements the Stochastic Depth from `"Deep Networks with Stochastic Depth"
45
+ <https://arxiv.org/abs/1603.09382>`_ used for randomly dropping residual
46
+ branches of residual architectures.
47
+
48
+ Args:
49
+ input (Tensor[N, ...]): The input tensor or arbitrary dimensions with the first one
50
+ being its batch i.e. a batch with ``N`` rows.
51
+ p (float): probability of the input to be zeroed.
52
+ mode (str): ``"batch"`` or ``"row"``.
53
+ ``"batch"`` randomly zeroes the entire input, ``"row"`` zeroes
54
+ randomly selected rows from the batch.
55
+ training: apply stochastic depth if is ``True``. Default: ``True``
56
+
57
+ Returns:
58
+ Tensor[N, ...]: The randomly zeroed tensor.
59
+ """
60
+ if p < 0.0 or p > 1.0:
61
+ raise ValueError(f"drop probability has to be between 0 and 1, but got {p}")
62
+ if mode not in ["batch", "row"]:
63
+ raise ValueError(f"mode has to be either 'batch' or 'row', but got {mode}")
64
+ if not training or p == 0.0:
65
+ return input
66
+
67
+ survival_rate = 1.0 - p
68
+ if mode == "row":
69
+ size = [input.shape[0]] + [1] * (input.ndim - 1)
70
+ else:
71
+ size = [1] * input.ndim
72
+ noise = torch.empty(size, dtype=input.dtype, device=input.device)
73
+ noise = noise.bernoulli_(survival_rate)
74
+ if survival_rate > 0.0:
75
+ noise.div_(survival_rate)
76
+ return input * noise
77
+
78
+
79
+ torch.fx.wrap("stochastic_depth")
80
+
81
+
82
+ class StochasticDepth(nn.Module):
83
+ """
84
+ See :func:`stochastic_depth`.
85
+ """
86
+
87
+ def __init__(self, p: float, mode: str) -> None:
88
+ super().__init__()
89
+ self.p = p
90
+ self.mode = mode
91
+
92
+ def forward(self, input: Tensor) -> Tensor:
93
+ return stochastic_depth(input, self.p, self.mode, self.training)
94
+
95
+ def __repr__(self) -> str:
96
+ s = f"{self.__class__.__name__}(p={self.p}, mode={self.mode})"
97
+ return s