Llama-3-8B-NOLA
Llama-3-8B-NOLA is a fine-tuned variant of meta-llama/Meta-Llama-3-8B on OpenAssistant/oasst1 for 100 steps only. The goal of this experiment was to try out this new technique NOLA and see the the number of trainable parameters. Due to limited compute, the results of this experiment might not be satisfactory. 1x A5000 was used for this experiment.
NOLA (Compressing LoRA using Linear Combination of Random Basis)
NOLA is a novel approach for fine-tuning large models such as LLMs and Vision Transformers. Similar to LoRA, NOLA uses a low-rank decomposition of weight matrices for the fine-tuning step. However, LoRA face two primary limitations:
- The parameter count is lower-bounded by the rank one decomposition
- The extent of reduction is heavily influenced by both the model architecture and the chosen rank.
NOLA brings parameter count felexiblity to LoRA. NOLA achieves this by re-parameterizing the low-rank matrices in LoRA using linear combinations of randomly generated matrices (basis) and optimizing the linear coefficients only. This approach allows us to decouple the number of trainable parameters from both the choice of rank and the network architecture.
Evaluation Results
***** train metrics ***** epoch = 0.0907 total_flos = 4091386GF train_loss = 1.618 train_runtime = 2:02:24.94 train_samples_per_second = 0.109 train_steps_per_second = 0.014
***** eval metrics ***** epoch = 0.0907 eval_loss = 1.5115 eval_runtime = 0:11:33.00 eval_samples_per_second = 0.144 eval_steps_per_second = 0.144
Usage
!pip install -qU transformers accelerate bitsandbytes
from transformers import AutoTokenizer
import transformers
import torch
model = "QueryloopAI/Llama-3-8B-NOLA"
tokenizer = AutoTokenizer.from_pretrained(model)
pipeline = transformers.pipeline(
"text-generation",
model=model,
model_kwargs={"torch_dtype":torch.bfloat16,"load_in_4bit":True}
)
prompt = '''What is Machine Learning?'''
prompt = f"""{prompt}"""
outputs = pipeline(prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
print(outputs[0]["generated_text"])
- Downloads last month
- 10