anicolson commited on
Commit
9691248
1 Parent(s): 1c38939

Upload model

Browse files
config.json CHANGED
@@ -21,11 +21,6 @@
21
  "diversity_penalty": 0.0,
22
  "do_sample": false,
23
  "early_stopping": false,
24
- "ed_module_columns": [
25
- "triage_chiefcomplaint",
26
- "triage_pain",
27
- "vitalsign_pain"
28
- ],
29
  "encoder_no_repeat_ngram_size": 0,
30
  "eos_token_id": 2,
31
  "exponential_decay_length_penalty": null,
@@ -34,16 +29,12 @@
34
  "forced_eos_token_id": null,
35
  "hidden_act": "silu",
36
  "hidden_size": 768,
 
37
  "id2label": {
38
  "0": "LABEL_0",
39
  "1": "LABEL_1"
40
  },
41
  "include_time_delta": true,
42
- "index_value_encoder_config": {
43
- "edstays": 40,
44
- "triage": 7,
45
- "vitalsign": 1177
46
- },
47
  "index_value_encoder_intermediate_size": 2048,
48
  "initializer_range": 0.02,
49
  "intermediate_size": 3072,
@@ -56,10 +47,6 @@
56
  "length_penalty": 1.0,
57
  "max_length": 20,
58
  "max_position_embeddings": 2048,
59
- "mimic_cxr_columns": [
60
- "indication",
61
- "history"
62
- ],
63
  "min_length": 0,
64
  "model_type": "llama",
65
  "no_repeat_ngram_size": 0,
@@ -69,7 +56,6 @@
69
  "num_hidden_layers": 6,
70
  "num_key_value_heads": 12,
71
  "num_return_sequences": 1,
72
- "num_token_types": 19,
73
  "output_attentions": false,
74
  "output_hidden_states": false,
75
  "output_scores": false,
@@ -77,6 +63,10 @@
77
  "prefix": null,
78
  "pretraining_tp": 1,
79
  "problem_type": null,
 
 
 
 
80
  "pruned_heads": {},
81
  "remove_invalid_values": false,
82
  "repetition_penalty": 1.0,
@@ -85,39 +75,19 @@
85
  "rms_norm_eps": 1e-06,
86
  "rope_scaling": null,
87
  "rope_theta": 10000.0,
88
- "section_ids": [
89
- 12,
90
- 13
91
- ],
92
  "sep_token_id": null,
93
  "suppress_tokens": null,
 
 
 
 
 
94
  "task_specific_params": null,
95
  "temperature": 1.0,
96
  "tf_legacy_loss": false,
97
  "tie_encoder_decoder": false,
98
  "tie_word_embeddings": false,
99
  "time_delta_monotonic_inversion": true,
100
- "token_type_to_token_type_id": {
101
- "comparison": 15,
102
- "edstays": 1,
103
- "findings": 12,
104
- "history": 11,
105
- "image": 14,
106
- "impression": 13,
107
- "indication": 10,
108
- "medrecon": 0,
109
- "medrecon_name": 6,
110
- "mimic_cxr_2_0_0_metadata": 5,
111
- "previous_findings": 16,
112
- "previous_image": 18,
113
- "previous_impression": 17,
114
- "pyxis": 4,
115
- "triage": 2,
116
- "triage_chiefcomplaint": 7,
117
- "triage_pain": 8,
118
- "vitalsign": 3,
119
- "vitalsign_pain": 9
120
- },
121
  "tokenizer_class": null,
122
  "top_k": 50,
123
  "top_p": 1.0,
@@ -126,14 +96,12 @@
126
  "typical_p": 1.0,
127
  "use_bfloat16": false,
128
  "use_cache": true,
129
- "vocab_size": 30000,
130
- "zero_time_delta_value": 1.0
131
  },
132
  "encoder": {
133
  "_name_or_path": "",
134
  "add_cross_attention": false,
135
  "architectures": null,
136
- "attention_probs_dropout_prob": 0.0,
137
  "attn_drop_rate": 0.0,
138
  "bad_words_ids": null,
139
  "begin_suppress_tokens": null,
@@ -160,24 +128,18 @@
160
  512
161
  ],
162
  "encoder_no_repeat_ngram_size": 0,
163
- "encoder_stride": 16,
164
  "eos_token_id": null,
165
  "exponential_decay_length_penalty": null,
166
  "finetuning_task": null,
167
  "forced_bos_token_id": null,
168
  "forced_eos_token_id": null,
169
  "head_dim": 64,
170
- "hidden_act": "gelu",
171
- "hidden_dropout_prob": 0.0,
172
- "hidden_size": 768,
173
  "id2label": {
174
  "0": "LABEL_0",
175
  "1": "LABEL_1"
176
  },
177
  "image_size": 384,
178
  "in_chans": 3,
179
- "initializer_range": 0.02,
180
- "intermediate_size": 3072,
181
  "is_decoder": false,
182
  "is_encoder_decoder": false,
183
  "label2id": {
@@ -189,14 +151,11 @@
189
  "max_length": 20,
190
  "min_length": 0,
191
  "mlp_ratio": 4,
192
- "model_type": "vit",
193
  "no_repeat_ngram_size": 0,
194
- "num_attention_heads": 12,
195
  "num_beam_groups": 1,
196
  "num_beams": 1,
197
- "num_channels": 3,
198
  "num_classes": 1000,
199
- "num_hidden_layers": 12,
200
  "num_return_sequences": 1,
201
  "output_attentions": false,
202
  "output_hidden_states": false,
@@ -234,8 +193,8 @@
234
  "typical_p": 1.0,
235
  "use_bfloat16": false
236
  },
237
- "model_type": "vision-encoder-decoder",
238
  "tie_word_embeddings": false,
239
  "torch_dtype": "float32",
240
- "transformers_version": "4.40.2"
241
  }
 
21
  "diversity_penalty": 0.0,
22
  "do_sample": false,
23
  "early_stopping": false,
 
 
 
 
 
24
  "encoder_no_repeat_ngram_size": 0,
25
  "eos_token_id": 2,
26
  "exponential_decay_length_penalty": null,
 
29
  "forced_eos_token_id": null,
30
  "hidden_act": "silu",
31
  "hidden_size": 768,
32
+ "history": 0,
33
  "id2label": {
34
  "0": "LABEL_0",
35
  "1": "LABEL_1"
36
  },
37
  "include_time_delta": true,
 
 
 
 
 
38
  "index_value_encoder_intermediate_size": 2048,
39
  "initializer_range": 0.02,
40
  "intermediate_size": 3072,
 
47
  "length_penalty": 1.0,
48
  "max_length": 20,
49
  "max_position_embeddings": 2048,
 
 
 
 
50
  "min_length": 0,
51
  "model_type": "llama",
52
  "no_repeat_ngram_size": 0,
 
56
  "num_hidden_layers": 6,
57
  "num_key_value_heads": 12,
58
  "num_return_sequences": 1,
 
59
  "output_attentions": false,
60
  "output_hidden_states": false,
61
  "output_scores": false,
 
63
  "prefix": null,
64
  "pretraining_tp": 1,
65
  "problem_type": null,
66
+ "prompt_report_sections_filter": [
67
+ "indication",
68
+ "history"
69
+ ],
70
  "pruned_heads": {},
71
  "remove_invalid_values": false,
72
  "repetition_penalty": 1.0,
 
75
  "rms_norm_eps": 1e-06,
76
  "rope_scaling": null,
77
  "rope_theta": 10000.0,
 
 
 
 
78
  "sep_token_id": null,
79
  "suppress_tokens": null,
80
+ "tables_filter": [
81
+ "mimic_cxr_sectioned",
82
+ "triage",
83
+ "medrecon"
84
+ ],
85
  "task_specific_params": null,
86
  "temperature": 1.0,
87
  "tf_legacy_loss": false,
88
  "tie_encoder_decoder": false,
89
  "tie_word_embeddings": false,
90
  "time_delta_monotonic_inversion": true,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  "tokenizer_class": null,
92
  "top_k": 50,
93
  "top_p": 1.0,
 
96
  "typical_p": 1.0,
97
  "use_bfloat16": false,
98
  "use_cache": true,
99
+ "vocab_size": 30000
 
100
  },
101
  "encoder": {
102
  "_name_or_path": "",
103
  "add_cross_attention": false,
104
  "architectures": null,
 
105
  "attn_drop_rate": 0.0,
106
  "bad_words_ids": null,
107
  "begin_suppress_tokens": null,
 
128
  512
129
  ],
130
  "encoder_no_repeat_ngram_size": 0,
 
131
  "eos_token_id": null,
132
  "exponential_decay_length_penalty": null,
133
  "finetuning_task": null,
134
  "forced_bos_token_id": null,
135
  "forced_eos_token_id": null,
136
  "head_dim": 64,
 
 
 
137
  "id2label": {
138
  "0": "LABEL_0",
139
  "1": "LABEL_1"
140
  },
141
  "image_size": 384,
142
  "in_chans": 3,
 
 
143
  "is_decoder": false,
144
  "is_encoder_decoder": false,
145
  "label2id": {
 
151
  "max_length": 20,
152
  "min_length": 0,
153
  "mlp_ratio": 4,
154
+ "model_type": "uniformer",
155
  "no_repeat_ngram_size": 0,
 
156
  "num_beam_groups": 1,
157
  "num_beams": 1,
 
158
  "num_classes": 1000,
 
159
  "num_return_sequences": 1,
160
  "output_attentions": false,
161
  "output_hidden_states": false,
 
193
  "typical_p": 1.0,
194
  "use_bfloat16": false
195
  },
196
+ "model_type": "encoder-decoder",
197
  "tie_word_embeddings": false,
198
  "torch_dtype": "float32",
199
+ "transformers_version": "4.39.3"
200
  }
configuration_cxrmate_ed.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+ from transformers.utils import logging
3
+
4
+ logger = logging.get_logger(__name__)
5
+
6
+
7
+ class EncoderDecoderConfig(PretrainedConfig):
8
+
9
+ model_type = "encoder-decoder"
10
+ is_composition = True
11
+
12
+ def __init__(self, **kwargs):
13
+ super().__init__(**kwargs)
14
+ if "encoder" not in kwargs or "decoder" not in kwargs:
15
+ raise ValueError(
16
+ f"A configuraton of type {self.model_type} cannot be instantiated because "
17
+ f"both `encoder` and `decoder` sub-configurations were not passed, only {kwargs}"
18
+ )
19
+
20
+ self.encoder = kwargs.pop("encoder")
21
+ self.decoder = kwargs.pop("decoder")
22
+ self.is_encoder_decoder = True
23
+
24
+ @classmethod
25
+ def from_encoder_decoder_configs(
26
+ cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs
27
+ ) -> PretrainedConfig:
28
+ r"""
29
+ Instantiate a [`EncoderDecoderConfig`] (or a derived class) from a pre-trained encoder model configuration and
30
+ decoder model configuration.
31
+
32
+ Returns:
33
+ [`EncoderDecoderConfig`]: An instance of a configuration object
34
+ """
35
+ logger.info("Set `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config")
36
+ decoder_config.is_decoder = True
37
+ decoder_config.add_cross_attention = True
38
+
39
+ return cls(encoder=encoder_config, decoder=decoder_config, **kwargs)
configuration_uniformer.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from transformers.utils import logging
3
+
4
+ logger = logging.get_logger(__name__)
5
+
6
+
7
+ class UniFormerWithProjectionHeadConfig(PretrainedConfig):
8
+
9
+ model_type = 'uniformer'
10
+
11
+ def __init__(
12
+ self,
13
+ projection_size=None,
14
+ embed_dim=[64, 128, 320, 512],
15
+ image_size=384,
16
+ in_chans=3,
17
+ depth=[5, 8, 20, 7],
18
+ patch_size=[4, 2, 2, 2],
19
+ head_dim=64,
20
+ mlp_ratio=4,
21
+ qkv_bias=True,
22
+ num_classes=1000,
23
+ qk_scale=None,
24
+ representation_size=None,
25
+ drop_rate=0.0,
26
+ drop_path_rate=0.3,
27
+ attn_drop_rate=0.0,
28
+ conv_stem=False,
29
+ layer_norm_eps=1e-6,
30
+ **kwargs,
31
+ ):
32
+ super().__init__(
33
+ layer_norm_eps=layer_norm_eps,
34
+ image_size=image_size,
35
+ qkv_bias=qkv_bias,
36
+ **kwargs,
37
+ )
38
+ self.projection_size = projection_size
39
+ self.embed_dim = embed_dim
40
+ self.in_chans = in_chans
41
+ self.depth = depth
42
+ self.patch_size = patch_size
43
+ self.head_dim = head_dim
44
+ self.mlp_ratio = mlp_ratio
45
+ self.num_classes = num_classes
46
+ self.qk_scale = qk_scale
47
+ self.representation_size = representation_size
48
+ self.drop_rate = drop_rate
49
+ self.drop_path_rate = drop_path_rate
50
+ self.attn_drop_rate = attn_drop_rate
51
+ self.conv_stem = conv_stem
create_section_files.py CHANGED
@@ -4,8 +4,10 @@ from pathlib import Path
4
 
5
  from tqdm import tqdm
6
 
7
- # local folder import
8
- from .section_parser import custom_mimic_cxr_rules, section_text
 
 
9
 
10
 
11
  def list_rindex(l, s):
@@ -98,7 +100,7 @@ def create_section_files(reports_path, output_path, no_split):
98
  # exist the radiologist has usually written the report
99
  # in the comparison section
100
  idx = -1
101
- for sn in ('impression', 'findings', 'indication', 'history', 'last_paragraph', 'comparison'):
102
  if sn in section_names:
103
  idx = list_rindex(section_names, sn)
104
  break
@@ -112,7 +114,7 @@ def create_section_files(reports_path, output_path, no_split):
112
  patient_studies.append([s_stem, sections[idx].strip()])
113
 
114
  study_sectioned = [s_stem]
115
- for sn in ('impression', 'findings', 'indication', 'history', 'last_paragraph', 'comparison'):
116
  if sn in section_names:
117
  idx = list_rindex(section_names, sn)
118
  study_sectioned.append(sections[idx].strip())
@@ -125,7 +127,7 @@ def create_section_files(reports_path, output_path, no_split):
125
  with open(output_path / 'mimic_cxr_sectioned.csv', 'w') as fp:
126
  csvwriter = csv.writer(fp)
127
  # write header
128
- csvwriter.writerow(['study', 'impression', 'findings', 'indication', 'history', 'last_paragraph', 'comparison'])
129
  for row in study_sections:
130
  csvwriter.writerow(row)
131
 
 
4
 
5
  from tqdm import tqdm
6
 
7
+ try:
8
+ from .section_parser import custom_mimic_cxr_rules, section_text
9
+ except ImportError:
10
+ from section_parser import custom_mimic_cxr_rules, section_text
11
 
12
 
13
  def list_rindex(l, s):
 
100
  # exist the radiologist has usually written the report
101
  # in the comparison section
102
  idx = -1
103
+ for sn in ('impression', 'findings', 'indication', 'history', 'technique', 'last_paragraph', 'comparison'):
104
  if sn in section_names:
105
  idx = list_rindex(section_names, sn)
106
  break
 
114
  patient_studies.append([s_stem, sections[idx].strip()])
115
 
116
  study_sectioned = [s_stem]
117
+ for sn in ('impression', 'findings', 'indication', 'history', 'technique', 'last_paragraph', 'comparison'):
118
  if sn in section_names:
119
  idx = list_rindex(section_names, sn)
120
  study_sectioned.append(sections[idx].strip())
 
127
  with open(output_path / 'mimic_cxr_sectioned.csv', 'w') as fp:
128
  csvwriter = csv.writer(fp)
129
  # write header
130
+ csvwriter.writerow(['study', 'impression', 'findings', 'indication', 'history', 'technique', 'last_paragraph', 'comparison'])
131
  for row in study_sections:
132
  csvwriter.writerow(row)
133
 
dataset.py CHANGED
@@ -1,253 +1,82 @@
1
- import os
 
