anicolson commited on
Commit
80559a6
1 Parent(s): 93d47b7

Upload model

Browse files
config.json CHANGED
@@ -1,4 +1,5 @@
1
  {
 
2
  "architectures": [
3
  "CXRMateEDModel"
4
  ],
@@ -6,49 +7,84 @@
6
  "AutoConfig": "configuration_cxrmate_ed.CXRMateEDConfig",
7
  "AutoModel": "modelling_cxrmate_ed.CXRMateEDModel"
8
  },
9
- "decoder": {
10
- "add_time_deltas": true,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  "hidden_size": 768,
12
- "history": 0,
13
- "include_time_delta": true,
14
- "index_value_encoder_intermediate_size": 2048,
15
  "intermediate_size": 3072,
16
- "is_decoder": true,
17
  "model_type": "llama",
18
  "num_attention_heads": 12,
19
  "num_hidden_layers": 6,
20
  "num_key_value_heads": 12,
21
- "pad_token_id": 4,
22
- "prompt_report_sections_filter": [
23
- "indication",
24
- "history"
25
- ],
26
- "tables_filter": [
27
- "mimic_cxr_sectioned",
28
- "triage",
29
- "medrecon"
30
- ],
31
- "time_delta_monotonic_inversion": true,
32
  "vocab_size": 30000
33
  },
34
- "encoder": {
 
 
 
35
  "_name_or_path": "aehrc/uniformer_base_tl_384",
36
  "architectures": [
37
  "UniFormerModel"
38
  ],
 
39
  "auto_map": {
40
  "AutoConfig": "aehrc/uniformer_base_tl_384--configuration_uniformer.UniFormerWithProjectionHeadConfig",
41
  "AutoModel": "aehrc/uniformer_base_tl_384--modelling_uniformer.UniFormerModel"
42
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  "init_value": 1e-06,
 
44
  "layer_scale": false,
 
45
  "model_type": "uniformer",
46
- "projection_size": 768,
 
 
 
 
 
 
 
 
 
 
47
  "torch_dtype": "float32"
48
  },
49
- "is_encoder_decoder": false,
50
- "model_type": "cxrmate-ed",
51
- "tie_word_embeddings": false,
52
- "torch_dtype": "float32",
53
- "transformers_version": "4.39.3"
54
  }
 
1
  {
2
+ "add_time_deltas": true,
3
  "architectures": [
4
  "CXRMateEDModel"
5
  ],
 
7
  "AutoConfig": "configuration_cxrmate_ed.CXRMateEDConfig",
8
  "AutoModel": "modelling_cxrmate_ed.CXRMateEDModel"
9
  },
10
+ "hidden_size": 768,
11
+ "history": 0,
12
+ "ignore_index": -100,
13
+ "image_seq_length": 576,
14
+ "image_token_index": 32000,
15
+ "include_time_delta": true,
16
+ "index_value_encoder_intermediate_size": 2048,
17
+ "model_type": "cxrmate-ed",
18
+ "pad_token_id": 4,
19
+ "projector_hidden_act": "gelu",
20
+ "prompt_report_sections_filter": [
21
+ "indication",
22
+ "history"
23
+ ],
24
+ "tables_filter": [
25
+ "mimic_cxr_sectioned",
26
+ "triage",
27
+ "medrecon"
28
+ ],
29
+ "text_config": {
30
+ "head_dim": 64,
31
  "hidden_size": 768,
 
 
 
32
  "intermediate_size": 3072,
 
33
  "model_type": "llama",
34
  "num_attention_heads": 12,
35
  "num_hidden_layers": 6,
36
  "num_key_value_heads": 12,
 
 
 
 
 
 
 
 
 
 
 
37
  "vocab_size": 30000
38
  },
39
+ "time_delta_monotonic_inversion": true,
40
+ "torch_dtype": "float32",
41
+ "transformers_version": "4.47.0",
42
+ "vision_config": {
43
  "_name_or_path": "aehrc/uniformer_base_tl_384",
44
  "architectures": [
45
  "UniFormerModel"
46
  ],
47
+ "attn_drop_rate": 0.0,
48
  "auto_map": {
49
  "AutoConfig": "aehrc/uniformer_base_tl_384--configuration_uniformer.UniFormerWithProjectionHeadConfig",
50
  "AutoModel": "aehrc/uniformer_base_tl_384--modelling_uniformer.UniFormerModel"
51
  },
52
+ "conv_stem": false,
53
+ "depth": [
54
+ 5,
55
+ 8,
56
+ 20,
57
+ 7
58
+ ],
59
+ "drop_path_rate": 0.3,
60
+ "drop_rate": 0.0,
61
+ "embed_dim": [
62
+ 64,
63
+ 128,
64
+ 320,
65
+ 512
66
+ ],
67
+ "head_dim": 64,
68
+ "image_size": 384,
69
+ "in_chans": 3,
70
  "init_value": 1e-06,
71
+ "layer_norm_eps": 1e-06,
72
  "layer_scale": false,
73
+ "mlp_ratio": 4,
74
  "model_type": "uniformer",
75
+ "num_classes": 1000,
76
+ "patch_size": [
77
+ 4,
78
+ 2,
79
+ 2,
80
+ 2
81
+ ],
82
+ "projection_size": null,
83
+ "qk_scale": null,
84
+ "qkv_bias": true,
85
+ "representation_size": null,
86
  "torch_dtype": "float32"
87
  },
88
+ "vision_feature_layer": -2,
89
+ "vision_feature_select_strategy": "default"
 
 
 
90
  }
configuration_cxrmate_ed.py CHANGED
@@ -1,61 +1,94 @@
1
- import transformers
2
- from transformers.configuration_utils import PretrainedConfig
3
- from transformers.utils import logging
4
 
5
- logger = logging.get_logger(__name__)
6
 
7
 
8
- class CXRMateEDConfig(PretrainedConfig):
9
 
10
- model_type = "cxrmate-ed"
11
 
