def load_matched_state_dict(model, state_dict, print_stats=True): """ Only loads weights that matched in key and shape. Ignore other weights. """ num_matched, num_total = 0, 0 curr_state_dict = model.state_dict() for key in curr_state_dict.keys(): num_total += 1 if key in state_dict and curr_state_dict[key].shape == state_dict[key].shape: curr_state_dict[key] = state_dict[key] num_matched += 1 model.load_state_dict(curr_state_dict) if print_stats: print(f'Loaded state_dict: {num_matched}/{num_total} matched')