Spaces:
Running
Running
import json | |
import h5py | |
import numpy as np | |
from omegaconf import OmegaConf | |
def load_eval(dir): | |
summaries, results = {}, {} | |
with h5py.File(str(dir / "results.h5"), "r") as hfile: | |
for k in hfile.keys(): | |
r = np.array(hfile[k]) | |
if len(r.shape) < 3: | |
results[k] = r | |
for k, v in hfile.attrs.items(): | |
summaries[k] = v | |
with open(dir / "summaries.json", "r") as f: | |
s = json.load(f) | |
summaries = {k: v if v is not None else np.nan for k, v in s.items()} | |
return summaries, results | |
def save_eval(dir, summaries, figures, results): | |
with h5py.File(str(dir / "results.h5"), "w") as hfile: | |
for k, v in results.items(): | |
arr = np.array(v) | |
if not np.issubdtype(arr.dtype, np.number): | |
arr = arr.astype("object") | |
hfile.create_dataset(k, data=arr) | |
# just to be safe, not used in practice | |
for k, v in summaries.items(): | |
hfile.attrs[k] = v | |
s = { | |
k: float(v) if np.isfinite(v) else None | |
for k, v in summaries.items() | |
if not isinstance(v, list) | |
} | |
s = {**s, **{k: v for k, v in summaries.items() if isinstance(v, list)}} | |
with open(dir / "summaries.json", "w") as f: | |
json.dump(s, f, indent=4) | |
for fig_name, fig in figures.items(): | |
fig.savefig(dir / f"{fig_name}.png") | |
def exists_eval(dir): | |
return (dir / "results.h5").exists() and (dir / "summaries.json").exists() | |
class EvalPipeline: | |
default_conf = {} | |
export_keys = [] | |
optional_export_keys = [] | |
def __init__(self, conf): | |
"""Assumes""" | |
self.default_conf = OmegaConf.create(self.default_conf) | |
self.conf = OmegaConf.merge(self.default_conf, conf) | |
self._init(self.conf) | |
def _init(self, conf): | |
pass | |
def get_dataloader(self, data_conf=None): | |
"""Returns a data loader with samples for each eval datapoint""" | |
raise NotImplementedError | |
def get_predictions(self, experiment_dir, model=None, overwrite=False): | |
"""Export a prediction file for each eval datapoint""" | |
raise NotImplementedError | |
def run_eval(self, loader, pred_file): | |
"""Run the eval on cached predictions""" | |
raise NotImplementedError | |
def run(self, experiment_dir, model=None, overwrite=False, overwrite_eval=False): | |
"""Run export+eval loop""" | |
self.save_conf( | |
experiment_dir, overwrite=overwrite, overwrite_eval=overwrite_eval | |
) | |
pred_file = self.get_predictions( | |
experiment_dir, model=model, overwrite=overwrite | |
) | |
f = {} | |
if not exists_eval(experiment_dir) or overwrite_eval or overwrite: | |
s, f, r = self.run_eval(self.get_dataloader(), pred_file) | |
save_eval(experiment_dir, s, f, r) | |
s, r = load_eval(experiment_dir) | |
return s, f, r | |
def save_conf(self, experiment_dir, overwrite=False, overwrite_eval=False): | |
# store config | |
conf_output_path = experiment_dir / "conf.yaml" | |
if conf_output_path.exists(): | |
saved_conf = OmegaConf.load(conf_output_path) | |
if (saved_conf.data != self.conf.data) or ( | |
saved_conf.model != self.conf.model | |
): | |
assert ( | |
overwrite | |
), "configs changed, add --overwrite to rerun experiment with new conf" | |
if saved_conf.eval != self.conf.eval: | |
assert ( | |
overwrite or overwrite_eval | |
), "eval configs changed, add --overwrite_eval to rerun evaluation" | |
OmegaConf.save(self.conf, experiment_dir / "conf.yaml") | |