Link to paper

The full paper is available here.

You can also find the paper on PapersWithCode here.

Abstract

  • Transformers are the state-of-the-art neural network architecture for machine learning.
  • Training Transformers on auto-regressive tasks is related to gradient-based meta-learning formulations.
  • Training self-attention-only Transformers on simple regression tasks shows similarity between models learned by GD and Transformers.
  • Optimized Transformers implement gradient descent in their forward pass.
  • Transformers surpass plain gradient descent by an iterative curvature correction.

Paper Content

Introduction

  • Transformers have been successful in many machine learning tasks
  • In-context learning is a phenomenon that may explain their success
  • The exact mechanisms of in-context learning are not fully understood
  • This paper aims to bridge the gap between in-context and meta-learning
  • Linear self-attention layers can induce an update identical to a single step of gradient descent
  • Linear self-attention-only Transformers can converge to the construction and generate models that align with models trained by gradient descent
  • MLPs can be incorporated into the Transformer architecture to solve non-linear regression tasks
  • Learned Transformers encode incoming tokens into a format amenable to in-context gradient descent learning
  • This paper connects learning Transformer weights and meta-learning a learning algorithm
  • Transformers come in different shapes and sizes and exhibit varying forms of in-context learning
  • Hypothesis 1: In-context learning in the Transformer forward pass is implemented by gradient-based optimization of an implicit auto-regressive inner loss

Data transformations induced by gradient descent

  • Introduce a reference linear model with a weight matrix
  • Goal of learning is to minimize squared-error loss
  • Gradient descent step yields weight change
  • Outcome of gradient descent step is an update to regression loss
  • In-context learning problem given N context tokens and an extra query token

Transformations induced by gradient descent and a linear self-attention layer can be equivalent

  • Learning a linear model is re-cast as directly modifying the data instead of computing and returning the weights
  • Self-attention is connected to gradient descent
  • Tokens are updated through a linear self-attention layer
  • A single head of self-attention is sufficient to transform training targets and test prediction simultaneously
  • Uniqueness of the construction is not required
  • Learning rate can be meta-learned to improve upon plain gradient descent
  • Self-attention layer can encode data transformations
  • Training objective is to minimize expected squared prediction error
  • Data is generated using a teacher model with parameters
  • Dimensions are set to N = n I = 10 and n O = 1
  • Teacher model is noiseless and tasks are analytically solvable

One-step of gradient descent vs. a single trained self-attention layer

  • Investigated whether a single, linear self-attention layer can be explained by weight construction that implements GD
  • Compared predictions made by LSA layer with trained weights and constructed weights
  • Predictions are linear in x test
  • Compared predictions with L2 norm
  • Compared sensitivities with cosine similarity
  • Sampled 10,000 validation tasks
  • Found excellent agreement between the two models
  • Scaling correction on trained weights is enough to recover weight construction implementing GD
  • Investigated how GD, trained LSA layer and interpolation behave when providing in-context data in different regimes
  • Tuned learning rate for input range [-1, 1] and one gradient step
  • Applied LSA update repeatedly, found same loss decrease for both GD and Transformer

Multiple steps of gradient descent vs. multiple layers of self-attention

  • Deep linear self-attention-only Transformers can be stacked up over K layers
  • Alignment between GD and GD++ trained models and model generated by trained TF is measured
  • TF aligns well with GD in beginning of training and converges to be better aligned with GD++ after training
  • Performance of trained Transformer mimics one of GD++ when testing on out-of-distribution tasks
  • K layers of Transformers generally outperform K steps of plain gradient descent
  • GD++ variant of gradient descent is used, with single parameter γ defined through transformation function H(X)
  • Recurrent 2-layer LSA model surpasses plain gradient descent
  • Trained model realigns perfectly with GD++ while matching its performance on in-and out-of distribution tasks
  • Non-recurrent 5-layer LSA-only Transformers with different parameters per layer outperform plain gradient descent

Transformers solve nonlinear regression tasks by gradient descent on deep data representations

  • Transformers can learn linear models by gradient descent on deep representations
  • MLPs in Transformers process both inputs and targets
  • A Transformer model with all weights learned is compared to a control Transformer with the final LSA weights set to a construction
  • Alignment measures represent the two first parts of the Taylor approximation of the obtained functions
  • MLPs and self-attention layers can interplay to support nonlinear in-context learning

Do self-attention layers build regression tasks?

  • Proposition 1 requires a particular token structure where input and output data are concatenated into a single token.
  • Proposition 2 allows a Transformer to build the required token construction on its own.
  • Evidence suggests that copying is performed in trained Transformers.
  • Softmax self-attention layers easily learn to copy, allowing the Transformer to emulate gradient-based learning.

Discussion

  • Transformers show remarkable in-context learning behavior
  • Mechanisms based on attention, associative memory and copying by induction heads are leading explanations
  • Hypothesis that Transformer’s in-context learning is driven by gradient descent
  • Schlag et al. (2021) linked linear self-attention layers with (fast-)inner loop learning by the delta rule
  • Evidence that in-context learning is based on gradient descent when training multi-layer self-attention-only Transformers on simple regression tasks
  • Need to incorporate noisy data and weight regularization into hypothesis
  • Investigate how to improve gradient descent based learning algorithms
  • Analyze in-context learning in larger models and language modeling
  • Analyze in-context learning in HyperTransformers
  • Compare GD and trained Transformer on other data distributions
  • Interpolate between construction θ GD and trained weights of Transformer θ
  • Construct key, query and value matrix W K , W Q , W V as well as the projection matrix P
  • Test behavior when repeating a single LSA-layer trained to lower objective
  • Compare GD and self-attention layer when changing dampening strength