import inspect import shutil import tempfile import typing from pathlib import Path import torch from torch import nn class BaseModel(nn.Module): """This is a class that adds useful save/load functionality to a ``torch.nn.Module`` object. ``BaseModel`` objects can be saved as ``torch.package`` easily, making them super easy to port between machines without requiring a ton of dependencies. Files can also be saved as just weights, in the standard way. >>> class Model(ml.BaseModel): >>> def __init__(self, arg1: float = 1.0): >>> super().__init__() >>> self.arg1 = arg1 >>> self.linear = nn.Linear(1, 1) >>> >>> def forward(self, x): >>> return self.linear(x) >>> >>> model1 = Model() >>> >>> with tempfile.NamedTemporaryFile(suffix=".pth") as f: >>> model1.save( >>> f.name, >>> ) >>> model2 = Model.load(f.name) >>> out2 = seed_and_run(model2, x) >>> assert torch.allclose(out1, out2) >>> >>> model1.save(f.name, package=True) >>> model2 = Model.load(f.name) >>> model2.save(f.name, package=False) >>> model3 = Model.load(f.name) >>> out3 = seed_and_run(model3, x) >>> >>> with tempfile.TemporaryDirectory() as d: >>> model1.save_to_folder(d, {"data": 1.0}) >>> Model.load_from_folder(d) """ EXTERN = [ "audiotools.**", "tqdm", "__main__", "numpy.**", "julius.**", "torchaudio.**", "scipy.**", "einops", ] """Names of libraries that are external to the torch.package saving mechanism. Source code from these libraries will not be packaged into the model. This can be edited by the user of this class by editing ``model.EXTERN``.""" INTERN = [] """Names of libraries that are internal to the torch.package saving mechanism. Source code from these libraries will be saved alongside the model.""" def save( self, path: str, metadata: dict = None, package: bool = True, intern: list = [], extern: list = [], mock: list = [], ): """Saves the model, either as a torch package, or just as weights, alongside some specified metadata. Parameters ---------- path : str Path to save model to. metadata : dict, optional Any metadata to save alongside the model, by default None package : bool, optional Whether to use ``torch.package`` to save the model in a format that is portable, by default True intern : list, optional List of additional libraries that are internal to the model, used with torch.package, by default [] extern : list, optional List of additional libraries that are external to the model, used with torch.package, by default [] mock : list, optional List of libraries to mock, used with torch.package, by default [] Returns ------- str Path to saved model. """ sig = inspect.signature(self.__class__) args = {} for key, val in sig.parameters.items(): arg_val = val.default if arg_val is not inspect.Parameter.empty: args[key] = arg_val # Look up attibutes in self, and if any of them are in args, # overwrite them in args. for attribute in dir(self): if attribute in args: args[attribute] = getattr(self, attribute) metadata = {} if metadata is None else metadata metadata["kwargs"] = args if not hasattr(self, "metadata"): self.metadata = {} self.metadata.update(metadata) if not package: state_dict = {"state_dict": self.state_dict(), "metadata": metadata} torch.save(state_dict, path) else: self._save_package(path, intern=intern, extern=extern, mock=mock) return path @property def device(self): """Gets the device the model is on by looking at the device of the first parameter. May not be valid if model is split across multiple devices. """ return list(self.parameters())[0].device @classmethod def load( cls, location: str, *args, package_name: str = None, strict: bool = False, **kwargs, ): """Load model from a path. Tries first to load as a package, and if that fails, tries to load as weights. The arguments to the class are specified inside the model weights file. Parameters ---------- location : str Path to file. package_name : str, optional Name of package, by default ``cls.__name__``. strict : bool, optional Ignore unmatched keys, by default False kwargs : dict Additional keyword arguments to the model instantiation, if not loading from package. Returns ------- BaseModel A model that inherits from BaseModel. """ try: model = cls._load_package(location, package_name=package_name) except: model_dict = torch.load(location, "cpu") metadata = model_dict["metadata"] metadata["kwargs"].update(kwargs) sig = inspect.signature(cls) class_keys = list(sig.parameters.keys()) for k in list(metadata["kwargs"].keys()): if k not in class_keys: metadata["kwargs"].pop(k) model = cls(*args, **metadata["kwargs"]) model.load_state_dict(model_dict["state_dict"], strict=strict) model.metadata = metadata return model def _save_package(self, path, intern=[], extern=[], mock=[], **kwargs): package_name = type(self).__name__ resource_name = f"{type(self).__name__}.pth" # Below is for loading and re-saving a package. if hasattr(self, "importer"): kwargs["importer"] = (self.importer, torch.package.sys_importer) del self.importer # Why do we use a tempfile, you ask? # It's so we can load a packaged model and then re-save # it to the same location. torch.package throws an # error if it's loading and writing to the same # file (this is undocumented). with tempfile.NamedTemporaryFile(suffix=".pth") as f: with torch.package.PackageExporter(f.name, **kwargs) as exp: exp.intern(self.INTERN + intern) exp.mock(mock) exp.extern(self.EXTERN + extern) exp.save_pickle(package_name, resource_name, self) if hasattr(self, "metadata"): exp.save_pickle( package_name, f"{package_name}.metadata", self.metadata ) shutil.copyfile(f.name, path) # Must reset the importer back to `self` if it existed # so that you can save the model again! if "importer" in kwargs: self.importer = kwargs["importer"][0] return path @classmethod def _load_package(cls, path, package_name=None): package_name = cls.__name__ if package_name is None else package_name resource_name = f"{package_name}.pth" imp = torch.package.PackageImporter(path) model = imp.load_pickle(package_name, resource_name, "cpu") try: model.metadata = imp.load_pickle(package_name, f"{package_name}.metadata") except: # pragma: no cover pass model.importer = imp return model def save_to_folder( self, folder: typing.Union[str, Path], extra_data: dict = None, package: bool = True, ): """Dumps a model into a folder, as both a package and as weights, as well as anything specified in ``extra_data``. ``extra_data`` is a dictionary of other pickleable files, with the keys being the paths to save them in. The model is saved under a subfolder specified by the name of the class (e.g. ``folder/generator/[package, weights].pth`` if the model name was ``Generator``). >>> with tempfile.TemporaryDirectory() as d: >>> extra_data = { >>> "optimizer.pth": optimizer.state_dict() >>> } >>> model.save_to_folder(d, extra_data) >>> Model.load_from_folder(d) Parameters ---------- folder : typing.Union[str, Path] _description_ extra_data : dict, optional _description_, by default None Returns ------- str Path to folder """ extra_data = {} if extra_data is None else extra_data model_name = type(self).__name__.lower() target_base = Path(f"{folder}/{model_name}/") target_base.mkdir(exist_ok=True, parents=True) if package: package_path = target_base / f"package.pth" self.save(package_path) weights_path = target_base / f"weights.pth" self.save(weights_path, package=False) for path, obj in extra_data.items(): torch.save(obj, target_base / path) return target_base @classmethod def load_from_folder( cls, folder: typing.Union[str, Path], package: bool = True, strict: bool = False, **kwargs, ): """Loads the model from a folder generated by :py:func:`audiotools.ml.layers.base.BaseModel.save_to_folder`. Like that function, this one looks for a subfolder that has the name of the class (e.g. ``folder/generator/[package, weights].pth`` if the model name was ``Generator``). Parameters ---------- folder : typing.Union[str, Path] _description_ package : bool, optional Whether to use ``torch.package`` to load the model, loading the model from ``package.pth``. strict : bool, optional Ignore unmatched keys, by default False Returns ------- tuple tuple of model and extra data as saved by :py:func:`audiotools.ml.layers.base.BaseModel.save_to_folder`. """ folder = Path(folder) / cls.__name__.lower() model_pth = "package.pth" if package else "weights.pth" model_pth = folder / model_pth model = cls.load(model_pth, strict=strict) extra_data = {} excluded = ["package.pth", "weights.pth"] files = [x for x in folder.glob("*") if x.is_file() and x.name not in excluded] for f in files: extra_data[f.name] = torch.load(f, **kwargs) return model, extra_data