m7mdal7aj commited on
Commit
a1f367e
·
verified ·
1 Parent(s): 9ba7b62

Update my_model/fine_tuner/fine_tuning_data_handler.py

Browse files
my_model/fine_tuner/fine_tuning_data_handler.py CHANGED
@@ -1,11 +1,10 @@
 
1
  from my_model.utilities.gen_utilities import is_pycharm
2
  import seaborn as sns
3
  from transformers import AutoTokenizer
4
  from datasets import Dataset, load_dataset
5
  import my_model.config.fine_tuning_config as config
6
  from my_model.LLAMA2.LLAMA2_model import Llama2ModelManager
7
- from typing import Tuple
8
-
9
 
10
 
11
  class FinetuningDataHandler:
@@ -13,7 +12,7 @@ class FinetuningDataHandler:
13
  A class dedicated to handling data for fine-tuning language models. It manages loading,
14
  inspecting, preparing, and splitting the dataset, specifically designed to filter out
15
  data samples exceeding a specified token count limit. This is crucial for models with
16
- token count constraints and it helps control the level of GPU RAM tolernace based on the number of tokens,
17
  ensuring efficient and effective model fine-tuning.
18
 
19
  Attributes:
@@ -22,14 +21,15 @@ class FinetuningDataHandler:
22
  max_token_count (int): Maximum allowable token count per data sample.
23
 
24
  Methods:
25
- load_llm_tokenizer(): Loads the LLM tokenizer and adds special tokens, if not already loaded.
26
- load_dataset(): Loads the dataset from a specified file path.
27
- plot_tokens_count_distribution(token_counts, title): Plots the distribution of token counts in the dataset.
28
- filter_dataset_by_indices(dataset, valid_indices): Filters the dataset based on valid indices, removing samples exceeding token limits.
29
- get_token_counts(dataset): Calculates token counts for each sample in the dataset.
30
- prepare_dataset(): Tokenizes and filters the dataset, preparing it for training. Also visualizes token count distribution before and after filtering.
31
- split_dataset_for_train_eval(dataset): Divides the dataset into training and evaluation sets.
32
- inspect_prepare_split_data(): Coordinates the data preparation and splitting process for fine-tuning.
 
33
  """
34
 
35
  def __init__(self, tokenizer: AutoTokenizer = None, dataset_file: str = config.DATASET_FILE) -> None:
@@ -37,17 +37,21 @@ class FinetuningDataHandler:
37
  Initializes the FinetuningDataHandler class.
38
 
39
  Args:
40
- tokenizer (AutoTokenizer): Tokenizer to use for tokenizing the dataset.
41
- dataset_file (str): Path to the dataset file.
42
  """
 
43
  self.tokenizer = tokenizer # The tokenizer used for processing the dataset.
44
  self.dataset_file = dataset_file # Path to the fine-tuning dataset file.
45
- self.max_token_count = config.MAX_TOKEN_COUNT # Max token count for filtering.
46
 
47
- def load_llm_tokenizer(self):
48
  """
49
  Loads the LLM tokenizer and adds special tokens, if not already loaded.
50
  If the tokenizer is already loaded, this method does nothing.
 
 
 
51
  """
52
 
53
  if self.tokenizer is None:
@@ -63,21 +67,26 @@ class FinetuningDataHandler:
63
  Returns:
64
  Dataset: The loaded dataset, ready for processing.
65
  """
 
66
  return load_dataset('csv', data_files=self.dataset_file)
67
 
68
- def plot_tokens_count_distribution(self, token_counts: list, title: str = "Token Count Distribution") -> None:
69
  """
70
  Plots the distribution of token counts in the dataset for visualization purposes.
71
 
72
  Args:
73
- token_counts (list): List of token counts, each count representing the number of tokens in a dataset sample.
 
74
  title (str): Title for the plot, highlighting the nature of the distribution.
 
 
 
75
  """
76
 
77
  if is_pycharm(): # Ensuring compatibility with PyCharm's environment for interactive plot.
78
- import matplotlib
79
  matplotlib.use('TkAgg') # Set the backend to 'TkAgg'
80
- import matplotlib.pyplot as plt
81
  sns.set_style("whitegrid")
82
  plt.figure(figsize=(15, 6))
83
  plt.hist(token_counts, bins=50, color='#3498db', edgecolor='black')
@@ -89,21 +98,21 @@ class FinetuningDataHandler:
89
  plt.tight_layout()
90
  plt.show()
91
 
92
- def filter_dataset_by_indices(self, dataset: Dataset, valid_indices: list) -> Dataset:
93
  """
94
  Filters the dataset based on a list of valid indices. This method is used to exclude