12
- def __init__(self, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
13
  super().__init__(**kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- if 'decoder' not in kwargs:
16
-
17
- self.decoder = transformers.LlamaConfig(
18
- vocab_size=30000,
19
- hidden_size=768,
20
- intermediate_size=3072,
21
- num_attention_heads=12,
22
- num_hidden_layers=6,
23
- max_position_embeddings=2048,
24
- )
25
- self.decoder.is_decoder = True
26
-
27
- self.decoder.index_value_encoder_intermediate_size = 2048
28
- self.decoder.include_time_delta = True
29
- self.decoder.time_delta_monotonic_inversion = True
30
- self.decoder.add_time_deltas = True
31
- self.decoder.history = 0
32
- self.decoder.tables_filter = ["mimic_cxr_sectioned", "triage", "medrecon"]
33
- self.decoder.prompt_report_sections_filter = ["indication", "history"]
34
- self.decoder.pad_token_id = 4
35
-
36
- else:
37
- self.decoder = kwargs.pop("decoder")
38
-
39
-
40
- if 'encoder' not in kwargs:
41
- self.encoder = transformers.AutoConfig.from_pretrained(
42
- 'aehrc/uniformer_base_tl_384',
43
- projection_size=768,
44
- trust_remote_code=True,
45
- )
46
- else:
47
- self.encoder = kwargs.pop("encoder")
48
 
49
 
50
- self.is_encoder_decoder = True
51
 
52
- @classmethod
53
- def from_encoder_decoder_configs(
54
- cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs
55
- ) -> PretrainedConfig:
56
 
57
- logger.info("Set `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config")
58
- decoder_config.is_decoder = True
59
- decoder_config.add_cross_attention = True
60
 
61
- return cls(encoder=encoder_config, decoder=decoder_config, **kwargs)
 
1
+ from typing import Any
 
 
2
 
3
+ from transformers import LlavaConfig
4
 
5
 
6
+ class CXRMateEDConfig(LlavaConfig):
7
 
8
+ model_type = 'cxrmate-ed'
9
 
10
+ def __init__(
11
+ self,
12
+ index_value_encoder_intermediate_size: int = 2048,
13
+ include_time_delta: bool = True,
14
+ time_delta_monotonic_inversion: bool = True,
15
+ add_time_deltas: bool = True,
16
+ history: int = 0,
17
+ tables_filter: list = ['mimic_cxr_sectioned', 'triage', 'medrecon'],
18
+ prompt_report_sections_filter: list = ['indication', 'history'],
19
+ pad_token_id: int = 4,
20
+ **kwargs: Any,
21
+ ) -> None:
22
  super().__init__(**kwargs)
23
+ self.index_value_encoder_intermediate_size = index_value_encoder_intermediate_size
24
+ self.include_time_delta = include_time_delta
25
+ self.time_delta_monotonic_inversion = time_delta_monotonic_inversion
26
+ self.add_time_deltas = add_time_deltas
27
+ self.history = history
28
+ self.tables_filter = tables_filter
29
+ self.prompt_report_sections_filter = prompt_report_sections_filter
30
+ self.pad_token_id = pad_token_id
31
+
32
+ self.hidden_size = self.text_config.hidden_size
33
+
34
+ # import transformers
35
+ # from transformers.configuration_utils import PretrainedConfig
36
+ # from transformers.utils import logging
37
+
38
+ # logger = logging.get_logger(__name__)
39
+
40
+
41
+ # class CXRMateEDConfig(PretrainedConfig):
42
+
43
+ # model_type = "cxrmate-ed"
44
+
45
+ # def __init__(self, **kwargs):
46
+ # super().__init__(**kwargs)
47
+
48
+ # if 'decoder' not in kwargs:
49
+
50
+ # self.decoder = transformers.LlamaConfig(
51
+ # vocab_size=30000,
52
+ # hidden_size=768,
53
+ # intermediate_size=3072,
54
+ # num_attention_heads=12,
55
+ # num_hidden_layers=6,
56
+ # max_position_embeddings=2048,
57
+ # )
58
+ # self.decoder.is_decoder = True
59
+
60
+ # self.decoder.index_value_encoder_intermediate_size = 2048
61
+ # self.decoder.include_time_delta = True
62
+ # self.decoder.time_delta_monotonic_inversion = True
63
+ # self.decoder.add_time_deltas = True
64
+ # self.decoder.history = 0
65
+ # self.decoder.tables_filter = ["mimic_cxr_sectioned", "triage", "medrecon"]
66
+ # self.decoder.prompt_report_sections_filter = ["indication", "history"]
67
+ # self.decoder.pad_token_id = 4
68
+
69
+ # else:
70
+ # self.decoder = kwargs.pop("decoder")
71
+
72
 
73
+ # if 'encoder' not in kwargs:
74
+ # self.encoder = transformers.AutoConfig.from_pretrained(
75
+ # 'aehrc/uniformer_base_tl_384',
76
+ # projection_size=768,
77
+ # trust_remote_code=True,
78
+ # )
79
+ # else:
80
+ # self.encoder = kwargs.pop("encoder")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
 
83
+ # self.is_encoder_decoder = True
84
 
85
+ # @classmethod
86
+ # def from_encoder_decoder_configs(
87
+ # cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs
88
+ # ) -> PretrainedConfig:
89
 
90
+ # logger.info("Set `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config")
91
+ # decoder_config.is_decoder = True
92
+ # decoder_config.add_cross_attention = True
93
 
94
+ # return cls(encoder=encoder_config, decoder=decoder_config, **kwargs)
generation_config.json CHANGED
@@ -3,5 +3,5 @@
3
  "bos_token_id": 1,
4
  "eos_token_id": 2,
5
  "pad_token_id": 4,
6
- "transformers_version": "4.39.3"
7
  }
 
3
  "bos_token_id": 1,
4
  "eos_token_id": 2,
5
  "pad_token_id": 4,
6
+ "transformers_version": "4.47.0"
7
  }
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:757ff7d2e55bf73d9a170d521fdfacc735b8226dcca11c32b5a20d2b2250ec48
3
- size 789964216
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:00a9a6697b96ba73294054503626e877190b4c30b95d826d3ca3410d44739aed
3
+ size 789967160
modelling_cxrmate_ed.py CHANGED
@@ -14,7 +14,7 @@ from transformers import PreTrainedTokenizerFast, VisionEncoderDecoderModel
14
  from transformers.configuration_utils import PretrainedConfig
15
  from transformers.modeling_outputs import ModelOutput, Seq2SeqLMOutput
16
  from transformers.modeling_utils import PreTrainedModel
17
- from transformers.utils import logging
18
 
19
  from .configuration_cxrmate_ed import CXRMateEDConfig
20
  from .dataset import PriorsDataset
@@ -108,74 +108,39 @@ class CXRStudyImagesEncoder(torch.nn.Module):
108
  return ModelOutput(last_hidden_state=last_hidden_state, attention_mask=attention_mask)
109
 
110
 
111
-
112
- class CXRMateEDModel(VisionEncoderDecoderModel):
113
 
114
  config_class = CXRMateEDConfig
115
 
