You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This repository contains the implementation of the PTR algorithm described in the paper: Pre-Training for Robots: Leveraging Diverse Multitask Data via Offline Reinforcement Learning.
data_preprocessing/generate_numpy.py: contains the code to preprocess the data.
jaxrl2/agents/cql_encoder_sep_parallel: contains our parallelized implementation of the PTR algorithm. This code is builds on the ideas introduced in the JAX TPU Colab. Note the parallelization environment is adaptive and does work with single GPU/CPU as well.
jaxrl2/utils: contains the code for the environment and dataset wrappers.
examples/configs: contains the config files for the dataset
examples/scripts: contains the script(s) to run the experiment
The jax_tpu.yml file is located in the root directory of this repository.
Public Datasets
You can find the datasets that were used for this paper here.
Acknowledgements
Our repostiory is based off of the JAX RL2 repository. We thank the authors for making their code public. We utilized on a earlier private version of the repository for our experiments. We have made the necessary changes to make it compatible with the latest version of the repository.
About
This repository contains the implementation of the PTR algorithm described in the paper: Pre-Training for Robots: Leveraging Diverse Multitask Data via Offline Reinforcement Learning.