anicolson commited on
Commit
6ab63da
1 Parent(s): b7f4ef0

Delete records.py

Browse files
Files changed (1) hide show
  1. records.py +0 -369
records.py DELETED
@@ -1,369 +0,0 @@
1
- import functools
2
- import os
3
- import re
4
- from collections import OrderedDict
5
- from typing import Dict, List, Optional
6
-
7
- import duckdb
8
- import pandas as pd
9
- import torch
10
-
11
- from .tables import ed_cxr_token_type_ids, ed_module_tables, mimic_cxr_tables
12
-
13
-
14
- def mimic_cxr_text_path(dir, subject_id, study_id, ext='txt'):
15
- return os.path.join(dir, 'p' + str(subject_id)[:2], 'p' + str(subject_id),
16
- 's' + str(study_id) + '.' + ext)
17
-
18
- def format(text):
19
- # Remove newline, tab, repeated whitespaces, and leading and trailing whitespaces:
20
- text = re.sub(r'\n|\t', ' ', text)
21
- text = re.sub(r'\s+', ' ', text)
22
- text = text.strip()
23
- return text
24
-
25
-
26
- def rgetattr(obj, attr, *args):
27
- def _getattr(obj, attr):
28
- return getattr(obj, attr, *args)
29
- return functools.reduce(_getattr, [obj] + attr.split('.'))
30
-
31
-
32
- def df_to_tensor_index_columns(
33
- df: pd.DataFrame,
34
- tensor: torch.Tensor,
35
- group_idx_to_y_idx: Dict,
36
- groupby: str,
37
- index_columns: List[str],
38
- ):
39
- """
40
- Converts a dataframe with index columns to a tensor, where each index of the y-axis is determined by the
41
- 'groupby' column.
42
- """
43
- assert len(group_idx_to_y_idx) == tensor.shape[0]
44
- all_columns = index_columns + [groupby]
45
- y_indices = [group_idx_to_y_idx[row[groupby]] for _, row in df[all_columns].iterrows() for i in index_columns if row[i] == row[i]]
46
- x_indices = [row[i] for _, row in df[all_columns].iterrows() for i in index_columns if row[i] == row[i]]
47
- tensor[y_indices, x_indices] = 1.0
48
- return tensor
49
-
50
-
51
- def df_to_tensor_value_columns(
52
- df: pd.DataFrame,
53
- tensor: torch.Tensor,
54
- group_idx_to_y_idx: Dict,
55
- groupby: str,
56
- value_columns: List[str],
57
- value_column_to_idx: Dict,
58
- ):
59
- """
60
- Converts a dataframe with value columns to a tensor, where each index of the y-axis is determined by the
61
- 'groupby' column. The x-index is determined by a dictionary using the column name.
62
- """
63
- assert len(group_idx_to_y_idx) == tensor.shape[0]
64
- all_columns = value_columns + [groupby]
65
- y_indices = [group_idx_to_y_idx[row[groupby]] for _, row in df[all_columns].iterrows() for i in value_columns if row[i] == row[i]]
66
- x_indices = [value_column_to_idx[i] for _, row in df[all_columns].iterrows() for i in value_columns if row[i] == row[i]]
67
- element_value = [row[i] for _, row in df[all_columns].iterrows() for i in value_columns if row[i] == row[i]]
68
- tensor[y_indices, x_indices] = torch.tensor(element_value, dtype=tensor.dtype)
69
- return tensor
70
-
71
-
72
- class EDCXRSubjectRecords:
73
- def __init__(
74
- self,
75
- database_path: str,
76
- dataset_dir: Optional[str] = None,
77
- reports_dir: Optional[str] = None,
78
- token_type_ids_starting_idx: Optional[int] = None,
79
- time_delta_map = lambda x: x,
80
- debug: bool = False
81
- ):
82
-
83
- self.database_path = database_path
84
- self.dataset_dir = dataset_dir
85
- self.reports_dir = reports_dir
86
- self.time_delta_map = time_delta_map
87
- self.debug = debug
88
-
89
- self.connect = duckdb.connect(self.database_path, read_only=True)
90
-
91
- self.streamlit_flag = False
92
-
93
- self.clear_start_end_times()
94
-
95
- # Module configurations:
96
- self.ed_module_tables = ed_module_tables
97
- self.mimic_cxr_tables = mimic_cxr_tables
98
-
99
- lut_info = self.connect.sql("FROM lut_info").df()
100
-
101
- for k, v in (self.ed_module_tables | self.mimic_cxr_tables).items():
102
- if v.load and (v.value_columns or v.index_columns):
103
- v.value_column_to_idx = {}
104
- if v.index_columns:
105
- v.total_indices = lut_info[lut_info['table_name'] == k]['end_index'].item() + 1
106
- else:
107
- v.total_indices = 0
108
- for i in v.value_columns:
109
- v.value_column_to_idx[i] = v.total_indices
110
- v.total_indices += 1
111
-
112
- # Token type identifiers:
113
- self.token_type_to_token_type_id = ed_cxr_token_type_ids
114
- if token_type_ids_starting_idx is not None:
115
- self.token_type_to_token_type_id = {k: v + token_type_ids_starting_idx for k, v in self.token_type_to_token_type_id.items()}
116
-
117
- def __len__(self):
118
- return len(self.subject_ids)
119
-
120
- def clear_start_end_times(self):
121
- self.start_time, self.end_time = None, None
122
-
123
- def admission_ed_stay_ids(self, hadm_id):
124
- if hadm_id:
125
- return self.connect.sql(f'SELECT stay_id FROM edstays WHERE subject_id = {self.subject_id} AND hadm_id = {hadm_id}').df()['stay_id'].tolist()
126
- else:
127
- return self.connect.sql(f'SELECT stay_id FROM edstays WHERE subject_id = {self.subject_id}').df()['stay_id'].tolist()
128
-
129
- def subject_study_ids(self):
130
- mimic_cxr = self.connect.sql(
131
- f'SELECT study_id, study_datetime FROM mimic_cxr WHERE subject_id = {self.subject_id}',
132
- ).df()
133
- if self.start_time and self.end_time:
134
- mimic_cxr = self.filter_admissions_by_time_span(mimic_cxr, 'study_datetime')
135
- mimic_cxr = mimic_cxr.drop_duplicates(subset=['study_id']).sort_values(by='study_datetime')
136
- return dict(zip(mimic_cxr['study_id'], mimic_cxr['study_datetime']))
137
-
138
- def load_ed_module(self, hadm_id=None, stay_id=None, reference_time=None):
139
- if not self.start_time and stay_id is not None:
140
- edstay = self.connect.sql(
141
- f"""
142
- SELECT intime, outtime
143
- FROM edstays
144
- WHERE stay_id = {stay_id}
145
- """
146
- ).df()
147
- self.start_time = edstay['intime'].item()
148
- self.end_time = edstay['outtime'].item()
149
- self.load_module(self.ed_module_tables, hadm_id=hadm_id, stay_id=stay_id, reference_time=reference_time)
150
-
151
- def load_mimic_cxr(self, study_id, reference_time=None):
152
- self.load_module(self.mimic_cxr_tables, study_id=study_id, reference_time=reference_time)
153
- if self.streamlit_flag:
154
- self.report_path = mimic_cxr_text_path(self.reports_dir, self.subject_id, study_id, 'txt')
155
-
156
- def load_module(self, module_dict, hadm_id=None, stay_id=None, study_id=None, reference_time=None):
157
- for k, v in module_dict.items():
158
-
159
- if self.streamlit_flag or v.load:
160
-
161
- query = f"FROM {k}"
162
-
163
- conditions = []
164
- if hasattr(self, 'subject_id') and v.subject_id_filter:
165
- conditions.append(f"subject_id={self.subject_id}")
166
- if v.hadm_id_filter:
167
- assert hadm_id is not None
168
- conditions.append(f"hadm_id={hadm_id}")
169
- if v.stay_id_filter:
170
- assert stay_id is not None
171
- conditions.append(f"stay_id={stay_id}")
172
- if v.study_id_filter:
173
- assert study_id is not None
174
- conditions.append(f"study_id={study_id}")
175
- if v.mimic_cxr_sectioned:
176
- assert study_id is not None
177
- conditions.append(f"study='s{study_id}'")
178
- ands = ['AND'] * (len(conditions) * 2 - 1)
179
- ands[0::2] = conditions
180
-
181
- if conditions:
182
- query += " WHERE ("
183
- query += ' '.join(ands)
184
- query += ")"
185
-
186
- df = self.connect.sql(query).df()
187
-
188
- if v.load:
189
-
190
- columns = [v.groupby] + v.time_columns + v.index_columns + v.text_columns + v.value_columns + v.target_sections
191
-
192
- # Use the starting time of the stay/admission as the time:
193
- if v.use_start_time:
194
- df['start_time'] = self.start_time
195
- columns += ['start_time']
196
-
197
- if reference_time is not None:
198
- time_column = v.time_columns[-1] if not v.use_start_time else 'start_time'
199
-
200
- # Remove rows that are after the reference time to maintain causality:
201
- df = df[df[time_column] < reference_time]
202
-
203
- if self.streamlit_flag:
204
- setattr(self, k, df)
205
-
206
- if v.load:
207
- columns = list(dict.fromkeys(columns)) # remove repetitions.
208
- df = df.drop(columns=df.columns.difference(columns), axis=1)
209
- setattr(self, f'{k}_feats', df)
210
-
211
- def return_ed_module_features(self, stay_id, reference_time=None):
212
- example_dict = {}
213
- if stay_id is not None:
214
- self.load_ed_module(stay_id=stay_id, reference_time=reference_time)
215
- for k, v in self.ed_module_tables.items():
216
- if v.load:
217
-
218
- df = getattr(self, f'{k}_feats')
219
-
220
- if self.debug:
221
- example_dict.setdefault('ed_tables', []).append(k)
222
-
223
- if not df.empty:
224
-
225
- assert f'{k}_index_value_feats' not in example_dict
226
-
227
- # The y-index and the time for each group:
228
- time_column = v.time_columns[-1] if not v.use_start_time else 'start_time'
229
- group_idx_to_y_idx, group_idx_to_datetime = OrderedDict(), OrderedDict()
230
- groups = df.dropna(subset=v.index_columns + v.value_columns + v.text_columns, axis=0, how='all')
231
- groups = groups.drop_duplicates(subset=[v.groupby])[list(dict.fromkeys([v.groupby, time_column]))]
232
- groups = groups.reset_index(drop=True)
233
- for i, row in groups.iterrows():
234
- group_idx_to_y_idx[row[v.groupby]] = i
235
- group_idx_to_datetime[row[v.groupby]] = row[time_column]
236
-
237
- if (v.index_columns or v.value_columns) and group_idx_to_y_idx:
238
- example_dict[f'{k}_index_value_feats'] = torch.zeros(len(group_idx_to_y_idx), v.total_indices)
239
- if v.index_columns:
240
- example_dict[f'{k}_index_value_feats'] = df_to_tensor_index_columns(
241
- df=df,
242
- tensor=example_dict[f'{k}_index_value_feats'],
243
- group_idx_to_y_idx=group_idx_to_y_idx,
244
- groupby=v.groupby,
245
- index_columns=v.index_columns,
246
- )
247
- if v.value_columns:
248
- example_dict[f'{k}_index_value_feats'] = df_to_tensor_value_columns(
249
- df=df,
250
- tensor=example_dict[f'{k}_index_value_feats'],
251
- group_idx_to_y_idx=group_idx_to_y_idx,
252
- groupby=v.groupby,
253
- value_columns=v.value_columns,
254
- value_column_to_idx=v.value_column_to_idx
255
- )
256
-
257
- example_dict[f'{k}_index_value_token_type_ids'] = torch.full(
258
- [example_dict[f'{k}_index_value_feats'].shape[0]],
259
- self.token_type_to_token_type_id[k],
260
- dtype=torch.long,
261
- )
262
-
263
- event_times = list(group_idx_to_datetime.values())
264
- assert all([i == i for i in event_times])
265
- time_delta = [self.compute_time_delta(i, reference_time) for i in event_times]
266
- example_dict[f'{k}_index_value_time_delta'] = torch.tensor(time_delta)[:, None]
267
-
268
- assert example_dict[f'{k}_index_value_feats'].shape[0] == example_dict[f'{k}_index_value_time_delta'].shape[0]
269
-
270
- if v.text_columns:
271
- for j in group_idx_to_y_idx.keys():
272
- group_text = df[df[v.groupby] == j]
273
- for i in v.text_columns:
274
-
275
- column_text = [format(k) for k in list(dict.fromkeys(group_text[i].tolist())) if k is not None]
276
-
277
- if column_text:
278
-
279
- example_dict.setdefault(f'{k}_{i}', []).append(f"{', '.join(column_text)}.")
280
-
281
- event_times = group_text[time_column].iloc[0]
282
- time_delta = self.compute_time_delta(event_times, reference_time, to_tensor=False)
283
- example_dict.setdefault(f'{k}_{i}_time_delta', []).append(time_delta)
284
-
285
- return example_dict
286
-
287
- def return_mimic_cxr_features(self, study_id, reference_time=None):
288
- example_dict = {}
289
- if study_id is not None:
290
- self.load_mimic_cxr(study_id=study_id, reference_time=reference_time)
291
- for k, v in self.mimic_cxr_tables.items():
292
- if v.load:
293
-
294
- df = getattr(self, f'{k}_feats')
295
-
296
- if self.debug:
297
- example_dict.setdefault('mimic_cxr_inputs', []).append(k)
298
-
299
- if not df.empty:
300
-
301
- # The y-index for each group:
302
- group_idx_to_y_idx = OrderedDict()
303
- groups = df.dropna(
304
- subset=v.index_columns + v.value_columns + v.text_columns + v.target_sections,
305
- axis=0,
306
- how='all'
307
- )
308
- groups = groups.drop_duplicates(subset=[v.groupby])[[v.groupby]]
309
- groups = groups.reset_index(drop=True)
310
- for i, row in groups.iterrows():
311
- group_idx_to_y_idx[row[v.groupby]] = i
312
-
313
- if v.index_columns and group_idx_to_y_idx:
314
-
315
- example_dict[f'{k}_index_value_feats'] = torch.zeros(len(group_idx_to_y_idx), v.total_indices)
316
- if v.index_columns:
317
- example_dict[f'{k}_index_value_feats'] = df_to_tensor_index_columns(
318
- df=df,
319
- tensor=example_dict[f'{k}_index_value_feats'],
320
- group_idx_to_y_idx=group_idx_to_y_idx,
321
- groupby=v.groupby,
322
- index_columns=v.index_columns,
323
- )
324
-
325
- example_dict[f'{k}_index_value_token_type_ids'] = torch.full(
326
- [example_dict[f'{k}_index_value_feats'].shape[0]],
327
- self.token_type_to_token_type_id[k],
328
- dtype=torch.long,
329
- )
330
-
331
- if v.text_columns:
332
- for j in group_idx_to_y_idx.keys():
333
- group_text = df[df[v.groupby] == j]
334
- for i in v.text_columns:
335
- column_text = [format(k) for k in list(dict.fromkeys(group_text[i].tolist())) if k is not None]
336
- if column_text:
337
- example_dict.setdefault(f'{i}', []).append(f"{', '.join(column_text)}.")
338
-
339
- if v.target_sections:
340
- for j in group_idx_to_y_idx.keys():
341
- group_text = df[df[v.groupby] == j]
342
- for i in v.target_sections:
343
- column_text = [format(k) for k in list(dict.fromkeys(group_text[i].tolist())) if k is not None]
344
- assert len(column_text) == 1
345
- example_dict[i] = column_text[-1]
346
-
347
- return example_dict
348
-
349
- def compute_time_delta(self, event_time, reference_time, denominator = 3600, to_tensor=True):
350
- """
351
- How to we transform time-delta inputs? It appears that minutes are used as the input to
352
- a weight matrix in "Self-Supervised Transformer for Sparse and Irregularly Sampled Multivariate
353
- Clinical Time-Series". This is almost confirmed by the CVE class defined here:
354
- https://github.com/sindhura97/STraTS/blob/main/strats_notebook.ipynb, where the input has
355
- a size of one.
356
- """
357
- time_delta = reference_time - event_time
358
- time_delta = time_delta.total_seconds() / (denominator)
359
- assert isinstance(time_delta, float), f'time_delta should be float, not {type(time_delta)}.'
360
- if time_delta < 0:
361
- raise ValueError(f'time_delta should be greater than or equal to zero, not {time_delta}.')
362
- time_delta = self.time_delta_map(time_delta)
363
- if to_tensor:
364
- time_delta = torch.tensor(time_delta)
365
- return time_delta
366
-
367
- def filter_admissions_by_time_span(self, df, time_column):
368
- return df[(df[time_column] > self.start_time) & (df[time_column] <= self.end_time)]
369
-