paulilioaica commited on
Commit
e298d48
1 Parent(s): 97c6fa6

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +958 -1
README.md CHANGED
@@ -95,4 +95,961 @@ The continents can be divided into several subregions, such as islands, archipel
95
  Each continent has its own unique geography, climate, flora, fauna, and human cultures. The continents are interconnected through various landforms, bodies of water, and global trade routes.
96
 
97
  In summary, there are seven continents on Earth, each with its own distinct characteristics and unique contributions to the world's diversity. While the number may vary depending on the categorization of Antarctica, all seven continents together make
98
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  Each continent has its own unique geography, climate, flora, fauna, and human cultures. The continents are interconnected through various landforms, bodies of water, and global trade routes.
96
 
97
  In summary, there are seven continents on Earth, each with its own distinct characteristics and unique contributions to the world's diversity. While the number may vary depending on the categorization of Antarctica, all seven continents together make
98
+ ```
99
+
100
+
101
+ ## ♻️ Replicate this repo
102
+
103
+ **beware** this will only work with 2 phis, you might have to tinker in the naming thing for more layers
104
+
105
+ **AFTER** all the file modifications and run, you need to replace `configs.json` with the one from **this repo**
106
+ **AFTER** that you need to add `modeling_phi.py` and `configurations.phi` from **this repo** to your repo
107
+
108
+ Steps
109
+
110
+ 1. Modify moe_mixtral.py from `/content/mergekit/mergekit/scripts/mixtral_moe.py` to your hf repo
111
+
112
+ ***mixtral_moe.py***
113
+
114
+ ```
115
+ # Copyright (C) 2024 Charles O. Goddard
116
+ #
117
+ # This software is free software: you can redistribute it and/or
118
+ # modify it under the terms of the GNU Lesser General Public License as
119
+ # published by the Free Software Foundation, either version 3 of the
120
+ # License, or (at your option) any later version.
121
+ #
122
+ # This software is distributed in the hope that it will be useful, but
123
+ # WITHOUT ANY WARRANTY; without even the implied warranty of
124
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
125
+ # Lesser General Public License for more details.
126
+ #
127
+ # You should have received a copy of the GNU Lesser General Public License
128
+ # along with this program. If not, see http://www.gnu.org/licenses/.
129
+
130
+ import logging
131
+ import os
132
+ import sys
133
+ from typing import Dict, List, Optional, Union
134
+
135
+ import click
136
+ import torch
137
+ import tqdm
138
+ import transformers
139
+ import yaml
140
+ from pydantic import BaseModel
141
+ from transformers import (
142
+ AutoModelForCausalLM,
143
+ LlamaForCausalLM,
144
+ MistralConfig,
145
+ MistralForCausalLM,
146
+ MixtralConfig,
147
+ )
148
+ from transformers.modeling_outputs import CausalLMOutputWithPast
149
+
150
+ import mergekit.architecture
151
+ from mergekit.common import ModelReference, dtype_from_name
152
+ from mergekit.io import LazyTensorLoader, TensorWriter
153
+ from mergekit.merge import MergeOptions
154
+ from mergekit.options import add_merge_options
155
+
156
+ # Create a Mixtral MoE from a set of equally-sized Mistral (or Llama) models.
157
+ # Takes the path to a yml config and an output path.
158
+ # Config schema is the two classes below.
159
+
160
+
161
+ class Expert(BaseModel):
162
+ source_model: str
163
+
164
+ positive_prompts: List[str]
165
+ negative_prompts: Optional[List[str]] = None
166
+ noise_scale: Optional[float] = None
167
+
168
+ @property
169
+ def model_ref(self):
170
+ return ModelReference.parse(self.source_model)
171
+
172
+
173
+ class MistralMOEConfig(BaseModel):
174
+ base_model: str
175
+ experts: List[Expert]
176
+ gate_mode: str = "hidden" # possible values: "hidden", "cheap_embed", "random"
177
+ # "hidden" uses hidden state vectors for the given prompts for each layer
178
+ # "cheap_embed" uses the average of token embeddings for the prompts, same for each layer
179
+ # "random" is random
180
+ dtype: Optional[str] = None
181
+ experts_per_token: int = 2
182
+
183
+
184
+ def get_hidden_states(
185
+ model: Union[MistralForCausalLM, LlamaForCausalLM],
186
+ tokenized: transformers.BatchEncoding,
187
+ average: bool = True,
188
+ ) -> List[torch.Tensor]:
189
+ with torch.no_grad():
190
+ output: CausalLMOutputWithPast = model(
191
+ **tokenized.to(model.device), output_hidden_states=True, return_dict=True
192
+ )
193
+ hidden_states = torch.stack(
194
+ output.hidden_states[:-1]
195
+ ) # (num_layers, batch_size, seq_len, hidden_size)
196
+ if average:
197
+ # use average over sequence
198
+ hidden_states = hidden_states.sum(dim=2) / hidden_states.shape[2]
199
+ else:
200
+ # take last value
201
+ hidden_states = hidden_states[:, :, -1, :]
202
+ return hidden_states.sum(dim=1) / hidden_states.shape[1]
203
+
204
+
205
+ def get_cheap_embedding(
206
+ embed: torch.Tensor,
207
+ tokenized: Dict[str, torch.Tensor],
208
+ num_layers: int,
209
+ vocab_size: int,
210
+ ) -> torch.Tensor:
211
+ onehot = torch.nn.functional.one_hot(
212
+ tokenized["input_ids"], num_classes=vocab_size
213
+ ) # (batch_size, seq_len, 32000)
214
+ h = onehot.float() @ embed.float() # (batch_size, seq_len, hidden_size)
215
+ embedded = (
216
+ (h * tokenized["attention_mask"].unsqueeze(-1))
217
+ .sum(dim=1)
218
+ .sum(dim=0, keepdim=True)
219
+ ) # (1, hidden_size)
220
+ res = embedded / embedded.norm(dim=-1, keepdim=True).clamp(
221
+ min=1e-8
222
+ ) # (1, hidden_size)
223
+ return res.repeat(num_layers, 1)
224
+
225
+
226
+ def tokenize_prompts(
227
+ prompts: List[str], tokenizer: transformers.PreTrainedTokenizerBase
228
+ ):
229
+ return tokenizer(
230
+ [tokenizer.bos_token + p for p in prompts],
231
+ return_tensors="pt",
232
+ padding=True,
233
+ add_special_tokens=False,
234
+ )
235
+
236
+
237
+ def get_gate_params(
238
+ model_ref: ModelReference,
239
+ tokenizer: transformers.PreTrainedTokenizerBase,
240
+ experts: List[Expert],
241
+ mode: str = "hidden",
242
+ load_in_4bit: bool = False,
243
+ load_in_8bit: bool = False,
244
+ lazy_unpickle: bool = False,
245
+ trust_remote_code: bool = False,
246
+ device: str = "auto",
247
+ ):
248
+ gate_vecs = []
249
+ _do_it = None
250
+
251
+ model_cfg = model_ref.config(trust_remote_code=trust_remote_code)
252
+
253
+ if mode == "random":
254
+ return torch.randn(
255
+ (model_cfg.num_hidden_layers, len(experts), model_cfg.hidden_size)
256
+ )
257
+ elif mode == "cheap_embed":
258
+ embed = LazyTensorLoader(
259
+ model_ref.tensor_index(), lazy_unpickle=lazy_unpickle
260
+ ).get_tensor("transformer.embd.wte.weight")
261
+
262
+ def _do_it(tokenized):
263
+ return get_cheap_embedding(
264
+ embed,
265
+ tokenized,
266
+ num_layers=model_cfg.num_hidden_layers,
267
+ vocab_size=model_cfg.vocab_size,
268
+ )
269
+
270
+ elif mode in ("hidden", "hidden_avg", "hidden_last"):
271
+ model = AutoModelForCausalLM.from_pretrained(
272
+ model_ref.model.path,
273
+ revision=model_ref.model.revision,
274
+ torch_dtype=torch.bfloat16,
275
+ device_map=device,
276
+ low_cpu_mem_usage=True,
277
+ load_in_4bit=load_in_4bit,
278
+ load_in_8bit=load_in_8bit,
279
+ trust_remote_code=trust_remote_code,
280
+ )
281
+
282
+ def _do_it(tokenized):
283
+ return get_hidden_states(
284
+ model, tokenized=tokenized, average=mode == "hidden_avg"
285
+ )
286
+
287
+
288
+ gate_vecs = []
289
+ print(experts)
290
+ for expert in tqdm.tqdm(experts, desc="expert prompts"):
291
+ print(_do_it)
292
+ hidden_states = _do_it(tokenize_prompts(expert.positive_prompts, tokenizer))
293
+ if expert.negative_prompts:
294
+ hidden_states -= _do_it(
295
+ tokenize_prompts(expert.negative_prompts, tokenizer)
296
+ )
297
+
298
+ hidden_states /= hidden_states.norm(p=2, dim=-1, keepdim=True).clamp(min=1e-8)
299
+ gate_vecs.append(hidden_states)
300
+ gate_vecs = torch.stack(gate_vecs, dim=0) # (num_expert, num_layer, hidden_size)
301
+ return gate_vecs.permute(1, 0, 2)
302
+
303
+
304
+ def warn_degenerate_gates(gate_vecs: torch.Tensor, threshold: float = 5.0):
305
+ degen_indices = []
306
+ num_layers, _num_experts, _hidden_size = gate_vecs.shape
307
+ for idx in range(num_layers):
308
+ c = torch.linalg.cond(gate_vecs[idx, :, :].float())
309
+ if c > threshold:
310
+ degen_indices.append(idx)
311
+
312
+ if degen_indices:
313
+ if len(degen_indices) == 1:
314
+ layer_str = f"layer {degen_indices[0]}"
315
+ verb = "has"
316
+ elif len(degen_indices) == 2:
317
+ layer_str = f"layers {' and '.join(map(str, degen_indices))}"
318
+ verb = "have"
319
+ elif len(degen_indices) >= num_layers:
320
+ layer_str = "ALL layers"
321
+ verb = "have"
322
+ else:
323
+ layer_str = (
324
+ "layers "
325
+ + ", ".join(map(str, degen_indices[:-1]))
326
+ + ", and "
327
+ + str(degen_indices[-1])
328
+ )
329
+ verb = "have"
330
+
331
+ logging.warning(
332
+ f"{layer_str} {verb} degenerate routing parameters "
333
+ "- your prompts may be too similar."
334
+ )
335
+ logging.warning("One or more experts will be underutilized in your model.")
336
+
337
+
338
+ def is_bad_config(config: MistralMOEConfig, allow_all_same: bool = False) -> bool:
339
+ if len(config.experts) < 2:
340
+ logging.error("Must include at least two experts.")
341
+ return True
342
+
343
+ if config.gate_mode == "random":
344
+ return False # eh we're good
345
+
346
+ def prompt_tup(e: Expert):
347
+ return (tuple(e.positive_prompts), tuple(e.negative_prompts or []))
348
+
349
+ # let's just nip this trend in the bud
350
+ p_first = prompt_tup(config.experts[0])
351
+ if all(prompt_tup(e) == p_first for e in config.experts[1:]):
352
+ logging.error(
353
+ "Your positive and negative prompts are identical for all experts. This will not produce a functioning MoE."
354
+ )
355
+ logging.error(
356
+ "For each expert, `positive_prompts` must contain one or more example prompt reflecting what should be routed to that expert."
357
+ )
358
+ return True
359
+
360
+ if not allow_all_same:
361
+ if all(
362
+ e.source_model == config.experts[0].source_model for e in config.experts[1:]
363
+ ):
364
+ logging.error(
365
+ "All of your expert models are the same. This will produce "
366
+ "a model that uses more resources but gives the exact same output. "
367
+ "If you plan to train the model after merging, proceed with the "
368
+ "--i-understand-this-is-not-useful-without-training flag."
369
+ )
370
+ return True
371
+
372
+
373
+ def build(
374
+ config: MistralMOEConfig,
375
+ out_path: str,
376
+ merge_options: MergeOptions,
377
+ load_in_4bit: bool = False,
378
+ load_in_8bit: bool = False,
379
+ device: str = "auto",
380
+ allow_all_same: bool = False,
381
+ ):
382
+ if is_bad_config(config, allow_all_same=allow_all_same):
383
+ sys.exit(1)
384
+
385
+ if config.experts_per_token < 1:
386
+ logging.error("Experts per token must be >= 1")
387
+ sys.exit(1)
388
+ if config.experts_per_token > len(config.experts):
389
+ logging.error("Experts per token must be <= number of experts")
390
+ sys.exit(1)
391
+
392
+ base_model = ModelReference.parse(config.base_model)
393
+ base_cfg = base_model.config(trust_remote_code=merge_options.trust_remote_code)
394
+ if not isinstance(base_cfg, MistralConfig):
395
+ base_cfg_mistral = MistralConfig(**base_cfg.to_dict())
396
+ base_cfg_mistral.sliding_window = None
397
+ base_cfg_mistral.max_position_embeddings = base_cfg.max_position_embeddings
398
+ base_cfg = base_cfg_mistral
399
+
400
+ out_cfg = MixtralConfig(**base_cfg.to_dict())
401
+ out_cfg.architectures = ["PhiForCausalLM"]
402
+ out_cfg.num_local_experts = len(config.experts)
403
+ out_cfg.num_experts_per_tok = config.experts_per_token
404
+ out_cfg.sliding_window = None
405
+ if config.dtype:
406
+ out_cfg.torch_dtype = config.dtype
407
+ out_cfg.save_pretrained(out_path)
408
+
409
+ if (out_cfg.num_local_experts & (out_cfg.num_local_experts - 1)) != 0:
410
+ logging.warning(
411
+ f"Your model has {out_cfg.num_local_experts} experts, which is "
412
+ "not a power of two. The model will not be usable in llama.cpp."
413
+ )
414
+
415
+ loaders: Dict[ModelReference, LazyTensorLoader] = {}
416
+ for model in tqdm.tqdm(
417
+ [base_model] + [e.model_ref for e in config.experts], desc="Warm up loaders"
418
+ ):
419
+ loaders[model] = LazyTensorLoader(
420
+ model.tensor_index(cache_dir=merge_options.transformers_cache),
421
+ lazy_unpickle=merge_options.lazy_unpickle,
422
+ )
423
+
424
+ base_loader = loaders.get(base_model)
425
+ writer = TensorWriter(
426
+ out_path=out_path,
427
+ max_shard_size=merge_options.out_shard_size,
428
+ safe_serialization=merge_options.safe_serialization,
429
+ )
430
+
431
+ if config.dtype:
432
+ out_dtype = dtype_from_name(config.dtype)
433
+ elif base_cfg.torch_dtype:
434
+ out_dtype = base_cfg.torch_dtype
435
+ if isinstance(out_dtype, str):
436
+ out_dtype = dtype_from_name(out_dtype)
437
+ else:
438
+ out_dtype = None
439
+
440
+ logging.info("Copying parameters...")
441
+ MISTRAL_INFO = mergekit.architecture.PHI2_INFO
442
+ for tensor_name in MISTRAL_INFO.pre_weight_names + MISTRAL_INFO.post_weight_names:
443
+ tensor = base_loader.get_tensor(tensor_name)
444
+ if not out_dtype:
445
+ # All else has failed, take the first dtype we see
446
+ out_dtype = tensor.dtype
447
+ writer.save_tensor(
448
+ tensor_name, tensor.to(dtype=out_dtype), clone=merge_options.clone_tensors
449
+ )
450
+ set_of_seen_tensors = set()
451
+
452
+ for name_format in tqdm.tqdm(MISTRAL_INFO.layer_weight_formats()):
453
+ for layer_idx in range(base_cfg.num_hidden_layers):
454
+ tensor_name = name_format.format(idx=layer_idx)
455
+ if ".mlp.fc" in name_format:
456
+ for moe_index, expert in enumerate(config.experts):
457
+ if tensor_name in set_of_seen_tensors:
458
+ expert_name = tensor_name.replace(
459
+ ".mlp.fc", f".moe.mlp.1.fc"
460
+ )
461
+ else:
462
+ expert_name = tensor_name.replace(
463
+ ".mlp.fc", f".moe.mlp.0.fc"
464
+ )
465
+ set_of_seen_tensors.add(tensor_name)
466
+
467
+ expert_loader = loaders.get(expert.model_ref)
468
+ tensor = expert_loader.get_tensor(tensor_name)
469
+ if expert.noise_scale:
470
+ tensor += torch.randn_like(tensor) * expert.noise_scale
471
+ writer.save_tensor(
472
+ expert_name, tensor.to(dtype=out_dtype), clone=True
473
+ )
474
+ print(expert_name, tensor_name)
475
+ continue
476
+ writer.save_tensor(
477
+ tensor_name, base_loader.get_tensor(tensor_name).to(dtype=out_dtype)
478
+ )
479
+
480
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
481
+ base_model.model.path, revision=base_model.model.revision
482
+ )
483
+ tokenizer.padding_side = "left"
484
+ tokenizer.pad_token_id = tokenizer.bos_token_id
485
+
486
+ logging.info("Getting gate parameters...")
487
+ gate_vecs = get_gate_params(
488
+ base_model,
489
+ tokenizer,
490
+ config.experts,
491
+ mode=config.gate_mode,
492
+ load_in_4bit=load_in_4bit,
493
+ load_in_8bit=load_in_8bit,
494
+ lazy_unpickle=merge_options.lazy_unpickle,
495
+ trust_remote_code=merge_options.trust_remote_code,
496
+ device=device,
497
+ )
498
+ # gate_vecs: (num_layers, num_experts, hidden_size)
499
+
500
+ warn_degenerate_gates(gate_vecs)
501
+
502
+ for layer_idx in range(base_cfg.num_hidden_layers):
503
+ writer.save_tensor(
504
+ f"transformer.h.{layer_idx}.moe.gate.weight",
505
+ gate_vecs[layer_idx, :, :].contiguous().to(dtype=out_dtype),
506
+ )
507
+ writer.finalize()
508
+
509
+ if merge_options.copy_tokenizer:
510
+ logging.info("Saving tokenizer...")
511
+ tokenizer.save_pretrained(out_path, safe_serialization=True)
512
+
513
+ logging.info("Done.")
514
+
515
+
516
+ @click.command("mergekit-moe")
517
+ @click.argument("config_path", type=click.Path(exists=True, dir_okay=False))
518
+ @click.argument("out_path", type=click.Path())
519
+ @click.option(
520
+ "--load-in-4bit",
521
+ is_flag=True,
522
+ type=bool,
523
+ default=False,
524
+ help="Load model in 4bit for computing hidden states",
525
+ )
526
+ @click.option(
527
+ "--load-in-8bit",
528
+ is_flag=True,
529
+ type=bool,
530
+ default=False,
531
+ help="Load model in 8bit for computing hidden states",
532
+ )
533
+ @click.option(
534
+ "--device",
535
+ type=str,
536
+ default="auto",
537
+ help="Device to use to compute embeddings",
538
+ show_default=True,
539
+ )
540
+ @click.option(
541
+ "--verbose", "-v", type=bool, default=False, is_flag=True, help="Verbose logging"
542
+ )
543
+ @click.option(
544
+ "--i-understand-this-is-not-useful-without-training",
545
+ type=bool,
546
+ default=False,
547
+ is_flag=True,
548
+ help="Really make the questionable model you want.",
549
+ )
550
+ @add_merge_options
551
+ def main(
552
+ config_path: str,
553
+ out_path: str,
554
+ load_in_4bit: bool,
555
+ load_in_8bit: bool,
556
+ device: str,
557
+ merge_options: MergeOptions,
558
+ verbose: bool,
559
+ i_understand_this_is_not_useful_without_training: bool,
560
+ ):
561
+ logging.basicConfig(level=logging.INFO if verbose else logging.WARNING)
562
+
563
+ if merge_options.cuda:
564
+ logging.warning(
565
+ '--cuda is a no-op for mergekit-moe, use "--device cuda" instead'
566
+ )
567
+
568
+ with open(config_path, "r", encoding="utf-8") as file:
569
+ config_source = file.read()
570
+
571
+ config = MistralMOEConfig.model_validate(yaml.safe_load(config_source))
572
+ build(
573
+ config,
574
+ out_path=out_path,
575
+ merge_options=merge_options,
576
+ load_in_4bit=load_in_4bit,
577
+ load_in_8bit=load_in_8bit,
578
+ device=device,
579
+ allow_all_same=i_understand_this_is_not_useful_without_training,
580
+ )
581
+
582
+ if merge_options.write_model_card:
583
+ # TODO: generate a README.md as well
584
+ with open(
585
+ os.path.join(out_path, "mergekit_moe_config.yml"), "w", encoding="utf-8"
586
+ ) as fp:
587
+ fp.write(config_source)
588
+
589
+
590
+ if __name__ == "__main__":
591
+ main()
592
+ ```
593
+
594
+ 2. Modify architecture.py `/content/mergekit/mergekit/architecture.py`
595
+ (this you can take from the link to the commit i have in description)
596
+
597
+ ***architecture.py***
598
+
599
+ ```
600
+ # Copyright (C) 2024 Charles O. Goddard
601
+ #
602
+ # This software is free software: you can redistribute it and/or
603
+ # modify it under the terms of the GNU Lesser General Public License as
604
+ # published by the Free Software Foundation, either version 3 of the
605
+ # License, or (at your option) any later version.
606
+ #
607
+ # This software is distributed in the hope that it will be useful, but
608
+ # WITHOUT ANY WARRANTY; without even the implied warranty of
609
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
610
+ # Lesser General Public License for more details.
611
+ #
612
+ # You should have received a copy of the GNU Lesser General Public License
613
+ # along with this program. If not, see http://www.gnu.org/licenses/.
614
+
615
+ from abc import ABC, abstractmethod
616
+ from typing import List, Optional
617
+
618
+ from pydantic import BaseModel
619
+ from transformers import PretrainedConfig
620
+
621
+
622
+ class ArchitectureInfo(ABC):
623
+ @abstractmethod
624
+ def pre_weights(self) -> List[str]:
625
+ """Return a list of all weights preceding the first layer."""
626
+ ...
627
+
628
+ @abstractmethod
629
+ def post_weights(self) -> List[str]:
630
+ """Return a list of all weights following the final layer."""
631
+ ...
632
+
633
+ @abstractmethod
634
+ def layer_weight_formats(self) -> List[str]:
635
+ """Return a list of format strings all weights associated with a layer."""
636
+ ...
637
+
638
+ @abstractmethod
639
+ def embed_weights(self) -> List[str]:
640
+ ...
641
+
642
+ def num_layers(self, config: PretrainedConfig) -> int:
643
+ return config.num_hidden_layers
644
+
645
+ def num_layers_config_key(self) -> str:
646
+ """Key in config that represents number of layers"""
647
+ return "num_hidden_layers"
648
+
649
+
650
+ class StaticTensorNames(ArchitectureInfo, BaseModel, frozen=True):
651
+ name: str
652
+
653
+ pre_weight_names: List[str] # weights applied before first layer
654
+ post_weight_names: List[str] # weights applied after last layer
655
+ embed_weight_names: List[str] # weights for embed/lm_head
656
+ layer_prefix_format: str
657
+ layer_weight_suffixes: List[str]
658
+ num_layers_key: Optional[str] = None
659
+
660
+ def pre_weights(self) -> List[str]:
661
+ return self.pre_weight_names
662
+
663
+ def post_weights(self) -> List[str]:
664
+ return self.post_weight_names
665
+
666
+ def embed_weights(self) -> List[str]:
667
+ return self.embed_weight_names
668
+
669
+ def layer_weight_formats(self) -> List[str]:
670
+ res = []
671
+ for suffix in self.layer_weight_suffixes:
672
+ res.append(self.layer_prefix_format + "." + suffix)
673
+ return res
674
+
675
+ def num_layers_config_key(self) -> str:
676
+ if self.num_layers_key:
677
+ return self.num_layers_key
678
+ return super().num_layers_config_key()
679
+
680
+ def num_layers(self, config: PretrainedConfig) -> int:
681
+ return getattr(config, self.num_layers_config_key())
682
+
683
+ def all_weights(self, config: PretrainedConfig) -> List[str]:
684
+ num_layers = self.num_layers(config)
685
+ tensor_names = list(self.pre_weights())
686
+ for layer_idx in range(num_layers):
687
+ for f in self.layer_weight_formats():
688
+ tensor_names.append(f.format(idx=layer_idx))
689
+ tensor_names.extend(self.post_weights())
690
+ return tensor_names
691
+
692
+
693
+ LLAMA_INFO = StaticTensorNames(
694
+ name="LlamaForCausalLM",
695
+ pre_weight_names=["model.embed_tokens.weight"],
696
+ post_weight_names=["model.norm.weight", "lm_head.weight"],
697
+ embed_weight_names=["model.embed_tokens.weight", "lm_head.weight"],
698
+ layer_prefix_format="model.layers.{idx}",
699
+ layer_weight_suffixes=[
700
+ "input_layernorm.weight",
701
+ "mlp.up_proj.weight",
702
+ "mlp.down_proj.weight",
703
+ "mlp.gate_proj.weight",
704
+ "post_attention_layernorm.weight",
705
+ "self_attn.q_proj.weight",
706
+ "self_attn.k_proj.weight",
707
+ "self_attn.v_proj.weight",
708
+ "self_attn.o_proj.weight",
709
+ ],
710
+ )
711
+
712
+ MISTRAL_INFO = StaticTensorNames(
713
+ name="MistralForCausalLM",
714
+ # lol
715
+ **LLAMA_INFO.model_dump(exclude=["name"]),
716
+ )
717
+
718
+
719
+ STABLELM_INFO = StaticTensorNames(
720
+ name="StableLMEpochForCausalLM",
721
+ post_weight_names=LLAMA_INFO.post_weight_names + ["model.norm.bias"],
722
+ layer_weight_suffixes=LLAMA_INFO.layer_weight_suffixes
723
+ + [
724
+ "input_layernorm.bias",
725
+ "post_attention_layernorm.bias",
726
+ ],
727
+ **LLAMA_INFO.model_dump(
728
+ exclude=["name", "layer_weight_suffixes", "post_weight_names"]
729
+ ),
730
+ )
731
+
732
+ GPT_NEOX_INFO = StaticTensorNames(
733
+ name="GPTNeoXForCausalLM",
734
+ pre_weight_names=["gpt_neox.embed_in.weight"],
735
+ post_weight_names=[
736
+ "gpt_neox.final_layer_norm.bias",
737
+ "gpt_neox.final_layer_norm.weight",
738
+ "embed_out.weight",
739
+ ],
740
+ embed_weight_names=["gpt_neox.embed_in.weight", "embed_out.weight"],
741
+ layer_prefix_format="gpt_neox.layers.{idx}",
742
+ layer_weight_suffixes=sum(
743
+ (
744
+ [f"{prefix}.weight", f"{prefix}.bias"]
745
+ for prefix in [
746
+ "attention.dense",
747
+ "attention.query_key_value",
748
+ "input_layernorm",
749
+ "mlp.dense_4h_to_h",
750
+ "mlp.dense_h_to_4h",
751
+ "post_attention_layernorm",
752
+ ]
753
+ ),
754
+ start=[],
755
+ )
756
+ + ["attention.bias", "attention.masked_bias", "attention.rotary_emb.inv_freq"],
757
+ )
758
+
759
+ GPT2_INFO = StaticTensorNames(
760
+ name="GPT2LMHeadModel",
761
+ pre_weight_names=["wte.weight", "wpe.weight"],
762
+ post_weight_names=["ln_f.weight", "ln_f.bias"],
763
+ embed_weight_names=["wte.weight"],
764
+ layer_prefix_format="h.{idx}",
765
+ layer_weight_suffixes=[
766
+ "attn.c_attn.weight",
767
+ "attn.c_attn.bias",
768
+ "attn.c_proj.weight",
769
+ "attn.c_proj.bias",
770
+ "ln_1.weight",
771
+ "ln_1.bias",
772
+ "ln_2.weight",
773
+ "ln_2.bias",
774
+ "mlp.c_proj.weight",
775
+ "mlp.c_proj.bias",
776
+ "mlp.c_fc.weight",
777
+ "mlp.c_fc.bias",
778
+ "mlp.c_proj.weight",
779
+ "mlp.c_proj.bias",
780
+ ],
781
+ num_layers_key="n_layer",
782
+ )
783
+
784
+ JAIS_INFO = StaticTensorNames(
785
+ name="JAISLMHeadModel",
786
+ pre_weight_names=["transformer.wte.weight", "transformer.relative_pe.slopes"],
787
+ post_weight_names=["transformer.ln_f.weight", "transformer.ln_f.bias"],
788
+ embed_weight_names=["transformer.wte.weight"],
789
+ layer_prefix_format="transformer.h.{idx}",
790
+ layer_weight_suffixes=[
791
+ "attn.c_attn.weight",
792
+ "attn.c_attn.bias",
793
+ "attn.c_proj.weight",
794
+ "attn.c_proj.bias",
795
+ "ln_1.weight",
796
+ "ln_1.bias",
797
+ "ln_2.weight",
798
+ "ln_2.bias",
799
+ "mlp.c_fc.weight",
800
+ "mlp.c_fc.bias",
801
+ "mlp.c_fc2.weight",
802
+ "mlp.c_fc2.bias",
803
+ "mlp.c_proj.weight",
804
+ "mlp.c_proj.bias",
805
+ ],
806
+ num_layers_key="n_layer",
807
+ )
808
+
809
+ GPT2_SEQCLASS_INFO = StaticTensorNames(
810
+ name="GPT2ForSequenceClassification",
811
+ pre_weight_names=["transformer.wte.weight", "transformer.wpe.weight"],
812
+ post_weight_names=[
813
+ "transformer.ln_f.weight",
814
+ "transformer.ln_f.bias",
815
+ "score.weight",
816
+ ],
817
+ layer_prefix_format="transformer.h.{idx}",
818
+ embed_weight_names=GPT2_INFO.embed_weight_names,
819
+ layer_weight_suffixes=GPT2_INFO.layer_weight_suffixes,
820
+ num_layers_key=GPT2_INFO.num_layers_key,
821
+ )
822
+
823
+
824
+ QWEN_INFO = StaticTensorNames(
825
+ name="QWenLMHeadModel",
826
+ pre_weight_names=["transformer.wte.weight"],
827
+ post_weight_names=["transformer.ln_f.weight", "lm_head.weight"],
828
+ embed_weight_names=["transformer.wte.weight", "lm_head.weight"],
829
+ layer_prefix_format="transformer.h.{idx}",
830
+ layer_weight_suffixes=[
831
+ "attn.c_attn.bias",
832
+ "attn.c_attn.weight",
833
+ "attn.c_proj.weight",
834
+ "ln_1.weight",
835
+ "ln_2.weight",
836
+ "mlp.c_proj.weight",
837
+ "mlp.w1.weight",
838
+ "mlp.w2.weight",
839
+ ],
840
+ )
841
+
842
+ CHATGLM_INFO = StaticTensorNames(
843
+ name="ChatGLMModel",
844
+ pre_weight_names=[
845
+ "transformer.embedding.word_embeddings.weight",
846
+ "transformer.rotary_pos_emb.inv_freq",
847
+ ],
848
+ post_weight_names=[
849
+ "transformer.encoder.final_layernorm.weight",
850
+ "transformer.output_layer.weight",
851
+ ],
852
+ embed_weight_names=[
853
+ "transformer.embedding.word_embeddings.weight",
854
+ "transformer.output_layer.weight",
855
+ ],
856
+ layer_prefix_format="transformer.encoder.layers.{idx}",
857
+ layer_weight_suffixes=[
858
+ "input_layernorm.weight",
859
+ "mlp.dense_4h_to_h.weight",
860
+ "mlp.dense_h_to_4h.weight",
861
+ "post_attention_layernorm.weight",
862
+ "self_attention.dense.weight",
863
+ "self_attention.query_key_value.bias",
864
+ "self_attention.query_key_value.weight",
865
+ ],
866
+ )
867
+
868
+ FALCON_INFO = StaticTensorNames(
869
+ name="FalconForCausalLM",
870
+ pre_weight_names=["transformer.word_embeddings.weight"],
871
+ post_weight_names=[
872
+ "transformer.ln_f.weight",
873
+ "transformer.ln_f.bias",
874
+ "lm_head.weight",
875
+ ],
876
+ embed_weight_names=["transformer.word_embeddings.weight", "lm_head.weight"],
877
+ layer_prefix_format="transformer.h.{idx}",
878
+ layer_weight_suffixes=[
879
+ "ln_attn.bias",
880
+ "ln_attn.weight",
881
+ "ln_mlp.bias",
882
+ "ln_mlp.weight",
883
+ "mlp.dense_4h_to_h.weight",
884
+ "mlp.dense_h_to_4h.weight",
885
+ "self_attention.dense.weight",
886
+ "self_attention.query_key_value.weight",
887
+ ],
888
+ )
889
+
890
+
891
+ class PhiTensorNames(ArchitectureInfo):
892
+ architecture_name: str = "MixFormerSequentialForCausalLM"
893
+
894
+ def __init__(self, config: PretrainedConfig):
895
+ self.config = config
896
+
897
+ def __eq__(self, rhs: "PhiTensorNames"):
898
+ if not isinstance(rhs, PhiTensorNames):
899
+ return False
900
+ return self.num_layers() == rhs.num_layers()
901
+
902
+ def pre_weights(self) -> List[str]:
903
+ return ["layers.0.wte.weight"]
904
+
905
+ def post_weights(self) -> List[str]:
906
+ fake_layer_idx = self.config.n_layer + 1
907
+ return [
908
+ f"layers.{fake_layer_idx}.{suffix}"
909
+ for suffix in ["linear.bias", "linear.weight", "ln.bias", "ln.weight"]
910
+ ]
911
+
912
+ def embed_weights(self) -> List[str]:
913
+ fake_layer_idx = self.config.n_layer + 1
914
+ return [
915
+ "layers.0.wte.weight",
916
+ f"layers.{fake_layer_idx}.linear.weight",
917
+ f"layers.{fake_layer_idx}.linear.bias",
918
+ ]
919
+
920
+ def layer_weight_formats(self) -> List[str]:
921
+ return [
922
+ ("layers.{idx}." + suffix)
923
+ for suffix in [
924
+ "ln.bias",
925
+ "ln.weight",
926
+ "mixer.Wqkv.bias",
927
+ "mixer.Wqkv.weight",
928
+ "mixer.out_proj.bias",
929
+ "mixer.out_proj.weight",
930
+ "mixer.rotary_emb.inv_freq",
931
+ "mlp.fc1.bias",
932
+ "mlp.fc1.weight",
933
+ "mlp.fc2.bias",
934
+ "mlp.fc2.weight",
935
+ ]
936
+ ]
937
+
938
+ def num_layers(self, config: PretrainedConfig) -> int:
939
+ return config.n_layer
940
+
941
+ def num_layers_config_key(self) -> str:
942
+ return "n_layer"
943
+
944
+
945
+ PHI2_INFO = StaticTensorNames(
946
+ name="PhiForCausalLM",
947
+ pre_weight_names=["transformer.embd.wte.weight"],
948
+ post_weight_names=[
949
+ "lm_head.linear.bias",
950
+ "lm_head.linear.weight",
951
+ "lm_head.ln.bias",
952
+ "lm_head.ln.weight",
953
+ ],
954
+ embed_weight_names=["lm_head.linear.weight", "transformer.embd.wte.weight"],
955
+ layer_prefix_format="transformer.h.{idx}",
956
+ layer_weight_suffixes=[
957
+ "ln.bias",
958
+ "ln.weight",
959
+ "mixer.out_proj.bias",
960
+ "mixer.out_proj.weight",
961
+ "mixer.Wqkv.bias",
962
+ "mixer.Wqkv.weight",
963
+ "mlp.fc1.bias",
964
+ "mlp.fc1.weight",
965
+ "mlp.fc2.bias",
966
+ "mlp.fc2.weight",
967
+ ],
968
+ num_layers_key="n_layer",
969
+ )
970
+
971
+
972
+ PHI2_INFO_AGAIN_BUT_DIFFERENT = StaticTensorNames(
973
+ name="PhiForCausalLM",
974
+ pre_weight_names=["model.embed_tokens.weight"],
975
+ post_weight_names=[
976
+ "lm_head.bias",
977
+ "lm_head.weight",
978
+ "model.final_layernorm.bias",
979
+ "model.final_layernorm.weight",
980
+ ],
981
+ embed_weight_names=["lm_head.weight", "model.embed_tokens.weight"],
982
+ layer_prefix_format="model.layers.{idx}",
983
+ layer_weight_suffixes=[
984
+ "input_layernorm.bias",
985
+ "input_layernorm.weight",
986
+ "self_attn.dense.bias",
987
+ "self_attn.dense.weight",
988
+ "self_attn.q_proj.bias",
989
+ "self_attn.q_proj.weight",
990
+ "self_attn.k_proj.bias",
991
+ "self_attn.k_proj.weight",
992
+ "self_attn.v_proj.bias",
993
+ "self_attn.v_proj.weight",
994
+ "mlp.fc1.bias",
995
+ "mlp.fc1.weight",
996
+ "mlp.fc2.bias",
997
+ "mlp.fc2.weight",
998
+ ],
999
+ )
1000
+
1001
+
1002
+ BAICHUAN_INFO = StaticTensorNames(
1003
+ name="BaichuanForCausalLM",
1004
+ pre_weight_names=["model.embed_tokens.weight"],
1005
+ post_weight_names=["model.norm.weight", "lm_head.weight"],
1006
+ embed_weight_names=["model.embed_tokens.weight", "lm_head.weight"],
1007
+ layer_prefix_format="model.layers.{idx}",
1008
+ layer_weight_suffixes=[
1009
+ "input_layernorm.weight",
1010
+ "self_attn.W_pack.weight",
1011
+ "self_attn.o_proj.weight",
1012
+ "post_attention_layernorm.weight",
1013
+ "mlp.gate_proj.weight",
1014
+ "mlp.down_proj.weight",
1015
+ "mlp.up_proj.weight",
1016
+ ],
1017
+ )
1018
+
1019
+
1020
+ def get_architecture_info(config: PretrainedConfig) -> StaticTensorNames:
1021
+ if len(config.architectures) != 1:
1022
+ raise RuntimeError("More than one architecture in config?")
1023
+
1024
+ arch_name = config.architectures[0]
1025
+ if arch_name == PhiTensorNames.architecture_name:
1026
+ return PhiTensorNames(config)
1027
+
1028
+ if arch_name == PHI2_INFO.name:
1029
+ if config.model_type == "phi-msft":
1030
+ return PHI2_INFO
1031
+ elif config.model_type == "phi":
1032
+ return PHI2_INFO_AGAIN_BUT_DIFFERENT
1033
+
1034
+ supported = [
1035
+ LLAMA_INFO,
1036
+ MISTRAL_INFO,
1037
+ GPT_NEOX_INFO,
1038
+ QWEN_INFO,
1039
+ GPT2_INFO,
1040
+ GPT2_SEQCLASS_INFO,
1041
+ CHATGLM_INFO,
1042
+ STABLELM_INFO,
1043
+ JAIS_INFO,
1044
+ BAICHUAN_INFO,
1045
+ FALCON_INFO,
1046
+ ]
1047
+ for arch in supported:
1048
+ if arch.name == arch_name:
1049
+ return arch
1050
+
1051
+ raise RuntimeError(f"Unsupported architecture {arch_name}")
1052
+ ```
1053
+
1054
+ 3) replace `configs.json` with the one from **this repo**
1055
+ 4) you need to add `modeling_phi.py` and `configurations.phi` from **this repo** to your repo