File size: 11,104 Bytes
9d3cb0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
import inspect
import shutil
import tempfile
import typing
from pathlib import Path

import torch
from torch import nn


class BaseModel(nn.Module):
    """This is a class that adds useful save/load functionality to a
    ``torch.nn.Module`` object. ``BaseModel`` objects can be saved
    as ``torch.package`` easily, making them super easy to port between
    machines without requiring a ton of dependencies. Files can also be
    saved as just weights, in the standard way.

    >>> class Model(ml.BaseModel):
    >>>     def __init__(self, arg1: float = 1.0):
    >>>         super().__init__()
    >>>         self.arg1 = arg1
    >>>         self.linear = nn.Linear(1, 1)
    >>>
    >>>     def forward(self, x):
    >>>         return self.linear(x)
    >>>
    >>> model1 = Model()
    >>>
    >>> with tempfile.NamedTemporaryFile(suffix=".pth") as f:
    >>>     model1.save(
    >>>         f.name,
    >>>     )
    >>>     model2 = Model.load(f.name)
    >>>     out2 = seed_and_run(model2, x)
    >>>     assert torch.allclose(out1, out2)
    >>>
    >>>     model1.save(f.name, package=True)
    >>>     model2 = Model.load(f.name)
    >>>     model2.save(f.name, package=False)
    >>>     model3 = Model.load(f.name)
    >>>     out3 = seed_and_run(model3, x)
    >>>
    >>> with tempfile.TemporaryDirectory() as d:
    >>>     model1.save_to_folder(d, {"data": 1.0})
    >>>     Model.load_from_folder(d)

    """

    EXTERN = [
        "audiotools.**",
        "tqdm",
        "__main__",
        "numpy.**",
        "julius.**",
        "torchaudio.**",
        "scipy.**",
        "einops",
    ]
    """Names of libraries that are external to the torch.package saving mechanism.
    Source code from these libraries will not be packaged into the model. This can
    be edited by the user of this class by editing ``model.EXTERN``."""
    INTERN = []
    """Names of libraries that are internal to the torch.package saving mechanism.
    Source code from these libraries will be saved alongside the model."""

    def save(
        self,
        path: str,
        metadata: dict = None,
        package: bool = True,
        intern: list = [],
        extern: list = [],
        mock: list = [],
    ):
        """Saves the model, either as a torch package, or just as
        weights, alongside some specified metadata.

        Parameters
        ----------
        path : str
            Path to save model to.
        metadata : dict, optional
            Any metadata to save alongside the model,
            by default None
        package : bool, optional
            Whether to use ``torch.package`` to save the model in
            a format that is portable, by default True
        intern : list, optional
            List of additional libraries that are internal
            to the model, used with torch.package, by default []
        extern : list, optional
            List of additional libraries that are external to
            the model, used with torch.package, by default []
        mock : list, optional
            List of libraries to mock, used with torch.package,
            by default []

        Returns
        -------
        str
            Path to saved model.
        """
        sig = inspect.signature(self.__class__)
        args = {}

        for key, val in sig.parameters.items():
            arg_val = val.default
            if arg_val is not inspect.Parameter.empty:
                args[key] = arg_val

        # Look up attibutes in self, and if any of them are in args,
        # overwrite them in args.
        for attribute in dir(self):
            if attribute in args:
                args[attribute] = getattr(self, attribute)

        metadata = {} if metadata is None else metadata
        metadata["kwargs"] = args
        if not hasattr(self, "metadata"):
            self.metadata = {}
        self.metadata.update(metadata)

        if not package:
            state_dict = {"state_dict": self.state_dict(), "metadata": metadata}
            torch.save(state_dict, path)
        else:
            self._save_package(path, intern=intern, extern=extern, mock=mock)

        return path

    @property
    def device(self):
        """Gets the device the model is on by looking at the device of
        the first parameter. May not be valid if model is split across
        multiple devices.
        """
        return list(self.parameters())[0].device

    @classmethod
    def load(
        cls,
        location: str,
        *args,
        package_name: str = None,
        strict: bool = False,
        **kwargs,
    ):
        """Load model from a path. Tries first to load as a package, and if
        that fails, tries to load as weights. The arguments to the class are
        specified inside the model weights file.

        Parameters
        ----------
        location : str
            Path to file.
        package_name : str, optional
            Name of package, by default ``cls.__name__``.
        strict : bool, optional
            Ignore unmatched keys, by default False
        kwargs : dict
            Additional keyword arguments to the model instantiation, if
            not loading from package.

        Returns
        -------
        BaseModel
            A model that inherits from BaseModel.
        """
        try:
            model = cls._load_package(location, package_name=package_name)
        except:
            model_dict = torch.load(location, "cpu")
            metadata = model_dict["metadata"]
            metadata["kwargs"].update(kwargs)

            sig = inspect.signature(cls)
            class_keys = list(sig.parameters.keys())
            for k in list(metadata["kwargs"].keys()):
                if k not in class_keys:
                    metadata["kwargs"].pop(k)

            model = cls(*args, **metadata["kwargs"])
            model.load_state_dict(model_dict["state_dict"], strict=strict)
            model.metadata = metadata

        return model

    def _save_package(self, path, intern=[], extern=[], mock=[], **kwargs):
        package_name = type(self).__name__
        resource_name = f"{type(self).__name__}.pth"

        # Below is for loading and re-saving a package.
        if hasattr(self, "importer"):
            kwargs["importer"] = (self.importer, torch.package.sys_importer)
            del self.importer

        # Why do we use a tempfile, you ask?
        # It's so we can load a packaged model and then re-save
        # it to the same location. torch.package throws an
        # error if it's loading and writing to the same
        # file (this is undocumented).
        with tempfile.NamedTemporaryFile(suffix=".pth") as f:
            with torch.package.PackageExporter(f.name, **kwargs) as exp:
                exp.intern(self.INTERN + intern)
                exp.mock(mock)
                exp.extern(self.EXTERN + extern)
                exp.save_pickle(package_name, resource_name, self)

                if hasattr(self, "metadata"):
                    exp.save_pickle(
                        package_name, f"{package_name}.metadata", self.metadata
                    )

            shutil.copyfile(f.name, path)

        # Must reset the importer back to `self` if it existed
        # so that you can save the model again!
        if "importer" in kwargs:
            self.importer = kwargs["importer"][0]
        return path

    @classmethod
    def _load_package(cls, path, package_name=None):
        package_name = cls.__name__ if package_name is None else package_name
        resource_name = f"{package_name}.pth"

        imp = torch.package.PackageImporter(path)
        model = imp.load_pickle(package_name, resource_name, "cpu")
        try:
            model.metadata = imp.load_pickle(package_name, f"{package_name}.metadata")
        except:  # pragma: no cover
            pass
        model.importer = imp

        return model

    def save_to_folder(
        self,
        folder: typing.Union[str, Path],
        extra_data: dict = None,
        package: bool = True,
    ):
        """Dumps a model into a folder, as both a package
        and as weights, as well as anything specified in
        ``extra_data``. ``extra_data`` is a dictionary of other
        pickleable files, with the keys being the paths
        to save them in. The model is saved under a subfolder
        specified by the name of the class (e.g. ``folder/generator/[package, weights].pth``
        if the model name was ``Generator``).

        >>> with tempfile.TemporaryDirectory() as d:
        >>>     extra_data = {
        >>>         "optimizer.pth": optimizer.state_dict()
        >>>     }
        >>>     model.save_to_folder(d, extra_data)
        >>>     Model.load_from_folder(d)

        Parameters
        ----------
        folder : typing.Union[str, Path]
            _description_
        extra_data : dict, optional
            _description_, by default None

        Returns
        -------
        str
            Path to folder
        """
        extra_data = {} if extra_data is None else extra_data
        model_name = type(self).__name__.lower()
        target_base = Path(f"{folder}/{model_name}/")
        target_base.mkdir(exist_ok=True, parents=True)

        if package:
            package_path = target_base / f"package.pth"
            self.save(package_path)

        weights_path = target_base / f"weights.pth"
        self.save(weights_path, package=False)

        for path, obj in extra_data.items():
            torch.save(obj, target_base / path)

        return target_base

    @classmethod
    def load_from_folder(
        cls,
        folder: typing.Union[str, Path],
        package: bool = True,
        strict: bool = False,
        **kwargs,
    ):
        """Loads the model from a folder generated by
        :py:func:`audiotools.ml.layers.base.BaseModel.save_to_folder`.
        Like that function, this one looks for a subfolder that has
        the name of the class (e.g. ``folder/generator/[package, weights].pth`` if the
        model name was ``Generator``).

        Parameters
        ----------
        folder : typing.Union[str, Path]
            _description_
        package : bool, optional
            Whether to use ``torch.package`` to load the model,
            loading the model from ``package.pth``.
        strict : bool, optional
            Ignore unmatched keys, by default False

        Returns
        -------
        tuple
            tuple of model and extra data as saved by
            :py:func:`audiotools.ml.layers.base.BaseModel.save_to_folder`.
        """
        folder = Path(folder) / cls.__name__.lower()
        model_pth = "package.pth" if package else "weights.pth"
        model_pth = folder / model_pth

        model = cls.load(model_pth, strict=strict)
        extra_data = {}
        excluded = ["package.pth", "weights.pth"]
        files = [x for x in folder.glob("*") if x.is_file() and x.name not in excluded]
        for f in files:
            extra_data[f.name] = torch.load(f, **kwargs)

        return model, extra_data