Link to paper

The full paper is available here.

You can also find the paper on PapersWithCode here.

Abstract

  • Stochastic gradient descent is important for deep learning applications.
  • Loss functions for large networks with large amounts of data are non-convex.
  • Loss for fixed mini-batches can be modeled by a quadratic function.
  • A simple model and geometric interpretation can explain the relationship between gradients of mini-batches and full batches.
  • Averaging two points a few steps apart can improve accuracy.

Paper Content

Introduction

  • Stochastic Gradient Descent (SGD) played a key role in the success of deep learning
  • Learning rate is a crucial hyper-parameter that affects training trajectory
  • Research continues to explore novel learning rate schedules
  • Impact of batch noise on optimization process is not well understood
  • Loss on training batch is different than on held-out batch
  • Popular averaging techniques have equivalent learning rate schedules
  • Model weights quickly converge to a quasi-stationary distribution
  • Averaging moves trajectory to inside of an ellipsoid
  • Distance in parameter space between two trajectories decreases over time
  • Loss on training batch is different than on unseen batch
  • Averaging intermediate SGD iterates in the weight space was first analyzed in Polyak & Juditsky (1992).
  • Mandt et al. (2017) established a connection between SGD and Bayesian inference.
  • Stochastic Weight Averaging (SWA) by Izmailov et al. (2018) uses a single trajectory with cyclic learning rates and averaging every c steps.
  • He et al. (2019) showed that averaging introduces bias towards flatter part of the basin, which results in better generalization.

Empirical observations

  • Goal is to explore properties of SGD that lead to theoretical model in Section 4
  • Experiments use ResNet34 architecture and ImageNet
  • Standard SGD with momentum 0.9 used
  • Experiments involve starting with partially trained architectures
  • Question: how does loss landscape look for single batch applied with varying learning rate?
  • Loss for both curves is smooth and can be approximated by low-degree polynomial
  • Loss on training batch reaches value below average loss of fully trained model in single step
  • Loss on held-out batch behaves differently, only small learning rate leads to improvement
  • Analytical model should have this property
  • Multi-step side-trip to check if same minima basin regardless of starting point
  • Side-trip using same fixed batch leads to same basin
  • Held-out batch loss increases but stays in same basin as main trajectory

Analytical model

  • Proposed a model to describe the behavior of a high dimensional weight vector
  • Used two points along an ImageNet training trajectory
  • Performed 9 steps of gradient descent on a fixed batch
  • Loss on the training batch reaches nearly 0 in 3 steps
  • Loss on held out batch is similar to training batch at step 0
  • Model is similar to one proposed in Schaul et al. (2013)
  • Stochastic gradient and full gradient can be written
  • Estimate how close to 0 can be reached for a fixed learning rate
  • Trajectory will traverse an ellipsoid around the minimum
  • Change in norm becomes zero when A x θ t = 0
  • Expected change in norm becomes zero when E x∼D A T x A x = I
  • Distance to global minimum stabilizes at a value proportional to √ λ
  • Weight averaging has an equivalent learning rate along the trajectory
  • Loss match can be achieved instead of matching the full trajectory

Experiments

  • Experimented with ImageNet, Cifar10 and Cifar100
  • Used ResNet-34 and SGD with momentum 0.9
  • Showed learning rate schedules match stochastic averaging
  • Included results in Figures 5, 6 and 10 of supplementary materials
  • Demonstrated gradient alignment and divergence of trajectories

Synthetic model experiments

  • Weight averaging can be used to simulate a properly chosen learning rate schedule.
  • Weight averaging can improve model performance while still allowing for a large learning rate.
  • Weight averaging is sensitive to the time scale of the underlying weight evolution.
  • Weight averaging can reduce the stationary loss without sacrificing the speed of convergence.

Open question and conclusions

  • Explored properties of learning rate in stochastic gradient descent
  • Demonstrated connection between iterate averaging and learning rate schedules
  • Observed in theoretical models and large-scale training on multiple datasets
  • Hope this work paves the way to further understanding of learning rate role in training

A weight averaging

  • Evolution of θ governed by a simplified version of the finite-batch version of equation 5
  • Coordinate transformation that whitens the batch noise
  • Moving average of the training trajectory
  • Derive a simple equation governing it and characterize the resulting converged steady-state distribution
  • Stochastic gradient descent trajectories depend on both A and C matrices
  • Change of coordinates θ t = Qu t removes dependence on one of the matrices
  • Solution of equation 15 can be obtained recursively
  • Expression for X ∞ connects the covariance of the stationary distribution in the weight space to the autocorrelation of the averaging kernel µ t
  • Expression for G 1 and G 2 fully define the covariance F for arbitrary values of λ, ∆ and an arbitrary Ω
  • Effective learning rate is decreased by approximately a factor of n
  • Equivalent learning rate schedules for different averaging methods
  • Proof of Lemma 2