Link to paper

The full paper is available here.

You can also find the paper on PapersWithCode here.

Abstract

  • Interpretability research aims to build tools to understand ML models.
  • We propose to build transformer models manually as a testbed for interpretability research.
  • Tracr is a “compiler” that translates human-readable programs into weights of a transformer model.
  • Tracr is used to create ground truth transformers that implement programs such as computing token frequencies, sorting, and Dyck-n parenthesis checking.
  • Tracr is open-source and available at https://github.com/deepmind/tracr.

Paper Content

Introduction

  • Deep learning models are becoming more capable and increasingly deployed in production.
  • Mechanistic interpretability aims to understand how they make decisions.
  • Cammarata et al. (2020) explain a range of specific circuits in InceptionV1 (Szegedy et al., 2015).
  • Elhage et al. (2021) and Wang et al. (2022) achieve early success in interpreting transformer language models.
  • Toolbox of approaches for generating mechanistic explanations is small and poorly understood.
  • Evaluating mechanistic explanations requires creativity and effort by researchers.
  • Standard approach for evaluating mechanistic explanations combines evidence from many ad-hoc experiments.
  • Tracr is a proof-of-concept implementation of a compiler to create models which perform nontrivial computation with a known implementation.
  • Tracr can be used to evaluate interpretability tools by comparing the resulting explanation to the ground truth.

Background

  • Tracr is a computer program that compiles a program written in the RASP programming language to a transformer model.
  • The compiled model for the input sequence “xacx” is shown with a full residual stream at each layer.
  • Section 4 of the paper discusses examples of the compiled model in more detail.

Transformer models

  • Transformer model consists of alternating MHA and MLP layers with residual connections
  • MHA computes attention maps on sequences of length
  • MHA combines attention heads by computing MHA()
  • MLP layers compute MLP()
  • Decoder-only transformer with GPT architecture consists of MHA, MLP, and layer normalization
  • Input to model is sum of learned embedding of sequence of input tokens and positional embedding
  • Model is trained to predict next token using gradient descent

Transformer circuits

  • Transformers have residual connections at each attention and MLP layer
  • Elhage et al. (2021) consider the residual connections a core feature of the architecture
  • Transformers are described in terms of a residual stream that each layer reads from and writes to in sequence
  • MHA is parameterised by low-rank matrices and acts as a bilinear operator
  • Softmax is the only nonlinearity in an attention head

The rasp programming language

  • RASP is a domain-specific language for expressing transformer computations
  • Weiss et al. (2021) proposed RASP as a computational model and provided an interpreter
  • RASP programs can be seen as a computational graph
  • There are two basic node types: sequence operations and selectors
  • Elementwise operations correspond to MLP layers in transformers
  • Select-aggregate operations correspond to attention in transformers

Tracr: a transformer compiler for rasp

  • Describes how RASP maps to transformer architecture
  • Proposes modifications to RASP to make mapping more straightforward
  • Introduces craft, an “assembly language” for transformer models
  • Describes how Tracr translates RASP programs to transformer weights
  • Full open-source implementation of Tracr available at GitHub

Mapping rasp to tranformers

  • RASP provides a computational model of transformers
  • RASP operations can be mapped to components of a transformer model
  • Tokens and positions can be embedded as categorical variables in orthogonal subspaces
  • MLP layers can approximate any function
  • RASP’s select-aggregate operations map to attention layers in transformer models

Modifications to rasp

  • RASP language needs to be modified to allow translating it to model weights
  • Disallow arbitrary selector combinations
  • Restrict RASP to selectors with only two input variables
  • Encode categorical and numerical variables in dedicated subspaces of the residual stream
  • Require each s-op to be either categorical or numerical
  • Make BOS token mandatory in RASP

Craft: an assembly language for transformers

  • Craft represents vector spaces with labelled basis dimensions and operations on them.
  • Craft abstracts away the need to keep track of padding in weight matrices.
  • Craft models are independent of concrete transformer implementations.

Compiler overview

  • Tracr is a computer science program
  • Tracr is written in Python
  • Tracr uses RASP programs
  • Tracr simplifies RASP programs
  • Tracr translates RASP programs to transformer weights in 6 steps
  • Tracr is based on the example decoder-only transformer from Haiku

