Link to paper

The full paper is available here.

You can also find the paper on PapersWithCode here.

Abstract

  • Transformer-based language models (LMs) create hidden representations of their inputs at every layer.
  • A method is suggested to cast the hidden representations as final representations, bypassing the transformer computation in-between.
  • This method produces more accurate approximations than the prevailing practice of inspecting hidden representations from all layers.
  • The method allows “peeking” into early layer representations of GPT-2 and BERT.
  • The method can be used for early exit strategies, saving additional layers for GPT-2 and BERT.
  • The method can be extended to linearly approximate sub-modules, finding that attention is most tolerant to this change.

Paper Content

Introduction

  • Transformer-based language models process input sequences of tokens by transforming them through attention and feed-forward network layers
  • Loss minimization directly optimizes the final representations, while hidden representations are only optimized implicitly
  • Utilizing hidden representations is desirable as it can shed light on the “decision-making process” and reduce compute resources
  • Previous attempts to exploit hidden representations viewed the hidden representations of an input token as a sequence of approximations of its final representation
  • This work tackles this question by learning linear transformations across layers in the network
  • Experiments show large accuracy gains in prediction estimation by mat over naive projections
  • Mappings often produce correct predictions when applied to the very early layers in the network
  • Method is useful in the setting of early exiting, allowing for the saving of compute time
  • Linearly approximating attention results in the least reduction of precision

Background and notation

  • Input to a transformer-based LM is a sequence of tokens from a vocabulary
  • Tokens are represented as vectors using an embedding matrix
  • Hidden representations are repeatedly transformed through L transformer blocks
  • Each block is composed of two layers: multi-head self-attention layer and FFN layer
  • Final representations are used to form various predictions
  • Propose to cast hidden representations from earlier layers to succeeding layers using linear regression

Method

  • Goal is to learn a mapping from hidden representations at one layer to those at another layer
  • Collect a set of corresponding hidden representation pairs
  • Run a set of input sequences through the model
  • Extract the hidden representations from each input sequence
  • Learn a matrix by fitting linear regression over the set of input sequences
  • Define the mapping of a representation from one layer to another

Baseline

  • Evaluated method against prevalent approach of “reading” hidden representations directly
  • Propagation of hidden representation from layer to layer given by identity function (id)
  • Commonly-used baseline assumes representations at different layers operate in same linear space

Quality of fit

  • Evaluated method by measuring how well linear mappings approximate representations at target layer
  • Used GPT-2 and BERT models with 12, 24, 36, and 48 layers for GPT-2, and 12 and 24 layers for BERT
  • Used Wikipedia and news articles from Leipzig Corpora Collection for training and validation
  • Used linear regression to fit mapping between layers
  • Evaluated quality of mapping and identity mapping using r2-coefficient
  • Results show mapping yields better approximations than identity mapping
  • Highlights shortcomings of existing practices to inspect representations in same linear space

Linear shortcut for language modeling

  • Method approximates future hidden representations better than naive propagation
  • Method translates to better predictive abilities from earlier layers
  • Evaluated using two metrics: Precision@k and Surprisal

Next token prediction

  • Auto-regressive LMs output a probability distribution over the vocabulary for the next token.
  • Layer normalization is applied to the final layer representation before conversion.
  • The goal is to measure how well the distribution can be estimated from intermediate representations.

Masked token prediction

  • Conducted an experiment to predict masked tokens in an input
  • Used BERT with a pretrained masked language model head
  • Results showed trends similar to those observed for next token prediction in GPT-2
  • Mat outperformed id in terms of precision
  • Mat improved the Precision@1 of id by more than 17%
  • Manually analyzed 50 random sentences from the Leipzig dataset and found mat had a higher plausibility rate of 85.36% compared to 52.8% for id

Implication to early exiting

  • Early exit strategies can save computation time by deciding when to stop computation and read the prediction from the hidden representation of that layer.
  • A confidence measure is used to decide when to stop the computation.
  • Experiments were conducted using GPT-2 and BERT models.
  • Using the proposed mapping method, early exit strategies can save up to 20% layers compared to the baseline.

Linear shortcut across sub-modules

  • Experiments show that transformer layers do not operate in the same linear space
  • There is a gap in approximating future representations using an identity mapping
  • Investigating whether discrepancies across layers result from specific sub-modules or are a general behaviour
  • Extending approach to test how well particular components in transformer blocks can be linearly approximated
  • Fitting linear regression to approximate output of sub-module given its input
  • Applying mappings to disable contextualization between layers
  • Evaluating Precision@k and Surprisal metrics for additional mappings
  • Linear approximation of attention sub-modules is less harmful than FFN or layer normalization sub-modules
  • Possibility of linear approximation permeates various transformer components
  • Interest in using intermediate representations in transformer-based LMs for interpretability and efficiency
  • Understanding prediction construction process of the model
  • Viewing inference pass as a residual stream of information
  • Probing to understand features stored in hidden representations
  • Converting intermediate representations into a final-layer form
  • Cutting computation at a dynamically-decided earlier stage
  • Utilizing a fixed early stage network to parallelize inference
  • Skipping transformer layers and analyzing linearity properties of transformer components

Conclusion and future work

  • Present a method for inspection of hidden representations in transformer models
  • Method uses pre-fitted context-free and token-uniform linear mappings
  • Method consistently outperforms prevalent practice of interpreting representations in the final-layer space of the model
  • Method improves computation efficiency and saves compute
  • Method can be extended to sub-modules, resulting in small deterioration of prediction
  • Experiments cover only English data
  • Experiments performed on different data sources, model architectures and scales
  • Results recorded for various sub-module linear shortcut mappings and the mapping mat →L
  • Results recorded for various early exit methods and fixed exit methods
  • Varying the confidence parameter λ
  • Examples of top-5 predictions at layers 4, 12 and 24