x-lora / README.md
mjbuehler's picture
Update README.md
569d3a4 verified
|
raw
history blame
No virus
6.25 kB
metadata
license: apache-2.0

X-LoRA: Mixture of Low-Rank Adapter Experts, a Flexible Framework for Large Language Models

X-LoRA works by learning scaling values for LoRA adapters. These learned scalings values are used to gate the LoRA experts in a dense fashion. Additionally, all LoRA adapters and the base model are frozen, allowing efficient fine tuning due to a low parameter count.

X-LoRA is easily applied to any HuggingFace Transformers model.

Features

  • Effective: Dense gating of experts allows effective mixing
  • Efficient fine-tuning: low trainable parameter count
  • Hierarchical encapsulated strategy: Re-use existing trained models or model section and re-use them to address complex tasks that cut across experts, following a bio-inspired strategy
  • Easy-to-use API: add_xlora_to_model, broad compatibility
  • Dynamically mix LoRA adapters: Deep layer-wise combinations of adapters.

X-LoRA source code

Installation, source code, API details and more examples:

https://github.com/EricLBuehler/xlora

Converting and loading a model

Example for model conversation:

import torch
import xlora
from transformers import AutoConfig, AutoModelForCausalLM # type: ignore

model = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mistral-7B-Instruct-v0.1",
    trust_remote_code=True,
    use_flash_attention_2=False,
    device_map="cuda:0",
    torch_dtype=torch.bfloat16,
)

config = AutoConfig.from_pretrained(
    "mistralai/Mistral-7B-Instruct-v0.1",
    trust_remote_code=True,
    use_flash_attention_2=False,
    device_map="auto",
)

### Convert the model to X-LoRA
model_created = xlora.add_xlora_to_model(
    model=model,
    xlora_config=xlora.xLoRAConfig(config.hidden_size, xlora_depth=8, device=torch.device("cuda")),
    verbose=True,
    adapters={
        "adapter_1": "./path/to/the/checkpoint_adapter_1/",
        "adapter_2": "./path/to/the/checkpoint_adapter_2/",
        "adapter_n": "./path/to/the/checkpoint_adapter_3/",
    },
)

Loading a trained X-LoRA model from scratch

import torch
import xlora
from transformers import AutoConfig, AutoModelForCausalLM # type: ignore

model = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mistral-7B-Instruct-v0.1",
    trust_remote_code=True,
    use_flash_attention_2=False,
    device_map="cuda:0",
    torch_dtype=torch.bfloat16,
)

config = AutoConfig.from_pretrained(
    "mistralai/Mistral-7B-Instruct-v0.1",
    trust_remote_code=True,
    use_flash_attention_2=False,
    device_map="auto",
)

model = xlora.from_pretrained(
    "./path/to/saved/model",
    model,
    {
        "adapter_1": "./path/to/the/checkpoint/",
        "adapter_2": "./path/to/the/checkpoint/",
        "adapter_n": "./path/to/the/checkpoint/",
    },
    "cuda",
)

Loading pre-trained X-LoRA model

import torch
from xlora.xlora_utils import load_model  # type: ignore

XLoRA_model_name = "lamm-mit/x-lora/X-LoRA"

model, tokenizer = load_model(
    model_name="HuggingFaceH4/zephyr-7b-beta",
    device="cuda:0",
    dtype=torch.bfloat16,
    fine_tune_model_name=XLoRA_model_name,
    adapters={
        "adapter_1": "lamm-mit/x-lora/X-LoRA_adapters/1/",
        "adapter_2": "lamm-mit/x-lora/X-LoRA_adapters/2/",
        "adapter_3": "lamm-mit/x-lora/X-LoRA_adapters/3/",
        "adapter_4": "lamm-mit/x-lora/X-LoRA_adapters/4/",
        "adapter_5": "lamm-mit/x-lora/X-LoRA_adapters/5/",
        "adapter_6": "lamm-mit/x-lora/X-LoRA_adapters/6/",
        "adapter_7": "lamm-mit/x-lora/X-LoRA_adapters/7/",
        "adapter_8": "lamm-mit/x-lora/X-LoRA_adapters/8/",
        "adapter_9": "lamm-mit/x-lora/X-LoRA_adapters/9/",
    },
)

Inference:

def generate_response (model, tokenizer,
                      text_input="What is the best biomaterial for superior strength?",
                      num_return_sequences = 1,
                      temperature = 0.75,  
                      max_new_tokens = 127,
                      num_beams = 1,
                      top_k = 50,
                      top_p = 0.9, repetition_penalty=1.,
                      eos_token_id=2, 
                      add_special_tokens=True,  
                      ):
    inputs = tokenizer(text_input,  
    with torch.no_grad():
          outputs = model.generate(input_ids = inputs["input_ids"],
                                    attention_mask = inputs["attention_mask"] ,  
                                    max_new_tokens=max_new_tokens,
                                    temperature=temperature, 
                                    num_beams=num_beams,
                                    top_k = top_k,
                                    top_p = top_p,
                                    num_return_sequences = num_return_sequences,
                                    eos_token_id=eos_token_id,
                                    pad_token_id = eos_token_id,
                                    do_sample =True, 
                                    repetition_penalty=repetition_penalty,
                                  )
    return tokenizer.batch_decode(outputs[:,inputs["input_ids"].shape[1]:].detach().cpu().numpy(), skip_special_tokens=True)

output_text=generate_response (model, tokenizer, text_input=txt,eos_token_id=eos_token,
                                           num_return_sequences=1, repetition_penalty=1.1,
                                           top_p=0.9, top_k=512, 
                                           temperature=0.5,
                                           max_new_tokens=256)

print (output_text[0])

Acknowledgements

This work is built on the Hugging Face PEFT library and other components in the Hugging Face ecosystem.

Original paper and citation

Cite this work as:

@article{NiBuehler_2024,
    title   = {X-LoRA: Mixture of Low-Rank Adapter Experts, a Flexible Framework for Large Language Models with Applications in Protein Mechanics and Design},
    author  = {E.L. Buehler, M.J. Buehler},
    journal = {},
    year    = {2024},
    volume  = {},
    pages   = {},
    url     = {https://arxiv.org/abs/XXXX.YYYYY}
}