File size: 2,772 Bytes
9d3cb0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
"""
Useful class for Experiment tracking, and ensuring code is
saved alongside files.
"""  # fmt: skip
import datetime
import os
import shlex
import shutil
import subprocess
import typing
from pathlib import Path

import randomname


class Experiment:
    """This class contains utilities for managing experiments.
    It is a context manager, that when you enter it, changes
    your directory to a specified experiment folder (which
    optionally can have an automatically generated experiment
    name, or a specified one), and changes the CUDA device used
    to the specified device (or devices).

    Parameters
    ----------
    exp_directory : str
        Folder where all experiments are saved, by default "runs/".
    exp_name : str, optional
        Name of the experiment, by default uses the current time, date, and
        hostname to save.
    """

    def __init__(
        self,
        exp_directory: str = "runs/",
        exp_name: str = None,
    ):
        if exp_name is None:
            exp_name = self.generate_exp_name()
        exp_dir = Path(exp_directory) / exp_name
        exp_dir.mkdir(parents=True, exist_ok=True)

        self.exp_dir = exp_dir
        self.exp_name = exp_name
        self.git_tracked_files = (
            subprocess.check_output(
                shlex.split("git ls-tree --full-tree --name-only -r HEAD")
            )
            .decode("utf-8")
            .splitlines()
        )
        self.parent_directory = Path(".").absolute()

    def __enter__(self):
        self.prev_dir = os.getcwd()
        os.chdir(self.exp_dir)
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        os.chdir(self.prev_dir)

    @staticmethod
    def generate_exp_name():
        """Generates a random experiment name based on the date
        and a randomly generated adjective-noun tuple.

        Returns
        -------
        str
            Randomly generated experiment name.
        """
        date = datetime.datetime.now().strftime("%y%m%d")
        name = f"{date}-{randomname.get_name()}"
        return name

    def snapshot(self, filter_fn: typing.Callable = lambda f: True):
        """Captures a full snapshot of all the files tracked by git at the time
        the experiment is run. It also captures the diff against the committed
        code as a separate file.

        Parameters
        ----------
        filter_fn : typing.Callable, optional
            Function that can be used to exclude some files
            from the snapshot, by default accepts all files
        """
        for f in self.git_tracked_files:
            if filter_fn(f):
                Path(f).parent.mkdir(parents=True, exist_ok=True)
                shutil.copyfile(self.parent_directory / f, f)