Link to paper

The full paper is available here.

You can also find the paper on PapersWithCode here.

Abstract

  • Speculative sampling is an algorithm for accelerating transformer decoding.
  • It enables the generation of multiple tokens from each transformer call.
  • It uses a faster but less powerful draft model to generate short continuations.
  • A modified rejection sampling scheme is used to preserve the distribution of the target model.
  • Speculative sampling was benchmarked with a 70 billion parameter language model, resulting in a 2-2.5x decoding speedup.

Paper Content

Introduction

  • Scaling transformer models has led to improved performance on many tasks
  • Transformer decoding is costly and inefficient
  • Memory bandwidth is a limitation for transformer sampling
  • Model parallelism is needed for large language models
  • Speculative sampling (SpS) algorithm accelerates transformer sampling
  • SpS leads to 2-2.5x speedup when sampling from Chinchilla
  • Improving sampling latency of large transformers and other auto-regressive models
  • Quantisation and distillation of transformers to reduce sampling latency
  • Model size contributes less to performance than expected
  • Multi-query attention to improve sampling performance
  • Combination of techniques to improve serving latency and efficiency of PaLM 540B
  • Existing body of work to exploit efficiency of transformers and sequence models operating in parallel

Auto-regressive sampling

  • Transformers can be trained efficiently and in parallel on TPUs and GPUs
  • Auto-regressive sampling is memory bandwidth bound and cannot make effective use of modern accelerator hardware
  • Generating multiple tokens introduces latency
  • As model size increases, parameters need to be divided across multiple accelerators, leading to communication overheads
  • Scoring and sampling should not be significantly slower for small batch sizes

Modified rejection sampling

  • Introduce rejection sampling scheme for drafted tokens
  • Accept tokens with probability based on target and draft models
  • Resample rejected tokens from distribution
  • Recover target model distribution for accepted tokens
  • Maximum of + 1 tokens per loop

Choice of draft models

  • Incorporate draft generation into target model
  • Use sequence level distillation to generate second model
  • Use smaller version of target language model as draft

Results

  • Trained 4 billion parameter draft model on 16 TPU v4s
  • Sampling speed of 1.8ms/token compared to 14.1ms/token for Chinchilla
  • Smaller models need fewer TPUs to achieve lowest sampling latency
  • Wider model with few layers minimises communication overhead

Evaluation on xsum and humaneval

  • Evaluated speculative sampling with Chinchilla on two tasks
  • XSum benchmark: 11,305 sequences, max length 128
  • HumanEval task: 16,400 samples, max length 512
  • Obtained substantial speedup in both tasks, HumanEval reaching 2.5x speedup
  • Benchmark metrics show same underlying sample distribution

Acceptance rate changes per domain

  • Acceptance rate is dependent on application and decoding method
  • HumanEval achieves larger speedup due to common sub-sequences and shorter tokens
  • Temperature value sharpens draft and target logits
  • Average time per loop increases linearly with increased number of model calls

Trade off between longer drafts and more frequent scoring

  • Increasing the number of tokens sampled by the draft model can lead to fewer scoring calls from the large models, potentially giving a speedup.
  • The total loop time increases with the larger number of draft model calls and small increases in the scoring time.
  • Efficiency of accepted tokens decreases as the number of tokens increases.
  • Speedup may plateau or degrade with a larger number of tokens, depending on the domain.
  • Larger values of tokens may increase the variance of the time to generate a full sequence.

Conclusion

  • A new algorithm and workflow is proposed to accelerate the decoding of language models
  • The technique does not require any modifications to the target language model’s parameters or architecture
  • It is provably lossless within numerics and scales well with the appropriate draft model
  • It yields a large speedup across benchmark tasks and common decoding methods
  • It is verified to be lossless empirically in its downstream tasks