Update README.md
Browse files
README.md
CHANGED
@@ -68,7 +68,7 @@ Consider the following chat interaction:
|
|
68 |
|
69 |
The model must predict the bolded parts. So, we randomly mask tokens from the bolded parts, and run the model once on the masked sequence and once on the full sequence.
|
70 |
|
71 |
-
We then compute a distance loss `D(p_masked, p_full)` between the two predictions.
|
72 |
|
73 |
Finally, we add this loss to the standard cross-entropy language modeling losses from each prediction, with a weighting value:
|
74 |
```
|
@@ -91,12 +91,28 @@ Keeping this in mind:
|
|
91 |
|
92 |
I trained StableLM-3B-4e1t repeatedly on [https://huggingface.co/datasets/euclaise/TinyCoT](TinyCoT), along with 1000 examples from [reddit-instruct-curated](https://huggingface.co/datasets/euclaise/reddit-instruct-curated) and 1000 examples from [oasst2-curated](https://huggingface.co/datasets/sablo/oasst2_curated).
|
93 |
|
94 |
-
I trained once with ReMask
|
|
|
|
|
95 |
|
96 |
Here are some benchmark results, computed using the the LM Evaluation Harness with vllm:
|
97 |
|
98 |
-
| Model | GSM8K (strict, 5-shot) |
|
99 |
-
|
100 |
-
| SFT |
|
101 |
-
| Masked Thought |
|
102 |
-
| **ReMask** | **
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
The model must predict the bolded parts. So, we randomly mask tokens from the bolded parts, and run the model once on the masked sequence and once on the full sequence.
|
70 |
|
71 |
+
We then compute a distance loss `D(p_masked, p_full)` between the two predictions. For this, I used the average of the backwards and forwards KL divergences between the predictions.
|
72 |
|
73 |
Finally, we add this loss to the standard cross-entropy language modeling losses from each prediction, with a weighting value:
|
74 |
```
|
|
|
91 |
|
92 |
I trained StableLM-3B-4e1t repeatedly on [https://huggingface.co/datasets/euclaise/TinyCoT](TinyCoT), along with 1000 examples from [reddit-instruct-curated](https://huggingface.co/datasets/euclaise/reddit-instruct-curated) and 1000 examples from [oasst2-curated](https://huggingface.co/datasets/sablo/oasst2_curated).
|
93 |
|
94 |
+
I trained once with ReMask/ReMask-CoT, once without regularization to match Masked Thought (w/ partial label-masking for CoT), and once with SFT.
|
95 |
+
|
96 |
+
If my hypothesis regarding exposure bias is correct, ReMask should significantly improve generative benchmarks like GSM8K, but would not necessarily improve logprob-based benchmarks like ARC-c (as implemented by the evaluation harness):
|
97 |
|
98 |
Here are some benchmark results, computed using the the LM Evaluation Harness with vllm:
|
99 |
|
100 |
+
| Model | GSM8K (strict, 5-shot) | ARC-c (acc_norm, 25-shot) |
|
101 |
+
|:--------------:|-----------------------:|--------------------------:|
|
102 |
+
| SFT | 24.34% | 42.92% |
|
103 |
+
| Masked Thought | 24.18% | **43.60%** |
|
104 |
+
| **ReMask** | **27.90%** | 43.26% |
|
105 |
+
|
106 |
+
As I expected, it improves GSM8K doesn't do much to ARC.
|
107 |
+
|
108 |
+
## Training details
|
109 |
+
- Framework: PyTorch Lightning
|
110 |
+
- Optimizer: [Lilith](https://github.com/euclaise/supertrainer2000/blob/master/src/supertrainer2k/optim/lilith.py)
|
111 |
+
- Training sequence length: 256
|
112 |
+
- Input masking probability: 40%
|
113 |
+
- Label masking probability: 10%
|
114 |
+
- Answer-only (full rationale masking) probability: 10%
|
115 |
+
- Batch size: 16, accumulated to 256
|
116 |
+
- Epochs: 6
|
117 |
+
- Learning rate: 1e-5
|
118 |
+
- Learning rate schedule: One Cycle, cosine, no cycle_momentum
|