Link to paper
The full paper is available here.
You can also find the paper on PapersWithCode here.
Abstract
- Most language models are trained and applied in a left-to-right fashion.
- This paper proposes a new pre-training paradigm to improve training data efficiency and capabilities of language models.
- The proposed pre-training paradigm includes a training objective and a bidirectional inference procedure.
- Experiments show the effectiveness of the pre-training paradigm, outperforming strong baselines.
Paper Content
Preliminaries
- Notation is introduced to be used throughout the paper
- Notation is used to denote prefixes and suffixes of a sequence of tokens
- Dependence of models on learnable parameters is suppressed
- Arrows are used to distinguish the two models and their outputs
The infilling task
- Infilling task is given a sequence of tokens, an insertion position, and a length
- Task is to generate a plausible sequence of tokens to fill the gap
- Finding the sequence that maximizes probability requires time exponential in M
- Fill in the Middle (FIM) approach allows LM to use context from both sides
- Advantages of FIM are that it can be applied to any pre-trained LM and is computationally efficient
- Drawbacks of FIM are unnatural contexts and difficulty balancing influence of prefix and suffix
Bidirectional language modeling
- Bidirectional language modeling has been used to train non-autoregressive LMs
- Autoregressive LMs produce better representations than non-autoregressive LMs
- Our model remains autoregressive and future tokens are used to regularize the model
- We do not attempt to produce a single probability for every token, instead there are two probabilities
Meet in the middle
- Pre-training involves training two models to predict the next token using different views of the input.
- The two models need to balance two goals: predicting the next token well and having the probability distributions assigned to the next token agree.
Pre-training
- Two decoder-only language models are used, which share all parameters
- The forward model predicts next tokens in the forward direction
- The backward model predicts previous tokens in the backward direction
- A regularizer is used to encourage the models to agree on their predictions
Infilling
- Goal is to have an efficient and low-latency generation procedure
- Naive procedure would generate from two models until each meets a condition
- Proposed procedure interleaves generation and scoring and can terminate quickly
- Models start building a completion from their own side and try to meet in the middle
- For each generated token from one model, check if it is in the generated tokens from the other
- If the two models produce the same sequence, they have met in the middle
- Agreement regularizer helps if the two models produce different sequences
- Use n-grams instead of single token to reduce false positives
- Parallel verification procedure adapted from [GXS + 22]
- If no partial match, return to autoregressive generation
- Trade off compatibility with autoregressive LMs for more powerful attention mechanism
- Modifications improve infilling metrics but come at the cost of incompatibility with existing autoregressive LMs
Experiments
- Pre-training experiments conducted
- Evaluation setup described
- Main results presented
- Ablation studies conducted
Data and models
- Pre-trained models on a large and diverse corpus of public code
- Python, Java, C++ are the dominant languages
- Corpus contains 300 billion tokens
- Six times larger than the pre-training dataset used in the original Incoder model
- Also pre-trained on natural language datasets
- Model sizes of 350M, 1.3B and 2.7B parameters
- Baselines pre-trained with FIM models
Benchmarks and metrics
- FIM implementation outperforms Incoder models on all metrics and datasets
- FIM-2.7B surpasses other strong baselines
- MIM consistently outperforms FIM across all metrics and datasets
- MIM pre-training receives more dense supervision from agreement regularizer
- MIM pre-training leads to high quality left-to-right generative model
- MIM consistently outperforms FIM in infilling setting
- MIM consistently outperforms FIM in natural language datasets
Ablation study
- Ablation study conducted to assess effect of optional enhancements
- Comparing autoregressive MIM model with Synchronous Bidirectional Attention layer
- Perplexity used to select value of λ
- Bidirectional context models outperform unidirectional context models
- MIM inference faster than FIM baselines
Related work
- Bidirectional language modeling has been studied extensively
- Early models such as BERT use permutation language modeling to maximize likelihood of training sequences
- Representation learning and in-context learning can be difficult
- Two works train neural models using similar ideas
- One work regularizes forward and backward RNNs by requiring representations to be close in Euclidean distance
- Another work encourages agreement in probability space
- Our approach focuses on achieving better infilling accuracy than FIM while also reducing latency
Conclusion
- Addressed two challenges faced by large LMs
- Proposed “Meet in the Middle” method
- Uses both forward and backward LMs
- Parameters shared between LMs
- Inference procedure reduces latency by up to 50%
- Standard transformer-based autoregressive language models used
- Multi Query Attention used
- Bidirectional context used during pre-training
- Adam optimizer used
- Training done in two stages
- Pre-training data statistics given
- In-domain and out-of-domain perplexity results given
- Masked tokens and spans of tokens used