Link to paper
The full paper is available here.
You can also find the paper on PapersWithCode here.
Abstract
- Neural networks have achieved impressive results on many tasks.
- This paper connects neural feature learning to the average gradient outer product.
- The paper introduces Recursive Feature Machines (RFMs) which are kernel machines that learn features.
- RFMs accurately capture features learned by deep fully connected neural networks.
- RFMs close the gap between kernel machines and fully connected networks.
- RFMs surpass a broad spectrum of models including neural networks on tabular data.
- RFMs shed light on deep learning phenomena such as grokking, lottery tickets, simplicity biases, and spurious features.
Paper Content
Introduction
- Modern neural networks have achieved major progress on a variety of applications
- Feature learning is thought to be a central contributor to their superior performance
- Identifying the component of neural networks associated with feature learning has been a challenge
- Different lines of investigation connect feature learning to various aspects of neural network methodology
- Isolate a key mechanism of feature learning in deep fully connected neural networks
- Connect feature learning with a statistical estimator for feature selection
- Develop a class of kernel machines, Recursive Feature Machines (RFMs), that learn features
- RFMs accurately capture features learned by deep networks
- Close the gap between kernel machines and neural networks
- Achieve state-of-the-art performance on tabular data
- Provide theoretical evidence for the connection between feature learning in fully connected networks and the expected gradient outer product
- Features learned by RFMs and deep networks are highly correlated
- Training RFMs on these learned features leads to improved predictive performance over deep neural networks
Neural feature learning and recursive feature machines
- Connecting feature learning in neural networks to the average gradient outer product
- Developing recursive feature machines, a class of kernel machines that learn features
- Defining the feature matrix of a neural network
- Claiming that the feature matrix and the average gradient outer product capture similar sets of features
- Implication that features learned by estimating the average gradient outer product should match those learned by deep networks
- Two lines of evidence: theoretical and empirical
- Theoretical evidence: establishing the connection between the feature matrix and the average gradient outer product in certain settings
- Empirical evidence: features learned by RFMs are highly correlated with those captured by the feature matrix of trained neural networks
- Alternating between estimating a predictor and learning features using kernel machines
- Using the Laplace kernel for RFMs
- Diagonals of feature matrices from RFMs and deep networks are highly correlated
Empirical results
- RFMs accurately capture features learned by fully connected networks
- RFMs close the gap between kernel machines and neural networks
- RFMs provide state-of-the-art results on large tabular datasets
- High correlation between feature matrices of deep networks and RFMs
Rfms close the gap between kernels and fully connected networks.
- RFMs accurately capture features learned by fully connected networks
- RFMs match or outperform neural networks in predictive performance
- RFMs and deep networks learn to subset to the center columns of the image
- RFMs and deep networks outperform previous kernel methods by 10%
- RFMs match or outperform neural networks on other illustrative tasks
- RFMs provide an improvement in error for large datasets with high data dimension
Rfms outperform neural networks and ntk on tabular data benchmarks.
- RFMs outperform NTKs, variants of Laplace kernels, and all 179 other methods on subset of 90 small datasets and full 121 datasets from original benchmark
- RFMs achieve highest average rank, average accuracy, P90/P95, and PMA across all datasets
- RFMs outperform modern neural networks and tree-based models on large tabular tasks from [23]
- RFMs achieve highest ADTM on large classification and regression tasks
- RFMs outperform modern neural networks on medium size tabular datasets
- Tree-based models slightly outperform RFMs on regression tasks with categorically encoded data
Deep learning phenomena through the prism of rfm
- Deep neural networks have phenomena not seen in kernel methods
- RFMs can reproduce these phenomena
- Grokking
- Lottery Ticket Hypothesis
Grokking in rfms and deep networks.
- Grokking is a phenomenon in which deep networks can increase test accuracy when training past 100% training accuracy.
- Grokking can be seen in transformers, RFMs, and deep fully connected networks.
- A dataset with a large class imbalance and a 5x5 pixel square in the upper left corner of each image can enable grokking.
Lottery tickets in rfms and deep networks.
- The lottery ticket hypothesis states that a randomly-initialized neural network contains a sub-network that can match or outperform the trained network when trained in isolation.
- Visualizing the diagonals of the feature matrix shows sparsity, which provides evidence for the lottery ticket hypothesis.
Inductive biases of rfms and deep networks.
- Neural networks are over-parameterized but still yield improved performance
- Recent works have analyzed inductive biases of deep networks
- Simplicity bias is a form of inductive bias in deep networks
- Neural networks and RFMs can accurately capture simplicity biases
Rfms capture spurious features and biases in deep networks.
- Deep networks can use features that are not related to the object of interest, making them unreliable.
- Adversarial examples are caused by these spurious features.
- RFMs can be used to identify these spurious features used by deep networks.
- Applying a small perturbation to these identified features can lead to a large decrease in accuracy.
Summary, discussion, and outlook
- Isolated key mechanism of feature learning in deep fully connected neural networks
- Proposed Neural Feature Ansatz - first layer of neural networks responsible for feature learning
- Learned features closely related to average gradient outer product
- Developed Recursive Feature Machines (RFMs) to learn features
- RFMs accurately capture features learned by deep fully connected neural networks
- RFMs close gap between kernel methods and fully connected networks
- RFMs yield state-of-the-art performance on large tabular datasets
- Many deep learning phenomena can be understood through RFMs
- Network width modulates between two different regimes
- Feature learning largely in first layer of deep networks
- Transparency of RFMs can increase interpretability of machine learning models
- RFMs can be viewed as a method for learning a data-dependent kernel
- Feature learning not always lead to improvements
- Connections to metric and manifold learning, FisherFaces and EigenFaces, Debiasing, Expectation Maximization
- Separating predictor and feature learning components of neural networks provides a modular path forward
B background on kernel ridge regression
- Kernel ridge regression is a non-parametric estimator used in computer science
- It involves solving an infinite dimensional optimization problem in a Reproducing Kernel Hilbert Space
- For datasets with n โค 100,000, the problem can be solved in closed form
- For larger datasets, the EigenPro solver is used to approximate the solution via early-stopped, preconditioned-SGD on the GPU
C dataset and experimental details
- Experiments with RFMs and fully connected networks on CelebA
- Normalize all images to be on the unit sphere
- Train 2-hidden layer ReLU networks with 1024 hidden units per layer using SGD for 500 epochs with a learning rate of 0.1 and a mini-batch size of 128
- Train RFMs for 1 iteration, use a ridge regularization term of 10-3, and average the gradient outer product of at most 20000 examples
- All RFMs use Laplace kernels as the base kernel and use a bandwidth parameter of L = 10
- Split available training data into 80% training and 20% validation for hyperparameter selection
- Report accuracy on a held out test set provided by PyTorch
- Ensure that the training set and test set are balanced by limiting the number of majority class samples to the same number of minority class samples
- Limit the total number of training and validation examples per experiment to 50000 (25000 per class)
- Train 2-hidden layer ReLU networks with 1024 hidden units per layer using SGD for 500 epochs with a learning rate of 0.1 and a mini-batch size of 100 for SVHN
- Train RFMs for 5 iterations and average the gradient outer product of at most 20000 examples
- RFMs and Laplace kernels used all have a bandwidth parameter of 10
- Compare with the NTK of a 2-hidden layer ReLU network
- Use 1000 examples for training and 10000 samples for testing for low rank polynomials
- Train a 1 hidden layer neural network for 1000 epochs using full batch gradient descent with a learning rate of .1 and initialize the first layer with standard deviation 10-3
- Train RFMs with no ridge term and set the base kernel function as the Laplace kernel with bandwidth 10
- Grid search over ridge regularization from the set {10, 1, .1, .01, 0} with fixed bandwidth L = 10 on large regression datasets
- Use EigenPro to train all kernel methods and RFMs on the largest dataset
- Use 553 samples with 500 examples of airplanes and 53 examples of trucks for Grokking
- Train a two hidden layer fully connected ReLU network using full gradient descent with a learning rate of 0.1
- Train RFMs for three iterations with ridge regularization of 10-3 and using the Laplace kernel as the base kernel function with a bandwidth of 10
- Compare to the metrics reported in [23] for tabular data benchmark from [23]