feat: implement task type embeddings (#1)
Browse files- feat: implemented task_type_ids (f6fcfb57d67d5db63e31189ab677e6dc549d66c7)
- fix: use task_type_embeddings.weight parameter (3573e5b516885c3e4014ab180ccdf90edf3f5870)
- fixed fill_ error, updated config (65647ba4180e123fae3875be079bde1246f77c28)
- merged remote-tracking branch origin/main (4b32a64a15ab0a673dbc1a09ad60413860b20e75)
- configuration_bert.py +2 -0
- modeling_bert.py +10 -0
configuration_bert.py
CHANGED
@@ -81,6 +81,7 @@ class JinaBertConfig(PretrainedConfig):
|
|
81 |
fused_dropout_add_ln=False,
|
82 |
fused_bias_fc=False,
|
83 |
pad_vocab_size_multiple=1,
|
|
|
84 |
use_flash_attn=True,
|
85 |
**kwargs,
|
86 |
):
|
@@ -107,4 +108,5 @@ class JinaBertConfig(PretrainedConfig):
|
|
107 |
self.fused_dropout_add_ln = fused_dropout_add_ln
|
108 |
self.fused_bias_fc = fused_bias_fc
|
109 |
self.pad_vocab_size_multiple = pad_vocab_size_multiple
|
|
|
110 |
self.use_flash_attn = use_flash_attn
|
|
|
81 |
fused_dropout_add_ln=False,
|
82 |
fused_bias_fc=False,
|
83 |
pad_vocab_size_multiple=1,
|
84 |
+
num_tasks=0,
|
85 |
use_flash_attn=True,
|
86 |
**kwargs,
|
87 |
):
|
|
|
108 |
self.fused_dropout_add_ln = fused_dropout_add_ln
|
109 |
self.fused_bias_fc = fused_bias_fc
|
110 |
self.pad_vocab_size_multiple = pad_vocab_size_multiple
|
111 |
+
self.num_tasks = num_tasks
|
112 |
self.use_flash_attn = use_flash_attn
|
modeling_bert.py
CHANGED
@@ -342,14 +342,21 @@ class BertModel(BertPreTrainedModel):
|
|
342 |
self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
343 |
self.encoder = BertEncoder(config)
|
344 |
self.pooler = BertPooler(config) if add_pooling_layer else None
|
|
|
345 |
|
346 |
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
|
|
|
|
|
|
|
|
|
|
347 |
|
348 |
def forward(
|
349 |
self,
|
350 |
input_ids,
|
351 |
position_ids=None,
|
352 |
token_type_ids=None,
|
|
|
353 |
attention_mask=None,
|
354 |
masked_tokens_mask=None,
|
355 |
):
|
@@ -361,6 +368,9 @@ class BertModel(BertPreTrainedModel):
|
|
361 |
hidden_states = self.embeddings(
|
362 |
input_ids, position_ids=position_ids, token_type_ids=token_type_ids
|
363 |
)
|
|
|
|
|
|
|
364 |
# TD [2022-12:18]: Don't need to force residual in fp32
|
365 |
# BERT puts embedding LayerNorm before embedding dropout.
|
366 |
if not self.fused_dropout_add_ln:
|
|
|
342 |
self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
343 |
self.encoder = BertEncoder(config)
|
344 |
self.pooler = BertPooler(config) if add_pooling_layer else None
|
345 |
+
self.task_type_embeddings = nn.Embedding(config.num_tasks, config.hidden_size)
|
346 |
|
347 |
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
348 |
+
# We now initialize the task embeddings to 0; We do not use task types during
|
349 |
+
# pretraining. When we start using task types during embedding training,
|
350 |
+
# we want the model to behave exactly as in pretraining (i.e. task types
|
351 |
+
# have no effect).
|
352 |
+
nn.init.zeros_(self.task_type_embeddings.weight)
|
353 |
|
354 |
def forward(
|
355 |
self,
|
356 |
input_ids,
|
357 |
position_ids=None,
|
358 |
token_type_ids=None,
|
359 |
+
task_type_ids=None,
|
360 |
attention_mask=None,
|
361 |
masked_tokens_mask=None,
|
362 |
):
|
|
|
368 |
hidden_states = self.embeddings(
|
369 |
input_ids, position_ids=position_ids, token_type_ids=token_type_ids
|
370 |
)
|
371 |
+
if task_type_ids is not None:
|
372 |
+
hidden_states = hidden_states + self.task_type_embeddings(task_type_ids)
|
373 |
+
|
374 |
# TD [2022-12:18]: Don't need to force residual in fp32
|
375 |
# BERT puts embedding LayerNorm before embedding dropout.
|
376 |
if not self.fused_dropout_add_ln:
|