nina-m-m
commited on
Commit
•
0f64eae
1
Parent(s):
9e4b4a3
Refactor source code with new organization name and update README
Browse files- README.md +5 -4
- notebooks/01_Model_Deployment_Research.ipynb +44 -29
- src/ecg2hrv.py +0 -2
- src/feature_extractor.py +1 -1
README.md
CHANGED
@@ -18,8 +18,8 @@ from huggingface_hub import hf_hub_download
|
|
18 |
import joblib
|
19 |
|
20 |
# Define parameters
|
21 |
-
REPO_ID = "
|
22 |
-
FILENAME = "
|
23 |
|
24 |
# Load the model
|
25 |
model = joblib.load(
|
@@ -28,6 +28,7 @@ model = joblib.load(
|
|
28 |
```
|
29 |
Example usage of the model:
|
30 |
```python
|
31 |
-
#
|
32 |
-
|
|
|
33 |
```
|
|
|
18 |
import joblib
|
19 |
|
20 |
# Define parameters
|
21 |
+
REPO_ID = "hubii-world/ECG2HRV"
|
22 |
+
FILENAME = "ECG2HRV.joblib"
|
23 |
|
24 |
# Load the model
|
25 |
model = joblib.load(
|
|
|
28 |
```
|
29 |
Example usage of the model:
|
30 |
```python
|
31 |
+
# ecg should be a 1D numpy array with the ECG signal
|
32 |
+
hrv_features = model(input_data=ecg, frequency=100.0)
|
33 |
+
# returns hrv_features in a dictionary with the feature names as keys
|
34 |
```
|
notebooks/01_Model_Deployment_Research.ipynb
CHANGED
@@ -40,7 +40,9 @@
|
|
40 |
"source": [
|
41 |
"import timm\n",
|
42 |
"import torch\n",
|
43 |
-
"from
|
|
|
|
|
44 |
],
|
45 |
"metadata": {
|
46 |
"collapsed": false,
|
@@ -248,7 +250,6 @@
|
|
248 |
],
|
249 |
"source": [
|
250 |
"# Example with pipeline\n",
|
251 |
-
"from transformers import pipeline\n",
|
252 |
"checkpoint = \"facebook/bart-base\"\n",
|
253 |
"feature_extractor = pipeline(\"feature-extraction\", framework=\"pt\",model=checkpoint)\n",
|
254 |
"text = \"Transformers is an awesome library!\""
|
@@ -366,7 +367,6 @@
|
|
366 |
],
|
367 |
"source": [
|
368 |
"# Example with AutoModel\n",
|
369 |
-
"from transformers import AutoTokenizer, AutoModel\n",
|
370 |
"model = AutoModel.from_pretrained('HUBII-Platform/ECG2HRV')"
|
371 |
],
|
372 |
"metadata": {
|
@@ -398,25 +398,23 @@
|
|
398 |
}
|
399 |
},
|
400 |
{
|
401 |
-
"cell_type": "
|
402 |
-
"execution_count": 1,
|
403 |
-
"outputs": [],
|
404 |
"source": [
|
405 |
-
"
|
406 |
-
"import joblib\n",
|
407 |
-
"import torch\n",
|
408 |
-
"import numpy as np\n",
|
409 |
-
"\n",
|
410 |
-
"from src.model import ECG2HRV"
|
411 |
],
|
412 |
"metadata": {
|
413 |
"collapsed": false
|
414 |
}
|
415 |
},
|
416 |
{
|
417 |
-
"cell_type": "
|
|
|
|
|
418 |
"source": [
|
419 |
-
"
|
|
|
|
|
|
|
420 |
],
|
421 |
"metadata": {
|
422 |
"collapsed": false
|
@@ -507,7 +505,7 @@
|
|
507 |
{
|
508 |
"cell_type": "markdown",
|
509 |
"source": [
|
510 |
-
"**Test
|
511 |
],
|
512 |
"metadata": {
|
513 |
"collapsed": false
|
@@ -515,12 +513,28 @@
|
|
515 |
},
|
516 |
{
|
517 |
"cell_type": "code",
|
518 |
-
"execution_count":
|
519 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
520 |
"source": [
|
|
|
|
|
|
|
521 |
"# Load from hub\n",
|
522 |
-
"REPO_ID = \"
|
523 |
-
"FILENAME = \"
|
524 |
"\n",
|
525 |
"model = joblib.load(\n",
|
526 |
" hf_hub_download(repo_id=REPO_ID, filename=FILENAME)\n",
|
@@ -529,21 +543,22 @@
|
|
529 |
"metadata": {
|
530 |
"collapsed": false,
|
531 |
"ExecuteTime": {
|
532 |
-
"end_time": "2024-02-
|
533 |
-
"start_time": "2024-02-
|
534 |
}
|
535 |
}
|
536 |
},
|
537 |
{
|
538 |
"cell_type": "code",
|
539 |
-
"execution_count":
|
540 |
"outputs": [
|
541 |
{
|
542 |
-
"
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
|
|
547 |
}
|
548 |
],
|
549 |
"source": [
|
@@ -553,8 +568,8 @@
|
|
553 |
"metadata": {
|
554 |
"collapsed": false,
|
555 |
"ExecuteTime": {
|
556 |
-
"end_time": "2024-02-
|
557 |
-
"start_time": "2024-02-
|
558 |
}
|
559 |
}
|
560 |
},
|
|
|
40 |
"source": [
|
41 |
"import timm\n",
|
42 |
"import torch\n",
|
43 |
+
"from transformers import pipeline, AutoTokenizer, AutoModel\n",
|
44 |
+
"\n",
|
45 |
+
"from src.deprecated.pipeline_wrapper import MyPipeline"
|
46 |
],
|
47 |
"metadata": {
|
48 |
"collapsed": false,
|
|
|
250 |
],
|
251 |
"source": [
|
252 |
"# Example with pipeline\n",
|
|
|
253 |
"checkpoint = \"facebook/bart-base\"\n",
|
254 |
"feature_extractor = pipeline(\"feature-extraction\", framework=\"pt\",model=checkpoint)\n",
|
255 |
"text = \"Transformers is an awesome library!\""
|
|
|
367 |
],
|
368 |
"source": [
|
369 |
"# Example with AutoModel\n",
|
|
|
370 |
"model = AutoModel.from_pretrained('HUBII-Platform/ECG2HRV')"
|
371 |
],
|
372 |
"metadata": {
|
|
|
398 |
}
|
399 |
},
|
400 |
{
|
401 |
+
"cell_type": "markdown",
|
|
|
|
|
402 |
"source": [
|
403 |
+
"**Instantiate model and save the model as a joblib file in the huggingface repository**"
|
|
|
|
|
|
|
|
|
|
|
404 |
],
|
405 |
"metadata": {
|
406 |
"collapsed": false
|
407 |
}
|
408 |
},
|
409 |
{
|
410 |
+
"cell_type": "code",
|
411 |
+
"execution_count": 1,
|
412 |
+
"outputs": [],
|
413 |
"source": [
|
414 |
+
"import joblib\n",
|
415 |
+
"import numpy as np\n",
|
416 |
+
"\n",
|
417 |
+
"from src.ecg2hrv import ECG2HRV"
|
418 |
],
|
419 |
"metadata": {
|
420 |
"collapsed": false
|
|
|
505 |
{
|
506 |
"cell_type": "markdown",
|
507 |
"source": [
|
508 |
+
"**Test the model loaded from the hub with random ecg**"
|
509 |
],
|
510 |
"metadata": {
|
511 |
"collapsed": false
|
|
|
513 |
},
|
514 |
{
|
515 |
"cell_type": "code",
|
516 |
+
"execution_count": 7,
|
517 |
+
"outputs": [
|
518 |
+
{
|
519 |
+
"data": {
|
520 |
+
"text/plain": "ECG2HRV.joblib: 0%| | 0.00/39.0 [00:00<?, ?B/s]",
|
521 |
+
"application/vnd.jupyter.widget-view+json": {
|
522 |
+
"version_major": 2,
|
523 |
+
"version_minor": 0,
|
524 |
+
"model_id": "aef3c2ac2c9a4d91a392ec8091d4c779"
|
525 |
+
}
|
526 |
+
},
|
527 |
+
"metadata": {},
|
528 |
+
"output_type": "display_data"
|
529 |
+
}
|
530 |
+
],
|
531 |
"source": [
|
532 |
+
"from huggingface_hub import hf_hub_download\n",
|
533 |
+
"import joblib\n",
|
534 |
+
"\n",
|
535 |
"# Load from hub\n",
|
536 |
+
"REPO_ID = \"hubii-world/ECG2HRV\"\n",
|
537 |
+
"FILENAME = \"ECG2HRV.joblib\"\n",
|
538 |
"\n",
|
539 |
"model = joblib.load(\n",
|
540 |
" hf_hub_download(repo_id=REPO_ID, filename=FILENAME)\n",
|
|
|
543 |
"metadata": {
|
544 |
"collapsed": false,
|
545 |
"ExecuteTime": {
|
546 |
+
"end_time": "2024-02-21T16:26:49.913818400Z",
|
547 |
+
"start_time": "2024-02-21T16:26:49.506802900Z"
|
548 |
}
|
549 |
}
|
550 |
},
|
551 |
{
|
552 |
"cell_type": "code",
|
553 |
+
"execution_count": 8,
|
554 |
"outputs": [
|
555 |
{
|
556 |
+
"data": {
|
557 |
+
"text/plain": "[{'HRV_MeanNN': 413.4782608695652,\n 'HRV_SDNN': 100.97743652790477,\n 'HRV_SDANN1': nan,\n 'HRV_SDNNI1': nan,\n 'HRV_SDANN2': nan,\n 'HRV_SDNNI2': nan,\n 'HRV_SDANN5': nan,\n 'HRV_SDNNI5': nan,\n 'HRV_RMSSD': 92.78518690551262,\n 'HRV_SDSD': 94.96410805236795,\n 'HRV_CVNN': 0.24421462041449105,\n 'HRV_CVSD': 0.22440160870944167,\n 'HRV_MedianNN': 400.0,\n 'HRV_MadNN': 118.60799999999999,\n 'HRV_MCVNN': 0.29651999999999995,\n 'HRV_IQRNN': 150.0,\n 'HRV_SDRMSSD': 1.0882926455785953,\n 'HRV_Prc20NN': 320.0,\n 'HRV_Prc80NN': 490.0,\n 'HRV_pNN50': 52.17391304347826,\n 'HRV_pNN20': 69.56521739130434,\n 'HRV_MinNN': 310.0,\n 'HRV_MaxNN': 640.0,\n 'HRV_HTI': 5.75,\n 'HRV_TINN': 0.0}]"
|
558 |
+
},
|
559 |
+
"execution_count": 8,
|
560 |
+
"metadata": {},
|
561 |
+
"output_type": "execute_result"
|
562 |
}
|
563 |
],
|
564 |
"source": [
|
|
|
568 |
"metadata": {
|
569 |
"collapsed": false,
|
570 |
"ExecuteTime": {
|
571 |
+
"end_time": "2024-02-21T16:26:58.064981500Z",
|
572 |
+
"start_time": "2024-02-21T16:26:58.041072600Z"
|
573 |
}
|
574 |
}
|
575 |
},
|
src/ecg2hrv.py
CHANGED
@@ -1,7 +1,5 @@
|
|
1 |
import numpy as np
|
2 |
-
import pandas as pd
|
3 |
import neurokit2 as nk
|
4 |
-
import torch
|
5 |
|
6 |
from src.feature_extractor import FeatureExtractor
|
7 |
|
|
|
1 |
import numpy as np
|
|
|
2 |
import neurokit2 as nk
|
|
|
3 |
|
4 |
from src.feature_extractor import FeatureExtractor
|
5 |
|
src/feature_extractor.py
CHANGED
@@ -13,5 +13,5 @@ class FeatureExtractor(ABC):
|
|
13 |
pass
|
14 |
|
15 |
@abstractmethod
|
16 |
-
def normalize_features(self, features, baseline_features=None):
|
17 |
pass
|
|
|
13 |
pass
|
14 |
|
15 |
@abstractmethod
|
16 |
+
def normalize_features(self, features, baseline_features, normalization_method=None):
|
17 |
pass
|