nina-m-m commited on
Commit
ec3f61b
1 Parent(s): 875bdf8

Implement inference pipeline

Browse files
Files changed (1) hide show
  1. pipeline.py +51 -11
pipeline.py CHANGED
@@ -1,18 +1,52 @@
 
1
  from typing import Dict, List, Union
2
- import os
 
 
 
 
3
 
4
  class PreTrainedPipeline():
5
  def __init__(self, path=""):
6
- # IMPLEMENT_THIS
7
  # Preload all the elements you are going to need at inference.
8
  # For instance your model, processors, tokenizer that might be needed.
9
- # This function is only called once, so do all the heavy processing I/O here"""
10
- raise NotImplementedError(
11
- "Please implement PreTrainedPipeline __init__ function"
12
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  def __call__(
15
- self, inputs: Dict[str, Dict[str, List[Union[str, float]]]]
16
  ) -> List[Union[str, float]]:
17
  """
18
  Args:
@@ -22,7 +56,13 @@ class PreTrainedPipeline():
22
  Return:
23
  A :obj:`list` of floats or strings: The classification output for each row.
24
  """
25
- # IMPLEMENT_THIS
26
- raise NotImplementedError(
27
- "Please implement PreTrainedPipeline __call__ function"
28
- )
 
 
 
 
 
 
 
1
+ import pandas as pd
2
  from typing import Dict, List, Union
3
+
4
+ from src.conversion import csv_to_pandas
5
+ from src.ecg_processing import process_batch
6
+ from src.pydantic_models import ECGConfig, ECGSample
7
+
8
 
9
  class PreTrainedPipeline():
10
  def __init__(self, path=""):
 
11
  # Preload all the elements you are going to need at inference.
12
  # For instance your model, processors, tokenizer that might be needed.
13
+ # This function is only called once, so do all the heavy processing I/O here
14
+ self.path = path
15
+ self.df = None # Placeholder for the DataFrame
16
+
17
+ if path:
18
+ self.load_data()
19
+
20
+ def load_data(self):
21
+ # Load CSV file into DataFrame
22
+ self.df = csv_to_pandas(self.path)
23
+
24
+ def process_data(self):
25
+ # Read csv file
26
+ df = self.df
27
+ # Implode
28
+ cols_to_implode = ['timestamp_idx', 'ecg', 'label']
29
+ df_imploded = df.groupby(list(set(df.columns) - set(cols_to_implode))) \
30
+ .agg({'timestamp_idx': list,
31
+ 'ecg': list,
32
+ 'label': list}) \
33
+ .reset_index()
34
+ # Get metadata
35
+ config_cols = [col for col in df.columns if col.startswith('configs.')]
36
+ configs = df_imploded[config_cols].iloc[0].to_dict()
37
+ configs = {key.removeprefix('configs.'): value for key, value in configs.items()}
38
+ configs = ECGConfig(**configs)
39
+ batch_cols = [col for col in df.columns if col.startswith('batch.')]
40
+ batch = df_imploded[batch_cols].iloc[0].to_dict()
41
+ batch = {key.removeprefix('batch.'): value for key, value in batch.items()}
42
+ # Get samples
43
+ samples = df_imploded.to_dict(orient='records')
44
+ samples = [ECGSample(**sample) for sample in samples]
45
+
46
+ features_df = process_batch(samples, configs)
47
 
48
  def __call__(
49
+ self, inputs: Dict[str, Dict[str, List[Union[str, float]]]]
50
  ) -> List[Union[str, float]]:
51
  """
52
  Args:
 
56
  Return:
57
  A :obj:`list` of floats or strings: The classification output for each row.
58
  """
59
+ if not self.df:
60
+ raise ValueError("No data loaded. Please provide a valid CSV path.")
61
+
62
+ # Implement your processing logic here, if needed
63
+ self.process_data()
64
+
65
+ # Assuming you want to return a list of strings or floats from the DataFrame
66
+ result = self.df.values.flatten().tolist()
67
+
68
+ return result