116
- def __init__(
117
- self,
118
- config: Optional[PretrainedConfig] = None,
119
- encoder: Optional[PreTrainedModel] = None,
120
- decoder: Optional[PreTrainedModel] = None,
121
- ):
122
-
123
- if decoder:
124
- assert decoder.config.is_decoder, '"is_decoder" must be True for the given decoder'
125
-
126
- if config is None and (encoder is None or decoder is None):
127
- raise ValueError("Either a configuration or an encoder and a decoder has to be provided.")
128
- if config is None:
129
- config = CXRMateEDConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
130
- else:
131
- if not isinstance(config, self.config_class):
132
- raise ValueError(f"Config: {config} has to be of type {self.config_class}")
133
-
134
- config.tie_word_embeddings = False
135
- config.is_encoder_decoder = False
136
-
137
- # Initialize with config:
138
- PreTrainedModel.__init__(self, config)
139
-
140
- # Encoder:
141
- if encoder is None:
142
- encoder = transformers.AutoModel.from_pretrained(
143
- 'aehrc/uniformer_base_tl_384',
144
- config=config.encoder,
145
- trust_remote_code=True,
146
- )
147
-
148
- # Decoder:
149
- if decoder is None:
150
- decoder = transformers.LlamaForCausalLM(config=config.decoder)
151
-
152
- self.encoder = CXRStudyImagesEncoder(encoder, self.config.decoder)
153
- self.decoder = decoder
154
-
155
- if self.encoder.config.to_dict() != self.config.encoder.to_dict():
156
- logger.warning(
157
- f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config:"
158
- f" {self.config.encoder}"
159
- )
160
- if self.decoder.config.to_dict() != self.config.decoder.to_dict():
161
- logger.warning(
162
- f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:"
163
- f" {self.config.decoder}"
164
- )
165
 
166
- self.encoder.config = self.config.encoder
167
- self.decoder.config = self.config.decoder
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
- assert config.decoder.is_decoder
170
- assert not config.decoder.is_encoder_decoder
171
- assert 'pad_token_id' in self.decoder.config.__dict__
172
- assert 'time_delta_monotonic_inversion' in self.decoder.config.__dict__
173
- assert 'add_time_deltas' in self.decoder.config.__dict__
174
- assert 'history' in self.decoder.config.__dict__
175
- assert 'tables_filter' in self.decoder.config.__dict__
176
- assert 'prompt_report_sections_filter' in self.decoder.config.__dict__
177
 
178
- assert isinstance(self.decoder.config.time_delta_monotonic_inversion, bool)
179
 
180
  with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tables.json'), 'r') as f:
181
  self.tables = json.load(f)
@@ -186,8 +151,8 @@ class CXRMateEDModel(VisionEncoderDecoderModel):
186
  with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'token_type_ids.json'), 'r') as f:
187
  self.token_type_to_token_type_id = json.load(f)
188
 
189
- self.tables = {k: self.tables[k] for k in self.decoder.config.tables_filter}
190
- self.tables['mimic_cxr_sectioned']['text_columns'] = self.decoder.config.prompt_report_sections_filter
191
 
192
  for k in self.tables.keys():
193
  if self.luts[k]['total'] > 0:
@@ -196,179 +161,182 @@ class CXRMateEDModel(VisionEncoderDecoderModel):
196
  f'{k}_index_value_encoder',
197
  FNNEncoder(
198
  num_features=self.luts[k]['total'],
199
- intermediate_size=self.decoder.config.index_value_encoder_intermediate_size,
200
- decoder_hidden_size=self.decoder.config.hidden_size,
201
  ),
202
  )
203
 
204
- if self.decoder.config.add_time_deltas:
205
  self.time_delta_encoder = FNNEncoder(
206
  num_features=1,
207
- intermediate_size=self.decoder.config.index_value_encoder_intermediate_size,
208
- decoder_hidden_size=self.decoder.config.hidden_size,
209
  )
210
 
211
- self.token_type_embeddings = torch.nn.Embedding(max(self.token_type_to_token_type_id.values()) + 1, self.decoder.config.hidden_size)
212
 
213
  self.time_delta_map = lambda x: 1 / math.sqrt(x + 1)
214
  self.zero_time_delta_value = self.time_delta_map(0)
215
 
216
  self.inf_time_delta_value = self.time_delta_map(float('inf'))
217
-
218
- @classmethod
219
- def from_encoder_decoder_pretrained(
220
- cls,
221
- encoder_pretrained_model_name_or_path: str = None,
222
- decoder_pretrained_model_name_or_path: str = None,
223
- *model_args,
224
- **kwargs,
225
- ) -> PreTrainedModel:
226
- r"""
227
- Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model
228
- checkpoints.
229
-
230
-
231
- The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
232
- the model, you need to first set it back in training mode with `model.train()`.
233
-
234
- Params:
235
- encoder_pretrained_model_name_or_path (`str`, *optional*):
236
- Information necessary to initiate the image encoder. Can be either:
237
-
238
- - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. An
239
- example is `google/vit-base-patch16-224-in21k`.
240
- - A path to a *directory* containing model weights saved using
241
- [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
242
- - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
243
- this case, `from_tf` should be set to `True` and a configuration object should be provided as
244
- `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
245
- PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
246
-
247
- decoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`):
248
- Information necessary to initiate the text decoder. Can be either:
249
-
250
- - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
251
- - A path to a *directory* containing model weights saved using
252
- [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
253
- - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
254
- this case, `from_tf` should be set to `True` and a configuration object should be provided as
255
- `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
256
- PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
257
-
258
- model_args (remaining positional arguments, *optional*):
259
- All remaning positional arguments will be passed to the underlying model's `__init__` method.
260
-
261
- kwargs (remaining dictionary of keyword arguments, *optional*):
262
- Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
263
- `output_attentions=True`).
264
-
265
- - To update the encoder configuration, use the prefix *encoder_* for each configuration parameter.
266
- - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter.
267
- - To update the parent model configuration, do not use a prefix for each configuration parameter.
268
-
269
- Behaves differently depending on whether a `config` is provided or automatically loaded.
270
-
271
- Example:
272
-
273
- ```python
274
- >>> from transformers import VisionEncoderDecoderModel
275
-
276
- >>> # initialize a vit-bert from a pretrained ViT and a pretrained BERT model. Note that the cross-attention layers will be randomly initialized
277
- >>> model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
278
- ... "google/vit-base-patch16-224-in21k", "google-bert/bert-base-uncased"
279
- ... )
280
- >>> # saving model after fine-tuning
281
- >>> model.save_pretrained("./vit-bert")
282
- >>> # load fine-tuned model
283
- >>> model = VisionEncoderDecoderModel.from_pretrained("./vit-bert")
284
- ```"""
285
-
286
- kwargs_encoder = {
287
- argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_")
288
- }
289
-
290
- kwargs_decoder = {
291
- argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
292
- }
293
-
294
- # remove encoder, decoder kwargs from kwargs
295
- for key in kwargs_encoder.keys():
296
- del kwargs["encoder_" + key]
297
- for key in kwargs_decoder.keys():
298
- del kwargs["decoder_" + key]
299
-
300
- # Load and initialize the encoder and decoder
301
- # The distinction between encoder and decoder at the model level is made
302
- # by the value of the flag `is_decoder` that we need to set correctly.
303
- encoder = kwargs_encoder.pop("model", None)
304
- if encoder is None:
305
- if encoder_pretrained_model_name_or_path is None:
306
- raise ValueError(
307
- "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
308
- "to be defined."
309
- )
310
-
311
- if "config" not in kwargs_encoder:
312
- encoder_config, kwargs_encoder = transformers.AutoConfig.from_pretrained(
313
- encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True
314
- )
315
-
316
- if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
317
- logger.info(
318
- f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
319
- "from a decoder model. Cross-attention and casual mask are disabled."
320
- )
321
- encoder_config.is_decoder = False
322
- encoder_config.add_cross_attention = False
323
-
324
- kwargs_encoder["config"] = encoder_config
325
-
326
- encoder = transformers.AutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder)
327
-
328
- decoder = kwargs_decoder.pop("model", None)
329
- if decoder is None:
330
- if decoder_pretrained_model_name_or_path is None:
331
- raise ValueError(
332
- "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
333
- "to be defined."
334
- )
335
-
336
- if "config" not in kwargs_decoder:
337
- decoder_config, kwargs_decoder = transformers.AutoConfig.from_pretrained(
338
- decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True
339
- )
340
-
341
- if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
342
- logger.info(
343
- f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention"
344
- f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if"
345
- f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
346
- )
347
- decoder_config.is_decoder = True
348
- decoder_config.add_cross_attention = False
349
-
350
- kwargs_decoder["config"] = decoder_config
351
-
352
- if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
353
- logger.warning(
354
- f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
355
- f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
356
- "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
357
- "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
358
- "`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
359
- )
360
-
361
- decoder = transformers.AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
362
-
363
- # instantiate config with corresponding kwargs
364
- config = CXRMateEDConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
365
-
366
- # make sure input & output embeddings is not tied
367
- config.tie_word_embeddings = False
368
 
