Link to paper
The full paper is available here.
You can also find the paper on PapersWithCode here.
Abstract
- Foundation models are changing how AI systems are built
- Practitioners use a standard procedure to build machine learning solutions
- Internet is full of foundation models fine-tuned on many tasks
- These individual fine-tunings lack strong generalization and exist in isolation
- Model recycling leverages multiple fine-tunings of the same foundation model on diverse tasks
- Model recycling maximizes model diversity and achieves a new state of the art on the DomainBed benchmark
Paper Content
Introduction
- Framework of foundation models is fueling adoption of machine learning for real-world applications
- Pre-trained models are easy to adapt to downstream tasks
- Two-step transfer learning strategy is followed
- Practitioner downloads a copy of foundation model from authorities
- Practitioner fine-tunes weights on target task with limited in-house data
- Risk of latching onto specific patterns from training data
- Short-sighted models fail to generalize with out-of-distribution test examples
- Negative impact on human lives
- DomainBed OOD accuracy benchmark
- Different training strategies discussed in paper
- Model recycling proposed
- Compute parallelism throughout training
- Maximizes diversity in predictions
- State-of-the-art performance in DomainBed
- No inference or training overhead
- Increased OOD generalization enables responsible use of machine learning
- Averaging neural networks’ weights inspires modern fine-tuning approaches
Fine-tuning for out-of-distribution generalization
- Learning setup involves deep learning model with two parts: featurizer and classifier
- Model is parametrized by weights θ = (w, φ)
- Aim is to maximize test accuracy acc te (θ) for out-of-distribution (OOD) generalization
Vanilla fine-tuning
- Fine-tuning is a simple recipe for transferring knowledge from pre-trained models to target tasks.
- Initializing the classifier with frozen featurizer (linear probing) on the target task improves results.
- Fine-tuning is the standard practice for datasets with the presence of distribution shift.
- Successful fine-tuning depends on the scale and diversity of the pre-training corpus.
- Public repositories of high-quality pre-trained models are available.
Weight averaging over epochs
- Vanilla fine-tuning was the main strategy to train robust models until weight averaging techniques were introduced.
- Weight averaging involves saving checkpoints during fine-tuning and building a final model with the average of those checkpoints.
- Linear mode connectivity (LMC) states that the accuracy of interpolated weights is higher than the interpolated accuracy.
- Weight averaging strategies improved performance in OOD classification.
Weight averaging over runs
- Weight averaging techniques suggest that pre-trained weights guide optimization to a flat basin of the loss landscape.
- Interpolating two random solutions from the same basin can produce solutions with better generalization performance.
- LMC holds between two instances of models trained from pre-trained weights.
Weight averaging over tasks
- Pre-trained models can be fine-tuned on external datasets to learn richer features
- Inter-training fine-tunes pre-trained model on an auxiliary task before tackling the target task, but can lead to catastrophic forgetting
- Fusing strategies accommodate multiple auxiliary tasks, but only provide marginal gains
- Model recycling proposes one target fine-tuning per auxiliary weights, and averages weights only as the very last step
Model recycling
- Model recycling is a proposal to leverage diverse auxiliary fine-tunings of the same pre-trained model
- Model recycling is a 5-step and parallelizable procedure
- Model recycling does not rely on intrinsically good auxiliary tasks
- Model recycling connects some of the fine-tuning strategies from Section 2 while overcoming their limitations
- Model recycling requires Hypothesis 1 and 2 to hold for successful generalization
- Model recycling is expected to be a high performer in out-of-distribution generalization benchmarks
Experiments
- Implemented numerical experiments to support 3 claims
- Showcased SoTA performance of model recycling in DomainBed
- Illustrated how performance gains arise from increased diversity across averaged models
- Provided empirical support for Hypotheses 1 and 2
- Discussed application of recycling for ID tasks
- Code released on GitHub
Recycling achieves state-of-the-art ood performance
- Model recycling is a computer vision technique that uses multiple datasets for OOD generalization.
- Experiments were conducted on five computer vision datasets: PACS, VLCS, OfficeHome, TerraIncognita and DomainNet.
- Model Soups/DiWA and ERM (Empirical Risk Minimization) are two approaches used in the experiments.
- Model recycling, inter-training, fusing and ensembling strategies were also used.
- Results showed that model recycling achieved a new SoTA.
- Model recycling was found to be robust to the choice of auxiliary tasks.
Recycling increases model diversity
- Investigated how diversity across models fine-tuned on target task influences OOD performance
- Measured diversity with prediction q-diversity
- Target task is OfficeHome, with “Art” as test OOD domain
- Trained on “ClipArt”, “Product” and “Photo” domains
- Inter-training influences final models
- Diversity gain comes from initialization and remains along fine-tuning
- Diversity positively linearly correlated with OOD generalization
- Best performance usually obtained around µ ≈ 0.5
- Auxiliary task fosters learning of diverse features
- Recycling increases diversity and improves performance
- Hypotheses 1 and 2 validated
- Two exceptions to Hypothesis 2 due to lack of similarity between auxiliary/target task and pre-training task
Discussion: towards updatable machine learning
- Model recycling is a paradigm for developing machine learning systems that can be incrementally improved and recombined
- Networks are considered as pieces of software, allowing for open-source development
- Weights are communicated and averaged only at the end of the learning process
- Advanced merging operations consider weighted interpolations or neuron permutations
- Collaborative repositories of fine-tuned models can be built in a decentralized way
- Practices such as unit tests can be used to specify and test neural networks
- Datasets can be used as test certificates
- LMC usually holds in ID, but with smaller gains than in OOD
- ID and OOD accuracies are not correlated
- Model recycling sets a new SoTA by leveraging auxiliary tasks’ diversity
- If two initializations satisfy the LMC, then the two fine-tuned weights too