|
import json |
|
import math |
|
import os |
|
import random |
|
from typing import Optional, Tuple, Union |
|
|
|
import datasets |
|
import torch |
|
import transformers |
|
from huggingface_hub import hf_hub_download |
|
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD |
|
from torch.nn import CrossEntropyLoss |
|
from torch.utils.data import Subset |
|
from torchvision.io import decode_image |
|
from torchvision.transforms import v2 |
|
from transformers import PreTrainedTokenizerFast |
|
from transformers.modeling_outputs import ModelOutput, Seq2SeqLMOutput |
|
from transformers.utils import check_min_version, logging |
|
|
|
from .configuration_cxrmate_ed import CXRMateEDConfig |
|
from .dataset import PriorsDataset |
|
from .prepare_dataset import prepare_dataset |
|
from .utils import compute_time_delta |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
VIEW_ORDER = [None, 'LPO', 'RAO', 'LAO', 'SWIMMERS', 'XTABLE LATERAL', 'LL', 'LATERAL', 'AP AXIAL', 'AP RLD', 'AP LLD', 'AP', 'PA RLD', 'PA LLD', 'PA'] |
|
|
|
|
|
def create_lookup_table(df, columns, start_idx): |
|
df = df.groupby(columns).head(1)[columns].sort_values(by=columns) |
|
indices = range(start_idx, start_idx + len(df)) |
|
df['index'] = indices |
|
return df, indices[-1] |
|
|
|
|
|
class FNNEncoder(torch.nn.Module): |
|
def __init__(self, num_features, intermediate_size, decoder_hidden_size): |
|
super().__init__() |
|
self.up_proj = torch.nn.Linear(num_features, intermediate_size, bias=False) |
|
self.down_proj = torch.nn.Linear(intermediate_size, decoder_hidden_size, bias=False) |
|
self.act_fn = torch.nn.SiLU() |
|
|
|
def forward(self, x): |
|
return self.down_proj(self.act_fn(self.up_proj(x))) |
|
|
|
|
|
class ProjectionHead(torch.nn.Module): |
|
|
|
def __init__(self, input_size, hidden_size) -> None: |
|
super().__init__() |
|
|
|
|
|
self.layer_norm = torch.nn.LayerNorm(input_size, eps=1e-6) |
|
|
|
|
|
self.projection = torch.nn.Linear(input_size, hidden_size, bias=False) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x = self.layer_norm(x) |
|
x = self.projection(x) |
|
return x |
|
|
|
|
|
class CXRStudyImagesEncoder(torch.nn.Module): |
|
def __init__(self, encoder, decoder_config): |
|
super().__init__() |
|
|
|
self.encoder = encoder |
|
self.config = encoder.config |
|
self.adapter = ProjectionHead(self.config.embed_dim[-1], decoder_config.hidden_size) |
|
|
|
def forward( |
|
self, |
|
pixel_values: Optional[torch.Tensor] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, ModelOutput]: |
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
assert len(pixel_values.shape) == 5, 'pixel_values must be B, S, C, H, W, where S is the max number of images for a study in the batch.' |
|
last_hidden_state = self.encoder(pixel_values.view(-1, *pixel_values.shape[2:])).last_hidden_state |
|
|
|
|
|
last_hidden_state = torch.flatten(last_hidden_state, 2) if last_hidden_state.dim() > 3 else last_hidden_state |
|
|
|
|
|
last_hidden_state = self.adapter(last_hidden_state) |
|
|
|
|
|
last_hidden_state = last_hidden_state.view(pixel_values.shape[0], -1, last_hidden_state.shape[-1]) |
|
|
|
|
|
mask = (pixel_values[:, :, 0, 0, 0] != 0.0)[:, :, None] |
|
attention_mask = torch.ones( |
|
[last_hidden_state.shape[0], pixel_values.shape[1], last_hidden_state.shape[1] // pixel_values.shape[1]], |
|
dtype=torch.long, |
|
device=mask.device, |
|
) |
|
attention_mask = attention_mask * mask |
|
attention_mask = attention_mask.view(attention_mask.shape[0], -1) |
|
|
|
if not return_dict: |
|
return last_hidden_state |
|
|
|
return ModelOutput(last_hidden_state=last_hidden_state, attention_mask=attention_mask) |
|
|
|
|
|
class CXRMateEDModel(transformers.LlavaForConditionalGeneration): |
|
|
|
config_class = CXRMateEDConfig |
|
|
|
def __init__(self, config: CXRMateEDConfig): |
|
|
|
check_min_version("4.46.0.dev0") |
|
|
|
super(transformers.LlavaPreTrainedModel, self).__init__(config) |
|
|
|
self.config = config |
|
|
|
self.vocab_size = config.text_config.vocab_size |
|
|
|
self.image_encoder = transformers.AutoModel.from_config(self.config.vision_config, trust_remote_code=True) |
|
|
|
self.language_model = transformers.AutoModelForCausalLM.from_config( |
|
config.text_config, |
|
attn_implementation=config._attn_implementation, |
|
) |
|
|
|
self.image_encoder = CXRStudyImagesEncoder(self.image_encoder, config.text_config) |
|
|
|
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tables.json') |
|
path = path if os.path.exists(path) else hf_hub_download(repo_id='aehrc/cxrmate-ed', filename='tables.json') |
|
with open(path, 'r') as f: |
|
self.tables = json.load(f) |
|
|
|
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'lookup_tables.json') |
|
path = path if os.path.exists(path) else hf_hub_download(repo_id='aehrc/cxrmate-ed', filename='lookup_tables.json') |
|
with open(path, 'r') as f: |
|
self.luts = json.load(f) |
|
|
|
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'token_type_ids.json') |
|
path = path if os.path.exists(path) else hf_hub_download(repo_id='aehrc/cxrmate-ed', filename='token_type_ids.json') |
|
with open(path, 'r') as f: |
|
self.token_type_to_token_type_id = json.load(f) |
|
|
|
self.tables = {k: self.tables[k] for k in self.config.tables_filter} |
|
self.tables['mimic_cxr_sectioned']['text_columns'] = self.config.prompt_report_sections_filter |
|
|
|
for k in self.tables.keys(): |
|
if self.luts[k]['total'] > 0: |
|
setattr( |
|
self, |
|
f'{k}_index_value_encoder', |
|
FNNEncoder( |
|
num_features=self.luts[k]['total'], |
|
intermediate_size=self.config.index_value_encoder_intermediate_size, |
|
decoder_hidden_size=self.config.text_config.hidden_size, |
|
), |
|
) |
|
|
|
if self.config.add_time_deltas: |
|
self.time_delta_encoder = FNNEncoder( |
|
num_features=1, |
|
intermediate_size=self.config.index_value_encoder_intermediate_size, |
|
decoder_hidden_size=self.config.text_config.hidden_size, |
|
) |
|
|
|
self.token_type_embeddings = torch.nn.Embedding(max(self.token_type_to_token_type_id.values()) + 1, self.config.text_config.hidden_size) |
|
|
|
self.time_delta_map = lambda x: 1 / math.sqrt(x + 1) |
|
self.zero_time_delta_value = self.time_delta_map(0) |
|
|
|
self.inf_time_delta_value = self.time_delta_map(float('inf')) |
|
|
|
|
|
self.train_transforms = v2.Compose( |
|
[ |
|
v2.Grayscale(num_output_channels=3), |
|
v2.Resize( |
|
size=self.config.vision_config.image_size, |
|
antialias=True, |
|
interpolation=v2.InterpolationMode.BICUBIC, |
|
), |
|
v2.RandomCrop( |
|
size=[self.config.vision_config.image_size, self.config.vision_config.image_size], |
|
pad_if_needed=True, |
|
), |
|
v2.RandomRotation(degrees=5), |
|
v2.ToDtype(torch.float32, scale=True), |
|
v2.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), |
|
] |
|
) |
|
self.test_transforms = v2.Compose( |
|
[ |
|
v2.Grayscale(num_output_channels=3), |
|
v2.Resize( |
|
size=self.config.vision_config.image_size, |
|
antialias=True, |
|
interpolation=v2.InterpolationMode.BICUBIC, |
|
), |
|
v2.CenterCrop(size=[self.config.vision_config.image_size, self.config.vision_config.image_size]), |
|
v2.ToDtype(torch.float32, scale=True), |
|
v2.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), |
|
] |
|
) |
|
|
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
decoder_position_ids: torch.LongTensor, |
|
decoder_attention_mask: torch.FloatTensor, |
|
decoder_token_type_ids: torch.LongTensor, |
|
decoder_input_ids: Optional[torch.LongTensor] = None, |
|
encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, |
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
|
decoder_inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
**kwargs, |
|
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: |
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
kwargs_decoder = { |
|
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") |
|
} |
|
|
|
if decoder_inputs_embeds is None: |
|
decoder_inputs_embeds = self.language_model.get_input_embeddings()(decoder_input_ids) |
|
decoder_inputs_embeds += self.token_type_embeddings(decoder_token_type_ids) |
|
|
|
if decoder_attention_mask.dim() == 4: |
|
assert decoder_attention_mask.dtype == decoder_inputs_embeds.dtype, f'The dtype for {decoder_attention_mask} was {decoder_attention_mask.dtype}. It should be {decoder_inputs_embeds.dtype}' |
|
else: |
|
assert decoder_attention_mask.dtype == torch.long, f'The dtype for {decoder_attention_mask} was {decoder_attention_mask.dtype}. It should be torch.long' |
|
|
|
|
|
decoder_outputs = self.language_model( |
|
inputs_embeds=decoder_inputs_embeds, |
|
attention_mask=decoder_attention_mask, |
|
position_ids=decoder_position_ids, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
use_cache=use_cache, |
|
past_key_values=past_key_values, |
|
return_dict=return_dict, |
|
**kwargs_decoder, |
|
) |
|
|
|
|
|
loss = None |
|
if labels is not None: |
|
logits = decoder_outputs.logits if return_dict else decoder_outputs[0] |
|
loss_fct = CrossEntropyLoss() |
|
loss = loss_fct(logits.reshape(-1, self.vocab_size), labels.reshape(-1)) |
|
|
|
if not return_dict: |
|
if loss is not None: |
|
return (loss,) + decoder_outputs + encoder_outputs |
|
else: |
|
return decoder_outputs + encoder_outputs |
|
|
|
return Seq2SeqLMOutput( |
|
loss=loss, |
|
logits=decoder_outputs.logits, |
|
past_key_values=decoder_outputs.past_key_values, |
|
decoder_hidden_states=decoder_outputs.hidden_states, |
|
decoder_attentions=decoder_outputs.attentions, |
|
) |
|
|
|
def prepare_inputs_for_generation( |
|
self, |
|
input_ids, |
|
special_token_ids, |
|
prompt_attention_mask, |
|
prompt_position_ids, |
|
past_key_values=None, |
|
use_cache=None, |
|
**kwargs, |
|
): |
|
""" |
|
Modification of: |
|
https://github.com/huggingface/transformers/blob/main/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py#L660 |
|
""" |
|
|
|
report_attention_mask = (input_ids != self.config.pad_token_id).long() |
|
|
|
if len(past_key_values) == 0: |
|
|
|
|
|
decoder_attention_mask = self.create_4d_attention_mask_mixed_causality( |
|
prompt_attention_mask, report_attention_mask, dtype=kwargs['decoder_inputs_embeds'].dtype, |
|
) |
|
|
|
|
|
report_position_ids = report_attention_mask.cumsum(-1) + prompt_position_ids.max(dim=1).values[:, None] |
|
report_position_ids.masked_fill_(report_attention_mask == 0, 1) |
|
decoder_position_ids = torch.cat([prompt_position_ids, report_position_ids], dim=1) |
|
|
|
|
|
inputs_embeds = torch.cat([kwargs['decoder_inputs_embeds'], self.language_model.get_input_embeddings()(input_ids)], dim=1) |
|
|
|
decoder_token_type_ids = self.token_ids_to_token_type_ids( |
|
input_ids, special_token_ids, |
|
[self.token_type_to_token_type_id['findings'], self.token_type_to_token_type_id['impression']], |
|
) |
|
decoder_token_type_ids = torch.cat( |
|
[ |
|
kwargs['decoder_token_type_ids'], |
|
decoder_token_type_ids, |
|
], |
|
dim=1, |
|
) |
|
|
|
input_dict = { |
|
'decoder_input_ids': input_ids, |
|
'decoder_inputs_embeds': inputs_embeds, |
|
'decoder_token_type_ids': decoder_token_type_ids, |
|
} |
|
else: |
|
|
|
|
|
decoder_attention_mask = self.create_4d_attention_mask_mixed_causality_past_key_values( |
|
prompt_attention_mask, report_attention_mask, dtype=kwargs['decoder_inputs_embeds'].dtype, |
|
) |
|
|
|
|
|
decoder_position_ids = report_attention_mask.cumsum(-1) + prompt_position_ids.max(dim=1).values[:, None] |
|
decoder_position_ids.masked_fill_(report_attention_mask == 0, 1) |
|
|
|
|
|
decoder_token_type_ids = self.token_ids_to_token_type_ids_past_key_values( |
|
input_ids, |
|
special_token_ids, |
|
[self.token_type_to_token_type_id['findings'], self.token_type_to_token_type_id['impression']], |
|
) |
|
decoder_position_ids = decoder_position_ids[:, -1:] |
|
|
|
past_length = past_key_values[0][0].shape[2] |
|
|
|
|
|
if input_ids.shape[1] > past_length: |
|
remove_prefix_length = past_length |
|
else: |
|
|
|
remove_prefix_length = input_ids.shape[1] - 1 |
|
|
|
input_ids = input_ids[:, remove_prefix_length:] |
|
|
|
input_dict = {'decoder_input_ids': input_ids, 'decoder_token_type_ids': decoder_token_type_ids} |
|
|
|
input_dict.update( |
|
{ |
|
'decoder_attention_mask': decoder_attention_mask, |
|
'decoder_position_ids': decoder_position_ids, |
|
'past_key_values': past_key_values, |
|
'use_cache': use_cache, |
|
} |
|
) |
|
return input_dict |
|
|
|
def token_ids_to_token_type_ids(self, token_ids, special_token_ids, token_type_id_sections): |
|
""" |
|
Extract token type identifiers from the token identifiers. |
|
|
|
Argument/s: |
|
token_ids - token identifiers. |
|
special_token_ids - special token identifiers that indicate the separation between sections. |
|
token_type_id_section - token type identifier for each section. |
|
|
|
Returns: |
|
token_type_ids - token type identifiers. |
|
""" |
|
|
|
token_type_id_sections = token_type_id_sections if token_type_id_sections is not None else list(range(len(special_token_ids) + 1)) |
|
|
|
mbatch_size, seq_len = token_ids.shape |
|
token_type_ids = torch.full_like(token_ids, token_type_id_sections[0], dtype=torch.long, device=token_ids.device) |
|
|
|
for i, j in enumerate(special_token_ids): |
|
|
|
cols = (token_ids == j).int().argmax(dim=1) |
|
rows = torch.arange(mbatch_size, device=token_ids.device) |
|
|
|
|
|
cols += 1 |
|
|
|
|
|
|
|
rows = rows[torch.logical_and(cols != 1, cols < seq_len)] |
|
cols = cols[torch.logical_and(cols != 1, cols < seq_len)] |
|
|
|
|
|
if rows.nelement() != 0: |
|
ids = torch.stack([ |
|
torch.stack([x, z]) for (x, y) in zip(rows, cols) for z in torch.arange( |
|
y, seq_len, device=token_ids.device, |
|
) |
|
]) |
|
|
|
token_type_ids[ids[:, 0], ids[:, 1]] = token_type_id_sections[i + 1] |
|
|
|
return token_type_ids |
|
|
|
def token_ids_to_token_type_ids_past_key_values(self, token_ids, special_token_ids, token_type_id_sections): |
|
""" |
|
Extract token type identifiers from the token identifiers if past != None. Make sure to input all the |
|
token_ids (e.g., do not input input_ids = input_ids[:, remove_prefix_length:] from prepare_inputs_for_generation). |
|
|
|
Argument/s: |
|
token_ids - token identifiers. |
|
special_token_ids - special token identifiers that indicate the separation between sections. |
|
|
|
Returns: |
|
token_type_ids - token type identifiers. |
|
""" |
|
|
|
token_type_id_sections = token_type_id_sections if token_type_id_sections is not None else list(range(len(special_token_ids) + 1)) |
|
token_type_ids = torch.full([token_ids.shape[0], 1], token_type_id_sections[0], dtype=torch.long, device=token_ids.device) |
|
|
|
|
|
token_ids = token_ids[:, :-1] |
|
|
|
for i, j in enumerate(special_token_ids): |
|
|
|
|
|
exists = torch.any(token_ids == j, dim=1, keepdim=True) |
|
token_type_ids[exists] = token_type_id_sections[i + 1] |
|
|
|
return token_type_ids |
|
|
|
def tokenize_report_teacher_forcing(self, findings: str, impression: str, tokenizer: PreTrainedTokenizerFast, max_len: int): |
|
""" |
|
Tokenize the reports and creates the inputs and targets for teacher forcing. |
|
|
|
Argument/s: |
|
findings - findings sections. |
|
impression - impression sections. |
|
return_token_type_ids - return the token type identifiers. |
|
tokenizer - Hugging Face tokenizer. |
|
max_len - maximum number of tokens. |
|
|
|
Returns: |
|
decoder_input_ids - the token identifiers for the input of the decoder. |
|
decoder_attention_mask - the attention mask for the decoder_input_ids. |
|
label_ids - the label token identifiers for the decoder. |
|
""" |
|
|
|
|
|
reports = [f'{tokenizer.bos_token}{i}{tokenizer.sep_token}{j}{tokenizer.eos_token}' for i, j in |
|
zip(findings, impression)] |
|
|
|
|
|
tokenized = tokenizer( |
|
reports, |
|
padding='longest', |
|
truncation=True, |
|
max_length=max_len + 1, |
|
return_tensors='pt', |
|
return_token_type_ids=False, |
|
add_special_tokens=False, |
|
).to(self.device) |
|
|
|
|
|
batch_dict = { |
|
|
|
|
|
'label_ids': tokenized['input_ids'][:, 1:].detach().clone(), |
|
|
|
|
|
'decoder_input_ids': tokenized['input_ids'][:, :-1], |
|
|
|
|
|
'decoder_attention_mask': tokenized['attention_mask'][:, 1:], |
|
} |
|
|
|
return batch_dict |
|
|
|
def tokenize_report_teacher_forcing_rev_a(self, tokenizer: PreTrainedTokenizerFast, max_len: int, findings: Optional[str] = None, impression: Optional[str] = None, reports: Optional[str] = None): |
|
""" |
|
Tokenize the reports and creates the inputs and targets for teacher forcing. |
|
|
|
Argument/s: |
|
tokenizer - Hugging Face tokenizer. |
|
max_len - maximum number of tokens. |
|
findings - findings sections. |
|
impression - impression sections. |
|
reports - prepared reports, with special tokens and report sections. |
|
|
|
Returns: |
|
decoder_input_ids - the token identifiers for the input of the decoder. |
|
decoder_attention_mask - the attention mask for the decoder_input_ids. |
|
label_ids - the label token identifiers for the decoder. |
|
""" |
|
|
|
|
|
if reports is None: |
|
assert findings and impression, "If 'reports' is not defined, 'findings' and 'impression' need to be defined." |
|
reports = [f'{tokenizer.bos_token}{i}{tokenizer.sep_token}{j}{tokenizer.eos_token}' for i, j in |
|
zip(findings, impression)] |
|
|
|
|
|
tokenized = tokenizer( |
|
reports, |
|
padding='longest', |
|
truncation=True, |
|
max_length=max_len + 1, |
|
return_tensors='pt', |
|
return_token_type_ids=False, |
|
add_special_tokens=False, |
|
).to(self.device) |
|
|
|
|
|
batch_dict = { |
|
|
|
|
|
'label_ids': tokenized['input_ids'][:, 1:].detach().clone(), |
|
|
|
|
|
'decoder_input_ids': tokenized['input_ids'][:, :-1], |
|
|
|
|
|
'decoder_attention_mask': tokenized['attention_mask'][:, 1:], |
|
} |
|
|
|
return batch_dict |
|
|
|
def split_and_decode_sections(self, token_ids, special_token_ids, tokenizer: PreTrainedTokenizerFast): |
|
""" |
|
Split the token identifiers into sections, then convert the token identifiers into strings. |
|
|
|
Argument/s: |
|
token_ids - token identifiers. |
|
special_token_ids - special token identifiers that indicate the end of each section. |
|
tokenizer - Hugging Face tokenizer. |
|
|
|
Returns: |
|
token_type_ids - token type identifiers. |
|
""" |
|
|
|
_, seq_len = token_ids.shape |
|
|
|
|
|
num_sections = len(special_token_ids) |
|
|
|
sections = {k: [] for k in range(num_sections)} |
|
|
|
for i in token_ids: |
|
prev_col = 0 |
|
for j, k in enumerate(special_token_ids): |
|
|
|
|
|
if prev_col >= seq_len: |
|
sections[j].append('') |
|
continue |
|
|
|
|
|
col = (i == k).int().argmax().item() |
|
|
|
|
|
|
|
if col == 0: |
|
col = seq_len |
|
|
|
|
|
section_token_ids = i[prev_col:col] |
|
prev_col = col |
|
section_string = tokenizer.decode(section_token_ids, skip_special_tokens=True) |
|
|
|
sections[j].append(section_string) |
|
|
|
return tuple(sections.values()) |
|
|
|
def prepare_inputs( |
|
self, |
|
images, |
|
tokenizer: PreTrainedTokenizerFast, |
|
tokenized_report=None, |
|
sep_token_id=None, |
|
**batch, |
|
): |
|
""" |
|
Tokenize the text columns from MIMIC-IV ED and MIMIC-CXR (excluding the findings and impression sections). |
|
|
|
Argument/s: |
|
images - images. |
|
tokenizer - Hugging Face tokenizer. |
|
tokenized_report - if training/teacher forcing, input the tokenized_report dict to include it in the prepared inputs. |
|
separator_token_id - separator token identifier. |
|
|
|
Returns: |
|
inputs_embeds - input embeddings. |
|
attention_mask - attention mask. |
|
token_type_ids - token type identifiers. |
|
position_ids - position identifiers. |
|
bos_token_ids - bos_token_ids for generation. |
|
""" |
|
|
|
input_ids = [] |
|
inputs_embeds = [] |
|
token_type_ids = [] |
|
attention_mask = [] |
|
time_delta = [] |
|
position_ids = None |
|
bos_token_ids = None |
|
|
|
|
|
batch_size = images.shape[0] |
|
for k, v in self.tables.items(): |
|
if 'index_columns' in v or 'value_columns' in v: |
|
if f'{k}_index_value_feats' not in batch: |
|
batch[f'{k}_index_value_feats'] = torch.empty(batch_size, 0, self.luts[k]['total'], device=self.device) |
|
inputs_embeds.append( |
|
getattr(self, f'{k}_index_value_encoder')(batch[f'{k}_index_value_feats']) |
|
) |
|
token_type_ids.append(batch[f'{k}_index_value_token_type_ids'] if f'{k}_index_value_token_type_ids' in batch else torch.empty(batch_size, 0, dtype=torch.long, device=self.device)) |
|
attention_mask.append(batch[f'{k}_index_value_mask'] if f'{k}_index_value_mask' in batch else torch.empty(batch_size, 0, dtype=torch.long, device=self.device)) |
|
if f'{k}_index_value_time_delta' in batch: |
|
time_delta.append(batch[f'{k}_index_value_time_delta']) |
|
else: |
|
time_delta_index_value = torch.zeros(*batch[f'{k}_index_value_mask'].shape, 1, device=self.device) if f'{k}_index_value_mask' in batch else torch.empty(batch_size, 0, 1, device=self.device) |
|
time_delta.append(time_delta_index_value) |
|
|
|
|
|
tokenized = self.tokenize_text_prompt(tokenizer, **batch) |
|
input_ids.append(tokenized['input_ids']) |
|
token_type_ids.append(tokenized['token_type_ids']) |
|
attention_mask.append(tokenized['attention_mask']) |
|
time_delta.append(tokenized['time_delta']) |
|
|
|
|
|
encoder_outputs = self.image_encoder(images) |
|
inputs_embeds.append(encoder_outputs[0]) |
|
|
|
inputs_per_image = encoder_outputs[0].shape[-2] // images.shape[1] |
|
time_delta_image_features = torch.tensor(batch['image_time_deltas'], device=self.device).repeat_interleave(inputs_per_image, dim=1) |
|
token_type_ids.append( |
|
torch.where( |
|
torch.logical_or( |
|
time_delta_image_features == self.zero_time_delta_value, |
|
time_delta_image_features == self.inf_time_delta_value, |
|
), |
|
self.token_type_to_token_type_id['image'], |
|
self.token_type_to_token_type_id['prior_image'], |
|
), |
|
) |
|
attention_mask.append(encoder_outputs[1]) |
|
time_delta.append(time_delta_image_features[:, :, None]) |
|
|
|
|
|
input_ids = torch.cat(input_ids, dim=1) |
|
inputs_embeds.append(self.language_model.get_input_embeddings()(input_ids)) |
|
|
|
|
|
time_delta = torch.cat(time_delta, dim=1) |
|
inputs_embeds = torch.cat(inputs_embeds, dim=1) |
|
|
|
|
|
if time_delta.shape[1] > 0 and self.config.add_time_deltas: |
|
time_delta = time_delta.to(dtype=inputs_embeds.dtype) |
|
inputs_embeds += self.time_delta_encoder(time_delta) |
|
|
|
|
|
attention_mask = torch.cat(attention_mask, dim=1) |
|
|
|
|
|
position_ids = self.position_ids_from_time_deltas_and_attention_mask(time_delta, attention_mask) |
|
|
|
|
|
if tokenized_report is not None: |
|
inputs_embeds = torch.cat([inputs_embeds, self.language_model.get_input_embeddings()(tokenized_report['decoder_input_ids'])], dim=1) |
|
|
|
report_token_type_ids = self.token_ids_to_token_type_ids( |
|
token_ids=tokenized_report['decoder_input_ids'], |
|
special_token_ids=[sep_token_id], |
|
token_type_id_sections=[self.token_type_to_token_type_id['findings'], self.token_type_to_token_type_id['impression']], |
|
) |
|
token_type_ids.append(report_token_type_ids) |
|
|
|
|
|
report_position_ids = tokenized_report['decoder_attention_mask'].cumsum(-1) + position_ids.max(dim=1).values[:, None] |
|
report_position_ids.masked_fill_(tokenized_report['decoder_attention_mask'] == 0, 1) |
|
position_ids = torch.cat([position_ids, report_position_ids], dim=1) |
|
|
|
|
|
attention_mask = self.create_4d_attention_mask_mixed_causality(attention_mask, tokenized_report['decoder_attention_mask'], dtype=inputs_embeds.dtype) |
|
|
|
|
|
|
|
else: |
|
|
|
|
|
bos_token_ids = torch.full((encoder_outputs[0].shape[0], 1), tokenizer.bos_token_id, dtype=torch.long, device=self.device) |
|
|
|
|
|
token_type_ids = torch.cat(token_type_ids, dim=1) |
|
|
|
assert inputs_embeds.shape[1] == attention_mask.shape[-1] |
|
assert inputs_embeds.shape[1] == token_type_ids.shape[1] |
|
|
|
return inputs_embeds, attention_mask, token_type_ids, position_ids, bos_token_ids |
|
|
|
def tokenize_text_prompt(self, tokenizer: PreTrainedTokenizerFast, **kwargs): |
|
""" |
|
Tokenize the text columns from MIMIC-IV ED and MIMIC-CXR (excluding the findings and impression sections). |
|
Time deltas for the input_ids are also prepared here. |
|
|
|
Argument/s: |
|
tokenizer - Hugging Face tokenizer. |
|
|
|
Returns: |
|
ed - dictionary containing the input_ids, token_type_ids, attention_mask and time_deltas for the ED module columns. |
|
cxr - dictionary containing the input_ids, token_type_ids, and attention_mask for MIMIC-CXR columns. |
|
""" |
|
|
|
batch_size = len(kwargs['study_id']) |
|
|
|
tokenized = { |
|
'input_ids': {i: [] for i in range(batch_size)}, |
|
'token_type_ids': {i: [] for i in range(batch_size)}, |
|
'time_delta': {i: [] for i in range(batch_size)}, |
|
'attention_mask': torch.empty(batch_size, 0, 1, device=self.device), |
|
} |
|
|
|
prompt_text_columns = [f'{k}_{j}' if k != 'mimic_cxr_sectioned' else j for k, v in self.tables.items() if 'text_columns' in v for j in (v['text_columns'] if isinstance(v['text_columns'], list) else [v['text_columns']])] + ['prior_findings', 'prior_impression'] |
|
|
|
for i in prompt_text_columns: |
|
if i in kwargs: |
|
if f'{i}_time_delta' not in kwargs: |
|
kwargs[f'{i}_time_delta'] = [[self.zero_time_delta_value for _ in j] if j is not None else None for j in kwargs[i]] |
|
for x, (y, z) in enumerate(zip(kwargs[i], kwargs[f'{i}_time_delta'])): |
|
if y is not None: |
|
assert isinstance(y, list) |
|
assert isinstance(z, list) |
|
for text, time_delta in zip(y, z): |
|
if text is not None: |
|
tokenized['input_ids'][x].append( |
|
tokenizer(text, add_special_tokens=False, return_tensors='pt')['input_ids'].to(device=self.device) |
|
) |
|
tokenized['token_type_ids'][x].append( |
|
torch.full( |
|
(1, tokenized['input_ids'][x][-1].shape[-1]), |
|
self.token_type_to_token_type_id[i], |
|
dtype=torch.long, |
|
device=self.device, |
|
) |
|
) |
|
tokenized['time_delta'][x].append( |
|
torch.full( |
|
(1, tokenized['input_ids'][x][-1].shape[-1]), |
|
time_delta, |
|
dtype=torch.float32, |
|
device=self.device, |
|
) |
|
) |
|
|
|
tokenized['input_ids'] = [torch.cat(j, dim=1).T if j else torch.empty(0, 1, dtype=torch.long, device=self.device) for j in tokenized['input_ids'].values()] |
|
tokenized['token_type_ids'] = [torch.cat(j, dim=1).T if j else torch.empty(0, 1, dtype=torch.long, device=self.device) for j in tokenized['token_type_ids'].values()] |
|
tokenized['time_delta'] = [torch.cat(j, dim=1).T if j else torch.empty(0, 1, device=self.device) for j in tokenized['time_delta'].values()] |
|
|
|
tokenized['input_ids'] = torch.nn.utils.rnn.pad_sequence( |
|
tokenized['input_ids'], batch_first=True, padding_value=tokenizer.pad_token_id |
|
)[:, :, 0] |
|
tokenized['token_type_ids'] = torch.nn.utils.rnn.pad_sequence( |
|
tokenized['token_type_ids'], batch_first=True, padding_value=0, |
|
)[:, :, 0] |
|
|
|
tokenized['attention_mask'] = (tokenized['input_ids'] != tokenizer.pad_token_id).int() |
|
|
|
tokenized['time_delta'] = torch.nn.utils.rnn.pad_sequence( |
|
tokenized['time_delta'], batch_first=True, padding_value=0, |
|
) |
|
|
|
return tokenized |
|
|
|
def position_ids_from_time_deltas_and_attention_mask(self, time_deltas, attention_mask): |
|
mask_value = torch.finfo(time_deltas.dtype).max if self.config.time_delta_monotonic_inversion else torch.finfo(time_deltas.dtype).min |
|
|
|
masked_time_deltas = torch.where(attention_mask == 1, time_deltas[:, :, 0], mask_value) |
|
_, col_indices = torch.sort(masked_time_deltas, descending=not self.config.time_delta_monotonic_inversion) |
|
|
|
num_rows, num_cols, _ = time_deltas.shape |
|
|
|
row_indices = torch.arange(num_rows, device=time_deltas.device).view(-1, 1).repeat(1, num_cols).view(-1) |
|
position_ids = torch.zeros_like(col_indices, device=time_deltas.device) |
|
position_ids[row_indices, col_indices.flatten()] = torch.arange(num_cols, device=time_deltas.device)[None, :].expand(num_rows, -1).flatten() |
|
position_ids.masked_fill_(attention_mask == 0, 1) |
|
|
|
return position_ids |
|
|
|
def prepare_index_value_feats(self, table, batch): |
|
|
|
index_value_columns = (self.tables[table].get('index_columns', []) + self.tables[table].get('value_columns', [])) |
|
index_value_columns = [f'{table}_{i}' for i in index_value_columns] if table != 'mimic_cxr_2_0_0_metadata' else index_value_columns |
|
|
|
|
|
if 'index_columns' in self.tables[table]: |
|
for i in self.tables[table]['index_columns']: |
|
k = f'{table}_{i}' if not table == 'mimic_cxr_2_0_0_metadata' else i |
|
batch[k] = [ |
|
[self.luts[table][i][str(k)] if k is not None else None for k in j] if j is not None else None for j in batch[k] |
|
] |
|
|
|
batch_index_value_feats_list = [] |
|
batch_token_type_ids_list = [] |
|
batch_time_deltas_list = [] |
|
|
|
for batch_idx in range(len(batch['study_id'])): |
|
|
|
if any([batch[k][batch_idx] for k in index_value_columns]): |
|
|
|
num_rows = [len(batch[i][batch_idx]) for i in index_value_columns] |
|
assert all(x == num_rows[0] for x in num_rows) |
|
num_rows = num_rows[0] |
|
|
|
|
|
if isinstance(batch[self.tables[table]['groupby']][batch_idx], list): |
|
y_indices = [d.setdefault(x, len(d)) for d in [{}] for x in batch[self.tables[table]['groupby']][batch_idx]] |
|
datetime = [j for i, j in enumerate(batch[self.tables[table]['time_column']][batch_idx]) if j not in batch[self.tables[table]['time_column']][batch_idx][:i]] |
|
assert len(set(y_indices)) == len(datetime) |
|
else: |
|
y_indices = [0] * num_rows |
|
datetime = batch[self.tables[table]['time_column']][batch_idx] if 'time_column' in self.tables[table] else [batch['latest_study_datetime'][batch_idx]] |
|
|
|
time_deltas = torch.tensor([compute_time_delta(i, batch['latest_study_datetime'][batch_idx], self.time_delta_map, to_tensor=False) for i in datetime])[:, None] |
|
|
|
tensor = torch.zeros(max(y_indices) + 1, self.luts[table]['total']) |
|
|
|
|
|
if 'index_columns' in self.tables[table]: |
|
|
|
for i in self.tables[table]['index_columns']: |
|
k = f'{table}_{i}' if not table == 'mimic_cxr_2_0_0_metadata' else i |
|
y_indices_column = [y_idx for y_idx, x_idx in zip(y_indices, batch[k][batch_idx]) if x_idx is not None] |
|
x_indices_column = [x_idx for x_idx in batch[k][batch_idx] if x_idx is not None] |
|
|
|
tensor[y_indices_column, x_indices_column] = 1.0 |
|
|
|
if 'value_columns' in self.tables[table]: |
|
for i in self.tables[table]['value_columns']: |
|
|
|
k = f'{table}_{i}' if not table == 'mimic_cxr_2_0_0_metadata' else i |
|
y_indices_column = [y_idx for y_idx, value in zip(y_indices, batch[k][batch_idx]) if value is not None] |
|
x_indices_column = [self.luts[table][i] for value in batch[k][batch_idx] if value is not None] |
|
values = [value for value in batch[k][batch_idx] if value is not None] |
|
|
|
tensor[y_indices_column, x_indices_column] = torch.tensor(values, dtype=tensor.dtype) |
|
assert not torch.isnan(tensor).any() |
|
else: |
|
tensor = torch.empty(0, self.luts[table]['total']) |
|
time_deltas = torch.empty(0, 1) |
|
|
|
batch_index_value_feats_list.append(tensor) |
|
batch_token_type_ids_list.append(torch.full( |
|
[tensor.shape[0]], |
|
self.token_type_to_token_type_id[table], |
|
dtype=torch.long, |
|
) |
|
) |
|
batch_time_deltas_list.append(time_deltas) |
|
|
|
assert tensor.shape[0] == batch_token_type_ids_list[-1].shape[0] |
|
assert tensor.shape[0] == time_deltas.shape[0] |
|
|
|
batch_index_value_feats = torch.nn.utils.rnn.pad_sequence(batch_index_value_feats_list, batch_first=True, padding_value=-1) |
|
batch_token_type_ids = torch.nn.utils.rnn.pad_sequence(batch_token_type_ids_list, batch_first=True, padding_value=0) |
|
batch_time_deltas = torch.nn.utils.rnn.pad_sequence(batch_time_deltas_list, batch_first=True, padding_value=0) |
|
|
|
batch_mask = (batch_index_value_feats != -1).any(dim=-1).int() |
|
|
|
return batch_index_value_feats, batch_token_type_ids, batch_time_deltas, batch_mask |
|
|
|
def prepare_text_prompt(self, table, column, batch): |
|
|
|
key = f'{table}_{column}' if not table == 'mimic_cxr_sectioned' else column |
|
|
|
batch_text_list = [] |
|
batch_time_deltas_list = [] |
|
|
|
for batch_idx in range(len(batch['study_id'])): |
|
if batch[key][batch_idx]: |
|
|
|
num_rows = len(batch[key][batch_idx]) |
|
|
|
|
|
if isinstance(batch[self.tables[table]['groupby']][batch_idx], list): |
|
y_indices = [d.setdefault(x, len(d)) for d in [{}] for x in batch[self.tables[table]['groupby']][batch_idx]] |
|
datetime = [j for i, j in enumerate(batch[self.tables[table]['time_column']][batch_idx]) if j not in batch[self.tables[table]['time_column']][batch_idx][:i]] |
|
assert len(set(y_indices)) == len(datetime) |
|
else: |
|
y_indices = [0] * num_rows |
|
datetime = batch[self.tables[table]['time_column']][batch_idx] if 'time_column' in self.tables[table] else [batch['latest_study_datetime'][batch_idx]] |
|
|
|
|
|
text_rows = batch[key][batch_idx] if isinstance(batch[key][batch_idx], list) else [batch[key][batch_idx]] |
|
y_indices = [i for i, j in zip(y_indices, text_rows) if j is not None] |
|
text_rows = [i for i in text_rows if i is not None] |
|
datetime = [datetime[i] for i in set(y_indices)] |
|
if text_rows: |
|
|
|
|
|
batch_text_list.append([', '.join([text_rows[j] for j in range(len(y_indices)) if y_indices[j] == k]) + '.' for k in set(y_indices)]) |
|
batch_time_deltas_list.append([compute_time_delta(i, batch['latest_study_datetime'][batch_idx], self.time_delta_map, to_tensor=False) for i in datetime]) |
|
|
|
assert len(batch_time_deltas_list[-1]) == len(batch_text_list[-1]) |
|
else: |
|
batch_text_list.append([]) |
|
batch_time_deltas_list.append([]) |
|
else: |
|
batch_text_list.append([]) |
|
batch_time_deltas_list.append([]) |
|
|
|
return batch_text_list, batch_time_deltas_list |
|
|
|
@staticmethod |
|
def create_4d_attention_mask_mixed_causality(non_causal_2d_attention_mask, causal_2d_attention_mask, dtype): |
|
|
|
prompt_seq_len = non_causal_2d_attention_mask.shape[-1] |
|
report_seq_len = causal_2d_attention_mask.shape[-1] |
|
|
|
non_causal_2d_attention_mask = non_causal_2d_attention_mask[:, None, None, :] |
|
causal_2d_attention_mask = causal_2d_attention_mask[:, None, None, :] |
|
|
|
|
|
upper_left = non_causal_2d_attention_mask.expand(-1, -1, prompt_seq_len, -1) |
|
upper_left = upper_left * non_causal_2d_attention_mask |
|
upper_left = upper_left * non_causal_2d_attention_mask.permute(0, 1, 3, 2) |
|
|
|
causal_mask = torch.tril( |
|
torch.ones( |
|
( |
|
report_seq_len, |
|
report_seq_len, |
|
), |
|
dtype=torch.long, |
|
device=causal_2d_attention_mask.device, |
|
), |
|
) |
|
|
|
|
|
lower_right = causal_2d_attention_mask.expand(-1, -1, report_seq_len, -1) |
|
lower_right = lower_right * causal_2d_attention_mask.permute(0, 1, 3, 2) |
|
lower_right = lower_right * causal_mask |
|
|
|
|
|
upper_right = torch.zeros( |
|
causal_2d_attention_mask.shape[0], |
|
1, |
|
prompt_seq_len, |
|
report_seq_len, |
|
dtype=torch.long, |
|
device=causal_2d_attention_mask.device, |
|
) |
|
|
|
|
|
lower_left = non_causal_2d_attention_mask.expand(-1, -1, report_seq_len, -1) |
|
lower_left = lower_left * causal_2d_attention_mask.permute(0, 1, 3, 2) |
|
|
|
left = torch.cat((upper_left, lower_left), dim=2) |
|
right = torch.cat((upper_right, lower_right), dim=2) |
|
|
|
mixed_causality_4d_attention_mask = torch.cat((left, right), dim=-1) |
|
|
|
mixed_causality_4d_attention_mask = mixed_causality_4d_attention_mask.to(dtype=dtype) |
|
mixed_causality_4d_attention_mask[mixed_causality_4d_attention_mask == 0] = torch.finfo(mixed_causality_4d_attention_mask.dtype).min |
|
mixed_causality_4d_attention_mask[mixed_causality_4d_attention_mask == 1] = 0.0 |
|
|
|
return mixed_causality_4d_attention_mask |
|
|
|
@staticmethod |
|
def create_4d_attention_mask_mixed_causality_past_key_values(non_causal_2d_attention_mask, causal_2d_attention_mask, dtype): |
|
|
|
non_causal_2d_attention_mask = non_causal_2d_attention_mask[:, None, None, :] |
|
causal_2d_attention_mask = causal_2d_attention_mask[:, None, None, :] |
|
|
|
mixed_causality_4d_attention_mask = torch.cat((non_causal_2d_attention_mask, causal_2d_attention_mask), dim=-1) |
|
|
|
mixed_causality_4d_attention_mask = mixed_causality_4d_attention_mask.to(dtype=dtype) |
|
mixed_causality_4d_attention_mask[mixed_causality_4d_attention_mask == 0] = torch.finfo(mixed_causality_4d_attention_mask.dtype).min |
|
mixed_causality_4d_attention_mask[mixed_causality_4d_attention_mask == 1] = 0.0 |
|
|
|
return mixed_causality_4d_attention_mask |
|
|
|
@staticmethod |
|
def collate_fn(batch): |
|
keys = set().union(*(d.keys() for d in batch)) |
|
batch = {j: [i.setdefault(j, None) for i in batch] for j in keys} |
|
batch = {k: torch.stack(v) if isinstance(v[0], torch.Tensor) else v for k, v in batch.items()} |
|
return batch |
|
|
|
@staticmethod |
|
def prepare_dataset(physionet_dir: str, database_dir: str): |
|
|
|
prepare_dataset(physionet_dir=physionet_dir, database_dir=database_dir) |
|
|
|
def get_dataset(self, database_dir, max_train_images_per_study=None, study_id_split='mimic_iv_ed_mimic_cxr_jpg', test_set_only=False): |
|
|
|
dataset_path = os.path.join(database_dir, 'mimic_iv_ed_mimic_cxr_jpg_dataset') |
|
|
|
assert max_train_images_per_study is not None or test_set_only, 'max_train_images_per_study must be defined if training.' |
|
|
|
def train_set_transform(batch): |
|
|
|
|
|
keys = ['images', 'dicom_id'] |
|
keys = keys + self.tables['mimic_cxr_2_0_0_metadata']['index_columns'] if 'mimic_cxr_2_0_0_metadata' in self.tables else keys |
|
for i in range(len(batch['images'])): |
|
if len(batch['images'][i]) > max_train_images_per_study: |
|
paired = list(zip(*(batch[key][i] for key in keys))) |
|
sampled_pairs = random.sample(paired, max_train_images_per_study) |
|
unzipped_samples = zip(*sampled_pairs) |
|
for key, values in zip(keys, unzipped_samples): |
|
batch[key][i] = list(values) |
|
|
|
batch['images'] = [[decode_image(torch.frombuffer(bytearray(j), dtype=torch.uint8)) for j in i] for i in batch['images']] |
|
|
|
|
|
batch['images'] = [list(zip(*sorted(zip(i, v), key=lambda x: VIEW_ORDER.index(x[1]))))[0] for i, v in zip(batch['images'], batch['ViewPosition'])] |
|
batch['images'] = [torch.stack([self.train_transforms(j) for j in i]) for i in batch['images']] |
|
max_size = max(i.shape[0] for i in batch['images']) |
|
|
|
batch['image_time_deltas'] = [[self.zero_time_delta_value if j < i.shape[0] else self.inf_time_delta_value for j in range(max_size)] for i in batch['images']] |
|
batch['images'] = torch.nn.utils.rnn.pad_sequence(batch['images'], batch_first=True, padding_value=0.0) |
|
|
|
for k, v in self.tables.items(): |
|
if 'index_columns' in v or 'value_columns' in v: |
|
batch[f'{k}_index_value_feats'], batch[f'{k}_index_value_token_type_ids'], batch[f'{k}_index_value_time_delta'], batch[f'{k}_index_value_mask'] = self.prepare_index_value_feats(k, batch) |
|
|
|
for k, v in self.tables.items(): |
|
if 'text_columns' in v: |
|
for i in v['text_columns']: |
|
key = f'{k}_{i}' if not k == 'mimic_cxr_sectioned' else i |
|
batch[key], batch[f'{key}_time_delta'] = self.prepare_text_prompt(k, i, batch) |
|
|
|
return batch |
|
|
|
def test_set_transform(batch): |
|
batch['images'] = [[decode_image(torch.frombuffer(bytearray(j), dtype=torch.uint8)) for j in i] for i in batch['images']] |
|
|
|
|
|
batch['images'] = [list(zip(*sorted(zip(i, v), key=lambda x: VIEW_ORDER.index(x[1]))))[0] for i, v in zip(batch['images'], batch['ViewPosition'])] |
|
batch['images'] = [torch.stack([self.test_transforms(j) for j in i]) for i in batch['images']] |
|
max_size = max(i.shape[0] for i in batch['images']) |
|
batch['image_time_deltas'] = [[self.zero_time_delta_value if j < i.shape[0] else self.inf_time_delta_value for j in range(max_size)] for i in batch['images']] |
|
batch['images'] = torch.nn.utils.rnn.pad_sequence(batch['images'], batch_first=True, padding_value=0.0) |
|
|
|
for k, v in self.tables.items(): |
|
if 'index_columns' in v or 'value_columns' in v: |
|
batch[f'{k}_index_value_feats'], batch[f'{k}_index_value_token_type_ids'], batch[f'{k}_index_value_time_delta'], batch[f'{k}_index_value_mask'] = self.prepare_index_value_feats(k, batch) |
|
|
|
for k, v in self.tables.items(): |
|
if 'text_columns' in v: |
|
for i in v['text_columns']: |
|
key = f'{k}_{i}' if not k == 'mimic_cxr_sectioned' else i |
|
batch[key], batch[f'{key}_time_delta'] = self.prepare_text_prompt(k, i, batch) |
|
|
|
return batch |
|
|
|
dataset = datasets.load_from_disk(dataset_path) |
|
|
|
|
|
if not test_set_only: |
|
|
|
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), f'{study_id_split}_train_study_ids.json') |
|
path = path if os.path.exists(path) else hf_hub_download(repo_id='aehrc/cxrmate-ed', filename=f'{study_id_split}_train_study_ids.json') |
|
with open(path, 'r') as f: |
|
study_ids = json.load(f) |
|
train_set = dataset['train'] |
|
train_set_study_ids = train_set['study_id'] |
|
index_map = {study_id: idx for idx, study_id in enumerate(train_set_study_ids)} |
|
indices = [index_map[study_id] for study_id in study_ids if study_id in index_map] |
|
indices.sort() |
|
train_set = PriorsDataset(train_set, self.config.history, self.time_delta_map) |
|
train_set.set_transform(train_set_transform) |
|
train_set = Subset(train_set, indices) |
|
else: |
|
train_set = None |
|
|
|
|
|
if not test_set_only: |
|
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), f'{study_id_split}_validate_study_ids.json') |
|
path = path if os.path.exists(path) else hf_hub_download(repo_id='aehrc/cxrmate-ed', filename=f'{study_id_split}_validate_study_ids.json') |
|
with open(path, 'r') as f: |
|
study_ids = json.load(f) |
|
val_set = dataset['validate'] |
|
val_set_study_ids = val_set['study_id'] |
|
index_map = {study_id: idx for idx, study_id in enumerate(val_set_study_ids)} |
|
indices = [index_map[study_id] for study_id in study_ids if study_id in index_map] |
|
indices.sort() |
|
val_set = PriorsDataset(val_set, self.config.history, self.time_delta_map) |
|
val_set.set_transform(test_set_transform) |
|
val_set = Subset(val_set, indices) |
|
else: |
|
val_set = None |
|
|
|
|
|
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), f'{study_id_split}_test_study_ids.json') |
|
path = path if os.path.exists(path) else hf_hub_download(repo_id='aehrc/cxrmate-ed', filename=f'{study_id_split}_test_study_ids.json') |
|
with open(path, 'r') as f: |
|
study_ids = json.load(f) |
|
test_set = dataset['test'] |
|
test_set_study_ids = test_set['study_id'] |
|
index_map = {study_id: idx for idx, study_id in enumerate(test_set_study_ids)} |
|
indices = [index_map[study_id] for study_id in study_ids if study_id in index_map] |
|
indices.sort() |
|
test_set = PriorsDataset(test_set, self.config.history, self.time_delta_map) |
|
test_set.set_transform(test_set_transform) |
|
test_set = Subset(test_set, indices) |
|
|
|
if not test_set_only: |
|
return train_set, val_set, test_set |
|
else: |
|
return test_set |
|
|
|
def get_stage_1_dataset(self, database_dir, max_train_images_per_study): |
|
|
|
dataset_path = os.path.join(database_dir, 'mimic_iv_ed_mimic_cxr_jpg_dataset') |
|
|
|
def train_set_transform(batch): |
|
|
|
|
|
for i in range(len(batch['images'])): |
|
if len(batch['images'][i]) > max_train_images_per_study: |
|
paired = list(zip(batch['images'][i], batch['ViewPosition'][i])) |
|
sampled_pairs = random.sample(paired, max_train_images_per_study) |
|
batch['images'][i], batch['ViewPosition'][i] = zip(*sampled_pairs) |
|
|
|
batch['images'] = [[decode_image(torch.frombuffer(bytearray(j), dtype=torch.uint8)) for j in i] for i in batch['images']] |
|
|
|
|
|
batch['images'] = [list(zip(*sorted(zip(i, v), key=lambda x: VIEW_ORDER.index(x[1]))))[0] for i, v in zip(batch['images'], batch['ViewPosition'])] |
|
batch['images'] = [torch.stack([self.train_transforms(j) for j in i]) for i in batch['images']] |
|
max_size = max(i.shape[0] for i in batch['images']) |
|
batch['image_time_deltas'] = [[self.zero_time_delta_value if j < i.shape[0] else self.inf_time_delta_value for j in range(max_size)] for i in batch['images']] |
|
batch['images'] = torch.nn.utils.rnn.pad_sequence(batch['images'], batch_first=True, padding_value=0.0) |
|
|
|
return batch |
|
|
|
def test_set_transform(batch): |
|
batch['images'] = [[decode_image(torch.frombuffer(bytearray(j), dtype=torch.uint8)) for j in i] for i in batch['images']] |
|
|
|
|
|
batch['images'] = [list(zip(*sorted(zip(i, v), key=lambda x: VIEW_ORDER.index(x[1]))))[0] for i, v in zip(batch['images'], batch['ViewPosition'])] |
|
batch['images'] = [torch.stack([self.test_transforms(j) for j in i]) for i in batch['images']] |
|
max_size = max(i.shape[0] for i in batch['images']) |
|
batch['image_time_deltas'] = [[self.zero_time_delta_value if j < i.shape[0] else self.inf_time_delta_value for j in range(max_size)] for i in batch['images']] |
|
batch['images'] = torch.nn.utils.rnn.pad_sequence(batch['images'], batch_first=True, padding_value=0.0) |
|
|
|
return batch |
|
|
|
dataset = datasets.load_from_disk(dataset_path) |
|
|
|
|
|
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), f'mimic_cxr_jpg_train_study_ids.json') |
|
path = path if os.path.exists(path) else hf_hub_download(repo_id='aehrc/cxrmate-ed', filename='mimic_cxr_jpg_train_study_ids.json') |
|
with open(path, 'r') as f: |
|
study_ids = json.load(f) |
|
train_set = dataset['train'] |
|
train_set_study_ids = train_set['study_id'] |
|
index_map = {study_id: idx for idx, study_id in enumerate(train_set_study_ids)} |
|
indices = [index_map[study_id] for study_id in study_ids if study_id in index_map] |
|
indices.sort() |
|
train_set = PriorsDataset(train_set, self.config.history, self.time_delta_map) |
|
train_set.set_transform(train_set_transform) |
|
train_set = Subset(train_set, indices) |
|
|
|
|
|
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), f'mimic_cxr_jpg_validate_study_ids.json') |
|
path = path if os.path.exists(path) else hf_hub_download(repo_id='aehrc/cxrmate-ed', filename='mimic_cxr_jpg_validate_study_ids.json') |
|
with open(path, 'r') as f: |
|
study_ids = json.load(f) |
|
val_set = dataset['validate'] |
|
val_set_study_ids = val_set['study_id'] |
|
index_map = {study_id: idx for idx, study_id in enumerate(val_set_study_ids)} |
|
indices = [index_map[study_id] for study_id in study_ids if study_id in index_map] |
|
indices.sort() |
|
val_set = PriorsDataset(val_set, self.config.history, self.time_delta_map) |
|
val_set.set_transform(test_set_transform) |
|
val_set = Subset(val_set, indices) |
|
|
|
|
|
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), f'mimic_cxr_jpg_test_study_ids.json') |
|
path = path if os.path.exists(path) else hf_hub_download(repo_id='aehrc/cxrmate-ed', filename='mimic_cxr_jpg_test_study_ids.json') |
|
with open(path, 'r') as f: |
|
study_ids = json.load(f) |
|
test_set = dataset['test'] |
|
test_set_study_ids = test_set['study_id'] |
|
index_map = {study_id: idx for idx, study_id in enumerate(test_set_study_ids)} |
|
indices = [index_map[study_id] for study_id in study_ids if study_id in index_map] |
|
indices.sort() |
|
test_set = PriorsDataset(test_set, self.config.history, self.time_delta_map) |
|
test_set.set_transform(test_set_transform) |
|
test_set = Subset(test_set, indices) |
|
|
|
return train_set, val_set, test_set |
|
|