Link to paper

The full paper is available here.

You can also find the paper on PapersWithCode here.

Abstract

  • Neural networks need plasticity to be adaptable and robust.
  • The mechanisms behind plasticity loss are not well understood.
  • This paper looks into plasticity loss to understand it better.
  • Plasticity loss is connected to changes in the loss landscape.
  • Saturated units and divergent gradient norms are not the cause of plasticity loss.
  • Layer normalization can help preserve plasticity.

Paper Content

Introduction

  • Neural networks trained to fit different learning objectives have reduced ability to solve new tasks
  • When inputs and prediction targets change over time, networks must learn to ‘overwrite’ prior predictions
  • Deep reinforcement learning agents are trained in a way that causes plasticity loss
  • Understanding and mitigating plasticity loss is important for developing deep RL agents
  • Existing methods act on potential mechanisms of plasticity loss
  • Loss landscape curvature is a crucial factor in determining plasticity

Background

  • Training a network on one task and then a second can reduce performance on the first task (catastrophic forgetting).
  • Training a network on a series of tasks can result in worse performance on later tasks than a randomly initialized network.

Preliminaries

  • TD learning is a form of reinforcement learning
  • TD learning uses networks trained with sampled transitions from an agent’s interaction with the environment
  • Loss landscape analysis studies the structure of the loss landscape traversed by an optimization algorithm
  • The Hessian of a network is used to measure the sharpness of the loss landscape
  • The gradient covariance is used to measure interference between inputs

Defining plasticity

  • Plasticity has been studied in neuroscience for decades and is a recent topic of interest in deep learning.
  • Classical notions of complexity evaluate whether a hypothesis class contains functions that capture arbitrary patterns, but don’t consider the ability of a search algorithm to find these functions.
  • Plasticity is a problem-dependent property that captures the interaction between the network state, optimization process, and training data.
  • Capacity is a fixed property of a network architecture.
  • Plasticity is measured by the ability of a network to update its predictions in response to a wide array of possible learning signals.
  • Plasticity is defined as the difference between the baseline and the expectation of the final loss obtained by an optimization process.

Two simple studies on plasticity

  • Plasticity loss can occur in learning problems.
  • Optimizer design can interact with nonstationarity to cause instabilities.

Optimizer instability and non-stationarity

  • Deep learning methods have been widely adopted due to the robustness of existing optimizers.
  • Adam optimizer can yield reasonable initial results, but can experience catastrophic divergence when assumptions on stationarity no longer hold.
  • An example of this is shown in Figure 1, where a two-hidden-layer fully-connected neural network is trained to memorize random labels of MNIST images.
  • Increasing and setting a more aggressive decay rate for the second-moment estimate can avoid catastrophic instability.

Loss landscape evolution under non-stationarity

  • Prior work observes reductions in network plasticity
  • Causes of this phenomenon are difficult to determine
  • Experiment compares evolution of two coupled updating procedures
  • Hessian eigenvalue distribution and gradient covariance structure are evaluated
  • Gradient descent induces bias that pushes parameters towards regions of parameter space with less friendly loss landscape
  • Plausible explanations of plasticity loss do not identify robust causal relationships
  • Plasticity loss may arise due to changes in network’s loss landscape

Experimental setting

  • Constructed a simple MDP analogue of image classification
  • Constructed three variants of a block MDP with state space from CIFAR-10 or MNIST image dataset
  • Reward and transition dynamics depend on action taken by agent
  • Reward functions allow comparison of tasks aligned with network’s inductive bias
  • Trained DQN agents on each environment-observation space combination
  • Evaluated ability of each network to fit randomly generated set of target functions
  • Updated target network every 1,000 steps
  • Logged loss after 2,000 steps of optimization
  • Considered two network architectures: MLP and CNN

Falsification of prior hypotheses

  • Prior work has proposed explanations for why neural networks may exhibit reduced ability to fit new targets over time
  • Explanatory power of these hypotheses has not been rigorously tested
  • To test, 128 DQN agents trained under a range of tasks, observation spaces, optimizers, and seeds
  • Logged several statistics of the parameters and activations, along with the plasticity of the parameters
  • Scatterplots show relationship between plasticity and each statistic
  • Correlation between plasticity loss and the quantity of interest is nonexistent or weak

Loss landscape evolution during training

  • Plasticity loss is characterized by reduced ability to fit arbitrary new targets
  • Learning curves convey ease or difficulty of navigating the loss landscape
  • Early training checkpoints quickly attain low losses, but learning curves have increasing variance
  • Increasing difficulty of navigating the loss landscape drives plasticity loss
  • Scaling architecture reduces plasticity loss, but cannot completely eliminate it

Solutions

  • Neural networks can lose plasticity when classifying MNIST digits
  • Section 5.1 will evaluate if scaling can reduce plasticity loss
  • Section 5.2 will evaluate the effect of interventions on plasticity
  • Section 5.3 will test findings on larger scale tasks

The role of scaling on plasticity

  • Plasticity loss is a challenge that may be solved by increasing the size of a network.
  • Scaling a CNN to the limit of a single GPU’s memory is not enough to eliminate plasticity loss.
  • Plasticity loss is unlikely to be the limiting factor for sufficiently large networks on simple tasks.

Interventions in toy problems

  • Evaluated effect of interventions on plasticity loss
  • Task used: 100 iterations of 1000 steps
  • 4 architectures: MLP, CNN, ResNet-18, Vision Transformer
  • Interventions: reset last layer, reset optimizer, layer normalization, Shrink and Perturb, spectral normalization, weight decay
  • Smoothing out loss landscape most effective for preserving plasticity
  • Two-hot encoding reduces plasticity loss but affects stability of policy

Application to larger benchmarks

  • Layer normalization improves performance on large-scale benchmarks
  • Double DQN, RMSProp, -greedy exploration, and frame stacking are used
  • Layer normalization reduces plasticity loss and improves robustness of RL agents
  • Layer normalization induces weaker gradient correlation in environments where it significantly improves performance
  • Initialization and architecture design can affect gradients in neural networks
  • ResNets bias each layer’s mapping towards the identity function to improve gradients
  • Meanfield analysis, information propagation, and deep kernel shaping have been used to study trainability
  • Loss landscape smoothness affects generalization and performance
  • Early training periods can be chaotic and related to linear mode connectivity
  • Loss of plasticity can limit deep reinforcement learning
  • Resetting and distillation can affect performance

Conclusions

  • Divide between curriculum learning and foundation models
  • Plasticity loss is not a limiting factor in network performance for small environments
  • Stabilizing the loss landscape is a crucial step towards promoting plasticity
  • Smoother loss landscape is easier to optimize and has better generalization
  • Optimizer instability studied on MNIST dataset
  • Brownian motion studied on easy classification MDP with MNIST observation space
  • MLP, CNN, ResNet, and Vision Transformer architectures studied
  • Interventions studied: reset last layer, weight decay, spectral norm, use layernorm, shrink and perturb
  • Categorical output representation found to have beneficial effects
  • Interventions interfere with learning on primary task
  • Learning curves of different networks studied
  • Gradient covariance and Hessian eigenvalue density studied
  • Positive and negative correlations between variables and plasticity found