Link to paper
The full paper is available here.
You can also find the paper on PapersWithCode here.
Abstract
- Natural language processing tasks benefit from long inputs.
- Processing long documents with Transformers is expensive.
- CoLT5 is a long-input Transformer model that uses conditional computation.
- CoLT5 achieves stronger performance than LongT5 with faster training and inference.
- CoLT5 achieves SOTA on the long-input SCROLLS benchmark.
- CoLT5 can effectively and tractably make use of extremely long inputs.
Paper Content
Introduction
- Natural language processing tasks require machine learning models to encode longform text.
- Processing long documents with a Transformer model is computationally expensive.
- Many “efficient Transformer” approaches have been proposed to reduce the cost of the attention mechanism over long inputs.
- This paper presents COLT5, a new family of models that enables fast processing of long inputs.
- COLT5 divides each feedforward layer and each attention layer into a light and heavy branch.
- COLT5 also includes two other modifications to the LONGT5 architecture.
- COLT5 performs faster finetuning and inference with similar or better model quality.
Background
- Transformer FLOPs COLT5 attempts to reduce computational cost of Transformer models
- Table 1 shows FLOPs for each component of a Transformer encoder layer
- Self-attention mechanism scales quadratically in input length
- Different approaches focus on reducing cost of different components
- Conditional computation avoids applying all model parameters to entire input
- Device utilization influences effective speed of operations
- Training objectives can lead to improved performance
Conditional computation
- Transformer FLOPs arise from feedforward and projection layers that scale with the length of the input sequence.
- LONGT5 training and inference on long documents is expensive.
- COLT5 reduces the cost of processing long documents through conditional computation.
- COLT5 has three components: routing modules, conditional feedforward layers, and conditional attention layers.
- Routing modules select important tokens from an input at each attention or feedforward layer.
- Conditional feedforward layer applies an additional high-capacity feedforward layer to selected tokens.
- Conditional attention layer applies an additional high-capacity attention layer that attends from selected query tokens to selected key-value tokens.
- COLT5 has less than half projection FLOPs and order-of-magnitude smaller quadratic length scaling compared to LONGT5.
Multi-query attention
- Conditional computation reduces computational cost of encoder.
- Majority of inference time is spent in decoder due to memory bandwidth constraints.
- Multi-query attention applied in cross-attention layers for faster inference.
Ul2
- UL2 pre-training objective combines denoising objectives
- UL2 leads to improved in-context learning
- COLT5 trained on UL2 instead of PEGA-SUS, giving it in-context learning capabilities
Experiments
- Compare COLT5 and LONGT5 on long input datasets
- Evaluate COLT5 on extremely long inputs and compare scaling against LONGT5
- Investigate how performance changes as input length and number of shots increase
- Perform ablations to understand effect of individual COLT5 components
- Investigate empirical routing patterns
Experimental setup
- Configurations COLT5 is based on T5.1.1 architecture
- Implemented with JAX, Flax, and Flaxformer 3
- Experiments with Base, Large, and XL model sizes
- Pre-trained for 1M steps on UL2 objective
- Evaluated on TriviaQA, arXiv, and SCROLLS benchmark
- Timing reported per sample per TPUv4 chip
Main results
- LONGT54 and COLT5 are compared in terms of quality-speed trade-off
- COLT5 is better than LONGT54 at any speed
- COLT5 has 35-75% training speedup and 50-100% inference speedup compared to LONGT54
- COLT5-XL achieves SOTA performance on the SCROLLS benchmark
Scaling to extremely long inputs
- We hypothesize that COLT5 will be more effective with longer inputs
- COLT5 achieves stronger performance and faster inference speed than LONGT5 at all input lengths
- COLT5 can effectively use extremely long inputs up to 64k tokens
- COLT5 uses conditional computation for higher quality and faster speed
- COLT5 has light and heavy branches, with important tokens selected by a learned router
- LONGT5 does not gain an unexpected quality gain from MQA
- COLT5 routed tokens are more likely to be question and answer tokens
- COLT5 uses fewer FLOPs than LONGT5
- COLT5 achieves better performance than LONGT5 on a variety of long-input datasets
- COLT5 can use its long-input capability to benefit from more examples for in-context learning