Skip to content

Commit

Permalink
Bugfix/ppo mask stats window size (#199)
Browse files Browse the repository at this point in the history
* fixes issue #198 regarding stats_window_size

* updates changelog

* updates test of stats_window_size

* updates test using maskable env

* removes print statement
  • Loading branch information
PatrickHelm authored Aug 1, 2023
1 parent 35f0625 commit dfa23bd
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 4 deletions.
1 change: 1 addition & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ New Features:

Bug Fixes:
^^^^^^^^^^
- Fixed MaskablePPO ignoring stats_window_size argument

Deprecations:
^^^^^^^^^^^^^
Expand Down
4 changes: 2 additions & 2 deletions sb3_contrib/ppo_mask/ppo_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,8 @@ def _setup_learn(
self.start_time = time.time_ns()
if self.ep_info_buffer is None or reset_num_timesteps:
# Initialize buffers if they don't exist, or reinitialize if resetting counters
self.ep_info_buffer = deque(maxlen=100)
self.ep_success_buffer = deque(maxlen=100)
self.ep_info_buffer = deque(maxlen=self._stats_window_size)
self.ep_success_buffer = deque(maxlen=self._stats_window_size)

if reset_num_timesteps:
self.num_timesteps = 0
Expand Down
5 changes: 3 additions & 2 deletions tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,13 @@ def test_advantage_normalization(normalize_advantage):
model.learn(64)


@pytest.mark.parametrize("algo", [TRPO, QRDQN])
@pytest.mark.parametrize("algo", [TRPO, QRDQN, MaskablePPO])
@pytest.mark.parametrize("stats_window_size", [1, 42])
def test_ep_buffers_stats_window_size(algo, stats_window_size):
"""Set stats_window_size for logging to non-default value and check if
ep_info_buffer and ep_success_buffer are initialized to the correct length"""
model = algo("MlpPolicy", "CartPole-v1", stats_window_size=stats_window_size)
env = InvalidActionEnvDiscrete() if algo == MaskablePPO else "CartPole-v1"
model = algo("MlpPolicy", env, stats_window_size=stats_window_size)
model.learn(total_timesteps=10)
assert model.ep_info_buffer.maxlen == stats_window_size
assert model.ep_success_buffer.maxlen == stats_window_size

0 comments on commit dfa23bd

Please sign in to comment.