Link to paper

The full paper is available here.

You can also find the paper on PapersWithCode here.

Abstract

  • Natural language processing tasks benefit from long inputs.
  • Processing long documents with Transformers is expensive.
  • CoLT5 is a long-input Transformer model that uses conditional computation.
  • CoLT5 achieves stronger performance than LongT5 with faster training and inference.
  • CoLT5 achieves SOTA on the long-input SCROLLS benchmark.
  • CoLT5 can effectively and tractably make use of extremely long inputs.

Paper Content

Introduction

  • Natural language processing tasks require machine learning models to encode longform text.
  • Processing long documents with a Transformer model is computationally expensive.
  • Many “efficient Transformer” approaches have been proposed to reduce the cost of the attention mechanism over long inputs.
  • This paper presents COLT5, a new family of models that enables fast processing of long inputs.
  • COLT5 divides each feedforward layer and each attention layer into a light and heavy branch.
  • COLT5 also includes two other modifications to the LONGT5 architecture.
  • COLT5 performs faster finetuning and inference with similar or better model quality.

Background

  • Transformer FLOPs COLT5 attempts to reduce computational cost of Transformer models
  • Table 1 shows FLOPs for each component of a Transformer encoder layer
  • Self-attention mechanism scales quadratically in input length
  • Different approaches focus on reducing cost of different components
  • Conditional computation avoids applying all model parameters to entire input
  • Device utilization influences effective speed of operations
  • Training objectives can lead to improved performance

Conditional computation

  • Transformer FLOPs arise from feedforward and projection layers that scale with the length of the input sequence.
  • LONGT5 training and inference on long documents is expensive.
  • COLT5 reduces the cost of processing long documents through conditional computation.
  • COLT5 has three components: routing modules, conditional feedforward layers, and conditional attention layers.
  • Routing modules select important tokens from an input at each attention or feedforward layer.
  • Conditional feedforward layer applies an additional high-capacity feedforward layer to selected tokens.
  • Conditional attention layer applies an additional high-capacity attention layer that attends from selected query tokens to selected key-value tokens.
  • COLT5 has less than half projection FLOPs and order-of-magnitude smaller quadratic length scaling compared to LONGT5.

Multi-query attention

  • Conditional computation reduces computational cost of encoder.
  • Majority of inference time is spent in decoder due to memory bandwidth constraints.
  • Multi-query attention applied in cross-attention layers for faster inference.

Ul2

  • UL2 pre-training objective combines denoising objectives
  • UL2 leads to improved in-context learning
  • COLT5 trained on UL2 instead of PEGA-SUS, giving it in-context learning capabilities

Experiments

  • Compare COLT5 and LONGT5 on long input datasets
  • Evaluate COLT5 on extremely long inputs and compare scaling against LONGT5
  • Investigate how performance changes as input length and number of shots increase
  • Perform ablations to understand effect of individual COLT5 components
  • Investigate empirical routing patterns

Experimental setup

  • Configurations COLT5 is based on T5.1.1 architecture
  • Implemented with JAX, Flax, and Flaxformer 3
  • Experiments with Base, Large, and XL model sizes
  • Pre-trained for 1M steps on UL2 objective
  • Evaluated on TriviaQA, arXiv, and SCROLLS benchmark
  • Timing reported per sample per TPUv4 chip

Main results

  • LONGT54 and COLT5 are compared in terms of quality-speed trade-off
  • COLT5 is better than LONGT54 at any speed
  • COLT5 has 35-75% training speedup and 50-100% inference speedup compared to LONGT54
  • COLT5-XL achieves SOTA performance on the SCROLLS benchmark

Scaling to extremely long inputs

  • We hypothesize that COLT5 will be more effective with longer inputs
  • COLT5 achieves stronger performance and faster inference speed than LONGT5 at all input lengths
  • COLT5 can effectively use extremely long inputs up to 64k tokens
  • COLT5 uses conditional computation for higher quality and faster speed
  • COLT5 has light and heavy branches, with important tokens selected by a learned router
  • LONGT5 does not gain an unexpected quality gain from MQA
  • COLT5 routed tokens are more likely to be question and answer tokens
  • COLT5 uses fewer FLOPs than LONGT5
  • COLT5 achieves better performance than LONGT5 on a variety of long-input datasets
  • COLT5 can use its long-input capability to benefit from more examples for in-context learning