|
import math |
|
import warnings |
|
import hashlib |
|
import os |
|
from typing import List, Optional, Tuple, Union |
|
from dataclasses import dataclass |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
import torch.utils.checkpoint |
|
from torch import nn |
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
|
|
|
from transformers.models.llama.modeling_llama import LlamaModel, LlamaPreTrainedModel, LlamaForCausalLM |
|
from transformers.models.gemma.modeling_gemma import GemmaModel, GemmaPreTrainedModel, GemmaForCausalLM |
|
from transformers.models.qwen2.modeling_qwen2 import Qwen2Model, Qwen2PreTrainedModel, Qwen2ForCausalLM |
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
from transformers.cache_utils import Cache |
|
|
|
|
|
|
|
|
|
|
|
def svd_with_cache(matrix, cache_dir, max_rank=1024): |
|
""" |
|
SVD with cache mechanism to avoid repeated SVD computation. |
|
SVD can be very slow for large matrices, so we cache the results. |
|
""" |
|
in_dim, out_dim = matrix.shape |
|
|
|
|
|
weight_hash = in_dim * out_dim |
|
cache_file = os.path.join(cache_dir, f'{weight_hash}.pt') |
|
|
|
if not os.path.exists(cache_dir): |
|
os.makedirs(cache_dir) |
|
|
|
if os.path.exists(cache_file): |
|
|
|
|
|
U, S, Vh = torch.load(cache_file) |
|
else: |
|
|
|
|
|
U, S, Vh = torch.linalg.svd(matrix.float()) |
|
U = U[:, :max_rank].clone() |
|
S = S[:max_rank].clone() |
|
Vh = Vh[:max_rank, :].clone() |
|
|
|
torch.save((U, S, Vh), cache_file) |
|
return U, S, Vh |
|
|
|
def create_factorized_compression_for_linear(source_linear, rank, svd_cache_dir='experiment_cache/'): |
|
""" |
|
Adapt from: https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/cli_svd.py |
|
Create a factorized compression for a given linear layer using SVD. |
|
Args: |
|
source_linear (nn.Linear): The original linear layer to be compressed. |
|
rank (int, optional): The rank for the factorization. If None, it will be calculated based on rank_factor. |
|
rank_factor (float, optional): The factor to determine the rank if rank is not provided. Default is 0.3. |
|
Returns: |
|
nn.Sequential: A sequential container of the compressed linear layers. |
|
""" |
|
|
|
with torch.no_grad(): |
|
dtype = source_linear.weight.dtype |
|
|
|
if hasattr(source_linear, 'bias'): |
|
bias = source_linear.bias |
|
else: |
|
bias = None |
|
|
|
source_num_params = sum(param.numel() for param in source_linear.parameters()) |
|
|
|
source_linear_weight = source_linear.weight.data |
|
|
|
assert rank < min(source_linear_weight.shape) |
|
|
|
|
|
U, S, Vh = svd_with_cache(source_linear_weight, svd_cache_dir) |
|
|
|
U = U[:, :rank].contiguous() |
|
S = S[:rank].contiguous() |
|
Vh = Vh[:rank, :].contiguous() |
|
|
|
U = U @ torch.diag(S) |
|
|
|
U_flatten = U.flatten() |
|
Vh_flatten = Vh.flatten() |
|
|
|
max_quant_size = 2**23 |
|
|
|
if len(U_flatten) + len(Vh_flatten) >= max_quant_size: |
|
dist2 = U_flatten[:min(len(U_flatten), max_quant_size)] |
|
dist3 = Vh_flatten[:min(len(Vh_flatten), max_quant_size)] |
|
hi_val = max(torch.quantile(dist3, 1), torch.quantile(dist2, 1)) |
|
else: |
|
dist = torch.cat([U_flatten, Vh_flatten]) |
|
hi_val = torch.quantile(dist, 1) |
|
low_val = -hi_val |
|
|
|
U = U.clamp(low_val, hi_val) |
|
Vh = Vh.clamp(low_val, hi_val) |
|
|
|
lora_down = nn.Linear(Vh.shape[1], Vh.shape[0], dtype=dtype, bias=False, device=source_linear_weight.device) |
|
lora_down.weight.data = Vh.to(device=source_linear_weight.device, dtype=dtype) |
|
|
|
lora_up = nn.Linear(U.shape[1], U.shape[0], dtype=dtype, bias=bias is not None, device=source_linear_weight.device) |
|
lora_up.weight.data = U.to(device=source_linear_weight.device, dtype=dtype) |
|
|
|
if bias is not None: |
|
lora_up.bias = nn.Parameter(bias.clone()) |
|
|
|
|
|
return lora_down, lora_up |
|
|
|
|
|
@dataclass |
|
class AdaCausalLMOutputWithPast(CausalLMOutputWithPast): |
|
|
|
|
|
|
|
lm_head_logits: Optional[torch.FloatTensor] = None |
|
lm_loss: Optional[torch.FloatTensor] = None |
|
mask_loss: Optional[torch.FloatTensor] = None |
|
topk_loss: Optional[torch.FloatTensor] = None |
|
|
|
class AdaVocabHead_MLP(nn.Module): |
|
|
|
def __init__(self, lm_head, sub_vocab_dim, activation_func=torch.nn.GELU()): |
|
hidden_size, vocab_size = lm_head.in_features, lm_head.out_features |
|
super().__init__() |
|
|
|
self.A = nn.Linear(hidden_size, sub_vocab_dim, bias=False) |
|
self.B = nn.Linear(sub_vocab_dim, sub_vocab_dim, bias=True) |
|
self.C = nn.Linear(sub_vocab_dim, vocab_size, bias=False) |
|
std_dev = 1 / math.sqrt(sub_vocab_dim) |
|
nn.init.normal_(self.A.weight, 0, std_dev) |
|
nn.init.normal_(self.B.weight, 0, std_dev) |
|
nn.init.zeros_(self.C.weight) |
|
self.activation_func = activation_func |
|
|
|
def forward(self, x): |
|
|
|
|
|
|
|
|
|
logits = self.A(x) |
|
logits = self.activation_func(logits) |
|
logits = self.B(logits) |
|
|
|
ada_vocab_logits = self.C(logits) |
|
|
|
return ada_vocab_logits |
|
|
|
class AdaVocabHead_LORA(nn.Module): |
|
def __init__(self, lm_head, sub_vocab_dim, svd=False): |
|
hidden_size, vocab_size = lm_head.in_features, lm_head.out_features |
|
super().__init__() |
|
if svd: |
|
self.A, self.B = create_factorized_compression_for_linear(lm_head, sub_vocab_dim) |
|
else: |
|
self.A = nn.Linear(hidden_size, sub_vocab_dim, bias=False) |
|
self.B = nn.Linear(sub_vocab_dim, vocab_size, bias=False) |
|
std_dev = 1 / math.sqrt(sub_vocab_dim) |
|
nn.init.normal_(self.A.weight, 0, std_dev) |
|
nn.init.zeros_(self.B.weight) |
|
|
|
def forward(self, x): |
|
|
|
logits = self.A(x) |
|
ada_vocab_logits = self.B(logits) |
|
return ada_vocab_logits |
|
|
|
def create_AdaVocabCausalLM(base_class): |
|
class AdaVocabCausalLM(base_class): |
|
|
|
_tied_weights_keys = ["lm_head.weight"] |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.sub_vocab_dim = config.ADA_DIM |
|
self.offload_tag = False |
|
|
|
|
|
if config.ADA_ACT: |
|
self.adavocab_head = AdaVocabHead_MLP(self.lm_head, self.sub_vocab_dim, activation_func=nn.GELU()) |
|
else: |
|
self.adavocab_head = AdaVocabHead_LORA(self.lm_head, self.sub_vocab_dim, svd=config.ADA_SVD) |
|
|
|
self.freeze_original_model() |
|
|
|
def freeze_original_model(self): |
|
|
|
for param in self.model.parameters(): |
|
param.requires_grad = False |
|
for param in self.lm_head.parameters(): |
|
param.requires_grad = False |
|
for param in self.adavocab_head.parameters(): |
|
param.requires_grad = True |
|
|
|
def offload_lm_head(self): |
|
self.offload_tag = True |
|
self.lm_head = self.lm_head.to(torch.device('cpu')) |
|
|
|
def topk_mask(self, logits): |
|
|
|
topk_values, topk_indices = torch.topk(logits, self.config.ADA_TOPK, dim=-1) |
|
|
|
mask = torch.zeros_like(logits) |
|
|
|
mask.scatter_(dim=-1, index=topk_indices, src=torch.ones_like(mask)) |
|
return mask |
|
|
|
def pred_with_sliced_lm_head_simple(self, ada_logits, hidden_states): |
|
|
|
|
|
|
|
ada_logits, topk_indices = torch.topk(ada_logits, self.config.ADA_TOPK, dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
gt_zero_pos = torch.nonzero(ada_logits[:, -1, :] > 0, as_tuple=True)[-1].shape[0] |
|
ada_index_slice = topk_indices[:, :, :gt_zero_pos].flatten().to(self.lm_head.weight.device) |
|
|
|
sliced_lm_head_weight = self.lm_head.weight[ada_index_slice, :].contiguous().to(hidden_states.device) |
|
lm_logits_sliced = hidden_states @ sliced_lm_head_weight.T |
|
|
|
return lm_logits_sliced, ada_index_slice |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
) -> Union[Tuple, CausalLMOutputWithPast]: |
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
outputs = self.model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
hidden_states = outputs[0] |
|
batch_size, seq_len, _ = hidden_states.size() |
|
vocab_size = self.lm_head.weight.shape[0] |
|
|
|
|
|
|
|
|
|
|
|
self.adavocab_head.A.to(hidden_states.device) |
|
self.adavocab_head.B.to(hidden_states.device) |
|
ada_logits = self.adavocab_head(hidden_states[:, -1:, :]) |
|
|
|
self.adavocab_head.A.to("cpu") |
|
self.adavocab_head.B.to("cpu") |
|
|
|
lm_head_logits = None |
|
lm_loss, mask_loss, topk_loss = None, None, None |
|
loss = None |
|
|
|
if labels is not None: |
|
|
|
|
|
lm_head_logits = self.lm_head(hidden_states) |
|
lm_head_logits = lm_head_logits.float() |
|
|
|
|
|
|
|
|
|
|
|
if self.training: |
|
|
|
|
|
shift_logits = ada_logits[..., :-1, :].contiguous() |
|
shift_labels = labels[..., 1:].contiguous() |
|
|
|
|
|
loss_fct = CrossEntropyLoss() |
|
shift_logits = shift_logits.view(-1, self.config.vocab_size) |
|
|
|
shift_labels = shift_labels.view(-1) |
|
shift_labels = shift_labels.to(shift_logits.device) |
|
|
|
lm_loss = loss_fct(shift_logits, shift_labels) |
|
else: |
|
_, lm_loss = self.pred_with_sliced_lm_head(ada_logits, hidden_states, input_ids, labels, min_logit=-100) |
|
|
|
|
|
ada_logits_flat = ada_logits.view(-1, self.config.vocab_size) |
|
ada_probs = torch.sigmoid(ada_logits_flat) |
|
|
|
topk_gt_mask = self.topk_mask(lm_head_logits) |
|
|
|
topk_gt_mask = topk_gt_mask.view(-1, self.config.vocab_size) |
|
|
|
mask_loss_fct = BCEWithLogitsLoss() |
|
mask_loss = mask_loss_fct(ada_logits_flat, topk_gt_mask) |
|
|
|
ada_ones = ada_probs.sum() |
|
|
|
target_ones = batch_size * seq_len * self.config.ADA_TOPK |
|
target_ones = torch.tensor(target_ones, dtype=torch.float32).to(ada_ones.device) |
|
|
|
topk_loss = F.l1_loss(ada_ones, target_ones) / target_ones |
|
|
|
loss = self.config.ADA_LOSS_WEIGHT * lm_loss + self.config.ADA_MASK_WEIGHT * mask_loss + self.config.ADA_TOPK_WEIGHT * topk_loss |
|
else: |
|
with torch.no_grad(): |
|
ada_logits, lm_head_logits = self.pred_with_sliced_lm_head_simple(ada_logits, hidden_states[:, -1:, :]) |
|
|
|
if not return_dict: |
|
output = (ada_logits,) + outputs[1:] |
|
return (loss,) + output if loss is not None else output |
|
|
|
return AdaCausalLMOutputWithPast( |
|
loss=loss, |
|
logits=ada_logits, |
|
past_key_values=outputs.past_key_values, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
|
|
lm_head_logits=lm_head_logits if lm_head_logits is not None else None, |
|
lm_loss=self.config.ADA_LOSS_WEIGHT * lm_loss if lm_loss is not None else None, |
|
mask_loss=self.config.ADA_MASK_WEIGHT * mask_loss if mask_loss is not None else None, |
|
topk_loss=self.config.ADA_TOPK_WEIGHT * topk_loss if topk_loss is not None else None, |
|
) |
|
|
|
def get_input_embeddings(self): |
|
return self.model.embed_tokens |
|
|
|
def set_input_embeddings(self, value): |
|
self.model.embed_tokens = value |
|
|
|
def get_output_embeddings(self): |
|
return self.lm_head |
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
self.lm_head = new_embeddings |
|
|
|
|
|
return AdaVocabCausalLM |
|
|
|
AdaVocabLlamaForCausalLM = create_AdaVocabCausalLM(LlamaForCausalLM) |
|
AdaVocabGemmaforCausalLM = create_AdaVocabCausalLM(GemmaForCausalLM) |
|
AdaVocabQwen2ForCausalLM = create_AdaVocabCausalLM(Qwen2ForCausalLM) |
|
|
|
|