from transformers import Pipeline from src.deprecated.conversion import csv_to_pandas from src.deprecated.pydantic_models import ECGConfig, ECGSample from src.deprecated.ecg_processing import process_batch class MyPipeline(Pipeline): def _sanitize_parameters(self, **kwargs): preprocess_kwargs = {} if "maybe_arg" in kwargs: preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"] return preprocess_kwargs, {}, {} def preprocess(self, inputs: str) -> dict: # inputs are csv files df = csv_to_pandas(inputs) # Implode cols_to_implode = ['timestamp_idx', 'ecg', 'label'] df_imploded = df.groupby(list(set(df.columns) - set(cols_to_implode))) \ .agg({'timestamp_idx': list, 'ecg': list, 'label': list}) \ .reset_index() # Get metadata config_cols = [col for col in df.columns if col.startswith('configs.')] configs = df_imploded[config_cols].iloc[0].to_dict() configs = {key.removeprefix('configs.'): value for key, value in configs.items()} configs = ECGConfig(**configs) batch_cols = [col for col in df.columns if col.startswith('batch.')] batch = df_imploded[batch_cols].iloc[0].to_dict() batch = {key.removeprefix('batch.'): value for key, value in batch.items()} # Get samples samples = df_imploded.to_dict(orient='records') samples = [ECGSample(**sample) for sample in samples] model_input = {"samples": samples, "configs": configs, "batch": batch} return {"model_input": model_input} def _forward(self, model_inputs): # model_inputs == {"model_input": model_input} samples = model_inputs["model_input"]["samples"] configs = model_inputs["model_input"]["configs"] batch = model_inputs["model_input"]["batch"] features_df = process_batch(samples, configs) return features_df def postprocess(self, model_outputs): return model_outputs