nina-m-m commited on
Commit
44cdef4
1 Parent(s): bf0f149

Update model and implement abstract class

Browse files
ECG2HRV.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3669bf4b201fd873f1fb1b48083c41f26c42be00f58d14d234af6b9aac1f6433
3
+ size 39
notebooks/01_Model_Deployment_Research.ipynb CHANGED
@@ -399,21 +399,18 @@
399
  },
400
  {
401
  "cell_type": "code",
402
- "execution_count": 25,
403
  "outputs": [],
404
  "source": [
405
  "from huggingface_hub import hf_hub_download\n",
406
  "import joblib\n",
407
  "import torch\n",
 
408
  "\n",
409
- "from src.model import HR2HRV"
410
  ],
411
  "metadata": {
412
- "collapsed": false,
413
- "ExecuteTime": {
414
- "end_time": "2024-02-21T11:39:25.775871100Z",
415
- "start_time": "2024-02-21T11:39:25.755838Z"
416
- }
417
  }
418
  },
419
  {
@@ -427,11 +424,11 @@
427
  },
428
  {
429
  "cell_type": "code",
430
- "execution_count": 27,
431
  "outputs": [],
432
  "source": [
433
  "# Instantiate model\n",
434
- "model = HR2HRV()\n",
435
  "# Save\n",
436
  "joblib.dump(model, \"..\\ECG2HRV.joblib\")\n",
437
  "# Load in notebook\n",
@@ -440,15 +437,15 @@
440
  "metadata": {
441
  "collapsed": false,
442
  "ExecuteTime": {
443
- "end_time": "2024-02-21T12:01:28.600527100Z",
444
- "start_time": "2024-02-21T12:01:28.580278200Z"
445
  }
446
  }
447
  },
448
  {
449
  "cell_type": "markdown",
450
  "source": [
451
- "**Test if the model can be loaded from the hub and used**"
452
  ],
453
  "metadata": {
454
  "collapsed": false
@@ -456,38 +453,84 @@
456
  },
457
  {
458
  "cell_type": "code",
459
- "execution_count": 19,
460
  "outputs": [],
461
  "source": [
462
- "# Load from hub\n",
463
- "REPO_ID = \"HUBII-Platform/ECG2HRV\"\n",
464
- "FILENAME = \"feature-extractor.joblib\"\n",
465
  "\n",
466
- "model = joblib.load(\n",
467
- " hf_hub_download(repo_id=REPO_ID, filename=FILENAME)\n",
468
- ")"
 
 
 
 
 
 
 
 
469
  ],
470
  "metadata": {
471
  "collapsed": false,
472
  "ExecuteTime": {
473
- "end_time": "2024-02-21T11:36:52.302912800Z",
474
- "start_time": "2024-02-21T11:36:52.145834500Z"
475
  }
476
  }
477
  },
478
  {
479
  "cell_type": "code",
480
- "execution_count": 20,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
481
  "outputs": [],
482
  "source": [
483
- "# Create a example tensor input\n",
484
- "tensor = torch.tensor([2.0, 3.0, 4.0])"
 
 
 
 
 
485
  ],
486
  "metadata": {
487
  "collapsed": false,
488
  "ExecuteTime": {
489
- "end_time": "2024-02-21T11:36:55.990475100Z",
490
- "start_time": "2024-02-21T11:36:55.989181100Z"
491
  }
492
  }
493
  },
@@ -504,9 +547,8 @@
504
  }
505
  ],
506
  "source": [
507
- "# Run forward pass\n",
508
- "output = model.forward(tensor)\n",
509
- "print(output)\n"
510
  ],
511
  "metadata": {
512
  "collapsed": false,
 
399
  },
400
  {
401
  "cell_type": "code",
402
+ "execution_count": 1,
403
  "outputs": [],
404
  "source": [
405
  "from huggingface_hub import hf_hub_download\n",
406
  "import joblib\n",
407
  "import torch\n",
408
+ "import numpy as np\n",
409
  "\n",
410
+ "from src.model import ECG2HRV"
411
  ],
412
  "metadata": {
413
+ "collapsed": false
 
 
 
 
414
  }
415
  },
416
  {
 
424
  },