2
 
3
- import lmdb
4
- import pandas as pd
5
  import torch
6
- from torch.utils.data import Dataset
7
- from torchvision.io import decode_image, read_image
8
 
9
- # Ordered by oblique, lateral, AP, and then PA views so that PA views are closest in position to the generated tokens (and oblique is furtherest).
10
- VIEW_ORDER = ['LPO', 'RAO', 'LAO', 'SWIMMERS', 'XTABLE LATERAL', 'LL', 'LATERAL', 'AP AXIAL', 'AP RLD', 'AP LLD', 'AP', 'PA RLD', 'PA LLD', 'PA']
11
 
12
 
13
- def mimic_cxr_image_path(dir, subject_id, study_id, dicom_id, ext='dcm'):
14
- return os.path.join(dir, 'p' + str(subject_id)[:2], 'p' + str(subject_id),
15
- 's' + str(study_id), str(dicom_id) + '.' + ext)
16
-
17
-
18
- class StudyIDEDStayIDSubset(Dataset):
19
- """
20
- Study ID & ED stay ID subset. Examples are indexed by the study identifier.
21
- Information from the ED module is added by finding the study_id that is within
22
- the timespan of the stay_id for the subject_id. The history and indication
23
- sections are also included.
24
- """
25
- def __init__(
26
- self,
27
- split,
28
- records,
29
- mimic_cxr_jpg_lmdb_path=None,
30
- mimic_cxr_dir=None,
31
- max_images_per_study=None,
32
- transforms=None,
33
- images=True,
34
- columns='study_id, dicom_id, subject_id, findings, impression',
35
- and_condition='',
36
- study_id_inclusion_list=None,
37
- return_images=True,
38
- ed_module=True,
39
- extension='jpg',
40
- ):
41
- """
42
- Argument/s:
43
- split - 'train', 'validate', or 'test'.
44
- records - MIMIC-CXR & MIMIC-IV-ED records class instance.
45
- mimic_cxr_jpg_lmdb_path - JPG database for MIMIC-CXR-JPG.
46
- mimic_cxr_dir - Path to the MIMIC-CXR directory containing the patient study subdirectories with the JPG or DCM images.
47
- max_images_per_study - the maximum number of images per study.
48
- transforms - torchvision transformations.
49
- colour_space - PIL target colour space.
50
- images - flag to return processed images.
51
- columns - which columns to query on.
52
- and_condition - AND condition to add to the SQL query.
53
- study_id_inclusion_list - studies not in this list are excluded.
54
- return_images - return CXR images for the study as tensors.
55
- ed_module - use the ED module.
56
- extension - 'jpg' or 'dcm'.
57
- """
58
- super(StudyIDEDStayIDSubset, self).__init__()
59
- self.split = split
60
- self.mimic_cxr_jpg_lmdb_path = mimic_cxr_jpg_lmdb_path
61
- self.mimic_cxr_dir = mimic_cxr_dir
62
- self.records = records
63
- self.max_images_per_study = max_images_per_study
64
- self.transforms = transforms
65
- self.images = images
66
- self.columns = columns
67
- self.and_condition = and_condition
68
- self.return_images = return_images
69
- self.ed_module = ed_module
70
- self.extension = extension
71
 
72
- # If max images per study is not set:
73
- self.max_images_per_study = float('inf') if self.max_images_per_study is None else self.max_images_per_study
74
-
75
- assert self.extension == 'jpg' or self.extension == 'dcm', '"extension" can only be either "jpg" or "dcm".'
76
- assert (mimic_cxr_jpg_lmdb_path is None) != (mimic_cxr_dir is None), 'Either "mimic_cxr_jpg_lmdb_path" or "mimic_cxr_dir" can be set.'
77
-
78
- if self.mimic_cxr_dir is not None and self.mimic_cxr_jpg_lmdb_path is None:
79
- if self.extension == 'jpg':
80
- if 'physionet.org/files/mimic-cxr-jpg/2.0.0/files' not in self.mimic_cxr_dir:
81
- self.mimic_cxr_dir = os.path.join(self.mimic_cxr_dir, 'physionet.org/files/mimic-cxr-jpg/2.0.0/files')
82
- elif self.extension == 'dcm':
83
- if 'physionet.org/files/mimic-cxr/2.0.0/files' not in self.mimic_cxr_dir:
84
- self.mimic_cxr_dir = os.path.join(self.mimic_cxr_dir, 'physionet.org/files/mimic-cxr/2.0.0/files')
85
-
86
- query = f"""
87
- SELECT {columns}
88
- FROM mimic_cxr
89
- WHERE split = '{split}'
90
- {and_condition}
91
- ORDER BY study_id
92
- """
93
-
94
- # For multi-image, the study identifiers make up the training examples:
95
- df = self.records.connect.sql(query).df()
96
-
97
- # Drop studies that don't have a findings or impression section:
98
- df = df.dropna(subset=['findings', 'impression'], how='any')
99
-
100
- # This study has two rows in edstays (removed as it causes issues):
101
- if self.ed_module:
102
- df = df[df['study_id'] != 59128861]
103
-
104
- # Exclude studies not in list:
105
- if study_id_inclusion_list is not None:
106
- df = df[df['study_id'].isin(study_id_inclusion_list)]
107
-
108
- # Example study identifiers for the subset:
109
- self.examples = df['study_id'].unique().tolist()
110
-
111
- # Record statistics:
112
- self.num_study_ids = len(self.examples)
113
- self.num_dicom_ids = len(df['dicom_id'].unique().tolist())
114
- self.num_subject_ids = len(df['subject_id'].unique().tolist())
115
-
116
- # Prepare the LMDB .jpg database:
117
- if self.mimic_cxr_jpg_lmdb_path is not None:
118
 
119
- print('Loading images using LMDB.')
120
-
121
- # Map size:
122
- map_size = int(0.65 * (1024 ** 4))
123
- assert isinstance(map_size, int)
 
 
 
 
 
124
 
125
- self.env = lmdb.open(self.mimic_cxr_jpg_lmdb_path, map_size=map_size, lock=False, readonly=True)
126
- self.txn = self.env.begin(write=False)
127
-
128
- def __len__(self):
129
- return self.num_study_ids
130
-
131
- def __getitem__(self, index):
132
-
133
- study_id = self.examples[index]
134
-
135
- # Get the study:
136
- study = self.records.connect.sql(
137
- f"""
138
- SELECT dicom_id, study_id, subject_id, study_datetime, ViewPosition
139
- FROM mimic_cxr
140
- WHERE (study_id = {study_id});
141
  """
142
- ).df()
143
- subject_id = study.iloc[0, study.columns.get_loc('subject_id')]
144
- study_id = study.iloc[0, study.columns.get_loc('study_id')]
145
- study_datetime = study['study_datetime'].max()
146
-
147
- example_dict = {
148
- 'study_ids': study_id,
149
- 'subject_id': subject_id,
150
- 'index': index,
151
- }
152
-
153
- example_dict.update(self.records.return_mimic_cxr_features(study_id))
154
-
155
- if self.ed_module:
156
- edstays = self.records.connect.sql(
157
- f"""
158
- SELECT stay_id, intime, outtime
159
- FROM edstays
160
- WHERE (subject_id = {subject_id})
161
- AND intime < '{study_datetime}'
162
- AND outtime > '{study_datetime}';
163
- """
164
- ).df()
165
-
166
- assert len(edstays) <= 1
167
- stay_id = edstays.iloc[0, edstays.columns.get_loc('stay_id')] if not edstays.empty else None
168
- self.records.clear_start_end_times()
169
- example_dict.update(self.records.return_ed_module_features(stay_id, study_datetime))
170
-
171
- example_dict['stay_ids'] = stay_id
172
-
173
- if self.return_images:
174
- example_dict['images'], example_dict['image_time_deltas'] = self.get_images(study, study_datetime)
175
-
176
- return example_dict
177
-
178
- def get_images(self, example, reference_time):
179
- """
180
- Get the image/s for a given example.
181
-
182
- Argument/s:
183
- example - dataframe for the example.
184
- reference_time - reference_time for time delta.
185
-
186
- Returns:
187
- The image/s for the example
188
- """
189
-
190
- # Sample if over max_images_per_study. Only allowed during training:
191
- if len(example) > self.max_images_per_study:
192
- assert self.split == 'train'
193
- example = example.sample(n=self.max_images_per_study, axis=0)
194
-
195
- # Order by ViewPostion:
196
- example['ViewPosition'] = example['ViewPosition'].astype(pd.CategoricalDtype(categories=VIEW_ORDER, ordered=True))
197
-
198
- # Sort the DataFrame based on the categorical column
199
- example = example.sort_values(by=['study_datetime', 'ViewPosition'])
200
-
201
- # Load and pre-process each CXR:
202
- images, time_deltas = [], []
203
- for _, row in example.iterrows():
204
- images.append(
205
- self.load_and_preprocess_image(
206
- row['subject_id'],
207
- row['study_id'],
208
- row['dicom_id'],
209
- ),
210
- )
211
- time_deltas.append(self.records.compute_time_delta(row['study_datetime'], reference_time, to_tensor=False))
212
-
213
- if self.transforms is not None:
214
- images = torch.stack(images, 0)
215
- return images, time_deltas
216
-
217
- def load_and_preprocess_image(self, subject_id, study_id, dicom_id):
218
- """
219
- Load and preprocess an image using torchvision.transforms.v2:
220
- https://pytorch.org/vision/stable/auto_examples/transforms/plot_transforms_getting_started.html#sphx-glr-auto-examples-transforms-plot-transforms-getting-started-py
221
-
222
- Argument/s:
223
- subject_id - subject identifier.
224
- study_id - study identifier.
225
- dicom_id - DICOM identifier.
226
-
227
- Returns:
228
- image - Tensor of the CXR.
229
- """
230
-
231
- if self.extension == 'jpg':
232
 
233
- if self.mimic_cxr_jpg_lmdb_path is not None:
234
-
235
- # Convert to bytes:
236
- key = bytes(dicom_id, 'utf-8')
237
-
238
- # Retrieve image:
239
- image = bytearray(self.txn.get(key))
240
- image = torch.frombuffer(image, dtype=torch.uint8)
241
- image = decode_image(image)
242
 
