jupyterjazz commited on
Commit
943cec2
1 Parent(s): c55e591

feat: truncation option during init

Browse files

Signed-off-by: jupyterjazz <saba.sturua@jina.ai>

configuration_xlm_roberta.py CHANGED
@@ -32,6 +32,7 @@ class XLMRobertaFlashConfig(PretrainedConfig):
32
  torch_dtype=None,
33
  emb_pooler=None,
34
  matryoshka_dimensions=None,
 
35
  **kwargs,
36
  ):
37
  super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
@@ -61,6 +62,7 @@ class XLMRobertaFlashConfig(PretrainedConfig):
61
  self.use_flash_attn = use_flash_attn
62
  self.emb_pooler = emb_pooler
63
  self.matryoshka_dimensions = matryoshka_dimensions
 
64
  if torch_dtype and hasattr(torch, torch_dtype) and type(getattr(torch, torch_dtype)) is torch.dtype:
65
  self.torch_dtype = getattr(torch, torch_dtype)
66
  else:
 
32
  torch_dtype=None,
33
  emb_pooler=None,
34
  matryoshka_dimensions=None,
35
+ truncate_dim=None,
36
  **kwargs,
37
  ):
38
  super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
 
62
  self.use_flash_attn = use_flash_attn
63
  self.emb_pooler = emb_pooler
64
  self.matryoshka_dimensions = matryoshka_dimensions
65
+ self.truncate_dim = truncate_dim
66
  if torch_dtype and hasattr(torch, torch_dtype) and type(getattr(torch, torch_dtype)) is torch.dtype:
67
  self.torch_dtype = getattr(torch, torch_dtype)
68
  else:
modeling_xlm_roberta.py CHANGED
@@ -578,6 +578,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
578
 
579
  all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
580
 
 
581
  if truncate_dim:
582
  all_embeddings = self.truncate_embeddings(all_embeddings, truncate_dim)
583
 
 
578
 
579
  all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
580
 
581
+ truncate_dim = truncate_dim or self.config.truncate_dim
582
  if truncate_dim:
583
  all_embeddings = self.truncate_embeddings(all_embeddings, truncate_dim)
584