Create README.md
Browse files
README.md
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: afl-3.0
|
3 |
+
language: en
|
4 |
+
tags:
|
5 |
+
- t5
|
6 |
+
datasets:
|
7 |
+
- wikipedia
|
8 |
+
---
|
9 |
+
|
10 |
+
# chunked T5 - small (cT5-small)
|
11 |
+
|
12 |
+
Github: https://github.com/mtreviso/chunked-t5
|
13 |
+
|
14 |
+
A T5 model that uses a new loss where a special end-of-chunk token `</c>` is appended after sentinel tokens.
|
15 |
+
The decoder has to predict the full input with masked tokens followed by `</c>`.
|
16 |
+
This allows a much faster auto-regressive generation since the decoder can predict multiple tokens in parallel.
|
17 |
+
|
18 |
+
For example, for the input `the quick brown fox jumps over the lazy dog`:
|
19 |
+
```
|
20 |
+
encoder: the <extra_id_0> fox jumps <extra_id_1> the lazy dog
|
21 |
+
|
22 |
+
T5 decoder : <extra_id_0> quick brown <extra_id_1> over <extra_id_2>
|
23 |
+
cT5 decoder: <extra_id_0> quick brown </c> <extra_id_1> over </c> <extra_id_2>
|
24 |
+
```
|
25 |
+
|
26 |
+
The generation may look like this for T5 and cT5:
|
27 |
+
```
|
28 |
+
T5: <extra_id_0>
|
29 |
+
T5: <extra_id_0> quick
|
30 |
+
T5: <extra_id_0> quick brown
|
31 |
+
T5: <extra_id_0> quick brown <extra_id_1>
|
32 |
+
T5: <extra_id_0> quick brown <extra_id_1> over
|
33 |
+
T5: <extra_id_0> quick brown <extra_id_1> over <extra_id_2>
|
34 |
+
T5: <extra_id_0> quick brown <extra_id_1> over <extra_id_2> </s>
|
35 |
+
|
36 |
+
cT5: <extra_id_0> <pad> <extra_id_1> <pad> <extra_id_2> </s>
|
37 |
+
cT5: <extra_id_0> quick <pad> <extra_id_1> over <pad> <extra_id_2> </s>
|
38 |
+
cT5: <extra_id_0> quick brown <pad> <extra_id_1> over </c> <extra_id_2> </s>
|
39 |
+
cT5: <extra_id_0> quick brown </c> <extra_id_1> over </c> <extra_id_2> </s>
|
40 |
+
```
|
41 |
+
|
42 |
+
In the original T5, the decoder is called \\(n_s + 1 + \sum_i |s_i|\\) times autoregressively,
|
43 |
+
where \\(n_s\\) is the number of sentinel tokens and \\(s_1,...,s_{n_s}\\) are the predicted chunks.
|
44 |
+
In contrast, cT5's decoder is called just \\(max_i |s_i| + 1\\) times.
|
45 |
+
The generation stops when all sentences were fully translated to complete chunks, i.e., until all `</c>` tokens were generated.
|
46 |
+
Alternatively, you can also set `max_chunk_size` to manually force the model to stop after generating a chunk with `max_chunk_size` tokens.
|
47 |
+
The overhead of calling the decoder with a longer input is less pronounced since this computation can be parallelized in GPUs/TPUs.
|
48 |
+
|
49 |
+
## Training details
|
50 |
+
|
51 |
+
cT5 models used T5's weights as a starting point, and then it was finetuned on the
|
52 |
+
English [wikipedia](https://huggingface.co/datasets/wikipedia) for 3 epochs,
|
53 |
+
achieving ~74% validation accuracy (ct5-small).
|
54 |
+
The training script is in JAX + Flax and can be found in `pretrain_ct5.py`.
|
55 |
+
|
56 |
+
Flax checkpoints can be converted to PyTorch via `convert_flax_to_pytorch.py [flax_dirname]`.
|
57 |
+
|
58 |
+
|
59 |
+
## Checkpoints
|
60 |
+
|
61 |
+
- ct5-small: https://huggingface.co/mtreviso/ct5-small-en-wiki
|
62 |
+
- ct5-base: todo
|
63 |
+
- ct5-large: todo
|
64 |
+
|
65 |
+
|
66 |
+
## Usage
|
67 |
+
|
68 |
+
```python
|
69 |
+
from transformers import AutoTokenizer
|
70 |
+
from modeling_ct5 import CT5ForConditionalGeneration
|
71 |
+
|
72 |
+
tokenizer = AutoTokenizer.from_pretrained("mtreviso/ct5-small-en-wiki")
|
73 |
+
model = CT5ForConditionalGeneration.from_pretrained("mtreviso/ct5-small-en-wiki")
|
74 |
+
```
|
75 |
+
|
76 |
+
For training:
|
77 |
+
|
78 |
+
```python
|
79 |
+
input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
|
80 |
+
labels = tokenizer("<extra_id_0> man </c> <extra_id_1> the </c> <extra_id_2>", return_tensors="pt").input_ids
|
81 |
+
outputs = model(input_ids=input_ids, labels=labels)
|
82 |
+
loss = outputs.loss
|
83 |
+
logits = outputs.logits
|
84 |
+
```
|
85 |
+
|
86 |
+
For generation:
|
87 |
+
|
88 |
+
```python
|
89 |
+
texts = [
|
90 |
+
"The <extra_id_0> walks in <extra_id_1> park",
|
91 |
+
"UN Chief says there is no way to <extra_id_0> in Syria",
|
92 |
+
]
|
93 |
+
input_ids = tokenizer(texts, return_tensors="pt", padding=True).input_ids
|
94 |
+
generated_ids = model.generate(
|
95 |
+
input_ids,
|
96 |
+
use_cache=False, # important to set to False to avoid caching
|
97 |
+
eoc_token_id=tokenizer.vocab['</c>'], # important to set to the correct end-of-chunk id
|
98 |
+
max_chunk_size=5, # the default is 9999999, which is a large number
|
99 |
+
)
|
100 |
+
```
|
101 |
+
|
102 |
+
This will produce the following tokens:
|
103 |
+
```python
|
104 |
+
>> ['<pad>', '<extra_id_0>', '▁Walking', '▁Trail', '</c>', '<extra_id_1>', '▁the', '</c>', '<extra_id_2>', '</s>']
|
105 |
+
>> ['<pad>', '<extra_id_0>', '▁treat', '▁Syria', '</c>', '<extra_id_1>', '</s>', '<pad>', '<pad>', '<pad>']
|
106 |
+
```
|
107 |
+
|
108 |
+
You have to pass `use_cache=False` to `generate()` in order to avoid caching during the generation procedure as caching is not available for parallel decoding.
|
109 |
+
Currently, parallel decoding is only supported for PyTorch (greedy search, greedy sampling, beam search, beam sampling) and JAX (greedy search and greedy sampling).
|
110 |
+
|
111 |
+
**Note on the beam search implementation**: my beam search implementation is slower than optimal.
|
112 |
+
This is because I use the structures provided by HuggingFace's implementation, namely, BeamScores and BeamHypotheses to store the beam search results for each chunk in the input.
|
113 |
+
In other words, my implementation computes independent "beams" for each chunk rather than for each input sequence.
|
114 |
+
It is possible to make it faster by using a custom BeamScores and BeamHypotheses class, but I haven't done that yet.
|
115 |
+
|
116 |
+
|
117 |
+
## Evaluation
|
118 |
+
|
119 |
+
See the notebook `evaluate_ct5.ipynb` for an example of how to evaluate cT5 in terms of accuracy and perplexity.
|
120 |
+
The notebook `profile.ipynb` shows how to profile the model to get runtimes.
|
121 |
+
|
122 |
+
Here is a comparison between cT5-small and T5-small on a subset of the WikiText-103 dataset using deterministic greedy search:
|
123 |
+
|
124 |
+
| Model | Exact match ↑ | Edit distance ratio ↑ | Perplexity ↓ | Time (seconds) ↓ |
|
125 |
+
|-------|---------------|----------------------|--------------|-----------------|
|
126 |
+
| T5-small | 0.11 | 0.60 | 2.22 | 44.71 |
|
127 |
+
| cT5-small | 0.09 | 0.58 | 1.48 | 10.63 |
|
128 |
+
|
129 |
+
On this toy dataset, cT5-small has a lower perplexity while being faster than T5-small. However, more experiments are needed for a rigorous evaluation.
|
130 |
+
|
131 |
+
If you are interested in applying cT5 to real data, please contact me.
|