Link to paper
The full paper is available here.
You can also find the paper on PapersWithCode here.
Abstract
- Data multiplexing is a method to improve a model’s inference efficiency.
- Prior work on data multiplexing only used task-specific Transformers without pre-training.
- This paper develops pre-trained multiplexed language models (MUX-PLMs).
- MUX-PLMs can be widely finetuned on any downstream task.
- MUX-PLMs include a three-stage training procedure and novel multiplexing and demultiplexing modules.
- MUX-BERT and MUX-ELECTRA models achieve 2x/5x inference speedup with a 2-4% drop in performance on GLUE and 1-2% drop on token-level tasks.
Paper Content
Introduction
- Pre-trained language models (PLMs) are getting larger
- Need for efficient inference techniques to make use of PLMs in high-volume and time-sensitive applications
- Network pruning, knowledge distillation and quantization aim to make a fully trained model sparser with fewer parameters
- Data multiplexing can make inputs to a neural network denser and improve inference efficiency
- Data multiplexing adapted to obtain pre-trained language models (MUX-PLMs)
- MUX-PLMs process multiple inputs in parallel with a single forward pass
- MUX-PLMs do not require fine-tuning or a priori access to task-specific data
- New demultiplexing module and Contextual Multiplexer introduced to improve performance
- MUX-PLMs evaluated on GLUE benchmark and token classification tasks
- MUX-PLMs achieve close to state-of-the-art scores with multi-fold throughput increase
Related work
Methodology
- Transformer model introduced in DataMUX (Murahari et al., 2022) is denoted T-MUX
- Data multiplexing allows parallel processing of multiple sequences with a single forward or backward pass through the model
- Multiplexer combines an ordered set of multiple inputs
- Vector (v i โ R d ) is sampled from a standard multivariate Gaussian and applied to the corresponding input representation (x i )
- Model processes the multiplexed representation and emits a multiplexed hidden stateh MUX
- Demultiplexer separates the superimposed output (h MUX ) into N output representations corresponding to the input
- Vector p i โ R d is dynamically generated for each instance (i) with the help of a prefix
- MUX-PLMs apply data multiplexing during pre-training for both the BERT and ELECTRA training objectives
- Model is trained in three stages
- Contextual multiplexer aggregates context from tokens in the same instance and tokens in the same position of other instances
- Keys (v i and k i ) are randomly initialized and learned to demultiplex the output representation
- Models are compared against T-MUX, ELECTRA and BERT across three different model configurations
- Results are reported across 5 random seeds
Results
- MUX-BERT and MUX-ELECTRA outperform T-MUX on all levels of multiplexing
- MUX-PLMs have faster throughput than T-MUX
- MUX-PLMs provide a significant boost in throughput compared to PLMs without a significant loss in performance
- As multiplexing level increases, MUX-PLMs’ throughput is better but performance can degrade
- MUX-PLMs achieve competitive performance and throughput improvement without additional data or task-specific data
Effect of varying model size
- MUX-BERT has close performance to BERT while having significantly better throughput (2x).
- Performance drops with MUX-BERT are 1.6 and 1.7 points on GLUE for SMALL and LARGE respectively.
- MUX-PLMs offer a performance-computational efficiency trade-off, with larger N leading to better throughput but lower performance.
- All multiplexed models lie either on or very close to the Pareto frontier.
- Ensembling improves performance for all models, with gains increasing with increasing N.
- Non-ensemble variants are faster but perform slightly worse, while the ensemble variant performs better but is slower.
Ablation study
- MUX-PLMs have two variants, one using prefix demultiplexing and one using contextual multiplexing
- Variant 1, using prefix demultiplexing, performs worse than MUX-BERT, except for N = 2
- Variant 2, using contextual multiplexing, performs better than non-contextual for TOKEN tasks but worse for GLUE tasks
Muxology: analyzing hidden representations of multiplexed models
Effect of data sampling strategies during inference
- MUX-PLMs sample N instances uniformly at random from the evaluation set.
- Other data-sampling strategies such as clustering similar instances based on word-overlap could improve performance.
- Performance of MUX-PLMs is measured by the difference between the best- and worst-performing ticket.
- Improved data sampling strategy could lead to improvements.
Conclusion
- Proposed methods to adapt data multiplexing for pretraining language models
- Demonstrated method on BERT and ELECTRA pre-training schemes
- MUX-BERT and MUX-ELECTRA models achieved close to state-of-the-art performance on several downstream sequence classification and token classification tasks
- Multi-fold inference throughput speedup
- New demultiplexing module faster and more performant
- New contextual multiplexing module performs better on token classification tasks
- Primed with token retrieval task
- Trained on Wikipedia and Bookscorpus datasets
- Hyper-parameter search over two learning rates
- Fine-tuning experiments trained on 1 V100 GPU
- Calculated throughput on single V100 GPU
- Compared MUX-BERT with state-of-the-art pruning and distillation methods
- MUX-BERT achieved competitive performance and throughput without using any additional data
- Activation norms tend to spike for MUX-BERT in the last layer
- Entropy of MUX-BERT is lower than BERT for higher layers
- GLUE results for MUX-BERT and MUX-ELECTRA reported