425
  {
426
  "cell_type": "code",
427
+ "execution_count": 2,
428
  "outputs": [],
429
  "source": [
430
  "# Instantiate model\n",
431
+ "model = ECG2HRV()\n",
432
  "# Save\n",
433
  "joblib.dump(model, \"..\\ECG2HRV.joblib\")\n",
434
  "# Load in notebook\n",
 
437
  "metadata": {
438
  "collapsed": false,
439
  "ExecuteTime": {
440
+ "end_time": "2024-02-21T16:08:51.659030Z",
441
+ "start_time": "2024-02-21T16:08:51.605730100Z"
442
  }
443
  }
444
  },
445
  {
446
  "cell_type": "markdown",
447
  "source": [
448
+ "**Test the model locally with random ecg**"
449
  ],
450
  "metadata": {
451
  "collapsed": false
 
453
  },
454
  {
455
  "cell_type": "code",
456
+ "execution_count": 3,
457
  "outputs": [],
458
  "source": [
459
+ "duration_seconds = 10 # Time duration for ECG signal (in seconds)\n",
460
+ "sample_rate = 100 # Sample rate (samples per second)\n",
461
+ "num_samples = duration_seconds * sample_rate # Number of samples\n",
462
  "\n",
463
+ "t = np.linspace(0, duration_seconds, num_samples) # Time array\n",
464
+ "\n",
465
+ "# Generate ECG signal (example synthetic data)\n",
466
+ "ecg_signal = (\n",
467
+ " 0.2 * np.sin(2 * np.pi * 1 * t) +\n",
468
+ " 0.5 * np.sin(2 * np.pi * 0.5 * t) -\n",
469
+ " 0.1 * np.sin(2 * np.pi * 2.5 * t)\n",
470
+ ")\n",
471
+ "\n",
472
+ "# Add some random noise\n",
473
+ "ecg_signal += np.random.normal(scale=0.1, size=num_samples)"
474
  ],
475
  "metadata": {
476
  "collapsed": false,
477
  "ExecuteTime": {
478
+ "end_time": "2024-02-21T16:08:51.669938Z",
479
+ "start_time": "2024-02-21T16:08:51.635032600Z"
480
  }
481
  }
482
  },
483
  {
484
  "cell_type": "code",
485
+ "execution_count": 4,
486
+ "outputs": [
487
+ {
488
+ "data": {
489
+ "text/plain": "[{'HRV_MeanNN': 413.4782608695652,\n 'HRV_SDNN': 100.97743652790477,\n 'HRV_SDANN1': nan,\n 'HRV_SDNNI1': nan,\n 'HRV_SDANN2': nan,\n 'HRV_SDNNI2': nan,\n 'HRV_SDANN5': nan,\n 'HRV_SDNNI5': nan,\n 'HRV_RMSSD': 92.78518690551262,\n 'HRV_SDSD': 94.96410805236795,\n 'HRV_CVNN': 0.24421462041449105,\n 'HRV_CVSD': 0.22440160870944167,\n 'HRV_MedianNN': 400.0,\n 'HRV_MadNN': 118.60799999999999,\n 'HRV_MCVNN': 0.29651999999999995,\n 'HRV_IQRNN': 150.0,\n 'HRV_SDRMSSD': 1.0882926455785953,\n 'HRV_Prc20NN': 320.0,\n 'HRV_Prc80NN': 490.0,\n 'HRV_pNN50': 52.17391304347826,\n 'HRV_pNN20': 69.56521739130434,\n 'HRV_MinNN': 310.0,\n 'HRV_MaxNN': 640.0,\n 'HRV_HTI': 5.75,\n 'HRV_TINN': 0.0}]"
490
+ },
491
+ "execution_count": 4,
492
+ "metadata": {},
493
+ "output_type": "execute_result"
494
+ }
495
+ ],
496
+ "source": [
497
+ "model(input_data=ecg_signal, frequency=100.0)"
498
+ ],
499
+ "metadata": {
500
+ "collapsed": false,
501
+ "ExecuteTime": {
502
+ "end_time": "2024-02-21T16:08:51.755181400Z",
503
+ "start_time": "2024-02-21T16:08:51.671014900Z"
504
+ }
505
+ }
506
+ },
507
+ {
508
+ "cell_type": "markdown",
509
+ "source": [
510
+ "**Test if the model can be loaded from the hub and used**"
511
+ ],
512
+ "metadata": {
513
+ "collapsed": false
514
+ }
515
+ },
516
+ {
517
+ "cell_type": "code",
518
+ "execution_count": 19,
519
  "outputs": [],
520
  "source": [
521
+ "# Load from hub\n",
522
+ "REPO_ID = \"HUBII-Platform/ECG2HRV\"\n",
523
+ "FILENAME = \"feature-extractor.joblib\"\n",
524
+ "\n",
525
+ "model = joblib.load(\n",
526
+ " hf_hub_download(repo_id=REPO_ID, filename=FILENAME)\n",
527
+ ")"
528
  ],
529
  "metadata": {
530
  "collapsed": false,
531
  "ExecuteTime": {
532
+ "end_time": "2024-02-21T11:36:52.302912800Z",
533
+ "start_time": "2024-02-21T11:36:52.145834500Z"
534
  }
535
  }
536
  },
 
547
  }
548
  ],
