Upload 11 files
Browse files- .gitattributes +2 -0
- my_model/LLAMA2/LLAMA2_config.py +15 -0
- my_model/LLAMA2/LLAMA2_model.py +173 -0
- my_model/extract_objects.py +45 -0
- my_model/fine_tuner/fine_tuner.py +347 -0
- my_model/fine_tuner/fine_tuning_config.py +114 -0
- my_model/fine_tuner/fine_tuning_data/fine_tuning_data_detic.csv +3 -0
- my_model/fine_tuner/fine_tuning_data/fine_tuning_data_yolov5.csv +3 -0
- my_model/fine_tuner/fine_tuning_data/read_me.txt +8 -0
- my_model/fine_tuner/fine_tuning_data_handler.py +182 -0
- my_model/object_detection.py +259 -0
- my_model/utilities.py +278 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
my_model/fine_tuner/fine_tuning_data/fine_tuning_data_detic.csv filter=lfs diff=lfs merge=lfs -text
|
37 |
+
my_model/fine_tuner/fine_tuning_data/fine_tuning_data_yolov5.csv filter=lfs diff=lfs merge=lfs -text
|
my_model/LLAMA2/LLAMA2_config.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Configuration parameters for LLaMA-2 model
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
|
5 |
+
MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf"
|
6 |
+
TOKENIZER_NAME = "meta-llama/Llama-2-7b-chat-hf"
|
7 |
+
QUANTIZATION = '4bit' # Options: '4bit', '8bit', or None
|
8 |
+
FROM_SAVED = False
|
9 |
+
MODEL_PATH = None
|
10 |
+
TRUST_REMOTE = False
|
11 |
+
USE_FAST = True
|
12 |
+
ADD_EOS_TOKEN = True
|
13 |
+
# ACCESS_TOKEN = "xx" # My HF Read-only Token, to be added here if needed
|
14 |
+
huggingface_token = os.getenv('HUGGINGFACE_TOKEN') # for use as a secret on hf space
|
15 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
my_model/LLAMA2/LLAMA2_model.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
3 |
+
from typing import Optional
|
4 |
+
import bitsandbytes # only for using on GPU
|
5 |
+
import accelerate # only for using on GPU
|
6 |
+
from my_model.LLAMA2 import LLAMA2_config as config # Importing LLAMA2 configuration file
|
7 |
+
import warnings
|
8 |
+
|
9 |
+
# Suppress only FutureWarning from transformers
|
10 |
+
warnings.filterwarnings("ignore", category=FutureWarning, module="transformers")
|
11 |
+
|
12 |
+
|
13 |
+
class Llama2ModelManager:
|
14 |
+
"""
|
15 |
+
Manages loading and configuring the LLaMA-2 model and tokenizer.
|
16 |
+
|
17 |
+
Attributes:
|
18 |
+
device (str): Device to use for the model ('cuda' or 'cpu').
|
19 |
+
model_name (str): Name or path of the pre-trained model.
|
20 |
+
tokenizer_name (str): Name or path of the tokenizer.
|
21 |
+
quantization (str): Specifies the quantization level ('4bit', '8bit', or None).
|
22 |
+
from_saved (bool): Flag to load the model from a saved path.
|
23 |
+
model_path (str or None): Path to the saved model if `from_saved` is True.
|
24 |
+
trust_remote (bool): Whether to trust remote code when loading the tokenizer.
|
25 |
+
use_fast (bool): Whether to use the fast version of the tokenizer.
|
26 |
+
add_eos_token (bool): Whether to add an EOS token to the tokenizer.
|
27 |
+
access_token (str): Access token for Hugging Face Hub.
|
28 |
+
model (AutoModelForCausalLM or None): Loaded model, initially None.
|
29 |
+
"""
|
30 |
+
|
31 |
+
def __init__(self) -> None:
|
32 |
+
"""
|
33 |
+
Initializes the Llama2ModelManager class with configuration settings.
|
34 |
+
"""
|
35 |
+
self.device: str = config.DEVICE
|
36 |
+
self.model_name: str = config.MODEL_NAME
|
37 |
+
self.tokenizer_name: str = config.TOKENIZER_NAME
|
38 |
+
self.quantization: str = config.QUANTIZATION
|
39 |
+
self.from_saved: bool = config.FROM_SAVED
|
40 |
+
self.model_path: Optional[str] = config.MODEL_PATH
|
41 |
+
self.trust_remote: bool = config.TRUST_REMOTE
|
42 |
+
self.use_fast: bool = config.USE_FAST
|
43 |
+
self.add_eos_token: bool = config.ADD_EOS_TOKEN
|
44 |
+
self.access_token: str = config.ACCESS_TOKEN
|
45 |
+
self.model: Optional[AutoModelForCausalLM] = None
|
46 |
+
|
47 |
+
def create_bnb_config(self) -> BitsAndBytesConfig:
|
48 |
+
"""
|
49 |
+
Creates a BitsAndBytes configuration based on the quantization setting.
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
BitsAndBytesConfig: Configuration for BitsAndBytes optimized model.
|
53 |
+
"""
|
54 |
+
if self.quantization == '4bit':
|
55 |
+
return BitsAndBytesConfig(
|
56 |
+
load_in_4bit=True,
|
57 |
+
bnb_4bit_use_double_quant=True,
|
58 |
+
bnb_4bit_quant_type="nf4",
|
59 |
+
bnb_4bit_compute_dtype=torch.bfloat16
|
60 |
+
)
|
61 |
+
elif self.quantization == '8bit':
|
62 |
+
return BitsAndBytesConfig(
|
63 |
+
load_in_8bit=True,
|
64 |
+
bnb_8bit_use_double_quant=True,
|
65 |
+
bnb_8bit_quant_type="nf4",
|
66 |
+
bnb_8bit_compute_dtype=torch.bfloat16
|
67 |
+
)
|
68 |
+
|
69 |
+
def load_model(self) -> AutoModelForCausalLM:
|
70 |
+
"""
|
71 |
+
Loads the LLaMA-2 model based on the specified configuration. If the model is already loaded, returns the existing model.
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
AutoModelForCausalLM: Loaded LLaMA-2 model.
|
75 |
+
"""
|
76 |
+
if self.model is not None:
|
77 |
+
print("Model is already loaded.")
|
78 |
+
return self.model
|
79 |
+
|
80 |
+
if self.from_saved:
|
81 |
+
self.model = AutoModelForCausalLM.from_pretrained(self.model_path, device_map="auto")
|
82 |
+
else:
|
83 |
+
bnb_config = None if self.quantization is None else self.create_bnb_config()
|
84 |
+
self.model = AutoModelForCausalLM.from_pretrained(self.model_name, device_map="auto",
|
85 |
+
quantization_config=bnb_config,
|
86 |
+
torch_dtype=torch.float16,
|
87 |
+
token=self.access_token)
|
88 |
+
|
89 |
+
if self.model is not None:
|
90 |
+
print(f"LLAMA2 Model loaded successfully in {self.quantization} quantization.")
|
91 |
+
else:
|
92 |
+
print("LLAMA2 Model failed to load.")
|
93 |
+
return self.model
|
94 |
+
|
95 |
+
def load_tokenizer(self) -> AutoTokenizer:
|
96 |
+
"""
|
97 |
+
Loads the tokenizer for the LLaMA-2 model with the specified configuration.
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
AutoTokenizer: Loaded tokenizer for LLaMA-2 model.
|
101 |
+
"""
|
102 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, use_fast=self.use_fast,
|
103 |
+
token=self.access_token,
|
104 |
+
trust_remote_code=self.trust_remote,
|
105 |
+
add_eos_token=self.add_eos_token)
|
106 |
+
|
107 |
+
if self.tokenizer is not None:
|
108 |
+
print(f"LLAMA2 Tokenizer loaded successfully.")
|
109 |
+
else:
|
110 |
+
print("LLAMA2 Tokenizer failed to load.")
|
111 |
+
|
112 |
+
return self.tokenizer
|
113 |
+
|
114 |
+
def load_model_and_tokenizer(self, for_fine_tuning):
|
115 |
+
"""
|
116 |
+
Loads LLAMa2 model and tokenizer in one method and adds special tokens if the purpose if fine tuning.
|
117 |
+
:param for_fine_tuning: YES(True) / NO (False)
|
118 |
+
:return: LLAMA2 Model and Tokenizer
|
119 |
+
"""
|
120 |
+
if for_fine_tuning:
|
121 |
+
self.tokenizer = self.load_tokenizer()
|
122 |
+
self.model = self.load_model()
|
123 |
+
self.add_special_tokens()
|
124 |
+
else:
|
125 |
+
self.tokenizer = self.load_tokenizer()
|
126 |
+
self.model = self.load_model()
|
127 |
+
|
128 |
+
return self.model, self.tokenizer
|
129 |
+
|
130 |
+
|
131 |
+
def add_special_tokens(self, tokens: Optional[list[str]] = None) -> None:
|
132 |
+
"""
|
133 |
+
Adds special tokens to the tokenizer and updates the model's token embeddings if the model is loaded,
|
134 |
+
only if the tokenizer is loaded.
|
135 |
+
|
136 |
+
Args:
|
137 |
+
tokens (list of str, optional): Special tokens to add. Defaults to a predefined set.
|
138 |
+
|
139 |
+
Returns:
|
140 |
+
None
|
141 |
+
"""
|
142 |
+
if self.tokenizer is None:
|
143 |
+
print("Tokenizer is not loaded. Cannot add special tokens.")
|
144 |
+
return
|
145 |
+
|
146 |
+
if tokens is None:
|
147 |
+
tokens = ['[CAP]', '[/CAP]', '[QES]', '[/QES]', '[OBJ]', '[/OBJ]']
|
148 |
+
|
149 |
+
# Update the tokenizer with new tokens
|
150 |
+
print(f"Original vocabulary size: {len(self.tokenizer)}")
|
151 |
+
print(f"Adding the following tokens: {tokens}")
|
152 |
+
self.tokenizer.add_tokens(tokens, special_tokens=True)
|
153 |
+
self.tokenizer.add_special_tokens({'pad_token': '<pad>'})
|
154 |
+
print(f"Adding Padding Token {self.tokenizer.pad_token}")
|
155 |
+
self.tokenizer.padding_side = "right"
|
156 |
+
print(f'Padding side: {self.tokenizer.padding_side}')
|
157 |
+
|
158 |
+
# Resize the model token embeddings if the model is loaded
|
159 |
+
if self.model is not None:
|
160 |
+
self.model.resize_token_embeddings(len(self.tokenizer))
|
161 |
+
self.model.config.pad_token_id = self.tokenizer.pad_token_id
|
162 |
+
|
163 |
+
print(f'Updated Vocabulary Size: {len(self.tokenizer)}')
|
164 |
+
print(f'Padding Token: {self.tokenizer.pad_token}')
|
165 |
+
print(f'Special Tokens: {self.tokenizer.added_tokens_decoder}')
|
166 |
+
|
167 |
+
|
168 |
+
if __name__ == "__main__":
|
169 |
+
pass
|
170 |
+
LLAMA2_manager = Llama2ModelManager()
|
171 |
+
LLAMA2_model = LLAMA2_manager.load_model() # First time loading the model
|
172 |
+
LLAMA2_tokenizer = LLAMA2_manager.load_tokenizer()
|
173 |
+
LLAMA2_manager.add_special_tokens(LLAMA2_model, LLAMA2_tokenizer)
|
my_model/extract_objects.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from object_detection import ObjectDetector
|
2 |
+
import os
|
3 |
+
|
4 |
+
def detect_objects_for_image(image_name, detector):
|
5 |
+
|
6 |
+
if os.path.exists(image_path):
|
7 |
+
image = detector.process_image(image_path)
|
8 |
+
detected_objects_str, _ = detector.detect_objects(image)
|
9 |
+
return detected_objects_str
|
10 |
+
else:
|
11 |
+
return "Image not found"
|
12 |
+
|
13 |
+
def add_detected_objects_to_dataframe(df, image_directory, detector):
|
14 |
+
"""
|
15 |
+
Adds a column to the DataFrame with detected objects for each image specified in the 'image_name' column.
|
16 |
+
|
17 |
+
Parameters:
|
18 |
+
df (pd.DataFrame): DataFrame containing a column 'image_name' with image filenames.
|
19 |
+
image_directory (str): Path to the directory containing images.
|
20 |
+
detector (ObjectDetector): An instance of the ObjectDetector class.
|
21 |
+
|
22 |
+
Returns:
|
23 |
+
pd.DataFrame: The original DataFrame with an additional column 'detected_objects'.
|
24 |
+
"""
|
25 |
+
|
26 |
+
# Ensure 'image_name' column exists in the DataFrame
|
27 |
+
if 'image_name' not in df.columns:
|
28 |
+
raise ValueError("DataFrame must contain an 'image_name' column.")
|
29 |
+
|
30 |
+
image_path = os.path.join(image_directory, image_name)
|
31 |
+
|
32 |
+
# Function to detect objects for a given image filename
|
33 |
+
|
34 |
+
|
35 |
+
# Apply the function to each row in the DataFrame
|
36 |
+
df['detected_objects'] = df['image_name'].apply(detect_objects_for_image)
|
37 |
+
|
38 |
+
return df
|
39 |
+
|
40 |
+
# Example usage (assuming the function will be used in a context where 'detector' is defined and configured):
|
41 |
+
# df_images = pd.DataFrame({"image_name": ["image1.jpg", "image2.jpg", ...]})
|
42 |
+
# image_directory = "path/to/image_directory"
|
43 |
+
# updated_df = add_detected_objects_to_dataframe(df_images, image_directory, detector)
|
44 |
+
# updated_df.head()
|
45 |
+
|
my_model/fine_tuner/fine_tuner.py
ADDED
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# Main Fine-Tuning Script for meta-llama/Llama-2-7b-chat-hf
|
3 |
+
|
4 |
+
# This script is the central executable for fine-tuning large language models, specifically designed for the LLAMA2
|
5 |
+
# model.
|
6 |
+
# It encompasses the entire process of fine-tuning, starting from data preparation to the final model training.
|
7 |
+
# The script leverages the 'FinetuningDataHandler' class for data loading, inspection, preparation, and splitting.
|
8 |
+
# This ensures that the dataset is correctly processed and prepared for effective training.
|
9 |
+
|
10 |
+
# The fine-tuning process is managed by the Finetuner class, which handles the training of the model using specific
|
11 |
+
# training arguments and datasets. Advanced configurations for Quantized Low-Rank Adaptation (QLoRA) and Parameter
|
12 |
+
# Efficient Fine-Tuning (PEFT) are utilized to optimize the training process on limited hardware resources.
|
13 |
+
|
14 |
+
# The script is designed to be executed as a standalone process, providing an end-to-end solution for fine-tuning
|
15 |
+
# LLMs. It is a part of a larger project aimed at optimizing the performance of language model to adapt to
|
16 |
+
# OK-VQA dataset.
|
17 |
+
|
18 |
+
# Ensure all dependencies are installed and the required files are in place before running this script.
|
19 |
+
# The configurations for the fine-tuning process are defined in the 'fine_tuning_config.py' file.
|
20 |
+
|
21 |
+
# ---------- Please run this file for the full fine-tuning process to start ----------#
|
22 |
+
# ---------- Please ensure this is run on a GPU ----------#
|
23 |
+
|
24 |
+
|
25 |
+
import torch
|
26 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, TRANSFORMERS_CACHE
|
27 |
+
from trl import SFTTrainer
|
28 |
+
from datasets import Dataset, load_dataset
|
29 |
+
from peft import LoraConfig, prepare_model_for_kbit_training, PeftModel
|
30 |
+
import fine_tuning_config as config
|
31 |
+
from typing import List
|
32 |
+
import bitsandbytes # only on GPU
|
33 |
+
import gc
|
34 |
+
import os
|
35 |
+
import shutil
|
36 |
+
from my_model.LLAMA2.LLAMA2_model import Llama2ModelManager
|
37 |
+
from fine_tuning_data_handler import FinetuningDataHandler
|
38 |
+
|
39 |
+
|
40 |
+
class QLoraConfig:
|
41 |
+
"""
|
42 |
+
Configures QLoRA (Quantized Low-Rank Adaptation) parameters for efficient model fine-tuning.
|
43 |
+
LoRA allows adapting large language models with a minimal number of trainable parameters.
|
44 |
+
|
45 |
+
Attributes:
|
46 |
+
lora_config (LoraConfig): Configuration object for LoRA parameters.
|
47 |
+
"""
|
48 |
+
|
49 |
+
def __init__(self) -> None:
|
50 |
+
"""
|
51 |
+
Initializes QLoraConfig with specific LoRA parameters.
|
52 |
+
|
53 |
+
"""
|
54 |
+
# please refer to config file 'fine_tuning_config.py' for QLORA arguments description.
|
55 |
+
self.lora_config = LoraConfig(
|
56 |
+
lora_alpha=config.LORA_ALPHA,
|
57 |
+
lora_dropout=config.LORA_DROPOUT,
|
58 |
+
r=config.LORA_R,
|
59 |
+
bias="none", # bias is already accounted for in LLAMA2 pre-trained model layers.
|
60 |
+
task_type="CAUSAL_LM",
|
61 |
+
target_modules=['up_proj', 'down_proj', 'k_proj', 'q_proj', 'v_proj', 'o_proj'] # modules for fine-tuning.
|
62 |
+
)
|
63 |
+
|
64 |
+
|
65 |
+
class Finetuner:
|
66 |
+
"""
|
67 |
+
The Finetuner class manages the fine-tuning process of a pre-trained language model using specific
|
68 |
+
training arguments and datasets. It is designed to adapt a pre-trained model on a specific dataset
|
69 |
+
to enhance its performance on similar data.
|
70 |
+
|
71 |
+
This class not only facilitates the fine-tuning of LLAMA2 but also includes advanced
|
72 |
+
resource management capabilities. It provides methods for deleting model and trainer objects,
|
73 |
+
clearing GPU memory, and cleaning up Hugging Face's Transformers cache. These functionalities
|
74 |
+
make the Finetuner class especially useful in environments with limited computational resources
|
75 |
+
or when managing multiple models or training sessions.
|
76 |
+
|
77 |
+
Additionally, the class supports configurations for Quantized Low-Rank Adaptation (QLoRA)
|
78 |
+
to fine-tune models with minimal trainable parameters, and Parameter Efficient Fine-Tuning (PEFT)
|
79 |
+
for training efficiency on limited hardware.
|
80 |
+
|
81 |
+
Attributes:
|
82 |
+
base_model (AutoModelForCausalLM): The pre-trained language model to be fine-tuned.
|
83 |
+
tokenizer (AutoTokenizer): The tokenizer associated with the model.
|
84 |
+
train_dataset (Dataset): The dataset used for training.
|
85 |
+
eval_dataset (Dataset): The dataset used for evaluation.
|
86 |
+
training_arguments (TrainingArguments): Configuration for training the model.
|
87 |
+
|
88 |
+
Key Methods:
|
89 |
+
- load_LLAMA2_for_finetuning: Loads the LLAMA2 model and tokenizer for fine-tuning.
|
90 |
+
- train: Trains the model using PEFT configuration.
|
91 |
+
- delete_model: Deletes a specified model attribute.
|
92 |
+
- delete_trainer: Deletes a specified trainer object.
|
93 |
+
- clear_training_resources: Clears GPU memory.
|
94 |
+
- clear_cache_and_collect_garbage: Clears Transformers cache and performs garbage collection.
|
95 |
+
- find_all_linear_names: Identifies linear layer names suitable for LoRA application.
|
96 |
+
- print_trainable_parameters: Prints the number of trainable parameters in the model.
|
97 |
+
"""
|
98 |
+
|
99 |
+
def __init__(self, train_dataset: Dataset, eval_dataset: Dataset) -> None:
|
100 |
+
"""
|
101 |
+
Initializes the Finetuner class with the model, tokenizer, and datasets.
|
102 |
+
|
103 |
+
Args:
|
104 |
+
model (AutoModelForCausalLM): The pre-trained language model.
|
105 |
+
tokenizer (AutoTokenizer): The tokenizer for the model.
|
106 |
+
train_dataset (Dataset): The dataset for training the model.
|
107 |
+
eval_dataset (Dataset): The dataset for evaluating the model.
|
108 |
+
"""
|
109 |
+
|
110 |
+
self.base_model, self.tokenizer = self.load_LLAMA2_for_finetuning()
|
111 |
+
self.merged_model = None
|
112 |
+
self.train_dataset = train_dataset
|
113 |
+
self.eval_dataset = eval_dataset
|
114 |
+
# please refer to config file 'fine_tuning_config.py' for training arguments description.
|
115 |
+
self.training_arguments = TrainingArguments(
|
116 |
+
output_dir=config.OUTPUT_DIR,
|
117 |
+
num_train_epochs=config.NUM_TRAIN_EPOCHS,
|
118 |
+
per_device_train_batch_size=config.PER_DEVICE_TRAIN_BATCH_SIZE,
|
119 |
+
per_device_eval_batch_size=config.PER_DEVICE_EVAL_BATCH_SIZE,
|
120 |
+
gradient_accumulation_steps=config.GRADIENT_ACCUMULATION_STEPS,
|
121 |
+
fp16=config.FP16,
|
122 |
+
bf16=config.BF16,
|
123 |
+
evaluation_strategy=config.Evaluation_STRATEGY,
|
124 |
+
eval_steps=config.EVALUATION_STEPS,
|
125 |
+
max_grad_norm=config.MAX_GRAD_NORM,
|
126 |
+
learning_rate=config.LEARNING_RATE,
|
127 |
+
weight_decay=config.WEIGHT_DECAY,
|
128 |
+
optim=config.OPTIM,
|
129 |
+
lr_scheduler_type=config.LR_SCHEDULER_TYPE,
|
130 |
+
max_steps=config.MAX_STEPS,
|
131 |
+
warmup_ratio=config.WARMUP_RATIO,
|
132 |
+
group_by_length=config.GROUP_BY_LENGTH,
|
133 |
+
save_steps=config.SAVE_STEPS,
|
134 |
+
logging_steps=config.LOGGING_STEPS,
|
135 |
+
report_to="tensorboard"
|
136 |
+
)
|
137 |
+
|
138 |
+
def load_LLAMA2_for_finetuning(self):
|
139 |
+
"""
|
140 |
+
Loads the LLAMA2 model and tokenizer, specifically configured for fine-tuning.
|
141 |
+
This method ensures the model is ready to be adapted to a specific task or dataset.
|
142 |
+
|
143 |
+
Returns:
|
144 |
+
Tuple[AutoModelForCausalLM, AutoTokenizer]: The loaded model and tokenizer.
|
145 |
+
"""
|
146 |
+
|
147 |
+
llm_manager = Llama2ModelManager()
|
148 |
+
base_model, tokenizer = llm_manager.load_model_and_tokenizer(for_fine_tuning=True)
|
149 |
+
|
150 |
+
return base_model, tokenizer
|
151 |
+
|
152 |
+
def find_all_linear_names(self) -> List[str]:
|
153 |
+
"""
|
154 |
+
Identifies all linear layer names in the model that are suitable for applying LoRA.
|
155 |
+
|
156 |
+
Returns:
|
157 |
+
List[str]: A list of linear layer names.
|
158 |
+
"""
|
159 |
+
cls = bitsandbytes.nn.Linear4bit
|
160 |
+
lora_module_names = set()
|
161 |
+
for name, module in self.base_model.named_modules():
|
162 |
+
if isinstance(module, cls):
|
163 |
+
names = name.split('.')
|
164 |
+
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
165 |
+
|
166 |
+
# We dont want to train these two modules to avoid computational overhead.
|
167 |
+
lora_module_names -= {'lm_head', 'gate_proj'}
|
168 |
+
return list(lora_module_names)
|
169 |
+
|
170 |
+
def print_trainable_parameters(self, use_4bit: bool = False) -> None:
|
171 |
+
"""
|
172 |
+
Calculates and prints the number of trainable parameters in the model.
|
173 |
+
|
174 |
+
Args:
|
175 |
+
use_4bit (bool): If true, calculates the parameter count considering 4-bit quantization.
|
176 |
+
"""
|
177 |
+
trainable_params = sum(p.numel() for p in self.base_model.parameters() if p.requires_grad)
|
178 |
+
if use_4bit:
|
179 |
+
trainable_params /= 2
|
180 |
+
|
181 |
+
total_params = sum(p.numel() for p in self.base_model.parameters())
|
182 |
+
print(f"All Parameters: {total_params:,d} || Trainable Parameters: {trainable_params:,d} "
|
183 |
+
f"|| Trainable Parameters %: {100 * trainable_params / total_params:.2f}%")
|
184 |
+
|
185 |
+
def train(self, peft_config: LoraConfig) -> None:
|
186 |
+
"""
|
187 |
+
Trains the model using the specified PEFT (Progressive Effort Fine-Tuning) configuration.
|
188 |
+
|
189 |
+
Args:
|
190 |
+
peft_config (LoraConfig): Configuration for the PEFT training process.
|
191 |
+
"""
|
192 |
+
self.base_model.config.use_cache = False
|
193 |
+
# Set the pretraining_tp flag to 1 to enable the use of LoRA (Low-Rank Adapters) layers.
|
194 |
+
self.base_model.config.pretraining_tp = 1
|
195 |
+
# Prepare the model for k-bit training by quantizing the weights to 4 bits using bitsandbytes.
|
196 |
+
self.base_model = prepare_model_for_kbit_training(self.base_model)
|
197 |
+
self.trainer = SFTTrainer(
|
198 |
+
model=self.base_model,
|
199 |
+
train_dataset=self.train_dataset,
|
200 |
+
eval_dataset=self.eval_dataset,
|
201 |
+
peft_config=peft_config,
|
202 |
+
dataset_text_field='text',
|
203 |
+
max_seq_length=config.MAX_TOKEN_COUNT,
|
204 |
+
tokenizer=self.tokenizer,
|
205 |
+
args=self.training_arguments,
|
206 |
+
packing=config.PACKING
|
207 |
+
)
|
208 |
+
self.trainer.train()
|
209 |
+
|
210 |
+
def save_model(self):
|
211 |
+
|
212 |
+
"""
|
213 |
+
Saves the fine-tuned model to the specified directory.
|
214 |
+
|
215 |
+
This method saves the model weights and configuration of the fine-tuned model.
|
216 |
+
The save directory and filename are determined by the configuration provided in
|
217 |
+
the 'fine_tuning_config.py' file. It is useful for persisting the fine-tuned model
|
218 |
+
for later use or evaluation.
|
219 |
+
|
220 |
+
The saved model can be easily loaded using Hugging Face's model loading utilities.
|
221 |
+
"""
|
222 |
+
|
223 |
+
self.fine_tuned_adapter_name = config.ADAPTER_SAVE_NAME
|
224 |
+
self.trainer.model.save_pretrained(self.fine_tuned_adapter_name)
|
225 |
+
|
226 |
+
def merge_weights(self):
|
227 |
+
"""
|
228 |
+
Merges the weights of the fine-tuned adapter with the base model.
|
229 |
+
|
230 |
+
This method integrates the fine-tuned adapter weights into the base model,
|
231 |
+
resulting in a single consolidated model. The merged model can then be used
|
232 |
+
for inference or further training.
|
233 |
+
|
234 |
+
After merging, the weights of the adapter are no longer separate from the
|
235 |
+
base model, enabling more efficient storage and deployment. The merged model
|
236 |
+
is stored in the 'self.merged_model' attribute of the Finetuner class.
|
237 |
+
"""
|
238 |
+
|
239 |
+
self.merged_model = PeftModel.from_pretrained(self.base_model, self.fine_tuned_adapter_name)
|
240 |
+
self.merged_model = self.merged_model.merge_and_unload()
|
241 |
+
|
242 |
+
def delete_model(self, model_name: str):
|
243 |
+
"""
|
244 |
+
Deletes a specified model attribute.
|
245 |
+
|
246 |
+
Args:
|
247 |
+
model_name (str): The name of the model attribute to delete.
|
248 |
+
"""
|
249 |
+
try:
|
250 |
+
if hasattr(self, model_name) and getattr(self, model_name) is not None:
|
251 |
+
delattr(self, model_name)
|
252 |
+
print(f"Model '{model_name}' has been deleted.")
|
253 |
+
else:
|
254 |
+
print(f"Warning: Model '{model_name}' has already been cleared or does not exist.")
|
255 |
+
except Exception as e:
|
256 |
+
print(f"Error occurred while deleting model '{model_name}': {str(e)}")
|
257 |
+
|
258 |
+
def delete_trainer(self, trainer_name: str):
|
259 |
+
"""
|
260 |
+
Deletes a specified trainer object.
|
261 |
+
|
262 |
+
Args:
|
263 |
+
trainer_name (str): The name of the trainer object to delete.
|
264 |
+
"""
|
265 |
+
try:
|
266 |
+
if hasattr(self, trainer_name) and getattr(self, trainer_name) is not None:
|
267 |
+
delattr(self, trainer_name)
|
268 |
+
print(f"Trainer object '{trainer_name}' has been deleted.")
|
269 |
+
else:
|
270 |
+
print(f"Warning: Trainer object '{trainer_name}' has already been cleared or does not exist.")
|
271 |
+
except Exception as e:
|
272 |
+
print(f"Error occurred while deleting trainer object '{trainer_name}': {str(e)}")
|
273 |
+
|
274 |
+
def clear_training_resources(self):
|
275 |
+
"""
|
276 |
+
Clears GPU memory.
|
277 |
+
"""
|
278 |
+
try:
|
279 |
+
if torch.cuda.is_available():
|
280 |
+
torch.cuda.empty_cache()
|
281 |
+
print("GPU memory has been cleared.")
|
282 |
+
except Exception as e:
|
283 |
+
print(f"Error occurred while clearing GPU memory: {str(e)}")
|
284 |
+
|
285 |
+
def clear_cache_and_collect_garbage(self):
|
286 |
+
"""
|
287 |
+
Clears Hugging Face's Transformers cache and runs garbage collection.
|
288 |
+
"""
|
289 |
+
try:
|
290 |
+
if os.path.exists(TRANSFORMERS_CACHE):
|
291 |
+
shutil.rmtree(TRANSFORMERS_CACHE, ignore_errors=True)
|
292 |
+
print("Transformers cache has been cleared.")
|
293 |
+
|
294 |
+
gc.collect()
|
295 |
+
print("Garbage collection has been executed.")
|
296 |
+
except Exception as e:
|
297 |
+
print(f"Error occurred while clearing cache and collecting garbage: {str(e)}")
|
298 |
+
|
299 |
+
def fine_tune(save_fine_tuned_adapter=False, merge=False, delete_trainer_after_fine_tune=False):
|
300 |
+
"""
|
301 |
+
Conducts the fine-tuning process of a pre-trained language model using specified configurations.
|
302 |
+
This function encompasses the complete workflow of fine-tuning, including data handling, training,
|
303 |
+
and optional steps like saving the fine-tuned model and merging weights.
|
304 |
+
|
305 |
+
Args:
|
306 |
+
save_fine_tuned_adapter (bool): If True, saves the fine-tuned adapter after training.
|
307 |
+
merge (bool): If True, merges the weights of the fine-tuned adapter into the base model.
|
308 |
+
delete_trainer_after_fine_tune (bool): If True, deletes the trainer object after fine-tuning to free up resources.
|
309 |
+
|
310 |
+
Returns:
|
311 |
+
The fine-tuned model after the fine-tuning process. This could be either the merged model
|
312 |
+
or the trained model based on the provided arguments.
|
313 |
+
|
314 |
+
The function initiates by preparing the training and evaluation datasets using the `FinetuningDataHandler`.
|
315 |
+
It then sets up the QLoRA configuration for the fine-tuning process. The actual training is carried out by
|
316 |
+
the `Finetuner` class. Post training, based on the arguments, the function can save the fine-tuned model,
|
317 |
+
merge the adapter weights with the base model, and clean up resources by deleting the trainer object.
|
318 |
+
"""
|
319 |
+
|
320 |
+
data_handler = FinetuningDataHandler()
|
321 |
+
fine_tuning_data_train, fine_tuning_data_eval = data_handler.inspect_prepare_split_data()
|
322 |
+
qlora = QLoraConfig()
|
323 |
+
peft_config = qlora.lora_config
|
324 |
+
tuner = Finetuner(fine_tuning_data_train, fine_tuning_data_eval)
|
325 |
+
tuner.train(peft_config=peft_config)
|
326 |
+
if save_fine_tuned_adapter:
|
327 |
+
tuner.save_model()
|
328 |
+
|
329 |
+
if merge:
|
330 |
+
tuner.merge_weights()
|
331 |
+
|
332 |
+
if delete_trainer_after_fine_tune:
|
333 |
+
tuner.delete_trainer("trainer")
|
334 |
+
|
335 |
+
tuner.delete_model("base_model") # We always delete this as it is not required after the merger.
|
336 |
+
|
337 |
+
if save_fine_tuned_adapter:
|
338 |
+
tuner.save_model()
|
339 |
+
if tuner.merged_model is not None:
|
340 |
+
return tuner.merged_model
|
341 |
+
else:
|
342 |
+
return tuner.trainer.model
|
343 |
+
|
344 |
+
|
345 |
+
|
346 |
+
if __name__ == "__main__":
|
347 |
+
fine_tune()
|
my_model/fine_tuner/fine_tuning_config.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Configurable parameters for fine-tuning
|
2 |
+
|
3 |
+
import os
|
4 |
+
|
5 |
+
|
6 |
+
# *** Dataset ***
|
7 |
+
# Base directory where the script is running
|
8 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
9 |
+
# Path to the folder containing the data files, relative to the configuration file
|
10 |
+
DATA_FOLDER = 'fine_tuning_data'
|
11 |
+
# Full path to the data folder
|
12 |
+
DATA_FOLDER_PATH = os.path.join(BASE_DIR, DATA_FOLDER)
|
13 |
+
# Path to the dataset file (CSV format)
|
14 |
+
DATASET_FILE = os.path.join(DATA_FOLDER_PATH, 'fine_tuning_data_yolov5.csv') # or 'fine_tuning_data_detic.csv'
|
15 |
+
|
16 |
+
|
17 |
+
# *** Fine-tuned Adapter ***
|
18 |
+
TRAINED_ADAPTER_NAME = 'fine_tuned_adapter' # name of fine-tuned adapter.
|
19 |
+
FINE_TUNED_ADAPTER_FOLDER = 'fine_tuned_model'
|
20 |
+
FINE_TUNED_ADAPTER_PATH = os.path.join(BASE_DIR, FINE_TUNED_ADAPTER_FOLDER)
|
21 |
+
ADAPTER_SAVE_NAME = os.path.join(FINE_TUNED_ADAPTER_PATH, TRAINED_ADAPTER_NAME)
|
22 |
+
|
23 |
+
|
24 |
+
# Proportion of the dataset to include in the test split (e.g., 0.1 for 10%)
|
25 |
+
TEST_SIZE = 0.1
|
26 |
+
|
27 |
+
# Seed for random operations to ensure reproducibility
|
28 |
+
SEED = 123
|
29 |
+
|
30 |
+
# *** QLoRA Configuration Parameters ***
|
31 |
+
# LoRA attention dimension: number of additional parameters in each LoRA layer
|
32 |
+
LORA_R = 64
|
33 |
+
|
34 |
+
# Alpha parameter for LoRA scaling: controls the scaling of LoRA weights
|
35 |
+
LORA_ALPHA = 32
|
36 |
+
|
37 |
+
# Dropout probability for LoRA layers: probability of dropping a unit in LoRA layers
|
38 |
+
LORA_DROPOUT = 0.05
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
# *** TrainingArguments Configuration Parameters for the Transformers library ***
|
43 |
+
# Output directory to save model predictions and checkpoints
|
44 |
+
OUTPUT_DIR = "./TUNED_MODEL_LLAMA"
|
45 |
+
|
46 |
+
# Number of epochs to train the model
|
47 |
+
NUM_TRAIN_EPOCHS = 1
|
48 |
+
|
49 |
+
# Enable mixed-precision training using fp16 (set to True for faster training)
|
50 |
+
FP16 = True
|
51 |
+
|
52 |
+
# Enable mixed-precision training using bf16 (set to True if using an A100 GPU)
|
53 |
+
BF16 = False
|
54 |
+
|
55 |
+
# Batch size per GPU/Device for training
|
56 |
+
PER_DEVICE_TRAIN_BATCH_SIZE = 16
|
57 |
+
|
58 |
+
# Batch size per GPU/Device for evaluation
|
59 |
+
PER_DEVICE_EVAL_BATCH_SIZE = 8
|
60 |
+
|
61 |
+
# Number of update steps to accumulate gradients before performing a backward/update pass
|
62 |
+
GRADIENT_ACCUMULATION_STEPS = 1
|
63 |
+
|
64 |
+
# Enable gradient checkpointing to reduce memory usage at the cost of a slight slowdown
|
65 |
+
GRADIENT_CHECKPOINTING = True
|
66 |
+
|
67 |
+
# Maximum gradient norm for gradient clipping to prevent exploding gradients
|
68 |
+
MAX_GRAD_NORM = 0.3
|
69 |
+
|
70 |
+
# Initial learning rate for the AdamW optimizer
|
71 |
+
LEARNING_RATE = 2e-4
|
72 |
+
|
73 |
+
# Weight decay coefficient for regularization (applied to all layers except bias/LayerNorm weights)
|
74 |
+
WEIGHT_DECAY = 0.01
|
75 |
+
|
76 |
+
# Optimizer type, here using 'paged_adamw_8bit' for efficient training
|
77 |
+
OPTIM = "paged_adamw_8bit"
|
78 |
+
|
79 |
+
# Learning rate scheduler type (e.g., 'linear', 'cosine', etc.)
|
80 |
+
LR_SCHEDULER_TYPE = "linear"
|
81 |
+
|
82 |
+
# Maximum number of training steps, overrides 'num_train_epochs' if set to a positive number
|
83 |
+
# Setting MAX_STEPS = -1 in training arguments for SFTTrainer means that the number of steps will be determined by the
|
84 |
+
# number of epochs, the size of the dataset, the batch size, and the number of GPUs1. This is the default behavior
|
85 |
+
# when MAX_STEPS is not specified or set to a negative value2.
|
86 |
+
MAX_STEPS = -1
|
87 |
+
|
88 |
+
# Ratio of the total number of training steps used for linear warmup
|
89 |
+
WARMUP_RATIO = 0.03
|
90 |
+
|
91 |
+
# Whether to group sequences into batches with the same length to save memory and increase speed
|
92 |
+
GROUP_BY_LENGTH = False
|
93 |
+
|
94 |
+
# Save a model checkpoint every X update steps
|
95 |
+
SAVE_STEPS = 50
|
96 |
+
|
97 |
+
# Log training information every X update steps
|
98 |
+
LOGGING_STEPS = 25
|
99 |
+
|
100 |
+
PACKING = False
|
101 |
+
|
102 |
+
# Evaluation strategy during training ("steps", "epoch, "no")
|
103 |
+
Evaluation_STRATEGY = "steps"
|
104 |
+
|
105 |
+
# Number of update steps between two evaluations if `evaluation_strategy="steps"`.
|
106 |
+
# Will default to the same value as `logging_steps` if not set.
|
107 |
+
EVALUATION_STEPS = 5
|
108 |
+
|
109 |
+
# Maximum number of tokens per sample in the dataset
|
110 |
+
MAX_TOKEN_COUNT = 1024
|
111 |
+
|
112 |
+
|
113 |
+
if __name__=="__main__":
|
114 |
+
pass
|
my_model/fine_tuner/fine_tuning_data/fine_tuning_data_detic.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:77089f24dd5414b0d1dcb5b8f3b34aac3daea86e68c1c70e2da6490482ac9d4b
|
3 |
+
size 54670629
|
my_model/fine_tuner/fine_tuning_data/fine_tuning_data_yolov5.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a44d22827c212a9d7a30bb3fd94cb7d7ad82a968a55eaa09e0ff5a61f85fde05
|
3 |
+
size 14547559
|
my_model/fine_tuner/fine_tuning_data/read_me.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
The data files 'fine_tuning_data_detic.csv' and 'fine_tuning_data_yolov5.csv' are the result of the preparation and
|
2 |
+
filtration after performing below steps:
|
3 |
+
|
4 |
+
- Generate the captions for all the images.
|
5 |
+
- Delete all samples with corrupted or rubbish data. (Please refer to the report for details)
|
6 |
+
- Run object detection models ('yolov5' and 'detic') and generate the corresponding objects for the images corresponding to the remaining samples.
|
7 |
+
- Convert all the question, answer, caption, objects together with the system prompt into the desired template for all
|
8 |
+
the samples (Please refer to the report for the detailed template design).
|
my_model/fine_tuner/fine_tuning_data_handler.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from my_model.utilities import is_pycharm
|
2 |
+
import seaborn as sns
|
3 |
+
from transformers import AutoTokenizer
|
4 |
+
from datasets import Dataset, load_dataset
|
5 |
+
import fine_tuning_config as config
|
6 |
+
from my_model.LLAMA2.LLAMA2_model import Llama2ModelManager
|
7 |
+
from typing import Tuple
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
class FinetuningDataHandler:
|
12 |
+
"""
|
13 |
+
A class dedicated to handling data for fine-tuning language models. It manages loading,
|
14 |
+
inspecting, preparing, and splitting the dataset, specifically designed to filter out
|
15 |
+
data samples exceeding a specified token count limit. This is crucial for models with
|
16 |
+
token count constraints and it helps control the level of GPU RAM tolernace based on the number of tokens,
|
17 |
+
ensuring efficient and effective model fine-tuning.
|
18 |
+
|
19 |
+
Attributes:
|
20 |
+
tokenizer (AutoTokenizer): Tokenizer used for tokenizing the dataset.
|
21 |
+
dataset_file (str): File path to the dataset.
|
22 |
+
max_token_count (int): Maximum allowable token count per data sample.
|
23 |
+
|
24 |
+
Methods:
|
25 |
+
load_llm_tokenizer(): Loads the LLM tokenizer and adds special tokens, if not already loaded.
|
26 |
+
load_dataset(): Loads the dataset from a specified file path.
|
27 |
+
plot_tokens_count_distribution(token_counts, title): Plots the distribution of token counts in the dataset.
|
28 |
+
filter_dataset_by_indices(dataset, valid_indices): Filters the dataset based on valid indices, removing samples exceeding token limits.
|
29 |
+
get_token_counts(dataset): Calculates token counts for each sample in the dataset.
|
30 |
+
prepare_dataset(): Tokenizes and filters the dataset, preparing it for training. Also visualizes token count distribution before and after filtering.
|
31 |
+
split_dataset_for_train_eval(dataset): Divides the dataset into training and evaluation sets.
|
32 |
+
inspect_prepare_split_data(): Coordinates the data preparation and splitting process for fine-tuning.
|
33 |
+
"""
|
34 |
+
|
35 |
+
def __init__(self, tokenizer: AutoTokenizer = None, dataset_file: str = config.DATASET_FILE) -> None:
|
36 |
+
"""
|
37 |
+
Initializes the FinetuningDataHandler class.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
tokenizer (AutoTokenizer): Tokenizer to use for tokenizing the dataset.
|
41 |
+
dataset_file (str): Path to the dataset file.
|
42 |
+
"""
|
43 |
+
self.tokenizer = tokenizer # The tokenizer used for processing the dataset.
|
44 |
+
self.dataset_file = dataset_file # Path to the fine-tuning dataset file.
|
45 |
+
self.max_token_count = config.MAX_TOKEN_COUNT # Max token count for filtering.
|
46 |
+
|
47 |
+
def load_llm_tokenizer(self):
|
48 |
+
"""
|
49 |
+
Loads the LLM tokenizer and adds special tokens, if not already loaded.
|
50 |
+
If the tokenizer is already loaded, this method does nothing.
|
51 |
+
"""
|
52 |
+
|
53 |
+
if self.tokenizer is None:
|
54 |
+
llm_manager = Llama2ModelManager() # Initialize Llama2 model manager.
|
55 |
+
# we only need the tokenizer for the data inspection not the model itself.
|
56 |
+
self.tokenizer = llm_manager.load_tokenizer()
|
57 |
+
llm_manager.add_special_tokens() # Add special tokens specific to LLAMA2 vocab for efficient tokenization.
|
58 |
+
|
59 |
+
def load_dataset(self) -> Dataset:
|
60 |
+
"""
|
61 |
+
Loads the dataset from the specified file path. The dataset is expected to be in CSV format.
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
Dataset: The loaded dataset, ready for processing.
|
65 |
+
"""
|
66 |
+
return load_dataset('csv', data_files=self.dataset_file)
|
67 |
+
|
68 |
+
def plot_tokens_count_distribution(self, token_counts: list, title: str = "Token Count Distribution") -> None:
|
69 |
+
"""
|
70 |
+
Plots the distribution of token counts in the dataset for visualization purposes.
|
71 |
+
|
72 |
+
Args:
|
73 |
+
token_counts (list): List of token counts, each count representing the number of tokens in a dataset sample.
|
74 |
+
title (str): Title for the plot, highlighting the nature of the distribution.
|
75 |
+
"""
|
76 |
+
|
77 |
+
if is_pycharm(): # Ensuring compatibility with PyCharm's environment for interactive plot.
|
78 |
+
import matplotlib
|
79 |
+
matplotlib.use('TkAgg') # Set the backend to 'TkAgg'
|
80 |
+
import matplotlib.pyplot as plt
|
81 |
+
sns.set_style("whitegrid")
|
82 |
+
plt.figure(figsize=(15, 6))
|
83 |
+
plt.hist(token_counts, bins=50, color='#3498db', edgecolor='black')
|
84 |
+
plt.title(title, fontsize=16)
|
85 |
+
plt.xlabel("Number of Tokens", fontsize=14)
|
86 |
+
plt.ylabel("Number of Samples", fontsize=14)
|
87 |
+
plt.xticks(fontsize=12)
|
88 |
+
plt.yticks(fontsize=12)
|
89 |
+
plt.tight_layout()
|
90 |
+
plt.show()
|
91 |
+
|
92 |
+
def filter_dataset_by_indices(self, dataset: Dataset, valid_indices: list) -> Dataset:
|
93 |
+
"""
|
94 |
+
Filters the dataset based on a list of valid indices. This method is used to exclude
|
95 |
+
data samples that have a token count exceeding the specified maximum token count.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
dataset (Dataset): The dataset to be filtered.
|
99 |
+
valid_indices (list): Indices of samples with token counts within the limit.
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
Dataset: Filtered dataset containing only samples with valid indices.
|
103 |
+
"""
|
104 |
+
return dataset['train'].select(valid_indices) # Select only samples with valid indices based on token count.
|
105 |
+
|
106 |
+
def get_token_counts(self, dataset):
|
107 |
+
"""
|
108 |
+
Calculates and returns the token counts for each sample in the dataset.
|
109 |
+
This function assumes the dataset has a 'train' split and a 'text' field.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
dataset (Dataset): The dataset for which to count tokens.
|
113 |
+
|
114 |
+
Returns:
|
115 |
+
List[int]: List of token counts per sample in the dataset.
|
116 |
+
"""
|
117 |
+
|
118 |
+
if 'train' in dataset:
|
119 |
+
return [len(self.tokenizer.tokenize(s)) for s in dataset["train"]["text"]]
|
120 |
+
else:
|
121 |
+
# After filtering the samples with unacceptable token count, the dataset is already
|
122 |
+
# dataset = dataset['train']
|
123 |
+
return [len(self.tokenizer.tokenize(s)) for s in dataset["text"]]
|
124 |
+
|
125 |
+
def prepare_dataset(self) -> Tuple[Dataset, Dataset]:
|
126 |
+
"""
|
127 |
+
Prepares the dataset for fine-tuning by tokenizing the data and filtering out samples
|
128 |
+
that exceed the maximum used context window (configurable through max_token_count).
|
129 |
+
It also visualizes the token count distribution before and after filtering.
|
130 |
+
|
131 |
+
Returns:
|
132 |
+
Tuple[Dataset, Dataset]: The train and evaluate datasets, post-filtering.
|
133 |
+
"""
|
134 |
+
dataset = self.load_dataset()
|
135 |
+
self.load_llm_tokenizer()
|
136 |
+
|
137 |
+
# Count tokens in each dataset sample before filtering
|
138 |
+
token_counts_before_filtering = self.get_token_counts(dataset)
|
139 |
+
# Plot token count distribution before filtering for visualization.
|
140 |
+
self.plot_tokens_count_distribution(token_counts_before_filtering, "Token Count Distribution Before Filtration")
|
141 |
+
# Identify valid indices based on max token count.
|
142 |
+
valid_indices = [i for i, count in enumerate(token_counts_before_filtering) if count <= self.max_token_count]
|
143 |
+
# Filter the dataset to exclude samples with excessive token counts.
|
144 |
+
filtered_dataset = self.filter_dataset_by_indices(dataset, valid_indices)
|
145 |
+
|
146 |
+
token_counts_after_filtering = self.get_token_counts(filtered_dataset)
|
147 |
+
self.plot_tokens_count_distribution(token_counts_after_filtering, "Token Count Distribution After Filtration")
|
148 |
+
|
149 |
+
return self.split_dataset_for_train_eval(filtered_dataset) # split the dataset into training and evaluation.
|
150 |
+
|
151 |
+
def split_dataset_for_train_eval(self, dataset) -> Tuple[Dataset, Dataset]:
|
152 |
+
"""
|
153 |
+
Splits the dataset into training and evaluation datasets.
|
154 |
+
|
155 |
+
Args:
|
156 |
+
dataset (Dataset): The dataset to split.
|
157 |
+
|
158 |
+
Returns:
|
159 |
+
tuple[Dataset, Dataset]: The split training and evaluation datasets.
|
160 |
+
"""
|
161 |
+
split_data = dataset.train_test_split(test_size=config.TEST_SIZE, shuffle=True, seed=config.SEED)
|
162 |
+
train_data, eval_data = split_data['train'], split_data['test']
|
163 |
+
return train_data, eval_data
|
164 |
+
|
165 |
+
def inspect_prepare_split_data(self) -> tuple[Dataset, Dataset]:
|
166 |
+
"""
|
167 |
+
Orchestrates the process of inspecting, preparing, and splitting the dataset for fine-tuning.
|
168 |
+
|
169 |
+
Returns:
|
170 |
+
tuple[Dataset, Dataset]: The prepared training and evaluation datasets.
|
171 |
+
"""
|
172 |
+
return self.prepare_dataset()
|
173 |
+
|
174 |
+
|
175 |
+
# Example usage
|
176 |
+
if __name__ == "__main__":
|
177 |
+
|
178 |
+
# Please uncomment the below lines to test the data prep.
|
179 |
+
#data_handler = FinetuningDataHandler()
|
180 |
+
#fine_tuning_data_train, fine_tuning_data_eval = data_handler.inspect_prepare_split_data()
|
181 |
+
#print(fine_tuning_data_train, fine_tuning_data_eval)
|
182 |
+
pass
|
my_model/object_detection.py
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from transformers import AutoImageProcessor, AutoModelForObjectDetection
|
3 |
+
import torch
|
4 |
+
from PIL import Image, ImageDraw, ImageFont
|
5 |
+
import numpy as np
|
6 |
+
import cv2
|
7 |
+
import os
|
8 |
+
from utilities import get_path, show_image
|
9 |
+
|
10 |
+
|
11 |
+
class ObjectDetector:
|
12 |
+
"""
|
13 |
+
A class for detecting objects in images using models like Detic and YOLOv5.
|
14 |
+
|
15 |
+
This class supports loading and using different object detection models to identify objects
|
16 |
+
in images and draw bounding boxes around them.
|
17 |
+
|
18 |
+
Attributes:
|
19 |
+
model (torch.nn.Module): The loaded object detection model.
|
20 |
+
processor (transformers.AutoImageProcessor): Processor for the Detic model.
|
21 |
+
model_name (str): Name of the model used for detection.
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self):
|
25 |
+
"""
|
26 |
+
Initializes the ObjectDetector class with default values.
|
27 |
+
"""
|
28 |
+
|
29 |
+
self.model = None
|
30 |
+
self.processor = None
|
31 |
+
self.model_name = None
|
32 |
+
|
33 |
+
def load_model(self, model_name='detic', pretrained=True, model_version='yolov5s'):
|
34 |
+
"""
|
35 |
+
Load the specified object detection model.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
model_name (str): Name of the model to load. Options are 'detic' and 'yolov5'.
|
39 |
+
pretrained (bool): Boolean indicating if a pretrained model should be used.
|
40 |
+
model_version (str): Version of the YOLOv5 model, applicable only when using YOLOv5.
|
41 |
+
|
42 |
+
Raises:
|
43 |
+
ValueError: If an unsupported model name is provided.
|
44 |
+
"""
|
45 |
+
|
46 |
+
self.model_name = model_name
|
47 |
+
if model_name == 'detic':
|
48 |
+
self._load_detic_model(pretrained)
|
49 |
+
elif model_name == 'yolov5':
|
50 |
+
self._load_yolov5_model(pretrained, model_version)
|
51 |
+
else:
|
52 |
+
raise ValueError(f"Unsupported model name: {model_name}")
|
53 |
+
|
54 |
+
def _load_detic_model(self, pretrained):
|
55 |
+
"""
|
56 |
+
Load the Detic model.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
pretrained (bool): If True, load a pretrained model.
|
60 |
+
"""
|
61 |
+
|
62 |
+
try:
|
63 |
+
model_path = get_path('deformable-detr-detic', 'models')
|
64 |
+
self.processor = AutoImageProcessor.from_pretrained(model_path)
|
65 |
+
self.model = AutoModelForObjectDetection.from_pretrained(model_path)
|
66 |
+
except Exception as e:
|
67 |
+
print(f"Error loading Detic model: {e}")
|
68 |
+
raise
|
69 |
+
|
70 |
+
def _load_yolov5_model(self, pretrained, model_version):
|
71 |
+
"""
|
72 |
+
Load the YOLOv5 model.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
pretrained (bool): If True, load a pretrained model.
|
76 |
+
model_version (str): Version of the YOLOv5 model.
|
77 |
+
"""
|
78 |
+
|
79 |
+
try:
|
80 |
+
model_path = get_path('yolov5', 'models')
|
81 |
+
if model_path and os.path.exists(model_path):
|
82 |
+
self.model = torch.hub.load(model_path, model_version, pretrained=pretrained, source='local')
|
83 |
+
else:
|
84 |
+
self.model = torch.hub.load('ultralytics/yolov5', model_version, pretrained=pretrained)
|
85 |
+
except Exception as e:
|
86 |
+
print(f"Error loading YOLOv5 model: {e}")
|
87 |
+
raise
|
88 |
+
|
89 |
+
def process_image(self, image_path):
|
90 |
+
"""
|
91 |
+
Process the image from the given path.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
image_path (str): Path to the image file.
|
95 |
+
|
96 |
+
Returns:
|
97 |
+
Image.Image: Processed image in RGB format.
|
98 |
+
|
99 |
+
Raises:
|
100 |
+
Exception: If an error occurs during image processing.
|
101 |
+
"""
|
102 |
+
|
103 |
+
try:
|
104 |
+
with Image.open(image_path) as image:
|
105 |
+
return image.convert("RGB")
|
106 |
+
except Exception as e:
|
107 |
+
print(f"Error processing image: {e}")
|
108 |
+
raise
|
109 |
+
|
110 |
+
def detect_objects(self, image, threshold=0.4):
|
111 |
+
"""
|
112 |
+
Detect objects in the given image using the loaded model.
|
113 |
+
|
114 |
+
Args:
|
115 |
+
image (Image.Image): Image in which to detect objects.
|
116 |
+
threshold (float): Model detection confidence.
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
tuple: A tuple containing a string representation and a list of detected objects.
|
120 |
+
|
121 |
+
Raises:
|
122 |
+
ValueError: If the model is not loaded or the model name is unsupported.
|
123 |
+
"""
|
124 |
+
|
125 |
+
if self.model_name == 'detic':
|
126 |
+
return self._detect_with_detic(image, threshold)
|
127 |
+
elif self.model_name == 'yolov5':
|
128 |
+
return self._detect_with_yolov5(image, threshold)
|
129 |
+
else:
|
130 |
+
raise ValueError("Model not loaded or unsupported model name")
|
131 |
+
|
132 |
+
def _detect_with_detic(self, image, threshold):
|
133 |
+
"""
|
134 |
+
Detect objects using the Detic model.
|
135 |
+
|
136 |
+
Args:
|
137 |
+
image (Image.Image): The image in which to detect objects.
|
138 |
+
threshold (float): The confidence threshold for detections.
|
139 |
+
|
140 |
+
Returns:
|
141 |
+
tuple: A tuple containing a string representation and a list of detected objects.
|
142 |
+
Each object in the list is represented as a tuple (label_name, box_rounded, certainty).
|
143 |
+
"""
|
144 |
+
|
145 |
+
inputs = self.processor(images=image, return_tensors="pt")
|
146 |
+
outputs = self.model(**inputs)
|
147 |
+
target_sizes = torch.tensor([image.size[::-1]])
|
148 |
+
results = self.processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=threshold)[
|
149 |
+
0]
|
150 |
+
|
151 |
+
detected_objects_str = ""
|
152 |
+
detected_objects_list = []
|
153 |
+
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
|
154 |
+
if score >= threshold:
|
155 |
+
label_name = self.model.config.id2label[label.item()]
|
156 |
+
box_rounded = [round(coord, 2) for coord in box.tolist()]
|
157 |
+
certainty = round(score.item() * 100, 2)
|
158 |
+
detected_objects_str += f"{{object: {label_name}, bounding box: {box_rounded}, certainty: {certainty}%}}\n"
|
159 |
+
detected_objects_list.append((label_name, box_rounded, certainty))
|
160 |
+
return detected_objects_str, detected_objects_list
|
161 |
+
|
162 |
+
def _detect_with_yolov5(self, image, threshold):
|
163 |
+
"""
|
164 |
+
Detect objects using the YOLOv5 model.
|
165 |
+
|
166 |
+
Args:
|
167 |
+
image (Image.Image): The image in which to detect objects.
|
168 |
+
threshold (float): The confidence threshold for detections.
|
169 |
+
|
170 |
+
Returns:
|
171 |
+
tuple: A tuple containing a string representation and a list of detected objects.
|
172 |
+
Each object in the list is represented as a tuple (label_name, box_rounded, certainty).
|
173 |
+
"""
|
174 |
+
|
175 |
+
cv2_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
176 |
+
results = self.model(cv2_img)
|
177 |
+
|
178 |
+
detected_objects_str = ""
|
179 |
+
detected_objects_list = []
|
180 |
+
for *bbox, conf, cls in results.xyxy[0]:
|
181 |
+
if conf >= threshold:
|
182 |
+
label_name = results.names[int(cls)]
|
183 |
+
box_rounded = [round(coord.item(), 2) for coord in bbox]
|
184 |
+
certainty = round(conf.item() * 100, 2)
|
185 |
+
detected_objects_str += f"{{object: {label_name}, bounding box: {box_rounded}, certainty: {certainty}%}}\n"
|
186 |
+
detected_objects_list.append((label_name, box_rounded, certainty))
|
187 |
+
return detected_objects_str, detected_objects_list
|
188 |
+
|
189 |
+
def draw_boxes(self, image, detected_objects, show_confidence=True):
|
190 |
+
"""
|
191 |
+
Draw bounding boxes around detected objects in the image.
|
192 |
+
|
193 |
+
Args:
|
194 |
+
image (Image.Image): Image on which to draw.
|
195 |
+
detected_objects (list): List of detected objects.
|
196 |
+
show_confidence (bool): Whether to show confidence scores.
|
197 |
+
|
198 |
+
Returns:
|
199 |
+
Image.Image: Image with drawn boxes.
|
200 |
+
"""
|
201 |
+
|
202 |
+
draw = ImageDraw.Draw(image)
|
203 |
+
try:
|
204 |
+
font = ImageFont.truetype("arial.ttf", 15)
|
205 |
+
except IOError:
|
206 |
+
font = ImageFont.load_default()
|
207 |
+
|
208 |
+
colors = ["red", "green", "blue", "yellow", "purple", "orange"]
|
209 |
+
label_color_map = {}
|
210 |
+
|
211 |
+
for label_name, box, score in detected_objects:
|
212 |
+
if label_name not in label_color_map:
|
213 |
+
label_color_map[label_name] = colors[len(label_color_map) % len(colors)]
|
214 |
+
|
215 |
+
color = label_color_map[label_name]
|
216 |
+
draw.rectangle(box, outline=color, width=3)
|
217 |
+
|
218 |
+
label_text = f"{label_name}"
|
219 |
+
if show_confidence:
|
220 |
+
label_text += f" ({round(score, 2)}%)"
|
221 |
+
draw.text((box[0], box[1]), label_text, fill=color, font=font)
|
222 |
+
|
223 |
+
return image
|
224 |
+
|
225 |
+
|
226 |
+
def detect_and_draw_objects(image_path, model_type='yolov5', threshold=0.2, show_confidence=True):
|
227 |
+
"""
|
228 |
+
Detects objects in an image, draws bounding boxes around them, and returns the processed image and a string description.
|
229 |
+
|
230 |
+
Args:
|
231 |
+
image_path (str): Path to the image file.
|
232 |
+
model_type (str): Type of model to use for detection ('yolov5' or 'detic').
|
233 |
+
threshold (float): Detection threshold.
|
234 |
+
show_confidence (bool): Whether to show confidence scores on the output image.
|
235 |
+
|
236 |
+
Returns:
|
237 |
+
tuple: A tuple containing the processed Image.Image and a string of detected objects.
|
238 |
+
"""
|
239 |
+
|
240 |
+
detector = ObjectDetector()
|
241 |
+
detector.load_model(model_type)
|
242 |
+
image = detector.process_image(image_path)
|
243 |
+
detected_objects_string, detected_objects_list = detector.detect_objects(image, threshold=threshold)
|
244 |
+
image_with_boxes = detector.draw_boxes(image, detected_objects_list, show_confidence=show_confidence)
|
245 |
+
return image_with_boxes, detected_objects_string
|
246 |
+
|
247 |
+
|
248 |
+
# Example usage
|
249 |
+
if __name__ == "__main__":
|
250 |
+
pass
|
251 |
+
|
252 |
+
# 'Sample_Images' is the folder conatining sample images for demo.
|
253 |
+
image_path = get_path('horse.jpg', 'Sample_Images')
|
254 |
+
processed_image, objects_string = detect_and_draw_objects(image_path,
|
255 |
+
model_type='detic',
|
256 |
+
threshold=0.2,
|
257 |
+
show_confidence=False)
|
258 |
+
show_image(processed_image)
|
259 |
+
print("Detected Objects:", objects_string)
|
my_model/utilities.py
ADDED
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
from collections import Counter
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
from PIL import Image
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
from IPython import get_ipython
|
10 |
+
import sys
|
11 |
+
|
12 |
+
|
13 |
+
class VQADataProcessor:
|
14 |
+
"""
|
15 |
+
A class to process OKVQA dataset.
|
16 |
+
|
17 |
+
Attributes:
|
18 |
+
questions_file_path (str): The file path for the questions JSON file.
|
19 |
+
annotations_file_path (str): The file path for the annotations JSON file.
|
20 |
+
questions (list): List of questions extracted from the JSON file.
|
21 |
+
annotations (list): List of annotations extracted from the JSON file.
|
22 |
+
df_questions (DataFrame): DataFrame created from the questions list.
|
23 |
+
df_answers (DataFrame): DataFrame created from the annotations list.
|
24 |
+
merged_df (DataFrame): DataFrame resulting from merging questions and answers.
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(self, questions_file_path, annotations_file_path):
|
28 |
+
"""
|
29 |
+
Initializes the VQADataProcessor with file paths for questions and annotations.
|
30 |
+
|
31 |
+
Parameters:
|
32 |
+
questions_file_path (str): The file path for the questions JSON file.
|
33 |
+
annotations_file_path (str): The file path for the annotations JSON file.
|
34 |
+
"""
|
35 |
+
self.questions_file_path = questions_file_path
|
36 |
+
self.annotations_file_path = annotations_file_path
|
37 |
+
self.questions, self.annotations = self.read_json_files()
|
38 |
+
self.df_questions = pd.DataFrame(self.questions)
|
39 |
+
self.df_answers = pd.DataFrame(self.annotations)
|
40 |
+
self.merged_df = None
|
41 |
+
|
42 |
+
def read_json_files(self):
|
43 |
+
"""
|
44 |
+
Reads the JSON files for questions and annotations.
|
45 |
+
|
46 |
+
Returns:
|
47 |
+
tuple: A tuple containing two lists: questions and annotations.
|
48 |
+
"""
|
49 |
+
with open(self.questions_file_path, 'r') as file:
|
50 |
+
data = json.load(file)
|
51 |
+
questions = data['questions']
|
52 |
+
|
53 |
+
with open(self.annotations_file_path, 'r') as file:
|
54 |
+
data = json.load(file)
|
55 |
+
annotations = data['annotations']
|
56 |
+
|
57 |
+
return questions, annotations
|
58 |
+
|
59 |
+
@staticmethod
|
60 |
+
def find_most_frequent(my_list):
|
61 |
+
"""
|
62 |
+
Finds the most frequent item in a list.
|
63 |
+
|
64 |
+
Parameters:
|
65 |
+
my_list (list): A list of items.
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
The most frequent item in the list. Returns None if the list is empty.
|
69 |
+
"""
|
70 |
+
if not my_list:
|
71 |
+
return None
|
72 |
+
counter = Counter(my_list)
|
73 |
+
most_common = counter.most_common(1)
|
74 |
+
return most_common[0][0]
|
75 |
+
|
76 |
+
def merge_dataframes(self):
|
77 |
+
"""
|
78 |
+
Merges the questions and answers DataFrames on 'question_id' and 'image_id'.
|
79 |
+
"""
|
80 |
+
self.merged_df = pd.merge(self.df_questions, self.df_answers, on=['question_id', 'image_id'])
|
81 |
+
|
82 |
+
def join_words_with_hyphen(self, sentence):
|
83 |
+
|
84 |
+
return '-'.join(sentence.split())
|
85 |
+
|
86 |
+
def process_answers(self):
|
87 |
+
"""
|
88 |
+
Processes the answers by extracting raw and processed answers and finding the most frequent ones.
|
89 |
+
"""
|
90 |
+
if self.merged_df is not None:
|
91 |
+
self.merged_df['raw_answers'] = self.merged_df['answers'].apply(lambda x: [ans['raw_answer'] for ans in x])
|
92 |
+
self.merged_df['processed_answers'] = self.merged_df['answers'].apply(
|
93 |
+
lambda x: [ans['answer'] for ans in x])
|
94 |
+
self.merged_df['most_frequent_raw_answer'] = self.merged_df['raw_answers'].apply(self.find_most_frequent)
|
95 |
+
self.merged_df['most_frequent_processed_answer'] = self.merged_df['processed_answers'].apply(
|
96 |
+
self.find_most_frequent)
|
97 |
+
self.merged_df.drop(columns=['answers'], inplace=True)
|
98 |
+
else:
|
99 |
+
print("DataFrames have not been merged yet.")
|
100 |
+
|
101 |
+
# Apply the function to the 'most_frequent_processed_answer' column
|
102 |
+
self.merged_df['single_word_answers'] = self.merged_df['most_frequent_processed_answer'].apply(
|
103 |
+
self.join_words_with_hyphen)
|
104 |
+
|
105 |
+
def get_processed_data(self):
|
106 |
+
"""
|
107 |
+
Retrieves the processed DataFrame.
|
108 |
+
|
109 |
+
Returns:
|
110 |
+
DataFrame: The processed DataFrame. Returns None if the DataFrame is empty or not processed.
|
111 |
+
"""
|
112 |
+
if self.merged_df is not None:
|
113 |
+
return self.merged_df
|
114 |
+
else:
|
115 |
+
print("DataFrame is empty or not processed yet.")
|
116 |
+
return None
|
117 |
+
|
118 |
+
def save_to_csv(self, df, saved_file_name):
|
119 |
+
|
120 |
+
if saved_file_name is not None:
|
121 |
+
if ".csv" not in saved_file_name:
|
122 |
+
df.to_csv(os.path.join(saved_file_name, ".csv"), index=None)
|
123 |
+
|
124 |
+
else:
|
125 |
+
df.to_csv(saved_file_name, index=None)
|
126 |
+
|
127 |
+
else:
|
128 |
+
df.to_csv("data.csv", index=None)
|
129 |
+
|
130 |
+
def display_dataframe(self):
|
131 |
+
"""
|
132 |
+
Displays the processed DataFrame.
|
133 |
+
"""
|
134 |
+
if self.merged_df is not None:
|
135 |
+
print(self.merged_df)
|
136 |
+
else:
|
137 |
+
print("DataFrame is empty.")
|
138 |
+
|
139 |
+
|
140 |
+
def process_okvqa_dataset(questions_file_path, annotations_file_path, save_to_csv=False, saved_file_name=None):
|
141 |
+
"""
|
142 |
+
Processes the OK-VQA dataset given the file paths for questions and annotations.
|
143 |
+
|
144 |
+
Parameters:
|
145 |
+
questions_file_path (str): The file path for the questions JSON file.
|
146 |
+
annotations_file_path (str): The file path for the annotations JSON file.
|
147 |
+
|
148 |
+
Returns:
|
149 |
+
DataFrame: The processed DataFrame containing merged and processed VQA data.
|
150 |
+
"""
|
151 |
+
# Create an instance of the class
|
152 |
+
processor = VQADataProcessor(questions_file_path, annotations_file_path)
|
153 |
+
|
154 |
+
# Process the data
|
155 |
+
processor.merge_dataframes()
|
156 |
+
processor.process_answers()
|
157 |
+
|
158 |
+
# Retrieve the processed DataFrame
|
159 |
+
processed_data = processor.get_processed_data()
|
160 |
+
|
161 |
+
if save_to_csv:
|
162 |
+
processor.save_to_csv(processed_data, saved_file_name)
|
163 |
+
|
164 |
+
return processed_data
|
165 |
+
|
166 |
+
|
167 |
+
def show_image(image):
|
168 |
+
"""
|
169 |
+
Display an image in various environments (Jupyter, PyCharm, Hugging Face Spaces).
|
170 |
+
Handles different types of image inputs (file path, PIL Image, numpy array, OpenCV, PyTorch tensor).
|
171 |
+
|
172 |
+
Args:
|
173 |
+
image (str or PIL.Image or numpy.ndarray or torch.Tensor): The image to display.
|
174 |
+
"""
|
175 |
+
in_jupyter = is_jupyter_notebook()
|
176 |
+
in_colab = is_google_colab()
|
177 |
+
|
178 |
+
# Convert image to PIL Image if it's a file path, numpy array, or PyTorch tensor
|
179 |
+
if isinstance(image, str):
|
180 |
+
|
181 |
+
if os.path.isfile(image):
|
182 |
+
image = Image.open(image)
|
183 |
+
else:
|
184 |
+
raise ValueError("File path provided does not exist.")
|
185 |
+
elif isinstance(image, np.ndarray):
|
186 |
+
|
187 |
+
if image.ndim == 3 and image.shape[2] in [3, 4]:
|
188 |
+
|
189 |
+
image = Image.fromarray(image[..., ::-1] if image.shape[2] == 3 else image)
|
190 |
+
else:
|
191 |
+
|
192 |
+
image = Image.fromarray(image)
|
193 |
+
elif torch.is_tensor(image):
|
194 |
+
|
195 |
+
image = Image.fromarray(image.permute(1, 2, 0).numpy().astype(np.uint8))
|
196 |
+
|
197 |
+
# Display the image
|
198 |
+
if in_jupyter or in_colab:
|
199 |
+
|
200 |
+
from IPython.display import display
|
201 |
+
display(image)
|
202 |
+
else:
|
203 |
+
|
204 |
+
image.show()
|
205 |
+
|
206 |
+
|
207 |
+
|
208 |
+
def show_image_with_matplotlib(image):
|
209 |
+
if isinstance(image, str):
|
210 |
+
image = Image.open(image)
|
211 |
+
elif isinstance(image, np.ndarray):
|
212 |
+
image = Image.fromarray(image)
|
213 |
+
elif torch.is_tensor(image):
|
214 |
+
image = Image.fromarray(image.permute(1, 2, 0).numpy().astype(np.uint8))
|
215 |
+
|
216 |
+
plt.imshow(image)
|
217 |
+
plt.axis('off') # Turn off axis numbers
|
218 |
+
plt.show()
|
219 |
+
|
220 |
+
|
221 |
+
def is_jupyter_notebook():
|
222 |
+
"""
|
223 |
+
Check if the code is running in a Jupyter notebook.
|
224 |
+
|
225 |
+
Returns:
|
226 |
+
bool: True if running in a Jupyter notebook, False otherwise.
|
227 |
+
"""
|
228 |
+
try:
|
229 |
+
from IPython import get_ipython
|
230 |
+
if 'IPKernelApp' not in get_ipython().config:
|
231 |
+
return False
|
232 |
+
if 'ipykernel' in str(type(get_ipython())):
|
233 |
+
return True # Running in Jupyter Notebook
|
234 |
+
except (NameError, AttributeError):
|
235 |
+
return False # Not running in Jupyter Notebook
|
236 |
+
|
237 |
+
return False # Default to False if none of the above conditions are met
|
238 |
+
|
239 |
+
|
240 |
+
def is_pycharm():
|
241 |
+
return 'PYCHARM_HOSTED' in os.environ
|
242 |
+
|
243 |
+
|
244 |
+
def is_google_colab():
|
245 |
+
return 'COLAB_GPU' in os.environ or 'google.colab' in sys.modules
|
246 |
+
|
247 |
+
|
248 |
+
def get_path(name, path_type):
|
249 |
+
"""
|
250 |
+
Generates a path for models, images, or data based on the specified type.
|
251 |
+
|
252 |
+
Args:
|
253 |
+
name (str): The name of the model, image, or data folder/file.
|
254 |
+
path_type (str): The type of path needed ('models', 'images', or 'data').
|
255 |
+
|
256 |
+
Returns:
|
257 |
+
str: The full path to the specified resource.
|
258 |
+
"""
|
259 |
+
# Get the current working directory (assumed to be inside 'code' folder)
|
260 |
+
current_dir = os.getcwd()
|
261 |
+
|
262 |
+
# Get the directory one level up (the parent directory)
|
263 |
+
parent_dir = os.path.dirname(current_dir)
|
264 |
+
|
265 |
+
# Construct the path to the specified folder
|
266 |
+
folder_path = os.path.join(parent_dir, path_type)
|
267 |
+
|
268 |
+
# Construct the full path to the specific resource
|
269 |
+
full_path = os.path.join(folder_path, name)
|
270 |
+
|
271 |
+
return full_path
|
272 |
+
|
273 |
+
|
274 |
+
|
275 |
+
if __name__ == "__main__":
|
276 |
+
pass
|
277 |
+
#val_data = process_okvqa_dataset('OpenEnded_mscoco_val2014_questions.json', 'mscoco_val2014_annotations.json', save_to_csv=True, saved_file_name="okvqa_val.csv")
|
278 |
+
#train_data = process_okvqa_dataset('OpenEnded_mscoco_train2014_questions.json', 'mscoco_train2014_annotations.json', save_to_csv=True, saved_file_name="okvqa_train.csv")
|