File size: 13,666 Bytes
17db2ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90f46ea
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
import optuna
from transformers import (
    AutoTokenizer, AutoModelForCausalLM, TrainingArguments,
    Trainer, DataCollatorForLanguageModeling
)
import torch
from datasets import load_dataset
import numpy as np
import gc
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import ConstantKernel, Matern
import matplotlib.pyplot as plt
from scipy.stats import norm
import warnings
warnings.filterwarnings('ignore', category=UserWarning)

from transformers import TrainerCallback

import argparse

# Configuration parameters
num_trials = 10  # Adjust this value to control the number of optimization trials
DATASET = load_dataset("BramVanroy/dolly-15k-dutch", split="train_sft[:1000]")
CONTEXT_WINDOW = 1024

# Initialize tokenizer once
tokenizer = AutoTokenizer.from_pretrained("Zyphra/Zamba2-1.2B")
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

def prepare_chat_format(examples):
    chats = []
    for messages in examples['messages']:
        try:
            chat = tokenizer.apply_chat_template(
                messages,
                tokenize=True,
                max_length=CONTEXT_WINDOW,
                truncation=True,
                return_tensors=None
            )
            chats.append(chat)
        except Exception as e:
            print(f"Error applying chat template: {e}")
            print("Fallback format if chat template fails")
            text = ""
            for message in messages:
                role = message["role"]
                content = message["content"]
                text += f"<|{role}|>\n{content}</s>\n"
            
            chat = tokenizer(
                text,
                max_length=CONTEXT_WINDOW,
                truncation=True,
                return_tensors=None
            )["input_ids"]
            
            chats.append(chat)
    return {"input_ids": chats}

# Prepare dataset once
tokenized_dataset = DATASET.map(
    prepare_chat_format,
    batched=True,
    remove_columns=DATASET.column_names
)

def clear_memory():
    """Clear GPU memory between trials"""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()

class LossCallback(TrainerCallback):
    def __init__(self):
        self.losses = []
        
    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is not None and "loss" in logs:
            self.losses.append(logs["loss"])

