anicolson commited on
Commit
19994b0
1 Parent(s): cef87b0

Upload model

Browse files
Files changed (1) hide show
  1. modelling_cxrmate_ed.py +29 -25
modelling_cxrmate_ed.py CHANGED
@@ -143,13 +143,19 @@ class CXRMateEDModel(transformers.LlavaForConditionalGeneration):
143
 
144
  # assert isinstance(self.config.time_delta_monotonic_inversion, bool)
145
 
146
- with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tables.json'), 'r') as f:
 
 
147
  self.tables = json.load(f)
148
 
149
- with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'lookup_tables.json'), 'r') as f:
 
 
150
  self.luts = json.load(f)
151
-
152
- with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'token_type_ids.json'), 'r') as f:
 
 
153
  self.token_type_to_token_type_id = json.load(f)
154
 
155
  self.tables = {k: self.tables[k] for k in self.config.tables_filter}
@@ -183,21 +189,6 @@ class CXRMateEDModel(transformers.LlavaForConditionalGeneration):
183
 
184
  self.post_init()
185
 
186
- @classmethod
187
- def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
188
-
189
- hf_hub_download(repo_id=pretrained_model_name_or_path, filename='tables.json')
190
- hf_hub_download(repo_id=pretrained_model_name_or_path, filename='token_type_ids.json')
191
- hf_hub_download(repo_id=pretrained_model_name_or_path, filename='lookup_tables.json')
192
- hf_hub_download(repo_id=pretrained_model_name_or_path, filename='mimic_cxr_jpg_train_study_ids.json')
193
- hf_hub_download(repo_id=pretrained_model_name_or_path, filename='mimic_cxr_jpg_validate_study_ids.json')
194
- hf_hub_download(repo_id=pretrained_model_name_or_path, filename='mimic_cxr_jpg_test_study_ids.json')
195
- hf_hub_download(repo_id=pretrained_model_name_or_path, filename='mimic_iv_ed_mimic_cxr_jpg_train_study_ids.json')
196
- hf_hub_download(repo_id=pretrained_model_name_or_path, filename='mimic_iv_ed_mimic_cxr_jpg_validate_study_ids.json')
197
- hf_hub_download(repo_id=pretrained_model_name_or_path, filename='mimic_iv_ed_mimic_cxr_jpg_test_study_ids.json')
198
-
199
- return super().from_pretrained(pretrained_model_name_or_path, **kwargs)
200
-
201
  # @classmethod
202
  # def from_encoder_decoder_pretrained(
203
  # cls,
@@ -1134,7 +1125,10 @@ class CXRMateEDModel(transformers.LlavaForConditionalGeneration):
1134
 
1135
  # Train set:
1136
  if not test_set_only:
1137
- with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), f'{study_id_split}_train_study_ids.json'), 'r') as f:
 
 
 
1138
  study_ids = json.load(f)
1139
  train_set = dataset['train']
1140
  train_set_study_ids = train_set['study_id']
@@ -1149,7 +1143,9 @@ class CXRMateEDModel(transformers.LlavaForConditionalGeneration):
1149
 
1150
  # Validation set:
1151
  if not test_set_only:
1152
- with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), f'{study_id_split}_validate_study_ids.json'), 'r') as f:
 
 
1153
  study_ids = json.load(f)
1154
  val_set = dataset['validate']
1155
  val_set_study_ids = val_set['study_id']
@@ -1163,7 +1159,9 @@ class CXRMateEDModel(transformers.LlavaForConditionalGeneration):
1163
  val_set = None
1164
 
1165
  # Test set:
1166
- with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), f'{study_id_split}_test_study_ids.json'), 'r') as f:
 
 
1167
  study_ids = json.load(f)
1168
  test_set = dataset['test']
1169
  test_set_study_ids = test_set['study_id']
@@ -1216,7 +1214,9 @@ class CXRMateEDModel(transformers.LlavaForConditionalGeneration):
1216
  dataset = datasets.load_from_disk(dataset_path)
1217
 
1218
  # Train set:
1219
- with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), f'mimic_cxr_jpg_train_study_ids.json'), 'r') as f:
 
 
1220
  study_ids = json.load(f)
1221
  train_set = dataset['train']
1222
  train_set_study_ids = train_set['study_id']
@@ -1228,7 +1228,9 @@ class CXRMateEDModel(transformers.LlavaForConditionalGeneration):
1228
  train_set = Subset(train_set, indices)
1229
 
1230
  # Validation set:
1231
- with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), f'mimic_cxr_jpg_validate_study_ids.json'), 'r') as f:
 
 
1232
  study_ids = json.load(f)
1233
  val_set = dataset['validate']
1234
  val_set_study_ids = val_set['study_id']
@@ -1240,7 +1242,9 @@ class CXRMateEDModel(transformers.LlavaForConditionalGeneration):
1240
  val_set = Subset(val_set, indices)
1241
 
1242
  # Test set:
