Link to paper

The full paper is available here.

You can also find the paper on PapersWithCode here.

Abstract

  • Training stability is important for Transformers.
  • Examined the evolution of attention layers.
  • Tracked attention entropy for each attention head during training.
  • Low attention entropy is accompanied by high training instability.
  • Proposed $\sigma$Reparam to prevent entropy collapse.
  • Experiments show $\sigma$Reparam provides stability and robustness.

Paper Content

Introduction

  • Transformers are state-of-the-art models in many application domains
  • Residual connections and Layer Normalizations are used in the original paper
  • Various works have attempted to promote better training stability and robustness
  • Attention entropy is tightly correlated with model’s stability and convergence
  • Small attention entropy often leads to slow convergence, fluctuations in training loss and divergence
  • Modifying the temperature of the Transformer gives direct control over the attention entropy
  • Sharpness of the Hessian is related to training stability
  • Entropy collapse can be prevented by controlling the spectral norms of the query and key projections
  • σReparam reparameterizes all weight matrices to update smoothly and in a controlled way
  • Entropy collapse is commonly observed in baseline models of various benchmarks
  • Transformers use LNs to achieve training stability
  • Entropy collapse happens even with extensive use of normalization layers
  • σReparam does not rely on specific normalization layers and can work without it
  • Weight reparameterization has been adopted in deep learning
  • σReparam is the first simple reparameterization technique that provides competitive performance
  • SpectralNorm explicitly constrains the model’s capacity
  • Rank collapse of Transformer training was first identified by Dong et al.
  • Entropy collapse characterizes a different failure pattern than rank collapse

Method

Attention entropy

  • At the core of Transformers is dot-product attention
  • Input sequence is denoted by X ∈ R T ×d
  • Attention layer computes Att(X) = AXW V
  • Attention entropy of a row i of A is Ent(A i )
  • Goal is to alleviate entropy collapse problem and achieve smooth evolution of attention entropy
  • Entropy is connected to spectral norm of matrix
  • Inputs and weights exist for which lower bound of entropy is tight
  • Minimum attainable entropy decreases exponentially with spectral norm
  • Transformers are hard to train and can exhibit instability
  • Instability and attention entropy collapse appear in tandem
  • Preventing attention collapse might prevent instability
  • σReparam reparameterizes weights to prevent entropy collapse
  • σReparam decouples update rate of spectral norm from dimensionality of weights
  • σReparam leaves representational capacity of network intact

Experiments

  • Improved robustness of hyperparameters when using σReparam
  • σReparam enables simplified framework for training ViT-B, ViT-L and ViT-H models
  • σReparam enables SGD training via LARS
  • σReparam maintains lower spectral norms for the attention weight matrices
  • σReparam presents competitive results against ViT-B/16’s trained on larger datasets

Self-supervised training of visual representations

  • SSL has been effective in computer vision
  • Most progress has been made using convolutional architectures
  • ViTs often require specialized training recipes
  • ViTs suffer from training instabilities in SSL tasks
  • Instabilities can be remedied with frozen patch embedders, initialization schemes, and longer learning rate warmups
  • σReparam is a ViT SSL stabilizer
  • σReparam + pre-LN is the best overall method for stable training of SimCLR ViTs
  • σReparam + pre-LN produces highest ImageNet1k linear probe performance at 100 and 300 epochs

Machine translation

  • Machine translation (MT) is an active research area
  • Vanishing gradients problem has been reported
  • Solutions to the problem include rescaling residual connections
  • σReparam resolves entropy collapse in deep models
  • σReparam bounds attention entropy for post-LN and DeepNorm models
  • σReparam matches baselines for 6L-6L and is in the same ballpark for 18L-18L, but is inferior for 50L-50L and 100L-100L

Speech recognition and language modeling

  • Stabilizes training of post-LN
  • Improves robustness with respect to hyperparameters

Conclusion

  • Transformer training stability is an unsolved problem
  • Attention entropy collapse is a commonly observed failure pattern
  • σReparam is a reparameterization of the weights that can address the entropy collapse problem
  • It is unclear if there is a causal relationship between entropy collapse and training instability
  • σReparam is not a panacea, other techniques can be used in combination
  • Theorem 3.1 states that there is a lower bound on attention entropy
  • Proposition A.1 states that the entropy is low when temperature is low
  • Experiments were done on ImageNet1k
  • Temperature interventions were done at epochs 20, 30, 50 and 80
  • Temperature interventions lower attention entropy but do not cause instability
  • Experiments were done on LibriSpeech dataset
  • Vanilla Transformer model was used
  • SpecAugment was used
  • Adagrad was used
  • Learning rate decayed when WER reached a plateau
  • Dynamic batching was used
  • Training was done on 8 Ampere A100 GPUs
  • No weight decay was used