369
- config.is_encoder_decoder = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
 
371
- return cls(encoder=encoder, decoder=decoder, config=config)
372
 
373
  def forward(
374
  self,
@@ -393,14 +361,17 @@ class CXRMateEDModel(VisionEncoderDecoderModel):
393
  argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
394
  }
395
 
396
- assert decoder_attention_mask.dtype == torch.long, f'The dtype for {decoder_attention_mask} was {decoder_attention_mask.dtype}. It should be torch.long'
397
-
398
  if decoder_inputs_embeds is None:
399
- decoder_inputs_embeds = self.decoder.get_input_embeddings()(decoder_input_ids)
400
  decoder_inputs_embeds += self.token_type_embeddings(decoder_token_type_ids)
401
 
 
 
 
 
 
402
  # Generation:
403
- decoder_outputs = self.decoder(
404
  inputs_embeds=decoder_inputs_embeds,
405
  attention_mask=decoder_attention_mask,
406
  position_ids=decoder_position_ids,
@@ -417,7 +388,7 @@ class CXRMateEDModel(VisionEncoderDecoderModel):
417
  if labels is not None:
418
  logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
419
  loss_fct = CrossEntropyLoss()
420
- loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.reshape(-1))
421
 
422
  if not return_dict:
423
  if loss is not None:
@@ -448,20 +419,22 @@ class CXRMateEDModel(VisionEncoderDecoderModel):
448
  https://github.com/huggingface/transformers/blob/main/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py#L660