def objective(trial):
    # Clear memory from previous trial
    clear_memory()
    
    lr = trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True)
    
    # Initialize model with fresh state
    torch.manual_seed(42)
    model = AutoModelForCausalLM.from_pretrained(
        "Zyphra/Zamba2-1.2B",
        torch_dtype=torch.bfloat16,
        device_map="auto"
    )
    model.config.pad_token_id = tokenizer.pad_token_id

    # Calculate steps with larger batch size
    batch_size = 4  # Increased from 1
    grad_accum_steps = 8  # Decreased from 32 since we increased batch size
    effective_batch_size = batch_size * grad_accum_steps  # Still 32 total
    total_steps = len(tokenized_dataset) // effective_batch_size
    
    # Training arguments
    training_args = TrainingArguments(
        output_dir=f"./optuna_runs/trial_{trial.number}",
        num_train_epochs=1,
        per_device_train_batch_size=batch_size,  # Increased
        gradient_accumulation_steps=grad_accum_steps,  # Decreased
        logging_steps=max(total_steps // 20, 1),
        learning_rate=lr,
        weight_decay=0.01,
        fp16=False,
        bf16=True,
        warmup_steps=total_steps // 10,
        save_steps=1000000,
        save_total_limit=None,
        report_to="none",
        seed=42,
        dataloader_num_workers=4,  # Added for faster data loading
        gradient_checkpointing=True,  # Added to optimize memory usage
        max_grad_norm=1.0  # Added for stability
    )

    print(f"\nTrial {trial.number}:")
    print(f"Learning rate: {lr}")
    print(f"Total steps: {total_steps}")
    print(f"Logging every {training_args.logging_steps} steps")

    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=False
    )

    class CustomTrainer(Trainer):
        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            self.model = model
            
        def _move_model_to_device(self, model, device):
            pass

    # Initialize callback
    loss_callback = LossCallback()

    trainer = CustomTrainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_dataset,
        data_collator=data_collator,
        callbacks=[loss_callback]  # Use the proper callback
    )

    try:
        train_result = trainer.train()
        
        # Calculate mean of last 20% of losses
        losses = loss_callback.losses  # Get losses from callback
        n_losses = max(len(losses) // 5, 1)
        final_losses = losses[-n_losses:]
        mean_loss = np.mean(final_losses) if final_losses else float('inf')
        
        # Clean up
        del model
        del trainer
        clear_memory()
        
        return mean_loss
    
    except Exception as e:
        print(f"Trial failed with error: {e}")
        # Clean up on failure
        del model
        del trainer
        clear_memory()
        return float('inf')

# Create and run the study
study = optuna.create_study(
    direction="minimize",
    sampler=optuna.samplers.TPESampler(seed=42),
    study_name="learning_rate_optimization"
)

study.optimize(objective, n_trials=num_trials)

# Print results
print(f"\nOptimization Results ({num_trials} trials):")
print("Best learning rate:", study.best_params["learning_rate"])
print("Best loss:", study.best_value)
print("\nAll trials:")
for trial in study.trials:
    print(f"Learning rate: {trial.params['learning_rate']:.2e}, Loss: {trial.value:.4f}")

# Save results
import json
results = {
    "best_learning_rate": study.best_params["learning_rate"],
    "best_loss": study.best_value,
    "all_trials": [(trial.params["learning_rate"], trial.value) for trial in study.trials]
}
with open("lr_optimization_results.json", "w") as f:
    json.dump(results, f, indent=4)

# Plot optimization history
try:
    fig = optuna.visualization.plot_optimization_history(study)
    fig.show()
except Exception as e:
    print(f"Could not create visualization: {e}")

# Add sophisticated final optimization using Gaussian Process Regression
def optimize_final_lr(study):
    try:
        # Extract learning rates and losses
        X = np.array([[trial.params['learning_rate']] for trial in study.trials])
        y = np.array([trial.value for trial in study.trials])
        
        # Check if we have any valid results
        valid_mask = np.isfinite(y)
        if not np.any(valid_mask):
            print("No valid trials found. Returning default learning rate.")
            return {
                'gpr_optimal_lr': 2e-5,  # default fallback
                'ei_optimal_lr': 2e-5,
                'predicted_loss': float('inf'),
                'uncertainty': float('inf')
            }
        
        # Filter out infinite values
        X = X[valid_mask]
        y = y[valid_mask]
        
        # Ensure we have enough points for fitting
        if len(X) < 2:
            print("Not enough valid trials for GPR. Returning best observed value.")
            best_idx = np.argmin(y)
            return {
                'gpr_optimal_lr': float(X[best_idx][0]),
                'ei_optimal_lr': float(X[best_idx][0]),
                'predicted_loss': float(y[best_idx]),
                'uncertainty': float('inf')
            }
        
        # Transform to log space
        X_log = np.log10(X)
        
        # Normalize y values
        y_mean = np.mean(y)
        y_std = np.std(y)
        if y_std == 0:
            y_std = 1
        y_normalized = (y - y_mean) / y_std
        
        # Define kernel
        kernel = ConstantKernel(1.0) * Matern(length_scale=1.0, nu=2.5)
        
        # Fit Gaussian Process
        gpr = GaussianProcessRegressor(
            kernel=kernel,
            n_restarts_optimizer=10,
            random_state=42,
            normalize_y=False  # we're manually normalizing
        )
        
        try:
            gpr.fit(X_log, y_normalized)
        except np.linalg.LinAlgError:
            print("GPR fitting failed. Returning best observed value.")
            best_idx = np.argmin(y)
            return {
                'gpr_optimal_lr': float(X[best_idx][0]),
                'ei_optimal_lr': float(X[best_idx][0]),
                'predicted_loss': float(y[best_idx]),
                'uncertainty': float('inf')
            }
        
        # Create fine grid of points for prediction
        X_pred_log = np.linspace(np.log10(X.min()), np.log10(X.max()), 1000).reshape(-1, 1)
        
        # Predict mean and std
        y_pred_normalized, sigma = gpr.predict(X_pred_log, return_std=True)
        
        # Denormalize predictions
        y_pred = y_pred_normalized * y_std + y_mean
        sigma = sigma * y_std
        
        # Find the point with lowest predicted value
        best_idx = np.argmin(y_pred)
        optimal_lr = 10 ** X_pred_log[best_idx, 0]
        
        # Calculate acquisition function (Expected Improvement)
        best_f = np.min(y)
        Z = (best_f - y_pred) / (sigma + 1e-9)  # add small constant to prevent division by zero
        ei = sigma * (Z * norm.cdf(Z) + norm.pdf(Z))
        
        # Find point with highest expected improvement
        ei_best_idx = np.argmax(ei)
        ei_optimal_lr = 10 ** X_pred_log[ei_best_idx, 0]
        
        return {
            'gpr_optimal_lr': float(optimal_lr),
            'ei_optimal_lr': float(ei_optimal_lr),
            'predicted_loss': float(y_pred[best_idx]),
            'uncertainty': float(sigma[best_idx])
        }
        
    except Exception as e:
        print(f"Optimization failed with error: {e}")
        return {
            'gpr_optimal_lr': 2e-5,  # default fallback
            'ei_optimal_lr': 2e-5,
            'predicted_loss': float('inf'),
            'uncertainty': float('inf')
        }

# Run final optimization and handle potential failures
try:
    final_optimization = optimize_final_lr(study)
    print("\nAdvanced Optimization Results:")
    print(f"GPR Optimal Learning Rate: {final_optimization['gpr_optimal_lr']:.2e}")
    print(f"Expected Improvement Optimal Learning Rate: {final_optimization['ei_optimal_lr']:.2e}")
    print(f"Predicted Loss: {final_optimization['predicted_loss']:.4f}")
    print(f"Uncertainty: {final_optimization['uncertainty']:.4f}")
except Exception as e:
    print(f"Final optimization failed: {e}")
    final_optimization = {
        'gpr_optimal_lr': 2e-5,
        'ei_optimal_lr': 2e-5,
        'predicted_loss': float('inf'),
        'uncertainty': float('inf')
    }

# Save extended results
results.update({
    "gpr_optimal_lr": float(final_optimization['gpr_optimal_lr']),
    "ei_optimal_lr": float(final_optimization['ei_optimal_lr']),
    "predicted_loss": float(final_optimization['predicted_loss']),
    "uncertainty": float(final_optimization['uncertainty'])
})

# Visualization of the GPR results
def plot_gpr_results(study, final_optimization):
    # Extract data and filter out infinite values
    X = np.array([[trial.params['learning_rate']] for trial in study.trials])
    y = np.array([trial.value for trial in study.trials])
    
    # Create mask for finite values
    finite_mask = np.isfinite(y)
    X = X[finite_mask]
    y = y[finite_mask]
    
    # Check if we have enough valid points
    if len(X) < 2:
        print("Not enough valid points for GPR visualization")
        return
    
    # Create prediction points
    X_pred = np.logspace(np.log10(X.min()), np.log10(X.max()), 100).reshape(-1, 1)
    X_pred_log = np.log10(X_pred)
    
    # Fit GPR for plotting
    kernel = ConstantKernel(1.0) * Matern(length_scale=1.0, nu=2.5)
    gpr = GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=10, random_state=42)
    gpr.fit(np.log10(X), y)
    
    # Predict mean and std
    y_pred, sigma = gpr.predict(X_pred_log, return_std=True)
    
    plt.figure(figsize=(12, 6))
    plt.semilogx(X, y, 'ko', label='Valid Trials', markersize=8)
    plt.semilogx(X_pred, y_pred, 'b-', label='GPR Mean')
    plt.fill_between(X_pred.ravel(), 
                    y_pred - 2*sigma, 
                    y_pred + 2*sigma, 
                    color='blue', 
                    alpha=0.2, 
                    label='95% Confidence')
    
    # Only plot optimal lines if they are finite
    if np.isfinite(final_optimization['gpr_optimal_lr']):
        plt.axvline(final_optimization['gpr_optimal_lr'], color='r', linestyle='--', 
                    label='GPR Optimal LR')
    if np.isfinite(final_optimization['ei_optimal_lr']):
        plt.axvline(final_optimization['ei_optimal_lr'], color='g', linestyle='--', 
                    label='EI Optimal LR')
    
    plt.xlabel('Learning Rate')
    plt.ylabel('Loss')
    plt.title('Learning Rate Optimization Results with GPR')
    plt.legend()
    plt.grid(True)
    plt.savefig('lr_optimization_plot.png', dpi=300, bbox_inches='tight')
    plt.close()

plot_gpr_results(study, final_optimization)

# Save all results
with open("lr_optimization_results.json", "w") as f:
    json.dump(results, f, indent=4)

# Store best learning rate as a variable for finetune.py to use
best_lr = study.best_params["learning_rate"]