Spaces:
Build error
Build error
#!/usr/bin/env python3 | |
# -*- coding:utf-8 -*- | |
import os | |
import shutil | |
import torch | |
import os.path as osp | |
from yolov6.utils.events import LOGGER | |
from yolov6.utils.torch_utils import fuse_model | |
def load_state_dict(weights, model, map_location=None): | |
"""Load weights from checkpoint file, only assign weights those layers' name and shape are match.""" | |
ckpt = torch.load(weights, map_location=map_location) | |
state_dict = ckpt['model'].float().state_dict() | |
model_state_dict = model.state_dict() | |
state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict and v.shape == model_state_dict[k].shape} | |
model.load_state_dict(state_dict, strict=False) | |
del ckpt, state_dict, model_state_dict | |
return model | |
def load_checkpoint(weights, map_location=None, inplace=True, fuse=True): | |
"""Load model from checkpoint file.""" | |
LOGGER.info("Loading checkpoint from {}".format(weights)) | |
ckpt = torch.load(weights, map_location=map_location) # load | |
model = ckpt['ema' if ckpt.get('ema') else 'model'].float() | |
if fuse: | |
LOGGER.info("\nFusing model...") | |
model = fuse_model(model).eval() | |
else: | |
model = model.eval() | |
return model | |
def save_checkpoint(ckpt, is_best, save_dir, model_name=""): | |
""" Save checkpoint to the disk.""" | |
if not osp.exists(save_dir): | |
os.makedirs(save_dir) | |
filename = osp.join(save_dir, model_name + '.pt') | |
torch.save(ckpt, filename) | |
if is_best: | |
best_filename = osp.join(save_dir, 'best_ckpt.pt') | |
shutil.copyfile(filename, best_filename) | |
def strip_optimizer(ckpt_dir, epoch): | |
for s in ['best', 'last']: | |
ckpt_path = osp.join(ckpt_dir, '{}_ckpt.pt'.format(s)) | |
if not osp.exists(ckpt_path): | |
continue | |
ckpt = torch.load(ckpt_path, map_location=torch.device('cpu')) | |
if ckpt.get('ema'): | |
ckpt['model'] = ckpt['ema'] # replace model with ema | |
for k in ['optimizer', 'ema', 'updates']: # keys | |
ckpt[k] = None | |
ckpt['epoch'] = epoch | |
ckpt['model'].half() # to FP16 | |
for p in ckpt['model'].parameters(): | |
p.requires_grad = False | |
torch.save(ckpt, ckpt_path) | |