1243
- with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), f'mimic_cxr_jpg_test_study_ids.json'), 'r') as f:
 
 
1244
  study_ids = json.load(f)
1245
  test_set = dataset['test']
1246
  test_set_study_ids = test_set['study_id']
 
143
 
144
  # assert isinstance(self.config.time_delta_monotonic_inversion, bool)
145
 
146
+ path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tables.json')
147
+ path = path if os.path.exists(path) else hf_hub_download(repo_id='aehrc/cxrmate-ed', filename='tables.json')
148
+ with open(path, 'r') as f:
149
  self.tables = json.load(f)
150
 
151
+ path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'lookup_tables.json')
152
+ path = path if os.path.exists(path) else hf_hub_download(repo_id='aehrc/cxrmate-ed', filename='lookup_tables.json')
153
+ with open(path, 'r') as f:
154
  self.luts = json.load(f)
155
+
156
+ path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'token_type_ids.json')
157
+ path = path if os.path.exists(path) else hf_hub_download(repo_id='aehrc/cxrmate-ed', filename='token_type_ids.json')
158
+ with open(path, 'r') as f:
159
  self.token_type_to_token_type_id = json.load(f)
160
 
161
  self.tables = {k: self.tables[k] for k in self.config.tables_filter}
 
189
 
190
  self.post_init()
191
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  # @classmethod
193
  # def from_encoder_decoder_pretrained(
194
  # cls,
 
1125
 
1126
  # Train set:
1127
  if not test_set_only:
1128
+
1129
+ path = os.path.join(os.path.dirname(os.path.abspath(__file__)), f'{study_id_split}_train_study_ids.json')
1130
+ path = path if os.path.exists(path) else hf_hub_download(repo_id='aehrc/cxrmate-ed', filename=f'{study_id_split}_train_study_ids.json')
1131
+ with open(path, 'r') as f:
1132
  study_ids = json.load(f)
1133
  train_set = dataset['train']
1134
  train_set_study_ids = train_set['study_id']
 
1143
 
1144
  # Validation set:
1145
  if not test_set_only:
1146
+ path = os.path.join(os.path.dirname(os.path.abspath(__file__)), f'{study_id_split}_validate_study_ids.json')
1147
+ path = path if os.path.exists(path) else hf_hub_download(repo_id='aehrc/cxrmate-ed', filename=f'{study_id_split}_validate_study_ids.json')
1148
+ with open(path, 'r') as f:
1149
  study_ids = json.load(f)
1150
  val_set = dataset['validate']
1151
  val_set_study_ids = val_set['study_id']
 
1159
  val_set = None
1160
 
1161
  # Test set:
1162
+ path = os.path.join(os.path.dirname(os.path.abspath(__file__)), f'{study_id_split}_test_study_ids.json')
1163
+ path = path if os.path.exists(path) else hf_hub_download(repo_id='aehrc/cxrmate-ed', filename=f'{study_id_split}_test_study_ids.json')
1164
+ with open(path, 'r') as f:
1165
  study_ids = json.load(f)
1166
  test_set = dataset['test']
1167
  test_set_study_ids = test_set['study_id']
 
1214
  dataset = datasets.load_from_disk(dataset_path)
1215
 
1216
  # Train set:
1217
+ path = os.path.join(os.path.dirname(os.path.abspath(__file__)), f'mimic_cxr_jpg_train_study_ids.json')
1218
+ path = path if os.path.exists(path) else hf_hub_download(repo_id='aehrc/cxrmate-ed', filename='mimic_cxr_jpg_train_study_ids.json')
1219
+ with open(path, 'r') as f:
1220
  study_ids = json.load(f)
1221
  train_set = dataset['train']
1222
  train_set_study_ids = train_set['study_id']
 
1228
  train_set = Subset(train_set, indices)
1229
 
1230
  # Validation set:
1231
+ path = os.path.join(os.path.dirname(os.path.abspath(__file__)), f'mimic_cxr_jpg_validate_study_ids.json')
1232
+ path = path if os.path.exists(path) else hf_hub_download(repo_id='aehrc/cxrmate-ed', filename='mimic_cxr_jpg_validate_study_ids.json')
1233
+ with open(path, 'r') as f:
1234
  study_ids = json.load(f)
1235
  val_set = dataset['validate']
1236
  val_set_study_ids = val_set['study_id']
 
1242
  val_set = Subset(val_set, indices)
1243
 
1244
  # Test set:
1245
+ path = os.path.join(os.path.dirname(os.path.abspath(__file__)), f'mimic_cxr_jpg_test_study_ids.json')
1246
+ path = path if os.path.exists(path) else hf_hub_download(repo_id='aehrc/cxrmate-ed', filename='mimic_cxr_jpg_test_study_ids.json')
1247
+ with open(path, 'r') as f:
1248
  study_ids = json.load(f)
1249
  test_set = dataset['test']
1250
  test_set_study_ids = test_set['study_id']