549
  "source": [
550
+ "# Run model\n",
551
+ "model(input_data=ecg_signal, frequency=100.0)"
 
552
  ],
553
  "metadata": {
554
  "collapsed": false,
feature-extractor.joblib → notebooks/feature-extractor.joblib RENAMED
File without changes
src/ecg2hrv.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import neurokit2 as nk
4
+ import torch
5
+
6
+ from src.feature_extractor import FeatureExtractor
7
+
8
+
9
+ class ECG2HRV(FeatureExtractor):
10
+ def __init__(self):
11
+ super().__init__()
12
+
13
+ def extract_features(self, ecg, frequency, baseline=None, normalization_method=None):
14
+ # Ensure the numpy has at least one dimension (i.e. is not a scalar)
15
+ if ecg.ndim < 1:
16
+ raise ValueError("Array must have at least one dimension")
17
+
18
+ # Preprocess the ecg signal
19
+ ecg = nk.ecg_clean(ecg_signal=ecg, sampling_rate=frequency, method="pantompkins1985")
20
+
21
+ # Compute the HRV features
22
+ features = self.get_hrv_features(ecg, frequency)
23
+
24
+ # Normalize if baseline is available and method is set - TBD
25
+ if baseline is not None and normalization_method is not None:
26
+ baseline_features = self.get_hrv_features(baseline)
27
+ features = self.normalize_features(features, baseline_features)
28
+
29
+ return features
30
+
31
+ def get_hrv_features(self, ecg, frequency):
32
+ # Find peaks
33
+ peaks, info = nk.ecg_peaks(ecg, sampling_rate=frequency, method="pantompkins1985")
34
+
35
+ # Compute time domain features
36
+ hrv_time_features = nk.hrv_time(peaks, sampling_rate=frequency)
37
+ # Compute frequency domain features
38
+ # hrv_frequency_features = nk.hrv_frequency(peaks, sampling_rate=fs, method="welch", show=False)
39
+
40
+ # Concat features
41
+ # hrv_features = pd.concat([hrv_time_features, hrv_frequency_features], axis=1)
42
+ hrv_features = hrv_time_features
43
+
44
+ return hrv_features.to_dict(orient="records")
45
+
46
+ def normalize_features(self, features, baseline_features, normalization_method=None):
47
+ if normalization_method == "difference":
48
+ features = features - baseline_features
49
+ elif normalization_method == "relative":
50
+ features = features / baseline_features
51
+ else:
52
+ raise ValueError(f"Normalization method {normalization_method} not supported")
53
+
54
+ return features
src/feature_extractor.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+
4
+ class FeatureExtractor(ABC):
5
+ def __init__(self):
6
+ pass
7
+
8
+ def __call__(self, input_data, frequency, baseline_data=None, normalization_method=None):
9
+ return self.extract_features(input_data, frequency, baseline_data, normalization_method)
10
+
11
+ @abstractmethod
12
+ def extract_features(self, input_data, baseline_data, frequency, normalization_method):
13
+ pass
14
+
15
+ @abstractmethod
16
+ def normalize_features(self, features, baseline_features=None):
17
+ pass