Skip to content

Commit

Permalink
fix: default value for obs in extras (#228)
Browse files Browse the repository at this point in the history
  • Loading branch information
clement-bonnet authored Mar 11, 2024
1 parent ce8b873 commit 8b10543
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 36 deletions.
44 changes: 24 additions & 20 deletions jumanji/training/loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,14 @@ def close(self) -> None:
def upload_checkpoint(self) -> None:
"""Uploads a checkpoint when exiting the logger."""

def is_loggable(self, value: Any) -> bool:
"""Returns True if the value is loggable."""
if isinstance(value, (float, int)):
return True
if isinstance(value, (jnp.ndarray, np.ndarray)):
return bool(value.ndim == 0)
return False

def __enter__(self) -> Logger:
logging.info("Starting logger.")
self._variables_enter = self._get_variables()
Expand Down Expand Up @@ -134,8 +142,9 @@ def __init__(
def _format_values(self, data: Dict[str, Any]) -> str:
return " | ".join(
f"{key.replace('_', ' ').title()}: "
f"{(f'{value:.3f}' if isinstance(value, (float, jnp.ndarray)) else f'{value:,}')}"
f"{(f'{value:,}' if isinstance(value, int) else f'{value:.3f}')}"
for key, value in sorted(data.items())
if self.is_loggable(value)
)

def write(
Expand Down Expand Up @@ -166,7 +175,8 @@ def write(
env_steps: Optional[int] = None,
) -> None:
for key, value in data.items():
self.history[key].append(value)
if self.is_loggable(value):
self.history[key].append(value)


class TensorboardLogger(Logger):
Expand All @@ -191,15 +201,12 @@ def write(
self._env_steps = env_steps
prefix = label and f"{label}/"
for key, metric in data.items():
if np.ndim(metric) == 0:
if not np.isnan(metric):
self.writer.add_scalar(
tag=f"{prefix}/{key}",
scalar_value=metric,
global_step=int(self._env_steps),
)
else:
raise ValueError(f"Expected metric {key} to be a scalar, got {metric}.")
if self.is_loggable(metric) and not np.isnan(metric):
self.writer.add_scalar(
tag=f"{prefix}/{key}",
scalar_value=metric,
global_step=int(self._env_steps),
)

def close(self) -> None:
self.writer.close()
Expand Down Expand Up @@ -232,15 +239,12 @@ def write(
self._env_steps = env_steps
prefix = label and f"{label}/"
for key, metric in data.items():
if np.ndim(metric) == 0:
if not np.isnan(metric):
self.run[f"{prefix}/{key}"].log(
float(metric),
step=int(self._env_steps),
wait=True,
)
else:
raise ValueError(f"Expected metric {key} to be a scalar, got {metric}.")
if self.is_loggable(metric) and not np.isnan(metric):
self.run[f"{prefix}/{key}"].log(
float(metric),
step=int(self._env_steps),
wait=True,
)

def close(self) -> None:
self.run.stop()
Expand Down
56 changes: 43 additions & 13 deletions jumanji/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,22 +363,20 @@ def render(self, state: State) -> Any:
NEXT_OBS_KEY_IN_EXTRAS = "next_obs"


def _obs_in_extras(
state: State, timestep: TimeStep[Observation]
) -> Tuple[State, TimeStep[Observation]]:
def add_obs_to_extras(timestep: TimeStep[Observation]) -> TimeStep[Observation]:
"""Place the observation in timestep.extras[NEXT_OBS_KEY_IN_EXTRAS].
Used when auto-resetting to store the observation from the terminal TimeStep.
Used when auto-resetting to store the observation from the terminal TimeStep (useful for
e.g. truncation).
Args:
state: State object containing the dynamics of the environment.
timestep: TimeStep object containing the timestep returned by the environment.
Returns:
(state, timestep): where the observation is placed in timestep.extras["next_obs"].
timestep where the observation is placed in timestep.extras["next_obs"].
"""
extras = timestep.extras
extras[NEXT_OBS_KEY_IN_EXTRAS] = timestep.observation
return state, timestep.replace(extras=extras) # type: ignore
return timestep.replace(extras=extras) # type: ignore


class AutoResetWrapper(Wrapper):
Expand All @@ -392,6 +390,21 @@ class AutoResetWrapper(Wrapper):
being processed each time `step` is called. Please use the `VmapAutoResetWrapper` instead.
"""

def __init__(self, env: Environment, next_obs_in_extras: bool = False):
"""Wrap an environment to automatically reset it when the episode terminates.
Args:
env: the environment to wrap.
next_obs_in_extras: whether to store the next observation in the extras of the
terminal timestep. This is useful for e.g. truncation.
"""
super().__init__(env)
self.next_obs_in_extras = next_obs_in_extras
if next_obs_in_extras:
self._maybe_add_obs_to_extras = add_obs_to_extras
else:
self._maybe_add_obs_to_extras = lambda timestep: timestep # no-op

def _auto_reset(
self, state: State, timestep: TimeStep[Observation]
) -> Tuple[State, TimeStep[Observation]]:
Expand All @@ -410,15 +423,17 @@ def _auto_reset(
state, reset_timestep = self._env.reset(key)

# Place original observation in extras.
state, timestep = _obs_in_extras(state, timestep)
timestep = self._maybe_add_obs_to_extras(timestep)

# Replace observation with reset observation.
timestep = timestep.replace(observation=reset_timestep.observation) # type: ignore

return state, timestep

def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
return _obs_in_extras(*super().reset(key))
state, timestep = super().reset(key)
timestep = self._maybe_add_obs_to_extras(timestep)
return state, timestep

def step(
self, state: State, action: chex.Array
Expand All @@ -430,7 +445,7 @@ def step(
state, timestep = jax.lax.cond(
timestep.last(),
self._auto_reset,
_obs_in_extras,
lambda s, t: (s, self._maybe_add_obs_to_extras(t)),
state,
timestep,
)
Expand All @@ -450,6 +465,21 @@ class VmapAutoResetWrapper(Wrapper):
NOTE: The observation from the terminal TimeStep is stored in timestep.extras["next_obs"].
"""

def __init__(self, env: Environment, next_obs_in_extras: bool = False):
"""Wrap an environment to vmap it and automatically reset it when the episode terminates.
Args:
env: the environment to wrap.
next_obs_in_extras: whether to store the next observation in the extras of the
terminal timestep. This is useful for e.g. truncation.
"""
super().__init__(env)
self.next_obs_in_extras = next_obs_in_extras
if next_obs_in_extras:
self._maybe_add_obs_to_extras = add_obs_to_extras
else:
self._maybe_add_obs_to_extras = lambda timestep: timestep # no-op

def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
"""Resets a batch of environments to initial states.
Expand All @@ -468,7 +498,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
environments,
"""
state, timestep = jax.vmap(self._env.reset)(key)
state, timestep = _obs_in_extras(state, timestep)
timestep = self._maybe_add_obs_to_extras(timestep)
return state, timestep

def step(
Expand Down Expand Up @@ -516,7 +546,7 @@ def _auto_reset(
state, reset_timestep = self._env.reset(key)

# Place original observation in extras.
state, timestep = _obs_in_extras(state, timestep)
timestep = self._maybe_add_obs_to_extras(timestep)

# Replace observation with reset observation.
timestep = timestep.replace( # type: ignore
Expand All @@ -532,7 +562,7 @@ def _maybe_reset(
state, timestep = jax.lax.cond(
timestep.last(),
self._auto_reset,
_obs_in_extras,
lambda s, t: (s, self._maybe_add_obs_to_extras(t)),
state,
timestep,
)
Expand Down
4 changes: 2 additions & 2 deletions jumanji/wrappers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ class TestAutoResetWrapper:
def fake_auto_reset_environment(
self, fake_environment: Environment
) -> AutoResetWrapper:
return AutoResetWrapper(fake_environment)
return AutoResetWrapper(fake_environment, next_obs_in_extras=True)

@pytest.fixture
def fake_state_and_timestep(
Expand Down Expand Up @@ -602,7 +602,7 @@ class TestVmapAutoResetWrapper:
def fake_vmap_auto_reset_environment(
self, fake_environment: FakeEnvironment
) -> VmapAutoResetWrapper:
return VmapAutoResetWrapper(fake_environment)
return VmapAutoResetWrapper(fake_environment, next_obs_in_extras=True)

@pytest.fixture
def action(
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ exclude =
.cache,
.eggs
max-line-length=100
max-cognitive-complexity=10
max-cognitive-complexity=14
import-order-style = google
application-import-names = jumanji
doctests = True
Expand Down

0 comments on commit 8b10543

Please sign in to comment.