Link to paper
The full paper is available here.
You can also find the paper on PapersWithCode here.
Abstract
- RNNs offer fast inference on long sequences but are hard to optimize and slow to train.
- Deep state-space models (SSMs) have recently been shown to perform well on long sequence modeling tasks and have the added benefits of fast parallelizable training and RNN-like fast inference.
- This paper shows that careful design of deep RNNs can recover the impressive performance of deep SSMs on long-range reasoning tasks, while also matching their training speed.
- The paper introduces an RNN block called the Linear Recurrent Unit that matches both the performance on the Long Range Arena benchmark and their computational efficiency.
Paper Content
Introduction
- RNNs have been used since the early days of deep learning
- RNNs can be hard to train in practice due to the vanishing and exploding gradient problem
- Techniques have been developed to mitigate this issue, such as orthogonal/unitary RNNs and gating mechanisms
- Transformers have gained prominence for sequence modelling tasks
- Transformers do not suffer from the vanishing gradient problem
- Transformers can be expensive to deploy on long sequences
- S4 model is a deep state-space model that achieves remarkable performance on tasks requiring long-range reasoning
- S4 is equivalent to an RNN during inference
- S4 is parameterized as a discretization of a latent continuous-time system of differential equations
- Can the performance and efficiency of deep SSMs be matched using deep RNNs?
- Linear Recurrent Unit (LRU) can match the performance and efficiency of deep SSMs on the Long Range Arena (LRA) benchmark
- Steps towards crafting performant and efficient RNN models: linear recurrences, complex diagonal recurrent matrices, stable exponential parameterization, normalization
Preliminaries
- RNNs and SSMs are key architectural components studied in this work
- Methodology and experimental setup are described
- Related work is discussed in §B
Recap of recurrent block structures
- RNNs and S4-like deep SSMs have major differences
- RNN layer computes a sequence of outputs from a sequence of inputs
- Simplified version of S4 recurrence introduced in Gu et al. (2021a)
- Input is seen as the result of sampling a latent continuous-time signal
- Output sequence is sampled from a signal computed by a continuous-time state-space model
- Computation on the right-hand-side is linear in the hidden state and in the input
- Most parameters are complex valued
- Transition matrix is structured and initialized deterministically
- Computation can be efficiently parallelized until = − 1
- Learning is performed on the continuous-time parameters Ã, B, C, D, Δ
- Most SSMs use complex-valued diagonal recurrent matrices initialized deterministically
Experimental setup
- Long Range Arena benchmark tests models’ ability to do long-range sequence modelling
- Transformers fail to perform well on most of these tasks
- Deep SSMs have shown remarkable performance on these tasks
- Deep RNNs are used to explore long-range modelling capabilities
- 6 layers with residual connections and layer/batch normalization are used
- AdamW optimizer is used with a smaller learning rate and no weight decay on recurrent parameters
Designing performant deep rnns
- Use linear recurrences to improve performance
- Speed up training and inference without affecting expressivity
- Change parameterization and initialization distribution to make RNN stable and improve long-range modeling
- Use normalization strategy for hidden activations to match performance of deep SSMs
Linear rnn layers are performant
- Linear RNN layers can be more expressive than nonlinear RNN variants.
- Removing the nonlinearity can improve test accuracy.
- Linear RNNs can match performance of deep SSMs on some tasks.
- Interleaving linear RNN layers with nonlinear feedforward blocks can approximate highly nonlinear systems.
Using complex diagonal recurrent matrices is efficient
- Deep linear RNNs can be sped up without losing performance by using complex-valued diagonal recurrent matrices.
- Diagonalizing linear systems is a common feature of deep SSMs.
- The recurrence equation can be unrolled using the assumption that −1 = 0 ∈ ℝ.
- The hidden-state of the linear variant can potentially explode or vanish exponentially as increases.
- Eigenvalue analysis can be used to understand this phenomenon.
- Complex numbers provide a convenient and compact representation of non-symmetric matrices in diagonal form.
- Learning recurrent linear systems in diagonal form provides substantial computational speedups.
- An equivalent initialization can be used to keep the eigenvalue spectrum of the recurrence unchanged.
- Diagonalizing the recurrence improves accuracy and reduces training and inference time.
Benefits of stable exponential parameterization
- Moving to complex diagonal recurrences is computationally efficient
- Learning the diagonal model can be more unstable than learning the dense model
- Eigenvalues need to have magnitude close to 1 to learn long-range dependencies
- Parameterization of the RNN as Λ = diag(exp(− + )) decouples magnitude and oscillation frequencies
- Enforcing stability on the eigenvalues is beneficial
- Stable parameterization helps on most LRA tasks
- Initializing eigenvalues to have a small phase helps with long-range reasoning tasks
Additional considerations for long-range reasoning tasks
- Model did not succeed in learning PathX dataset
- Initializing eigenvalues close to unit disk improves performance on long-range tasks
- Forward pass blows up when eigenvalues are close to unit disk
- Normalization scheme for hidden activations to tackle this problem
- Restricting eigenvalue phase at initialization helps solve PathX
- Normalization and restricted phase at initialization are both necessary
- LRU model provides flexible, interpretable, and principled framework
- Matches performance and efficiency of deep SSMs across all LRA tasks
Insights on s4 and variants
- Deep SSMs are successful due to underlying mechanisms
- Diagonal SSMs are instantiated and parameterized through discretization of a latent continuous-time model
- Matrix exponentials make training easier
- To enforce stability, real part of matrix is often fed into a positive nonlinearity
- Power of exponential parameterization is not necessarily attributable to accurate integration
- Magnitude-phase decoupling on the recurrence and learning in diagonalized space make training with Adam easier
- HiPPO theory is not the main source of S4 success
- Discretization changes initialization spectrum
- Discretization performs normalization
- Parameter sharing is not necessary
Conclusion
- Introduction of a new RNN layer called the Linear Recurrent Unit (LRU)
- LRU can be used as core layers of deep sequence models
- Theoretical insights and extensive ablations on a series of modifications of a vanilla RNN to improve performance
- Design does not rely on discretization of a latent continuous-time system or on structured transition matrices
- Improvements directly follow from initialization and forward pass analysis arguments standard in the deep learning community
- Matches performance of modern deep state-space models (e.g. S4 or S5) on all LRA tasks
- Standard RNN-based approaches for sequence-to-sequence modeling
- Historical overview on the progress of the literature stemming from the S4 paper
- RNNs widely used in various applications of natural language processing tasks
- Issue of vanishing or exploding gradients limits the ability of RNNs to learn
- Introduction of gating mechanisms such as the Long Short-Term Memory (LSTM)
- Mitigating the vanishing gradient problem with orthogonal and unitary RNNs
- Deep state-space models (SSMs)
- Gu et al. (2020) provided an alternative view on the vanishing gradient problem
- Discretized structured (i.e., using structured HiPPO matrices) state-space model as a starting for the design and initialization of a novel gated RNN
- Gu et al. (2021a) scaled up this idea into a deep architecture
- Gu et al. (2022a) showed that, to retrieve most performance of S4, one can simply initialize the transition matrix in diagonal form
- Smith et al. (2022) found that one can also depart from the formal one-dimensional discretization structure of S4
- Applications of DSS and S4 in language, vision and audio
- Ma et al. (2022) simplified S4 to a diagonal SSM
- Li et al. (2022a) leveraged the convolutional interpretation of SSMs
- Gupta et al. (2022b) pointed out that, after numerical integration, diagonal state-space models and linear RNNs share the same function approximation class
- Applications of models inspired by the S4 architecture in language modeling, video/audio understanding and generation, biology and time series forecasting