Link to paper
The full paper is available here.
You can also find the paper on PapersWithCode here.
Abstract
- Learning from multiple tasks sequentially
- HyperTransformer (HT) generates specialized task-specific CNN weights from support set
- Re-use generated weights as input to HT for next task
- Continual HyperTransformer method with prototypical loss can learn and retain knowledge about past tasks
Paper Content
Introduction
- Involves learning from a continuous stream of tasks with a small number of examples
- Retains previously learned information
- Useful for classifying a large number of classes with limited observations
- Can enable robots to adapt to changing environments
- Can allow for privacy-preserving learning
- Uses HyperTransformer (HT) to meta-learn from episodes
- HT decouples domain knowledge model from learner
- Proposed modification is Continual HyperTransformer (CHT)
- CHT updates CNN weights with new task information while retaining knowledge of previous tasks
- Uses prototypical loss to maintain and update set of prototypes
- Evaluated in three realistic scenarios
- Does not suffer from catastrophic forgetting
- Performance of given weight improves with more tasks
- Can be used to learn a larger number of tasks than originally trained for
Related work
- Few-shot learning can be divided into two categories: metric-based learning and optimization-based learning
- Metric-based methods treat every task independently and have no memory of the past tasks
- Optimization-based methods learn an initial fixed embedding, which is later adapted to a specific task
- Continual learning methods can be grouped into three categories: rehearsal, regularization and architectural
- Rehearsal methods inject replay data from past tasks or distill a part of a network
- Regularization methods introduce an explicit regularization function when learning new tasks
- Architectural methods modify the network architecture with additional task-specific modules
- Our approach reuses the same principle that made HT work: decoupling the specialized representation model from the domain-aware Transformer model
- Incremental few-shot learning adapts a few-shot task to an existing base classifier without forgetting the original data
Continual few-shot learning
- Problem of continual few-shot learning
- Given a series of tasks, each task has K classes and N labeled examples for each class
- Classes for different tasks can be chosen in different ways
- Learner needs to produce CNN weights based on support set and previously generated weights
- Task-incremental learning: identify class attribute given sample and task attribute
- Class-incremental learning: identify class and task attributes of samples
- Test performance of trained model on episodes from holdout set of classes
- HyperTransformer (HT) uses self-attention mechanism to generate CNN weights
- Continual HyperTransformer (CHT) extends HT to handle continual stream of tasks
- CHT uses generated weights from past tasks as input weight embeddings
- CHT generates weights suited for all tasks up to current task
- CHT can be trained on T tasks and run for any number of tasks τ < T
- CHT can be recurrent and generate weights for additional tasks beyond the ones it was trained
Prototypical loss
- Algorithm 1 is used to perform class-incremental learning using HyperTransformer with Prototypical Loss
- Cross-entropy loss is not suitable for continual learning
- Domain-incremental learning is an option, but it causes collisions
- Accuracy decreases dramatically with increasing number of tasks
- Loss function needs to decouple class predictions while keeping embedding space fixed
- Solution is to learn location of class prototypes from support set
- Task-incremental and class-incremental learning scenarios
- Loss is minimized by negative log probability of chosen softmax
- Prototypes are frozen after being computed
- Privacy-preserving scenario possible with weights and prototypes
Connection between prototypical loss and maml
- The core idea behind the prototypical loss is a special case of a 1-step MAML-like learning algorithm.
- The model adjusts the final logits layer to adapt the pretrained embedding to a new task.
- A single step of the gradient descent results in logits weight and bias updates that rely on computing a dot-product of the pretrained embedding with “prototypes”.
Experiments
- Experiments conducted using Omniglot and tieredImageNet datasets
- Network weights composed of four convolutional blocks and a single dense layer
- For Omniglot, 8 filters for convolutional layers and 20-dim FC layer used
- For tieredImageNet, 64 filters for convolutional and 40-dim for the FC layer used
- Models trained in episodic fashion, accuracy calculated from 1024 random episodic evaluations
Learning from mini-batches
- Three models were compared using a set of four 5-way 1-shot support set batches
- Cross-entropy loss was used since the label set was the same for all batches
- Each support set batch contained a single example per class
- Test accuracies for the three models were equal to 67.9%, 68.0% and 68.3% respectively
- HT trained with just one or two samples per class performed significantly worse, reaching test accuracies of 56.2% and 62.9% respectively
- Updating generated CNN weights using information from multiple support set batches can achieve performance comparable to processing all samples in a single pass with HT
Learning from tasks within a single domain
- Considered a scenario where tasks consist of classes from a single overall distribution
- Presented results of two models: one trained on 20-way, 1-shot tasks with classes sampled from Omniglot dataset, and another trained on 5-way, 5-shot tasks with classes sampled from tieredImageNet dataset
- Compared performance of CHT to two baseline models: Constant ProtoNet and Merged HyperTransformer
- For best results on Constant ProtoNet, had to increase number of classes by a factor of 5
- For task-incremental learning, saw positive backward knowledge transfer, where performance on past tasks improved as more weights were generated
- For class-incremental learning, accuracy of all models decreased as more tasks were included in the prediction
- CHT had better results than Constant ProtoNet up to the number of tasks it was trained for
- CHT was able to outperform Merged HyperTransformer for Omniglot, even though MergedHT had access to information about all tasks from the beginning
Learning from tasks across multiple domains
- Experiments used same general distribution and image domain
- ConstPN would suffer in accuracy if tasks were from different distributions and image domains
- CHT can adapt sample representations for different image domains
- Multi-domain episode generator used tasks from various image datasets
- ConstPN accuracy of 53%, 52.8%, 50.8% for task 0, 1, combined tasks
- CHT accuracy of 56.2%, 55.2%, 53.8% for task 0, 1, combined tasks
- CHT better at adapting to multi-domain task distribution
Conclusions
- Proposed Continual HyperTransformer model has efficient few-shot learner capabilities
- Can generate CNN weights on the fly with no training required
- Can update weights with information from new tasks
- Learning occurs without catastrophic forgetting
- Prototype loss optimizes location of prototypes of all classes of every task
- Single trained CHT model can be used in task-incremental and class-incremental scenarios
- Cross-entropy loss leads to collisions when more tasks are added
- Class-incremental objective performs better on class-incremental problems
- UMAP projections of embeddings show tasks are well separated
- ConstPN model does not make distinction between tasks
- CHT model performs better than ConstPN
- Omniglot embeddings overlap, but accuracy is still high
- CHT model can handle larger number of tasks
- MergedHT and CHT models have significant difference in performance
- Figures 12-15 show additional experiments with Omniglot dataset