Skip to content

Latest commit

 

History

History
66 lines (44 loc) · 3.05 KB

README.md

File metadata and controls

66 lines (44 loc) · 3.05 KB

Temporal Predictive Coding For Model-Based Planning In Latent Space

This is an implementation of the paper Temporal Predictive Coding For Model-Based Planning In Latent Space in Tensorflow 2 . We propose TPC - an information-theoretic representation learning framework for reinforcement learning from high-dimensional observations. TPC employs temporal predictive coding to encode elements that can be predicted across time in the environment. In addition, it learns the representations of these elements in conjunction with a recurrent state space model that allows for direct planning in latent space. Experiments show that our model is superior to existing methods in several DMControl tasks.

Details of the model architecture and experimental results can be found in our following paper:

@InProceedings{pmlr-139-nguyen21a,
  title = 	 {Temporal Predictive Coding For Model-Based Planning In Latent Space},
  author =       {Nguyen, Tung and Shu, Rui and Pham, Tuan and Bui, Hung and Ermon, Stefano},
  booktitle = 	 {Proceedings of the 38th International Conference on Machine Learning},
  year = 	 {2021},
  volume = 	 {139},
  series = 	 {Proceedings of Machine Learning Research},
  publisher =    {PMLR},
}

Please CITE our paper whenever this repository is used to help produce published results or incorporated into other software.

Installing

First, clone the repository:

https://github.com/tung-nd/TPC-tensorflow.git

Then install the dependencies as listed in tpc.yml and activate the environment:

conda env create -f tpc.yml
conda activate tpc

Training

python tpc.py --task dmc_cartpole_swingup --logdir ./logdir/standard/cartpole_swingup/TPC/1 --img_source_type none --seed 1

The above command will train the TPC agent on cartpole swingup task in the standard setting (i.e., no backgrounds), and the results will be saved in ./logdir/standard/cartpole_swingup/TPC/1. To run the agent in the natural background setting, first, download the source videos, then run by setting --img_source_type video. If --random_bg is also set to True, the background will be randomized for each time step.

Download background videos

There are two sources of background videos. The first is the kinetics400 dataset, which can be downloaded by simply running python download_videos.py. The other source is using simplistic random backgrounds, which can be generated by running python gen_rain.py.

Plotting

Visualize losses and videos of the agent during training:

tensorboard --logdir ./logdir/cartpole_swingup/TPC/1 --port 6010

Generate plots:

python plotting.py --indir ./logdir/standard --outdir ./plots/standard --xaxis step --yaxis test/return --bins 3e4 --cols 3

Acknowledgement

This codebase is largely based on the Tensorflow 2 implementation of Dreamer. In addition, many thanks to Amy Zhang for helping with the implementation of the natural background setting.