import logging from mmengine.runner.checkpoint import CheckpointLoader from mmengine.logging.logger import print_log from huggingface_hub import hf_hub_download HF_HUB_PREFIX = 'hf-hub:' def load_checkpoint_with_prefix(filename, prefix=None, map_location='cpu', logger='current'): """Load partial pretrained model with specific prefix. Args: prefix (str): The prefix of sub-module. filename (str): Accept local filepath, URL, ``torchvision://xxx``, ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for details. map_location (str | None): Same as :func:`torch.load`. Defaults to None. logger: logger Returns: dict or OrderedDict: The loaded checkpoint. """ if filename.startswith('hf-hub:'): model_id = filename[len(HF_HUB_PREFIX):] filename = hf_hub_download(model_id, 'pytorch_model.bin') checkpoint = CheckpointLoader.load_checkpoint(filename, map_location=map_location, logger=logger) if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] elif 'model' in checkpoint: state_dict = checkpoint['model'] else: state_dict = checkpoint if not prefix: return state_dict if not prefix.endswith('.'): prefix += '.' prefix_len = len(prefix) state_dict = { k[prefix_len:]: v for k, v in state_dict.items() if k.startswith(prefix) } assert state_dict, f'{prefix} is not in the pretrained model' return state_dict def load_state_dict_to_model(model, state_dict, logger='current'): missing_keys, unexpected_keys = model.load_state_dict(state_dict) if missing_keys: print_log(missing_keys, logger=logger, level=logging.ERROR) raise RuntimeError() if unexpected_keys: print_log(unexpected_keys, logger=logger, level=logging.ERROR) raise RuntimeError() print_log("Loaded checkpoint successfully", logger=logger)