|
# Title: Reinforcement Learning for Dynamic Learning Rate Adaptation in Transformer Training |
|
# Experiment description: 1. Implement a simpler RL method (e.g., Q-learning) that takes the current state (e.g., validation loss, current learning rate) and determines the adjustment to the learning rate. 2. Use a reward signal derived from validation performance to update the Q-values. 3. Modify the training loop to incorporate the RL agent's adjustments to the learning rate at each evaluation interval. 4. Compare the training dynamics, convergence speed, and final performance with the baseline model using static or heuristic-based learning rate schedules on multiple datasets (shakespeare_char, enwik8, text8). |
|
## Run 0: Baseline |
|
Results: {'shakespeare_char': {'final_train_loss_mean': 0.8186181902885437, 'best_val_loss_mean': 1.4654763221740723, 'total_train_time_mean': 77.26942734718322, 'avg_inference_tokens_per_second_mean': 666.5076153519527}, 'enwik8': {'final_train_loss_mean': 0.930223822593689, 'best_val_loss_mean': 1.0055421590805054, 'total_train_time_mean': 819.4551751613617, 'avg_inference_tokens_per_second_mean': 671.9918599180683}, 'text8': {'final_train_loss_mean': 1.0013301372528076, 'best_val_loss_mean': 0.979989230632782, 'total_train_time_mean': 801.224205493927, 'avg_inference_tokens_per_second_mean': 671.5678332249411}} |
|
Description: Baseline results. |
|
|
|
Plot Descriptions: |
|
1. Training Loss Across Runs for shakespeare_char Dataset: This plot shows the training loss over iterations for the shakespeare_char dataset across different runs. It helps in understanding how the training loss decreases over time for each run. Filename: train_loss_shakespeare_char.png |
|
2. Validation Loss Across Runs for shakespeare_char Dataset: This plot shows the validation loss over iterations for the shakespeare_char dataset across different runs. It helps in understanding how the validation loss decreases over time for each run. Filename: val_loss_shakespeare_char.png |
|
3. Training Loss Across Runs for enwik8 Dataset: This plot shows the training loss over iterations for the enwik8 dataset across different runs. It helps in understanding how the training loss decreases over time for each run. Filename: train_loss_enwik8.png |
|
4. Validation Loss Across Runs for enwik8 Dataset: This plot shows the validation loss over iterations for the enwik8 dataset across different runs. It helps in understanding how the validation loss decreases over time for each run. Filename: val_loss_enwik8.png |
|
5. Training Loss Across Runs for text8 Dataset: This plot shows the training loss over iterations for the text8 dataset across different runs. It helps in understanding how the training loss decreases over time for each run. Filename: train_loss_text8.png |
|
6. Validation Loss Across Runs for text8 Dataset: This plot shows the validation loss over iterations for the text8 dataset across different runs. It helps in understanding how the validation loss decreases over time for each run. Filename: val_loss_text8.png |
|
|
|
## Run 1: Q-learning with initial learning rate adaptation |
|
Results: {'shakespeare_char': {'final_train_loss_mean': 0.8112714489301046, 'best_val_loss_mean': 1.4664853016535442, 'total_train_time_mean': 76.33582202593486, 'avg_inference_tokens_per_second_mean': 680.220956113138}, 'enwik8': {'final_train_loss_mean': 0.9324554204940796, 'best_val_loss_mean': 1.0050768852233887, 'total_train_time_mean': 799.199625492096, 'avg_inference_tokens_per_second_mean': 690.1664700419294}, 'text8': {'final_train_loss_mean': 0.9926028251647949, 'best_val_loss_mean': 0.9795507192611694, 'total_train_time_mean': 796.1075961589813, 'avg_inference_tokens_per_second_mean': 691.9504174462957}} |
|
Description: This run implemented Q-learning for dynamic learning rate adaptation. The Q-learning agent adjusted the learning rate based on the current state (validation loss, current learning rate) and used a reward signal derived from validation performance to update the Q-values. The training loop was modified to incorporate the RL agent's adjustments to the learning rate at each evaluation interval. The results were compared with the baseline model using static or heuristic-based learning rate schedules on multiple datasets (shakespeare_char, enwik8, text8). |
|
|
|
## Run 2: Q-learning with different initial learning rates |
|
Results: {'shakespeare_char': {'final_train_loss_mean': 0.8047561645507812, 'best_val_loss_mean': 1.4602874914805095, 'total_train_time_mean': 76.26222737630208, 'avg_inference_tokens_per_second_mean': 675.5019470493302}, 'enwik8': {'final_train_loss_mean': 0.9224221706390381, 'best_val_loss_mean': 0.9933806657791138, 'total_train_time_mean': 806.1875951290131, 'avg_inference_tokens_per_second_mean': 682.6881990162254}, 'text8': {'final_train_loss_mean': 0.9798105955123901, 'best_val_loss_mean': 0.9613448977470398, 'total_train_time_mean': 807.7686207294464, 'avg_inference_tokens_per_second_mean': 652.3187905322042}} |
|
Description: This run implemented Q-learning for dynamic learning rate adaptation with different initial learning rates. The initial learning rate was set to 2e-3 for shakespeare_char and 1e-3 for enwik8 and text8. The Q-learning agent adjusted the learning rate based on the current state (validation loss, current learning rate) and used a reward signal derived from validation performance to update the Q-values. The training loop was modified to incorporate the RL agent's adjustments to the learning rate at each evaluation interval. The results were compared with the baseline model using static or heuristic-based learning rate schedules on multiple datasets (shakespeare_char, enwik8, text8). |
|
|
|
## Run 3: Q-learning with reward signal based on improvement in validation loss |
|
Results: {'shakespeare_char': {'final_train_loss_mean': 0.8062439958254496, 'best_val_loss_mean': 1.461962143580119, 'total_train_time_mean': 75.80110216140747, 'avg_inference_tokens_per_second_mean': 668.3102066342188}, 'enwik8': {'final_train_loss_mean': 0.9246289730072021, 'best_val_loss_mean': 0.9944368004798889, 'total_train_time_mean': 796.9592888355255, 'avg_inference_tokens_per_second_mean': 688.6266631351763}, 'text8': {'final_train_loss_mean': 0.9843199849128723, 'best_val_loss_mean': 0.961367666721344, 'total_train_time_mean': 791.6123127937317, 'avg_inference_tokens_per_second_mean': 658.961942825521}} |
|
Description: This run implemented Q-learning for dynamic learning rate adaptation with a different reward signal. Instead of using the negative validation loss as the reward, the improvement in validation loss was used. The Q-learning agent adjusted the learning rate based on the current state (validation loss, current learning rate) and used the reward signal derived from the improvement in validation performance to update the Q-values. The training loop was modified to incorporate the RL agent's adjustments to the learning rate at each evaluation interval. The results were compared with the baseline model using static or heuristic-based learning rate schedules on multiple datasets (shakespeare_char, enwik8, text8). |
|
|
|
## Run 4: Q-learning with epsilon decay strategy |
|
Results: {'shakespeare_char': {'final_train_loss_mean': 0.7984780073165894, 'best_val_loss_mean': 1.463551680246989, 'total_train_time_mean': 79.24612506230672, 'avg_inference_tokens_per_second_mean': 617.9132836431749}, 'enwik8': {'final_train_loss_mean': 0.925983190536499, 'best_val_loss_mean': 0.9917866587638855, 'total_train_time_mean': 852.1484353542328, 'avg_inference_tokens_per_second_mean': 605.0617699125265}, 'text8': {'final_train_loss_mean': 0.9827583432197571, 'best_val_loss_mean': 0.9615200161933899, 'total_train_time_mean': 846.4471461772919, 'avg_inference_tokens_per_second_mean': 613.2623906747798}} |
|
Description: This run implemented Q-learning for dynamic learning rate adaptation with an epsilon decay strategy. Instead of using a fixed epsilon value for exploration, an epsilon decay strategy was used where epsilon decreases over time. The Q-learning agent adjusted the learning rate based on the current state (validation loss, current learning rate) and used a reward signal derived from validation performance to update the Q-values. The training loop was modified to incorporate the RL agent's adjustments to the learning rate at each evaluation interval. The results were compared with the baseline model using static or heuristic-based learning rate schedules on multiple datasets (shakespeare_char, enwik8, text8). |
|
|