Link to paper

The full paper is available here.

You can also find the paper on PapersWithCode here.

Abstract

  • Neural networks have achieved impressive results on many tasks.
  • This paper connects neural feature learning to the average gradient outer product.
  • The paper introduces Recursive Feature Machines (RFMs) which are kernel machines that learn features.
  • RFMs accurately capture features learned by deep fully connected neural networks.
  • RFMs close the gap between kernel machines and fully connected networks.
  • RFMs surpass a broad spectrum of models including neural networks on tabular data.
  • RFMs shed light on deep learning phenomena such as grokking, lottery tickets, simplicity biases, and spurious features.

Paper Content

Introduction

  • Modern neural networks have achieved major progress on a variety of applications
  • Feature learning is thought to be a central contributor to their superior performance
  • Identifying the component of neural networks associated with feature learning has been a challenge
  • Different lines of investigation connect feature learning to various aspects of neural network methodology
  • Isolate a key mechanism of feature learning in deep fully connected neural networks
  • Connect feature learning with a statistical estimator for feature selection
  • Develop a class of kernel machines, Recursive Feature Machines (RFMs), that learn features
  • RFMs accurately capture features learned by deep networks
  • Close the gap between kernel machines and neural networks
  • Achieve state-of-the-art performance on tabular data
  • Provide theoretical evidence for the connection between feature learning in fully connected networks and the expected gradient outer product
  • Features learned by RFMs and deep networks are highly correlated
  • Training RFMs on these learned features leads to improved predictive performance over deep neural networks

Neural feature learning and recursive feature machines

  • Connecting feature learning in neural networks to the average gradient outer product
  • Developing recursive feature machines, a class of kernel machines that learn features
  • Defining the feature matrix of a neural network
  • Claiming that the feature matrix and the average gradient outer product capture similar sets of features
  • Implication that features learned by estimating the average gradient outer product should match those learned by deep networks
  • Two lines of evidence: theoretical and empirical
  • Theoretical evidence: establishing the connection between the feature matrix and the average gradient outer product in certain settings
  • Empirical evidence: features learned by RFMs are highly correlated with those captured by the feature matrix of trained neural networks
  • Alternating between estimating a predictor and learning features using kernel machines
  • Using the Laplace kernel for RFMs
  • Diagonals of feature matrices from RFMs and deep networks are highly correlated

Empirical results

  • RFMs accurately capture features learned by fully connected networks
  • RFMs close the gap between kernel machines and neural networks
  • RFMs provide state-of-the-art results on large tabular datasets
  • High correlation between feature matrices of deep networks and RFMs

Rfms close the gap between kernels and fully connected networks.

  • RFMs accurately capture features learned by fully connected networks
  • RFMs match or outperform neural networks in predictive performance
  • RFMs and deep networks learn to subset to the center columns of the image
  • RFMs and deep networks outperform previous kernel methods by 10%
  • RFMs match or outperform neural networks on other illustrative tasks
  • RFMs provide an improvement in error for large datasets with high data dimension

Rfms outperform neural networks and ntk on tabular data benchmarks.

  • RFMs outperform NTKs, variants of Laplace kernels, and all 179 other methods on subset of 90 small datasets and full 121 datasets from original benchmark
  • RFMs achieve highest average rank, average accuracy, P90/P95, and PMA across all datasets
  • RFMs outperform modern neural networks and tree-based models on large tabular tasks from [23]
  • RFMs achieve highest ADTM on large classification and regression tasks
  • RFMs outperform modern neural networks on medium size tabular datasets
  • Tree-based models slightly outperform RFMs on regression tasks with categorically encoded data

Deep learning phenomena through the prism of rfm

  • Deep neural networks have phenomena not seen in kernel methods
  • RFMs can reproduce these phenomena
  • Grokking
  • Lottery Ticket Hypothesis

Grokking in rfms and deep networks.

  • Grokking is a phenomenon in which deep networks can increase test accuracy when training past 100% training accuracy.
  • Grokking can be seen in transformers, RFMs, and deep fully connected networks.
  • A dataset with a large class imbalance and a 5x5 pixel square in the upper left corner of each image can enable grokking.

Lottery tickets in rfms and deep networks.

  • The lottery ticket hypothesis states that a randomly-initialized neural network contains a sub-network that can match or outperform the trained network when trained in isolation.
  • Visualizing the diagonals of the feature matrix shows sparsity, which provides evidence for the lottery ticket hypothesis.

