Link to paper

The full paper is available here.

You can also find the paper on PapersWithCode here.

Abstract

  • Proposed a model called Meta-World Conditional Neural Processes (MW-CNP)
  • Model enables an agent to sample from its own “hallucination”
  • Aim is to reduce the agent’s interaction with the target environment at test time
  • Obtain a latent representation of the transition dynamics from a single rollout from the test environment
  • Few-shot learning by interacting with the “hallucination” generated by the meta-world model
  • Agent can adapt to an unseen target environment with fewer samples than baselines
  • Agent does not have access to the task parameters throughout training and testing
  • MW-CNP is trained on offline interaction data logged during meta-training

Paper Content

Introduction

  • Fast adaptation under uncertainty has various applications in the real-world
  • Robustness in these settings requires a diverse set of skills
  • World model generation is an integral part of open-ended learning
  • Meta-reinforcement learning is a promising direction for increasing robustness
  • Propose generating world models from the agent’s experience to reduce sample efficiency
  • Meta-learning provides a framework for learning new tasks more efficiently from limited experience
  • Meta-RL framework assumes tasks share a common structure but differ in transition dynamics
  • Context-based meta-RL methods assume each task can be represented by a low-dimensional context variable
  • Memory-augmented models and expert demonstrations are used to learn a contextual representation
  • MAML is a gradient-based meta-learning algorithm
  • NORML is an extension of MAML-RL for settings where the environment dynamics change
  • CNP model is used to represent a family of functions using Bayesian Inference

Preliminaries

Conditional neural processes

  • Parameter sharing encoders and a decoder (query network) make up the architecture of CNPs
  • Random input and true output pairs are sampled from a function
  • Encoder networks turn each pair into a latent representation
  • Average of the representations is obtained for invariance
  • Target input query is concatenated with the average representation and fed to the query network
  • Query network outputs predicted mean and standard deviation for the queried input

No-reward meta learning

  • NORML learns a pseudo-advantage function to guide meta-policy adaptation
  • The pseudo-advantage function is used to compute task-specific parameters from a set of state transitions without reward signal

Proposed method: meta-world conditional neural process (mw-cnp)

  • Problem setup: Few-shot learning, goal is to quickly adapt to unseen target task using few labeled data in target environment
  • Meta-World Conditional Neural Processes (MW-CNP): Reduce number of samples required from target environment, generate world models from fewer samples, use to obtain inexpensive rollouts for finetuning at test time
  • Offline meta-learning: Store transitions for each task as set of observations without task parameter
  • MW-CNP training: Unlabeled batches of Markov Decision Process (MDP) tuples collected during online meta-learning
  • MW-CNP: Obtain latent representation of hidden environment transition function, concatenate with [s q , a q ] to predict distribution parameters of next state, use to generate rollouts
  • Finetuning: Feed experiences from generated world model to learned advantage function, finetune meta policy for fast adaptation to target task using estimated advantage values and combined set of MW-CNP generated rollouts and single target environment rollout

Experiments

  • MW-CNP requires less interaction with the environment than NORML
  • MW-CNP can adapt quickly to unseen tasks

2d point agent with unknown artificial force field

  • Goal of the point agent is to move to position [x=1,y=0] on a 2D plane
  • Meta-RL setting used with same reward function across multiple tasks
  • Different tasks created by generating different artificial force fields
  • 5000 tasks defined over [-π, π] interval
  • Agent initially trained across distribution of 5000 tasks
  • Tested in unseen target task
  • Oracle agent uses 25 actual rollouts from target environment
  • NORML and MW-CNP use 1 rollout for fine-tuning meta-policy
  • MW-CNP outperforms NORML when using same amount of actual rollouts
  • MW-CNP uses 25 mixed rollouts, similar to Oracle NORML
  • Samples generated from MW-CNP can be used for finetuning meta-policy
  • Results not symmetric across meta-tasks due to gradient bias

Walker-2d randomized agent dynamics parameters

  • Evaluated MW-CNP in Walker-2D-Rand-Params environment
  • Parameters sampled from uniform distribution range
  • 40 tasks sampled for meta-training, 100 for meta-testing
  • Figure 9 shows post-update reward in meta-testing
  • Table 2 shows increased sample efficiency and meta-test adaptation performance

Conclusion

  • MW-CNP framework can be used to collect meaningful hallucinated rollouts
  • MW-CNP performance matched ORACLE and outperformed it in complex tasks
  • MW-CNP model generated samples from fewer MDP tuples, increasing sample efficiency
  • Using generated data for meta-updates increases sample efficiency and can yield superior performance
  • Extension of work to high-dimensional sensorimotor spaces is an interesting research direction