Skip to content

Implementation of Constrained Policy Optimization with JAX

License

Notifications You must be signed in to change notification settings

lasgroup/jax-cpo

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

56 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Constrained Policy Optimization with JAX

Constrained Policy Optimization is a safe reinforcement learning algorithm that solves constrained Markov decision processes to ensure safety. Our implementation is a port of the original OpenAI implementation to JAX.

Install

First, make sure to have a python 3.10.12 installed.

Using Poetry

poetry install

Check out the additional (optional) installation groups in pyproject.toml for additional functionality.

Without Poetry

You have two options, cloning the repository (for example, for local development and hacking) or just install it as it is, directly from github.

  1. Clone: git clone https://github.com/lasgroup/jax-cpo.git, then cd jax-cpo and pip install -e .; or
  2. pip install git+https://git@github.com/lasgroup/jax-cpo

Usage

Via Trainer class

This is the easier entry point for running experiments. A usage example here.

With your own training loop

If you just want to use our implementation with a different training/evaluation setup, you can directly use the CPO class. The only required interface is via the __call__(observation: np.ndarray, train: bool) -> np.array function. The function implements the following:

  • Observes the state (provided by the environment), put it in an episodic buffer for the next policy update.
  • At each timestep use the current policy to return an action.
  • Whenever the train flag is true, and the buffer is full, a policy update is triggered.

Consult configs.yaml for hyper-parameters.

About

Implementation of Constrained Policy Optimization with JAX

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages