|
import json |
|
import logging |
|
import os |
|
from typing import Dict, List, Optional, Union |
|
|
|
import numpy as np |
|
import torch |
|
import torch.multiprocessing as mp |
|
from peft import PeftModel |
|
from torch import Tensor, device, nn |
|
from tqdm.autonotebook import tqdm, trange |
|
from transformers import ( |
|
AutoModel, |
|
AutoConfig, |
|
PretrainedConfig, |
|
PreTrainedModel, |
|
AutoTokenizer, |
|
LlamaConfig, |
|
MistralConfig, |
|
GemmaConfig, |
|
Qwen2Config, |
|
) |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def batch_to_device(batch, target_device: device): |
|
""" |
|
send a pytorch batch to a device (CPU/GPU) |
|
""" |
|
for key in batch: |
|
if isinstance(batch[key], Tensor): |
|
batch[key] = batch[key].to(target_device) |
|
return batch |
|
|
|
|
|
class LLMEncoderConfig(PretrainedConfig): |
|
def __init__( |
|
self, |
|
pooling_mode: str = "weighted_mean", |
|
max_length: int = 512, |
|
doc_max_length: int = 400, |
|
skip_instruction: bool = True, |
|
**kwargs, |
|
): |
|
if pooling_mode not in ["mean", "weighted_mean", "eos_token", "bos_token"]: |
|
raise ValueError( |
|
(f"Pooling mode {pooling_mode} is not supported.", |
|
"Please choose one of 'mean', 'weighted_mean', 'eos_token', 'bos_token'.") |
|
) |
|
self.pooling_mode = pooling_mode |
|
self.max_length = max_length |
|
self.doc_max_length = doc_max_length |
|
self.skip_instruction = skip_instruction |
|
self.model_config = None |
|
self.base_model = None |
|
|
|
super().__init__(**kwargs) |
|
|
|
class LLMEncoder(PreTrainedModel): |
|
config_class = LLMEncoderConfig |
|
|
|
def __init__( |
|
self, |
|
model: PreTrainedModel, |
|
tokenizer: AutoTokenizer, |
|
config: LLMEncoderConfig |
|
): |
|
super().__init__(config) |
|
self.model = model |
|
self.tokenizer = tokenizer |
|
self.pooling_mode = config.pooling_mode |
|
self.max_length = config.max_length |
|
self.doc_max_length = config.doc_max_length |
|
self.skip_instruction = config.skip_instruction |
|
self.model_config = None |
|
|
|
@classmethod |
|
def from_pretrained( |
|
self, |
|
base_model_name_or_path, |
|
peft_model_name_or_path=None, |
|
config=None, |
|
**kwargs, |
|
): |
|
""" |
|
Load a pretrained model from a model identifier or path. |
|
Args: |
|
base_model_name_or_path: Model identifier or path to pretrained model. |
|
peft_model_name_or_path: Path to any PEFT models to apply. |
|
Returns: L3Prune model. |
|
""" |
|
|
|
if not config: |
|
config = LLMEncoderConfig() |
|
|
|
if not config.base_model: |
|
config.base_model = base_model_name_or_path |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path) |
|
tokenizer.pad_token = tokenizer.eos_token |
|
tokenizer.padding_side = "left" |
|
|
|
if config.model_config: |
|
model_config = AutoConfig.from_pretrained(config.base_model) |
|
model_config = model_config.from_dict(config.model_config) |
|
else: |
|
model_config = AutoConfig.from_pretrained(base_model_name_or_path) |
|
config.model_config = model_config |
|
|
|
model = AutoModel.from_pretrained(base_model_name_or_path, config=model_config, **kwargs) |
|
|
|
|
|
if peft_model_name_or_path is not None: |
|
model = PeftModel.from_pretrained( |
|
model, |
|
peft_model_name_or_path, |
|
) |
|
model = model.merge_and_unload() |
|
|
|
return self(model=model, tokenizer=tokenizer, config=config) |
|
|
|
def prune(self, percent_prune=0): |
|
""" |
|
Prune a model to a percentage of layers of the base model. If percent_prune is equal to or greater than 1, |
|
it is taken as the specific layer number to prune to. For example, if percent_prune=0.3, 30% of the layers will be pruned. If |
|
percent_prune=3, the model will be pruned to 3 layers. |
|
""" |
|
|
|
if percent_prune >= 1: |
|
new_num_layers = int(percent_prune) |
|
else: |
|
new_num_layers = int(self.model.config.num_hidden_layers * (1 - percent_prune)) |
|
print(f"Pruning to {new_num_layers} layer.") |
|
self.model.layers = self.model.layers[:new_num_layers] |
|
self.model.config.num_hidden_layers = new_num_layers |
|
self.config.model_config.num_hidden_layers = new_num_layers |
|
|
|
def prepare_for_tokenization(self, text): |
|
if self.model.config._name_or_path == "meta-llama/Meta-Llama-3-8B-Instruct": |
|
text = ( |
|
"<|start_header_id|>user<|end_header_id|>\n\n" |
|
+ text.strip() |
|
+ "<|eot_id|>" |
|
) |
|
return text |
|
if self.model.config._name_or_path in [ |
|
"mistralai/Mistral-7B-Instruct-v0.2", |
|
"meta-llama/Llama-2-7b-chat-hf", |
|
]: |
|
text = "[INST] " + text.strip() + " [/INST]" |
|
if self.model.config._name_or_path in [ |
|
"google/gemma-2-9b-it", |
|
]: |
|
text = "<bos><start_of_turn>user\n" + text.strip() + "<end_of_turn>" |
|
if self.model.config._name_or_path in [ |
|
"Qwen/Qwen2-1.5B-Instruct", |
|
"Qwen/Qwen2-7B-Instruct", |
|
]: |
|
text = "<|im_start|>user\n" + text.strip() + "<|im_end|>" |
|
if self.pooling_mode == "eos_token": |
|
if self.model.config._name_or_path == "meta-llama/Meta-Llama-3-8B": |
|
text = text.strip() + "<|end_of_text|>" |
|
elif isinstance(self.model.config, LlamaConfig) or isinstance( |
|
self.model.config, MistralConfig |
|
): |
|
text = text.strip() + " </s>" |
|
elif isinstance(self.model.config, GemmaConfig): |
|
text = text.strip() + "<eos>" |
|
elif isinstance(self.model.config, Qwen2Config): |
|
text = text.strip() + "<|endoftext|>" |
|
return text |
|
|
|
def tokenize(self, texts): |
|
texts_2 = [] |
|
original_texts = [] |
|
for text in texts: |
|
t = text.split("!@#$%^&*()") |
|
texts_2.append(t[1] if len(t) > 1 else "") |
|
original_texts.append("".join(t)) |
|
|
|
original = self.tokenizer( |
|
original_texts, |
|
return_tensors="pt", |
|
padding=True, |
|
truncation=True, |
|
max_length=self.max_length, |
|
) |
|
embed_mask = None |
|
for t_i, t in enumerate(texts_2): |
|
ids = self.tokenizer( |
|
[t], |
|
return_tensors="pt", |
|
padding=True, |
|
truncation=True, |
|
max_length=self.max_length, |
|
add_special_tokens=False, |
|
) |
|
if embed_mask is None: |
|
e_m = torch.zeros_like(original["attention_mask"][t_i]) |
|
if len(ids["input_ids"][0]) > 0: |
|
e_m[-len(ids["input_ids"][0]) :] = torch.ones( |
|
len(ids["input_ids"][0]) |
|
) |
|
embed_mask = e_m.unsqueeze(0) |
|
else: |
|
e_m = torch.zeros_like(original["attention_mask"][t_i]) |
|
if len(ids["input_ids"][0]) > 0: |
|
e_m[-len(ids["input_ids"][0]) :] = torch.ones( |
|
len(ids["input_ids"][0]) |
|
) |
|
embed_mask = torch.cat((embed_mask, e_m.unsqueeze(0)), dim=0) |
|
|
|
original["embed_mask"] = embed_mask |
|
return original |
|
|
|
def _skip_instruction(self, sentence_feature): |
|
assert ( |
|
sentence_feature["attention_mask"].shape |
|
== sentence_feature["embed_mask"].shape |
|
) |
|
sentence_feature["attention_mask"] = sentence_feature["embed_mask"] |
|
|
|
def forward(self, sentence_feature: Dict[str, Tensor]): |
|
embed_mask = None |
|
if "embed_mask" in sentence_feature: |
|
embed_mask = sentence_feature.pop("embed_mask") |
|
reps = self.model(**sentence_feature) |
|
sentence_feature["embed_mask"] = embed_mask |
|
|
|
return self.get_pooling(sentence_feature, reps.last_hidden_state) |
|
|
|
def get_pooling(self, features, last_hidden_states): |
|
assert ( |
|
self.tokenizer.padding_side == "left" |
|
), "Pooling modes are implemented for padding from left." |
|
if self.skip_instruction: |
|
self._skip_instruction(features) |
|
seq_lengths = features["attention_mask"].sum(dim=-1) |
|
if self.pooling_mode == "mean": |
|
return torch.stack( |
|
[ |
|
last_hidden_states[i, -length:, :].mean(dim=0) |
|
for i, length in enumerate(seq_lengths) |
|
], |
|
dim=0, |
|
) |
|
elif self.pooling_mode == "weighted_mean": |
|
bs, l, _ = last_hidden_states.shape |
|
complete_weights = torch.zeros(bs, l, device=last_hidden_states.device) |
|
for i, seq_l in enumerate(seq_lengths): |
|
if seq_l > 0: |
|
complete_weights[i, -seq_l:] = torch.arange(seq_l) + 1 |
|
complete_weights[i] /= torch.clamp( |
|
complete_weights[i].sum(), min=1e-9 |
|
) |
|
return torch.sum(last_hidden_states * complete_weights.unsqueeze(-1), dim=1) |
|
elif self.pooling_mode == "eos_token" or self.pooling_mode == "last_token": |
|
return last_hidden_states[:, -1] |
|
elif self.pooling_mode == "bos_token": |
|
return last_hidden_states[ |
|
features["input_ids"] == self.tokenizer.bos_token_id |
|
] |
|
else: |
|
raise ValueError(f"{self.pooling_mode} is not implemented yet.") |
|
|
|
def _convert_to_str(self, instruction, text): |
|
tokenized_q = self.tokenizer( |
|
text, |
|
return_tensors="pt", |
|
padding=True, |
|
truncation=True, |
|
max_length=self.max_length, |
|
add_special_tokens=False, |
|
) |
|
tokenized_q_length = len(tokenized_q["input_ids"][0]) |
|
|
|
while tokenized_q_length > self.doc_max_length: |
|
reduction_ratio = self.doc_max_length / tokenized_q_length |
|
reduced_length = int(len(text.split()) * reduction_ratio) |
|
text = " ".join(text.split()[:reduced_length]) |
|
tokenized_q = self.tokenizer( |
|
text, |
|
return_tensors="pt", |
|
padding=True, |
|
truncation=True, |
|
max_length=self.max_length, |
|
add_special_tokens=False, |
|
) |
|
tokenized_q_length = len(tokenized_q["input_ids"][0]) |
|
|
|
return ( |
|
f"{instruction.strip()} !@#$%^&*(){text}" |
|
if instruction |
|
else f"!@#$%^&*(){text}" |
|
) |
|
|
|
def encode( |
|
self, |
|
sentences: Union[str, List[str]], |
|
batch_size: int = 32, |
|
show_progress_bar: bool = True, |
|
convert_to_numpy: bool = False, |
|
convert_to_tensor: bool = False, |
|
device: Optional[str] = None, |
|
): |
|
""" |
|
Encode a list of sentences to their respective embeddings. The sentences can be a list of strings or a string. |
|
Args: |
|
sentences: sentence or sentences to encode. |
|
batch_size: batch size for turning sentence tokens into embeddings. |
|
show_progress_bar: whether to show progress bars during encoding steps. |
|
convert_to_numpy: If true, return numpy arrays instead of torch tensors. |
|
convert_to_tensor: If true, return torch tensors (default). |
|
device: torch backend device identifier (e.g., 'cuda', 'cpu','mps' etc.). If not specified, |
|
the default is to use cuda when available, otherwise cpu. Note that only the choice of 'cuda' supports |
|
multiprocessing as currently implemented. |
|
|
|
Returns: embeddings of the sentences. Embeddings are detached and always on the CPU (see _encode implementation). |
|
|
|
""" |
|
if isinstance(sentences[0], str) and isinstance(sentences[-1], int): |
|
sentences = [sentences] |
|
|
|
if isinstance(sentences[0], str): |
|
sentences = [[""] + [sentence] for sentence in sentences] |
|
|
|
if device is None: |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
concatenated_input_texts = [] |
|
for sentence in sentences: |
|
assert isinstance(sentence[0], str) |
|
assert isinstance(sentence[1], str) |
|
concatenated_input_texts.append( |
|
self._convert_to_str(sentence[0], sentence[1]) |
|
) |
|
sentences = concatenated_input_texts |
|
|
|
self.eval() |
|
|
|
if convert_to_tensor: |
|
convert_to_numpy = False |
|
|
|
length_sorted_idx = np.argsort([-self._text_length(sen) for sen in sentences]) |
|
sentences_sorted = [sentences[idx] for idx in length_sorted_idx] |
|
all_embeddings = [] |
|
|
|
if torch.cuda.device_count() <= 1: |
|
|
|
self.to(device) |
|
for start_index in trange( |
|
0, |
|
len(sentences), |
|
batch_size, |
|
desc="Batches", |
|
disable=not show_progress_bar, |
|
): |
|
sentences_batch = sentences_sorted[ |
|
start_index : start_index + batch_size |
|
] |
|
embeddings = self._encode( |
|
sentences_batch, device=device, convert_to_numpy=convert_to_numpy |
|
) |
|
all_embeddings.append(embeddings) |
|
else: |
|
|
|
num_proc = torch.cuda.device_count() |
|
cuda_compatible_multiprocess = mp.get_context("spawn") |
|
with cuda_compatible_multiprocess.Pool(num_proc) as p: |
|
sentences_batches = [ |
|
sentences_sorted[start_index : start_index + batch_size] |
|
for start_index in range(0, len(sentences), batch_size) |
|
] |
|
|
|
progress_bar = tqdm( |
|
total=len(sentences_batches), |
|
desc="Batches", |
|
disable=not show_progress_bar, |
|
) |
|
results = [] |
|
|
|
def update(*args): |
|
progress_bar.update() |
|
|
|
for batch in sentences_batches: |
|
results.append( |
|
p.apply_async( |
|
self._encode, |
|
args=(batch, None, convert_to_numpy, True), |
|
callback=update, |
|
) |
|
) |
|
|
|
all_embeddings = [result.get() for result in results] |
|
progress_bar.close() |
|
|
|
all_embeddings = torch.cat(all_embeddings, dim=0) |
|
all_embeddings = all_embeddings[np.argsort(length_sorted_idx)] |
|
all_embeddings = all_embeddings.to(torch.float32) |
|
if convert_to_numpy: |
|
all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings]) |
|
return all_embeddings |
|
|
|
def save(self, output_path, merge_before_save=False, save_config=True): |
|
if merge_before_save and isinstance(self.model, PeftModel): |
|
self.model = self.model.merge_and_unload() |
|
if hasattr(self.model, "_hf_peft_config_loaded"): |
|
self.model._hf_peft_config_loaded = False |
|
|
|
self.model.save_pretrained(output_path) |
|
self.tokenizer.save_pretrained(output_path) |
|
|
|
l3prune_config = { |
|
"pooling_mode": self.pooling_mode, |
|
"max_length": self.max_length, |
|
"doc_max_length": self.doc_max_length, |
|
"skip_instruction": self.skip_instruction, |
|
} |
|
|
|
if save_config: |
|
os.makedirs(output_path, exist_ok=True) |
|
with open(f"{output_path}/l3prune_config.json", "w") as fOut: |
|
json.dump(l3prune_config, fOut, indent=4) |
|
|
|
def _encode( |
|
self, |
|
sentences_batch, |
|
device: Optional[str] = None, |
|
convert_to_numpy: bool = False, |
|
multiprocessing=False, |
|
): |
|
if multiprocessing: |
|
|
|
|
|
rank = mp.current_process()._identity[0] |
|
if device is None and torch.cuda.is_available(): |
|
device = f"cuda:{rank % torch.cuda.device_count()}" |
|
|
|
self.to(device) |
|
features = self.tokenize( |
|
[self.prepare_for_tokenization(sentence) for sentence in sentences_batch] |
|
) |
|
features = batch_to_device(features, device) |
|
|
|
with torch.no_grad(): |
|
embeddings = self.forward(features) |
|
embeddings = embeddings.detach() |
|
embeddings = embeddings.cpu() |
|
|
|
return embeddings |
|
|
|
def _text_length(self, text: Union[List[int], List[List[int]]]): |
|
""" |
|
Help function to get the length for the input text. Text can be either a string (which means a single text) |
|
a list of ints (which means a single tokenized text), or a tuple of list of ints |
|
(representing several text inputs to the model). |
|
""" |
|
if ( |
|
isinstance(text, str) |
|
or (isinstance(text, list) and isinstance(text[0], int)) |
|
or len(text) == 0 |
|
): |
|
return len(text) |
|
if isinstance(text, dict): |
|
return len(next(iter(text.values()))) |
|
elif not hasattr(text, "__len__"): |
|
return 1 |
|
else: |
|
return sum([len(t) for t in text]) |
|
|
|
def resize_token_embeddings( |
|
self, |
|
new_num_tokens: Optional[int] = None, |
|
pad_to_multiple_of: Optional[int] = None, |
|
) -> nn.Embedding: |
|
return self.model.resize_token_embeddings( |
|
new_num_tokens=new_num_tokens, pad_to_multiple_of=pad_to_multiple_of |
|
) |
|
|
|
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): |
|
self.model.gradient_checkpointing_enable( |
|
gradient_checkpointing_kwargs=gradient_checkpointing_kwargs |
|
) |
|
|
|
def save_pretrained(self, save_directory, **kwargs): |
|
self.tokenizer.save_pretrained(save_directory, **kwargs) |
|
super().save_pretrained(save_directory, **kwargs) |
|
|
|
def push_to_hub(self, repo_id, **kwargs): |
|
self.tokenizer.push_to_hub(repo_id, **kwargs) |
|
super().push_to_hub(repo_id, **kwargs) |