Efficient Long-Range Transformers: You Need to Attend More, but Not Necessarily at Every Layer
Abstract
Pretrained transformer models have demonstrated remarkable performance across various natural language processing tasks. These models leverage the attention mechanism to capture long- and short-range dependencies in the sequence. However, the (full) attention mechanism incurs high computational cost - quadratic in the sequence length, which is not affordable in tasks with long sequences, e.g., inputs with 8k tokens. Although sparse attention can be used to improve computational efficiency, as suggested in existing work, it has limited modeling capacity and often fails to capture complicated dependencies in long sequences. To tackle this challenge, we propose MASFormer, an easy-to-implement transformer variant with Mixed Attention Spans. Specifically, MASFormer is equipped with full attention to capture long-range dependencies, but only at a small number of layers. For the remaining layers, MASformer only employs sparse attention to capture short-range dependencies. Our experiments on natural language modeling and generation tasks show that a decoder-only MASFormer model of 1.3B parameters can achieve competitive performance to vanilla transformers with full attention while significantly reducing computational cost (up to 75%). Additionally, we investigate the effectiveness of continual training with long sequence data and how sequence length impacts downstream generation performance, which may be of independent interest.
Community
I love this. This is really interesting way to reduce cost of attention AND get some improvements on long-context modeling by addressing the lost-in-middle problem.
A few additional related work on varying local/global attention:
- https://huggingface.co/papers/2007.03356 - Rae and Razavi attempts a very similar study on Transformer-XL. Note that they found very different results there (though the setup and model architecture is different on T-XL)
However perhaps more surprisingly, we see a model with 12 LRMs at the lower layers of the network is actually worse than a model with a single LRM on the final layer. We then see that the full TXL with 24 LRMs is seemingly identical to the 12 LRM models, with either LRMs interleaved across the whole model or LRMs placed in the final 12 layers
Our finding is that we do not need long-range memories at every layer of the network. Comparable performance can be obtained with a fraction (1/6th) of long-range memories if they are spaced equally across the network, or in the latter layers.
https://huggingface.co/papers/2312.08618 - Attempts a study of just interleaving local (window, not block) and global attention layers. Performs a 3 to 1 local/global split. Like your comment on GPT-Neo, this is also mainly to move the FLOPs-frontier on efficiency, to achieve comparable loss with less FLOPs (but more steps)
https://huggingface.co/papers/2305.12689 - More different, but also investigates the idea of alternating local sparse attention layers with global sparse attention layers
Models citing this paper 0
No model linking this paper
Datasets citing this paper 0
No dataset linking this paper
Spaces citing this paper 0
No Space linking this paper