Decoupling Representation Learning from Reinforcement Learning


Adam Stooke
UC Berkeley
Kimin Lee
UC Berkeley
Pieter Abbeel
UC Berkeley
Misha Laskin
UC Berkeley


Links 👉 Github Code , ArXiv Paper , Cite BibTex
Media 📰 Twitter


Summary

In an effort to overcome limitations of reward-driven feature learning in deep reinforcement learning (RL) from images, we propose decoupling representation learning from policy learning. To this end, we introduce a new unsupervised learning (UL) task, called Augmented Temporal Contrast (ATC), which trains a convolutional encoder by pairing augmented versions of observations separated by a short time difference, using a contrastive loss.

In online RL experiments, we show that training the encoder exclusively using ATC matches or outperforms end-to-end RL in most environments. Additionally, we benchmark several leading UL algorithms by pre-training encoders on expert demonstrations and using them, with weights frozen, in RL agents; we find that agents using ATC-trained encoders outperform all others.

We also train multi-task encoders on data from multiple environments and show generalization to different downstream RL tasks. Finally, we ablate components of ATC, and introduce a new data augmentation to enable replay of (compressed) latent images from pre-trained encoders when RL requires augmentation. Our experiments span visually diverse RL benchmarks in DeepMind Control, DeepMind Lab, and Atari,


Augmented Temporal Contrast

Our method, Augmented Temporal Contrast (ATC), trains the CNN encoder with an unsupervised objective and then learns RL policies on top of the extracted features. The unsupervised objective is temporal contrast of observations and we also introduce a novel augmentation on the latent vectors - latent random shift - to regularize the latent features that the RL algorithm learns from. The full architecture is summarized in the image below:



Online RL Results

We show that ATC matches or outperforms state-of-the-art end-to-end RL on all DMControl & DMLab environments and most Atari environments. In Atari, we also show that initializing the CNN encoder with unsupervised pretraining on random data improves performance.



Benchmarking UL for RL

We also benchmarked a variety of unsupervised objectives for learning features. For these experiments, we use a set up similar to how unsupervised representations are evaluated in computer vision. We collect a dataset of expert demonstrations, then we learn features from this data, finally we freeze the CNN encoder and train RL on top of the learned features for a downstream task. This is exactly how state-of-the-art unsupervised learning methods in vision (SimCLR, MoCo, BYOL) evaluate the quality of the unsupervised representations. We extensively benchmark ATC features against those learned with the RL objective, contrastive learning (CURL), a variational autoencoder, inverse dynamics, among others and find that ATC is the state-of-the-art unsupervised representation learning algorithm for RL across all three environments.



ATC Features for Multi-task Learning

We also show that unsupervised pre-training with ATC results in features that are useful for multi-task learning. In these experiments, we collect a dataset of demonstrations across 4 DMControl tasks, and train a single encoder. We then learn RL policies on top of the frozen encoder features for 8 DMControl tasks. Surpringly, features learned by ATC enable efficient learning for both train and test environments, which were not included during the original encoder learning phase.



For more information, check out our paper and codebase.



More details on ArXiv