449
  """
450
 
451
- report_attention_mask = (input_ids != self.decoder.config.pad_token_id).long()
452
 
453
- if past_key_values is None:
454
 
455
  # 4D attention mask:
456
- decoder_attention_mask = self.create_4d_attention_mask_mixed_causality(prompt_attention_mask, report_attention_mask)
457
-
 
 
458
  # Position identifiers accounting for padding:
459
  report_position_ids = report_attention_mask.cumsum(-1) + prompt_position_ids.max(dim=1).values[:, None]
460
  report_position_ids.masked_fill_(report_attention_mask == 0, 1)
461
  decoder_position_ids = torch.cat([prompt_position_ids, report_position_ids], dim=1)
462
 
463
  # `inputs_embeds` are only to be used in the 1st generation step:
464
- inputs_embeds = torch.cat([kwargs['decoder_inputs_embeds'], self.decoder.get_input_embeddings()(input_ids)], dim=1)
465
 
466
  decoder_token_type_ids = self.token_ids_to_token_type_ids(
467
  input_ids, special_token_ids,
@@ -483,7 +456,9 @@ class CXRMateEDModel(VisionEncoderDecoderModel):
483
  else:
484
 
485
  # 4D attention mask:
486
- decoder_attention_mask = self.create_4d_attention_mask_mixed_causality_past_key_values(prompt_attention_mask, report_attention_mask)
 
 
487
 
488
  # Position identifiers accounting for padding:
489
  decoder_position_ids = report_attention_mask.cumsum(-1) + prompt_position_ids.max(dim=1).values[:, None]
@@ -863,7 +838,7 @@ class CXRMateEDModel(VisionEncoderDecoderModel):
863
  time_delta.append(tokenized['time_delta'])
864
 
865
  # Image encoder:
866
- encoder_outputs = self.encoder(images)
867
  inputs_embeds.append(encoder_outputs[0])
868
 
869
  inputs_per_image = encoder_outputs[0].shape[-2] // images.shape[1]
@@ -883,14 +858,14 @@ class CXRMateEDModel(VisionEncoderDecoderModel):
883
 
884
  # Compute embeddings from token identifiers:
885
  input_ids = torch.cat(input_ids, dim=1)
886
- inputs_embeds.append(self.decoder.get_input_embeddings()(input_ids))
887
 
888
  # Concatentate time deltas and input embeddings before adding time delta embedding to prompt:
889
  time_delta = torch.cat(time_delta, dim=1)
890
  inputs_embeds = torch.cat(inputs_embeds, dim=1)
891
 
892
  # Add time delta embeddings to prompt:
893
- if time_delta.shape[1] > 0 and self.decoder.config.add_time_deltas:
894
  time_delta = time_delta.to(dtype=inputs_embeds.dtype)
895
  inputs_embeds += self.time_delta_encoder(time_delta)
896
 
@@ -902,7 +877,7 @@ class CXRMateEDModel(VisionEncoderDecoderModel):
902
 
903
  # Tokenize report:
904
  if tokenized_report is not None:
905
- inputs_embeds = torch.cat([inputs_embeds, self.decoder.get_input_embeddings()(tokenized_report['decoder_input_ids'])], dim=1)
906
 
907
  report_token_type_ids = self.token_ids_to_token_type_ids(
908
  token_ids=tokenized_report['decoder_input_ids'],
@@ -917,7 +892,8 @@ class CXRMateEDModel(VisionEncoderDecoderModel):
917
  position_ids = torch.cat([position_ids, report_position_ids], dim=1)
918
 
919
  # 4D attention mask:
920
- attention_mask = self.create_4d_attention_mask_mixed_causality(attention_mask, tokenized_report['decoder_attention_mask'])
 
921
  # attention_mask_diagonal = torch.diagonal(attention_mask[:, 0], dim1=1, dim2=2)
922
 
923
  else:
@@ -934,7 +910,7 @@ class CXRMateEDModel(VisionEncoderDecoderModel):
934
  return inputs_embeds, attention_mask, token_type_ids, position_ids, bos_token_ids
935
 
936
  @staticmethod
937
- def create_4d_attention_mask_mixed_causality(non_causal_2d_attention_mask, causal_2d_attention_mask):
938
 
939
  prompt_seq_len = non_causal_2d_attention_mask.shape[-1]
940
  report_seq_len = causal_2d_attention_mask.shape[-1]
@@ -982,22 +958,91 @@ class CXRMateEDModel(VisionEncoderDecoderModel):
982
 
983
  mixed_causality_4d_attention_mask = torch.cat((left, right), dim=-1)
984
 
 
 
 
 
985
  return mixed_causality_4d_attention_mask
986
 
987
  @staticmethod
988
- def create_4d_attention_mask_mixed_causality_past_key_values(non_causal_2d_attention_mask, causal_2d_attention_mask):
989
 
990
  non_causal_2d_attention_mask = non_causal_2d_attention_mask[:, None, None, :]
991
  causal_2d_attention_mask = causal_2d_attention_mask[:, None, None, :]
992
 
993
  mixed_causality_4d_attention_mask = torch.cat((non_causal_2d_attention_mask, causal_2d_attention_mask), dim=-1)
 
 
 
 
 
994
  return mixed_causality_4d_attention_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
995
 
996
  def position_ids_from_time_deltas_and_attention_mask(self, time_deltas, attention_mask):
997
- mask_value = torch.finfo(time_deltas.dtype).max if self.decoder.config.time_delta_monotonic_inversion else torch.finfo(time_deltas.dtype).min
998
 
999
  masked_time_deltas = torch.where(attention_mask == 1, time_deltas[:, :, 0], mask_value)
1000
- _, col_indices = torch.sort(masked_time_deltas, descending=not self.decoder.config.time_delta_monotonic_inversion)
1001
 
1002
  num_rows, num_cols, _ = time_deltas.shape
1003
 
@@ -1081,7 +1126,7 @@ class CXRMateEDModel(VisionEncoderDecoderModel):
1081
  index_map = {study_id: idx for idx, study_id in enumerate(train_set_study_ids)}
1082
  indices = [index_map[study_id] for study_id in study_ids if study_id in index_map]
1083
  indices.sort()
1084
- train_set = PriorsDataset(train_set, self.decoder.config.history, self.time_delta_map)
1085
  train_set.set_transform(train_set_transform)
1086
  train_set = Subset(train_set, indices)
1087
  else:
@@ -1096,7 +1141,7 @@ class CXRMateEDModel(VisionEncoderDecoderModel):
1096
  index_map = {study_id: idx for idx, study_id in enumerate(val_set_study_ids)}
1097
  indices = [index_map[study_id] for study_id in study_ids if study_id in index_map]
1098
  indices.sort()
1099
- val_set = PriorsDataset(val_set, self.decoder.config.history, self.time_delta_map)
1100
  val_set.set_transform(test_set_transform)
1101
  val_set = Subset(val_set, indices)
1102
  else:
@@ -1110,7 +1155,7 @@ class CXRMateEDModel(VisionEncoderDecoderModel):
1110
  index_map = {study_id: idx for idx, study_id in enumerate(test_set_study_ids)}
1111
  indices = [index_map[study_id] for study_id in study_ids if study_id in index_map]
1112
  indices.sort()
1113
- test_set = PriorsDataset(test_set, self.decoder.config.history, self.time_delta_map)
1114
  test_set.set_transform(test_set_transform)
1115
  test_set = Subset(test_set, indices)
1116
 
@@ -1163,7 +1208,7 @@ class CXRMateEDModel(VisionEncoderDecoderModel):
1163
  index_map = {study_id: idx for idx, study_id in enumerate(train_set_study_ids)}
1164
  indices = [index_map[study_id] for study_id in study_ids if study_id in index_map]
1165
  indices.sort()
1166
- train_set = PriorsDataset(train_set, self.decoder.config.history, self.time_delta_map)
1167
  train_set.set_transform(train_set_transform)
1168
  train_set = Subset(train_set, indices)
1169
 
@@ -1175,7 +1220,7 @@ class CXRMateEDModel(VisionEncoderDecoderModel):
1175
  index_map = {study_id: idx for idx, study_id in enumerate(val_set_study_ids)}
1176
  indices = [index_map[study_id] for study_id in study_ids if study_id in index_map]
1177
  indices.sort()
1178
- val_set = PriorsDataset(val_set, self.decoder.config.history, self.time_delta_map)
1179
  val_set.set_transform(test_set_transform)
1180
  val_set = Subset(val_set, indices)
1181
 
@@ -1187,7 +1232,7 @@ class CXRMateEDModel(VisionEncoderDecoderModel):
1187
  index_map = {study_id: idx for idx, study_id in enumerate(test_set_study_ids)}
1188
  indices = [index_map[study_id] for study_id in study_ids if study_id in index_map]
1189
  indices.sort()
1190
- test_set = PriorsDataset(test_set, self.decoder.config.history, self.time_delta_map)
1191
  test_set.set_transform(test_set_transform)
1192
  test_set = Subset(test_set, indices)
1193
 
 
14
  from transformers.configuration_utils import PretrainedConfig
15
  from transformers.modeling_outputs import ModelOutput, Seq2SeqLMOutput
16
  from transformers.modeling_utils import PreTrainedModel
17
+ from transformers.utils import check_min_version, logging
18
 
19
  from .configuration_cxrmate_ed import CXRMateEDConfig
20
  from .dataset import PriorsDataset
 
108
  return ModelOutput(last_hidden_state=last_hidden_state, attention_mask=attention_mask)
109
 
110
 
111
+ class CXRMateEDModel(transformers.LlavaForConditionalGeneration):
 
112
 
113
  config_class = CXRMateEDConfig
114
 
115
+ def __init__(self, config: CXRMateEDConfig):
116
+
117
+ check_min_version("4.46.0.dev0")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
+ super(transformers.LlavaPreTrainedModel, self).__init__(config)
120
+
121
+ self.config = config
122
+
123
+ self.vocab_size = config.text_config.vocab_size
124
+
125
+ self.image_encoder = transformers.AutoModel.from_config(self.config.vision_config, trust_remote_code=True)
126
+
127
+ self.language_model = transformers.AutoModelForCausalLM.from_config(
128
+ config.text_config,
129
+ attn_implementation=config._attn_implementation,
130
+ )
131
+
132
+ self.image_encoder = CXRStudyImagesEncoder(self.image_encoder, config.text_config)
133
+
134
+ self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
135
 
136
+ # assert 'pad_token_id' in self.config.__dict__
137
+ # assert 'time_delta_monotonic_inversion' in self.config.__dict__
138
+ # assert 'add_time_deltas' in self.config.__dict__
139
+ # assert 'history' in self.config.__dict__
140
+ # assert 'tables_filter' in self.config.__dict__
141
+ # assert 'prompt_report_sections_filter' in self.config.__dict__
 
 
142
 
143
+ # assert isinstance(self.config.time_delta_monotonic_inversion, bool)
144
 
145
  with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tables.json'), 'r') as f:
146
  self.tables = json.load(f)
 
151
  with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'token_type_ids.json'), 'r') as f:
152
  self.token_type_to_token_type_id = json.load(f)
153
 
154
+ self.tables = {k: self.tables[k] for k in self.config.tables_filter}
155
+ self.tables['mimic_cxr_sectioned']['text_columns'] = self.config.prompt_report_sections_filter
156
 
157
  for k in self.tables.keys():
158
  if self.luts[k]['total'] > 0:
 
161
  f'{k}_index_value_encoder',
162
  FNNEncoder(
163
  num_features=self.luts[k]['total'],
164
+ intermediate_size=self.config.index_value_encoder_intermediate_size,
165
+ decoder_hidden_size=self.config.hidden_size,
166
  ),
167
  )
168
 
169
+ if self.config.add_time_deltas:
170
  self.time_delta_encoder = FNNEncoder(
171
  num_features=1,
172
+ intermediate_size=self.config.index_value_encoder_intermediate_size,
173
+ decoder_hidden_size=self.config.hidden_size,
174
  )
175
 
176
+ self.token_type_embeddings = torch.nn.Embedding(max(self.token_type_to_token_type_id.values()) + 1, self.config.hidden_size)
177
 
178
  self.time_delta_map = lambda x: 1 / math.sqrt(x + 1)
179
  self.zero_time_delta_value = self.time_delta_map(0)
180
 
181
  self.inf_time_delta_value = self.time_delta_map(float('inf'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
+ self.post_init()
184
+
185
+
186
+ # @classmethod
187
+ # def from_encoder_decoder_pretrained(
188
+ # cls,
189
+ # encoder_pretrained_model_name_or_path: str = None,
190
+ # decoder_pretrained_model_name_or_path: str = None,
191
+ # *model_args,
192
+ # **kwargs,
193
+ # ) -> PreTrainedModel:
194
+ # r"""
195
+ # Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model
196
+ # checkpoints.
197
+
198
+
199
+ # The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
200
+ # the model, you need to first set it back in training mode with `model.train()`.
201
+
202
+ # Params:
203
+ # encoder_pretrained_model_name_or_path (`str`, *optional*):
204
+ # Information necessary to initiate the image encoder. Can be either:
205
+
206
+ # - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. An
207
+ # example is `google/vit-base-patch16-224-in21k`.
208
+ # - A path to a *directory* containing model weights saved using
209
+ # [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
210
+ # - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
211
+ # this case, `from_tf` should be set to `True` and a configuration object should be provided as
212
+ # `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
213
+ # PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
214
+
215
+ # decoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`):
216
+ # Information necessary to initiate the text decoder. Can be either:
217
+
218
+ # - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
219
+ # - A path to a *directory* containing model weights saved using
220
+ # [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
221
+ # - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
222
+ # this case, `from_tf` should be set to `True` and a configuration object should be provided as
223
+ # `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
224
+ # PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
225
+
226
+ # model_args (remaining positional arguments, *optional*):
227
+ # All remaning positional arguments will be passed to the underlying model's `__init__` method.
228
+
229
+ # kwargs (remaining dictionary of keyword arguments, *optional*):
230
+ # Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
231
+ # `output_attentions=True`).
232
+
233
+ # - To update the encoder configuration, use the prefix *encoder_* for each configuration parameter.
234
+ # - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter.
235
+ # - To update the parent model configuration, do not use a prefix for each configuration parameter.
236
+
237
+ # Behaves differently depending on whether a `config` is provided or automatically loaded.
238
+
239
+ # Example:
240
+
241
+ # ```python
242
+ # >>> from transformers import VisionEncoderDecoderModel
243
+
244
+ # >>> # initialize a vit-bert from a pretrained ViT and a pretrained BERT model. Note that the cross-attention layers will be randomly initialized
245
+ # >>> model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
246
+ # ... "google/vit-base-patch16-224-in21k", "google-bert/bert-base-uncased"
247
+ # ... )
248
+ # >>> # saving model after fine-tuning
249
+ # >>> model.save_pretrained("./vit-bert")
250
+ # >>> # load fine-tuned model
251
+ # >>> model = VisionEncoderDecoderModel.from_pretrained("./vit-bert")
252
+ # ```"""
253
+
254
+ # kwargs_encoder = {
255
+ # argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_")
256
+ # }
257
+
258
+ # kwargs_decoder = {
259
+ # argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
260
+ # }
261
+
262
+ # # remove encoder, decoder kwargs from kwargs
263
+ # for key in kwargs_encoder.keys():
264
+ # del kwargs["encoder_" + key]
265
+ # for key in kwargs_decoder.keys():
266
+ # del kwargs["decoder_" + key]
267
+
268
+ # # Load and initialize the encoder and decoder
269
+ # # The distinction between encoder and decoder at the model level is made
270
+ # # by the value of the flag `is_decoder` that we need to set correctly.
271
+ # encoder = kwargs_encoder.pop("model", None)
272
+ # if encoder is None:
273
+ # if encoder_pretrained_model_name_or_path is None:
274
+ # raise ValueError(
275
+ # "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
276
+ # "to be defined."
277
+ # )
278
+
279
+ # if "config" not in kwargs_encoder:
280
+ # encoder_config, kwargs_encoder = transformers.AutoConfig.from_pretrained(
281
+ # encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True
282
+ # )
283
+
284
+ # if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
285
+ # logger.info(
286
+ # f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
287
+ # "from a decoder model. Cross-attention and casual mask are disabled."
288
+ # )
289
+ # encoder_config.is_decoder = False
290
+ # encoder_config.add_cross_attention = False
291
+
292
+ # kwargs_encoder["config"] = encoder_config
293
+
294
+ # encoder = transformers.AutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder)
295
+
296
+ # decoder = kwargs_decoder.pop("model", None)
297
+ # if decoder is None:
298
+ # if decoder_pretrained_model_name_or_path is None:
299
+ # raise ValueError(
300
+ # "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
301
+ # "to be defined."
302
+ # )
303
+
304
+ # if "config" not in kwargs_decoder:
305
+ # decoder_config, kwargs_decoder = transformers.AutoConfig.from_pretrained(
306
+ # decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True
307
+ # )
308
+
309
+ # if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
310
+ # logger.info(
311
+ # f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention"
312
+ # f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if"
313
+ # f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
314
+ # )
315
+ # decoder_config.is_decoder = True
316
+ # decoder_config.add_cross_attention = False
317
+
318
+ # kwargs_decoder["config"] = decoder_config
319
+
320
+ # if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
321
+ # logger.warning(
322
+ # f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
323
+ # f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
324
+ # "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
325
+ # "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
326
+ # "`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
327
+ # )
328
+
329
+ # decoder = transformers.AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
330
+
331
+ # # instantiate config with corresponding kwargs
332
+ # config = CXRMateEDConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
333
+
334
+ # # make sure input & output embeddings is not tied
335
+ # config.tie_word_embeddings = False
336
+
337
+ # config.is_encoder_decoder = False
338
 
