Link to paper
The full paper is available here.
You can also find the paper on PapersWithCode here.
Abstract
- Speculative sampling is an algorithm for accelerating transformer decoding.
- It enables the generation of multiple tokens from each transformer call.
- It uses a faster but less powerful draft model to generate short continuations.
- A modified rejection sampling scheme is used to preserve the distribution of the target model.
- Speculative sampling was benchmarked with a 70 billion parameter language model, resulting in a 2-2.5x decoding speedup.
Paper Content
Introduction
- Scaling transformer models has led to improved performance on many tasks
- Transformer decoding is costly and inefficient
- Memory bandwidth is a limitation for transformer sampling
- Model parallelism is needed for large language models
- Speculative sampling (SpS) algorithm accelerates transformer sampling
- SpS leads to 2-2.5x speedup when sampling from Chinchilla
Related work
- Improving sampling latency of large transformers and other auto-regressive models
- Quantisation and distillation of transformers to reduce sampling latency
- Model size contributes less to performance than expected
- Multi-query attention to improve sampling performance
- Combination of techniques to improve serving latency and efficiency of PaLM 540B
- Existing body of work to exploit efficiency of transformers and sequence models operating in parallel
Auto-regressive sampling
- Transformers can be trained efficiently and in parallel on TPUs and GPUs
- Auto-regressive sampling is memory bandwidth bound and cannot make effective use of modern accelerator hardware
- Generating multiple tokens introduces latency
- As model size increases, parameters need to be divided across multiple accelerators, leading to communication overheads
- Scoring and sampling should not be significantly slower for small batch sizes
Modified rejection sampling
- Introduce rejection sampling scheme for drafted tokens
- Accept tokens with probability based on target and draft models
- Resample rejected tokens from distribution
- Recover target model distribution for accepted tokens
- Maximum of + 1 tokens per loop
Choice of draft models
- Incorporate draft generation into target model
- Use sequence level distillation to generate second model
- Use smaller version of target language model as draft
Results
- Trained 4 billion parameter draft model on 16 TPU v4s
- Sampling speed of 1.8ms/token compared to 14.1ms/token for Chinchilla
- Smaller models need fewer TPUs to achieve lowest sampling latency
- Wider model with few layers minimises communication overhead
Evaluation on xsum and humaneval
- Evaluated speculative sampling with Chinchilla on two tasks
- XSum benchmark: 11,305 sequences, max length 128
- HumanEval task: 16,400 samples, max length 512
- Obtained substantial speedup in both tasks, HumanEval reaching 2.5x speedup
- Benchmark metrics show same underlying sample distribution
Acceptance rate changes per domain
- Acceptance rate is dependent on application and decoding method
- HumanEval achieves larger speedup due to common sub-sequences and shorter tokens
- Temperature value sharpens draft and target logits
- Average time per loop increases linearly with increased number of model calls
Trade off between longer drafts and more frequent scoring
- Increasing the number of tokens sampled by the draft model can lead to fewer scoring calls from the large models, potentially giving a speedup.
- The total loop time increases with the larger number of draft model calls and small increases in the scoring time.
- Efficiency of accepted tokens decreases as the number of tokens increases.
- Speedup may plateau or degrade with a larger number of tokens, depending on the domain.
- Larger values of tokens may increase the variance of the time to generate a full sequence.
Conclusion
- A new algorithm and workflow is proposed to accelerate the decoding of language models
- The technique does not require any modifications to the target language model’s parameters or architecture
- It is provably lossless within numerics and scales well with the appropriate draft model
- It yields a large speedup across benchmark tasks and common decoding methods
- It is verified to be lossless empirically in its downstream tasks