Link to paper

The full paper is available here.

You can also find the paper on PapersWithCode here.

Abstract

  • Few-shot learning involves learning an effective model from only a few labeled datapoints.
  • FLAD is a training paradigm that uses auxiliary data to improve generalization.
  • Automated sampling strategies are related to the explore-exploit dilemma.
  • Two algorithms are proposed and compared with methods that either explore or exploit.
  • Using the proposed algorithms yields a 9% absolute improvement.

Paper Content

Introduction

  • Few-shot learning is an attractive learning setting due to efficiency and data availability.
  • Few-shot learning is challenging and requires a balance between learning and preventing overfitting.
  • FLAD is one approach to improving generalizability of models in the few-shot setting.
  • FLAD methods can introduce their own challenges, including increased complexity.
  • Manually designing the curriculum for training on large quantities of auxiliary data is not feasible.
  • Delegating design choices to an algorithm can lead to better solutions.
  • FLAD is related to the exploration-exploitation trade-off in multiarmed bandit settings.
  • EXP3-FLAD and UCB1-FLAD are two efficient algorithms designed to enhance few-shot generalization.
  • EXP3-FLAD and UCB1-FLAD outperform baselines by up to 9.1%.
  • Case studies show why EXP3-FLAD and UCB1-FLAD outperform baselines.

Background

  • Describes the few-shot learning with auxiliary data (FLAD) problem
  • Presents the goals of the multi-armed bandit setting
  • Connects the adversarial bandit setting to FLAD
  • Reviews the EXP3 algorithm
  • Presents the UCB1 algorithm

Few-shot learning with auxiliary data

  • FLAD aims to improve generalization when training on a small quantity of target data
  • FLAD uses a much larger quantity of related auxiliary data
  • Prior work sets up training in stages, using all of the auxiliary data in the first stage
  • Other approaches use retrieval-augmented models, scale the loss of auxiliary datasets, or include unsupervised auxiliary data
  • FLAD refers specifically to the setting where the final stage of training includes simultaneous training on both supervised auxiliary and target data

Multi-armed bandits

  • Multi-Armed Bandit (MAB) is a problem from machine learning
  • Learner interacts with environment over N rounds by following a policy
  • Goal is to adopt policy that leads to largest cumulative reward
  • Adversarial MAB setting assumes reward-generating process is controlled by an adversary
  • EXP3 algorithm targets adversarial MAB problem
  • UCB1 algorithm assigns each arm a value called the upper confidence bound
  • UCB1 suggests learner should play arms with large expected reward or not well explored

From mab to flad

  • FLAD is formulated as a multi-armed bandit problem.
  • At each round, a batch of data is sampled from a single dataset.
  • Model parameters are updated after every G rounds.
  • Reward function is based on cosine similarity between gradients of sampled auxiliary dataset batch and the whole target dataset.

Modelling assumptions, implications, and consequences

  • Environment modelled as adversary due to non-convex loss landscape, SGD, and gradient alignment reward function
  • Parameters shared across datasets
  • Mini-batch used to approximate gradient of full dataset
  • Success of policy measured by reward, not necessarily metric being optimized

Adapting mab algorithms for flad

  • EXP3 and UCB1 algorithms are modified for FLAD setting
  • Assumption made that helpfulness/harmfulness of training on a particular auxiliary dataset will be consistent throughout training
  • Exploration rate decays over time
  • Negative rewards are possible due to cosine similarity
  • Learner can push down on probability of “bad” datasets
  • UCB1-FLAD includes exponential moving average when estimating mean reward
  • Estimated rewards initialized with larger data quantities
  • Upper confidence bounds calculated for each auxiliary dataset
  • Gradients calculated for target dataset and all auxiliary datasets

Experimental setup

  • Experiments utilize encoder-decoder models from the T5 family
  • Models include T5-LM and T0
  • T5-LM is trained on C4 dataset for 100,000 steps
  • T0 is trained on a mixture of prompted datasets
  • Experiments repeated with T5-XL and T0-3B models
  • Datasets obtained from Hugging Face Datasets
  • Evaluated on held-out datasets covering 4 tasks
  • 5 few-shot splits sampled from training data
  • Compare performance using 2 sets of auxiliary data

Training details

  • Learning rate of 1e-4 used for all methods except target-only fine-tuning
  • Batch sizes of 32 and 128 used for target-only fine-tuning and FLAD baselines
  • Mini-batches of 8 samples used for EXP3-FLAD and UCB1-FLAD
  • Adafactor 1 optimizer used
  • Validation-based early stopping used
  • Best performance achieved using parameters from language modelling head

Findings and analyses

  • EXP3-FLAD and UCB1-FLAD algorithms are compared to other FLAD methods
  • FLAD helps models generalize
  • All FLAD methods provide significant improvement in few-shot generalization over target-only fine-tuning
  • EXP3-FLAD and UCB1-FLAD consistently outperform the explore-only and exploit-only FLAD methods
  • EXP3-FLAD and UCB1-FLAD show a notable improvement of 2.6-3.5% when leveraging additional auxiliary data
  • UCB1-FLAD outperforms EXP3-FLAD
  • UCB1-FLAD has a bimodal sampling distribution, while EXP3-FLAD has a more flat distribution
  • EXP3-FLAD rarely assigns an auxiliary dataset a probability < 1%
  • T0 models trained using FLAD-based methods outperform all other methods
  • FLAD-based methods provide robust performance across datasets
  • Data selection involves selecting specific datapoints based on criteria related to the model being trained.
  • FLAD is similar to transfer-learning, meta-learning, and multi-task learning.
  • MAB has been used to improve supervised learning.

Conclusion

  • Proposed methods demonstrate effectiveness of simultaneous training on auxiliary and target datasets in few-shot settings
  • Continuous updating of beliefs by exploring and exploiting auxiliary data
  • Framing FLAD as an MAB problem
  • FLAD is not a strong adversary
  • Gradient alignments maintain a weak form of stationarity
  • Useful information in the noisy signal that models can utilize
  • Improving reward function and reducing space complexity
  • Validation-based early stopping, max 10,000 gradient update steps
  • Smoothing factor ฮฒ = 0.9
  • Rewards using gradients from LM head provides best performance
  • G + 1 gradients stored at any time
  • UCB1-FLAD outperforms EXP3-FLAD on RTE
  • EXP3-FLAD outperforms UCB1-FLAD on COPA
  • Gradient alignment signal is particularly noisy for EXP3-FLAD
  • EXP3 never makes large separations in empirical sampling distribution
  • Detailed results from main experiment including direct fine-tuning, exploration-only, exploitation-only baselines and proposed methods