Link to paper
The full paper is available here.
You can also find the paper on PapersWithCode here.
Abstract
- Investigates one-pass stochastic gradient descent (SGD) dynamics of a two-layer neural network
- Data and labels generated by a similar target function
- Analyse limiting dynamics via deterministic and low-dimensional description
- Bridges different regimes of interest, such as gradient-flow, high-dimensional, and overparameterised
- In high-dimensional limit, dynamics remains close to a low-dimensional subspace spanned by target principal directions
Paper Content
Introduction
- Stochastic gradient descent was introduced as a method for stochastic approximation
- It was applied to population risk minimization
- Its properties have been studied for finite learning rate and input dimension
- High-dimensional limits of SGD were studied for non-convex, single-index models
- SGD dynamics of two-layer neural networks was studied for synthetic data with simple target functions
- Mean-field limit of SGD dynamics of two-layer neural networks was studied and proved global convergence
- Dimension-free limits of mean-field equations were derived for low-dimensional target functions
- Global convergence of gradient flow dynamics at finite width was proven for orthogonal input data
- Unifying low-dimensional description of one-pass SGD dynamics was discussed as learning rate and hidden layer width scale with diverging input dimension
Setting
- Supervised learning regression task with n independent samples
- Fully-connected two-layer neural network with trainable parameters Θ
- Mean-field normalization adopted
- Square loss used to penalize deviations from true labels
- Training performed via empirical risk minimization
- Stochastic gradient descent used to optimize training parameters
- Inputs x ν are Gaussian, outputs y ν drawn from generative model
- Teacher-student scenario provides data model for studying generalization
- SGD dynamics remain in bounded subset of R p×d
- Student and teacher activation functions twice differentiable and upper bounded by K
The three limit regimes and their dimensionless description
- Non-convex optimisation problem is introduced in Sec. 2
- Data dimension and number of parameters in neural networks are large
- Evolution of weights is studied by studying p(d+1) non-linear, coupled, non-convex stochastic process
- Derive a tractable, low-dimensional description for SGD in different regimes of practical interest
Main concepts
- Performance of the predictor only depends on the statistics of the student and teacher pre-activations
- Pre-activations are jointly Gaussian vectors
- Equations are exact and allow for trading a high-dimensional process for a low-dimensional process
- Deterministic description is valid beyond the high-dimensional limit
- Theorem 3.1 is non-asymptotic in (d, p, γ)
- Provides a tractable, low-dimensional description of SGD
The classical regime
- Classical regime of SGD converges to gradient flow on population risk
- Effective SGD noise is subleading in this limit
- ODE of dimension p(d+1) can be easily implemented and solved
- Alternative description of ODE in p(k+p) parameters
- ODE trajectories follow simulated ones
- Dimension independence property due to Gaussianity assumption
The high-dimensional regime
- Modern machine learning practice involves high-dimensional data.
- SGD dynamics in high-dimension can yield a finite risk contribution at large times.
- Classical and high-dimensional regimes have major differences.
- SGD with fixed learning rate converges to a stationary distribution of variance.
The overparametrised regime
- SGD dynamics rely on quantities which scale with the hidden-layer width
- Need to deal with wide, overparametrised networks
- Key idea is to define an empirical density over the weights
- Asymptotic density converges to a partial differential equation
- Exploit symmetries of the problem to derive an approximation of constant dimension
- Decompose the dynamics into orthogonal projection and covariance matrix
- Mean-field ODEs read
- Risk is computed as
- Theorem 3.2 states that the approximation is consistent
- Assumption A3 enforces orthogonal invariance
- High-dimensional limit of mean-field does not depend on auxiliary random variables
- Lemma 3.4 holds for any (M, q)
- Mean-field ODEs are approximated by high-dimensional equivalent
- Theorem 3.5 states that the approximation is consistent
- Square activation example is given
- Numerical experiments are shown in Figure 3
Conclusion
- Our work provides a comprehensive analysis of the one-pass SGD dynamics of two-layer neural networks
- Bridges different regimes of interest
- Offers a unifying picture of the limiting SGD dynamics
- Sheds light on the behavior of neural networks trained on synthetic data
- Provides a useful tool for further investigations
- A term from Ψ (Var) ij (Ω) can be kept in order to get better agreement between ODEs and simulation in presence of label noise
- Uctuations visible in the final plateaus result from the stochastic process in Equation (9)
- Matching the equations to show Theorem C.5 implies Theorem 3.2
- High-dimensional mean-field approximation provides several bounds on expectations of functions of Gaussians
- Lemma D.5 shows there exists a constant c ≥ 0 such that for any choice of λ
- Derive differential equations for the dynamics when both σ and σ* are the square function