95
  data samples that have a token count exceeding the specified maximum token count.
96
 
97
  Args:
98
  dataset (Dataset): The dataset to be filtered.
99
- valid_indices (list): Indices of samples with token counts within the limit.
100
 
101
  Returns:
102
  Dataset: Filtered dataset containing only samples with valid indices.
103
  """
104
  return dataset['train'].select(valid_indices) # Select only samples with valid indices based on token count.
105
 
106
- def get_token_counts(self, dataset):
107
  """
108
  Calculates and returns the token counts for each sample in the dataset.
109
  This function assumes the dataset has a 'train' split and a 'text' field.
@@ -131,6 +140,7 @@ class FinetuningDataHandler:
131
  Returns:
132
  Tuple[Dataset, Dataset]: The train and evaluate datasets, post-filtering.
133
  """
 
134
  dataset = self.load_dataset()
135
  self.load_llm_tokenizer()
136
 
@@ -148,7 +158,7 @@ class FinetuningDataHandler:
148
 
149
  return self.split_dataset_for_train_eval(filtered_dataset) # split the dataset into training and evaluation.
150
 
151
- def split_dataset_for_train_eval(self, dataset) -> Tuple[Dataset, Dataset]:
152
  """
153
  Splits the dataset into training and evaluation datasets.
154
 
@@ -156,27 +166,29 @@ class FinetuningDataHandler:
156
  dataset (Dataset): The dataset to split.
157
 
158
  Returns:
159
- tuple[Dataset, Dataset]: The split training and evaluation datasets.
160
  """
 
161
  split_data = dataset.train_test_split(test_size=config.TEST_SIZE, shuffle=True, seed=config.SEED)
162
  train_data, eval_data = split_data['train'], split_data['test']
163
  return train_data, eval_data
164
 
165
- def inspect_prepare_split_data(self) -> tuple[Dataset, Dataset]:
166
  """
167
  Orchestrates the process of inspecting, preparing, and splitting the dataset for fine-tuning.
168
 
169
  Returns:
170
- tuple[Dataset, Dataset]: The prepared training and evaluation datasets.
171
  """
 
172
  return self.prepare_dataset()
173
 
174
 
175
  # Example usage
176
  if __name__ == "__main__":
177
-
178
- # Please uncomment the below lines to test the data prep.
179
- #data_handler = FinetuningDataHandler()
180
- #fine_tuning_data_train, fine_tuning_data_eval = data_handler.inspect_prepare_split_data()
181
- #print(fine_tuning_data_train, fine_tuning_data_eval)
182
  pass
 
1
+ from typing import Tuple, List
2
  from my_model.utilities.gen_utilities import is_pycharm
3
  import seaborn as sns
4
  from transformers import AutoTokenizer
5
  from datasets import Dataset, load_dataset
6
  import my_model.config.fine_tuning_config as config
7
  from my_model.LLAMA2.LLAMA2_model import Llama2ModelManager
 
 
8
 
9
 
10
  class FinetuningDataHandler:
 
12
  A class dedicated to handling data for fine-tuning language models. It manages loading,
13
  inspecting, preparing, and splitting the dataset, specifically designed to filter out
14
  data samples exceeding a specified token count limit. This is crucial for models with
15
+ token count constraints and it helps control the level of GPU RAM tolerance based on the number of tokens,
16
  ensuring efficient and effective model fine-tuning.
17
 
18
  Attributes:
 
21
  max_token_count (int): Maximum allowable token count per data sample.
22
 
23
  Methods:
24
+ load_llm_tokenizer: Loads the LLM tokenizer and adds special tokens, if not already loaded.
25
+ load_dataset: Loads the dataset from a specified file path.
26
+ plot_tokens_count_distribution: Plots the distribution of token counts in the dataset.
27
+ filter_dataset_by_indices: Filters the dataset based on valid indices, removing samples exceeding token limits.
28
+ get_token_counts: Calculates token counts for each sample in the dataset.
29
+ prepare_dataset: Tokenizes and filters the dataset, preparing it for training. Also visualizes token count
30
+ distribution before and after filtering.
31
+ split_dataset_for_train_eval: Divides the dataset into training and evaluation sets.
32
+ inspect_prepare_split_data: Coordinates the data preparation and splitting process for fine-tuning.
33
  """
34
 
35
  def __init__(self, tokenizer: AutoTokenizer = None, dataset_file: str = config.DATASET_FILE) -> None:
 
37
  Initializes the FinetuningDataHandler class.
38
 
39
  Args:
