nina-m-m
commited on
Commit
•
44cdef4
1
Parent(s):
bf0f149
Update model and implement abstract class
Browse files
ECG2HRV.joblib
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3669bf4b201fd873f1fb1b48083c41f26c42be00f58d14d234af6b9aac1f6433
|
3 |
+
size 39
|
notebooks/01_Model_Deployment_Research.ipynb
CHANGED
@@ -399,21 +399,18 @@
|
|
399 |
},
|
400 |
{
|
401 |
"cell_type": "code",
|
402 |
-
"execution_count":
|
403 |
"outputs": [],
|
404 |
"source": [
|
405 |
"from huggingface_hub import hf_hub_download\n",
|
406 |
"import joblib\n",
|
407 |
"import torch\n",
|
|
|
408 |
"\n",
|
409 |
-
"from src.model import
|
410 |
],
|
411 |
"metadata": {
|
412 |
-
"collapsed": false
|
413 |
-
"ExecuteTime": {
|
414 |
-
"end_time": "2024-02-21T11:39:25.775871100Z",
|
415 |
-
"start_time": "2024-02-21T11:39:25.755838Z"
|
416 |
-
}
|
417 |
}
|
418 |
},
|
419 |
{
|
@@ -427,11 +424,11 @@
|
|
427 |
},
|
428 |
{
|
429 |
"cell_type": "code",
|
430 |
-
"execution_count":
|
431 |
"outputs": [],
|
432 |
"source": [
|
433 |
"# Instantiate model\n",
|
434 |
-
"model =
|
435 |
"# Save\n",
|
436 |
"joblib.dump(model, \"..\\ECG2HRV.joblib\")\n",
|
437 |
"# Load in notebook\n",
|
@@ -440,15 +437,15 @@
|
|
440 |
"metadata": {
|
441 |
"collapsed": false,
|
442 |
"ExecuteTime": {
|
443 |
-
"end_time": "2024-02-
|
444 |
-
"start_time": "2024-02-
|
445 |
}
|
446 |
}
|
447 |
},
|
448 |
{
|
449 |
"cell_type": "markdown",
|
450 |
"source": [
|
451 |
-
"**Test
|
452 |
],
|
453 |
"metadata": {
|
454 |
"collapsed": false
|
@@ -456,38 +453,84 @@
|
|
456 |
},
|
457 |
{
|
458 |
"cell_type": "code",
|
459 |
-
"execution_count":
|
460 |
"outputs": [],
|
461 |
"source": [
|
462 |
-
"#
|
463 |
-
"
|
464 |
-
"
|
465 |
"\n",
|
466 |
-
"
|
467 |
-
"
|
468 |
-
")"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
469 |
],
|
470 |
"metadata": {
|
471 |
"collapsed": false,
|
472 |
"ExecuteTime": {
|
473 |
-
"end_time": "2024-02-
|
474 |
-
"start_time": "2024-02-
|
475 |
}
|
476 |
}
|
477 |
},
|
478 |
{
|
479 |
"cell_type": "code",
|
480 |
-
"execution_count":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
481 |
"outputs": [],
|
482 |
"source": [
|
483 |
-
"#
|
484 |
-
"
|
|
|
|
|
|
|
|
|
|
|
485 |
],
|
486 |
"metadata": {
|
487 |
"collapsed": false,
|
488 |
"ExecuteTime": {
|
489 |
-
"end_time": "2024-02-21T11:36:
|
490 |
-
"start_time": "2024-02-21T11:36:
|
491 |
}
|
492 |
}
|
493 |
},
|
@@ -504,9 +547,8 @@
|
|
504 |
}
|
505 |
],
|
506 |
"source": [
|
507 |
-
"# Run
|
508 |
-
"
|
509 |
-
"print(output)\n"
|
510 |
],
|
511 |
"metadata": {
|
512 |
"collapsed": false,
|
|
|
399 |
},
|
400 |
{
|
401 |
"cell_type": "code",
|
402 |
+
"execution_count": 1,
|
403 |
"outputs": [],
|
404 |
"source": [
|
405 |
"from huggingface_hub import hf_hub_download\n",
|
406 |
"import joblib\n",
|
407 |
"import torch\n",
|
408 |
+
"import numpy as np\n",
|
409 |
"\n",
|
410 |
+
"from src.model import ECG2HRV"
|
411 |
],
|
412 |
"metadata": {
|
413 |
+
"collapsed": false
|
|
|
|
|
|
|
|
|
414 |
}
|
415 |
},
|
416 |
{
|
|
|
424 |
},
|
425 |
{
|
426 |
"cell_type": "code",
|
427 |
+
"execution_count": 2,
|
428 |
"outputs": [],
|
429 |
"source": [
|
430 |
"# Instantiate model\n",
|
431 |
+
"model = ECG2HRV()\n",
|
432 |
"# Save\n",
|
433 |
"joblib.dump(model, \"..\\ECG2HRV.joblib\")\n",
|
434 |
"# Load in notebook\n",
|
|
|
437 |
"metadata": {
|
438 |
"collapsed": false,
|
439 |
"ExecuteTime": {
|
440 |
+
"end_time": "2024-02-21T16:08:51.659030Z",
|
441 |
+
"start_time": "2024-02-21T16:08:51.605730100Z"
|
442 |
}
|
443 |
}
|
444 |
},
|
445 |
{
|
446 |
"cell_type": "markdown",
|
447 |
"source": [
|
448 |
+
"**Test the model locally with random ecg**"
|
449 |
],
|
450 |
"metadata": {
|
451 |
"collapsed": false
|
|
|
453 |
},
|
454 |
{
|
455 |
"cell_type": "code",
|
456 |
+
"execution_count": 3,
|
457 |
"outputs": [],
|
458 |
"source": [
|
459 |
+
"duration_seconds = 10 # Time duration for ECG signal (in seconds)\n",
|
460 |
+
"sample_rate = 100 # Sample rate (samples per second)\n",
|
461 |
+
"num_samples = duration_seconds * sample_rate # Number of samples\n",
|
462 |
"\n",
|
463 |
+
"t = np.linspace(0, duration_seconds, num_samples) # Time array\n",
|
464 |
+
"\n",
|
465 |
+
"# Generate ECG signal (example synthetic data)\n",
|
466 |
+
"ecg_signal = (\n",
|
467 |
+
" 0.2 * np.sin(2 * np.pi * 1 * t) +\n",
|
468 |
+
" 0.5 * np.sin(2 * np.pi * 0.5 * t) -\n",
|
469 |
+
" 0.1 * np.sin(2 * np.pi * 2.5 * t)\n",
|
470 |
+
")\n",
|
471 |
+
"\n",
|
472 |
+
"# Add some random noise\n",
|
473 |
+
"ecg_signal += np.random.normal(scale=0.1, size=num_samples)"
|
474 |
],
|
475 |
"metadata": {
|
476 |
"collapsed": false,
|
477 |
"ExecuteTime": {
|
478 |
+
"end_time": "2024-02-21T16:08:51.669938Z",
|
479 |
+
"start_time": "2024-02-21T16:08:51.635032600Z"
|
480 |
}
|
481 |
}
|
482 |
},
|
483 |
{
|
484 |
"cell_type": "code",
|
485 |
+
"execution_count": 4,
|
486 |
+
"outputs": [
|
487 |
+
{
|
488 |
+
"data": {
|
489 |
+
"text/plain": "[{'HRV_MeanNN': 413.4782608695652,\n 'HRV_SDNN': 100.97743652790477,\n 'HRV_SDANN1': nan,\n 'HRV_SDNNI1': nan,\n 'HRV_SDANN2': nan,\n 'HRV_SDNNI2': nan,\n 'HRV_SDANN5': nan,\n 'HRV_SDNNI5': nan,\n 'HRV_RMSSD': 92.78518690551262,\n 'HRV_SDSD': 94.96410805236795,\n 'HRV_CVNN': 0.24421462041449105,\n 'HRV_CVSD': 0.22440160870944167,\n 'HRV_MedianNN': 400.0,\n 'HRV_MadNN': 118.60799999999999,\n 'HRV_MCVNN': 0.29651999999999995,\n 'HRV_IQRNN': 150.0,\n 'HRV_SDRMSSD': 1.0882926455785953,\n 'HRV_Prc20NN': 320.0,\n 'HRV_Prc80NN': 490.0,\n 'HRV_pNN50': 52.17391304347826,\n 'HRV_pNN20': 69.56521739130434,\n 'HRV_MinNN': 310.0,\n 'HRV_MaxNN': 640.0,\n 'HRV_HTI': 5.75,\n 'HRV_TINN': 0.0}]"
|
490 |
+
},
|
491 |
+
"execution_count": 4,
|
492 |
+
"metadata": {},
|
493 |
+
"output_type": "execute_result"
|
494 |
+
}
|
495 |
+
],
|
496 |
+
"source": [
|
497 |
+
"model(input_data=ecg_signal, frequency=100.0)"
|
498 |
+
],
|
499 |
+
"metadata": {
|
500 |
+
"collapsed": false,
|
501 |
+
"ExecuteTime": {
|
502 |
+
"end_time": "2024-02-21T16:08:51.755181400Z",
|
503 |
+
"start_time": "2024-02-21T16:08:51.671014900Z"
|
504 |
+
}
|
505 |
+
}
|
506 |
+
},
|
507 |
+
{
|
508 |
+
"cell_type": "markdown",
|
509 |
+
"source": [
|
510 |
+
"**Test if the model can be loaded from the hub and used**"
|
511 |
+
],
|
512 |
+
"metadata": {
|
513 |
+
"collapsed": false
|
514 |
+
}
|
515 |
+
},
|
516 |
+
{
|
517 |
+
"cell_type": "code",
|
518 |
+
"execution_count": 19,
|
519 |
"outputs": [],
|
520 |
"source": [
|
521 |
+
"# Load from hub\n",
|
522 |
+
"REPO_ID = \"HUBII-Platform/ECG2HRV\"\n",
|
523 |
+
"FILENAME = \"feature-extractor.joblib\"\n",
|
524 |
+
"\n",
|
525 |
+
"model = joblib.load(\n",
|
526 |
+
" hf_hub_download(repo_id=REPO_ID, filename=FILENAME)\n",
|
527 |
+
")"
|
528 |
],
|
529 |
"metadata": {
|
530 |
"collapsed": false,
|
531 |
"ExecuteTime": {
|
532 |
+
"end_time": "2024-02-21T11:36:52.302912800Z",
|
533 |
+
"start_time": "2024-02-21T11:36:52.145834500Z"
|
534 |
}
|
535 |
}
|
536 |
},
|
|
|
547 |
}
|
548 |
],
|
549 |
"source": [
|
550 |
+
"# Run model\n",
|
551 |
+
"model(input_data=ecg_signal, frequency=100.0)"
|
|
|
552 |
],
|
553 |
"metadata": {
|
554 |
"collapsed": false,
|
feature-extractor.joblib → notebooks/feature-extractor.joblib
RENAMED
File without changes
|
src/ecg2hrv.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pandas as pd
|
3 |
+
import neurokit2 as nk
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from src.feature_extractor import FeatureExtractor
|
7 |
+
|
8 |
+
|
9 |
+
class ECG2HRV(FeatureExtractor):
|
10 |
+
def __init__(self):
|
11 |
+
super().__init__()
|
12 |
+
|
13 |
+
def extract_features(self, ecg, frequency, baseline=None, normalization_method=None):
|
14 |
+
# Ensure the numpy has at least one dimension (i.e. is not a scalar)
|
15 |
+
if ecg.ndim < 1:
|
16 |
+
raise ValueError("Array must have at least one dimension")
|
17 |
+
|
18 |
+
# Preprocess the ecg signal
|
19 |
+
ecg = nk.ecg_clean(ecg_signal=ecg, sampling_rate=frequency, method="pantompkins1985")
|
20 |
+
|
21 |
+
# Compute the HRV features
|
22 |
+
features = self.get_hrv_features(ecg, frequency)
|
23 |
+
|
24 |
+
# Normalize if baseline is available and method is set - TBD
|
25 |
+
if baseline is not None and normalization_method is not None:
|
26 |
+
baseline_features = self.get_hrv_features(baseline)
|
27 |
+
features = self.normalize_features(features, baseline_features)
|
28 |
+
|
29 |
+
return features
|
30 |
+
|
31 |
+
def get_hrv_features(self, ecg, frequency):
|
32 |
+
# Find peaks
|
33 |
+
peaks, info = nk.ecg_peaks(ecg, sampling_rate=frequency, method="pantompkins1985")
|
34 |
+
|
35 |
+
# Compute time domain features
|
36 |
+
hrv_time_features = nk.hrv_time(peaks, sampling_rate=frequency)
|
37 |
+
# Compute frequency domain features
|
38 |
+
# hrv_frequency_features = nk.hrv_frequency(peaks, sampling_rate=fs, method="welch", show=False)
|
39 |
+
|
40 |
+
# Concat features
|
41 |
+
# hrv_features = pd.concat([hrv_time_features, hrv_frequency_features], axis=1)
|
42 |
+
hrv_features = hrv_time_features
|
43 |
+
|
44 |
+
return hrv_features.to_dict(orient="records")
|
45 |
+
|
46 |
+
def normalize_features(self, features, baseline_features, normalization_method=None):
|
47 |
+
if normalization_method == "difference":
|
48 |
+
features = features - baseline_features
|
49 |
+
elif normalization_method == "relative":
|
50 |
+
features = features / baseline_features
|
51 |
+
else:
|
52 |
+
raise ValueError(f"Normalization method {normalization_method} not supported")
|
53 |
+
|
54 |
+
return features
|
src/feature_extractor.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
|
3 |
+
|
4 |
+
class FeatureExtractor(ABC):
|
5 |
+
def __init__(self):
|
6 |
+
pass
|
7 |
+
|
8 |
+
def __call__(self, input_data, frequency, baseline_data=None, normalization_method=None):
|
9 |
+
return self.extract_features(input_data, frequency, baseline_data, normalization_method)
|
10 |
+
|
11 |
+
@abstractmethod
|
12 |
+
def extract_features(self, input_data, baseline_data, frequency, normalization_method):
|
13 |
+
pass
|
14 |
+
|
15 |
+
@abstractmethod
|
16 |
+
def normalize_features(self, features, baseline_features=None):
|
17 |
+
pass
|