Link to paper
The full paper is available here.
You can also find the paper on PapersWithCode here.
Abstract
- Problem of learning a single neuron with ReLU activation under Gaussian input with square loss is revisited.
- Over-parameterization setting (student network has $n\ge 2$ neurons) is focused on.
- Global convergence of randomly initialized gradient descent with a $O\left(T^{-3}\right)$ rate is proven.
- $\Omega\left(T^{-3}\right)$ lower bound for randomly initialized gradient flow in the over-parameterization setting is presented.
- Over-parameterization can exponentially slow down the convergence rate.
Paper Content
Introduction
- Gradient descent (GD) is used to train deep neural networks
- Over-parameterization plays a key role in successful training
- Over-parameterization can slow down convergence of GD
- Two-layer ReLU networks with n neurons and input dimension d are studied
- Student network is trained to learn a ground truth teacher network
- Square loss is considered
- Teacher network consists of one single neuron
- Student network is initialized with a Gaussian distribution
- Empirically, over-parameterization exponentially slows down convergence
- Theorem 1 shows global convergence of GD
- Theorem 2 provides a convergence rate lower bound
- Over-parameterization can slow down gradient-based methods
Related works
- Problem of learning a single neuron is well-understood
- Classical single index models algorithms can solve it
- GD can converge for learning a single neuron
- GD can converge for learning a convolutional filter
- Over-parameterization setting studied
- Optimization landscape studied
- GD can converge to global minimum with tensor initialization
- Over-parameterization can eliminate certain types of spurious local minima
- Neural tangent kernel connects training of ultra-wide neural networks with kernel methods
- Mean-field analysis studies training of infinite-width neural networks
- GD can learn two-layer networks better than kernel methods
- GD can converge globally
Technical overview
- Three-phase convergence analysis is divided into three phases
- θ i represents the radial difference between teacher and students, while H represents the tangential difference
- When initialization is small enough, in phase 1, θ i decreases while w i remains small
- In phase 2, θ i remains bounded while H decreases exponentially
- In phase 3, two properties are established: lower bound of gradient and regularity condition of student neurons
- Optimization landscape is different and harder to analyze when network is over-parameterized
- Loss function is not smooth when student neurons are close to 0
- GD implicitly regularizes student neurons and keeps them away from non-smooth regions near 0
- Gradient lower bound is improved from Ω(L(w)) to Ω(L 2/3 (w))
- Non-degeneracy condition is established to build lower bound of convergence rate
Preliminaries
- Bold-faced letters denote vectors
- [n] denotes {1, 2, . . . , n}
- θ(w, v) denotes the angle between two nonzero vectors w, v
- ∇ i denotes the gradient of the i th student neuron
- w(t) denotes the value of a variable at the t th iteration
- Expectation taken w.r.t the standard Gaussian is abbreviated
- Length of the projection of w i onto v is defined
- Closed form expressions of L(w) and ∇L(w) can be obtained
- Proof sketch for Theorem 1 provided
Initialization
- Initialization of w i (0) has high probability
- Norms of w i (0) have upper and lower bounds
- θ i will fall in the interval [ π 3 , 2π 3 ] initially
Phase 1
- Theorem 4 states that there are upper and lower bounds for w i
- Gradient norm is used to bound the dynamics of θ i
- θ i is small at the end of Phase 1
- Projections of student neurons on teacher neuron are balanced
- Proving upper bound of w i is straightforward
- Proving lower bound of w i is more difficult due to small perturbation term
- θ i remains small in second interval of Phase 1
- Increases of h i in second interval are balanced
Phase 2
- Phase 2 starts at time T1 + 1 and ends at time T2
- Theorem 5 states that projections hi remain balanced in Phase 2
- Equation 11 shows that hi remain balanced, 12 bounds the dynamics of H(t), 13 gives upper and lower bounds for hi, and 14 shows that θi remains upper bounded by a small term 2 in Phase 2
- Gradient (3) has the property that max i θi ≤ 2 and max i wi = O(v/n)
- H(t) exponentially decreases and gives upper and lower bounds
- Dynamics of cos θi is calculated and a potential function V(t) is defined
- Upper bound for V yields the final upper bound for θi
Phase 3
- Theorem 6 states that the desired 1/T3 convergence rate can be achieved.
- Prove a gradient lower bound ∇L(w) ≥ poly(n-1, v-1)L2/3(w).
- Prove that the loss function is smooth and Lipschitz on the gradient trajectory.
- Improved version of Theorem 3 in Zhou et al. [2021].
- Pick a global minimum w1, w2, …, wn and lower bound the projection of gradient on the direction wi - wi.
- Upper bound wi - w*i.
- Gradient lower bound scales with L2/3.
- Show that GD implicitly regularizes wi.
- Improved gradient lower bound is crucial for bounding the movement of student neurons.
- Convergence rate L(w(T)) ∼ 1/T2.
Main theorem
- Theorem 1 is formally stated and proven
- Parameters in Theorem 4, 5, 6 are assigned values (listed in Appendix E.2)
- Combining the initialization condition (Lemma 3) and three phases of analysis (Theorem 4, 5, 6) proves Theorem 9
- Corner case of w i = 0 is excluded in the analysis (discussed in Appendix E.3)
Proof overview: convergence rate lower bound
- Gradient flow (gradient descent with infinitesimal step size) is considered
- Other settings (network architecture, initialization scheme, etc.) are kept unchanged
- Over-parameterization causes a significant change of the convergence rate
- Several toy cases are investigated to understand why
Case study
- w 1 and w 2 are reflection symmetric with respect to v
- gradient descent preserves the symmetry of w 1 and w 2
- λ 1 and λ 2 converge to 0 with rate λ 2 (t) ∼ t −1
- all student neurons are parallel with the teacher neuron: convergence rate is linear
- all student neurons are equal: training process is equivalent to learning one teacher neuron with one student neuron, step size multiplied by a factor of n
Non-degeneracy
- Theorem 9 gives a worst case optimal convergence rate.
- Theorem 12 gives an average case lower bound for the convergence rate.
- Theorem 12 requires random initialization to be non-degenerate.
- Theorem 12 depends on 1/κ-2 max (0).
Proof sketch
- Theorem 12 is proven by considering a potential function Z(t)
- The potential function explains why the convergence rate is different for n = 1 and n ≥ 2
- The potential function also explains why the convergence rates in two counter-examples are linear
- Several technical properties of the gradient flow trajectory are needed
- Lower bounding κ max is the most non-trivial step
- Lemma 15 states that when vectors are too close in direction, gradient flow will automatically separate them
- Remark 16 states that in toy case 3, all z i ’s remain parallel and will not be separated
- Closed form expressions for L and ∇L are presented
- Global convergence is shown in Theorem 4
C global convergence: phase 2
- Theorem 5 states that if the initial condition in Lemma 3 holds, then (40), (41), (42) and (43) hold.
- Proof of (40) involves computing a bound using (27) and (23).
- Proof of (41) involves computing a bound using (35) and (37).
- Proof of (42) involves computing a bound using (35) and (46).
- Proof of (43) involves computing a bound using (29), (31) and (34).
- Lemma 17 states that at the start of Phase 3, the conditions (40), (41), (42) and (43) in Theorem 5 are satisfied.
- Proof of Lemma 17 involves computing a bound using (34) and (37).
D.2 proofs for gradient lower bound
- Lemma 8 introduces the idea of residual decomposition
- Lemma 18 provides a bound of θi
- Lemma 19 provides a bound of r
- Theorem 7 states that if certain conditions are met, then the global minimum can be found
D.3 handling non-smoothness
- Set 2 = O(n-14)
- Set 84.5 = O(n-14)
- Set nη = O(n-14)
- Lemma 21 ensures smoothness of L when student neurons are regularized
- Lemma 22 ensures student neurons won’t move too far in phase 3
- Theorem 6 sets parameters according to Theorem 4 and Theorem 5
- Theorem 7 bounds decrease of loss at time t
- Theorem 8 ensures properties hold with probability at least 1-δ
E.3 non-degeneracy of student neurons
- Technical issue: student neuron w i can be degenerate and w i = 0, making loss function L(w) not differentiable
- Proof shows student neurons are always non-degenerate
- Student neuron’s norm w i is always lower-bounded in all three phases of analysis
- Corollary describing non-degeneracy of student neurons
- Assumption on initialization condition necessary, otherwise counter-examples exist
- Technical preparations for proving Theorem 12
- Gradient flow version of Theorem 6 and Corollary 23
- Lemma 27: given non-degenerate initialization, at least one student neuron z i (t) = 0
- Lemma 28: given non-degenerate initialization, at least one of two conditions must hold
- Lemma 29: given non-degenerate initialization, Z(t) ≥ Ω(κ max (0) max i∈ [n] z i (t) )
- Corollary 30: given conditions, Z(t) > 0
- Lemma 31: given conditions, ∂ ∂t Z(t) ≥ −O(n 2 v θ 2 max (t))