Text Generation
Transformers
Safetensors
English
stablelm
conversational
Inference Endpoints
euclaise commited on
Commit
896b291
1 Parent(s): 0b86149

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +23 -7
README.md CHANGED
@@ -68,7 +68,7 @@ Consider the following chat interaction:
68
 
69
  The model must predict the bolded parts. So, we randomly mask tokens from the bolded parts, and run the model once on the masked sequence and once on the full sequence.
70
 
71
- We then compute a distance loss `D(p_masked, p_full)` between the two predictions. This approach resembles self-distillation, and MSE tends to perform better than KL Divergence for distillation, along with being easier to tune, so I went with MSE (note that R-TeaFor uses a mix of reverse and forward KL divergence).
72
 
73
  Finally, we add this loss to the standard cross-entropy language modeling losses from each prediction, with a weighting value:
74
  ```
@@ -91,12 +91,28 @@ Keeping this in mind:
91
 
92
  I trained StableLM-3B-4e1t repeatedly on [https://huggingface.co/datasets/euclaise/TinyCoT](TinyCoT), along with 1000 examples from [reddit-instruct-curated](https://huggingface.co/datasets/euclaise/reddit-instruct-curated) and 1000 examples from [oasst2-curated](https://huggingface.co/datasets/sablo/oasst2_curated).
93
 
94
- I trained once with ReMask (ReMask-CoT for CoT examples), once with Masked Thought (w/ partial label-masking for CoT), and once with SFT.
 
 
95
 
96
  Here are some benchmark results, computed using the the LM Evaluation Harness with vllm:
97
 
98
- | Model | GSM8K (strict, 5-shot) | AGIEval (Nous subset, 0-shot) | ARC-C | BBH
99
- |:--------------:|-----------------------:|------------------------------:|------:|-----
100
- | SFT | 23.81% |
101
- | Masked Thought | 20.24% | 23.80%
102
- | **ReMask** | **24.03%** | 24.71%
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  The model must predict the bolded parts. So, we randomly mask tokens from the bolded parts, and run the model once on the masked sequence and once on the full sequence.
70
 
71
+ We then compute a distance loss `D(p_masked, p_full)` between the two predictions. For this, I used the average of the backwards and forwards KL divergences between the predictions.
72
 
73
  Finally, we add this loss to the standard cross-entropy language modeling losses from each prediction, with a weighting value:
74
  ```
 
91
 
92
  I trained StableLM-3B-4e1t repeatedly on [https://huggingface.co/datasets/euclaise/TinyCoT](TinyCoT), along with 1000 examples from [reddit-instruct-curated](https://huggingface.co/datasets/euclaise/reddit-instruct-curated) and 1000 examples from [oasst2-curated](https://huggingface.co/datasets/sablo/oasst2_curated).
93
 
94
+ I trained once with ReMask/ReMask-CoT, once without regularization to match Masked Thought (w/ partial label-masking for CoT), and once with SFT.
95
+
96
+ If my hypothesis regarding exposure bias is correct, ReMask should significantly improve generative benchmarks like GSM8K, but would not necessarily improve logprob-based benchmarks like ARC-c (as implemented by the evaluation harness):
97
 
98
  Here are some benchmark results, computed using the the LM Evaluation Harness with vllm:
99
 
100
+ | Model | GSM8K (strict, 5-shot) | ARC-c (acc_norm, 25-shot) |
101
+ |:--------------:|-----------------------:|--------------------------:|
102
+ | SFT | 24.34% | 42.92% |
103
+ | Masked Thought | 24.18% | **43.60%** |
104
+ | **ReMask** | **27.90%** | 43.26% |
105
+
106
+ As I expected, it improves GSM8K doesn't do much to ARC.
107
+
108
+ ## Training details
109
+ - Framework: PyTorch Lightning
110
+ - Optimizer: [Lilith](https://github.com/euclaise/supertrainer2000/blob/master/src/supertrainer2k/optim/lilith.py)
111
+ - Training sequence length: 256
112
+ - Input masking probability: 40%
113
+ - Label masking probability: 10%
114
+ - Answer-only (full rationale masking) probability: 10%
115
+ - Batch size: 16, accumulated to 256
116
+ - Epochs: 6
117
+ - Learning rate: 1e-5
118
+ - Learning rate schedule: One Cycle, cosine, no cycle_momentum