MantraDas's picture
Duplicate from facebook/MusicGen
8aafaec
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Utility to export a training checkpoint to a lightweight release checkpoint.
"""
from pathlib import Path
import typing as tp
from omegaconf import OmegaConf, DictConfig
import torch
def _clean_lm_cfg(cfg: DictConfig):
OmegaConf.set_struct(cfg, False)
# This used to be set automatically in the LM solver, need a more robust solution
# for the future.
cfg['transformer_lm']['card'] = 2048
cfg['transformer_lm']['n_q'] = 4
# Experimental params no longer supported.
bad_params = ['spectral_norm_attn_iters', 'spectral_norm_ff_iters',
'residual_balancer_attn', 'residual_balancer_ff', 'layer_drop']
for name in bad_params:
del cfg['transformer_lm'][name]
OmegaConf.set_struct(cfg, True)
return cfg
def export_encodec(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]):
sig = Path(checkpoint_path).parent.name
assert len(sig) == 8, "Not a valid Dora signature"
pkg = torch.load(checkpoint_path, 'cpu')
new_pkg = {
'best_state': pkg['ema']['state']['model'],
'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
}
out_file = Path(out_folder) / f'{sig}.th'
torch.save(new_pkg, out_file)
return out_file
def export_lm(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]):
sig = Path(checkpoint_path).parent.name
assert len(sig) == 8, "Not a valid Dora signature"
pkg = torch.load(checkpoint_path, 'cpu')
new_pkg = {
'best_state': pkg['fsdp_best_state']['model'],
'xp.cfg': OmegaConf.to_yaml(_clean_lm_cfg(pkg['xp.cfg']))
}
out_file = Path(out_folder) / f'{sig}.th'
torch.save(new_pkg, out_file)
return out_file