|
import torch |
|
from pytorch_lightning import Trainer |
|
from torch.utils.data import DataLoader |
|
from utils.datasets import ECGDataset |
|
from utils.models import EffNet |
|
from utils.training_models import RegressionModel |
|
|
|
|
|
|
|
|
|
data_path = "your/ecg/data/folder" |
|
|
|
|
|
manifest_path = 'your/manifest/path' |
|
|
|
|
|
|
|
|
|
test_ds = ECGDataset( |
|
split="test", |
|
data_path=data_path, |
|
manifest_path=manifest_path, |
|
update_manifest_func=None, |
|
) |
|
|
|
|
|
test_dl = DataLoader( |
|
test_ds, |
|
num_workers=16, |
|
batch_size=256, |
|
drop_last=False, |
|
shuffle=False |
|
) |
|
|
|
|
|
backbone = EffNet(input_channels=12, output_neurons=1) |
|
|
|
model = RegressionModel(backbone) |
|
|
|
weights = torch.load("model_12_lead.pt") |
|
print(model.load_state_dict(weights)) |
|
|
|
|
|
trainer = Trainer(accelerator="gpu", devices=1) |
|
|
|
trainer.predict(model, dataloaders=test_dl) |
|
|