import deepspeed from transformers import AutoConfig from transformers.deepspeed import is_deepspeed_zero3_enabled from torch import nn def remove_mismatched_weights(model, pretrained_state_dict): own_state = model.state_dict() mismatch_keys = [] for name in list(pretrained_state_dict.keys()): if name not in own_state or own_state[name].shape != pretrained_state_dict[name].shape: mismatch_keys.append(name) pretrained_state_dict.pop(name) return pretrained_state_dict, mismatch_keys def load_zero3_checkpoint(module: nn.Module, state_dict, prefix="", error_msgs = [], top=True): # check if zero3 zero3_enabled = is_deepspeed_zero3_enabled() # print(f'zero3_enabled: {zero3_enabled}') if not is_deepspeed_zero3_enabled(): state_dict, mismatch_keys = remove_mismatched_weights(module, state_dict) info = module.load_state_dict(state_dict, strict=False) if len(mismatch_keys) > 0: print("shape mismatch keys: ", mismatch_keys) if len(info.missing_keys) > 0: print("missing keys: ", info.missing_keys) if len(info.unexpected_keys) > 0: print("unexpected keys: ", info.unexpected_keys) else: # error_msgs = [] local_metadata = {} args = (state_dict, prefix, local_metadata, True, [], [], error_msgs) # Parameters of module and children will start with prefix. We can exit early if there are none in this # state_dict if len([key for key in state_dict if key.startswith(prefix)]) > 0: named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False)) params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters] params_name = [k for k in state_dict.keys() if k in named_parameters] ## named buffer for layers like batchnorm named_buffers = dict(module.named_buffers(prefix=prefix[:-1], recurse=False)) buffers_to_gather = [named_buffers[k] for k in state_dict.keys() if k in named_buffers] if len(params_to_gather) > 0 or len(buffers_to_gather)>0: # if len(buffers_to_gather)>0: # print("loading buffers") with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0): # if torch.distributed.get_rank() == 0: # if only rank0, then module's buffer will not be syncd # for k, v in zip(params_name, params_to_gather): # log the shape of the loaded weights # print(f'loading {k} with shape {v.shape}') module._load_from_state_dict(*args) # if len (error_msgs) > 0: # print(error_msgs) for name, child in module._modules.items(): if child is not None: load_zero3_checkpoint(child, state_dict, prefix + name + ".", top=False) if top: if len(error_msgs) > 0: print('loading zero3 model weights meets error messages!') print(error_msgs) else: print('loading zero3 model weights success!')