outofray commited on
Commit
ff8e6c1
1 Parent(s): 4e12dbc

copy data from repo

Browse files
README.md CHANGED
@@ -1,3 +1,85 @@
1
- ---
2
- license: afl-3.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Kardio-Net: A Deep Learning Model for Predicting Serum Potassium Levels from Apple Watch ECG in ESRD Patients
2
+
3
+
4
+ ### I. Input preparation
5
+
6
+
7
+ 1. Prerequisites
8
+
9
+ Ensure that all ECG files are saved in *.npy format with the following shapes:
10
+ - (5000, 12) for 12-lead ECGs
11
+ - (500 * 30 seconds, 1) for Apple Watch ECGs
12
+
13
+ Store these files in a single flat directory, and include a manifest CSV file in the same location to accompany
14
+ them.
15
+
16
+
17
+ 2. Manifest File Format
18
+
19
+ The manifest CSV file should include a header.
20
+ Each row corresponds to one ECG file with the following columns:
21
+
22
+ - filename: Name of the .npy file (without the extension).
23
+ - label: serum potassium label.
24
+
25
+
26
+
27
+ ### II. Inference
28
+
29
+
30
+ ### For 12-Lead ECG
31
+
32
+ <!-- #region -->
33
+ 1. use predict_potassium_12lead.py to get the potassium level prediction.
34
+
35
+
36
+ 2. Edit data_path and manifest_path and run the predict_potassium.py script.
37
+
38
+
39
+ 3. Upon completion, a file named "dataloader_0_predictions.csv" will be saved in the same directory. This file contains the inference results "preds" from the model.
40
+
41
+
42
+ 4. Use generate_result.py to get the performance metric and figure.
43
+ <!-- #endregion -->
44
+
45
+ ### For Single Lead ECG
46
+
47
+
48
+ ### A. ECG preprocessing and segmentation
49
+
50
+ <!-- #region -->
51
+ 1. Use preprocessing.py for denoise, normalize, and segment ECG into 5-second for input
52
+
53
+
54
+ 2. Set the following paths in preprocessing.py:
55
+
56
+ - raw_directory = "path/to/raw_data_directory" #raw ecg folder for target task
57
+ - output_directory = "path/to/output_directory" #output folder for normalize ECG
58
+ - manifest_path = "/path/to/manifest.csv" # Manifest file path
59
+ - output_path = "path/to/output_path" # Output path for segmented ECGs
60
+
61
+
62
+ 3. Execute predict.py.
63
+
64
+
65
+ 4. If the ECG files were already normalize, can execute the segmentation function only.
66
+ <!-- #endregion -->
67
+
68
+ ### B. potassium regression model
69
+
70
+ <!-- #region -->
71
+ 1. use predict_potassium_1lead.py to get the potassium level prediction.
72
+
73
+
74
+ 2. Edit data_path and manifest_path and run the predict_potassium.py script.
75
+
76
+
77
+ 3. Upon completion, a file named "dataloader_0_predictions.csv" will be saved in the same directory. This file contains the inference results "preds" from the model.
78
+
79
+
80
+ 4. Use generate_result.py to get the performance metric and figure.
81
+ <!-- #endregion -->
82
+
83
+ ```python
84
+
85
+ ```
noise_classifier.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # An example of how to run inference using a model trained with cvair.
2
+ # We want to use a pretrained model to make predictions on a dataset of new examples.
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ import torch
7
+ from pytorch_lightning import Trainer
8
+ from torch.utils.data import DataLoader
9
+ from utils.datasets import ECGSingleLeadDataset
10
+ from utils.models import EffNet
11
+ from utils.training_models import BinaryClassificationModel
12
+
13
+
14
+ def append_noise_predictions_to_manifest(data_path, manifest_path, weights_path):
15
+ # Initialize a dataset
16
+ test_ds = ECGSingleLeadDataset(
17
+ data_path=data_path,
18
+ manifest_path=manifest_path,
19
+ update_manifest_func=None,
20
+ )
21
+
22
+ # Wrap the dataset in a dataloader
23
+ test_dl = DataLoader(
24
+ test_ds,
25
+ num_workers=16,
26
+ batch_size=512,
27
+ drop_last=False,
28
+ shuffle=False
29
+ )
30
+
31
+ # Initialize the backbone model
32
+ backbone = EffNet(input_channels=1, output_neurons=1)
33
+
34
+ # Pass the backbone to a wrapper
35
+ model = BinaryClassificationModel(backbone)
36
+
37
+ # Load the pretrained weights
38
+ weights = torch.load(weights_path)
39
+ model.load_state_dict(weights)
40
+
41
+ # Initialize a Trainer object
42
+ trainer = Trainer(accelerator="gpu", devices=1)
43
+
44
+ # Run inference
45
+ trainer.predict(model, dataloaders=test_dl)
46
+
47
+ # Read the predictions CSV file
48
+ df = pd.read_csv('dataloader_0_predictions.csv')
49
+
50
+ # Normalize predictions
51
+ max_preds = df['preds'].max()
52
+ min_preds = df['preds'].min()
53
+ df['noise_preds_normal'] = (df['preds'] - min_preds) / (max_preds - min_preds)
54
+
55
+ # Drop the original predictions column
56
+ df.drop(columns='preds', inplace=True)
57
+
58
+ # Save the modified dataframe to the original manifest path
59
+ df.to_csv(manifest_path)
60
+
61
+
62
+
63
+ if __name__ == "__main__":
64
+
65
+ data_path="/your/wearable/ecg/data path/"
66
+
67
+ manifest_path="manifest"
68
+
69
+ weights_path = "model_noise_classifier.pt"
70
+
71
+ append_noise_predictions_to_manifest(data_path, manifest_path, weights_path)
72
+
predict_potassium_12lead.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from pytorch_lightning import Trainer
3
+ from torch.utils.data import DataLoader
4
+ from utils.datasets import ECGDataset
5
+ from utils.models import EffNet
6
+ from utils.training_models import RegressionModel
7
+
8
+
9
+ # +
10
+ # This is the path where your data samples are stored.
11
+ data_path = "your/ecg/data/folder"
12
+
13
+ # This is the path where your manifest, containing filenames for inference to be run on, is stored.
14
+ manifest_path = 'your/manifest/path'
15
+ # -
16
+
17
+
18
+ # Initialize a dataset that contains the examples you want to run prediction on.
19
+ test_ds = ECGDataset(
20
+ split="test",
21
+ data_path=data_path,
22
+ manifest_path=manifest_path,
23
+ update_manifest_func=None,
24
+ )
25
+
26
+ # Wrap the dataset in a dataloader to handle batching and multithreading.
27
+ test_dl = DataLoader(
28
+ test_ds,
29
+ num_workers=16,
30
+ batch_size=256,
31
+ drop_last=False,
32
+ shuffle=False
33
+ )
34
+
35
+ # Initialize the "backbone", the core model weights that will act on the data.
36
+ backbone = EffNet(input_channels=12, output_neurons=1)
37
+
38
+ model = RegressionModel(backbone)
39
+
40
+ weights = torch.load("model_12_lead.pt")
41
+ print(model.load_state_dict(weights))
42
+
43
+ # +
44
+ trainer = Trainer(accelerator="gpu", devices=1)
45
+
46
+ trainer.predict(model, dataloaders=test_dl)
predict_potassium_1lead.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from pytorch_lightning import Trainer
3
+ from torch.utils.data import DataLoader
4
+ from utils.datasets import ECGSingleLeadDataset
5
+ from utils.models import EffNet
6
+ from utils.training_models import RegressionModel
7
+
8
+ # +
9
+ # This is the path where your data samples are stored.
10
+ data_path = "your/ecg/data/folder"
11
+
12
+ # This is the path where your manifest, containing filenames for inference to be run on, is stored.
13
+ manifest_path = 'your/manifest/path'
14
+ # -
15
+
16
+ # Initialize a dataset that contains the examples you want to run prediction on.
17
+ test_ds = ECGSingleLeadDataset(
18
+ data_path=data_path,
19
+ manifest_path=manifest_path,
20
+ update_manifest_func=None,
21
+ )
22
+
23
+ # Wrap the dataset in a dataloader to handle batching and multithreading.
24
+ test_dl = DataLoader(
25
+ test_ds,
26
+ num_workers=16,
27
+ batch_size=512,
28
+ drop_last=False,
29
+ shuffle=False
30
+ )
31
+
32
+ # +
33
+ backbone = EffNet()
34
+
35
+ model = RegressionModel(backbone)
36
+ # -
37
+
38
+ weights = torch.load("model_single_lead_5seconds_length.pt")
39
+ print(model.load_state_dict(weights))
40
+
41
+ # +
42
+ trainer = Trainer(accelerator="gpu", devices=1)
43
+
44
+ trainer.predict(model, dataloaders=test_dl)
preprocessing.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+ # %%
4
+ import numpy as np
5
+ import os
6
+ from pathlib import Path
7
+ from tqdm import tqdm
8
+ from utils.ecg_utils import (
9
+ remove_baseline_wander,
10
+ wavelet_denoise_signal,
11
+ plot_12_lead_ecg,
12
+ )
13
+
14
+
15
+ # %%
16
+ def calculate_means_stds(npy_directory, n):
17
+
18
+ npy_directory = Path(npy_directory)
19
+ filelist = os.listdir(npy_directory)
20
+ np.random.shuffle(filelist)
21
+
22
+ full_batch = np.zeros((n, 5000, 12))
23
+ count = 0
24
+
25
+ for i, npy_filename in enumerate(tqdm(filelist[:n])):
26
+ npy_filepath = npy_directory / npy_filename
27
+ ekg_numpy_array = np.load(npy_filepath)
28
+
29
+ if ekg_numpy_array.shape[0] != 5000:
30
+ continue
31
+
32
+ full_batch[count] = ekg_numpy_array
33
+ count += 1
34
+
35
+ full_batch = full_batch[:count] # Trim the array to remove unused entries
36
+ ecg_means = np.mean(full_batch, axis=(0, 1))
37
+ ecg_stds = np.std(full_batch, axis=(0, 1))
38
+
39
+ if ecg_means.shape[0] == ecg_stds.shape[0] == 12:
40
+ print('Shape of mean and std for ECG normalization are correct!')
41
+
42
+ return ecg_means, ecg_stds
43
+
44
+
45
+ # %%
46
+ # run the function on a list of filenames.
47
+ def ecg_denoising(
48
+ raw_directory=raw_directory,
49
+ output_directory=output_directory,
50
+ ecg_means = ecg_means,
51
+ ecg_stds = ecg_stds
52
+ ):
53
+
54
+ filelist = os.listdir(raw_directory)
55
+
56
+ for i, filename in enumerate(tqdm(filelist[:n])):
57
+
58
+ # Signal processing
59
+ raw_directory = Path(raw_directory)
60
+ ecg_filepath = raw_directory / filename
61
+ ecg_numpy_array = np.load(ecg_filepath)
62
+ # 1. Wandering baseline removal
63
+ ecg_numpy_array = remove_baseline_wander(
64
+ ecg_numpy_array, sampling_frequency=sampling_frequency
65
+ )
66
+
67
+ # Discrete wavelet transform denoising
68
+ for lead in range(12):
69
+ ecg_numpy_array[:, lead] = wavelet_denoise_signal(ecg_numpy_array[:, lead])
70
+
71
+ # Lead-wise normalization with precomputed means and standard deviations
72
+ ecg_numpy_array = (ecg_numpy_array - ecg_means) / ecg_stds
73
+
74
+ np.save(output_directory / filename, ecg_numpy_array)
75
+
76
+ return True
77
+
78
+
79
+ # %%
80
+ def segmentation(data_path, manifest_path, output_path, length, steps):
81
+ manifest = pd.read_csv(manifest_path)
82
+
83
+ data = []
84
+ print('Staring segmenting ECG......')
85
+ for index in tqdm(range(manifest.shape[0])):
86
+ mrn = manifest['MRN'].iloc[index] #MRN as column name for medical record number
87
+ filename = manifest['filename'].iloc[index] #filename as column name for ecg filename
88
+ k = manifest['TEST_RSLT'].iloc[index] #TEST_RSLT as column name for potassium level
89
+
90
+ ecg_array = np.load(os.path.join(data_path, filename))
91
+ ecg_array = ecg_array[:, 0] # assume lead I is the first lead in npy file
92
+
93
+ # Loop through every second as start point:
94
+ for start in range(0, len(ecg_array), 500*steps):
95
+ end = start + 500 * length # 500 points for each seconds
96
+
97
+ if start >= 0 and end <= len(ecg_array):
98
+ sample = ecg_array[start:end]
99
+
100
+ if len(sample) == 500 * length:
101
+ data.append({'mrn': mrn, 'original_filename': filename, 'ecg': sample, 'label': k})
102
+ else:
103
+ print(f'Different sample size for {filename}: {len(sample)}')
104
+
105
+ df = pd.DataFrame(data)
106
+ df['filename'] = None
107
+
108
+ if not os.path.exists(output_path):
109
+ os.makedirs(output_path)
110
+
111
+ print('Saving segmented ECG......')
112
+ for index, row in df.iterrows():
113
+ original_filename = row['original_filename']
114
+ ecg_array = row['ecg']
115
+ new_file_name = f"{original_filename.rsplit('.', 1)[0]}_{index+1}.npy"
116
+ new_file_path = os.path.join(output_path, new_file_name)
117
+
118
+ df.at[index, 'filename'] = new_file_name
119
+ np.save(new_file_path, ecg_array)
120
+
121
+ df.drop(columns=['ecg'], inplace=True)
122
+ df.to_csv(f'{output_path}/{length}seconds_length_{steps}seconds_step_ecg_manifest.csv')
123
+
124
+
125
+ # %%
126
+ if __name__ == "__main__":
127
+ npy_directory = "ecg folder for entire database"
128
+ n = 100000 # the number of ecg for calculating mean and std
129
+
130
+ print('Calculating ECG means and stds........')
131
+ ecg_means, ecg_stds = calculate_means_stds(npy_directory, n)
132
+
133
+ raw_directory = "path/to/raw_data_directory" #raw ecg folder for target task
134
+ output_directory = "path/to/output_directory" #output ecg folder for target task
135
+
136
+ print('Denoising and Normalizing ECGs........')
137
+ ecg_denoising(raw_directory, output_directory, ecg_means, ecg_stds)
138
+
139
+
140
+ data_path = output_directory # Output directory from the above step
141
+ manifest_path = "/path/to/manifest.csv" # Manifest file path
142
+ output_path = "path/to/output_path" # Output path for segmented ECGs
143
+ length = 5 # Length of each segment in seconds
144
+ steps = 1 # Number of seconds step
145
+
146
+ process_ecg(data_path, manifest_path, output_path, length, steps)
147
+
utils/datasets.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from pathlib import Path
4
+ from typing import Callable, List, Tuple, Optional, Iterable, Dict, Union
5
+ from typing_extensions import TypedDict, Unpack, Required, NotRequired
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+ import torch
10
+ import torchvision.transforms.functional as TF
11
+ import wandb
12
+ from torch.utils.data import Dataset
13
+ import torch.nn as nn
14
+
15
+
16
+ def format_mrn(mrn):
17
+ return str(mrn).strip().zfill(20)
18
+
19
+
20
+ class CedarsDatasetTypeAnnotations(TypedDict, total=False):
21
+ """A dummy class used to make IDE autocomplete and tooltips work properly with how we pass **kwargs through in subclasses of CedarsDataset."""
22
+ data_path: Required[Union[Path, str]]
23
+ manifest_path: Required[Union[Path, str]]
24
+ split: NotRequired[str]
25
+ labels: NotRequired[Iterable[str]]
26
+ extra_inputs: NotRequired[Iterable[str]]
27
+ update_manifest_func: NotRequired[Callable[[pd.DataFrame], pd.DataFrame]]
28
+ subsample: NotRequired[Union[Path, str]]
29
+ augmentations: NotRequired[Union[Iterable[Callable[[torch.Tensor], torch.Tensor]], Callable[[dict], dict], nn.Module]]
30
+ apply_augmentations_to: NotRequired[Iterable[str]]
31
+ verify_existing: NotRequired[bool]
32
+ drop_na_labels: NotRequired[bool]
33
+ verbose: NotRequired[bool]
34
+
35
+
36
+ class CedarsDataset(Dataset):
37
+ """
38
+ Generic parent class for several differnet kinds of common datasets we use here at Cedars CVAIR.
39
+
40
+ Expects to be used in a scenario where you have a big folder full of input examples (videos, ecgs, 3d arrays, images, etc.) and a big CSV that contains metadata and labels for those examples, called a 'manifest'.
41
+
42
+ Args:
43
+ data_path: Path to a directory full of files you want the dataset to load from.
44
+ manifest_path: Path to a CSV or Parquet file containing the names, labels, and/or metadata of your files.
45
+ split: Optional. Allows user to select which split of the manifest to use, assuming the presence of a categorical 'split' column. Defaults to None, meaning that the entire manifest is used by default.
46
+ extra_inputs: Optional. A list of column names in the manifest that contain additional inputs to the model. Defaults to None.
47
+ labels: Optional. Name(s) of column(s) in your manifest which contain training labels, in the order you want them returned. If set to None, the dataset will not return any labels, only filenames and inputs. Defaults to None.
48
+ update_manifest_func: Optional. Allows user to pass in a function to preprocess the manifest after it is loaded, but before the dataset does anything to it.
49
+ subsample: Optional. A number indicating how many examples to randomly subsample from the manifest. Defaults to None.
50
+ verbose: Whether to print out progress statements when initializing. Defaults to True.
51
+ augmentations: Optional. Can be a list of augmentation functions which take in a tensor and return a tensor, a single custom augmentation function which takes in a dict and returns a dict, or a single nn.Module. Defaults to None.
52
+ apply_augmentations_to: Optional. A list of strings indicating which batch elements to apply augmentations to. Defaults to ("primary_input").
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ data_path,
58
+ manifest_path=None,
59
+ split=None,
60
+ labels=None,
61
+ extra_inputs=None,
62
+ update_manifest_func=None,
63
+ subsample=None,
64
+ augmentations=None,
65
+ apply_augmentations_to=("primary_input",),
66
+ verify_existing=True,
67
+ drop_na_labels=True,
68
+ verbose=True,
69
+ ):
70
+
71
+ self.data_path = Path(data_path)
72
+ self.augmentations = augmentations
73
+ self.apply_augmentations_to = apply_augmentations_to
74
+ self.extra_inputs = extra_inputs
75
+ self.labels = labels
76
+
77
+ if isinstance(self.augmentations, nn.Module):
78
+ self.augmentations = [self.augmentations]
79
+
80
+ if (self.labels is None) and verbose:
81
+ print(
82
+ "No label column names were provided, only filenames and inputs will be returned."
83
+ )
84
+ if (self.labels is not None) and isinstance(self.labels, str):
85
+ self.labels = [self.labels]
86
+ if (self.extra_inputs is not None) and isinstance(self.extra_inputs, str):
87
+ self.extra_inputs = [self.extra_inputs]
88
+
89
+ # Read manifest file
90
+ if manifest_path is not None:
91
+ self.manifest_path = Path(manifest_path)
92
+ else:
93
+ self.manifest_path = self.data_path / "manifest.csv"
94
+
95
+ if self.manifest_path.exists():
96
+ if self.manifest_path.suffix == ".csv":
97
+ self.manifest = pd.read_csv(self.manifest_path, low_memory=False)
98
+ elif self.manifest_path.suffix == ".parquet":
99
+ self.manifest = pd.read_parquet(self.manifest_path)
100
+ else:
101
+ self.manifest = pd.DataFrame(
102
+ {
103
+ "filename": os.listdir(self.data_path),
104
+ }
105
+ )
106
+
107
+ # do manifest processing that's specific to a given task (different from update_manifest_func,
108
+ # exists as a method overridden in child classes)
109
+ self.manifest = self.process_manifest(self.manifest)
110
+
111
+ # Apply user-provided update function to manifest
112
+ if update_manifest_func is not None:
113
+ self.manifest = update_manifest_func(self, self.manifest)
114
+
115
+ # Usually set to "train", "val", or "test". If set to None, the entire manifest is used.
116
+ if split is not None:
117
+ self.manifest = self.manifest[self.manifest["split"] == split]
118
+ if verbose:
119
+ print(
120
+ f"Manifest loaded. \nSplit: {split}\nLength: {len(self.manifest):,}"
121
+ )
122
+
123
+ # Make sure all files actually exist. This can be disabled for efficiency if
124
+ # you have an especially large dataset
125
+ if verify_existing and "filename" in self.manifest:
126
+ old_len = len(self.manifest)
127
+ existing_files = os.listdir(self.data_path)
128
+ self.manifest = self.manifest[
129
+ self.manifest["filename"].isin(existing_files)
130
+ ]
131
+ new_len = len(self.manifest)
132
+ if verbose:
133
+ print(
134
+ f"{old_len - new_len} files in the manifest are missing from {self.data_path}."
135
+ )
136
+ elif (not verify_existing) and verbose:
137
+ print(
138
+ f"self.verify_existing is set to False, so it's possible for the manifest to contain filenames which are not present in {data_path}"
139
+ )
140
+
141
+ # Option to subsample dataset for doing smaller, faster runs
142
+ if subsample is not None:
143
+ if isinstance(subsample, int):
144
+ self.manifest = self.manifest.sample(n=subsample)
145
+ else:
146
+ self.manifest = self.manifest.sample(frac=subsample)
147
+ if verbose:
148
+ print(f"{subsample} examples subsampled.")
149
+
150
+ # Make sure that there are no NAN labels
151
+ if (self.labels is not None) and drop_na_labels:
152
+ old_len = len(self.manifest)
153
+ self.manifest = self.manifest.dropna(subset=self.labels)
154
+ new_len = len(self.manifest)
155
+ if verbose:
156
+ print(
157
+ f"{old_len - new_len} examples contained NaN value(s) in their labels and were dropped."
158
+ )
159
+ elif (self.labels is not None) and (not drop_na_labels):
160
+ print(
161
+ "drop_na_labels is set to False, so it's possible for the manifest to contain NaN-valued labels."
162
+ )
163
+
164
+ # Save manifest to weights and biases run directory
165
+ if wandb.run is not None:
166
+ run_data_path = Path(wandb.run.dir).parent / "data"
167
+ if not run_data_path.is_dir():
168
+ run_data_path.mkdir()
169
+
170
+ save_name = "manifest.csv"
171
+ if split is not None:
172
+ save_name = f"{split}_{save_name}"
173
+
174
+ self.manifest.to_csv(run_data_path / save_name)
175
+
176
+ if verbose:
177
+ print(f"Copy of manifest saved to {run_data_path}")
178
+
179
+ def __len__(self) -> int:
180
+ return len(self.manifest)
181
+
182
+ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:
183
+ output = {}
184
+ row = self.manifest.iloc[index]
185
+ if "filename" in row:
186
+ output["filename"] = row["filename"]
187
+ if self.labels is not None:
188
+ output["labels"] = torch.FloatTensor(row[self.labels])
189
+ file_results = self.read_file(self.data_path / output["filename"], row)
190
+ if isinstance(file_results, dict):
191
+ output.update(file_results)
192
+ else:
193
+ output["primary_input"] = file_results
194
+
195
+ if self.extra_inputs is not None:
196
+ output["extra_inputs"] = row["extra_inputs"]
197
+
198
+ if self.augmentations is not None:
199
+ output = self.augment(output)
200
+
201
+ return output
202
+
203
+ def process_manifest(self, manifest: pd.DataFrame) -> pd.DataFrame:
204
+ if "mrn" in manifest.columns:
205
+ manifest["mrn"] = manifest["mrn"].apply(format_mrn)
206
+ if "study_date" in manifest.columns:
207
+ manifest["study_date"] = pd.to_datetime(manifest["study_date"])
208
+ if "dob" in manifest.columns:
209
+ manifest["dob"] = pd.to_datetime(
210
+ manifest["dob"], infer_datetime_format=True, errors="coerce"
211
+ )
212
+ if ("study_date" in manifest.columns) and ("dob" in manifest.columns):
213
+ manifest["study_age"] = (
214
+ manifest["study_date"] - manifest["dob"]
215
+ ) / np.timedelta64(1, "Y")
216
+ return manifest
217
+
218
+ def augment(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
219
+
220
+ if isinstance(self.augmentations, Iterable):
221
+ # would use torch.stack here for cleanliness, but it seems that torchvision
222
+ # transforms v1's claims about supporting "arbitrary leading dimensions" is
223
+ # hogwash. they only support up to 4D. so we have to concatenate along the
224
+ # channel dimension, then apply the augmentations, then split along the channel
225
+ # dimension.
226
+ augmentable_inputs = torch.cat(
227
+ [output_dict[key] for key in self.apply_augmentations_to], dim=0
228
+ ) # (C*N, T, H, W)
229
+
230
+ for aug in self.augmentations:
231
+ augmentable_inputs = aug(augmentable_inputs)
232
+
233
+ place = 0
234
+ for i, key in enumerate(self.apply_augmentations_to):
235
+ n_channels = output_dict[key].shape[0]
236
+ output_dict[key] = augmentable_inputs[place:place+n_channels]
237
+ place += n_channels
238
+
239
+ elif isinstance(self.augmentations, Callable):
240
+ output_dict = self.augmentations(output_dict)
241
+
242
+ else:
243
+ raise Exception(
244
+ "self.augmentations must be either an Iterable of augmentations or a single custom augmentation function."
245
+ )
246
+
247
+ return output_dict
248
+
249
+ def read_file(self, filepath: Path, row: Optional[pd.Series] = None) -> torch.Tensor:
250
+ raise NotImplementedError
251
+
252
+
253
+ class ECGDataset(CedarsDataset):
254
+ def __init__(
255
+ self,
256
+ # CedarsDataset params
257
+ data_path: Union[Path, str],
258
+ manifest_path: Union[Path, str] = None,
259
+ split: str = None,
260
+ labels: Union[List[str], str] = None,
261
+ update_manifest_func: Callable = None,
262
+ subsample: float = None,
263
+ verbose: bool = True,
264
+ verify_existing: bool = True,
265
+ drop_na_labels: bool = True,
266
+ # ECGoDataset params
267
+ leads: List[str] = None,
268
+ random_lead: bool = False, # New parameter for random lead selection
269
+ data_length: int = 5000,
270
+ **kwargs,
271
+ ):
272
+ """
273
+ Args:
274
+ leads: List[str] -- which leads you want passed to the model. Defaults to all 12.
275
+ """
276
+
277
+ super().__init__(
278
+ data_path=data_path,
279
+ manifest_path=manifest_path,
280
+ split=split,
281
+ labels=labels,
282
+ update_manifest_func=update_manifest_func,
283
+ subsample=subsample,
284
+ verbose=verbose,
285
+ verify_existing=verify_existing,
286
+ drop_na_labels=drop_na_labels,
287
+ **kwargs,
288
+ )
289
+
290
+ self.lead_order = [
291
+ "I",
292
+ "II",
293
+ "III",
294
+ "aVR",
295
+ "aVL",
296
+ "aVF",
297
+ "V1",
298
+ "V2",
299
+ "V3",
300
+ "V4",
301
+ "V5",
302
+ "V6",
303
+ ]
304
+ self.leads = leads
305
+ if self.leads is None:
306
+ self.leads = self.lead_order
307
+ if isinstance(self.leads, str):
308
+ self.leads = [self.leads]
309
+
310
+ if "first_lead_only" in kwargs:
311
+ raise (
312
+ Exception(
313
+ '"first_lead_only" has been deprecated. Please pass leads=["I"] \
314
+ instead if you would like to train on only the first lead.'
315
+ )
316
+ )
317
+
318
+ self.random_lead = random_lead # Storing the random_lead attribute
319
+
320
+ self.data_length = data_length
321
+
322
+
323
+ def read_file(self, filepath, row=None):
324
+ # ECGs are usually stored as .npy files.
325
+ file = np.load(filepath)
326
+ if file.shape[0] != 12:
327
+ file = file.T
328
+ file = torch.tensor(file).float()
329
+
330
+ # Slice the data to the specified length
331
+ file = file[:, :self.data_length]
332
+
333
+ if self.random_lead:
334
+ lead_idx = random.choice(range(12))
335
+ file = file[lead_idx:lead_idx+1] # Select the random lead
336
+ else:
337
+ channels = [self.lead_order.index(lead) for lead in self.leads]
338
+ file = file[channels]
339
+
340
+
341
+ # Final shape should ideally be NumLeadsxTime(or NumLeadsxTime depending on the resolution of the ECG)
342
+ return file
343
+
344
+
345
+ class ECGSingleLeadDataset(CedarsDataset):
346
+ def __init__(
347
+ self,
348
+ # CedarsDataset params
349
+ data_path: Union[Path, str],
350
+ manifest_path: Union[Path, str] = None,
351
+ labels: Union[List[str], str] = None,
352
+ update_manifest_func: Callable = None,
353
+ subsample: float = None,
354
+ verbose: bool = True,
355
+ verify_existing: bool = True,
356
+ drop_na_labels: bool = True,
357
+ **kwargs,
358
+ ):
359
+ """
360
+ Args:
361
+ leads: List[str] -- which leads you want passed to the model. Defaults to all 12.
362
+ """
363
+
364
+ super().__init__(
365
+ data_path=data_path,
366
+ manifest_path=manifest_path,
367
+ labels=labels,
368
+ update_manifest_func=update_manifest_func,
369
+ subsample=subsample,
370
+ verbose=verbose,
371
+ verify_existing=verify_existing,
372
+ drop_na_labels=drop_na_labels,
373
+ **kwargs,
374
+ )
375
+
376
+
377
+ def read_file(self, filepath, row=None):
378
+ # ECGs are usually stored as .npy files.
379
+ try:
380
+ file = np.load(filepath)
381
+ except Exception as e:
382
+ print(filepath)
383
+ print(e)
384
+
385
+ file = torch.tensor(file).float().unsqueeze(0)
386
+
387
+
388
+ # Final shape should ideally be NumLeadsxTime(or NumLeadsxTime depending on the resolution of the ECG)
389
+ return file
390
+
391
+
392
+
utils/ecg_utils.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # +
2
+ import base64
3
+ import struct
4
+
5
+ import numpy as np
6
+ import xmltodict
7
+ from scipy.ndimage import median_filter as scipy_ndimage_median_filter
8
+ import matplotlib.pyplot as plt
9
+ import pywt
10
+ # -
11
+
12
+ lead_order = [
13
+ "I",
14
+ "II",
15
+ "III",
16
+ "aVR",
17
+ "aVL",
18
+ "aVF",
19
+ "V1",
20
+ "V2",
21
+ "V3",
22
+ "V4",
23
+ "V5",
24
+ "V6",
25
+ ]
26
+
27
+
28
+ def plot_12_lead_ecg(ecg_array, output_filename=None):
29
+ """
30
+ Plot each lead of the 12-lead ECG, and save the plot to a file.
31
+ All leads share the x axis, but each lead gets its own chart.
32
+ """
33
+ fig, axs = plt.subplots(12, 1, sharex=True, figsize=(16, 9))
34
+ for lead, lead_name in enumerate(lead_order):
35
+ axs[lead].plot(ecg_array[:, lead])
36
+ axs[lead].set_ylabel(str(lead_name))
37
+ if output_filename is not None:
38
+ plt.savefig(output_filename)
39
+ plt.show()
40
+ plt.close()
41
+
42
+
43
+ def get_median_filter_width(sampling_frequency, duration):
44
+ res = int(sampling_frequency * duration)
45
+ res += (res % 2) - 1 # needs to be an odd number
46
+ return res
47
+
48
+
49
+ def remove_baseline_wander(waveform: np.ndarray, sampling_frequency: int) -> np.ndarray:
50
+
51
+ """
52
+ Remove baseline wander from ECG NPYs
53
+ de Chazal et al. used two median filters to remove baseline wander.
54
+ Median filters take the median value of a sliding window of a specified size
55
+ One median filter of 200-ms width to remove QRS complexes and P-waves and other of
56
+ 600 ms width to remove T-waves.
57
+ Do one filter and then the next filter. Then take the result and subtract it form the original signal
58
+ https://pubmed.ncbi.nlm.nih.gov/15248536/
59
+ Example of median filter:
60
+ medfilt([2,6,5,4,0,3,5,7,9,2,0,1], 5) -> [ 2. 4. 4. 4. 4. 4. 5. 5. 5. 2. 1. 0.]
61
+ >>> np.median([0, 0, 2, 6, 5])
62
+ 2.0
63
+ >>> np.median([0, 2, 6, 5, 4])
64
+ 4.0
65
+
66
+ """
67
+
68
+ # Depending on the sampling frequency, the widths of the convolutional median filters changes
69
+ filter_widths = [
70
+ get_median_filter_width(sampling_frequency, duration) for duration in [0.2, 0.6]
71
+ ]
72
+ filter_widths = np.array(filter_widths, dtype="int")
73
+
74
+ # make a copy of orignal signal
75
+ original_waveform = waveform.copy()
76
+
77
+ # apply median filters one by one on top of each other
78
+ for filter_width in filter_widths:
79
+ waveform = scipy_ndimage_median_filter(
80
+ waveform, size=(filter_width, 1), mode="constant", cval=0.0
81
+ )
82
+ waveform = original_waveform - waveform # finally subtract from orignal signal
83
+ return waveform
84
+
85
+
86
+ def wavelet_denoise_signal(
87
+ waveform: np.ndarray,
88
+ dwt_transform: str = "bior4.4",
89
+ dlevels: int = 9,
90
+ cutoff_low: int = 1,
91
+ cutoff_high: int = 7,
92
+ ) -> np.ndarray:
93
+
94
+ # cutoff_low determines how flat you want overall baseline to be.
95
+ # Higher means more flat baseline
96
+ # cutoff_high determines within the small segments how much do
97
+ # you want to suppress the squiggliness. Lower cutoff_high
98
+ # suppresses more squiggliness but also suppresses R wave morphology
99
+
100
+ coefficients = pywt.wavedec(
101
+ waveform, dwt_transform, level=dlevels
102
+ ) # wavelet transform 'bior4.4'
103
+ # scale 0 to cutoff_low
104
+ for low_cutoff_value in range(0, cutoff_low):
105
+ coefficients[low_cutoff_value] = np.multiply(
106
+ coefficients[low_cutoff_value], [0.0]
107
+ )
108
+ # scale cutoff_high to end
109
+ for high_cutoff_value in range(cutoff_high, len(coefficients)):
110
+ coefficients[high_cutoff_value] = np.multiply(
111
+ coefficients[high_cutoff_value], [0.0]
112
+ )
113
+ waveform = pywt.waverec(coefficients, dwt_transform) # inverse wavelet transform
114
+ return waveform
utils/models.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+ from functools import reduce
4
+ from operator import __add__
5
+ import torch.nn.functional as F
6
+ from collections import OrderedDict
7
+ from typing import Callable, List
8
+ from torch import Tensor
9
+
10
+
11
+ class EffNet(nn.Module):
12
+ # lightly retouched version of John's EffNet to add clean support for multiple output
13
+ # layer designs as well as single-lead inputs
14
+ def __init__(
15
+ self,
16
+ num_extra_inputs: int = 0,
17
+ output_neurons: int = 1,
18
+ channels: List[int] = (32, 16, 24, 40, 80, 112, 192, 320, 1280),
19
+ depth: List[int] = (1, 2, 2, 3, 3, 3, 3),
20
+ dilation: int = 2,
21
+ stride: int = 8,
22
+ expansion: int = 6,
23
+ embedding_hook: bool = False,
24
+ input_channels: int = 1,
25
+ verbose: bool = False,
26
+ embedding_shift: bool = False,
27
+ ):
28
+ super().__init__()
29
+
30
+ self.input_channels = input_channels
31
+ self.channels = channels
32
+ self.output_nerons = output_neurons
33
+
34
+ # backwards compatibility change to prevent the addition of the output_neurons param
35
+ # from breaking people's existing EffNet initializations
36
+ if len(self.channels) == 10:
37
+ self.output_nerons = self.channels[9]
38
+ print(
39
+ "DEPRECATION WARNING: instead of controlling the number of output neurons by changing the 10th item in the channels parameter, use the new output_neurons parameter instead."
40
+ )
41
+
42
+ self.depth = depth
43
+ self.expansion = expansion
44
+ self.stride = stride
45
+ self.dilation = dilation
46
+ self.embedding_hook = embedding_hook
47
+ self.embedding_shift = embedding_shift
48
+
49
+ if verbose:
50
+ print("\nEffNet Parameters:")
51
+ print(f"{self.input_channels=}")
52
+ print(f"{self.channels=}")
53
+ print(f"{self.output_nerons=}")
54
+ print(f"{self.depth=}")
55
+ print(f"{self.expansion=}")
56
+ print(f"{self.stride=}")
57
+ print(f"{self.dilation=}")
58
+ print(f"{self.embedding_hook=}")
59
+ print("\n")
60
+
61
+ self.stage1 = nn.Conv1d(
62
+ self.input_channels,
63
+ self.channels[0],
64
+ kernel_size=3,
65
+ stride=stride,
66
+ padding=1,
67
+ dilation=dilation,
68
+ ) # 1 conv
69
+
70
+ self.b0 = nn.BatchNorm1d(self.channels[0])
71
+
72
+ self.stage2 = MBConv(
73
+ self.channels[0], self.channels[1], self.expansion, self.depth[0], stride=2
74
+ )
75
+
76
+ self.stage3 = MBConv(
77
+ self.channels[1], self.channels[2], self.expansion, self.depth[1], stride=2
78
+ )
79
+
80
+ self.Pool = nn.MaxPool1d(3, stride=1, padding=1)
81
+
82
+ self.stage4 = MBConv(
83
+ self.channels[2], self.channels[3], self.expansion, self.depth[2], stride=2
84
+ )
85
+
86
+ self.stage5 = MBConv(
87
+ self.channels[3], self.channels[4], self.expansion, self.depth[3], stride=2
88
+ )
89
+
90
+ self.stage6 = MBConv(
91
+ self.channels[4], self.channels[5], self.expansion, self.depth[4], stride=2
92
+ )
93
+
94
+ self.stage7 = MBConv(
95
+ self.channels[5], self.channels[6], self.expansion, self.depth[5], stride=2
96
+ )
97
+
98
+ self.stage8 = MBConv(
99
+ self.channels[6], self.channels[7], self.expansion, self.depth[6], stride=2
100
+ )
101
+
102
+ self.stage9 = nn.Conv1d(self.channels[7], self.channels[8], kernel_size=1)
103
+ self.AAP = nn.AdaptiveAvgPool1d(1)
104
+ self.act = nn.ReLU()
105
+ self.drop = nn.Dropout(p=0.3)
106
+ self.num_extra_inputs = num_extra_inputs
107
+
108
+ self.fc = nn.Linear(self.channels[5] + num_extra_inputs, self.output_nerons)
109
+ self.fc = nn.Linear(self.channels[8] + num_extra_inputs, self.output_nerons)
110
+
111
+
112
+ self.fc.bias.data[0] = 0.275
113
+
114
+ def forward(self, x: Tensor) -> Tensor:
115
+ if self.num_extra_inputs > 0:
116
+ x, extra_inputs = x
117
+
118
+ x = self.b0(self.stage1(x))
119
+ x = self.stage2(x)
120
+ x = self.stage3(x)
121
+ x = self.Pool(x)
122
+ x = self.stage4(x)
123
+ x = self.stage5(x)
124
+ x = self.stage6(x)
125
+ x = self.Pool(x)
126
+ x = self.stage7(x)
127
+ x = self.stage8(x)
128
+ x = self.stage9(x)
129
+ x = self.act(self.AAP(x)[:, :, 0])
130
+ if self.embedding_hook:
131
+ return x
132
+
133
+ else:
134
+
135
+ if self.embedding_shift:
136
+ delta_embedding_array = np.load('/workspace/imin/applewatch_potassium/delta_embedding_poolaverage_5second_to_5second.npy')
137
+ delta_embedding_tensor = torch.tensor(delta_embedding_array, device='cuda')
138
+ x += delta_embedding_tensor
139
+
140
+ x = self.drop(x)
141
+
142
+ if self.num_extra_inputs > 0:
143
+ x = torch.cat((x, extra_inputs), 1)
144
+
145
+ x = self.fc(x)
146
+ return x
147
+
148
+
149
+ class Bottleneck(nn.Module):
150
+ def __init__(
151
+ self,
152
+ in_channel: int,
153
+ out_channel: int,
154
+ expansion: int,
155
+ activation: Callable,
156
+ stride: int = 1,
157
+ padding: int = 1,
158
+ ):
159
+ super().__init__()
160
+
161
+ self.stride = stride
162
+ self.conv1 = nn.Conv1d(in_channel, in_channel * expansion, kernel_size=1)
163
+ self.conv2 = nn.Conv1d(
164
+ in_channel * expansion,
165
+ in_channel * expansion,
166
+ kernel_size=3,
167
+ groups=in_channel * expansion,
168
+ padding=padding,
169
+ stride=stride,
170
+ )
171
+ self.conv3 = nn.Conv1d(
172
+ in_channel * expansion, out_channel, kernel_size=1, stride=1
173
+ )
174
+ self.b0 = nn.BatchNorm1d(in_channel * expansion)
175
+ self.b1 = nn.BatchNorm1d(in_channel * expansion)
176
+ self.d = nn.Dropout()
177
+ self.act = activation()
178
+
179
+ def forward(self, x: Tensor) -> Tensor:
180
+ if self.stride == 1:
181
+ y = self.act(self.b0(self.conv1(x)))
182
+ y = self.act(self.b1(self.conv2(y)))
183
+ y = self.conv3(y)
184
+ y = self.d(y)
185
+ y = x + y
186
+ return y
187
+ else:
188
+ y = self.act(self.b0(self.conv1(x)))
189
+ y = self.act(self.b1(self.conv2(y)))
190
+ y = self.conv3(y)
191
+ return y
192
+
193
+
194
+ class MBConv(nn.Module):
195
+ def __init__(
196
+ self, in_channel, out_channels, expansion, layers, activation=nn.ReLU6, stride=2
197
+ ):
198
+ super().__init__()
199
+
200
+ self.stack = OrderedDict()
201
+ for i in range(0, layers - 1):
202
+ self.stack["s" + str(i)] = Bottleneck(
203
+ in_channel, in_channel, expansion, activation
204
+ )
205
+
206
+ self.stack["s" + str(layers + 1)] = Bottleneck(
207
+ in_channel, out_channels, expansion, activation, stride=stride
208
+ )
209
+
210
+ self.stack = nn.Sequential(self.stack)
211
+
212
+ self.bn = nn.BatchNorm1d(out_channels)
213
+
214
+ def forward(self, x: Tensor) -> Tensor:
215
+ x = self.stack(x)
216
+ return self.bn(x)
217
+
218
+
219
+ class BasicBlock(nn.Module):
220
+ expansion = 1
221
+
222
+ def __init__(self, in_channels, out_channels, stride=1, downsample=None):
223
+ super(BasicBlock, self).__init__()
224
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
225
+ self.bn1 = nn.BatchNorm2d(out_channels)
226
+ self.conv2 = nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, stride=1, padding=1, bias=False)
227
+ self.bn2 = nn.BatchNorm2d(out_channels * BasicBlock.expansion)
228
+ self.downsample = downsample
229
+ self.stride = stride
230
+
231
+ def forward(self, x):
232
+ identity = x
233
+
234
+ out = self.conv1(x)
235
+ out = self.bn1(out)
236
+ out = F.relu(out)
237
+
238
+ out = self.conv2(out)
239
+ out = self.bn2(out)
240
+
241
+ if self.downsample is not None:
242
+ identity = self.downsample(x)
243
+
244
+ out += identity
245
+ out = F.relu(out)
246
+
247
+ return out
utils/training_models.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import lru_cache
2
+ from typing import Iterable
3
+ import pytorch_lightning as pl
4
+ import torch
5
+ import torch.nn as nn
6
+ import numpy as np
7
+ import wandb
8
+ from sklearn import metrics as skl_metrics
9
+ import torchvision
10
+ import os
11
+ from pathlib import Path
12
+ import pandas as pd
13
+
14
+
15
+ class TrainingMetric:
16
+ def __init__(self, metric_func, metric_name, optimum=None):
17
+ self.func = metric_func
18
+ self.name = metric_name
19
+ self.optimum = optimum
20
+
21
+ def calc_metric(self, *args, **kwargs):
22
+ try:
23
+ return self.func(*args, **kwargs)
24
+ except ValueError as e:
25
+ return np.nan
26
+
27
+ def __call__(self, y_true, y_pred, labels=None, split=None, step_type=None) -> dict:
28
+
29
+ # if y_true is empty
30
+ if y_true.shape[0] == 0: # TODO: handle other cases
31
+ m = {
32
+ f"{step_type}_{split}_{l}_{self.name}": self.calc_metric(None, yp)
33
+ for yp, l in zip(y_pred.T, labels)
34
+ }
35
+ return m
36
+
37
+ # Simple 1:1 y_true and y_pred are either shape=(batch, 1) or shape=(batch,)
38
+ if len(y_pred.shape) == 1 or (y_pred.shape[1] == 1 and y_true.shape[1] == 1):
39
+ m = {
40
+ f"{step_type}_{split}_{self.name}": self.calc_metric(
41
+ y_true.flatten(), y_pred.flatten()
42
+ )
43
+ }
44
+
45
+ # Multi-binary classification-like y_true and y_pred are shape=(batch, class)
46
+ elif y_true.shape[1] != 1 and y_pred.shape[1] != 1:
47
+ m = {
48
+ f"{step_type}_{split}_{l}_{self.name}": self.calc_metric(yt, yp)
49
+ for yt, yp, l in zip(y_true.T, y_pred.T, labels)
50
+ }
51
+
52
+ # Multi-class classification-like y_true is shape=(batch, 1) or shape=(batch,) and y_pred is shape=(batch, class)
53
+ elif (len(y_true.shape) == 1 or y_true.shape[1] == 1) and y_pred.shape[1] != 1:
54
+ m = {
55
+ f"{step_type}_{split}_{l}_{self.name}": self.calc_metric(
56
+ y_true.flatten() == i, yp
57
+ )
58
+ for i, (yp, l) in enumerate(
59
+ zip(y_pred.T, labels)
60
+ ) # turn multi class into binary classification
61
+ }
62
+
63
+ return m
64
+
65
+
66
+ class CumulativeMetric(TrainingMetric):
67
+
68
+ """Wraps a metric to apply to every class in output and calculate a cumulative value (like mean AUC)"""
69
+
70
+ def __init__(
71
+ self,
72
+ training_metric: TrainingMetric,
73
+ metric_func,
74
+ metric_name="cumulative",
75
+ optimum=None,
76
+ ):
77
+ optimum = optimum or training_metric.optimum
78
+ metric_name = f"{metric_name}_{training_metric.name}"
79
+ super().__init__(metric_func, metric_name, optimum)
80
+ self.base_metric = training_metric
81
+
82
+ def __call__(self, y_true, y_pred, labels=None, split=None, step_type=None):
83
+ vals = list(self.base_metric(y_true, y_pred, labels, split, step_type).values())
84
+
85
+ m = {f"{step_type}_{split}_{self.name}": self.func(vals)}
86
+ return m
87
+
88
+
89
+ r2_metric = TrainingMetric(skl_metrics.r2_score, "r2", optimum="max")
90
+ roc_auc_metric = TrainingMetric(skl_metrics.roc_auc_score, "roc_auc", optimum="max")
91
+ accuracy_metric = TrainingMetric(skl_metrics.accuracy_score, "accuracy", optimum="max")
92
+ mae_metric = TrainingMetric(skl_metrics.mean_absolute_error, "mae", optimum="min")
93
+ pred_value_mean_metric = TrainingMetric(
94
+ lambda y_true, y_pred: np.mean(y_pred), "pred_value_mean"
95
+ )
96
+ pred_value_std_metric = TrainingMetric(
97
+ lambda y_true, y_pred: np.std(y_pred), "pred_value_std"
98
+ )
99
+
100
+
101
+ class TrainingModel(pl.LightningModule):
102
+ def __init__(
103
+ self,
104
+ model,
105
+ metrics: Iterable[TrainingMetric] = dict(),
106
+ tracked_metric=None,
107
+ early_stop_epochs=10,
108
+ checkpoint_every_epoch=False,
109
+ checkpoint_every_n_steps=None,
110
+ index_labels=None,
111
+ save_predictions_path=None,
112
+ lr=0.01,
113
+ ):
114
+ super().__init__()
115
+ self.epoch_preds = {"train": ([], []), "val": ([], [])}
116
+ self.epoch_losses = {"train": [], "val": []}
117
+ self.metrics = {}
118
+ self.metric_funcs = {m.name: m for m in metrics}
119
+ self.tracked_metric = f"epoch_val_{tracked_metric}"
120
+ self.best_tracked_metric = None
121
+ self.early_stop_epochs = early_stop_epochs
122
+ self.checkpoint_every_epoch = checkpoint_every_epoch
123
+ self.checkpoint_every_n_steps = checkpoint_every_n_steps
124
+ self.metrics["epochs_since_last_best"] = 0
125
+ self.m = model
126
+ self.training_steps = 0
127
+ self.steps_since_checkpoint = 0
128
+ self.labels = index_labels
129
+ if self.labels is not None and isinstance(self.labels, str):
130
+ self.labels = [self.labels]
131
+ if isinstance(save_predictions_path, str):
132
+ save_predictions_path = Path(save_predictions_path)
133
+ self.save_predictions_path = save_predictions_path
134
+ self.lr = lr
135
+ self.step_loss = (None, None)
136
+
137
+ self.log_path = Path(wandb.run.dir) if wandb.run is not None else None
138
+
139
+ def configure_optimizers(self):
140
+ return torch.optim.AdamW(self.parameters(), self.lr)
141
+
142
+ def forward(self, x: dict):
143
+ # if anything other than 'primary_input' and 'extra_inputs' is used,
144
+ # this function must be overridden
145
+ if 'extra_inputs' in x:
146
+ return self.m((x['primary_input'], x['extra_inputs']))
147
+ else:
148
+ return self.m(x['primary_input'])
149
+
150
+ def step(self, batch, step_type='train'):
151
+ batch = self.prepare_batch(batch)
152
+ y_pred = self.forward(batch)
153
+
154
+ if step_type != 'predict':
155
+ if 'labels' not in batch:
156
+ batch['labels'] = torch.empty(0)
157
+ loss = self.loss_func(y_pred, batch['labels'])
158
+ if torch.isnan(loss):
159
+ raise ValueError(loss)
160
+
161
+ self.log_step(step_type, batch['labels'], y_pred, loss)
162
+
163
+ return loss
164
+ else:
165
+ return y_pred
166
+
167
+ def prepare_batch(self, batch):
168
+ return batch
169
+
170
+ def training_step(self, batch, i):
171
+ return self.step(batch, "train")
172
+
173
+ def validation_step(self, batch, i):
174
+ return self.step(batch, "val")
175
+
176
+ def predict_step(self, batch, *args):
177
+ y_pred = self.step(batch, "predict")
178
+ return {"filename": batch["filename"], "prediction": y_pred.cpu().numpy()}
179
+
180
+ def on_predict_epoch_end(self, results):
181
+
182
+ for i, predict_results in enumerate(results):
183
+ filename_df = pd.DataFrame(
184
+ {
185
+ "filename": np.concatenate(
186
+ [batch["filename"] for batch in predict_results]
187
+ )
188
+ }
189
+ )
190
+
191
+ if self.labels is not None:
192
+ columns = [f"{class_name}_preds" for class_name in self.labels]
193
+ else:
194
+ columns = ["preds"]
195
+ outputs_df = pd.DataFrame(
196
+ np.concatenate(
197
+ [batch["prediction"] for batch in predict_results], axis=0
198
+ ),
199
+ columns=columns,
200
+ )
201
+
202
+ prediction_df = pd.concat([filename_df, outputs_df], axis=1)
203
+
204
+ dataloader = self.trainer.predict_dataloaders[i]
205
+ manifest = dataloader.dataset.manifest
206
+ prediction_df = prediction_df.merge(manifest, on="filename", how="outer")
207
+ if wandb.run is not None:
208
+ prediction_df.to_csv(
209
+ Path(wandb.run.dir).parent
210
+ / "data"
211
+ / f"dataloader_{i}_potassium_predictions.csv",
212
+ index=False,
213
+ )
214
+ if self.save_predictions_path is not None:
215
+
216
+ if ".csv" in self.save_predictions_path.name:
217
+ prediction_df.to_csv(
218
+ self.save_predictions_path.parent
219
+ / self.save_predictions_path.name.replace(".csv", f"_{i}_.csv"),
220
+ index=False,
221
+ )
222
+ else:
223
+ prediction_df.to_csv(
224
+ self.save_predictions_path / f"dataloader_{i}_potassium_predictions.csv",
225
+ index=False,
226
+ )
227
+
228
+ if wandb.run is None and self.save_predictions_path is None:
229
+ print(
230
+ "WandB is not active and self.save_predictions_path is None. Predictions will be saved to the directory this script is being run in."
231
+ )
232
+ prediction_df.to_csv(f"dataloader_{i}_potassium_predictions.csv", index=False)
233
+
234
+ def log_step(self, step_type, labels, output_tensor, loss):
235
+ self.step_loss = (step_type, loss.detach().item())
236
+ self.epoch_preds[step_type][0].append(labels.detach().cpu().numpy())
237
+ self.epoch_preds[step_type][1].append(output_tensor.detach().cpu().numpy())
238
+ self.epoch_losses[step_type].append(loss.detach().item())
239
+ if step_type == "train":
240
+ self.training_steps += 1
241
+ self.steps_since_checkpoint += 1
242
+ if (
243
+ self.checkpoint_every_n_steps is not None
244
+ and self.steps_since_checkpoint > self.checkpoint_every_n_steps
245
+ ):
246
+ self.steps_since_checkpoint = 0
247
+ self.checkpoint_weights(f"step_{self.training_steps}")
248
+
249
+ def checkpoint_weights(self, name=""):
250
+ if wandb.run is not None:
251
+ weights_path = Path(wandb.run.dir).parent / "weights"
252
+ if not weights_path.is_dir():
253
+ weights_path.mkdir()
254
+ torch.save(self.state_dict(), weights_path / f"model_{name}.pt")
255
+ else:
256
+ print("Did not checkpoint model. wandb not initialized.")
257
+
258
+ def validation_epoch_end(self, preds):
259
+
260
+ # Save weights
261
+ self.metrics["epoch"] = self.current_epoch
262
+ if self.checkpoint_every_epoch:
263
+ self.checkpoint_weights(f"epoch_{self.current_epoch}")
264
+
265
+ # Calculate metrics
266
+ for m_type in ["train", "val"]:
267
+
268
+ y_true, y_pred = self.epoch_preds[m_type]
269
+ if len(y_true) == 0 or len(y_pred) == 0:
270
+ continue
271
+ y_true, y_pred = np.concatenate(y_true), np.concatenate(y_pred)
272
+
273
+ self.metrics[f"epoch_{m_type}_loss"] = np.mean(self.epoch_losses[m_type])
274
+ for m in self.metric_funcs.values():
275
+ self.metrics.update(
276
+ m(
277
+ y_true,
278
+ y_pred,
279
+ labels=self.labels,
280
+ split=m_type,
281
+ step_type="epoch",
282
+ )
283
+ )
284
+
285
+ # Reset predictions
286
+ self.epoch_losses[m_type] = []
287
+ self.epoch_preds[m_type] = ([], [])
288
+
289
+ # Check if new best epoch
290
+ if self.metrics is not None and self.tracked_metric is not None:
291
+ if self.tracked_metric == "epoch_val_loss":
292
+ metric_optimization = "min"
293
+ else:
294
+ metric_optimization = self.metric_funcs[
295
+ self.tracked_metric.replace("epoch_val_", "")
296
+ ].optimum
297
+ if (
298
+ self.metrics[self.tracked_metric] is not None
299
+ and (
300
+ self.best_tracked_metric is None
301
+ or (
302
+ metric_optimization == "max"
303
+ and self.metrics[self.tracked_metric] > self.best_tracked_metric
304
+ )
305
+ or (
306
+ metric_optimization == "min"
307
+ and self.metrics[self.tracked_metric] < self.best_tracked_metric
308
+ )
309
+ )
310
+ and self.current_epoch > 0
311
+ ):
312
+ print(
313
+ f"New best epoch! {self.tracked_metric}={self.metrics[self.tracked_metric]}, epoch={self.current_epoch}"
314
+ )
315
+ self.checkpoint_weights(f"best_{self.tracked_metric}")
316
+ self.metrics["epochs_since_last_best"] = 0
317
+ self.best_tracked_metric = self.metrics[self.tracked_metric]
318
+ else:
319
+ self.metrics["epochs_since_last_best"] += 1
320
+ if self.metrics["epochs_since_last_best"] >= self.early_stop_epochs:
321
+ raise KeyboardInterrupt("Early stopping condition met")
322
+
323
+ # Log to w&b
324
+ if wandb.run is not None:
325
+ wandb.log(self.metrics)
326
+
327
+
328
+ class RegressionModel(TrainingModel):
329
+ def __init__(
330
+ self,
331
+ model,
332
+ metrics=(r2_metric, mae_metric, pred_value_mean_metric, pred_value_std_metric),
333
+ tracked_metric="mae",
334
+ early_stop_epochs=10,
335
+ checkpoint_every_epoch=False,
336
+ checkpoint_every_n_steps=None,
337
+ index_labels=None,
338
+ save_predictions_path=None,
339
+ lr=0.01,
340
+ ):
341
+ super().__init__(
342
+ model=model,
343
+ metrics=metrics,
344
+ tracked_metric=tracked_metric,
345
+ early_stop_epochs=early_stop_epochs,
346
+ checkpoint_every_epoch=checkpoint_every_epoch,
347
+ checkpoint_every_n_steps=checkpoint_every_n_steps,
348
+ index_labels=index_labels,
349
+ save_predictions_path=save_predictions_path,
350
+ lr=lr,
351
+ )
352
+ self.loss_func = nn.MSELoss()
353
+
354
+ def prepare_batch(self, batch):
355
+ if "labels" in batch and len(batch["labels"].shape) == 1:
356
+ batch["labels"] = batch["labels"][:, None]
357
+ return batch
358
+
359
+
360
+ class BinaryClassificationModel(TrainingModel):
361
+ def __init__(
362
+ self,
363
+ model,
364
+ metrics=(roc_auc_metric, CumulativeMetric(roc_auc_metric, np.nanmean, "mean")),
365
+ tracked_metric="mean_roc_auc",
366
+ early_stop_epochs=10,
367
+ checkpoint_every_epoch=False,
368
+ checkpoint_every_n_steps=None,
369
+ index_labels=None,
370
+ save_predictions_path=None,
371
+ lr=0.01,
372
+ ):
373
+ super().__init__(
374
+ model=model,
375
+ metrics=metrics,
376
+ tracked_metric=tracked_metric,
377
+ early_stop_epochs=early_stop_epochs,
378
+ checkpoint_every_epoch=checkpoint_every_epoch,
379
+ checkpoint_every_n_steps=checkpoint_every_n_steps,
380
+ index_labels=index_labels,
381
+ save_predictions_path=save_predictions_path,
382
+ lr=lr,
383
+ )
384
+ self.loss_func = nn.BCEWithLogitsLoss()
385
+
386
+ def prepare_batch(self, batch):
387
+ if "labels" in batch and len(batch["labels"].shape) == 1:
388
+ batch["labels"] = batch["labels"][:, None]
389
+ return batch
390
+
391
+
392
+ # Addresses bug caused by labels from a single column in a manifest being delivered as Bx1,
393
+ # but nn.CrossEntropyLoss wants a simple list of length B.
394
+ class SqueezeCrossEntropyLoss(nn.Module):
395
+ def __init__(self):
396
+ super().__init__()
397
+ self.cross_entropy = nn.CrossEntropyLoss()
398
+
399
+ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor):
400
+ return self.cross_entropy(y_pred, y_true.squeeze(dim=-1))
401
+
402
+
403
+ class MultiClassificationModel(TrainingModel):
404
+ def __init__(
405
+ self,
406
+ model,
407
+ metrics=(roc_auc_metric, CumulativeMetric(roc_auc_metric, np.mean, "mean")),
408
+ tracked_metric="mean_roc_auc",
409
+ early_stop_epochs=10,
410
+ checkpoint_every_epoch=False,
411
+ checkpoint_every_n_steps=None,
412
+ index_labels=None,
413
+ save_predictions_path=None,
414
+ lr=0.01,
415
+ ):
416
+ metrics = [*metrics]
417
+ super().__init__(
418
+ model=model,
419
+ metrics=metrics,
420
+ tracked_metric=tracked_metric,
421
+ early_stop_epochs=early_stop_epochs,
422
+ checkpoint_every_epoch=checkpoint_every_epoch,
423
+ checkpoint_every_n_steps=checkpoint_every_n_steps,
424
+ index_labels=index_labels,
425
+ save_predictions_path=save_predictions_path,
426
+ lr=lr,
427
+ )
428
+ self.loss_func = SqueezeCrossEntropyLoss()
429
+
430
+ def prepare_batch(self, batch):
431
+ if "labels" in batch:
432
+ batch["labels"] = batch["labels"].long()
433
+ batch["primary_input"] = batch["primary_input"].float()
434
+ return batch
435
+
436
+
437
+ if __name__ == "__main__":
438
+ os.environ["WANDB_MODE"] = "offline"
439
+
440
+ m = torchvision.models.video.r2plus1d_18()
441
+ m.fc = nn.Linear(512, 1)
442
+ training_model = RegressionModel(m)
443
+ x = torch.randn((4, 3, 8, 112, 112))
444
+ y = m(x)
445
+ print(y.shape)
446
+
447
+