OpenSound's picture
Upload 211 files
9d3cb0a verified
raw
history blame
2.77 kB
"""
Useful class for Experiment tracking, and ensuring code is
saved alongside files.
""" # fmt: skip
import datetime
import os
import shlex
import shutil
import subprocess
import typing
from pathlib import Path
import randomname
class Experiment:
"""This class contains utilities for managing experiments.
It is a context manager, that when you enter it, changes
your directory to a specified experiment folder (which
optionally can have an automatically generated experiment
name, or a specified one), and changes the CUDA device used
to the specified device (or devices).
Parameters
----------
exp_directory : str
Folder where all experiments are saved, by default "runs/".
exp_name : str, optional
Name of the experiment, by default uses the current time, date, and
hostname to save.
"""
def __init__(
self,
exp_directory: str = "runs/",
exp_name: str = None,
):
if exp_name is None:
exp_name = self.generate_exp_name()
exp_dir = Path(exp_directory) / exp_name
exp_dir.mkdir(parents=True, exist_ok=True)
self.exp_dir = exp_dir
self.exp_name = exp_name
self.git_tracked_files = (
subprocess.check_output(
shlex.split("git ls-tree --full-tree --name-only -r HEAD")
)
.decode("utf-8")
.splitlines()
)
self.parent_directory = Path(".").absolute()
def __enter__(self):
self.prev_dir = os.getcwd()
os.chdir(self.exp_dir)
return self
def __exit__(self, exc_type, exc_value, traceback):
os.chdir(self.prev_dir)
@staticmethod
def generate_exp_name():
"""Generates a random experiment name based on the date
and a randomly generated adjective-noun tuple.
Returns
-------
str
Randomly generated experiment name.
"""
date = datetime.datetime.now().strftime("%y%m%d")
name = f"{date}-{randomname.get_name()}"
return name
def snapshot(self, filter_fn: typing.Callable = lambda f: True):
"""Captures a full snapshot of all the files tracked by git at the time
the experiment is run. It also captures the diff against the committed
code as a separate file.
Parameters
----------
filter_fn : typing.Callable, optional
Function that can be used to exclude some files
from the snapshot, by default accepts all files
"""
for f in self.git_tracked_files:
if filter_fn(f):
Path(f).parent.mkdir(parents=True, exist_ok=True)
shutil.copyfile(self.parent_directory / f, f)