KB-VQA / my_model /LLAMA2 /LLAMA2_model.py
m7mdal7aj's picture
Update my_model/LLAMA2/LLAMA2_model.py
045e961 verified
raw
history blame
7.57 kB
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from typing import Optional
import bitsandbytes # only for using on GPU
import accelerate # only for using on GPU
from my_model.config import LLAMA2_config as config
import warnings
# Suppress only FutureWarning from transformers
warnings.filterwarnings("ignore", category=FutureWarning, module="transformers")
class Llama2ModelManager:
"""
Manages loading and configuring the LLaMA-2 model and tokenizer.
Attributes:
device (str): Device to use for the model ('cuda' or 'cpu').
model_name (str): Name or path of the pre-trained model.
tokenizer_name (str): Name or path of the tokenizer.
quantization (str): Specifies the quantization level ('4bit', '8bit', or None).
from_saved (bool): Flag to load the model from a saved path.
model_path (str or None): Path to the saved model if `from_saved` is True.
trust_remote (bool): Whether to trust remote code when loading the tokenizer.
use_fast (bool): Whether to use the fast version of the tokenizer.
add_eos_token (bool): Whether to add an EOS token to the tokenizer.
access_token (str): Access token for Hugging Face Hub.
model (AutoModelForCausalLM or None): Loaded model, initially None.
"""
def __init__(self) -> None:
"""
Initializes the Llama2ModelManager class with configuration settings.
"""
self.device: str = config.DEVICE
self.model_name: str = config.MODEL_NAME
self.tokenizer_name: str = config.TOKENIZER_NAME
self.quantization: str = config.QUANTIZATION
self.from_saved: bool = config.FROM_SAVED
self.model_path: Optional[str] = config.MODEL_PATH
self.trust_remote: bool = config.TRUST_REMOTE
self.use_fast: bool = config.USE_FAST
self.add_eos_token: bool = config.ADD_EOS_TOKEN
self.access_token: str = config.ACCESS_TOKEN
self.model: Optional[AutoModelForCausalLM] = None
def create_bnb_config(self) -> BitsAndBytesConfig:
"""
Creates a BitsAndBytes configuration based on the quantization setting.
Returns:
BitsAndBytesConfig: Configuration for BitsAndBytes optimized model.
"""
if self.quantization == '4bit':
return BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
elif self.quantization == '8bit':
return BitsAndBytesConfig(
load_in_8bit=True,
bnb_8bit_use_double_quant=True,
bnb_8bit_quant_type="nf4",
bnb_8bit_compute_dtype=torch.bfloat16
)
def load_model(self) -> AutoModelForCausalLM:
"""
Loads the LLaMA-2 model based on the specified configuration.
If the model is already loaded, returns the existing model.
Returns:
AutoModelForCausalLM: Loaded LLaMA-2 model.
"""
if self.model is not None:
print("Model is already loaded.")
return self.model
if self.from_saved:
self.model = AutoModelForCausalLM.from_pretrained(self.model_path, device_map="auto")
else:
bnb_config = None if self.quantization is None else self.create_bnb_config()
self.model = AutoModelForCausalLM.from_pretrained(self.model_name, device_map="auto",
quantization_config=bnb_config,
torch_dtype=torch.float16,
token=self.access_token)
if self.model is not None:
print(f"LLAMA2 Model loaded successfully in {self.quantization} quantization.")
else:
print("LLAMA2 Model failed to load.")
return self.model
def load_tokenizer(self) -> AutoTokenizer:
"""
Loads the tokenizer for the LLaMA-2 model with the specified configuration.
Returns:
AutoTokenizer: Loaded tokenizer for LLaMA-2 model.
"""
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, use_fast=self.use_fast,
token=self.access_token,
trust_remote_code=self.trust_remote,
add_eos_token=self.add_eos_token)
if self.tokenizer is not None:
print(f"LLAMA2 Tokenizer loaded successfully.")
else:
print("LLAMA2 Tokenizer failed to load.")
return self.tokenizer
def load_model_and_tokenizer(self, for_fine_tuning: bool) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
"""
Loads the LLaMA-2 model and tokenizer, and optionally adds special tokens for fine-tuning.
Args:
for_fine_tuning (bool): Whether to prepare the model and tokenizer for fine-tuning.
Returns:
Tuple[AutoModelForCausalLM, AutoTokenizer]: The loaded model and tokenizer.
"""
if for_fine_tuning:
self.tokenizer = self.load_tokenizer()
self.model = self.load_model()
self.add_special_tokens()
else:
self.tokenizer = self.load_tokenizer()
self.model = self.load_model()
return self.model, self.tokenizer
def add_special_tokens(self, tokens: Optional[List[str]] = None) -> None:
"""
Adds special tokens to the tokenizer and updates the model's token embeddings if the model is loaded.
Args:
tokens (Optional[List[str]]): Special tokens to add. Defaults to a predefined set.
Returns:
None
"""
if self.tokenizer is None:
print("Tokenizer is not loaded. Cannot add special tokens.")
return
if tokens is None:
tokens = ['[CAP]', '[/CAP]', '[QES]', '[/QES]', '[OBJ]', '[/OBJ]']
# Update the tokenizer with new tokens
print(f"Original vocabulary size: {len(self.tokenizer)}")
print(f"Adding the following tokens: {tokens}")
self.tokenizer.add_tokens(tokens, special_tokens=True)
self.tokenizer.add_special_tokens({'pad_token': '<pad>'})
print(f"Adding Padding Token {self.tokenizer.pad_token}")
self.tokenizer.padding_side = "right"
print(f'Padding side: {self.tokenizer.padding_side}')
# Resize the model token embeddings if the model is loaded
if self.model is not None:
self.model.resize_token_embeddings(len(self.tokenizer))
self.model.config.pad_token_id = self.tokenizer.pad_token_id
print(f'Updated Vocabulary Size: {len(self.tokenizer)}')
print(f'Padding Token: {self.tokenizer.pad_token}')
print(f'Special Tokens: {self.tokenizer.added_tokens_decoder}')
if __name__ == "__main__":
pass # uncomment to to load the mode and tokenizer and add the designed special tokens.
LLAMA2_manager = Llama2ModelManager()
LLAMA2_model = LLAMA2_manager.load_model() # First time loading the model
LLAMA2_tokenizer = LLAMA2_manager.load_tokenizer()
LLAMA2_manager.add_special_tokens(LLAMA2_model, LLAMA2_tokenizer)