Markus28 commited on
Commit
8adf551
1 Parent(s): d4d5621

feat: implement task type embeddings (#1)

Browse files

- feat: implemented task_type_ids (f6fcfb57d67d5db63e31189ab677e6dc549d66c7)
- fix: use task_type_embeddings.weight parameter (3573e5b516885c3e4014ab180ccdf90edf3f5870)
- fixed fill_ error, updated config (65647ba4180e123fae3875be079bde1246f77c28)
- merged remote-tracking branch origin/main (4b32a64a15ab0a673dbc1a09ad60413860b20e75)

Files changed (2) hide show
  1. configuration_bert.py +2 -0
  2. modeling_bert.py +10 -0
configuration_bert.py CHANGED
@@ -81,6 +81,7 @@ class JinaBertConfig(PretrainedConfig):
81
  fused_dropout_add_ln=False,
82
  fused_bias_fc=False,
83
  pad_vocab_size_multiple=1,
 
84
  use_flash_attn=True,
85
  **kwargs,
86
  ):
@@ -107,4 +108,5 @@ class JinaBertConfig(PretrainedConfig):
107
  self.fused_dropout_add_ln = fused_dropout_add_ln
108
  self.fused_bias_fc = fused_bias_fc
109
  self.pad_vocab_size_multiple = pad_vocab_size_multiple
 
110
  self.use_flash_attn = use_flash_attn
 
81
  fused_dropout_add_ln=False,
82
  fused_bias_fc=False,
83
  pad_vocab_size_multiple=1,
84
+ num_tasks=0,
85
  use_flash_attn=True,
86
  **kwargs,
87
  ):
 
108
  self.fused_dropout_add_ln = fused_dropout_add_ln
109
  self.fused_bias_fc = fused_bias_fc
110
  self.pad_vocab_size_multiple = pad_vocab_size_multiple
111
+ self.num_tasks = num_tasks
112
  self.use_flash_attn = use_flash_attn
modeling_bert.py CHANGED
@@ -342,14 +342,21 @@ class BertModel(BertPreTrainedModel):
342
  self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
343
  self.encoder = BertEncoder(config)
344
  self.pooler = BertPooler(config) if add_pooling_layer else None
 
345
 
346
  self.apply(partial(_init_weights, initializer_range=config.initializer_range))
 
 
 
 
 
347
 
348
  def forward(
349
  self,
350
  input_ids,
351
  position_ids=None,
352
  token_type_ids=None,
 
353
  attention_mask=None,
354
  masked_tokens_mask=None,
355
  ):
@@ -361,6 +368,9 @@ class BertModel(BertPreTrainedModel):
361
  hidden_states = self.embeddings(
362
  input_ids, position_ids=position_ids, token_type_ids=token_type_ids
363
  )
 
 
 
364
  # TD [2022-12:18]: Don't need to force residual in fp32
365
  # BERT puts embedding LayerNorm before embedding dropout.
366
  if not self.fused_dropout_add_ln:
 
342
  self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
343
  self.encoder = BertEncoder(config)
344
  self.pooler = BertPooler(config) if add_pooling_layer else None
345
+ self.task_type_embeddings = nn.Embedding(config.num_tasks, config.hidden_size)
346
 
347
  self.apply(partial(_init_weights, initializer_range=config.initializer_range))
348
+ # We now initialize the task embeddings to 0; We do not use task types during
349
+ # pretraining. When we start using task types during embedding training,
350
+ # we want the model to behave exactly as in pretraining (i.e. task types
351
+ # have no effect).
352
+ nn.init.zeros_(self.task_type_embeddings.weight)
353
 
354
  def forward(
355
  self,
356
  input_ids,
357
  position_ids=None,
358
  token_type_ids=None,
359
+ task_type_ids=None,
360
  attention_mask=None,
361
  masked_tokens_mask=None,
362
  ):
 
368
  hidden_states = self.embeddings(
369
  input_ids, position_ids=position_ids, token_type_ids=token_type_ids
370
  )
371
+ if task_type_ids is not None:
372
+ hidden_states = hidden_states + self.task_type_embeddings(task_type_ids)
373
+
374
  # TD [2022-12:18]: Don't need to force residual in fp32
375
  # BERT puts embedding LayerNorm before embedding dropout.
376
  if not self.fused_dropout_add_ln: