Link to paper

The full paper is available here.

You can also find the paper on PapersWithCode here.

Abstract

  • Transformers are the most popular models in machine learning and have achieved impressive performance.
  • The theoretical understanding of transformer building blocks is limited.
  • Dense Associative Memory models have a well-established theoretical foundation but have not achieved impressive practical results.
  • Energy Transformer (ET) replaces the sequence of feedforward transformer blocks with a single large Associative Memory model.
  • ET has many of the familiar architectural primitives used in current generation of transformers.
  • ET is designed to minimize a specifically engineered energy function.
  • ET’s attention is different from conventional attention mechanism.
  • This paper introduces the theoretical foundations of ET, explores its empirical capabilities, and obtains strong quantitative results on the graph anomaly detection task.

Paper Content

Introduction

  • Transformers are used in machine learning for language, vision, and audio processing.
  • Transformers use four operations: attention, MLP, residual connection, and layer normalization.
  • Variations of transformers are created by combining these operations in different ways.
  • The search for new transformer architectures is mostly empirical.
  • Hopfield Networks have been gaining popularity in the machine learning community.
  • Transformers and Hopfield Networks are designed for denoising of the input.
  • Transformers are pre-trained on a masked-token task.
  • Hopfield Networks are trained to complete incomplete patterns.
  • There are similarities and differences between Transformers and Hopfield Networks.
  • An Energy Transformer (ET) is proposed to route information between tokens.
  • ET uses an energy function tailored to the problem of interest.
  • ET is used for image completion and graph anomaly detection.
  • ET stands in line with or outperforms the latest benchmarks.

Energy transformer block

  • ET network is a theoretical framework for computer science
  • Input image is split into non-overlapping patches
  • Each patch is represented by a D-dimensional vector
  • ET block is described by a continuous time differential equation
  • Goal is to allow masked particles to find an identity consistent with their locations
  • Dynamical evolution is designed to minimize a global energy function
  • Energy function is guided by two pieces of information: identities of open particles and general knowledge about possible patches

Layer norm

  • Each token is represented by a vector in a D-dimensional space.
  • ET block operations are defined using a normalized token representation.
  • γ, δ, and ε are learnable parameters.
  • This operation can be viewed as an activation function for neurons.

Multi-head energy attention

  • ET’s energy function exchanges information between particles (patches).
  • ET has a query and key matrix, but no separate value matrix.
  • The goal of the energy based attention is to evolve the tokens so that the keys of open patches align with the queries of masked patches.
  • The energy function is minimal when the queries of each patch are aligned with the keys of a small number of other patches.

Hopfield network module

  • Hopfield Network (HN) ensures token representations are consistent with realistic images
  • Energy of HN sub-block is defined by a set of learnable weights and an activation function
  • Depending on activation function, HN can be viewed as classical or modern continuous Hopfield Network
  • HN module is an MLP with shared weights applied recurrently
  • Energy contribution of HN is low when token representations are aligned with memories

Dynamics of token updates

  • ET network is described by a continuous time differential equation
  • Minimizes sum of two energies
  • First energy is low when queries are aligned with neighbors’ keys
  • Second energy is low when patch has content consistent with expectations
  • Dynamical system finds trade-off between two desirable properties
  • Temporal derivative is positive semi-definite

Relationship to modern hopfield networks and conventional attention

  • Design of energy attention mechanism and corresponding energy function
  • Energy function similar to continuous Hopfield Network with softmax activation
  • Keys in Modern Hopfield Networks are constant parameters, in energy attention network they are dynamical variables
  • ET attention contribution to update dynamics given by two terms, one conventional and one new

Qualitative inspection of the et framework on imagenet

  • Trained ET network on masked image completion task using ImageNet-1k dataset
  • Images broken into non-overlapping patches of 16x16 RGB pixels
  • Half of tokens “masked” and a distinct learnable position encoding vector added to each token
  • ET block processes tokens recurrently for T steps
  • Loss function is MSE loss on occluded patches
  • Model learns to perform task well, but struggles to understand global structure
  • Position encoding vectors associated with every token have high similarity values to other patches in same row/column
  • ET block moves tokens around in same space from which final fixed point representation can be decoded back into image plane
  • Visualize weights of HN module directly in image plane
  • Visualize gradients of energy function of ATTN and HN blocks

Graph anomaly detection

  • ET network is used to evaluate performance on graph anomaly detection
  • Anomalies are outliers that deviate from majority of samples
  • Three types of graph anomalies: node, edge, subgraph
  • Focused on node anomaly detection in attributed graphs
  • Node attributes encoded in latent space and treated as token
  • Graph Convolutional Networks (GCN) used for task
  • Vanilla GCNs suffer from over-smoothing problem
  • ET network uses energy based attention to route information between nodes
  • Task is to predict label of node given graph structure and node’s features
  • Anomaly detection is an imbalanced node classification task
  • Feature vectors converted to token representation using linear embedding and positional embedding
  • Output of network fed into MLP with sigmoid activation to compute anomaly probabilities
  • Weighted cross entropy used to train network

Experimental evaluation

  • Four datasets used for graph anomaly detection experiments
  • Graph treated as homogeneous graph, feature vector associated with each node
  • Task is to predict label (anomaly status) of nodes
  • Compare with state-of-the-art approaches for graph anomaly detection
  • Evaluation metrics used are macro-F1 score and AUC

Discussion and conclusions

  • Recent research has focused on understanding the analogy between Hopfield Networks and the attention mechanism in transformers.
  • The transformer block can be viewed as a single large Hopfield Network.
  • A novel energy function has been designed for dynamical information routing between tokens and representations.
  • The attention mechanism contains an extra term compared to conventional attention.
  • The network has been tested on image completion and node anomaly detection tasks.

Reproducibility statement

  • Experiments were conducted to ensure reproducibility of results
  • Training protocols and implementation details are described in Appendices A, B, and F
  • Model and training code for images can be found at a given link
  • Code for image reconstruction is written in JAX
  • Training script sets a seed to recreate the same training setup
  • No additional training data was used beyond the training set
  • Code for Graph Anomaly Detection is written in PyTorch
  • Results on graphs are reported with mean and standard deviations
  • ET architecture can be built with HAMUX, a JAX-based Deep Learning library

A details of training on imagenet

  • Trained ET network on masked-image completion task on ImageNet-1k dataset
  • Images split into non-overlapping patches of 16x16 RGB pixels
  • 90 tokens replaced with learnable MASK token, 10 left untouched
  • Tokens passed to Energy Transformer block recurred for T steps
  • MSE Loss between original pixels and reconstructed pixels for 100 occluded patches
  • Self attention formula for energy of multiheaded attention
  • Step size of 0.1 provides smoother descent down energy function
  • MSE loss must include some subset of un-occluded patches for HN to learn meaningful filters
  • β parameter in energy attention too high prevents model from training
  • Gradient clipping helps model train faster and at higher learning rates
  • Features mapped into token space with linear projection and learnable positional embeddings
  • Attention operation restricted to neighborhood of given node
  • Forward pass of ET-block minimizes energy function
  • Anomaly detection treated as semi-supervised learning task
  • Adam optimizer with learning rate of 0.001 used to train models
  • Ablation studies show ATT block performs most of computation, removing ATT module results in more significant drop in performance