import optuna from transformers import ( AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling ) import torch from datasets import load_dataset import numpy as np import gc from sklearn.gaussian_process import GaussianProcessRegressor from sklearn.gaussian_process.kernels import ConstantKernel, Matern import matplotlib.pyplot as plt from scipy.stats import norm import warnings warnings.filterwarnings('ignore', category=UserWarning) from transformers import TrainerCallback import argparse # Configuration parameters num_trials = 10 # Adjust this value to control the number of optimization trials DATASET = load_dataset("BramVanroy/dolly-15k-dutch", split="train_sft[:1000]") CONTEXT_WINDOW = 1024 # Initialize tokenizer once tokenizer = AutoTokenizer.from_pretrained("Zyphra/Zamba2-1.2B") if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "left" def prepare_chat_format(examples): chats = [] for messages in examples['messages']: try: chat = tokenizer.apply_chat_template( messages, tokenize=True, max_length=CONTEXT_WINDOW, truncation=True, return_tensors=None ) chats.append(chat) except Exception as e: print(f"Error applying chat template: {e}") print("Fallback format if chat template fails") text = "" for message in messages: role = message["role"] content = message["content"] text += f"<|{role}|>\n{content}\n" chat = tokenizer( text, max_length=CONTEXT_WINDOW, truncation=True, return_tensors=None )["input_ids"] chats.append(chat) return {"input_ids": chats} # Prepare dataset once tokenized_dataset = DATASET.map( prepare_chat_format, batched=True, remove_columns=DATASET.column_names ) def clear_memory(): """Clear GPU memory between trials""" if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() class LossCallback(TrainerCallback): def __init__(self): self.losses = [] def on_log(self, args, state, control, logs=None, **kwargs): if logs is not None and "loss" in logs: self.losses.append(logs["loss"]) def objective(trial): # Clear memory from previous trial clear_memory() lr = trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True) # Initialize model with fresh state torch.manual_seed(42) model = AutoModelForCausalLM.from_pretrained( "Zyphra/Zamba2-1.2B", torch_dtype=torch.bfloat16, device_map="auto" ) model.config.pad_token_id = tokenizer.pad_token_id # Calculate steps with larger batch size batch_size = 4 # Increased from 1 grad_accum_steps = 8 # Decreased from 32 since we increased batch size effective_batch_size = batch_size * grad_accum_steps # Still 32 total total_steps = len(tokenized_dataset) // effective_batch_size # Training arguments training_args = TrainingArguments( output_dir=f"./optuna_runs/trial_{trial.number}", num_train_epochs=1, per_device_train_batch_size=batch_size, # Increased gradient_accumulation_steps=grad_accum_steps, # Decreased logging_steps=max(total_steps // 20, 1), learning_rate=lr, weight_decay=0.01, fp16=False, bf16=True, warmup_steps=total_steps // 10, save_steps=1000000, save_total_limit=None, report_to="none", seed=42, dataloader_num_workers=4, # Added for faster data loading gradient_checkpointing=True, # Added to optimize memory usage max_grad_norm=1.0 # Added for stability ) print(f"\nTrial {trial.number}:") print(f"Learning rate: {lr}") print(f"Total steps: {total_steps}") print(f"Logging every {training_args.logging_steps} steps") data_collator = DataCollatorForLanguageModeling( tokenizer=tokenizer, mlm=False ) class CustomTrainer(Trainer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.model = model def _move_model_to_device(self, model, device): pass # Initialize callback loss_callback = LossCallback() trainer = CustomTrainer( model=model, args=training_args, train_dataset=tokenized_dataset, data_collator=data_collator, callbacks=[loss_callback] # Use the proper callback ) try: train_result = trainer.train() # Calculate mean of last 20% of losses losses = loss_callback.losses # Get losses from callback n_losses = max(len(losses) // 5, 1) final_losses = losses[-n_losses:] mean_loss = np.mean(final_losses) if final_losses else float('inf') # Clean up del model del trainer clear_memory() return mean_loss except Exception as e: print(f"Trial failed with error: {e}") # Clean up on failure del model del trainer clear_memory() return float('inf') # Create and run the study study = optuna.create_study( direction="minimize", sampler=optuna.samplers.TPESampler(seed=42), study_name="learning_rate_optimization" ) study.optimize(objective, n_trials=num_trials) # Print results print(f"\nOptimization Results ({num_trials} trials):") print("Best learning rate:", study.best_params["learning_rate"]) print("Best loss:", study.best_value) print("\nAll trials:") for trial in study.trials: print(f"Learning rate: {trial.params['learning_rate']:.2e}, Loss: {trial.value:.4f}") # Save results import json results = { "best_learning_rate": study.best_params["learning_rate"], "best_loss": study.best_value, "all_trials": [(trial.params["learning_rate"], trial.value) for trial in study.trials] } with open("lr_optimization_results.json", "w") as f: json.dump(results, f, indent=4) # Plot optimization history try: fig = optuna.visualization.plot_optimization_history(study) fig.show() except Exception as e: print(f"Could not create visualization: {e}") # Add sophisticated final optimization using Gaussian Process Regression def optimize_final_lr(study): try: # Extract learning rates and losses X = np.array([[trial.params['learning_rate']] for trial in study.trials]) y = np.array([trial.value for trial in study.trials]) # Check if we have any valid results valid_mask = np.isfinite(y) if not np.any(valid_mask): print("No valid trials found. Returning default learning rate.") return { 'gpr_optimal_lr': 2e-5, # default fallback 'ei_optimal_lr': 2e-5, 'predicted_loss': float('inf'), 'uncertainty': float('inf') } # Filter out infinite values X = X[valid_mask] y = y[valid_mask] # Ensure we have enough points for fitting if len(X) < 2: print("Not enough valid trials for GPR. Returning best observed value.") best_idx = np.argmin(y) return { 'gpr_optimal_lr': float(X[best_idx][0]), 'ei_optimal_lr': float(X[best_idx][0]), 'predicted_loss': float(y[best_idx]), 'uncertainty': float('inf') } # Transform to log space X_log = np.log10(X) # Normalize y values y_mean = np.mean(y) y_std = np.std(y) if y_std == 0: y_std = 1 y_normalized = (y - y_mean) / y_std # Define kernel kernel = ConstantKernel(1.0) * Matern(length_scale=1.0, nu=2.5) # Fit Gaussian Process gpr = GaussianProcessRegressor( kernel=kernel, n_restarts_optimizer=10, random_state=42, normalize_y=False # we're manually normalizing ) try: gpr.fit(X_log, y_normalized) except np.linalg.LinAlgError: print("GPR fitting failed. Returning best observed value.") best_idx = np.argmin(y) return { 'gpr_optimal_lr': float(X[best_idx][0]), 'ei_optimal_lr': float(X[best_idx][0]), 'predicted_loss': float(y[best_idx]), 'uncertainty': float('inf') } # Create fine grid of points for prediction X_pred_log = np.linspace(np.log10(X.min()), np.log10(X.max()), 1000).reshape(-1, 1) # Predict mean and std y_pred_normalized, sigma = gpr.predict(X_pred_log, return_std=True) # Denormalize predictions y_pred = y_pred_normalized * y_std + y_mean sigma = sigma * y_std # Find the point with lowest predicted value best_idx = np.argmin(y_pred) optimal_lr = 10 ** X_pred_log[best_idx, 0] # Calculate acquisition function (Expected Improvement) best_f = np.min(y) Z = (best_f - y_pred) / (sigma + 1e-9) # add small constant to prevent division by zero ei = sigma * (Z * norm.cdf(Z) + norm.pdf(Z)) # Find point with highest expected improvement ei_best_idx = np.argmax(ei) ei_optimal_lr = 10 ** X_pred_log[ei_best_idx, 0] return { 'gpr_optimal_lr': float(optimal_lr), 'ei_optimal_lr': float(ei_optimal_lr), 'predicted_loss': float(y_pred[best_idx]), 'uncertainty': float(sigma[best_idx]) } except Exception as e: print(f"Optimization failed with error: {e}") return { 'gpr_optimal_lr': 2e-5, # default fallback 'ei_optimal_lr': 2e-5, 'predicted_loss': float('inf'), 'uncertainty': float('inf') } # Run final optimization and handle potential failures try: final_optimization = optimize_final_lr(study) print("\nAdvanced Optimization Results:") print(f"GPR Optimal Learning Rate: {final_optimization['gpr_optimal_lr']:.2e}") print(f"Expected Improvement Optimal Learning Rate: {final_optimization['ei_optimal_lr']:.2e}") print(f"Predicted Loss: {final_optimization['predicted_loss']:.4f}") print(f"Uncertainty: {final_optimization['uncertainty']:.4f}") except Exception as e: print(f"Final optimization failed: {e}") final_optimization = { 'gpr_optimal_lr': 2e-5, 'ei_optimal_lr': 2e-5, 'predicted_loss': float('inf'), 'uncertainty': float('inf') } # Save extended results results.update({ "gpr_optimal_lr": float(final_optimization['gpr_optimal_lr']), "ei_optimal_lr": float(final_optimization['ei_optimal_lr']), "predicted_loss": float(final_optimization['predicted_loss']), "uncertainty": float(final_optimization['uncertainty']) }) # Visualization of the GPR results def plot_gpr_results(study, final_optimization): # Extract data and filter out infinite values X = np.array([[trial.params['learning_rate']] for trial in study.trials]) y = np.array([trial.value for trial in study.trials]) # Create mask for finite values finite_mask = np.isfinite(y) X = X[finite_mask] y = y[finite_mask] # Check if we have enough valid points if len(X) < 2: print("Not enough valid points for GPR visualization") return # Create prediction points X_pred = np.logspace(np.log10(X.min()), np.log10(X.max()), 100).reshape(-1, 1) X_pred_log = np.log10(X_pred) # Fit GPR for plotting kernel = ConstantKernel(1.0) * Matern(length_scale=1.0, nu=2.5) gpr = GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=10, random_state=42) gpr.fit(np.log10(X), y) # Predict mean and std y_pred, sigma = gpr.predict(X_pred_log, return_std=True) plt.figure(figsize=(12, 6)) plt.semilogx(X, y, 'ko', label='Valid Trials', markersize=8) plt.semilogx(X_pred, y_pred, 'b-', label='GPR Mean') plt.fill_between(X_pred.ravel(), y_pred - 2*sigma, y_pred + 2*sigma, color='blue', alpha=0.2, label='95% Confidence') # Only plot optimal lines if they are finite if np.isfinite(final_optimization['gpr_optimal_lr']): plt.axvline(final_optimization['gpr_optimal_lr'], color='r', linestyle='--', label='GPR Optimal LR') if np.isfinite(final_optimization['ei_optimal_lr']): plt.axvline(final_optimization['ei_optimal_lr'], color='g', linestyle='--', label='EI Optimal LR') plt.xlabel('Learning Rate') plt.ylabel('Loss') plt.title('Learning Rate Optimization Results with GPR') plt.legend() plt.grid(True) plt.savefig('lr_optimization_plot.png', dpi=300, bbox_inches='tight') plt.close() plot_gpr_results(study, final_optimization) # Save all results with open("lr_optimization_results.json", "w") as f: json.dump(results, f, indent=4)