Link to paper

The full paper is available here.

You can also find the paper on PapersWithCode here.

Abstract

  • Transformers are slow and memory-hungry on long sequences.
  • Approximate attention methods have attempted to reduce compute complexity, but often do not achieve wall-clock speedup.
  • FlashAttention is an IO-aware exact attention algorithm that uses tiling to reduce memory reads/writes.
  • FlashAttention trains Transformers faster than existing baselines.
  • FlashAttention enables longer context in Transformers, yielding higher quality models.

Paper Content

Introduction

  • Transformer models are widely used in applications such as natural language processing and image classification
  • Transformers have grown larger and deeper, but equipping them with longer context remains difficult
  • Making attention faster and more memory-efficient can help Transformer models address their runtime and memory challenges for long sequences
  • Many approximate attention methods have aimed to reduce the compute and memory requirements of attention
  • These methods focus on FLOP reduction and tend to ignore overheads from memory access
  • FlashAttention is a new attention algorithm that computes exact attention with far fewer memory accesses
  • FlashAttention avoids reading and writing the attention matrix to and from HBM
  • FlashAttention is up to 7.6x faster than the PyTorch implementation of attention on GPT-2
  • FlashAttention is up to 3x faster than the standard attention implementation across common sequence lengths
  • FlashAttention trains Transformer models faster in wall-clock time
  • FlashAttention enables the first Transformer that can achieve better-than-chance performance on the Path-X challenge
  • Block-sparse FlashAttention enables a Transformer to scale to even longer sequences
  • Block-sparse FlashAttention is faster than all existing approximate attention methods

Background

  • Performance characteristics of deep learning operations on GPUs
  • Standard implementation of attention

Hardware performance

  • GPUs are used in this paper
  • Performance on other hardware accelerators is similar
  • GPU memory hierarchy has multiple forms of memory of different sizes and speeds
  • Operations can be classified as compute-bound or memory-bound
  • Kernel fusion is used to accelerate memory-bound operations

Standard attention implementation

  • Input sequences Q, K, V ∈ R ×
  • Standard attention implementation requires ( 2 ) memory
  • Standard attention implementation performs HBM accesses quadratic in sequence length
  • FlashAttention requires fewer HBM reads/writes and no large intermediate matrices
  • FlashAttention is both memory efficient and faster in wall-clock time
  • FlashAttention can handle block-sparse attention

An efficient attention algorithm with tiling and recomputation

  • Aim to compute attention output O in sub-quadratic HBM accesses
  • Split inputs Q, K, V into blocks
  • Compute softmax values along with extra statistics
  • Store output O and softmax normalization statistics
  • Compute gradients with respect to Q, K, V
  • Implement in one CUDA kernel
  • Algorithm 1 returns O with ( 2 ) FLOPs and () additional memory
  • FlashAttention requires many times fewer HBM accesses than standard implementation
  • Load blocks of K, V of size Θ() each
  • Iterate over all blocks of Q to compute intermediate values
  • Standard attention requires Θ( + 2 ) HBM accesses
  • FlashAttention requires Θ( 2 2 −1 ) HBM accesses
  • No exact attention algorithm can asymptotically improve on the number of HBM accesses
  • Runtime of FlashAttention is bottlenecked by other factors for large block size
  • Block-sparse FlashAttention requires Θ( + 2 2 −1 ) HBM accesses
  • Block-sparse FlashAttention achieves 2.8× speedup

Experiments

  • FlashAttention outperforms MLPerf 1.1 speed record for BERT by 15%
  • FlashAttention speeds up GPT-2 up to 3x over HuggingFace and 1.8x over Megatron
  • FlashAttention trains GPT-2 with context length 4K faster than Megatron with context length 1K while achieving 0.7 better perplexity
  • FlashAttention yields 6.4 points of lift on two long-document classification tasks
  • FlashAttention achieves better-than-random performance on Path-X task (sequence length 16K)
  • Block-sparse FlashAttention yields better-than-random performance on Path-256 (sequence length 64K)
  • FlashAttention memory footprint scales linearly with sequence length and is up to 3x faster than standard attention
  • Block-sparse FlashAttention runtime scales linearly with sequence length and is faster than all existing approximate attention baselines

Better models with longer sequences

  • FlashAttention allows for 4x longer context length than GPT-2 while running faster.
  • GPT-2 with FlashAttention and context length 4K is 30% faster than GPT-2 from Megatron with context length 1K.
  • FlashAttention enables Transformers to scale to sequence length 64K.

Benchmarking attention

  • Vary sequence length and measure runtime and memory usage of FlashAttention and block-sparse FlashAttention
  • Compare against reference implementations for exact, approximate, and sparse attention
  • Results reported in main body, more baselines and details in Appendix E

Limitations and future directions

  • Discuss limitations of current approach to building IO-aware implementations of attention
  • Requires writing attention algorithm in lower-level language than PyTorch
  • Implementations may not be transferrable across GPU architectures
  • Need for a method that supports writing attention algorithms in high-level language
  • IO-aware approach can extend beyond attention
  • Multi-GPU IO-aware methods require data transfer between GPUs
  • IO-aware runtime optimization has a long history in computer science
  • Efficient ML models with structured matrices
  • Sparse training
  • Efficient Transformer
  • Derive forward and backward passes of attention
  • Compute forward pass with linear extra memory
  • Compute backward pass with linear extra memory
  • Do not need to store dropout mask from forward pass