Update modeling_nort5.py
Browse files- modeling_nort5.py +18 -1
modeling_nort5.py
CHANGED
@@ -387,7 +387,24 @@ class NorT5Model(NorT5PreTrainedModel):
|
|
387 |
self.embedding.word_embedding = value
|
388 |
|
389 |
def get_encoder(self):
|
390 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
391 |
|
392 |
def get_decoder(self):
|
393 |
return self.get_decoder_output
|
|
|
387 |
self.embedding.word_embedding = value
|
388 |
|
389 |
def get_encoder(self):
|
390 |
+
class EncoderWrapper:
|
391 |
+
def __call__(cls, *args, **kwargs):
|
392 |
+
return cls.forward(*args, **kwargs)
|
393 |
+
|
394 |
+
def forward(
|
395 |
+
cls,
|
396 |
+
input_ids: Optional[torch.Tensor] = None,
|
397 |
+
attention_mask: Optional[torch.Tensor] = None,
|
398 |
+
output_hidden_states: Optional[bool] = None,
|
399 |
+
output_attentions: Optional[bool] = None,
|
400 |
+
return_dict: Optional[bool] = None,
|
401 |
+
):
|
402 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
403 |
+
|
404 |
+
return self.get_encoder_output(
|
405 |
+
input_ids, attention_mask, output_hidden_states, output_attentions, return_dict=return_dict
|
406 |
+
)
|
407 |
+
return EncoderWrapper()
|
408 |
|
409 |
def get_decoder(self):
|
410 |
return self.get_decoder_output
|