From 087951d34bb3f1f3f04bb15208b67fb25c093c00 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Tue, 12 Jul 2022 23:12:24 +0200 Subject: [PATCH] Release v1.6.0 and bug fix for TRPO (#84) --- docs/conda_env.yml | 6 +++--- docs/conf.py | 13 ++++++++++++- docs/misc/changelog.rst | 5 +++-- sb3_contrib/trpo/trpo.py | 6 +++--- sb3_contrib/version.txt | 2 +- setup.py | 2 +- tests/test_identity.py | 38 ++++++++++++++++++++++++++++++++++++++ 7 files changed, 61 insertions(+), 11 deletions(-) create mode 100644 tests/test_identity.py diff --git a/docs/conda_env.yml b/docs/conda_env.yml index 467a6a97..64b1aaf1 100644 --- a/docs/conda_env.yml +++ b/docs/conda_env.yml @@ -6,9 +6,9 @@ dependencies: - cpuonly=1.0=0 - pip=21.1 - python=3.7 - - pytorch=1.8.1=py3.7_cpu_0 + - pytorch=1.11=py3.7_cpu_0 - pip: - - gym>=0.17.2 + - gym==0.21 - cloudpickle - opencv-python-headless - pandas @@ -17,5 +17,5 @@ dependencies: - sphinx_autodoc_typehints - stable-baselines3>=1.3.0 - sphinx>=4.2 - # See https://github.com/readthedocs/sphinx_rtd_theme/issues/1115 - sphinx_rtd_theme>=1.0 + - sphinx_copybutton diff --git a/docs/conf.py b/docs/conf.py index 85c9c16e..d62bfc9e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -24,6 +24,14 @@ except ImportError: enable_spell_check = False +# Try to enable copy button +try: + import sphinx_copybutton # noqa: F401 + + enable_copy_button = True +except ImportError: + enable_copy_button = False + # source code directory, relative to this file, for sphinx-autobuild sys.path.insert(0, os.path.abspath("..")) @@ -51,7 +59,7 @@ def __getattr__(cls, name): # -- Project information ----------------------------------------------------- project = "Stable Baselines3 - Contrib" -copyright = "2020, Stable Baselines3" +copyright = "2022, Stable Baselines3" author = "Stable Baselines3 Contributors" # The short X.Y version @@ -83,6 +91,9 @@ def __getattr__(cls, name): if enable_spell_check: extensions.append("sphinxcontrib.spelling") +if enable_copy_button: + extensions.append("sphinx_copybutton") + # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 82766aa4..5c2a865a 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 1.5.1a9 (WIP) +Release 1.6.0 (2022-07-11) ------------------------------- **Add RecurrentPPO (aka PPO LSTM)** @@ -25,8 +25,9 @@ New Features: - Added ``RecurrentPPO`` (aka PPO LSTM) Bug Fixes: -- Fixed a bug in ``RecurrentPPO`` when calculating the masked loss functions (@rnederstigt) ^^^^^^^^^^ +- Fixed a bug in ``RecurrentPPO`` when calculating the masked loss functions (@rnederstigt) +- Fixed a bug in ``TRPO`` where kl divergence was not implemented for ``MultiDiscrete`` space Deprecations: ^^^^^^^^^^^^^ diff --git a/sb3_contrib/trpo/trpo.py b/sb3_contrib/trpo/trpo.py index f20a9ae7..cb35d730 100644 --- a/sb3_contrib/trpo/trpo.py +++ b/sb3_contrib/trpo/trpo.py @@ -6,12 +6,12 @@ import numpy as np import torch as th from gym import spaces +from stable_baselines3.common.distributions import kl_divergence from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm from stable_baselines3.common.policies import ActorCriticPolicy, BasePolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, RolloutBufferSamples, Schedule from stable_baselines3.common.utils import explained_variance from torch import nn -from torch.distributions import kl_divergence from torch.nn import functional as F from sb3_contrib.common.utils import conjugate_gradient_solver, flat_grad @@ -279,7 +279,7 @@ def train(self) -> None: policy_objective = (advantages * ratio).mean() # KL divergence - kl_div = kl_divergence(distribution.distribution, old_distribution.distribution).mean() + kl_div = kl_divergence(distribution, old_distribution).mean() # Surrogate & KL gradient self.policy.optimizer.zero_grad() @@ -332,7 +332,7 @@ def train(self) -> None: new_policy_objective = (advantages * ratio).mean() # New KL-divergence - kl_div = kl_divergence(distribution.distribution, old_distribution.distribution).mean() + kl_div = kl_divergence(distribution, old_distribution).mean() # Constraint criteria: # we need to improve the surrogate policy objective diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index 125ec275..dc1e644a 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -1.5.1a9 +1.6.0 diff --git a/setup.py b/setup.py index aede425b..bcc2c3c4 100644 --- a/setup.py +++ b/setup.py @@ -65,7 +65,7 @@ packages=[package for package in find_packages() if package.startswith("sb3_contrib")], package_data={"sb3_contrib": ["py.typed", "version.txt"]}, install_requires=[ - "stable_baselines3>=1.5.1a7", + "stable_baselines3>=1.6.0", ], description="Contrib package of Stable Baselines3, experimental code.", author="Antonin Raffin", diff --git a/tests/test_identity.py b/tests/test_identity.py new file mode 100644 index 00000000..6ad03174 --- /dev/null +++ b/tests/test_identity.py @@ -0,0 +1,38 @@ +import numpy as np +import pytest +from stable_baselines3.common.envs import IdentityEnv, IdentityEnvMultiBinary, IdentityEnvMultiDiscrete +from stable_baselines3.common.evaluation import evaluate_policy +from stable_baselines3.common.vec_env import DummyVecEnv + +from sb3_contrib import QRDQN, TRPO + +DIM = 4 + + +@pytest.mark.parametrize("model_class", [QRDQN, TRPO]) +@pytest.mark.parametrize("env", [IdentityEnv(DIM), IdentityEnvMultiDiscrete(DIM), IdentityEnvMultiBinary(DIM)]) +def test_discrete(model_class, env): + env_ = DummyVecEnv([lambda: env]) + kwargs = {} + n_steps = 1500 + if model_class == QRDQN: + kwargs = dict( + learning_starts=0, + policy_kwargs=dict(n_quantiles=25, net_arch=[32]), + target_update_interval=10, + train_freq=2, + batch_size=256, + ) + n_steps = 1500 + # DQN only support discrete actions + if isinstance(env, (IdentityEnvMultiDiscrete, IdentityEnvMultiBinary)): + return + elif n_steps == TRPO: + kwargs = dict(n_steps=256, cg_max_steps=5) + + model = model_class("MlpPolicy", env_, learning_rate=1e-3, gamma=0.4, seed=1, **kwargs).learn(n_steps) + + evaluate_policy(model, env_, n_eval_episodes=20, reward_threshold=90, warn=False) + obs = env.reset() + + assert np.shape(model.predict(obs)[0]) == np.shape(obs)