mrcuddle commited on
Commit
11b8aac
1 Parent(s): 48a108f

Update merge.py

Browse files
Files changed (1) hide show
  1. merge.py +108 -32
merge.py CHANGED
@@ -5,21 +5,43 @@ import shutil
5
  import torch
6
  import torch.nn.functional as F
7
  from safetensors.torch import safe_open, save_file
 
8
 
9
- def merge_tensors(tensor1, tensor2, p):
10
- # Calculate the delta of the weights
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  delta = tensor2 - tensor1
12
- # Generate the mask m^t from Bernoulli distribution
13
  m = torch.from_numpy(np.random.binomial(1, p, delta.shape)).to(tensor1.dtype)
14
- # Apply the mask to the delta to get δ̃^t
15
  delta_tilde = m * delta
16
- # Scale the masked delta by the dropout rate to get δ̂^t
17
  delta_hat = delta_tilde / (1 - p)
18
  return delta_hat
19
 
20
- def merge_safetensors(file_path1, file_path2, p, lambda_val):
21
- merged_tensors = {}
 
 
 
 
 
 
 
22
 
 
 
 
 
23
  with safe_open(file_path1, framework="pt", device="cpu") as f1, safe_open(file_path2, framework="pt", device="cpu") as f2:
24
  keys1 = set(f1.keys())
25
  keys2 = set(f2.keys())
@@ -30,18 +52,31 @@ def merge_safetensors(file_path1, file_path2, p, lambda_val):
30
  tensor2 = f2.get_tensor(key)
31
  tensor1, tensor2 = resize_tensors(tensor1, tensor2)
32
  merged_tensors[key] = tensor1 + lambda_val * merge_tensors(tensor1, tensor2, p)
33
- print("merging", key)
34
 
35
  return merged_tensors
36
 
37
- class BinDataHandler():
38
- def __init__(self, data):
 
 
 
39
  self.data = data
40
 
41
- def get_tensor(self, key):
42
  return self.data[key]
43
 
44
- def read_tensors(file_path, ext):
 
 
 
 
 
 
 
 
 
 
45
  if ext == ".safetensors" and file_path.endswith(".safetensors"):
46
  f = safe_open(file_path, framework="pt", device="cpu")
47
  return f, set(f.keys())
@@ -51,11 +86,20 @@ def read_tensors(file_path, ext):
51
  return f, set(data.keys())
52
  return None, None
53
 
54
- def resize_tensors(tensor1, tensor2):
 
 
 
 
 
 
 
 
 
 
55
  if len(tensor1.shape) not in [1, 2]:
56
  return tensor1, tensor2
57
 
58
- # Pad along the last dimension (width)
59
  if tensor1.shape[-1] < tensor2.shape[-1]:
60
  padding_size = tensor2.shape[-1] - tensor1.shape[-1]
61
  tensor1 = F.pad(tensor1, (0, padding_size, 0, 0))
@@ -63,7 +107,6 @@ def resize_tensors(tensor1, tensor2):
63
  padding_size = tensor1.shape[-1] - tensor2.shape[-1]
64
  tensor2 = F.pad(tensor2, (0, padding_size, 0, 0))
65
 
66
- # Pad along the first dimension (height)
67
  if tensor1.shape[0] < tensor2.shape[0]:
68
  padding_size = tensor2.shape[0] - tensor1.shape[0]
69
  tensor1 = F.pad(tensor1, (0, 0, 0, padding_size))
@@ -73,18 +116,28 @@ def resize_tensors(tensor1, tensor2):
73
 
74
  return tensor1, tensor2
75
 
76
- def merge_folder(tensor_map, directory_path, p, lambda_val):
 
 
 
 
 
 
 
 
 
 
 
 
77
  keys1 = set(tensor_map.keys())
78
- # Some repos have both bin and safetensors, choose safetensors if so
79
  ext = None
80
  for filename in os.listdir(directory_path):
