ECG2HRV / pydantic_models.py
nina-m-m's picture
Upload source files
0ffeb19 verified
raw
history blame
7.63 kB
""" Pydantic models for use in the API. """
import json
from datetime import datetime, timedelta, date
from typing import Union, Dict, Any
from uuid import UUID, uuid4
from pydantic import BaseModel, Field, model_validator
from .configs import SignalEnum, WindowSlicingMethodEnum, NormalizationMethodEnum
# Try opening json file samples
try:
with open('data/examples/example0_input.json') as json_file:
example0 = json.load(json_file)
with open('data/examples/example1_input.json') as json_file:
example1 = json.load(json_file)
except FileNotFoundError:
print(
"Example Files for interface not found. Please run the Jupyter Notebook in notebooks/1_Data_Formatting_and_transformation.py first.")
example0 = {}
example1 = {}
class ECGSample(BaseModel):
""" Model of the results of a single subject of an experiment with ECG biosignals. """
sample_id: UUID = Field(example="f70c1033-36ae-4b8b-8b89-099a96dccca5", default_factory=uuid4)
subject_id: str = Field(..., example="participant_1")
frequency: int = Field(..., example=1000)
device_name: str = Field(example="bioplux", default=None)
# pydantic will process either an int or float (unix timestamp) (e.g. 1496498400),
# an int or float as a string (assumed as Unix timestamp), or
# o string representing the date (e.g. "YYYY - MM - DD[T]HH: MM[:SS[.ffffff]][Z or [±]HH[:]MM]")
timestamp_idx: list[datetime] = Field(..., min_items=2, example=[1679709871, 1679713471, 1679720671])
ecg: list[float] = Field(..., min_items=2, example=[1.0, -1.100878, -3.996840])
label: list[str] = Field(min_items=2, example=["undefined", "stress", "undefined"], default=None)
class Config:
json_schema_extra = {
"example": {
"sample_id": "f70c1033-36ae-4b8b-8b89-099a96dccca5",
"subject_id": "participant_1",
"frequency": 1000,
"device_name": "bioplux",
"timestamp_idx": [1679709871, 1679713471, 1679720671],
"ecg": [1.0, -1.100878, -3.996840],
"label": ["undefined", "stress", "undefined"]
}
}
@model_validator(mode='before')
@classmethod
def set_label_default(cls, values: Any) -> Any:
"""
Set default for list parameter "label" if list has empty values.
"""
if isinstance(values, dict):
max_len = max(len(values['timestamp_idx']), len(values['ecg']))
if values['label'] is None:
values['label'] = ['undefined'] * max_len
elif len(values['label']) < max_len:
values['label'] += ['undefined'] * (max_len - len(values['label']))
return values
@model_validator(mode='after')
def check_length(self) -> 'ECGSample':
"""
Validates that given lists have the same length.
"""
lengths = [len(self.timestamp_idx), len(self.ecg)]
if len(set(lengths)) != 1:
raise ValueError('Given timestamp and ecg list must have the same length!')
return self
class ECGConfig(BaseModel):
""" Model of the configuration of an experiment with ECG biosignals. """
signal: SignalEnum = Field(example=SignalEnum.chest, default=None)
window_slicing_method: WindowSlicingMethodEnum = Field(example=WindowSlicingMethodEnum.time_related,
default=WindowSlicingMethodEnum.time_related)
window_size: float = Field(example=1.0, default=5.0)
# pydantic will process either an int or float (unix timestamp) (e.g. 1496498400),
# an int or float as a string (assumed as Unix timestamp), or
# o string representing the date (e.g. "YYYY - MM - DD[T]HH: MM[:SS[.ffffff]][Z or [±]HH[:]MM]")
baseline_start: datetime = Field(example="2034-01-16T00:00:00", default=None)
baseline_end: datetime = Field(example="2034-01-16T00:01:00", default=None)
baseline_duration: int = Field(example=60, default=None) # in seconds
normalization_method: Union[NormalizationMethodEnum | None] = Field(
example=NormalizationMethodEnum.baseline_difference,
default=NormalizationMethodEnum.baseline_difference)
extra: Dict[str, Any] = Field(default=None)
class Config:
json_schema_extra = {
"example": {
"signal": "chest",
"window_slicing_method": "time_related",
"window_size": 60,
"baseline_start": "2023-05-23 22:58:01.335",
"baseline_duration": 60,
"test": "test"
}
}
@model_validator(mode='before')
@classmethod
def build_extra(cls, values: Any) -> Any:
required_fields = {field.alias for field in cls.model_fields.values() if field.alias != 'extra'}
extra: Dict[str, Any] = {}
for field_name in list(values):
if field_name not in required_fields:
extra[field_name] = values.pop(field_name)
values['extra'] = extra
return values
@model_validator(mode='after')
def check_baseline_start(self) -> 'ECGConfig':
"""
Validates that baseline_start and either baseline_duration or baseline_end are given if baseline is True.
If baseline_end is not provided, it is calculated as baseline_start + baseline_duration.
"""
if self.baseline_start:
if self.baseline_duration is None and self.baseline_end is None:
raise ValueError(
'If baseline_start id given, either baseline_duration or baseline_end must be provided.')
if self.baseline_end is None:
if self.baseline_duration is None:
raise ValueError(
'If baseline is True, baseline_duration must be provided when baseline_end is not provided.')
self.baseline_end = self.baseline_start + timedelta(seconds=self.baseline_duration)
elif self.baseline_start is None and (self.baseline_duration or self.baseline_end) is not None:
raise ValueError(
'If basleine_duration or baseline_end is given, baseline_start must be provided in order. Delete the '
'baseline Parameters if the baseline is not needed.')
return self
@classmethod
def __get_validators__(cls):
yield cls.validate_to_json
@classmethod
def validate_to_json(cls, value):
if isinstance(value, str):
return cls.model_validate(json.loads(value.encode()))
return cls.model_validate(value)
class ECGBatch(BaseModel):
""" Input Modle for Data Validation. The Input being the results of an experiment with ECG biosignals,
including a batch of ecg data of different subjects. """
supervisor: str = Field(..., example="Lieschen Mueller")
# pydantic will process either an int or float (unix timestamp) (e.g. 1496498400),
# an int or float as a string (assumed as Unix timestamp), or
# o string representing the date (e.g. "YYYY-MM-DD")
record_date: date = Field(example="2034-01-16", default_factory=datetime.utcnow)
configs: ECGConfig = Field(..., example=ECGConfig.Config.json_schema_extra)
samples: list[ECGSample] = Field(..., min_items=1,
example=[ECGSample.Config.json_schema_extra, ECGSample.Config.json_schema_extra])
class Config:
json_schema_extra = {
"example": example1,
"examples": [
example0,
example1
]
}