File size: 3,725 Bytes
4d4dd90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import json

import h5py
import numpy as np
from omegaconf import OmegaConf


def load_eval(dir):
    summaries, results = {}, {}
    with h5py.File(str(dir / "results.h5"), "r") as hfile:
        for k in hfile.keys():
            r = np.array(hfile[k])
            if len(r.shape) < 3:
                results[k] = r
        for k, v in hfile.attrs.items():
            summaries[k] = v
    with open(dir / "summaries.json", "r") as f:
        s = json.load(f)
    summaries = {k: v if v is not None else np.nan for k, v in s.items()}
    return summaries, results


def save_eval(dir, summaries, figures, results):
    with h5py.File(str(dir / "results.h5"), "w") as hfile:
        for k, v in results.items():
            arr = np.array(v)
            if not np.issubdtype(arr.dtype, np.number):
                arr = arr.astype("object")
            hfile.create_dataset(k, data=arr)
        # just to be safe, not used in practice
        for k, v in summaries.items():
            hfile.attrs[k] = v
    s = {
        k: float(v) if np.isfinite(v) else None
        for k, v in summaries.items()
        if not isinstance(v, list)
    }
    s = {**s, **{k: v for k, v in summaries.items() if isinstance(v, list)}}
    with open(dir / "summaries.json", "w") as f:
        json.dump(s, f, indent=4)

    for fig_name, fig in figures.items():
        fig.savefig(dir / f"{fig_name}.png")


def exists_eval(dir):
    return (dir / "results.h5").exists() and (dir / "summaries.json").exists()


class EvalPipeline:
    default_conf = {}

    export_keys = []
    optional_export_keys = []

    def __init__(self, conf):
        """Assumes"""
        self.default_conf = OmegaConf.create(self.default_conf)
        self.conf = OmegaConf.merge(self.default_conf, conf)
        self._init(self.conf)

    def _init(self, conf):
        pass

    @classmethod
    def get_dataloader(self, data_conf=None):
        """Returns a data loader with samples for each eval datapoint"""
        raise NotImplementedError

    def get_predictions(self, experiment_dir, model=None, overwrite=False):
        """Export a prediction file for each eval datapoint"""
        raise NotImplementedError

    def run_eval(self, loader, pred_file):
        """Run the eval on cached predictions"""
        raise NotImplementedError

    def run(self, experiment_dir, model=None, overwrite=False, overwrite_eval=False):
        """Run export+eval loop"""
        self.save_conf(
            experiment_dir, overwrite=overwrite, overwrite_eval=overwrite_eval
        )
        pred_file = self.get_predictions(
            experiment_dir, model=model, overwrite=overwrite
        )

        f = {}
        if not exists_eval(experiment_dir) or overwrite_eval or overwrite:
            s, f, r = self.run_eval(self.get_dataloader(), pred_file)
            save_eval(experiment_dir, s, f, r)
        s, r = load_eval(experiment_dir)
        return s, f, r

    def save_conf(self, experiment_dir, overwrite=False, overwrite_eval=False):
        # store config
        conf_output_path = experiment_dir / "conf.yaml"
        if conf_output_path.exists():
            saved_conf = OmegaConf.load(conf_output_path)
            if (saved_conf.data != self.conf.data) or (
                saved_conf.model != self.conf.model
            ):
                assert (
                    overwrite
                ), "configs changed, add --overwrite to rerun experiment with new conf"
            if saved_conf.eval != self.conf.eval:
                assert (
                    overwrite or overwrite_eval
                ), "eval configs changed, add --overwrite_eval to rerun evaluation"
        OmegaConf.save(self.conf, experiment_dir / "conf.yaml")