from transformers import PretrainedConfig class HLMEncoderConfig(PretrainedConfig): def __init__( self, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_dropout_prob=0.1, layer_norm_eps=1e-7, sandwich=False, sandwich_size=0, **kwargs, ): super().__init__(**kwargs) self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.intermediate_size = intermediate_size self.dropout_prob = hidden_dropout_prob self.layer_norm_eps = layer_norm_eps if sandwich: self.sandwich_size = num_hidden_layers // 6 else: self.sandwich_size = sandwich_size class HLMConfig(PretrainedConfig): model_type = "hlm" def __init__( self, vocab_size=512, type_vocab_size=2, embedding_size=-1, max_seq_length=256, max_word_length=16, initializer_range=0.02, pad_token_id=0, intra_word_encoder={}, inter_word_encoder={}, residual_word_embedding=False, **kwargs, ): super().__init__(**kwargs) self.vocab_size = vocab_size self.type_vocab_size = type_vocab_size self.embedding_size = embedding_size self.initializer_range = initializer_range self.max_seq_length = max_seq_length self.max_word_length = max_word_length self.pad_token_id = pad_token_id self.intra_word_encoder = HLMEncoderConfig(**intra_word_encoder) self.inter_word_encoder = HLMEncoderConfig(**inter_word_encoder) self.hidden_size = self.inter_word_encoder.hidden_size self.residual_word_embedding = residual_word_embedding