Link to paper
The full paper is available here.
You can also find the paper on PapersWithCode here.
Abstract
- Few-shot learning involves learning an effective model from only a few labeled datapoints.
- FLAD is a training paradigm that uses auxiliary data to improve generalization.
- Automated sampling strategies are related to the explore-exploit dilemma.
- Two algorithms are proposed and compared with methods that either explore or exploit.
- Using the proposed algorithms yields a 9% absolute improvement.
Paper Content
Introduction
- Few-shot learning is an attractive learning setting due to efficiency and data availability.
- Few-shot learning is challenging and requires a balance between learning and preventing overfitting.
- FLAD is one approach to improving generalizability of models in the few-shot setting.
- FLAD methods can introduce their own challenges, including increased complexity.
- Manually designing the curriculum for training on large quantities of auxiliary data is not feasible.
- Delegating design choices to an algorithm can lead to better solutions.
- FLAD is related to the exploration-exploitation trade-off in multiarmed bandit settings.
- EXP3-FLAD and UCB1-FLAD are two efficient algorithms designed to enhance few-shot generalization.
- EXP3-FLAD and UCB1-FLAD outperform baselines by up to 9.1%.
- Case studies show why EXP3-FLAD and UCB1-FLAD outperform baselines.
Background
- Describes the few-shot learning with auxiliary data (FLAD) problem
- Presents the goals of the multi-armed bandit setting
- Connects the adversarial bandit setting to FLAD
- Reviews the EXP3 algorithm
- Presents the UCB1 algorithm
Few-shot learning with auxiliary data
- FLAD aims to improve generalization when training on a small quantity of target data
- FLAD uses a much larger quantity of related auxiliary data
- Prior work sets up training in stages, using all of the auxiliary data in the first stage
- Other approaches use retrieval-augmented models, scale the loss of auxiliary datasets, or include unsupervised auxiliary data
- FLAD refers specifically to the setting where the final stage of training includes simultaneous training on both supervised auxiliary and target data
Multi-armed bandits
- Multi-Armed Bandit (MAB) is a problem from machine learning
- Learner interacts with environment over N rounds by following a policy
- Goal is to adopt policy that leads to largest cumulative reward
- Adversarial MAB setting assumes reward-generating process is controlled by an adversary
- EXP3 algorithm targets adversarial MAB problem
- UCB1 algorithm assigns each arm a value called the upper confidence bound
- UCB1 suggests learner should play arms with large expected reward or not well explored
From mab to flad
- FLAD is formulated as a multi-armed bandit problem.
- At each round, a batch of data is sampled from a single dataset.
- Model parameters are updated after every G rounds.
- Reward function is based on cosine similarity between gradients of sampled auxiliary dataset batch and the whole target dataset.
Modelling assumptions, implications, and consequences
- Environment modelled as adversary due to non-convex loss landscape, SGD, and gradient alignment reward function
- Parameters shared across datasets
- Mini-batch used to approximate gradient of full dataset
- Success of policy measured by reward, not necessarily metric being optimized
Adapting mab algorithms for flad
- EXP3 and UCB1 algorithms are modified for FLAD setting
- Assumption made that helpfulness/harmfulness of training on a particular auxiliary dataset will be consistent throughout training
- Exploration rate decays over time
- Negative rewards are possible due to cosine similarity
- Learner can push down on probability of “bad” datasets
- UCB1-FLAD includes exponential moving average when estimating mean reward
- Estimated rewards initialized with larger data quantities
- Upper confidence bounds calculated for each auxiliary dataset
- Gradients calculated for target dataset and all auxiliary datasets
Experimental setup
- Experiments utilize encoder-decoder models from the T5 family
- Models include T5-LM and T0
- T5-LM is trained on C4 dataset for 100,000 steps
- T0 is trained on a mixture of prompted datasets
- Experiments repeated with T5-XL and T0-3B models
- Datasets obtained from Hugging Face Datasets
- Evaluated on held-out datasets covering 4 tasks
- 5 few-shot splits sampled from training data
- Compare performance using 2 sets of auxiliary data
Training details
- Learning rate of 1e-4 used for all methods except target-only fine-tuning
- Batch sizes of 32 and 128 used for target-only fine-tuning and FLAD baselines
- Mini-batches of 8 samples used for EXP3-FLAD and UCB1-FLAD
- Adafactor 1 optimizer used
- Validation-based early stopping used
- Best performance achieved using parameters from language modelling head
Findings and analyses
- EXP3-FLAD and UCB1-FLAD algorithms are compared to other FLAD methods
- FLAD helps models generalize
- All FLAD methods provide significant improvement in few-shot generalization over target-only fine-tuning
- EXP3-FLAD and UCB1-FLAD consistently outperform the explore-only and exploit-only FLAD methods
- EXP3-FLAD and UCB1-FLAD show a notable improvement of 2.6-3.5% when leveraging additional auxiliary data
- UCB1-FLAD outperforms EXP3-FLAD
- UCB1-FLAD has a bimodal sampling distribution, while EXP3-FLAD has a more flat distribution
- EXP3-FLAD rarely assigns an auxiliary dataset a probability < 1%
- T0 models trained using FLAD-based methods outperform all other methods
- FLAD-based methods provide robust performance across datasets
Related work
- Data selection involves selecting specific datapoints based on criteria related to the model being trained.
- FLAD is similar to transfer-learning, meta-learning, and multi-task learning.
- MAB has been used to improve supervised learning.
Conclusion
- Proposed methods demonstrate effectiveness of simultaneous training on auxiliary and target datasets in few-shot settings
- Continuous updating of beliefs by exploring and exploiting auxiliary data
- Framing FLAD as an MAB problem
- FLAD is not a strong adversary
- Gradient alignments maintain a weak form of stationarity
- Useful information in the noisy signal that models can utilize
- Improving reward function and reducing space complexity
- Validation-based early stopping, max 10,000 gradient update steps
- Smoothing factor ฮฒ = 0.9
- Rewards using gradients from LM head provides best performance
- G + 1 gradients stored at any time
- UCB1-FLAD outperforms EXP3-FLAD on RTE
- EXP3-FLAD outperforms UCB1-FLAD on COPA
- Gradient alignment signal is particularly noisy for EXP3-FLAD
- EXP3 never makes large separations in empirical sampling distribution
- Detailed results from main experiment including direct fine-tuning, exploration-only, exploitation-only baselines and proposed methods