Link to paper

The full paper is available here.

You can also find the paper on PapersWithCode here.

Abstract

  • Investigates one-pass stochastic gradient descent (SGD) dynamics of a two-layer neural network
  • Data and labels generated by a similar target function
  • Analyse limiting dynamics via deterministic and low-dimensional description
  • Bridges different regimes of interest, such as gradient-flow, high-dimensional, and overparameterised
  • In high-dimensional limit, dynamics remains close to a low-dimensional subspace spanned by target principal directions

Paper Content

Introduction

  • Stochastic gradient descent was introduced as a method for stochastic approximation
  • It was applied to population risk minimization
  • Its properties have been studied for finite learning rate and input dimension
  • High-dimensional limits of SGD were studied for non-convex, single-index models
  • SGD dynamics of two-layer neural networks was studied for synthetic data with simple target functions
  • Mean-field limit of SGD dynamics of two-layer neural networks was studied and proved global convergence
  • Dimension-free limits of mean-field equations were derived for low-dimensional target functions
  • Global convergence of gradient flow dynamics at finite width was proven for orthogonal input data
  • Unifying low-dimensional description of one-pass SGD dynamics was discussed as learning rate and hidden layer width scale with diverging input dimension

Setting

  • Supervised learning regression task with n independent samples
  • Fully-connected two-layer neural network with trainable parameters Θ
  • Mean-field normalization adopted
  • Square loss used to penalize deviations from true labels
  • Training performed via empirical risk minimization
  • Stochastic gradient descent used to optimize training parameters
  • Inputs x ν are Gaussian, outputs y ν drawn from generative model
  • Teacher-student scenario provides data model for studying generalization
  • SGD dynamics remain in bounded subset of R p×d
  • Student and teacher activation functions twice differentiable and upper bounded by K

The three limit regimes and their dimensionless description

  • Non-convex optimisation problem is introduced in Sec. 2
  • Data dimension and number of parameters in neural networks are large
  • Evolution of weights is studied by studying p(d+1) non-linear, coupled, non-convex stochastic process
  • Derive a tractable, low-dimensional description for SGD in different regimes of practical interest

Main concepts

  • Performance of the predictor only depends on the statistics of the student and teacher pre-activations
  • Pre-activations are jointly Gaussian vectors
  • Equations are exact and allow for trading a high-dimensional process for a low-dimensional process
  • Deterministic description is valid beyond the high-dimensional limit
  • Theorem 3.1 is non-asymptotic in (d, p, γ)
  • Provides a tractable, low-dimensional description of SGD

The classical regime

  • Classical regime of SGD converges to gradient flow on population risk
  • Effective SGD noise is subleading in this limit
  • ODE of dimension p(d+1) can be easily implemented and solved
  • Alternative description of ODE in p(k+p) parameters
  • ODE trajectories follow simulated ones
  • Dimension independence property due to Gaussianity assumption

The high-dimensional regime

  • Modern machine learning practice involves high-dimensional data.
  • SGD dynamics in high-dimension can yield a finite risk contribution at large times.
  • Classical and high-dimensional regimes have major differences.
  • SGD with fixed learning rate converges to a stationary distribution of variance.

The overparametrised regime

  • SGD dynamics rely on quantities which scale with the hidden-layer width
  • Need to deal with wide, overparametrised networks
  • Key idea is to define an empirical density over the weights
  • Asymptotic density converges to a partial differential equation
  • Exploit symmetries of the problem to derive an approximation of constant dimension
  • Decompose the dynamics into orthogonal projection and covariance matrix
  • Mean-field ODEs read
  • Risk is computed as
  • Theorem 3.2 states that the approximation is consistent
  • Assumption A3 enforces orthogonal invariance
  • High-dimensional limit of mean-field does not depend on auxiliary random variables
  • Lemma 3.4 holds for any (M, q)
  • Mean-field ODEs are approximated by high-dimensional equivalent
  • Theorem 3.5 states that the approximation is consistent
  • Square activation example is given
  • Numerical experiments are shown in Figure 3

Conclusion

  • Our work provides a comprehensive analysis of the one-pass SGD dynamics of two-layer neural networks
  • Bridges different regimes of interest
  • Offers a unifying picture of the limiting SGD dynamics
  • Sheds light on the behavior of neural networks trained on synthetic data
  • Provides a useful tool for further investigations
  • A term from Ψ (Var) ij (Ω) can be kept in order to get better agreement between ODEs and simulation in presence of label noise
  • Uctuations visible in the final plateaus result from the stochastic process in Equation (9)
  • Matching the equations to show Theorem C.5 implies Theorem 3.2
  • High-dimensional mean-field approximation provides several bounds on expectations of functions of Gaussians
  • Lemma D.5 shows there exists a constant c ≥ 0 such that for any choice of λ
  • Derive differential equations for the dynamics when both σ and σ* are the square function