From c08fc625b4ed08f23c121786ca5d837c10d9fffe Mon Sep 17 00:00:00 2001 From: RuanJohn Date: Mon, 11 Mar 2024 09:25:28 +0200 Subject: [PATCH] chore: remove jnp.where from _get_ones_like_expanded_block method --- jumanji/environments/packing/flat_pack/env.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/jumanji/environments/packing/flat_pack/env.py b/jumanji/environments/packing/flat_pack/env.py index f9be67d9b..f40ac668d 100644 --- a/jumanji/environments/packing/flat_pack/env.py +++ b/jumanji/environments/packing/flat_pack/env.py @@ -378,15 +378,9 @@ def _is_legal_action( return legal def _get_ones_like_expanded_block(self, grid_block: chex.Array) -> chex.Array: - """Makes a grid of zeroes with ones where the block is placed. + """Makes a grid of zeroes with ones where the block is placed.""" - Args: - grid_with_ones: block placed on a grid of zeroes. - """ - - grid_with_ones = jnp.where(grid_block != 0, 1, 0) - - return grid_with_ones + return (grid_block != 0).astype(jnp.int32) def _expand_block_to_grid( self,