Spaces:
Runtime error
Runtime error
import argparse | |
from pathlib import Path | |
from pytorch_lightning.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict | |
if __name__ == '__main__': | |
## read a path using argparse and pass it to convert_zero_checkpoint_to_fp32_state_dict | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--input', type=str, default=None, help='path to the desired checkpoint folder') | |
parser.add_argument('--output', type=str, default=None, help='path to the pytorch fp32 state_dict output file') | |
# parser.add_argument('--tag', type=str, help='checkpoint tag used as a unique identifier for checkpoint') | |
args = parser.parse_args() | |
if args.output is None: | |
args.output = Path(args.input) / 'converted.ckpt' | |
convert_zero_checkpoint_to_fp32_state_dict(args.input, args.output) |