|
""" |
|
@date: 2021/11/22 |
|
@description: Conversion training ckpt into inference ckpt |
|
""" |
|
import argparse |
|
import os |
|
|
|
import torch |
|
|
|
from config.defaults import merge_from_file |
|
|
|
|
|
def parse_option(): |
|
parser = argparse.ArgumentParser(description='Conversion training ckpt into inference ckpt') |
|
parser.add_argument('--cfg', |
|
type=str, |
|
required=True, |
|
metavar='FILE', |
|
help='path of config file') |
|
|
|
parser.add_argument('--output_path', |
|
type=str, |
|
help='path of output ckpt') |
|
|
|
args = parser.parse_args() |
|
|
|
print("arguments:") |
|
for arg in vars(args): |
|
print(arg, ":", getattr(args, arg)) |
|
print("-" * 50) |
|
return args |
|
|
|
|
|
def convert_ckpt(): |
|
args = parse_option() |
|
config = merge_from_file(args.cfg) |
|
ck_dir = os.path.join("checkpoints", f"{config.MODEL.ARGS[0]['decoder_name']}_{config.MODEL.ARGS[0]['output_name']}_Net", |
|
config.TAG) |
|
print(f"Processing {ck_dir}") |
|
model_paths = [name for name in os.listdir(ck_dir) if '_best_' in name] |
|
if len(model_paths) == 0: |
|
print("Not find best ckpt") |
|
return |
|
model_path = os.path.join(ck_dir, model_paths[0]) |
|
print(f"Loading {model_path}") |
|
checkpoint = torch.load(model_path, map_location=torch.device('cuda:0')) |
|
net = checkpoint['net'] |
|
output_path = None |
|
if args.output_path is None: |
|
output_path = os.path.join(ck_dir, 'best.pkl') |
|
else: |
|
output_path = args.output_path |
|
if output_path is None: |
|
print("Output path is invalid") |
|
print(f"Save on: {output_path}") |
|
os.makedirs(os.path.dirname(output_path), exist_ok=True) |
|
torch.save(net, output_path) |
|
|
|
|
|
if __name__ == '__main__': |
|
convert_ckpt() |
|
|