Link to paper
The full paper is available here.
You can also find the paper on PapersWithCode here.
Abstract
- FiD is a powerful retrieval-augmented language model.
- Inference is expensive for FiD.
- Memory bandwidth constraints in the decoder cause most of the inference time.
- Two changes to the FiD architecture speed up inference by 7x.
- FiDO-Large-XXL performs faster inference than FiD-Base and better performance than FiD-Large.
Paper Content
Introduction
- Language model performance can be improved by augmenting with retrieved text
- Fusion-in-Decoder (FiD) architecture stands out for strong performance
- FiD is expensive and has high computational burden
- Performance and computational cost are two sides of the coin
- Encoder requires more Floating Point Operations (FLOPs) than decoder
- Majority of inference time is spent in decoder
- Memory bandwidth bottleneck is eliminated with proposed changes
- FiDO outperforms vanilla and efficient FiD models on question-answering datasets
Analysis
- Retrieval-augmented models process many context tokens for each question or answer token.
- Most inference time for Fusion-in-Decoder is spent in the decoder.
Fusion-in-decoder
- Fusion-in-Decoder model uses a T5 encoder-decoder architecture
- Model is given a question and relevant text passages
- Question is prepended to each retrieved passage
- Encoder is applied to each passage separately
- Representations are concatenated
- Decoder cross-attends to concatenated representations to generate an answer
Flops of fid model
- Model speed is determined by the number of FLOPs and FLOP/s.
- Operations in a Transformer can be divided into MLP layers, attention projection layers, and attention operations.
- Decoder FLOPs are 1/6 of encoder FLOPs.
- Actual measured training time mirrors FLOPs approximation.
- Decoder is more expensive for inference due to memory bandwidth constraints.
Effective computational throughput
- Data transmission between global memory and registers can limit performance
- Roofline model can be used to model peak FLOP/s
- Data constraint is determined by device memory bandwidth and operational intensity
- High operational intensity is necessary for good performance on modern GPU/TPU hardware
- Low operational intensity can cause the accelerator to spend time waiting for data
- Inverse operational intensity can be derived for Fusion-in-Decoder
Operational intensity of fid inference
- MLP layer requires large batch size
- Memory bandwidth is a bottleneck for attention inference
- Model applies projections for a single token
- Query/key/value/output projections take O(bd2) operations
- Model needs to load projection matrices and past keys and values
- Cross-attention operational intensity is very low
Method
- Encoder accounts for most of the FLOPs and training cost in FiD
- FiD is more expensive than other methods
Layer-sparse cross-attention
- Decoder cross-attention layer is a bottleneck for inference
- FiD-Light improves operational intensity by reducing input length
- Layer-sparse cross-attention (LSA) proposed to remove cross-attention from some decoder layers
- LSA achieves similar speedups without drop in quality
Multi-query attention
- Shazeer (2019) proposes a way to increase the operational intensity of decoder attention layers.
- Keys and values share a single head each and only queries have multiple heads.
- This reduces memory and makes loading faster.
- Inference cost and memory are reduced, but not training speed.
Asymmetric decoder
- Layer-sparse cross-attention and multi-query attention make the decoder cheaper than the encoder.
- Replacing the Basesized decoder with an XL-sized decoder increases total inference time per sample by 21%.
- Pre-training costs increase more strongly than inference costs.
- FiDO uses decoders that are two T5 sizes larger than the encoder.
Related work
- Retrieval-augmented models are a type of computer science model
- REALM, RAG, RETRO and Fusion-in-Decoder are examples of retrieval-augmented models
- Fusion-in-Decoder has achieved state-of-the-art performance on a variety of tasks
- Efficient Transformers can reduce memory bandwidth constraints
- Multi-query attention and quantizing models can reduce data movement and memory usage
- FiDO reduces the cost of the decoder and increases the benefit of other approaches
Experiments
Experiment setup
- Models based on T5.1.1 architecture
- Pre-trained from scratch on C4
- Standard T5 training recipe with modified Adafactor optimizer
- Report results on open-domain QA splits from Lee et al.
- Each sample paired with 100-word Wikipedia passages
- Question prepended to each passage, truncated to 256 tokens
- Fine-tune each model with batch size 64 and learning rate 0.001
- Inference on single TPUv4
- Batch size of 64 (or largest that fits) for main experiments
- Predictions generated with greedy decoding
Main results
- FiDO outperforms FiD at any given inference budget
- Layer-sparse cross-attention reduces inference cost with modest performance degradation
- Token-sparse cross-attention from FiD-Light has similar speedup but higher performance penalty
- Multi-query attention achieves large cost reduction with minimal performance degradation
- Asymmetric decoders achieve better performance-inference trade-off
Other analysis
- Layer-sparse cross-attention and multi-query attention are important for all settings
- As output length increases, the cost of the decoder rises
- For tasks with long outputs, reducing the level of decoder asymmetry is recommended
- For real-time serving, smaller models are used
- Layer-sparse and multi-query attention are still important for lower batch sizes
- Scaling the decoder is more cost-efficient than beam search for most tasks
Conclusion
- FiD requires more FLOPs but most time is spent in the decoder due to memory bandwidth constraints
- FiDO removes most cross-attention layers and employs multi-query attention to reduce the cost of the decoder
- FiDO achieves higher performance given the same inference budget relative to existing standard and efficient FiD models
- FiDO uses layer-sparse cross-attention and multi-query attention to reduce decoder inference cost with minor performance penalty
- FiDO is evaluated on dev sets for ablation results and compared to published results