import transformers class CXRMateEDConfig(transformers.PretrainedConfig): model_type = 'cxrmate-ed' def __init__( self, vision_config=None, text_config=None, index_value_encoder_intermediate_size: int = 2048, include_time_delta: bool = True, time_delta_monotonic_inversion: bool = True, add_time_deltas: bool = True, history: int = 0, tables_filter: list = ['mimic_cxr_sectioned', 'triage', 'medrecon'], prompt_report_sections_filter: list = ['indication', 'history'], pad_token_id: int = 4, **kwargs, ): super().__init__(**kwargs) self.vision_config = vision_config self.text_config = text_config self.index_value_encoder_intermediate_size = index_value_encoder_intermediate_size self.include_time_delta = include_time_delta self.time_delta_monotonic_inversion = time_delta_monotonic_inversion self.add_time_deltas = add_time_deltas self.history = history self.tables_filter = tables_filter self.prompt_report_sections_filter = prompt_report_sections_filter self.pad_token_id = pad_token_id # class CXRMateEDConfig(transformers.PretrainedConfig): # model_type = 'cxrmate-ed' # # def __init__( # # self, # # index_value_encoder_intermediate_size: int = 2048, # # include_time_delta: bool = True, # # time_delta_monotonic_inversion: bool = True, # # add_time_deltas: bool = True, # # history: int = 0, # # tables_filter: list = ['mimic_cxr_sectioned', 'triage', 'medrecon'], # # prompt_report_sections_filter: list = ['indication', 'history'], # # pad_token_id: int = 4, # # **kwargs: Any, # # ) -> None: # # super().__init__(**kwargs) # # self.index_value_encoder_intermediate_size = index_value_encoder_intermediate_size # # self.include_time_delta = include_time_delta # # self.time_delta_monotonic_inversion = time_delta_monotonic_inversion # # self.add_time_deltas = add_time_deltas # # self.history = history # # self.tables_filter = tables_filter # # self.prompt_report_sections_filter = prompt_report_sections_filter # # self.pad_token_id = pad_token_id # # self.hidden_size = self.text_config.hidden_size # def __init__( # self, # vision_config=None, # text_config=None, # # ignore_index=-100, # # image_token_index=32000, # # projector_hidden_act="gelu", # # vision_feature_select_strategy="default", # # vision_feature_layer=-2, # # image_seq_length=576, # index_value_encoder_intermediate_size: int = 2048, # include_time_delta: bool = True, # time_delta_monotonic_inversion: bool = True, # add_time_deltas: bool = True, # history: int = 0, # tables_filter: list = ['mimic_cxr_sectioned', 'triage', 'medrecon'], # prompt_report_sections_filter: list = ['indication', 'history'], # pad_token_id: int = 4, # **kwargs, # ): # transformers.PretrainedConfig.__init__(self, **kwargs) # self.vision_config = vision_config # self.text_config = text_config # self.index_value_encoder_intermediate_size = index_value_encoder_intermediate_size # self.include_time_delta = include_time_delta # self.time_delta_monotonic_inversion = time_delta_monotonic_inversion # self.add_time_deltas = add_time_deltas # self.history = history # self.tables_filter = tables_filter # self.prompt_report_sections_filter = prompt_report_sections_filter # self.pad_token_id = pad_token_id # self.ignore_index = ignore_index # self.image_token_index = image_token_index # self.projector_hidden_act = projector_hidden_act # self.image_seq_length = image_seq_length # if vision_feature_select_strategy not in ["default", "full"]: # raise ValueError( # "vision_feature_select_strategy should be one of 'default', 'full'." # f"Got: {vision_feature_select_strategy}" # ) # self.vision_feature_select_strategy = vision_feature_select_strategy # self.vision_feature_layer = vision_feature_layer # if isinstance(vision_config, dict): # vision_config["model_type"] = ( # vision_config["model_type"] if "model_type" in vision_config else "clip_vision_model" # ) # vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config) # elif vision_config is None: # vision_config = CONFIG_MAPPING["clip_vision_model"]( # intermediate_size=4096, # hidden_size=1024, # patch_size=14, # image_size=336, # num_hidden_layers=24, # num_attention_heads=16, # vocab_size=32000, # projection_dim=768, # ) # if isinstance(text_config, dict): # text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama" # text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) # elif text_config is None: # text_config = CONFIG_MAPPING["llama"]() # super().__init__(**kwargs) # import transformers # from transformers.configuration_utils import PretrainedConfig # from transformers.utils import logging # logger = logging.get_logger(__name__) # class CXRMateEDConfig(PretrainedConfig): # model_type = "cxrmate-ed" # def __init__(self, **kwargs): # super().__init__(**kwargs) # if 'decoder' not in kwargs: # self.decoder = transformers.LlamaConfig( # vocab_size=30000, # hidden_size=768, # intermediate_size=3072, # num_attention_heads=12, # num_hidden_layers=6, # max_position_embeddings=2048, # ) # self.decoder.is_decoder = True # self.decoder.index_value_encoder_intermediate_size = 2048 # self.decoder.include_time_delta = True # self.decoder.time_delta_monotonic_inversion = True # self.decoder.add_time_deltas = True # self.decoder.history = 0 # self.decoder.tables_filter = ["mimic_cxr_sectioned", "triage", "medrecon"] # self.decoder.prompt_report_sections_filter = ["indication", "history"] # self.decoder.pad_token_id = 4 # else: # self.decoder = kwargs.pop("decoder") # if 'encoder' not in kwargs: # self.encoder = transformers.AutoConfig.from_pretrained( # 'aehrc/uniformer_base_tl_384', # projection_size=768, # trust_remote_code=True, # ) # else: # self.encoder = kwargs.pop("encoder") # self.is_encoder_decoder = True # @classmethod # def from_encoder_decoder_configs( # cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs # ) -> PretrainedConfig: # logger.info("Set `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config") # decoder_config.is_decoder = True # decoder_config.add_cross_attention = True # return cls(encoder=encoder_config, decoder=decoder_config, **kwargs)