Link to paper
The full paper is available here.
You can also find the paper on PapersWithCode here.
Abstract
- Transfer learning is the dominant paradigm of machine learning.
- Pre-trained models can be fine-tuned for downstream tasks with fewer labelled examples.
- Modular deep learning is a promising solution to challenges of positive transfer and systematic generalisation.
- Modular architectures provide a unified view of research that evolved independently.
- Modularity can be used for scaling language models, causal inference, programme induction, and planning in reinforcement learning.
- Modularity has been successfully deployed in cross-lingual and cross-modal knowledge transfer.
Paper Content
Introduction and motivation
- Transfer learning is a popular machine learning technology
- Pre-training a model on raw data and fine-tuning it for new tasks is successful but has limitations
- Modularity in artificial and biological systems is beneficial for evolvability, adaptability, and resilience
- Vanilla neural networks can learn a modular pattern
- Modular neural architectures have advantages such as positive transfer, compositionality, and parameter efficiency
- Modules can be implemented in different ways
- Routing functions control the flow of information to the modules
- Aggregation strategies are used when multiple modules are selected
- Modules can be trained in different ways
- Modularity is beneficial in transfer learning, causal inference, and programme simulation
- Modular deep learning can be used in natural language processing, computer vision, and speech processing
Modular deep learning
- Focuses on modular deep learning models composed of modules
- Modules are autonomous computation functions referred to as adapters, options, or experts
- Distinguished from routing function which controls information flow
- Aggregation function aggregates outputs
- Modules can be combined with fully shared parameters
- Taxonomy of four dimensions of variation: computation, routing, aggregation, and training
Taxonomy
- A module may consist of any component of a neural architecture
- Routing function is either fixed or learned
- Aggregation function is either deterministic or learnable
- Modules can be trained jointly or post-hoc
Notation
- Neural network decomposed into graph of sub-functions
- Graph is a linear chain
- Sub-functions refer to model’s layers with unique indexed parameters
- Parameters can be further decomposed recursively
- Modules can be modified by parameter composition, concatenation, or feeding output of one function into another
- Routing function returns score for each module conditioned on data or metadata
- Routing function can be hard, soft, or unnormalised score vector
- Output of each module combined through aggregation function
- Aggregation function can be deterministic or learnable neural network
Computation function
- Computation function determines module design
- Different module architectures proposed
- Most often modules integrated into base architecture with shared parameters
- Three core methods to merge single module with ([φ, x])
Routing function
- Fixed routing is a method of generating modular parameters conditioned on metadata.
- Fixed routing is applicable to function composition.
- There are three methods of instantiating modules: parameter composition, input composition, and function composition.
- Examples of the three computation functions are provided in Figure 2 and discussed in Table 3.
Parameter composition
- Parameter composition methods augment a base model with weights and module parameters to make the model more parameter-efficient.
- Sparsity is a common inductive bias based on the assumptions that only a small number of parameters are relevant for a particular task and that similar tasks share similar sub-networks.
- Pruning is the most widespread method to induce sparsity, which can be interpreted as the application of a binary mask.
- Iterative pruning is carried out over multiple stages and often retains or surpasses the performance of the original dense model.
- Winning tickets have been shown to exist in RL, NLP, and computer vision.
- Pruning based on first-order (gradient-based) information better captures the task-specific relevance of each weight.
- Sparsification techniques can be employed for adaptation, such as Sparse Fine-Tuning and Diff pruning.
- Low-rank modules can also be used to save space and time.
- Parameter composition methods are very parameter-efficient and often require updating less than 0.5% of a model’s parameters.
Input composition
- Input composition methods add a parameter vector to a function’s input.
- Common strategy is to add the parameter vector to the model’s first layer.
- Task-specific text prompt is converted to natural language and corresponds to modular parameters.
- Continuous prompt vector can be learned directly.
- Module vectors can be learned for each layer of the model.
Function composition
- Parameter composition deals with individual weights
- Input composition methods act on a function’s input
- Function composition methods add new task-specific sub-functions
- Parameter sharing models in multi-task learning consist of shared layers
- Multi-task architecture can be obtained by tying sets of parameters between models
- Cross-stitch unit linearly combines inputs at every layer
- Sluice networks extend cross-stitch units to multiple modules per layer
- Adapter layers are an alternative to parameter sharing
- Adapter layers are composed with a pre-trained model
- Adapter layers are modality-specific
- Adapter layers can be routed sequentially or in parallel
Hypernetworks
- Hypernetworks are a third kind of routing with unnormalised routing scores.
- Parameters for a task are generated by a linear function.
- Task embedding is the output of a task-level routing function with unnormalised scores.
- Generator is a matrix of module parameters stacked column-wise.
- Parameters are a linear combination of the columns of the linear generator.
- Hypernetworks learn both sets of parameters jointly.
Unifying parameter, input, and function composition
- All modular computation functions can be reduced to function composition.
- Function composition methods use a weighted addition, while parameter and function composition use an unweighted addition.
- Different methods have different trade-offs in terms of capacity, memory footprint, and performance.
- Routing, aggregation, and training settings for the modules are discussed in the following sections.
Routing function
- A decision-making process is required to determine which modules are active in a modular neural architecture.
- This process is implemented through a routing function.
- When metadata is available, the routing decision can be made a priori.
- When no prior information is available, the routing function needs to be learned.
- Learning-to-route can be split into hard routing and soft routing.
Fixed routing
- Making routing decisions based on metadata is referred to as fixed routing
- Fixed routing simplifies the routing function to selecting a subset of modules
- An example of fixed routing is when all parameters, except the final classification layer, are shared among all tasks
- Methods that adapt pre-trained models towards individual tasks also route representations through a newly introduced module
- Fixed routing can select separate language and task components
- Fixed routing can also be based on other metadata such as language, domain, or modality information
Learned routing
- Routing function can be implemented as a learnable neural network
- Learning the routing function implies that the specialisation of each module is unknown
- Training instability, module collapse, and overfitting are challenges
- Challenges are caused by need to balance between exploration and exploitation and sharing modules across examples or tasks
- Training instability can be mitigated by curriculum learning or training the router parameters with a different learning rate
- Module collapse can be avoided by -greedy routing, auxiliary losses, or intrinsic rewards
- Overfitting can be avoided by routing conditioned on metadata or favouring combinatorial behaviour of modules
- Hard routing can be learned via reinforcement learning, evolutionary algorithms, or stochastic re-parameterisation
- Soft routing uses a mixture of experts
- Token-level routing uses top-k selection to load balance computation
Level of routing
- Routing can be done globally, per layer, or hierarchically.
- Allowing for different decisions per layer is more challenging as the space of potential architectures grows exponentially.
- Routing scores are sometimes used to select a subset of modules and to aggregate their outputs.
Aggregation function
- Routing and aggregation of modules are performed simultaneously
- Strategies for aggregating functions are similar to the taxonomy discussed for computation functions
- Aggregation of modular components can be realised on the parameter level
- Output level aggregation is also discussed
Parameter aggregation
- Interpolating module weights can have catastrophic consequences
- Linear mode connectivity suggests that interpolation between multiple models is possible under certain conditions
- Mode paths are not usually linear
- Linear mode connectivity is linked to the Lottery Ticket Hypothesis
- Interpolation is connected to the flatness of the loss landscape
- Success of interpolation is connected to the optimiser used
Representation aggregation
- Representation aggregation is equivalent to parameter aggregation if the functions are linear.
- Representation aggregation does not work for non-linear functions.
- Representation aggregation involves interpolating the outputs of individual modules.
- Representation aggregation can be done by learning weights to interpolate hidden representations.
- Attention-based aggregation functions take into account the information content of the hidden representations.
- Top-k hard routing is more efficient for weighted averaging and attention-based aggregation.
Input aggregation
- Input aggregation is used to create adapters such as prompts or prefix tuning
- Hypernetworks can combine different embeddings in the input to the parameter generator
- Embeddings can represent the position of the generated parameters in the neural architecture
Function aggregation
- Aggregation can be achieved on the function level
- Different aggregation methods infer either a sequence or a tree structure
- Forward pass through multiple modules transforms hidden representations
- Pfeiffer et al. propose a two-stage setup for zero-shot cross-lingual transfer
- Stickland et al. perform function composition for multilingual multi-domain machine translation
- Neural Module Networks use a semantic parse to infer a graphical structure for module aggregation
Training setting
- All modules can be trained together for multi-task learning
- Modules can be introduced at different stages during continual learning
- Transfer learning involves adding modules post-hoc after pre-training
Joint multitask learning
- Joint multi-task learning has two main settings
- Task-specific parameterised components can be integrated into shared neural network architectures
- Alternative is to have a fully modular architecture, sharing only the parameters for learned routing
- Joint training can also be performed before post-hoc training
Continual learning
- Multi-task learning and continual learning aim to prevent catastrophic forgetting
- New layers can be added to the network to update new data while keeping the others untouched
- Progressive Networks scale the model capacity linearly with the number of tasks
- Separate experts can be trained for each task
- Subnetworks of the model can be identified and updated without affecting previously learned knowledge
- Supermasks can be used to extend to a vast number of tasks during continual learning
Parameter-efficient transfer learning
- Transfer learning is the dominating strategy for state-of-the-art results
- Models are pre-trained on large amounts of data and then fine-tuned on target tasks
- Parameter-efficient fine-tuning strategies exist
- Modularity can be achieved through parameter, input and function composition
- Hypernetworks can be used to generate parameters of modules
- All methods share the same functional form