Link to paper

The full paper is available here.

You can also find the paper on PapersWithCode here.

Abstract

  • Transformer models have achieved superior performance in NLP tasks.
  • Attention mechanism has quadratic computational cost, limiting its practicality for long sequences.
  • Existing attention variants improve computational efficiency but lack ability to compute global information.
  • State space models are tailored for long sequences but lack flexibility to capture complicated local information.

Paper Content

Introduction

  • Transformer models have achieved superior performance on various natural language processing tasks
  • Leverage the attention mechanism which has quadratic time and space complexity
  • Complexity is prohibitive for tasks with long sequences
  • Transformer models with full attention are easy to overfit
  • Various approaches have been proposed to reduce complexity and introduce structural biases
  • Approximation methods approximate full attention with linear complexity
  • Partial attention methods reduce complexity and introduce structural biases
  • State space models introduce a different structural bias
  • Proposed SPADE model is a multi-layer Transformer model that captures global and local information
  • SPADE outperforms existing approaches on Long Range Arena benchmark
  • SPADE is faster and yields better performance in autoregressive language modeling
  • SPADE is scalable and outperforms baselines on natural language understanding and natural language generation benchmarks
  • Code and pre-trained model checkpoints are publicly available

Background

Attention mechanism

  • Attention mechanism takes input X and outputs alignment between any pair of input tokens.
  • Attention mechanism models long-range dependencies better than recurrent neural networks.

State space models

  • Continuous time state space model maps 1-dimensional input signal to ds-dimensional latent state
  • Discrete time state space model uses bilinear method to discretize model
  • Output can be written as convolutional representation
  • Structured State Space Sequence model (S4) developed to efficiently compute convolution kernel
  • S4 convolution kernel can be computed with O(L) time and space complexity

Method

  • SSMs do not model local information well
  • SPADE combines global and local information by augmenting SSMs into the Transformer architecture

Attention vs. state space models

  • SPADE is a multi-layer Transformer model that captures global and local information.
  • SSMs perform poorly on language modeling tasks.
  • Local information is more important than global information for language modeling.
  • SPADE has a hierarchical structure with a global layer and local layers.
  • SPADE uses an SSM to capture global information and local attention methods to capture local information.
  • SPADE does not need additional positional embedding techniques.

Experiments

  • Implemented models using PyTorch and Fairseq
  • Training details in appendix

Long range arena

  • Long Range Arena (LRA) benchmark used to evaluate model’s ability to model long sequences
  • Six tasks: ListOps, Text, Retrieval, Image, Pathfinder, Path-X
  • Models used are small (less than 2M parameters)
  • Two approaches used to aggregate local information: window attention and chunk attention
  • SPADE (softmax-window and MEGA-chunk) significantly outperforms all baselines in terms of average accuracy

Language modeling

  • Evaluated model by conducting language modeling experiments on Wikitext-103
  • SPADE (softmax-window) achieved significant performance improvement and outperformed all baselines
  • SPADE with softmax-window was faster than Transformer with full attention and yielded better performance
  • Parameters in S4 were initialized using Eq. 5 and frozen during training

Language model pre-training

  • Implemented model pre-training using Fairseq
  • Implemented model fine-tuning using MT-DNN
  • All experiments only use single task fine-tuning
  • Hyper-parameter settings in appendix

Pre-training details

  • Pre-trained an encoder-decoder variant of SPADE
  • Model architecture is the same as T5 base
  • Embedding dimension is 768, hidden dimension of FFN is 3072, number of attention heads is 12, encoder and decoder have 12 layers
  • Added S4 module to bottom layer of SPADE
  • Used softmax-window attention as local information extractor, window size set to 128
  • Model contains 290M parameters
  • Two pre-training settings with different datasets and number of training steps

Natural language understanding

  • Pre-trained models are fine-tuned on GLUE benchmark
  • GLUE includes two single-sentence classification tasks
  • GLUE also includes three similarity and paraphrase tasks
  • GLUE also includes natural language inference tasks
  • Experiments show SPADE base++ outperforms T5 base

Natural language generation

  • Pre-trained models are fine-tuned on abstractive summarization datasets
  • Evaluation metric is ROUGE-2 scores
  • SPADE is compared to LongT5, which is tailored for long sequences
  • SPADE is more general and has superior performance on natural language understanding tasks
  • SPADE base++ has 290M parameters, LongT5 large has 770M parameters and LongT5 xl has 3B parameters
  • SPADE base++ performs on par or better than LongT5 large on all tasks
  • SPADE outperforms LongT5 xl on MultiNews dataset

Efficiency comparison

  • SPADE is efficient in terms of both training speed and GPU memory usage
  • SPADE trains significantly faster than the vanilla Transformer
  • S4 may be less efficient than the vanilla Transformer
  • SPADE with softmax-window has only marginally different training speed and memory usage from windowattention Transformer

Location and number of global layers

  • SPADE has a bottom layer equipped with a SSM which serves as the global layer.
  • Results show that model performance decreases when more global layers are used.
  • Results also show that model performance drops significantly when the global layer is used as the top layer.

Different configurations

  • Performance of Transformer with full attention increases when sequence length increases from 512 to 3k, but decreases when further increased to 4k.
  • Performance of Transformer with window attention increases when window size increases.
  • Performance of SPADE increases when window size increases and marginally decreases when sequence length increases from 4k to 6k.

Pre-trained language models

  • Pre-trained language models have achieved state-of-the-art performance on various natural language processing tasks.
  • BERT cannot handle sequences with length more than 512.
  • LongT5 facilitates training on long sequences.
  • LongT5 targets long sequence modeling tasks such as text summarization.

Conclusion

  • Propose SPADE, a state space augmented Transformer model
  • Bottom layer is global layer, rest are local layers
  • Use SSM to augment coarse global information
  • Local layers use off-the-shelf efficient attention methods
  • Linear time and space complexity
  • Extensive experiments on LRA benchmark and language modeling datasets
  • Pre-train encoder-decoder models to demonstrate scalability
  • Fine-tuning experiments on GLUE and summarization tasks
  • Superior performance and outperforms baselines