From 6dc34b320172ba1e3b4ed761b2dcb36819cf8877 Mon Sep 17 00:00:00 2001 From: Ruo Yu Tao Date: Thu, 9 May 2024 17:50:59 -0400 Subject: [PATCH] fix: pacman ghost valid action calculations result in NaNs --- README.md | 2 +- docs/environments/pac_man.md | 2 +- jumanji/__init__.py | 6 ++++-- jumanji/environments/routing/pac_man/utils.py | 2 +- jumanji/training/configs/env/pac_man.yaml | 2 +- 5 files changed, 8 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index ab6472ee4..4866657e3 100644 --- a/README.md +++ b/README.md @@ -119,7 +119,7 @@ problems. | 🐍 Snake | Routing | `Snake-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/snake/) | [doc](https://instadeepai.github.io/jumanji/environments/snake/) | | 📬 TSP (Travelling Salesman Problem) | Routing | `TSP-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/tsp/) | [doc](https://instadeepai.github.io/jumanji/environments/tsp/) | | Multi Minimum Spanning Tree Problem | Routing | `MMST-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/mmst) | [doc](https://instadeepai.github.io/jumanji/environments/mmst/) | -| ᗧ•••ᗣ•• PacMan | Routing | `PacMan-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/pac_man/) | [doc](https://instadeepai.github.io/jumanji/environments/pac_man/) +| ᗧ•••ᗣ•• PacMan | Routing | `PacMan-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/pac_man/) | [doc](https://instadeepai.github.io/jumanji/environments/pac_man/) | 👾 Sokoban | Routing | `Sokoban-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/sokoban/) | [doc](https://instadeepai.github.io/jumanji/environments/sokoban/) |

Installation 🎬

diff --git a/docs/environments/pac_man.md b/docs/environments/pac_man.md index b90b0ac09..ca19bcd5a 100644 --- a/docs/environments/pac_man.md +++ b/docs/environments/pac_man.md @@ -62,4 +62,4 @@ Eating a ghost when scatter mode is enabled also awards +200 points but, points ## Registered Versions 📖 -- `PacMan-v0`, PacMan in a 31x28 map with simple grid observations. +- `PacMan-v1`, PacMan in a 31x28 map with simple grid observations. diff --git a/jumanji/__init__.py b/jumanji/__init__.py index 5e04ad474..1beecead2 100644 --- a/jumanji/__init__.py +++ b/jumanji/__init__.py @@ -132,8 +132,10 @@ # Sokoban with deepmind dataset generator register(id="Sokoban-v0", entry_point="jumanji.environments:Sokoban") -# Pacman - minimal version of Atarti Pacman game -register(id="PacMan-v0", entry_point="jumanji.environments:PacMan") + +# Pacman - minimal version of Atari Pacman game +register(id="PacMan-v1", entry_point="jumanji.environments:PacMan") + # SlidingTilePuzzle - A sliding tile puzzle environment with the default grid size of 5x5. register( id="SlidingTilePuzzle-v0", entry_point="jumanji.environments:SlidingTilePuzzle" diff --git a/jumanji/environments/routing/pac_man/utils.py b/jumanji/environments/routing/pac_man/utils.py index 95216f37d..ef49f7741 100644 --- a/jumanji/environments/routing/pac_man/utils.py +++ b/jumanji/environments/routing/pac_man/utils.py @@ -325,7 +325,7 @@ def get_valid_positions(pos: chex.Array) -> Any: # Get distances of valid locations valid_no_back_d = valid_no_back * ghost_dist invert_mask = valid_no_back != 1 - invert_mask = invert_mask * jnp.inf + invert_mask = jnp.where(invert_mask, jnp.inf, invert_mask) # Set distance of all invalid areas to infinity valid_no_back_d = valid_no_back_d + invert_mask masked_dist = valid_no_back_d diff --git a/jumanji/training/configs/env/pac_man.yaml b/jumanji/training/configs/env/pac_man.yaml index 3c0f7b78e..e5ce83bce 100644 --- a/jumanji/training/configs/env/pac_man.yaml +++ b/jumanji/training/configs/env/pac_man.yaml @@ -1,5 +1,5 @@ name: pac_man -registered_version: PacMan-v0 +registered_version: PacMan-v1 network: num_channels: [4,4,1]