Link to paper

The full paper is available here.

You can also find the paper on PapersWithCode here.

Abstract

  • Traditional approaches to RL focus on learning decision policies from episodic decisions
  • Some approaches refine representations via auxiliary self-supervised losses while learning decision policies
  • Supervised language model cascades can adapt to many diverse manifolds
  • Transfer methods for language models require human supervision
  • Proposed self-supervised loss policy called contrastive distillation
  • Contrastive distillation outperforms common methods of transfer learning
  • Contrastive distillation is improved through sampling from memory

Paper Content

Introduction

  • Humans excel at generating contrastive self-learning data in regimes of limited data
  • Humans manifest latent axes of variation as sequences
  • Transfer learning and a diversity of online learning algorithms are needed to build more causal representations
  • Current machine learning approaches to transfer learning are less data-efficient
  • Unsupervised pre-training then online supervised transfer is the current paradigm
  • Auxiliary losses require view augmentations to be specified beforehand
  • Self-learning approaches have been shown to be effective at improving in-domain accuracy
  • Self-learning of sequences to model distribution shifts has been relatively unexplored
  • Self-supervised loss surface created via generation of auxiliary sequences with high mutual information improves transfer learning
  • Tradeoff between manifesting latent axes of variation and additional compute and training time
  • Contrastive distillation and sampling from memory can improve contrastive distillation

Core contributions

  • Synthesize 4 literatures into unifying model for self-learning
  • Introduce contrastive distillation to improve transfer learning
  • Sampling from episodic memory improves contrastive distillation
  • Algorithm for sampling negative examples for contrastive losses

Self-learning via expert iteration

  • Self-learning of large language models via CoT rollouts and majority voting has been shown to have promising results (Huang et al., 2022;Wei et al., 2022).
  • STaR (Zelikman et al., 2022) uses a rationalization bootstrapping technique to sample posterior updates given data, but does not consider contrastive trajectories and episodic memory.
  • Micheli et al. (2022) clusters continuous observations to form a vocabulary for a world model.
  • Polu et al. (2022) use expert iteration and a curriculum of starting proof contexts to self-train GPT-f to solve IMO math problems.

Self-learning via in-context learning

  • Self question-asking has been shown to improve language model performance and answer questions for robotics.
  • Hand-designed actor-critic-like transition functions have been applied to text editing and story-writing.

Self-learning via contrastive methods

  • Deng et al. (2022) and Li et al. (2022) show that contrasting likely sequences from paired sequence models can manifest latent axes of variation.
  • Brown et al. (2020) and Radford et al. demonstrate that large language models can do few-shot learning after pre-training on natural language data.
  • Chan et al. (2022) show that data distributions similar to 1st-person datasets and natural language can induce few-shot learning on Transformer architectures.

Memory

  • Several architectures have introduced per-cycle “working memory” for tasks requiring short context windows
  • End-to-end memory networks, Universal Transformers, and Neural Turing Machines have been used to improve performance on NLP tasks
  • Most approaches do not leverage latent natural language policies and use memory which is in-scope for a single pass
  • Persistent memory architectures have been used to introduce episodic memory in RL agents
  • Goyal et al. (2022) augment a MuZero agent with a retrieval-only episodic memory
  • BlenderBot 2.0 uses a supervised summarization scheme to persist important summaries into a long-term memory
  • Our approach is similar to MERLIN and BlenderBot 2.0, but is framed in the context of transfer learning and is less hand-engineered

Self-learning decision-making process

  • Formalize a discrete self-learning decision-making process
  • Proposal policy generates prediction task and expected observation
  • Solver policy generates selection-inference chain of adaptive length
  • Verifier discriminates noisy judgement of consistency
  • Update policy generates updates, memory and action
  • Reduces to standard supervised cross-entropy loss
  • Self-I/O constraint: output updates can be used as solver inputs or training data

Connection to linguistics: a (type 0) grammar that learns a posterior grammar conditioned on evidence

  • Self-learning process can be interpreted as a distribution over tokens
  • Self-learning process collects evidence from a continuous memory and applies contrastive updates using its own likely sequences

Constraints on update representations from multi-task self-learning

  • Multi-task self-learning has additional constraints
  • Constraints include task-adaptive length, source and target mutual information, information bottleneck, position-invariance, and iterative compositionality
  • Method demonstrated to fulfill these constraints where fine-tuning and other adaptation approaches fail

Contrastive distillation

  • Contrastive distillation is a method of conditionally sampling a sequence of tokens from a task observation.
  • It is a form of rationalization which satisfies the target mutual information constraint.
  • Contrastive distillation manifests the latent axes of variation of source and target distributions from weights to tokens.
  • It is contrasted with direct weight updates which do not fulfill the target mutual information constraint.

