|
from transformers import Pipeline |
|
|
|
from src.conversion import csv_to_pandas |
|
from src.pydantic_models import ECGConfig, ECGSample |
|
from src.ecg_processing import process_batch |
|
|
|
|
|
class MyPipeline(Pipeline): |
|
def _sanitize_parameters(self, **kwargs): |
|
preprocess_kwargs = {} |
|
if "maybe_arg" in kwargs: |
|
preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"] |
|
return preprocess_kwargs, {}, {} |
|
|
|
def preprocess(self, inputs: str) -> dict: |
|
|
|
df = csv_to_pandas(inputs) |
|
|
|
cols_to_implode = ['timestamp_idx', 'ecg', 'label'] |
|
df_imploded = df.groupby(list(set(df.columns) - set(cols_to_implode))) \ |
|
.agg({'timestamp_idx': list, |
|
'ecg': list, |
|
'label': list}) \ |
|
.reset_index() |
|
|
|
config_cols = [col for col in df.columns if col.startswith('configs.')] |
|
configs = df_imploded[config_cols].iloc[0].to_dict() |
|
configs = {key.removeprefix('configs.'): value for key, value in configs.items()} |
|
configs = ECGConfig(**configs) |
|
batch_cols = [col for col in df.columns if col.startswith('batch.')] |
|
batch = df_imploded[batch_cols].iloc[0].to_dict() |
|
batch = {key.removeprefix('batch.'): value for key, value in batch.items()} |
|
|
|
samples = df_imploded.to_dict(orient='records') |
|
samples = [ECGSample(**sample) for sample in samples] |
|
|
|
model_input = {"samples": samples, "configs": configs, "batch": batch} |
|
return {"model_input": model_input} |
|
|
|
def _forward(self, model_inputs): |
|
|
|
samples = model_inputs["model_input"]["samples"] |
|
configs = model_inputs["model_input"]["configs"] |
|
batch = model_inputs["model_input"]["batch"] |
|
|
|
features_df = process_batch(samples, configs) |
|
return features_df |
|
|
|
def postprocess(self, model_outputs): |
|
return model_outputs |
|
|