Link to paper
The full paper is available here.
You can also find the paper on PapersWithCode here.
Abstract
- Transformers are the state-of-the-art neural network architecture for machine learning.
- Training Transformers on auto-regressive tasks is related to gradient-based meta-learning formulations.
- Training self-attention-only Transformers on simple regression tasks shows similarity between models learned by GD and Transformers.
- Optimized Transformers implement gradient descent in their forward pass.
- Transformers surpass plain gradient descent by an iterative curvature correction.
Paper Content
Introduction
- Transformers have been successful in many machine learning tasks
- In-context learning is a phenomenon that may explain their success
- The exact mechanisms of in-context learning are not fully understood
- This paper aims to bridge the gap between in-context and meta-learning
- Linear self-attention layers can induce an update identical to a single step of gradient descent
- Linear self-attention-only Transformers can converge to the construction and generate models that align with models trained by gradient descent
- MLPs can be incorporated into the Transformer architecture to solve non-linear regression tasks
- Learned Transformers encode incoming tokens into a format amenable to in-context gradient descent learning
- This paper connects learning Transformer weights and meta-learning a learning algorithm
- Transformers come in different shapes and sizes and exhibit varying forms of in-context learning
- Hypothesis 1: In-context learning in the Transformer forward pass is implemented by gradient-based optimization of an implicit auto-regressive inner loss
Data transformations induced by gradient descent
- Introduce a reference linear model with a weight matrix
- Goal of learning is to minimize squared-error loss
- Gradient descent step yields weight change
- Outcome of gradient descent step is an update to regression loss
- In-context learning problem given N context tokens and an extra query token
Transformations induced by gradient descent and a linear self-attention layer can be equivalent
- Learning a linear model is re-cast as directly modifying the data instead of computing and returning the weights
- Self-attention is connected to gradient descent
- Tokens are updated through a linear self-attention layer
- A single head of self-attention is sufficient to transform training targets and test prediction simultaneously
- Uniqueness of the construction is not required
- Learning rate can be meta-learned to improve upon plain gradient descent
- Self-attention layer can encode data transformations
- Training objective is to minimize expected squared prediction error
- Data is generated using a teacher model with parameters
- Dimensions are set to N = n I = 10 and n O = 1
- Teacher model is noiseless and tasks are analytically solvable
One-step of gradient descent vs. a single trained self-attention layer
- Investigated whether a single, linear self-attention layer can be explained by weight construction that implements GD
- Compared predictions made by LSA layer with trained weights and constructed weights
- Predictions are linear in x test
- Compared predictions with L2 norm
- Compared sensitivities with cosine similarity
- Sampled 10,000 validation tasks
- Found excellent agreement between the two models
- Scaling correction on trained weights is enough to recover weight construction implementing GD
- Investigated how GD, trained LSA layer and interpolation behave when providing in-context data in different regimes
- Tuned learning rate for input range [-1, 1] and one gradient step
- Applied LSA update repeatedly, found same loss decrease for both GD and Transformer
Multiple steps of gradient descent vs. multiple layers of self-attention
- Deep linear self-attention-only Transformers can be stacked up over K layers
- Alignment between GD and GD++ trained models and model generated by trained TF is measured
- TF aligns well with GD in beginning of training and converges to be better aligned with GD++ after training
- Performance of trained Transformer mimics one of GD++ when testing on out-of-distribution tasks
- K layers of Transformers generally outperform K steps of plain gradient descent
- GD++ variant of gradient descent is used, with single parameter γ defined through transformation function H(X)
- Recurrent 2-layer LSA model surpasses plain gradient descent
- Trained model realigns perfectly with GD++ while matching its performance on in-and out-of distribution tasks
- Non-recurrent 5-layer LSA-only Transformers with different parameters per layer outperform plain gradient descent
Transformers solve nonlinear regression tasks by gradient descent on deep data representations
- Transformers can learn linear models by gradient descent on deep representations
- MLPs in Transformers process both inputs and targets
- A Transformer model with all weights learned is compared to a control Transformer with the final LSA weights set to a construction
- Alignment measures represent the two first parts of the Taylor approximation of the obtained functions
- MLPs and self-attention layers can interplay to support nonlinear in-context learning
Do self-attention layers build regression tasks?
- Proposition 1 requires a particular token structure where input and output data are concatenated into a single token.
- Proposition 2 allows a Transformer to build the required token construction on its own.
- Evidence suggests that copying is performed in trained Transformers.
- Softmax self-attention layers easily learn to copy, allowing the Transformer to emulate gradient-based learning.
Discussion
- Transformers show remarkable in-context learning behavior
- Mechanisms based on attention, associative memory and copying by induction heads are leading explanations
- Hypothesis that Transformer’s in-context learning is driven by gradient descent
- Schlag et al. (2021) linked linear self-attention layers with (fast-)inner loop learning by the delta rule
- Evidence that in-context learning is based on gradient descent when training multi-layer self-attention-only Transformers on simple regression tasks
- Need to incorporate noisy data and weight regularization into hypothesis
- Investigate how to improve gradient descent based learning algorithms
- Analyze in-context learning in larger models and language modeling
- Analyze in-context learning in HyperTransformers
- Compare GD and trained Transformer on other data distributions
- Interpolate between construction θ GD and trained weights of Transformer θ
- Construct key, query and value matrix W K , W Q , W V as well as the projection matrix P
- Test behavior when repeating a single LSA-layer trained to lower objective
- Compare GD and self-attention layer when changing dampening strength