Link to paper

The full paper is available here.

You can also find the paper on PapersWithCode here.

Abstract

  • Language models (LMs) calculate representations of an already-seen context to predict the next word.
  • Retrieval-augmented LMs have been shown to improve over standard neural LMs.
  • This paper investigates why k-nearest neighbor language models (kNN-LMs) perform better than standard parametric LMs.
  • Three main reasons are identified: different input representation, approximate kNN search, and softmax temperature.
  • These insights are incorporated into the model architecture or training procedure of the standard parametric LM.

Paper Content

Introduction

  • Language modeling is the task of predicting the probability of a text
  • It has applications in natural language processing
  • Retrieval-augmented language models have shown impressive results
  • K-nearest neighbor language model (kNN-LM) is a simple and effective model
  • It improves the ability to model the training data, not just benefiting from more data
  • kNN-LM is a linear interpolation between a base LM and a kNN model
  • It creates a datastore of key-value pairs
  • During inference, the parametric component of the LM generates the output distribution
  • The non-parametric component queries the datastore with the context representation
  • It computes a probability distribution over the nearest neighbors using the softmax of their negative distances
  • This distribution is interpolated with the parametric LM distribution
  • There are many design choices for kNN-LM-like models
  • Size of W ds is very large
  • h ds is the output from the multi-headed attention layer of the last transformer block
  • Similarity function is negative squared L2 distance
  • Softmax temperature (τ ) can control the scaling of the similarity scores
  • k is usually less than V and most of the datastore entries are pruned out

Replacing the datastore with trainable embeddings

  • Choice of h ds has a large impact on performance of kNN-LM
  • Explore if improvements from kNN-LM come from different input representations
  • Remove non-parametric datastore and initialize W ds with randomly initialized word embedding matrix
  • Re-training W ds has benefit
  • Ensembling two predictors is useful
  • Parametric ensembles can achieve majority of gain from kNN-LM

Increasing the softmax capacity

  • kNN-LM works well due to the large datastore
  • Testing the effect of datastore size on kNN-LM performance
  • Perplexity decreases almost linearly with more percentage of the original datastore
  • Even with just 5% of the datastore size, kNN-LM still provides a benefit
  • Increasing the embedding matrix size does not always bring better performance
  • Over-parameterizing W ds is not an effective method of achieving accuracy gains
  • Approximate kNN search is better than exact search
  • Approximate FAISS mask is better than ground truth nearest neighbors
  • Approximate score returned by FAISS is better than recomputing ground truth distances

Adding softmax temperature to knn distribution

  • kNN distribution is peaky compared to standard LM output distribution
  • Temperature can be used to control peakiness of distribution
  • Temperature of 1 is close to optimal in original settings
  • Tuning temperature is important for optimal results
  • Using real mask and real score outperforms FAISS mask and FAISS score
  • As datastore size grows, using accurate distance values is better
  • Approximate search provides fuzziness that functions as a regularizer
  • Approximate kNN search prevents overfitting
  • Longer words and words that appear in many different contexts have better results with approximate nearest neighbors
  • Ensemble of two distance metrics is key to improvement
  • Improvement from kNN-LM is orthogonal to ensembling with a different base LM

Sparsification

  • kNN retrieval induces sparsity in the distribution over the vocabulary
  • k is typically small compared to the size of the vocabulary
  • kNN-LM increases the probability of the top-k entries
  • Sparsifying the output probability of a standard LM and interpolating it with the original LM does not provide any benefits

Stolen probabilities

  • Stolen probabilities effect refers to words that can never be selected as the argmax word
  • KNN-LM solves stolen probabilities problem by allowing to assign highest probability to any word
  • Vectors in embedding matrix not located in convex hull of others
  • KNN component provides memorization of training set
  • KNN-LM’s improvement lies in reducing “over-correction” error
  • Standard LM cross-entropy training loss does not emphasize examples where base LM performs badly

Conclusion

  • kNN-LM improves perplexity by ensembling different input representations, using approximate nearest neighbor search, and tuning the softmax temperature.
  • Using 1 + log b f v for each word type v ∈ V , based on either the frequency or the total training loss of the word, does not make a significant difference in the final perplexity.
  • Mixture of Softmax (MoS) increases the performance of the language model marginally.
  • Compressing the datastore down to a similar-sized embedding matrix for softmax computation does not work.
  • Unexpected results when using different kNN approximate retrieval settings are due to longer words being better with approximate neighbors than ground truth ones, and words that appear in more diverse contexts being better with approximate kNN.
  • The key to kNN-LM’s performance gain is the ensemble of two distance metrics: the standard dot product distance and the L2 distance.
  • Sparsifying the output probability with the tokens retrieved by the kNN search does not help.
  • The key in kNN-LM is that it selects “which tokens to include” in the kNN distribution, and not their distances.