nina-m-m
commited on
Commit
•
ec3f61b
1
Parent(s):
875bdf8
Implement inference pipeline
Browse files- pipeline.py +51 -11
pipeline.py
CHANGED
@@ -1,18 +1,52 @@
|
|
|
|
1 |
from typing import Dict, List, Union
|
2 |
-
|
|
|
|
|
|
|
|
|
3 |
|
4 |
class PreTrainedPipeline():
|
5 |
def __init__(self, path=""):
|
6 |
-
# IMPLEMENT_THIS
|
7 |
# Preload all the elements you are going to need at inference.
|
8 |
# For instance your model, processors, tokenizer that might be needed.
|
9 |
-
# This function is only called once, so do all the heavy processing I/O here
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
def __call__(
|
15 |
-
|
16 |
) -> List[Union[str, float]]:
|
17 |
"""
|
18 |
Args:
|
@@ -22,7 +56,13 @@ class PreTrainedPipeline():
|
|
22 |
Return:
|
23 |
A :obj:`list` of floats or strings: The classification output for each row.
|
24 |
"""
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
from typing import Dict, List, Union
|
3 |
+
|
4 |
+
from src.conversion import csv_to_pandas
|
5 |
+
from src.ecg_processing import process_batch
|
6 |
+
from src.pydantic_models import ECGConfig, ECGSample
|
7 |
+
|
8 |
|
9 |
class PreTrainedPipeline():
|
10 |
def __init__(self, path=""):
|
|
|
11 |
# Preload all the elements you are going to need at inference.
|
12 |
# For instance your model, processors, tokenizer that might be needed.
|
13 |
+
# This function is only called once, so do all the heavy processing I/O here
|
14 |
+
self.path = path
|
15 |
+
self.df = None # Placeholder for the DataFrame
|
16 |
+
|
17 |
+
if path:
|
18 |
+
self.load_data()
|
19 |
+
|
20 |
+
def load_data(self):
|
21 |
+
# Load CSV file into DataFrame
|
22 |
+
self.df = csv_to_pandas(self.path)
|
23 |
+
|
24 |
+
def process_data(self):
|
25 |
+
# Read csv file
|
26 |
+
df = self.df
|
27 |
+
# Implode
|
28 |
+
cols_to_implode = ['timestamp_idx', 'ecg', 'label']
|
29 |
+
df_imploded = df.groupby(list(set(df.columns) - set(cols_to_implode))) \
|
30 |
+
.agg({'timestamp_idx': list,
|
31 |
+
'ecg': list,
|
32 |
+
'label': list}) \
|
33 |
+
.reset_index()
|
34 |
+
# Get metadata
|
35 |
+
config_cols = [col for col in df.columns if col.startswith('configs.')]
|
36 |
+
configs = df_imploded[config_cols].iloc[0].to_dict()
|
37 |
+
configs = {key.removeprefix('configs.'): value for key, value in configs.items()}
|
38 |
+
configs = ECGConfig(**configs)
|
39 |
+
batch_cols = [col for col in df.columns if col.startswith('batch.')]
|
40 |
+
batch = df_imploded[batch_cols].iloc[0].to_dict()
|
41 |
+
batch = {key.removeprefix('batch.'): value for key, value in batch.items()}
|
42 |
+
# Get samples
|
43 |
+
samples = df_imploded.to_dict(orient='records')
|
44 |
+
samples = [ECGSample(**sample) for sample in samples]
|
45 |
+
|
46 |
+
features_df = process_batch(samples, configs)
|
47 |
|
48 |
def __call__(
|
49 |
+
self, inputs: Dict[str, Dict[str, List[Union[str, float]]]]
|
50 |
) -> List[Union[str, float]]:
|
51 |
"""
|
52 |
Args:
|
|
|
56 |
Return:
|
57 |
A :obj:`list` of floats or strings: The classification output for each row.
|
58 |
"""
|
59 |
+
if not self.df:
|
60 |
+
raise ValueError("No data loaded. Please provide a valid CSV path.")
|
61 |
+
|
62 |
+
# Implement your processing logic here, if needed
|
63 |
+
self.process_data()
|
64 |
+
|
65 |
+
# Assuming you want to return a list of strings or floats from the DataFrame
|
66 |
+
result = self.df.values.flatten().tolist()
|
67 |
+
|
68 |
+
return result
|