Link to paper

The full paper is available here.

You can also find the paper on PapersWithCode here.

Abstract

  • Foundation models are changing how AI systems are built
  • Practitioners use a standard procedure to build machine learning solutions
  • Internet is full of foundation models fine-tuned on many tasks
  • These individual fine-tunings lack strong generalization and exist in isolation
  • Model recycling leverages multiple fine-tunings of the same foundation model on diverse tasks
  • Model recycling maximizes model diversity and achieves a new state of the art on the DomainBed benchmark

Paper Content

Introduction

  • Framework of foundation models is fueling adoption of machine learning for real-world applications
  • Pre-trained models are easy to adapt to downstream tasks
  • Two-step transfer learning strategy is followed
  • Practitioner downloads a copy of foundation model from authorities
  • Practitioner fine-tunes weights on target task with limited in-house data
  • Risk of latching onto specific patterns from training data
  • Short-sighted models fail to generalize with out-of-distribution test examples
  • Negative impact on human lives
  • DomainBed OOD accuracy benchmark
  • Different training strategies discussed in paper
  • Model recycling proposed
  • Compute parallelism throughout training
  • Maximizes diversity in predictions
  • State-of-the-art performance in DomainBed
  • No inference or training overhead
  • Increased OOD generalization enables responsible use of machine learning
  • Averaging neural networks’ weights inspires modern fine-tuning approaches

Fine-tuning for out-of-distribution generalization

  • Learning setup involves deep learning model with two parts: featurizer and classifier
  • Model is parametrized by weights θ = (w, φ)
  • Aim is to maximize test accuracy acc te (θ) for out-of-distribution (OOD) generalization

Vanilla fine-tuning

  • Fine-tuning is a simple recipe for transferring knowledge from pre-trained models to target tasks.
  • Initializing the classifier with frozen featurizer (linear probing) on the target task improves results.
  • Fine-tuning is the standard practice for datasets with the presence of distribution shift.
  • Successful fine-tuning depends on the scale and diversity of the pre-training corpus.
  • Public repositories of high-quality pre-trained models are available.

Weight averaging over epochs

  • Vanilla fine-tuning was the main strategy to train robust models until weight averaging techniques were introduced.
  • Weight averaging involves saving checkpoints during fine-tuning and building a final model with the average of those checkpoints.
  • Linear mode connectivity (LMC) states that the accuracy of interpolated weights is higher than the interpolated accuracy.
  • Weight averaging strategies improved performance in OOD classification.

Weight averaging over runs

  • Weight averaging techniques suggest that pre-trained weights guide optimization to a flat basin of the loss landscape.
  • Interpolating two random solutions from the same basin can produce solutions with better generalization performance.
  • LMC holds between two instances of models trained from pre-trained weights.

Weight averaging over tasks

  • Pre-trained models can be fine-tuned on external datasets to learn richer features
  • Inter-training fine-tunes pre-trained model on an auxiliary task before tackling the target task, but can lead to catastrophic forgetting
  • Fusing strategies accommodate multiple auxiliary tasks, but only provide marginal gains
  • Model recycling proposes one target fine-tuning per auxiliary weights, and averages weights only as the very last step

Model recycling

  • Model recycling is a proposal to leverage diverse auxiliary fine-tunings of the same pre-trained model
  • Model recycling is a 5-step and parallelizable procedure
  • Model recycling does not rely on intrinsically good auxiliary tasks
  • Model recycling connects some of the fine-tuning strategies from Section 2 while overcoming their limitations
  • Model recycling requires Hypothesis 1 and 2 to hold for successful generalization
  • Model recycling is expected to be a high performer in out-of-distribution generalization benchmarks

Experiments

  • Implemented numerical experiments to support 3 claims
  • Showcased SoTA performance of model recycling in DomainBed
  • Illustrated how performance gains arise from increased diversity across averaged models
  • Provided empirical support for Hypotheses 1 and 2
  • Discussed application of recycling for ID tasks
  • Code released on GitHub

Recycling achieves state-of-the-art ood performance

  • Model recycling is a computer vision technique that uses multiple datasets for OOD generalization.
  • Experiments were conducted on five computer vision datasets: PACS, VLCS, OfficeHome, TerraIncognita and DomainNet.
  • Model Soups/DiWA and ERM (Empirical Risk Minimization) are two approaches used in the experiments.
  • Model recycling, inter-training, fusing and ensembling strategies were also used.
  • Results showed that model recycling achieved a new SoTA.
  • Model recycling was found to be robust to the choice of auxiliary tasks.

Recycling increases model diversity

  • Investigated how diversity across models fine-tuned on target task influences OOD performance
  • Measured diversity with prediction q-diversity
  • Target task is OfficeHome, with “Art” as test OOD domain
  • Trained on “ClipArt”, “Product” and “Photo” domains
  • Inter-training influences final models
  • Diversity gain comes from initialization and remains along fine-tuning
  • Diversity positively linearly correlated with OOD generalization
  • Best performance usually obtained around µ ≈ 0.5
  • Auxiliary task fosters learning of diverse features
  • Recycling increases diversity and improves performance
  • Hypotheses 1 and 2 validated
  • Two exceptions to Hypothesis 2 due to lack of similarity between auxiliary/target task and pre-training task

Discussion: towards updatable machine learning

  • Model recycling is a paradigm for developing machine learning systems that can be incrementally improved and recombined
  • Networks are considered as pieces of software, allowing for open-source development
  • Weights are communicated and averaged only at the end of the learning process
  • Advanced merging operations consider weighted interpolations or neuron permutations
  • Collaborative repositories of fine-tuned models can be built in a decentralized way
  • Practices such as unit tests can be used to specify and test neural networks
  • Datasets can be used as test certificates
  • LMC usually holds in ID, but with smaller gains than in OOD
  • ID and OOD accuracies are not correlated
  • Model recycling sets a new SoTA by leveraging auxiliary tasks’ diversity
  • If two initializations satisfy the LMC, then the two fine-tuned weights too