|
--- |
|
license: apache-2.0 |
|
--- |
|
# X-LoRA: Mixture of Low-Rank Adapter Experts, a Flexible Framework for Large Language Models |
|
|
|
X-LoRA works by learning scaling values for LoRA adapters. These learned scalings values are used to |
|
gate the LoRA experts in a dense fashion. Additionally, all LoRA adapters and the base model are frozen, allowing efficient fine tuning due to a low parameter count. |
|
|
|
X-LoRA is easily applied to any HuggingFace Transformers model. |
|
|
|
## Features |
|
- Effective: Dense gating of experts allows effective mixing |
|
- Efficient fine-tuning: low trainable parameter count |
|
- Hierarchical encapsulated strategy: Re-use existing trained models or model section and re-use them to address complex tasks that cut across experts, following a bio-inspired strategy |
|
- Easy-to-use API: `add_xlora_to_model`, broad compatibility |
|
- Dynamically mix LoRA adapters: Deep layer-wise combinations of adapters. |
|
|
|
## X-LoRA source code |
|
|
|
Install directly from source |
|
|
|
``` |
|
pip install git+https://github.com/EricLBuehler/xlora.git -U |
|
``` |
|
|
|
Further details on installation, packages with source code, API details and more examples: |
|
|
|
[https://github.com/EricLBuehler/xlora](https://github.com/EricLBuehler/xlora) |
|
|
|
## Converting and loading a model |
|
|
|
Example for model conversation: |
|
|
|
```python |
|
import torch |
|
import xlora |
|
from transformers import AutoConfig, AutoModelForCausalLM # type: ignore |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
"mistralai/Mistral-7B-Instruct-v0.1", |
|
trust_remote_code=True, |
|
use_flash_attention_2=False, |
|
device_map="cuda:0", |
|
torch_dtype=torch.bfloat16, |
|
) |
|
|
|
config = AutoConfig.from_pretrained( |
|
"mistralai/Mistral-7B-Instruct-v0.1", |
|
trust_remote_code=True, |
|
use_flash_attention_2=False, |
|
device_map="auto", |
|
) |
|
|
|
### Convert the model to X-LoRA |
|
model_created = xlora.add_xlora_to_model( |
|
model=model, |
|
xlora_config=xlora.xLoRAConfig(config.hidden_size, xlora_depth=8, device=torch.device("cuda")), |
|
verbose=True, |
|
adapters={ |
|
"adapter_1": "./path/to/the/checkpoint_adapter_1/", |
|
"adapter_2": "./path/to/the/checkpoint_adapter_2/", |
|
"adapter_n": "./path/to/the/checkpoint_adapter_3/", |
|
}, |
|
) |
|
``` |
|
|
|
## Loading a trained X-LoRA model from scratch |
|
```python |
|
import torch |
|
import xlora |
|
from transformers import AutoConfig, AutoModelForCausalLM # type: ignore |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
"mistralai/Mistral-7B-Instruct-v0.1", |
|
trust_remote_code=True, |
|
use_flash_attention_2=False, |
|
device_map="cuda:0", |
|
torch_dtype=torch.bfloat16, |
|
) |
|
|
|
config = AutoConfig.from_pretrained( |
|
"mistralai/Mistral-7B-Instruct-v0.1", |
|
trust_remote_code=True, |
|
use_flash_attention_2=False, |
|
device_map="auto", |
|
) |
|
|
|
model = xlora.from_pretrained( |
|
"./path/to/saved/model", |
|
model, |
|
{ |
|
"adapter_1": "./path/to/the/checkpoint/", |
|
"adapter_2": "./path/to/the/checkpoint/", |
|
"adapter_n": "./path/to/the/checkpoint/", |
|
}, |
|
"cuda", |
|
) |
|
``` |
|
## Loading pre-trained X-LoRA model |
|
|
|
```python |
|
import torch |
|
from xlora.xlora_utils import load_model # type: ignore |
|
|
|
XLoRA_model_name = "lamm-mit/x-lora/X-LoRA" |
|
|
|
model, tokenizer = load_model( |
|
model_name="HuggingFaceH4/zephyr-7b-beta", |
|
device="cuda:0", |
|
dtype=torch.bfloat16, |
|
fine_tune_model_name=XLoRA_model_name, |
|
adapters={ |
|
"adapter_1": "lamm-mit/x-lora/X-LoRA_adapters/1/", |
|
"adapter_2": "lamm-mit/x-lora/X-LoRA_adapters/2/", |
|
"adapter_3": "lamm-mit/x-lora/X-LoRA_adapters/3/", |
|
"adapter_4": "lamm-mit/x-lora/X-LoRA_adapters/4/", |
|
"adapter_5": "lamm-mit/x-lora/X-LoRA_adapters/5/", |
|
"adapter_6": "lamm-mit/x-lora/X-LoRA_adapters/6/", |
|
"adapter_7": "lamm-mit/x-lora/X-LoRA_adapters/7/", |
|
"adapter_8": "lamm-mit/x-lora/X-LoRA_adapters/8/", |
|
"adapter_9": "lamm-mit/x-lora/X-LoRA_adapters/9/", |
|
}, |
|
) |
|
``` |
|
Inference: |
|
```python |
|
def generate_response (model, tokenizer, |
|
text_input="What is the best biomaterial for superior strength?", |
|
num_return_sequences = 1, |
|
temperature = 0.75, |
|
max_new_tokens = 127, |
|
num_beams = 1, |
|
top_k = 50, |
|
top_p = 0.9, |
|
repetition_penalty=1., |
|
eos_token_id=2, |
|
add_special_tokens=True, |
|
): |
|
inputs = tokenizer(text_input, add_special_tokens=add_special_tokens) |
|
with torch.no_grad(): |
|
outputs = model.generate(input_ids = inputs["input_ids"], |
|
attention_mask = inputs["attention_mask"] , |
|
max_new_tokens=max_new_tokens, |
|
temperature=temperature, |
|
num_beams=num_beams, |
|
top_k = top_k, |
|
top_p = top_p, |
|
num_return_sequences = num_return_sequences, |
|
eos_token_id=eos_token_id, |
|
pad_token_id = eos_token_id, |
|
do_sample =True, |
|
repetition_penalty=repetition_penalty, |
|
) |
|
return tokenizer.batch_decode(outputs[:,inputs["input_ids"].shape[1]:].detach().cpu().numpy(), skip_special_tokens=True) |
|
|
|
output_text=generate_response (model, tokenizer, text_input=txt,eos_token_id=eos_token, |
|
num_return_sequences=1, repetition_penalty=1.1, |
|
top_p=0.9, top_k=512, |
|
temperature=0.5, |
|
max_new_tokens=256) |
|
|
|
print (output_text[0]) |
|
``` |
|
|
|
## Dataset |
|
|
|
See [lamm-mit/x-lora-dataset](https://huggingface.co/datasets/lamm-mit/x-lora-dataset) for the dataset used to train the X-LoRA model. Details on the datasets used to train the original adapters are included in the paper (see reference below). |
|
|
|
## Acknowledgements |
|
|
|
This work is built on the Hugging Face [PEFT library](https://github.com/huggingface/peft/tree/main/src/peft) and other components in the Hugging Face ecosystem. |
|
|
|
## Original paper and citation |
|
|
|
Cite this work as: |
|
```bibtex |
|
@article{Buehler_XLoRA_2024, |
|
title = {X-LoRA: Mixture of Low-Rank Adapter Experts, a Flexible Framework for Large Language Models with Applications in Protein Mechanics and Design}, |
|
author = {E.L. Buehler, M.J. Buehler}, |
|
journal = {}, |
|
year = {2024}, |
|
volume = {}, |
|
pages = {}, |
|
url = {https://arxiv.org/abs/2402.07148} |
|
} |
|
``` |
|
|
|
|