Exploring compiled transformers

  • Compiling models is ready to start
  • Examples are provided in Appendix D
  • Programs were compiled for tasks in Weiss et al. (2021)
  • Programs were modified to use features supported by Tracr

Example 1: counting tokens

  • Input dimension is fixed to 1
  • Used as a constant to add bias in MLP layers

Example 2: sorting

  • Figure 5 shows a program that sorts a sequence of unique tokens
  • selector_width is a primitive that compiles directly to an attention and MLP layer
  • Weiss et al. (2021) propose a sort program that can handle duplicates
  • An alternative implementation of sort is provided in Appendix D that handles duplicates by adding a small multiple of indices to the keys

More examples

  • Tracr can compile RASP programs
  • Appendix D discusses more examples, including checking balanced parentheses
  • Tracr implementation contains a library of example programs

Compressing compiled transformers

  • Tracr models can be sparse and inefficient.
  • An experimental approach is proposed to make them more efficient.
  • Two case studies of compressing compiled models are presented.
  • Compressed models allow us to study how real neural networks might compress features.
  • Superposition has not been studied in models deeper than two layers.

Gradient descent based compression

  • Use a single linear projection to compress the residual stream
  • Train the projection using SGD
  • Minimise loss under the constraint that it implements the same computation as the original model
  • Regularization term incentivises the compressed model to match the per-layer outputs of the original model

What does the compression learn?

  • Model from Figure 2 computes fraction of token “x”
  • Embedding matrix reduces residual dimension from 14 to 6
  • Comparing to PCA to compress model
  • Input tokens, variables crucial for computing fraction of tokens
  • Variables encode similar information

Do the compressed models still implement the same computation?

  • Compressed models can achieve low loss but need to be checked to see if they implement the same computation as the compiled models.
  • Evaluated average cosine similarity between output at each layer of the two models.
  • Compressed model can learn a different computation to be more efficient.
  • Compressed model not guaranteed to be faithful to the original RASP program.

Discussion

  • Open-source implementation of Tracr available
  • Potential applications in interpretability research
  • Limitations of Tracr and how to address them

Applications of compiled models in interpretability research

  • Compilers like Tracr allow researchers to test hypotheses about computational structure of transformers
  • Compiled models can be used to test faithfulness of explanations given by interpretability techniques
  • Compiled models can be used to build libraries of test cases for interpretability tools
  • Compiled models can be used to evaluate understanding of how a model works by replacing parts of the model with hand-coded components

Limitations of rasp and tracr

  • RASP and Tracr are limited in terms of expressivity, efficiency and realism compared to real transformer models
  • RASP is designed for algorithmic tasks that map an input sequence to a discrete output sequence
  • Current language models usually map a sequence of input tokens to a probability distribution over the next token
  • RASP only uses binary attention patterns
  • Tracr models store all variables in orthogonal subspaces of the residual stream
  • Tracr constructs layers from hand-coded parameter matrices
  • Tracr models align their features with the computational basis
  • Compiled models are an intermediate step between very simple toy models and real learned models
  • Compiled models will always be simpler than real models

Conclusion

  • Proposed manually constructing neural network weights and using them to develop and evaluate new interpretability tools
  • Developed Tracr, a tool for compiling human-readable code to the weights of a transformer model
  • Outlined vision for use of compiled models in interpretability
  • Described how to construct MLP and attention blocks
  • Implemented selector width primitive
  • Extended RASP and Tracr to use causal attention
  • Implemented compression in Jax on top of Haiku transformer
  • Trained using AdamW optimizer
  • Studied how to implement combinations of selectors with more than two inputs
  • Interpreted matrix as a bilinear operator
  • Lemma 1: No ∧ operating over ( ⊕ ) × ( ⊕ )
  • Lemma 2: Construct ∧ operating over ( ⊗ ) × ( ⊗ )
  • Figure 1: Tracr allows us to create models that implement a known mechanism
  • Figure 2: Example RASP program that computes fraction of previous “x” tokens
  • Figure 3: Tracr translates RASP to craft and then to model weights
  • Figure 4: Schematic overview of how Tracr compiles frac_prevs program
  • Figure 6: Training setup for compressing a compiled transformer model
  • Figure 7: Loss of compressed Tracr models for frac_prevs program