339
+ # return cls(encoder=encoder, decoder=decoder, config=config)
340
 
341
  def forward(
342
  self,
 
361
  argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
362
  }
363
 
 
 
364
  if decoder_inputs_embeds is None:
365
+ decoder_inputs_embeds = self.language_model.get_input_embeddings()(decoder_input_ids)
366
  decoder_inputs_embeds += self.token_type_embeddings(decoder_token_type_ids)
367
 
368
+ if decoder_attention_mask.dim() == 4:
369
+ assert decoder_attention_mask.dtype == decoder_inputs_embeds.dtype, f'The dtype for {decoder_attention_mask} was {decoder_attention_mask.dtype}. It should be {decoder_inputs_embeds.dtype}'
370
+ else:
371
+ assert decoder_attention_mask.dtype == torch.long, f'The dtype for {decoder_attention_mask} was {decoder_attention_mask.dtype}. It should be torch.long'
372
+
373
  # Generation:
374
+ decoder_outputs = self.language_model(
375
  inputs_embeds=decoder_inputs_embeds,
376
  attention_mask=decoder_attention_mask,
377
  position_ids=decoder_position_ids,
 
388
  if labels is not None:
389
  logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
390
  loss_fct = CrossEntropyLoss()
391
+ loss = loss_fct(logits.reshape(-1, self.vocab_size), labels.reshape(-1))
392
 
393
  if not return_dict:
394
  if loss is not None:
 
419
  https://github.com/huggingface/transformers/blob/main/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py#L660
420
  """
421
 
422
+ report_attention_mask = (input_ids != self.config.pad_token_id).long()
423
 
424
+ if len(past_key_values) == 0:
425
 
426
  # 4D attention mask:
427
+ decoder_attention_mask = self.create_4d_attention_mask_mixed_causality(
428
+ prompt_attention_mask, report_attention_mask, dtype=kwargs['decoder_inputs_embeds'].dtype,
429
+ )
430
+
431
  # Position identifiers accounting for padding:
432
  report_position_ids = report_attention_mask.cumsum(-1) + prompt_position_ids.max(dim=1).values[:, None]
433
  report_position_ids.masked_fill_(report_attention_mask == 0, 1)
434
  decoder_position_ids = torch.cat([prompt_position_ids, report_position_ids], dim=1)
435
 
436
  # `inputs_embeds` are only to be used in the 1st generation step:
437
+ inputs_embeds = torch.cat([kwargs['decoder_inputs_embeds'], self.language_model.get_input_embeddings()(input_ids)], dim=1)
438
 
439
  decoder_token_type_ids = self.token_ids_to_token_type_ids(
440
  input_ids, special_token_ids,
 
456
  else:
457
 
458
  # 4D attention mask:
459
+ decoder_attention_mask = self.create_4d_attention_mask_mixed_causality_past_key_values(
460
+ prompt_attention_mask, report_attention_mask, dtype=kwargs['decoder_inputs_embeds'].dtype,
461
+ )
462
 
463
  # Position identifiers accounting for padding:
464
  decoder_position_ids = report_attention_mask.cumsum(-1) + prompt_position_ids.max(dim=1).values[:, None]
 
838
  time_delta.append(tokenized['time_delta'])
839
 
840
  # Image encoder:
841
+ encoder_outputs = self.image_encoder(images)
842
  inputs_embeds.append(encoder_outputs[0])
843
 
844
  inputs_per_image = encoder_outputs[0].shape[-2] // images.shape[1]
 
858
 
859
  # Compute embeddings from token identifiers:
860
  input_ids = torch.cat(input_ids, dim=1)
861
+ inputs_embeds.append(self.language_model.get_input_embeddings()(input_ids))
862
 
863
  # Concatentate time deltas and input embeddings before adding time delta embedding to prompt:
864
  time_delta = torch.cat(time_delta, dim=1)
865
  inputs_embeds = torch.cat(inputs_embeds, dim=1)
866
 
867
  # Add time delta embeddings to prompt:
868
+ if time_delta.shape[1] > 0 and self.config.add_time_deltas:
869
  time_delta = time_delta.to(dtype=inputs_embeds.dtype)
870
  inputs_embeds += self.time_delta_encoder(time_delta)
871
 
 
877
 
878
  # Tokenize report:
879
  if tokenized_report is not None:
880
+ inputs_embeds = torch.cat([inputs_embeds, self.language_model.get_input_embeddings()(tokenized_report['decoder_input_ids'])], dim=1)
881
 
882
  report_token_type_ids = self.token_ids_to_token_type_ids(
883
  token_ids=tokenized_report['decoder_input_ids'],
 
892
  position_ids = torch.cat([position_ids, report_position_ids], dim=1)
893
 
894
  # 4D attention mask:
895
+ attention_mask = self.create_4d_attention_mask_mixed_causality(attention_mask, tokenized_report['decoder_attention_mask'], dtype=inputs_embeds.dtype)
896
+ # attention_mask = self.create_4d_attention_mask_mixed_causality(attention_mask, tokenized_report['decoder_attention_mask'])
897
  # attention_mask_diagonal = torch.diagonal(attention_mask[:, 0], dim1=1, dim2=2)
898
 
899
  else:
 
910
  return inputs_embeds, attention_mask, token_type_ids, position_ids, bos_token_ids
911
 
912
  @staticmethod
913
+ def create_4d_attention_mask_mixed_causality(non_causal_2d_attention_mask, causal_2d_attention_mask, dtype):
914
 
915
  prompt_seq_len = non_causal_2d_attention_mask.shape[-1]
916
  report_seq_len = causal_2d_attention_mask.shape[-1]
 
958
 
959
  mixed_causality_4d_attention_mask = torch.cat((left, right), dim=-1)
960
 
961
+ mixed_causality_4d_attention_mask = mixed_causality_4d_attention_mask.to(dtype=dtype)
962
+ mixed_causality_4d_attention_mask[mixed_causality_4d_attention_mask == 0] = torch.finfo(mixed_causality_4d_attention_mask.dtype).min
963
+ mixed_causality_4d_attention_mask[mixed_causality_4d_attention_mask == 1] = 0.0
964
+
965
  return mixed_causality_4d_attention_mask
966
 
967
  @staticmethod
968
+ def create_4d_attention_mask_mixed_causality_past_key_values(non_causal_2d_attention_mask, causal_2d_attention_mask, dtype):
969
 
970
  non_causal_2d_attention_mask = non_causal_2d_attention_mask[:, None, None, :]
971
  causal_2d_attention_mask = causal_2d_attention_mask[:, None, None, :]
972
 
973
  mixed_causality_4d_attention_mask = torch.cat((non_causal_2d_attention_mask, causal_2d_attention_mask), dim=-1)
974
+
975
+ mixed_causality_4d_attention_mask = mixed_causality_4d_attention_mask.to(dtype=dtype)
976
+ mixed_causality_4d_attention_mask[mixed_causality_4d_attention_mask == 0] = torch.finfo(mixed_causality_4d_attention_mask.dtype).min
977
+ mixed_causality_4d_attention_mask[mixed_causality_4d_attention_mask == 1] = 0.0
978
+
979
  return mixed_causality_4d_attention_mask
980
+
981
+ # @staticmethod
982
+ # def create_4d_attention_mask_mixed_causality(non_causal_2d_attention_mask, causal_2d_attention_mask):
983
+
984
+ # prompt_seq_len = non_causal_2d_attention_mask.shape[-1]
985
+ # report_seq_len = causal_2d_attention_mask.shape[-1]
986
+
987
+ # non_causal_2d_attention_mask = non_causal_2d_attention_mask[:, None, None, :]
988
+ # causal_2d_attention_mask = causal_2d_attention_mask[:, None, None, :]
989
+
990
+ # # Upper left of attention matrix:
991
+ # upper_left = non_causal_2d_attention_mask.expand(-1, -1, prompt_seq_len, -1)
992
+ # upper_left = upper_left * non_causal_2d_attention_mask
993
+ # upper_left = upper_left * non_causal_2d_attention_mask.permute(0, 1, 3, 2)
994
+
995
+ # causal_mask = torch.tril(
996
+ # torch.ones(
997
+ # (
998
+ # report_seq_len,
999
+ # report_seq_len,
1000
+ # ),
1001
+ # dtype=torch.long,
1002
+ # device=causal_2d_attention_mask.device,
1003
+ # ),
1004
+ # )
1005
+
1006
+ # # Lower right of attention matrix:
1007
+ # lower_right = causal_2d_attention_mask.expand(-1, -1, report_seq_len, -1)
1008
+ # lower_right = lower_right * causal_2d_attention_mask.permute(0, 1, 3, 2)
1009
+ # lower_right = lower_right * causal_mask
1010
+
1011
+ # # Upper right of attention matrix:
1012
+ # upper_right = torch.zeros(
1013
+ # causal_2d_attention_mask.shape[0],
1014
+ # 1,
1015
+ # prompt_seq_len,
1016
+ # report_seq_len,
1017
+ # dtype=torch.long,
1018
+ # device=causal_2d_attention_mask.device,
1019
+ # )
1020
+
1021
+ # # Lower left of attention matrix:
1022
+ # lower_left = non_causal_2d_attention_mask.expand(-1, -1, report_seq_len, -1)
1023
+ # lower_left = lower_left * causal_2d_attention_mask.permute(0, 1, 3, 2)
1024
+
1025
+ # left = torch.cat((upper_left, lower_left), dim=2)
1026
+ # right = torch.cat((upper_right, lower_right), dim=2)
1027
+
1028
+ # mixed_causality_4d_attention_mask = torch.cat((left, right), dim=-1)
1029
+
1030
+ # return mixed_causality_4d_attention_mask
1031
+
1032
+ # @staticmethod
1033
+ # def create_4d_attention_mask_mixed_causality_past_key_values(non_causal_2d_attention_mask, causal_2d_attention_mask):
1034
+
1035
+ # non_causal_2d_attention_mask = non_causal_2d_attention_mask[:, None, None, :]
1036
+ # causal_2d_attention_mask = causal_2d_attention_mask[:, None, None, :]
1037
+
1038
+ # mixed_causality_4d_attention_mask = torch.cat((non_causal_2d_attention_mask, causal_2d_attention_mask), dim=-1)
1039
+ # return mixed_causality_4d_attention_mask
1040
 
1041
  def position_ids_from_time_deltas_and_attention_mask(self, time_deltas, attention_mask):
1042
+ mask_value = torch.finfo(time_deltas.dtype).max if self.config.time_delta_monotonic_inversion else torch.finfo(time_deltas.dtype).min
1043
 
1044
  masked_time_deltas = torch.where(attention_mask == 1, time_deltas[:, :, 0], mask_value)
1045
+ _, col_indices = torch.sort(masked_time_deltas, descending=not self.config.time_delta_monotonic_inversion)
1046
 
1047
  num_rows, num_cols, _ = time_deltas.shape
1048
 
 
1126
  index_map = {study_id: idx for idx, study_id in enumerate(train_set_study_ids)}
1127
  indices = [index_map[study_id] for study_id in study_ids if study_id in index_map]
1128
  indices.sort()
1129
+ train_set = PriorsDataset(train_set, self.config.history, self.time_delta_map)
1130
  train_set.set_transform(train_set_transform)
1131
  train_set = Subset(train_set, indices)
1132
  else:
 
1141
  index_map = {study_id: idx for idx, study_id in enumerate(val_set_study_ids)}
1142
  indices = [index_map[study_id] for study_id in study_ids if study_id in index_map]
1143
  indices.sort()
1144
+ val_set = PriorsDataset(val_set, self.config.history, self.time_delta_map)
1145
  val_set.set_transform(test_set_transform)
1146
  val_set = Subset(val_set, indices)
1147
  else:
 
1155
  index_map = {study_id: idx for idx, study_id in enumerate(test_set_study_ids)}
1156
  indices = [index_map[study_id] for study_id in study_ids if study_id in index_map]
1157
  indices.sort()
1158
+ test_set = PriorsDataset(test_set, self.config.history, self.time_delta_map)
1159
  test_set.set_transform(test_set_transform)
1160
  test_set = Subset(test_set, indices)
1161
 
 
1208
  index_map = {study_id: idx for idx, study_id in enumerate(train_set_study_ids)}
1209
  indices = [index_map[study_id] for study_id in study_ids if study_id in index_map]
1210
  indices.sort()
1211
+ train_set = PriorsDataset(train_set, self.config.history, self.time_delta_map)
1212
  train_set.set_transform(train_set_transform)
1213
  train_set = Subset(train_set, indices)
1214
 
 
1220
  index_map = {study_id: idx for idx, study_id in enumerate(val_set_study_ids)}
1221
  indices = [index_map[study_id] for study_id in study_ids if study_id in index_map]
1222
  indices.sort()
1223
+ val_set = PriorsDataset(val_set, self.config.history, self.time_delta_map)
1224
  val_set.set_transform(test_set_transform)
1225
  val_set = Subset(val_set, indices)
1226
 
 
1232
  index_map = {study_id: idx for idx, study_id in enumerate(test_set_study_ids)}
1233
  indices = [index_map[study_id] for study_id in study_ids if study_id in index_map]
1234
  indices.sort()
1235
+ test_set = PriorsDataset(test_set, self.config.history, self.time_delta_map)
1236
  test_set.set_transform(test_set_transform)
1237
  test_set = Subset(test_set, indices)
1238