243
- else:
244
- image_file_path = mimic_cxr_image_path(self.mimic_cxr_dir, subject_id, study_id, dicom_id, self.extension)
245
- image = read_image(image_file_path)
246
-
247
- elif self.extension == 'dcm':
248
- raise NotImplementedError
249
-
250
- if self.transforms is not None:
251
- image = self.transforms(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
252
 
253
- return image
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ from typing import List
3
 
 
 
4
  import torch
 
 
5
 
6
+ from .utils import compute_time_delta
 
7
 
8
 
9
+ class PriorsDataset:
10
+ def __init__(self, dataset, history, time_delta_map):
11
+ self.dataset = dataset
12
+ self.history = history
13
+ self.study_id_to_index = dict(zip(dataset['study_id'], range(len(dataset))))
14
+ self.time_delta_map = time_delta_map
15
+ self.inf_time_delta_value = time_delta_map(float('inf'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ def __getitem__(self, idx):
18
+ batch = self.dataset[idx]
19
+
20
+ if self.history:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ # Prior studies:
23
+ prior_study_indices = [
24
+ None if i is None else [self.study_id_to_index[j] for j in i[:self.history]] for i in batch['prior_study_ids']
25
+ ]
26
+ prior_studies = [None if i is None else [self.dataset[j] for j in i] for i in prior_study_indices]
27
+
28
+ # Prior time deltas:
29
+ time_deltas = [
30
+ None if i is None else [compute_time_delta(k['latest_study_datetime'], j, self.time_delta_map, to_tensor=False) for k in i] for i, j in zip(prior_studies, batch['latest_study_datetime'])
31
+ ]
32
 
33
+ # Prior findings and impressions:
34
+ batch['prior_findings'] = [
35
+ None if i is None else [j['findings'] for j in i] for i in prior_studies
36
+ ]
37
+ batch['prior_impression'] = [
38
+ None if i is None else [j['findings'] for j in i] for i in prior_studies
39
+ ]
40
+ batch['prior_findings_time_delta'] = time_deltas.copy()
41
+ batch['prior_impression_time_delta'] = time_deltas.copy()
42
+
43
+ # Prior images:
 
 
 
 
 
44
  """
45
+ Note:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
+ Random selection of max_train_images_per_study from the study if the number of images for a study exceeds max_train_images_per_study is performed in train_set_transform and test_set_transform.
 
 
 
 
 
 
 
 
48
 
49
+ Sorting the images based on the view is done in test_set_transform.
50
+
51
+ No need to do it here.
52
+ """
53
+ prior_images = [
54
+ torch.cat(
55
+ [
56
+ torch.empty(0, *batch['images'].shape[-3:])
57
+ ] if i is None else [j['images'] for j in i]
58
+ ) for i in prior_studies
59
+ ]
60
+ prior_images = torch.nn.utils.rnn.pad_sequence(prior_images, batch_first=True, padding_value=0.0)
61
+ batch['images'] = torch.cat([batch['images'], prior_images], dim=1)
62
+ prior_image_time_deltas = [
63
+ None if i is None else list(itertools.chain.from_iterable([y] * x['images'].shape[0] for x, y in zip(i, j)))
64
+ for i, j in zip(prior_studies, time_deltas)
65
+ ]
66
+ max_len = max((len(item) for item in prior_image_time_deltas if item is not None), default=0)
67
+ prior_image_time_deltas = [i + [self.inf_time_delta_value] * (max_len - len(i)) if i else [self.inf_time_delta_value] * max_len for i in prior_image_time_deltas]
68
+ batch['image_time_deltas'] = [i + j for i, j in zip(batch['image_time_deltas'], prior_image_time_deltas)]
69
+
70
+ return batch
71
 
72
+ def __len__(self):
73
+ return len(self.dataset)
74
+
75
+ def __getattr__(self, name):
76
+ return getattr(self.dataset, name)
77
+
78
+ def __getitems__(self, keys: List):
79
+ batch = self.__getitem__(keys)
80
+ n_examples = len(batch[next(iter(batch))])
81
+ return [{col: array[i] for col, array in batch.items()} for i in range(n_examples)]
82
+
generation_config.json CHANGED
@@ -3,5 +3,5 @@
3
  "bos_token_id": 1,
4
  "eos_token_id": 2,
5
  "pad_token_id": 4,
6
- "transformers_version": "4.40.2"
7
  }
 
3
  "bos_token_id": 1,
4
  "eos_token_id": 2,
5
  "pad_token_id": 4,
6
+ "transformers_version": "4.39.3"
7
  }
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:e4b1ed2a5298bb8999cb91a9b905ace6733e5c66ebdef9702baa4d421428fad3
3
- size 644854104
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ffbf3e699a139ad98f20f8e057cd085586aea444b4b015471d697b43b440c14e
3
+ size 789958760
modelling_cxrmate_ed.py CHANGED
@@ -1,33 +1,32 @@
 
1
  import math
2
  import os
3
- from glob import glob
4
- from pathlib import Path
5
  from typing import Optional, Tuple, Union
6
 
7
- import duckdb
8
- import pandas as pd
9
  import torch
10
  import transformers
11
  from torch.nn import CrossEntropyLoss
12
- from tqdm import tqdm
 
13
  from transformers import PreTrainedTokenizerFast, VisionEncoderDecoderModel
14
  from transformers.configuration_utils import PretrainedConfig
15
  from transformers.modeling_outputs import Seq2SeqLMOutput
16
  from transformers.modeling_utils import PreTrainedModel
17
- from transformers.models.vision_encoder_decoder.configuration_vision_encoder_decoder import (
18
- VisionEncoderDecoderConfig,
19
- )
20
  from transformers.utils import logging
21
 
22
- from .create_section_files import create_section_files
23
- from .dataset import StudyIDEDStayIDSubset
24
- from .lmdb_jpg import prepare_mimic_cxr_jpg_lmdb
25
  from .modelling_uniformer import MultiUniFormerWithProjectionHead
26
- from .records import EDCXRSubjectRecords
27
- from .tables import ed_module_tables, mimic_cxr_tables
28
 
29
  logger = logging.get_logger(__name__)
30
 
 
 
 
31
 
32
  def create_lookup_table(df, columns, start_idx):
33
  df = df.groupby(columns).head(1)[columns].sort_values(by=columns)
@@ -49,12 +48,12 @@ class FNNEncoder(torch.nn.Module):
49
 
50
  class MIMICIVEDCXRMultimodalModel(VisionEncoderDecoderModel):
51
 
52
- config_class = VisionEncoderDecoderConfig
53
  base_model_prefix = "vision_encoder_decoder"
54
  main_input_name = "input_ids"
55
  supports_gradient_checkpointing = True
56
 
57
- def __init__(
58
  self,
59
  config: Optional[PretrainedConfig] = None,
60
  encoder: Optional[PreTrainedModel] = None,
@@ -70,7 +69,7 @@ class MIMICIVEDCXRMultimodalModel(VisionEncoderDecoderModel):
70
  if config is None and (encoder is None or decoder is None):
71
  raise ValueError("Either a configuration or an encoder and a decoder has to be provided.")
72
  if config is None:
73
- config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
74
  else:
75
  if not isinstance(config, self.config_class):
76
  raise ValueError(f"Config: {config} has to be of type {self.config_class}")
@@ -111,29 +110,50 @@ class MIMICIVEDCXRMultimodalModel(VisionEncoderDecoderModel):
111
  assert not config.decoder.is_encoder_decoder
112
  assert 'pad_token_id' in self.decoder.config.__dict__
113
  assert 'time_delta_monotonic_inversion' in self.decoder.config.__dict__
114
- assert 'zero_time_delta_value' in self.decoder.config.__dict__
115
  assert 'add_time_deltas' in self.decoder.config.__dict__
 
 
 
116
 
117
  assert isinstance(self.decoder.config.time_delta_monotonic_inversion, bool)
118
- assert isinstance(self.decoder.config.zero_time_delta_value, float)
119
-
120
- for k, v in self.decoder.config.index_value_encoder_config.items():
121
- setattr(
122
- self,
123
- f'{k}_index_value_encoder',
124
- FNNEncoder(
125
- num_features=v,
126
- intermediate_size=self.decoder.config.index_value_encoder_intermediate_size,
127
- decoder_hidden_size=self.decoder.config.hidden_size,
128
- ),
129
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  if self.decoder.config.add_time_deltas:
131
  self.time_delta_encoder = FNNEncoder(
132
  num_features=1,
133
  intermediate_size=self.decoder.config.index_value_encoder_intermediate_size,
134
  decoder_hidden_size=self.decoder.config.hidden_size,
135
  )
136
- self.token_type_embeddings = torch.nn.Embedding(self.decoder.config.num_token_types, self.decoder.config.hidden_size)
 
 
 
 
 
 
137
 
138
  @classmethod
139
  def from_encoder_decoder_pretrained(
@@ -281,7 +301,7 @@ class MIMICIVEDCXRMultimodalModel(VisionEncoderDecoderModel):
281
  decoder = transformers.AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
282
 
283
  # instantiate config with corresponding kwargs
284
- config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
285
 
286
  # make sure input & output embeddings is not tied
287
  config.tie_word_embeddings = False
@@ -292,13 +312,13 @@ class MIMICIVEDCXRMultimodalModel(VisionEncoderDecoderModel):
292
 
293
  def forward(
294
  self,
 
 
 
295
  decoder_input_ids: Optional[torch.LongTensor] = None,
296
- decoder_attention_mask: Optional[torch.FloatTensor] = None,
297
- decoder_token_type_ids: Optional[torch.LongTensor] = None,
298
  encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
299
  past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
300
  decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
301
- decoder_position_ids: Optional[torch.LongTensor] = None,
302
  labels: Optional[torch.LongTensor] = None,
303
  use_cache: Optional[bool] = None,
304
  output_attentions: Optional[bool] = None,
@@ -313,10 +333,7 @@ class MIMICIVEDCXRMultimodalModel(VisionEncoderDecoderModel):
313
  argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
314
  }
315
 
316
- assert decoder_position_ids is not None
317
- assert decoder_attention_mask is not None
318
  assert decoder_attention_mask.dtype == torch.long, f'The dtype for {decoder_attention_mask} was {decoder_attention_mask.dtype}. It should be torch.long'
319
- assert decoder_token_type_ids is not None
320
 
321
  if decoder_inputs_embeds is None:
322
  decoder_inputs_embeds = self.decoder.get_input_embeddings()(decoder_input_ids)
@@ -362,7 +379,6 @@ class MIMICIVEDCXRMultimodalModel(VisionEncoderDecoderModel):
362
  special_token_ids,
363
  prompt_attention_mask,
364
  prompt_position_ids,
365
- token_type_id_sections=None,
366
  past_key_values=None,
367
  use_cache=None,
368
  **kwargs,
@@ -387,7 +403,10 @@ class MIMICIVEDCXRMultimodalModel(VisionEncoderDecoderModel):
387
  # `inputs_embeds` are only to be used in the 1st generation step:
388
  inputs_embeds = torch.cat([kwargs['decoder_inputs_embeds'], self.decoder.get_input_embeddings()(input_ids)], dim=1)
389
 
390
- decoder_token_type_ids = self.token_ids_to_token_type_ids(input_ids, special_token_ids, token_type_id_sections)
 
 
 
391
  decoder_token_type_ids = torch.cat(
392
  [
393
  kwargs['decoder_token_type_ids'],
@@ -411,7 +430,11 @@ class MIMICIVEDCXRMultimodalModel(VisionEncoderDecoderModel):
411
  decoder_position_ids.masked_fill_(report_attention_mask == 0, 1)
412
 
413
  # Always place token_ids_to_token_type_ids_past_key_values before input_ids = input_ids[:, remove_prefix_length:]:
414
- decoder_token_type_ids = self.token_ids_to_token_type_ids_past_key_values(input_ids, special_token_ids, token_type_id_sections)
 
 
 
 
415
  decoder_position_ids = decoder_position_ids[:, -1:]
416
 
417
  past_length = past_key_values[0][0].shape[2]
@@ -437,7 +460,7 @@ class MIMICIVEDCXRMultimodalModel(VisionEncoderDecoderModel):
437
  )
438
  return input_dict
439
 
440
- def token_ids_to_token_type_ids(self, token_ids, special_token_ids, token_type_id_sections=None):
441
  """
442
  Extract token type identifiers from the token identifiers.
443
 
@@ -480,7 +503,7 @@ class MIMICIVEDCXRMultimodalModel(VisionEncoderDecoderModel):
480
 
481
  return token_type_ids
482
 
483
- def token_ids_to_token_type_ids_past_key_values(self, token_ids, special_token_ids, token_type_id_sections=None):
484
  """
485
  Extract token type identifiers from the token identifiers if past != None. Make sure to input all the
486
  token_ids (e.g., do not input input_ids = input_ids[:, remove_prefix_length:] from prepare_inputs_for_generation).
@@ -649,7 +672,7 @@ class MIMICIVEDCXRMultimodalModel(VisionEncoderDecoderModel):
649
 
650
  return tuple(sections.values())
651
 
652
- def tokenize_text_columns(self, tokenizer: PreTrainedTokenizerFast, **kwargs):
653
  """
654
  Tokenize the text columns from MIMIC-IV ED and MIMIC-CXR (excluding the findings and impression sections).
655
  Time deltas for the input_ids are also prepared here.
@@ -662,7 +685,7 @@ class MIMICIVEDCXRMultimodalModel(VisionEncoderDecoderModel):
662
  cxr - dictionary containing the input_ids, token_type_ids, and attention_mask for MIMIC-CXR columns.
663
  """
664
 
665
- batch_size = len(kwargs['index'])
666
 
667
  tokenized = {
668
  'input_ids': {i: [] for i in range(batch_size)},
@@ -671,34 +694,37 @@ class MIMICIVEDCXRMultimodalModel(VisionEncoderDecoderModel):
671
  'attention_mask': torch.empty(batch_size, 0, 1, device=self.device),
672
  }
673
 
674
- for i in self.decoder.config.ed_module_columns + self.decoder.config.mimic_cxr_columns + ['previous_findings', 'previous_impression']:
 
 
675
  if i in kwargs:
676
  if f'{i}_time_delta' not in kwargs:
677
- kwargs[f'{i}_time_delta'] = [[self.decoder.config.zero_time_delta_value for _ in j] if j is not None else None for j in kwargs[i]]
678
  for x, (y, z) in enumerate(zip(kwargs[i], kwargs[f'{i}_time_delta'])):
679
  if y is not None:
680
  assert isinstance(y, list)
681
  assert isinstance(z, list)
682
  for text, time_delta in zip(y, z):
683
- tokenized['input_ids'][x].append(
684
- tokenizer(text, add_special_tokens=False, return_tensors='pt')['input_ids'].to(device=self.device)
685
- )
686
- tokenized['token_type_ids'][x].append(
687
- torch.full(
688
- (1, tokenized['input_ids'][x][-1].shape[-1]),
689
- self.decoder.config.token_type_to_token_type_id[i],
690
- dtype=torch.long,
691
- device=self.device,
692
  )
693
- )
694
- tokenized['time_delta'][x].append(
695
- torch.full(
696
- (1, tokenized['input_ids'][x][-1].shape[-1]),
697
- time_delta,
698
- dtype=torch.float32,
699
- device=self.device,
 
 
 
 
 
 
 
 
700
  )
701
- )
702
 
703
  tokenized['input_ids'] = [torch.cat(j, dim=1).T if j else torch.empty(0, 1, dtype=torch.long, device=self.device) for j in tokenized['input_ids'].values()]
704
  tokenized['token_type_ids'] = [torch.cat(j, dim=1).T if j else torch.empty(0, 1, dtype=torch.long, device=self.device) for j in tokenized['token_type_ids'].values()]
@@ -725,7 +751,6 @@ class MIMICIVEDCXRMultimodalModel(VisionEncoderDecoderModel):
725
  tokenizer: PreTrainedTokenizerFast,
726
  tokenized_report=None,
727
  sep_token_id=None,
728
- section_ids=None,
729
  **batch,
730
  ):
731
  """
@@ -736,8 +761,7 @@ class MIMICIVEDCXRMultimodalModel(VisionEncoderDecoderModel):
736
  tokenizer - Hugging Face tokenizer.
737
  tokenized_report - if training/teacher forcing, input the tokenized_report dict to include it in the prepared inputs.
738
  separator_token_id - separator token identifier.
739
- section_ids - section identifiers for the findings and impression sections.
740
-
741
  Returns:
742
  inputs_embeds - input embeddings.
743
  attention_mask - attention mask.
@@ -755,23 +779,24 @@ class MIMICIVEDCXRMultimodalModel(VisionEncoderDecoderModel):
755
  bos_token_ids = None
756
 
757
  # Index and value columns:
758
- batch_size = len(batch['index'])
759
- for k in self.decoder.config.index_value_encoder_config.keys():
760
- if f'{k}_index_value_feats' not in batch:
761
- batch[f'{k}_index_value_feats'] = torch.empty(batch_size, 0, self.decoder.config.index_value_encoder_config[k], device=self.device)
762
- inputs_embeds.append(
763
- getattr(self, f'{k}_index_value_encoder')(batch[f'{k}_index_value_feats'])
764
- )
765
- token_type_ids.append(batch[f'{k}_index_value_token_type_ids'] if f'{k}_index_value_token_type_ids' in batch else torch.empty(batch_size, 0, dtype=torch.long, device=self.device))
766
- attention_mask.append(batch[f'{k}_index_value_mask'] if f'{k}_index_value_mask' in batch else torch.empty(batch_size, 0, dtype=torch.long, device=self.device))
767
- if f'{k}_time_delta' in batch:
768
- time_delta.append(batch[f'{k}_time_delta'])
769
- else:
770
- time_delta_index_value = torch.zeros(*batch[f'{k}_index_value_mask'].shape, 1, device=self.device) if f'{k}_index_value_mask' in batch else torch.empty(batch_size, 0, 1, device=self.device)
771
- time_delta.append(time_delta_index_value)
 
772
 
773
  # Tokenize text columns for prompt:
774
- tokenized = self.tokenize_text_columns(tokenizer, **batch)
775
  input_ids.append(tokenized['input_ids'])
776
  token_type_ids.append(tokenized['token_type_ids'])
777
  attention_mask.append(tokenized['attention_mask'])
@@ -780,14 +805,17 @@ class MIMICIVEDCXRMultimodalModel(VisionEncoderDecoderModel):
780
  # Image encoder:
781
  encoder_outputs = self.encoder(images)
782
  inputs_embeds.append(encoder_outputs[0])
 
783
  inputs_per_image = encoder_outputs[0].shape[-2] // images.shape[1]
784
- padded_image_time_deltas = [i + [self.decoder.config.zero_time_delta_value] * (images.shape[1] - len(i)) for i in batch['image_time_deltas']]
785
- time_delta_image_features = torch.tensor(padded_image_time_deltas, device=self.device).repeat_interleave(inputs_per_image, dim=1)
786
  token_type_ids.append(
787
  torch.where(
788
- time_delta_image_features == self.decoder.config.zero_time_delta_value,
789
- self.decoder.config.token_type_to_token_type_id['image'],
790
- self.decoder.config.token_type_to_token_type_id['previous_image'],
 
 
 
791
  ),
792
  )
793
  attention_mask.append(encoder_outputs[1])
@@ -819,7 +847,7 @@ class MIMICIVEDCXRMultimodalModel(VisionEncoderDecoderModel):
819
  report_token_type_ids = self.token_ids_to_token_type_ids(
820
  token_ids=tokenized_report['decoder_input_ids'],
821
  special_token_ids=[sep_token_id],
822
- token_type_id_sections=section_ids,
823
  )
824
  token_type_ids.append(report_token_type_ids)
825
 
@@ -906,8 +934,11 @@ class MIMICIVEDCXRMultimodalModel(VisionEncoderDecoderModel):
906
  return mixed_causality_4d_attention_mask
907
 
908
  def position_ids_from_time_deltas_and_attention_mask(self, time_deltas, attention_mask):
909
- _, col_indices = torch.sort(torch.where(attention_mask == 1, time_deltas[:, :, 0], torch.finfo(time_deltas.dtype).min), descending=not self.decoder.config.time_delta_monotonic_inversion)
910
-
 
 
 
911
  num_rows, num_cols, _ = time_deltas.shape
912
 
913
  row_indices = torch.arange(num_rows, device=time_deltas.device).view(-1, 1).repeat(1, num_cols).view(-1)
@@ -917,272 +948,316 @@ class MIMICIVEDCXRMultimodalModel(VisionEncoderDecoderModel):
917
 
918
  return position_ids
919
 
920
- @staticmethod
921
- def prepare_data(physionet_dir, database_dir):
922
 
923
- Path(database_dir).mkdir(parents=True, exist_ok=True)
924
-
925
- mimic_iv_duckdb_path = os.path.join(database_dir, 'mimic_iv_duckdb.db')
926
- mimic_cxr_jpg_lmdb_path = os.path.join(database_dir, 'mimic_cxr_jpg_lmdb.db')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
927
 
928
- sectioned_dir = os.path.join(database_dir, 'mimic_cxr_sectioned')
929
 
930
- mimic_cxr_sectioned_path = os.path.join(sectioned_dir, 'mimic_cxr_sectioned.csv')
931
- if not os.path.exists(mimic_cxr_sectioned_path):
932
- print(f'{mimic_cxr_sectioned_path} does not exist, creating...')
933
 
934
- # Check if reports exist. Reports for the first and last patients are checked only for speed, this comprimises comprehensiveness for speed:
935
- report_paths = [
936
- os.path.join(physionet_dir, 'mimic-cxr/2.0.0/files/p10/p10000032/s50414267.txt'),
937
- os.path.join(physionet_dir, 'mimic-cxr/2.0.0/files/p10/p10000032/s53189527.txt'),
938
- os.path.join(physionet_dir, 'mimic-cxr/2.0.0/files/p10/p10000032/s53911762.txt'),
939
- os.path.join(physionet_dir, 'mimic-cxr/2.0.0/files/p10/p10000032/s56699142.txt'),
940
- os.path.join(physionet_dir, 'mimic-cxr/2.0.0/files/p19/p19999987/s55368167.txt'),
941
- os.path.join(physionet_dir, 'mimic-cxr/2.0.0/files/p19/p19999987/s58621812.txt'),
942
- os.path.join(physionet_dir, 'mimic-cxr/2.0.0/files/p19/p19999987/s58971208.txt'),
943
- ]
944
- assert all([os.path.isfile(i) for i in report_paths]), f"""The reports do not exist with the following regex: {os.path.join(physionet_dir, 'mimic-cxr/2.0.0/files/p1*/p1*/s*.txt')}.
945
- "Please download them using wget -r -N -c -np --reject dcm --user <username> --ask-password https://physionet.org/files/mimic-cxr/2.0.0/"""
946
-
947
- print('Extracting sections from reports...')
948
- create_section_files(
949
- reports_path=os.path.join(physionet_dir, 'mimic-cxr', '2.0.0', 'files'),
950
- output_path=sectioned_dir,
951
- no_split=True,
952
- )
953
 
954
- if not os.path.exists(mimic_iv_duckdb_path):
 
 
955
 
956
- connect = duckdb.connect(mimic_iv_duckdb_path)
957
-
958
- csv_paths = []
959
- csv_paths.append(glob(os.path.join(physionet_dir, 'mimic-iv-ed', '*', 'ed', 'edstays.csv.gz'))[0])
960
- csv_paths.append(glob(os.path.join(physionet_dir, 'mimic-iv-ed', '*', 'ed', 'medrecon.csv.gz'))[0])
961
- csv_paths.append(glob(os.path.join(physionet_dir, 'mimic-iv-ed', '*', 'ed', 'pyxis.csv.gz'))[0])
962
- csv_paths.append(glob(os.path.join(physionet_dir, 'mimic-iv-ed', '*', 'ed', 'triage.csv.gz'))[0])
963
- csv_paths.append(glob(os.path.join(physionet_dir, 'mimic-iv-ed', '*', 'ed', 'vitalsign.csv.gz'))[0])
964
 
965
- base_names = [os.path.basename(i) for i in csv_paths]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
966
 
967
- for i in ['edstays.csv.gz', 'medrecon.csv.gz', 'pyxis.csv.gz', 'triage.csv.gz', 'vitalsign.csv.gz']:
968
- assert i in base_names, f"""Table {i} is missing from MIMIC-IV-ED.
969
- Please download the tables from https://physionet.org/content/mimic-iv-ed. Do not decompress them."""
970
-
971
- csv_paths.append(glob(os.path.join(physionet_dir, 'mimic-cxr-jpg', '*', 'mimic-cxr-2.0.0-metadata.csv.gz'))[0])
972
- csv_paths.append(glob(os.path.join(physionet_dir, 'mimic-cxr-jpg', '*', 'mimic-cxr-2.0.0-chexpert.csv.gz'))[0])
973
- csv_paths.append(glob(os.path.join(physionet_dir, 'mimic-cxr-jpg', '*', 'mimic-cxr-2.0.0-split.csv.gz'))[0])
974
-
975
- base_names = [os.path.basename(i) for i in csv_paths[-3:]]
976
-
977
- for i in ['mimic-cxr-2.0.0-metadata.csv.gz', 'mimic-cxr-2.0.0-chexpert.csv.gz', 'mimic-cxr-2.0.0-split.csv.gz']:
978
- assert i in base_names, f"""CSV file {i} is missing from MIMIC-IV-ED.
979
- Please download the tables from https://physionet.org/content/mimic-cxr-jpg. Do not decompress them."""
980
-
981
- for i in csv_paths:
982
- name = Path(i).stem.replace('.csv', '').replace('.gz', '').replace('-', '_').replace('.', '_')
983
- print(f'Copying {name} into database...')
984
- connect.sql(f"CREATE OR REPLACE TABLE {name} AS FROM '{i}';")
985
-
986
- # MIMIC-CXR report sections:
987
- print(f'Copying mimic_cxr_sectioned into database...')
988
- connect.sql(f"CREATE OR REPLACE TABLE mimic_cxr_sectioned AS FROM '{mimic_cxr_sectioned_path}';")
989
- columns = list(connect.sql('FROM mimic_cxr_sectioned LIMIT 1').df().columns)
990
- if 'column0' in columns: # If the column headers are not read correctly:
991
- connect.sql("ALTER TABLE mimic_cxr_sectioned RENAME COLUMN column0 TO study;")
992
- connect.sql("ALTER TABLE mimic_cxr_sectioned RENAME COLUMN column1 TO impression;")
993
- connect.sql("ALTER TABLE mimic_cxr_sectioned RENAME COLUMN column2 TO findings;")
994
- connect.sql("ALTER TABLE mimic_cxr_sectioned RENAME COLUMN column3 TO indication;")
995
- connect.sql("ALTER TABLE mimic_cxr_sectioned RENAME COLUMN column4 TO history;")
996
- connect.sql("ALTER TABLE mimic_cxr_sectioned RENAME COLUMN column5 TO last_paragraph;")
997
- connect.sql("ALTER TABLE mimic_cxr_sectioned RENAME COLUMN column6 TO comparison;")
998
- connect.sql("DELETE FROM mimic_cxr_sectioned WHERE study='study';")
999
-
1000
- splits = connect.sql("FROM mimic_cxr_2_0_0_split").df()
1001
- reports = connect.sql("FROM mimic_cxr_sectioned").df()
1002
- metadata = connect.sql("FROM mimic_cxr_2_0_0_metadata").df()
1003
- chexpert = connect.sql("FROM mimic_cxr_2_0_0_chexpert").df()
1004
-
1005
- # Create datetime column:
1006
- metadata['StudyTime'] = metadata['StudyTime'].astype(int)
1007
- metadata['study_datetime'] = pd.to_datetime(
1008
- metadata.apply(lambda x: f'{x["StudyDate"]} {x["StudyTime"]:06}', axis=1),
1009
- format='%Y%m%d %H%M%S',
1010
- )
1011
- reports.rename(columns={'study': 'study_id'}, inplace=True)
1012
- reports.study_id = reports.study_id.str[1:].astype('int32')
1013
- df = pd.merge(splits, reports, on='study_id')
1014
- df = pd.merge(df, metadata, on=['dicom_id', 'study_id', 'subject_id'])
1015
- df = pd.merge(df, chexpert, on=['study_id', 'subject_id'])
1016
-
1017
- connect.sql(f"CREATE OR REPLACE TABLE mimic_cxr AS SELECT * FROM df")
1018
 
1019
- # Create lookup tables:
1020
- for k, v in (ed_module_tables | mimic_cxr_tables).items():
1021
- if v.load and v.index_columns:
1022
- start_idx = 0
1023
- for i in v.index_columns_source:
1024
- lut_name = f'{k}_{i}_lut'
1025
- table = k
1026
- lut, end_idx = create_lookup_table(connect.sql(f"SELECT {i} FROM {table}").df(), [i], start_idx)
1027
- start_idx = end_idx + 1
1028
- lut = lut.rename(columns={'index': f'{i}_index'})
1029
-
1030
- print(f'Creating {lut_name}...')
1031
-
1032
- connect.sql(f"CREATE OR REPLACE TABLE {lut_name} AS SELECT * FROM lut")
1033
-
1034
- if f'{i}_index' in connect.sql(f"FROM {k} LIMIT 0").df().columns:
1035
- connect.sql(
1036
- f"""
1037
- ALTER TABLE {k}
1038
- DROP COLUMN {i}_index;
1039
- """
1040
- )
1041
-
1042
- connect.sql(
1043
- f"""
1044
- CREATE OR REPLACE TABLE {k} AS
1045
- SELECT {k}.*, {lut_name}.{i}_index
1046
- FROM {k} LEFT JOIN {lut_name}
1047
- ON {k}.{i} = {lut_name}.{i}
1048
- """
1049
- )
1050
-
1051
- connect.sql(
1052
- f"""
1053
- CREATE TABLE IF NOT EXISTS lut_info (table_name VARCHAR PRIMARY KEY, start_index INT, end_index INT);
1054
- INSERT OR REPLACE INTO lut_info VALUES ('{k}', {0}, {end_idx});
1055
- """
1056
- )
1057
-
1058
- table_studies = {
1059
- 'edstays': [],
1060
- 'triage': [],
1061
- 'medrecon': [],
1062
- 'vitalsign': [],
1063
- 'pyxis': [],
1064
- }
1065
- stay_id_tables = ['triage']
1066
- stay_id_charttime_tables = ['medrecon', 'vitalsign', 'pyxis']
1067
-
1068
- df = connect.sql(f"FROM mimic_cxr").df()
1069
 
1070
- # DICOM identifiers can have different datetimes, so use most recent datetime for the study:
1071
- df = df.sort_values(by='study_datetime', ascending=False)
1072
- df = df.groupby('study_id').first().reset_index()
1073
-
1074
- print('Searching for studies associated with an ED stay...')
1075
- for _, row in tqdm(df.iterrows(), total=df.shape[0]):
1076
- edstays = connect.sql(
1077
- f"""
1078
- SELECT stay_id, intime, outtime
1079
- FROM edstays
1080
- WHERE (subject_id = {row['subject_id']})
1081
- AND intime < '{row['study_datetime']}'
1082
- AND outtime > '{row['study_datetime']}';
1083
- """
1084
- ).df()
1085
-
1086
- if len(edstays) > 0:
1087
-
1088
- for i in edstays['stay_id'].to_list():
1089
- table_studies['edstays'].append({'study_id': row['study_id'], 'stay_id': i})
1090
- for j in stay_id_tables:
1091
- table = connect.sql(
1092
- f"""
1093
- SELECT stay_id
1094
- FROM {j}
1095
- WHERE (stay_id = {i});
1096
- """
1097
- ).df()
1098
-
1099
- for k in table['stay_id'].to_list():
1100
- table_studies[j].append({'study_id': row['study_id'], 'stay_id': k})
1101
-
1102
- for j in stay_id_charttime_tables:
1103
- table = connect.sql(
1104
- f"""
1105
- SELECT stay_id
1106
- FROM {j}
1107
- WHERE (stay_id = {i})
1108
- AND charttime < '{row['study_datetime']}';
1109
- """
1110
- ).df()
1111
-
1112
- for k in table['stay_id'].to_list():
1113
- table_studies[j].append({'study_id': row['study_id'], 'stay_id': k})
1114
-
1115
- for k, v in table_studies.items():
1116
- df = pd.DataFrame(v)
1117
- df = df.drop_duplicates(subset=['study_id', 'stay_id'])
1118
- connect.sql(f"CREATE TABLE {k}_study_ids AS SELECT * FROM df")
1119
 
1120
- connect.close()
1121
-
1122
- if not os.path.exists(mimic_cxr_jpg_lmdb_path):
1123
- print('Preparing MIMIC-CXR-JPG LMDB database...')
1124
- pattern = os.path.join(physionet_dir, 'mimic-cxr-jpg', '*', 'files')
1125
- mimic_cxr_jpg_dir = glob(pattern)
1126
- assert len(mimic_cxr_jpg_dir), f'Multiple directories matched the pattern {pattern}: {mimic_cxr_jpg_dir}. Only one is required.'
1127
- prepare_mimic_cxr_jpg_lmdb(
1128
- mimic_iv_duckdb_path=mimic_iv_duckdb_path,
1129
- mimic_cxr_jpg_dir=mimic_cxr_jpg_dir[0],
1130
- mimic_cxr_jpg_lmdb_path=mimic_cxr_jpg_lmdb_path,
1131
- map_size_tb=0.65
1132
- )
1133
 
1134
- @staticmethod
1135
- def get_dataset(split, transforms, database_dir, max_images_per_study=5, mimic_cxr_jpg_dir=None, records=None):
1136
-
1137
- mimic_iv_duckdb_path = os.path.join(database_dir, 'mimic_iv_duckdb.db')
1138
- mimic_cxr_jpg_lmdb_path = os.path.join(database_dir, 'mimic_cxr_jpg_lmdb.db') if mimic_cxr_jpg_dir is None else None
1139
-
1140
- if records is None:
1141
 
1142
- # This is the setup for CXRs + all effective inputs - medicine reconciliation:
1143
- records = EDCXRSubjectRecords(database_path=mimic_iv_duckdb_path, time_delta_map=lambda x: 1 / math.sqrt(x + 1))
 
 
 
 
1144
 
1145
- records.ed_module_tables = {k: records.ed_module_tables[k] for k in ['edstays', 'triage', 'vitalsign']}
1146
- records.mimic_cxr_tables = {k: records.mimic_cxr_tables[k] for k in ['mimic_cxr_sectioned']}
1147
- records.mimic_cxr_tables['mimic_cxr_sectioned'].text_columns = ['indication', 'history']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1148
 
1149
- dataset = StudyIDEDStayIDSubset(
1150
- mimic_cxr_jpg_lmdb_path=mimic_cxr_jpg_lmdb_path,
1151
- mimic_cxr_dir=mimic_cxr_jpg_dir,
1152
- transforms=transforms,
1153
- split=split,
1154
- max_images_per_study=max_images_per_study,
1155
- records=records,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1156
  )
1157
- print(f'No. of examples: {dataset.__len__()}.')
1158
- print(
1159
- f'No. of training dicom_ids, study_ids, & subject_ids: {dataset.num_dicom_ids},',
1160
- f'{dataset.num_study_ids}, & {dataset.num_subject_ids}.',
1161
- )
1162
- return dataset
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1163
 
1164
  @staticmethod
1165
  def collate_fn(batch):
1166
  keys = set().union(*(d.keys() for d in batch))
1167
  batch = {j: [i.setdefault(j, None) for i in batch] for j in keys}
1168
- batch['images'] = torch.nn.utils.rnn.pad_sequence(batch['images'], batch_first=True, padding_value=0.0)
1169
-
1170
- for k in keys:
1171
- if 'index_value_feats' in k:
1172
-
1173
- total_indices = next(i for i in batch[k] if i is not None).shape[-1]
1174
- batch[k] = [i if i is not None else torch.empty(0, total_indices) for i in batch[k]]
1175
- batch[k] = torch.nn.utils.rnn.pad_sequence(batch[k], batch_first=True, padding_value=-1) # Pad value of -1 is not ideal. Need to use something else.
1176
- token_type_id_name = k.replace('_feats', '_token_type_ids')
1177
- batch[token_type_id_name] = [i if i is not None else torch.empty(0, dtype=torch.long) for i in batch[token_type_id_name]]
1178
- batch[token_type_id_name] = torch.nn.utils.rnn.pad_sequence(
1179
- batch[token_type_id_name], batch_first=True, padding_value=0,
1180
- )
1181
- mask_name = k.replace('_feats', '_mask')
1182
- batch[mask_name] = (batch[k] != -1).any(dim=-1).int()
1183
-
1184
- if 'time_delta' in k and 'index_value' in k:
1185
- batch[k] = [i if i is not None else torch.empty(0, 1) for i in batch[k]]
1186
- batch[k] = torch.nn.utils.rnn.pad_sequence(batch[k], batch_first=True, padding_value=0)
1187
-
1188
- return batch
 
1
+ import json
2
  import math
3
  import os
4
+ import random
 
5
  from typing import Optional, Tuple, Union
6
 
7
+ import datasets
 
8
  import torch
9
  import transformers
10
  from torch.nn import CrossEntropyLoss
11
+ from torch.utils.data import Subset
12
+ from torchvision.io import decode_image
13
  from transformers import PreTrainedTokenizerFast, VisionEncoderDecoderModel
14
  from transformers.configuration_utils import PretrainedConfig
15
  from transformers.modeling_outputs import Seq2SeqLMOutput
16
  from transformers.modeling_utils import PreTrainedModel
 
 
 
17
  from transformers.utils import logging
18
 
19
+ from .configuration_cxrmate_ed import EncoderDecoderConfig
20
+ from .dataset import PriorsDataset
 
21
  from .modelling_uniformer import MultiUniFormerWithProjectionHead
22
+ from .prepare_dataset import prepare_dataset
23
+ from .utils import compute_time_delta
24
 
25
  logger = logging.get_logger(__name__)
26
 
27
+ # Ordered by oblique, lateral, AP, and then PA views so that PA views are closest in position to the generated tokens (and oblique is furtherest).
28
+ VIEW_ORDER = [None, 'LPO', 'RAO', 'LAO', 'SWIMMERS', 'XTABLE LATERAL', 'LL', 'LATERAL', 'AP AXIAL', 'AP RLD', 'AP LLD', 'AP', 'PA RLD', 'PA LLD', 'PA']
29
+
30
 
31
  def create_lookup_table(df, columns, start_idx):
32
  df = df.groupby(columns).head(1)[columns].sort_values(by=columns)
 
48
 
49
  class MIMICIVEDCXRMultimodalModel(VisionEncoderDecoderModel):
50
 
51
+ config_class = EncoderDecoderConfig
52
  base_model_prefix = "vision_encoder_decoder"
53
  main_input_name = "input_ids"
54
  supports_gradient_checkpointing = True
55
 
56
+ def __init__(
57
  self,
58
  config: Optional[PretrainedConfig] = None,
59
  encoder: Optional[PreTrainedModel] = None,
 
69
  if config is None and (encoder is None or decoder is None):
70
  raise ValueError("Either a configuration or an encoder and a decoder has to be provided.")
71
  if config is None:
72
+ config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
73
  else:
74
  if not isinstance(config, self.config_class):
75
  raise ValueError(f"Config: {config} has to be of type {self.config_class}")
 
110
  assert not config.decoder.is_encoder_decoder
111
  assert 'pad_token_id' in self.decoder.config.__dict__
112
  assert 'time_delta_monotonic_inversion' in self.decoder.config.__dict__
 
113
  assert 'add_time_deltas' in self.decoder.config.__dict__
114
+ assert 'history' in self.decoder.config.__dict__
115
+ assert 'tables_filter' in self.decoder.config.__dict__
116
+ assert 'prompt_report_sections_filter' in self.decoder.config.__dict__
117
 
118
  assert isinstance(self.decoder.config.time_delta_monotonic_inversion, bool)
119
+
120
+ with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tables.json'), 'r') as f:
121
+ self.tables = json.load(f)
122
+
123
+ with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'lookup_tables.json'), 'r') as f:
124
+ self.luts = json.load(f)
125
+
126
+ with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'token_type_ids.json'), 'r') as f:
127
+ self.token_type_to_token_type_id = json.load(f)
128
+
129
+ self.tables = {k: self.tables[k] for k in self.decoder.config.tables_filter}
130
+ self.tables['mimic_cxr_sectioned']['text_columns'] = self.decoder.config.prompt_report_sections_filter
131
+
132
+ for k in self.tables.keys():
133
+ if self.luts[k]['total'] > 0:
134
+ setattr(
135
+ self,
136
+ f'{k}_index_value_encoder',
137
+ FNNEncoder(
138
+ num_features=self.luts[k]['total'],
139
+ intermediate_size=self.decoder.config.index_value_encoder_intermediate_size,
140
+ decoder_hidden_size=self.decoder.config.hidden_size,
141
+ ),
142
+ )
143
+
144
  if self.decoder.config.add_time_deltas:
145
  self.time_delta_encoder = FNNEncoder(
146
  num_features=1,
147
  intermediate_size=self.decoder.config.index_value_encoder_intermediate_size,
148
  decoder_hidden_size=self.decoder.config.hidden_size,
149
  )
150
+
151
+ self.token_type_embeddings = torch.nn.Embedding(max(self.token_type_to_token_type_id.values()) + 1, self.decoder.config.hidden_size)
152
+
153
+ self.time_delta_map = lambda x: 1 / math.sqrt(x + 1)
154
+ self.zero_time_delta_value = self.time_delta_map(0)
155
+
156
+ self.inf_time_delta_value = self.time_delta_map(float('inf'))
157
 
158
  @classmethod
159
  def from_encoder_decoder_pretrained(
 
301
  decoder = transformers.AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
302
 
303
  # instantiate config with corresponding kwargs
304
+ config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
305
 
306
  # make sure input & output embeddings is not tied
307
  config.tie_word_embeddings = False
 
312
 
313
  def forward(
314
  self,
315
+ decoder_position_ids: torch.LongTensor,
316
+ decoder_attention_mask: torch.FloatTensor,
317
+ decoder_token_type_ids: torch.LongTensor,
318
  decoder_input_ids: Optional[torch.LongTensor] = None,
 
 
319
  encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
320
  past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
321
  decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
 
322
  labels: Optional[torch.LongTensor] = None,
323
  use_cache: Optional[bool] = None,
324
  output_attentions: Optional[bool] = None,
 
333
  argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
334
  }
335
 
 
 
336
  assert decoder_attention_mask.dtype == torch.long, f'The dtype for {decoder_attention_mask} was {decoder_attention_mask.dtype}. It should be torch.long'
 
337
 
338
  if decoder_inputs_embeds is None:
339
  decoder_inputs_embeds = self.decoder.get_input_embeddings()(decoder_input_ids)
 
379
  special_token_ids,
380
  prompt_attention_mask,
381
  prompt_position_ids,
 
382
  past_key_values=None,
383
  use_cache=None,
384
  **kwargs,
 
403
  # `inputs_embeds` are only to be used in the 1st generation step:
404
  inputs_embeds = torch.cat([kwargs['decoder_inputs_embeds'], self.decoder.get_input_embeddings()(input_ids)], dim=1)
405
 
406
+ decoder_token_type_ids = self.token_ids_to_token_type_ids(
407
+ input_ids, special_token_ids,
408
+ [self.token_type_to_token_type_id['findings'], self.token_type_to_token_type_id['impression']],
409
+ )
410
  decoder_token_type_ids = torch.cat(
411
  [
412
  kwargs['decoder_token_type_ids'],
 
430
  decoder_position_ids.masked_fill_(report_attention_mask == 0, 1)
431
 
432
  # Always place token_ids_to_token_type_ids_past_key_values before input_ids = input_ids[:, remove_prefix_length:]:
433
+ decoder_token_type_ids = self.token_ids_to_token_type_ids_past_key_values(
434
+ input_ids,
435
+ special_token_ids,
436
+ [self.token_type_to_token_type_id['findings'], self.token_type_to_token_type_id['impression']],
437
+ )
438
  decoder_position_ids = decoder_position_ids[:, -1:]
439
 
440
  past_length = past_key_values[0][0].shape[2]
 
460
  )
461
  return input_dict
462
 
463
+ def token_ids_to_token_type_ids(self, token_ids, special_token_ids, token_type_id_sections):
464
  """
465
  Extract token type identifiers from the token identifiers.
466
 
 
503
 
504
  return token_type_ids
505
 
506
+ def token_ids_to_token_type_ids_past_key_values(self, token_ids, special_token_ids, token_type_id_sections):
507
  """
508
  Extract token type identifiers from the token identifiers if past != None. Make sure to input all the
509
  token_ids (e.g., do not input input_ids = input_ids[:, remove_prefix_length:] from prepare_inputs_for_generation).
 
672
 
673
  return tuple(sections.values())
674
 
675
+ def tokenize_text_prompt(self, tokenizer: PreTrainedTokenizerFast, **kwargs):
676
  """
677
  Tokenize the text columns from MIMIC-IV ED and MIMIC-CXR (excluding the findings and impression sections).
678
  Time deltas for the input_ids are also prepared here.
 
685
  cxr - dictionary containing the input_ids, token_type_ids, and attention_mask for MIMIC-CXR columns.
686
  """
687
 
688
+ batch_size = len(kwargs['study_id'])
689
 
690
  tokenized = {
691
  'input_ids': {i: [] for i in range(batch_size)},
 
694
  'attention_mask': torch.empty(batch_size, 0, 1, device=self.device),
695
  }
696
 
697
+ prompt_text_columns = [f'{k}_{j}' if k != 'mimic_cxr_sectioned' else j for k, v in self.tables.items() if 'text_columns' in v for j in (v['text_columns'] if isinstance(v['text_columns'], list) else [v['text_columns']])] + ['prior_findings', 'prior_impression']
698
+
699
+ for i in prompt_text_columns:
700
  if i in kwargs:
701
  if f'{i}_time_delta' not in kwargs:
702
+ kwargs[f'{i}_time_delta'] = [[self.zero_time_delta_value for _ in j] if j is not None else None for j in kwargs[i]]
703
  for x, (y, z) in enumerate(zip(kwargs[i], kwargs[f'{i}_time_delta'])):
704
  if y is not None:
705
  assert isinstance(y, list)
706
  assert isinstance(z, list)
707
  for text, time_delta in zip(y, z):
708
+ if text is not None:
709
+ tokenized['input_ids'][x].append(
710
+ tokenizer(text, add_special_tokens=False, return_tensors='pt')['input_ids'].to(device=self.device)
 
 
 
 
 
 
711
  )
712
+ tokenized['token_type_ids'][x].append(
713
+ torch.full(
714
+ (1, tokenized['input_ids'][x][-1].shape[-1]),
715
+ self.token_type_to_token_type_id[i],
716
+ dtype=torch.long,
717
+ device=self.device,
718
+ )
719
+ )
720
+ tokenized['time_delta'][x].append(
721
+ torch.full(
722
+ (1, tokenized['input_ids'][x][-1].shape[-1]),
723
+ time_delta,
724
+ dtype=torch.float32,
725
+ device=self.device,
726
+ )
727
  )
 
728
 
729
  tokenized['input_ids'] = [torch.cat(j, dim=1).T if j else torch.empty(0, 1, dtype=torch.long, device=self.device) for j in tokenized['input_ids'].values()]
730
  tokenized['token_type_ids'] = [torch.cat(j, dim=1).T if j else torch.empty(0, 1, dtype=torch.long, device=self.device) for j in tokenized['token_type_ids'].values()]
 
751
  tokenizer: PreTrainedTokenizerFast,
752
  tokenized_report=None,
753
  sep_token_id=None,
 
754
  **batch,
755
  ):
756
  """
 
761
  tokenizer - Hugging Face tokenizer.
762
  tokenized_report - if training/teacher forcing, input the tokenized_report dict to include it in the prepared inputs.
763
  separator_token_id - separator token identifier.
764
+
 
765
  Returns:
766
  inputs_embeds - input embeddings.
767
  attention_mask - attention mask.
 
779
  bos_token_ids = None
780
 
781
  # Index and value columns:
782
+ batch_size = images.shape[0]
783
+ for k, v in self.tables.items():
784
+ if 'index_columns' in v or 'value_columns' in v:
785
+ if f'{k}_index_value_feats' not in batch:
786
+ batch[f'{k}_index_value_feats'] = torch.empty(batch_size, 0, self.luts[k]['total'], device=self.device)
787
+ inputs_embeds.append(
788
+ getattr(self, f'{k}_index_value_encoder')(batch[f'{k}_index_value_feats'])
789
+ )
790
+ token_type_ids.append(batch[f'{k}_index_value_token_type_ids'] if f'{k}_index_value_token_type_ids' in batch else torch.empty(batch_size, 0, dtype=torch.long, device=self.device))
791
+ attention_mask.append(batch[f'{k}_index_value_mask'] if f'{k}_index_value_mask' in batch else torch.empty(batch_size, 0, dtype=torch.long, device=self.device))
792
+ if f'{k}_index_value_time_delta' in batch:
793
+ time_delta.append(batch[f'{k}_index_value_time_delta'])
794
+ else:
795
+ time_delta_index_value = torch.zeros(*batch[f'{k}_index_value_mask'].shape, 1, device=self.device) if f'{k}_index_value_mask' in batch else torch.empty(batch_size, 0, 1, device=self.device)
796
+ time_delta.append(time_delta_index_value)
797
 
798
  # Tokenize text columns for prompt:
799
+ tokenized = self.tokenize_text_prompt(tokenizer, **batch)
800
  input_ids.append(tokenized['input_ids'])
801
  token_type_ids.append(tokenized['token_type_ids'])
802
  attention_mask.append(tokenized['attention_mask'])
 
805
  # Image encoder:
806
  encoder_outputs = self.encoder(images)
807
  inputs_embeds.append(encoder_outputs[0])
808
+
809
  inputs_per_image = encoder_outputs[0].shape[-2] // images.shape[1]
810
+ time_delta_image_features = torch.tensor(batch['image_time_deltas'], device=self.device).repeat_interleave(inputs_per_image, dim=1)
 
811
  token_type_ids.append(
812
  torch.where(
813
+ torch.logical_or(
814
+ time_delta_image_features == self.zero_time_delta_value,
815
+ time_delta_image_features == self.inf_time_delta_value,
816
+ ),
817
+ self.token_type_to_token_type_id['image'],
818
+ self.token_type_to_token_type_id['prior_image'],
819
  ),
820
  )
821
  attention_mask.append(encoder_outputs[1])
 
847
  report_token_type_ids = self.token_ids_to_token_type_ids(
848
  token_ids=tokenized_report['decoder_input_ids'],
849
  special_token_ids=[sep_token_id],
850
+ token_type_id_sections=[self.token_type_to_token_type_id['findings'], self.token_type_to_token_type_id['impression']],
851
  )
852
  token_type_ids.append(report_token_type_ids)
853
 
 
934
  return mixed_causality_4d_attention_mask
935
 
936
  def position_ids_from_time_deltas_and_attention_mask(self, time_deltas, attention_mask):
937
+ mask_value = torch.finfo(time_deltas.dtype).max if self.decoder.config.time_delta_monotonic_inversion else torch.finfo(time_deltas.dtype).min
938
+
939
+ masked_time_deltas = torch.where(attention_mask == 1, time_deltas[:, :, 0], mask_value)
940
+ _, col_indices = torch.sort(masked_time_deltas, descending=not self.decoder.config.time_delta_monotonic_inversion)
941
+
942
  num_rows, num_cols, _ = time_deltas.shape
943
 
944
  row_indices = torch.arange(num_rows, device=time_deltas.device).view(-1, 1).repeat(1, num_cols).view(-1)
 
948
 
949
  return position_ids
950
 
951
+ def get_dataset(self, dataset_path, train_transforms, test_transforms, max_train_images_per_study, study_id_split='mimic_iv_ed_mimic_cxr_jpg', test_set_only=False):
 
952
 
953
+ def train_set_transform(batch):
954
+
955
+ # Randomly select max_train_images_per_study if the number of images for a study exceeds max_train_images_per_study.
956
+ keys = ['images', 'dicom_id']
957
+ keys = keys + self.tables['mimic_cxr_2_0_0_metadata']['index_columns'] if 'mimic_cxr_2_0_0_metadata' in self.tables else keys
958
+ for i in range(len(batch['images'])):
959
+ if len(batch['images'][i]) > max_train_images_per_study:
960
+ paired = list(zip(*(batch[key][i] for key in keys)))
961
+ sampled_pairs = random.sample(paired, max_train_images_per_study)
962
+ unzipped_samples = zip(*sampled_pairs)
963
+ for key, values in zip(keys, unzipped_samples):
964
+ batch[key][i] = list(values)
965
+
966
+ batch['images'] = [[decode_image(torch.frombuffer(bytearray(j), dtype=torch.uint8)) for j in i] for i in batch['images']]
967
+
968
+ # Sort based on ViewPosition:
969
+ batch['images'] = [list(zip(*sorted(zip(i, v), key=lambda x: VIEW_ORDER.index(x[1]))))[0] for i, v in zip(batch['images'], batch['ViewPosition'])]
970
+ batch['images'] = [torch.stack([train_transforms(j) for j in i]) for i in batch['images']]
971
+ max_size = max(i.shape[0] for i in batch['images'])
972
+
973
+ batch['image_time_deltas'] = [[self.zero_time_delta_value if j < i.shape[0] else self.inf_time_delta_value for j in range(max_size)] for i in batch['images']]
974
+ batch['images'] = torch.nn.utils.rnn.pad_sequence(batch['images'], batch_first=True, padding_value=0.0)
975
+
976
+ for k, v in self.tables.items():
977
+ if 'index_columns' in v or 'value_columns' in v:
978
+ batch[f'{k}_index_value_feats'], batch[f'{k}_index_value_token_type_ids'], batch[f'{k}_index_value_time_delta'], batch[f'{k}_index_value_mask'] = self.prepare_index_value_feats(k, batch)
979
+
980
+ for k, v in self.tables.items():
981
+ if 'text_columns' in v:
982
+ for i in v['text_columns']:
983
+ key = f'{k}_{i}' if not k == 'mimic_cxr_sectioned' else i
984
+ batch[key], batch[f'{key}_time_delta'] = self.prepare_text_prompt(k, i, batch)
985
 
986
+ return batch
987
 
988
+ def test_set_transform(batch):
989
+ batch['images'] = [[decode_image(torch.frombuffer(bytearray(j), dtype=torch.uint8)) for j in i] for i in batch['images']]
 
990
 
991
+ # Sort based on ViewPosition:
992
+ batch['images'] = [list(zip(*sorted(zip(i, v), key=lambda x: VIEW_ORDER.index(x[1]))))[0] for i, v in zip(batch['images'], batch['ViewPosition'])]
993
+ batch['images'] = [torch.stack([test_transforms(j) for j in i]) for i in batch['images']]
994
+ max_size = max(i.shape[0] for i in batch['images'])
995
+ batch['image_time_deltas'] = [[self.zero_time_delta_value if j < i.shape[0] else self.inf_time_delta_value for j in range(max_size)] for i in batch['images']]
996
+ batch['images'] = torch.nn.utils.rnn.pad_sequence(batch['images'], batch_first=True, padding_value=0.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
997
 
998
+ for k, v in self.tables.items():
999
+ if 'index_columns' in v or 'value_columns' in v:
1000
+ batch[f'{k}_index_value_feats'], batch[f'{k}_index_value_token_type_ids'], batch[f'{k}_index_value_time_delta'], batch[f'{k}_index_value_mask'] = self.prepare_index_value_feats(k, batch)
1001
 
1002
+ for k, v in self.tables.items():
1003
+ if 'text_columns' in v:
1004
+ for i in v['text_columns']:
1005
+ key = f'{k}_{i}' if not k == 'mimic_cxr_sectioned' else i
1006
+ batch[key], batch[f'{key}_time_delta'] = self.prepare_text_prompt(k, i, batch)
 
 
 
1007
 
1008
+ return batch
1009
+
1010
+ dataset = datasets.load_from_disk(dataset_path)
1011
+
1012
+ # Train set:
1013
+ if not test_set_only:
1014
+ with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), f'{study_id_split}_train_study_ids.json'), 'r') as f:
1015
+ study_ids = json.load(f)
1016
+ train_set = dataset['train']
1017
+ train_set_study_ids = train_set['study_id']
1018
+ index_map = {study_id: idx for idx, study_id in enumerate(train_set_study_ids)}
1019
+ indices = [index_map[study_id] for study_id in study_ids if study_id in index_map]
1020
+ indices.sort()
1021
+ train_set = PriorsDataset(train_set, self.decoder.config.history, self.time_delta_map)
1022
+ train_set.set_transform(train_set_transform)
1023
+ train_set = Subset(train_set, indices)
1024
+ else:
1025
+ train_set = None
1026
+
1027
+ # Validation set:
1028
+ if not test_set_only:
1029
+ with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), f'{study_id_split}_validate_study_ids.json'), 'r') as f:
1030
+ study_ids = json.load(f)
1031
+ val_set = dataset['validate']
1032
+ val_set_study_ids = val_set['study_id']
1033
+ index_map = {study_id: idx for idx, study_id in enumerate(val_set_study_ids)}
1034
+ indices = [index_map[study_id] for study_id in study_ids if study_id in index_map]
1035
+ indices.sort()
1036
+ val_set = PriorsDataset(val_set, self.decoder.config.history, self.time_delta_map)
1037
+ val_set.set_transform(test_set_transform)
1038
+ val_set = Subset(val_set, indices)
1039
+ else:
1040
+ val_set = None
1041
+
1042
+ # Test set:
1043
+ with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), f'{study_id_split}_test_study_ids.json'), 'r') as f:
1044
+ study_ids = json.load(f)
1045
+ test_set = dataset['test']
1046
+ test_set_study_ids = test_set['study_id']
1047
+ index_map = {study_id: idx for idx, study_id in enumerate(test_set_study_ids)}
1048
+ indices = [index_map[study_id] for study_id in study_ids if study_id in index_map]
1049
+ indices.sort()
1050
+ test_set = PriorsDataset(test_set, self.decoder.config.history, self.time_delta_map)
1051
+ test_set.set_transform(test_set_transform)
1052
+ test_set = Subset(test_set, indices)
1053
+
1054
+ return train_set, val_set, test_set
1055
 
1056
+ def get_stage_1_dataset(self, dataset_path, train_transforms, test_transforms, max_train_images_per_study):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1057
 
1058
+ def train_set_transform(batch):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1059
 
1060
+ # Randomly select max_train_images_per_study if the number of images for a study exceeds max_train_images_per_study.
1061
+ for i in range(len(batch['images'])):
1062
+ if len(batch['images'][i]) > max_train_images_per_study:
1063
+ paired = list(zip(batch['images'][i], batch['ViewPosition'][i]))
1064
+ sampled_pairs = random.sample(paired, max_train_images_per_study)
1065
+ batch['images'][i], batch['ViewPosition'][i] = zip(*sampled_pairs)
1066
+
1067
+ batch['images'] = [[decode_image(torch.frombuffer(bytearray(j), dtype=torch.uint8)) for j in i] for i in batch['images']]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1068
 
1069
+ # Sort based on ViewPosition:
1070
+ batch['images'] = [list(zip(*sorted(zip(i, v), key=lambda x: VIEW_ORDER.index(x[1]))))[0] for i, v in zip(batch['images'], batch['ViewPosition'])]
1071
+ batch['images'] = [torch.stack([train_transforms(j) for j in i]) for i in batch['images']]
1072
+ max_size = max(i.shape[0] for i in batch['images'])
1073
+ batch['image_time_deltas'] = [[self.zero_time_delta_value if j < i.shape[0] else self.inf_time_delta_value for j in range(max_size)] for i in batch['images']]
1074
+ batch['images'] = torch.nn.utils.rnn.pad_sequence(batch['images'], batch_first=True, padding_value=0.0)
1075
+
1076
+ return batch
 
 
 
 
 
1077
 
1078
+ def test_set_transform(batch):
1079
+ batch['images'] = [[decode_image(torch.frombuffer(bytearray(j), dtype=torch.uint8)) for j in i] for i in batch['images']]
 
 
 
 
 
1080
 
1081
+ # Sort based on ViewPosition:
1082
+ batch['images'] = [list(zip(*sorted(zip(i, v), key=lambda x: VIEW_ORDER.index(x[1]))))[0] for i, v in zip(batch['images'], batch['ViewPosition'])]
1083
+ batch['images'] = [torch.stack([test_transforms(j) for j in i]) for i in batch['images']]
1084
+ max_size = max(i.shape[0] for i in batch['images'])
1085
+ batch['image_time_deltas'] = [[self.zero_time_delta_value if j < i.shape[0] else self.inf_time_delta_value for j in range(max_size)] for i in batch['images']]
1086
+ batch['images'] = torch.nn.utils.rnn.pad_sequence(batch['images'], batch_first=True, padding_value=0.0)
1087
 
1088
+ return batch
1089
+
1090
+ dataset = datasets.load_from_disk(dataset_path)
1091
+
1092
+ # Train set:
1093
+ with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), f'mimic_cxr_jpg_train_study_ids.json'), 'r') as f:
1094
+ study_ids = json.load(f)
1095
+ train_set = dataset['train']
1096
+ train_set_study_ids = train_set['study_id']
1097
+ index_map = {study_id: idx for idx, study_id in enumerate(train_set_study_ids)}
1098
+ indices = [index_map[study_id] for study_id in study_ids if study_id in index_map]
1099
+ indices.sort()
1100
+ train_set = PriorsDataset(train_set, self.decoder.config.history, self.time_delta_map)
1101
+ train_set.set_transform(train_set_transform)
1102
+ train_set = Subset(train_set, indices)
1103
+
1104
+ # Validation set:
1105
+ with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), f'mimic_cxr_jpg_validate_study_ids.json'), 'r') as f:
1106
+ study_ids = json.load(f)
1107
+ val_set = dataset['validate']
1108
+ val_set_study_ids = val_set['study_id']
1109
+ index_map = {study_id: idx for idx, study_id in enumerate(val_set_study_ids)}
1110
+ indices = [index_map[study_id] for study_id in study_ids if study_id in index_map]
1111
+ indices.sort()
1112
+ val_set = PriorsDataset(val_set, self.decoder.config.history, self.time_delta_map)
1113
+ val_set.set_transform(test_set_transform)
1114
+ val_set = Subset(val_set, indices)
1115
+
1116
+ # Test set:
1117
+ with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), f'mimic_cxr_jpg_test_study_ids.json'), 'r') as f:
1118
+ study_ids = json.load(f)
1119
+ test_set = dataset['test']
1120
+ test_set_study_ids = test_set['study_id']
1121
+ index_map = {study_id: idx for idx, study_id in enumerate(test_set_study_ids)}
1122
+ indices = [index_map[study_id] for study_id in study_ids if study_id in index_map]
1123
+ indices.sort()
1124
+ test_set = PriorsDataset(test_set, self.decoder.config.history, self.time_delta_map)
1125
+ test_set.set_transform(test_set_transform)
1126
+ test_set = Subset(test_set, indices)
1127
+
1128
+ return train_set, val_set, test_set
1129
+
1130
+ def prepare_index_value_feats(self, table, batch):
1131
+
1132
+ index_value_columns = (self.tables[table].get('index_columns', []) + self.tables[table].get('value_columns', []))
1133
+ index_value_columns = [f'{table}_{i}' for i in index_value_columns] if table != 'mimic_cxr_2_0_0_metadata' else index_value_columns
1134
+
1135
+ # Map to indices with lookup table:
1136
+ if 'index_columns' in self.tables[table]:
1137
+ for i in self.tables[table]['index_columns']:
1138
+ k = f'{table}_{i}' if not table == 'mimic_cxr_2_0_0_metadata' else i
1139
+ batch[k] = [
1140
+ [self.luts[table][i][str(k)] if k is not None else None for k in j] if j is not None else None for j in batch[k]
1141
+ ]
1142
+
1143
+ batch_index_value_feats_list = []
1144
+ batch_token_type_ids_list = []
1145
+ batch_time_deltas_list = []
1146
 
1147
+ for batch_idx in range(len(batch['study_id'])):
1148
+
1149
+ if any([batch[k][batch_idx] for k in index_value_columns]):
1150
+
1151
+ num_rows = [len(batch[i][batch_idx]) for i in index_value_columns]
1152
+ assert all(x == num_rows[0] for x in num_rows)
1153
+ num_rows = num_rows[0]
1154
+
1155
+ # The y-index and the datetime for each group:
1156
+ if isinstance(batch[self.tables[table]['groupby']][batch_idx], list):
1157
+ y_indices = [d.setdefault(x, len(d)) for d in [{}] for x in batch[self.tables[table]['groupby']][batch_idx]]
1158
+ datetime = [j for i, j in enumerate(batch[self.tables[table]['time_column']][batch_idx]) if j not in batch[self.tables[table]['time_column']][batch_idx][:i]]
1159
+ assert len(set(y_indices)) == len(datetime)
1160
+ else:
1161
+ y_indices = [0] * num_rows
1162
+ datetime = batch[self.tables[table]['time_column']][batch_idx] if 'time_column' in self.tables[table] else [batch['latest_study_datetime'][batch_idx]]
1163
+
1164
+ time_deltas = torch.tensor([compute_time_delta(i, batch['latest_study_datetime'][batch_idx], self.time_delta_map, to_tensor=False) for i in datetime])[:, None]
1165
+
1166
+ tensor = torch.zeros(max(y_indices) + 1, self.luts[table]['total'])
1167
+
1168
+ # Index columns to feats:
1169
+ if 'index_columns' in self.tables[table]:
1170
+
1171
+ for i in self.tables[table]['index_columns']:
1172
+ k = f'{table}_{i}' if not table == 'mimic_cxr_2_0_0_metadata' else i
1173
+ y_indices_column = [y_idx for y_idx, x_idx in zip(y_indices, batch[k][batch_idx]) if x_idx is not None]
1174
+ x_indices_column = [x_idx for x_idx in batch[k][batch_idx] if x_idx is not None]
1175
+
1176
+ tensor[y_indices_column, x_indices_column] = 1.0
1177
+
1178
+ if 'value_columns' in self.tables[table]:
1179
+ for i in self.tables[table]['value_columns']:
1180
+
1181
+ k = f'{table}_{i}' if not table == 'mimic_cxr_2_0_0_metadata' else i
1182
+ y_indices_column = [y_idx for y_idx, value in zip(y_indices, batch[k][batch_idx]) if value is not None]
1183
+ x_indices_column = [self.luts[table][i] for value in batch[k][batch_idx] if value is not None]
1184
+ values = [value for value in batch[k][batch_idx] if value is not None]
1185
+
1186
+ tensor[y_indices_column, x_indices_column] = torch.tensor(values, dtype=tensor.dtype)
1187
+ assert not torch.isnan(tensor).any()
1188
+ else:
1189
+ tensor = torch.empty(0, self.luts[table]['total'])
1190
+ time_deltas = torch.empty(0, 1)
1191
+
1192
+ batch_index_value_feats_list.append(tensor)
1193
+ batch_token_type_ids_list.append(torch.full(
1194
+ [tensor.shape[0]],
1195
+ self.token_type_to_token_type_id[table],
1196
+ dtype=torch.long,
1197
+ )
1198
  )
1199
+ batch_time_deltas_list.append(time_deltas)
1200
+
1201
+ assert tensor.shape[0] == batch_token_type_ids_list[-1].shape[0]
1202
+ assert tensor.shape[0] == time_deltas.shape[0]
1203
+
1204
+ batch_index_value_feats = torch.nn.utils.rnn.pad_sequence(batch_index_value_feats_list, batch_first=True, padding_value=-1) # Pad value of -1 is not ideal. Need to use something else.
1205
+ batch_token_type_ids = torch.nn.utils.rnn.pad_sequence(batch_token_type_ids_list, batch_first=True, padding_value=0)
1206
+ batch_time_deltas = torch.nn.utils.rnn.pad_sequence(batch_time_deltas_list, batch_first=True, padding_value=0)
1207
+
1208
+ batch_mask = (batch_index_value_feats != -1).any(dim=-1).int()
1209
+
1210
+ return batch_index_value_feats, batch_token_type_ids, batch_time_deltas, batch_mask
1211
+
1212
+ def prepare_text_prompt(self, table, column, batch):
1213
+
1214
+ key = f'{table}_{column}' if not table == 'mimic_cxr_sectioned' else column
1215
+
1216
+ batch_text_list = []
1217
+ batch_time_deltas_list = []
1218
+
1219
+ for batch_idx in range(len(batch['study_id'])):
1220
+ if batch[key][batch_idx]:
1221
+
1222
+ num_rows = len(batch[key][batch_idx])
1223
+
1224
+ # The y-index and the datetime for each group:
1225
+ if isinstance(batch[self.tables[table]['groupby']][batch_idx], list):
1226
+ y_indices = [d.setdefault(x, len(d)) for d in [{}] for x in batch[self.tables[table]['groupby']][batch_idx]]
1227
+ datetime = [j for i, j in enumerate(batch[self.tables[table]['time_column']][batch_idx]) if j not in batch[self.tables[table]['time_column']][batch_idx][:i]]
1228
+ assert len(set(y_indices)) == len(datetime)
1229
+ else:
1230
+ y_indices = [0] * num_rows
1231
+ datetime = batch[self.tables[table]['time_column']][batch_idx] if 'time_column' in self.tables[table] else [batch['latest_study_datetime'][batch_idx]]
1232
+
1233
+ # Remove None values:
1234
+ text_rows = batch[key][batch_idx] if isinstance(batch[key][batch_idx], list) else [batch[key][batch_idx]]
1235
+ y_indices = [i for i, j in zip(y_indices, text_rows) if j is not None]
1236
+ text_rows = [i for i in text_rows if i is not None]
1237
+ datetime = [datetime[i] for i in set(y_indices)]
1238
+ if text_rows:
1239
+
1240
+ # Those in the same group (or those with the same y-index) get joined as the same string:
1241
+ batch_text_list.append([', '.join([text_rows[j] for j in range(len(y_indices)) if y_indices[j] == k]) + '.' for k in set(y_indices)])
1242
+ batch_time_deltas_list.append([compute_time_delta(i, batch['latest_study_datetime'][batch_idx], self.time_delta_map, to_tensor=False) for i in datetime])
1243
+
1244
+ assert len(batch_time_deltas_list[-1]) == len(batch_text_list[-1])
1245
+ else:
1246
+ batch_text_list.append([])
1247
+ batch_time_deltas_list.append([])
1248
+ else:
1249
+ batch_text_list.append([])
1250
+ batch_time_deltas_list.append([])
1251
+
1252
+ return batch_text_list, batch_time_deltas_list
1253
 
1254
  @staticmethod
1255
  def collate_fn(batch):
1256
  keys = set().union(*(d.keys() for d in batch))
1257
  batch = {j: [i.setdefault(j, None) for i in batch] for j in keys}
1258
+ batch = {k: torch.stack(v) if isinstance(v[0], torch.Tensor) else v for k, v in batch.items()}
1259
+ return batch
1260
+
1261
+ @staticmethod
1262
+ def prepare_dataset(physionet_dir: str, database_dir: str):
1263
+ prepare_dataset(physionet_dir=physionet_dir, database_dir=database_dir)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modelling_uniformer.py CHANGED
@@ -1,16 +1,17 @@
1
  from collections import OrderedDict
2
  from functools import partial
3
- from typing import Optional, Tuple, Union
4
  from math import isqrt
 
5
 
6
  import torch
7
  import torch.nn as nn
8
  from timm.models.layers import DropPath, to_2tuple, trunc_normal_
9
- from transformers import ViTConfig
10
  from transformers.modeling_outputs import ModelOutput
11
  from transformers.modeling_utils import PreTrainedModel
12
  from transformers.utils import logging
13
 
 
 
14
  logger = logging.get_logger(__name__)
15
 
16
 
@@ -293,8 +294,7 @@ class UniFormerPreTrainedModel(PreTrainedModel):
293
  models.
294
  """
295
 
296
- config_class = ViTConfig
297
- base_model_prefix = "vit"
298
  main_input_name = "pixel_values"
299
 
300
  def _init_weights(self, m):
 
1
  from collections import OrderedDict
2
  from functools import partial
 
3
  from math import isqrt
4
+ from typing import Optional, Tuple, Union
5
 
6
  import torch
7
  import torch.nn as nn
8
  from timm.models.layers import DropPath, to_2tuple, trunc_normal_
 
9
  from transformers.modeling_outputs import ModelOutput
10
  from transformers.modeling_utils import PreTrainedModel
11
  from transformers.utils import logging
12
 
13
+ from .configuration_uniformer import UniFormerWithProjectionHeadConfig
14
+
15
  logger = logging.get_logger(__name__)
16
 
17
 
 
294
  models.
295
  """
296
 
297
+ config_class = UniFormerWithProjectionHeadConfig
 
298
  main_input_name = "pixel_values"
299
 
300
  def _init_weights(self, m):
prepare_dataset.py ADDED
@@ -0,0 +1,558 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import multiprocessing
3
+ import os
4
+ import re
5
+ import shutil
6
+ from glob import glob
7
+ from pathlib import Path
8
+
9
+ import datasets
10
+ import duckdb
11
+ import numpy as np
12
+ import pandas as pd
13
+
14
+ try:
15
+ from .create_section_files import create_section_files
16
+ except ImportError:
17
+ from create_section_files import create_section_files
18
+
19
+
20
+ def mimic_cxr_image_path(dir, subject_id, study_id, dicom_id, ext='dcm'):
21
+ return os.path.join(dir, 'p' + str(subject_id)[:2], 'p' + str(subject_id),
22
+ 's' + str(study_id), str(dicom_id) + '.' + ext)
23
+
24
+
25
+ def format(text):
26
+ # Remove newline, tab, repeated whitespaces, and leading and trailing whitespaces:
27
+ def remove(text):
28
+ text = re.sub(r'\n|\t', ' ', text)
29
+ text = re.sub(r'\s+', ' ', text)
30
+ return text.strip()
31
+
32
+ if isinstance(text, np.ndarray) or isinstance(text, list):
33
+ return [remove(t) if not pd.isna(t) else t for t in text]
34
+ else:
35
+ if pd.isna(text):
36
+ return text
37
+ return remove(text)
38
+
39
+
40
+ def create_lookup_table(df, columns, start_idx):
41
+ df = df.groupby(columns).head(1)[columns].sort_values(by=columns)
42
+ indices = range(start_idx, start_idx + len(df))
43
+ df['index'] = indices
44
+ return df, indices[-1]
45
+
46
+
47
+ def lookup_tables(con, tables):
48
+ luts_dict = {}
49
+ for k, v in tables.items():
50
+ luts_dict[k] = {}
51
+ start_idx = 0
52
+ if 'index_columns' in v:
53
+ for i in v['index_columns']:
54
+ lut, end_idx = create_lookup_table(con.sql(f"SELECT {i} FROM {k}").df(), [i], start_idx)
55
+ start_idx = end_idx + 1
56
+ luts_dict[k][i] = {str(row[i]): int(row['index']) for _, row in lut.iterrows()}
57
+ if 'value_columns' in v:
58
+ for i in v['value_columns']:
59
+ luts_dict[k][i] = start_idx
60
+ start_idx += 1
61
+
62
+ luts_dict[k]['total'] = start_idx
63
+
64
+ with open( os.path.join(os.path.dirname(os.path.abspath(__file__)), 'lookup_tables.json'), 'w') as file:
65
+ json.dump(luts_dict, file)
66
+
67
+
68
+ def prepare_dataset(physionet_dir, database_dir, num_workers=None):
69
+
70
+ num_workers = num_workers if num_workers is not None else multiprocessing.cpu_count()
71
+
72
+ Path(database_dir).mkdir(parents=True, exist_ok=True)
73
+
74
+ sectioned_dir = os.path.join(database_dir, 'mimic_cxr_sectioned')
75
+ mimic_cxr_sectioned_path = os.path.join(sectioned_dir, 'mimic_cxr_sectioned.csv')
76
+ if not os.path.exists(mimic_cxr_sectioned_path):
77
+ print(f'{mimic_cxr_sectioned_path} does not exist, creating...')
78
+
79
+ # Check if reports exist. Reports for the first and last patients are checked only for speed, this comprimises comprehensiveness for speed:
80
+ report_paths = [
81
+ os.path.join(physionet_dir, 'mimic-cxr/2.0.0/files/p10/p10000032/s50414267.txt'),
82
+ os.path.join(physionet_dir, 'mimic-cxr/2.0.0/files/p10/p10000032/s53189527.txt'),
83
+ os.path.join(physionet_dir, 'mimic-cxr/2.0.0/files/p10/p10000032/s53911762.txt'),
84
+ os.path.join(physionet_dir, 'mimic-cxr/2.0.0/files/p10/p10000032/s56699142.txt'),
85
+ os.path.join(physionet_dir, 'mimic-cxr/2.0.0/files/p19/p19999987/s55368167.txt'),
86
+ os.path.join(physionet_dir, 'mimic-cxr/2.0.0/files/p19/p19999987/s58621812.txt'),
87
+ os.path.join(physionet_dir, 'mimic-cxr/2.0.0/files/p19/p19999987/s58971208.txt'),
88
+ ]
89
+ assert all([os.path.isfile(i) for i in report_paths]), f"""The reports do not exist with the following regex: {os.path.join(physionet_dir, 'mimic-cxr/2.0.0/files/p1*/p1*/s*.txt')}.
90
+ "Please download them using wget -r -N -c -np --reject dcm --user <username> --ask-password https://physionet.org/files/mimic-cxr/2.0.0/"""
91
+
92
+ print('Extracting sections from reports...')
93
+ create_section_files(
94
+ reports_path=os.path.join(physionet_dir, 'mimic-cxr', '2.0.0', 'files'),
95
+ output_path=sectioned_dir,
96
+ no_split=True,
97
+ )
98
+
99
+ csv_paths = []
100
+ csv_paths.append(glob(os.path.join(physionet_dir, 'mimic-iv-ed', '*', 'ed', 'edstays.csv.gz'))[0])
101
+ csv_paths.append(glob(os.path.join(physionet_dir, 'mimic-iv-ed', '*', 'ed', 'medrecon.csv.gz'))[0])
102
+ csv_paths.append(glob(os.path.join(physionet_dir, 'mimic-iv-ed', '*', 'ed', 'pyxis.csv.gz'))[0])
103
+ csv_paths.append(glob(os.path.join(physionet_dir, 'mimic-iv-ed', '*', 'ed', 'triage.csv.gz'))[0])
104
+ csv_paths.append(glob(os.path.join(physionet_dir, 'mimic-iv-ed', '*', 'ed', 'vitalsign.csv.gz'))[0])
105
+
106
+ base_names = [os.path.basename(i) for i in csv_paths]
107
+
108
+ for i in ['edstays.csv.gz', 'medrecon.csv.gz', 'pyxis.csv.gz', 'triage.csv.gz', 'vitalsign.csv.gz']:
109
+ assert i in base_names, f"""Table {i} is missing from MIMIC-IV-ED.
110
+ Please download the tables from https://physionet.org/content/mimic-iv-ed. Do not decompress them."""
111
+
112
+ csv_paths.append(glob(os.path.join(physionet_dir, 'mimic-cxr-jpg', '*', 'mimic-cxr-2.0.0-metadata.csv.gz'))[0])
113
+ csv_paths.append(glob(os.path.join(physionet_dir, 'mimic-cxr-jpg', '*', 'mimic-cxr-2.0.0-chexpert.csv.gz'))[0])
114
+ csv_paths.append(glob(os.path.join(physionet_dir, 'mimic-cxr-jpg', '*', 'mimic-cxr-2.0.0-split.csv.gz'))[0])
115
+
116
+ base_names = [os.path.basename(i) for i in csv_paths[-3:]]
117
+
118
+ for i in ['mimic-cxr-2.0.0-metadata.csv.gz', 'mimic-cxr-2.0.0-chexpert.csv.gz', 'mimic-cxr-2.0.0-split.csv.gz']:
119
+ assert i in base_names, f"""CSV file {i} is missing from MIMIC-CXR-JPG.
120
+ Please download the tables from https://physionet.org/content/mimic-cxr-jpg. Do not decompress them."""
121
+
122
+ con = duckdb.connect(':memory:')
123
+ for i in csv_paths:
124
+ name = Path(i).stem.replace('.csv', '').replace('.gz', '').replace('-', '_').replace('.', '_')
125
+ print(f'Copying {name} into database...')
126
+ con.sql(f"CREATE OR REPLACE TABLE {name} AS FROM '{i}';")
127
+
128
+ # DuckDB has trouble reading the sectioned .csv file, read with pandas instead:
129
+ sections = pd.read_csv(mimic_cxr_sectioned_path)
130
+
131
+ # Remove the first character from the study column and rename it to study_id:
132
+ con.sql(
133
+ """
134
+ CREATE OR REPLACE TABLE mimic_cxr_sectioned AS
135
+ SELECT *, CAST(SUBSTR(study, 2) AS INT32) AS study_id
136
+ FROM sections;
137
+ """
138
+ )
139
+
140
+ # Combine StudyDate and StudyTime into a single column and create the studies table:
141
+ con.sql(
142
+ """
143
+ CREATE OR REPLACE TABLE studies AS
144
+ SELECT *,
145
+ strptime(
146
+ CAST(StudyDate AS VARCHAR) || ' ' || lpad(split_part(CAST(StudyTime AS VARCHAR), '.', 1), 6, '0'),
147
+ '%Y%m%d %H%M%S'
148
+ ) AS study_datetime
149
+ FROM mimic_cxr_2_0_0_metadata;
150
+ """
151
+ )
152
+
153
+ # Load the table configuration:
154
+ with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tables.json'), 'r') as file:
155
+ tables = json.load(file)
156
+
157
+ # Create lookup tables:
158
+ lookup_tables(con, tables)
159
+
160
+ # Collapse to one row per study, aggregate each studies columns as a list:
161
+ con.sql(
162
+ """
163
+ CREATE OR REPLACE TABLE studies AS
164
+ SELECT
165
+ LIST(dicom_id) AS dicom_id,
166
+ FIRST(subject_id) AS subject_id,
167
+ study_id,
168
+ LIST(PerformedProcedureStepDescription) AS PerformedProcedureStepDescription,
169
+ LIST(ViewPosition) AS ViewPosition,
170
+ LIST(Rows) AS Rows,
171
+ LIST(Columns) AS Columns,
172
+ LIST(StudyDate) AS StudyDate,
173
+ LIST(StudyTime) AS StudyTime,
174
+ LIST(ProcedureCodeSequence_CodeMeaning) AS ProcedureCodeSequence_CodeMeaning,
175
+ LIST(ViewCodeSequence_CodeMeaning) AS ViewCodeSequence_CodeMeaning,
176
+ LIST(PatientOrientationCodeSequence_CodeMeaning) AS PatientOrientationCodeSequence_CodeMeaning,
177
+ LIST(study_datetime) AS study_datetime,
178
+ MAX(study_datetime) AS latest_study_datetime,
179
+ FROM studies
180
+ GROUP BY study_id;
181
+ """
182
+ )
183
+
184
+ # Join and filter the studies that overlap with ED stays:
185
+ con.sql(
186
+ """
187
+ CREATE OR REPLACE TABLE studies AS
188
+ SELECT
189
+ s.*,
190
+ e.hadm_id,
191
+ e.stay_id,
192
+ e.intime,
193
+ e.outtime,
194
+ FROM studies s
195
+ LEFT JOIN edstays e
196
+ ON s.subject_id = e.subject_id
197
+ AND e.intime < s.latest_study_datetime
198
+ AND e.outtime > s.latest_study_datetime
199
+ AND s.study_id != 59128861;
200
+ """
201
+ ) # Don't join study 59128861 as it overlaps with two ED stays
202
+
203
+
204
+ # Aggregate and add the edstays table:
205
+ con.sql(
206
+ """
207
+ CREATE OR REPLACE TABLE edstays_aggregated AS
208
+ SELECT
209
+ FIRST(subject_id) AS subject_id,
210
+ stay_id,
211
+ LIST(intime) AS intime,
212
+ LIST(outtime) AS outtime,
213
+ LIST(gender) AS gender,
214
+ LIST(race) AS race,
215
+ LIST(arrival_transport) AS arrival_transport,
216
+ LIST(disposition) AS disposition,
217
+ FROM edstays
218
+ GROUP BY stay_id;
219
+ """
220
+ )
221
+ con.sql(
222
+ """
223
+ CREATE OR REPLACE TABLE studies AS
224
+ SELECT
225
+ s.*,
226
+ e.intime AS edstays_intime,
227
+ e.outtime AS edstays_outtime,
228
+ e.gender AS edstays_gender,
229
+ e.race AS edstays_race,
230
+ e.arrival_transport AS edstays_arrival_transport,
231
+ e.disposition AS edstays_disposition,
232
+ FROM studies s
233
+ LEFT JOIN edstays_aggregated e
234
+ ON s.stay_id = e.stay_id;
235
+ """
236
+ )
237
+
238
+ # Aggregate and add the triage table:
239
+ con.sql(
240
+ """
241
+ CREATE OR REPLACE TABLE triage_aggregated AS
242
+ SELECT
243
+ FIRST(subject_id) AS subject_id,
244
+ stay_id,
245
+ LIST(temperature) as temperature,
246
+ LIST(heartrate) AS heartrate,
247
+ LIST(resprate) AS resprate,
248
+ LIST(o2sat) AS o2sat,
249
+ LIST(sbp) AS sbp,
250
+ LIST(dbp) AS dbp,
251
+ LIST(pain) AS pain,
252
+ LIST(acuity) AS acuity,
253
+ LIST(chiefcomplaint) AS chiefcomplaint,
254
+ FROM triage
255
+ GROUP BY stay_id;
256
+ """
257
+ )
258
+ con.sql(
259
+ """
260
+ CREATE OR REPLACE TABLE studies AS
261
+ SELECT
262
+ s.*,
263
+ t.temperature AS triage_temperature,
264
+ t.heartrate AS triage_heartrate,
265
+ t.resprate AS triage_resprate,
266
+ t.o2sat AS triage_o2sat,
267
+ t.sbp AS triage_sbp,
268
+ t.dbp AS triage_dbp,
269
+ t.pain AS triage_pain,
270
+ t.acuity AS triage_acuity,
271
+ t.chiefcomplaint AS triage_chiefcomplaint,
272
+ FROM studies s
273
+ LEFT JOIN triage_aggregated t
274
+ ON s.stay_id = t.stay_id;
275
+ """
276
+ )
277
+
278
+ # Aggregate and then add the vitalsign table (ensuring no rows with a charttime after the latest study_datetime):
279
+ con.sql(
280
+ """
281
+ CREATE OR REPLACE TABLE vitalsign_causal AS
282
+ SELECT v.*, s.latest_study_datetime, s.study_id,
283
+ FROM vitalsign v
284
+ JOIN studies s ON v.stay_id = s.stay_id
285
+ WHERE v.charttime < s.latest_study_datetime;
286
+ """
287
+ ) # This duplicates the rows for stay_ids that cover multiple study_ids. Hence, the following joins must be on study_id, not stay_id.
288
+ con.sql(
289
+ """
290
+ CREATE OR REPLACE TABLE vitalsign_aggregated AS
291
+ SELECT
292
+ study_id,
293
+ FIRST(subject_id) AS subject_id,
294
+ FIRST(stay_id) as stay_id,
295
+ LIST(charttime) AS charttime,
296
+ LIST(temperature) as temperature,
297
+ LIST(heartrate) AS heartrate,
298
+ LIST(resprate) AS resprate,
299
+ LIST(o2sat) AS o2sat,
300
+ LIST(sbp) AS sbp,
301
+ LIST(dbp) AS dbp,
302
+ LIST(rhythm) AS rhythm,
303
+ LIST(pain) AS pain,
304
+ FROM vitalsign_causal
305
+ GROUP BY study_id;
306
+ """
307
+ )
308
+ con.sql(
309
+ """
310
+ CREATE OR REPLACE TABLE studies AS
311
+ SELECT
312
+ s.*,
313
+ v.charttime AS vitalsign_charttime,
314
+ v.temperature AS vitalsign_temperature,
315
+ v.heartrate AS vitalsign_heartrate,
316
+ v.resprate AS vitalsign_resprate,
317
+ v.o2sat AS vitalsign_o2sat,
318
+ v.sbp AS vitalsign_sbp,
319
+ v.dbp AS vitalsign_dbp,
320
+ v.rhythm AS vitalsign_rhythm,
321
+ v.pain AS vitalsign_pain,
322
+ FROM studies s
323
+ LEFT JOIN vitalsign_aggregated v
324
+ ON s.study_id = v.study_id;
325
+ """
326
+ )
327
+
328
+ # Aggregate and then add the medrecon table:
329
+ con.sql(
330
+ """
331
+ CREATE OR REPLACE TABLE medrecon_aggregated AS
332
+ SELECT
333
+ FIRST(subject_id) AS subject_id,
334
+ stay_id,
335
+ LIST(charttime) AS charttime,
336
+ LIST(name) as name,
337
+ LIST(gsn) AS gsn,
338
+ LIST(ndc) AS ndc,
339
+ LIST(etc_rn) AS etc_rn,
340
+ LIST(etccode) AS etccode,
341
+ LIST(etcdescription) AS etcdescription,
342
+ FROM medrecon
343
+ GROUP BY stay_id;
344
+ """
345
+ )
346
+ con.sql(
347
+ """
348
+ CREATE OR REPLACE TABLE studies AS
349
+ SELECT
350
+ s.*,
351
+ m.charttime AS medrecon_charttime,
352
+ m.name AS medrecon_name,
353
+ m.gsn AS medrecon_gsn,
354
+ m.ndc AS medrecon_ndc,
355
+ m.etc_rn AS medrecon_etc_rn,
356
+ m.etccode AS medrecon_etccode,
357
+ m.etcdescription AS medrecon_etcdescription,
358
+ FROM studies s
359
+ LEFT JOIN medrecon_aggregated m
360
+ ON s.stay_id = m.stay_id;
361
+ """
362
+ )
363
+
364
+ # Aggregate and then add the pyxis table (ensuring no rows with a charttime after the latest study_datetime):
365
+ con.sql(
366
+ """
367
+ CREATE OR REPLACE TABLE pyxis_causal AS
368
+ SELECT p.*, s.latest_study_datetime, s.study_id,
369
+ FROM pyxis p
370
+ JOIN studies s ON p.stay_id = s.stay_id
371
+ WHERE p.charttime < s.latest_study_datetime;
372
+ """
373
+ ) # This duplicates the rows for stay_ids that cover multiple study_ids. Hence, the following joins must be on study_id, not stay_id.
374
+ con.sql(
375
+ """
376
+ CREATE OR REPLACE TABLE pyxis_aggregated AS
377
+ SELECT
378
+ study_id,
379
+ FIRST(subject_id) AS subject_id,
380
+ FIRST(stay_id) as stay_id,
381
+ LIST(charttime) AS charttime,
382
+ LIST(med_rn) as med_rn,
383
+ LIST(name) as name,
384
+ LIST(gsn_rn) AS gsn_rn,
385
+ LIST(gsn) AS gsn,
386
+ FROM pyxis_causal
387
+ GROUP BY study_id;
388
+ """
389
+ )
390
+ con.sql(
391
+ """
392
+ CREATE OR REPLACE TABLE studies AS
393
+ SELECT
394
+ s.*,
395
+ p.charttime AS pyxis_charttime,
396
+ p.med_rn AS pyxis_med_rn,
397
+ p.name AS pyxis_name,
398
+ p.gsn_rn AS pyxis_gsn_rn,
399
+ p.gsn AS pyxis_gsn,
400
+ FROM studies s
401
+ LEFT JOIN pyxis_aggregated p
402
+ ON s.study_id = p.study_id;
403
+ """
404
+ )
405
+
406
+ # Add the reports:
407
+ con.sql(
408
+ """
409
+ CREATE OR REPLACE TABLE studies AS
410
+ SELECT s.*, r.findings, r.impression, r.indication, r.history, r.comparison, r.last_paragraph, r.technique,
411
+ FROM studies s
412
+ LEFT JOIN mimic_cxr_sectioned r
413
+ ON s.study_id = r.study_id
414
+ """
415
+ )
416
+
417
+ # Aggregate and then add the splits:
418
+ con.sql(
419
+ """
420
+ CREATE OR REPLACE TABLE split_aggregated AS
421
+ SELECT
422
+ study_id,
423
+ FIRST(split) AS split,
424
+ FROM mimic_cxr_2_0_0_split
425
+ GROUP BY study_id;
426
+ """
427
+ )
428
+ con.sql(
429
+ """
430
+ CREATE OR REPLACE TABLE studies AS
431
+ SELECT s.*, x.split,
432
+ FROM studies s
433
+ JOIN split_aggregated x
434
+ ON s.study_id = x.study_id;
435
+ """
436
+ )
437
+
438
+ # Prior studies column:
439
+ con.sql(
440
+ """
441
+ CREATE OR REPLACE TABLE prior_studies AS
442
+ WITH sorted AS (
443
+ SELECT *,
444
+ ROW_NUMBER() OVER (PARTITION BY subject_id ORDER BY latest_study_datetime) AS rn
445
+ FROM studies
446
+ ),
447
+ aggregated AS (
448
+ SELECT subject_id,
449
+ study_id,
450
+ latest_study_datetime,
451
+ ARRAY_AGG(study_id) OVER (PARTITION BY subject_id ORDER BY rn ROWS BETWEEN UNBOUNDED PRECEDING AND 1 PRECEDING) AS prior_study_ids,
452
+ ARRAY_AGG(latest_study_datetime) OVER (PARTITION BY subject_id ORDER BY rn ROWS BETWEEN UNBOUNDED PRECEDING AND 1 PRECEDING) AS prior_study_datetimes
453
+ FROM sorted
454
+ )
455
+ SELECT *
456
+ FROM aggregated;
457
+ """
458
+ )
459
+ con.sql(
460
+ """
461
+ CREATE OR REPLACE TABLE studies AS
462
+ SELECT s.*, p.prior_study_ids, p.prior_study_datetimes,
463
+ FROM studies s
464
+ LEFT JOIN prior_studies p
465
+ ON s.study_id = p.study_id
466
+ ORDER BY s.subject_id, s.study_datetime DESC;
467
+ """
468
+ )
469
+
470
+ # Text columns:
471
+ text_columns = [f'{k}_{j}' if k != 'mimic_cxr_sectioned' else j for k, v in tables.items() if 'text_columns' in v for j in (v['text_columns'] if isinstance(v['text_columns'], list) else [v['text_columns']])] + ['findings', 'impression']
472
+
473
+ pattern = os.path.join(physionet_dir, 'mimic-cxr-jpg', '*', 'files')
474
+ mimic_cxr_jpg_dir = glob(pattern)
475
+ assert len(mimic_cxr_jpg_dir), f'Multiple directories matched the pattern {pattern}: {mimic_cxr_jpg_dir}. Only one is required.'
476
+ mimic_cxr_jpg_dir = mimic_cxr_jpg_dir[0]
477
+
478
+ def load_image(row):
479
+ images = []
480
+ for dicom_ids, study_id, subject_id in zip(row['dicom_id'], row['study_id'], row['subject_id']):
481
+ study_images = []
482
+ for dicom_id in dicom_ids:
483
+ image_path = mimic_cxr_image_path(mimic_cxr_jpg_dir, subject_id, study_id, dicom_id, 'jpg')
484
+ with open(image_path, 'rb') as f:
485
+ image = f.read()
486
+ study_images.append(image)
487
+ images.append(study_images)
488
+ row['images'] = images
489
+ return row
490
+
491
+ dataset_dict = {}
492
+ for split in ['test', 'validate', 'train']:
493
+ df = con.sql(f"FROM studies WHERE split = '{split}'").df()
494
+
495
+ # Format text columns:
496
+ for i in text_columns:
497
+ df[i] = df[i].apply(format)
498
+
499
+ # Save indices for each split:
500
+ df[df['findings'].notna() & df['impression'].notna()]['study_id'].to_json(
501
+ os.path.join(os.path.dirname(os.path.abspath(__file__)), f'mimic_cxr_jpg_{split}_study_ids.json'),
502
+ orient='records',
503
+ lines=False,
504
+ )
505
+ df_stay_id = df[df['findings'].notna() & df['impression'].notna() & df['stay_id'].notna()][['study_id', 'stay_id']]
506
+ df_stay_id['stay_id'] = df_stay_id['stay_id'].astype(int)
507
+ df_stay_id['study_id'].to_json(
508
+ os.path.join(os.path.dirname(os.path.abspath(__file__)), f'mimic_iv_ed_mimic_cxr_jpg_{split}_study_ids.json'),
509
+ orient='records',
510
+ lines=False,
511
+ )
512
+
513
+ if split == 'test':
514
+ pyxis_columns = [col for col in df.columns if col.startswith('pyxis_')]
515
+ df_pyxis = df[df['findings'].notna() & df['impression'].notna() & df['stay_id'].notna()]
516
+ df_pyxis = df_pyxis[~df_pyxis[pyxis_columns].isna().all(axis=1)]
517
+ df_pyxis['study_id'].to_json(
518
+ os.path.join(os.path.dirname(os.path.abspath(__file__)), f'mimic_iv_ed_mimic_cxr_jpg_pyxis_{split}_study_ids.json'),
519
+ orient='records',
520
+ lines=False,
521
+ )
522
+
523
+ vitalsign_columns = [col for col in df.columns if col.startswith('vitalsign_')]
524
+ df_vitalsign = df[df['findings'].notna() & df['impression'].notna() & df['stay_id'].notna()]
525
+ df_vitalsign = df_vitalsign[~df_vitalsign[vitalsign_columns].isna().all(axis=1)]
526
+ df_vitalsign['study_id'].to_json(
527
+ os.path.join(os.path.dirname(os.path.abspath(__file__)), f'mimic_iv_ed_mimic_cxr_jpg_vitalsign_{split}_study_ids.json'),
528
+ orient='records',
529
+ lines=False,
530
+ )
531
+
532
+ # dataset_dict[split] = datasets.Dataset.from_pandas(df)
533
+ # cache_dir = os.path.join(database_dir, '.cache')
534
+ # Path(cache_dir).mkdir(parents=True, exist_ok=True)
535
+ # dataset_dict[split] = dataset_dict[split].map(
536
+ # load_image,
537
+ # num_proc=num_workers,
538
+ # writer_batch_size=8,
539
+ # batched=True,
540
+ # batch_size=8,
541
+ # keep_in_memory=False,
542
+ # cache_file_name=os.path.join(cache_dir, f'.{split}'),
543
+ # load_from_cache_file=False,
544
+ # )
545
+ # dataset_dict[split].cleanup_cache_files()
546
+ # shutil.rmtree(cache_dir)
547
+
548
+ # dataset = datasets.DatasetDict(dataset_dict)
549
+ # dataset.save_to_disk(os.path.join(database_dir, 'mimic_iv_ed_mimic_cxr_jpg_dataset'))
550
+
551
+ # con.close()
552
+
553
+
554
+ if __name__ == "__main__":
555
+ physionet_dir = '/datasets/work/hb-mlaifsp-mm/work/archive/physionet.org/files' # Where MIMIC-CXR, MIMIC-CXR-JPG, and MIMIC-IV-ED are stored.
556
+ database_dir = '/scratch3/nic261/database/cxrmate_ed' # Where the resultant database will be stored.
557
+
558
+ prepare_dataset(physionet_dir=physionet_dir, database_dir=database_dir)
utils.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def compute_time_delta(event_time, reference_time, time_delta_map, denominator = 3600, to_tensor=True):
5
+ """
6
+ How to we transform time delta inputs? It appears that minutes are used as the input to
7
+ a weight matrix in "Self-Supervised Transformer for Sparse and Irregularly Sampled Multivariate
8
+ Clinical Time-Series". This is almost confirmed by the CVE class defined here:
9
+ https://github.com/sindhura97/STraTS/blob/main/strats_notebook.ipynb, where the input has
10
+ a size of one.
11
+ """
12
+ time_delta = reference_time - event_time
13
+ time_delta = time_delta.total_seconds() / (denominator)
14
+ assert isinstance(time_delta, float), f'time_delta should be float, not {type(time_delta)}.'
15
+ if time_delta < 0:
16
+ raise ValueError(f'time_delta should be greater than or equal to zero, not {time_delta}.')
17
+ time_delta = time_delta_map(time_delta)
18
+ if to_tensor:
19
+ time_delta = torch.tensor(time_delta)
20
+ return time_delta