Upload model
Browse files- modelling_cxrmate_ed.py +29 -25
modelling_cxrmate_ed.py
CHANGED
@@ -143,13 +143,19 @@ class CXRMateEDModel(transformers.LlavaForConditionalGeneration):
|
|
143 |
|
144 |
# assert isinstance(self.config.time_delta_monotonic_inversion, bool)
|
145 |
|
146 |
-
|
|
|
|
|
147 |
self.tables = json.load(f)
|
148 |
|
149 |
-
|
|
|
|
|
150 |
self.luts = json.load(f)
|
151 |
-
|
152 |
-
|
|
|
|
|
153 |
self.token_type_to_token_type_id = json.load(f)
|
154 |
|
155 |
self.tables = {k: self.tables[k] for k in self.config.tables_filter}
|
@@ -183,21 +189,6 @@ class CXRMateEDModel(transformers.LlavaForConditionalGeneration):
|
|
183 |
|
184 |
self.post_init()
|
185 |
|
186 |
-
@classmethod
|
187 |
-
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
188 |
-
|
189 |
-
hf_hub_download(repo_id=pretrained_model_name_or_path, filename='tables.json')
|
190 |
-
hf_hub_download(repo_id=pretrained_model_name_or_path, filename='token_type_ids.json')
|
191 |
-
hf_hub_download(repo_id=pretrained_model_name_or_path, filename='lookup_tables.json')
|
192 |
-
hf_hub_download(repo_id=pretrained_model_name_or_path, filename='mimic_cxr_jpg_train_study_ids.json')
|
193 |
-
hf_hub_download(repo_id=pretrained_model_name_or_path, filename='mimic_cxr_jpg_validate_study_ids.json')
|
194 |
-
hf_hub_download(repo_id=pretrained_model_name_or_path, filename='mimic_cxr_jpg_test_study_ids.json')
|
195 |
-
hf_hub_download(repo_id=pretrained_model_name_or_path, filename='mimic_iv_ed_mimic_cxr_jpg_train_study_ids.json')
|
196 |
-
hf_hub_download(repo_id=pretrained_model_name_or_path, filename='mimic_iv_ed_mimic_cxr_jpg_validate_study_ids.json')
|
197 |
-
hf_hub_download(repo_id=pretrained_model_name_or_path, filename='mimic_iv_ed_mimic_cxr_jpg_test_study_ids.json')
|
198 |
-
|
199 |
-
return super().from_pretrained(pretrained_model_name_or_path, **kwargs)
|
200 |
-
|
201 |
# @classmethod
|
202 |
# def from_encoder_decoder_pretrained(
|
203 |
# cls,
|
@@ -1134,7 +1125,10 @@ class CXRMateEDModel(transformers.LlavaForConditionalGeneration):
|
|
1134 |
|
1135 |
# Train set:
|
1136 |
if not test_set_only:
|
1137 |
-
|
|
|
|
|
|
|
1138 |
study_ids = json.load(f)
|
1139 |
train_set = dataset['train']
|
1140 |
train_set_study_ids = train_set['study_id']
|
@@ -1149,7 +1143,9 @@ class CXRMateEDModel(transformers.LlavaForConditionalGeneration):
|
|
1149 |
|
1150 |
# Validation set:
|
1151 |
if not test_set_only:
|
1152 |
-
|
|
|
|
|
1153 |
study_ids = json.load(f)
|
1154 |
val_set = dataset['validate']
|
1155 |
val_set_study_ids = val_set['study_id']
|
@@ -1163,7 +1159,9 @@ class CXRMateEDModel(transformers.LlavaForConditionalGeneration):
|
|
1163 |
val_set = None
|
1164 |
|
1165 |
# Test set:
|
1166 |
-
|
|
|
|
|
1167 |
study_ids = json.load(f)
|
1168 |
test_set = dataset['test']
|
1169 |
test_set_study_ids = test_set['study_id']
|
@@ -1216,7 +1214,9 @@ class CXRMateEDModel(transformers.LlavaForConditionalGeneration):
|
|
1216 |
dataset = datasets.load_from_disk(dataset_path)
|
1217 |
|
1218 |
# Train set:
|
1219 |
-
|
|
|
|
|
1220 |
study_ids = json.load(f)
|
1221 |
train_set = dataset['train']
|
1222 |
train_set_study_ids = train_set['study_id']
|
@@ -1228,7 +1228,9 @@ class CXRMateEDModel(transformers.LlavaForConditionalGeneration):
|
|
1228 |
train_set = Subset(train_set, indices)
|
1229 |
|
1230 |
# Validation set:
|
1231 |
-
|
|
|
|
|
1232 |
study_ids = json.load(f)
|
1233 |
val_set = dataset['validate']
|
1234 |
val_set_study_ids = val_set['study_id']
|
@@ -1240,7 +1242,9 @@ class CXRMateEDModel(transformers.LlavaForConditionalGeneration):
|
|
1240 |
val_set = Subset(val_set, indices)
|
1241 |
|
1242 |
# Test set:
|
1243 |
-
|
|
|
|
|
1244 |
study_ids = json.load(f)
|
1245 |
test_set = dataset['test']
|
1246 |
test_set_study_ids = test_set['study_id']
|
|
|
143 |
|
144 |
# assert isinstance(self.config.time_delta_monotonic_inversion, bool)
|
145 |
|
146 |
+
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tables.json')
|
147 |
+
path = path if os.path.exists(path) else hf_hub_download(repo_id='aehrc/cxrmate-ed', filename='tables.json')
|
148 |
+
with open(path, 'r') as f:
|
149 |
self.tables = json.load(f)
|
150 |
|
151 |
+
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'lookup_tables.json')
|
152 |
+
path = path if os.path.exists(path) else hf_hub_download(repo_id='aehrc/cxrmate-ed', filename='lookup_tables.json')
|
153 |
+
with open(path, 'r') as f:
|
154 |
self.luts = json.load(f)
|
155 |
+
|
156 |
+
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'token_type_ids.json')
|
157 |
+
path = path if os.path.exists(path) else hf_hub_download(repo_id='aehrc/cxrmate-ed', filename='token_type_ids.json')
|
158 |
+
with open(path, 'r') as f:
|
159 |
self.token_type_to_token_type_id = json.load(f)
|
160 |
|
161 |
self.tables = {k: self.tables[k] for k in self.config.tables_filter}
|
|
|
189 |
|
190 |
self.post_init()
|
191 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
# @classmethod
|
193 |
# def from_encoder_decoder_pretrained(
|
194 |
# cls,
|
|
|
1125 |
|
1126 |
# Train set:
|
1127 |
if not test_set_only:
|
1128 |
+
|
1129 |
+
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), f'{study_id_split}_train_study_ids.json')
|
1130 |
+
path = path if os.path.exists(path) else hf_hub_download(repo_id='aehrc/cxrmate-ed', filename=f'{study_id_split}_train_study_ids.json')
|
1131 |
+
with open(path, 'r') as f:
|
1132 |
study_ids = json.load(f)
|
1133 |
train_set = dataset['train']
|
1134 |
train_set_study_ids = train_set['study_id']
|
|
|
1143 |
|
1144 |
# Validation set:
|
1145 |
if not test_set_only:
|
1146 |
+
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), f'{study_id_split}_validate_study_ids.json')
|
1147 |
+
path = path if os.path.exists(path) else hf_hub_download(repo_id='aehrc/cxrmate-ed', filename=f'{study_id_split}_validate_study_ids.json')
|
1148 |
+
with open(path, 'r') as f:
|
1149 |
study_ids = json.load(f)
|
1150 |
val_set = dataset['validate']
|
1151 |
val_set_study_ids = val_set['study_id']
|
|
|
1159 |
val_set = None
|
1160 |
|
1161 |
# Test set:
|
1162 |
+
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), f'{study_id_split}_test_study_ids.json')
|
1163 |
+
path = path if os.path.exists(path) else hf_hub_download(repo_id='aehrc/cxrmate-ed', filename=f'{study_id_split}_test_study_ids.json')
|
1164 |
+
with open(path, 'r') as f:
|
1165 |
study_ids = json.load(f)
|
1166 |
test_set = dataset['test']
|
1167 |
test_set_study_ids = test_set['study_id']
|
|
|
1214 |
dataset = datasets.load_from_disk(dataset_path)
|
1215 |
|
1216 |
# Train set:
|
1217 |
+
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), f'mimic_cxr_jpg_train_study_ids.json')
|
1218 |
+
path = path if os.path.exists(path) else hf_hub_download(repo_id='aehrc/cxrmate-ed', filename='mimic_cxr_jpg_train_study_ids.json')
|
1219 |
+
with open(path, 'r') as f:
|
1220 |
study_ids = json.load(f)
|
1221 |
train_set = dataset['train']
|
1222 |
train_set_study_ids = train_set['study_id']
|
|
|
1228 |
train_set = Subset(train_set, indices)
|
1229 |
|
1230 |
# Validation set:
|
1231 |
+
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), f'mimic_cxr_jpg_validate_study_ids.json')
|
1232 |
+
path = path if os.path.exists(path) else hf_hub_download(repo_id='aehrc/cxrmate-ed', filename='mimic_cxr_jpg_validate_study_ids.json')
|
1233 |
+
with open(path, 'r') as f:
|
1234 |
study_ids = json.load(f)
|
1235 |
val_set = dataset['validate']
|
1236 |
val_set_study_ids = val_set['study_id']
|
|
|
1242 |
val_set = Subset(val_set, indices)
|
1243 |
|
1244 |
# Test set:
|
1245 |
+
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), f'mimic_cxr_jpg_test_study_ids.json')
|
1246 |
+
path = path if os.path.exists(path) else hf_hub_download(repo_id='aehrc/cxrmate-ed', filename='mimic_cxr_jpg_test_study_ids.json')
|
1247 |
+
with open(path, 'r') as f:
|
1248 |
study_ids = json.load(f)
|
1249 |
test_set = dataset['test']
|
1250 |
test_set_study_ids = test_set['study_id']
|