70b Distillation Experiment
This is not the full-fledged run that I plan to do for a large scale distillation of Llama3 70b. Instead, it's a preliminary test train of the custom distillation trainer, where we target KL divergence from the larger Llama3 70b teacher model onto 4x8b (the student). I'm releasing it here mainly so that people who are interested can tinker with it / finetune it to see how it behaves before I am ready to do a larger run.
Training details
Each of the 8b expert MLP layers is duplicated 3x from the original Llama3 8b in a typical Mixtral-style Sparse MoE layout.
Over the course of the training run, the expert selection count was gradually increased from the minimum (topk=1) to the maximum (topk=4), as in Sparse MoE as the New Dropout. This was done with a stochastic / randomized top_k expert selection with frozen gate layers, as recommended in the paper.
LR = 2e-6, ~2.5 mil tokens of Python instruct data, all around ~8k tokens ish for each sample ~(300 total samples). Despite the use of instruct data, the model does not necessarily behave like one, as the training process involves mimicking a larger base model's distributions over to said data.
1 epoch distillation of 70b logprobs, topk=200 logits from the fp16 Llama3-70b.
Evals
llama3-4x8b-pythonT2_step_final
- mmlu: 65.10 (66.69) - 0.97x
- arc: 57.94 (59.47) - 0.97x
- hellaswag: 81.93 (82.09) - 0.99x
- winogrande: 77.03 (77.35) - 0.99x
- gsm8k: 50.95 (45.79) - 1.11x
- truthfulqa-mc1: 27.66
- truthfulqa-mc2: 44.53 (43.9) - 1.01x
- humaneval+: 32.9 (29.3) - 1.12x
- humaneval: 37.2 (33.5) - 1.11x
Current Conclusions
Going by evals (and evals alone), full-finetuning seems to have caused some degree of mild catastrophic forgetting outside of the domains that were specifically distilled, as you might expect from the lack of data. I plan to remedy this with lower LRs and/or bigger batch sizes, and of course, on a much larger dataset than the limited selection seen here. The plan is to do at least 1 billion unique tokens; we are still conducting custom tests for alternative loss functions (i.e, things in the vein of a weighted Cross-Entropy loss function to be used in tandem with KL divergence.)