Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: pacman ghost valid action calculations result in NaNs #241

Merged
merged 1 commit into from
May 11, 2024

Conversation

taodav
Copy link
Contributor

@taodav taodav commented May 9, 2024

In the PacMan environment, when trying to calculate all the valid actions a ghost could take (in check_ghost_wall_collisions in pac_man/utils.py) the invert_mask * jnp.inf call was producing an array full of NaN's where invert_mask == 1. This lead to all actions being valid for ghosts.

Instead, what this line should be doing is a jnp.where call, that conditionally replaces all 1's in invert_mask with jnp.inf.

@CLAassistant
Copy link

CLAassistant commented May 9, 2024

CLA assistant check
All committers have signed the CLA.

@clement-bonnet
Copy link
Collaborator

Hi, thank you for spotting this bug!
If this is true and that this fixes it, we then need to bump the version of PacMan from PacMan-v0 to PacMan-v1. Could you please make the following changes to the registry and documentation?

@taodav
Copy link
Contributor Author

taodav commented May 11, 2024

I've updated my commit to bump the version of PacMan to v1. Let me know if I've missed anything!

@clement-bonnet
Copy link
Collaborator

That's perfect, thanks!
I'm struggling to reproduce the issue as I am not finding NaNs in different tests I've done. Would you have a small reproduction of the NaNs that you would be able to share?
Thank you!

image

@taodav
Copy link
Contributor Author

taodav commented May 11, 2024

I did a bit of digging, and essentially if the function is jited, then jnp.inf * False returns 0 (which works with PacMan), whereas it should return NaN:

jax-ml/jax#12233 (comment)

The NaN's don't show up in the action_mask, but essentially zeros out the action_mask for ghosts. I put a breakpoint at the return statement of check_ghost_wall_collisions to see the NaNs in invert_mask. This only shows up if you set the environment variable JAX_DISABLE_JIT=1, which turns jit off.

@taodav
Copy link
Contributor Author

taodav commented May 11, 2024

Here is the script that I run, with the environment variable JAX_DISABLE_JIT=1:

import jax
from jumanji.environments.routing.pac_man import PacMan


if __name__ == "__main__":
    jax.disable_jit(True)

    seed = 2024
    key = jax.random.PRNGKey(seed)
    reset_key, key = jax.random.split(key)

    env = PacMan()

    state, tstep = env.reset(reset_key)

    next_state, tstep = env.step(state, 1)

@clement-bonnet
Copy link
Collaborator

When I train with and without the fix, I get the exact same learning curves (same loss at every step), hinting that the behavior has not changed. I wonder if the NaN behavior depends on the version of JAX?
I'm happy to merge this change as it is a cleaner implementation but if the environment behavior has not changed, then we should probably not bump the version. Would you have a way to show that the environment produces NaNs before the fix and not after this change?

@taodav
Copy link
Contributor Author

taodav commented May 11, 2024

Yes, the behavior would be the same, since according to this thread, XLA returns 0 instead of NaN for False * jnp.inf when things are JIT'ed, which just so happens to be the intended behavior in the code. The issue comes with debugging: when JAX_DISABLE_JIT=1, you have weird issues with ghosts going through walls. Here's an animated example:

unjit_pacman

If this doesn't warrant a version bump, then I'm more than happy to change the version back to v0.

@clement-bonnet
Copy link
Collaborator

Oh that makes complete sense. Since the behavior of the non-jitted environment changes, let's then bump the version.

Thank you for your contribution!

Copy link
Collaborator

@clement-bonnet clement-bonnet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@clement-bonnet clement-bonnet merged commit fd511b4 into instadeepai:main May 11, 2024
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants