Upload 2 files
Browse files- finetune.py +111 -0
- optimize_lr.py +401 -0
finetune.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
CONTEXT_WINDOW = 1024 #has to fit in 4090
|
4 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
5 |
+
|
6 |
+
from transformers import (
|
7 |
+
AutoTokenizer, AutoModelForCausalLM, TrainingArguments,
|
8 |
+
Trainer, DataCollatorForLanguageModeling
|
9 |
+
)
|
10 |
+
import torch
|
11 |
+
from datasets import load_dataset
|
12 |
+
from huggingface_hub import login
|
13 |
+
|
14 |
+
# setup tokenizer
|
15 |
+
tokenizer = AutoTokenizer.from_pretrained("Zyphra/Zamba2-1.2B-instruct", token=HF_TOKEN)
|
16 |
+
if tokenizer.pad_token is None:
|
17 |
+
tokenizer.pad_token = tokenizer.eos_token
|
18 |
+
tokenizer.padding_side = "left" # better for inference
|
19 |
+
|
20 |
+
# init model with auto device mapping
|
21 |
+
model = AutoModelForCausalLM.from_pretrained(
|
22 |
+
"Zyphra/Zamba2-1.2B-instruct",
|
23 |
+
torch_dtype=torch.bfloat16,
|
24 |
+
device_map="auto" # handles multi-gpu/cpu mapping
|
25 |
+
)
|
26 |
+
model.config.pad_token_id = tokenizer.pad_token_id
|
27 |
+
|
28 |
+
# Load the Dutch Dolly dataset
|
29 |
+
dataset = load_dataset("BramVanroy/dolly-15k-dutch", split="train_sft")
|
30 |
+
|
31 |
+
def prepare_chat_format(examples):
|
32 |
+
chats = []
|
33 |
+
for messages in examples['messages']:
|
34 |
+
try:
|
35 |
+
chat = tokenizer.apply_chat_template(
|
36 |
+
messages,
|
37 |
+
tokenize=True,
|
38 |
+
max_length=CONTEXT_WINDOW,
|
39 |
+
truncation=True,
|
40 |
+
return_tensors=None
|
41 |
+
)
|
42 |
+
except Exception as e:
|
43 |
+
print(f"Error applying chat template: {e}")
|
44 |
+
# Fallback format if chat template fails
|
45 |
+
text = ""
|
46 |
+
for message in messages:
|
47 |
+
role = message["role"]
|
48 |
+
content = message["content"]
|
49 |
+
text += f"<|{role}|>\n{content}</s>\n"
|
50 |
+
|
51 |
+
chat = tokenizer(
|
52 |
+
text,
|
53 |
+
max_length=CONTEXT_WINDOW,
|
54 |
+
truncation=True,
|
55 |
+
return_tensors=None
|
56 |
+
)["input_ids"]
|
57 |
+
|
58 |
+
chats.append(chat)
|
59 |
+
return {"input_ids": chats}
|
60 |
+
|
61 |
+
# Process the dataset
|
62 |
+
tokenized_dataset = dataset.map(
|
63 |
+
prepare_chat_format,
|
64 |
+
batched=True,
|
65 |
+
remove_columns=dataset.column_names
|
66 |
+
)
|
67 |
+
|
68 |
+
# training config
|
69 |
+
training_args = TrainingArguments(
|
70 |
+
output_dir="./zamba2-finetuned",
|
71 |
+
num_train_epochs=2,
|
72 |
+
per_device_train_batch_size=4,
|
73 |
+
save_steps=500,
|
74 |
+
save_total_limit=2,
|
75 |
+
logging_steps=100,
|
76 |
+
learning_rate=2e-5,
|
77 |
+
weight_decay=0.01,
|
78 |
+
fp16=False,
|
79 |
+
bf16=True,
|
80 |
+
gradient_accumulation_steps=8,
|
81 |
+
dataloader_num_workers=4,
|
82 |
+
gradient_checkpointing=True,
|
83 |
+
max_grad_norm=1.0,
|
84 |
+
warmup_steps=100
|
85 |
+
)
|
86 |
+
|
87 |
+
data_collator = DataCollatorForLanguageModeling(
|
88 |
+
tokenizer=tokenizer,
|
89 |
+
mlm=False
|
90 |
+
)
|
91 |
+
|
92 |
+
# custom trainer to handle device mapping
|
93 |
+
class CustomTrainer(Trainer):
|
94 |
+
def __init__(self, *args, **kwargs):
|
95 |
+
super().__init__(*args, **kwargs)
|
96 |
+
self.model = model
|
97 |
+
|
98 |
+
def _move_model_to_device(self, model, device):
|
99 |
+
pass # model already mapped to devices
|
100 |
+
|
101 |
+
trainer = CustomTrainer(
|
102 |
+
model=model,
|
103 |
+
args=training_args,
|
104 |
+
train_dataset=tokenized_dataset,
|
105 |
+
data_collator=data_collator
|
106 |
+
)
|
107 |
+
|
108 |
+
# Add explicit training and saving steps
|
109 |
+
trainer.train()
|
110 |
+
model.save_pretrained("./zamba2-finetuned-final")
|
111 |
+
tokenizer.save_pretrained("./zamba2-finetuned-final")
|
optimize_lr.py
ADDED
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import optuna
|
2 |
+
from transformers import (
|
3 |
+
AutoTokenizer, AutoModelForCausalLM, TrainingArguments,
|
4 |
+
Trainer, DataCollatorForLanguageModeling
|
5 |
+
)
|
6 |
+
import torch
|
7 |
+
from datasets import load_dataset
|
8 |
+
import numpy as np
|
9 |
+
import gc
|
10 |
+
from sklearn.gaussian_process import GaussianProcessRegressor
|
11 |
+
from sklearn.gaussian_process.kernels import ConstantKernel, Matern
|
12 |
+
import matplotlib.pyplot as plt
|
13 |
+
from scipy.stats import norm
|
14 |
+
import warnings
|
15 |
+
warnings.filterwarnings('ignore', category=UserWarning)
|
16 |
+
|
17 |
+
from transformers import TrainerCallback
|
18 |
+
|
19 |
+
import argparse
|
20 |
+
|
21 |
+
# Configuration parameters
|
22 |
+
num_trials = 10 # Adjust this value to control the number of optimization trials
|
23 |
+
DATASET = load_dataset("BramVanroy/dolly-15k-dutch", split="train_sft[:1000]")
|
24 |
+
CONTEXT_WINDOW = 1024
|
25 |
+
|
26 |
+
# Initialize tokenizer once
|
27 |
+
tokenizer = AutoTokenizer.from_pretrained("Zyphra/Zamba2-1.2B")
|
28 |
+
if tokenizer.pad_token is None:
|
29 |
+
tokenizer.pad_token = tokenizer.eos_token
|
30 |
+
tokenizer.padding_side = "left"
|
31 |
+
|
32 |
+
def prepare_chat_format(examples):
|
33 |
+
chats = []
|
34 |
+
for messages in examples['messages']:
|
35 |
+
try:
|
36 |
+
chat = tokenizer.apply_chat_template(
|
37 |
+
messages,
|
38 |
+
tokenize=True,
|
39 |
+
max_length=CONTEXT_WINDOW,
|
40 |
+
truncation=True,
|
41 |
+
return_tensors=None
|
42 |
+
)
|
43 |
+
chats.append(chat)
|
44 |
+
except Exception as e:
|
45 |
+
print(f"Error applying chat template: {e}")
|
46 |
+
print("Fallback format if chat template fails")
|
47 |
+
text = ""
|
48 |
+
for message in messages:
|
49 |
+
role = message["role"]
|
50 |
+
content = message["content"]
|
51 |
+
text += f"<|{role}|>\n{content}</s>\n"
|
52 |
+
|
53 |
+
chat = tokenizer(
|
54 |
+
text,
|
55 |
+
max_length=CONTEXT_WINDOW,
|
56 |
+
truncation=True,
|
57 |
+
return_tensors=None
|
58 |
+
)["input_ids"]
|
59 |
+
|
60 |
+
chats.append(chat)
|
61 |
+
return {"input_ids": chats}
|
62 |
+
|
63 |
+
# Prepare dataset once
|
64 |
+
tokenized_dataset = DATASET.map(
|
65 |
+
prepare_chat_format,
|
66 |
+
batched=True,
|
67 |
+
remove_columns=DATASET.column_names
|
68 |
+
)
|
69 |
+
|
70 |
+
def clear_memory():
|
71 |
+
"""Clear GPU memory between trials"""
|
72 |
+
if torch.cuda.is_available():
|
73 |
+
torch.cuda.empty_cache()
|
74 |
+
gc.collect()
|
75 |
+
|
76 |
+
class LossCallback(TrainerCallback):
|
77 |
+
def __init__(self):
|
78 |
+
self.losses = []
|
79 |
+
|
80 |
+
def on_log(self, args, state, control, logs=None, **kwargs):
|
81 |
+
if logs is not None and "loss" in logs:
|
82 |
+
self.losses.append(logs["loss"])
|
83 |
+
|
84 |
+
def objective(trial):
|
85 |
+
# Clear memory from previous trial
|
86 |
+
clear_memory()
|
87 |
+
|
88 |
+
lr = trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True)
|
89 |
+
|
90 |
+
# Initialize model with fresh state
|
91 |
+
torch.manual_seed(42)
|
92 |
+
model = AutoModelForCausalLM.from_pretrained(
|
93 |
+
"Zyphra/Zamba2-1.2B",
|
94 |
+
torch_dtype=torch.bfloat16,
|
95 |
+
device_map="auto"
|
96 |
+
)
|
97 |
+
model.config.pad_token_id = tokenizer.pad_token_id
|
98 |
+
|
99 |
+
# Calculate steps with larger batch size
|
100 |
+
batch_size = 4 # Increased from 1
|
101 |
+
grad_accum_steps = 8 # Decreased from 32 since we increased batch size
|
102 |
+
effective_batch_size = batch_size * grad_accum_steps # Still 32 total
|
103 |
+
total_steps = len(tokenized_dataset) // effective_batch_size
|
104 |
+
|
105 |
+
# Training arguments
|
106 |
+
training_args = TrainingArguments(
|
107 |
+
output_dir=f"./optuna_runs/trial_{trial.number}",
|
108 |
+
num_train_epochs=1,
|
109 |
+
per_device_train_batch_size=batch_size, # Increased
|
110 |
+
gradient_accumulation_steps=grad_accum_steps, # Decreased
|
111 |
+
logging_steps=max(total_steps // 20, 1),
|
112 |
+
learning_rate=lr,
|
113 |
+
weight_decay=0.01,
|
114 |
+
fp16=False,
|
115 |
+
bf16=True,
|
116 |
+
warmup_steps=total_steps // 10,
|
117 |
+
save_steps=1000000,
|
118 |
+
save_total_limit=None,
|
119 |
+
report_to="none",
|
120 |
+
seed=42,
|
121 |
+
dataloader_num_workers=4, # Added for faster data loading
|
122 |
+
gradient_checkpointing=True, # Added to optimize memory usage
|
123 |
+
max_grad_norm=1.0 # Added for stability
|
124 |
+
)
|
125 |
+
|
126 |
+
print(f"\nTrial {trial.number}:")
|
127 |
+
print(f"Learning rate: {lr}")
|
128 |
+
print(f"Total steps: {total_steps}")
|
129 |
+
print(f"Logging every {training_args.logging_steps} steps")
|
130 |
+
|
131 |
+
data_collator = DataCollatorForLanguageModeling(
|
132 |
+
tokenizer=tokenizer,
|
133 |
+
mlm=False
|
134 |
+
)
|
135 |
+
|
136 |
+
class CustomTrainer(Trainer):
|
137 |
+
def __init__(self, *args, **kwargs):
|
138 |
+
super().__init__(*args, **kwargs)
|
139 |
+
self.model = model
|
140 |
+
|
141 |
+
def _move_model_to_device(self, model, device):
|
142 |
+
pass
|
143 |
+
|
144 |
+
# Initialize callback
|
145 |
+
loss_callback = LossCallback()
|
146 |
+
|
147 |
+
trainer = CustomTrainer(
|
148 |
+
model=model,
|
149 |
+
args=training_args,
|
150 |
+
train_dataset=tokenized_dataset,
|
151 |
+
data_collator=data_collator,
|
152 |
+
callbacks=[loss_callback] # Use the proper callback
|
153 |
+
)
|
154 |
+
|
155 |
+
try:
|
156 |
+
train_result = trainer.train()
|
157 |
+
|
158 |
+
# Calculate mean of last 20% of losses
|
159 |
+
losses = loss_callback.losses # Get losses from callback
|
160 |
+
n_losses = max(len(losses) // 5, 1)
|
161 |
+
final_losses = losses[-n_losses:]
|
162 |
+
mean_loss = np.mean(final_losses) if final_losses else float('inf')
|
163 |
+
|
164 |
+
# Clean up
|
165 |
+
del model
|
166 |
+
del trainer
|
167 |
+
clear_memory()
|
168 |
+
|
169 |
+
return mean_loss
|
170 |
+
|
171 |
+
except Exception as e:
|
172 |
+
print(f"Trial failed with error: {e}")
|
173 |
+
# Clean up on failure
|
174 |
+
del model
|
175 |
+
del trainer
|
176 |
+
clear_memory()
|
177 |
+
return float('inf')
|
178 |
+
|
179 |
+
# Create and run the study
|
180 |
+
study = optuna.create_study(
|
181 |
+
direction="minimize",
|
182 |
+
sampler=optuna.samplers.TPESampler(seed=42),
|
183 |
+
study_name="learning_rate_optimization"
|
184 |
+
)
|
185 |
+
|
186 |
+
study.optimize(objective, n_trials=num_trials)
|
187 |
+
|
188 |
+
# Print results
|
189 |
+
print(f"\nOptimization Results ({num_trials} trials):")
|
190 |
+
print("Best learning rate:", study.best_params["learning_rate"])
|
191 |
+
print("Best loss:", study.best_value)
|
192 |
+
print("\nAll trials:")
|
193 |
+
for trial in study.trials:
|
194 |
+
print(f"Learning rate: {trial.params['learning_rate']:.2e}, Loss: {trial.value:.4f}")
|
195 |
+
|
196 |
+
# Save results
|
197 |
+
import json
|
198 |
+
results = {
|
199 |
+
"best_learning_rate": study.best_params["learning_rate"],
|
200 |
+
"best_loss": study.best_value,
|
201 |
+
"all_trials": [(trial.params["learning_rate"], trial.value) for trial in study.trials]
|
202 |
+
}
|
203 |
+
with open("lr_optimization_results.json", "w") as f:
|
204 |
+
json.dump(results, f, indent=4)
|
205 |
+
|
206 |
+
# Plot optimization history
|
207 |
+
try:
|
208 |
+
fig = optuna.visualization.plot_optimization_history(study)
|
209 |
+
fig.show()
|
210 |
+
except Exception as e:
|
211 |
+
print(f"Could not create visualization: {e}")
|
212 |
+
|
213 |
+
# Add sophisticated final optimization using Gaussian Process Regression
|
214 |
+
def optimize_final_lr(study):
|
215 |
+
try:
|
216 |
+
# Extract learning rates and losses
|
217 |
+
X = np.array([[trial.params['learning_rate']] for trial in study.trials])
|
218 |
+
y = np.array([trial.value for trial in study.trials])
|
219 |
+
|
220 |
+
# Check if we have any valid results
|
221 |
+
valid_mask = np.isfinite(y)
|
222 |
+
if not np.any(valid_mask):
|
223 |
+
print("No valid trials found. Returning default learning rate.")
|
224 |
+
return {
|
225 |
+
'gpr_optimal_lr': 2e-5, # default fallback
|
226 |
+
'ei_optimal_lr': 2e-5,
|
227 |
+
'predicted_loss': float('inf'),
|
228 |
+
'uncertainty': float('inf')
|
229 |
+
}
|
230 |
+
|
231 |
+
# Filter out infinite values
|
232 |
+
X = X[valid_mask]
|
233 |
+
y = y[valid_mask]
|
234 |
+
|
235 |
+
# Ensure we have enough points for fitting
|
236 |
+
if len(X) < 2:
|
237 |
+
print("Not enough valid trials for GPR. Returning best observed value.")
|
238 |
+
best_idx = np.argmin(y)
|
239 |
+
return {
|
240 |
+
'gpr_optimal_lr': float(X[best_idx][0]),
|
241 |
+
'ei_optimal_lr': float(X[best_idx][0]),
|
242 |
+
'predicted_loss': float(y[best_idx]),
|
243 |
+
'uncertainty': float('inf')
|
244 |
+
}
|
245 |
+
|
246 |
+
# Transform to log space
|
247 |
+
X_log = np.log10(X)
|
248 |
+
|
249 |
+
# Normalize y values
|
250 |
+
y_mean = np.mean(y)
|
251 |
+
y_std = np.std(y)
|
252 |
+
if y_std == 0:
|
253 |
+
y_std = 1
|
254 |
+
y_normalized = (y - y_mean) / y_std
|
255 |
+
|
256 |
+
# Define kernel
|
257 |
+
kernel = ConstantKernel(1.0) * Matern(length_scale=1.0, nu=2.5)
|
258 |
+
|
259 |
+
# Fit Gaussian Process
|
260 |
+
gpr = GaussianProcessRegressor(
|
261 |
+
kernel=kernel,
|
262 |
+
n_restarts_optimizer=10,
|
263 |
+
random_state=42,
|
264 |
+
normalize_y=False # we're manually normalizing
|
265 |
+
)
|
266 |
+
|
267 |
+
try:
|
268 |
+
gpr.fit(X_log, y_normalized)
|
269 |
+
except np.linalg.LinAlgError:
|
270 |
+
print("GPR fitting failed. Returning best observed value.")
|
271 |
+
best_idx = np.argmin(y)
|
272 |
+
return {
|
273 |
+
'gpr_optimal_lr': float(X[best_idx][0]),
|
274 |
+
'ei_optimal_lr': float(X[best_idx][0]),
|
275 |
+
'predicted_loss': float(y[best_idx]),
|
276 |
+
'uncertainty': float('inf')
|
277 |
+
}
|
278 |
+
|
279 |
+
# Create fine grid of points for prediction
|
280 |
+
X_pred_log = np.linspace(np.log10(X.min()), np.log10(X.max()), 1000).reshape(-1, 1)
|
281 |
+
|
282 |
+
# Predict mean and std
|
283 |
+
y_pred_normalized, sigma = gpr.predict(X_pred_log, return_std=True)
|
284 |
+
|
285 |
+
# Denormalize predictions
|
286 |
+
y_pred = y_pred_normalized * y_std + y_mean
|
287 |
+
sigma = sigma * y_std
|
288 |
+
|
289 |
+
# Find the point with lowest predicted value
|
290 |
+
best_idx = np.argmin(y_pred)
|
291 |
+
optimal_lr = 10 ** X_pred_log[best_idx, 0]
|
292 |
+
|
293 |
+
# Calculate acquisition function (Expected Improvement)
|
294 |
+
best_f = np.min(y)
|
295 |
+
Z = (best_f - y_pred) / (sigma + 1e-9) # add small constant to prevent division by zero
|
296 |
+
ei = sigma * (Z * norm.cdf(Z) + norm.pdf(Z))
|
297 |
+
|
298 |
+
# Find point with highest expected improvement
|
299 |
+
ei_best_idx = np.argmax(ei)
|
300 |
+
ei_optimal_lr = 10 ** X_pred_log[ei_best_idx, 0]
|
301 |
+
|
302 |
+
return {
|
303 |
+
'gpr_optimal_lr': float(optimal_lr),
|
304 |
+
'ei_optimal_lr': float(ei_optimal_lr),
|
305 |
+
'predicted_loss': float(y_pred[best_idx]),
|
306 |
+
'uncertainty': float(sigma[best_idx])
|
307 |
+
}
|
308 |
+
|
309 |
+
except Exception as e:
|
310 |
+
print(f"Optimization failed with error: {e}")
|
311 |
+
return {
|
312 |
+
'gpr_optimal_lr': 2e-5, # default fallback
|
313 |
+
'ei_optimal_lr': 2e-5,
|
314 |
+
'predicted_loss': float('inf'),
|
315 |
+
'uncertainty': float('inf')
|
316 |
+
}
|
317 |
+
|
318 |
+
# Run final optimization and handle potential failures
|
319 |
+
try:
|
320 |
+
final_optimization = optimize_final_lr(study)
|
321 |
+
print("\nAdvanced Optimization Results:")
|
322 |
+
print(f"GPR Optimal Learning Rate: {final_optimization['gpr_optimal_lr']:.2e}")
|
323 |
+
print(f"Expected Improvement Optimal Learning Rate: {final_optimization['ei_optimal_lr']:.2e}")
|
324 |
+
print(f"Predicted Loss: {final_optimization['predicted_loss']:.4f}")
|
325 |
+
print(f"Uncertainty: {final_optimization['uncertainty']:.4f}")
|
326 |
+
except Exception as e:
|
327 |
+
print(f"Final optimization failed: {e}")
|
328 |
+
final_optimization = {
|
329 |
+
'gpr_optimal_lr': 2e-5,
|
330 |
+
'ei_optimal_lr': 2e-5,
|
331 |
+
'predicted_loss': float('inf'),
|
332 |
+
'uncertainty': float('inf')
|
333 |
+
}
|
334 |
+
|
335 |
+
# Save extended results
|
336 |
+
results.update({
|
337 |
+
"gpr_optimal_lr": float(final_optimization['gpr_optimal_lr']),
|
338 |
+
"ei_optimal_lr": float(final_optimization['ei_optimal_lr']),
|
339 |
+
"predicted_loss": float(final_optimization['predicted_loss']),
|
340 |
+
"uncertainty": float(final_optimization['uncertainty'])
|
341 |
+
})
|
342 |
+
|
343 |
+
# Visualization of the GPR results
|
344 |
+
def plot_gpr_results(study, final_optimization):
|
345 |
+
# Extract data and filter out infinite values
|
346 |
+
X = np.array([[trial.params['learning_rate']] for trial in study.trials])
|
347 |
+
y = np.array([trial.value for trial in study.trials])
|
348 |
+
|
349 |
+
# Create mask for finite values
|
350 |
+
finite_mask = np.isfinite(y)
|
351 |
+
X = X[finite_mask]
|
352 |
+
y = y[finite_mask]
|
353 |
+
|
354 |
+
# Check if we have enough valid points
|
355 |
+
if len(X) < 2:
|
356 |
+
print("Not enough valid points for GPR visualization")
|
357 |
+
return
|
358 |
+
|
359 |
+
# Create prediction points
|
360 |
+
X_pred = np.logspace(np.log10(X.min()), np.log10(X.max()), 100).reshape(-1, 1)
|
361 |
+
X_pred_log = np.log10(X_pred)
|
362 |
+
|
363 |
+
# Fit GPR for plotting
|
364 |
+
kernel = ConstantKernel(1.0) * Matern(length_scale=1.0, nu=2.5)
|
365 |
+
gpr = GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=10, random_state=42)
|
366 |
+
gpr.fit(np.log10(X), y)
|
367 |
+
|
368 |
+
# Predict mean and std
|
369 |
+
y_pred, sigma = gpr.predict(X_pred_log, return_std=True)
|
370 |
+
|
371 |
+
plt.figure(figsize=(12, 6))
|
372 |
+
plt.semilogx(X, y, 'ko', label='Valid Trials', markersize=8)
|
373 |
+
plt.semilogx(X_pred, y_pred, 'b-', label='GPR Mean')
|
374 |
+
plt.fill_between(X_pred.ravel(),
|
375 |
+
y_pred - 2*sigma,
|
376 |
+
y_pred + 2*sigma,
|
377 |
+
color='blue',
|
378 |
+
alpha=0.2,
|
379 |
+
label='95% Confidence')
|
380 |
+
|
381 |
+
# Only plot optimal lines if they are finite
|
382 |
+
if np.isfinite(final_optimization['gpr_optimal_lr']):
|
383 |
+
plt.axvline(final_optimization['gpr_optimal_lr'], color='r', linestyle='--',
|
384 |
+
label='GPR Optimal LR')
|
385 |
+
if np.isfinite(final_optimization['ei_optimal_lr']):
|
386 |
+
plt.axvline(final_optimization['ei_optimal_lr'], color='g', linestyle='--',
|
387 |
+
label='EI Optimal LR')
|
388 |
+
|
389 |
+
plt.xlabel('Learning Rate')
|
390 |
+
plt.ylabel('Loss')
|
391 |
+
plt.title('Learning Rate Optimization Results with GPR')
|
392 |
+
plt.legend()
|
393 |
+
plt.grid(True)
|
394 |
+
plt.savefig('lr_optimization_plot.png', dpi=300, bbox_inches='tight')
|
395 |
+
plt.close()
|
396 |
+
|
397 |
+
plot_gpr_results(study, final_optimization)
|
398 |
+
|
399 |
+
# Save all results
|
400 |
+
with open("lr_optimization_results.json", "w") as f:
|
401 |
+
json.dump(results, f, indent=4)
|