Link to paper
The full paper is available here.
You can also find the paper on PapersWithCode here.
Abstract
- Stochastic gradient descent is important for deep learning applications.
- Loss functions for large networks with large amounts of data are non-convex.
- Loss for fixed mini-batches can be modeled by a quadratic function.
- A simple model and geometric interpretation can explain the relationship between gradients of mini-batches and full batches.
- Averaging two points a few steps apart can improve accuracy.
Paper Content
Introduction
- Stochastic Gradient Descent (SGD) played a key role in the success of deep learning
- Learning rate is a crucial hyper-parameter that affects training trajectory
- Research continues to explore novel learning rate schedules
- Impact of batch noise on optimization process is not well understood
- Loss on training batch is different than on held-out batch
- Popular averaging techniques have equivalent learning rate schedules
- Model weights quickly converge to a quasi-stationary distribution
- Averaging moves trajectory to inside of an ellipsoid
- Distance in parameter space between two trajectories decreases over time
- Loss on training batch is different than on unseen batch
Related work
- Averaging intermediate SGD iterates in the weight space was first analyzed in Polyak & Juditsky (1992).
- Mandt et al. (2017) established a connection between SGD and Bayesian inference.
- Stochastic Weight Averaging (SWA) by Izmailov et al. (2018) uses a single trajectory with cyclic learning rates and averaging every c steps.
- He et al. (2019) showed that averaging introduces bias towards flatter part of the basin, which results in better generalization.
Empirical observations
- Goal is to explore properties of SGD that lead to theoretical model in Section 4
- Experiments use ResNet34 architecture and ImageNet
- Standard SGD with momentum 0.9 used
- Experiments involve starting with partially trained architectures
- Question: how does loss landscape look for single batch applied with varying learning rate?
- Loss for both curves is smooth and can be approximated by low-degree polynomial
- Loss on training batch reaches value below average loss of fully trained model in single step
- Loss on held-out batch behaves differently, only small learning rate leads to improvement
- Analytical model should have this property
- Multi-step side-trip to check if same minima basin regardless of starting point
- Side-trip using same fixed batch leads to same basin
- Held-out batch loss increases but stays in same basin as main trajectory
Analytical model
- Proposed a model to describe the behavior of a high dimensional weight vector
- Used two points along an ImageNet training trajectory
- Performed 9 steps of gradient descent on a fixed batch
- Loss on the training batch reaches nearly 0 in 3 steps
- Loss on held out batch is similar to training batch at step 0
- Model is similar to one proposed in Schaul et al. (2013)
- Stochastic gradient and full gradient can be written
- Estimate how close to 0 can be reached for a fixed learning rate
- Trajectory will traverse an ellipsoid around the minimum
- Change in norm becomes zero when A x θ t = 0
- Expected change in norm becomes zero when E x∼D A T x A x = I
- Distance to global minimum stabilizes at a value proportional to √ λ
- Weight averaging has an equivalent learning rate along the trajectory
- Loss match can be achieved instead of matching the full trajectory
Experiments
- Experimented with ImageNet, Cifar10 and Cifar100
- Used ResNet-34 and SGD with momentum 0.9
- Showed learning rate schedules match stochastic averaging
- Included results in Figures 5, 6 and 10 of supplementary materials
- Demonstrated gradient alignment and divergence of trajectories
Synthetic model experiments
- Weight averaging can be used to simulate a properly chosen learning rate schedule.
- Weight averaging can improve model performance while still allowing for a large learning rate.
- Weight averaging is sensitive to the time scale of the underlying weight evolution.
- Weight averaging can reduce the stationary loss without sacrificing the speed of convergence.
Open question and conclusions
- Explored properties of learning rate in stochastic gradient descent
- Demonstrated connection between iterate averaging and learning rate schedules
- Observed in theoretical models and large-scale training on multiple datasets
- Hope this work paves the way to further understanding of learning rate role in training
A weight averaging
- Evolution of θ governed by a simplified version of the finite-batch version of equation 5
- Coordinate transformation that whitens the batch noise
- Moving average of the training trajectory
- Derive a simple equation governing it and characterize the resulting converged steady-state distribution
- Stochastic gradient descent trajectories depend on both A and C matrices
- Change of coordinates θ t = Qu t removes dependence on one of the matrices
- Solution of equation 15 can be obtained recursively
- Expression for X ∞ connects the covariance of the stationary distribution in the weight space to the autocorrelation of the averaging kernel µ t
- Expression for G 1 and G 2 fully define the covariance F for arbitrary values of λ, ∆ and an arbitrary Ω
- Effective learning rate is decreased by approximately a factor of n
- Equivalent learning rate schedules for different averaging methods
- Proof of Lemma 2