Link to paper
The full paper is available here.
You can also find the paper on PapersWithCode here.
Abstract
- Training stability is important for Transformers.
- Examined the evolution of attention layers.
- Tracked attention entropy for each attention head during training.
- Low attention entropy is accompanied by high training instability.
- Proposed $\sigma$Reparam to prevent entropy collapse.
- Experiments show $\sigma$Reparam provides stability and robustness.
Paper Content
Introduction
- Transformers are state-of-the-art models in many application domains
- Residual connections and Layer Normalizations are used in the original paper
- Various works have attempted to promote better training stability and robustness
- Attention entropy is tightly correlated with model’s stability and convergence
- Small attention entropy often leads to slow convergence, fluctuations in training loss and divergence
- Modifying the temperature of the Transformer gives direct control over the attention entropy
- Sharpness of the Hessian is related to training stability
- Entropy collapse can be prevented by controlling the spectral norms of the query and key projections
- σReparam reparameterizes all weight matrices to update smoothly and in a controlled way
- Entropy collapse is commonly observed in baseline models of various benchmarks
Related works
- Transformers use LNs to achieve training stability
- Entropy collapse happens even with extensive use of normalization layers
- σReparam does not rely on specific normalization layers and can work without it
- Weight reparameterization has been adopted in deep learning
- σReparam is the first simple reparameterization technique that provides competitive performance
- SpectralNorm explicitly constrains the model’s capacity
- Rank collapse of Transformer training was first identified by Dong et al.
- Entropy collapse characterizes a different failure pattern than rank collapse
Method
Attention entropy
- At the core of Transformers is dot-product attention
- Input sequence is denoted by X ∈ R T ×d
- Attention layer computes Att(X) = AXW V
- Attention entropy of a row i of A is Ent(A i )
- Goal is to alleviate entropy collapse problem and achieve smooth evolution of attention entropy
- Entropy is connected to spectral norm of matrix
- Inputs and weights exist for which lower bound of entropy is tight
- Minimum attainable entropy decreases exponentially with spectral norm
- Transformers are hard to train and can exhibit instability
- Instability and attention entropy collapse appear in tandem
- Preventing attention collapse might prevent instability
- σReparam reparameterizes weights to prevent entropy collapse
- σReparam decouples update rate of spectral norm from dimensionality of weights
- σReparam leaves representational capacity of network intact
Experiments
- Improved robustness of hyperparameters when using σReparam
- σReparam enables simplified framework for training ViT-B, ViT-L and ViT-H models
- σReparam enables SGD training via LARS
- σReparam maintains lower spectral norms for the attention weight matrices
- σReparam presents competitive results against ViT-B/16’s trained on larger datasets
Self-supervised training of visual representations
- SSL has been effective in computer vision
- Most progress has been made using convolutional architectures
- ViTs often require specialized training recipes
- ViTs suffer from training instabilities in SSL tasks
- Instabilities can be remedied with frozen patch embedders, initialization schemes, and longer learning rate warmups
- σReparam is a ViT SSL stabilizer
- σReparam + pre-LN is the best overall method for stable training of SimCLR ViTs
- σReparam + pre-LN produces highest ImageNet1k linear probe performance at 100 and 300 epochs
Machine translation
- Machine translation (MT) is an active research area
- Vanishing gradients problem has been reported
- Solutions to the problem include rescaling residual connections
- σReparam resolves entropy collapse in deep models
- σReparam bounds attention entropy for post-LN and DeepNorm models
- σReparam matches baselines for 6L-6L and is in the same ballpark for 18L-18L, but is inferior for 50L-50L and 100L-100L
Speech recognition and language modeling
- Stabilizes training of post-LN
- Improves robustness with respect to hyperparameters
Conclusion
- Transformer training stability is an unsolved problem
- Attention entropy collapse is a commonly observed failure pattern
- σReparam is a reparameterization of the weights that can address the entropy collapse problem
- It is unclear if there is a causal relationship between entropy collapse and training instability
- σReparam is not a panacea, other techniques can be used in combination
- Theorem 3.1 states that there is a lower bound on attention entropy
- Proposition A.1 states that the entropy is low when temperature is low
- Experiments were done on ImageNet1k
- Temperature interventions were done at epochs 20, 30, 50 and 80
- Temperature interventions lower attention entropy but do not cause instability
- Experiments were done on LibriSpeech dataset
- Vanilla Transformer model was used
- SpecAugment was used
- Adagrad was used
- Learning rate decayed when WER reached a plateau
- Dynamic batching was used
- Training was done on 8 Ampere A100 GPUs
- No weight decay was used