Link to paper
The full paper is available here.
You can also find the paper on PapersWithCode here.
Abstract
- Sharpness of minima can correlate with generalization in deep networks.
- Reparametrization-invariant sharpness definitions have been proposed.
- Does it really capture generalization in modern practical settings?
- Observed that sharpness does not correlate well with generalization.
- Negative correlation of sharpness with out-of-distribution error.
- Right sharpness measure is highly data-dependent.
Paper Content
Introduction
- Sharpness of training objective has intuitive appeal and appears in generalization bounds
- Sharpness can correlate well with generalization in deep learning setups
- Training methods that minimize sharpness have had empirical success
- Many works suggest flatter minima should generalize better
- Standard sharpness definitions do not correlate well with generalization
- Different sharpness definitions can capture different trends
- Right sharpness measure is highly data-dependent
Related work
- Sharpness of minima is correlated with performance degradation of large-batch SGD
- Different generalization measures may explain generalization for deep networks
- Strong correlation between sharpness and generalization on a large set of CIFAR-10/SVHN models
- Reparametrization-invariant sharpness definitions exist
- Flat minima can be beneficial for generalization
- Different criteria optimize for more robust minima
- Maximum eigenvalue and trace of the Hessian are focus of many works
- Focus on sharpness-related metrics to better understand generalization for deep networks
Background on sharpness
- Loss on a set of training points is defined as L S (w)
- Average-case and worst-case m-sharpness are defined
- Worst-case sharpness is correlated with generalization
- Adaptive sharpness is invariant under multiplicative reparametrizations
- Analytical expressions of standard sharpness for radius ρ → 0 depend on first-or second-order terms
- Strong hypothesis: sharpness is highly correlated with generalization
- Weak hypothesis: models with lowest sharpness generalize well
- Kendall rank correlation coefficient is used to detect correlation
- Adaptive sharpness is invariant for ResNets and ViTs
- Scale-sensitivity of classification losses is discussed
- Normalization of logits is proposed to fix scaling issue
How to compute worst-case sharpness efficiently?
- Estimation of worst-case sharpness involves solving a constrained maximization problem.
- Auto-PGD is a hyperparameter-free method designed to accurately estimate adversarial robustness.
- Gradient steps are proportional to |w| to better take into account the geometry induced by elementwise adaptive sharpness.
- 20 steps are typically sufficient to converge with APGD.
- Correlation of sharpness with test error is either close to zero or negative.
Sharpness vs. generalization: modern setup
- Current understanding of relationship between sharpness and generalization based on experiments on non-residual convolution networks and small datasets
- Investigated both in-distribution and out-of-distribution generalization
- Focused on worst-case ∞ adaptive sharpness with low m (256)
- Experiments on ImageNet-1k from scratch and fine-tuning from CLIP and BERT
- Neither sharpness measure effectively distinguishes model performance
- Negative correlation between sharpness and test error on ID and OOD data
- Sharpness does not capture ranking by test error
- Weak correlation between sharpness and test error on MNLI and HANS datasets
Exploring the role of sharpness in a controlled setup
- Three potential explanations for why sharpness does not correlate well with generalization
- Trained 200 ResNets-18 and 200 ViTs on CIFAR-10
- Evaluate sharpness only for models that reach at most 1% training error
- Use SGD with momentum and linearly decreasing learning rates
- 12 different sharpness definitions tested
- Strong correlation between sharpness and test error within each subgroup
- No consistent observation that models with lowest sharpness generalize best
- Negative correlation between sharpness and learning rate for ResNets and ViTs
Is sharpness the right quantity in the first place? insights from simple models
- Diagonal linear networks are a type of neural network
- Diagonal linear networks are defined by two weights, u and v
- We consider an overparametrized sparse regression problem
- Local sharpness has a closed-form expression
- Adaptive sharpness is invariant under α-reparametrization
- Adaptive sharpness captures different properties of β
- Different sharpness definitions capture different trends
- Correlation between sharpness and generalization is not perfect
Conclusions
- Reparametrization-invariant sharpness is not a good indicator of generalization in the modern setting.
- Correlation between sharpness and generalization is positive in some restricted settings.
- Correlation is much lower for vision transformers.
- Flat minima do not always generalize better.
- Maximum eigenvalue of the Hessian is not always predictive of generalization.
- Edge of stability regime is investigated in many subsequent works.
- Definitions of adaptive sharpness measures include average-case and worst-case m-sharpness.
- For diagonal linear networks, maximum eigenvalue is max 1≤i≤d v 2 i + u 2 i.
- For ImageNet-1k models trained from scratch, sharpness variants are not predictive of performance on ImageNet or OOD datasets.
- For ImageNet-21k models, sharpness variants are not predictive of performance on ImageNet or OOD datasets.
- For combined analysis of ImageNet-1k and ImageNet-21k models, better-generalizing models have higher worst-case sharpness and roughly equal or higher logit-normalized average-case adaptive sharpness.
- For CLIP models fine-tuned on ImageNet, sharpness variants are not predictive of performance on ImageNet or ImageNet-V2.
- For BERT models fine-tuned on MNLI, sharpness variants are not predictive of generalization performance.
- For CIFAR-10 models, none of the considered sharpness definitions or radii correlate positively with generalization.