Skip to content

Commit

Permalink
Release v1.6.0 and bug fix for TRPO (#84)
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin authored Jul 12, 2022
1 parent db4c011 commit 087951d
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 11 deletions.
6 changes: 3 additions & 3 deletions docs/conda_env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ dependencies:
- cpuonly=1.0=0
- pip=21.1
- python=3.7
- pytorch=1.8.1=py3.7_cpu_0
- pytorch=1.11=py3.7_cpu_0
- pip:
- gym>=0.17.2
- gym==0.21
- cloudpickle
- opencv-python-headless
- pandas
Expand All @@ -17,5 +17,5 @@ dependencies:
- sphinx_autodoc_typehints
- stable-baselines3>=1.3.0
- sphinx>=4.2
# See https://github.com/readthedocs/sphinx_rtd_theme/issues/1115
- sphinx_rtd_theme>=1.0
- sphinx_copybutton
13 changes: 12 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@
except ImportError:
enable_spell_check = False

# Try to enable copy button
try:
import sphinx_copybutton # noqa: F401

enable_copy_button = True
except ImportError:
enable_copy_button = False

# source code directory, relative to this file, for sphinx-autobuild
sys.path.insert(0, os.path.abspath(".."))

Expand Down Expand Up @@ -51,7 +59,7 @@ def __getattr__(cls, name):
# -- Project information -----------------------------------------------------

project = "Stable Baselines3 - Contrib"
copyright = "2020, Stable Baselines3"
copyright = "2022, Stable Baselines3"
author = "Stable Baselines3 Contributors"

# The short X.Y version
Expand Down Expand Up @@ -83,6 +91,9 @@ def __getattr__(cls, name):
if enable_spell_check:
extensions.append("sphinxcontrib.spelling")

if enable_copy_button:
extensions.append("sphinx_copybutton")

# Add any paths that contain templates here, relative to this directory.
templates_path = ["_templates"]

Expand Down
5 changes: 3 additions & 2 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Changelog
==========

Release 1.5.1a9 (WIP)
Release 1.6.0 (2022-07-11)
-------------------------------

**Add RecurrentPPO (aka PPO LSTM)**
Expand All @@ -25,8 +25,9 @@ New Features:
- Added ``RecurrentPPO`` (aka PPO LSTM)

Bug Fixes:
- Fixed a bug in ``RecurrentPPO`` when calculating the masked loss functions (@rnederstigt)
^^^^^^^^^^
- Fixed a bug in ``RecurrentPPO`` when calculating the masked loss functions (@rnederstigt)
- Fixed a bug in ``TRPO`` where kl divergence was not implemented for ``MultiDiscrete`` space

Deprecations:
^^^^^^^^^^^^^
Expand Down
6 changes: 3 additions & 3 deletions sb3_contrib/trpo/trpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
import numpy as np
import torch as th
from gym import spaces
from stable_baselines3.common.distributions import kl_divergence
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.policies import ActorCriticPolicy, BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, RolloutBufferSamples, Schedule
from stable_baselines3.common.utils import explained_variance
from torch import nn
from torch.distributions import kl_divergence
from torch.nn import functional as F

from sb3_contrib.common.utils import conjugate_gradient_solver, flat_grad
Expand Down Expand Up @@ -279,7 +279,7 @@ def train(self) -> None:
policy_objective = (advantages * ratio).mean()

# KL divergence
kl_div = kl_divergence(distribution.distribution, old_distribution.distribution).mean()
kl_div = kl_divergence(distribution, old_distribution).mean()

# Surrogate & KL gradient
self.policy.optimizer.zero_grad()
Expand Down Expand Up @@ -332,7 +332,7 @@ def train(self) -> None:
new_policy_objective = (advantages * ratio).mean()

# New KL-divergence
kl_div = kl_divergence(distribution.distribution, old_distribution.distribution).mean()
kl_div = kl_divergence(distribution, old_distribution).mean()

# Constraint criteria:
# we need to improve the surrogate policy objective
Expand Down
2 changes: 1 addition & 1 deletion sb3_contrib/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.5.1a9
1.6.0
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[
"stable_baselines3>=1.5.1a7",
"stable_baselines3>=1.6.0",
],
description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin",
Expand Down
38 changes: 38 additions & 0 deletions tests/test_identity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import numpy as np
import pytest
from stable_baselines3.common.envs import IdentityEnv, IdentityEnvMultiBinary, IdentityEnvMultiDiscrete
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv

from sb3_contrib import QRDQN, TRPO

DIM = 4


@pytest.mark.parametrize("model_class", [QRDQN, TRPO])
@pytest.mark.parametrize("env", [IdentityEnv(DIM), IdentityEnvMultiDiscrete(DIM), IdentityEnvMultiBinary(DIM)])
def test_discrete(model_class, env):
env_ = DummyVecEnv([lambda: env])
kwargs = {}
n_steps = 1500
if model_class == QRDQN:
kwargs = dict(
learning_starts=0,
policy_kwargs=dict(n_quantiles=25, net_arch=[32]),
target_update_interval=10,
train_freq=2,
batch_size=256,
)
n_steps = 1500
# DQN only support discrete actions
if isinstance(env, (IdentityEnvMultiDiscrete, IdentityEnvMultiBinary)):
return
elif n_steps == TRPO:
kwargs = dict(n_steps=256, cg_max_steps=5)

model = model_class("MlpPolicy", env_, learning_rate=1e-3, gamma=0.4, seed=1, **kwargs).learn(n_steps)

evaluate_policy(model, env_, n_eval_episodes=20, reward_threshold=90, warn=False)
obs = env.reset()

assert np.shape(model.predict(obs)[0]) == np.shape(obs)

0 comments on commit 087951d

Please sign in to comment.