Link to paper
The full paper is available here.
You can also find the paper on PapersWithCode here.
Abstract
- Interpretability research aims to build tools to understand ML models.
- We propose to build transformer models manually as a testbed for interpretability research.
- Tracr is a “compiler” that translates human-readable programs into weights of a transformer model.
- Tracr is used to create ground truth transformers that implement programs such as computing token frequencies, sorting, and Dyck-n parenthesis checking.
- Tracr is open-source and available at https://github.com/deepmind/tracr.
Paper Content
Introduction
- Deep learning models are becoming more capable and increasingly deployed in production.
- Mechanistic interpretability aims to understand how they make decisions.
- Cammarata et al. (2020) explain a range of specific circuits in InceptionV1 (Szegedy et al., 2015).
- Elhage et al. (2021) and Wang et al. (2022) achieve early success in interpreting transformer language models.
- Toolbox of approaches for generating mechanistic explanations is small and poorly understood.
- Evaluating mechanistic explanations requires creativity and effort by researchers.
- Standard approach for evaluating mechanistic explanations combines evidence from many ad-hoc experiments.
- Tracr is a proof-of-concept implementation of a compiler to create models which perform nontrivial computation with a known implementation.
- Tracr can be used to evaluate interpretability tools by comparing the resulting explanation to the ground truth.
Background
- Tracr is a computer program that compiles a program written in the RASP programming language to a transformer model.
- The compiled model for the input sequence “xacx” is shown with a full residual stream at each layer.
- Section 4 of the paper discusses examples of the compiled model in more detail.
Transformer models
- Transformer model consists of alternating MHA and MLP layers with residual connections
- MHA computes attention maps on sequences of length
- MHA combines attention heads by computing MHA()
- MLP layers compute MLP()
- Decoder-only transformer with GPT architecture consists of MHA, MLP, and layer normalization
- Input to model is sum of learned embedding of sequence of input tokens and positional embedding
- Model is trained to predict next token using gradient descent
Transformer circuits
- Transformers have residual connections at each attention and MLP layer
- Elhage et al. (2021) consider the residual connections a core feature of the architecture
- Transformers are described in terms of a residual stream that each layer reads from and writes to in sequence
- MHA is parameterised by low-rank matrices and acts as a bilinear operator
- Softmax is the only nonlinearity in an attention head
The rasp programming language
- RASP is a domain-specific language for expressing transformer computations
- Weiss et al. (2021) proposed RASP as a computational model and provided an interpreter
- RASP programs can be seen as a computational graph
- There are two basic node types: sequence operations and selectors
- Elementwise operations correspond to MLP layers in transformers
- Select-aggregate operations correspond to attention in transformers
Tracr: a transformer compiler for rasp
- Describes how RASP maps to transformer architecture
- Proposes modifications to RASP to make mapping more straightforward
- Introduces craft, an “assembly language” for transformer models
- Describes how Tracr translates RASP programs to transformer weights
- Full open-source implementation of Tracr available at GitHub
Mapping rasp to tranformers
- RASP provides a computational model of transformers
- RASP operations can be mapped to components of a transformer model
- Tokens and positions can be embedded as categorical variables in orthogonal subspaces
- MLP layers can approximate any function
- RASP’s select-aggregate operations map to attention layers in transformer models
Modifications to rasp
- RASP language needs to be modified to allow translating it to model weights
- Disallow arbitrary selector combinations
- Restrict RASP to selectors with only two input variables
- Encode categorical and numerical variables in dedicated subspaces of the residual stream
- Require each s-op to be either categorical or numerical
- Make BOS token mandatory in RASP
Craft: an assembly language for transformers
- Craft represents vector spaces with labelled basis dimensions and operations on them.
- Craft abstracts away the need to keep track of padding in weight matrices.
- Craft models are independent of concrete transformer implementations.
Compiler overview
- Tracr is a computer science program
- Tracr is written in Python
- Tracr uses RASP programs
- Tracr simplifies RASP programs
- Tracr translates RASP programs to transformer weights in 6 steps
- Tracr is based on the example decoder-only transformer from Haiku
Exploring compiled transformers
- Compiling models is ready to start
- Examples are provided in Appendix D
- Programs were compiled for tasks in Weiss et al. (2021)
- Programs were modified to use features supported by Tracr
Example 1: counting tokens
- Input dimension is fixed to 1
- Used as a constant to add bias in MLP layers
Example 2: sorting
- Figure 5 shows a program that sorts a sequence of unique tokens
- selector_width is a primitive that compiles directly to an attention and MLP layer
- Weiss et al. (2021) propose a sort program that can handle duplicates
- An alternative implementation of sort is provided in Appendix D that handles duplicates by adding a small multiple of indices to the keys
More examples
- Tracr can compile RASP programs
- Appendix D discusses more examples, including checking balanced parentheses
- Tracr implementation contains a library of example programs
Compressing compiled transformers
- Tracr models can be sparse and inefficient.
- An experimental approach is proposed to make them more efficient.
- Two case studies of compressing compiled models are presented.
- Compressed models allow us to study how real neural networks might compress features.
- Superposition has not been studied in models deeper than two layers.
Gradient descent based compression
- Use a single linear projection to compress the residual stream
- Train the projection using SGD
- Minimise loss under the constraint that it implements the same computation as the original model
- Regularization term incentivises the compressed model to match the per-layer outputs of the original model
What does the compression learn?
- Model from Figure 2 computes fraction of token “x”
- Embedding matrix reduces residual dimension from 14 to 6
- Comparing to PCA to compress model
- Input tokens, variables crucial for computing fraction of tokens
- Variables encode similar information
Do the compressed models still implement the same computation?
- Compressed models can achieve low loss but need to be checked to see if they implement the same computation as the compiled models.
- Evaluated average cosine similarity between output at each layer of the two models.
- Compressed model can learn a different computation to be more efficient.
- Compressed model not guaranteed to be faithful to the original RASP program.
Discussion
- Open-source implementation of Tracr available
- Potential applications in interpretability research
- Limitations of Tracr and how to address them
Applications of compiled models in interpretability research
- Compilers like Tracr allow researchers to test hypotheses about computational structure of transformers
- Compiled models can be used to test faithfulness of explanations given by interpretability techniques
- Compiled models can be used to build libraries of test cases for interpretability tools
- Compiled models can be used to evaluate understanding of how a model works by replacing parts of the model with hand-coded components
Limitations of rasp and tracr
- RASP and Tracr are limited in terms of expressivity, efficiency and realism compared to real transformer models
- RASP is designed for algorithmic tasks that map an input sequence to a discrete output sequence
- Current language models usually map a sequence of input tokens to a probability distribution over the next token
- RASP only uses binary attention patterns
- Tracr models store all variables in orthogonal subspaces of the residual stream
- Tracr constructs layers from hand-coded parameter matrices
- Tracr models align their features with the computational basis
- Compiled models are an intermediate step between very simple toy models and real learned models
- Compiled models will always be simpler than real models
Conclusion
- Proposed manually constructing neural network weights and using them to develop and evaluate new interpretability tools
- Developed Tracr, a tool for compiling human-readable code to the weights of a transformer model
- Outlined vision for use of compiled models in interpretability
- Described how to construct MLP and attention blocks
- Implemented selector width primitive
- Extended RASP and Tracr to use causal attention
- Implemented compression in Jax on top of Haiku transformer
- Trained using AdamW optimizer
- Studied how to implement combinations of selectors with more than two inputs
- Interpreted matrix as a bilinear operator
- Lemma 1: No ∧ operating over ( ⊕ ) × ( ⊕ )
- Lemma 2: Construct ∧ operating over ( ⊗ ) × ( ⊗ )
- Figure 1: Tracr allows us to create models that implement a known mechanism
- Figure 2: Example RASP program that computes fraction of previous “x” tokens
- Figure 3: Tracr translates RASP to craft and then to model weights
- Figure 4: Schematic overview of how Tracr compiles frac_prevs program
- Figure 6: Training setup for compressing a compiled transformer model
- Figure 7: Loss of compressed Tracr models for frac_prevs program