81
- # Default to safetensors
82
  if filename.endswith(".safetensors"):
83
  ext = ".safetensors"
84
  if filename.endswith(".bin") and ext is None:
85
  ext = ".bin"
86
  if ext is None:
87
- raise "Could not find model files"
88
 
89
  for filename in os.listdir(directory_path):
90
  file_path = os.path.join(directory_path, filename)
@@ -95,7 +148,7 @@ def merge_folder(tensor_map, directory_path, p, lambda_val):
95
  if "block_sparse_moe.gate" in key:
96
  tensor1 = tensor_map[key]['tensor']
97
  tensor2 = f.get_tensor(key)
98
- tensor_map[key]['tensor'] = (tensor1 + tensor2) /2.0
99
  continue
100
  tensor1 = tensor_map[key]['tensor']
101
  tensor2 = f.get_tensor(key)
@@ -103,27 +156,48 @@ def merge_folder(tensor_map, directory_path, p, lambda_val):
103
  tensor_map[key]['tensor'] = tensor1 + lambda_val * merge_tensors(tensor1, tensor2, p)
104
  return tensor_map
105
 
106
- def map_tensors_to_files(directory_path):
107
- tensor_map = {}
 
 
 
 
108
 
 
 
 
 
109
  for filename in os.listdir(directory_path):
110
  file_path = os.path.join(directory_path, filename)
111
  f, keys = read_tensors(file_path, '.safetensors')
112
  if keys:
113
  for key in keys:
114
  tensor = f.get_tensor(key)
115
- tensor_map[key] = {'filename':filename, 'shape':tensor.shape, 'tensor': tensor}
116
-
117
  return tensor_map
118
 
119
- def copy_nontensor_files(from_path, to_path):
 
 
 
 
 
 
 
120
  for filename in os.listdir(from_path):
121
  file_path = os.path.join(from_path, filename)
122
  if from_path != to_path and not filename.startswith(".") and not filename.startswith("README") and not filename.endswith(".bin") and not filename.endswith(".safetensors") and not filename.endswith(".pt") and not os.path.isdir(file_path):
123
- print(f"Copying {file_path} to {to_path}")
124
- shutil.copyfile(file_path, to_path+'/'+filename)
 
 
 
 
125
 
126
- def save_tensor_map(tensor_map, output_folder):
 
 
 
127
  metadata = {'format': 'pt'}
128
  by_filename = {}
129
 
@@ -135,12 +209,14 @@ def save_tensor_map(tensor_map, output_folder):
135
  by_filename[filename][key] = tensor
136
 
137
  for filename in sorted(by_filename.keys()):
138
- output_file = output_folder+'/'+filename
139
- print("Saving:", output_file)
140
  save_file(by_filename[filename], output_file, metadata=metadata)
141
 
142
  def main():
143
- # Parse command-line arguments
 
 
144
  parser = argparse.ArgumentParser(description='Merge two safetensor model files.')
145
  parser.add_argument('base_model', type=str, help='The base model safetensor file')
146
  parser.add_argument('second_model', type=str, help='The second model safetensor file')
@@ -162,4 +238,4 @@ def main():
162
  save_file(merged, args.output_model)
163
 
164
  if __name__ == '__main__':
165
- main()
 
5
  import torch
6
  import torch.nn.functional as F
7
  from safetensors.torch import safe_open, save_file
8
+ import logging
9
 
10
+ # Set up logging
11
+ logging.basicConfig(level=logging.INFO)
12
+
13
+ def merge_tensors(tensor1: torch.Tensor, tensor2: torch.Tensor, p: float) -> torch.Tensor:
14
+ """
15
+ Merge two tensors using dropout and scaling.
16
+
17
+ Args:
18
+ tensor1 (torch.Tensor): The first tensor.
19
+ tensor2 (torch.Tensor): The second tensor.
20
+ p (float): Dropout probability.
21
+
22
+ Returns:
23
+ torch.Tensor: The merged tensor.
24
+ """
25
  delta = tensor2 - tensor1
 
