namespace-Pt
commited on
Commit
•
2fee8e7
1
Parent(s):
d7bd91c
Upload modeling_llama.py with huggingface_hub
Browse files- modeling_llama.py +180 -71
modeling_llama.py
CHANGED
@@ -226,8 +226,35 @@ class LlamaMLP(nn.Module):
|
|
226 |
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
227 |
self.act_fn = ACT2FN[config.hidden_act]
|
228 |
|
229 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
if self.config.pretraining_tp > 1:
|
|
|
|
|
|
|
231 |
slice = self.intermediate_size // self.config.pretraining_tp
|
232 |
gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
|
233 |
up_proj_slices = self.up_proj.weight.split(slice, dim=0)
|
@@ -243,8 +270,28 @@ class LlamaMLP(nn.Module):
|
|
243 |
F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
|
244 |
]
|
245 |
down_proj = sum(down_proj)
|
|
|
246 |
else:
|
247 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
248 |
|
249 |
return down_proj
|
250 |
|
@@ -297,16 +344,20 @@ class LlamaAttention(nn.Module):
|
|
297 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
|
298 |
self._init_rope()
|
299 |
|
300 |
-
# NOTE: add extra parameters for
|
301 |
-
self.beacon_q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
|
302 |
-
self.beacon_k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
303 |
-
self.beacon_v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
304 |
-
self.beacon_o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
|
305 |
# skip post initialization to speed up loading
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
310 |
|
311 |
def _init_rope(self):
|
312 |
if self.config.rope_scaling is None:
|
@@ -335,22 +386,33 @@ class LlamaAttention(nn.Module):
|
|
335 |
else:
|
336 |
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
337 |
|
338 |
-
def _init_beacon_proj(self):
|
339 |
-
"""Initialize the
|
|
|
|
|
|
|
340 |
if is_deepspeed_zero3_enabled():
|
341 |
import deepspeed
|
342 |
params = [self.beacon_q_proj.weight, self.beacon_k_proj.weight, self.beacon_v_proj.weight, self.beacon_o_proj.weight, self.q_proj.weight, self.k_proj.weight, self.v_proj.weight, self.o_proj.weight]
|
343 |
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
344 |
self.beacon_q_proj.weight.data[:] = self.q_proj.weight.data
|
|
|
345 |
self.beacon_k_proj.weight.data[:] = self.k_proj.weight.data
|
|
|
346 |
self.beacon_v_proj.weight.data[:] = self.v_proj.weight.data
|
|
|
347 |
self.beacon_o_proj.weight.data[:] = self.o_proj.weight.data
|
348 |
-
else:
|
349 |
-
# only copy the value in-place, without tieing the weight
|
350 |
-
self.beacon_q_proj.weight.data[:] = self.q_proj.weight.data
|
351 |
-
self.beacon_k_proj.weight.data[:] = self.k_proj.weight.data
|
352 |
-
self.beacon_v_proj.weight.data[:] = self.v_proj.weight.data
|
353 |
-
self.beacon_o_proj.weight.data[:] = self.o_proj.weight.data
|
354 |
|
355 |
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
356 |
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
@@ -360,17 +422,26 @@ class LlamaAttention(nn.Module):
|
|
360 |
ordinal_hidden_states = hidden_states[:, :-beacon_size]
|
361 |
beacon_hidden_states = hidden_states[:, -beacon_size:]
|
362 |
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
370 |
|
371 |
-
|
372 |
-
|
373 |
-
|
|
|
|
|
|
|
374 |
|
375 |
else:
|
376 |
query_states = self.q_proj(hidden_states)
|
@@ -378,6 +449,18 @@ class LlamaAttention(nn.Module):
|
|
378 |
value_states = self.v_proj(hidden_states)
|
379 |
|
380 |
return query_states, key_states, value_states
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
381 |
|
382 |
def forward(
|
383 |
self,
|
@@ -403,8 +486,10 @@ class LlamaAttention(nn.Module):
|
|
403 |
else:
|
404 |
past_seq_len = 0
|
405 |
|
406 |
-
# TODO: support pretraining_tp
|
407 |
if self.config.pretraining_tp > 1:
|
|
|
|
|
|
|
408 |
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
409 |
query_slices = self.q_proj.weight.split(
|
410 |
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
|
@@ -430,13 +515,14 @@ class LlamaAttention(nn.Module):
|
|
430 |
|
431 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
432 |
|
|
|
|
|
|
|
|
|
433 |
if past_key is not None:
|
434 |
# reuse k, v, self_attention
|
435 |
key_states = torch.cat([past_key, key_states], dim=2)
|
436 |
value_states = torch.cat([past_value, value_states], dim=2)
|
437 |
-
|
438 |
-
# return keys and values before rope
|
439 |
-
past_key_value = (key_states, value_states, beacon_size, raw_size_to_cache, window_size)
|
440 |
|
441 |
key_position_ids = position_ids
|
442 |
# align query position_ids with key
|
@@ -480,16 +566,13 @@ class LlamaAttention(nn.Module):
|
|
480 |
|
481 |
if self.config.pretraining_tp > 1:
|
482 |
# TODO: support pretraining_tp
|
|
|
483 |
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
484 |
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
|
485 |
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
|
|
|
486 |
else:
|
487 |
-
|
488 |
-
regular_attn_output = self.o_proj(attn_output[:, :-beacon_size])
|
489 |
-
beacon_attn_output = self.beacon_o_proj(attn_output[:, -beacon_size:])
|
490 |
-
attn_output = torch.cat([regular_attn_output, beacon_attn_output], dim=1)
|
491 |
-
else:
|
492 |
-
attn_output = self.o_proj(attn_output)
|
493 |
|
494 |
if not output_attentions:
|
495 |
attn_weights = None
|
@@ -545,14 +628,15 @@ class LlamaSdpaAttention(LlamaAttention):
|
|
545 |
|
546 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
547 |
|
|
|
|
|
|
|
|
|
548 |
if past_key is not None:
|
549 |
# reuse k, v, self_attention
|
550 |
key_states = torch.cat([past_key, key_states], dim=2)
|
551 |
value_states = torch.cat([past_value, value_states], dim=2)
|
552 |
|
553 |
-
# return keys and values before rope
|
554 |
-
past_key_value = (key_states, value_states, beacon_size, raw_size_to_cache, window_size)
|
555 |
-
|
556 |
key_position_ids = position_ids
|
557 |
# align query position_ids with key
|
558 |
query_position_ids = key_position_ids[:, -q_len:]
|
@@ -588,13 +672,20 @@ class LlamaSdpaAttention(LlamaAttention):
|
|
588 |
|
589 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
590 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
|
597 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
598 |
|
599 |
return attn_output, None, past_key_value
|
600 |
|
@@ -645,6 +736,9 @@ class LlamaDecoderLayer(nn.Module):
|
|
645 |
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
646 |
)
|
647 |
|
|
|
|
|
|
|
648 |
residual = hidden_states
|
649 |
|
650 |
hidden_states = self.input_layernorm(hidden_states)
|
@@ -664,7 +758,7 @@ class LlamaDecoderLayer(nn.Module):
|
|
664 |
# Fully Connected
|
665 |
residual = hidden_states
|
666 |
hidden_states = self.post_attention_layernorm(hidden_states)
|
667 |
-
hidden_states = self.mlp(hidden_states)
|
668 |
hidden_states = residual + hidden_states
|
669 |
|
670 |
outputs = (hidden_states,)
|
@@ -843,10 +937,6 @@ def compute_loss(logits, labels, shift=False):
|
|
843 |
if (valid_token_num == 0).any():
|
844 |
batch_loss = batch_loss.masked_fill(valid_token_num == 0, 0.)
|
845 |
|
846 |
-
# print("beacon")
|
847 |
-
# print(f"token_loss: {token_loss[:, :100].tolist()}")
|
848 |
-
# print(f"batch_loss: {batch_loss}")
|
849 |
-
# input()
|
850 |
return loss, batch_loss, valid_token_num
|
851 |
|
852 |
@dataclass
|
@@ -895,7 +985,7 @@ class LlamaModel(LlamaPreTrainedModel):
|
|
895 |
self.post_init()
|
896 |
|
897 |
def _init_beacon_embed(self):
|
898 |
-
"""Initialize the
|
899 |
if is_deepspeed_zero3_enabled():
|
900 |
import deepspeed
|
901 |
params = [self.beacon_embed_tokens.weight, self.embed_tokens.weight]
|
@@ -1109,6 +1199,15 @@ class LlamaModel(LlamaPreTrainedModel):
|
|
1109 |
|
1110 |
hidden_states = self.norm(hidden_states)
|
1111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1112 |
# add hidden states from the last decoder layer
|
1113 |
if output_hidden_states:
|
1114 |
all_hidden_states += (hidden_states,)
|
@@ -1139,9 +1238,6 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
|
1139 |
|
1140 |
def set_memory(self):
|
1141 |
config: LlamaConfig = self.config
|
1142 |
-
info = f"applying activation beacon on {'all' if config.beacon_layers is None else config.beacon_layers} layers, with window size {config.beacon_window}, stride {config.beacon_stride} (mixed by {config.beacon_stride_mix}), {config.beacon_attn} attention, and condensing ratio {config.beacon_ratio} (mixed by {config.beacon_ratio_mix}), seed {config.beacon_seed}..."
|
1143 |
-
logger.info(info)
|
1144 |
-
|
1145 |
self.memory = Memory(
|
1146 |
model_config=config,
|
1147 |
beacon_window=config.beacon_window,
|
@@ -1151,10 +1247,11 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
|
1151 |
beacon_ratio=config.beacon_ratio,
|
1152 |
beacon_stride_mix=config.beacon_stride_mix,
|
1153 |
beacon_ratio_mix=config.beacon_ratio_mix,
|
1154 |
-
|
1155 |
-
beacon_layers=config.beacon_layers,
|
1156 |
k_seq_dim=2,
|
1157 |
v_seq_dim=2,
|
|
|
|
|
1158 |
)
|
1159 |
|
1160 |
def get_input_embeddings(self):
|
@@ -1180,15 +1277,26 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
|
1180 |
"""Override the default from_pretrained to extend vocab size according to beacon_size."""
|
1181 |
model, loading_info = super().from_pretrained(*args, **kwargs, output_loading_info=True)
|
1182 |
missing_keys = loading_info["missing_keys"]
|
1183 |
-
# only initialize weights when they are missing from the checkpoint
|
1184 |
-
|
1185 |
-
|
1186 |
-
|
1187 |
-
|
1188 |
-
|
1189 |
-
|
1190 |
-
|
1191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1192 |
return model
|
1193 |
|
1194 |
def _native_forward(
|
@@ -1397,7 +1505,7 @@ def evaluate_perplexity(model, dataloader, accelerator:Optional[Accelerator]=Non
|
|
1397 |
|
1398 |
# NOTE: we need the loss for each element in the batch for accurate computation, because the number of valid tokens may differ among elements
|
1399 |
if hasattr(output, "batch_loss"):
|
1400 |
-
# output from
|
1401 |
batch_loss = output.batch_loss
|
1402 |
valid_token_num = output.valid_token_num
|
1403 |
else:
|
@@ -1415,9 +1523,10 @@ def evaluate_perplexity(model, dataloader, accelerator:Optional[Accelerator]=Non
|
|
1415 |
all_loss[_id].append((_loss * _num, _num))
|
1416 |
|
1417 |
for _id, loss_and_num in all_loss.items():
|
1418 |
-
# sum up the loss for all valid tokens, and divide the number of valid tokens
|
1419 |
all_loss[_id] = sum([x[0] for x in loss_and_num]) / sum(x[1] for x in loss_and_num)
|
1420 |
|
|
|
1421 |
perplexity = math.exp(sum(all_loss.values()) / len(all_loss))
|
1422 |
return perplexity
|
1423 |
|
|
|
226 |
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
227 |
self.act_fn = ACT2FN[config.hidden_act]
|
228 |
|
229 |
+
if "mlp" in config.beacon_param:
|
230 |
+
self.beacon_up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
231 |
+
self.beacon_down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
232 |
+
self.beacon_up_proj._is_hf_initialized = True
|
233 |
+
self.beacon_down_proj._is_hf_initialized = True
|
234 |
+
|
235 |
+
def _init_beacon_proj(self, beacon_param=None):
|
236 |
+
"""Initialize the beacon projection weight with that of the ordinal projection."""
|
237 |
+
if beacon_param is None:
|
238 |
+
beacon_param = self.config.beacon_param
|
239 |
+
|
240 |
+
if is_deepspeed_zero3_enabled():
|
241 |
+
import deepspeed
|
242 |
+
params = [self.up_proj, self.down_proj, self.beacon_up_proj, self.beacon_down_proj]
|
243 |
+
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
|
244 |
+
if "mlp" in beacon_param:
|
245 |
+
self.beacon_up_proj.weight.data[:] = self.up_proj.weight.data
|
246 |
+
self.beacon_down_proj.weight.data[:] = self.down_proj.weight.data
|
247 |
+
else:
|
248 |
+
# only copy the value in-place, without tieing the weight
|
249 |
+
if "mlp" in beacon_param:
|
250 |
+
self.beacon_up_proj.weight.data[:] = self.up_proj.weight.data
|
251 |
+
self.beacon_down_proj.weight.data[:] = self.down_proj.weight.data
|
252 |
+
|
253 |
+
def forward(self, x, beacon_size):
|
254 |
if self.config.pretraining_tp > 1:
|
255 |
+
# TODO: support pretraining_tp
|
256 |
+
raise NotImplementedError
|
257 |
+
|
258 |
slice = self.intermediate_size // self.config.pretraining_tp
|
259 |
gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
|
260 |
up_proj_slices = self.up_proj.weight.split(slice, dim=0)
|
|
|
270 |
F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
|
271 |
]
|
272 |
down_proj = sum(down_proj)
|
273 |
+
|
274 |
else:
|
275 |
+
if "mlp" in self.config.beacon_param:
|
276 |
+
if beacon_size > 0:
|
277 |
+
ordinal_hidden_states = x[:, :-beacon_size]
|
278 |
+
beacon_hidden_states = x[:, -beacon_size:]
|
279 |
+
|
280 |
+
# ordinal_up_proj = self.up_proj(ordinal_hidden_states)
|
281 |
+
# beacon_up_proj = self.beacon_up_proj(beacon_hidden_states)
|
282 |
+
# up_proj = torch.cat([ordinal_up_proj, beacon_up_proj], dim=1)
|
283 |
+
# intermediate = self.act_fn(self.gate_proj(x)) * up_proj
|
284 |
+
# ordinal_down_proj = self.down_proj(intermediate[:, :-beacon_size])
|
285 |
+
# beacon_down_proj = self.beacon_down_proj(intermediate[:, -beacon_size:])
|
286 |
+
# down_proj = torch.cat([ordinal_down_proj, beacon_down_proj], dim=1)
|
287 |
+
|
288 |
+
ordinal_down_proj = self.down_proj(self.act_fn(self.gate_proj(ordinal_hidden_states)) * self.up_proj(ordinal_hidden_states))
|
289 |
+
beacon_down_proj = self.beacon_down_proj(self.act_fn(self.gate_proj(beacon_hidden_states)) * self.beacon_up_proj(beacon_hidden_states))
|
290 |
+
down_proj = torch.cat([ordinal_down_proj, beacon_down_proj], dim=1)
|
291 |
+
else:
|
292 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
293 |
+
else:
|
294 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
295 |
|
296 |
return down_proj
|
297 |
|
|
|
344 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
|
345 |
self._init_rope()
|
346 |
|
347 |
+
# NOTE: add extra parameters for beacon tokens
|
|
|
|
|
|
|
|
|
348 |
# skip post initialization to speed up loading
|
349 |
+
if "q" in config.beacon_param:
|
350 |
+
self.beacon_q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
|
351 |
+
self.beacon_q_proj._is_hf_initialized = True
|
352 |
+
if "k" in config.beacon_param:
|
353 |
+
self.beacon_k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
354 |
+
self.beacon_k_proj._is_hf_initialized = True
|
355 |
+
if "v" in config.beacon_param:
|
356 |
+
self.beacon_v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
357 |
+
self.beacon_v_proj._is_hf_initialized = True
|
358 |
+
if "o" in config.beacon_param:
|
359 |
+
self.beacon_o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
|
360 |
+
self.beacon_o_proj._is_hf_initialized = True
|
361 |
|
362 |
def _init_rope(self):
|
363 |
if self.config.rope_scaling is None:
|
|
|
386 |
else:
|
387 |
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
388 |
|
389 |
+
def _init_beacon_proj(self, beacon_param=None):
|
390 |
+
"""Initialize the beacon projection weight with that of the ordinal projection."""
|
391 |
+
if beacon_param is None:
|
392 |
+
beacon_param = self.config.beacon_param
|
393 |
+
|
394 |
if is_deepspeed_zero3_enabled():
|
395 |
import deepspeed
|
396 |
params = [self.beacon_q_proj.weight, self.beacon_k_proj.weight, self.beacon_v_proj.weight, self.beacon_o_proj.weight, self.q_proj.weight, self.k_proj.weight, self.v_proj.weight, self.o_proj.weight]
|
397 |
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
|
398 |
+
if "q" in beacon_param:
|
399 |
+
self.beacon_q_proj.weight.data[:] = self.q_proj.weight.data
|
400 |
+
if "k" in beacon_param:
|
401 |
+
self.beacon_k_proj.weight.data[:] = self.k_proj.weight.data
|
402 |
+
if "v" in beacon_param:
|
403 |
+
self.beacon_v_proj.weight.data[:] = self.v_proj.weight.data
|
404 |
+
if "o" in beacon_param:
|
405 |
+
self.beacon_o_proj.weight.data[:] = self.o_proj.weight.data
|
406 |
+
else:
|
407 |
+
# only copy the value in-place, without tieing the weight
|
408 |
+
if "q" in beacon_param:
|
409 |
self.beacon_q_proj.weight.data[:] = self.q_proj.weight.data
|
410 |
+
if "k" in beacon_param:
|
411 |
self.beacon_k_proj.weight.data[:] = self.k_proj.weight.data
|
412 |
+
if "v" in beacon_param:
|
413 |
self.beacon_v_proj.weight.data[:] = self.v_proj.weight.data
|
414 |
+
if "o" in beacon_param:
|
415 |
self.beacon_o_proj.weight.data[:] = self.o_proj.weight.data
|
|
|
|
|
|
|
|
|
|
|
|
|
416 |
|
417 |
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
418 |
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
|
|
422 |
ordinal_hidden_states = hidden_states[:, :-beacon_size]
|
423 |
beacon_hidden_states = hidden_states[:, -beacon_size:]
|
424 |
|
425 |
+
if "q" in self.config.beacon_param:
|
426 |
+
ordinal_query_states = self.q_proj(ordinal_hidden_states)
|
427 |
+
beacon_query_states = self.beacon_q_proj(beacon_hidden_states)
|
428 |
+
query_states = torch.cat([ordinal_query_states, beacon_query_states], dim=1)
|
429 |
+
else:
|
430 |
+
query_states = self.q_proj(hidden_states)
|
431 |
+
|
432 |
+
if "k" in self.config.beacon_param:
|
433 |
+
ordinal_key_states = self.k_proj(ordinal_hidden_states)
|
434 |
+
beacon_key_states = self.beacon_k_proj(beacon_hidden_states)
|
435 |
+
key_states = torch.cat([ordinal_key_states, beacon_key_states], dim=1)
|
436 |
+
else:
|
437 |
+
key_states = self.k_proj(hidden_states)
|
438 |
|
439 |
+
if "v" in self.config.beacon_param:
|
440 |
+
ordinal_value_states = self.v_proj(ordinal_hidden_states)
|
441 |
+
beacon_value_states = self.beacon_v_proj(beacon_hidden_states)
|
442 |
+
value_states = torch.cat([ordinal_value_states, beacon_value_states], dim=1)
|
443 |
+
else:
|
444 |
+
value_states = self.v_proj(hidden_states)
|
445 |
|
446 |
else:
|
447 |
query_states = self.q_proj(hidden_states)
|
|
|
449 |
value_states = self.v_proj(hidden_states)
|
450 |
|
451 |
return query_states, key_states, value_states
|
452 |
+
|
453 |
+
def o_proj_with_beacon(self, attn_output, beacon_size=0):
|
454 |
+
if beacon_size > 0:
|
455 |
+
if "o" in self.config.beacon_param:
|
456 |
+
ordinal_attn_output = self.o_proj(attn_output[:, :-beacon_size])
|
457 |
+
beacon_attn_output = self.beacon_o_proj(attn_output[:, -beacon_size:])
|
458 |
+
attn_output = torch.cat([ordinal_attn_output, beacon_attn_output], dim=1)
|
459 |
+
else:
|
460 |
+
attn_output = self.o_proj(attn_output)
|
461 |
+
else:
|
462 |
+
attn_output = self.o_proj(attn_output)
|
463 |
+
return attn_output
|
464 |
|
465 |
def forward(
|
466 |
self,
|
|
|
486 |
else:
|
487 |
past_seq_len = 0
|
488 |
|
|
|
489 |
if self.config.pretraining_tp > 1:
|
490 |
+
# TODO: support pretraining_tp
|
491 |
+
raise NotImplementedError
|
492 |
+
|
493 |
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
494 |
query_slices = self.q_proj.weight.split(
|
495 |
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
|
|
|
515 |
|
516 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
517 |
|
518 |
+
# return keys and values before rope
|
519 |
+
# NOTE: incrementally return keys and values for efficiency
|
520 |
+
past_key_value = (key_states, value_states, beacon_size, raw_size_to_cache, window_size)
|
521 |
+
|
522 |
if past_key is not None:
|
523 |
# reuse k, v, self_attention
|
524 |
key_states = torch.cat([past_key, key_states], dim=2)
|
525 |
value_states = torch.cat([past_value, value_states], dim=2)
|
|
|
|
|
|
|
526 |
|
527 |
key_position_ids = position_ids
|
528 |
# align query position_ids with key
|
|
|
566 |
|
567 |
if self.config.pretraining_tp > 1:
|
568 |
# TODO: support pretraining_tp
|
569 |
+
raise NotImplementedError
|
570 |
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
571 |
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
|
572 |
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
|
573 |
+
|
574 |
else:
|
575 |
+
attn_output = self.o_proj_with_beacon(attn_output, beacon_size)
|
|
|
|
|
|
|
|
|
|
|
576 |
|
577 |
if not output_attentions:
|
578 |
attn_weights = None
|
|
|
628 |
|
629 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
630 |
|
631 |
+
# return keys and values before rope
|
632 |
+
# NOTE: incrementally return keys and values for efficiency
|
633 |
+
past_key_value = (key_states, value_states, beacon_size, raw_size_to_cache, window_size)
|
634 |
+
|
635 |
if past_key is not None:
|
636 |
# reuse k, v, self_attention
|
637 |
key_states = torch.cat([past_key, key_states], dim=2)
|
638 |
value_states = torch.cat([past_value, value_states], dim=2)
|
639 |
|
|
|
|
|
|
|
640 |
key_position_ids = position_ids
|
641 |
# align query position_ids with key
|
642 |
query_position_ids = key_position_ids[:, -q_len:]
|
|
|
672 |
|
673 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
674 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
675 |
+
attn_output = self.o_proj_with_beacon(attn_output, beacon_size)
|
676 |
+
|
677 |
+
# for debug
|
678 |
+
# if torch.distributed.get_rank() == 4 and self.layer_idx == 0:
|
679 |
+
# torch.save({
|
680 |
+
# "hidden_states": hidden_states,
|
681 |
+
# "past_key_value": past_key_value,
|
682 |
+
# "query_states": query_states,
|
683 |
+
# "key_states": key_states,
|
684 |
+
# "value_states": value_states,
|
685 |
+
# "attn_output": attn_output,
|
686 |
+
# "attention_mask": attention_mask,
|
687 |
+
# "key_position_ids": key_position_ids,
|
688 |
+
# }, "beacon_llama_layer_0")
|
689 |
|
690 |
return attn_output, None, past_key_value
|
691 |
|
|
|
736 |
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
737 |
)
|
738 |
|
739 |
+
# NOTE: get beacon_size in case the mlp is included in beacon_param
|
740 |
+
past_key, past_value, beacon_size, raw_size_to_cache, window_size = past_key_value
|
741 |
+
|
742 |
residual = hidden_states
|
743 |
|
744 |
hidden_states = self.input_layernorm(hidden_states)
|
|
|
758 |
# Fully Connected
|
759 |
residual = hidden_states
|
760 |
hidden_states = self.post_attention_layernorm(hidden_states)
|
761 |
+
hidden_states = self.mlp(hidden_states, beacon_size)
|
762 |
hidden_states = residual + hidden_states
|
763 |
|
764 |
outputs = (hidden_states,)
|
|
|
937 |
if (valid_token_num == 0).any():
|
938 |
batch_loss = batch_loss.masked_fill(valid_token_num == 0, 0.)
|
939 |
|
|
|
|
|
|
|
|
|
940 |
return loss, batch_loss, valid_token_num
|
941 |
|
942 |
@dataclass
|
|
|
985 |
self.post_init()
|
986 |
|
987 |
def _init_beacon_embed(self):
|
988 |
+
"""Initialize the beacon token embedding with that of the eos token."""
|
989 |
if is_deepspeed_zero3_enabled():
|
990 |
import deepspeed
|
991 |
params = [self.beacon_embed_tokens.weight, self.embed_tokens.weight]
|
|
|
1199 |
|
1200 |
hidden_states = self.norm(hidden_states)
|
1201 |
|
1202 |
+
# for debug
|
1203 |
+
# if torch.distributed.get_rank() == 4:
|
1204 |
+
# torch.save({
|
1205 |
+
# "hidden_states": hidden_states,
|
1206 |
+
# "past_key_values": past_key_values,
|
1207 |
+
# "attention_mask": attention_mask,
|
1208 |
+
# "position_ids": position_ids,
|
1209 |
+
# }, "beacon_llama_inputs")
|
1210 |
+
|
1211 |
# add hidden states from the last decoder layer
|
1212 |
if output_hidden_states:
|
1213 |
all_hidden_states += (hidden_states,)
|
|
|
1238 |
|
1239 |
def set_memory(self):
|
1240 |
config: LlamaConfig = self.config
|
|
|
|
|
|
|
1241 |
self.memory = Memory(
|
1242 |
model_config=config,
|
1243 |
beacon_window=config.beacon_window,
|
|
|
1247 |
beacon_ratio=config.beacon_ratio,
|
1248 |
beacon_stride_mix=config.beacon_stride_mix,
|
1249 |
beacon_ratio_mix=config.beacon_ratio_mix,
|
1250 |
+
beacon_param=config.beacon_param,
|
|
|
1251 |
k_seq_dim=2,
|
1252 |
v_seq_dim=2,
|
1253 |
+
retrieval_method=config.retrieval_method,
|
1254 |
+
retrieval_topk=config.retrieval_topk,
|
1255 |
)
|
1256 |
|
1257 |
def get_input_embeddings(self):
|
|
|
1277 |
"""Override the default from_pretrained to extend vocab size according to beacon_size."""
|
1278 |
model, loading_info = super().from_pretrained(*args, **kwargs, output_loading_info=True)
|
1279 |
missing_keys = loading_info["missing_keys"]
|
1280 |
+
# only initialize beacon weights when they are missing from the checkpoint
|
1281 |
+
beacon_param = set()
|
1282 |
+
for missing_key in missing_keys:
|
1283 |
+
if "beacon_embed_tokens" in missing_key:
|
1284 |
+
model.model._init_beacon_embed()
|
1285 |
+
elif "beacon_q_proj" in missing_key:
|
1286 |
+
beacon_param.add("q")
|
1287 |
+
elif "beacon_k_proj" in missing_key:
|
1288 |
+
beacon_param.add("k")
|
1289 |
+
elif "beacon_v_proj" in missing_key:
|
1290 |
+
beacon_param.add("v")
|
1291 |
+
elif "beacon_o_proj" in missing_key:
|
1292 |
+
beacon_param.add("o")
|
1293 |
+
elif "beacon_up_proj" in missing_key:
|
1294 |
+
beacon_param.add("mlp")
|
1295 |
+
|
1296 |
+
# initialize weights of possible q,k,v,o,mlp
|
1297 |
+
for layer in model.model.layers:
|
1298 |
+
layer.self_attn._init_beacon_proj(beacon_param)
|
1299 |
+
layer.mlp._init_beacon_proj(beacon_param)
|
1300 |
return model
|
1301 |
|
1302 |
def _native_forward(
|
|
|
1505 |
|
1506 |
# NOTE: we need the loss for each element in the batch for accurate computation, because the number of valid tokens may differ among elements
|
1507 |
if hasattr(output, "batch_loss"):
|
1508 |
+
# output from our model has batch_loss by default
|
1509 |
batch_loss = output.batch_loss
|
1510 |
valid_token_num = output.valid_token_num
|
1511 |
else:
|
|
|
1523 |
all_loss[_id].append((_loss * _num, _num))
|
1524 |
|
1525 |
for _id, loss_and_num in all_loss.items():
|
1526 |
+
# sum up the loss for all valid tokens in the entire sequence, and divide the number of valid tokens
|
1527 |
all_loss[_id] = sum([x[0] for x in loss_and_num]) / sum(x[1] for x in loss_and_num)
|
1528 |
|
1529 |
+
# average across then take exp
|
1530 |
perplexity = math.exp(sum(all_loss.values()) / len(all_loss))
|
1531 |
return perplexity
|
1532 |
|