Inductive biases of rfms and deep networks.

  • Neural networks are over-parameterized but still yield improved performance
  • Recent works have analyzed inductive biases of deep networks
  • Simplicity bias is a form of inductive bias in deep networks
  • Neural networks and RFMs can accurately capture simplicity biases

Rfms capture spurious features and biases in deep networks.

  • Deep networks can use features that are not related to the object of interest, making them unreliable.
  • Adversarial examples are caused by these spurious features.
  • RFMs can be used to identify these spurious features used by deep networks.
  • Applying a small perturbation to these identified features can lead to a large decrease in accuracy.

Summary, discussion, and outlook

  • Isolated key mechanism of feature learning in deep fully connected neural networks
  • Proposed Neural Feature Ansatz - first layer of neural networks responsible for feature learning
  • Learned features closely related to average gradient outer product
  • Developed Recursive Feature Machines (RFMs) to learn features
  • RFMs accurately capture features learned by deep fully connected neural networks
  • RFMs close gap between kernel methods and fully connected networks
  • RFMs yield state-of-the-art performance on large tabular datasets
  • Many deep learning phenomena can be understood through RFMs
  • Network width modulates between two different regimes
  • Feature learning largely in first layer of deep networks
  • Transparency of RFMs can increase interpretability of machine learning models
  • RFMs can be viewed as a method for learning a data-dependent kernel
  • Feature learning not always lead to improvements
  • Connections to metric and manifold learning, FisherFaces and EigenFaces, Debiasing, Expectation Maximization
  • Separating predictor and feature learning components of neural networks provides a modular path forward

B background on kernel ridge regression

  • Kernel ridge regression is a non-parametric estimator used in computer science
  • It involves solving an infinite dimensional optimization problem in a Reproducing Kernel Hilbert Space
  • For datasets with n โ‰ค 100,000, the problem can be solved in closed form
  • For larger datasets, the EigenPro solver is used to approximate the solution via early-stopped, preconditioned-SGD on the GPU

C dataset and experimental details

  • Experiments with RFMs and fully connected networks on CelebA
  • Normalize all images to be on the unit sphere
  • Train 2-hidden layer ReLU networks with 1024 hidden units per layer using SGD for 500 epochs with a learning rate of 0.1 and a mini-batch size of 128
  • Train RFMs for 1 iteration, use a ridge regularization term of 10-3, and average the gradient outer product of at most 20000 examples
  • All RFMs use Laplace kernels as the base kernel and use a bandwidth parameter of L = 10
  • Split available training data into 80% training and 20% validation for hyperparameter selection
  • Report accuracy on a held out test set provided by PyTorch
  • Ensure that the training set and test set are balanced by limiting the number of majority class samples to the same number of minority class samples
  • Limit the total number of training and validation examples per experiment to 50000 (25000 per class)
  • Train 2-hidden layer ReLU networks with 1024 hidden units per layer using SGD for 500 epochs with a learning rate of 0.1 and a mini-batch size of 100 for SVHN
  • Train RFMs for 5 iterations and average the gradient outer product of at most 20000 examples
  • RFMs and Laplace kernels used all have a bandwidth parameter of 10
  • Compare with the NTK of a 2-hidden layer ReLU network
  • Use 1000 examples for training and 10000 samples for testing for low rank polynomials
  • Train a 1 hidden layer neural network for 1000 epochs using full batch gradient descent with a learning rate of .1 and initialize the first layer with standard deviation 10-3
  • Train RFMs with no ridge term and set the base kernel function as the Laplace kernel with bandwidth 10
  • Grid search over ridge regularization from the set {10, 1, .1, .01, 0} with fixed bandwidth L = 10 on large regression datasets
  • Use EigenPro to train all kernel methods and RFMs on the largest dataset
  • Use 553 samples with 500 examples of airplanes and 53 examples of trucks for Grokking
  • Train a two hidden layer fully connected ReLU network using full gradient descent with a learning rate of 0.1
  • Train RFMs for three iterations with ridge regularization of 10-3 and using the Laplace kernel as the base kernel function with a bandwidth of 10
  • Compare to the metrics reported in [23] for tabular data benchmark from [23]