Link to paper
The full paper is available here.
You can also find the paper on PapersWithCode here.
Abstract
- Transformers are slow and memory-hungry on long sequences.
- Approximate attention methods have attempted to reduce compute complexity, but often do not achieve wall-clock speedup.
- FlashAttention is an IO-aware exact attention algorithm that uses tiling to reduce memory reads/writes.
- FlashAttention trains Transformers faster than existing baselines.
- FlashAttention enables longer context in Transformers, yielding higher quality models.
Paper Content
Introduction
- Transformer models are widely used in applications such as natural language processing and image classification
- Transformers have grown larger and deeper, but equipping them with longer context remains difficult
- Making attention faster and more memory-efficient can help Transformer models address their runtime and memory challenges for long sequences
- Many approximate attention methods have aimed to reduce the compute and memory requirements of attention
- These methods focus on FLOP reduction and tend to ignore overheads from memory access
- FlashAttention is a new attention algorithm that computes exact attention with far fewer memory accesses
- FlashAttention avoids reading and writing the attention matrix to and from HBM
- FlashAttention is up to 7.6x faster than the PyTorch implementation of attention on GPT-2
- FlashAttention is up to 3x faster than the standard attention implementation across common sequence lengths
- FlashAttention trains Transformer models faster in wall-clock time
- FlashAttention enables the first Transformer that can achieve better-than-chance performance on the Path-X challenge
- Block-sparse FlashAttention enables a Transformer to scale to even longer sequences
- Block-sparse FlashAttention is faster than all existing approximate attention methods
Background
- Performance characteristics of deep learning operations on GPUs
- Standard implementation of attention
Hardware performance
- GPUs are used in this paper
- Performance on other hardware accelerators is similar
- GPU memory hierarchy has multiple forms of memory of different sizes and speeds
- Operations can be classified as compute-bound or memory-bound
- Kernel fusion is used to accelerate memory-bound operations
Standard attention implementation
- Input sequences Q, K, V ∈ R ×
- Standard attention implementation requires ( 2 ) memory
- Standard attention implementation performs HBM accesses quadratic in sequence length
- FlashAttention requires fewer HBM reads/writes and no large intermediate matrices
- FlashAttention is both memory efficient and faster in wall-clock time
- FlashAttention can handle block-sparse attention
An efficient attention algorithm with tiling and recomputation
- Aim to compute attention output O in sub-quadratic HBM accesses
- Split inputs Q, K, V into blocks
- Compute softmax values along with extra statistics
- Store output O and softmax normalization statistics
- Compute gradients with respect to Q, K, V
- Implement in one CUDA kernel
- Algorithm 1 returns O with ( 2 ) FLOPs and () additional memory
- FlashAttention requires many times fewer HBM accesses than standard implementation
- Load blocks of K, V of size Θ() each
- Iterate over all blocks of Q to compute intermediate values
- Standard attention requires Θ( + 2 ) HBM accesses
- FlashAttention requires Θ( 2 2 −1 ) HBM accesses
- No exact attention algorithm can asymptotically improve on the number of HBM accesses
- Runtime of FlashAttention is bottlenecked by other factors for large block size
- Block-sparse FlashAttention requires Θ( + 2 2 −1 ) HBM accesses
- Block-sparse FlashAttention achieves 2.8× speedup
Experiments
- FlashAttention outperforms MLPerf 1.1 speed record for BERT by 15%
- FlashAttention speeds up GPT-2 up to 3x over HuggingFace and 1.8x over Megatron
- FlashAttention trains GPT-2 with context length 4K faster than Megatron with context length 1K while achieving 0.7 better perplexity
- FlashAttention yields 6.4 points of lift on two long-document classification tasks
- FlashAttention achieves better-than-random performance on Path-X task (sequence length 16K)
- Block-sparse FlashAttention yields better-than-random performance on Path-256 (sequence length 64K)
- FlashAttention memory footprint scales linearly with sequence length and is up to 3x faster than standard attention
- Block-sparse FlashAttention runtime scales linearly with sequence length and is faster than all existing approximate attention baselines
Better models with longer sequences
- FlashAttention allows for 4x longer context length than GPT-2 while running faster.
- GPT-2 with FlashAttention and context length 4K is 30% faster than GPT-2 from Megatron with context length 1K.
- FlashAttention enables Transformers to scale to sequence length 64K.
Benchmarking attention
- Vary sequence length and measure runtime and memory usage of FlashAttention and block-sparse FlashAttention
- Compare against reference implementations for exact, approximate, and sparse attention
- Results reported in main body, more baselines and details in Appendix E
Limitations and future directions
- Discuss limitations of current approach to building IO-aware implementations of attention
- Requires writing attention algorithm in lower-level language than PyTorch
- Implementations may not be transferrable across GPU architectures
- Need for a method that supports writing attention algorithms in high-level language
- IO-aware approach can extend beyond attention
- Multi-GPU IO-aware methods require data transfer between GPUs
- IO-aware runtime optimization has a long history in computer science
- Efficient ML models with structured matrices
- Sparse training
- Efficient Transformer
- Derive forward and backward passes of attention
- Compute forward pass with linear extra memory
- Compute backward pass with linear extra memory
- Do not need to store dropout mask from forward pass