26
  m = torch.from_numpy(np.random.binomial(1, p, delta.shape)).to(tensor1.dtype)
 
27
  delta_tilde = m * delta
 
28
  delta_hat = delta_tilde / (1 - p)
29
  return delta_hat
30
 
31
+ def merge_safetensors(file_path1: str, file_path2: str, p: float, lambda_val: float) -> dict:
32
+ """
33
+ Merge two safetensors files.
34
+
35
+ Args:
36
+ file_path1 (str): Path to the first safetensors file.
37
+ file_path2 (str): Path to the second safetensors file.
38
+ p (float): Dropout probability.
39
+ lambda_val (float): Scaling factor.
40
 
41
+ Returns:
42
+ dict: A dictionary of merged tensors.
43
+ """
44
+ merged_tensors = {}
45
  with safe_open(file_path1, framework="pt", device="cpu") as f1, safe_open(file_path2, framework="pt", device="cpu") as f2:
46
  keys1 = set(f1.keys())
47
  keys2 = set(f2.keys())
 
52
  tensor2 = f2.get_tensor(key)
53
  tensor1, tensor2 = resize_tensors(tensor1, tensor2)
54
  merged_tensors[key] = tensor1 + lambda_val * merge_tensors(tensor1, tensor2, p)
55
+ logging.info(f"Merging {key}")
56
 
57
  return merged_tensors
58
 
59
+ class BinDataHandler:
60
+ """
61
+ A handler for binary data files.
62
+ """
63
+ def __init__(self, data: dict):
64
  self.data = data
65
 
66
+ def get_tensor(self, key: str) -> torch.Tensor:
67
  return self.data[key]
68
 
69
+ def read_tensors(file_path: str, ext: str) -> tuple:
70
+ """
71
+ Read tensors from a file.
72
+
73
+ Args:
74
+ file_path (str): Path to the file.
75
+ ext (str): File extension.
76
+
77
+ Returns:
78
+ tuple: A tuple containing the file handler and the set of keys.
79
+ """
80
  if ext == ".safetensors" and file_path.endswith(".safetensors"):
81
  f = safe_open(file_path, framework="pt", device="cpu")
82
  return f, set(f.keys())
 
86
  return f, set(data.keys())
87
  return None, None
88
 
89
+ def resize_tensors(tensor1: torch.Tensor, tensor2: torch.Tensor) -> tuple:
90
+ """
91
+ Resize tensors to ensure they have the same shape.
92
+
93
+ Args:
94
+ tensor1 (torch.Tensor): The first tensor.
95
+ tensor2 (torch.Tensor): The second tensor.
96
+
97
+ Returns:
98
+ tuple: A tuple containing the resized tensors.
99
+ """
100
  if len(tensor1.shape) not in [1, 2]:
101
  return tensor1, tensor2
102
 
 
103
  if tensor1.shape[-1] < tensor2.shape[-1]:
104
  padding_size = tensor2.shape[-1] - tensor1.shape[-1]
105
  tensor1 = F.pad(tensor1, (0, padding_size, 0, 0))
 
107
  padding_size = tensor1.shape[-1] - tensor2.shape[-1]
108
  tensor2 = F.pad(tensor2, (0, padding_size, 0, 0))
109
 
 
110
  if tensor1.shape[0] < tensor2.shape[0]:
111
  padding_size = tensor2.shape[0] - tensor1.shape[0]
112
  tensor1 = F.pad(tensor1, (0, 0, 0, padding_size))
 
116
 
117
  return tensor1, tensor2
118
 
119
+ def merge_folder(tensor_map: dict, directory_path: str, p: float, lambda_val: float) -> dict:
120
+ """
121
+ Merge tensors from a directory of model files.
122
+
123
+ Args:
124
+ tensor_map (dict): A dictionary mapping tensor keys to their file paths.
125
+ directory_path (str): Path to the directory containing model files.
126
+ p (float): Dropout probability.
127
+ lambda_val (float): Scaling factor.
128
+
129
+ Returns:
130
+ dict: A dictionary of merged tensors.
131
+ """
132
  keys1 = set(tensor_map.keys())
 
