Link to paper
The full paper is available here.
You can also find the paper on PapersWithCode here.
Abstract
- Traditional approaches to RL focus on learning decision policies from episodic decisions
- Some approaches refine representations via auxiliary self-supervised losses while learning decision policies
- Supervised language model cascades can adapt to many diverse manifolds
- Transfer methods for language models require human supervision
- Proposed self-supervised loss policy called contrastive distillation
- Contrastive distillation outperforms common methods of transfer learning
- Contrastive distillation is improved through sampling from memory
Paper Content
Introduction
- Humans excel at generating contrastive self-learning data in regimes of limited data
- Humans manifest latent axes of variation as sequences
- Transfer learning and a diversity of online learning algorithms are needed to build more causal representations
- Current machine learning approaches to transfer learning are less data-efficient
- Unsupervised pre-training then online supervised transfer is the current paradigm
- Auxiliary losses require view augmentations to be specified beforehand
- Self-learning approaches have been shown to be effective at improving in-domain accuracy
- Self-learning of sequences to model distribution shifts has been relatively unexplored
- Self-supervised loss surface created via generation of auxiliary sequences with high mutual information improves transfer learning
- Tradeoff between manifesting latent axes of variation and additional compute and training time
- Contrastive distillation and sampling from memory can improve contrastive distillation
Core contributions
- Synthesize 4 literatures into unifying model for self-learning
- Introduce contrastive distillation to improve transfer learning
- Sampling from episodic memory improves contrastive distillation
- Algorithm for sampling negative examples for contrastive losses
Related works
Self-learning via expert iteration
- Self-learning of large language models via CoT rollouts and majority voting has been shown to have promising results (Huang et al., 2022;Wei et al., 2022).
- STaR (Zelikman et al., 2022) uses a rationalization bootstrapping technique to sample posterior updates given data, but does not consider contrastive trajectories and episodic memory.
- Micheli et al. (2022) clusters continuous observations to form a vocabulary for a world model.
- Polu et al. (2022) use expert iteration and a curriculum of starting proof contexts to self-train GPT-f to solve IMO math problems.
Self-learning via in-context learning
- Self question-asking has been shown to improve language model performance and answer questions for robotics.
- Hand-designed actor-critic-like transition functions have been applied to text editing and story-writing.
Self-learning via contrastive methods
- Deng et al. (2022) and Li et al. (2022) show that contrasting likely sequences from paired sequence models can manifest latent axes of variation.
- Brown et al. (2020) and Radford et al. demonstrate that large language models can do few-shot learning after pre-training on natural language data.
- Chan et al. (2022) show that data distributions similar to 1st-person datasets and natural language can induce few-shot learning on Transformer architectures.
Memory
- Several architectures have introduced per-cycle “working memory” for tasks requiring short context windows
- End-to-end memory networks, Universal Transformers, and Neural Turing Machines have been used to improve performance on NLP tasks
- Most approaches do not leverage latent natural language policies and use memory which is in-scope for a single pass
- Persistent memory architectures have been used to introduce episodic memory in RL agents
- Goyal et al. (2022) augment a MuZero agent with a retrieval-only episodic memory
- BlenderBot 2.0 uses a supervised summarization scheme to persist important summaries into a long-term memory
- Our approach is similar to MERLIN and BlenderBot 2.0, but is framed in the context of transfer learning and is less hand-engineered
Self-learning decision-making process
- Formalize a discrete self-learning decision-making process
- Proposal policy generates prediction task and expected observation
- Solver policy generates selection-inference chain of adaptive length
- Verifier discriminates noisy judgement of consistency
- Update policy generates updates, memory and action
- Reduces to standard supervised cross-entropy loss
- Self-I/O constraint: output updates can be used as solver inputs or training data
Connection to linguistics: a (type 0) grammar that learns a posterior grammar conditioned on evidence
- Self-learning process can be interpreted as a distribution over tokens
- Self-learning process collects evidence from a continuous memory and applies contrastive updates using its own likely sequences
Constraints on update representations from multi-task self-learning
- Multi-task self-learning has additional constraints
- Constraints include task-adaptive length, source and target mutual information, information bottleneck, position-invariance, and iterative compositionality
- Method demonstrated to fulfill these constraints where fine-tuning and other adaptation approaches fail
Contrastive distillation
- Contrastive distillation is a method of conditionally sampling a sequence of tokens from a task observation.
- It is a form of rationalization which satisfies the target mutual information constraint.
- Contrastive distillation manifests the latent axes of variation of source and target distributions from weights to tokens.
- It is contrasted with direct weight updates which do not fulfill the target mutual information constraint.
Contrastive distillation through fine-tuning
- Update used as self-learning sequence
- Includes x t and u t in later iterations
- Referred to as contrastive distillation through fine-tuning
Contrastive distillation through memory
- Contrastive distillation is stored in memory
- Memory allows recurrent control flow
- Memory lookups can bring a combination of contrastive distillations into scope simultaneously
Bayesian contrastive distillation
- Contrastive distillation can be used to generate negative samples for contrastive losses.
- Samples are taken from task-adaptive prior distributions.
- Contrastive distillation can be used for data augmentation and compression.
Experiment: contrastive distillation improves nlp transfer learning
- Experimented with update policy applied to transfer learning
- Used teacher oracle (GPT-3) to test contrastive distillation and memory mechanics
- Future experiments will use self-generation with larger student models
Datasets
- Trained and tested on low-data configurations of bAbI and Com2Sense datasets
- bAbI tests reasoning on episodic memory, Com2Sense tests reasoning on semantic memory
- bAbI requires de novo reasoning, Com2Sense requires copying with adaptation
- Both forms of reasoning are necessary for a successful self-learning agent
Proposal policy
- Implemented proposal policy as a deterministic iterator
- Yields 5 examples for each of the 20 bAbI tasks
- Yields 100 datapoints for the singleton Com2Sense task
Solver policy
- Implemented solver policy as single-hop solver
- Fine-tuned from T5-3B (Raffel et al., 2020)
- Various update policies used
Memory policy
- Memory is implemented using FAISS
- Memories are split into context embedding and updates
- Updates are prefixed with a position ordinal
- Updates are constrained to be less than 200 tokens long
- Queries are sampled by combining a key embedding with PCA decomposition
- During training, kq random queries are selected
- During eval, the top kq-1 PCA queries + the key embedding are selected
- Updates are injected FiD-style as a prefix of the solver model prompt
- Symbolic updates are re-encoded using T5 embeddings at every epoch
Verifier policy
- Use BLEURT metric to score predictions against known labels
- BLEURT metric correlates well with accuracy of held-out datapoints
Update policy
- Generate updates for datapoints using “why” prompts
- Randomly sample an update and add as target prefix during fine-tuning
- Index updates in memory with a total memory size of 1,000 sequences
Source environment
- Randomly sample updates from teacher model at each epoch
- Baseline case: no updates sampled
- Contrastive distillation models: updates added as prefixes of proposed target y
- Full string weighted using target loss to prioritize outputting well-formatted final answer
- Contrastive distillation loss function used
- Baseline model trained using unweighted language modeling loss
- Contrastive distillation agent with memory: updates also added as memory examples
- All models fine-tuned on source task until validation error plateaus
Target environment
- Trained networks are decoded in target environment without fine-tuning
- Generations are parsed into answers by matching regex
Baselines
- Tested T5-3B without contrastive distillation
- Experimented with in-context few-shot learning for T5-3B, but performance was not above noise
Results
- The bathroom is south of the kitchen
- Mary is carrying a bag of milk
Manifesting latent variables from weights to tokens improves transfer at the cost of compute
- Contrastive distillation models achieve comparable performance to baseline on source task, but substantially outperform on target task
- Contrastive distillation has inductive biases which better fulfill constraints of multitask transfer learning
- Comes with cost in terms of training and inference time
Contrastive retrievals improve transfer
- Transfer using distillation through memory improves over distillation which uses fine-tuning only
- Ability to incorporate snippets of partially relevant information as explicit contrastive examples
- Memory contents generated by a teacher model, rather than humans
- Memory lookups on the target task have no obvious connection to the generated updates
- Overfitting is a more serious problem in limited data regimes
- Adding randomness to memory retrievals at training time significantly improved source task performance
Error cases
- Generated selection-inference chains are not always autoregressively coherent
- Sampling without filtering from the teacher oracle can lead to updates which diverge from the label
- Filtering these updates is an important avenue for improvement
- Cleaning the updates data yielded the best BLEURT score for transferring bAbI -> com2sense
Conclusion
- Designing loss policies for decision-making agents to adapt to distribution shift is a challenging problem.
- We present a promising approach via iteratively manifesting latent variables from weights into tokens.
- This suggests a design tradeoff between generalizability and compute for language models used as self-learning decision-making agents.