paulilioaica
commited on
Commit
•
ae3dbb5
1
Parent(s):
e298d48
Update README.md
Browse files
README.md
CHANGED
@@ -109,947 +109,17 @@ Steps
|
|
109 |
|
110 |
1. Modify moe_mixtral.py from `/content/mergekit/mergekit/scripts/mixtral_moe.py` to your hf repo
|
111 |
|
112 |
-
***mixtral_moe.py***
|
113 |
|
114 |
```
|
115 |
-
# Copyright (C) 2024 Charles O. Goddard
|
116 |
-
#
|
117 |
-
# This software is free software: you can redistribute it and/or
|
118 |
-
# modify it under the terms of the GNU Lesser General Public License as
|
119 |
-
# published by the Free Software Foundation, either version 3 of the
|
120 |
-
# License, or (at your option) any later version.
|
121 |
-
#
|
122 |
-
# This software is distributed in the hope that it will be useful, but
|
123 |
-
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
124 |
-
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
125 |
-
# Lesser General Public License for more details.
|
126 |
-
#
|
127 |
-
# You should have received a copy of the GNU Lesser General Public License
|
128 |
-
# along with this program. If not, see http://www.gnu.org/licenses/.
|
129 |
|
130 |
-
import logging
|
131 |
-
import os
|
132 |
-
import sys
|
133 |
-
from typing import Dict, List, Optional, Union
|
134 |
-
|
135 |
-
import click
|
136 |
-
import torch
|
137 |
-
import tqdm
|
138 |
-
import transformers
|
139 |
-
import yaml
|
140 |
-
from pydantic import BaseModel
|
141 |
-
from transformers import (
|
142 |
-
AutoModelForCausalLM,
|
143 |
-
LlamaForCausalLM,
|
144 |
-
MistralConfig,
|
145 |
-
MistralForCausalLM,
|
146 |
-
MixtralConfig,
|
147 |
-
)
|
148 |
-
from transformers.modeling_outputs import CausalLMOutputWithPast
|
149 |
-
|
150 |
-
import mergekit.architecture
|
151 |
-
from mergekit.common import ModelReference, dtype_from_name
|
152 |
-
from mergekit.io import LazyTensorLoader, TensorWriter
|
153 |
-
from mergekit.merge import MergeOptions
|
154 |
-
from mergekit.options import add_merge_options
|
155 |
-
|
156 |
-
# Create a Mixtral MoE from a set of equally-sized Mistral (or Llama) models.
|
157 |
-
# Takes the path to a yml config and an output path.
|
158 |
-
# Config schema is the two classes below.
|
159 |
-
|
160 |
-
|
161 |
-
class Expert(BaseModel):
|
162 |
-
source_model: str
|
163 |
-
|
164 |
-
positive_prompts: List[str]
|
165 |
-
negative_prompts: Optional[List[str]] = None
|
166 |
-
noise_scale: Optional[float] = None
|
167 |
-
|
168 |
-
@property
|
169 |
-
def model_ref(self):
|
170 |
-
return ModelReference.parse(self.source_model)
|
171 |
-
|
172 |
-
|
173 |
-
class MistralMOEConfig(BaseModel):
|
174 |
-
base_model: str
|
175 |
-
experts: List[Expert]
|
176 |
-
gate_mode: str = "hidden" # possible values: "hidden", "cheap_embed", "random"
|
177 |
-
# "hidden" uses hidden state vectors for the given prompts for each layer
|
178 |
-
# "cheap_embed" uses the average of token embeddings for the prompts, same for each layer
|
179 |
-
# "random" is random
|
180 |
-
dtype: Optional[str] = None
|
181 |
-
experts_per_token: int = 2
|
182 |
-
|
183 |
-
|
184 |
-
def get_hidden_states(
|
185 |
-
model: Union[MistralForCausalLM, LlamaForCausalLM],
|
186 |
-
tokenized: transformers.BatchEncoding,
|
187 |
-
average: bool = True,
|
188 |
-
) -> List[torch.Tensor]:
|
189 |
-
with torch.no_grad():
|
190 |
-
output: CausalLMOutputWithPast = model(
|
191 |
-
**tokenized.to(model.device), output_hidden_states=True, return_dict=True
|
192 |
-
)
|
193 |
-
hidden_states = torch.stack(
|
194 |
-
output.hidden_states[:-1]
|
195 |
-
) # (num_layers, batch_size, seq_len, hidden_size)
|
196 |
-
if average:
|
197 |
-
# use average over sequence
|
198 |
-
hidden_states = hidden_states.sum(dim=2) / hidden_states.shape[2]
|
199 |
-
else:
|
200 |
-
# take last value
|
201 |
-
hidden_states = hidden_states[:, :, -1, :]
|
202 |
-
return hidden_states.sum(dim=1) / hidden_states.shape[1]
|
203 |
-
|
204 |
-
|
205 |
-
def get_cheap_embedding(
|
206 |
-
embed: torch.Tensor,
|
207 |
-
tokenized: Dict[str, torch.Tensor],
|
208 |
-
num_layers: int,
|
209 |
-
vocab_size: int,
|
210 |
-
) -> torch.Tensor:
|
211 |
-
onehot = torch.nn.functional.one_hot(
|
212 |
-
tokenized["input_ids"], num_classes=vocab_size
|
213 |
-
) # (batch_size, seq_len, 32000)
|
214 |
-
h = onehot.float() @ embed.float() # (batch_size, seq_len, hidden_size)
|
215 |
-
embedded = (
|
216 |
-
(h * tokenized["attention_mask"].unsqueeze(-1))
|
217 |
-
.sum(dim=1)
|
218 |
-
.sum(dim=0, keepdim=True)
|
219 |
-
) # (1, hidden_size)
|
220 |
-
res = embedded / embedded.norm(dim=-1, keepdim=True).clamp(
|
221 |
-
min=1e-8
|
222 |
-
) # (1, hidden_size)
|
223 |
-
return res.repeat(num_layers, 1)
|
224 |
-
|
225 |
-
|
226 |
-
def tokenize_prompts(
|
227 |
-
prompts: List[str], tokenizer: transformers.PreTrainedTokenizerBase
|
228 |
-
):
|
229 |
-
return tokenizer(
|
230 |
-
[tokenizer.bos_token + p for p in prompts],
|
231 |
-
return_tensors="pt",
|
232 |
-
padding=True,
|
233 |
-
add_special_tokens=False,
|
234 |
-
)
|
235 |
-
|
236 |
-
|
237 |
-
def get_gate_params(
|
238 |
-
model_ref: ModelReference,
|
239 |
-
tokenizer: transformers.PreTrainedTokenizerBase,
|
240 |
-
experts: List[Expert],
|
241 |
-
mode: str = "hidden",
|
242 |
-
load_in_4bit: bool = False,
|
243 |
-
load_in_8bit: bool = False,
|
244 |
-
lazy_unpickle: bool = False,
|
245 |
-
trust_remote_code: bool = False,
|
246 |
-
device: str = "auto",
|
247 |
-
):
|
248 |
-
gate_vecs = []
|
249 |
-
_do_it = None
|
250 |
-
|
251 |
-
model_cfg = model_ref.config(trust_remote_code=trust_remote_code)
|
252 |
-
|
253 |
-
if mode == "random":
|
254 |
-
return torch.randn(
|
255 |
-
(model_cfg.num_hidden_layers, len(experts), model_cfg.hidden_size)
|
256 |
-
)
|
257 |
-
elif mode == "cheap_embed":
|
258 |
-
embed = LazyTensorLoader(
|
259 |
-
model_ref.tensor_index(), lazy_unpickle=lazy_unpickle
|
260 |
-
).get_tensor("transformer.embd.wte.weight")
|
261 |
-
|
262 |
-
def _do_it(tokenized):
|
263 |
-
return get_cheap_embedding(
|
264 |
-
embed,
|
265 |
-
tokenized,
|
266 |
-
num_layers=model_cfg.num_hidden_layers,
|
267 |
-
vocab_size=model_cfg.vocab_size,
|
268 |
-
)
|
269 |
-
|
270 |
-
elif mode in ("hidden", "hidden_avg", "hidden_last"):
|
271 |
-
model = AutoModelForCausalLM.from_pretrained(
|
272 |
-
model_ref.model.path,
|
273 |
-
revision=model_ref.model.revision,
|
274 |
-
torch_dtype=torch.bfloat16,
|
275 |
-
device_map=device,
|
276 |
-
low_cpu_mem_usage=True,
|
277 |
-
load_in_4bit=load_in_4bit,
|
278 |
-
load_in_8bit=load_in_8bit,
|
279 |
-
trust_remote_code=trust_remote_code,
|
280 |
-
)
|
281 |
-
|
282 |
-
def _do_it(tokenized):
|
283 |
-
return get_hidden_states(
|
284 |
-
model, tokenized=tokenized, average=mode == "hidden_avg"
|
285 |
-
)
|
286 |
-
|
287 |
-
|
288 |
-
gate_vecs = []
|
289 |
-
print(experts)
|
290 |
-
for expert in tqdm.tqdm(experts, desc="expert prompts"):
|
291 |
-
print(_do_it)
|
292 |
-
hidden_states = _do_it(tokenize_prompts(expert.positive_prompts, tokenizer))
|
293 |
-
if expert.negative_prompts:
|
294 |
-
hidden_states -= _do_it(
|
295 |
-
tokenize_prompts(expert.negative_prompts, tokenizer)
|
296 |
-
)
|
297 |
-
|
298 |
-
hidden_states /= hidden_states.norm(p=2, dim=-1, keepdim=True).clamp(min=1e-8)
|
299 |
-
gate_vecs.append(hidden_states)
|
300 |
-
gate_vecs = torch.stack(gate_vecs, dim=0) # (num_expert, num_layer, hidden_size)
|
301 |
-
return gate_vecs.permute(1, 0, 2)
|
302 |
-
|
303 |
-
|
304 |
-
def warn_degenerate_gates(gate_vecs: torch.Tensor, threshold: float = 5.0):
|
305 |
-
degen_indices = []
|
306 |
-
num_layers, _num_experts, _hidden_size = gate_vecs.shape
|
307 |
-
for idx in range(num_layers):
|
308 |
-
c = torch.linalg.cond(gate_vecs[idx, :, :].float())
|
309 |
-
if c > threshold:
|
310 |
-
degen_indices.append(idx)
|
311 |
-
|
312 |
-
if degen_indices:
|
313 |
-
if len(degen_indices) == 1:
|
314 |
-
layer_str = f"layer {degen_indices[0]}"
|
315 |
-
verb = "has"
|
316 |
-
elif len(degen_indices) == 2:
|
317 |
-
layer_str = f"layers {' and '.join(map(str, degen_indices))}"
|
318 |
-
verb = "have"
|
319 |
-
elif len(degen_indices) >= num_layers:
|
320 |
-
layer_str = "ALL layers"
|
321 |
-
verb = "have"
|
322 |
-
else:
|
323 |
-
layer_str = (
|
324 |
-
"layers "
|
325 |
-
+ ", ".join(map(str, degen_indices[:-1]))
|
326 |
-
+ ", and "
|
327 |
-
+ str(degen_indices[-1])
|
328 |
-
)
|
329 |
-
verb = "have"
|
330 |
-
|
331 |
-
logging.warning(
|
332 |
-
f"{layer_str} {verb} degenerate routing parameters "
|
333 |
-
"- your prompts may be too similar."
|
334 |
-
)
|
335 |
-
logging.warning("One or more experts will be underutilized in your model.")
|
336 |
-
|
337 |
-
|
338 |
-
def is_bad_config(config: MistralMOEConfig, allow_all_same: bool = False) -> bool:
|
339 |
-
if len(config.experts) < 2:
|
340 |
-
logging.error("Must include at least two experts.")
|
341 |
-
return True
|
342 |
-
|
343 |
-
if config.gate_mode == "random":
|
344 |
-
return False # eh we're good
|
345 |
-
|
346 |
-
def prompt_tup(e: Expert):
|
347 |
-
return (tuple(e.positive_prompts), tuple(e.negative_prompts or []))
|
348 |
-
|
349 |
-
# let's just nip this trend in the bud
|
350 |
-
p_first = prompt_tup(config.experts[0])
|
351 |
-
if all(prompt_tup(e) == p_first for e in config.experts[1:]):
|
352 |
-
logging.error(
|
353 |
-
"Your positive and negative prompts are identical for all experts. This will not produce a functioning MoE."
|
354 |
-
)
|
355 |
-
logging.error(
|
356 |
-
"For each expert, `positive_prompts` must contain one or more example prompt reflecting what should be routed to that expert."
|
357 |
-
)
|
358 |
-
return True
|
359 |
-
|
360 |
-
if not allow_all_same:
|
361 |
-
if all(
|
362 |
-
e.source_model == config.experts[0].source_model for e in config.experts[1:]
|
363 |
-
):
|
364 |
-
logging.error(
|
365 |
-
"All of your expert models are the same. This will produce "
|
366 |
-
"a model that uses more resources but gives the exact same output. "
|
367 |
-
"If you plan to train the model after merging, proceed with the "
|
368 |
-
"--i-understand-this-is-not-useful-without-training flag."
|
369 |
-
)
|
370 |
-
return True
|
371 |
-
|
372 |
-
|
373 |
-
def build(
|
374 |
-
config: MistralMOEConfig,
|
375 |
-
out_path: str,
|
376 |
-
merge_options: MergeOptions,
|
377 |
-
load_in_4bit: bool = False,
|
378 |
-
load_in_8bit: bool = False,
|
379 |
-
device: str = "auto",
|
380 |
-
allow_all_same: bool = False,
|
381 |
-
):
|
382 |
-
if is_bad_config(config, allow_all_same=allow_all_same):
|
383 |
-
sys.exit(1)
|
384 |
-
|
385 |
-
if config.experts_per_token < 1:
|
386 |
-
logging.error("Experts per token must be >= 1")
|
387 |
-
sys.exit(1)
|
388 |
-
if config.experts_per_token > len(config.experts):
|
389 |
-
logging.error("Experts per token must be <= number of experts")
|
390 |
-
sys.exit(1)
|
391 |
-
|
392 |
-
base_model = ModelReference.parse(config.base_model)
|
393 |
-
base_cfg = base_model.config(trust_remote_code=merge_options.trust_remote_code)
|
394 |
-
if not isinstance(base_cfg, MistralConfig):
|
395 |
-
base_cfg_mistral = MistralConfig(**base_cfg.to_dict())
|
396 |
-
base_cfg_mistral.sliding_window = None
|
397 |
-
base_cfg_mistral.max_position_embeddings = base_cfg.max_position_embeddings
|
398 |
-
base_cfg = base_cfg_mistral
|
399 |
-
|
400 |
-
out_cfg = MixtralConfig(**base_cfg.to_dict())
|
401 |
-
out_cfg.architectures = ["PhiForCausalLM"]
|
402 |
-
out_cfg.num_local_experts = len(config.experts)
|
403 |
-
out_cfg.num_experts_per_tok = config.experts_per_token
|
404 |
-
out_cfg.sliding_window = None
|
405 |
-
if config.dtype:
|
406 |
-
out_cfg.torch_dtype = config.dtype
|
407 |
-
out_cfg.save_pretrained(out_path)
|
408 |
-
|
409 |
-
if (out_cfg.num_local_experts & (out_cfg.num_local_experts - 1)) != 0:
|
410 |
-
logging.warning(
|
411 |
-
f"Your model has {out_cfg.num_local_experts} experts, which is "
|
412 |
-
"not a power of two. The model will not be usable in llama.cpp."
|
413 |
-
)
|
414 |
-
|
415 |
-
loaders: Dict[ModelReference, LazyTensorLoader] = {}
|
416 |
-
for model in tqdm.tqdm(
|
417 |
-
[base_model] + [e.model_ref for e in config.experts], desc="Warm up loaders"
|
418 |
-
):
|
419 |
-
loaders[model] = LazyTensorLoader(
|
420 |
-
model.tensor_index(cache_dir=merge_options.transformers_cache),
|
421 |
-
lazy_unpickle=merge_options.lazy_unpickle,
|
422 |
-
)
|
423 |
-
|
424 |
-
base_loader = loaders.get(base_model)
|
425 |
-
writer = TensorWriter(
|
426 |
-
out_path=out_path,
|
427 |
-
max_shard_size=merge_options.out_shard_size,
|
428 |
-
safe_serialization=merge_options.safe_serialization,
|
429 |
-
)
|
430 |
-
|
431 |
-
if config.dtype:
|
432 |
-
out_dtype = dtype_from_name(config.dtype)
|
433 |
-
elif base_cfg.torch_dtype:
|
434 |
-
out_dtype = base_cfg.torch_dtype
|
435 |
-
if isinstance(out_dtype, str):
|
436 |
-
out_dtype = dtype_from_name(out_dtype)
|
437 |
-
else:
|
438 |
-
out_dtype = None
|
439 |
-
|
440 |
-
logging.info("Copying parameters...")
|
441 |
-
MISTRAL_INFO = mergekit.architecture.PHI2_INFO
|
442 |
-
for tensor_name in MISTRAL_INFO.pre_weight_names + MISTRAL_INFO.post_weight_names:
|
443 |
-
tensor = base_loader.get_tensor(tensor_name)
|
444 |
-
if not out_dtype:
|
445 |
-
# All else has failed, take the first dtype we see
|
446 |
-
out_dtype = tensor.dtype
|
447 |
-
writer.save_tensor(
|
448 |
-
tensor_name, tensor.to(dtype=out_dtype), clone=merge_options.clone_tensors
|
449 |
-
)
|
450 |
-
set_of_seen_tensors = set()
|
451 |
-
|
452 |
-
for name_format in tqdm.tqdm(MISTRAL_INFO.layer_weight_formats()):
|
453 |
-
for layer_idx in range(base_cfg.num_hidden_layers):
|
454 |
-
tensor_name = name_format.format(idx=layer_idx)
|
455 |
-
if ".mlp.fc" in name_format:
|
456 |
-
for moe_index, expert in enumerate(config.experts):
|
457 |
-
if tensor_name in set_of_seen_tensors:
|
458 |
-
expert_name = tensor_name.replace(
|
459 |
-
".mlp.fc", f".moe.mlp.1.fc"
|
460 |
-
)
|
461 |
-
else:
|
462 |
-
expert_name = tensor_name.replace(
|
463 |
-
".mlp.fc", f".moe.mlp.0.fc"
|
464 |
-
)
|
465 |
-
set_of_seen_tensors.add(tensor_name)
|
466 |
-
|
467 |
-
expert_loader = loaders.get(expert.model_ref)
|
468 |
-
tensor = expert_loader.get_tensor(tensor_name)
|
469 |
-
if expert.noise_scale:
|
470 |
-
tensor += torch.randn_like(tensor) * expert.noise_scale
|
471 |
-
writer.save_tensor(
|
472 |
-
expert_name, tensor.to(dtype=out_dtype), clone=True
|
473 |
-
)
|
474 |
-
print(expert_name, tensor_name)
|
475 |
-
continue
|
476 |
-
writer.save_tensor(
|
477 |
-
tensor_name, base_loader.get_tensor(tensor_name).to(dtype=out_dtype)
|
478 |
-
)
|
479 |
-
|
480 |
-
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
481 |
-
base_model.model.path, revision=base_model.model.revision
|
482 |
-
)
|
483 |
-
tokenizer.padding_side = "left"
|
484 |
-
tokenizer.pad_token_id = tokenizer.bos_token_id
|
485 |
-
|
486 |
-
logging.info("Getting gate parameters...")
|
487 |
-
gate_vecs = get_gate_params(
|
488 |
-
base_model,
|
489 |
-
tokenizer,
|
490 |
-
config.experts,
|
491 |
-
mode=config.gate_mode,
|
492 |
-
load_in_4bit=load_in_4bit,
|
493 |
-
load_in_8bit=load_in_8bit,
|
494 |
-
lazy_unpickle=merge_options.lazy_unpickle,
|
495 |
-
trust_remote_code=merge_options.trust_remote_code,
|
496 |
-
device=device,
|
497 |
-
)
|
498 |
-
# gate_vecs: (num_layers, num_experts, hidden_size)
|
499 |
-
|
500 |
-
warn_degenerate_gates(gate_vecs)
|
501 |
-
|
502 |
-
for layer_idx in range(base_cfg.num_hidden_layers):
|
503 |
-
writer.save_tensor(
|
504 |
-
f"transformer.h.{layer_idx}.moe.gate.weight",
|
505 |
-
gate_vecs[layer_idx, :, :].contiguous().to(dtype=out_dtype),
|
506 |
-
)
|
507 |
-
writer.finalize()
|
508 |
-
|
509 |
-
if merge_options.copy_tokenizer:
|
510 |
-
logging.info("Saving tokenizer...")
|
511 |
-
tokenizer.save_pretrained(out_path, safe_serialization=True)
|
512 |
-
|
513 |
-
logging.info("Done.")
|
514 |
-
|
515 |
-
|
516 |
-
@click.command("mergekit-moe")
|
517 |
-
@click.argument("config_path", type=click.Path(exists=True, dir_okay=False))
|
518 |
-
@click.argument("out_path", type=click.Path())
|
519 |
-
@click.option(
|
520 |
-
"--load-in-4bit",
|
521 |
-
is_flag=True,
|
522 |
-
type=bool,
|
523 |
-
default=False,
|
524 |
-
help="Load model in 4bit for computing hidden states",
|
525 |
-
)
|
526 |
-
@click.option(
|
527 |
-
"--load-in-8bit",
|
528 |
-
is_flag=True,
|
529 |
-
type=bool,
|
530 |
-
default=False,
|
531 |
-
help="Load model in 8bit for computing hidden states",
|
532 |
-
)
|
533 |
-
@click.option(
|
534 |
-
"--device",
|
535 |
-
type=str,
|
536 |
-
default="auto",
|
537 |
-
help="Device to use to compute embeddings",
|
538 |
-
show_default=True,
|
539 |
-
)
|
540 |
-
@click.option(
|
541 |
-
"--verbose", "-v", type=bool, default=False, is_flag=True, help="Verbose logging"
|
542 |
-
)
|
543 |
-
@click.option(
|
544 |
-
"--i-understand-this-is-not-useful-without-training",
|
545 |
-
type=bool,
|
546 |
-
default=False,
|
547 |
-
is_flag=True,
|
548 |
-
help="Really make the questionable model you want.",
|
549 |
-
)
|
550 |
-
@add_merge_options
|
551 |
-
def main(
|
552 |
-
config_path: str,
|
553 |
-
out_path: str,
|
554 |
-
load_in_4bit: bool,
|
555 |
-
load_in_8bit: bool,
|
556 |
-
device: str,
|
557 |
-
merge_options: MergeOptions,
|
558 |
-
verbose: bool,
|
559 |
-
i_understand_this_is_not_useful_without_training: bool,
|
560 |
-
):
|
561 |
-
logging.basicConfig(level=logging.INFO if verbose else logging.WARNING)
|
562 |
-
|
563 |
-
if merge_options.cuda:
|
564 |
-
logging.warning(
|
565 |
-
'--cuda is a no-op for mergekit-moe, use "--device cuda" instead'
|
566 |
-
)
|
567 |
-
|
568 |
-
with open(config_path, "r", encoding="utf-8") as file:
|
569 |
-
config_source = file.read()
|
570 |
-
|
571 |
-
config = MistralMOEConfig.model_validate(yaml.safe_load(config_source))
|
572 |
-
build(
|
573 |
-
config,
|
574 |
-
out_path=out_path,
|
575 |
-
merge_options=merge_options,
|
576 |
-
load_in_4bit=load_in_4bit,
|
577 |
-
load_in_8bit=load_in_8bit,
|
578 |
-
device=device,
|
579 |
-
allow_all_same=i_understand_this_is_not_useful_without_training,
|
580 |
-
)
|
581 |
-
|
582 |
-
if merge_options.write_model_card:
|
583 |
-
# TODO: generate a README.md as well
|
584 |
-
with open(
|
585 |
-
os.path.join(out_path, "mergekit_moe_config.yml"), "w", encoding="utf-8"
|
586 |
-
) as fp:
|
587 |
-
fp.write(config_source)
|
588 |
-
|
589 |
-
|
590 |
-
if __name__ == "__main__":
|
591 |
-
main()
|
592 |
```
|
593 |
|
594 |
2. Modify architecture.py `/content/mergekit/mergekit/architecture.py`
|
595 |
(this you can take from the link to the commit i have in description)
|
596 |
|
597 |
-
***architecture.py***
|
598 |
|
599 |
-
```
|
600 |
-
# Copyright (C) 2024 Charles O. Goddard
|
601 |
-
#
|
602 |
-
# This software is free software: you can redistribute it and/or
|
603 |
-
# modify it under the terms of the GNU Lesser General Public License as
|
604 |
-
# published by the Free Software Foundation, either version 3 of the
|
605 |
-
# License, or (at your option) any later version.
|
606 |
-
#
|
607 |
-
# This software is distributed in the hope that it will be useful, but
|
608 |
-
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
609 |
-
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
610 |
-
# Lesser General Public License for more details.
|
611 |
-
#
|
612 |
-
# You should have received a copy of the GNU Lesser General Public License
|
613 |
-
# along with this program. If not, see http://www.gnu.org/licenses/.
|
614 |
-
|
615 |
-
from abc import ABC, abstractmethod
|
616 |
-
from typing import List, Optional
|
617 |
-
|
618 |
-
from pydantic import BaseModel
|
619 |
-
from transformers import PretrainedConfig
|
620 |
-
|
621 |
-
|
622 |
-
class ArchitectureInfo(ABC):
|
623 |
-
@abstractmethod
|
624 |
-
def pre_weights(self) -> List[str]:
|
625 |
-
"""Return a list of all weights preceding the first layer."""
|
626 |
-
...
|
627 |
-
|
628 |
-
@abstractmethod
|
629 |
-
def post_weights(self) -> List[str]:
|
630 |
-
"""Return a list of all weights following the final layer."""
|
631 |
-
...
|
632 |
-
|
633 |
-
@abstractmethod
|
634 |
-
def layer_weight_formats(self) -> List[str]:
|
635 |
-
"""Return a list of format strings all weights associated with a layer."""
|
636 |
-
...
|
637 |
-
|
638 |
-
@abstractmethod
|
639 |
-
def embed_weights(self) -> List[str]:
|
640 |
-
...
|
641 |
-
|
642 |
-
def num_layers(self, config: PretrainedConfig) -> int:
|
643 |
-
return config.num_hidden_layers
|
644 |
-
|
645 |
-
def num_layers_config_key(self) -> str:
|
646 |
-
"""Key in config that represents number of layers"""
|
647 |
-
return "num_hidden_layers"
|
648 |
-
|
649 |
-
|
650 |
-
class StaticTensorNames(ArchitectureInfo, BaseModel, frozen=True):
|
651 |
-
name: str
|
652 |
-
|
653 |
-
pre_weight_names: List[str] # weights applied before first layer
|
654 |
-
post_weight_names: List[str] # weights applied after last layer
|
655 |
-
embed_weight_names: List[str] # weights for embed/lm_head
|
656 |
-
layer_prefix_format: str
|
657 |
-
layer_weight_suffixes: List[str]
|
658 |
-
num_layers_key: Optional[str] = None
|
659 |
-
|
660 |
-
def pre_weights(self) -> List[str]:
|
661 |
-
return self.pre_weight_names
|
662 |
-
|
663 |
-
def post_weights(self) -> List[str]:
|
664 |
-
return self.post_weight_names
|
665 |
-
|
666 |
-
def embed_weights(self) -> List[str]:
|
667 |
-
return self.embed_weight_names
|
668 |
-
|
669 |
-
def layer_weight_formats(self) -> List[str]:
|
670 |
-
res = []
|
671 |
-
for suffix in self.layer_weight_suffixes:
|
672 |
-
res.append(self.layer_prefix_format + "." + suffix)
|
673 |
-
return res
|
674 |
-
|
675 |
-
def num_layers_config_key(self) -> str:
|
676 |
-
if self.num_layers_key:
|
677 |
-
return self.num_layers_key
|
678 |
-
return super().num_layers_config_key()
|
679 |
-
|
680 |
-
def num_layers(self, config: PretrainedConfig) -> int:
|
681 |
-
return getattr(config, self.num_layers_config_key())
|
682 |
-
|
683 |
-
def all_weights(self, config: PretrainedConfig) -> List[str]:
|
684 |
-
num_layers = self.num_layers(config)
|
685 |
-
tensor_names = list(self.pre_weights())
|
686 |
-
for layer_idx in range(num_layers):
|
687 |
-
for f in self.layer_weight_formats():
|
688 |
-
tensor_names.append(f.format(idx=layer_idx))
|
689 |
-
tensor_names.extend(self.post_weights())
|
690 |
-
return tensor_names
|
691 |
-
|
692 |
-
|
693 |
-
LLAMA_INFO = StaticTensorNames(
|
694 |
-
name="LlamaForCausalLM",
|
695 |
-
pre_weight_names=["model.embed_tokens.weight"],
|
696 |
-
post_weight_names=["model.norm.weight", "lm_head.weight"],
|
697 |
-
embed_weight_names=["model.embed_tokens.weight", "lm_head.weight"],
|
698 |
-
layer_prefix_format="model.layers.{idx}",
|
699 |
-
layer_weight_suffixes=[
|
700 |
-
"input_layernorm.weight",
|
701 |
-
"mlp.up_proj.weight",
|
702 |
-
"mlp.down_proj.weight",
|
703 |
-
"mlp.gate_proj.weight",
|
704 |
-
"post_attention_layernorm.weight",
|
705 |
-
"self_attn.q_proj.weight",
|
706 |
-
"self_attn.k_proj.weight",
|
707 |
-
"self_attn.v_proj.weight",
|
708 |
-
"self_attn.o_proj.weight",
|
709 |
-
],
|
710 |
-
)
|
711 |
-
|
712 |
-
MISTRAL_INFO = StaticTensorNames(
|
713 |
-
name="MistralForCausalLM",
|
714 |
-
# lol
|
715 |
-
**LLAMA_INFO.model_dump(exclude=["name"]),
|
716 |
-
)
|
717 |
-
|
718 |
-
|
719 |
-
STABLELM_INFO = StaticTensorNames(
|
720 |
-
name="StableLMEpochForCausalLM",
|
721 |
-
post_weight_names=LLAMA_INFO.post_weight_names + ["model.norm.bias"],
|
722 |
-
layer_weight_suffixes=LLAMA_INFO.layer_weight_suffixes
|
723 |
-
+ [
|
724 |
-
"input_layernorm.bias",
|
725 |
-
"post_attention_layernorm.bias",
|
726 |
-
],
|
727 |
-
**LLAMA_INFO.model_dump(
|
728 |
-
exclude=["name", "layer_weight_suffixes", "post_weight_names"]
|
729 |
-
),
|
730 |
-
)
|
731 |
-
|
732 |
-
GPT_NEOX_INFO = StaticTensorNames(
|
733 |
-
name="GPTNeoXForCausalLM",
|
734 |
-
pre_weight_names=["gpt_neox.embed_in.weight"],
|
735 |
-
post_weight_names=[
|
736 |
-
"gpt_neox.final_layer_norm.bias",
|
737 |
-
"gpt_neox.final_layer_norm.weight",
|
738 |
-
"embed_out.weight",
|
739 |
-
],
|
740 |
-
embed_weight_names=["gpt_neox.embed_in.weight", "embed_out.weight"],
|
741 |
-
layer_prefix_format="gpt_neox.layers.{idx}",
|
742 |
-
layer_weight_suffixes=sum(
|
743 |
-
(
|
744 |
-
[f"{prefix}.weight", f"{prefix}.bias"]
|
745 |
-
for prefix in [
|
746 |
-
"attention.dense",
|
747 |
-
"attention.query_key_value",
|
748 |
-
"input_layernorm",
|
749 |
-
"mlp.dense_4h_to_h",
|
750 |
-
"mlp.dense_h_to_4h",
|
751 |
-
"post_attention_layernorm",
|
752 |
-
]
|
753 |
-
),
|
754 |
-
start=[],
|
755 |
-
)
|
756 |
-
+ ["attention.bias", "attention.masked_bias", "attention.rotary_emb.inv_freq"],
|
757 |
-
)
|
758 |
-
|
759 |
-
GPT2_INFO = StaticTensorNames(
|
760 |
-
name="GPT2LMHeadModel",
|
761 |
-
pre_weight_names=["wte.weight", "wpe.weight"],
|
762 |
-
post_weight_names=["ln_f.weight", "ln_f.bias"],
|
763 |
-
embed_weight_names=["wte.weight"],
|
764 |
-
layer_prefix_format="h.{idx}",
|
765 |
-
layer_weight_suffixes=[
|
766 |
-
"attn.c_attn.weight",
|
767 |
-
"attn.c_attn.bias",
|
768 |
-
"attn.c_proj.weight",
|
769 |
-
"attn.c_proj.bias",
|
770 |
-
"ln_1.weight",
|
771 |
-
"ln_1.bias",
|
772 |
-
"ln_2.weight",
|
773 |
-
"ln_2.bias",
|
774 |
-
"mlp.c_proj.weight",
|
775 |
-
"mlp.c_proj.bias",
|
776 |
-
"mlp.c_fc.weight",
|
777 |
-
"mlp.c_fc.bias",
|
778 |
-
"mlp.c_proj.weight",
|
779 |
-
"mlp.c_proj.bias",
|
780 |
-
],
|
781 |
-
num_layers_key="n_layer",
|
782 |
-
)
|
783 |
-
|
784 |
-
JAIS_INFO = StaticTensorNames(
|
785 |
-
name="JAISLMHeadModel",
|
786 |
-
pre_weight_names=["transformer.wte.weight", "transformer.relative_pe.slopes"],
|
787 |
-
post_weight_names=["transformer.ln_f.weight", "transformer.ln_f.bias"],
|
788 |
-
embed_weight_names=["transformer.wte.weight"],
|
789 |
-
layer_prefix_format="transformer.h.{idx}",
|
790 |
-
layer_weight_suffixes=[
|
791 |
-
"attn.c_attn.weight",
|
792 |
-
"attn.c_attn.bias",
|
793 |
-
"attn.c_proj.weight",
|
794 |
-
"attn.c_proj.bias",
|
795 |
-
"ln_1.weight",
|
796 |
-
"ln_1.bias",
|
797 |
-
"ln_2.weight",
|
798 |
-
"ln_2.bias",
|
799 |
-
"mlp.c_fc.weight",
|
800 |
-
"mlp.c_fc.bias",
|
801 |
-
"mlp.c_fc2.weight",
|
802 |
-
"mlp.c_fc2.bias",
|
803 |
-
"mlp.c_proj.weight",
|
804 |
-
"mlp.c_proj.bias",
|
805 |
-
],
|
806 |
-
num_layers_key="n_layer",
|
807 |
-
)
|
808 |
-
|
809 |
-
GPT2_SEQCLASS_INFO = StaticTensorNames(
|
810 |
-
name="GPT2ForSequenceClassification",
|
811 |
-
pre_weight_names=["transformer.wte.weight", "transformer.wpe.weight"],
|
812 |
-
post_weight_names=[
|
813 |
-
"transformer.ln_f.weight",
|
814 |
-
"transformer.ln_f.bias",
|
815 |
-
"score.weight",
|
816 |
-
],
|
817 |
-
layer_prefix_format="transformer.h.{idx}",
|
818 |
-
embed_weight_names=GPT2_INFO.embed_weight_names,
|
819 |
-
layer_weight_suffixes=GPT2_INFO.layer_weight_suffixes,
|
820 |
-
num_layers_key=GPT2_INFO.num_layers_key,
|
821 |
-
)
|
822 |
-
|
823 |
-
|
824 |
-
QWEN_INFO = StaticTensorNames(
|
825 |
-
name="QWenLMHeadModel",
|
826 |
-
pre_weight_names=["transformer.wte.weight"],
|
827 |
-
post_weight_names=["transformer.ln_f.weight", "lm_head.weight"],
|
828 |
-
embed_weight_names=["transformer.wte.weight", "lm_head.weight"],
|
829 |
-
layer_prefix_format="transformer.h.{idx}",
|
830 |
-
layer_weight_suffixes=[
|
831 |
-
"attn.c_attn.bias",
|
832 |
-
"attn.c_attn.weight",
|
833 |
-
"attn.c_proj.weight",
|
834 |
-
"ln_1.weight",
|
835 |
-
"ln_2.weight",
|
836 |
-
"mlp.c_proj.weight",
|
837 |
-
"mlp.w1.weight",
|
838 |
-
"mlp.w2.weight",
|
839 |
-
],
|
840 |
-
)
|
841 |
-
|
842 |
-
CHATGLM_INFO = StaticTensorNames(
|
843 |
-
name="ChatGLMModel",
|
844 |
-
pre_weight_names=[
|
845 |
-
"transformer.embedding.word_embeddings.weight",
|
846 |
-
"transformer.rotary_pos_emb.inv_freq",
|
847 |
-
],
|
848 |
-
post_weight_names=[
|
849 |
-
"transformer.encoder.final_layernorm.weight",
|
850 |
-
"transformer.output_layer.weight",
|
851 |
-
],
|
852 |
-
embed_weight_names=[
|
853 |
-
"transformer.embedding.word_embeddings.weight",
|
854 |
-
"transformer.output_layer.weight",
|
855 |
-
],
|
856 |
-
layer_prefix_format="transformer.encoder.layers.{idx}",
|
857 |
-
layer_weight_suffixes=[
|
858 |
-
"input_layernorm.weight",
|
859 |
-
"mlp.dense_4h_to_h.weight",
|
860 |
-
"mlp.dense_h_to_4h.weight",
|
861 |
-
"post_attention_layernorm.weight",
|
862 |
-
"self_attention.dense.weight",
|
863 |
-
"self_attention.query_key_value.bias",
|
864 |
-
"self_attention.query_key_value.weight",
|
865 |
-
],
|
866 |
-
)
|
867 |
-
|
868 |
-
FALCON_INFO = StaticTensorNames(
|
869 |
-
name="FalconForCausalLM",
|
870 |
-
pre_weight_names=["transformer.word_embeddings.weight"],
|
871 |
-
post_weight_names=[
|
872 |
-
"transformer.ln_f.weight",
|
873 |
-
"transformer.ln_f.bias",
|
874 |
-
"lm_head.weight",
|
875 |
-
],
|
876 |
-
embed_weight_names=["transformer.word_embeddings.weight", "lm_head.weight"],
|
877 |
-
layer_prefix_format="transformer.h.{idx}",
|
878 |
-
layer_weight_suffixes=[
|
879 |
-
"ln_attn.bias",
|
880 |
-
"ln_attn.weight",
|
881 |
-
"ln_mlp.bias",
|
882 |
-
"ln_mlp.weight",
|
883 |
-
"mlp.dense_4h_to_h.weight",
|
884 |
-
"mlp.dense_h_to_4h.weight",
|
885 |
-
"self_attention.dense.weight",
|
886 |
-
"self_attention.query_key_value.weight",
|
887 |
-
],
|
888 |
-
)
|
889 |
-
|
890 |
-
|
891 |
-
class PhiTensorNames(ArchitectureInfo):
|
892 |
-
architecture_name: str = "MixFormerSequentialForCausalLM"
|
893 |
-
|
894 |
-
def __init__(self, config: PretrainedConfig):
|
895 |
-
self.config = config
|
896 |
-
|
897 |
-
def __eq__(self, rhs: "PhiTensorNames"):
|
898 |
-
if not isinstance(rhs, PhiTensorNames):
|
899 |
-
return False
|
900 |
-
return self.num_layers() == rhs.num_layers()
|
901 |
-
|
902 |
-
def pre_weights(self) -> List[str]:
|
903 |
-
return ["layers.0.wte.weight"]
|
904 |
-
|
905 |
-
def post_weights(self) -> List[str]:
|
906 |
-
fake_layer_idx = self.config.n_layer + 1
|
907 |
-
return [
|
908 |
-
f"layers.{fake_layer_idx}.{suffix}"
|
909 |
-
for suffix in ["linear.bias", "linear.weight", "ln.bias", "ln.weight"]
|
910 |
-
]
|
911 |
-
|
912 |
-
def embed_weights(self) -> List[str]:
|
913 |
-
fake_layer_idx = self.config.n_layer + 1
|
914 |
-
return [
|
915 |
-
"layers.0.wte.weight",
|
916 |
-
f"layers.{fake_layer_idx}.linear.weight",
|
917 |
-
f"layers.{fake_layer_idx}.linear.bias",
|
918 |
-
]
|
919 |
-
|
920 |
-
def layer_weight_formats(self) -> List[str]:
|
921 |
-
return [
|
922 |
-
("layers.{idx}." + suffix)
|
923 |
-
for suffix in [
|
924 |
-
"ln.bias",
|
925 |
-
"ln.weight",
|
926 |
-
"mixer.Wqkv.bias",
|
927 |
-
"mixer.Wqkv.weight",
|
928 |
-
"mixer.out_proj.bias",
|
929 |
-
"mixer.out_proj.weight",
|
930 |
-
"mixer.rotary_emb.inv_freq",
|
931 |
-
"mlp.fc1.bias",
|
932 |
-
"mlp.fc1.weight",
|
933 |
-
"mlp.fc2.bias",
|
934 |
-
"mlp.fc2.weight",
|
935 |
-
]
|
936 |
-
]
|
937 |
-
|
938 |
-
def num_layers(self, config: PretrainedConfig) -> int:
|
939 |
-
return config.n_layer
|
940 |
-
|
941 |
-
def num_layers_config_key(self) -> str:
|
942 |
-
return "n_layer"
|
943 |
-
|
944 |
-
|
945 |
-
PHI2_INFO = StaticTensorNames(
|
946 |
-
name="PhiForCausalLM",
|
947 |
-
pre_weight_names=["transformer.embd.wte.weight"],
|
948 |
-
post_weight_names=[
|
949 |
-
"lm_head.linear.bias",
|
950 |
-
"lm_head.linear.weight",
|
951 |
-
"lm_head.ln.bias",
|
952 |
-
"lm_head.ln.weight",
|
953 |
-
],
|
954 |
-
embed_weight_names=["lm_head.linear.weight", "transformer.embd.wte.weight"],
|
955 |
-
layer_prefix_format="transformer.h.{idx}",
|
956 |
-
layer_weight_suffixes=[
|
957 |
-
"ln.bias",
|
958 |
-
"ln.weight",
|
959 |
-
"mixer.out_proj.bias",
|
960 |
-
"mixer.out_proj.weight",
|
961 |
-
"mixer.Wqkv.bias",
|
962 |
-
"mixer.Wqkv.weight",
|
963 |
-
"mlp.fc1.bias",
|
964 |
-
"mlp.fc1.weight",
|
965 |
-
"mlp.fc2.bias",
|
966 |
-
"mlp.fc2.weight",
|
967 |
-
],
|
968 |
-
num_layers_key="n_layer",
|
969 |
-
)
|
970 |
-
|
971 |
-
|
972 |
-
PHI2_INFO_AGAIN_BUT_DIFFERENT = StaticTensorNames(
|
973 |
-
name="PhiForCausalLM",
|
974 |
-
pre_weight_names=["model.embed_tokens.weight"],
|
975 |
-
post_weight_names=[
|
976 |
-
"lm_head.bias",
|
977 |
-
"lm_head.weight",
|
978 |
-
"model.final_layernorm.bias",
|
979 |
-
"model.final_layernorm.weight",
|
980 |
-
],
|
981 |
-
embed_weight_names=["lm_head.weight", "model.embed_tokens.weight"],
|
982 |
-
layer_prefix_format="model.layers.{idx}",
|
983 |
-
layer_weight_suffixes=[
|
984 |
-
"input_layernorm.bias",
|
985 |
-
"input_layernorm.weight",
|
986 |
-
"self_attn.dense.bias",
|
987 |
-
"self_attn.dense.weight",
|
988 |
-
"self_attn.q_proj.bias",
|
989 |
-
"self_attn.q_proj.weight",
|
990 |
-
"self_attn.k_proj.bias",
|
991 |
-
"self_attn.k_proj.weight",
|
992 |
-
"self_attn.v_proj.bias",
|
993 |
-
"self_attn.v_proj.weight",
|
994 |
-
"mlp.fc1.bias",
|
995 |
-
"mlp.fc1.weight",
|
996 |
-
"mlp.fc2.bias",
|
997 |
-
"mlp.fc2.weight",
|
998 |
-
],
|
999 |
-
)
|
1000 |
-
|
1001 |
-
|
1002 |
-
BAICHUAN_INFO = StaticTensorNames(
|
1003 |
-
name="BaichuanForCausalLM",
|
1004 |
-
pre_weight_names=["model.embed_tokens.weight"],
|
1005 |
-
post_weight_names=["model.norm.weight", "lm_head.weight"],
|
1006 |
-
embed_weight_names=["model.embed_tokens.weight", "lm_head.weight"],
|
1007 |
-
layer_prefix_format="model.layers.{idx}",
|
1008 |
-
layer_weight_suffixes=[
|
1009 |
-
"input_layernorm.weight",
|
1010 |
-
"self_attn.W_pack.weight",
|
1011 |
-
"self_attn.o_proj.weight",
|
1012 |
-
"post_attention_layernorm.weight",
|
1013 |
-
"mlp.gate_proj.weight",
|
1014 |
-
"mlp.down_proj.weight",
|
1015 |
-
"mlp.up_proj.weight",
|
1016 |
-
],
|
1017 |
-
)
|
1018 |
-
|
1019 |
-
|
1020 |
-
def get_architecture_info(config: PretrainedConfig) -> StaticTensorNames:
|
1021 |
-
if len(config.architectures) != 1:
|
1022 |
-
raise RuntimeError("More than one architecture in config?")
|
1023 |
-
|
1024 |
-
arch_name = config.architectures[0]
|
1025 |
-
if arch_name == PhiTensorNames.architecture_name:
|
1026 |
-
return PhiTensorNames(config)
|
1027 |
-
|
1028 |
-
if arch_name == PHI2_INFO.name:
|
1029 |
-
if config.model_type == "phi-msft":
|
1030 |
-
return PHI2_INFO
|
1031 |
-
elif config.model_type == "phi":
|
1032 |
-
return PHI2_INFO_AGAIN_BUT_DIFFERENT
|
1033 |
-
|
1034 |
-
supported = [
|
1035 |
-
LLAMA_INFO,
|
1036 |
-
MISTRAL_INFO,
|
1037 |
-
GPT_NEOX_INFO,
|
1038 |
-
QWEN_INFO,
|
1039 |
-
GPT2_INFO,
|
1040 |
-
GPT2_SEQCLASS_INFO,
|
1041 |
-
CHATGLM_INFO,
|
1042 |
-
STABLELM_INFO,
|
1043 |
-
JAIS_INFO,
|
1044 |
-
BAICHUAN_INFO,
|
1045 |
-
FALCON_INFO,
|
1046 |
-
]
|
1047 |
-
for arch in supported:
|
1048 |
-
if arch.name == arch_name:
|
1049 |
-
return arch
|
1050 |
-
|
1051 |
-
raise RuntimeError(f"Unsupported architecture {arch_name}")
|
1052 |
-
```
|
1053 |
|
1054 |
3) replace `configs.json` with the one from **this repo**
|
1055 |
4) you need to add `modeling_phi.py` and `configurations.phi` from **this repo** to your repo
|
|
|
109 |
|
110 |
1. Modify moe_mixtral.py from `/content/mergekit/mergekit/scripts/mixtral_moe.py` to your hf repo
|
111 |
|
112 |
+
[***mixtral_moe.py***](https://github.com/paulilioaica/Phi-MOE/blob/main/mixtral_moe.py)
|
113 |
|
114 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
```
|
117 |
|
118 |
2. Modify architecture.py `/content/mergekit/mergekit/architecture.py`
|
119 |
(this you can take from the link to the commit i have in description)
|
120 |
|
121 |
+
[***architecture.py***](https://github.com/paulilioaica/Phi-MOE/blob/main/architecture.py)
|
122 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
|
124 |
3) replace `configs.json` with the one from **this repo**
|
125 |
4) you need to add `modeling_phi.py` and `configurations.phi` from **this repo** to your repo
|