cxrmate-ed / configuration_cxrmate_ed.py
anicolson's picture
Upload model
9691248 verified
raw
history blame
1.47 kB
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class EncoderDecoderConfig(PretrainedConfig):
model_type = "encoder-decoder"
is_composition = True
def __init__(self, **kwargs):
super().__init__(**kwargs)
if "encoder" not in kwargs or "decoder" not in kwargs:
raise ValueError(
f"A configuraton of type {self.model_type} cannot be instantiated because "
f"both `encoder` and `decoder` sub-configurations were not passed, only {kwargs}"
)
self.encoder = kwargs.pop("encoder")
self.decoder = kwargs.pop("decoder")
self.is_encoder_decoder = True
@classmethod
def from_encoder_decoder_configs(
cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs
) -> PretrainedConfig:
r"""
Instantiate a [`EncoderDecoderConfig`] (or a derived class) from a pre-trained encoder model configuration and
decoder model configuration.
Returns:
[`EncoderDecoderConfig`]: An instance of a configuration object
"""
logger.info("Set `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config")
decoder_config.is_decoder = True
decoder_config.add_cross_attention = True
return cls(encoder=encoder_config, decoder=decoder_config, **kwargs)