Link to paper
The full paper is available here.
You can also find the paper on PapersWithCode here.
Abstract
- Transformer models have achieved superior performance in NLP tasks.
- Attention mechanism has quadratic computational cost, limiting its practicality for long sequences.
- Existing attention variants improve computational efficiency but lack ability to compute global information.
- State space models are tailored for long sequences but lack flexibility to capture complicated local information.
Paper Content
Introduction
- Transformer models have achieved superior performance on various natural language processing tasks
- Leverage the attention mechanism which has quadratic time and space complexity
- Complexity is prohibitive for tasks with long sequences
- Transformer models with full attention are easy to overfit
- Various approaches have been proposed to reduce complexity and introduce structural biases
- Approximation methods approximate full attention with linear complexity
- Partial attention methods reduce complexity and introduce structural biases
- State space models introduce a different structural bias
- Proposed SPADE model is a multi-layer Transformer model that captures global and local information
- SPADE outperforms existing approaches on Long Range Arena benchmark
- SPADE is faster and yields better performance in autoregressive language modeling
- SPADE is scalable and outperforms baselines on natural language understanding and natural language generation benchmarks
- Code and pre-trained model checkpoints are publicly available
Background
Attention mechanism
- Attention mechanism takes input X and outputs alignment between any pair of input tokens.
- Attention mechanism models long-range dependencies better than recurrent neural networks.
State space models
- Continuous time state space model maps 1-dimensional input signal to ds-dimensional latent state
- Discrete time state space model uses bilinear method to discretize model
- Output can be written as convolutional representation
- Structured State Space Sequence model (S4) developed to efficiently compute convolution kernel
- S4 convolution kernel can be computed with O(L) time and space complexity
Method
- SSMs do not model local information well
- SPADE combines global and local information by augmenting SSMs into the Transformer architecture
Attention vs. state space models
- SPADE is a multi-layer Transformer model that captures global and local information.
- SSMs perform poorly on language modeling tasks.
- Local information is more important than global information for language modeling.
- SPADE has a hierarchical structure with a global layer and local layers.
- SPADE uses an SSM to capture global information and local attention methods to capture local information.
- SPADE does not need additional positional embedding techniques.
Experiments
- Implemented models using PyTorch and Fairseq
- Training details in appendix
Long range arena
- Long Range Arena (LRA) benchmark used to evaluate model’s ability to model long sequences
- Six tasks: ListOps, Text, Retrieval, Image, Pathfinder, Path-X
- Models used are small (less than 2M parameters)
- Two approaches used to aggregate local information: window attention and chunk attention
- SPADE (softmax-window and MEGA-chunk) significantly outperforms all baselines in terms of average accuracy
Language modeling
- Evaluated model by conducting language modeling experiments on Wikitext-103
- SPADE (softmax-window) achieved significant performance improvement and outperformed all baselines
- SPADE with softmax-window was faster than Transformer with full attention and yielded better performance
- Parameters in S4 were initialized using Eq. 5 and frozen during training
Language model pre-training
- Implemented model pre-training using Fairseq
- Implemented model fine-tuning using MT-DNN
- All experiments only use single task fine-tuning
- Hyper-parameter settings in appendix
Pre-training details
- Pre-trained an encoder-decoder variant of SPADE
- Model architecture is the same as T5 base
- Embedding dimension is 768, hidden dimension of FFN is 3072, number of attention heads is 12, encoder and decoder have 12 layers
- Added S4 module to bottom layer of SPADE
- Used softmax-window attention as local information extractor, window size set to 128
- Model contains 290M parameters
- Two pre-training settings with different datasets and number of training steps
Natural language understanding
- Pre-trained models are fine-tuned on GLUE benchmark
- GLUE includes two single-sentence classification tasks
- GLUE also includes three similarity and paraphrase tasks
- GLUE also includes natural language inference tasks
- Experiments show SPADE base++ outperforms T5 base
Natural language generation
- Pre-trained models are fine-tuned on abstractive summarization datasets
- Evaluation metric is ROUGE-2 scores
- SPADE is compared to LongT5, which is tailored for long sequences
- SPADE is more general and has superior performance on natural language understanding tasks
- SPADE base++ has 290M parameters, LongT5 large has 770M parameters and LongT5 xl has 3B parameters
- SPADE base++ performs on par or better than LongT5 large on all tasks
- SPADE outperforms LongT5 xl on MultiNews dataset
Efficiency comparison
- SPADE is efficient in terms of both training speed and GPU memory usage
- SPADE trains significantly faster than the vanilla Transformer
- S4 may be less efficient than the vanilla Transformer
- SPADE with softmax-window has only marginally different training speed and memory usage from windowattention Transformer
Location and number of global layers
- SPADE has a bottom layer equipped with a SSM which serves as the global layer.
- Results show that model performance decreases when more global layers are used.
- Results also show that model performance drops significantly when the global layer is used as the top layer.
Different configurations
- Performance of Transformer with full attention increases when sequence length increases from 512 to 3k, but decreases when further increased to 4k.
- Performance of Transformer with window attention increases when window size increases.
- Performance of SPADE increases when window size increases and marginally decreases when sequence length increases from 4k to 6k.
Pre-trained language models
- Pre-trained language models have achieved state-of-the-art performance on various natural language processing tasks.
- BERT cannot handle sequences with length more than 512.
- LongT5 facilitates training on long sequences.
- LongT5 targets long sequence modeling tasks such as text summarization.
Conclusion
- Propose SPADE, a state space augmented Transformer model
- Bottom layer is global layer, rest are local layers
- Use SSM to augment coarse global information
- Local layers use off-the-shelf efficient attention methods
- Linear time and space complexity
- Extensive experiments on LRA benchmark and language modeling datasets
- Pre-train encoder-decoder models to demonstrate scalability
- Fine-tuning experiments on GLUE and summarization tasks
- Superior performance and outperforms baselines