copy data from repo
Browse files- README.md +85 -3
- noise_classifier.py +72 -0
- predict_potassium_12lead.py +46 -0
- predict_potassium_1lead.py +44 -0
- preprocessing.py +147 -0
- utils/datasets.py +392 -0
- utils/ecg_utils.py +114 -0
- utils/models.py +247 -0
- utils/training_models.py +447 -0
README.md
CHANGED
@@ -1,3 +1,85 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Kardio-Net: A Deep Learning Model for Predicting Serum Potassium Levels from Apple Watch ECG in ESRD Patients
|
2 |
+
|
3 |
+
|
4 |
+
### I. Input preparation
|
5 |
+
|
6 |
+
|
7 |
+
1. Prerequisites
|
8 |
+
|
9 |
+
Ensure that all ECG files are saved in *.npy format with the following shapes:
|
10 |
+
- (5000, 12) for 12-lead ECGs
|
11 |
+
- (500 * 30 seconds, 1) for Apple Watch ECGs
|
12 |
+
|
13 |
+
Store these files in a single flat directory, and include a manifest CSV file in the same location to accompany
|
14 |
+
them.
|
15 |
+
|
16 |
+
|
17 |
+
2. Manifest File Format
|
18 |
+
|
19 |
+
The manifest CSV file should include a header.
|
20 |
+
Each row corresponds to one ECG file with the following columns:
|
21 |
+
|
22 |
+
- filename: Name of the .npy file (without the extension).
|
23 |
+
- label: serum potassium label.
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
### II. Inference
|
28 |
+
|
29 |
+
|
30 |
+
### For 12-Lead ECG
|
31 |
+
|
32 |
+
<!-- #region -->
|
33 |
+
1. use predict_potassium_12lead.py to get the potassium level prediction.
|
34 |
+
|
35 |
+
|
36 |
+
2. Edit data_path and manifest_path and run the predict_potassium.py script.
|
37 |
+
|
38 |
+
|
39 |
+
3. Upon completion, a file named "dataloader_0_predictions.csv" will be saved in the same directory. This file contains the inference results "preds" from the model.
|
40 |
+
|
41 |
+
|
42 |
+
4. Use generate_result.py to get the performance metric and figure.
|
43 |
+
<!-- #endregion -->
|
44 |
+
|
45 |
+
### For Single Lead ECG
|
46 |
+
|
47 |
+
|
48 |
+
### A. ECG preprocessing and segmentation
|
49 |
+
|
50 |
+
<!-- #region -->
|
51 |
+
1. Use preprocessing.py for denoise, normalize, and segment ECG into 5-second for input
|
52 |
+
|
53 |
+
|
54 |
+
2. Set the following paths in preprocessing.py:
|
55 |
+
|
56 |
+
- raw_directory = "path/to/raw_data_directory" #raw ecg folder for target task
|
57 |
+
- output_directory = "path/to/output_directory" #output folder for normalize ECG
|
58 |
+
- manifest_path = "/path/to/manifest.csv" # Manifest file path
|
59 |
+
- output_path = "path/to/output_path" # Output path for segmented ECGs
|
60 |
+
|
61 |
+
|
62 |
+
3. Execute predict.py.
|
63 |
+
|
64 |
+
|
65 |
+
4. If the ECG files were already normalize, can execute the segmentation function only.
|
66 |
+
<!-- #endregion -->
|
67 |
+
|
68 |
+
### B. potassium regression model
|
69 |
+
|
70 |
+
<!-- #region -->
|
71 |
+
1. use predict_potassium_1lead.py to get the potassium level prediction.
|
72 |
+
|
73 |
+
|
74 |
+
2. Edit data_path and manifest_path and run the predict_potassium.py script.
|
75 |
+
|
76 |
+
|
77 |
+
3. Upon completion, a file named "dataloader_0_predictions.csv" will be saved in the same directory. This file contains the inference results "preds" from the model.
|
78 |
+
|
79 |
+
|
80 |
+
4. Use generate_result.py to get the performance metric and figure.
|
81 |
+
<!-- #endregion -->
|
82 |
+
|
83 |
+
```python
|
84 |
+
|
85 |
+
```
|
noise_classifier.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# An example of how to run inference using a model trained with cvair.
|
2 |
+
# We want to use a pretrained model to make predictions on a dataset of new examples.
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
import torch
|
7 |
+
from pytorch_lightning import Trainer
|
8 |
+
from torch.utils.data import DataLoader
|
9 |
+
from utils.datasets import ECGSingleLeadDataset
|
10 |
+
from utils.models import EffNet
|
11 |
+
from utils.training_models import BinaryClassificationModel
|
12 |
+
|
13 |
+
|
14 |
+
def append_noise_predictions_to_manifest(data_path, manifest_path, weights_path):
|
15 |
+
# Initialize a dataset
|
16 |
+
test_ds = ECGSingleLeadDataset(
|
17 |
+
data_path=data_path,
|
18 |
+
manifest_path=manifest_path,
|
19 |
+
update_manifest_func=None,
|
20 |
+
)
|
21 |
+
|
22 |
+
# Wrap the dataset in a dataloader
|
23 |
+
test_dl = DataLoader(
|
24 |
+
test_ds,
|
25 |
+
num_workers=16,
|
26 |
+
batch_size=512,
|
27 |
+
drop_last=False,
|
28 |
+
shuffle=False
|
29 |
+
)
|
30 |
+
|
31 |
+
# Initialize the backbone model
|
32 |
+
backbone = EffNet(input_channels=1, output_neurons=1)
|
33 |
+
|
34 |
+
# Pass the backbone to a wrapper
|
35 |
+
model = BinaryClassificationModel(backbone)
|
36 |
+
|
37 |
+
# Load the pretrained weights
|
38 |
+
weights = torch.load(weights_path)
|
39 |
+
model.load_state_dict(weights)
|
40 |
+
|
41 |
+
# Initialize a Trainer object
|
42 |
+
trainer = Trainer(accelerator="gpu", devices=1)
|
43 |
+
|
44 |
+
# Run inference
|
45 |
+
trainer.predict(model, dataloaders=test_dl)
|
46 |
+
|
47 |
+
# Read the predictions CSV file
|
48 |
+
df = pd.read_csv('dataloader_0_predictions.csv')
|
49 |
+
|
50 |
+
# Normalize predictions
|
51 |
+
max_preds = df['preds'].max()
|
52 |
+
min_preds = df['preds'].min()
|
53 |
+
df['noise_preds_normal'] = (df['preds'] - min_preds) / (max_preds - min_preds)
|
54 |
+
|
55 |
+
# Drop the original predictions column
|
56 |
+
df.drop(columns='preds', inplace=True)
|
57 |
+
|
58 |
+
# Save the modified dataframe to the original manifest path
|
59 |
+
df.to_csv(manifest_path)
|
60 |
+
|
61 |
+
|
62 |
+
|
63 |
+
if __name__ == "__main__":
|
64 |
+
|
65 |
+
data_path="/your/wearable/ecg/data path/"
|
66 |
+
|
67 |
+
manifest_path="manifest"
|
68 |
+
|
69 |
+
weights_path = "model_noise_classifier.pt"
|
70 |
+
|
71 |
+
append_noise_predictions_to_manifest(data_path, manifest_path, weights_path)
|
72 |
+
|
predict_potassium_12lead.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from pytorch_lightning import Trainer
|
3 |
+
from torch.utils.data import DataLoader
|
4 |
+
from utils.datasets import ECGDataset
|
5 |
+
from utils.models import EffNet
|
6 |
+
from utils.training_models import RegressionModel
|
7 |
+
|
8 |
+
|
9 |
+
# +
|
10 |
+
# This is the path where your data samples are stored.
|
11 |
+
data_path = "your/ecg/data/folder"
|
12 |
+
|
13 |
+
# This is the path where your manifest, containing filenames for inference to be run on, is stored.
|
14 |
+
manifest_path = 'your/manifest/path'
|
15 |
+
# -
|
16 |
+
|
17 |
+
|
18 |
+
# Initialize a dataset that contains the examples you want to run prediction on.
|
19 |
+
test_ds = ECGDataset(
|
20 |
+
split="test",
|
21 |
+
data_path=data_path,
|
22 |
+
manifest_path=manifest_path,
|
23 |
+
update_manifest_func=None,
|
24 |
+
)
|
25 |
+
|
26 |
+
# Wrap the dataset in a dataloader to handle batching and multithreading.
|
27 |
+
test_dl = DataLoader(
|
28 |
+
test_ds,
|
29 |
+
num_workers=16,
|
30 |
+
batch_size=256,
|
31 |
+
drop_last=False,
|
32 |
+
shuffle=False
|
33 |
+
)
|
34 |
+
|
35 |
+
# Initialize the "backbone", the core model weights that will act on the data.
|
36 |
+
backbone = EffNet(input_channels=12, output_neurons=1)
|
37 |
+
|
38 |
+
model = RegressionModel(backbone)
|
39 |
+
|
40 |
+
weights = torch.load("model_12_lead.pt")
|
41 |
+
print(model.load_state_dict(weights))
|
42 |
+
|
43 |
+
# +
|
44 |
+
trainer = Trainer(accelerator="gpu", devices=1)
|
45 |
+
|
46 |
+
trainer.predict(model, dataloaders=test_dl)
|
predict_potassium_1lead.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from pytorch_lightning import Trainer
|
3 |
+
from torch.utils.data import DataLoader
|
4 |
+
from utils.datasets import ECGSingleLeadDataset
|
5 |
+
from utils.models import EffNet
|
6 |
+
from utils.training_models import RegressionModel
|
7 |
+
|
8 |
+
# +
|
9 |
+
# This is the path where your data samples are stored.
|
10 |
+
data_path = "your/ecg/data/folder"
|
11 |
+
|
12 |
+
# This is the path where your manifest, containing filenames for inference to be run on, is stored.
|
13 |
+
manifest_path = 'your/manifest/path'
|
14 |
+
# -
|
15 |
+
|
16 |
+
# Initialize a dataset that contains the examples you want to run prediction on.
|
17 |
+
test_ds = ECGSingleLeadDataset(
|
18 |
+
data_path=data_path,
|
19 |
+
manifest_path=manifest_path,
|
20 |
+
update_manifest_func=None,
|
21 |
+
)
|
22 |
+
|
23 |
+
# Wrap the dataset in a dataloader to handle batching and multithreading.
|
24 |
+
test_dl = DataLoader(
|
25 |
+
test_ds,
|
26 |
+
num_workers=16,
|
27 |
+
batch_size=512,
|
28 |
+
drop_last=False,
|
29 |
+
shuffle=False
|
30 |
+
)
|
31 |
+
|
32 |
+
# +
|
33 |
+
backbone = EffNet()
|
34 |
+
|
35 |
+
model = RegressionModel(backbone)
|
36 |
+
# -
|
37 |
+
|
38 |
+
weights = torch.load("model_single_lead_5seconds_length.pt")
|
39 |
+
print(model.load_state_dict(weights))
|
40 |
+
|
41 |
+
# +
|
42 |
+
trainer = Trainer(accelerator="gpu", devices=1)
|
43 |
+
|
44 |
+
trainer.predict(model, dataloaders=test_dl)
|
preprocessing.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding: utf-8
|
3 |
+
# %%
|
4 |
+
import numpy as np
|
5 |
+
import os
|
6 |
+
from pathlib import Path
|
7 |
+
from tqdm import tqdm
|
8 |
+
from utils.ecg_utils import (
|
9 |
+
remove_baseline_wander,
|
10 |
+
wavelet_denoise_signal,
|
11 |
+
plot_12_lead_ecg,
|
12 |
+
)
|
13 |
+
|
14 |
+
|
15 |
+
# %%
|
16 |
+
def calculate_means_stds(npy_directory, n):
|
17 |
+
|
18 |
+
npy_directory = Path(npy_directory)
|
19 |
+
filelist = os.listdir(npy_directory)
|
20 |
+
np.random.shuffle(filelist)
|
21 |
+
|
22 |
+
full_batch = np.zeros((n, 5000, 12))
|
23 |
+
count = 0
|
24 |
+
|
25 |
+
for i, npy_filename in enumerate(tqdm(filelist[:n])):
|
26 |
+
npy_filepath = npy_directory / npy_filename
|
27 |
+
ekg_numpy_array = np.load(npy_filepath)
|
28 |
+
|
29 |
+
if ekg_numpy_array.shape[0] != 5000:
|
30 |
+
continue
|
31 |
+
|
32 |
+
full_batch[count] = ekg_numpy_array
|
33 |
+
count += 1
|
34 |
+
|
35 |
+
full_batch = full_batch[:count] # Trim the array to remove unused entries
|
36 |
+
ecg_means = np.mean(full_batch, axis=(0, 1))
|
37 |
+
ecg_stds = np.std(full_batch, axis=(0, 1))
|
38 |
+
|
39 |
+
if ecg_means.shape[0] == ecg_stds.shape[0] == 12:
|
40 |
+
print('Shape of mean and std for ECG normalization are correct!')
|
41 |
+
|
42 |
+
return ecg_means, ecg_stds
|
43 |
+
|
44 |
+
|
45 |
+
# %%
|
46 |
+
# run the function on a list of filenames.
|
47 |
+
def ecg_denoising(
|
48 |
+
raw_directory=raw_directory,
|
49 |
+
output_directory=output_directory,
|
50 |
+
ecg_means = ecg_means,
|
51 |
+
ecg_stds = ecg_stds
|
52 |
+
):
|
53 |
+
|
54 |
+
filelist = os.listdir(raw_directory)
|
55 |
+
|
56 |
+
for i, filename in enumerate(tqdm(filelist[:n])):
|
57 |
+
|
58 |
+
# Signal processing
|
59 |
+
raw_directory = Path(raw_directory)
|
60 |
+
ecg_filepath = raw_directory / filename
|
61 |
+
ecg_numpy_array = np.load(ecg_filepath)
|
62 |
+
# 1. Wandering baseline removal
|
63 |
+
ecg_numpy_array = remove_baseline_wander(
|
64 |
+
ecg_numpy_array, sampling_frequency=sampling_frequency
|
65 |
+
)
|
66 |
+
|
67 |
+
# Discrete wavelet transform denoising
|
68 |
+
for lead in range(12):
|
69 |
+
ecg_numpy_array[:, lead] = wavelet_denoise_signal(ecg_numpy_array[:, lead])
|
70 |
+
|
71 |
+
# Lead-wise normalization with precomputed means and standard deviations
|
72 |
+
ecg_numpy_array = (ecg_numpy_array - ecg_means) / ecg_stds
|
73 |
+
|
74 |
+
np.save(output_directory / filename, ecg_numpy_array)
|
75 |
+
|
76 |
+
return True
|
77 |
+
|
78 |
+
|
79 |
+
# %%
|
80 |
+
def segmentation(data_path, manifest_path, output_path, length, steps):
|
81 |
+
manifest = pd.read_csv(manifest_path)
|
82 |
+
|
83 |
+
data = []
|
84 |
+
print('Staring segmenting ECG......')
|
85 |
+
for index in tqdm(range(manifest.shape[0])):
|
86 |
+
mrn = manifest['MRN'].iloc[index] #MRN as column name for medical record number
|
87 |
+
filename = manifest['filename'].iloc[index] #filename as column name for ecg filename
|
88 |
+
k = manifest['TEST_RSLT'].iloc[index] #TEST_RSLT as column name for potassium level
|
89 |
+
|
90 |
+
ecg_array = np.load(os.path.join(data_path, filename))
|
91 |
+
ecg_array = ecg_array[:, 0] # assume lead I is the first lead in npy file
|
92 |
+
|
93 |
+
# Loop through every second as start point:
|
94 |
+
for start in range(0, len(ecg_array), 500*steps):
|
95 |
+
end = start + 500 * length # 500 points for each seconds
|
96 |
+
|
97 |
+
if start >= 0 and end <= len(ecg_array):
|
98 |
+
sample = ecg_array[start:end]
|
99 |
+
|
100 |
+
if len(sample) == 500 * length:
|
101 |
+
data.append({'mrn': mrn, 'original_filename': filename, 'ecg': sample, 'label': k})
|
102 |
+
else:
|
103 |
+
print(f'Different sample size for {filename}: {len(sample)}')
|
104 |
+
|
105 |
+
df = pd.DataFrame(data)
|
106 |
+
df['filename'] = None
|
107 |
+
|
108 |
+
if not os.path.exists(output_path):
|
109 |
+
os.makedirs(output_path)
|
110 |
+
|
111 |
+
print('Saving segmented ECG......')
|
112 |
+
for index, row in df.iterrows():
|
113 |
+
original_filename = row['original_filename']
|
114 |
+
ecg_array = row['ecg']
|
115 |
+
new_file_name = f"{original_filename.rsplit('.', 1)[0]}_{index+1}.npy"
|
116 |
+
new_file_path = os.path.join(output_path, new_file_name)
|
117 |
+
|
118 |
+
df.at[index, 'filename'] = new_file_name
|
119 |
+
np.save(new_file_path, ecg_array)
|
120 |
+
|
121 |
+
df.drop(columns=['ecg'], inplace=True)
|
122 |
+
df.to_csv(f'{output_path}/{length}seconds_length_{steps}seconds_step_ecg_manifest.csv')
|
123 |
+
|
124 |
+
|
125 |
+
# %%
|
126 |
+
if __name__ == "__main__":
|
127 |
+
npy_directory = "ecg folder for entire database"
|
128 |
+
n = 100000 # the number of ecg for calculating mean and std
|
129 |
+
|
130 |
+
print('Calculating ECG means and stds........')
|
131 |
+
ecg_means, ecg_stds = calculate_means_stds(npy_directory, n)
|
132 |
+
|
133 |
+
raw_directory = "path/to/raw_data_directory" #raw ecg folder for target task
|
134 |
+
output_directory = "path/to/output_directory" #output ecg folder for target task
|
135 |
+
|
136 |
+
print('Denoising and Normalizing ECGs........')
|
137 |
+
ecg_denoising(raw_directory, output_directory, ecg_means, ecg_stds)
|
138 |
+
|
139 |
+
|
140 |
+
data_path = output_directory # Output directory from the above step
|
141 |
+
manifest_path = "/path/to/manifest.csv" # Manifest file path
|
142 |
+
output_path = "path/to/output_path" # Output path for segmented ECGs
|
143 |
+
length = 5 # Length of each segment in seconds
|
144 |
+
steps = 1 # Number of seconds step
|
145 |
+
|
146 |
+
process_ecg(data_path, manifest_path, output_path, length, steps)
|
147 |
+
|
utils/datasets.py
ADDED
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Callable, List, Tuple, Optional, Iterable, Dict, Union
|
5 |
+
from typing_extensions import TypedDict, Unpack, Required, NotRequired
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import pandas as pd
|
9 |
+
import torch
|
10 |
+
import torchvision.transforms.functional as TF
|
11 |
+
import wandb
|
12 |
+
from torch.utils.data import Dataset
|
13 |
+
import torch.nn as nn
|
14 |
+
|
15 |
+
|
16 |
+
def format_mrn(mrn):
|
17 |
+
return str(mrn).strip().zfill(20)
|
18 |
+
|
19 |
+
|
20 |
+
class CedarsDatasetTypeAnnotations(TypedDict, total=False):
|
21 |
+
"""A dummy class used to make IDE autocomplete and tooltips work properly with how we pass **kwargs through in subclasses of CedarsDataset."""
|
22 |
+
data_path: Required[Union[Path, str]]
|
23 |
+
manifest_path: Required[Union[Path, str]]
|
24 |
+
split: NotRequired[str]
|
25 |
+
labels: NotRequired[Iterable[str]]
|
26 |
+
extra_inputs: NotRequired[Iterable[str]]
|
27 |
+
update_manifest_func: NotRequired[Callable[[pd.DataFrame], pd.DataFrame]]
|
28 |
+
subsample: NotRequired[Union[Path, str]]
|
29 |
+
augmentations: NotRequired[Union[Iterable[Callable[[torch.Tensor], torch.Tensor]], Callable[[dict], dict], nn.Module]]
|
30 |
+
apply_augmentations_to: NotRequired[Iterable[str]]
|
31 |
+
verify_existing: NotRequired[bool]
|
32 |
+
drop_na_labels: NotRequired[bool]
|
33 |
+
verbose: NotRequired[bool]
|
34 |
+
|
35 |
+
|
36 |
+
class CedarsDataset(Dataset):
|
37 |
+
"""
|
38 |
+
Generic parent class for several differnet kinds of common datasets we use here at Cedars CVAIR.
|
39 |
+
|
40 |
+
Expects to be used in a scenario where you have a big folder full of input examples (videos, ecgs, 3d arrays, images, etc.) and a big CSV that contains metadata and labels for those examples, called a 'manifest'.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
data_path: Path to a directory full of files you want the dataset to load from.
|
44 |
+
manifest_path: Path to a CSV or Parquet file containing the names, labels, and/or metadata of your files.
|
45 |
+
split: Optional. Allows user to select which split of the manifest to use, assuming the presence of a categorical 'split' column. Defaults to None, meaning that the entire manifest is used by default.
|
46 |
+
extra_inputs: Optional. A list of column names in the manifest that contain additional inputs to the model. Defaults to None.
|
47 |
+
labels: Optional. Name(s) of column(s) in your manifest which contain training labels, in the order you want them returned. If set to None, the dataset will not return any labels, only filenames and inputs. Defaults to None.
|
48 |
+
update_manifest_func: Optional. Allows user to pass in a function to preprocess the manifest after it is loaded, but before the dataset does anything to it.
|
49 |
+
subsample: Optional. A number indicating how many examples to randomly subsample from the manifest. Defaults to None.
|
50 |
+
verbose: Whether to print out progress statements when initializing. Defaults to True.
|
51 |
+
augmentations: Optional. Can be a list of augmentation functions which take in a tensor and return a tensor, a single custom augmentation function which takes in a dict and returns a dict, or a single nn.Module. Defaults to None.
|
52 |
+
apply_augmentations_to: Optional. A list of strings indicating which batch elements to apply augmentations to. Defaults to ("primary_input").
|
53 |
+
"""
|
54 |
+
|
55 |
+
def __init__(
|
56 |
+
self,
|
57 |
+
data_path,
|
58 |
+
manifest_path=None,
|
59 |
+
split=None,
|
60 |
+
labels=None,
|
61 |
+
extra_inputs=None,
|
62 |
+
update_manifest_func=None,
|
63 |
+
subsample=None,
|
64 |
+
augmentations=None,
|
65 |
+
apply_augmentations_to=("primary_input",),
|
66 |
+
verify_existing=True,
|
67 |
+
drop_na_labels=True,
|
68 |
+
verbose=True,
|
69 |
+
):
|
70 |
+
|
71 |
+
self.data_path = Path(data_path)
|
72 |
+
self.augmentations = augmentations
|
73 |
+
self.apply_augmentations_to = apply_augmentations_to
|
74 |
+
self.extra_inputs = extra_inputs
|
75 |
+
self.labels = labels
|
76 |
+
|
77 |
+
if isinstance(self.augmentations, nn.Module):
|
78 |
+
self.augmentations = [self.augmentations]
|
79 |
+
|
80 |
+
if (self.labels is None) and verbose:
|
81 |
+
print(
|
82 |
+
"No label column names were provided, only filenames and inputs will be returned."
|
83 |
+
)
|
84 |
+
if (self.labels is not None) and isinstance(self.labels, str):
|
85 |
+
self.labels = [self.labels]
|
86 |
+
if (self.extra_inputs is not None) and isinstance(self.extra_inputs, str):
|
87 |
+
self.extra_inputs = [self.extra_inputs]
|
88 |
+
|
89 |
+
# Read manifest file
|
90 |
+
if manifest_path is not None:
|
91 |
+
self.manifest_path = Path(manifest_path)
|
92 |
+
else:
|
93 |
+
self.manifest_path = self.data_path / "manifest.csv"
|
94 |
+
|
95 |
+
if self.manifest_path.exists():
|
96 |
+
if self.manifest_path.suffix == ".csv":
|
97 |
+
self.manifest = pd.read_csv(self.manifest_path, low_memory=False)
|
98 |
+
elif self.manifest_path.suffix == ".parquet":
|
99 |
+
self.manifest = pd.read_parquet(self.manifest_path)
|
100 |
+
else:
|
101 |
+
self.manifest = pd.DataFrame(
|
102 |
+
{
|
103 |
+
"filename": os.listdir(self.data_path),
|
104 |
+
}
|
105 |
+
)
|
106 |
+
|
107 |
+
# do manifest processing that's specific to a given task (different from update_manifest_func,
|
108 |
+
# exists as a method overridden in child classes)
|
109 |
+
self.manifest = self.process_manifest(self.manifest)
|
110 |
+
|
111 |
+
# Apply user-provided update function to manifest
|
112 |
+
if update_manifest_func is not None:
|
113 |
+
self.manifest = update_manifest_func(self, self.manifest)
|
114 |
+
|
115 |
+
# Usually set to "train", "val", or "test". If set to None, the entire manifest is used.
|
116 |
+
if split is not None:
|
117 |
+
self.manifest = self.manifest[self.manifest["split"] == split]
|
118 |
+
if verbose:
|
119 |
+
print(
|
120 |
+
f"Manifest loaded. \nSplit: {split}\nLength: {len(self.manifest):,}"
|
121 |
+
)
|
122 |
+
|
123 |
+
# Make sure all files actually exist. This can be disabled for efficiency if
|
124 |
+
# you have an especially large dataset
|
125 |
+
if verify_existing and "filename" in self.manifest:
|
126 |
+
old_len = len(self.manifest)
|
127 |
+
existing_files = os.listdir(self.data_path)
|
128 |
+
self.manifest = self.manifest[
|
129 |
+
self.manifest["filename"].isin(existing_files)
|
130 |
+
]
|
131 |
+
new_len = len(self.manifest)
|
132 |
+
if verbose:
|
133 |
+
print(
|
134 |
+
f"{old_len - new_len} files in the manifest are missing from {self.data_path}."
|
135 |
+
)
|
136 |
+
elif (not verify_existing) and verbose:
|
137 |
+
print(
|
138 |
+
f"self.verify_existing is set to False, so it's possible for the manifest to contain filenames which are not present in {data_path}"
|
139 |
+
)
|
140 |
+
|
141 |
+
# Option to subsample dataset for doing smaller, faster runs
|
142 |
+
if subsample is not None:
|
143 |
+
if isinstance(subsample, int):
|
144 |
+
self.manifest = self.manifest.sample(n=subsample)
|
145 |
+
else:
|
146 |
+
self.manifest = self.manifest.sample(frac=subsample)
|
147 |
+
if verbose:
|
148 |
+
print(f"{subsample} examples subsampled.")
|
149 |
+
|
150 |
+
# Make sure that there are no NAN labels
|
151 |
+
if (self.labels is not None) and drop_na_labels:
|
152 |
+
old_len = len(self.manifest)
|
153 |
+
self.manifest = self.manifest.dropna(subset=self.labels)
|
154 |
+
new_len = len(self.manifest)
|
155 |
+
if verbose:
|
156 |
+
print(
|
157 |
+
f"{old_len - new_len} examples contained NaN value(s) in their labels and were dropped."
|
158 |
+
)
|
159 |
+
elif (self.labels is not None) and (not drop_na_labels):
|
160 |
+
print(
|
161 |
+
"drop_na_labels is set to False, so it's possible for the manifest to contain NaN-valued labels."
|
162 |
+
)
|
163 |
+
|
164 |
+
# Save manifest to weights and biases run directory
|
165 |
+
if wandb.run is not None:
|
166 |
+
run_data_path = Path(wandb.run.dir).parent / "data"
|
167 |
+
if not run_data_path.is_dir():
|
168 |
+
run_data_path.mkdir()
|
169 |
+
|
170 |
+
save_name = "manifest.csv"
|
171 |
+
if split is not None:
|
172 |
+
save_name = f"{split}_{save_name}"
|
173 |
+
|
174 |
+
self.manifest.to_csv(run_data_path / save_name)
|
175 |
+
|
176 |
+
if verbose:
|
177 |
+
print(f"Copy of manifest saved to {run_data_path}")
|
178 |
+
|
179 |
+
def __len__(self) -> int:
|
180 |
+
return len(self.manifest)
|
181 |
+
|
182 |
+
def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:
|
183 |
+
output = {}
|
184 |
+
row = self.manifest.iloc[index]
|
185 |
+
if "filename" in row:
|
186 |
+
output["filename"] = row["filename"]
|
187 |
+
if self.labels is not None:
|
188 |
+
output["labels"] = torch.FloatTensor(row[self.labels])
|
189 |
+
file_results = self.read_file(self.data_path / output["filename"], row)
|
190 |
+
if isinstance(file_results, dict):
|
191 |
+
output.update(file_results)
|
192 |
+
else:
|
193 |
+
output["primary_input"] = file_results
|
194 |
+
|
195 |
+
if self.extra_inputs is not None:
|
196 |
+
output["extra_inputs"] = row["extra_inputs"]
|
197 |
+
|
198 |
+
if self.augmentations is not None:
|
199 |
+
output = self.augment(output)
|
200 |
+
|
201 |
+
return output
|
202 |
+
|
203 |
+
def process_manifest(self, manifest: pd.DataFrame) -> pd.DataFrame:
|
204 |
+
if "mrn" in manifest.columns:
|
205 |
+
manifest["mrn"] = manifest["mrn"].apply(format_mrn)
|
206 |
+
if "study_date" in manifest.columns:
|
207 |
+
manifest["study_date"] = pd.to_datetime(manifest["study_date"])
|
208 |
+
if "dob" in manifest.columns:
|
209 |
+
manifest["dob"] = pd.to_datetime(
|
210 |
+
manifest["dob"], infer_datetime_format=True, errors="coerce"
|
211 |
+
)
|
212 |
+
if ("study_date" in manifest.columns) and ("dob" in manifest.columns):
|
213 |
+
manifest["study_age"] = (
|
214 |
+
manifest["study_date"] - manifest["dob"]
|
215 |
+
) / np.timedelta64(1, "Y")
|
216 |
+
return manifest
|
217 |
+
|
218 |
+
def augment(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
219 |
+
|
220 |
+
if isinstance(self.augmentations, Iterable):
|
221 |
+
# would use torch.stack here for cleanliness, but it seems that torchvision
|
222 |
+
# transforms v1's claims about supporting "arbitrary leading dimensions" is
|
223 |
+
# hogwash. they only support up to 4D. so we have to concatenate along the
|
224 |
+
# channel dimension, then apply the augmentations, then split along the channel
|
225 |
+
# dimension.
|
226 |
+
augmentable_inputs = torch.cat(
|
227 |
+
[output_dict[key] for key in self.apply_augmentations_to], dim=0
|
228 |
+
) # (C*N, T, H, W)
|
229 |
+
|
230 |
+
for aug in self.augmentations:
|
231 |
+
augmentable_inputs = aug(augmentable_inputs)
|
232 |
+
|
233 |
+
place = 0
|
234 |
+
for i, key in enumerate(self.apply_augmentations_to):
|
235 |
+
n_channels = output_dict[key].shape[0]
|
236 |
+
output_dict[key] = augmentable_inputs[place:place+n_channels]
|
237 |
+
place += n_channels
|
238 |
+
|
239 |
+
elif isinstance(self.augmentations, Callable):
|
240 |
+
output_dict = self.augmentations(output_dict)
|
241 |
+
|
242 |
+
else:
|
243 |
+
raise Exception(
|
244 |
+
"self.augmentations must be either an Iterable of augmentations or a single custom augmentation function."
|
245 |
+
)
|
246 |
+
|
247 |
+
return output_dict
|
248 |
+
|
249 |
+
def read_file(self, filepath: Path, row: Optional[pd.Series] = None) -> torch.Tensor:
|
250 |
+
raise NotImplementedError
|
251 |
+
|
252 |
+
|
253 |
+
class ECGDataset(CedarsDataset):
|
254 |
+
def __init__(
|
255 |
+
self,
|
256 |
+
# CedarsDataset params
|
257 |
+
data_path: Union[Path, str],
|
258 |
+
manifest_path: Union[Path, str] = None,
|
259 |
+
split: str = None,
|
260 |
+
labels: Union[List[str], str] = None,
|
261 |
+
update_manifest_func: Callable = None,
|
262 |
+
subsample: float = None,
|
263 |
+
verbose: bool = True,
|
264 |
+
verify_existing: bool = True,
|
265 |
+
drop_na_labels: bool = True,
|
266 |
+
# ECGoDataset params
|
267 |
+
leads: List[str] = None,
|
268 |
+
random_lead: bool = False, # New parameter for random lead selection
|
269 |
+
data_length: int = 5000,
|
270 |
+
**kwargs,
|
271 |
+
):
|
272 |
+
"""
|
273 |
+
Args:
|
274 |
+
leads: List[str] -- which leads you want passed to the model. Defaults to all 12.
|
275 |
+
"""
|
276 |
+
|
277 |
+
super().__init__(
|
278 |
+
data_path=data_path,
|
279 |
+
manifest_path=manifest_path,
|
280 |
+
split=split,
|
281 |
+
labels=labels,
|
282 |
+
update_manifest_func=update_manifest_func,
|
283 |
+
subsample=subsample,
|
284 |
+
verbose=verbose,
|
285 |
+
verify_existing=verify_existing,
|
286 |
+
drop_na_labels=drop_na_labels,
|
287 |
+
**kwargs,
|
288 |
+
)
|
289 |
+
|
290 |
+
self.lead_order = [
|
291 |
+
"I",
|
292 |
+
"II",
|
293 |
+
"III",
|
294 |
+
"aVR",
|
295 |
+
"aVL",
|
296 |
+
"aVF",
|
297 |
+
"V1",
|
298 |
+
"V2",
|
299 |
+
"V3",
|
300 |
+
"V4",
|
301 |
+
"V5",
|
302 |
+
"V6",
|
303 |
+
]
|
304 |
+
self.leads = leads
|
305 |
+
if self.leads is None:
|
306 |
+
self.leads = self.lead_order
|
307 |
+
if isinstance(self.leads, str):
|
308 |
+
self.leads = [self.leads]
|
309 |
+
|
310 |
+
if "first_lead_only" in kwargs:
|
311 |
+
raise (
|
312 |
+
Exception(
|
313 |
+
'"first_lead_only" has been deprecated. Please pass leads=["I"] \
|
314 |
+
instead if you would like to train on only the first lead.'
|
315 |
+
)
|
316 |
+
)
|
317 |
+
|
318 |
+
self.random_lead = random_lead # Storing the random_lead attribute
|
319 |
+
|
320 |
+
self.data_length = data_length
|
321 |
+
|
322 |
+
|
323 |
+
def read_file(self, filepath, row=None):
|
324 |
+
# ECGs are usually stored as .npy files.
|
325 |
+
file = np.load(filepath)
|
326 |
+
if file.shape[0] != 12:
|
327 |
+
file = file.T
|
328 |
+
file = torch.tensor(file).float()
|
329 |
+
|
330 |
+
# Slice the data to the specified length
|
331 |
+
file = file[:, :self.data_length]
|
332 |
+
|
333 |
+
if self.random_lead:
|
334 |
+
lead_idx = random.choice(range(12))
|
335 |
+
file = file[lead_idx:lead_idx+1] # Select the random lead
|
336 |
+
else:
|
337 |
+
channels = [self.lead_order.index(lead) for lead in self.leads]
|
338 |
+
file = file[channels]
|
339 |
+
|
340 |
+
|
341 |
+
# Final shape should ideally be NumLeadsxTime(or NumLeadsxTime depending on the resolution of the ECG)
|
342 |
+
return file
|
343 |
+
|
344 |
+
|
345 |
+
class ECGSingleLeadDataset(CedarsDataset):
|
346 |
+
def __init__(
|
347 |
+
self,
|
348 |
+
# CedarsDataset params
|
349 |
+
data_path: Union[Path, str],
|
350 |
+
manifest_path: Union[Path, str] = None,
|
351 |
+
labels: Union[List[str], str] = None,
|
352 |
+
update_manifest_func: Callable = None,
|
353 |
+
subsample: float = None,
|
354 |
+
verbose: bool = True,
|
355 |
+
verify_existing: bool = True,
|
356 |
+
drop_na_labels: bool = True,
|
357 |
+
**kwargs,
|
358 |
+
):
|
359 |
+
"""
|
360 |
+
Args:
|
361 |
+
leads: List[str] -- which leads you want passed to the model. Defaults to all 12.
|
362 |
+
"""
|
363 |
+
|
364 |
+
super().__init__(
|
365 |
+
data_path=data_path,
|
366 |
+
manifest_path=manifest_path,
|
367 |
+
labels=labels,
|
368 |
+
update_manifest_func=update_manifest_func,
|
369 |
+
subsample=subsample,
|
370 |
+
verbose=verbose,
|
371 |
+
verify_existing=verify_existing,
|
372 |
+
drop_na_labels=drop_na_labels,
|
373 |
+
**kwargs,
|
374 |
+
)
|
375 |
+
|
376 |
+
|
377 |
+
def read_file(self, filepath, row=None):
|
378 |
+
# ECGs are usually stored as .npy files.
|
379 |
+
try:
|
380 |
+
file = np.load(filepath)
|
381 |
+
except Exception as e:
|
382 |
+
print(filepath)
|
383 |
+
print(e)
|
384 |
+
|
385 |
+
file = torch.tensor(file).float().unsqueeze(0)
|
386 |
+
|
387 |
+
|
388 |
+
# Final shape should ideally be NumLeadsxTime(or NumLeadsxTime depending on the resolution of the ECG)
|
389 |
+
return file
|
390 |
+
|
391 |
+
|
392 |
+
|
utils/ecg_utils.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# +
|
2 |
+
import base64
|
3 |
+
import struct
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import xmltodict
|
7 |
+
from scipy.ndimage import median_filter as scipy_ndimage_median_filter
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
import pywt
|
10 |
+
# -
|
11 |
+
|
12 |
+
lead_order = [
|
13 |
+
"I",
|
14 |
+
"II",
|
15 |
+
"III",
|
16 |
+
"aVR",
|
17 |
+
"aVL",
|
18 |
+
"aVF",
|
19 |
+
"V1",
|
20 |
+
"V2",
|
21 |
+
"V3",
|
22 |
+
"V4",
|
23 |
+
"V5",
|
24 |
+
"V6",
|
25 |
+
]
|
26 |
+
|
27 |
+
|
28 |
+
def plot_12_lead_ecg(ecg_array, output_filename=None):
|
29 |
+
"""
|
30 |
+
Plot each lead of the 12-lead ECG, and save the plot to a file.
|
31 |
+
All leads share the x axis, but each lead gets its own chart.
|
32 |
+
"""
|
33 |
+
fig, axs = plt.subplots(12, 1, sharex=True, figsize=(16, 9))
|
34 |
+
for lead, lead_name in enumerate(lead_order):
|
35 |
+
axs[lead].plot(ecg_array[:, lead])
|
36 |
+
axs[lead].set_ylabel(str(lead_name))
|
37 |
+
if output_filename is not None:
|
38 |
+
plt.savefig(output_filename)
|
39 |
+
plt.show()
|
40 |
+
plt.close()
|
41 |
+
|
42 |
+
|
43 |
+
def get_median_filter_width(sampling_frequency, duration):
|
44 |
+
res = int(sampling_frequency * duration)
|
45 |
+
res += (res % 2) - 1 # needs to be an odd number
|
46 |
+
return res
|
47 |
+
|
48 |
+
|
49 |
+
def remove_baseline_wander(waveform: np.ndarray, sampling_frequency: int) -> np.ndarray:
|
50 |
+
|
51 |
+
"""
|
52 |
+
Remove baseline wander from ECG NPYs
|
53 |
+
de Chazal et al. used two median filters to remove baseline wander.
|
54 |
+
Median filters take the median value of a sliding window of a specified size
|
55 |
+
One median filter of 200-ms width to remove QRS complexes and P-waves and other of
|
56 |
+
600 ms width to remove T-waves.
|
57 |
+
Do one filter and then the next filter. Then take the result and subtract it form the original signal
|
58 |
+
https://pubmed.ncbi.nlm.nih.gov/15248536/
|
59 |
+
Example of median filter:
|
60 |
+
medfilt([2,6,5,4,0,3,5,7,9,2,0,1], 5) -> [ 2. 4. 4. 4. 4. 4. 5. 5. 5. 2. 1. 0.]
|
61 |
+
>>> np.median([0, 0, 2, 6, 5])
|
62 |
+
2.0
|
63 |
+
>>> np.median([0, 2, 6, 5, 4])
|
64 |
+
4.0
|
65 |
+
|
66 |
+
"""
|
67 |
+
|
68 |
+
# Depending on the sampling frequency, the widths of the convolutional median filters changes
|
69 |
+
filter_widths = [
|
70 |
+
get_median_filter_width(sampling_frequency, duration) for duration in [0.2, 0.6]
|
71 |
+
]
|
72 |
+
filter_widths = np.array(filter_widths, dtype="int")
|
73 |
+
|
74 |
+
# make a copy of orignal signal
|
75 |
+
original_waveform = waveform.copy()
|
76 |
+
|
77 |
+
# apply median filters one by one on top of each other
|
78 |
+
for filter_width in filter_widths:
|
79 |
+
waveform = scipy_ndimage_median_filter(
|
80 |
+
waveform, size=(filter_width, 1), mode="constant", cval=0.0
|
81 |
+
)
|
82 |
+
waveform = original_waveform - waveform # finally subtract from orignal signal
|
83 |
+
return waveform
|
84 |
+
|
85 |
+
|
86 |
+
def wavelet_denoise_signal(
|
87 |
+
waveform: np.ndarray,
|
88 |
+
dwt_transform: str = "bior4.4",
|
89 |
+
dlevels: int = 9,
|
90 |
+
cutoff_low: int = 1,
|
91 |
+
cutoff_high: int = 7,
|
92 |
+
) -> np.ndarray:
|
93 |
+
|
94 |
+
# cutoff_low determines how flat you want overall baseline to be.
|
95 |
+
# Higher means more flat baseline
|
96 |
+
# cutoff_high determines within the small segments how much do
|
97 |
+
# you want to suppress the squiggliness. Lower cutoff_high
|
98 |
+
# suppresses more squiggliness but also suppresses R wave morphology
|
99 |
+
|
100 |
+
coefficients = pywt.wavedec(
|
101 |
+
waveform, dwt_transform, level=dlevels
|
102 |
+
) # wavelet transform 'bior4.4'
|
103 |
+
# scale 0 to cutoff_low
|
104 |
+
for low_cutoff_value in range(0, cutoff_low):
|
105 |
+
coefficients[low_cutoff_value] = np.multiply(
|
106 |
+
coefficients[low_cutoff_value], [0.0]
|
107 |
+
)
|
108 |
+
# scale cutoff_high to end
|
109 |
+
for high_cutoff_value in range(cutoff_high, len(coefficients)):
|
110 |
+
coefficients[high_cutoff_value] = np.multiply(
|
111 |
+
coefficients[high_cutoff_value], [0.0]
|
112 |
+
)
|
113 |
+
waveform = pywt.waverec(coefficients, dwt_transform) # inverse wavelet transform
|
114 |
+
return waveform
|
utils/models.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
import torch
|
3 |
+
from functools import reduce
|
4 |
+
from operator import __add__
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from collections import OrderedDict
|
7 |
+
from typing import Callable, List
|
8 |
+
from torch import Tensor
|
9 |
+
|
10 |
+
|
11 |
+
class EffNet(nn.Module):
|
12 |
+
# lightly retouched version of John's EffNet to add clean support for multiple output
|
13 |
+
# layer designs as well as single-lead inputs
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
num_extra_inputs: int = 0,
|
17 |
+
output_neurons: int = 1,
|
18 |
+
channels: List[int] = (32, 16, 24, 40, 80, 112, 192, 320, 1280),
|
19 |
+
depth: List[int] = (1, 2, 2, 3, 3, 3, 3),
|
20 |
+
dilation: int = 2,
|
21 |
+
stride: int = 8,
|
22 |
+
expansion: int = 6,
|
23 |
+
embedding_hook: bool = False,
|
24 |
+
input_channels: int = 1,
|
25 |
+
verbose: bool = False,
|
26 |
+
embedding_shift: bool = False,
|
27 |
+
):
|
28 |
+
super().__init__()
|
29 |
+
|
30 |
+
self.input_channels = input_channels
|
31 |
+
self.channels = channels
|
32 |
+
self.output_nerons = output_neurons
|
33 |
+
|
34 |
+
# backwards compatibility change to prevent the addition of the output_neurons param
|
35 |
+
# from breaking people's existing EffNet initializations
|
36 |
+
if len(self.channels) == 10:
|
37 |
+
self.output_nerons = self.channels[9]
|
38 |
+
print(
|
39 |
+
"DEPRECATION WARNING: instead of controlling the number of output neurons by changing the 10th item in the channels parameter, use the new output_neurons parameter instead."
|
40 |
+
)
|
41 |
+
|
42 |
+
self.depth = depth
|
43 |
+
self.expansion = expansion
|
44 |
+
self.stride = stride
|
45 |
+
self.dilation = dilation
|
46 |
+
self.embedding_hook = embedding_hook
|
47 |
+
self.embedding_shift = embedding_shift
|
48 |
+
|
49 |
+
if verbose:
|
50 |
+
print("\nEffNet Parameters:")
|
51 |
+
print(f"{self.input_channels=}")
|
52 |
+
print(f"{self.channels=}")
|
53 |
+
print(f"{self.output_nerons=}")
|
54 |
+
print(f"{self.depth=}")
|
55 |
+
print(f"{self.expansion=}")
|
56 |
+
print(f"{self.stride=}")
|
57 |
+
print(f"{self.dilation=}")
|
58 |
+
print(f"{self.embedding_hook=}")
|
59 |
+
print("\n")
|
60 |
+
|
61 |
+
self.stage1 = nn.Conv1d(
|
62 |
+
self.input_channels,
|
63 |
+
self.channels[0],
|
64 |
+
kernel_size=3,
|
65 |
+
stride=stride,
|
66 |
+
padding=1,
|
67 |
+
dilation=dilation,
|
68 |
+
) # 1 conv
|
69 |
+
|
70 |
+
self.b0 = nn.BatchNorm1d(self.channels[0])
|
71 |
+
|
72 |
+
self.stage2 = MBConv(
|
73 |
+
self.channels[0], self.channels[1], self.expansion, self.depth[0], stride=2
|
74 |
+
)
|
75 |
+
|
76 |
+
self.stage3 = MBConv(
|
77 |
+
self.channels[1], self.channels[2], self.expansion, self.depth[1], stride=2
|
78 |
+
)
|
79 |
+
|
80 |
+
self.Pool = nn.MaxPool1d(3, stride=1, padding=1)
|
81 |
+
|
82 |
+
self.stage4 = MBConv(
|
83 |
+
self.channels[2], self.channels[3], self.expansion, self.depth[2], stride=2
|
84 |
+
)
|
85 |
+
|
86 |
+
self.stage5 = MBConv(
|
87 |
+
self.channels[3], self.channels[4], self.expansion, self.depth[3], stride=2
|
88 |
+
)
|
89 |
+
|
90 |
+
self.stage6 = MBConv(
|
91 |
+
self.channels[4], self.channels[5], self.expansion, self.depth[4], stride=2
|
92 |
+
)
|
93 |
+
|
94 |
+
self.stage7 = MBConv(
|
95 |
+
self.channels[5], self.channels[6], self.expansion, self.depth[5], stride=2
|
96 |
+
)
|
97 |
+
|
98 |
+
self.stage8 = MBConv(
|
99 |
+
self.channels[6], self.channels[7], self.expansion, self.depth[6], stride=2
|
100 |
+
)
|
101 |
+
|
102 |
+
self.stage9 = nn.Conv1d(self.channels[7], self.channels[8], kernel_size=1)
|
103 |
+
self.AAP = nn.AdaptiveAvgPool1d(1)
|
104 |
+
self.act = nn.ReLU()
|
105 |
+
self.drop = nn.Dropout(p=0.3)
|
106 |
+
self.num_extra_inputs = num_extra_inputs
|
107 |
+
|
108 |
+
self.fc = nn.Linear(self.channels[5] + num_extra_inputs, self.output_nerons)
|
109 |
+
self.fc = nn.Linear(self.channels[8] + num_extra_inputs, self.output_nerons)
|
110 |
+
|
111 |
+
|
112 |
+
self.fc.bias.data[0] = 0.275
|
113 |
+
|
114 |
+
def forward(self, x: Tensor) -> Tensor:
|
115 |
+
if self.num_extra_inputs > 0:
|
116 |
+
x, extra_inputs = x
|
117 |
+
|
118 |
+
x = self.b0(self.stage1(x))
|
119 |
+
x = self.stage2(x)
|
120 |
+
x = self.stage3(x)
|
121 |
+
x = self.Pool(x)
|
122 |
+
x = self.stage4(x)
|
123 |
+
x = self.stage5(x)
|
124 |
+
x = self.stage6(x)
|
125 |
+
x = self.Pool(x)
|
126 |
+
x = self.stage7(x)
|
127 |
+
x = self.stage8(x)
|
128 |
+
x = self.stage9(x)
|
129 |
+
x = self.act(self.AAP(x)[:, :, 0])
|
130 |
+
if self.embedding_hook:
|
131 |
+
return x
|
132 |
+
|
133 |
+
else:
|
134 |
+
|
135 |
+
if self.embedding_shift:
|
136 |
+
delta_embedding_array = np.load('/workspace/imin/applewatch_potassium/delta_embedding_poolaverage_5second_to_5second.npy')
|
137 |
+
delta_embedding_tensor = torch.tensor(delta_embedding_array, device='cuda')
|
138 |
+
x += delta_embedding_tensor
|
139 |
+
|
140 |
+
x = self.drop(x)
|
141 |
+
|
142 |
+
if self.num_extra_inputs > 0:
|
143 |
+
x = torch.cat((x, extra_inputs), 1)
|
144 |
+
|
145 |
+
x = self.fc(x)
|
146 |
+
return x
|
147 |
+
|
148 |
+
|
149 |
+
class Bottleneck(nn.Module):
|
150 |
+
def __init__(
|
151 |
+
self,
|
152 |
+
in_channel: int,
|
153 |
+
out_channel: int,
|
154 |
+
expansion: int,
|
155 |
+
activation: Callable,
|
156 |
+
stride: int = 1,
|
157 |
+
padding: int = 1,
|
158 |
+
):
|
159 |
+
super().__init__()
|
160 |
+
|
161 |
+
self.stride = stride
|
162 |
+
self.conv1 = nn.Conv1d(in_channel, in_channel * expansion, kernel_size=1)
|
163 |
+
self.conv2 = nn.Conv1d(
|
164 |
+
in_channel * expansion,
|
165 |
+
in_channel * expansion,
|
166 |
+
kernel_size=3,
|
167 |
+
groups=in_channel * expansion,
|
168 |
+
padding=padding,
|
169 |
+
stride=stride,
|
170 |
+
)
|
171 |
+
self.conv3 = nn.Conv1d(
|
172 |
+
in_channel * expansion, out_channel, kernel_size=1, stride=1
|
173 |
+
)
|
174 |
+
self.b0 = nn.BatchNorm1d(in_channel * expansion)
|
175 |
+
self.b1 = nn.BatchNorm1d(in_channel * expansion)
|
176 |
+
self.d = nn.Dropout()
|
177 |
+
self.act = activation()
|
178 |
+
|
179 |
+
def forward(self, x: Tensor) -> Tensor:
|
180 |
+
if self.stride == 1:
|
181 |
+
y = self.act(self.b0(self.conv1(x)))
|
182 |
+
y = self.act(self.b1(self.conv2(y)))
|
183 |
+
y = self.conv3(y)
|
184 |
+
y = self.d(y)
|
185 |
+
y = x + y
|
186 |
+
return y
|
187 |
+
else:
|
188 |
+
y = self.act(self.b0(self.conv1(x)))
|
189 |
+
y = self.act(self.b1(self.conv2(y)))
|
190 |
+
y = self.conv3(y)
|
191 |
+
return y
|
192 |
+
|
193 |
+
|
194 |
+
class MBConv(nn.Module):
|
195 |
+
def __init__(
|
196 |
+
self, in_channel, out_channels, expansion, layers, activation=nn.ReLU6, stride=2
|
197 |
+
):
|
198 |
+
super().__init__()
|
199 |
+
|
200 |
+
self.stack = OrderedDict()
|
201 |
+
for i in range(0, layers - 1):
|
202 |
+
self.stack["s" + str(i)] = Bottleneck(
|
203 |
+
in_channel, in_channel, expansion, activation
|
204 |
+
)
|
205 |
+
|
206 |
+
self.stack["s" + str(layers + 1)] = Bottleneck(
|
207 |
+
in_channel, out_channels, expansion, activation, stride=stride
|
208 |
+
)
|
209 |
+
|
210 |
+
self.stack = nn.Sequential(self.stack)
|
211 |
+
|
212 |
+
self.bn = nn.BatchNorm1d(out_channels)
|
213 |
+
|
214 |
+
def forward(self, x: Tensor) -> Tensor:
|
215 |
+
x = self.stack(x)
|
216 |
+
return self.bn(x)
|
217 |
+
|
218 |
+
|
219 |
+
class BasicBlock(nn.Module):
|
220 |
+
expansion = 1
|
221 |
+
|
222 |
+
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
|
223 |
+
super(BasicBlock, self).__init__()
|
224 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
|
225 |
+
self.bn1 = nn.BatchNorm2d(out_channels)
|
226 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, stride=1, padding=1, bias=False)
|
227 |
+
self.bn2 = nn.BatchNorm2d(out_channels * BasicBlock.expansion)
|
228 |
+
self.downsample = downsample
|
229 |
+
self.stride = stride
|
230 |
+
|
231 |
+
def forward(self, x):
|
232 |
+
identity = x
|
233 |
+
|
234 |
+
out = self.conv1(x)
|
235 |
+
out = self.bn1(out)
|
236 |
+
out = F.relu(out)
|
237 |
+
|
238 |
+
out = self.conv2(out)
|
239 |
+
out = self.bn2(out)
|
240 |
+
|
241 |
+
if self.downsample is not None:
|
242 |
+
identity = self.downsample(x)
|
243 |
+
|
244 |
+
out += identity
|
245 |
+
out = F.relu(out)
|
246 |
+
|
247 |
+
return out
|
utils/training_models.py
ADDED
@@ -0,0 +1,447 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import lru_cache
|
2 |
+
from typing import Iterable
|
3 |
+
import pytorch_lightning as pl
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import numpy as np
|
7 |
+
import wandb
|
8 |
+
from sklearn import metrics as skl_metrics
|
9 |
+
import torchvision
|
10 |
+
import os
|
11 |
+
from pathlib import Path
|
12 |
+
import pandas as pd
|
13 |
+
|
14 |
+
|
15 |
+
class TrainingMetric:
|
16 |
+
def __init__(self, metric_func, metric_name, optimum=None):
|
17 |
+
self.func = metric_func
|
18 |
+
self.name = metric_name
|
19 |
+
self.optimum = optimum
|
20 |
+
|
21 |
+
def calc_metric(self, *args, **kwargs):
|
22 |
+
try:
|
23 |
+
return self.func(*args, **kwargs)
|
24 |
+
except ValueError as e:
|
25 |
+
return np.nan
|
26 |
+
|
27 |
+
def __call__(self, y_true, y_pred, labels=None, split=None, step_type=None) -> dict:
|
28 |
+
|
29 |
+
# if y_true is empty
|
30 |
+
if y_true.shape[0] == 0: # TODO: handle other cases
|
31 |
+
m = {
|
32 |
+
f"{step_type}_{split}_{l}_{self.name}": self.calc_metric(None, yp)
|
33 |
+
for yp, l in zip(y_pred.T, labels)
|
34 |
+
}
|
35 |
+
return m
|
36 |
+
|
37 |
+
# Simple 1:1 y_true and y_pred are either shape=(batch, 1) or shape=(batch,)
|
38 |
+
if len(y_pred.shape) == 1 or (y_pred.shape[1] == 1 and y_true.shape[1] == 1):
|
39 |
+
m = {
|
40 |
+
f"{step_type}_{split}_{self.name}": self.calc_metric(
|
41 |
+
y_true.flatten(), y_pred.flatten()
|
42 |
+
)
|
43 |
+
}
|
44 |
+
|
45 |
+
# Multi-binary classification-like y_true and y_pred are shape=(batch, class)
|
46 |
+
elif y_true.shape[1] != 1 and y_pred.shape[1] != 1:
|
47 |
+
m = {
|
48 |
+
f"{step_type}_{split}_{l}_{self.name}": self.calc_metric(yt, yp)
|
49 |
+
for yt, yp, l in zip(y_true.T, y_pred.T, labels)
|
50 |
+
}
|
51 |
+
|
52 |
+
# Multi-class classification-like y_true is shape=(batch, 1) or shape=(batch,) and y_pred is shape=(batch, class)
|
53 |
+
elif (len(y_true.shape) == 1 or y_true.shape[1] == 1) and y_pred.shape[1] != 1:
|
54 |
+
m = {
|
55 |
+
f"{step_type}_{split}_{l}_{self.name}": self.calc_metric(
|
56 |
+
y_true.flatten() == i, yp
|
57 |
+
)
|
58 |
+
for i, (yp, l) in enumerate(
|
59 |
+
zip(y_pred.T, labels)
|
60 |
+
) # turn multi class into binary classification
|
61 |
+
}
|
62 |
+
|
63 |
+
return m
|
64 |
+
|
65 |
+
|
66 |
+
class CumulativeMetric(TrainingMetric):
|
67 |
+
|
68 |
+
"""Wraps a metric to apply to every class in output and calculate a cumulative value (like mean AUC)"""
|
69 |
+
|
70 |
+
def __init__(
|
71 |
+
self,
|
72 |
+
training_metric: TrainingMetric,
|
73 |
+
metric_func,
|
74 |
+
metric_name="cumulative",
|
75 |
+
optimum=None,
|
76 |
+
):
|
77 |
+
optimum = optimum or training_metric.optimum
|
78 |
+
metric_name = f"{metric_name}_{training_metric.name}"
|
79 |
+
super().__init__(metric_func, metric_name, optimum)
|
80 |
+
self.base_metric = training_metric
|
81 |
+
|
82 |
+
def __call__(self, y_true, y_pred, labels=None, split=None, step_type=None):
|
83 |
+
vals = list(self.base_metric(y_true, y_pred, labels, split, step_type).values())
|
84 |
+
|
85 |
+
m = {f"{step_type}_{split}_{self.name}": self.func(vals)}
|
86 |
+
return m
|
87 |
+
|
88 |
+
|
89 |
+
r2_metric = TrainingMetric(skl_metrics.r2_score, "r2", optimum="max")
|
90 |
+
roc_auc_metric = TrainingMetric(skl_metrics.roc_auc_score, "roc_auc", optimum="max")
|
91 |
+
accuracy_metric = TrainingMetric(skl_metrics.accuracy_score, "accuracy", optimum="max")
|
92 |
+
mae_metric = TrainingMetric(skl_metrics.mean_absolute_error, "mae", optimum="min")
|
93 |
+
pred_value_mean_metric = TrainingMetric(
|
94 |
+
lambda y_true, y_pred: np.mean(y_pred), "pred_value_mean"
|
95 |
+
)
|
96 |
+
pred_value_std_metric = TrainingMetric(
|
97 |
+
lambda y_true, y_pred: np.std(y_pred), "pred_value_std"
|
98 |
+
)
|
99 |
+
|
100 |
+
|
101 |
+
class TrainingModel(pl.LightningModule):
|
102 |
+
def __init__(
|
103 |
+
self,
|
104 |
+
model,
|
105 |
+
metrics: Iterable[TrainingMetric] = dict(),
|
106 |
+
tracked_metric=None,
|
107 |
+
early_stop_epochs=10,
|
108 |
+
checkpoint_every_epoch=False,
|
109 |
+
checkpoint_every_n_steps=None,
|
110 |
+
index_labels=None,
|
111 |
+
save_predictions_path=None,
|
112 |
+
lr=0.01,
|
113 |
+
):
|
114 |
+
super().__init__()
|
115 |
+
self.epoch_preds = {"train": ([], []), "val": ([], [])}
|
116 |
+
self.epoch_losses = {"train": [], "val": []}
|
117 |
+
self.metrics = {}
|
118 |
+
self.metric_funcs = {m.name: m for m in metrics}
|
119 |
+
self.tracked_metric = f"epoch_val_{tracked_metric}"
|
120 |
+
self.best_tracked_metric = None
|
121 |
+
self.early_stop_epochs = early_stop_epochs
|
122 |
+
self.checkpoint_every_epoch = checkpoint_every_epoch
|
123 |
+
self.checkpoint_every_n_steps = checkpoint_every_n_steps
|
124 |
+
self.metrics["epochs_since_last_best"] = 0
|
125 |
+
self.m = model
|
126 |
+
self.training_steps = 0
|
127 |
+
self.steps_since_checkpoint = 0
|
128 |
+
self.labels = index_labels
|
129 |
+
if self.labels is not None and isinstance(self.labels, str):
|
130 |
+
self.labels = [self.labels]
|
131 |
+
if isinstance(save_predictions_path, str):
|
132 |
+
save_predictions_path = Path(save_predictions_path)
|
133 |
+
self.save_predictions_path = save_predictions_path
|
134 |
+
self.lr = lr
|
135 |
+
self.step_loss = (None, None)
|
136 |
+
|
137 |
+
self.log_path = Path(wandb.run.dir) if wandb.run is not None else None
|
138 |
+
|
139 |
+
def configure_optimizers(self):
|
140 |
+
return torch.optim.AdamW(self.parameters(), self.lr)
|
141 |
+
|
142 |
+
def forward(self, x: dict):
|
143 |
+
# if anything other than 'primary_input' and 'extra_inputs' is used,
|
144 |
+
# this function must be overridden
|
145 |
+
if 'extra_inputs' in x:
|
146 |
+
return self.m((x['primary_input'], x['extra_inputs']))
|
147 |
+
else:
|
148 |
+
return self.m(x['primary_input'])
|
149 |
+
|
150 |
+
def step(self, batch, step_type='train'):
|
151 |
+
batch = self.prepare_batch(batch)
|
152 |
+
y_pred = self.forward(batch)
|
153 |
+
|
154 |
+
if step_type != 'predict':
|
155 |
+
if 'labels' not in batch:
|
156 |
+
batch['labels'] = torch.empty(0)
|
157 |
+
loss = self.loss_func(y_pred, batch['labels'])
|
158 |
+
if torch.isnan(loss):
|
159 |
+
raise ValueError(loss)
|
160 |
+
|
161 |
+
self.log_step(step_type, batch['labels'], y_pred, loss)
|
162 |
+
|
163 |
+
return loss
|
164 |
+
else:
|
165 |
+
return y_pred
|
166 |
+
|
167 |
+
def prepare_batch(self, batch):
|
168 |
+
return batch
|
169 |
+
|
170 |
+
def training_step(self, batch, i):
|
171 |
+
return self.step(batch, "train")
|
172 |
+
|
173 |
+
def validation_step(self, batch, i):
|
174 |
+
return self.step(batch, "val")
|
175 |
+
|
176 |
+
def predict_step(self, batch, *args):
|
177 |
+
y_pred = self.step(batch, "predict")
|
178 |
+
return {"filename": batch["filename"], "prediction": y_pred.cpu().numpy()}
|
179 |
+
|
180 |
+
def on_predict_epoch_end(self, results):
|
181 |
+
|
182 |
+
for i, predict_results in enumerate(results):
|
183 |
+
filename_df = pd.DataFrame(
|
184 |
+
{
|
185 |
+
"filename": np.concatenate(
|
186 |
+
[batch["filename"] for batch in predict_results]
|
187 |
+
)
|
188 |
+
}
|
189 |
+
)
|
190 |
+
|
191 |
+
if self.labels is not None:
|
192 |
+
columns = [f"{class_name}_preds" for class_name in self.labels]
|
193 |
+
else:
|
194 |
+
columns = ["preds"]
|
195 |
+
outputs_df = pd.DataFrame(
|
196 |
+
np.concatenate(
|
197 |
+
[batch["prediction"] for batch in predict_results], axis=0
|
198 |
+
),
|
199 |
+
columns=columns,
|
200 |
+
)
|
201 |
+
|
202 |
+
prediction_df = pd.concat([filename_df, outputs_df], axis=1)
|
203 |
+
|
204 |
+
dataloader = self.trainer.predict_dataloaders[i]
|
205 |
+
manifest = dataloader.dataset.manifest
|
206 |
+
prediction_df = prediction_df.merge(manifest, on="filename", how="outer")
|
207 |
+
if wandb.run is not None:
|
208 |
+
prediction_df.to_csv(
|
209 |
+
Path(wandb.run.dir).parent
|
210 |
+
/ "data"
|
211 |
+
/ f"dataloader_{i}_potassium_predictions.csv",
|
212 |
+
index=False,
|
213 |
+
)
|
214 |
+
if self.save_predictions_path is not None:
|
215 |
+
|
216 |
+
if ".csv" in self.save_predictions_path.name:
|
217 |
+
prediction_df.to_csv(
|
218 |
+
self.save_predictions_path.parent
|
219 |
+
/ self.save_predictions_path.name.replace(".csv", f"_{i}_.csv"),
|
220 |
+
index=False,
|
221 |
+
)
|
222 |
+
else:
|
223 |
+
prediction_df.to_csv(
|
224 |
+
self.save_predictions_path / f"dataloader_{i}_potassium_predictions.csv",
|
225 |
+
index=False,
|
226 |
+
)
|
227 |
+
|
228 |
+
if wandb.run is None and self.save_predictions_path is None:
|
229 |
+
print(
|
230 |
+
"WandB is not active and self.save_predictions_path is None. Predictions will be saved to the directory this script is being run in."
|
231 |
+
)
|
232 |
+
prediction_df.to_csv(f"dataloader_{i}_potassium_predictions.csv", index=False)
|
233 |
+
|
234 |
+
def log_step(self, step_type, labels, output_tensor, loss):
|
235 |
+
self.step_loss = (step_type, loss.detach().item())
|
236 |
+
self.epoch_preds[step_type][0].append(labels.detach().cpu().numpy())
|
237 |
+
self.epoch_preds[step_type][1].append(output_tensor.detach().cpu().numpy())
|
238 |
+
self.epoch_losses[step_type].append(loss.detach().item())
|
239 |
+
if step_type == "train":
|
240 |
+
self.training_steps += 1
|
241 |
+
self.steps_since_checkpoint += 1
|
242 |
+
if (
|
243 |
+
self.checkpoint_every_n_steps is not None
|
244 |
+
and self.steps_since_checkpoint > self.checkpoint_every_n_steps
|
245 |
+
):
|
246 |
+
self.steps_since_checkpoint = 0
|
247 |
+
self.checkpoint_weights(f"step_{self.training_steps}")
|
248 |
+
|
249 |
+
def checkpoint_weights(self, name=""):
|
250 |
+
if wandb.run is not None:
|
251 |
+
weights_path = Path(wandb.run.dir).parent / "weights"
|
252 |
+
if not weights_path.is_dir():
|
253 |
+
weights_path.mkdir()
|
254 |
+
torch.save(self.state_dict(), weights_path / f"model_{name}.pt")
|
255 |
+
else:
|
256 |
+
print("Did not checkpoint model. wandb not initialized.")
|
257 |
+
|
258 |
+
def validation_epoch_end(self, preds):
|
259 |
+
|
260 |
+
# Save weights
|
261 |
+
self.metrics["epoch"] = self.current_epoch
|
262 |
+
if self.checkpoint_every_epoch:
|
263 |
+
self.checkpoint_weights(f"epoch_{self.current_epoch}")
|
264 |
+
|
265 |
+
# Calculate metrics
|
266 |
+
for m_type in ["train", "val"]:
|
267 |
+
|
268 |
+
y_true, y_pred = self.epoch_preds[m_type]
|
269 |
+
if len(y_true) == 0 or len(y_pred) == 0:
|
270 |
+
continue
|
271 |
+
y_true, y_pred = np.concatenate(y_true), np.concatenate(y_pred)
|
272 |
+
|
273 |
+
self.metrics[f"epoch_{m_type}_loss"] = np.mean(self.epoch_losses[m_type])
|
274 |
+
for m in self.metric_funcs.values():
|
275 |
+
self.metrics.update(
|
276 |
+
m(
|
277 |
+
y_true,
|
278 |
+
y_pred,
|
279 |
+
labels=self.labels,
|
280 |
+
split=m_type,
|
281 |
+
step_type="epoch",
|
282 |
+
)
|
283 |
+
)
|
284 |
+
|
285 |
+
# Reset predictions
|
286 |
+
self.epoch_losses[m_type] = []
|
287 |
+
self.epoch_preds[m_type] = ([], [])
|
288 |
+
|
289 |
+
# Check if new best epoch
|
290 |
+
if self.metrics is not None and self.tracked_metric is not None:
|
291 |
+
if self.tracked_metric == "epoch_val_loss":
|
292 |
+
metric_optimization = "min"
|
293 |
+
else:
|
294 |
+
metric_optimization = self.metric_funcs[
|
295 |
+
self.tracked_metric.replace("epoch_val_", "")
|
296 |
+
].optimum
|
297 |
+
if (
|
298 |
+
self.metrics[self.tracked_metric] is not None
|
299 |
+
and (
|
300 |
+
self.best_tracked_metric is None
|
301 |
+
or (
|
302 |
+
metric_optimization == "max"
|
303 |
+
and self.metrics[self.tracked_metric] > self.best_tracked_metric
|
304 |
+
)
|
305 |
+
or (
|
306 |
+
metric_optimization == "min"
|
307 |
+
and self.metrics[self.tracked_metric] < self.best_tracked_metric
|
308 |
+
)
|
309 |
+
)
|
310 |
+
and self.current_epoch > 0
|
311 |
+
):
|
312 |
+
print(
|
313 |
+
f"New best epoch! {self.tracked_metric}={self.metrics[self.tracked_metric]}, epoch={self.current_epoch}"
|
314 |
+
)
|
315 |
+
self.checkpoint_weights(f"best_{self.tracked_metric}")
|
316 |
+
self.metrics["epochs_since_last_best"] = 0
|
317 |
+
self.best_tracked_metric = self.metrics[self.tracked_metric]
|
318 |
+
else:
|
319 |
+
self.metrics["epochs_since_last_best"] += 1
|
320 |
+
if self.metrics["epochs_since_last_best"] >= self.early_stop_epochs:
|
321 |
+
raise KeyboardInterrupt("Early stopping condition met")
|
322 |
+
|
323 |
+
# Log to w&b
|
324 |
+
if wandb.run is not None:
|
325 |
+
wandb.log(self.metrics)
|
326 |
+
|
327 |
+
|
328 |
+
class RegressionModel(TrainingModel):
|
329 |
+
def __init__(
|
330 |
+
self,
|
331 |
+
model,
|
332 |
+
metrics=(r2_metric, mae_metric, pred_value_mean_metric, pred_value_std_metric),
|
333 |
+
tracked_metric="mae",
|
334 |
+
early_stop_epochs=10,
|
335 |
+
checkpoint_every_epoch=False,
|
336 |
+
checkpoint_every_n_steps=None,
|
337 |
+
index_labels=None,
|
338 |
+
save_predictions_path=None,
|
339 |
+
lr=0.01,
|
340 |
+
):
|
341 |
+
super().__init__(
|
342 |
+
model=model,
|
343 |
+
metrics=metrics,
|
344 |
+
tracked_metric=tracked_metric,
|
345 |
+
early_stop_epochs=early_stop_epochs,
|
346 |
+
checkpoint_every_epoch=checkpoint_every_epoch,
|
347 |
+
checkpoint_every_n_steps=checkpoint_every_n_steps,
|
348 |
+
index_labels=index_labels,
|
349 |
+
save_predictions_path=save_predictions_path,
|
350 |
+
lr=lr,
|
351 |
+
)
|
352 |
+
self.loss_func = nn.MSELoss()
|
353 |
+
|
354 |
+
def prepare_batch(self, batch):
|
355 |
+
if "labels" in batch and len(batch["labels"].shape) == 1:
|
356 |
+
batch["labels"] = batch["labels"][:, None]
|
357 |
+
return batch
|
358 |
+
|
359 |
+
|
360 |
+
class BinaryClassificationModel(TrainingModel):
|
361 |
+
def __init__(
|
362 |
+
self,
|
363 |
+
model,
|
364 |
+
metrics=(roc_auc_metric, CumulativeMetric(roc_auc_metric, np.nanmean, "mean")),
|
365 |
+
tracked_metric="mean_roc_auc",
|
366 |
+
early_stop_epochs=10,
|
367 |
+
checkpoint_every_epoch=False,
|
368 |
+
checkpoint_every_n_steps=None,
|
369 |
+
index_labels=None,
|
370 |
+
save_predictions_path=None,
|
371 |
+
lr=0.01,
|
372 |
+
):
|
373 |
+
super().__init__(
|
374 |
+
model=model,
|
375 |
+
metrics=metrics,
|
376 |
+
tracked_metric=tracked_metric,
|
377 |
+
early_stop_epochs=early_stop_epochs,
|
378 |
+
checkpoint_every_epoch=checkpoint_every_epoch,
|
379 |
+
checkpoint_every_n_steps=checkpoint_every_n_steps,
|
380 |
+
index_labels=index_labels,
|
381 |
+
save_predictions_path=save_predictions_path,
|
382 |
+
lr=lr,
|
383 |
+
)
|
384 |
+
self.loss_func = nn.BCEWithLogitsLoss()
|
385 |
+
|
386 |
+
def prepare_batch(self, batch):
|
387 |
+
if "labels" in batch and len(batch["labels"].shape) == 1:
|
388 |
+
batch["labels"] = batch["labels"][:, None]
|
389 |
+
return batch
|
390 |
+
|
391 |
+
|
392 |
+
# Addresses bug caused by labels from a single column in a manifest being delivered as Bx1,
|
393 |
+
# but nn.CrossEntropyLoss wants a simple list of length B.
|
394 |
+
class SqueezeCrossEntropyLoss(nn.Module):
|
395 |
+
def __init__(self):
|
396 |
+
super().__init__()
|
397 |
+
self.cross_entropy = nn.CrossEntropyLoss()
|
398 |
+
|
399 |
+
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor):
|
400 |
+
return self.cross_entropy(y_pred, y_true.squeeze(dim=-1))
|
401 |
+
|
402 |
+
|
403 |
+
class MultiClassificationModel(TrainingModel):
|
404 |
+
def __init__(
|
405 |
+
self,
|
406 |
+
model,
|
407 |
+
metrics=(roc_auc_metric, CumulativeMetric(roc_auc_metric, np.mean, "mean")),
|
408 |
+
tracked_metric="mean_roc_auc",
|
409 |
+
early_stop_epochs=10,
|
410 |
+
checkpoint_every_epoch=False,
|
411 |
+
checkpoint_every_n_steps=None,
|
412 |
+
index_labels=None,
|
413 |
+
save_predictions_path=None,
|
414 |
+
lr=0.01,
|
415 |
+
):
|
416 |
+
metrics = [*metrics]
|
417 |
+
super().__init__(
|
418 |
+
model=model,
|
419 |
+
metrics=metrics,
|
420 |
+
tracked_metric=tracked_metric,
|
421 |
+
early_stop_epochs=early_stop_epochs,
|
422 |
+
checkpoint_every_epoch=checkpoint_every_epoch,
|
423 |
+
checkpoint_every_n_steps=checkpoint_every_n_steps,
|
424 |
+
index_labels=index_labels,
|
425 |
+
save_predictions_path=save_predictions_path,
|
426 |
+
lr=lr,
|
427 |
+
)
|
428 |
+
self.loss_func = SqueezeCrossEntropyLoss()
|
429 |
+
|
430 |
+
def prepare_batch(self, batch):
|
431 |
+
if "labels" in batch:
|
432 |
+
batch["labels"] = batch["labels"].long()
|
433 |
+
batch["primary_input"] = batch["primary_input"].float()
|
434 |
+
return batch
|
435 |
+
|
436 |
+
|
437 |
+
if __name__ == "__main__":
|
438 |
+
os.environ["WANDB_MODE"] = "offline"
|
439 |
+
|
440 |
+
m = torchvision.models.video.r2plus1d_18()
|
441 |
+
m.fc = nn.Linear(512, 1)
|
442 |
+
training_model = RegressionModel(m)
|
443 |
+
x = torch.randn((4, 3, 8, 112, 112))
|
444 |
+
y = m(x)
|
445 |
+
print(y.shape)
|
446 |
+
|
447 |
+
|