cxrmate-ed / configuration_cxrmate_ed.py
anicolson's picture
Upload model
a35023d verified
raw
history blame
8.26 kB
import transformers
from transformers.models.auto import CONFIG_MAPPING
class CXRMateEDConfig(transformers.PretrainedConfig):
model_type = 'cxrmate-ed'
def __init__(
self,
vision_config=None,
text_config=None,
index_value_encoder_intermediate_size: int = 2048,
include_time_delta: bool = True,
time_delta_monotonic_inversion: bool = True,
add_time_deltas: bool = True,
history: int = 0,
tables_filter: list = ['mimic_cxr_sectioned', 'triage', 'medrecon'],
prompt_report_sections_filter: list = ['indication', 'history'],
pad_token_id: int = 4,
**kwargs,
):
super().__init__(**kwargs)
self.vision_config = vision_config
self.index_value_encoder_intermediate_size = index_value_encoder_intermediate_size
self.include_time_delta = include_time_delta
self.time_delta_monotonic_inversion = time_delta_monotonic_inversion
self.add_time_deltas = add_time_deltas
self.history = history
self.tables_filter = tables_filter
self.prompt_report_sections_filter = prompt_report_sections_filter
self.pad_token_id = pad_token_id
if isinstance(vision_config, dict):
vision_config = transformers.AutoConfig.from_pretrained(
'aehrc/uniformer_base_tl_384',
trust_remote_code=True,
**vision_config,
)
self.vision_config = vision_config
if isinstance(text_config, dict):
text_config['model_type'] = text_config['model_type'] if 'model_type' in text_config else 'llama'
text_config = CONFIG_MAPPING[text_config['model_type']](**text_config)
self.text_config = text_config
# class CXRMateEDConfig(transformers.PretrainedConfig):
# model_type = 'cxrmate-ed'
# # def __init__(
# # self,
# # index_value_encoder_intermediate_size: int = 2048,
# # include_time_delta: bool = True,
# # time_delta_monotonic_inversion: bool = True,
# # add_time_deltas: bool = True,
# # history: int = 0,
# # tables_filter: list = ['mimic_cxr_sectioned', 'triage', 'medrecon'],
# # prompt_report_sections_filter: list = ['indication', 'history'],
# # pad_token_id: int = 4,
# # **kwargs: Any,
# # ) -> None:
# # super().__init__(**kwargs)
# # self.index_value_encoder_intermediate_size = index_value_encoder_intermediate_size
# # self.include_time_delta = include_time_delta
# # self.time_delta_monotonic_inversion = time_delta_monotonic_inversion
# # self.add_time_deltas = add_time_deltas
# # self.history = history
# # self.tables_filter = tables_filter
# # self.prompt_report_sections_filter = prompt_report_sections_filter
# # self.pad_token_id = pad_token_id
# # self.hidden_size = self.text_config.hidden_size
# def __init__(
# self,
# vision_config=None,
# text_config=None,
# # ignore_index=-100,
# # image_token_index=32000,
# # projector_hidden_act="gelu",
# # vision_feature_select_strategy="default",
# # vision_feature_layer=-2,
# # image_seq_length=576,
# index_value_encoder_intermediate_size: int = 2048,
# include_time_delta: bool = True,
# time_delta_monotonic_inversion: bool = True,
# add_time_deltas: bool = True,
# history: int = 0,
# tables_filter: list = ['mimic_cxr_sectioned', 'triage', 'medrecon'],
# prompt_report_sections_filter: list = ['indication', 'history'],
# pad_token_id: int = 4,
# **kwargs,
# ):
# transformers.PretrainedConfig.__init__(self, **kwargs)
# self.vision_config = vision_config
# self.text_config = text_config
# self.index_value_encoder_intermediate_size = index_value_encoder_intermediate_size
# self.include_time_delta = include_time_delta
# self.time_delta_monotonic_inversion = time_delta_monotonic_inversion
# self.add_time_deltas = add_time_deltas
# self.history = history
# self.tables_filter = tables_filter
# self.prompt_report_sections_filter = prompt_report_sections_filter
# self.pad_token_id = pad_token_id
# self.ignore_index = ignore_index
# self.image_token_index = image_token_index
# self.projector_hidden_act = projector_hidden_act
# self.image_seq_length = image_seq_length
# if vision_feature_select_strategy not in ["default", "full"]:
# raise ValueError(
# "vision_feature_select_strategy should be one of 'default', 'full'."
# f"Got: {vision_feature_select_strategy}"
# )
# self.vision_feature_select_strategy = vision_feature_select_strategy
# self.vision_feature_layer = vision_feature_layer
# if isinstance(vision_config, dict):
# vision_config["model_type"] = (
# vision_config["model_type"] if "model_type" in vision_config else "clip_vision_model"
# )
# vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
# elif vision_config is None:
# vision_config = CONFIG_MAPPING["clip_vision_model"](
# intermediate_size=4096,
# hidden_size=1024,
# patch_size=14,
# image_size=336,
# num_hidden_layers=24,
# num_attention_heads=16,
# vocab_size=32000,
# projection_dim=768,
# )
# if isinstance(text_config, dict):
# text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
# text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
# elif text_config is None:
# text_config = CONFIG_MAPPING["llama"]()
# super().__init__(**kwargs)
# import transformers
# from transformers.configuration_utils import PretrainedConfig
# from transformers.utils import logging
# logger = logging.get_logger(__name__)
# class CXRMateEDConfig(PretrainedConfig):
# model_type = "cxrmate-ed"
# def __init__(self, **kwargs):
# super().__init__(**kwargs)
# if 'decoder' not in kwargs:
# self.decoder = transformers.LlamaConfig(
# vocab_size=30000,
# hidden_size=768,
# intermediate_size=3072,
# num_attention_heads=12,
# num_hidden_layers=6,
# max_position_embeddings=2048,
# )
# self.decoder.is_decoder = True
# self.decoder.index_value_encoder_intermediate_size = 2048
# self.decoder.include_time_delta = True
# self.decoder.time_delta_monotonic_inversion = True
# self.decoder.add_time_deltas = True
# self.decoder.history = 0
# self.decoder.tables_filter = ["mimic_cxr_sectioned", "triage", "medrecon"]
# self.decoder.prompt_report_sections_filter = ["indication", "history"]
# self.decoder.pad_token_id = 4
# else:
# self.decoder = kwargs.pop("decoder")
# if 'encoder' not in kwargs:
# self.encoder = transformers.AutoConfig.from_pretrained(
# 'aehrc/uniformer_base_tl_384',
# projection_size=768,
# trust_remote_code=True,
# )
# else:
# self.encoder = kwargs.pop("encoder")
# self.is_encoder_decoder = True
# @classmethod
# def from_encoder_decoder_configs(
# cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs
# ) -> PretrainedConfig:
# logger.info("Set `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config")
# decoder_config.is_decoder = True
# decoder_config.add_cross_attention = True
# return cls(encoder=encoder_config, decoder=decoder_config, **kwargs)