|
from transformers.configuration_utils import PretrainedConfig |
|
from transformers.utils import logging |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class EncoderDecoderConfig(PretrainedConfig): |
|
|
|
model_type = "encoder-decoder" |
|
is_composition = True |
|
|
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
if "encoder" not in kwargs or "decoder" not in kwargs: |
|
raise ValueError( |
|
f"A configuraton of type {self.model_type} cannot be instantiated because " |
|
f"both `encoder` and `decoder` sub-configurations were not passed, only {kwargs}" |
|
) |
|
|
|
self.encoder = kwargs.pop("encoder") |
|
self.decoder = kwargs.pop("decoder") |
|
self.is_encoder_decoder = True |
|
|
|
@classmethod |
|
def from_encoder_decoder_configs( |
|
cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs |
|
) -> PretrainedConfig: |
|
r""" |
|
Instantiate a [`EncoderDecoderConfig`] (or a derived class) from a pre-trained encoder model configuration and |
|
decoder model configuration. |
|
|
|
Returns: |
|
[`EncoderDecoderConfig`]: An instance of a configuration object |
|
""" |
|
logger.info("Set `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config") |
|
decoder_config.is_decoder = True |
|
decoder_config.add_cross_attention = True |
|
|
|
return cls(encoder=encoder_config, decoder=decoder_config, **kwargs) |