Skip to content

Commit

Permalink
Release v2.1.0 (#204)
Browse files Browse the repository at this point in the history
* Release v2.1.0

* Fix mypy

* Fix warnings in tests
  • Loading branch information
araffin authored Aug 17, 2023
1 parent dfa23bd commit 67d3eef
Show file tree
Hide file tree
Showing 10 changed files with 13 additions and 33 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
SHELL=/bin/bash
LINT_PATHS=sb3_contrib/ tests/ setup.py
LINT_PATHS=sb3_contrib/ tests/ setup.py docs/conf.py

pytest:
./scripts/run_tests.sh
Expand Down
21 changes: 3 additions & 18 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#
import os
import sys
from unittest.mock import MagicMock
from typing import Dict

# We CANNOT enable 'sphinxcontrib.spelling' because ReadTheDocs.org does not support
# PyEnchant.
Expand All @@ -36,21 +36,6 @@
sys.path.insert(0, os.path.abspath(".."))


class Mock(MagicMock):
__subclasses__ = []

@classmethod
def __getattr__(cls, name):
return MagicMock()


# Mock modules that requires C modules
# Note: because of that we cannot test examples using CI
# 'torch', 'torch.nn', 'torch.nn.functional',
# DO not mock modules for now, we will need to do that for read the docs later
MOCK_MODULES = []
sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES)

# Read version from file
version_file = os.path.join(os.path.dirname(__file__), "../sb3_contrib", "version.txt")
with open(version_file) as file_handler:
Expand All @@ -59,7 +44,7 @@ def __getattr__(cls, name):
# -- Project information -----------------------------------------------------

project = "Stable Baselines3 - Contrib"
copyright = "2022, Stable Baselines3"
copyright = "2023, Stable Baselines3"
author = "Stable Baselines3 Contributors"

# The short X.Y version
Expand Down Expand Up @@ -171,7 +156,7 @@ def setup(app):

# -- Options for LaTeX output ------------------------------------------------

latex_elements = {
latex_elements: Dict[str, str] = {
# The paper size ('letterpaper' or 'a4paper').
#
# 'papersize': 'letterpaper',
Expand Down
5 changes: 3 additions & 2 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,22 @@ Changelog
==========


Release 2.1.0a0 (WIP)
Release 2.1.0 (2023-08-17)
--------------------------

Breaking Changes:
^^^^^^^^^^^^^^^^^
- Removed Python 3.7 support
- SB3 now requires PyTorch > 1.13
- Upgraded to Stable-Baselines3 >= 2.1.0

New Features:
^^^^^^^^^^^^^
- Added Python 3.11 support

Bug Fixes:
^^^^^^^^^^
- Fixed MaskablePPO ignoring stats_window_size argument
- Fixed MaskablePPO ignoring ``stats_window_size`` argument

Deprecations:
^^^^^^^^^^^^^
Expand Down
2 changes: 1 addition & 1 deletion sb3_contrib/common/recurrent/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def predict(
observation, vectorized_env = self.obs_to_tensor(observation)

if isinstance(observation, dict):
n_envs = observation[list(observation.keys())[0]].shape[0]
n_envs = observation[next(iter(observation.keys()))].shape[0]
else:
n_envs = observation.shape[0]
# state : (n_layers, n_envs, dim)
Expand Down
1 change: 0 additions & 1 deletion sb3_contrib/common/vec_env/async_eval.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import multiprocessing
import multiprocessing as mp
from collections import defaultdict
from typing import Callable, List, Optional, Tuple, Union
Expand Down
2 changes: 1 addition & 1 deletion sb3_contrib/qrdqn/qrdqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def predict(
if not deterministic and np.random.rand() < self.exploration_rate:
if self.policy.is_vectorized_observation(observation):
if isinstance(observation, dict):
n_batch = observation[list(observation.keys())[0]].shape[0]
n_batch = observation[next(iter(observation.keys()))].shape[0]
else:
n_batch = observation.shape[0]
action = np.array([self.action_space.sample() for _ in range(n_batch)])
Expand Down
2 changes: 1 addition & 1 deletion sb3_contrib/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.1.0a0
2.1.0
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[
"stable_baselines3>=2.0.0",
"stable_baselines3>=2.1.0",
],
description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin",
Expand Down
7 changes: 1 addition & 6 deletions tests/test_dict_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,6 @@ def __init__(
# Add dictionary observation inside observation space
self.observation_space.spaces["nested-dict"] = spaces.Dict({"nested-dict-discrete": spaces.Discrete(4)})

def seed(self, seed=None):
if seed is not None:
self.observation_space.seed(seed)

def step(self, action):
reward = 0.0
done = truncated = False
Expand Down Expand Up @@ -103,8 +99,7 @@ def test_consistency(model_class):
dict_env = DummyDictEnv(use_discrete_actions=use_discrete_actions, vec_only=True)
dict_env = gym.wrappers.TimeLimit(dict_env, 100)
env = gym.wrappers.FlattenObservation(dict_env)
dict_env.seed(10)
obs, _ = dict_env.reset()
obs, _ = dict_env.reset(seed=10)

kwargs = {}
n_steps = 256
Expand Down
2 changes: 1 addition & 1 deletion tests/wrappers/test_action_masker.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,5 +76,5 @@ def test_action_masks_returns_expected_result():

# Only one valid action expected
masks = env.action_masks()
masks[env.state] = not masks[env.state] # Bit-flip the one expected valid action
masks[env.unwrapped.state] = not masks[env.unwrapped.state] # Bit-flip the one expected valid action
assert all([not mask for mask in masks])

0 comments on commit 67d3eef

Please sign in to comment.