Link to paper
The full paper is available here.
You can also find the paper on PapersWithCode here.
Abstract
- Transformer-based language models (LMs) create hidden representations of their inputs at every layer.
- A method is suggested to cast the hidden representations as final representations, bypassing the transformer computation in-between.
- This method produces more accurate approximations than the prevailing practice of inspecting hidden representations from all layers.
- The method allows “peeking” into early layer representations of GPT-2 and BERT.
- The method can be used for early exit strategies, saving additional layers for GPT-2 and BERT.
- The method can be extended to linearly approximate sub-modules, finding that attention is most tolerant to this change.
Paper Content
Introduction
- Transformer-based language models process input sequences of tokens by transforming them through attention and feed-forward network layers
- Loss minimization directly optimizes the final representations, while hidden representations are only optimized implicitly
- Utilizing hidden representations is desirable as it can shed light on the “decision-making process” and reduce compute resources
- Previous attempts to exploit hidden representations viewed the hidden representations of an input token as a sequence of approximations of its final representation
- This work tackles this question by learning linear transformations across layers in the network
- Experiments show large accuracy gains in prediction estimation by mat over naive projections
- Mappings often produce correct predictions when applied to the very early layers in the network
- Method is useful in the setting of early exiting, allowing for the saving of compute time
- Linearly approximating attention results in the least reduction of precision
Background and notation
- Input to a transformer-based LM is a sequence of tokens from a vocabulary
- Tokens are represented as vectors using an embedding matrix
- Hidden representations are repeatedly transformed through L transformer blocks
- Each block is composed of two layers: multi-head self-attention layer and FFN layer
- Final representations are used to form various predictions
- Propose to cast hidden representations from earlier layers to succeeding layers using linear regression
Method
- Goal is to learn a mapping from hidden representations at one layer to those at another layer
- Collect a set of corresponding hidden representation pairs
- Run a set of input sequences through the model
- Extract the hidden representations from each input sequence
- Learn a matrix by fitting linear regression over the set of input sequences
- Define the mapping of a representation from one layer to another
Baseline
- Evaluated method against prevalent approach of “reading” hidden representations directly
- Propagation of hidden representation from layer to layer given by identity function (id)
- Commonly-used baseline assumes representations at different layers operate in same linear space
Quality of fit
- Evaluated method by measuring how well linear mappings approximate representations at target layer
- Used GPT-2 and BERT models with 12, 24, 36, and 48 layers for GPT-2, and 12 and 24 layers for BERT
- Used Wikipedia and news articles from Leipzig Corpora Collection for training and validation
- Used linear regression to fit mapping between layers
- Evaluated quality of mapping and identity mapping using r2-coefficient
- Results show mapping yields better approximations than identity mapping
- Highlights shortcomings of existing practices to inspect representations in same linear space
Linear shortcut for language modeling
- Method approximates future hidden representations better than naive propagation
- Method translates to better predictive abilities from earlier layers
- Evaluated using two metrics: Precision@k and Surprisal
Next token prediction
- Auto-regressive LMs output a probability distribution over the vocabulary for the next token.
- Layer normalization is applied to the final layer representation before conversion.
- The goal is to measure how well the distribution can be estimated from intermediate representations.
Masked token prediction
- Conducted an experiment to predict masked tokens in an input
- Used BERT with a pretrained masked language model head
- Results showed trends similar to those observed for next token prediction in GPT-2
- Mat outperformed id in terms of precision
- Mat improved the Precision@1 of id by more than 17%
- Manually analyzed 50 random sentences from the Leipzig dataset and found mat had a higher plausibility rate of 85.36% compared to 52.8% for id
Implication to early exiting
- Early exit strategies can save computation time by deciding when to stop computation and read the prediction from the hidden representation of that layer.
- A confidence measure is used to decide when to stop the computation.
- Experiments were conducted using GPT-2 and BERT models.
- Using the proposed mapping method, early exit strategies can save up to 20% layers compared to the baseline.
Linear shortcut across sub-modules
- Experiments show that transformer layers do not operate in the same linear space
- There is a gap in approximating future representations using an identity mapping
- Investigating whether discrepancies across layers result from specific sub-modules or are a general behaviour
- Extending approach to test how well particular components in transformer blocks can be linearly approximated
- Fitting linear regression to approximate output of sub-module given its input
- Applying mappings to disable contextualization between layers
- Evaluating Precision@k and Surprisal metrics for additional mappings
- Linear approximation of attention sub-modules is less harmful than FFN or layer normalization sub-modules
- Possibility of linear approximation permeates various transformer components
Related work
- Interest in using intermediate representations in transformer-based LMs for interpretability and efficiency
- Understanding prediction construction process of the model
- Viewing inference pass as a residual stream of information
- Probing to understand features stored in hidden representations
- Converting intermediate representations into a final-layer form
- Cutting computation at a dynamically-decided earlier stage
- Utilizing a fixed early stage network to parallelize inference
- Skipping transformer layers and analyzing linearity properties of transformer components
Conclusion and future work
- Present a method for inspection of hidden representations in transformer models
- Method uses pre-fitted context-free and token-uniform linear mappings
- Method consistently outperforms prevalent practice of interpreting representations in the final-layer space of the model
- Method improves computation efficiency and saves compute
- Method can be extended to sub-modules, resulting in small deterioration of prediction
- Experiments cover only English data
- Experiments performed on different data sources, model architectures and scales
- Results recorded for various sub-module linear shortcut mappings and the mapping mat →L
- Results recorded for various early exit methods and fixed exit methods
- Varying the confidence parameter λ
- Examples of top-5 predictions at layers 4, 12 and 24