Link to paper

The full paper is available here.

You can also find the paper on PapersWithCode here.

Abstract

  • Data multiplexing is a method to improve a model’s inference efficiency.
  • Prior work on data multiplexing only used task-specific Transformers without pre-training.
  • This paper develops pre-trained multiplexed language models (MUX-PLMs).
  • MUX-PLMs can be widely finetuned on any downstream task.
  • MUX-PLMs include a three-stage training procedure and novel multiplexing and demultiplexing modules.
  • MUX-BERT and MUX-ELECTRA models achieve 2x/5x inference speedup with a 2-4% drop in performance on GLUE and 1-2% drop on token-level tasks.

Paper Content

Introduction

  • Pre-trained language models (PLMs) are getting larger
  • Need for efficient inference techniques to make use of PLMs in high-volume and time-sensitive applications
  • Network pruning, knowledge distillation and quantization aim to make a fully trained model sparser with fewer parameters
  • Data multiplexing can make inputs to a neural network denser and improve inference efficiency
  • Data multiplexing adapted to obtain pre-trained language models (MUX-PLMs)
  • MUX-PLMs process multiple inputs in parallel with a single forward pass
  • MUX-PLMs do not require fine-tuning or a priori access to task-specific data
  • New demultiplexing module and Contextual Multiplexer introduced to improve performance
  • MUX-PLMs evaluated on GLUE benchmark and token classification tasks
  • MUX-PLMs achieve close to state-of-the-art scores with multi-fold throughput increase

Methodology

  • Transformer model introduced in DataMUX (Murahari et al., 2022) is denoted T-MUX
  • Data multiplexing allows parallel processing of multiple sequences with a single forward or backward pass through the model
  • Multiplexer combines an ordered set of multiple inputs
  • Vector (v i โˆˆ R d ) is sampled from a standard multivariate Gaussian and applied to the corresponding input representation (x i )
  • Model processes the multiplexed representation and emits a multiplexed hidden stateh MUX
  • Demultiplexer separates the superimposed output (h MUX ) into N output representations corresponding to the input
  • Vector p i โˆˆ R d is dynamically generated for each instance (i) with the help of a prefix
  • MUX-PLMs apply data multiplexing during pre-training for both the BERT and ELECTRA training objectives
  • Model is trained in three stages
  • Contextual multiplexer aggregates context from tokens in the same instance and tokens in the same position of other instances
  • Keys (v i and k i ) are randomly initialized and learned to demultiplex the output representation
  • Models are compared against T-MUX, ELECTRA and BERT across three different model configurations
  • Results are reported across 5 random seeds

Results

  • MUX-BERT and MUX-ELECTRA outperform T-MUX on all levels of multiplexing
  • MUX-PLMs have faster throughput than T-MUX
  • MUX-PLMs provide a significant boost in throughput compared to PLMs without a significant loss in performance
  • As multiplexing level increases, MUX-PLMs’ throughput is better but performance can degrade
  • MUX-PLMs achieve competitive performance and throughput improvement without additional data or task-specific data

Effect of varying model size

  • MUX-BERT has close performance to BERT while having significantly better throughput (2x).
  • Performance drops with MUX-BERT are 1.6 and 1.7 points on GLUE for SMALL and LARGE respectively.
  • MUX-PLMs offer a performance-computational efficiency trade-off, with larger N leading to better throughput but lower performance.
  • All multiplexed models lie either on or very close to the Pareto frontier.
  • Ensembling improves performance for all models, with gains increasing with increasing N.
  • Non-ensemble variants are faster but perform slightly worse, while the ensemble variant performs better but is slower.

Ablation study

  • MUX-PLMs have two variants, one using prefix demultiplexing and one using contextual multiplexing
  • Variant 1, using prefix demultiplexing, performs worse than MUX-BERT, except for N = 2
  • Variant 2, using contextual multiplexing, performs better than non-contextual for TOKEN tasks but worse for GLUE tasks

Muxology: analyzing hidden representations of multiplexed models

Effect of data sampling strategies during inference

  • MUX-PLMs sample N instances uniformly at random from the evaluation set.
  • Other data-sampling strategies such as clustering similar instances based on word-overlap could improve performance.
  • Performance of MUX-PLMs is measured by the difference between the best- and worst-performing ticket.
  • Improved data sampling strategy could lead to improvements.

Conclusion

  • Proposed methods to adapt data multiplexing for pretraining language models
  • Demonstrated method on BERT and ELECTRA pre-training schemes
  • MUX-BERT and MUX-ELECTRA models achieved close to state-of-the-art performance on several downstream sequence classification and token classification tasks
  • Multi-fold inference throughput speedup
  • New demultiplexing module faster and more performant
  • New contextual multiplexing module performs better on token classification tasks
  • Primed with token retrieval task
  • Trained on Wikipedia and Bookscorpus datasets
  • Hyper-parameter search over two learning rates
  • Fine-tuning experiments trained on 1 V100 GPU
  • Calculated throughput on single V100 GPU
  • Compared MUX-BERT with state-of-the-art pruning and distillation methods
  • MUX-BERT achieved competitive performance and throughput without using any additional data
  • Activation norms tend to spike for MUX-BERT in the last layer
  • Entropy of MUX-BERT is lower than BERT for higher layers
  • GLUE results for MUX-BERT and MUX-ELECTRA reported