133
  ext = None
134
  for filename in os.listdir(directory_path):
 
135
  if filename.endswith(".safetensors"):
136
  ext = ".safetensors"
137
  if filename.endswith(".bin") and ext is None:
138
  ext = ".bin"
139
  if ext is None:
140
+ raise FileNotFoundError("Could not find model files")
141
 
142
  for filename in os.listdir(directory_path):
143
  file_path = os.path.join(directory_path, filename)
 
148
  if "block_sparse_moe.gate" in key:
149
  tensor1 = tensor_map[key]['tensor']
150
  tensor2 = f.get_tensor(key)
151
+ tensor_map[key]['tensor'] = (tensor1 + tensor2) / 2.0
152
  continue
153
  tensor1 = tensor_map[key]['tensor']
154
  tensor2 = f.get_tensor(key)
 
156
  tensor_map[key]['tensor'] = tensor1 + lambda_val * merge_tensors(tensor1, tensor2, p)
157
  return tensor_map
158
 
159
+ def map_tensors_to_files(directory_path: str) -> dict:
160
+ """
161
+ Map tensors to their respective files in a directory.
162
+
163
+ Args:
164
+ directory_path (str): Path to the directory containing model files.
165
 
166
+ Returns:
167
+ dict: A dictionary mapping tensor keys to their file paths.
168
+ """
169
+ tensor_map = {}
170
  for filename in os.listdir(directory_path):
171
  file_path = os.path.join(directory_path, filename)
172
  f, keys = read_tensors(file_path, '.safetensors')
173
  if keys:
174
  for key in keys:
175
  tensor = f.get_tensor(key)
176
+ tensor_map[key] = {'filename': filename, 'shape': tensor.shape, 'tensor': tensor}
 
177
  return tensor_map
178
 
179
+ def copy_nontensor_files(from_path: str, to_path: str):
180
+ """
181
+ Copy non-tensor files from one directory to another.
182
+
183
+ Args:
184
+ from_path (str): Path to the source directory.
185
+ to_path (str): Path to the destination directory.
186
+ """
187
  for filename in os.listdir(from_path):
188
  file_path = os.path.join(from_path, filename)
189
  if from_path != to_path and not filename.startswith(".") and not filename.startswith("README") and not filename.endswith(".bin") and not filename.endswith(".safetensors") and not filename.endswith(".pt") and not os.path.isdir(file_path):
190
+ logging.info(f"Copying {file_path} to {to_path}")
191
+ shutil.copyfile(file_path, to_path + '/' + filename)
192
+
193
+ def save_tensor_map(tensor_map: dict, output_folder: str):
194
+ """
195
+ Save the merged tensor map to the output directory.
196
 
197
+ Args:
198
+ tensor_map (dict): A dictionary of merged tensors.
199
+ output_folder (str): Path to the output directory.
200
+ """
201
  metadata = {'format': 'pt'}
202
  by_filename = {}
203
 
 
209
  by_filename[filename][key] = tensor
210
 
211
  for filename in sorted(by_filename.keys()):
212
+ output_file = output_folder + '/' + filename
213
+ logging.info(f"Saving: {output_file}")
214
  save_file(by_filename[filename], output_file, metadata=metadata)
215
 
216
  def main():
217
+ """
218
+ Main function to parse command-line arguments and orchestrate the merging process.
219
+ """
220
  parser = argparse.ArgumentParser(description='Merge two safetensor model files.')
221
  parser.add_argument('base_model', type=str, help='The base model safetensor file')
222
  parser.add_argument('second_model', type=str, help='The second model safetensor file')
 
238
  save_file(merged, args.output_model)
239
 
240
  if __name__ == '__main__':
241
+ main()