import argparse import numpy as np import os import shutil import torch import torch.nn.functional as F from safetensors.torch import safe_open, save_file def merge_tensors(tensor1, tensor2, p): # Calculate the delta of the weights delta = tensor2 - tensor1 # Generate the mask m^t from Bernoulli distribution m = torch.from_numpy(np.random.binomial(1, p, delta.shape)).to(tensor1.dtype) # Apply the mask to the delta to get δ̃^t delta_tilde = m * delta # Scale the masked delta by the dropout rate to get δ̂^t delta_hat = delta_tilde / (1 - p) return delta_hat def merge_safetensors(file_path1, file_path2, p, lambda_val): merged_tensors = {} with safe_open(file_path1, framework="pt", device="cpu") as f1, safe_open(file_path2, framework="pt", device="cpu") as f2: keys1 = set(f1.keys()) keys2 = set(f2.keys()) common_keys = keys1.intersection(keys2) for key in common_keys: tensor1 = f1.get_tensor(key) tensor2 = f2.get_tensor(key) tensor1, tensor2 = resize_tensors(tensor1, tensor2) merged_tensors[key] = tensor1 + lambda_val * merge_tensors(tensor1, tensor2, p) print("merging", key) return merged_tensors class BinDataHandler(): def __init__(self, data): self.data = data def get_tensor(self, key): return self.data[key] def read_tensors(file_path, ext): if ext == ".safetensors" and file_path.endswith(".safetensors"): f = safe_open(file_path, framework="pt", device="cpu") return f, set(f.keys()) if ext == ".bin" and file_path.endswith(".bin"): data = torch.load(file_path, map_location=torch.device('cpu')) f = BinDataHandler(data) return f, set(data.keys()) return None, None def resize_tensors(tensor1, tensor2): if len(tensor1.shape) not in [1, 2]: return tensor1, tensor2 # Pad along the last dimension (width) if tensor1.shape[-1] < tensor2.shape[-1]: padding_size = tensor2.shape[-1] - tensor1.shape[-1] tensor1 = F.pad(tensor1, (0, padding_size, 0, 0)) elif tensor2.shape[-1] < tensor1.shape[-1]: padding_size = tensor1.shape[-1] - tensor2.shape[-1] tensor2 = F.pad(tensor2, (0, padding_size, 0, 0)) # Pad along the first dimension (height) if tensor1.shape[0] < tensor2.shape[0]: padding_size = tensor2.shape[0] - tensor1.shape[0] tensor1 = F.pad(tensor1, (0, 0, 0, padding_size)) elif tensor2.shape[0] < tensor1.shape[0]: padding_size = tensor1.shape[0] - tensor2.shape[0] tensor2 = F.pad(tensor2, (0, 0, 0, padding_size)) return tensor1, tensor2 def merge_folder(tensor_map, directory_path, p, lambda_val): keys1 = set(tensor_map.keys()) # Some repos have both bin and safetensors, choose safetensors if so ext = None for filename in os.listdir(directory_path): # Default to safetensors if filename.endswith(".safetensors"): ext = ".safetensors" if filename.endswith(".bin") and ext is None: ext = ".bin" if ext is None: raise "Could not find model files" for filename in os.listdir(directory_path): file_path = os.path.join(directory_path, filename) f, keys2 = read_tensors(file_path, ext) if keys2: common_keys = keys1.intersection(keys2) for key in common_keys: if "block_sparse_moe.gate" in key: tensor1 = tensor_map[key]['tensor'] tensor2 = f.get_tensor(key) tensor_map[key]['tensor'] = (tensor1 + tensor2) /2.0 continue tensor1 = tensor_map[key]['tensor'] tensor2 = f.get_tensor(key) tensor1, tensor2 = resize_tensors(tensor1, tensor2) tensor_map[key]['tensor'] = tensor1 + lambda_val * merge_tensors(tensor1, tensor2, p) return tensor_map def map_tensors_to_files(directory_path): tensor_map = {} for filename in os.listdir(directory_path): file_path = os.path.join(directory_path, filename) f, keys = read_tensors(file_path, '.safetensors') if keys: for key in keys: tensor = f.get_tensor(key) tensor_map[key] = {'filename':filename, 'shape':tensor.shape, 'tensor': tensor} return tensor_map def copy_nontensor_files(from_path, to_path): for filename in os.listdir(from_path): file_path = os.path.join(from_path, filename) if from_path != to_path and not filename.startswith(".") and not filename.startswith("README") and not filename.endswith(".bin") and not filename.endswith(".safetensors") and not filename.endswith(".pt") and not os.path.isdir(file_path): print(f"Copying {file_path} to {to_path}") shutil.copyfile(file_path, to_path+'/'+filename) def save_tensor_map(tensor_map, output_folder): metadata = {'format': 'pt'} by_filename = {} for key, value in tensor_map.items(): filename = value["filename"] tensor = value["tensor"] if filename not in by_filename: by_filename[filename] = {} by_filename[filename][key] = tensor for filename in sorted(by_filename.keys()): output_file = output_folder+'/'+filename print("Saving:", output_file) save_file(by_filename[filename], output_file, metadata=metadata) def main(): # Parse command-line arguments parser = argparse.ArgumentParser(description='Merge two safetensor model files.') parser.add_argument('base_model', type=str, help='The base model safetensor file') parser.add_argument('second_model', type=str, help='The second model safetensor file') parser.add_argument('output_model', type=str, help='The output merged model safetensor file') parser.add_argument('-p', type=float, default=0.5, help='Dropout probability') parser.add_argument('-lambda', dest='lambda_val', type=float, default=1.0, help='Scaling factor for the weight delta') args = parser.parse_args() if os.path.isdir(args.base_model): if not os.path.exists(args.output_model): os.makedirs(args.output_model) tensor_map = map_tensors_to_files(args.base_model) tensor_map = merge_folder(tensor_map, args.second_model, args.p, args.lambda_val) copy_nontensor_files(args.base_model, args.output_model) save_tensor_map(tensor_map, args.output_model) else: merged = merge_safetensors(args.base_model, args.second_model, args.p, args.lambda_val) save_file(merged, args.output_model) if __name__ == '__main__': main()