Link to paper

The full paper is available here.

You can also find the paper on PapersWithCode here.

Abstract

  • Learning from multiple tasks sequentially
  • HyperTransformer (HT) generates specialized task-specific CNN weights from support set
  • Re-use generated weights as input to HT for next task
  • Continual HyperTransformer method with prototypical loss can learn and retain knowledge about past tasks

Paper Content

Introduction

  • Involves learning from a continuous stream of tasks with a small number of examples
  • Retains previously learned information
  • Useful for classifying a large number of classes with limited observations
  • Can enable robots to adapt to changing environments
  • Can allow for privacy-preserving learning
  • Uses HyperTransformer (HT) to meta-learn from episodes
  • HT decouples domain knowledge model from learner
  • Proposed modification is Continual HyperTransformer (CHT)
  • CHT updates CNN weights with new task information while retaining knowledge of previous tasks
  • Uses prototypical loss to maintain and update set of prototypes
  • Evaluated in three realistic scenarios
  • Does not suffer from catastrophic forgetting
  • Performance of given weight improves with more tasks
  • Can be used to learn a larger number of tasks than originally trained for
  • Few-shot learning can be divided into two categories: metric-based learning and optimization-based learning
  • Metric-based methods treat every task independently and have no memory of the past tasks
  • Optimization-based methods learn an initial fixed embedding, which is later adapted to a specific task
  • Continual learning methods can be grouped into three categories: rehearsal, regularization and architectural
  • Rehearsal methods inject replay data from past tasks or distill a part of a network
  • Regularization methods introduce an explicit regularization function when learning new tasks
  • Architectural methods modify the network architecture with additional task-specific modules
  • Our approach reuses the same principle that made HT work: decoupling the specialized representation model from the domain-aware Transformer model
  • Incremental few-shot learning adapts a few-shot task to an existing base classifier without forgetting the original data

Continual few-shot learning

  • Problem of continual few-shot learning
  • Given a series of tasks, each task has K classes and N labeled examples for each class
  • Classes for different tasks can be chosen in different ways
  • Learner needs to produce CNN weights based on support set and previously generated weights
  • Task-incremental learning: identify class attribute given sample and task attribute
  • Class-incremental learning: identify class and task attributes of samples
  • Test performance of trained model on episodes from holdout set of classes
  • HyperTransformer (HT) uses self-attention mechanism to generate CNN weights
  • Continual HyperTransformer (CHT) extends HT to handle continual stream of tasks
  • CHT uses generated weights from past tasks as input weight embeddings
  • CHT generates weights suited for all tasks up to current task
  • CHT can be trained on T tasks and run for any number of tasks τ < T
  • CHT can be recurrent and generate weights for additional tasks beyond the ones it was trained

Prototypical loss

  • Algorithm 1 is used to perform class-incremental learning using HyperTransformer with Prototypical Loss
  • Cross-entropy loss is not suitable for continual learning
  • Domain-incremental learning is an option, but it causes collisions
  • Accuracy decreases dramatically with increasing number of tasks
  • Loss function needs to decouple class predictions while keeping embedding space fixed
  • Solution is to learn location of class prototypes from support set
  • Task-incremental and class-incremental learning scenarios
  • Loss is minimized by negative log probability of chosen softmax
  • Prototypes are frozen after being computed
  • Privacy-preserving scenario possible with weights and prototypes

Connection between prototypical loss and maml

  • The core idea behind the prototypical loss is a special case of a 1-step MAML-like learning algorithm.
  • The model adjusts the final logits layer to adapt the pretrained embedding to a new task.
  • A single step of the gradient descent results in logits weight and bias updates that rely on computing a dot-product of the pretrained embedding with “prototypes”.

Experiments

  • Experiments conducted using Omniglot and tieredImageNet datasets
  • Network weights composed of four convolutional blocks and a single dense layer
  • For Omniglot, 8 filters for convolutional layers and 20-dim FC layer used
  • For tieredImageNet, 64 filters for convolutional and 40-dim for the FC layer used
  • Models trained in episodic fashion, accuracy calculated from 1024 random episodic evaluations

Learning from mini-batches

  • Three models were compared using a set of four 5-way 1-shot support set batches
  • Cross-entropy loss was used since the label set was the same for all batches
  • Each support set batch contained a single example per class
  • Test accuracies for the three models were equal to 67.9%, 68.0% and 68.3% respectively
  • HT trained with just one or two samples per class performed significantly worse, reaching test accuracies of 56.2% and 62.9% respectively
  • Updating generated CNN weights using information from multiple support set batches can achieve performance comparable to processing all samples in a single pass with HT

Learning from tasks within a single domain

  • Considered a scenario where tasks consist of classes from a single overall distribution
  • Presented results of two models: one trained on 20-way, 1-shot tasks with classes sampled from Omniglot dataset, and another trained on 5-way, 5-shot tasks with classes sampled from tieredImageNet dataset
  • Compared performance of CHT to two baseline models: Constant ProtoNet and Merged HyperTransformer
  • For best results on Constant ProtoNet, had to increase number of classes by a factor of 5
  • For task-incremental learning, saw positive backward knowledge transfer, where performance on past tasks improved as more weights were generated
  • For class-incremental learning, accuracy of all models decreased as more tasks were included in the prediction
  • CHT had better results than Constant ProtoNet up to the number of tasks it was trained for
  • CHT was able to outperform Merged HyperTransformer for Omniglot, even though MergedHT had access to information about all tasks from the beginning

Learning from tasks across multiple domains

  • Experiments used same general distribution and image domain
  • ConstPN would suffer in accuracy if tasks were from different distributions and image domains
  • CHT can adapt sample representations for different image domains
  • Multi-domain episode generator used tasks from various image datasets
  • ConstPN accuracy of 53%, 52.8%, 50.8% for task 0, 1, combined tasks
  • CHT accuracy of 56.2%, 55.2%, 53.8% for task 0, 1, combined tasks
  • CHT better at adapting to multi-domain task distribution

Conclusions

  • Proposed Continual HyperTransformer model has efficient few-shot learner capabilities
  • Can generate CNN weights on the fly with no training required
  • Can update weights with information from new tasks
  • Learning occurs without catastrophic forgetting
  • Prototype loss optimizes location of prototypes of all classes of every task
  • Single trained CHT model can be used in task-incremental and class-incremental scenarios
  • Cross-entropy loss leads to collisions when more tasks are added
  • Class-incremental objective performs better on class-incremental problems
  • UMAP projections of embeddings show tasks are well separated
  • ConstPN model does not make distinction between tasks
  • CHT model performs better than ConstPN
  • Omniglot embeddings overlap, but accuracy is still high
  • CHT model can handle larger number of tasks
  • MergedHT and CHT models have significant difference in performance
  • Figures 12-15 show additional experiments with Omniglot dataset