paulilioaica
commited on
Commit
•
e298d48
1
Parent(s):
97c6fa6
Update README.md
Browse files
README.md
CHANGED
@@ -95,4 +95,961 @@ The continents can be divided into several subregions, such as islands, archipel
|
|
95 |
Each continent has its own unique geography, climate, flora, fauna, and human cultures. The continents are interconnected through various landforms, bodies of water, and global trade routes.
|
96 |
|
97 |
In summary, there are seven continents on Earth, each with its own distinct characteristics and unique contributions to the world's diversity. While the number may vary depending on the categorization of Antarctica, all seven continents together make
|
98 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
Each continent has its own unique geography, climate, flora, fauna, and human cultures. The continents are interconnected through various landforms, bodies of water, and global trade routes.
|
96 |
|
97 |
In summary, there are seven continents on Earth, each with its own distinct characteristics and unique contributions to the world's diversity. While the number may vary depending on the categorization of Antarctica, all seven continents together make
|
98 |
+
```
|
99 |
+
|
100 |
+
|
101 |
+
## ♻️ Replicate this repo
|
102 |
+
|
103 |
+
**beware** this will only work with 2 phis, you might have to tinker in the naming thing for more layers
|
104 |
+
|
105 |
+
**AFTER** all the file modifications and run, you need to replace `configs.json` with the one from **this repo**
|
106 |
+
**AFTER** that you need to add `modeling_phi.py` and `configurations.phi` from **this repo** to your repo
|
107 |
+
|
108 |
+
Steps
|
109 |
+
|
110 |
+
1. Modify moe_mixtral.py from `/content/mergekit/mergekit/scripts/mixtral_moe.py` to your hf repo
|
111 |
+
|
112 |
+
***mixtral_moe.py***
|
113 |
+
|
114 |
+
```
|
115 |
+
# Copyright (C) 2024 Charles O. Goddard
|
116 |
+
#
|
117 |
+
# This software is free software: you can redistribute it and/or
|
118 |
+
# modify it under the terms of the GNU Lesser General Public License as
|
119 |
+
# published by the Free Software Foundation, either version 3 of the
|
120 |
+
# License, or (at your option) any later version.
|
121 |
+
#
|
122 |
+
# This software is distributed in the hope that it will be useful, but
|
123 |
+
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
124 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
125 |
+
# Lesser General Public License for more details.
|
126 |
+
#
|
127 |
+
# You should have received a copy of the GNU Lesser General Public License
|
128 |
+
# along with this program. If not, see http://www.gnu.org/licenses/.
|
129 |
+
|
130 |
+
import logging
|
131 |
+
import os
|
132 |
+
import sys
|
133 |
+
from typing import Dict, List, Optional, Union
|
134 |
+
|
135 |
+
import click
|
136 |
+
import torch
|
137 |
+
import tqdm
|
138 |
+
import transformers
|
139 |
+
import yaml
|
140 |
+
from pydantic import BaseModel
|
141 |
+
from transformers import (
|
142 |
+
AutoModelForCausalLM,
|
143 |
+
LlamaForCausalLM,
|
144 |
+
MistralConfig,
|
145 |
+
MistralForCausalLM,
|
146 |
+
MixtralConfig,
|
147 |
+
)
|
148 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
149 |
+
|
150 |
+
import mergekit.architecture
|
151 |
+
from mergekit.common import ModelReference, dtype_from_name
|
152 |
+
from mergekit.io import LazyTensorLoader, TensorWriter
|
153 |
+
from mergekit.merge import MergeOptions
|
154 |
+
from mergekit.options import add_merge_options
|
155 |
+
|
156 |
+
# Create a Mixtral MoE from a set of equally-sized Mistral (or Llama) models.
|
157 |
+
# Takes the path to a yml config and an output path.
|
158 |
+
# Config schema is the two classes below.
|
159 |
+
|
160 |
+
|
161 |
+
class Expert(BaseModel):
|
162 |
+
source_model: str
|
163 |
+
|
164 |
+
positive_prompts: List[str]
|
165 |
+
negative_prompts: Optional[List[str]] = None
|
166 |
+
noise_scale: Optional[float] = None
|
167 |
+
|
168 |
+
@property
|
169 |
+
def model_ref(self):
|
170 |
+
return ModelReference.parse(self.source_model)
|
171 |
+
|
172 |
+
|
173 |
+
class MistralMOEConfig(BaseModel):
|
174 |
+
base_model: str
|
175 |
+
experts: List[Expert]
|
176 |
+
gate_mode: str = "hidden" # possible values: "hidden", "cheap_embed", "random"
|
177 |
+
# "hidden" uses hidden state vectors for the given prompts for each layer
|
178 |
+
# "cheap_embed" uses the average of token embeddings for the prompts, same for each layer
|
179 |
+
# "random" is random
|
180 |
+
dtype: Optional[str] = None
|
181 |
+
experts_per_token: int = 2
|
182 |
+
|
183 |
+
|
184 |
+
def get_hidden_states(
|
185 |
+
model: Union[MistralForCausalLM, LlamaForCausalLM],
|
186 |
+
tokenized: transformers.BatchEncoding,
|
187 |
+
average: bool = True,
|
188 |
+
) -> List[torch.Tensor]:
|
189 |
+
with torch.no_grad():
|
190 |
+
output: CausalLMOutputWithPast = model(
|
191 |
+
**tokenized.to(model.device), output_hidden_states=True, return_dict=True
|
192 |
+
)
|
193 |
+
hidden_states = torch.stack(
|
194 |
+
output.hidden_states[:-1]
|
195 |
+
) # (num_layers, batch_size, seq_len, hidden_size)
|
196 |
+
if average:
|
197 |
+
# use average over sequence
|
198 |
+
hidden_states = hidden_states.sum(dim=2) / hidden_states.shape[2]
|
199 |
+
else:
|
200 |
+
# take last value
|
201 |
+
hidden_states = hidden_states[:, :, -1, :]
|
202 |
+
return hidden_states.sum(dim=1) / hidden_states.shape[1]
|
203 |
+
|
204 |
+
|
205 |
+
def get_cheap_embedding(
|
206 |
+
embed: torch.Tensor,
|
207 |
+
tokenized: Dict[str, torch.Tensor],
|
208 |
+
num_layers: int,
|
209 |
+
vocab_size: int,
|
210 |
+
) -> torch.Tensor:
|
211 |
+
onehot = torch.nn.functional.one_hot(
|
212 |
+
tokenized["input_ids"], num_classes=vocab_size
|
213 |
+
) # (batch_size, seq_len, 32000)
|
214 |
+
h = onehot.float() @ embed.float() # (batch_size, seq_len, hidden_size)
|
215 |
+
embedded = (
|
216 |
+
(h * tokenized["attention_mask"].unsqueeze(-1))
|
217 |
+
.sum(dim=1)
|
218 |
+
.sum(dim=0, keepdim=True)
|
219 |
+
) # (1, hidden_size)
|
220 |
+
res = embedded / embedded.norm(dim=-1, keepdim=True).clamp(
|
221 |
+
min=1e-8
|
222 |
+
) # (1, hidden_size)
|
223 |
+
return res.repeat(num_layers, 1)
|
224 |
+
|
225 |
+
|
226 |
+
def tokenize_prompts(
|
227 |
+
prompts: List[str], tokenizer: transformers.PreTrainedTokenizerBase
|
228 |
+
):
|
229 |
+
return tokenizer(
|
230 |
+
[tokenizer.bos_token + p for p in prompts],
|
231 |
+
return_tensors="pt",
|
232 |
+
padding=True,
|
233 |
+
add_special_tokens=False,
|
234 |
+
)
|
235 |
+
|
236 |
+
|
237 |
+
def get_gate_params(
|
238 |
+
model_ref: ModelReference,
|
239 |
+
tokenizer: transformers.PreTrainedTokenizerBase,
|
240 |
+
experts: List[Expert],
|
241 |
+
mode: str = "hidden",
|
242 |
+
load_in_4bit: bool = False,
|
243 |
+
load_in_8bit: bool = False,
|
244 |
+
lazy_unpickle: bool = False,
|
245 |
+
trust_remote_code: bool = False,
|
246 |
+
device: str = "auto",
|
247 |
+
):
|
248 |
+
gate_vecs = []
|
249 |
+
_do_it = None
|
250 |
+
|
251 |
+
model_cfg = model_ref.config(trust_remote_code=trust_remote_code)
|
252 |
+
|
253 |
+
if mode == "random":
|
254 |
+
return torch.randn(
|
255 |
+
(model_cfg.num_hidden_layers, len(experts), model_cfg.hidden_size)
|
256 |
+
)
|
257 |
+
elif mode == "cheap_embed":
|
258 |
+
embed = LazyTensorLoader(
|
259 |
+
model_ref.tensor_index(), lazy_unpickle=lazy_unpickle
|
260 |
+
).get_tensor("transformer.embd.wte.weight")
|
261 |
+
|
262 |
+
def _do_it(tokenized):
|
263 |
+
return get_cheap_embedding(
|
264 |
+
embed,
|
265 |
+
tokenized,
|
266 |
+
num_layers=model_cfg.num_hidden_layers,
|
267 |
+
vocab_size=model_cfg.vocab_size,
|
268 |
+
)
|
269 |
+
|
270 |
+
elif mode in ("hidden", "hidden_avg", "hidden_last"):
|
271 |
+
model = AutoModelForCausalLM.from_pretrained(
|
272 |
+
model_ref.model.path,
|
273 |
+
revision=model_ref.model.revision,
|
274 |
+
torch_dtype=torch.bfloat16,
|
275 |
+
device_map=device,
|
276 |
+
low_cpu_mem_usage=True,
|
277 |
+
load_in_4bit=load_in_4bit,
|
278 |
+
load_in_8bit=load_in_8bit,
|
279 |
+
trust_remote_code=trust_remote_code,
|
280 |
+
)
|
281 |
+
|
282 |
+
def _do_it(tokenized):
|
283 |
+
return get_hidden_states(
|
284 |
+
model, tokenized=tokenized, average=mode == "hidden_avg"
|
285 |
+
)
|
286 |
+
|
287 |
+
|
288 |
+
gate_vecs = []
|
289 |
+
print(experts)
|
290 |
+
for expert in tqdm.tqdm(experts, desc="expert prompts"):
|
291 |
+
print(_do_it)
|
292 |
+
hidden_states = _do_it(tokenize_prompts(expert.positive_prompts, tokenizer))
|
293 |
+
if expert.negative_prompts:
|
294 |
+
hidden_states -= _do_it(
|
295 |
+
tokenize_prompts(expert.negative_prompts, tokenizer)
|
296 |
+
)
|
297 |
+
|
298 |
+
hidden_states /= hidden_states.norm(p=2, dim=-1, keepdim=True).clamp(min=1e-8)
|
299 |
+
gate_vecs.append(hidden_states)
|
300 |
+
gate_vecs = torch.stack(gate_vecs, dim=0) # (num_expert, num_layer, hidden_size)
|
301 |
+
return gate_vecs.permute(1, 0, 2)
|
302 |
+
|
303 |
+
|
304 |
+
def warn_degenerate_gates(gate_vecs: torch.Tensor, threshold: float = 5.0):
|
305 |
+
degen_indices = []
|
306 |
+
num_layers, _num_experts, _hidden_size = gate_vecs.shape
|
307 |
+
for idx in range(num_layers):
|
308 |
+
c = torch.linalg.cond(gate_vecs[idx, :, :].float())
|
309 |
+
if c > threshold:
|
310 |
+
degen_indices.append(idx)
|
311 |
+
|
312 |
+
if degen_indices:
|
313 |
+
if len(degen_indices) == 1:
|
314 |
+
layer_str = f"layer {degen_indices[0]}"
|
315 |
+
verb = "has"
|
316 |
+
elif len(degen_indices) == 2:
|
317 |
+
layer_str = f"layers {' and '.join(map(str, degen_indices))}"
|
318 |
+
verb = "have"
|
319 |
+
elif len(degen_indices) >= num_layers:
|
320 |
+
layer_str = "ALL layers"
|
321 |
+
verb = "have"
|
322 |
+
else:
|
323 |
+
layer_str = (
|
324 |
+
"layers "
|
325 |
+
+ ", ".join(map(str, degen_indices[:-1]))
|
326 |
+
+ ", and "
|
327 |
+
+ str(degen_indices[-1])
|
328 |
+
)
|
329 |
+
verb = "have"
|
330 |
+
|
331 |
+
logging.warning(
|
332 |
+
f"{layer_str} {verb} degenerate routing parameters "
|
333 |
+
"- your prompts may be too similar."
|
334 |
+
)
|
335 |
+
logging.warning("One or more experts will be underutilized in your model.")
|
336 |
+
|
337 |
+
|
338 |
+
def is_bad_config(config: MistralMOEConfig, allow_all_same: bool = False) -> bool:
|
339 |
+
if len(config.experts) < 2:
|
340 |
+
logging.error("Must include at least two experts.")
|
341 |
+
return True
|
342 |
+
|
343 |
+
if config.gate_mode == "random":
|
344 |
+
return False # eh we're good
|
345 |
+
|
346 |
+
def prompt_tup(e: Expert):
|
347 |
+
return (tuple(e.positive_prompts), tuple(e.negative_prompts or []))
|
348 |
+
|
349 |
+
# let's just nip this trend in the bud
|
350 |
+
p_first = prompt_tup(config.experts[0])
|
351 |
+
if all(prompt_tup(e) == p_first for e in config.experts[1:]):
|
352 |
+
logging.error(
|
353 |
+
"Your positive and negative prompts are identical for all experts. This will not produce a functioning MoE."
|
354 |
+
)
|
355 |
+
logging.error(
|
356 |
+
"For each expert, `positive_prompts` must contain one or more example prompt reflecting what should be routed to that expert."
|
357 |
+
)
|
358 |
+
return True
|
359 |
+
|
360 |
+
if not allow_all_same:
|
361 |
+
if all(
|
362 |
+
e.source_model == config.experts[0].source_model for e in config.experts[1:]
|
363 |
+
):
|
364 |
+
logging.error(
|
365 |
+
"All of your expert models are the same. This will produce "
|
366 |
+
"a model that uses more resources but gives the exact same output. "
|
367 |
+
"If you plan to train the model after merging, proceed with the "
|
368 |
+
"--i-understand-this-is-not-useful-without-training flag."
|
369 |
+
)
|
370 |
+
return True
|
371 |
+
|
372 |
+
|
373 |
+
def build(
|
374 |
+
config: MistralMOEConfig,
|
375 |
+
out_path: str,
|
376 |
+
merge_options: MergeOptions,
|
377 |
+
load_in_4bit: bool = False,
|
378 |
+
load_in_8bit: bool = False,
|
379 |
+
device: str = "auto",
|
380 |
+
allow_all_same: bool = False,
|
381 |
+
):
|
382 |
+
if is_bad_config(config, allow_all_same=allow_all_same):
|
383 |
+
sys.exit(1)
|
384 |
+
|
385 |
+
if config.experts_per_token < 1:
|
386 |
+
logging.error("Experts per token must be >= 1")
|
387 |
+
sys.exit(1)
|
388 |
+
if config.experts_per_token > len(config.experts):
|
389 |
+
logging.error("Experts per token must be <= number of experts")
|
390 |
+
sys.exit(1)
|
391 |
+
|
392 |
+
base_model = ModelReference.parse(config.base_model)
|
393 |
+
base_cfg = base_model.config(trust_remote_code=merge_options.trust_remote_code)
|
394 |
+
if not isinstance(base_cfg, MistralConfig):
|
395 |
+
base_cfg_mistral = MistralConfig(**base_cfg.to_dict())
|
396 |
+
base_cfg_mistral.sliding_window = None
|
397 |
+
base_cfg_mistral.max_position_embeddings = base_cfg.max_position_embeddings
|
398 |
+
base_cfg = base_cfg_mistral
|
399 |
+
|
400 |
+
out_cfg = MixtralConfig(**base_cfg.to_dict())
|
401 |
+
out_cfg.architectures = ["PhiForCausalLM"]
|
402 |
+
out_cfg.num_local_experts = len(config.experts)
|
403 |
+
out_cfg.num_experts_per_tok = config.experts_per_token
|
404 |
+
out_cfg.sliding_window = None
|
405 |
+
if config.dtype:
|
406 |
+
out_cfg.torch_dtype = config.dtype
|
407 |
+
out_cfg.save_pretrained(out_path)
|
408 |
+
|
409 |
+
if (out_cfg.num_local_experts & (out_cfg.num_local_experts - 1)) != 0:
|
410 |
+
logging.warning(
|
411 |
+
f"Your model has {out_cfg.num_local_experts} experts, which is "
|
412 |
+
"not a power of two. The model will not be usable in llama.cpp."
|
413 |
+
)
|
414 |
+
|
415 |
+
loaders: Dict[ModelReference, LazyTensorLoader] = {}
|
416 |
+
for model in tqdm.tqdm(
|
417 |
+
[base_model] + [e.model_ref for e in config.experts], desc="Warm up loaders"
|
418 |
+
):
|
419 |
+
loaders[model] = LazyTensorLoader(
|
420 |
+
model.tensor_index(cache_dir=merge_options.transformers_cache),
|
421 |
+
lazy_unpickle=merge_options.lazy_unpickle,
|
422 |
+
)
|
423 |
+
|
424 |
+
base_loader = loaders.get(base_model)
|
425 |
+
writer = TensorWriter(
|
426 |
+
out_path=out_path,
|
427 |
+
max_shard_size=merge_options.out_shard_size,
|
428 |
+
safe_serialization=merge_options.safe_serialization,
|
429 |
+
)
|
430 |
+
|
431 |
+
if config.dtype:
|
432 |
+
out_dtype = dtype_from_name(config.dtype)
|
433 |
+
elif base_cfg.torch_dtype:
|
434 |
+
out_dtype = base_cfg.torch_dtype
|
435 |
+
if isinstance(out_dtype, str):
|
436 |
+
out_dtype = dtype_from_name(out_dtype)
|
437 |
+
else:
|
438 |
+
out_dtype = None
|
439 |
+
|
440 |
+
logging.info("Copying parameters...")
|
441 |
+
MISTRAL_INFO = mergekit.architecture.PHI2_INFO
|
442 |
+
for tensor_name in MISTRAL_INFO.pre_weight_names + MISTRAL_INFO.post_weight_names:
|
443 |
+
tensor = base_loader.get_tensor(tensor_name)
|
444 |
+
if not out_dtype:
|
445 |
+
# All else has failed, take the first dtype we see
|
446 |
+
out_dtype = tensor.dtype
|
447 |
+
writer.save_tensor(
|
448 |
+
tensor_name, tensor.to(dtype=out_dtype), clone=merge_options.clone_tensors
|
449 |
+
)
|
450 |
+
set_of_seen_tensors = set()
|
451 |
+
|
452 |
+
for name_format in tqdm.tqdm(MISTRAL_INFO.layer_weight_formats()):
|
453 |
+
for layer_idx in range(base_cfg.num_hidden_layers):
|
454 |
+
tensor_name = name_format.format(idx=layer_idx)
|
455 |
+
if ".mlp.fc" in name_format:
|
456 |
+
for moe_index, expert in enumerate(config.experts):
|
457 |
+
if tensor_name in set_of_seen_tensors:
|
458 |
+
expert_name = tensor_name.replace(
|
459 |
+
".mlp.fc", f".moe.mlp.1.fc"
|
460 |
+
)
|
461 |
+
else:
|
462 |
+
expert_name = tensor_name.replace(
|
463 |
+
".mlp.fc", f".moe.mlp.0.fc"
|
464 |
+
)
|
465 |
+
set_of_seen_tensors.add(tensor_name)
|
466 |
+
|
467 |
+
expert_loader = loaders.get(expert.model_ref)
|
468 |
+
tensor = expert_loader.get_tensor(tensor_name)
|
469 |
+
if expert.noise_scale:
|
470 |
+
tensor += torch.randn_like(tensor) * expert.noise_scale
|
471 |
+
writer.save_tensor(
|
472 |
+
expert_name, tensor.to(dtype=out_dtype), clone=True
|
473 |
+
)
|
474 |
+
print(expert_name, tensor_name)
|
475 |
+
continue
|
476 |
+
writer.save_tensor(
|
477 |
+
tensor_name, base_loader.get_tensor(tensor_name).to(dtype=out_dtype)
|
478 |
+
)
|
479 |
+
|
480 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
481 |
+
base_model.model.path, revision=base_model.model.revision
|
482 |
+
)
|
483 |
+
tokenizer.padding_side = "left"
|
484 |
+
tokenizer.pad_token_id = tokenizer.bos_token_id
|
485 |
+
|
486 |
+
logging.info("Getting gate parameters...")
|
487 |
+
gate_vecs = get_gate_params(
|
488 |
+
base_model,
|
489 |
+
tokenizer,
|
490 |
+
config.experts,
|
491 |
+
mode=config.gate_mode,
|
492 |
+
load_in_4bit=load_in_4bit,
|
493 |
+
load_in_8bit=load_in_8bit,
|
494 |
+
lazy_unpickle=merge_options.lazy_unpickle,
|
495 |
+
trust_remote_code=merge_options.trust_remote_code,
|
496 |
+
device=device,
|
497 |
+
)
|
498 |
+
# gate_vecs: (num_layers, num_experts, hidden_size)
|
499 |
+
|
500 |
+
warn_degenerate_gates(gate_vecs)
|
501 |
+
|
502 |
+
for layer_idx in range(base_cfg.num_hidden_layers):
|
503 |
+
writer.save_tensor(
|
504 |
+
f"transformer.h.{layer_idx}.moe.gate.weight",
|
505 |
+
gate_vecs[layer_idx, :, :].contiguous().to(dtype=out_dtype),
|
506 |
+
)
|
507 |
+
writer.finalize()
|
508 |
+
|
509 |
+
if merge_options.copy_tokenizer:
|
510 |
+
logging.info("Saving tokenizer...")
|
511 |
+
tokenizer.save_pretrained(out_path, safe_serialization=True)
|
512 |
+
|
513 |
+
logging.info("Done.")
|
514 |
+
|
515 |
+
|
516 |
+
@click.command("mergekit-moe")
|
517 |
+
@click.argument("config_path", type=click.Path(exists=True, dir_okay=False))
|
518 |
+
@click.argument("out_path", type=click.Path())
|
519 |
+
@click.option(
|
520 |
+
"--load-in-4bit",
|
521 |
+
is_flag=True,
|
522 |
+
type=bool,
|
523 |
+
default=False,
|
524 |
+
help="Load model in 4bit for computing hidden states",
|
525 |
+
)
|
526 |
+
@click.option(
|
527 |
+
"--load-in-8bit",
|
528 |
+
is_flag=True,
|
529 |
+
type=bool,
|
530 |
+
default=False,
|
531 |
+
help="Load model in 8bit for computing hidden states",
|
532 |
+
)
|
533 |
+
@click.option(
|
534 |
+
"--device",
|
535 |
+
type=str,
|
536 |
+
default="auto",
|
537 |
+
help="Device to use to compute embeddings",
|
538 |
+
show_default=True,
|
539 |
+
)
|
540 |
+
@click.option(
|
541 |
+
"--verbose", "-v", type=bool, default=False, is_flag=True, help="Verbose logging"
|
542 |
+
)
|
543 |
+
@click.option(
|
544 |
+
"--i-understand-this-is-not-useful-without-training",
|
545 |
+
type=bool,
|
546 |
+
default=False,
|
547 |
+
is_flag=True,
|
548 |
+
help="Really make the questionable model you want.",
|
549 |
+
)
|
550 |
+
@add_merge_options
|
551 |
+
def main(
|
552 |
+
config_path: str,
|
553 |
+
out_path: str,
|
554 |
+
load_in_4bit: bool,
|
555 |
+
load_in_8bit: bool,
|
556 |
+
device: str,
|
557 |
+
merge_options: MergeOptions,
|
558 |
+
verbose: bool,
|
559 |
+
i_understand_this_is_not_useful_without_training: bool,
|
560 |
+
):
|
561 |
+
logging.basicConfig(level=logging.INFO if verbose else logging.WARNING)
|
562 |
+
|
563 |
+
if merge_options.cuda:
|
564 |
+
logging.warning(
|
565 |
+
'--cuda is a no-op for mergekit-moe, use "--device cuda" instead'
|
566 |
+
)
|
567 |
+
|
568 |
+
with open(config_path, "r", encoding="utf-8") as file:
|
569 |
+
config_source = file.read()
|
570 |
+
|
571 |
+
config = MistralMOEConfig.model_validate(yaml.safe_load(config_source))
|
572 |
+
build(
|
573 |
+
config,
|
574 |
+
out_path=out_path,
|
575 |
+
merge_options=merge_options,
|
576 |
+
load_in_4bit=load_in_4bit,
|
577 |
+
load_in_8bit=load_in_8bit,
|
578 |
+
device=device,
|
579 |
+
allow_all_same=i_understand_this_is_not_useful_without_training,
|
580 |
+
)
|
581 |
+
|
582 |
+
if merge_options.write_model_card:
|
583 |
+
# TODO: generate a README.md as well
|
584 |
+
with open(
|
585 |
+
os.path.join(out_path, "mergekit_moe_config.yml"), "w", encoding="utf-8"
|
586 |
+
) as fp:
|
587 |
+
fp.write(config_source)
|
588 |
+
|
589 |
+
|
590 |
+
if __name__ == "__main__":
|
591 |
+
main()
|
592 |
+
```
|
593 |
+
|
594 |
+
2. Modify architecture.py `/content/mergekit/mergekit/architecture.py`
|
595 |
+
(this you can take from the link to the commit i have in description)
|
596 |
+
|
597 |
+
***architecture.py***
|
598 |
+
|
599 |
+
```
|
600 |
+
# Copyright (C) 2024 Charles O. Goddard
|
601 |
+
#
|
602 |
+
# This software is free software: you can redistribute it and/or
|
603 |
+
# modify it under the terms of the GNU Lesser General Public License as
|
604 |
+
# published by the Free Software Foundation, either version 3 of the
|
605 |
+
# License, or (at your option) any later version.
|
606 |
+
#
|
607 |
+
# This software is distributed in the hope that it will be useful, but
|
608 |
+
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
609 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
610 |
+
# Lesser General Public License for more details.
|
611 |
+
#
|
612 |
+
# You should have received a copy of the GNU Lesser General Public License
|
613 |
+
# along with this program. If not, see http://www.gnu.org/licenses/.
|
614 |
+
|
615 |
+
from abc import ABC, abstractmethod
|
616 |
+
from typing import List, Optional
|
617 |
+
|
618 |
+
from pydantic import BaseModel
|
619 |
+
from transformers import PretrainedConfig
|
620 |
+
|
621 |
+
|
622 |
+
class ArchitectureInfo(ABC):
|
623 |
+
@abstractmethod
|
624 |
+
def pre_weights(self) -> List[str]:
|
625 |
+
"""Return a list of all weights preceding the first layer."""
|
626 |
+
...
|
627 |
+
|
628 |
+
@abstractmethod
|
629 |
+
def post_weights(self) -> List[str]:
|
630 |
+
"""Return a list of all weights following the final layer."""
|
631 |
+
...
|
632 |
+
|
633 |
+
@abstractmethod
|
634 |
+
def layer_weight_formats(self) -> List[str]:
|
635 |
+
"""Return a list of format strings all weights associated with a layer."""
|
636 |
+
...
|
637 |
+
|
638 |
+
@abstractmethod
|
639 |
+
def embed_weights(self) -> List[str]:
|
640 |
+
...
|
641 |
+
|
642 |
+
def num_layers(self, config: PretrainedConfig) -> int:
|
643 |
+
return config.num_hidden_layers
|
644 |
+
|
645 |
+
def num_layers_config_key(self) -> str:
|
646 |
+
"""Key in config that represents number of layers"""
|
647 |
+
return "num_hidden_layers"
|
648 |
+
|
649 |
+
|
650 |
+
class StaticTensorNames(ArchitectureInfo, BaseModel, frozen=True):
|
651 |
+
name: str
|
652 |
+
|
653 |
+
pre_weight_names: List[str] # weights applied before first layer
|
654 |
+
post_weight_names: List[str] # weights applied after last layer
|
655 |
+
embed_weight_names: List[str] # weights for embed/lm_head
|
656 |
+
layer_prefix_format: str
|
657 |
+
layer_weight_suffixes: List[str]
|
658 |
+
num_layers_key: Optional[str] = None
|
659 |
+
|
660 |
+
def pre_weights(self) -> List[str]:
|
661 |
+
return self.pre_weight_names
|
662 |
+
|
663 |
+
def post_weights(self) -> List[str]:
|
664 |
+
return self.post_weight_names
|
665 |
+
|
666 |
+
def embed_weights(self) -> List[str]:
|
667 |
+
return self.embed_weight_names
|
668 |
+
|
669 |
+
def layer_weight_formats(self) -> List[str]:
|
670 |
+
res = []
|
671 |
+
for suffix in self.layer_weight_suffixes:
|
672 |
+
res.append(self.layer_prefix_format + "." + suffix)
|
673 |
+
return res
|
674 |
+
|
675 |
+
def num_layers_config_key(self) -> str:
|
676 |
+
if self.num_layers_key:
|
677 |
+
return self.num_layers_key
|
678 |
+
return super().num_layers_config_key()
|
679 |
+
|
680 |
+
def num_layers(self, config: PretrainedConfig) -> int:
|
681 |
+
return getattr(config, self.num_layers_config_key())
|
682 |
+
|
683 |
+
def all_weights(self, config: PretrainedConfig) -> List[str]:
|
684 |
+
num_layers = self.num_layers(config)
|
685 |
+
tensor_names = list(self.pre_weights())
|
686 |
+
for layer_idx in range(num_layers):
|
687 |
+
for f in self.layer_weight_formats():
|
688 |
+
tensor_names.append(f.format(idx=layer_idx))
|
689 |
+
tensor_names.extend(self.post_weights())
|
690 |
+
return tensor_names
|
691 |
+
|
692 |
+
|
693 |
+
LLAMA_INFO = StaticTensorNames(
|
694 |
+
name="LlamaForCausalLM",
|
695 |
+
pre_weight_names=["model.embed_tokens.weight"],
|
696 |
+
post_weight_names=["model.norm.weight", "lm_head.weight"],
|
697 |
+
embed_weight_names=["model.embed_tokens.weight", "lm_head.weight"],
|
698 |
+
layer_prefix_format="model.layers.{idx}",
|
699 |
+
layer_weight_suffixes=[
|
700 |
+
"input_layernorm.weight",
|
701 |
+
"mlp.up_proj.weight",
|
702 |
+
"mlp.down_proj.weight",
|
703 |
+
"mlp.gate_proj.weight",
|
704 |
+
"post_attention_layernorm.weight",
|
705 |
+
"self_attn.q_proj.weight",
|
706 |
+
"self_attn.k_proj.weight",
|
707 |
+
"self_attn.v_proj.weight",
|
708 |
+
"self_attn.o_proj.weight",
|
709 |
+
],
|
710 |
+
)
|
711 |
+
|
712 |
+
MISTRAL_INFO = StaticTensorNames(
|
713 |
+
name="MistralForCausalLM",
|
714 |
+
# lol
|
715 |
+
**LLAMA_INFO.model_dump(exclude=["name"]),
|
716 |
+
)
|
717 |
+
|
718 |
+
|
719 |
+
STABLELM_INFO = StaticTensorNames(
|
720 |
+
name="StableLMEpochForCausalLM",
|
721 |
+
post_weight_names=LLAMA_INFO.post_weight_names + ["model.norm.bias"],
|
722 |
+
layer_weight_suffixes=LLAMA_INFO.layer_weight_suffixes
|
723 |
+
+ [
|
724 |
+
"input_layernorm.bias",
|
725 |
+
"post_attention_layernorm.bias",
|
726 |
+
],
|
727 |
+
**LLAMA_INFO.model_dump(
|
728 |
+
exclude=["name", "layer_weight_suffixes", "post_weight_names"]
|
729 |
+
),
|
730 |
+
)
|
731 |
+
|
732 |
+
GPT_NEOX_INFO = StaticTensorNames(
|
733 |
+
name="GPTNeoXForCausalLM",
|
734 |
+
pre_weight_names=["gpt_neox.embed_in.weight"],
|
735 |
+
post_weight_names=[
|
736 |
+
"gpt_neox.final_layer_norm.bias",
|
737 |
+
"gpt_neox.final_layer_norm.weight",
|
738 |
+
"embed_out.weight",
|
739 |
+
],
|
740 |
+
embed_weight_names=["gpt_neox.embed_in.weight", "embed_out.weight"],
|
741 |
+
layer_prefix_format="gpt_neox.layers.{idx}",
|
742 |
+
layer_weight_suffixes=sum(
|
743 |
+
(
|
744 |
+
[f"{prefix}.weight", f"{prefix}.bias"]
|
745 |
+
for prefix in [
|
746 |
+
"attention.dense",
|
747 |
+
"attention.query_key_value",
|
748 |
+
"input_layernorm",
|
749 |
+
"mlp.dense_4h_to_h",
|
750 |
+
"mlp.dense_h_to_4h",
|
751 |
+
"post_attention_layernorm",
|
752 |
+
]
|
753 |
+
),
|
754 |
+
start=[],
|
755 |
+
)
|
756 |
+
+ ["attention.bias", "attention.masked_bias", "attention.rotary_emb.inv_freq"],
|
757 |
+
)
|
758 |
+
|
759 |
+
GPT2_INFO = StaticTensorNames(
|
760 |
+
name="GPT2LMHeadModel",
|
761 |
+
pre_weight_names=["wte.weight", "wpe.weight"],
|
762 |
+
post_weight_names=["ln_f.weight", "ln_f.bias"],
|
763 |
+
embed_weight_names=["wte.weight"],
|
764 |
+
layer_prefix_format="h.{idx}",
|
765 |
+
layer_weight_suffixes=[
|
766 |
+
"attn.c_attn.weight",
|
767 |
+
"attn.c_attn.bias",
|
768 |
+
"attn.c_proj.weight",
|
769 |
+
"attn.c_proj.bias",
|
770 |
+
"ln_1.weight",
|
771 |
+
"ln_1.bias",
|
772 |
+
"ln_2.weight",
|
773 |
+
"ln_2.bias",
|
774 |
+
"mlp.c_proj.weight",
|
775 |
+
"mlp.c_proj.bias",
|
776 |
+
"mlp.c_fc.weight",
|
777 |
+
"mlp.c_fc.bias",
|
778 |
+
"mlp.c_proj.weight",
|
779 |
+
"mlp.c_proj.bias",
|
780 |
+
],
|
781 |
+
num_layers_key="n_layer",
|
782 |
+
)
|
783 |
+
|
784 |
+
JAIS_INFO = StaticTensorNames(
|
785 |
+
name="JAISLMHeadModel",
|
786 |
+
pre_weight_names=["transformer.wte.weight", "transformer.relative_pe.slopes"],
|
787 |
+
post_weight_names=["transformer.ln_f.weight", "transformer.ln_f.bias"],
|
788 |
+
embed_weight_names=["transformer.wte.weight"],
|
789 |
+
layer_prefix_format="transformer.h.{idx}",
|
790 |
+
layer_weight_suffixes=[
|
791 |
+
"attn.c_attn.weight",
|
792 |
+
"attn.c_attn.bias",
|
793 |
+
"attn.c_proj.weight",
|
794 |
+
"attn.c_proj.bias",
|
795 |
+
"ln_1.weight",
|
796 |
+
"ln_1.bias",
|
797 |
+
"ln_2.weight",
|
798 |
+
"ln_2.bias",
|
799 |
+
"mlp.c_fc.weight",
|
800 |
+
"mlp.c_fc.bias",
|
801 |
+
"mlp.c_fc2.weight",
|
802 |
+
"mlp.c_fc2.bias",
|
803 |
+
"mlp.c_proj.weight",
|
804 |
+
"mlp.c_proj.bias",
|
805 |
+
],
|
806 |
+
num_layers_key="n_layer",
|
807 |
+
)
|
808 |
+
|
809 |
+
GPT2_SEQCLASS_INFO = StaticTensorNames(
|
810 |
+
name="GPT2ForSequenceClassification",
|
811 |
+
pre_weight_names=["transformer.wte.weight", "transformer.wpe.weight"],
|
812 |
+
post_weight_names=[
|
813 |
+
"transformer.ln_f.weight",
|
814 |
+
"transformer.ln_f.bias",
|
815 |
+
"score.weight",
|
816 |
+
],
|
817 |
+
layer_prefix_format="transformer.h.{idx}",
|
818 |
+
embed_weight_names=GPT2_INFO.embed_weight_names,
|
819 |
+
layer_weight_suffixes=GPT2_INFO.layer_weight_suffixes,
|
820 |
+
num_layers_key=GPT2_INFO.num_layers_key,
|
821 |
+
)
|
822 |
+
|
823 |
+
|
824 |
+
QWEN_INFO = StaticTensorNames(
|
825 |
+
name="QWenLMHeadModel",
|
826 |
+
pre_weight_names=["transformer.wte.weight"],
|
827 |
+
post_weight_names=["transformer.ln_f.weight", "lm_head.weight"],
|
828 |
+
embed_weight_names=["transformer.wte.weight", "lm_head.weight"],
|
829 |
+
layer_prefix_format="transformer.h.{idx}",
|
830 |
+
layer_weight_suffixes=[
|
831 |
+
"attn.c_attn.bias",
|
832 |
+
"attn.c_attn.weight",
|
833 |
+
"attn.c_proj.weight",
|
834 |
+
"ln_1.weight",
|
835 |
+
"ln_2.weight",
|
836 |
+
"mlp.c_proj.weight",
|
837 |
+
"mlp.w1.weight",
|
838 |
+
"mlp.w2.weight",
|
839 |
+
],
|
840 |
+
)
|
841 |
+
|
842 |
+
CHATGLM_INFO = StaticTensorNames(
|
843 |
+
name="ChatGLMModel",
|
844 |
+
pre_weight_names=[
|
845 |
+
"transformer.embedding.word_embeddings.weight",
|
846 |
+
"transformer.rotary_pos_emb.inv_freq",
|
847 |
+
],
|
848 |
+
post_weight_names=[
|
849 |
+
"transformer.encoder.final_layernorm.weight",
|
850 |
+
"transformer.output_layer.weight",
|
851 |
+
],
|
852 |
+
embed_weight_names=[
|
853 |
+
"transformer.embedding.word_embeddings.weight",
|
854 |
+
"transformer.output_layer.weight",
|
855 |
+
],
|
856 |
+
layer_prefix_format="transformer.encoder.layers.{idx}",
|
857 |
+
layer_weight_suffixes=[
|
858 |
+
"input_layernorm.weight",
|
859 |
+
"mlp.dense_4h_to_h.weight",
|
860 |
+
"mlp.dense_h_to_4h.weight",
|
861 |
+
"post_attention_layernorm.weight",
|
862 |
+
"self_attention.dense.weight",
|
863 |
+
"self_attention.query_key_value.bias",
|
864 |
+
"self_attention.query_key_value.weight",
|
865 |
+
],
|
866 |
+
)
|
867 |
+
|
868 |
+
FALCON_INFO = StaticTensorNames(
|
869 |
+
name="FalconForCausalLM",
|
870 |
+
pre_weight_names=["transformer.word_embeddings.weight"],
|
871 |
+
post_weight_names=[
|
872 |
+
"transformer.ln_f.weight",
|
873 |
+
"transformer.ln_f.bias",
|
874 |
+
"lm_head.weight",
|
875 |
+
],
|
876 |
+
embed_weight_names=["transformer.word_embeddings.weight", "lm_head.weight"],
|
877 |
+
layer_prefix_format="transformer.h.{idx}",
|
878 |
+
layer_weight_suffixes=[
|
879 |
+
"ln_attn.bias",
|
880 |
+
"ln_attn.weight",
|
881 |
+
"ln_mlp.bias",
|
882 |
+
"ln_mlp.weight",
|
883 |
+
"mlp.dense_4h_to_h.weight",
|
884 |
+
"mlp.dense_h_to_4h.weight",
|
885 |
+
"self_attention.dense.weight",
|
886 |
+
"self_attention.query_key_value.weight",
|
887 |
+
],
|
888 |
+
)
|
889 |
+
|
890 |
+
|
891 |
+
class PhiTensorNames(ArchitectureInfo):
|
892 |
+
architecture_name: str = "MixFormerSequentialForCausalLM"
|
893 |
+
|
894 |
+
def __init__(self, config: PretrainedConfig):
|
895 |
+
self.config = config
|
896 |
+
|
897 |
+
def __eq__(self, rhs: "PhiTensorNames"):
|
898 |
+
if not isinstance(rhs, PhiTensorNames):
|
899 |
+
return False
|
900 |
+
return self.num_layers() == rhs.num_layers()
|
901 |
+
|
902 |
+
def pre_weights(self) -> List[str]:
|
903 |
+
return ["layers.0.wte.weight"]
|
904 |
+
|
905 |
+
def post_weights(self) -> List[str]:
|
906 |
+
fake_layer_idx = self.config.n_layer + 1
|
907 |
+
return [
|
908 |
+
f"layers.{fake_layer_idx}.{suffix}"
|
909 |
+
for suffix in ["linear.bias", "linear.weight", "ln.bias", "ln.weight"]
|
910 |
+
]
|
911 |
+
|
912 |
+
def embed_weights(self) -> List[str]:
|
913 |
+
fake_layer_idx = self.config.n_layer + 1
|
914 |
+
return [
|
915 |
+
"layers.0.wte.weight",
|
916 |
+
f"layers.{fake_layer_idx}.linear.weight",
|
917 |
+
f"layers.{fake_layer_idx}.linear.bias",
|
918 |
+
]
|
919 |
+
|
920 |
+
def layer_weight_formats(self) -> List[str]:
|
921 |
+
return [
|
922 |
+
("layers.{idx}." + suffix)
|
923 |
+
for suffix in [
|
924 |
+
"ln.bias",
|
925 |
+
"ln.weight",
|
926 |
+
"mixer.Wqkv.bias",
|
927 |
+
"mixer.Wqkv.weight",
|
928 |
+
"mixer.out_proj.bias",
|
929 |
+
"mixer.out_proj.weight",
|
930 |
+
"mixer.rotary_emb.inv_freq",
|
931 |
+
"mlp.fc1.bias",
|
932 |
+
"mlp.fc1.weight",
|
933 |
+
"mlp.fc2.bias",
|
934 |
+
"mlp.fc2.weight",
|
935 |
+
]
|
936 |
+
]
|
937 |
+
|
938 |
+
def num_layers(self, config: PretrainedConfig) -> int:
|
939 |
+
return config.n_layer
|
940 |
+
|
941 |
+
def num_layers_config_key(self) -> str:
|
942 |
+
return "n_layer"
|
943 |
+
|
944 |
+
|
945 |
+
PHI2_INFO = StaticTensorNames(
|
946 |
+
name="PhiForCausalLM",
|
947 |
+
pre_weight_names=["transformer.embd.wte.weight"],
|
948 |
+
post_weight_names=[
|
949 |
+
"lm_head.linear.bias",
|
950 |
+
"lm_head.linear.weight",
|
951 |
+
"lm_head.ln.bias",
|
952 |
+
"lm_head.ln.weight",
|
953 |
+
],
|
954 |
+
embed_weight_names=["lm_head.linear.weight", "transformer.embd.wte.weight"],
|
955 |
+
layer_prefix_format="transformer.h.{idx}",
|
956 |
+
layer_weight_suffixes=[
|
957 |
+
"ln.bias",
|
958 |
+
"ln.weight",
|
959 |
+
"mixer.out_proj.bias",
|
960 |
+
"mixer.out_proj.weight",
|
961 |
+
"mixer.Wqkv.bias",
|
962 |
+
"mixer.Wqkv.weight",
|
963 |
+
"mlp.fc1.bias",
|
964 |
+
"mlp.fc1.weight",
|
965 |
+
"mlp.fc2.bias",
|
966 |
+
"mlp.fc2.weight",
|
967 |
+
],
|
968 |
+
num_layers_key="n_layer",
|
969 |
+
)
|
970 |
+
|
971 |
+
|
972 |
+
PHI2_INFO_AGAIN_BUT_DIFFERENT = StaticTensorNames(
|
973 |
+
name="PhiForCausalLM",
|
974 |
+
pre_weight_names=["model.embed_tokens.weight"],
|
975 |
+
post_weight_names=[
|
976 |
+
"lm_head.bias",
|
977 |
+
"lm_head.weight",
|
978 |
+
"model.final_layernorm.bias",
|
979 |
+
"model.final_layernorm.weight",
|
980 |
+
],
|
981 |
+
embed_weight_names=["lm_head.weight", "model.embed_tokens.weight"],
|
982 |
+
layer_prefix_format="model.layers.{idx}",
|
983 |
+
layer_weight_suffixes=[
|
984 |
+
"input_layernorm.bias",
|
985 |
+
"input_layernorm.weight",
|
986 |
+
"self_attn.dense.bias",
|
987 |
+
"self_attn.dense.weight",
|
988 |
+
"self_attn.q_proj.bias",
|
989 |
+
"self_attn.q_proj.weight",
|
990 |
+
"self_attn.k_proj.bias",
|
991 |
+
"self_attn.k_proj.weight",
|
992 |
+
"self_attn.v_proj.bias",
|
993 |
+
"self_attn.v_proj.weight",
|
994 |
+
"mlp.fc1.bias",
|
995 |
+
"mlp.fc1.weight",
|
996 |
+
"mlp.fc2.bias",
|
997 |
+
"mlp.fc2.weight",
|
998 |
+
],
|
999 |
+
)
|
1000 |
+
|
1001 |
+
|
1002 |
+
BAICHUAN_INFO = StaticTensorNames(
|
1003 |
+
name="BaichuanForCausalLM",
|
1004 |
+
pre_weight_names=["model.embed_tokens.weight"],
|
1005 |
+
post_weight_names=["model.norm.weight", "lm_head.weight"],
|
1006 |
+
embed_weight_names=["model.embed_tokens.weight", "lm_head.weight"],
|
1007 |
+
layer_prefix_format="model.layers.{idx}",
|
1008 |
+
layer_weight_suffixes=[
|
1009 |
+
"input_layernorm.weight",
|
1010 |
+
"self_attn.W_pack.weight",
|
1011 |
+
"self_attn.o_proj.weight",
|
1012 |
+
"post_attention_layernorm.weight",
|
1013 |
+
"mlp.gate_proj.weight",
|
1014 |
+
"mlp.down_proj.weight",
|
1015 |
+
"mlp.up_proj.weight",
|
1016 |
+
],
|
1017 |
+
)
|
1018 |
+
|
1019 |
+
|
1020 |
+
def get_architecture_info(config: PretrainedConfig) -> StaticTensorNames:
|
1021 |
+
if len(config.architectures) != 1:
|
1022 |
+
raise RuntimeError("More than one architecture in config?")
|
1023 |
+
|
1024 |
+
arch_name = config.architectures[0]
|
1025 |
+
if arch_name == PhiTensorNames.architecture_name:
|
1026 |
+
return PhiTensorNames(config)
|
1027 |
+
|
1028 |
+
if arch_name == PHI2_INFO.name:
|
1029 |
+
if config.model_type == "phi-msft":
|
1030 |
+
return PHI2_INFO
|
1031 |
+
elif config.model_type == "phi":
|
1032 |
+
return PHI2_INFO_AGAIN_BUT_DIFFERENT
|
1033 |
+
|
1034 |
+
supported = [
|
1035 |
+
LLAMA_INFO,
|
1036 |
+
MISTRAL_INFO,
|
1037 |
+
GPT_NEOX_INFO,
|
1038 |
+
QWEN_INFO,
|
1039 |
+
GPT2_INFO,
|
1040 |
+
GPT2_SEQCLASS_INFO,
|
1041 |
+
CHATGLM_INFO,
|
1042 |
+
STABLELM_INFO,
|
1043 |
+
JAIS_INFO,
|
1044 |
+
BAICHUAN_INFO,
|
1045 |
+
FALCON_INFO,
|
1046 |
+
]
|
1047 |
+
for arch in supported:
|
1048 |
+
if arch.name == arch_name:
|
1049 |
+
return arch
|
1050 |
+
|
1051 |
+
raise RuntimeError(f"Unsupported architecture {arch_name}")
|
1052 |
+
```
|
1053 |
+
|
1054 |
+
3) replace `configs.json` with the one from **this repo**
|
1055 |
+
4) you need to add `modeling_phi.py` and `configurations.phi` from **this repo** to your repo
|