from typing import Any | |
import transformers | |
class CXRMateEDConfig(transformers.PretrainedConfig): | |
model_type = 'cxrmate-ed' | |
# def __init__( | |
# self, | |
# index_value_encoder_intermediate_size: int = 2048, | |
# include_time_delta: bool = True, | |
# time_delta_monotonic_inversion: bool = True, | |
# add_time_deltas: bool = True, | |
# history: int = 0, | |
# tables_filter: list = ['mimic_cxr_sectioned', 'triage', 'medrecon'], | |
# prompt_report_sections_filter: list = ['indication', 'history'], | |
# pad_token_id: int = 4, | |
# **kwargs: Any, | |
# ) -> None: | |
# super().__init__(**kwargs) | |
# self.index_value_encoder_intermediate_size = index_value_encoder_intermediate_size | |
# self.include_time_delta = include_time_delta | |
# self.time_delta_monotonic_inversion = time_delta_monotonic_inversion | |
# self.add_time_deltas = add_time_deltas | |
# self.history = history | |
# self.tables_filter = tables_filter | |
# self.prompt_report_sections_filter = prompt_report_sections_filter | |
# self.pad_token_id = pad_token_id | |
# self.hidden_size = self.text_config.hidden_size | |
def __init__( | |
self, | |
vision_config=None, | |
text_config=None, | |
# ignore_index=-100, | |
# image_token_index=32000, | |
# projector_hidden_act="gelu", | |
# vision_feature_select_strategy="default", | |
# vision_feature_layer=-2, | |
# image_seq_length=576, | |
index_value_encoder_intermediate_size: int = 2048, | |
include_time_delta: bool = True, | |
time_delta_monotonic_inversion: bool = True, | |
add_time_deltas: bool = True, | |
history: int = 0, | |
tables_filter: list = ['mimic_cxr_sectioned', 'triage', 'medrecon'], | |
prompt_report_sections_filter: list = ['indication', 'history'], | |
pad_token_id: int = 4, | |
**kwargs, | |
): | |
transformers.PretrainedConfig.__init__(self, **kwargs) | |
self.vision_config = vision_config | |
self.text_config = text_config | |
self.index_value_encoder_intermediate_size = index_value_encoder_intermediate_size | |
self.include_time_delta = include_time_delta | |
self.time_delta_monotonic_inversion = time_delta_monotonic_inversion | |
self.add_time_deltas = add_time_deltas | |
self.history = history | |
self.tables_filter = tables_filter | |
self.prompt_report_sections_filter = prompt_report_sections_filter | |
self.pad_token_id = pad_token_id | |
# self.ignore_index = ignore_index | |
# self.image_token_index = image_token_index | |
# self.projector_hidden_act = projector_hidden_act | |
# self.image_seq_length = image_seq_length | |
# if vision_feature_select_strategy not in ["default", "full"]: | |
# raise ValueError( | |
# "vision_feature_select_strategy should be one of 'default', 'full'." | |
# f"Got: {vision_feature_select_strategy}" | |
# ) | |
# self.vision_feature_select_strategy = vision_feature_select_strategy | |
# self.vision_feature_layer = vision_feature_layer | |
# if isinstance(vision_config, dict): | |
# vision_config["model_type"] = ( | |
# vision_config["model_type"] if "model_type" in vision_config else "clip_vision_model" | |
# ) | |
# vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config) | |
# elif vision_config is None: | |
# vision_config = CONFIG_MAPPING["clip_vision_model"]( | |
# intermediate_size=4096, | |
# hidden_size=1024, | |
# patch_size=14, | |
# image_size=336, | |
# num_hidden_layers=24, | |
# num_attention_heads=16, | |
# vocab_size=32000, | |
# projection_dim=768, | |
# ) | |
# if isinstance(text_config, dict): | |
# text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama" | |
# text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) | |
# elif text_config is None: | |
# text_config = CONFIG_MAPPING["llama"]() | |
# super().__init__(**kwargs) | |
# import transformers | |
# from transformers.configuration_utils import PretrainedConfig | |
# from transformers.utils import logging | |
# logger = logging.get_logger(__name__) | |
# class CXRMateEDConfig(PretrainedConfig): | |
# model_type = "cxrmate-ed" | |
# def __init__(self, **kwargs): | |
# super().__init__(**kwargs) | |
# if 'decoder' not in kwargs: | |
# self.decoder = transformers.LlamaConfig( | |
# vocab_size=30000, | |
# hidden_size=768, | |
# intermediate_size=3072, | |
# num_attention_heads=12, | |
# num_hidden_layers=6, | |
# max_position_embeddings=2048, | |
# ) | |
# self.decoder.is_decoder = True | |
# self.decoder.index_value_encoder_intermediate_size = 2048 | |
# self.decoder.include_time_delta = True | |
# self.decoder.time_delta_monotonic_inversion = True | |
# self.decoder.add_time_deltas = True | |
# self.decoder.history = 0 | |
# self.decoder.tables_filter = ["mimic_cxr_sectioned", "triage", "medrecon"] | |
# self.decoder.prompt_report_sections_filter = ["indication", "history"] | |
# self.decoder.pad_token_id = 4 | |
# else: | |
# self.decoder = kwargs.pop("decoder") | |
# if 'encoder' not in kwargs: | |
# self.encoder = transformers.AutoConfig.from_pretrained( | |
# 'aehrc/uniformer_base_tl_384', | |
# projection_size=768, | |
# trust_remote_code=True, | |
# ) | |
# else: | |
# self.encoder = kwargs.pop("encoder") | |
# self.is_encoder_decoder = True | |
# @classmethod | |
# def from_encoder_decoder_configs( | |
# cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs | |
# ) -> PretrainedConfig: | |
# logger.info("Set `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config") | |
# decoder_config.is_decoder = True | |
# decoder_config.add_cross_attention = True | |
# return cls(encoder=encoder_config, decoder=decoder_config, **kwargs) |