Contrastive distillation through fine-tuning

  • Update used as self-learning sequence
  • Includes x t and u t in later iterations
  • Referred to as contrastive distillation through fine-tuning

Contrastive distillation through memory

  • Contrastive distillation is stored in memory
  • Memory allows recurrent control flow
  • Memory lookups can bring a combination of contrastive distillations into scope simultaneously

Bayesian contrastive distillation

  • Contrastive distillation can be used to generate negative samples for contrastive losses.
  • Samples are taken from task-adaptive prior distributions.
  • Contrastive distillation can be used for data augmentation and compression.

Experiment: contrastive distillation improves nlp transfer learning

  • Experimented with update policy applied to transfer learning
  • Used teacher oracle (GPT-3) to test contrastive distillation and memory mechanics
  • Future experiments will use self-generation with larger student models

Datasets

  • Trained and tested on low-data configurations of bAbI and Com2Sense datasets
  • bAbI tests reasoning on episodic memory, Com2Sense tests reasoning on semantic memory
  • bAbI requires de novo reasoning, Com2Sense requires copying with adaptation
  • Both forms of reasoning are necessary for a successful self-learning agent

Proposal policy

  • Implemented proposal policy as a deterministic iterator
  • Yields 5 examples for each of the 20 bAbI tasks
  • Yields 100 datapoints for the singleton Com2Sense task

Solver policy

  • Implemented solver policy as single-hop solver
  • Fine-tuned from T5-3B (Raffel et al., 2020)
  • Various update policies used

Memory policy

  • Memory is implemented using FAISS
  • Memories are split into context embedding and updates
  • Updates are prefixed with a position ordinal
  • Updates are constrained to be less than 200 tokens long
  • Queries are sampled by combining a key embedding with PCA decomposition
  • During training, kq random queries are selected
  • During eval, the top kq-1 PCA queries + the key embedding are selected
  • Updates are injected FiD-style as a prefix of the solver model prompt
  • Symbolic updates are re-encoded using T5 embeddings at every epoch

Verifier policy

  • Use BLEURT metric to score predictions against known labels
  • BLEURT metric correlates well with accuracy of held-out datapoints

Update policy

  • Generate updates for datapoints using “why” prompts
  • Randomly sample an update and add as target prefix during fine-tuning
  • Index updates in memory with a total memory size of 1,000 sequences

Source environment

  • Randomly sample updates from teacher model at each epoch
  • Baseline case: no updates sampled
  • Contrastive distillation models: updates added as prefixes of proposed target y
  • Full string weighted using target loss to prioritize outputting well-formatted final answer
  • Contrastive distillation loss function used
  • Baseline model trained using unweighted language modeling loss
  • Contrastive distillation agent with memory: updates also added as memory examples
  • All models fine-tuned on source task until validation error plateaus

Target environment

  • Trained networks are decoded in target environment without fine-tuning
  • Generations are parsed into answers by matching regex

Baselines

  • Tested T5-3B without contrastive distillation
  • Experimented with in-context few-shot learning for T5-3B, but performance was not above noise

Results

  • The bathroom is south of the kitchen
  • Mary is carrying a bag of milk

Manifesting latent variables from weights to tokens improves transfer at the cost of compute

  • Contrastive distillation models achieve comparable performance to baseline on source task, but substantially outperform on target task
  • Contrastive distillation has inductive biases which better fulfill constraints of multitask transfer learning
  • Comes with cost in terms of training and inference time

Contrastive retrievals improve transfer

  • Transfer using distillation through memory improves over distillation which uses fine-tuning only
  • Ability to incorporate snippets of partially relevant information as explicit contrastive examples
  • Memory contents generated by a teacher model, rather than humans
  • Memory lookups on the target task have no obvious connection to the generated updates
  • Overfitting is a more serious problem in limited data regimes
  • Adding randomness to memory retrievals at training time significantly improved source task performance

Error cases

  • Generated selection-inference chains are not always autoregressively coherent
  • Sampling without filtering from the teacher oracle can lead to updates which diverge from the label
  • Filtering these updates is an important avenue for improvement
  • Cleaning the updates data yielded the best BLEURT score for transferring bAbI -> com2sense

Conclusion

  • Designing loss policies for decision-making agents to adapt to distribution shift is a challenging problem.
  • We present a promising approach via iteratively manifesting latent variables from weights into tokens.
  • This suggests a design tradeoff between generalizability and compute for language models used as self-learning decision-making agents.