40
+ tokenizer (AutoTokenizer, optional): Tokenizer to use for tokenizing the dataset. Defaults to None.
41
+ dataset_file (str): Path to the dataset file. Defaults to config.DATASET_FILE.
42
  """
43
+
44
  self.tokenizer = tokenizer # The tokenizer used for processing the dataset.
45
  self.dataset_file = dataset_file # Path to the fine-tuning dataset file.
46
+ self.max_token_count = config.MAX_TOKEN_COUNT # Max token count for filtering set to 1,024.
47
 
48
+ def load_llm_tokenizer(self) -> None:
49
  """
50
  Loads the LLM tokenizer and adds special tokens, if not already loaded.
51
  If the tokenizer is already loaded, this method does nothing.
52
+
53
+ Returns:
54
+ None
55
  """
56
 
57
  if self.tokenizer is None:
 
67
  Returns:
68
  Dataset: The loaded dataset, ready for processing.
69
  """
70
+
71
  return load_dataset('csv', data_files=self.dataset_file)
72
 
73
+ def plot_tokens_count_distribution(self, token_counts: List[int], title: str = "Token Count Distribution") -> None:
74
  """
75
  Plots the distribution of token counts in the dataset for visualization purposes.
76
 
77
  Args:
78
+ token_counts (List[int]): List of token counts, each count representing the number of tokens in a dataset
79
+ sample.
80
  title (str): Title for the plot, highlighting the nature of the distribution.
81
+
82
+ Returns:
83
+ None
84
  """
85
 
86
  if is_pycharm(): # Ensuring compatibility with PyCharm's environment for interactive plot.
87
+ import matplotlib # The import is kept here intentionaly.
88
  matplotlib.use('TkAgg') # Set the backend to 'TkAgg'
89
+ import matplotlib.pyplot as plt # The import is kept here intentionaly.
90
  sns.set_style("whitegrid")
91
  plt.figure(figsize=(15, 6))
92
  plt.hist(token_counts, bins=50, color='#3498db', edgecolor='black')
 
98
  plt.tight_layout()
99
  plt.show()
100
 
101
+ def filter_dataset_by_indices(self, dataset: Dataset, valid_indices: List[int]) -> Dataset:
102
  """
103
  Filters the dataset based on a list of valid indices. This method is used to exclude
104
  data samples that have a token count exceeding the specified maximum token count.
105
 
106
  Args:
107
  dataset (Dataset): The dataset to be filtered.
108
+ valid_indices (List[int]): Indices of samples with token counts within the limit.
109
 
110
  Returns:
111
  Dataset: Filtered dataset containing only samples with valid indices.
112
  """
113
  return dataset['train'].select(valid_indices) # Select only samples with valid indices based on token count.
114
 
115
+ def get_token_counts(self, dataset: Dataset) -> List[int]:
116
  """
117
  Calculates and returns the token counts for each sample in the dataset.
118
  This function assumes the dataset has a 'train' split and a 'text' field.
 
140
  Returns:
141
  Tuple[Dataset, Dataset]: The train and evaluate datasets, post-filtering.
142
  """
143
+
144
  dataset = self.load_dataset()
145
  self.load_llm_tokenizer()
146
 
 
158
 
159
  return self.split_dataset_for_train_eval(filtered_dataset) # split the dataset into training and evaluation.
160
 
161
+ def split_dataset_for_train_eval(self, dataset: Dataset) -> Tuple[Dataset, Dataset]:
162
  """
163
  Splits the dataset into training and evaluation datasets.
164
 
 
166
  dataset (Dataset): The dataset to split.
167
 
168
  Returns:
169
+ Tuple[Dataset, Dataset]: The split training and evaluation datasets.
170
  """
171
+
172
  split_data = dataset.train_test_split(test_size=config.TEST_SIZE, shuffle=True, seed=config.SEED)
173
  train_data, eval_data = split_data['train'], split_data['test']
174
  return train_data, eval_data
175
 
176
+ def inspect_prepare_split_data(self) -> Tuple[Dataset, Dataset]:
177
  """
178
  Orchestrates the process of inspecting, preparing, and splitting the dataset for fine-tuning.
179
 
180
  Returns:
181
+ Tuple[Dataset, Dataset]: The prepared training and evaluation datasets.
182
  """
183
+
184
  return self.prepare_dataset()
185
 
186
 
187
  # Example usage
188
  if __name__ == "__main__":
189
+
190
+ # Please uncomment the below lines to test the data prep.
191
+ # data_handler = FinetuningDataHandler()
192
+ # fine_tuning_data_train, fine_tuning_data_eval = data_handler.inspect_prepare_split_data()
193
+ # print(fine_tuning_data_train, fine_tuning_data_eval)
194
  pass