Link to paper
The full paper is available here.
You can also find the paper on PapersWithCode here.
Abstract
- Transformers are the most popular models in machine learning and have achieved impressive performance.
- The theoretical understanding of transformer building blocks is limited.
- Dense Associative Memory models have a well-established theoretical foundation but have not achieved impressive practical results.
- Energy Transformer (ET) replaces the sequence of feedforward transformer blocks with a single large Associative Memory model.
- ET has many of the familiar architectural primitives used in current generation of transformers.
- ET is designed to minimize a specifically engineered energy function.
- ET’s attention is different from conventional attention mechanism.
- This paper introduces the theoretical foundations of ET, explores its empirical capabilities, and obtains strong quantitative results on the graph anomaly detection task.
Paper Content
Introduction
- Transformers are used in machine learning for language, vision, and audio processing.
- Transformers use four operations: attention, MLP, residual connection, and layer normalization.
- Variations of transformers are created by combining these operations in different ways.
- The search for new transformer architectures is mostly empirical.
- Hopfield Networks have been gaining popularity in the machine learning community.
- Transformers and Hopfield Networks are designed for denoising of the input.
- Transformers are pre-trained on a masked-token task.
- Hopfield Networks are trained to complete incomplete patterns.
- There are similarities and differences between Transformers and Hopfield Networks.
- An Energy Transformer (ET) is proposed to route information between tokens.
- ET uses an energy function tailored to the problem of interest.
- ET is used for image completion and graph anomaly detection.
- ET stands in line with or outperforms the latest benchmarks.
Energy transformer block
- ET network is a theoretical framework for computer science
- Input image is split into non-overlapping patches
- Each patch is represented by a D-dimensional vector
- ET block is described by a continuous time differential equation
- Goal is to allow masked particles to find an identity consistent with their locations
- Dynamical evolution is designed to minimize a global energy function
- Energy function is guided by two pieces of information: identities of open particles and general knowledge about possible patches
Layer norm
- Each token is represented by a vector in a D-dimensional space.
- ET block operations are defined using a normalized token representation.
- γ, δ, and ε are learnable parameters.
- This operation can be viewed as an activation function for neurons.
Multi-head energy attention
- ET’s energy function exchanges information between particles (patches).
- ET has a query and key matrix, but no separate value matrix.
- The goal of the energy based attention is to evolve the tokens so that the keys of open patches align with the queries of masked patches.
- The energy function is minimal when the queries of each patch are aligned with the keys of a small number of other patches.
Hopfield network module
- Hopfield Network (HN) ensures token representations are consistent with realistic images
- Energy of HN sub-block is defined by a set of learnable weights and an activation function
- Depending on activation function, HN can be viewed as classical or modern continuous Hopfield Network
- HN module is an MLP with shared weights applied recurrently
- Energy contribution of HN is low when token representations are aligned with memories
Dynamics of token updates
- ET network is described by a continuous time differential equation
- Minimizes sum of two energies
- First energy is low when queries are aligned with neighbors’ keys
- Second energy is low when patch has content consistent with expectations
- Dynamical system finds trade-off between two desirable properties
- Temporal derivative is positive semi-definite
Relationship to modern hopfield networks and conventional attention
- Design of energy attention mechanism and corresponding energy function
- Energy function similar to continuous Hopfield Network with softmax activation
- Keys in Modern Hopfield Networks are constant parameters, in energy attention network they are dynamical variables
- ET attention contribution to update dynamics given by two terms, one conventional and one new
Qualitative inspection of the et framework on imagenet
- Trained ET network on masked image completion task using ImageNet-1k dataset
- Images broken into non-overlapping patches of 16x16 RGB pixels
- Half of tokens “masked” and a distinct learnable position encoding vector added to each token
- ET block processes tokens recurrently for T steps
- Loss function is MSE loss on occluded patches
- Model learns to perform task well, but struggles to understand global structure
- Position encoding vectors associated with every token have high similarity values to other patches in same row/column
- ET block moves tokens around in same space from which final fixed point representation can be decoded back into image plane
- Visualize weights of HN module directly in image plane
- Visualize gradients of energy function of ATTN and HN blocks
Graph anomaly detection
- ET network is used to evaluate performance on graph anomaly detection
- Anomalies are outliers that deviate from majority of samples
- Three types of graph anomalies: node, edge, subgraph
- Focused on node anomaly detection in attributed graphs
- Node attributes encoded in latent space and treated as token
- Graph Convolutional Networks (GCN) used for task
- Vanilla GCNs suffer from over-smoothing problem
- ET network uses energy based attention to route information between nodes
- Task is to predict label of node given graph structure and node’s features
- Anomaly detection is an imbalanced node classification task
- Feature vectors converted to token representation using linear embedding and positional embedding
- Output of network fed into MLP with sigmoid activation to compute anomaly probabilities
- Weighted cross entropy used to train network
Experimental evaluation
- Four datasets used for graph anomaly detection experiments
- Graph treated as homogeneous graph, feature vector associated with each node
- Task is to predict label (anomaly status) of nodes
- Compare with state-of-the-art approaches for graph anomaly detection
- Evaluation metrics used are macro-F1 score and AUC
Discussion and conclusions
- Recent research has focused on understanding the analogy between Hopfield Networks and the attention mechanism in transformers.
- The transformer block can be viewed as a single large Hopfield Network.
- A novel energy function has been designed for dynamical information routing between tokens and representations.
- The attention mechanism contains an extra term compared to conventional attention.
- The network has been tested on image completion and node anomaly detection tasks.
Reproducibility statement
- Experiments were conducted to ensure reproducibility of results
- Training protocols and implementation details are described in Appendices A, B, and F
- Model and training code for images can be found at a given link
- Code for image reconstruction is written in JAX
- Training script sets a seed to recreate the same training setup
- No additional training data was used beyond the training set
- Code for Graph Anomaly Detection is written in PyTorch
- Results on graphs are reported with mean and standard deviations
- ET architecture can be built with HAMUX, a JAX-based Deep Learning library
A details of training on imagenet
- Trained ET network on masked-image completion task on ImageNet-1k dataset
- Images split into non-overlapping patches of 16x16 RGB pixels
- 90 tokens replaced with learnable MASK token, 10 left untouched
- Tokens passed to Energy Transformer block recurred for T steps
- MSE Loss between original pixels and reconstructed pixels for 100 occluded patches
- Self attention formula for energy of multiheaded attention
- Step size of 0.1 provides smoother descent down energy function
- MSE loss must include some subset of un-occluded patches for HN to learn meaningful filters
- β parameter in energy attention too high prevents model from training
- Gradient clipping helps model train faster and at higher learning rates
- Features mapped into token space with linear projection and learnable positional embeddings
- Attention operation restricted to neighborhood of given node
- Forward pass of ET-block minimizes energy function
- Anomaly detection treated as semi-supervised learning task
- Adam optimizer with learning rate of 0.001 used to train models
- Ablation studies show ATT block performs most of computation, removing ATT module results in more significant drop in performance