diff --git a/api/wrappers/index.html b/api/wrappers/index.html index 59327e8b0..29c32f4e9 100644 --- a/api/wrappers/index.html +++ b/api/wrappers/index.html @@ -2199,7 +2199,7 @@

-__init__(self, env: Environment, reward_aggregator: Callable = <function sum at 0x7f157acde700>, discount_aggregator: Callable = <function amax at 0x7f157acdeee0>) +__init__(self, env: Environment, reward_aggregator: Callable = <function sum at 0x7fe959e73700>, discount_aggregator: Callable = <function amax at 0x7fe959e73ee0>) special @@ -2233,14 +2233,14 @@

Callable

a function to aggregate all agents rewards into a single scalar value, e.g. sum.

- <function sum at 0x7f157acde700> + <function sum at 0x7fe959e73700> discount_aggregator Callable

a function to aggregate all agents discounts into a single scalar value, e.g. max.

- <function amax at 0x7f157acdeee0> + <function amax at 0x7fe959e73ee0> diff --git a/search/search_index.json b/search/search_index.json index 03db6f660..e74ec0bef 100644 --- a/search/search_index.json +++ b/search/search_index.json @@ -1 +1 @@ -{"config":{"indexing":"full","lang":["en"],"min_search_length":3,"prebuild_index":false,"separator":"[\\s\\-]+"},"docs":[{"location":"","text":"Environments | Installation | Quickstart | Training | Citation | Docs Welcome to the Jungle! \ud83c\udf34 # Jumanji is a diverse suite of scalable reinforcement learning environments written in JAX. Jumanji is helping pioneer a new wave of hardware-accelerated research and development in the field of RL. Jumanji's high-speed environments enable faster iteration and large-scale experimentation while simultaneously reducing complexity. Originating in the Research Team at InstaDeep , Jumanji is now developed jointly with the open-source community. To join us in these efforts, reach out, raise issues and read our contribution guidelines or just star \ud83c\udf1f to stay up to date with the latest developments! Goals \ud83d\ude80 # Provide a simple, well-tested API for JAX-based environments. Make research in RL more accessible. Facilitate the research on RL for problems in the industry and help close the gap between research and industrial applications. Provide environments whose difficulty can be scaled to be arbitrarily hard. Overview \ud83e\udd9c # \ud83e\udd51 Environment API : core abstractions for JAX-based environments. \ud83d\udd79\ufe0f Environment Suite : a collection of RL environments ranging from simple games to NP-hard combinatorial problems. \ud83c\udf6c Wrappers : easily connect to your favourite RL frameworks and libraries such as Acme , Stable Baselines3 , RLlib , OpenAI Gym and DeepMind-Env through our dm_env and gym wrappers. \ud83c\udf93 Examples : guides to facilitate Jumanji's adoption and highlight the added value of JAX-based environments. \ud83c\udfce\ufe0f Training: example agents that can be used as inspiration for the agents one may implement in their research. Environments \ud83c\udf0d Jumanji provides a diverse range of environments ranging from simple games to NP-hard combinatorial problems. Environment Category Registered Version(s) Source Description \ud83d\udd22 Game2048 Logic Game2048-v1 code doc \ud83c\udfa8 GraphColoring Logic GraphColoring-v0 code doc \ud83d\udca3 Minesweeper Logic Minesweeper-v0 code doc \ud83c\udfb2 RubiksCube Logic RubiksCube-v0 RubiksCube-partly-scrambled-v0 code doc \u270f\ufe0f Sudoku Logic Sudoku-v0 Sudoku-very-easy-v0 code doc \ud83d\udce6 BinPack (3D BinPacking Problem) Packing BinPack-v2 code doc \ud83c\udfed JobShop (Job Shop Scheduling Problem) Packing JobShop-v0 code doc \ud83c\udf92 Knapsack Packing Knapsack-v1 code doc \u2592 Tetris Packing Tetris-v0 code doc \ud83e\uddf9 Cleaner Routing Cleaner-v0 code doc Connector Routing Connector-v2 code doc \ud83d\ude9a CVRP (Capacitated Vehicle Routing Problem) Routing CVRP-v1 code doc \ud83d\ude9a MultiCVRP (Multi-Agent Capacitated Vehicle Routing Problem) Routing MultiCVRP-v0 code doc Maze Routing Maze-v0 code doc RobotWarehouse Routing RobotWarehouse-v0 code doc \ud83d\udc0d Snake Routing Snake-v1 code doc \ud83d\udcec TSP (Travelling Salesman Problem) Routing TSP-v1 code doc Multi Minimum Spanning Tree Problem Routing MMST-v0 code doc Installation \ud83c\udfac You can install the latest release of Jumanji from PyPI: 1 pip install jumanji Alternatively, you can install the latest development version directly from GitHub: 1 pip install git+https://github.com/instadeepai/jumanji.git Jumanji has been tested on Python 3.8 and 3.9. Note that because the installation of JAX differs depending on your hardware accelerator, we advise users to explicitly install the correct JAX version (see the official installation guide ). Rendering: Matplotlib is used for rendering all the environments. To visualize the environments you will need a GUI backend. For example, on Linux, you can install Tk via: apt-get install python3-tk , or using conda: conda install tk . Check out Matplotlib backends for a list of backends you can use. Quickstart \u26a1 RL practitioners will find Jumanji's interface familiar as it combines the widely adopted OpenAI Gym and DeepMind Environment interfaces. From OpenAI Gym, we adopted the idea of a registry and the render method, while our TimeStep structure is inspired by DeepMind Environment. Basic Usage \ud83e\uddd1\u200d\ud83d\udcbb # 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 import jax import jumanji # Instantiate a Jumanji environment using the registry env = jumanji . make ( 'Snake-v1' ) # Reset your (jit-able) environment key = jax . random . PRNGKey ( 0 ) state , timestep = jax . jit ( env . reset )( key ) # (Optional) Render the env state env . render ( state ) # Interact with the (jit-able) environment action = env . action_spec () . generate_value () # Action selection (dummy value here) state , timestep = jax . jit ( env . step )( state , action ) # Take a step and observe the next state and time step state represents the internal state of the environment: it contains all the information required to take a step when executing an action. This should not be confused with the observation contained in the timestep , which is the information perceived by the agent. timestep is a dataclass containing step_type , reward , discount , observation and extras . This structure is similar to dm_env.TimeStep except for the extras field that was added to allow users to log environments metrics that are neither part of the agent's observation nor part of the environment's internal state. Advanced Usage \ud83e\uddd1\u200d\ud83d\udd2c # Being written in JAX, Jumanji's environments benefit from many of its features including automatic vectorization/parallelization ( jax.vmap , jax.pmap ) and JIT-compilation ( jax.jit ), which can be composed arbitrarily. We provide an example of a more advanced usage in the advanced usage guide . Registry and Versioning \ud83d\udcd6 # Like OpenAI Gym, Jumanji keeps a strict versioning of its environments for reproducibility reasons. We maintain a registry of standard environments with their configuration. For each environment, a version suffix is appended, e.g. Snake-v1 . When changes are made to environments that might impact learning results, the version number is incremented by one to prevent potential confusion. For a full list of registered versions of each environment, check out the documentation . Training \ud83c\udfce\ufe0f To showcase how to train RL agents on Jumanji environments, we provide a random agent and a vanilla actor-critic (A2C) agent. These agents can be found in jumanji/training/ . Because the environment framework in Jumanji is so flexible, it allows pretty much any problem to be implemented as a Jumanji environment, giving rise to very diverse observations. For this reason, environment-specific networks are required to capture the symmetries of each environment. Alongside the A2C agent implementation, we provide examples of such environment-specific actor-critic networks in jumanji/training/networks . \u26a0\ufe0f The example agents in jumanji/training are only meant to serve as inspiration for how one can implement an agent. Jumanji is first and foremost a library of environments - as such, the agents and networks will not be maintained to a production standard. For more information on how to use the example agents, see the training guide . Contributing \ud83e\udd1d # Contributions are welcome! See our issue tracker for good first issues . Please read our contributing guidelines for details on how to submit pull requests, our Contributor License Agreement, and community guidelines. Citing Jumanji \u270f\ufe0f If you use Jumanji in your work, please cite the library using: 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 @misc{bonnet2023jumanji, title={Jumanji: a Diverse Suite of Scalable Reinforcement Learning Environments in JAX}, author={ Cl\u00e9ment Bonnet and Daniel Luo and Donal Byrne and Shikha Surana and Vincent Coyette and Paul Duckworth and Laurence I. Midgley and Tristan Kalloniatis and Sasha Abramowitz and Cemlyn N. Waters and Andries P. Smit and Nathan Grinsztajn and Ulrich A. Mbou Sob and Omayma Mahjoub and Elshadai Tegegn and Mohamed A. Mimouni and Raphael Boige and Ruan de Kock and Daniel Furelos-Blanco and Victor Le and Arnu Pretorius and Alexandre Laterre }, year={2023}, eprint={2306.09884}, url={https://arxiv.org/abs/2306.09884}, archivePrefix={arXiv}, primaryClass={cs.LG} } See Also \ud83d\udd0e # Other works have embraced the approach of writing RL environments in JAX. In particular, we suggest users check out the following sister repositories: \ud83e\udd16 Qdax is a library to accelerate Quality-Diversity and neuro-evolution algorithms through hardware accelerators and parallelization. \ud83c\udf33 Evojax provides tools to enable neuroevolution algorithms to work with neural networks running across multiple TPU/GPUs. \ud83e\uddbe Brax is a differentiable physics engine that simulates environments made up of rigid bodies, joints, and actuators. \ud83c\udfcb\ufe0f\u200d Gymnax implements classic environments including classic control, bsuite, MinAtar and a collection of meta RL tasks. \ud83c\udfb2 Pgx provides classic board game environments like Backgammon, Shogi, and Go. Acknowledgements \ud83d\ude4f # The development of this library was supported with Cloud TPUs from Google's TPU Research Cloud (TRC) \ud83c\udf24.","title":"Home"},{"location":"#welcome-to-the-jungle","text":"Jumanji is a diverse suite of scalable reinforcement learning environments written in JAX. Jumanji is helping pioneer a new wave of hardware-accelerated research and development in the field of RL. Jumanji's high-speed environments enable faster iteration and large-scale experimentation while simultaneously reducing complexity. Originating in the Research Team at InstaDeep , Jumanji is now developed jointly with the open-source community. To join us in these efforts, reach out, raise issues and read our contribution guidelines or just star \ud83c\udf1f to stay up to date with the latest developments!","title":"Welcome to the Jungle! \ud83c\udf34"},{"location":"#goals","text":"Provide a simple, well-tested API for JAX-based environments. Make research in RL more accessible. Facilitate the research on RL for problems in the industry and help close the gap between research and industrial applications. Provide environments whose difficulty can be scaled to be arbitrarily hard.","title":"Goals \ud83d\ude80"},{"location":"#overview","text":"\ud83e\udd51 Environment API : core abstractions for JAX-based environments. \ud83d\udd79\ufe0f Environment Suite : a collection of RL environments ranging from simple games to NP-hard combinatorial problems. \ud83c\udf6c Wrappers : easily connect to your favourite RL frameworks and libraries such as Acme , Stable Baselines3 , RLlib , OpenAI Gym and DeepMind-Env through our dm_env and gym wrappers. \ud83c\udf93 Examples : guides to facilitate Jumanji's adoption and highlight the added value of JAX-based environments. \ud83c\udfce\ufe0f Training: example agents that can be used as inspiration for the agents one may implement in their research.","title":"Overview \ud83e\udd9c"},{"location":"#basic-usage","text":"1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 import jax import jumanji # Instantiate a Jumanji environment using the registry env = jumanji . make ( 'Snake-v1' ) # Reset your (jit-able) environment key = jax . random . PRNGKey ( 0 ) state , timestep = jax . jit ( env . reset )( key ) # (Optional) Render the env state env . render ( state ) # Interact with the (jit-able) environment action = env . action_spec () . generate_value () # Action selection (dummy value here) state , timestep = jax . jit ( env . step )( state , action ) # Take a step and observe the next state and time step state represents the internal state of the environment: it contains all the information required to take a step when executing an action. This should not be confused with the observation contained in the timestep , which is the information perceived by the agent. timestep is a dataclass containing step_type , reward , discount , observation and extras . This structure is similar to dm_env.TimeStep except for the extras field that was added to allow users to log environments metrics that are neither part of the agent's observation nor part of the environment's internal state.","title":"Basic Usage \ud83e\uddd1\u200d\ud83d\udcbb"},{"location":"#advanced-usage","text":"Being written in JAX, Jumanji's environments benefit from many of its features including automatic vectorization/parallelization ( jax.vmap , jax.pmap ) and JIT-compilation ( jax.jit ), which can be composed arbitrarily. We provide an example of a more advanced usage in the advanced usage guide .","title":"Advanced Usage \ud83e\uddd1\u200d\ud83d\udd2c"},{"location":"#registry-and-versioning","text":"Like OpenAI Gym, Jumanji keeps a strict versioning of its environments for reproducibility reasons. We maintain a registry of standard environments with their configuration. For each environment, a version suffix is appended, e.g. Snake-v1 . When changes are made to environments that might impact learning results, the version number is incremented by one to prevent potential confusion. For a full list of registered versions of each environment, check out the documentation .","title":"Registry and Versioning \ud83d\udcd6"},{"location":"#contributing","text":"Contributions are welcome! See our issue tracker for good first issues . Please read our contributing guidelines for details on how to submit pull requests, our Contributor License Agreement, and community guidelines.","title":"Contributing \ud83e\udd1d"},{"location":"#see-also","text":"Other works have embraced the approach of writing RL environments in JAX. In particular, we suggest users check out the following sister repositories: \ud83e\udd16 Qdax is a library to accelerate Quality-Diversity and neuro-evolution algorithms through hardware accelerators and parallelization. \ud83c\udf33 Evojax provides tools to enable neuroevolution algorithms to work with neural networks running across multiple TPU/GPUs. \ud83e\uddbe Brax is a differentiable physics engine that simulates environments made up of rigid bodies, joints, and actuators. \ud83c\udfcb\ufe0f\u200d Gymnax implements classic environments including classic control, bsuite, MinAtar and a collection of meta RL tasks. \ud83c\udfb2 Pgx provides classic board game environments like Backgammon, Shogi, and Go.","title":"See Also \ud83d\udd0e"},{"location":"#acknowledgements","text":"The development of this library was supported with Cloud TPUs from Google's TPU Research Cloud (TRC) \ud83c\udf24.","title":"Acknowledgements \ud83d\ude4f"},{"location":"api/env/","text":"Environment ( ABC , Generic ) # Environment written in Jax that differs from the gym API to make the step and reset functions jittable. The state contains all the dynamics and data needed to step the environment, no computation stored in attributes of self. The API is inspired by brax . unwrapped : Environment property readonly # reset ( self , key : PRNGKeyArray ) -> Tuple [ ~ State , jumanji . types . TimeStep ] # Resets the environment to an initial state. Parameters: Name Type Description Default key PRNGKeyArray random key used to reset the environment. required Returns: Type Description state State object corresponding to the new state of the environment, timestep: TimeStep object corresponding the first timestep returned by the environment, step ( self , state : ~ State , action : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ]) -> Tuple [ ~ State , jumanji . types . TimeStep ] # Run one timestep of the environment's dynamics. Parameters: Name Type Description Default state ~State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] Array containing the action to take. required Returns: Type Description state State object corresponding to the next state of the environment, timestep: TimeStep object corresponding the timestep returned by the environment, observation_spec ( self ) -> Spec # Returns the observation spec. Returns: Type Description observation_spec a NestedSpec tree of spec. action_spec ( self ) -> Spec # Returns the action spec. Returns: Type Description action_spec a NestedSpec tree of spec. reward_spec ( self ) -> Array # Describes the reward returned by the environment. By default, this is assumed to be a single float. Returns: Type Description reward_spec a specs.Array spec. discount_spec ( self ) -> BoundedArray # Describes the discount returned by the environment. By default, this is assumed to be a single float between 0 and 1. Returns: Type Description discount_spec a specs.BoundedArray spec. render ( self , state : ~ State ) -> Any # Render frames of the environment for a given state. Parameters: Name Type Description Default state ~State State object containing the current dynamics of the environment. required close ( self ) -> None # Perform any necessary cleanup. __enter__ ( self ) -> Environment special # __exit__ ( self , * args : Any ) -> None special # Calls :meth: close() .","title":"Base"},{"location":"api/env/#jumanji.env.Environment","text":"Environment written in Jax that differs from the gym API to make the step and reset functions jittable. The state contains all the dynamics and data needed to step the environment, no computation stored in attributes of self. The API is inspired by brax .","title":"Environment"},{"location":"api/env/#jumanji.env.Environment.unwrapped","text":"","title":"unwrapped"},{"location":"api/env/#jumanji.env.Environment.reset","text":"Resets the environment to an initial state. Parameters: Name Type Description Default key PRNGKeyArray random key used to reset the environment. required Returns: Type Description state State object corresponding to the new state of the environment, timestep: TimeStep object corresponding the first timestep returned by the environment,","title":"reset()"},{"location":"api/env/#jumanji.env.Environment.step","text":"Run one timestep of the environment's dynamics. Parameters: Name Type Description Default state ~State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] Array containing the action to take. required Returns: Type Description state State object corresponding to the next state of the environment, timestep: TimeStep object corresponding the timestep returned by the environment,","title":"step()"},{"location":"api/env/#jumanji.env.Environment.observation_spec","text":"Returns the observation spec. Returns: Type Description observation_spec a NestedSpec tree of spec.","title":"observation_spec()"},{"location":"api/env/#jumanji.env.Environment.action_spec","text":"Returns the action spec. Returns: Type Description action_spec a NestedSpec tree of spec.","title":"action_spec()"},{"location":"api/env/#jumanji.env.Environment.reward_spec","text":"Describes the reward returned by the environment. By default, this is assumed to be a single float. Returns: Type Description reward_spec a specs.Array spec.","title":"reward_spec()"},{"location":"api/env/#jumanji.env.Environment.discount_spec","text":"Describes the discount returned by the environment. By default, this is assumed to be a single float between 0 and 1. Returns: Type Description discount_spec a specs.BoundedArray spec.","title":"discount_spec()"},{"location":"api/env/#jumanji.env.Environment.render","text":"Render frames of the environment for a given state. Parameters: Name Type Description Default state ~State State object containing the current dynamics of the environment. required","title":"render()"},{"location":"api/env/#jumanji.env.Environment.close","text":"Perform any necessary cleanup.","title":"close()"},{"location":"api/env/#jumanji.env.Environment.__enter__","text":"","title":"__enter__()"},{"location":"api/env/#jumanji.env.Environment.__exit__","text":"Calls :meth: close() .","title":"__exit__()"},{"location":"api/types/","text":"types # StepType ( int8 ) # Defines the status of a TimeStep within a sequence. First: 0 Mid: 1 Last: 2 TimeStep ( Generic , Mapping ) dataclass # Copied from dm_env.TimeStep with the goal of making it a Jax Type. The original dm_env.TimeStep is not a Jax type because inheriting a namedtuple is not treated as a valid Jax type (https://github.com/google/jax/issues/806). A TimeStep contains the data emitted by an environment at each step of interaction. A TimeStep holds a step_type , an observation (typically a NumPy array or a dict or list of arrays), and an associated reward and discount . The first TimeStep in a sequence will have StepType.FIRST . The final TimeStep will have StepType.LAST . All other TimeStep s in a sequence will have `StepType.MID. Attributes: Name Type Description step_type StepType A StepType enum value. reward Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] A scalar, NumPy array, nested dict, list or tuple of rewards; or None if step_type is StepType.FIRST , i.e. at the start of a sequence. discount Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] A scalar, NumPy array, nested dict, list or tuple of discount values in the range [0, 1] , or None if step_type is StepType.FIRST , i.e. at the start of a sequence. observation ~Observation A NumPy array, or a nested dict, list or tuple of arrays. Scalar values that can be cast to NumPy arrays (e.g. Python floats) are also valid in place of a scalar array. extras Optional[Dict] environment metric(s) or information returned by the environment but not observed by the agent (hence not in the observation). For example, it could be whether an invalid action was taken. In most environments, extras is None. step_type : StepType dataclass-field # reward : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ] dataclass-field # discount : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ] dataclass-field # observation : ~ Observation dataclass-field # extras : Optional [ Dict ] dataclass-field # __eq__ ( self , other ) special # __init__ ( self , step_type : StepType , reward : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ], discount : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ], observation : ~ Observation , extras : Optional [ Dict ] = None ) -> None special # __repr__ ( self ) special # __getitem__ ( self , x ) special # __len__ ( self ) special # __iter__ ( self ) special # first ( self ) -> Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ] # mid ( self ) -> Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ] # last ( self ) -> Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ] # from_tuple ( args ) # to_tuple ( self ) # replace ( self , ** kwargs ) # __getstate__ ( self ) special # __setstate__ ( self , state ) special # restart ( observation : ~ Observation , extras : Optional [ Dict ] = None , shape : Union [ int , Sequence [ int ]] = ()) -> TimeStep # Returns a TimeStep with step_type set to StepType.FIRST . Parameters: Name Type Description Default observation ~Observation array or tree of arrays. required extras Optional[Dict] environment metric(s) or information returned by the environment but not observed by the agent (hence not in the observation). For example, it could be whether an invalid action was taken. In most environments, extras is None. None shape Union[int, Sequence[int]] optional parameter to specify the shape of the rewards and discounts. Allows multi-agent environment compatibility. Defaults to () for scalar reward and discount. () Returns: Type Description TimeStep TimeStep identified as a reset. transition ( reward : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ], observation : ~ Observation , discount : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ] = None , extras : Optional [ Dict ] = None , shape : Union [ int , Sequence [ int ]] = ()) -> TimeStep # Returns a TimeStep with step_type set to StepType.MID . Parameters: Name Type Description Default reward Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] array. required observation ~Observation array or tree of arrays. required discount Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] array. None extras Optional[Dict] environment metric(s) or information returned by the environment but not observed by the agent (hence not in the observation). For example, it could be whether an invalid action was taken. In most environments, extras is None. None shape Union[int, Sequence[int]] optional parameter to specify the shape of the rewards and discounts. Allows multi-agent environment compatibility. Defaults to () for scalar reward and discount. () Returns: Type Description TimeStep TimeStep identified as a transition. termination ( reward : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ], observation : ~ Observation , extras : Optional [ Dict ] = None , shape : Union [ int , Sequence [ int ]] = ()) -> TimeStep # Returns a TimeStep with step_type set to StepType.LAST . Parameters: Name Type Description Default reward Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] array. required observation ~Observation array or tree of arrays. required extras Optional[Dict] environment metric(s) or information returned by the environment but not observed by the agent (hence not in the observation). For example, it could be whether an invalid action was taken. In most environments, extras is None. None shape Union[int, Sequence[int]] optional parameter to specify the shape of the rewards and discounts. Allows multi-agent environment compatibility. Defaults to () for scalar reward and discount. () Returns: Type Description TimeStep TimeStep identified as the termination of an episode. truncation ( reward : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ], observation : ~ Observation , discount : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ] = None , extras : Optional [ Dict ] = None , shape : Union [ int , Sequence [ int ]] = ()) -> TimeStep # Returns a TimeStep with step_type set to StepType.LAST . Parameters: Name Type Description Default reward Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] array. required observation ~Observation array or tree of arrays. required discount Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] array. None extras Optional[Dict] environment metric(s) or information returned by the environment but not observed by the agent (hence not in the observation). For example, it could be whether an invalid action was taken. In most environments, extras is None. None shape Union[int, Sequence[int]] optional parameter to specify the shape of the rewards and discounts. Allows multi-agent environment compatibility. Defaults to () for scalar reward and discount. () Returns: Type Description TimeStep TimeStep identified as the truncation of an episode. get_valid_dtype ( dtype : Union [ numpy . dtype , type ]) -> dtype # Cast a dtype taking into account the user type precision. E.g., if 64 bit is not enabled, jnp.dtype(jnp.float_) is still float64. By passing the given dtype through jnp.empty we get the supported dtype of float32. Parameters: Name Type Description Default dtype Union[numpy.dtype, type] jax numpy dtype or string specifying the array dtype. required Returns: Type Description dtype dtype converted to the correct type precision.","title":"Types"},{"location":"api/types/#jumanji.types","text":"","title":"types"},{"location":"api/types/#jumanji.types.StepType","text":"Defines the status of a TimeStep within a sequence. First: 0 Mid: 1 Last: 2","title":"StepType"},{"location":"api/types/#jumanji.types.TimeStep","text":"Copied from dm_env.TimeStep with the goal of making it a Jax Type. The original dm_env.TimeStep is not a Jax type because inheriting a namedtuple is not treated as a valid Jax type (https://github.com/google/jax/issues/806). A TimeStep contains the data emitted by an environment at each step of interaction. A TimeStep holds a step_type , an observation (typically a NumPy array or a dict or list of arrays), and an associated reward and discount . The first TimeStep in a sequence will have StepType.FIRST . The final TimeStep will have StepType.LAST . All other TimeStep s in a sequence will have `StepType.MID. Attributes: Name Type Description step_type StepType A StepType enum value. reward Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] A scalar, NumPy array, nested dict, list or tuple of rewards; or None if step_type is StepType.FIRST , i.e. at the start of a sequence. discount Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] A scalar, NumPy array, nested dict, list or tuple of discount values in the range [0, 1] , or None if step_type is StepType.FIRST , i.e. at the start of a sequence. observation ~Observation A NumPy array, or a nested dict, list or tuple of arrays. Scalar values that can be cast to NumPy arrays (e.g. Python floats) are also valid in place of a scalar array. extras Optional[Dict] environment metric(s) or information returned by the environment but not observed by the agent (hence not in the observation). For example, it could be whether an invalid action was taken. In most environments, extras is None.","title":"TimeStep"},{"location":"api/types/#jumanji.types.restart","text":"Returns a TimeStep with step_type set to StepType.FIRST . Parameters: Name Type Description Default observation ~Observation array or tree of arrays. required extras Optional[Dict] environment metric(s) or information returned by the environment but not observed by the agent (hence not in the observation). For example, it could be whether an invalid action was taken. In most environments, extras is None. None shape Union[int, Sequence[int]] optional parameter to specify the shape of the rewards and discounts. Allows multi-agent environment compatibility. Defaults to () for scalar reward and discount. () Returns: Type Description TimeStep TimeStep identified as a reset.","title":"restart()"},{"location":"api/types/#jumanji.types.transition","text":"Returns a TimeStep with step_type set to StepType.MID . Parameters: Name Type Description Default reward Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] array. required observation ~Observation array or tree of arrays. required discount Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] array. None extras Optional[Dict] environment metric(s) or information returned by the environment but not observed by the agent (hence not in the observation). For example, it could be whether an invalid action was taken. In most environments, extras is None. None shape Union[int, Sequence[int]] optional parameter to specify the shape of the rewards and discounts. Allows multi-agent environment compatibility. Defaults to () for scalar reward and discount. () Returns: Type Description TimeStep TimeStep identified as a transition.","title":"transition()"},{"location":"api/types/#jumanji.types.termination","text":"Returns a TimeStep with step_type set to StepType.LAST . Parameters: Name Type Description Default reward Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] array. required observation ~Observation array or tree of arrays. required extras Optional[Dict] environment metric(s) or information returned by the environment but not observed by the agent (hence not in the observation). For example, it could be whether an invalid action was taken. In most environments, extras is None. None shape Union[int, Sequence[int]] optional parameter to specify the shape of the rewards and discounts. Allows multi-agent environment compatibility. Defaults to () for scalar reward and discount. () Returns: Type Description TimeStep TimeStep identified as the termination of an episode.","title":"termination()"},{"location":"api/types/#jumanji.types.truncation","text":"Returns a TimeStep with step_type set to StepType.LAST . Parameters: Name Type Description Default reward Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] array. required observation ~Observation array or tree of arrays. required discount Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] array. None extras Optional[Dict] environment metric(s) or information returned by the environment but not observed by the agent (hence not in the observation). For example, it could be whether an invalid action was taken. In most environments, extras is None. None shape Union[int, Sequence[int]] optional parameter to specify the shape of the rewards and discounts. Allows multi-agent environment compatibility. Defaults to () for scalar reward and discount. () Returns: Type Description TimeStep TimeStep identified as the truncation of an episode.","title":"truncation()"},{"location":"api/types/#jumanji.types.get_valid_dtype","text":"Cast a dtype taking into account the user type precision. E.g., if 64 bit is not enabled, jnp.dtype(jnp.float_) is still float64. By passing the given dtype through jnp.empty we get the supported dtype of float32. Parameters: Name Type Description Default dtype Union[numpy.dtype, type] jax numpy dtype or string specifying the array dtype. required Returns: Type Description dtype dtype converted to the correct type precision.","title":"get_valid_dtype()"},{"location":"api/wrappers/","text":"wrappers # Wrapper ( Environment , Generic ) # Wraps the environment to allow modular transformations. Source: https://github.com/google/brax/blob/main/brax/envs/env.py#L72 unwrapped : Environment property readonly # Returns the wrapped env. __init__ ( self , env : Environment ) special # reset ( self , key : PRNGKeyArray ) -> Tuple [ ~ State , jumanji . types . TimeStep ] # Resets the environment to an initial state. Parameters: Name Type Description Default key PRNGKeyArray random key used to reset the environment. required Returns: Type Description state State object corresponding to the new state of the environment, timestep: TimeStep object corresponding the first timestep returned by the environment, step ( self , state : ~ State , action : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ]) -> Tuple [ ~ State , jumanji . types . TimeStep ] # Run one timestep of the environment's dynamics. Parameters: Name Type Description Default state ~State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] Array containing the action to take. required Returns: Type Description state State object corresponding to the next state of the environment, timestep: TimeStep object corresponding the timestep returned by the environment, observation_spec ( self ) -> Spec # Returns the observation spec. action_spec ( self ) -> Spec # Returns the action spec. render ( self , state : ~ State ) -> Any # Compute render frames during initialisation of the environment. Parameters: Name Type Description Default state ~State State object containing the dynamics of the environment. required close ( self ) -> None # Perform any necessary cleanup. Environments will automatically :meth: close() themselves when garbage collected or when the program exits. __enter__ ( self ) -> Wrapper special # __exit__ ( self , * args : Any ) -> None special # JumanjiToDMEnvWrapper ( Environment ) # A wrapper that converts Environment to dm_env.Environment. unwrapped : Environment property readonly # __init__ ( self , env : Environment , key : Optional [ jax . _src . prng . PRNGKeyArray ] = None ) special # Create the wrapped environment. Parameters: Name Type Description Default env Environment Environment to wrap to a dm_env.Environment . required key Optional[jax._src.prng.PRNGKeyArray] optional key to initialize the Environment with. None reset ( self ) -> TimeStep # Starts a new sequence and returns the first TimeStep of this sequence. Returns: Type Description A `TimeStep` namedtuple containing step_type: A StepType of FIRST . reward: None , indicating the reward is undefined. discount: None , indicating the discount is undefined. observation: A NumPy array, or a nested dict, list or tuple of arrays. Scalar values that can be cast to NumPy arrays (e.g. Python floats) are also valid in place of a scalar array. Must conform to the specification returned by observation_spec() . step ( self , action : ndarray ) -> TimeStep # Updates the environment according to the action and returns a TimeStep . If the environment returned a TimeStep with StepType.LAST at the previous step, this call to step will start a new sequence and action will be ignored. This method will also start a new sequence if called after the environment has been constructed and reset has not been called. Again, in this case action will be ignored. Parameters: Name Type Description Default action ndarray A NumPy array, or a nested dict, list or tuple of arrays corresponding to action_spec() . required Returns: Type Description A `TimeStep` namedtuple containing step_type: A StepType value. reward: Reward at this timestep, or None if step_type is StepType.FIRST . Must conform to the specification returned by reward_spec() . discount: A discount in the range [0, 1], or None if step_type is StepType.FIRST . Must conform to the specification returned by discount_spec() . observation: A NumPy array, or a nested dict, list or tuple of arrays. Scalar values that can be cast to NumPy arrays (e.g. Python floats) are also valid in place of a scalar array. Must conform to the specification returned by observation_spec() . observation_spec ( self ) -> Array # Returns the dm_env observation spec. action_spec ( self ) -> Array # Returns the dm_env action spec. MultiToSingleWrapper ( Wrapper ) # A wrapper that converts a multi-agent Environment to a single-agent Environment. __init__ ( self , env : Environment , reward_aggregator : Callable = < function sum at 0x7f157acde700 > , discount_aggregator : Callable = < function amax at 0x7f157acdeee0 > ) special # Create the wrapped environment. Parameters: Name Type Description Default env Environment Environment to wrap to a dm_env.Environment . required reward_aggregator Callable a function to aggregate all agents rewards into a single scalar value, e.g. sum. discount_aggregator Callable a function to aggregate all agents discounts into a single scalar value, e.g. max. reset ( self , key : PRNGKeyArray ) -> Tuple [ ~ State , jumanji . types . TimeStep [ ~ Observation ]] # Resets the environment to an initial state. Parameters: Name Type Description Default key PRNGKeyArray random key used to reset the environment. required Returns: Type Description state State object corresponding to the new state of the environment, timestep: TimeStep object corresponding the first timestep returned by the environment, step ( self , state : ~ State , action : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ]) -> Tuple [ ~ State , jumanji . types . TimeStep [ ~ Observation ]] # Run one timestep of the environment's dynamics. The rewards are aggregated into a single value based on the given reward aggregator. The discount value is set to the largest discount of all the agents. This essentially means that if any single agent is alive, the discount value won't be zero. Parameters: Name Type Description Default state ~State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] Array containing the action to take. required Returns: Type Description state State object corresponding to the next state of the environment, timestep: TimeStep object corresponding the timestep returned by the environment, VmapWrapper ( Wrapper ) # Vectorized Jax env. Please note that all methods that return arrays do not return a batch dimension because the batch size is not known to the VmapWrapper. Methods that omit the batch dimension include: - observation_spec - action_spec - reward_spec - discount_spec reset ( self , key : PRNGKeyArray ) -> Tuple [ ~ State , jumanji . types . TimeStep [ ~ Observation ]] # Resets the environment to an initial state. The first dimension of the key will dictate the number of concurrent environments. To obtain a key with the right first dimension, you may call jax.random.split on key with the parameter num representing the number of concurrent environments. Parameters: Name Type Description Default key PRNGKeyArray random keys used to reset the environments where the first dimension is the number of desired environments. required Returns: Type Description state State object corresponding to the new state of the environments, timestep: TimeStep object corresponding the first timesteps returned by the environments, step ( self , state : ~ State , action : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ]) -> Tuple [ ~ State , jumanji . types . TimeStep [ ~ Observation ]] # Run one timestep of the environment's dynamics. The first dimension of the state will dictate the number of concurrent environments. See VmapWrapper.reset for more details on how to get a state of concurrent environments. Parameters: Name Type Description Default state ~State State object containing the dynamics of the environments. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] Array containing the actions to take. required Returns: Type Description state State object corresponding to the next states of the environments, timestep: TimeStep object corresponding the timesteps returned by the environments, render ( self , state : ~ State ) -> Any # Render the first environment state of the given batch. The remaining elements of the batched state are ignored. Parameters: Name Type Description Default state ~State State object containing the current dynamics of the environment. required AutoResetWrapper ( Wrapper ) # Automatically resets environments that are done. Once the terminal state is reached, the state, observation, and step_type are reset. The observation and step_type of the terminal TimeStep is reset to the reset observation and StepType.LAST, respectively. The reward, discount, and extras retrieved from the transition to the terminal state. WARNING: do not jax.vmap the wrapped environment (e.g. do not use with the VmapWrapper ), which would lead to inefficient computation due to both the step and reset functions being processed each time step is called. Please use the VmapAutoResetWrapper instead. step ( self , state : ~ State , action : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ]) -> Tuple [ ~ State , jumanji . types . TimeStep [ ~ Observation ]] # Step the environment, with automatic resetting if the episode terminates. JumanjiToGymWrapper ( Env ) # A wrapper that converts a Jumanji Environment to one that follows the gym.Env API. unwrapped : Environment property readonly # Returns the base non-wrapped environment. Returns: Type Description Env The base non-wrapped gym.Env instance __init__ ( self , env : Environment , seed : int = 0 , backend : Optional [ str ] = None ) special # Create the Gym environment. Parameters: Name Type Description Default env Environment Environment to wrap to a gym.Env . required seed int the seed that is used to initialize the environment's PRNG. 0 backend Optional[str] the XLA backend. None reset ( self , * , seed : Optional [ int ] = None , return_info : bool = False , options : Optional [ dict ] = None ) -> Union [ Any , Tuple [ Any , Union [ Any ]]] # Resets the environment to an initial state by starting a new sequence and returns the first Observation of this sequence. Returns: Type Description obs an element of the environment's observation_space. info (optional): contains supplementary information such as metrics. step ( self , action : ndarray ) -> Tuple [ Any , float , bool , Optional [ Any ]] # Updates the environment according to the action and returns an Observation . Parameters: Name Type Description Default action ndarray A NumPy array representing the action provided by the agent. required Returns: Type Description observation an element of the environment's observation_space. reward: the amount of reward returned as a result of taking the action. terminated: whether a terminal state is reached. info: contains supplementary information such as metrics. seed ( self , seed : int = 0 ) -> None # Function which sets the seed for the environment's random number generator(s). Parameters: Name Type Description Default seed int the seed value for the random number generator(s). 0 render ( self , mode : str = 'human' ) -> Any # Renders the environment. Parameters: Name Type Description Default mode str currently not used since Jumanji does not currently support modes. 'human' close ( self ) -> None # Closes the environment, important for rendering where pygame is imported. jumanji_to_gym_obs ( observation : ~ Observation ) -> Any # Convert a Jumanji observation into a gym observation. Parameters: Name Type Description Default observation ~Observation JAX pytree with (possibly nested) containers that either have the __dict__ or _asdict methods implemented. required Returns: Type Description Any Numpy array or nested dictionary of numpy arrays.","title":"Wrappers"},{"location":"api/wrappers/#jumanji.wrappers","text":"","title":"wrappers"},{"location":"api/wrappers/#jumanji.wrappers.Wrapper","text":"Wraps the environment to allow modular transformations. Source: https://github.com/google/brax/blob/main/brax/envs/env.py#L72","title":"Wrapper"},{"location":"api/wrappers/#jumanji.wrappers.JumanjiToDMEnvWrapper","text":"A wrapper that converts Environment to dm_env.Environment.","title":"JumanjiToDMEnvWrapper"},{"location":"api/wrappers/#jumanji.wrappers.MultiToSingleWrapper","text":"A wrapper that converts a multi-agent Environment to a single-agent Environment.","title":"MultiToSingleWrapper"},{"location":"api/wrappers/#jumanji.wrappers.VmapWrapper","text":"Vectorized Jax env. Please note that all methods that return arrays do not return a batch dimension because the batch size is not known to the VmapWrapper. Methods that omit the batch dimension include: - observation_spec - action_spec - reward_spec - discount_spec","title":"VmapWrapper"},{"location":"api/wrappers/#jumanji.wrappers.AutoResetWrapper","text":"Automatically resets environments that are done. Once the terminal state is reached, the state, observation, and step_type are reset. The observation and step_type of the terminal TimeStep is reset to the reset observation and StepType.LAST, respectively. The reward, discount, and extras retrieved from the transition to the terminal state. WARNING: do not jax.vmap the wrapped environment (e.g. do not use with the VmapWrapper ), which would lead to inefficient computation due to both the step and reset functions being processed each time step is called. Please use the VmapAutoResetWrapper instead.","title":"AutoResetWrapper"},{"location":"api/wrappers/#jumanji.wrappers.JumanjiToGymWrapper","text":"A wrapper that converts a Jumanji Environment to one that follows the gym.Env API.","title":"JumanjiToGymWrapper"},{"location":"api/wrappers/#jumanji.wrappers.jumanji_to_gym_obs","text":"Convert a Jumanji observation into a gym observation. Parameters: Name Type Description Default observation ~Observation JAX pytree with (possibly nested) containers that either have the __dict__ or _asdict methods implemented. required Returns: Type Description Any Numpy array or nested dictionary of numpy arrays.","title":"jumanji_to_gym_obs()"},{"location":"api/environments/bin_pack/","text":"BinPack ( Environment ) # Problem of 3D bin packing, where a set of items have to be placed in a 3D container with the goal of maximizing its volume utilization. This environment only supports 1 bin, meaning it is equivalent to the 3D-knapsack problem. We use the Empty Maximal Space (EMS) formulation of this problem. An EMS is a 3D-rectangular space that lives inside the container and has the following Properties It does not intersect any items, and it is not fully included into any other EMSs. It is defined by 2 3D-points, hence 6 coordinates (x1, x2, y1, y2, z1, z2), the first point corresponding to its bottom-left location while the second defining its top-right corner. observation: Observation ems: EMS tree of jax arrays (float if normalize_dimensions else int32) each of shape (obs_num_ems,), coordinates of all EMSs at the current timestep. ems_mask: jax array (bool) of shape (obs_num_ems,) indicates the EMSs that are valid. items: Item tree of jax arrays (float if normalize_dimensions else int32) each of shape (max_num_items,), characteristics of all items for this instance. items_mask: jax array (bool) of shape (max_num_items,) indicates the items that are valid. items_placed: jax array (bool) of shape (max_num_items,) indicates the items that have been placed so far. action_mask: jax array (bool) of shape (obs_num_ems, max_num_items) mask of the joint action space: True if the action (ems_id, item_id) is valid. action: MultiDiscreteArray (int32) of shape (obs_num_ems, max_num_items). ems_id: int between 0 and obs_num_ems - 1 (included). item_id: int between 0 and max_num_items - 1 (included). reward: jax array (float) of shape (), could be either: dense: increase in volume utilization of the container due to packing the chosen item. sparse: volume utilization of the container at the end of the episode. episode termination: if no action can be performed, i.e. no items fit in any EMSs, or all items have been packed. if an invalid action is taken, i.e. an item that does not fit in an EMS or one that is already packed. state: State coordinates: jax array (float) of shape (num_nodes + 1, 2) the coordinates of each node and the depot. demands: jax array (int32) of shape (num_nodes + 1,) the associated cost of each node and the depot (0.0 for the depot). position: jax array (int32) the index of the last visited node. capacity: jax array (int32) the current capacity of the vehicle. visited_mask: jax array (bool) of shape (num_nodes + 1,) binary mask (False/True <--> not visited/visited). trajectory: jax array (int32) of shape (2 * num_nodes,) identifiers of the nodes that have been visited (set to DEPOT_IDX if not filled yet). num_visits: int32 number of actions that have been taken (i.e., unique visits). 1 2 3 4 5 6 7 8 from jumanji.environments import BinPack env = BinPack () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state ) __init__ ( self , generator : Optional [ jumanji . environments . packing . bin_pack . generator . Generator ] = None , obs_num_ems : int = 40 , reward_fn : Optional [ jumanji . environments . packing . bin_pack . reward . RewardFn ] = None , normalize_dimensions : bool = True , debug : bool = False , viewer : Optional [ jumanji . viewer . Viewer [ jumanji . environments . packing . bin_pack . types . State ]] = None ) special # Instantiates a BinPack environment. Parameters: Name Type Description Default generator Optional[jumanji.environments.packing.bin_pack.generator.Generator] Generator whose __call__ instantiates an environment instance. Implemented options are [ RandomGenerator , ToyGenerator , CSVGenerator ]. Defaults to RandomGenerator that generates up to 20 items maximum and that can handle 40 EMSs. None obs_num_ems int number of EMSs (possible spaces in which to place an item) to show to the agent. If obs_num_ems is smaller than generator.max_num_ems , the first obs_num_ems largest EMSs (in terms of volume) will be returned in the observation. The good number heavily depends on the number of items (given by the instance generator). Default to 40 EMSs observable. 40 reward_fn Optional[jumanji.environments.packing.bin_pack.reward.RewardFn] compute the reward based on the current state, the chosen action, the next state, whether the transition is valid and if it is terminal. Implemented options are [ DenseReward , SparseReward ]. In each case, the total return at the end of an episode is the volume utilization of the container. Defaults to DenseReward . None normalize_dimensions bool if True, the observation is normalized (float) along each dimension into a unit cubic container. If False, the observation is returned in millimeters, i.e. integers (for both items and EMSs). Default to True. True debug bool if True, will add to timestep.extras an invalid_ems_from_env field that checks if an invalid EMS was created by the environment, which should not happen. Computing this metric slows down the environment. Default to False. False viewer Optional[jumanji.viewer.Viewer[jumanji.environments.packing.bin_pack.types.State]] Viewer used for rendering. Defaults to BinPackViewer with \"human\" render mode. None observation_spec ( self ) -> jumanji . specs . Spec [ jumanji . environments . packing . bin_pack . types . Observation ] # Specifications of the observation of the BinPack environment. Returns: Type Description Spec for the `Observation` whose fields are ems: if normalize_dimensions: tree of BoundedArray (float) of shape (obs_num_ems,). else: tree of BoundedArray (int32) of shape (obs_num_ems,). ems_mask: BoundedArray (bool) of shape (obs_num_ems,). items: if normalize_dimensions: tree of BoundedArray (float) of shape (max_num_items,). else: tree of BoundedArray (int32) of shape (max_num_items,). items_mask: BoundedArray (bool) of shape (max_num_items,). items_placed: BoundedArray (bool) of shape (max_num_items,). action_mask: BoundedArray (bool) of shape (obs_num_ems, max_num_items). action_spec ( self ) -> MultiDiscreteArray # Specifications of the action expected by the BinPack environment. Returns: Type Description MultiDiscreteArray (int32) of shape (obs_num_ems, max_num_items). - ems_id int between 0 and obs_num_ems - 1 (included). - item_id: int between 0 and max_num_items - 1 (included). reset ( self , key : PRNGKeyArray ) -> Tuple [ jumanji . environments . packing . bin_pack . types . State , jumanji . types . TimeStep [ jumanji . environments . packing . bin_pack . types . Observation ]] # Resets the environment by calling the instance generator for a new instance. Parameters: Name Type Description Default key PRNGKeyArray random key used to reset the environment. required Returns: Type Description state State object corresponding to the new state of the environment after a reset. timestep: TimeStep object corresponding the first timestep returned by the environment after a reset. Also contains the following metrics in the extras field: - volume_utilization: utilization (in [0, 1]) of the container. - packed_items: number of items that are packed in the container. - ratio_packed_items: ratio (in [0, 1]) of items that are packed in the container. - active_ems: number of active EMSs in the current instance. - invalid_action: True if the action that was just taken was invalid. - invalid_ems_from_env (optional): True if the environment produced an EMS that was invalid. Only available in debug mode. step ( self , state : State , action : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ]) -> Tuple [ jumanji . environments . packing . bin_pack . types . State , jumanji . types . TimeStep [ jumanji . environments . packing . bin_pack . types . Observation ]] # Run one timestep of the environment's dynamics. If the action is invalid, the state is not updated, i.e. the action is not taken, and the episode terminates. Parameters: Name Type Description Default state State State object containing the data of the current instance. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] jax array (int32) of shape (2,): (ems_id, item_id). This means placing the given item at the location of the given EMS. If the action is not valid, the flag invalid_action will be set to True in timestep.extras and the episode terminates. required Returns: Type Description state State object corresponding to the next state of the environment. timestep: TimeStep object corresponding to the timestep returned by the environment. Also contains metrics in the extras field: - volume_utilization: utilization (in [0, 1]) of the container. - packed_items: number of items that are packed in the container. - ratio_packed_items: ratio (in [0, 1]) of items that are packed in the container. - active_ems: number of EMSs in the current instance. - invalid_action: True if the action that was just taken was invalid. - invalid_ems_from_env (optional): True if the environment produced an EMS that was invalid. Only available in debug mode. render ( self , state : State ) -> Optional [ numpy . ndarray [ Any , numpy . dtype [ + ScalarType ]]] # Render the given state of the environment. Parameters: Name Type Description Default state State State object containing the current dynamics of the environment. required close ( self ) -> None # Perform any necessary cleanup. Environments will automatically :meth: close() themselves when garbage collected or when the program exits.","title":"BinPack"},{"location":"api/environments/bin_pack/#jumanji.environments.packing.bin_pack.env.BinPack","text":"Problem of 3D bin packing, where a set of items have to be placed in a 3D container with the goal of maximizing its volume utilization. This environment only supports 1 bin, meaning it is equivalent to the 3D-knapsack problem. We use the Empty Maximal Space (EMS) formulation of this problem. An EMS is a 3D-rectangular space that lives inside the container and has the following Properties It does not intersect any items, and it is not fully included into any other EMSs. It is defined by 2 3D-points, hence 6 coordinates (x1, x2, y1, y2, z1, z2), the first point corresponding to its bottom-left location while the second defining its top-right corner. observation: Observation ems: EMS tree of jax arrays (float if normalize_dimensions else int32) each of shape (obs_num_ems,), coordinates of all EMSs at the current timestep. ems_mask: jax array (bool) of shape (obs_num_ems,) indicates the EMSs that are valid. items: Item tree of jax arrays (float if normalize_dimensions else int32) each of shape (max_num_items,), characteristics of all items for this instance. items_mask: jax array (bool) of shape (max_num_items,) indicates the items that are valid. items_placed: jax array (bool) of shape (max_num_items,) indicates the items that have been placed so far. action_mask: jax array (bool) of shape (obs_num_ems, max_num_items) mask of the joint action space: True if the action (ems_id, item_id) is valid. action: MultiDiscreteArray (int32) of shape (obs_num_ems, max_num_items). ems_id: int between 0 and obs_num_ems - 1 (included). item_id: int between 0 and max_num_items - 1 (included). reward: jax array (float) of shape (), could be either: dense: increase in volume utilization of the container due to packing the chosen item. sparse: volume utilization of the container at the end of the episode. episode termination: if no action can be performed, i.e. no items fit in any EMSs, or all items have been packed. if an invalid action is taken, i.e. an item that does not fit in an EMS or one that is already packed. state: State coordinates: jax array (float) of shape (num_nodes + 1, 2) the coordinates of each node and the depot. demands: jax array (int32) of shape (num_nodes + 1,) the associated cost of each node and the depot (0.0 for the depot). position: jax array (int32) the index of the last visited node. capacity: jax array (int32) the current capacity of the vehicle. visited_mask: jax array (bool) of shape (num_nodes + 1,) binary mask (False/True <--> not visited/visited). trajectory: jax array (int32) of shape (2 * num_nodes,) identifiers of the nodes that have been visited (set to DEPOT_IDX if not filled yet). num_visits: int32 number of actions that have been taken (i.e., unique visits). 1 2 3 4 5 6 7 8 from jumanji.environments import BinPack env = BinPack () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state )","title":"BinPack"},{"location":"api/environments/bin_pack/#jumanji.environments.packing.bin_pack.env.BinPack.__init__","text":"Instantiates a BinPack environment. Parameters: Name Type Description Default generator Optional[jumanji.environments.packing.bin_pack.generator.Generator] Generator whose __call__ instantiates an environment instance. Implemented options are [ RandomGenerator , ToyGenerator , CSVGenerator ]. Defaults to RandomGenerator that generates up to 20 items maximum and that can handle 40 EMSs. None obs_num_ems int number of EMSs (possible spaces in which to place an item) to show to the agent. If obs_num_ems is smaller than generator.max_num_ems , the first obs_num_ems largest EMSs (in terms of volume) will be returned in the observation. The good number heavily depends on the number of items (given by the instance generator). Default to 40 EMSs observable. 40 reward_fn Optional[jumanji.environments.packing.bin_pack.reward.RewardFn] compute the reward based on the current state, the chosen action, the next state, whether the transition is valid and if it is terminal. Implemented options are [ DenseReward , SparseReward ]. In each case, the total return at the end of an episode is the volume utilization of the container. Defaults to DenseReward . None normalize_dimensions bool if True, the observation is normalized (float) along each dimension into a unit cubic container. If False, the observation is returned in millimeters, i.e. integers (for both items and EMSs). Default to True. True debug bool if True, will add to timestep.extras an invalid_ems_from_env field that checks if an invalid EMS was created by the environment, which should not happen. Computing this metric slows down the environment. Default to False. False viewer Optional[jumanji.viewer.Viewer[jumanji.environments.packing.bin_pack.types.State]] Viewer used for rendering. Defaults to BinPackViewer with \"human\" render mode. None","title":"__init__()"},{"location":"api/environments/bin_pack/#jumanji.environments.packing.bin_pack.env.BinPack.observation_spec","text":"Specifications of the observation of the BinPack environment. Returns: Type Description Spec for the `Observation` whose fields are ems: if normalize_dimensions: tree of BoundedArray (float) of shape (obs_num_ems,). else: tree of BoundedArray (int32) of shape (obs_num_ems,). ems_mask: BoundedArray (bool) of shape (obs_num_ems,). items: if normalize_dimensions: tree of BoundedArray (float) of shape (max_num_items,). else: tree of BoundedArray (int32) of shape (max_num_items,). items_mask: BoundedArray (bool) of shape (max_num_items,). items_placed: BoundedArray (bool) of shape (max_num_items,). action_mask: BoundedArray (bool) of shape (obs_num_ems, max_num_items).","title":"observation_spec()"},{"location":"api/environments/bin_pack/#jumanji.environments.packing.bin_pack.env.BinPack.action_spec","text":"Specifications of the action expected by the BinPack environment. Returns: Type Description MultiDiscreteArray (int32) of shape (obs_num_ems, max_num_items). - ems_id int between 0 and obs_num_ems - 1 (included). - item_id: int between 0 and max_num_items - 1 (included).","title":"action_spec()"},{"location":"api/environments/bin_pack/#jumanji.environments.packing.bin_pack.env.BinPack.reset","text":"Resets the environment by calling the instance generator for a new instance. Parameters: Name Type Description Default key PRNGKeyArray random key used to reset the environment. required Returns: Type Description state State object corresponding to the new state of the environment after a reset. timestep: TimeStep object corresponding the first timestep returned by the environment after a reset. Also contains the following metrics in the extras field: - volume_utilization: utilization (in [0, 1]) of the container. - packed_items: number of items that are packed in the container. - ratio_packed_items: ratio (in [0, 1]) of items that are packed in the container. - active_ems: number of active EMSs in the current instance. - invalid_action: True if the action that was just taken was invalid. - invalid_ems_from_env (optional): True if the environment produced an EMS that was invalid. Only available in debug mode.","title":"reset()"},{"location":"api/environments/bin_pack/#jumanji.environments.packing.bin_pack.env.BinPack.step","text":"Run one timestep of the environment's dynamics. If the action is invalid, the state is not updated, i.e. the action is not taken, and the episode terminates. Parameters: Name Type Description Default state State State object containing the data of the current instance. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] jax array (int32) of shape (2,): (ems_id, item_id). This means placing the given item at the location of the given EMS. If the action is not valid, the flag invalid_action will be set to True in timestep.extras and the episode terminates. required Returns: Type Description state State object corresponding to the next state of the environment. timestep: TimeStep object corresponding to the timestep returned by the environment. Also contains metrics in the extras field: - volume_utilization: utilization (in [0, 1]) of the container. - packed_items: number of items that are packed in the container. - ratio_packed_items: ratio (in [0, 1]) of items that are packed in the container. - active_ems: number of EMSs in the current instance. - invalid_action: True if the action that was just taken was invalid. - invalid_ems_from_env (optional): True if the environment produced an EMS that was invalid. Only available in debug mode.","title":"step()"},{"location":"api/environments/bin_pack/#jumanji.environments.packing.bin_pack.env.BinPack.render","text":"Render the given state of the environment. Parameters: Name Type Description Default state State State object containing the current dynamics of the environment. required","title":"render()"},{"location":"api/environments/bin_pack/#jumanji.environments.packing.bin_pack.env.BinPack.close","text":"Perform any necessary cleanup. Environments will automatically :meth: close() themselves when garbage collected or when the program exits.","title":"close()"},{"location":"api/environments/cleaner/","text":"Cleaner ( Environment ) # A JAX implementation of the 'Cleaner' game where multiple agents have to clean all tiles of a maze. observation: Observation grid: jax array (int32) of shape (num_rows, num_cols) contains the state of the board: 0 for dirty tile, 1 for clean tile, 2 for wall. agents_locations: jax array (int32) of shape (num_agents, 2) contains the location of each agent on the board. action_mask: jax array (bool) of shape (num_agents, 4) indicates for each agent if each of the four actions (up, right, down, left) is allowed. step_count: (int32) the number of step since the beginning of the episode. action: jax array (int32) of shape (num_agents,) the action for each agent: (0: up, 1: right, 2: down, 3: left) reward: jax array (float) of shape () +1 every time a tile is cleaned and a configurable penalty (-0.5 by default) for each timestep. episode termination: All tiles are clean. The number of steps is greater than the limit. An invalid action is selected for any of the agents. state: State grid: jax array (int32) of shape (num_rows, num_cols) contains the current state of the board: 0 for dirty tile, 1 for clean tile, 2 for wall. agents_locations: jax array (int32) of shape (num_agents, 2) contains the location of each agent on the board. action_mask: jax array (bool) of shape (num_agents, 4) indicates for each agent if each of the four actions (up, right, down, left) is allowed. step_count: jax array (int32) of shape () the number of steps since the beginning of the episode. key: jax array (uint) of shape (2,) jax random generation key. Ignored since the environment is deterministic. 1 2 3 4 5 6 7 8 from jumanji.environments import Cleaner env = Cleaner () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state ) __init__ ( self , generator : Optional [ jumanji . environments . routing . cleaner . generator . Generator ] = None , time_limit : Optional [ int ] = None , penalty_per_timestep : float = 0.5 , viewer : Optional [ jumanji . viewer . Viewer [ jumanji . environments . routing . cleaner . types . State ]] = None ) -> None special # Instantiates a Cleaner environment. Parameters: Name Type Description Default num_agents number of agents. Defaults to 3. required time_limit Optional[int] max number of steps in an episode. Defaults to num_rows * num_cols . None generator Optional[jumanji.environments.routing.cleaner.generator.Generator] Generator whose __call__ instantiates an environment instance. Implemented options are [ RandomGenerator ]. Defaults to RandomGenerator with num_rows=10 , num_cols=10 and num_agents=3 . None viewer Optional[jumanji.viewer.Viewer[jumanji.environments.routing.cleaner.types.State]] Viewer used for rendering. Defaults to CleanerViewer with \"human\" render mode. None penalty_per_timestep float the penalty returned at each timestep in the reward. 0.5 __repr__ ( self ) -> str special # observation_spec ( self ) -> jumanji . specs . Spec [ jumanji . environments . routing . cleaner . types . Observation ] # Specification of the observation of the Cleaner environment. Returns: Type Description Spec for the `Observation`, consisting of the fields grid: BoundedArray (int32) of shape (num_rows, num_cols). Values are between 0 and 2 (inclusive). agent_locations_spec: BoundedArray (int32) of shape (num_agents, 2). Maximum value for the first column is num_rows, and maximum value for the second is num_cols. action_mask: BoundedArray (bool) of shape (num_agent, 4). step_count: BoundedArray (int32) of shape (). action_spec ( self ) -> MultiDiscreteArray # Specification of the action for the Cleaner environment. Returns: Type Description action_spec a specs.MultiDiscreteArray spec. reset ( self , key : PRNGKeyArray ) -> Tuple [ jumanji . environments . routing . cleaner . types . State , jumanji . types . TimeStep [ jumanji . environments . routing . cleaner . types . Observation ]] # Reset the environment to its initial state. All the tiles except upper left are dirty, and the agents start in the upper left corner of the grid. Parameters: Name Type Description Default key PRNGKeyArray random key used to reset the environment. required Returns: Type Description state State object corresponding to the new state of the environment after a reset. timestep: TimeStep object corresponding to the first timestep returned by the environment after a reset. step ( self , state : State , action : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ]) -> Tuple [ jumanji . environments . routing . cleaner . types . State , jumanji . types . TimeStep [ jumanji . environments . routing . cleaner . types . Observation ]] # Run one timestep of the environment's dynamics. If an action is invalid, the corresponding agent does not move and the episode terminates. Parameters: Name Type Description Default state State current environment state. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] Jax array of shape (num_agents,). Each agent moves one step in the specified direction (0: up, 1: right, 2: down, 3: left). required Returns: Type Description state State object corresponding to the next state of the environment. timestep: TimeStep object corresponding to the timestep returned by the environment. render ( self , state : State ) -> Optional [ numpy . ndarray [ Any , numpy . dtype [ + ScalarType ]]] # Render the given state of the environment. Parameters: Name Type Description Default state State State object containing the current environment state. required animate ( self , states : Sequence [ jumanji . environments . routing . cleaner . types . State ], interval : int = 200 , save_path : Optional [ str ] = None ) -> FuncAnimation # Creates an animated gif of the Cleaner environment based on the sequence of states. Parameters: Name Type Description Default states Sequence[jumanji.environments.routing.cleaner.types.State] sequence of environment states corresponding to consecutive timesteps. required interval int delay between frames in milliseconds, default to 200. 200 save_path Optional[str] the path where the animation file should be saved. If it is None, the plot will not be saved. None Returns: Type Description animation.FuncAnimation the animation object that was created. close ( self ) -> None # Perform any necessary cleanup. Environments will automatically :meth: close() themselves when garbage collected or when the program exits.","title":"Cleaner"},{"location":"api/environments/cleaner/#jumanji.environments.routing.cleaner.env.Cleaner","text":"A JAX implementation of the 'Cleaner' game where multiple agents have to clean all tiles of a maze. observation: Observation grid: jax array (int32) of shape (num_rows, num_cols) contains the state of the board: 0 for dirty tile, 1 for clean tile, 2 for wall. agents_locations: jax array (int32) of shape (num_agents, 2) contains the location of each agent on the board. action_mask: jax array (bool) of shape (num_agents, 4) indicates for each agent if each of the four actions (up, right, down, left) is allowed. step_count: (int32) the number of step since the beginning of the episode. action: jax array (int32) of shape (num_agents,) the action for each agent: (0: up, 1: right, 2: down, 3: left) reward: jax array (float) of shape () +1 every time a tile is cleaned and a configurable penalty (-0.5 by default) for each timestep. episode termination: All tiles are clean. The number of steps is greater than the limit. An invalid action is selected for any of the agents. state: State grid: jax array (int32) of shape (num_rows, num_cols) contains the current state of the board: 0 for dirty tile, 1 for clean tile, 2 for wall. agents_locations: jax array (int32) of shape (num_agents, 2) contains the location of each agent on the board. action_mask: jax array (bool) of shape (num_agents, 4) indicates for each agent if each of the four actions (up, right, down, left) is allowed. step_count: jax array (int32) of shape () the number of steps since the beginning of the episode. key: jax array (uint) of shape (2,) jax random generation key. Ignored since the environment is deterministic. 1 2 3 4 5 6 7 8 from jumanji.environments import Cleaner env = Cleaner () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state )","title":"Cleaner"},{"location":"api/environments/cleaner/#jumanji.environments.routing.cleaner.env.Cleaner.__init__","text":"Instantiates a Cleaner environment. Parameters: Name Type Description Default num_agents number of agents. Defaults to 3. required time_limit Optional[int] max number of steps in an episode. Defaults to num_rows * num_cols . None generator Optional[jumanji.environments.routing.cleaner.generator.Generator] Generator whose __call__ instantiates an environment instance. Implemented options are [ RandomGenerator ]. Defaults to RandomGenerator with num_rows=10 , num_cols=10 and num_agents=3 . None viewer Optional[jumanji.viewer.Viewer[jumanji.environments.routing.cleaner.types.State]] Viewer used for rendering. Defaults to CleanerViewer with \"human\" render mode. None penalty_per_timestep float the penalty returned at each timestep in the reward. 0.5","title":"__init__()"},{"location":"api/environments/cleaner/#jumanji.environments.routing.cleaner.env.Cleaner.__repr__","text":"","title":"__repr__()"},{"location":"api/environments/cleaner/#jumanji.environments.routing.cleaner.env.Cleaner.observation_spec","text":"Specification of the observation of the Cleaner environment. Returns: Type Description Spec for the `Observation`, consisting of the fields grid: BoundedArray (int32) of shape (num_rows, num_cols). Values are between 0 and 2 (inclusive). agent_locations_spec: BoundedArray (int32) of shape (num_agents, 2). Maximum value for the first column is num_rows, and maximum value for the second is num_cols. action_mask: BoundedArray (bool) of shape (num_agent, 4). step_count: BoundedArray (int32) of shape ().","title":"observation_spec()"},{"location":"api/environments/cleaner/#jumanji.environments.routing.cleaner.env.Cleaner.action_spec","text":"Specification of the action for the Cleaner environment. Returns: Type Description action_spec a specs.MultiDiscreteArray spec.","title":"action_spec()"},{"location":"api/environments/cleaner/#jumanji.environments.routing.cleaner.env.Cleaner.reset","text":"Reset the environment to its initial state. All the tiles except upper left are dirty, and the agents start in the upper left corner of the grid. Parameters: Name Type Description Default key PRNGKeyArray random key used to reset the environment. required Returns: Type Description state State object corresponding to the new state of the environment after a reset. timestep: TimeStep object corresponding to the first timestep returned by the environment after a reset.","title":"reset()"},{"location":"api/environments/cleaner/#jumanji.environments.routing.cleaner.env.Cleaner.step","text":"Run one timestep of the environment's dynamics. If an action is invalid, the corresponding agent does not move and the episode terminates. Parameters: Name Type Description Default state State current environment state. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] Jax array of shape (num_agents,). Each agent moves one step in the specified direction (0: up, 1: right, 2: down, 3: left). required Returns: Type Description state State object corresponding to the next state of the environment. timestep: TimeStep object corresponding to the timestep returned by the environment.","title":"step()"},{"location":"api/environments/cleaner/#jumanji.environments.routing.cleaner.env.Cleaner.render","text":"Render the given state of the environment. Parameters: Name Type Description Default state State State object containing the current environment state. required","title":"render()"},{"location":"api/environments/cleaner/#jumanji.environments.routing.cleaner.env.Cleaner.animate","text":"Creates an animated gif of the Cleaner environment based on the sequence of states. Parameters: Name Type Description Default states Sequence[jumanji.environments.routing.cleaner.types.State] sequence of environment states corresponding to consecutive timesteps. required interval int delay between frames in milliseconds, default to 200. 200 save_path Optional[str] the path where the animation file should be saved. If it is None, the plot will not be saved. None Returns: Type Description animation.FuncAnimation the animation object that was created.","title":"animate()"},{"location":"api/environments/cleaner/#jumanji.environments.routing.cleaner.env.Cleaner.close","text":"Perform any necessary cleanup. Environments will automatically :meth: close() themselves when garbage collected or when the program exits.","title":"close()"},{"location":"api/environments/connector/","text":"Connector ( Environment ) # The Connector environment is a gridworld problem where multiple pairs of points (sets) must be connected without overlapping the paths taken by any other set. This is achieved by allowing certain points to move to an adjacent cell at each step. However, each time a point moves it leaves an impassable trail behind it. The goal is to connect all sets. observation - Observation action mask: jax array (bool) of shape (num_agents, 5). step_count: jax array (int32) of shape () the current episode step. grid: jax array (int32) of shape (grid_size, grid_size) with 2 agents you might have a grid like this: 4 0 1 5 0 1 6 3 2 which means agent 1 has moved from the top right of the grid down and is currently in the bottom right corner and is aiming to get to the middle bottom cell. Agent 2 started in the top left and moved down once towards its target in the bottom left. action: jax array (int32) of shape (num_agents,): can take the values [0,1,2,3,4] which correspond to [No Op, Up, Right, Down, Left]. each value in the array corresponds to an agent's action. reward: jax array (float) of shape (): dense: reward is 1 for each successful connection on that step. Additionally, each pair of points that have not connected receives a penalty reward of -0.03. episode termination: all agents either can't move (no available actions) or have connected to their target. the time limit is reached. state: State: key: jax PRNG key used to randomly spawn agents and targets. grid: jax array (int32) of shape (grid_size, grid_size) giving the observation. step_count: jax array (int32) of shape () number of steps elapsed in the current episode. 1 2 3 4 5 6 7 8 from jumanji.environments import Connector env = Connector () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state ) reset ( self , key : PRNGKeyArray ) -> Tuple [ jumanji . environments . routing . connector . types . State , jumanji . types . TimeStep [ jumanji . environments . routing . connector . types . Observation ]] # Resets the environment. Parameters: Name Type Description Default key PRNGKeyArray used to randomly generate the connector grid. required Returns: Type Description state State object corresponding to the new state of the environment. timestep: TimeStep object corresponding to the initial environment timestep. step ( self , state : State , action : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ]) -> Tuple [ jumanji . environments . routing . connector . types . State , jumanji . types . TimeStep [ jumanji . environments . routing . connector . types . Observation ]] # Perform an environment step. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] Array containing the actions to take for each agent. - 0 no op - 1 move up - 2 move right - 3 move down - 4 move left required Returns: Type Description state State object corresponding to the next state of the environment. timestep: TimeStep object corresponding the timestep returned by the environment. render ( self , state : State ) -> Optional [ numpy . ndarray [ Any , numpy . dtype [ + ScalarType ]]] # Render the given state of the environment. Parameters: Name Type Description Default state State State object containing the current environment state. required observation_spec ( self ) -> jumanji . specs . Spec [ jumanji . environments . routing . connector . types . Observation ] # Specifications of the observation of the Connector environment. Returns: Type Description Spec for the `Observation` whose fields are grid: BoundedArray (int32) of shape (grid_size, grid_size). action_mask: BoundedArray (bool) of shape (num_agents, 5). step_count: BoundedArray (int32) of shape (). action_spec ( self ) -> MultiDiscreteArray # Returns the action spec for the Connector environment. 5 actions: [0,1,2,3,4] -> [No Op, Up, Right, Down, Left]. Since this is an environment with a multi-dimensional action space, it expects an array of actions of shape (num_agents,). Returns: Type Description observation_spec MultiDiscreteArray of shape (num_agents,).","title":"Connector"},{"location":"api/environments/connector/#jumanji.environments.routing.connector.env.Connector","text":"The Connector environment is a gridworld problem where multiple pairs of points (sets) must be connected without overlapping the paths taken by any other set. This is achieved by allowing certain points to move to an adjacent cell at each step. However, each time a point moves it leaves an impassable trail behind it. The goal is to connect all sets. observation - Observation action mask: jax array (bool) of shape (num_agents, 5). step_count: jax array (int32) of shape () the current episode step. grid: jax array (int32) of shape (grid_size, grid_size) with 2 agents you might have a grid like this: 4 0 1 5 0 1 6 3 2 which means agent 1 has moved from the top right of the grid down and is currently in the bottom right corner and is aiming to get to the middle bottom cell. Agent 2 started in the top left and moved down once towards its target in the bottom left. action: jax array (int32) of shape (num_agents,): can take the values [0,1,2,3,4] which correspond to [No Op, Up, Right, Down, Left]. each value in the array corresponds to an agent's action. reward: jax array (float) of shape (): dense: reward is 1 for each successful connection on that step. Additionally, each pair of points that have not connected receives a penalty reward of -0.03. episode termination: all agents either can't move (no available actions) or have connected to their target. the time limit is reached. state: State: key: jax PRNG key used to randomly spawn agents and targets. grid: jax array (int32) of shape (grid_size, grid_size) giving the observation. step_count: jax array (int32) of shape () number of steps elapsed in the current episode. 1 2 3 4 5 6 7 8 from jumanji.environments import Connector env = Connector () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state )","title":"Connector"},{"location":"api/environments/connector/#jumanji.environments.routing.connector.env.Connector.reset","text":"Resets the environment. Parameters: Name Type Description Default key PRNGKeyArray used to randomly generate the connector grid. required Returns: Type Description state State object corresponding to the new state of the environment. timestep: TimeStep object corresponding to the initial environment timestep.","title":"reset()"},{"location":"api/environments/connector/#jumanji.environments.routing.connector.env.Connector.step","text":"Perform an environment step. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] Array containing the actions to take for each agent. - 0 no op - 1 move up - 2 move right - 3 move down - 4 move left required Returns: Type Description state State object corresponding to the next state of the environment. timestep: TimeStep object corresponding the timestep returned by the environment.","title":"step()"},{"location":"api/environments/connector/#jumanji.environments.routing.connector.env.Connector.render","text":"Render the given state of the environment. Parameters: Name Type Description Default state State State object containing the current environment state. required","title":"render()"},{"location":"api/environments/connector/#jumanji.environments.routing.connector.env.Connector.observation_spec","text":"Specifications of the observation of the Connector environment. Returns: Type Description Spec for the `Observation` whose fields are grid: BoundedArray (int32) of shape (grid_size, grid_size). action_mask: BoundedArray (bool) of shape (num_agents, 5). step_count: BoundedArray (int32) of shape ().","title":"observation_spec()"},{"location":"api/environments/connector/#jumanji.environments.routing.connector.env.Connector.action_spec","text":"Returns the action spec for the Connector environment. 5 actions: [0,1,2,3,4] -> [No Op, Up, Right, Down, Left]. Since this is an environment with a multi-dimensional action space, it expects an array of actions of shape (num_agents,). Returns: Type Description observation_spec MultiDiscreteArray of shape (num_agents,).","title":"action_spec()"},{"location":"api/environments/cvrp/","text":"CVRP ( Environment ) # Capacitated Vehicle Routing Problem (CVRP) environment as described in [1]. observation: Observation coordinates: jax array (float) of shape (num_nodes + 1, 2) the coordinates of each node and the depot. demands: jax array (float) of shape (num_nodes + 1,) the associated cost of each node and the depot (0.0 for the depot). unvisited_nodes: jax array (bool) of shape (num_nodes + 1,) indicates nodes that remain to be visited. position: jax array (int32) of shape () the index of the last visited node. trajectory: jax array (int32) of shape (2 * num_nodes,) array of node indices defining the route (set to DEPOT_IDX if not filled yet). capacity: jax array (float) of shape () the current capacity of the vehicle. action_mask: jax array (bool) of shape (num_nodes + 1,) binary mask (False/True <--> invalid/valid action). action: jax array (int32) of shape () [0, ..., num_nodes] -> node to visit. 0 corresponds to visiting the depot. reward: jax array (float) of shape (), could be either: dense: the negative distance between the current node and the chosen next node to go to. For the last node, it also includes the distance to the depot to complete the tour. sparse: the negative tour length at the end of the episode. The tour length is defined as the sum of the distances between consecutive nodes. In both cases, the reward is a large negative penalty of -2 * num_nodes * sqrt(2) if the action is invalid, e.g. a previously selected node other than the depot is selected again. episode termination: if no action can be performed, i.e. all nodes have been visited. if an invalid action is taken, i.e. a previously visited city other than the depot is chosen. state: State coordinates: jax array (float) of shape (num_nodes + 1, 2) the coordinates of each node and the depot. demands: jax array (int32) of shape (num_nodes + 1,) the associated cost of each node and the depot (0.0 for the depot). position: jax array (int32) the index of the last visited node. capacity: jax array (int32) the current capacity of the vehicle. visited_mask: jax array (bool) of shape (num_nodes + 1,) binary mask (False/True <--> not visited/visited). trajectory: jax array (int32) of shape (2 * num_nodes,) identifiers of the nodes that have been visited (set to DEPOT_IDX if not filled yet). num_visits: int32 number of actions that have been taken (i.e., unique visits). [1] Toth P., Vigo D. (2014). \"Vehicle routing: problems, methods, and applications\". 1 2 3 4 5 6 7 8 from jumanji.environments import CVRP env = CVRP () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state ) __init__ ( self , generator : Optional [ jumanji . environments . routing . cvrp . generator . Generator ] = None , reward_fn : Optional [ jumanji . environments . routing . cvrp . reward . RewardFn ] = None , viewer : Optional [ jumanji . viewer . Viewer [ jumanji . environments . routing . cvrp . types . State ]] = None ) special # Instantiates a CVRP environment. Parameters: Name Type Description Default generator Optional[jumanji.environments.routing.cvrp.generator.Generator] Generator whose __call__ instantiates an environment instance. The default option is 'UniformGenerator' which randomly generates CVRP instances with 20 cities sampled from a uniform distribution, a maximum vehicle capacity of 30, and a maximum city demand of 10. None reward_fn Optional[jumanji.environments.routing.cvrp.reward.RewardFn] RewardFn whose __call__ method computes the reward of an environment transition. The function must compute the reward based on the current state, the chosen action, the next state and whether the action is valid. Implemented options are [ DenseReward , SparseReward ]. Defaults to DenseReward . None viewer Optional[jumanji.viewer.Viewer[jumanji.environments.routing.cvrp.types.State]] Viewer used for rendering. Defaults to CVRPViewer with \"human\" render mode. None reset ( self , key : PRNGKeyArray ) -> Tuple [ jumanji . environments . routing . cvrp . types . State , jumanji . types . TimeStep [ jumanji . environments . routing . cvrp . types . Observation ]] # Resets the environment. Parameters: Name Type Description Default key PRNGKeyArray used to randomly generate the coordinates. required Returns: Type Description state State object corresponding to the new state of the environment. timestep: TimeStep object corresponding to the first timestep returned by the environment. step ( self , state : State , action : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number , float , int ]) -> Tuple [ jumanji . environments . routing . cvrp . types . State , jumanji . types . TimeStep [ jumanji . environments . routing . cvrp . types . Observation ]] # Run one timestep of the environment's dynamics. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, float, int] jax array (int32) of shape () containing the index of the next node to visit. required Returns: Type Description state, timestep next state of the environment and timestep to be observed. observation_spec ( self ) -> jumanji . specs . Spec [ jumanji . environments . routing . cvrp . types . Observation ] # Returns the observation spec. Returns: Type Description Spec for the `Observation` whose fields are coordinates: BoundedArray (float) of shape (num_nodes + 1, 2). demands: BoundedArray (float) of shape (num_nodes + 1,). unvisited_nodes: BoundedArray (bool) of shape (num_nodes + 1,). position: DiscreteArray (num_values = num_nodes + 1) of shape (). trajectory: BoundedArray (int32) of shape (2 * num_nodes,). capacity: BoundedArray (float) of shape (). action_mask: BoundedArray (bool) of shape (num_nodes + 1,). action_spec ( self ) -> DiscreteArray # Returns the action spec. Returns: Type Description action_spec a specs.DiscreteArray spec.","title":"CVRP"},{"location":"api/environments/cvrp/#jumanji.environments.routing.cvrp.env.CVRP","text":"Capacitated Vehicle Routing Problem (CVRP) environment as described in [1]. observation: Observation coordinates: jax array (float) of shape (num_nodes + 1, 2) the coordinates of each node and the depot. demands: jax array (float) of shape (num_nodes + 1,) the associated cost of each node and the depot (0.0 for the depot). unvisited_nodes: jax array (bool) of shape (num_nodes + 1,) indicates nodes that remain to be visited. position: jax array (int32) of shape () the index of the last visited node. trajectory: jax array (int32) of shape (2 * num_nodes,) array of node indices defining the route (set to DEPOT_IDX if not filled yet). capacity: jax array (float) of shape () the current capacity of the vehicle. action_mask: jax array (bool) of shape (num_nodes + 1,) binary mask (False/True <--> invalid/valid action). action: jax array (int32) of shape () [0, ..., num_nodes] -> node to visit. 0 corresponds to visiting the depot. reward: jax array (float) of shape (), could be either: dense: the negative distance between the current node and the chosen next node to go to. For the last node, it also includes the distance to the depot to complete the tour. sparse: the negative tour length at the end of the episode. The tour length is defined as the sum of the distances between consecutive nodes. In both cases, the reward is a large negative penalty of -2 * num_nodes * sqrt(2) if the action is invalid, e.g. a previously selected node other than the depot is selected again. episode termination: if no action can be performed, i.e. all nodes have been visited. if an invalid action is taken, i.e. a previously visited city other than the depot is chosen. state: State coordinates: jax array (float) of shape (num_nodes + 1, 2) the coordinates of each node and the depot. demands: jax array (int32) of shape (num_nodes + 1,) the associated cost of each node and the depot (0.0 for the depot). position: jax array (int32) the index of the last visited node. capacity: jax array (int32) the current capacity of the vehicle. visited_mask: jax array (bool) of shape (num_nodes + 1,) binary mask (False/True <--> not visited/visited). trajectory: jax array (int32) of shape (2 * num_nodes,) identifiers of the nodes that have been visited (set to DEPOT_IDX if not filled yet). num_visits: int32 number of actions that have been taken (i.e., unique visits). [1] Toth P., Vigo D. (2014). \"Vehicle routing: problems, methods, and applications\". 1 2 3 4 5 6 7 8 from jumanji.environments import CVRP env = CVRP () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state )","title":"CVRP"},{"location":"api/environments/cvrp/#jumanji.environments.routing.cvrp.env.CVRP.__init__","text":"Instantiates a CVRP environment. Parameters: Name Type Description Default generator Optional[jumanji.environments.routing.cvrp.generator.Generator] Generator whose __call__ instantiates an environment instance. The default option is 'UniformGenerator' which randomly generates CVRP instances with 20 cities sampled from a uniform distribution, a maximum vehicle capacity of 30, and a maximum city demand of 10. None reward_fn Optional[jumanji.environments.routing.cvrp.reward.RewardFn] RewardFn whose __call__ method computes the reward of an environment transition. The function must compute the reward based on the current state, the chosen action, the next state and whether the action is valid. Implemented options are [ DenseReward , SparseReward ]. Defaults to DenseReward . None viewer Optional[jumanji.viewer.Viewer[jumanji.environments.routing.cvrp.types.State]] Viewer used for rendering. Defaults to CVRPViewer with \"human\" render mode. None","title":"__init__()"},{"location":"api/environments/cvrp/#jumanji.environments.routing.cvrp.env.CVRP.reset","text":"Resets the environment. Parameters: Name Type Description Default key PRNGKeyArray used to randomly generate the coordinates. required Returns: Type Description state State object corresponding to the new state of the environment. timestep: TimeStep object corresponding to the first timestep returned by the environment.","title":"reset()"},{"location":"api/environments/cvrp/#jumanji.environments.routing.cvrp.env.CVRP.step","text":"Run one timestep of the environment's dynamics. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, float, int] jax array (int32) of shape () containing the index of the next node to visit. required Returns: Type Description state, timestep next state of the environment and timestep to be observed.","title":"step()"},{"location":"api/environments/cvrp/#jumanji.environments.routing.cvrp.env.CVRP.observation_spec","text":"Returns the observation spec. Returns: Type Description Spec for the `Observation` whose fields are coordinates: BoundedArray (float) of shape (num_nodes + 1, 2). demands: BoundedArray (float) of shape (num_nodes + 1,). unvisited_nodes: BoundedArray (bool) of shape (num_nodes + 1,). position: DiscreteArray (num_values = num_nodes + 1) of shape (). trajectory: BoundedArray (int32) of shape (2 * num_nodes,). capacity: BoundedArray (float) of shape (). action_mask: BoundedArray (bool) of shape (num_nodes + 1,).","title":"observation_spec()"},{"location":"api/environments/cvrp/#jumanji.environments.routing.cvrp.env.CVRP.action_spec","text":"Returns the action spec. Returns: Type Description action_spec a specs.DiscreteArray spec.","title":"action_spec()"},{"location":"api/environments/game_2048/","text":"Game2048 ( Environment ) # Environment for the game 2048. The game consists of a board of size board_size x board_size (4x4 by default) in which the player can take actions to move the tiles on the board up, down, left, or right. The goal of the game is to combine tiles with the same number to create a tile with twice the value, until the player at least creates a tile with the value 2048 to consider it a win. observation: Observation board: jax array (int32) of shape (board_size, board_size) the current state of the board. An empty tile is represented by zero whereas a non-empty tile is an exponent of 2, e.g. 1, 2, 3, 4, ... (corresponding to 2, 4, 8, 16, ...). action_mask: jax array (bool) of shape (4,) indicates which actions are valid in the current state of the environment. action: jax array (int32) of shape (). Is in [0, 1, 2, 3] representing the actions up, right, down, and left, respectively. reward: jax array (float) of shape (). The reward is 0 except when the player combines tiles to create a new tile with twice the value. In this case, the reward is the value of the new tile. episode termination: if no more valid moves exist (this can happen when the board is full). state: State board: same as observation. step_count: jax array (int32) of shape (), the number of time steps in the episode so far. action_mask: same as observation. score: jax array (int32) of shape (), the sum of all tile values on the board. key: jax array (uint32) of shape (2,) random key used to generate random numbers at each step and for auto-reset. 1 2 3 4 5 6 7 8 from jumanji.environments import Game2048 env = Game2048 () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state ) __init__ ( self , board_size : int = 4 , viewer : Optional [ jumanji . viewer . Viewer [ jumanji . environments . logic . game_2048 . types . State ]] = None ) -> None special # Initialize the 2048 game. Parameters: Name Type Description Default board_size int size of the board. Defaults to 4. 4 viewer Optional[jumanji.viewer.Viewer[jumanji.environments.logic.game_2048.types.State]] Viewer used for rendering. Defaults to Game2048Viewer . None observation_spec ( self ) -> jumanji . specs . Spec [ jumanji . environments . logic . game_2048 . types . Observation ] # Specifications of the observation of the Game2048 environment. Returns: Type Description Spec containing all the specifications for all the `Observation` fields board: Array (jnp.int32) of shape (board_size, board_size). action_mask: BoundedArray (bool) of shape (4,). action_spec ( self ) -> DiscreteArray # Returns the action spec. 4 actions: [0, 1, 2, 3] -> [Up, Right, Down, Left]. Returns: Type Description action_spec DiscreteArray spec object. reset ( self , key : PRNGKeyArray ) -> Tuple [ jumanji . environments . logic . game_2048 . types . State , jumanji . types . TimeStep [ jumanji . environments . logic . game_2048 . types . Observation ]] # Resets the environment. Parameters: Name Type Description Default key PRNGKeyArray random number generator key. required Returns: Type Description state the new state of the environment. timestep: the first timestep returned by the environment. step ( self , state : State , action : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ]) -> Tuple [ jumanji . environments . logic . game_2048 . types . State , jumanji . types . TimeStep [ jumanji . environments . logic . game_2048 . types . Observation ]] # Updates the environment state after the agent takes an action. Parameters: Name Type Description Default state State the current state of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] the action taken by the agent. required Returns: Type Description state the new state of the environment. timestep: the next timestep.","title":"Game2048"},{"location":"api/environments/game_2048/#jumanji.environments.logic.game_2048.env.Game2048","text":"Environment for the game 2048. The game consists of a board of size board_size x board_size (4x4 by default) in which the player can take actions to move the tiles on the board up, down, left, or right. The goal of the game is to combine tiles with the same number to create a tile with twice the value, until the player at least creates a tile with the value 2048 to consider it a win. observation: Observation board: jax array (int32) of shape (board_size, board_size) the current state of the board. An empty tile is represented by zero whereas a non-empty tile is an exponent of 2, e.g. 1, 2, 3, 4, ... (corresponding to 2, 4, 8, 16, ...). action_mask: jax array (bool) of shape (4,) indicates which actions are valid in the current state of the environment. action: jax array (int32) of shape (). Is in [0, 1, 2, 3] representing the actions up, right, down, and left, respectively. reward: jax array (float) of shape (). The reward is 0 except when the player combines tiles to create a new tile with twice the value. In this case, the reward is the value of the new tile. episode termination: if no more valid moves exist (this can happen when the board is full). state: State board: same as observation. step_count: jax array (int32) of shape (), the number of time steps in the episode so far. action_mask: same as observation. score: jax array (int32) of shape (), the sum of all tile values on the board. key: jax array (uint32) of shape (2,) random key used to generate random numbers at each step and for auto-reset. 1 2 3 4 5 6 7 8 from jumanji.environments import Game2048 env = Game2048 () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state )","title":"Game2048"},{"location":"api/environments/game_2048/#jumanji.environments.logic.game_2048.env.Game2048.__init__","text":"Initialize the 2048 game. Parameters: Name Type Description Default board_size int size of the board. Defaults to 4. 4 viewer Optional[jumanji.viewer.Viewer[jumanji.environments.logic.game_2048.types.State]] Viewer used for rendering. Defaults to Game2048Viewer . None","title":"__init__()"},{"location":"api/environments/game_2048/#jumanji.environments.logic.game_2048.env.Game2048.observation_spec","text":"Specifications of the observation of the Game2048 environment. Returns: Type Description Spec containing all the specifications for all the `Observation` fields board: Array (jnp.int32) of shape (board_size, board_size). action_mask: BoundedArray (bool) of shape (4,).","title":"observation_spec()"},{"location":"api/environments/game_2048/#jumanji.environments.logic.game_2048.env.Game2048.action_spec","text":"Returns the action spec. 4 actions: [0, 1, 2, 3] -> [Up, Right, Down, Left]. Returns: Type Description action_spec DiscreteArray spec object.","title":"action_spec()"},{"location":"api/environments/game_2048/#jumanji.environments.logic.game_2048.env.Game2048.reset","text":"Resets the environment. Parameters: Name Type Description Default key PRNGKeyArray random number generator key. required Returns: Type Description state the new state of the environment. timestep: the first timestep returned by the environment.","title":"reset()"},{"location":"api/environments/game_2048/#jumanji.environments.logic.game_2048.env.Game2048.step","text":"Updates the environment state after the agent takes an action. Parameters: Name Type Description Default state State the current state of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] the action taken by the agent. required Returns: Type Description state the new state of the environment. timestep: the next timestep.","title":"step()"},{"location":"api/environments/graph_coloring/","text":"GraphColoring ( Environment ) # Environment for the GraphColoring problem. The problem is a combinatorial optimization task where the goal is to assign a color to each vertex of a graph in such a way that no two adjacent vertices share the same color. The problem is usually formulated as minimizing the number of colors used. observation: Observation adj_matrix: jax array (bool) of shape (num_nodes, num_nodes), representing the adjacency matrix of the graph. colors: jax array (int32) of shape (num_nodes,), representing the current color assignments for the vertices. action_mask: jax array (bool) of shape (num_colors,), indicating which actions are valid in the current state of the environment. current_node_index: integer representing the current node being colored. action: int, the color to be assigned to the current node (0 to num_nodes - 1) reward: float, a sparse reward is provided at the end of the episode. Equals the negative of the number of unique colors used to color all vertices in the graph. If an invalid action is taken, the reward is the negative of the total number of colors. episode termination: if all nodes have been assigned a color or if an invalid action is taken. state: State adj_matrix: jax array (bool) of shape (num_nodes, num_nodes), representing the adjacency matrix of the graph. colors: jax array (int32) of shape (num_nodes,), color assigned to each node, -1 if not assigned. current_node_index: jax array (int) with shape (), index of the current node. action_mask: jax array (bool) of shape (num_colors,), indicating which actions are valid in the current state of the environment. key: jax array (uint32) of shape (2,), random key used to generate random numbers at each step and for auto-reset. 1 2 3 4 5 6 7 8 from jumanji.environments import GraphColoring env = GraphColoring () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state ) __init__ ( self , generator : Optional [ jumanji . environments . logic . graph_coloring . generator . Generator ] = None , viewer : Optional [ jumanji . viewer . Viewer [ jumanji . environments . logic . graph_coloring . types . State ]] = None ) special # Instantiate a GraphColoring environment. Parameters: Name Type Description Default generator Optional[jumanji.environments.logic.graph_coloring.generator.Generator] callable to instantiate environment instances. Defaults to RandomGenerator which generates graphs with 20 num_nodes and edge_probability equal to 0.8. None viewer Optional[jumanji.viewer.Viewer[jumanji.environments.logic.graph_coloring.types.State]] environment viewer for rendering. Defaults to GraphColoringViewer . None reset ( self , key : PRNGKeyArray ) -> Tuple [ jumanji . environments . logic . graph_coloring . types . State , jumanji . types . TimeStep [ jumanji . environments . logic . graph_coloring . types . Observation ]] # Resets the environment to an initial state. Returns: Type Description Tuple[jumanji.environments.logic.graph_coloring.types.State, jumanji.types.TimeStep[jumanji.environments.logic.graph_coloring.types.Observation]] The initial state and timestep. step ( self , state : State , action : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ]) -> Tuple [ jumanji . environments . logic . graph_coloring . types . State , jumanji . types . TimeStep [ jumanji . environments . logic . graph_coloring . types . Observation ]] # Updates the environment state after the agent takes an action. Specifically, this function allows the agent to choose a color for the current node (based on the action taken) in a graph coloring problem. It then updates the state of the environment based on the color chosen and calculates the reward based on the validity of the action and the completion of the coloring task. Parameters: Name Type Description Default state State the current state of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] the action taken by the agent. required Returns: Type Description state the new state of the environment. timestep: the next timestep. observation_spec ( self ) -> jumanji . specs . Spec [ jumanji . environments . logic . graph_coloring . types . Observation ] # Returns the observation spec. Returns: Type Description Spec for the `Observation` whose fields are adj_matrix: BoundedArray (bool) of shape (num_nodes, num_nodes). Represents the adjacency matrix of the graph. action_mask: BoundedArray (bool) of shape (num_nodes,). Represents the valid actions in the current state. colors: BoundedArray (int32) of shape (num_nodes,). Represents the colors assigned to each node. current_node_index: BoundedArray (int32) of shape (). Represents the index of the current node. action_spec ( self ) -> DiscreteArray # Specification of the action for the GraphColoring environment. Returns: Type Description action_spec specs.DiscreteArray object","title":"GraphColoring"},{"location":"api/environments/graph_coloring/#jumanji.environments.logic.graph_coloring.env.GraphColoring","text":"Environment for the GraphColoring problem. The problem is a combinatorial optimization task where the goal is to assign a color to each vertex of a graph in such a way that no two adjacent vertices share the same color. The problem is usually formulated as minimizing the number of colors used. observation: Observation adj_matrix: jax array (bool) of shape (num_nodes, num_nodes), representing the adjacency matrix of the graph. colors: jax array (int32) of shape (num_nodes,), representing the current color assignments for the vertices. action_mask: jax array (bool) of shape (num_colors,), indicating which actions are valid in the current state of the environment. current_node_index: integer representing the current node being colored. action: int, the color to be assigned to the current node (0 to num_nodes - 1) reward: float, a sparse reward is provided at the end of the episode. Equals the negative of the number of unique colors used to color all vertices in the graph. If an invalid action is taken, the reward is the negative of the total number of colors. episode termination: if all nodes have been assigned a color or if an invalid action is taken. state: State adj_matrix: jax array (bool) of shape (num_nodes, num_nodes), representing the adjacency matrix of the graph. colors: jax array (int32) of shape (num_nodes,), color assigned to each node, -1 if not assigned. current_node_index: jax array (int) with shape (), index of the current node. action_mask: jax array (bool) of shape (num_colors,), indicating which actions are valid in the current state of the environment. key: jax array (uint32) of shape (2,), random key used to generate random numbers at each step and for auto-reset. 1 2 3 4 5 6 7 8 from jumanji.environments import GraphColoring env = GraphColoring () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state )","title":"GraphColoring"},{"location":"api/environments/graph_coloring/#jumanji.environments.logic.graph_coloring.env.GraphColoring.__init__","text":"Instantiate a GraphColoring environment. Parameters: Name Type Description Default generator Optional[jumanji.environments.logic.graph_coloring.generator.Generator] callable to instantiate environment instances. Defaults to RandomGenerator which generates graphs with 20 num_nodes and edge_probability equal to 0.8. None viewer Optional[jumanji.viewer.Viewer[jumanji.environments.logic.graph_coloring.types.State]] environment viewer for rendering. Defaults to GraphColoringViewer . None","title":"__init__()"},{"location":"api/environments/graph_coloring/#jumanji.environments.logic.graph_coloring.env.GraphColoring.reset","text":"Resets the environment to an initial state. Returns: Type Description Tuple[jumanji.environments.logic.graph_coloring.types.State, jumanji.types.TimeStep[jumanji.environments.logic.graph_coloring.types.Observation]] The initial state and timestep.","title":"reset()"},{"location":"api/environments/graph_coloring/#jumanji.environments.logic.graph_coloring.env.GraphColoring.step","text":"Updates the environment state after the agent takes an action. Specifically, this function allows the agent to choose a color for the current node (based on the action taken) in a graph coloring problem. It then updates the state of the environment based on the color chosen and calculates the reward based on the validity of the action and the completion of the coloring task. Parameters: Name Type Description Default state State the current state of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] the action taken by the agent. required Returns: Type Description state the new state of the environment. timestep: the next timestep.","title":"step()"},{"location":"api/environments/graph_coloring/#jumanji.environments.logic.graph_coloring.env.GraphColoring.observation_spec","text":"Returns the observation spec. Returns: Type Description Spec for the `Observation` whose fields are adj_matrix: BoundedArray (bool) of shape (num_nodes, num_nodes). Represents the adjacency matrix of the graph. action_mask: BoundedArray (bool) of shape (num_nodes,). Represents the valid actions in the current state. colors: BoundedArray (int32) of shape (num_nodes,). Represents the colors assigned to each node. current_node_index: BoundedArray (int32) of shape (). Represents the index of the current node.","title":"observation_spec()"},{"location":"api/environments/graph_coloring/#jumanji.environments.logic.graph_coloring.env.GraphColoring.action_spec","text":"Specification of the action for the GraphColoring environment. Returns: Type Description action_spec specs.DiscreteArray object","title":"action_spec()"},{"location":"api/environments/job_shop/","text":"JobShop ( Environment ) # The Job Shop Scheduling Problem, as described in [1], is one of the best known combinatorial optimization problems. We are given num_jobs jobs, each consisting of at most max_num_ops ops, which need to be processed on num_machines machines. Each operation (op) has a specific machine that it needs to be processed on and a duration (which must be less than or equal to max_duration_op ). The goal is to minimise the total length of the schedule, also known as the makespan. [1] https://developers.google.com/optimization/scheduling/job_shop. observation: Observation ops_machine_ids: jax array (int32) of (num_jobs, max_num_ops) id of the machine each operation must be processed on. ops_durations: jax array (int32) of (num_jobs, max_num_ops) processing time of each operation. ops_mask: jax array (bool) of (num_jobs, max_num_ops) indicating which operations have yet to be scheduled. machines_job_ids: jax array (int32) of shape (num_machines,) id of the job (or no-op) that each machine is processing. machines_remaining_times: jax array (int32) of shape (num_machines,) specifying, for each machine, the number of time steps until available. action_mask: jax array (bool) of shape (num_machines, num_jobs + 1) indicates which job(s) (or no-op) can legally be scheduled on each machine. action: jax array (int32) of shape (num_machines,). reward: jax array (float) of shape (). A reward of -1 is given each time step. If all machines are simultaneously idle or the agent selects an invalid action, the agent is given a large penalty of -num_jobs * max_num_ops * max_op_duration which is an upper bound on the makespan. episode termination: Finished schedule: all operations (and thus all jobs) every job have been processed. Illegal action: the agent ignores the action mask and takes an illegal action. Simultaneously idle: all machines are inactive at the same time. state: State ops_machine_ids: same as observation. ops_durations: same as observation. ops_mask: same as observation. machines_job_ids: same as observation. machines_remaining_times: same as observation. action_mask: same as observation. step_count: jax array (int32) of shape (), the number of time steps in the episode so far. scheduled_times: jax array (int32) of shape (num_jobs, max_num_ops), specifying the timestep at which every op (scheduled so far) was scheduled. 1 2 3 4 5 6 7 8 from jumanji.environments import JobShop env = JobShop () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state ) __init__ ( self , generator : Optional [ jumanji . environments . packing . job_shop . generator . Generator ] = None , viewer : Optional [ jumanji . viewer . Viewer [ jumanji . environments . packing . job_shop . types . State ]] = None ) special # Instantiate a JobShop environment. Parameters: Name Type Description Default generator Optional[jumanji.environments.packing.job_shop.generator.Generator] Generator whose __call__ instantiates an environment instance. Implemented options are ['ToyGenerator', 'RandomGenerator']. Defaults to RandomGenerator with 20 jobs, 10 machines, up to 8 ops for any given job, and a max operation duration of 6. None viewer Optional[jumanji.viewer.Viewer[jumanji.environments.packing.job_shop.types.State]] Viewer used for rendering. Defaults to JobShopViewer . None reset ( self , key : PRNGKeyArray ) -> Tuple [ jumanji . environments . packing . job_shop . types . State , jumanji . types . TimeStep [ jumanji . environments . packing . job_shop . types . Observation ]] # Resets the environment by creating a new problem instance and initialising the state and timestep. Parameters: Name Type Description Default key PRNGKeyArray random key used to reset the environment. required Returns: Type Description state the environment state after the reset. timestep: the first timestep returned by the environment after the reset. step ( self , state : State , action : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ]) -> Tuple [ jumanji . environments . packing . job_shop . types . State , jumanji . types . TimeStep [ jumanji . environments . packing . job_shop . types . Observation ]] # Updates the status of all machines, the status of the operations, and increments the time step. It updates the environment state and the timestep (which contains the new observation). It calculates the reward based on the three terminal conditions: - The action provided by the agent is invalid. - The schedule has finished. - All machines do a no-op that leads to all machines being simultaneously idle. Parameters: Name Type Description Default state State the environment state. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] the action to take. required Returns: Type Description state the updated environment state. timestep: the updated timestep. observation_spec ( self ) -> jumanji . specs . Spec [ jumanji . environments . packing . job_shop . types . Observation ] # Specifications of the observation of the JobShop environment. Returns: Type Description Spec containing the specifications for all the `Observation` fields ops_machine_ids: BoundedArray (int32) of shape (num_jobs, max_num_ops). ops_durations: BoundedArray (int32) of shape (num_jobs, max_num_ops). ops_mask: BoundedArray (bool) of shape (num_jobs, max_num_ops). machines_job_ids: BoundedArray (int32) of shape (num_machines,). machines_remaining_times: BoundedArray (int32) of shape (num_machines,). action_mask: BoundedArray (bool) of shape (num_machines, num_jobs + 1). action_spec ( self ) -> MultiDiscreteArray # Specifications of the action in the JobShop environment. The action gives each machine a job id ranging from 0, 1, ..., num_jobs where the last value corresponds to a no-op. Returns: Type Description action_spec a specs.MultiDiscreteArray spec.","title":"JobShop"},{"location":"api/environments/job_shop/#jumanji.environments.packing.job_shop.env.JobShop","text":"The Job Shop Scheduling Problem, as described in [1], is one of the best known combinatorial optimization problems. We are given num_jobs jobs, each consisting of at most max_num_ops ops, which need to be processed on num_machines machines. Each operation (op) has a specific machine that it needs to be processed on and a duration (which must be less than or equal to max_duration_op ). The goal is to minimise the total length of the schedule, also known as the makespan. [1] https://developers.google.com/optimization/scheduling/job_shop. observation: Observation ops_machine_ids: jax array (int32) of (num_jobs, max_num_ops) id of the machine each operation must be processed on. ops_durations: jax array (int32) of (num_jobs, max_num_ops) processing time of each operation. ops_mask: jax array (bool) of (num_jobs, max_num_ops) indicating which operations have yet to be scheduled. machines_job_ids: jax array (int32) of shape (num_machines,) id of the job (or no-op) that each machine is processing. machines_remaining_times: jax array (int32) of shape (num_machines,) specifying, for each machine, the number of time steps until available. action_mask: jax array (bool) of shape (num_machines, num_jobs + 1) indicates which job(s) (or no-op) can legally be scheduled on each machine. action: jax array (int32) of shape (num_machines,). reward: jax array (float) of shape (). A reward of -1 is given each time step. If all machines are simultaneously idle or the agent selects an invalid action, the agent is given a large penalty of -num_jobs * max_num_ops * max_op_duration which is an upper bound on the makespan. episode termination: Finished schedule: all operations (and thus all jobs) every job have been processed. Illegal action: the agent ignores the action mask and takes an illegal action. Simultaneously idle: all machines are inactive at the same time. state: State ops_machine_ids: same as observation. ops_durations: same as observation. ops_mask: same as observation. machines_job_ids: same as observation. machines_remaining_times: same as observation. action_mask: same as observation. step_count: jax array (int32) of shape (), the number of time steps in the episode so far. scheduled_times: jax array (int32) of shape (num_jobs, max_num_ops), specifying the timestep at which every op (scheduled so far) was scheduled. 1 2 3 4 5 6 7 8 from jumanji.environments import JobShop env = JobShop () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state )","title":"JobShop"},{"location":"api/environments/job_shop/#jumanji.environments.packing.job_shop.env.JobShop.__init__","text":"Instantiate a JobShop environment. Parameters: Name Type Description Default generator Optional[jumanji.environments.packing.job_shop.generator.Generator] Generator whose __call__ instantiates an environment instance. Implemented options are ['ToyGenerator', 'RandomGenerator']. Defaults to RandomGenerator with 20 jobs, 10 machines, up to 8 ops for any given job, and a max operation duration of 6. None viewer Optional[jumanji.viewer.Viewer[jumanji.environments.packing.job_shop.types.State]] Viewer used for rendering. Defaults to JobShopViewer . None","title":"__init__()"},{"location":"api/environments/job_shop/#jumanji.environments.packing.job_shop.env.JobShop.reset","text":"Resets the environment by creating a new problem instance and initialising the state and timestep. Parameters: Name Type Description Default key PRNGKeyArray random key used to reset the environment. required Returns: Type Description state the environment state after the reset. timestep: the first timestep returned by the environment after the reset.","title":"reset()"},{"location":"api/environments/job_shop/#jumanji.environments.packing.job_shop.env.JobShop.step","text":"Updates the status of all machines, the status of the operations, and increments the time step. It updates the environment state and the timestep (which contains the new observation). It calculates the reward based on the three terminal conditions: - The action provided by the agent is invalid. - The schedule has finished. - All machines do a no-op that leads to all machines being simultaneously idle. Parameters: Name Type Description Default state State the environment state. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] the action to take. required Returns: Type Description state the updated environment state. timestep: the updated timestep.","title":"step()"},{"location":"api/environments/job_shop/#jumanji.environments.packing.job_shop.env.JobShop.observation_spec","text":"Specifications of the observation of the JobShop environment. Returns: Type Description Spec containing the specifications for all the `Observation` fields ops_machine_ids: BoundedArray (int32) of shape (num_jobs, max_num_ops). ops_durations: BoundedArray (int32) of shape (num_jobs, max_num_ops). ops_mask: BoundedArray (bool) of shape (num_jobs, max_num_ops). machines_job_ids: BoundedArray (int32) of shape (num_machines,). machines_remaining_times: BoundedArray (int32) of shape (num_machines,). action_mask: BoundedArray (bool) of shape (num_machines, num_jobs + 1).","title":"observation_spec()"},{"location":"api/environments/job_shop/#jumanji.environments.packing.job_shop.env.JobShop.action_spec","text":"Specifications of the action in the JobShop environment. The action gives each machine a job id ranging from 0, 1, ..., num_jobs where the last value corresponds to a no-op. Returns: Type Description action_spec a specs.MultiDiscreteArray spec.","title":"action_spec()"},{"location":"api/environments/knapsack/","text":"Knapsack ( Environment ) # Knapsack environment as described in [1]. observation: Observation weights: jax array (float) of shape (num_items,) the weights of the items. values: jax array (float) of shape (num_items,) the values of the items. packed_items: jax array (bool) of shape (num_items,) binary mask denoting which items are already packed into the knapsack. action_mask: jax array (bool) of shape (num_items,) binary mask denoting which items can be packed into the knapsack. action: jax array (int32) of shape () [0, ..., num_items - 1] -> item to pack. reward: jax array (float) of shape (), could be either: dense: the value of the item to pack at the current timestep. sparse: the sum of the values of the items packed in the bag at the end of the episode. In both cases, the reward is 0 if the action is invalid, i.e. an item that was previously selected is selected again or has a weight larger than the bag capacity. episode termination: if no action can be performed, i.e. all items are packed or each remaining item's weight is larger than the bag capacity. if an invalid action is taken, i.e. the chosen item is already packed or has a weight larger than the bag capacity. state: State weights: jax array (float) of shape (num_items,) the weights of the items. values: jax array (float) of shape (num_items,) the values of the items. packed_items: jax array (bool) of shape (num_items,) binary mask denoting which items are already packed into the knapsack. remaining_budget: jax array (float) the budget currently remaining. [1] https://arxiv.org/abs/2010.16011 1 2 3 4 5 6 7 8 from jumanji.environments import Knapsack env = Knapsack () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state ) __init__ ( self , generator : Optional [ jumanji . environments . packing . knapsack . generator . Generator ] = None , reward_fn : Optional [ jumanji . environments . packing . knapsack . reward . RewardFn ] = None , viewer : Optional [ jumanji . viewer . Viewer [ jumanji . environments . packing . knapsack . types . State ]] = None ) special # Instantiates a Knapsack environment. Parameters: Name Type Description Default generator Optional[jumanji.environments.packing.knapsack.generator.Generator] Generator whose __call__ instantiates an environment instance. The default option is 'RandomGenerator' which samples Knapsack instances with 50 items and a total budget of 12.5. None reward_fn Optional[jumanji.environments.packing.knapsack.reward.RewardFn] RewardFn whose __call__ method computes the reward of an environment transition. The function must compute the reward based on the current state, the chosen action, the next state and whether the action is valid. Implemented options are [ DenseReward , SparseReward ]. Defaults to DenseReward . None viewer Optional[jumanji.viewer.Viewer[jumanji.environments.packing.knapsack.types.State]] Viewer used for rendering. Defaults to KnapsackViewer with \"human\" render mode. None reset ( self , key : PRNGKeyArray ) -> Tuple [ jumanji . environments . packing . knapsack . types . State , jumanji . types . TimeStep [ jumanji . environments . packing . knapsack . types . Observation ]] # Resets the environment. Parameters: Name Type Description Default key PRNGKeyArray used to randomly generate the weights and values of the items. required Returns: Type Description state the new state of the environment. timestep: the first timestep returned by the environment. step ( self , state : State , action : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number , float , int ]) -> Tuple [ jumanji . environments . packing . knapsack . types . State , jumanji . types . TimeStep [ jumanji . environments . packing . knapsack . types . Observation ]] # Run one timestep of the environment's dynamics. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, float, int] index of next item to take. required Returns: Type Description state next state of the environment. timestep: the timestep to be observed. observation_spec ( self ) -> jumanji . specs . Spec [ jumanji . environments . packing . knapsack . types . Observation ] # Returns the observation spec. Returns: Type Description Spec for each field in the Observation weights: BoundedArray (float) of shape (num_items,). values: BoundedArray (float) of shape (num_items,). packed_items: BoundedArray (bool) of shape (num_items,). action_mask: BoundedArray (bool) of shape (num_items,). action_spec ( self ) -> DiscreteArray # Returns the action spec. Returns: Type Description action_spec a specs.DiscreteArray spec.","title":"Knapsack"},{"location":"api/environments/knapsack/#jumanji.environments.packing.knapsack.env.Knapsack","text":"Knapsack environment as described in [1]. observation: Observation weights: jax array (float) of shape (num_items,) the weights of the items. values: jax array (float) of shape (num_items,) the values of the items. packed_items: jax array (bool) of shape (num_items,) binary mask denoting which items are already packed into the knapsack. action_mask: jax array (bool) of shape (num_items,) binary mask denoting which items can be packed into the knapsack. action: jax array (int32) of shape () [0, ..., num_items - 1] -> item to pack. reward: jax array (float) of shape (), could be either: dense: the value of the item to pack at the current timestep. sparse: the sum of the values of the items packed in the bag at the end of the episode. In both cases, the reward is 0 if the action is invalid, i.e. an item that was previously selected is selected again or has a weight larger than the bag capacity. episode termination: if no action can be performed, i.e. all items are packed or each remaining item's weight is larger than the bag capacity. if an invalid action is taken, i.e. the chosen item is already packed or has a weight larger than the bag capacity. state: State weights: jax array (float) of shape (num_items,) the weights of the items. values: jax array (float) of shape (num_items,) the values of the items. packed_items: jax array (bool) of shape (num_items,) binary mask denoting which items are already packed into the knapsack. remaining_budget: jax array (float) the budget currently remaining. [1] https://arxiv.org/abs/2010.16011 1 2 3 4 5 6 7 8 from jumanji.environments import Knapsack env = Knapsack () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state )","title":"Knapsack"},{"location":"api/environments/knapsack/#jumanji.environments.packing.knapsack.env.Knapsack.__init__","text":"Instantiates a Knapsack environment. Parameters: Name Type Description Default generator Optional[jumanji.environments.packing.knapsack.generator.Generator] Generator whose __call__ instantiates an environment instance. The default option is 'RandomGenerator' which samples Knapsack instances with 50 items and a total budget of 12.5. None reward_fn Optional[jumanji.environments.packing.knapsack.reward.RewardFn] RewardFn whose __call__ method computes the reward of an environment transition. The function must compute the reward based on the current state, the chosen action, the next state and whether the action is valid. Implemented options are [ DenseReward , SparseReward ]. Defaults to DenseReward . None viewer Optional[jumanji.viewer.Viewer[jumanji.environments.packing.knapsack.types.State]] Viewer used for rendering. Defaults to KnapsackViewer with \"human\" render mode. None","title":"__init__()"},{"location":"api/environments/knapsack/#jumanji.environments.packing.knapsack.env.Knapsack.reset","text":"Resets the environment. Parameters: Name Type Description Default key PRNGKeyArray used to randomly generate the weights and values of the items. required Returns: Type Description state the new state of the environment. timestep: the first timestep returned by the environment.","title":"reset()"},{"location":"api/environments/knapsack/#jumanji.environments.packing.knapsack.env.Knapsack.step","text":"Run one timestep of the environment's dynamics. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, float, int] index of next item to take. required Returns: Type Description state next state of the environment. timestep: the timestep to be observed.","title":"step()"},{"location":"api/environments/knapsack/#jumanji.environments.packing.knapsack.env.Knapsack.observation_spec","text":"Returns the observation spec. Returns: Type Description Spec for each field in the Observation weights: BoundedArray (float) of shape (num_items,). values: BoundedArray (float) of shape (num_items,). packed_items: BoundedArray (bool) of shape (num_items,). action_mask: BoundedArray (bool) of shape (num_items,).","title":"observation_spec()"},{"location":"api/environments/knapsack/#jumanji.environments.packing.knapsack.env.Knapsack.action_spec","text":"Returns the action spec. Returns: Type Description action_spec a specs.DiscreteArray spec.","title":"action_spec()"},{"location":"api/environments/macvrp/","text":"MultiCVRP ( Environment ) # Multi-Vehicle Routing Problems with Soft Time Windows (MVRPSTW) environment as described in [1]. We simplfy the naming to multi-agent capacitated vehicle routing problem (MultiCVRP). reward: jax array (float32) this global reward is provided to each agent. The reward is equal to the negative sum of the distances between consecutive nodes at the end of the episode over all agents. All time penalties are also added to the reward. observation and state: the observation and state variable types are defined in: jumanji/environments/routing/multi_cvrp/types.py [1] Zhang et al. (2020). \"Multi-Vehicle Routing Problems with Soft Time Windows: A Multi-Agent Reinforcement Learning Approach\". __init__ ( self , generator : Optional [ jumanji . environments . routing . multi_cvrp . generator . Generator ] = None , reward_fn : Optional [ jumanji . environments . routing . multi_cvrp . reward . RewardFn ] = None , viewer : Optional [ jumanji . viewer . Viewer ] = None ) special # Instantiates a MultiCVRP environment. Parameters: Name Type Description Default generator Optional[jumanji.environments.routing.multi_cvrp.generator.Generator] Generator whose __call__ instantiates an environment instance. Implemented options are [ UniformRandomGenerator ]. Defaults to UniformRandomGenerator with num_customers=20 and num_vehicles=2 . None reward_fn Optional[jumanji.environments.routing.multi_cvrp.reward.RewardFn] RewardFn whose __call__ method computes the reward of an environment transition. The function must compute the reward based on the current state and whether the environment is done. Implemented options are [ DenseReward , SparseReward ]. Defaults to DenseReward . None viewer Optional[jumanji.viewer.Viewer] Viewer used for rendering. Defaults to MultiCVRPViewer with \"human\" render mode. None reset ( self , key : PRNGKeyArray ) -> Tuple [ jumanji . environments . routing . multi_cvrp . types . State , jumanji . types . TimeStep [ jumanji . environments . routing . multi_cvrp . types . Observation ]] # Resets the environment. Parameters: Name Type Description Default key PRNGKeyArray used to randomly generate the problem and the start node. required Returns: Type Description state State object corresponding to the new state of the environment. timestep: TimeStep object corresponding to the first timestep returned by the environment. step ( self , state : State , action : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ]) -> Tuple [ jumanji . environments . routing . multi_cvrp . types . State , jumanji . types . TimeStep [ jumanji . environments . routing . multi_cvrp . types . Observation ]] # Run one timestep of the environment's dynamics. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] Array containing the index of the next nodes to visit. required Returns: Type Description state, timestep Tuple[State, TimeStep] containing the next state of the environment, as well as the timestep to be observed. observation_spec ( self ) -> jumanji . specs . Spec [ jumanji . environments . routing . multi_cvrp . types . Observation ] # Returns the observation spec. Returns: Type Description observation_spec a Tuple containing the spec for each of the constituent fields of an observation. action_spec ( self ) -> BoundedArray # Returns the action spec. Returns: Type Description action_spec a specs.BoundedArray spec.","title":"Macvrp"},{"location":"api/environments/macvrp/#jumanji.environments.routing.multi_cvrp.env.MultiCVRP","text":"Multi-Vehicle Routing Problems with Soft Time Windows (MVRPSTW) environment as described in [1]. We simplfy the naming to multi-agent capacitated vehicle routing problem (MultiCVRP). reward: jax array (float32) this global reward is provided to each agent. The reward is equal to the negative sum of the distances between consecutive nodes at the end of the episode over all agents. All time penalties are also added to the reward. observation and state: the observation and state variable types are defined in: jumanji/environments/routing/multi_cvrp/types.py [1] Zhang et al. (2020). \"Multi-Vehicle Routing Problems with Soft Time Windows: A Multi-Agent Reinforcement Learning Approach\".","title":"MultiCVRP"},{"location":"api/environments/macvrp/#jumanji.environments.routing.multi_cvrp.env.MultiCVRP.__init__","text":"Instantiates a MultiCVRP environment. Parameters: Name Type Description Default generator Optional[jumanji.environments.routing.multi_cvrp.generator.Generator] Generator whose __call__ instantiates an environment instance. Implemented options are [ UniformRandomGenerator ]. Defaults to UniformRandomGenerator with num_customers=20 and num_vehicles=2 . None reward_fn Optional[jumanji.environments.routing.multi_cvrp.reward.RewardFn] RewardFn whose __call__ method computes the reward of an environment transition. The function must compute the reward based on the current state and whether the environment is done. Implemented options are [ DenseReward , SparseReward ]. Defaults to DenseReward . None viewer Optional[jumanji.viewer.Viewer] Viewer used for rendering. Defaults to MultiCVRPViewer with \"human\" render mode. None","title":"__init__()"},{"location":"api/environments/macvrp/#jumanji.environments.routing.multi_cvrp.env.MultiCVRP.reset","text":"Resets the environment. Parameters: Name Type Description Default key PRNGKeyArray used to randomly generate the problem and the start node. required Returns: Type Description state State object corresponding to the new state of the environment. timestep: TimeStep object corresponding to the first timestep returned by the environment.","title":"reset()"},{"location":"api/environments/macvrp/#jumanji.environments.routing.multi_cvrp.env.MultiCVRP.step","text":"Run one timestep of the environment's dynamics. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] Array containing the index of the next nodes to visit. required Returns: Type Description state, timestep Tuple[State, TimeStep] containing the next state of the environment, as well as the timestep to be observed.","title":"step()"},{"location":"api/environments/macvrp/#jumanji.environments.routing.multi_cvrp.env.MultiCVRP.observation_spec","text":"Returns the observation spec. Returns: Type Description observation_spec a Tuple containing the spec for each of the constituent fields of an observation.","title":"observation_spec()"},{"location":"api/environments/macvrp/#jumanji.environments.routing.multi_cvrp.env.MultiCVRP.action_spec","text":"Returns the action spec. Returns: Type Description action_spec a specs.BoundedArray spec.","title":"action_spec()"},{"location":"api/environments/maze/","text":"Maze ( Environment ) # A JAX implementation of a 2D Maze. The goal is to navigate the maze to find the target position. observation: agent_position: current 2D Position of agent. target_position: 2D Position of target cell. walls: jax array (bool) of shape (num_rows, num_cols) whose values are True where walls are and False for empty cells. action_mask: array (bool) of shape (4,) defining the available actions in the current position. step_count: jax array (int32) of shape () step number of the episode. action: jax array (int32) of shape () specifying which action to take: [0,1,2,3] correspond to [Up, Right, Down, Left]. If an invalid action is taken, i.e. there is a wall blocking the action, then no action (no-op) is taken. reward: jax array (float32) of shape (): 1 if the target is reached, 0 otherwise. episode termination (if any): agent reaches the target position. the time_limit is reached. state: State: agent_position: current 2D Position of agent. target_position: 2D Position of target cell. walls: jax array (bool) of shape (num_rows, num_cols) whose values are True where walls are and False for empty cells. action_mask: array (bool) of shape (4,) defining the available actions in the current position. step_count: jax array (int32) of shape () step number of the episode. key: random key (uint) of shape (2,). 1 2 3 4 5 6 7 8 from jumanji.environments import Maze env = Maze () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state ) __init__ ( self , generator : Optional [ jumanji . environments . routing . maze . generator . Generator ] = None , time_limit : Optional [ int ] = None , viewer : Optional [ jumanji . viewer . Viewer [ jumanji . environments . routing . maze . types . State ]] = None ) -> None special # Instantiates a Maze environment. Parameters: Name Type Description Default generator Optional[jumanji.environments.routing.maze.generator.Generator] Generator whose __call__ instantiates an environment instance. Implemented options are [ ToyGenerator , RandomGenerator ]. Defaults to RandomGenerator with num_rows=10 and num_cols=10 . None time_limit Optional[int] the time_limit of an episode, i.e. the maximum number of environment steps before the episode terminates. By default, time_limit = num_rows * num_cols . None viewer Optional[jumanji.viewer.Viewer[jumanji.environments.routing.maze.types.State]] Viewer used for rendering. Defaults to MazeEnvViewer with \"human\" render mode. None observation_spec ( self ) -> jumanji . specs . Spec [ jumanji . environments . routing . maze . types . Observation ] # Specifications of the observation of the Maze environment. Returns: Type Description Spec for the `Observation` whose fields are agent_position: tree of BoundedArray (int32) of shape (). target_position: tree of BoundedArray (int32) of shape (). walls: BoundedArray (bool) of shape (num_rows, num_cols). step_count: Array (int32) of shape (). action_mask: BoundedArray (bool) of shape (4,). action_spec ( self ) -> DiscreteArray # Returns the action spec. 4 actions: [0,1,2,3] -> [Up, Right, Down, Left]. Returns: Type Description action_spec discrete action space with 4 values. reset ( self , key : PRNGKeyArray ) -> Tuple [ jumanji . environments . routing . maze . types . State , jumanji . types . TimeStep [ jumanji . environments . routing . maze . types . Observation ]] # Resets the environment by calling the instance generator for a new instance. Parameters: Name Type Description Default key PRNGKeyArray random key used to reset the environment since it is stochastic. required Returns: Type Description state State object corresponding to the new state of the environment after a reset. timestep: TimeStep object corresponding the first timestep returned by the environment after a reset. step ( self , state : State , action : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ]) -> Tuple [ jumanji . environments . routing . maze . types . State , jumanji . types . TimeStep [ jumanji . environments . routing . maze . types . Observation ]] # Run one timestep of the environment's dynamics. If an action is invalid, the agent does not move, i.e. the episode does not automatically terminate. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] (int32) specifying which action to take: [0,1,2,3] correspond to [Up, Right, Down, Left]. If an invalid action is taken, i.e. there is a wall blocking the action, then no action (no-op) is taken. required Returns: Type Description state the next state of the environment. timestep: the next timestep to be observed.","title":"Maze"},{"location":"api/environments/maze/#jumanji.environments.routing.maze.env.Maze","text":"A JAX implementation of a 2D Maze. The goal is to navigate the maze to find the target position. observation: agent_position: current 2D Position of agent. target_position: 2D Position of target cell. walls: jax array (bool) of shape (num_rows, num_cols) whose values are True where walls are and False for empty cells. action_mask: array (bool) of shape (4,) defining the available actions in the current position. step_count: jax array (int32) of shape () step number of the episode. action: jax array (int32) of shape () specifying which action to take: [0,1,2,3] correspond to [Up, Right, Down, Left]. If an invalid action is taken, i.e. there is a wall blocking the action, then no action (no-op) is taken. reward: jax array (float32) of shape (): 1 if the target is reached, 0 otherwise. episode termination (if any): agent reaches the target position. the time_limit is reached. state: State: agent_position: current 2D Position of agent. target_position: 2D Position of target cell. walls: jax array (bool) of shape (num_rows, num_cols) whose values are True where walls are and False for empty cells. action_mask: array (bool) of shape (4,) defining the available actions in the current position. step_count: jax array (int32) of shape () step number of the episode. key: random key (uint) of shape (2,). 1 2 3 4 5 6 7 8 from jumanji.environments import Maze env = Maze () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state )","title":"Maze"},{"location":"api/environments/maze/#jumanji.environments.routing.maze.env.Maze.__init__","text":"Instantiates a Maze environment. Parameters: Name Type Description Default generator Optional[jumanji.environments.routing.maze.generator.Generator] Generator whose __call__ instantiates an environment instance. Implemented options are [ ToyGenerator , RandomGenerator ]. Defaults to RandomGenerator with num_rows=10 and num_cols=10 . None time_limit Optional[int] the time_limit of an episode, i.e. the maximum number of environment steps before the episode terminates. By default, time_limit = num_rows * num_cols . None viewer Optional[jumanji.viewer.Viewer[jumanji.environments.routing.maze.types.State]] Viewer used for rendering. Defaults to MazeEnvViewer with \"human\" render mode. None","title":"__init__()"},{"location":"api/environments/maze/#jumanji.environments.routing.maze.env.Maze.observation_spec","text":"Specifications of the observation of the Maze environment. Returns: Type Description Spec for the `Observation` whose fields are agent_position: tree of BoundedArray (int32) of shape (). target_position: tree of BoundedArray (int32) of shape (). walls: BoundedArray (bool) of shape (num_rows, num_cols). step_count: Array (int32) of shape (). action_mask: BoundedArray (bool) of shape (4,).","title":"observation_spec()"},{"location":"api/environments/maze/#jumanji.environments.routing.maze.env.Maze.action_spec","text":"Returns the action spec. 4 actions: [0,1,2,3] -> [Up, Right, Down, Left]. Returns: Type Description action_spec discrete action space with 4 values.","title":"action_spec()"},{"location":"api/environments/maze/#jumanji.environments.routing.maze.env.Maze.reset","text":"Resets the environment by calling the instance generator for a new instance. Parameters: Name Type Description Default key PRNGKeyArray random key used to reset the environment since it is stochastic. required Returns: Type Description state State object corresponding to the new state of the environment after a reset. timestep: TimeStep object corresponding the first timestep returned by the environment after a reset.","title":"reset()"},{"location":"api/environments/maze/#jumanji.environments.routing.maze.env.Maze.step","text":"Run one timestep of the environment's dynamics. If an action is invalid, the agent does not move, i.e. the episode does not automatically terminate. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] (int32) specifying which action to take: [0,1,2,3] correspond to [Up, Right, Down, Left]. If an invalid action is taken, i.e. there is a wall blocking the action, then no action (no-op) is taken. required Returns: Type Description state the next state of the environment. timestep: the next timestep to be observed.","title":"step()"},{"location":"api/environments/minesweeper/","text":"Minesweeper ( Environment ) # A JAX implementation of the minesweeper game. observation: Observation board: jax array (int32) of shape (num_rows, num_cols): each cell contains -1 if not yet explored, or otherwise the number of mines in the 8 adjacent squares. action_mask: jax array (bool) of shape (num_rows, num_cols): indicates which actions are valid (not yet explored squares). num_mines: jax array (int32) of shape () , indicates the number of mines to locate. step_count: jax array (int32) of shape (): specifies how many timesteps have elapsed since environment reset. action: multi discrete array containing the square to explore (row and col). reward: jax array (float32): Configurable function of state and action. By default: 1 for every timestep where a valid action is chosen that doesn't reveal a mine, 0 for revealing a mine or selecting an already revealed square (and terminate the episode). episode termination: Configurable function of state, next_state, and action. By default: Stop the episode if a mine is explored, an invalid action is selected (exploring an already explored square), or the board is solved. state: State board: jax array (int32) of shape (num_rows, num_cols): each cell contains -1 if not yet explored, or otherwise the number of mines in the 8 adjacent squares. step_count: jax array (int32) of shape (): specifies how many timesteps have elapsed since environment reset. flat_mine_locations: jax array (int32) of shape (num_rows * num_cols,): indicates the (flat) locations of all the mines on the board. Will be of length num_mines. key: jax array (int32) of shape (2,) used for seeding the sampling of mine placement on reset. 1 2 3 4 5 6 7 8 from jumanji.environments import Minesweeper env = Minesweeper () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state ) __init__ ( self , generator : Optional [ jumanji . environments . logic . minesweeper . generator . Generator ] = None , reward_function : Optional [ jumanji . environments . logic . minesweeper . reward . RewardFn ] = None , done_function : Optional [ jumanji . environments . logic . minesweeper . done . DoneFn ] = None , viewer : Optional [ jumanji . viewer . Viewer [ jumanji . environments . logic . minesweeper . types . State ]] = None ) special # Instantiate a Minesweeper environment. Parameters: Name Type Description Default generator Optional[jumanji.environments.logic.minesweeper.generator.Generator] Generator to generate problem instances on environment reset. Implemented options are [ SamplingGenerator ]. Defaults to SamplingGenerator . The generator will have attributes: - num_rows: number of rows, i.e. height of the board. Defaults to 10. - num_cols: number of columns, i.e. width of the board. Defaults to 10. - num_mines: number of mines generated. Defaults to 10. None reward_function Optional[jumanji.environments.logic.minesweeper.reward.RewardFn] RewardFn whose __call__ method computes the reward of an environment transition based on the given current state and selected action. Implemented options are [ DefaultRewardFn ]. Defaults to DefaultRewardFn , giving a reward of 1.0 for revealing an empty square, 0.0 for revealing a mine, and 0.0 for an invalid action (selecting an already revealed square). None done_function Optional[jumanji.environments.logic.minesweeper.done.DoneFn] DoneFn whose __call__ method computes the done signal given the current state, action taken, and next state. Implemented options are [ DefaultDoneFn ]. Defaults to DefaultDoneFn , ending the episode on solving the board, revealing a mine, or picking an invalid action. None viewer Optional[jumanji.viewer.Viewer[jumanji.environments.logic.minesweeper.types.State]] Viewer to support rendering and animation methods. Implemented options are [ MinesweeperViewer ]. Defaults to MinesweeperViewer . None reset ( self , key : PRNGKeyArray ) -> Tuple [ jumanji . environments . logic . minesweeper . types . State , jumanji . types . TimeStep [ jumanji . environments . logic . minesweeper . types . Observation ]] # Resets the environment. Parameters: Name Type Description Default key PRNGKeyArray needed for placing mines. required Returns: Type Description state State corresponding to the new state of the environment, timestep: TimeStep corresponding to the first timestep returned by the environment. step ( self , state : State , action : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ]) -> Tuple [ jumanji . environments . logic . minesweeper . types . State , jumanji . types . TimeStep [ jumanji . environments . logic . minesweeper . types . Observation ]] # Run one timestep of the environment's dynamics. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] Array containing the row and column of the square to be explored. required Returns: Type Description next_state State corresponding to the next state of the environment, next_timestep: TimeStep corresponding to the timestep returned by the environment. observation_spec ( self ) -> jumanji . specs . Spec [ jumanji . environments . logic . minesweeper . types . Observation ] # Specifications of the observation of the Minesweeper environment. Returns: Type Description Spec for the `Observation` whose fields are board: BoundedArray (int32) of shape (num_rows, num_cols). action_mask: BoundedArray (bool) of shape (num_rows, num_cols). num_mines: BoundedArray (int32) of shape (). step_count: BoundedArray (int32) of shape (). action_spec ( self ) -> MultiDiscreteArray # Returns the action spec. An action consists of the height and width of the square to be explored. Returns: Type Description action_spec specs.MultiDiscreteArray object.","title":"Minesweeper"},{"location":"api/environments/minesweeper/#jumanji.environments.logic.minesweeper.env.Minesweeper","text":"A JAX implementation of the minesweeper game. observation: Observation board: jax array (int32) of shape (num_rows, num_cols): each cell contains -1 if not yet explored, or otherwise the number of mines in the 8 adjacent squares. action_mask: jax array (bool) of shape (num_rows, num_cols): indicates which actions are valid (not yet explored squares). num_mines: jax array (int32) of shape () , indicates the number of mines to locate. step_count: jax array (int32) of shape (): specifies how many timesteps have elapsed since environment reset. action: multi discrete array containing the square to explore (row and col). reward: jax array (float32): Configurable function of state and action. By default: 1 for every timestep where a valid action is chosen that doesn't reveal a mine, 0 for revealing a mine or selecting an already revealed square (and terminate the episode). episode termination: Configurable function of state, next_state, and action. By default: Stop the episode if a mine is explored, an invalid action is selected (exploring an already explored square), or the board is solved. state: State board: jax array (int32) of shape (num_rows, num_cols): each cell contains -1 if not yet explored, or otherwise the number of mines in the 8 adjacent squares. step_count: jax array (int32) of shape (): specifies how many timesteps have elapsed since environment reset. flat_mine_locations: jax array (int32) of shape (num_rows * num_cols,): indicates the (flat) locations of all the mines on the board. Will be of length num_mines. key: jax array (int32) of shape (2,) used for seeding the sampling of mine placement on reset. 1 2 3 4 5 6 7 8 from jumanji.environments import Minesweeper env = Minesweeper () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state )","title":"Minesweeper"},{"location":"api/environments/minesweeper/#jumanji.environments.logic.minesweeper.env.Minesweeper.__init__","text":"Instantiate a Minesweeper environment. Parameters: Name Type Description Default generator Optional[jumanji.environments.logic.minesweeper.generator.Generator] Generator to generate problem instances on environment reset. Implemented options are [ SamplingGenerator ]. Defaults to SamplingGenerator . The generator will have attributes: - num_rows: number of rows, i.e. height of the board. Defaults to 10. - num_cols: number of columns, i.e. width of the board. Defaults to 10. - num_mines: number of mines generated. Defaults to 10. None reward_function Optional[jumanji.environments.logic.minesweeper.reward.RewardFn] RewardFn whose __call__ method computes the reward of an environment transition based on the given current state and selected action. Implemented options are [ DefaultRewardFn ]. Defaults to DefaultRewardFn , giving a reward of 1.0 for revealing an empty square, 0.0 for revealing a mine, and 0.0 for an invalid action (selecting an already revealed square). None done_function Optional[jumanji.environments.logic.minesweeper.done.DoneFn] DoneFn whose __call__ method computes the done signal given the current state, action taken, and next state. Implemented options are [ DefaultDoneFn ]. Defaults to DefaultDoneFn , ending the episode on solving the board, revealing a mine, or picking an invalid action. None viewer Optional[jumanji.viewer.Viewer[jumanji.environments.logic.minesweeper.types.State]] Viewer to support rendering and animation methods. Implemented options are [ MinesweeperViewer ]. Defaults to MinesweeperViewer . None","title":"__init__()"},{"location":"api/environments/minesweeper/#jumanji.environments.logic.minesweeper.env.Minesweeper.reset","text":"Resets the environment. Parameters: Name Type Description Default key PRNGKeyArray needed for placing mines. required Returns: Type Description state State corresponding to the new state of the environment, timestep: TimeStep corresponding to the first timestep returned by the environment.","title":"reset()"},{"location":"api/environments/minesweeper/#jumanji.environments.logic.minesweeper.env.Minesweeper.step","text":"Run one timestep of the environment's dynamics. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] Array containing the row and column of the square to be explored. required Returns: Type Description next_state State corresponding to the next state of the environment, next_timestep: TimeStep corresponding to the timestep returned by the environment.","title":"step()"},{"location":"api/environments/minesweeper/#jumanji.environments.logic.minesweeper.env.Minesweeper.observation_spec","text":"Specifications of the observation of the Minesweeper environment. Returns: Type Description Spec for the `Observation` whose fields are board: BoundedArray (int32) of shape (num_rows, num_cols). action_mask: BoundedArray (bool) of shape (num_rows, num_cols). num_mines: BoundedArray (int32) of shape (). step_count: BoundedArray (int32) of shape ().","title":"observation_spec()"},{"location":"api/environments/minesweeper/#jumanji.environments.logic.minesweeper.env.Minesweeper.action_spec","text":"Returns the action spec. An action consists of the height and width of the square to be explored. Returns: Type Description action_spec specs.MultiDiscreteArray object.","title":"action_spec()"},{"location":"api/environments/mmst/","text":"MMST ( Environment ) # The MMST (Multi Minimum Spanning Tree) environment consists of a random connected graph with groups of nodes (same node types) that needs to be connected. The goal of the environment is to connect all nodes of the same type together without using the same utility nodes (nodes that do not belong to any group of nodes). Note: routing problems are randomly generated and may not be solvable! Requirements: The total number of nodes should be at least 20% more than the number of nodes we want to connect to guarantee we have enough remaining nodes to create a path with all the nodes we want to connect. An exception will be raised if the number of nodes is not greater than (0.8 x num_agents x num_nodes_per_agent). observation: Observation node_types: jax array (int) of shape (num_nodes): the component type of each node (-1 represents utility nodes). adj_matrix: jax array (bool) of shape (num_nodes, num_nodes): adjacency matrix of the graph. positions: jax array (int) of shape (num_agents,): the index of the last visited node. step_count: jax array (int) of shape (): integer to keep track of the number of steps. action_mask: jax array (bool) of shape (num_agent, num_nodes): binary mask (False/True <--> invalid/valid action). reward: float action: jax array (int) of shape (num_agents,): [0,1,..., num_nodes-1] Each agent selects the next node to which it wants to connect. state: State node_type: jax array (int) of shape (num_nodes,). the component type of each node (-1 represents utility nodes). adj_matrix: jax array (bool) of shape (num_nodes, num_nodes): adjacency matrix of the graph. connected_nodes: jax array (int) of shape (num_agents, time_limit). we only count each node visit once. connected_nodes_index: jax array (int) of shape (num_agents, num_nodes). position_index: jax array (int) of shape (num_agents,). node_edges: jax array (int) of shape (num_agents, num_nodes, num_nodes). positions: jax array (int) of shape (num_agents,). the index of the last visited node. action_mask: jax array (bool) of shape (num_agent, num_nodes). binary mask (False/True <--> invalid/valid action). finished_agents: jax array (bool) of shape (num_agent,). nodes_to_connect: jax array (int) of shape (num_agents, num_nodes_per_agent). step_count: step counter. time_limit: the number of steps allowed before an episode terminates. key: PRNG key for random sample. constants definitions: Nodes INVALID_NODE = -1: used to check if an agent selects an invalid node. A node may be invalid if its has no edge with the current node or if it is a utility node already selected by another agent. UTILITY_NODE = -1: utility node (belongs to no agent). EMPTY_NODE = -1: used for padding. state.connected_nodes stores the path (all the nodes) visited by an agent. Hence it has size equal to the step limit. We use this constant to initialise this array since 0 represents the first node. DUMMY_NODE = -10: used for tie-breaking if multiple agents select the same node. Edges EMPTY_EDGE = -1: used for masking edges array. state.node_edges is the graph's adjacency matrix, but we don't represent it using 0s and 1s, we use the node values instead, i.e A_ij = j or A_ij = -1 . Also edges are masked when utility nodes are selected by an agent to make it unaccessible by other agents. Actions encoding INVALID_CHOICE = -1 INVALID_TIE_BREAK = -2 INVALID_ALREADY_TRAVERSED = -3 __init__ ( self , generator : Optional [ jumanji . environments . routing . mmst . generator . Generator ] = None , reward_fn : Optional [ jumanji . environments . routing . mmst . reward . RewardFn ] = None , time_limit : int = 70 , viewer : Optional [ jumanji . viewer . Viewer [ jumanji . environments . routing . mmst . types . State ]] = None ) special # Create the MMST environment. Parameters: Name Type Description Default generator Optional[jumanji.environments.routing.mmst.generator.Generator] Generator whose __call__ instantiates an environment instance. Implemented options are [ SplitRandomGenerator ]. Defaults to SplitRandomGenerator(num_nodes=36, num_edges=72, max_degree=5, num_agents=3, num_nodes_per_agent=4, max_step=time_limit) . None reward_fn Optional[jumanji.environments.routing.mmst.reward.RewardFn] class of type RewardFn , whose __call__ is used as a reward function. Implemented options are [ DenseRewardFn ]. Defaults to DenseRewardFn(reward_values=(10.0, -1.0, -1.0)) . None time_limit int the number of steps allowed before an episode terminates. Defaults to 70. 70 viewer Optional[jumanji.viewer.Viewer[jumanji.environments.routing.mmst.types.State]] Viewer used for rendering. Defaults to MMSTViewer None reset ( self , key : PRNGKeyArray ) -> Tuple [ jumanji . environments . routing . mmst . types . State , jumanji . types . TimeStep [ jumanji . environments . routing . mmst . types . Observation ]] # Resets the environment. Parameters: Name Type Description Default key PRNGKeyArray used to randomly generate the problem and the different start nodes. required Returns: Type Description state State object corresponding to the new state of the environment. timestep: TimeStep object corresponding to the first timestep returned by the environment. step ( self , state : State , action : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ]) -> Tuple [ jumanji . environments . routing . mmst . types . State , jumanji . types . TimeStep [ jumanji . environments . routing . mmst . types . Observation ]] # Run one timestep of the environment's dynamics. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] Array containing the index of the next node to visit. required Returns: Type Description state, timestep Tuple[State, TimeStep] containing the next state of the environment, as well as the timestep to be observed. action_spec ( self ) -> MultiDiscreteArray # Returns the action spec. Returns: Type Description action_spec a specs.MultiDiscreteArray spec. observation_spec ( self ) -> jumanji . specs . Spec [ jumanji . environments . routing . mmst . types . Observation ] # Returns the observation spec. Returns: Type Description Spec for the `Observation` whose fields are node_types: BoundedArray (int32) of shape (num_nodes,). adj_matrix: BoundedArray (int) of shape (num_nodes, num_nodes). Represents the adjacency matrix of the graph. positions: BoundedArray (int32) of shape (num_agents). Current node position of agent. action_mask: BoundedArray (bool) of shape (num_agents, num_nodes,). Represents the valid actions in the current state. render ( self , state : State ) -> Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ] # Render the environment for a given state. Returns: Type Description Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] Array of rgb pixel values in the shape (width, height, rgb).","title":"MMST"},{"location":"api/environments/mmst/#jumanji.environments.routing.mmst.env.MMST","text":"The MMST (Multi Minimum Spanning Tree) environment consists of a random connected graph with groups of nodes (same node types) that needs to be connected. The goal of the environment is to connect all nodes of the same type together without using the same utility nodes (nodes that do not belong to any group of nodes). Note: routing problems are randomly generated and may not be solvable! Requirements: The total number of nodes should be at least 20% more than the number of nodes we want to connect to guarantee we have enough remaining nodes to create a path with all the nodes we want to connect. An exception will be raised if the number of nodes is not greater than (0.8 x num_agents x num_nodes_per_agent). observation: Observation node_types: jax array (int) of shape (num_nodes): the component type of each node (-1 represents utility nodes). adj_matrix: jax array (bool) of shape (num_nodes, num_nodes): adjacency matrix of the graph. positions: jax array (int) of shape (num_agents,): the index of the last visited node. step_count: jax array (int) of shape (): integer to keep track of the number of steps. action_mask: jax array (bool) of shape (num_agent, num_nodes): binary mask (False/True <--> invalid/valid action). reward: float action: jax array (int) of shape (num_agents,): [0,1,..., num_nodes-1] Each agent selects the next node to which it wants to connect. state: State node_type: jax array (int) of shape (num_nodes,). the component type of each node (-1 represents utility nodes). adj_matrix: jax array (bool) of shape (num_nodes, num_nodes): adjacency matrix of the graph. connected_nodes: jax array (int) of shape (num_agents, time_limit). we only count each node visit once. connected_nodes_index: jax array (int) of shape (num_agents, num_nodes). position_index: jax array (int) of shape (num_agents,). node_edges: jax array (int) of shape (num_agents, num_nodes, num_nodes). positions: jax array (int) of shape (num_agents,). the index of the last visited node. action_mask: jax array (bool) of shape (num_agent, num_nodes). binary mask (False/True <--> invalid/valid action). finished_agents: jax array (bool) of shape (num_agent,). nodes_to_connect: jax array (int) of shape (num_agents, num_nodes_per_agent). step_count: step counter. time_limit: the number of steps allowed before an episode terminates. key: PRNG key for random sample. constants definitions: Nodes INVALID_NODE = -1: used to check if an agent selects an invalid node. A node may be invalid if its has no edge with the current node or if it is a utility node already selected by another agent. UTILITY_NODE = -1: utility node (belongs to no agent). EMPTY_NODE = -1: used for padding. state.connected_nodes stores the path (all the nodes) visited by an agent. Hence it has size equal to the step limit. We use this constant to initialise this array since 0 represents the first node. DUMMY_NODE = -10: used for tie-breaking if multiple agents select the same node. Edges EMPTY_EDGE = -1: used for masking edges array. state.node_edges is the graph's adjacency matrix, but we don't represent it using 0s and 1s, we use the node values instead, i.e A_ij = j or A_ij = -1 . Also edges are masked when utility nodes are selected by an agent to make it unaccessible by other agents. Actions encoding INVALID_CHOICE = -1 INVALID_TIE_BREAK = -2 INVALID_ALREADY_TRAVERSED = -3","title":"MMST"},{"location":"api/environments/mmst/#jumanji.environments.routing.mmst.env.MMST.__init__","text":"Create the MMST environment. Parameters: Name Type Description Default generator Optional[jumanji.environments.routing.mmst.generator.Generator] Generator whose __call__ instantiates an environment instance. Implemented options are [ SplitRandomGenerator ]. Defaults to SplitRandomGenerator(num_nodes=36, num_edges=72, max_degree=5, num_agents=3, num_nodes_per_agent=4, max_step=time_limit) . None reward_fn Optional[jumanji.environments.routing.mmst.reward.RewardFn] class of type RewardFn , whose __call__ is used as a reward function. Implemented options are [ DenseRewardFn ]. Defaults to DenseRewardFn(reward_values=(10.0, -1.0, -1.0)) . None time_limit int the number of steps allowed before an episode terminates. Defaults to 70. 70 viewer Optional[jumanji.viewer.Viewer[jumanji.environments.routing.mmst.types.State]] Viewer used for rendering. Defaults to MMSTViewer None","title":"__init__()"},{"location":"api/environments/mmst/#jumanji.environments.routing.mmst.env.MMST.reset","text":"Resets the environment. Parameters: Name Type Description Default key PRNGKeyArray used to randomly generate the problem and the different start nodes. required Returns: Type Description state State object corresponding to the new state of the environment. timestep: TimeStep object corresponding to the first timestep returned by the environment.","title":"reset()"},{"location":"api/environments/mmst/#jumanji.environments.routing.mmst.env.MMST.step","text":"Run one timestep of the environment's dynamics. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] Array containing the index of the next node to visit. required Returns: Type Description state, timestep Tuple[State, TimeStep] containing the next state of the environment, as well as the timestep to be observed.","title":"step()"},{"location":"api/environments/mmst/#jumanji.environments.routing.mmst.env.MMST.action_spec","text":"Returns the action spec. Returns: Type Description action_spec a specs.MultiDiscreteArray spec.","title":"action_spec()"},{"location":"api/environments/mmst/#jumanji.environments.routing.mmst.env.MMST.observation_spec","text":"Returns the observation spec. Returns: Type Description Spec for the `Observation` whose fields are node_types: BoundedArray (int32) of shape (num_nodes,). adj_matrix: BoundedArray (int) of shape (num_nodes, num_nodes). Represents the adjacency matrix of the graph. positions: BoundedArray (int32) of shape (num_agents). Current node position of agent. action_mask: BoundedArray (bool) of shape (num_agents, num_nodes,). Represents the valid actions in the current state.","title":"observation_spec()"},{"location":"api/environments/mmst/#jumanji.environments.routing.mmst.env.MMST.render","text":"Render the environment for a given state. Returns: Type Description Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] Array of rgb pixel values in the shape (width, height, rgb).","title":"render()"},{"location":"api/environments/rubiks_cube/","text":"RubiksCube ( Environment ) # A JAX implementation of the Rubik's Cube with a configurable cube size (by default, 3) and number of scrambles at reset. observation: Observation cube: jax array (int8) of shape (6, cube_size, cube_size): each cell contains the index of the corresponding colour of the sticker in the scramble. step_count: jax array (int32) of shape (): specifies how many timesteps have elapsed since environment reset. action: multi discrete array containing the move to perform (face, depth, and direction). reward: jax array (float) of shape (): by default, 1.0 if cube is solved, otherwise 0.0. episode termination: if either the cube is solved or a time limit is reached. state: State cube: jax array (int8) of shape (6, cube_size, cube_size): each cell contains the index of the corresponding colour of the sticker in the scramble. step_count: jax array (int32) of shape (): specifies how many timesteps have elapsed since environment reset. key: jax array (uint) of shape (2,) used for seeding the sampling for scrambling on reset. 1 2 3 4 5 6 7 8 from jumanji.environments import RubiksCube env = RubiksCube () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state ) __init__ ( self , generator : Optional [ jumanji . environments . logic . rubiks_cube . generator . Generator ] = None , time_limit : int = 200 , reward_fn : Optional [ jumanji . environments . logic . rubiks_cube . reward . RewardFn ] = None , viewer : Optional [ jumanji . viewer . Viewer [ jumanji . environments . logic . rubiks_cube . types . State ]] = None ) special # Instantiate a RubiksCube environment. Parameters: Name Type Description Default generator Optional[jumanji.environments.logic.rubiks_cube.generator.Generator] Generator used to generate problem instances on environment reset. Implemented options are [ ScramblingGenerator ]. Defaults to ScramblingGenerator , with 100 scrambles on reset. The generator will contain an attribute cube_size , corresponding to the number of cubies to an edge, and defaulting to 3. None time_limit int the number of steps allowed before an episode terminates. Defaults to 200. 200 reward_fn Optional[jumanji.environments.logic.rubiks_cube.reward.RewardFn] RewardFn whose __call__ method computes the reward given the new state. Implemented options are [ SparseRewardFn ]. Defaults to SparseRewardFn , giving a reward of 1.0 if the cube is solved or otherwise 0.0. None viewer Optional[jumanji.viewer.Viewer[jumanji.environments.logic.rubiks_cube.types.State]] Viewer to support rendering and animation methods. Implemented options are [ RubiksCubeViewer ]. Defaults to RubiksCubeViewer . None reset ( self , key : PRNGKeyArray ) -> Tuple [ jumanji . environments . logic . rubiks_cube . types . State , jumanji . types . TimeStep [ jumanji . environments . logic . rubiks_cube . types . Observation ]] # Resets the environment. Parameters: Name Type Description Default key PRNGKeyArray needed for scramble. required Returns: Type Description state State corresponding to the new state of the environment. timestep: TimeStep corresponding to the first timestep returned by the environment. step ( self , state : State , action : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ]) -> Tuple [ jumanji . environments . logic . rubiks_cube . types . State , jumanji . types . TimeStep [ jumanji . environments . logic . rubiks_cube . types . Observation ]] # Run one timestep of the environment's dynamics. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] Array of shape (3,) indicating the face to move, depth of the move, and the amount to move by. required Returns: Type Description next_state State corresponding to the next state of the environment. next_timestep: TimeStep corresponding to the timestep returned by the environment. observation_spec ( self ) -> jumanji . specs . Spec [ jumanji . environments . logic . rubiks_cube . types . Observation ] # Specifications of the observation of the RubiksCube environment. Returns: Type Description Spec containing all the specifications for all the `Observation` fields cube: BoundedArray (jnp.int8) of shape (num_faces, cube_size, cube_size). step_count: BoundedArray (jnp.int32) of shape (). action_spec ( self ) -> MultiDiscreteArray # Returns the action spec. An action is composed of 3 elements that range in: 6 faces, each with cube_size//2 possible depths, and 3 possible directions. Returns: Type Description action_spec MultiDiscreteArray object.","title":"RubiksCube"},{"location":"api/environments/rubiks_cube/#jumanji.environments.logic.rubiks_cube.env.RubiksCube","text":"A JAX implementation of the Rubik's Cube with a configurable cube size (by default, 3) and number of scrambles at reset. observation: Observation cube: jax array (int8) of shape (6, cube_size, cube_size): each cell contains the index of the corresponding colour of the sticker in the scramble. step_count: jax array (int32) of shape (): specifies how many timesteps have elapsed since environment reset. action: multi discrete array containing the move to perform (face, depth, and direction). reward: jax array (float) of shape (): by default, 1.0 if cube is solved, otherwise 0.0. episode termination: if either the cube is solved or a time limit is reached. state: State cube: jax array (int8) of shape (6, cube_size, cube_size): each cell contains the index of the corresponding colour of the sticker in the scramble. step_count: jax array (int32) of shape (): specifies how many timesteps have elapsed since environment reset. key: jax array (uint) of shape (2,) used for seeding the sampling for scrambling on reset. 1 2 3 4 5 6 7 8 from jumanji.environments import RubiksCube env = RubiksCube () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state )","title":"RubiksCube"},{"location":"api/environments/rubiks_cube/#jumanji.environments.logic.rubiks_cube.env.RubiksCube.__init__","text":"Instantiate a RubiksCube environment. Parameters: Name Type Description Default generator Optional[jumanji.environments.logic.rubiks_cube.generator.Generator] Generator used to generate problem instances on environment reset. Implemented options are [ ScramblingGenerator ]. Defaults to ScramblingGenerator , with 100 scrambles on reset. The generator will contain an attribute cube_size , corresponding to the number of cubies to an edge, and defaulting to 3. None time_limit int the number of steps allowed before an episode terminates. Defaults to 200. 200 reward_fn Optional[jumanji.environments.logic.rubiks_cube.reward.RewardFn] RewardFn whose __call__ method computes the reward given the new state. Implemented options are [ SparseRewardFn ]. Defaults to SparseRewardFn , giving a reward of 1.0 if the cube is solved or otherwise 0.0. None viewer Optional[jumanji.viewer.Viewer[jumanji.environments.logic.rubiks_cube.types.State]] Viewer to support rendering and animation methods. Implemented options are [ RubiksCubeViewer ]. Defaults to RubiksCubeViewer . None","title":"__init__()"},{"location":"api/environments/rubiks_cube/#jumanji.environments.logic.rubiks_cube.env.RubiksCube.reset","text":"Resets the environment. Parameters: Name Type Description Default key PRNGKeyArray needed for scramble. required Returns: Type Description state State corresponding to the new state of the environment. timestep: TimeStep corresponding to the first timestep returned by the environment.","title":"reset()"},{"location":"api/environments/rubiks_cube/#jumanji.environments.logic.rubiks_cube.env.RubiksCube.step","text":"Run one timestep of the environment's dynamics. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] Array of shape (3,) indicating the face to move, depth of the move, and the amount to move by. required Returns: Type Description next_state State corresponding to the next state of the environment. next_timestep: TimeStep corresponding to the timestep returned by the environment.","title":"step()"},{"location":"api/environments/rubiks_cube/#jumanji.environments.logic.rubiks_cube.env.RubiksCube.observation_spec","text":"Specifications of the observation of the RubiksCube environment. Returns: Type Description Spec containing all the specifications for all the `Observation` fields cube: BoundedArray (jnp.int8) of shape (num_faces, cube_size, cube_size). step_count: BoundedArray (jnp.int32) of shape ().","title":"observation_spec()"},{"location":"api/environments/rubiks_cube/#jumanji.environments.logic.rubiks_cube.env.RubiksCube.action_spec","text":"Returns the action spec. An action is composed of 3 elements that range in: 6 faces, each with cube_size//2 possible depths, and 3 possible directions. Returns: Type Description action_spec MultiDiscreteArray object.","title":"action_spec()"},{"location":"api/environments/rware/","text":"RobotWarehouse ( Environment ) # A JAX implementation of the 'Robotic warehouse' environment: https://github.com/semitable/robotic-warehouse which is described in the paper [1]. Creates a grid world where multiple agents (robots) are supposed to collect shelves, bring them to a goal and then return them. Below is an example warehouse floor grid: the grid layout is instantiated using three arguments shelf_rows: number of vertical shelf clusters shelf_columns: odd number of horizontal shelf clusters column_height: height of each cluster A cluster is a set of grouped shelves (two cells wide) represented below as 1 XX Shelf cluster -> XX (this cluster is of height 3) XX Grid Layout: 1 2 3 4 shelf columns (here set to 3, i.e. v v v shelf_columns=3, must be an odd number) ---------- > -XX-XX-XX- ^ Shelf Row 1 -> -XX-XX-XX- Column Height (here set to 3, i.e. > -XX-XX-XX- v column_height=3) ---------- -XX----XX- < -XX----XX- <- Shelf Row 2 (here set to 2, i.e. -XX----XX- < shelf_rows=2) ---------- ----GG---- G: is the goal positions where agents are rewarded if they successfully deliver a requested shelf (i.e toggle the load action inside the goal position while carrying a requested shelf). The final grid size will be - height: (column_height + 1) * shelf_rows + 2 - width: (2 + 1) * shelf_columns + 1 The bottom-middle column is removed to allow for agents to queue in front of the goal positions action: jax array (int) of shape (num_agents,) containing the action for each agent. (0: noop, 1: forward, 2: left, 3: right, 4: toggle_load) reward: jax array (int) of shape (), global reward shared by all agents, +1 for every successful delivery of a requested shelf to the goal position. episode termination: The number of steps is greater than the limit. Any agent selects an action which causes two agents to collide. state: State grid: an array representing the warehouse floor as a 2D grid with two separate channels one for the agents, and one for the shelves agents: a pytree of Agent type with per agent leaves: [position, direction, is_carrying] shelves: a pytree of Shelf type with per shelf leaves: [position, is_requested] request_queue: the queue of requested shelves (by ID). step_count: an integer representing the current step of the episode. action_mask: an array of shape (num_agents, 5) containing the valid actions for each agent. key: a pseudorandom number generator key. [1] Papoudakis et al., Benchmarking Multi-Agent Deep Reinforcement Learning Algorithms in Cooperative Tasks (2021) 1 2 3 4 5 6 7 8 from jumanji.environments import RobotWarehouse env = RobotWarehouse () key = jax . random . PRNGKey ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state ) __init__ ( self , generator : Optional [ jumanji . environments . routing . robot_warehouse . generator . Generator ] = None , time_limit : int = 500 , viewer : Optional [ jumanji . viewer . Viewer [ jumanji . environments . routing . robot_warehouse . types . State ]] = None ) special # Instantiates an RobotWarehouse environment. Parameters: Name Type Description Default generator Optional[jumanji.environments.routing.robot_warehouse.generator.Generator] callable to instantiate environment instances. Defaults to RandomGenerator with parameters: shelf_rows = 2 , shelf_columns = 3 , column_height = 8 , num_agents = 4 , sensor_range = 1 , request_queue_size = 8 . None time_limit int the maximum step limit allowed within the environment. Defaults to 500. 500 viewer Optional[jumanji.viewer.Viewer[jumanji.environments.routing.robot_warehouse.types.State]] viewer to render the environment. Defaults to RobotWarehouseViewer . None reset ( self , key : PRNGKeyArray ) -> Tuple [ jumanji . environments . routing . robot_warehouse . types . State , jumanji . types . TimeStep [ jumanji . environments . routing . robot_warehouse . types . Observation ]] # Resets the environment. Parameters: Name Type Description Default key PRNGKeyArray random key used to reset the environment since it is stochastic. required Returns: Type Description state State object corresponding to the new state of the environment. timestep: TimeStep object corresponding the first timestep returned by the environment. step ( self , state : State , action : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ]) -> Tuple [ jumanji . environments . routing . robot_warehouse . types . State , jumanji . types . TimeStep [ jumanji . environments . routing . robot_warehouse . types . Observation ]] # Perform an environment step. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] Array containing the action to take. - 0 no op - 1 move forward - 2 turn left - 3 turn right - 4 toggle load required Returns: Type Description state State object corresponding to the next state of the environment. timestep: TimeStep object corresponding the timestep returned by the environment. observation_spec ( self ) -> jumanji . specs . Spec [ jumanji . environments . routing . robot_warehouse . types . Observation ] # Specification of the observation of the RobotWarehouse environment. Returns: Type Description Spec for the `Observation`, consisting of the fields agents_view: Array (int32) of shape (num_agents, num_obs_features). action_mask: BoundedArray (bool) of shape (num_agent, 5). step_count: BoundedArray (int32) of shape (). action_spec ( self ) -> MultiDiscreteArray # Returns the action spec. 5 actions: [0,1,2,3,4] -> [No Op, Forward, Left, Right, Toggle_load]. Since this is a multi-agent environment, the environment expects an array of actions. This array is of shape (num_agents,).","title":"Rware"},{"location":"api/environments/rware/#jumanji.environments.routing.robot_warehouse.env.RobotWarehouse","text":"A JAX implementation of the 'Robotic warehouse' environment: https://github.com/semitable/robotic-warehouse which is described in the paper [1]. Creates a grid world where multiple agents (robots) are supposed to collect shelves, bring them to a goal and then return them. Below is an example warehouse floor grid: the grid layout is instantiated using three arguments shelf_rows: number of vertical shelf clusters shelf_columns: odd number of horizontal shelf clusters column_height: height of each cluster A cluster is a set of grouped shelves (two cells wide) represented below as 1 XX Shelf cluster -> XX (this cluster is of height 3) XX Grid Layout: 1 2 3 4 shelf columns (here set to 3, i.e. v v v shelf_columns=3, must be an odd number) ---------- > -XX-XX-XX- ^ Shelf Row 1 -> -XX-XX-XX- Column Height (here set to 3, i.e. > -XX-XX-XX- v column_height=3) ---------- -XX----XX- < -XX----XX- <- Shelf Row 2 (here set to 2, i.e. -XX----XX- < shelf_rows=2) ---------- ----GG---- G: is the goal positions where agents are rewarded if they successfully deliver a requested shelf (i.e toggle the load action inside the goal position while carrying a requested shelf). The final grid size will be - height: (column_height + 1) * shelf_rows + 2 - width: (2 + 1) * shelf_columns + 1 The bottom-middle column is removed to allow for agents to queue in front of the goal positions action: jax array (int) of shape (num_agents,) containing the action for each agent. (0: noop, 1: forward, 2: left, 3: right, 4: toggle_load) reward: jax array (int) of shape (), global reward shared by all agents, +1 for every successful delivery of a requested shelf to the goal position. episode termination: The number of steps is greater than the limit. Any agent selects an action which causes two agents to collide. state: State grid: an array representing the warehouse floor as a 2D grid with two separate channels one for the agents, and one for the shelves agents: a pytree of Agent type with per agent leaves: [position, direction, is_carrying] shelves: a pytree of Shelf type with per shelf leaves: [position, is_requested] request_queue: the queue of requested shelves (by ID). step_count: an integer representing the current step of the episode. action_mask: an array of shape (num_agents, 5) containing the valid actions for each agent. key: a pseudorandom number generator key. [1] Papoudakis et al., Benchmarking Multi-Agent Deep Reinforcement Learning Algorithms in Cooperative Tasks (2021) 1 2 3 4 5 6 7 8 from jumanji.environments import RobotWarehouse env = RobotWarehouse () key = jax . random . PRNGKey ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state )","title":"RobotWarehouse"},{"location":"api/environments/rware/#jumanji.environments.routing.robot_warehouse.env.RobotWarehouse.__init__","text":"Instantiates an RobotWarehouse environment. Parameters: Name Type Description Default generator Optional[jumanji.environments.routing.robot_warehouse.generator.Generator] callable to instantiate environment instances. Defaults to RandomGenerator with parameters: shelf_rows = 2 , shelf_columns = 3 , column_height = 8 , num_agents = 4 , sensor_range = 1 , request_queue_size = 8 . None time_limit int the maximum step limit allowed within the environment. Defaults to 500. 500 viewer Optional[jumanji.viewer.Viewer[jumanji.environments.routing.robot_warehouse.types.State]] viewer to render the environment. Defaults to RobotWarehouseViewer . None","title":"__init__()"},{"location":"api/environments/rware/#jumanji.environments.routing.robot_warehouse.env.RobotWarehouse.reset","text":"Resets the environment. Parameters: Name Type Description Default key PRNGKeyArray random key used to reset the environment since it is stochastic. required Returns: Type Description state State object corresponding to the new state of the environment. timestep: TimeStep object corresponding the first timestep returned by the environment.","title":"reset()"},{"location":"api/environments/rware/#jumanji.environments.routing.robot_warehouse.env.RobotWarehouse.step","text":"Perform an environment step. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] Array containing the action to take. - 0 no op - 1 move forward - 2 turn left - 3 turn right - 4 toggle load required Returns: Type Description state State object corresponding to the next state of the environment. timestep: TimeStep object corresponding the timestep returned by the environment.","title":"step()"},{"location":"api/environments/rware/#jumanji.environments.routing.robot_warehouse.env.RobotWarehouse.observation_spec","text":"Specification of the observation of the RobotWarehouse environment. Returns: Type Description Spec for the `Observation`, consisting of the fields agents_view: Array (int32) of shape (num_agents, num_obs_features). action_mask: BoundedArray (bool) of shape (num_agent, 5). step_count: BoundedArray (int32) of shape ().","title":"observation_spec()"},{"location":"api/environments/rware/#jumanji.environments.routing.robot_warehouse.env.RobotWarehouse.action_spec","text":"Returns the action spec. 5 actions: [0,1,2,3,4] -> [No Op, Forward, Left, Right, Toggle_load]. Since this is a multi-agent environment, the environment expects an array of actions. This array is of shape (num_agents,).","title":"action_spec()"},{"location":"api/environments/snake/","text":"Snake ( Environment ) # A JAX implementation of the 'Snake' game. observation: Observation grid: jax array (float) of shape (num_rows, num_cols, 5) feature maps that include information about the fruit, the snake head, its body and tail. body: 2D map with 1. where a body cell is present, else 0. head: 2D map with 1. where the snake's head is located, else 0. tail: 2D map with 1. where the snake's tail is located, else 0. fruit: 2D map with 1. where the fruit is located, else 0. norm_body_state: 2D map with a float between 0. and 1. for each body cell in the decreasing order from head to tail. step_count: jax array (int32) of shape () current number of steps in the episode. action_mask: jax array (bool) of shape (4,) array specifying which directions the snake can move in from its current position. action: jax array (int32) of shape() [0,1,2,3] -> [Up, Right, Down, Left]. reward: jax array (float) of shape () 1.0 if a fruit is eaten, otherwise 0.0. episode termination: if no action can be performed, i.e. the snake is surrounded. if the time limit is reached. if an invalid action is taken, the snake exits the grid or bumps into itself. state: State body: jax array (bool) of shape (num_rows, num_cols) array indicating the snake's body cells. body_state: jax array (int32) of shape (num_rows, num_cols) array ordering the snake's body cells, in decreasing order from head to tail. head_position: Position (int32) of shape () position of the snake's head on the 2D grid. tail: jax array (bool) of shape (num_rows, num_cols) array indicating the snake's tail. fruit_position: Position (int32) of shape () position of the fruit on the 2D grid. length: jax array (int32) of shape () current length of the snake. step_count: jax array (int32) of shape () current number of steps in the episode. action_mask: jax array (bool) of shape (4,) array specifying which directions the snake can move in from its current position. key: jax array (uint32) of shape (2,) random key used to sample a new fruit when one is eaten and used for auto-reset. 1 2 3 4 5 6 7 8 from jumanji.environments import Snake env = Snake () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state ) __init__ ( self , num_rows : int = 12 , num_cols : int = 12 , time_limit : int = 4000 , viewer : Optional [ jumanji . viewer . Viewer [ jumanji . environments . routing . snake . types . State ]] = None ) special # Instantiates a Snake environment. Parameters: Name Type Description Default num_rows int number of rows of the 2D grid. Defaults to 12. 12 num_cols int number of columns of the 2D grid. Defaults to 12. 12 time_limit int time_limit of an episode, i.e. number of environment steps before the episode ends. Defaults to 4000. 4000 viewer Optional[jumanji.viewer.Viewer[jumanji.environments.routing.snake.types.State]] Viewer used for rendering. Defaults to SnakeViewer . None reset ( self , key : PRNGKeyArray ) -> Tuple [ jumanji . environments . routing . snake . types . State , jumanji . types . TimeStep [ jumanji . environments . routing . snake . types . Observation ]] # Resets the environment. Parameters: Name Type Description Default key PRNGKeyArray random key used to sample the snake and fruit positions. required Returns: Type Description state State object corresponding to the new state of the environment. timestep: TimeStep object corresponding to the first timestep returned by the environment. step ( self , state : State , action : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number , float , int ]) -> Tuple [ jumanji . environments . routing . snake . types . State , jumanji . types . TimeStep [ jumanji . environments . routing . snake . types . Observation ]] # Run one timestep of the environment's dynamics. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, float, int] Array containing the action to take: - 0: move up. - 1: move to the right. - 2: move down. - 3: move to the left. required Returns: Type Description state, timestep next state of the environment and timestep to be observed. observation_spec ( self ) -> jumanji . specs . Spec [ jumanji . environments . routing . snake . types . Observation ] # Returns the observation spec. Returns: Type Description Spec for the `Observation` whose fields are grid: BoundedArray (float) of shape (num_rows, num_cols, 5). step_count: DiscreteArray (num_values = time_limit) of shape (). action_mask: BoundedArray (bool) of shape (4,). action_spec ( self ) -> DiscreteArray # Returns the action spec. 4 actions: [0,1,2,3] -> [Up, Right, Down, Left]. Returns: Type Description action_spec a specs.DiscreteArray spec.","title":"Snake"},{"location":"api/environments/snake/#jumanji.environments.routing.snake.env.Snake","text":"A JAX implementation of the 'Snake' game. observation: Observation grid: jax array (float) of shape (num_rows, num_cols, 5) feature maps that include information about the fruit, the snake head, its body and tail. body: 2D map with 1. where a body cell is present, else 0. head: 2D map with 1. where the snake's head is located, else 0. tail: 2D map with 1. where the snake's tail is located, else 0. fruit: 2D map with 1. where the fruit is located, else 0. norm_body_state: 2D map with a float between 0. and 1. for each body cell in the decreasing order from head to tail. step_count: jax array (int32) of shape () current number of steps in the episode. action_mask: jax array (bool) of shape (4,) array specifying which directions the snake can move in from its current position. action: jax array (int32) of shape() [0,1,2,3] -> [Up, Right, Down, Left]. reward: jax array (float) of shape () 1.0 if a fruit is eaten, otherwise 0.0. episode termination: if no action can be performed, i.e. the snake is surrounded. if the time limit is reached. if an invalid action is taken, the snake exits the grid or bumps into itself. state: State body: jax array (bool) of shape (num_rows, num_cols) array indicating the snake's body cells. body_state: jax array (int32) of shape (num_rows, num_cols) array ordering the snake's body cells, in decreasing order from head to tail. head_position: Position (int32) of shape () position of the snake's head on the 2D grid. tail: jax array (bool) of shape (num_rows, num_cols) array indicating the snake's tail. fruit_position: Position (int32) of shape () position of the fruit on the 2D grid. length: jax array (int32) of shape () current length of the snake. step_count: jax array (int32) of shape () current number of steps in the episode. action_mask: jax array (bool) of shape (4,) array specifying which directions the snake can move in from its current position. key: jax array (uint32) of shape (2,) random key used to sample a new fruit when one is eaten and used for auto-reset. 1 2 3 4 5 6 7 8 from jumanji.environments import Snake env = Snake () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state )","title":"Snake"},{"location":"api/environments/snake/#jumanji.environments.routing.snake.env.Snake.__init__","text":"Instantiates a Snake environment. Parameters: Name Type Description Default num_rows int number of rows of the 2D grid. Defaults to 12. 12 num_cols int number of columns of the 2D grid. Defaults to 12. 12 time_limit int time_limit of an episode, i.e. number of environment steps before the episode ends. Defaults to 4000. 4000 viewer Optional[jumanji.viewer.Viewer[jumanji.environments.routing.snake.types.State]] Viewer used for rendering. Defaults to SnakeViewer . None","title":"__init__()"},{"location":"api/environments/snake/#jumanji.environments.routing.snake.env.Snake.reset","text":"Resets the environment. Parameters: Name Type Description Default key PRNGKeyArray random key used to sample the snake and fruit positions. required Returns: Type Description state State object corresponding to the new state of the environment. timestep: TimeStep object corresponding to the first timestep returned by the environment.","title":"reset()"},{"location":"api/environments/snake/#jumanji.environments.routing.snake.env.Snake.step","text":"Run one timestep of the environment's dynamics. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, float, int] Array containing the action to take: - 0: move up. - 1: move to the right. - 2: move down. - 3: move to the left. required Returns: Type Description state, timestep next state of the environment and timestep to be observed.","title":"step()"},{"location":"api/environments/snake/#jumanji.environments.routing.snake.env.Snake.observation_spec","text":"Returns the observation spec. Returns: Type Description Spec for the `Observation` whose fields are grid: BoundedArray (float) of shape (num_rows, num_cols, 5). step_count: DiscreteArray (num_values = time_limit) of shape (). action_mask: BoundedArray (bool) of shape (4,).","title":"observation_spec()"},{"location":"api/environments/snake/#jumanji.environments.routing.snake.env.Snake.action_spec","text":"Returns the action spec. 4 actions: [0,1,2,3] -> [Up, Right, Down, Left]. Returns: Type Description action_spec a specs.DiscreteArray spec.","title":"action_spec()"},{"location":"api/environments/sudoku/","text":"Sudoku ( Environment ) # A JAX implementation of the sudoku game. observation: Observation board: jax array (int32) of shape (9,9): empty cells are represented by -1, and filled cells are represented by 0-8. action_mask: jax array (bool) of shape (9,9,9): indicates which actions are valid. action: multi discrete array containing the square to write a digit, and the digits to input. reward: jax array (float32): 1 at the end of the episode if the board is valid 0 otherwise state: State board: jax array (int32) of shape (9,9): empty cells are represented by -1, and filled cells are represented by 0-8. action_mask: jax array (bool) of shape (9,9,9): indicates which actions are valid (empty cells and valid digits). key: jax array (int32) of shape (2,) used for seeding initial sudoku configuration. 1 2 3 4 5 6 7 8 from jumanji.environments import Sudoku env = Sudoku () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state ) __init__ ( self , generator : Optional [ jumanji . environments . logic . sudoku . generator . Generator ] = None , reward_fn : Optional [ jumanji . environments . logic . sudoku . reward . RewardFn ] = None , viewer : Optional [ jumanji . viewer . Viewer [ jumanji . environments . logic . sudoku . types . State ]] = None ) special # reset ( self , key : PRNGKeyArray ) -> Tuple [ jumanji . environments . logic . sudoku . types . State , jumanji . types . TimeStep [ jumanji . environments . logic . sudoku . types . Observation ]] # Resets the environment to an initial state. Parameters: Name Type Description Default key PRNGKeyArray random key used to reset the environment. required Returns: Type Description state State object corresponding to the new state of the environment, timestep: TimeStep object corresponding the first timestep returned by the environment, step ( self , state : State , action : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ]) -> Tuple [ jumanji . environments . logic . sudoku . types . State , jumanji . types . TimeStep [ jumanji . environments . logic . sudoku . types . Observation ]] # Run one timestep of the environment's dynamics. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] Array containing the action to take. required Returns: Type Description state State object corresponding to the next state of the environment, timestep: TimeStep object corresponding the timestep returned by the environment, observation_spec ( self ) -> jumanji . specs . Spec [ jumanji . environments . logic . sudoku . types . Observation ] # Returns the observation spec containing the board and action_mask arrays. Returns: Type Description Spec containing all the specifications for all the `Observation` fields board: BoundedArray (jnp.int8) of shape (9,9). action_mask: BoundedArray (bool) of shape (9,9,9). action_spec ( self ) -> MultiDiscreteArray # Returns the action spec. An action is composed of 3 integers: the row index, the column index and the value to be placed in the cell. Returns: Type Description action_spec MultiDiscreteArray object.","title":"Sudoku"},{"location":"api/environments/sudoku/#jumanji.environments.logic.sudoku.env.Sudoku","text":"A JAX implementation of the sudoku game. observation: Observation board: jax array (int32) of shape (9,9): empty cells are represented by -1, and filled cells are represented by 0-8. action_mask: jax array (bool) of shape (9,9,9): indicates which actions are valid. action: multi discrete array containing the square to write a digit, and the digits to input. reward: jax array (float32): 1 at the end of the episode if the board is valid 0 otherwise state: State board: jax array (int32) of shape (9,9): empty cells are represented by -1, and filled cells are represented by 0-8. action_mask: jax array (bool) of shape (9,9,9): indicates which actions are valid (empty cells and valid digits). key: jax array (int32) of shape (2,) used for seeding initial sudoku configuration. 1 2 3 4 5 6 7 8 from jumanji.environments import Sudoku env = Sudoku () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state )","title":"Sudoku"},{"location":"api/environments/sudoku/#jumanji.environments.logic.sudoku.env.Sudoku.__init__","text":"","title":"__init__()"},{"location":"api/environments/sudoku/#jumanji.environments.logic.sudoku.env.Sudoku.reset","text":"Resets the environment to an initial state. Parameters: Name Type Description Default key PRNGKeyArray random key used to reset the environment. required Returns: Type Description state State object corresponding to the new state of the environment, timestep: TimeStep object corresponding the first timestep returned by the environment,","title":"reset()"},{"location":"api/environments/sudoku/#jumanji.environments.logic.sudoku.env.Sudoku.step","text":"Run one timestep of the environment's dynamics. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] Array containing the action to take. required Returns: Type Description state State object corresponding to the next state of the environment, timestep: TimeStep object corresponding the timestep returned by the environment,","title":"step()"},{"location":"api/environments/sudoku/#jumanji.environments.logic.sudoku.env.Sudoku.observation_spec","text":"Returns the observation spec containing the board and action_mask arrays. Returns: Type Description Spec containing all the specifications for all the `Observation` fields board: BoundedArray (jnp.int8) of shape (9,9). action_mask: BoundedArray (bool) of shape (9,9,9).","title":"observation_spec()"},{"location":"api/environments/sudoku/#jumanji.environments.logic.sudoku.env.Sudoku.action_spec","text":"Returns the action spec. An action is composed of 3 integers: the row index, the column index and the value to be placed in the cell. Returns: Type Description action_spec MultiDiscreteArray object.","title":"action_spec()"},{"location":"api/environments/tetris/","text":"Tetris ( Environment ) # RL Environment for the game of Tetris. The environment has a grid where the player can place tetrominoes. The environment has the following characteristics: observation: Observation grid: jax array (int32) of shape (num_rows, num_cols) representing the current state of the grid. tetromino: jax array (int32) of shape (4, 4) representing the current tetromino sampled from the tetromino list. action_mask: jax array (bool) of shape (4, num_cols). For each tetromino there are 4 rotations, each one corresponds to a line in the action_mask. Mask of the joint action space: True if the action (x_position and rotation degree) is feasible for the current tetromino and grid state. action: multi discrete array of shape (2,) rotation_index: The degree index determines the rotation of the tetromino: 0 corresponds to 0 degrees, 1 corresponds to 90 degrees, 2 corresponds to 180 degrees, and 3 corresponds to 270 degrees. x_position: int between 0 and num_cols - 1 (included). reward: The reward is 0 if no lines was cleared by the action and a convex function of the number of cleared lines otherwise. episode termination: if the tetromino cannot be placed anymore (i.e., it hits the top of the grid). 1 2 3 4 5 6 7 8 from jumanji.environments import Tetris env = Tetris () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state ) __init__ ( self , num_rows : int = 10 , num_cols : int = 10 , time_limit : int = 400 , viewer : Optional [ jumanji . viewer . Viewer [ jumanji . environments . packing . tetris . types . State ]] = None ) -> None special # Instantiates a Tetris environment. Parameters: Name Type Description Default num_rows int number of rows of the 2D grid. Defaults to 10. 10 num_cols int number of columns of the 2D grid. Defaults to 10. 10 time_limit int time_limit of an episode, i.e. number of environment steps before the episode ends. Defaults to 400. 400 viewer Optional[jumanji.viewer.Viewer[jumanji.environments.packing.tetris.types.State]] Viewer used for rendering. Defaults to TetrisViewer . None reset ( self , key : PRNGKeyArray ) -> Tuple [ jumanji . environments . packing . tetris . types . State , jumanji . types . TimeStep [ jumanji . environments . packing . tetris . types . Observation ]] # Resets the environment. Parameters: Name Type Description Default key PRNGKeyArray needed for generating new tetrominoes. required Returns: Type Description state State corresponding to the new state of the environment, timestep: TimeStep corresponding to the first timestep returned by the environment. step ( self , state : State , action : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ]) -> Tuple [ jumanji . environments . packing . tetris . types . State , jumanji . types . TimeStep [ jumanji . environments . packing . tetris . types . Observation ]] # Run one timestep of the environment's dynamics. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] chex.Array containing the rotation_index and x_position of the tetromino. required Returns: Type Description next_state State corresponding to the next state of the environment, next_timestep: TimeStep corresponding to the timestep returned by the environment. observation_spec ( self ) -> jumanji . specs . Spec [ jumanji . environments . packing . tetris . types . Observation ] # Specifications of the observation of the Tetris environment. Returns: Type Description Spec containing all the specifications for all the `Observation` fields grid: BoundedArray (jnp.int32) of shape (num_rows, num_cols). tetromino: BoundedArray (bool) of shape (4, 4). action_mask: BoundedArray (bool) of shape (NUM_ROTATIONS, num_cols). step_count: DiscreteArray (num_values = time_limit) of shape (). action_spec ( self ) -> MultiDiscreteArray # Returns the action spec. An action consists of two pieces of information: the amount of rotation (number of 90-degree rotations) and the x-position of the leftmost part of the tetromino. Returns: Type Description MultiDiscreteArray The action spec, which is a specs.MultiDiscreteArray object.","title":"Tetris"},{"location":"api/environments/tetris/#jumanji.environments.packing.tetris.env.Tetris","text":"RL Environment for the game of Tetris. The environment has a grid where the player can place tetrominoes. The environment has the following characteristics: observation: Observation grid: jax array (int32) of shape (num_rows, num_cols) representing the current state of the grid. tetromino: jax array (int32) of shape (4, 4) representing the current tetromino sampled from the tetromino list. action_mask: jax array (bool) of shape (4, num_cols). For each tetromino there are 4 rotations, each one corresponds to a line in the action_mask. Mask of the joint action space: True if the action (x_position and rotation degree) is feasible for the current tetromino and grid state. action: multi discrete array of shape (2,) rotation_index: The degree index determines the rotation of the tetromino: 0 corresponds to 0 degrees, 1 corresponds to 90 degrees, 2 corresponds to 180 degrees, and 3 corresponds to 270 degrees. x_position: int between 0 and num_cols - 1 (included). reward: The reward is 0 if no lines was cleared by the action and a convex function of the number of cleared lines otherwise. episode termination: if the tetromino cannot be placed anymore (i.e., it hits the top of the grid). 1 2 3 4 5 6 7 8 from jumanji.environments import Tetris env = Tetris () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state )","title":"Tetris"},{"location":"api/environments/tetris/#jumanji.environments.packing.tetris.env.Tetris.__init__","text":"Instantiates a Tetris environment. Parameters: Name Type Description Default num_rows int number of rows of the 2D grid. Defaults to 10. 10 num_cols int number of columns of the 2D grid. Defaults to 10. 10 time_limit int time_limit of an episode, i.e. number of environment steps before the episode ends. Defaults to 400. 400 viewer Optional[jumanji.viewer.Viewer[jumanji.environments.packing.tetris.types.State]] Viewer used for rendering. Defaults to TetrisViewer . None","title":"__init__()"},{"location":"api/environments/tetris/#jumanji.environments.packing.tetris.env.Tetris.reset","text":"Resets the environment. Parameters: Name Type Description Default key PRNGKeyArray needed for generating new tetrominoes. required Returns: Type Description state State corresponding to the new state of the environment, timestep: TimeStep corresponding to the first timestep returned by the environment.","title":"reset()"},{"location":"api/environments/tetris/#jumanji.environments.packing.tetris.env.Tetris.step","text":"Run one timestep of the environment's dynamics. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] chex.Array containing the rotation_index and x_position of the tetromino. required Returns: Type Description next_state State corresponding to the next state of the environment, next_timestep: TimeStep corresponding to the timestep returned by the environment.","title":"step()"},{"location":"api/environments/tetris/#jumanji.environments.packing.tetris.env.Tetris.observation_spec","text":"Specifications of the observation of the Tetris environment. Returns: Type Description Spec containing all the specifications for all the `Observation` fields grid: BoundedArray (jnp.int32) of shape (num_rows, num_cols). tetromino: BoundedArray (bool) of shape (4, 4). action_mask: BoundedArray (bool) of shape (NUM_ROTATIONS, num_cols). step_count: DiscreteArray (num_values = time_limit) of shape ().","title":"observation_spec()"},{"location":"api/environments/tetris/#jumanji.environments.packing.tetris.env.Tetris.action_spec","text":"Returns the action spec. An action consists of two pieces of information: the amount of rotation (number of 90-degree rotations) and the x-position of the leftmost part of the tetromino. Returns: Type Description MultiDiscreteArray The action spec, which is a specs.MultiDiscreteArray object.","title":"action_spec()"},{"location":"api/environments/tsp/","text":"TSP ( Environment ) # Traveling Salesman Problem (TSP) environment as described in [1]. observation: Observation coordinates: jax array (float) of shape (num_cities, 2) the coordinates of each city. position: jax array (int32) of shape () the index corresponding to the last visited city. trajectory: jax array (int32) of shape (num_cities,) array of city indices defining the route (-1 --> not filled yet). action_mask: jax array (bool) of shape (num_cities,) binary mask (False/True <--> illegal/legal <--> cannot be visited/can be visited). action: jax array (int32) of shape () [0, ..., num_cities - 1] -> city to visit. reward: jax array (float) of shape (), could be either: dense: the negative distance between the current city and the chosen next city to go to. It is 0 for the first chosen city, and for the last city, it also includes the distance to the initial city to complete the tour. sparse: the negative tour length at the end of the episode. The tour length is defined as the sum of the distances between consecutive cities. It is computed by starting at the first city and ending there, after visiting all the cities. In both cases, the reward is a large negative penalty of -num_cities * sqrt(2) if the action is invalid, i.e. a previously selected city is selected again. episode termination: if no action can be performed, i.e. all cities have been visited. if an invalid action is taken, i.e. an already visited city is chosen. state: State coordinates: jax array (float) of shape (num_cities, 2) the coordinates of each city. position: int32 the identifier (index) of the last visited city. visited_mask: jax array (bool) of shape (num_cities,) binary mask (False/True <--> not visited/visited). trajectory: jax array (int32) of shape (num_cities,) the identifiers of the cities that have been visited (-1 means that no city has been visited yet at that time in the sequence). num_visited: int32 number of cities that have been visited. [1] Kwon Y., Choo J., Kim B., Yoon I., Min S., Gwon Y. (2020). \"POMO: Policy Optimization with Multiple Optima for Reinforcement Learning\". 1 2 3 4 5 6 7 8 from jumanji.environments import TSP env = TSP () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state ) __init__ ( self , generator : Optional [ jumanji . environments . routing . tsp . generator . Generator ] = None , reward_fn : Optional [ jumanji . environments . routing . tsp . reward . RewardFn ] = None , viewer : Optional [ jumanji . viewer . Viewer [ jumanji . environments . routing . tsp . types . State ]] = None ) special # Instantiates a TSP environment. Parameters: Name Type Description Default generator Optional[jumanji.environments.routing.tsp.generator.Generator] Generator whose __call__ instantiates an environment instance. The default option is 'UniformGenerator' which randomly generates TSP instances with 20 cities sampled from a uniform distribution. None reward_fn Optional[jumanji.environments.routing.tsp.reward.RewardFn] RewardFn whose __call__ method computes the reward of an environment transition. The function must compute the reward based on the current state, the chosen action and the next state. Implemented options are [ DenseReward , SparseReward ]. Defaults to DenseReward . None viewer Optional[jumanji.viewer.Viewer[jumanji.environments.routing.tsp.types.State]] Viewer used for rendering. Defaults to TSPViewer with \"human\" render mode. None reset ( self , key : PRNGKeyArray ) -> Tuple [ jumanji . environments . routing . tsp . types . State , jumanji . types . TimeStep [ jumanji . environments . routing . tsp . types . Observation ]] # Resets the environment. Parameters: Name Type Description Default key PRNGKeyArray used to randomly generate the coordinates. required Returns: Type Description state State object corresponding to the new state of the environment. timestep: TimeStep object corresponding to the first timestep returned by the environment. step ( self , state : State , action : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number , float , int ]) -> Tuple [ jumanji . environments . routing . tsp . types . State , jumanji . types . TimeStep [ jumanji . environments . routing . tsp . types . Observation ]] # Run one timestep of the environment's dynamics. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, float, int] Array containing the index of the next position to visit. required Returns: Type Description state the next state of the environment. timestep: the timestep to be observed. observation_spec ( self ) -> jumanji . specs . Spec [ jumanji . environments . routing . tsp . types . Observation ] # Returns the observation spec. Returns: Type Description Spec for the `Observation` whose fields are coordinates: BoundedArray (float) of shape (num_cities,). position: DiscreteArray (num_values = num_cities) of shape (). trajectory: BoundedArray (int32) of shape (num_cities,). action_mask: BoundedArray (bool) of shape (num_cities,). action_spec ( self ) -> DiscreteArray # Returns the action spec. Returns: Type Description action_spec a specs.DiscreteArray spec.","title":"TSP"},{"location":"api/environments/tsp/#jumanji.environments.routing.tsp.env.TSP","text":"Traveling Salesman Problem (TSP) environment as described in [1]. observation: Observation coordinates: jax array (float) of shape (num_cities, 2) the coordinates of each city. position: jax array (int32) of shape () the index corresponding to the last visited city. trajectory: jax array (int32) of shape (num_cities,) array of city indices defining the route (-1 --> not filled yet). action_mask: jax array (bool) of shape (num_cities,) binary mask (False/True <--> illegal/legal <--> cannot be visited/can be visited). action: jax array (int32) of shape () [0, ..., num_cities - 1] -> city to visit. reward: jax array (float) of shape (), could be either: dense: the negative distance between the current city and the chosen next city to go to. It is 0 for the first chosen city, and for the last city, it also includes the distance to the initial city to complete the tour. sparse: the negative tour length at the end of the episode. The tour length is defined as the sum of the distances between consecutive cities. It is computed by starting at the first city and ending there, after visiting all the cities. In both cases, the reward is a large negative penalty of -num_cities * sqrt(2) if the action is invalid, i.e. a previously selected city is selected again. episode termination: if no action can be performed, i.e. all cities have been visited. if an invalid action is taken, i.e. an already visited city is chosen. state: State coordinates: jax array (float) of shape (num_cities, 2) the coordinates of each city. position: int32 the identifier (index) of the last visited city. visited_mask: jax array (bool) of shape (num_cities,) binary mask (False/True <--> not visited/visited). trajectory: jax array (int32) of shape (num_cities,) the identifiers of the cities that have been visited (-1 means that no city has been visited yet at that time in the sequence). num_visited: int32 number of cities that have been visited. [1] Kwon Y., Choo J., Kim B., Yoon I., Min S., Gwon Y. (2020). \"POMO: Policy Optimization with Multiple Optima for Reinforcement Learning\". 1 2 3 4 5 6 7 8 from jumanji.environments import TSP env = TSP () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state )","title":"TSP"},{"location":"api/environments/tsp/#jumanji.environments.routing.tsp.env.TSP.__init__","text":"Instantiates a TSP environment. Parameters: Name Type Description Default generator Optional[jumanji.environments.routing.tsp.generator.Generator] Generator whose __call__ instantiates an environment instance. The default option is 'UniformGenerator' which randomly generates TSP instances with 20 cities sampled from a uniform distribution. None reward_fn Optional[jumanji.environments.routing.tsp.reward.RewardFn] RewardFn whose __call__ method computes the reward of an environment transition. The function must compute the reward based on the current state, the chosen action and the next state. Implemented options are [ DenseReward , SparseReward ]. Defaults to DenseReward . None viewer Optional[jumanji.viewer.Viewer[jumanji.environments.routing.tsp.types.State]] Viewer used for rendering. Defaults to TSPViewer with \"human\" render mode. None","title":"__init__()"},{"location":"api/environments/tsp/#jumanji.environments.routing.tsp.env.TSP.reset","text":"Resets the environment. Parameters: Name Type Description Default key PRNGKeyArray used to randomly generate the coordinates. required Returns: Type Description state State object corresponding to the new state of the environment. timestep: TimeStep object corresponding to the first timestep returned by the environment.","title":"reset()"},{"location":"api/environments/tsp/#jumanji.environments.routing.tsp.env.TSP.step","text":"Run one timestep of the environment's dynamics. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, float, int] Array containing the index of the next position to visit. required Returns: Type Description state the next state of the environment. timestep: the timestep to be observed.","title":"step()"},{"location":"api/environments/tsp/#jumanji.environments.routing.tsp.env.TSP.observation_spec","text":"Returns the observation spec. Returns: Type Description Spec for the `Observation` whose fields are coordinates: BoundedArray (float) of shape (num_cities,). position: DiscreteArray (num_values = num_cities) of shape (). trajectory: BoundedArray (int32) of shape (num_cities,). action_mask: BoundedArray (bool) of shape (num_cities,).","title":"observation_spec()"},{"location":"api/environments/tsp/#jumanji.environments.routing.tsp.env.TSP.action_spec","text":"Returns the action spec. Returns: Type Description action_spec a specs.DiscreteArray spec.","title":"action_spec()"},{"location":"environments/bin_pack/","text":"BinPack Environment # We provide here an implementation of the 3D bin packing problem . In this problem, the goal of the agent is to efficiently pack a set of boxes (items) of different sizes into a single container with as little empty space as possible. Since there is only 1 bin, this formulation is equivalent to the 3D-knapsack problem. Observation # The observation given to the agent provides information on the available empty space (called EMSs), the items that still need to be packed, and information on what actions are valid at this point. The full observation is as follows: ems : EMS tree of jax arrays (float if normalize_dimensions else int32) each of shape (obs_num_ems,) , coordinates of all EMSs at the current timestep. ems_mask : jax array (bool) of shape (obs_num_ems,) , indicates the EMSs that are valid. items : Item tree of jax arrays (float if normalize_dimensions else int32) each of shape (max_num_items,) , characteristics of all items for this instance. items_mask : jax array (bool) of shape (max_num_items,) , indicates the items that are valid. items_placed : jax array (bool) of shape (max_num_items,) , indicates the items that have been placed so far. action_mask : jax array (bool) of shape (obs_num_ems, max_num_items) , mask of the joint action space: True if the action [ems_id, item_id] is valid. Action # The action space is a MultiDiscreteArray of 2 integer values representing the ID of an EMS (space) and the ID of an item. For instance, [1, 5] will place item 5 in EMS 1. Reward # The reward could be either: Dense : normalized volume (relative to the container volume) of the item packed by taking the chosen action. The computed reward is equivalent to the increase in volume utilization of the container due to packing the chosen item. If the action is invalid, the reward is 0.0 instead. Sparse : computed only at the end of the episode (otherwise, returns 0.0). Returns the volume utilization of the container (between 0.0 and 1.0). If the action is invalid, the action is ignored and the reward is still returned as the current container utilization. Registered Versions \ud83d\udcd6 # BinPack-v2 , 3D bin-packing problem with a solvable random generator that generates up to 20 items maximum, that can handle 40 EMSs maximum that are given in the observation.","title":"BinPack"},{"location":"environments/bin_pack/#binpack-environment","text":"We provide here an implementation of the 3D bin packing problem . In this problem, the goal of the agent is to efficiently pack a set of boxes (items) of different sizes into a single container with as little empty space as possible. Since there is only 1 bin, this formulation is equivalent to the 3D-knapsack problem.","title":"BinPack Environment"},{"location":"environments/bin_pack/#observation","text":"The observation given to the agent provides information on the available empty space (called EMSs), the items that still need to be packed, and information on what actions are valid at this point. The full observation is as follows: ems : EMS tree of jax arrays (float if normalize_dimensions else int32) each of shape (obs_num_ems,) , coordinates of all EMSs at the current timestep. ems_mask : jax array (bool) of shape (obs_num_ems,) , indicates the EMSs that are valid. items : Item tree of jax arrays (float if normalize_dimensions else int32) each of shape (max_num_items,) , characteristics of all items for this instance. items_mask : jax array (bool) of shape (max_num_items,) , indicates the items that are valid. items_placed : jax array (bool) of shape (max_num_items,) , indicates the items that have been placed so far. action_mask : jax array (bool) of shape (obs_num_ems, max_num_items) , mask of the joint action space: True if the action [ems_id, item_id] is valid.","title":"Observation"},{"location":"environments/bin_pack/#action","text":"The action space is a MultiDiscreteArray of 2 integer values representing the ID of an EMS (space) and the ID of an item. For instance, [1, 5] will place item 5 in EMS 1.","title":"Action"},{"location":"environments/bin_pack/#reward","text":"The reward could be either: Dense : normalized volume (relative to the container volume) of the item packed by taking the chosen action. The computed reward is equivalent to the increase in volume utilization of the container due to packing the chosen item. If the action is invalid, the reward is 0.0 instead. Sparse : computed only at the end of the episode (otherwise, returns 0.0). Returns the volume utilization of the container (between 0.0 and 1.0). If the action is invalid, the action is ignored and the reward is still returned as the current container utilization.","title":"Reward"},{"location":"environments/bin_pack/#registered-versions","text":"BinPack-v2 , 3D bin-packing problem with a solvable random generator that generates up to 20 items maximum, that can handle 40 EMSs maximum that are given in the observation.","title":"Registered Versions \ud83d\udcd6"},{"location":"environments/cleaner/","text":"Cleaner Environment # We provide here a JAX jit-able implementation of the Multi-Agent Cleaning environment. In this environment, multiple agents must cooperatively clean the floor of a room with complex indoor barriers (black). At the beginning of an episode, the whole floor is dirty (green). Every time an agent (red) visits a dirty tile, it is cleaned (white). The goal is to clean as many tiles as possible in a given time budget. A new maze is randomly generated using a recursive division method for each new episode. Agents always start in the top left corner of the maze. Observation # The observation seen by the agent is a NamedTuple containing the following: grid : jax array (int) of shape (num_rows, num_cols) , array representing the grid, each tile is either dirty (0), clean (1), or a wall (2). agents_locations : jax array (int) of shape (num_agents, 2) , array specifying the x and y coordinates of every agent. action_mask : jax array (bool) of shape (num_agents, 4) , array specifying, for each agent, which action (up, right, down, left) is legal. step_count : jax array (int32) of shape () , number of steps elapsed in the current episode. Action # The action space is a MultiDiscreteArray containing an integer value in [0, 1, 2, 3] for each agent. Each agent can take one of four actions: up ( 0 ), right ( 1 ), down ( 2 ), or left ( 3 ). The episode terminates if any agent meets one of the following conditions: An invalid action is taken, or An action is blocked by a wall. In both cases, the agent's position remains unchanged. Reward # The reward is global and shared among the agents. It is equal to the number of tiles which were cleaned during the time step, minus a penalty (0.5 by default) to encourage agents to clean the maze faster. Registered Versions \ud83d\udcd6 # Cleaner-v0 , a room of size 10x10 with 3 agents.","title":"Cleaner"},{"location":"environments/cleaner/#cleaner-environment","text":"We provide here a JAX jit-able implementation of the Multi-Agent Cleaning environment. In this environment, multiple agents must cooperatively clean the floor of a room with complex indoor barriers (black). At the beginning of an episode, the whole floor is dirty (green). Every time an agent (red) visits a dirty tile, it is cleaned (white). The goal is to clean as many tiles as possible in a given time budget. A new maze is randomly generated using a recursive division method for each new episode. Agents always start in the top left corner of the maze.","title":"Cleaner Environment"},{"location":"environments/cleaner/#observation","text":"The observation seen by the agent is a NamedTuple containing the following: grid : jax array (int) of shape (num_rows, num_cols) , array representing the grid, each tile is either dirty (0), clean (1), or a wall (2). agents_locations : jax array (int) of shape (num_agents, 2) , array specifying the x and y coordinates of every agent. action_mask : jax array (bool) of shape (num_agents, 4) , array specifying, for each agent, which action (up, right, down, left) is legal. step_count : jax array (int32) of shape () , number of steps elapsed in the current episode.","title":"Observation"},{"location":"environments/cleaner/#action","text":"The action space is a MultiDiscreteArray containing an integer value in [0, 1, 2, 3] for each agent. Each agent can take one of four actions: up ( 0 ), right ( 1 ), down ( 2 ), or left ( 3 ). The episode terminates if any agent meets one of the following conditions: An invalid action is taken, or An action is blocked by a wall. In both cases, the agent's position remains unchanged.","title":"Action"},{"location":"environments/cleaner/#reward","text":"The reward is global and shared among the agents. It is equal to the number of tiles which were cleaned during the time step, minus a penalty (0.5 by default) to encourage agents to clean the maze faster.","title":"Reward"},{"location":"environments/cleaner/#registered-versions","text":"Cleaner-v0 , a room of size 10x10 with 3 agents.","title":"Registered Versions \ud83d\udcd6"},{"location":"environments/connector/","text":"Connector Environment # The Connector environment contains multiple agents spawned in a grid world with each agent representing a start and end position that need to be connected. The main goal of the environment is to connect each start and end position in as few steps as possible. However, when an agent moves it leaves behind a path, which is impassable by all agents. Thus, agents need to cooperate in order to allow each other to connect to their own targets without overlapping. An episode ends when all agents have connected to their targets or no agents can make any further moves due to being blocked. Observation # At each step observation contains 3 items: a grid, an action mask for each agent and the episode step count. grid : jax array (int32) of shape (grid_size, grid_size) , a 2D matrix that represents pairs of points that need to be connected. Each agent has three types of points: position , target and path which are represented by different numbers on the grid. The position of an agent has to connect to its target , leaving a path behind it as it moves across the grid forming its route. Each agent connects to only 1 target. action_mask : jax array (bool) of shape (num_agents, 5) , indicates which actions each agent can take. step_count : jax array (int32) of shape () , represents how many steps have been taken in the environment since the last reset. Encoding # Each agent has 3 components represented in the observation space: position , target , and path . Each agent in the environment will have an integer representing their components. Positions are encoded starting from 2 in multiples of 3: 2, 5, 8, \u2026 Targets are encoded starting from 3 in multiples of 3: 3, 6, 9, \u2026 Paths appear in the location of the head once it moves, starting from 1 in multiples of 3: 1, 4, 7, \u2026 Every group of 3 corresponds to 1 agent: (1,2,3), (4,5,6), \u2026 Example: 1 2 3 Agent1[path=1, position=2, target=3] Agent2[path=4, position=5, target=6] Agent3[path=7, position=8, target=9] For example, on a 6x6 grid, a possible observation is shown below. 1 2 3 4 5 6 [[ 2 0 3 0 0 0] [ 1 0 4 4 4 0] [ 1 0 5 9 0 0] [ 1 0 0 0 0 0] [ 0 0 0 8 0 0] [ 0 0 6 7 7 7]] Action # The action space is a MultiDiscreteArray of shape (num_agents,) of integer values in the range of [0, 4] . Each value corresponds to an agent moving in 1 of 4 cardinal directions or taking the no-op action. That is, [0, 1, 2, 3, 4] -> [No Op, Up, Right, Down, Left]. Reward # The reward is dense : +1.0 per agent that connects at that step and -0.03 per agent that has not connected yet. Rewards are provided in the shape (num_agents,) so that each agent can have a reward. Registered Versions \ud83d\udcd6 # Connector-v2 , grid size of 10 and 10 agents.","title":"Connector"},{"location":"environments/connector/#connector-environment","text":"The Connector environment contains multiple agents spawned in a grid world with each agent representing a start and end position that need to be connected. The main goal of the environment is to connect each start and end position in as few steps as possible. However, when an agent moves it leaves behind a path, which is impassable by all agents. Thus, agents need to cooperate in order to allow each other to connect to their own targets without overlapping. An episode ends when all agents have connected to their targets or no agents can make any further moves due to being blocked.","title":"Connector Environment"},{"location":"environments/connector/#observation","text":"At each step observation contains 3 items: a grid, an action mask for each agent and the episode step count. grid : jax array (int32) of shape (grid_size, grid_size) , a 2D matrix that represents pairs of points that need to be connected. Each agent has three types of points: position , target and path which are represented by different numbers on the grid. The position of an agent has to connect to its target , leaving a path behind it as it moves across the grid forming its route. Each agent connects to only 1 target. action_mask : jax array (bool) of shape (num_agents, 5) , indicates which actions each agent can take. step_count : jax array (int32) of shape () , represents how many steps have been taken in the environment since the last reset.","title":"Observation"},{"location":"environments/connector/#encoding","text":"Each agent has 3 components represented in the observation space: position , target , and path . Each agent in the environment will have an integer representing their components. Positions are encoded starting from 2 in multiples of 3: 2, 5, 8, \u2026 Targets are encoded starting from 3 in multiples of 3: 3, 6, 9, \u2026 Paths appear in the location of the head once it moves, starting from 1 in multiples of 3: 1, 4, 7, \u2026 Every group of 3 corresponds to 1 agent: (1,2,3), (4,5,6), \u2026 Example: 1 2 3 Agent1[path=1, position=2, target=3] Agent2[path=4, position=5, target=6] Agent3[path=7, position=8, target=9] For example, on a 6x6 grid, a possible observation is shown below. 1 2 3 4 5 6 [[ 2 0 3 0 0 0] [ 1 0 4 4 4 0] [ 1 0 5 9 0 0] [ 1 0 0 0 0 0] [ 0 0 0 8 0 0] [ 0 0 6 7 7 7]]","title":"Encoding"},{"location":"environments/connector/#action","text":"The action space is a MultiDiscreteArray of shape (num_agents,) of integer values in the range of [0, 4] . Each value corresponds to an agent moving in 1 of 4 cardinal directions or taking the no-op action. That is, [0, 1, 2, 3, 4] -> [No Op, Up, Right, Down, Left].","title":"Action"},{"location":"environments/connector/#reward","text":"The reward is dense : +1.0 per agent that connects at that step and -0.03 per agent that has not connected yet. Rewards are provided in the shape (num_agents,) so that each agent can have a reward.","title":"Reward"},{"location":"environments/connector/#registered-versions","text":"Connector-v2 , grid size of 10 and 10 agents.","title":"Registered Versions \ud83d\udcd6"},{"location":"environments/cvrp/","text":"Capacitated Vehicle Routing Problem (CVRP) Environment # We provide here a Jax JIT-able implementation of the capacitated vehicle routing problem (CVRP) which is a specific type of VRP . CVRP is a classic combinatorial optimization problem. Given a set of nodes with specific demands, a depot node, and a vehicle with limited capacity, the goal is to determine the shortest route between the nodes such that each node (excluding depot) is visited exactly once and has its demand covered. The problem is NP-complete, thus there is no known algorithm both correct and fast (i.e., that runs in polynomial time) for any instance of the problem. A new problem instance is generated by resetting the environment. The problem instance contains coordinates for each node sampled from a uniform distribution between 0 and 1, and each node (except for depot) has a specific demand which is an integer value sampled from a uniform distribution between 1 and the maximum demand (which is a parameter of the CVRP environment). The number of nodes with demand is a parameter of the environment. Observation # The observation given to the agent provides information on the problem layout, the visited/unvisited cities and the current position of the agent as well as the current capacity. coordinates : jax array (float) of shape (num_nodes + 1, 2) , array of coordinates of each city node and the depot node. demands : jax array (float) of shape (num_nodes + 1,) , array of the demands of each city node and the depot node whose demand is set to 0. unvisited_nodes : jax array (bool) of shape (num_nodes + 1,) , array denoting which nodes remain to be visited. position : jax array (int32) of shape () , identifier (index) of the current visited node (city or depot). trajectory : jax array (int32) of shape (2 * num_nodes,) , identifiers of the nodes that have been visited (set to DEPOT_IDX if not filled yet). capacity : jax array (float) of shape () , current capacity of the vehicle. action_mask : jax array (bool) of shape (num_nodes + 1,) , array denoting which actions are possible (True) and which are not (False). Action # The action space is a DiscreteArray of integer values in the range of [0, num_nodes] . An action is the index of the next node to visit, and an action value of 0 corresponds to visiting the depot. Reward # The reward could be either: Dense : the negative distance between the current node and the chosen next node to go to. For the last node, it also includes the distance to the depot to complete the tour. Sparse : the negative tour length at the end of the episode. The tour length is defined as the sum of the distances between consecutive nodes. In both cases, the reward is a large negative penalty of -2 * num_nodes * sqrt(2) if the action is invalid, e.g. a previously selected node other than the depot is selected again. Registered Versions \ud83d\udcd6 # CVRP-v1 : CVRP problem with 20 randomly generated nodes, a maximum capacity of 30, a maximum demand for each node of 10 and a dense reward function.","title":"CVRP"},{"location":"environments/cvrp/#capacitated-vehicle-routing-problem-cvrp-environment","text":"We provide here a Jax JIT-able implementation of the capacitated vehicle routing problem (CVRP) which is a specific type of VRP . CVRP is a classic combinatorial optimization problem. Given a set of nodes with specific demands, a depot node, and a vehicle with limited capacity, the goal is to determine the shortest route between the nodes such that each node (excluding depot) is visited exactly once and has its demand covered. The problem is NP-complete, thus there is no known algorithm both correct and fast (i.e., that runs in polynomial time) for any instance of the problem. A new problem instance is generated by resetting the environment. The problem instance contains coordinates for each node sampled from a uniform distribution between 0 and 1, and each node (except for depot) has a specific demand which is an integer value sampled from a uniform distribution between 1 and the maximum demand (which is a parameter of the CVRP environment). The number of nodes with demand is a parameter of the environment.","title":"Capacitated Vehicle Routing Problem (CVRP) Environment"},{"location":"environments/cvrp/#observation","text":"The observation given to the agent provides information on the problem layout, the visited/unvisited cities and the current position of the agent as well as the current capacity. coordinates : jax array (float) of shape (num_nodes + 1, 2) , array of coordinates of each city node and the depot node. demands : jax array (float) of shape (num_nodes + 1,) , array of the demands of each city node and the depot node whose demand is set to 0. unvisited_nodes : jax array (bool) of shape (num_nodes + 1,) , array denoting which nodes remain to be visited. position : jax array (int32) of shape () , identifier (index) of the current visited node (city or depot). trajectory : jax array (int32) of shape (2 * num_nodes,) , identifiers of the nodes that have been visited (set to DEPOT_IDX if not filled yet). capacity : jax array (float) of shape () , current capacity of the vehicle. action_mask : jax array (bool) of shape (num_nodes + 1,) , array denoting which actions are possible (True) and which are not (False).","title":"Observation"},{"location":"environments/cvrp/#action","text":"The action space is a DiscreteArray of integer values in the range of [0, num_nodes] . An action is the index of the next node to visit, and an action value of 0 corresponds to visiting the depot.","title":"Action"},{"location":"environments/cvrp/#reward","text":"The reward could be either: Dense : the negative distance between the current node and the chosen next node to go to. For the last node, it also includes the distance to the depot to complete the tour. Sparse : the negative tour length at the end of the episode. The tour length is defined as the sum of the distances between consecutive nodes. In both cases, the reward is a large negative penalty of -2 * num_nodes * sqrt(2) if the action is invalid, e.g. a previously selected node other than the depot is selected again.","title":"Reward"},{"location":"environments/cvrp/#registered-versions","text":"CVRP-v1 : CVRP problem with 20 randomly generated nodes, a maximum capacity of 30, a maximum demand for each node of 10 and a dense reward function.","title":"Registered Versions \ud83d\udcd6"},{"location":"environments/game_2048/","text":"2048 Environment # We provide here a Jax JIT-able implementation of the game 2048 . 2048 is a popular single-player puzzle game that is played on a 4x4 grid. The game board consists of cells, each containing a power of 2, and the objective is to reach a score of at least 2048 by merging cells together. The player can shift the entire grid in one of the four directions (up, down, right, left) to combine cells of the same value. When two adjacent cells have the same value, they merge into a single cell with a value equal to the sum of the two cells. The game ends when the player is no longer able to make any further moves. The ultimate goal is to achieve the highest-valued tile possible, with the hope of surpassing 2048. With each move, the player must carefully plan and strategize to reach the highest score possible. Observation # The observation in the game 2048 includes information about the board, the action mask, and the step count. board : jax array (int32) of shape (board_size, board_size) , representing the current game state. Each nonzero element in the array corresponds to a game tile and holds an exponent of 2. The actual value of the tile is obtained by raising 2 to the power of said exponent. Here is an example of a random observation of the game board: 1 2 3 4 [[ 2 0 1 4] [ 5 3 0 2] [ 0 2 3 2] [ 1 2 0 0]] This array can be converted into the actual game board: 1 2 3 4 [[ 4 0 2 16] [ 32 8 0 4] [ 0 4 8 4] [ 2 4 0 0]] action_mask : jax array (bool) of shape (4,) , indicating which actions are valid in the current state of the environment. The actions include moving the tiles up, right, down, or left. For example, an action mask [False, True, False, False] means that the only valid action is to move the tiles rightward. Action # The action space is a DiscreteArray of integer values in [0, 1, 2, 3] . Specifically, these four actions correspond to: up (0), right (1), down (2), or left (3). Reward # Taking an action in 2048 only returns a reward when two tiles of equal value are merged into a new tile containing their sum (i.e. twice each of their values). The cumulative reward in an episode is the sum of the values of all newly created tiles. For example, if a player merges two 512-value tiles to create a new 1024-value tile, and then merges two 256-value tiles to create a new 512-value tile, the total reward from these actions is 1536 (i.e., 1024 + 512). Registered Versions \ud83d\udcd6 # Game2048-v1 , the default settings for 2048 with a board of size 4x4.","title":"Game2048"},{"location":"environments/game_2048/#2048-environment","text":"We provide here a Jax JIT-able implementation of the game 2048 . 2048 is a popular single-player puzzle game that is played on a 4x4 grid. The game board consists of cells, each containing a power of 2, and the objective is to reach a score of at least 2048 by merging cells together. The player can shift the entire grid in one of the four directions (up, down, right, left) to combine cells of the same value. When two adjacent cells have the same value, they merge into a single cell with a value equal to the sum of the two cells. The game ends when the player is no longer able to make any further moves. The ultimate goal is to achieve the highest-valued tile possible, with the hope of surpassing 2048. With each move, the player must carefully plan and strategize to reach the highest score possible.","title":"2048 Environment"},{"location":"environments/game_2048/#observation","text":"The observation in the game 2048 includes information about the board, the action mask, and the step count. board : jax array (int32) of shape (board_size, board_size) , representing the current game state. Each nonzero element in the array corresponds to a game tile and holds an exponent of 2. The actual value of the tile is obtained by raising 2 to the power of said exponent. Here is an example of a random observation of the game board: 1 2 3 4 [[ 2 0 1 4] [ 5 3 0 2] [ 0 2 3 2] [ 1 2 0 0]] This array can be converted into the actual game board: 1 2 3 4 [[ 4 0 2 16] [ 32 8 0 4] [ 0 4 8 4] [ 2 4 0 0]] action_mask : jax array (bool) of shape (4,) , indicating which actions are valid in the current state of the environment. The actions include moving the tiles up, right, down, or left. For example, an action mask [False, True, False, False] means that the only valid action is to move the tiles rightward.","title":"Observation"},{"location":"environments/game_2048/#action","text":"The action space is a DiscreteArray of integer values in [0, 1, 2, 3] . Specifically, these four actions correspond to: up (0), right (1), down (2), or left (3).","title":"Action"},{"location":"environments/game_2048/#reward","text":"Taking an action in 2048 only returns a reward when two tiles of equal value are merged into a new tile containing their sum (i.e. twice each of their values). The cumulative reward in an episode is the sum of the values of all newly created tiles. For example, if a player merges two 512-value tiles to create a new 1024-value tile, and then merges two 256-value tiles to create a new 512-value tile, the total reward from these actions is 1536 (i.e., 1024 + 512).","title":"Reward"},{"location":"environments/game_2048/#registered-versions","text":"Game2048-v1 , the default settings for 2048 with a board of size 4x4.","title":"Registered Versions \ud83d\udcd6"},{"location":"environments/graph_coloring/","text":"Graph Coloring Environment # We provide here a Jax JIT-able implementation of the Graph Coloring environment. Graph coloring is a combinatorial optimization problem where the objective is to assign a color to each vertex of a graph in such a way that no two adjacent vertices share the same color. The problem is usually formulated as minimizing the number of colors used. The GraphColoring environment is an episodic, single-agent setting that allows for the exploration of graph coloring algorithms and reinforcement learning methods. Observation # The observation in the GraphColoring environment includes information about the graph, the colors assigned to the vertices, the action mask, and the current node index. graph : jax array (bool) of shape (num_nodes, num_nodes) , representing the adjacency matrix of the graph. For example, a random observation of the graph adjacency matrix: 1 2 3 4 ```[[False, True, False, True], [ True, False, True, False], [False, True, False, True], [ True, False, True, False]]``` colors : a JAX array (int32) of shape (num_nodes,) , representing the current color assignments for the vertices. Initially, all elements are set to -1, indicating that no colors have been assigned yet. For example, an initial color assignment: [-1, -1, -1, -1] action_mask : a JAX array of boolean values, shaped (num_colors,) , which indicates the valid actions in the current state of the environment. Each position in the array corresponds to a color. True at a position signifies that the corresponding color can be used to color a node, while False indicates the opposite. For example, for 4 number of colors available: [True, False, True, False] current_node_index : an integer representing the current node being colored. For example, an initial current_node_index might be 0. Action # The action space is a DiscreteArray of integer values in [0, 1, ..., num_colors - 1] . Each action corresponds to assigning a color to the current node. Reward # The reward in the GraphColoring environment is given as follows: sparse reward : a reward is provided at the end of the episode and equals the negative of the number of unique colors used to color all vertices in the graph. The agent's goal is to find a valid coloring using as few colors as possible while avoiding conflicts with adjacent nodes. Episode Termination # The goal of the agent is to find a valid coloring using as few colors as possible. An episode in the graph coloring environment can terminate under two conditions: All nodes have been assigned a color: the environment iteratively assigns colors to nodes. When all nodes have a color assigned (i.e., there are no nodes with a color value of -1), the episode ends. This is the natural termination condition and ideally the one we'd like the agent to achieve. Invalid action is taken: an action is considered invalid if it tries to assign a color to a node that is not within the allowed color set for that node at that time. The allowed color set for each node is updated after every action. If an invalid action is attempted, the episode immediately terminates and the agent receives a large negative reward. This encourages the agent to learn valid actions and discourages it from making invalid actions. Registered Versions \ud83d\udcd6 # GraphColoring-v0 : The default settings for the GraphColoring problem with a configurable number of nodes and edge_probability. The default number of nodes is 20, and the default edge probability is 0.8.","title":"GraphColoring"},{"location":"environments/graph_coloring/#graph-coloring-environment","text":"We provide here a Jax JIT-able implementation of the Graph Coloring environment. Graph coloring is a combinatorial optimization problem where the objective is to assign a color to each vertex of a graph in such a way that no two adjacent vertices share the same color. The problem is usually formulated as minimizing the number of colors used. The GraphColoring environment is an episodic, single-agent setting that allows for the exploration of graph coloring algorithms and reinforcement learning methods.","title":"Graph Coloring Environment"},{"location":"environments/graph_coloring/#observation","text":"The observation in the GraphColoring environment includes information about the graph, the colors assigned to the vertices, the action mask, and the current node index. graph : jax array (bool) of shape (num_nodes, num_nodes) , representing the adjacency matrix of the graph. For example, a random observation of the graph adjacency matrix: 1 2 3 4 ```[[False, True, False, True], [ True, False, True, False], [False, True, False, True], [ True, False, True, False]]``` colors : a JAX array (int32) of shape (num_nodes,) , representing the current color assignments for the vertices. Initially, all elements are set to -1, indicating that no colors have been assigned yet. For example, an initial color assignment: [-1, -1, -1, -1] action_mask : a JAX array of boolean values, shaped (num_colors,) , which indicates the valid actions in the current state of the environment. Each position in the array corresponds to a color. True at a position signifies that the corresponding color can be used to color a node, while False indicates the opposite. For example, for 4 number of colors available: [True, False, True, False] current_node_index : an integer representing the current node being colored. For example, an initial current_node_index might be 0.","title":"Observation"},{"location":"environments/graph_coloring/#action","text":"The action space is a DiscreteArray of integer values in [0, 1, ..., num_colors - 1] . Each action corresponds to assigning a color to the current node.","title":"Action"},{"location":"environments/graph_coloring/#reward","text":"The reward in the GraphColoring environment is given as follows: sparse reward : a reward is provided at the end of the episode and equals the negative of the number of unique colors used to color all vertices in the graph. The agent's goal is to find a valid coloring using as few colors as possible while avoiding conflicts with adjacent nodes.","title":"Reward"},{"location":"environments/graph_coloring/#episode-termination","text":"The goal of the agent is to find a valid coloring using as few colors as possible. An episode in the graph coloring environment can terminate under two conditions: All nodes have been assigned a color: the environment iteratively assigns colors to nodes. When all nodes have a color assigned (i.e., there are no nodes with a color value of -1), the episode ends. This is the natural termination condition and ideally the one we'd like the agent to achieve. Invalid action is taken: an action is considered invalid if it tries to assign a color to a node that is not within the allowed color set for that node at that time. The allowed color set for each node is updated after every action. If an invalid action is attempted, the episode immediately terminates and the agent receives a large negative reward. This encourages the agent to learn valid actions and discourages it from making invalid actions.","title":"Episode Termination"},{"location":"environments/graph_coloring/#registered-versions","text":"GraphColoring-v0 : The default settings for the GraphColoring problem with a configurable number of nodes and edge_probability. The default number of nodes is 20, and the default edge probability is 0.8.","title":"Registered Versions \ud83d\udcd6"},{"location":"environments/job_shop/","text":"JobShop Environment # We provide here a JAX jit-able implementation of the job shop scheduling problem . It is NP-hard and one of the most well-known combinatorial optimisation problems. The problem formulation is: N jobs , each consisting of a sequence of operations , need to be scheduled on M machines. For each job, its operations must be processed in order . This is called the precedence constraints . Only one operation in a job can be processed at any given time. A machine can only work on one operation at a time. Once started, an operation must run to completion. The goal of the agent is to determine the schedule that minimises the time needed to process all the jobs. The length of the schedule is also known as its makespan . Observation # The observation seen by the agent is a NamedTuple containing the following: ops_machine_ids : jax array (int32) of shape (num_jobs, max_num_ops) . For each job, it specifies the machine each op must be processed on. Note that a -1 corresponds to padded ops since not all jobs have the same number of ops. ops_durations : jax array (int32) of shape (num_jobs, max_num_ops) . For each job, it specifies the processing time of each operation. Note that a -1 corresponds to padded ops since not all jobs have the same number of ops. ops_mask : jax array (bool) of shape (num_jobs, max_num_ops) . For each job, indicates which operations remain to be scheduled. False if the op has been scheduled or if the op was added for padding, True otherwise. The first True in each row (i.e. each job) identifies the next operation for that job. machines_job_ids : jax array (int32) of shape (num_machines,) . For each machine, it specifies the job currently being processed. Note that -1 means no-op in which case the remaining time until available is always 0. machines_remaining_times : jax array (int32) of shape (num_machines,) . For each machine, it specifies the number of time steps until available. action_mask : jax array (bool) of (num_machines, num_jobs + 1) . For each machine, it indicates which jobs (or no-op) can legally be scheduled. The last column corresponds to no-op. Action # The action space is a MultiDiscreteArray containing an integer value in [0, 1, ..., num_jobs] for each machine. Thus, an action consists of the following: for each machine, decide which job (or no-op) to schedule at the current time step. The action is represented as a 1-dimensional array of length num_machines . For example, suppose we have M=5 machines and there are N=10 jobs. A legal action might be 1 action = [ 4 , 7 , 0 , 10 , 10 ] This action represents scheduling Job 4 on Machine 0, Job 7 on Machine 1, Job 0 on Machine 2, No-op on Machine 3, No-op on Machine 4. As such, the action is multidimensional and can be thought of as each machine (each agent) deciding which job (or no-op) to schedule. Importantly, the action space is a product of the marginal action space of each agent (machine). The rationale for having a no-op is the following: A machine might be busy processing an operation, in which case a no-op is the only allowed action for that machine. There might not be any jobs that can be scheduled on a machine. There may be scenarios where waiting to schedule a job via one or more no-op(s) ultimately minimizes the makespan. Reward # The reward setting is dense: a reward of -1 is given each time step if none of the termination criteria are met. An episode will terminate in any of the three scenarios below: Finished schedule : all operations (and thus all jobs) every job have been processed. Illegal action: the agent ignores the action mask and takes an illegal action. Simultaneously idle: all machines are inactive at the same time. If all machines are simultaneously idle or the agent selects an invalid action, this is reflected in a large penalty in the reward. This would be -num_jobs * max_num_ops * max_op_duration which is a upper bound on the makespan, corresponding to if every job had max_num_ops operations and every operation had a processing time of max_op_duration . Registered Versions \ud83d\udcd6 # JobShop-v0 : job-shop scheduling problem with 20 jobs, 10 machines, a maximum of 8 operations per job, and a max operation duration of 6 timesteps per operation.","title":"JobShop"},{"location":"environments/job_shop/#jobshop-environment","text":"We provide here a JAX jit-able implementation of the job shop scheduling problem . It is NP-hard and one of the most well-known combinatorial optimisation problems. The problem formulation is: N jobs , each consisting of a sequence of operations , need to be scheduled on M machines. For each job, its operations must be processed in order . This is called the precedence constraints . Only one operation in a job can be processed at any given time. A machine can only work on one operation at a time. Once started, an operation must run to completion. The goal of the agent is to determine the schedule that minimises the time needed to process all the jobs. The length of the schedule is also known as its makespan .","title":"JobShop Environment"},{"location":"environments/job_shop/#observation","text":"The observation seen by the agent is a NamedTuple containing the following: ops_machine_ids : jax array (int32) of shape (num_jobs, max_num_ops) . For each job, it specifies the machine each op must be processed on. Note that a -1 corresponds to padded ops since not all jobs have the same number of ops. ops_durations : jax array (int32) of shape (num_jobs, max_num_ops) . For each job, it specifies the processing time of each operation. Note that a -1 corresponds to padded ops since not all jobs have the same number of ops. ops_mask : jax array (bool) of shape (num_jobs, max_num_ops) . For each job, indicates which operations remain to be scheduled. False if the op has been scheduled or if the op was added for padding, True otherwise. The first True in each row (i.e. each job) identifies the next operation for that job. machines_job_ids : jax array (int32) of shape (num_machines,) . For each machine, it specifies the job currently being processed. Note that -1 means no-op in which case the remaining time until available is always 0. machines_remaining_times : jax array (int32) of shape (num_machines,) . For each machine, it specifies the number of time steps until available. action_mask : jax array (bool) of (num_machines, num_jobs + 1) . For each machine, it indicates which jobs (or no-op) can legally be scheduled. The last column corresponds to no-op.","title":"Observation"},{"location":"environments/job_shop/#action","text":"The action space is a MultiDiscreteArray containing an integer value in [0, 1, ..., num_jobs] for each machine. Thus, an action consists of the following: for each machine, decide which job (or no-op) to schedule at the current time step. The action is represented as a 1-dimensional array of length num_machines . For example, suppose we have M=5 machines and there are N=10 jobs. A legal action might be 1 action = [ 4 , 7 , 0 , 10 , 10 ] This action represents scheduling Job 4 on Machine 0, Job 7 on Machine 1, Job 0 on Machine 2, No-op on Machine 3, No-op on Machine 4. As such, the action is multidimensional and can be thought of as each machine (each agent) deciding which job (or no-op) to schedule. Importantly, the action space is a product of the marginal action space of each agent (machine). The rationale for having a no-op is the following: A machine might be busy processing an operation, in which case a no-op is the only allowed action for that machine. There might not be any jobs that can be scheduled on a machine. There may be scenarios where waiting to schedule a job via one or more no-op(s) ultimately minimizes the makespan.","title":"Action"},{"location":"environments/job_shop/#reward","text":"The reward setting is dense: a reward of -1 is given each time step if none of the termination criteria are met. An episode will terminate in any of the three scenarios below: Finished schedule : all operations (and thus all jobs) every job have been processed. Illegal action: the agent ignores the action mask and takes an illegal action. Simultaneously idle: all machines are inactive at the same time. If all machines are simultaneously idle or the agent selects an invalid action, this is reflected in a large penalty in the reward. This would be -num_jobs * max_num_ops * max_op_duration which is a upper bound on the makespan, corresponding to if every job had max_num_ops operations and every operation had a processing time of max_op_duration .","title":"Reward"},{"location":"environments/job_shop/#registered-versions","text":"JobShop-v0 : job-shop scheduling problem with 20 jobs, 10 machines, a maximum of 8 operations per job, and a max operation duration of 6 timesteps per operation.","title":"Registered Versions \ud83d\udcd6"},{"location":"environments/knapsack/","text":"Knapskack Environment # We provide here a Jax JIT-able implementation of the knapskack problem . The knapsack problem is a famous problem in combinatorial optimization. The goal is to determine, given a set of items, each with a weight and a value, which items to include in a collection so that the total weight is less than or equal to a given limit and the total value is as large as possible. The decision problem form of the knapsack problem is NP-complete, thus there is no known algorithm both correct and fast (polynomial-time) in all cases. When the environment is reset, a new problem instance is generated, by sampling weights and values from a uniform distribution between 0 and 1. The weight limit of the knapsack is a parameter of the environment. A trajectory terminates when no further item can be added to the knapsack or the chosen action is invalid. Observation # The observation given to the agent provides information regarding the weights and the values of all the items, as well as, which items have been packed into the knapsack. weights : jax array (float) of shape (num_items,) , array of weights of the items to be packed into the knapsack. values : jax array (float) of shape (num_items,) , array of values of the items to be packed into the knapsack. packed_items : jax array (bool) of shape (num_items,) , array of binary values denoting which items are already packed into the knapsack. action_mask : jax array (bool) of shape (num_items,) , array of binary values denoting which items can be packed into the knapsack. Action # The action space is a DiscreteArray of integer values in the range of [0, num_items-1] . An action is the index of the next item to pack. Reward # The reward can be either: Dense : the value of the item to pack at the current timestep. Sparse : the sum of the values of the items packed in the bag at the end of the episode. In both cases, the reward is 0 if the action is invalid, i.e. an item that was previously selected is selected again or has a weight larger than the bag capacity. Registered Versions \ud83d\udcd6 # Knapsack-v1 : Knapsack problem with 50 randomly generated items, a total budget of 12.5 and a dense reward function.","title":"Knapsack"},{"location":"environments/knapsack/#knapskack-environment","text":"We provide here a Jax JIT-able implementation of the knapskack problem . The knapsack problem is a famous problem in combinatorial optimization. The goal is to determine, given a set of items, each with a weight and a value, which items to include in a collection so that the total weight is less than or equal to a given limit and the total value is as large as possible. The decision problem form of the knapsack problem is NP-complete, thus there is no known algorithm both correct and fast (polynomial-time) in all cases. When the environment is reset, a new problem instance is generated, by sampling weights and values from a uniform distribution between 0 and 1. The weight limit of the knapsack is a parameter of the environment. A trajectory terminates when no further item can be added to the knapsack or the chosen action is invalid.","title":"Knapskack Environment"},{"location":"environments/knapsack/#observation","text":"The observation given to the agent provides information regarding the weights and the values of all the items, as well as, which items have been packed into the knapsack. weights : jax array (float) of shape (num_items,) , array of weights of the items to be packed into the knapsack. values : jax array (float) of shape (num_items,) , array of values of the items to be packed into the knapsack. packed_items : jax array (bool) of shape (num_items,) , array of binary values denoting which items are already packed into the knapsack. action_mask : jax array (bool) of shape (num_items,) , array of binary values denoting which items can be packed into the knapsack.","title":"Observation"},{"location":"environments/knapsack/#action","text":"The action space is a DiscreteArray of integer values in the range of [0, num_items-1] . An action is the index of the next item to pack.","title":"Action"},{"location":"environments/knapsack/#reward","text":"The reward can be either: Dense : the value of the item to pack at the current timestep. Sparse : the sum of the values of the items packed in the bag at the end of the episode. In both cases, the reward is 0 if the action is invalid, i.e. an item that was previously selected is selected again or has a weight larger than the bag capacity.","title":"Reward"},{"location":"environments/knapsack/#registered-versions","text":"Knapsack-v1 : Knapsack problem with 50 randomly generated items, a total budget of 12.5 and a dense reward function.","title":"Registered Versions \ud83d\udcd6"},{"location":"environments/maze/","text":"Maze Environment # We provide here a Jax JIT-able implementation of a 2D maze problem. The maze is a size-configurable 2D matrix where each cell represents either free space (white) or wall (black). The goal is for the agent (green) to reach the single target cell (red). It is a sparse reward problem, where the agent receives a reward of 0 at every step and a reward of 1 for reaching the target. The agent may choose to move one space up, right, down, or left: (\"N\", \u201cE\u201d, \"S\", \"W\"). If the way is blocked by a wall, it will remain at the same position. Each maze is randomly generated using a recursive division function. By default, a new maze, initial agent position and target position are generated each time the environment is reset. Observation # As an observation, the agent has access to the current maze configuration in the array named walls . It also has access to its current position agent_position , the target's target_position , the number of steps step_count elapsed in the current episode and the action mask action_mask . agent_position : Position(row, col) (int32) each of shape () , agent position in the maze. target_position : Position(row, col) (int32) each of shape () , target position in the maze. walls : jax array (bool) of shape (num_rows, num_cols) , indicates whether a grid cell is a wall. step_count : jax array (int32) of shape () , number of steps elapsed in the current episode. action_mask : jax array (bool) of shape (4,) , binary values denoting whether each action is possible. An example 5x5 observation walls array, is shown below. 1 represents a wall, and 0 represents free space. 1 2 3 4 5 [0, 1, 0, 0, 0], [0, 1, 0, 1, 1], [0, 1, 0, 0, 0], [0, 0, 0, 1, 1], [0, 0, 0, 0, 0] Action # The action space is a DiscreteArray of integer values in the range of [0, 3]. I.e. the agent can take one of four actions: up ( 0 ), right ( 1 ), down ( 2 ), or left ( 3 ). If an invalid action is taken, or an action is blocked by a wall, a no-op is performed and the agent's position remains unchanged. Reward # Maze is a sparse reward problem, where the agent receives a reward of 0 at every step and a reward of 1 for reaching the target position. An episode ends when the agent reaches the target position, or after a set number of steps (by default, this is twice the number of cells in the maze, i.e. step_limit=2*num_rows*num_cols ). Registered Versions \ud83d\udcd6 # Maze-v0 , maze with 10 rows and 10 cols.","title":"Maze"},{"location":"environments/maze/#maze-environment","text":"We provide here a Jax JIT-able implementation of a 2D maze problem. The maze is a size-configurable 2D matrix where each cell represents either free space (white) or wall (black). The goal is for the agent (green) to reach the single target cell (red). It is a sparse reward problem, where the agent receives a reward of 0 at every step and a reward of 1 for reaching the target. The agent may choose to move one space up, right, down, or left: (\"N\", \u201cE\u201d, \"S\", \"W\"). If the way is blocked by a wall, it will remain at the same position. Each maze is randomly generated using a recursive division function. By default, a new maze, initial agent position and target position are generated each time the environment is reset.","title":"Maze Environment"},{"location":"environments/maze/#observation","text":"As an observation, the agent has access to the current maze configuration in the array named walls . It also has access to its current position agent_position , the target's target_position , the number of steps step_count elapsed in the current episode and the action mask action_mask . agent_position : Position(row, col) (int32) each of shape () , agent position in the maze. target_position : Position(row, col) (int32) each of shape () , target position in the maze. walls : jax array (bool) of shape (num_rows, num_cols) , indicates whether a grid cell is a wall. step_count : jax array (int32) of shape () , number of steps elapsed in the current episode. action_mask : jax array (bool) of shape (4,) , binary values denoting whether each action is possible. An example 5x5 observation walls array, is shown below. 1 represents a wall, and 0 represents free space. 1 2 3 4 5 [0, 1, 0, 0, 0], [0, 1, 0, 1, 1], [0, 1, 0, 0, 0], [0, 0, 0, 1, 1], [0, 0, 0, 0, 0]","title":"Observation"},{"location":"environments/maze/#action","text":"The action space is a DiscreteArray of integer values in the range of [0, 3]. I.e. the agent can take one of four actions: up ( 0 ), right ( 1 ), down ( 2 ), or left ( 3 ). If an invalid action is taken, or an action is blocked by a wall, a no-op is performed and the agent's position remains unchanged.","title":"Action"},{"location":"environments/maze/#reward","text":"Maze is a sparse reward problem, where the agent receives a reward of 0 at every step and a reward of 1 for reaching the target position. An episode ends when the agent reaches the target position, or after a set number of steps (by default, this is twice the number of cells in the maze, i.e. step_limit=2*num_rows*num_cols ).","title":"Reward"},{"location":"environments/maze/#registered-versions","text":"Maze-v0 , maze with 10 rows and 10 cols.","title":"Registered Versions \ud83d\udcd6"},{"location":"environments/minesweeper/","text":"Minesweeper Environment # We provide here a Jax JIT-able implementation of the Minesweeper game. Observation # The observation given to the agent consists of: board : jax array (int32) of shape (num_rows, num_cols) : each cell contains -1 if not yet explored, or otherwise the number of mines in the 8 adjacent squares. action_mask : jax array (bool) of shape (num_rows, num_cols) : indicates which actions are valid (not yet explored squares). This can also be determined from the board which will have an entry of -1 in all of these positions. num_mines : jax array (int32) of shape () , indicates the number of mines to locate. step_count : jax array (int32) of shape () : specifies how many timesteps have elapsed since environment reset. Action # The action space is a MultiDiscreteArray of integer values representing coordinates of the square to explore, e.g. [3, 6] for the cell located on the third row and sixth column. If either a mined square or an already explored square is selected, the episode terminates (the latter are termed invalid actions ). Also, exploring a square will reveal only the contents of that square. This differs slightly from the usual implementation of the game, which automatically and recursively reveals neighbouring squares if there are no adjacent mines. Reward # The reward is configurable, but default to +1 for exploring a new square that does not contain a mine, and 0 otherwise (which also terminates the episode). The episode also terminates if the board is solved. Registered Versions \ud83d\udcd6 # Minesweeper-v0 , the classic game on a 10x10 grid with 10 mines to locate.","title":"Minesweeper"},{"location":"environments/minesweeper/#minesweeper-environment","text":"We provide here a Jax JIT-able implementation of the Minesweeper game.","title":"Minesweeper Environment"},{"location":"environments/minesweeper/#observation","text":"The observation given to the agent consists of: board : jax array (int32) of shape (num_rows, num_cols) : each cell contains -1 if not yet explored, or otherwise the number of mines in the 8 adjacent squares. action_mask : jax array (bool) of shape (num_rows, num_cols) : indicates which actions are valid (not yet explored squares). This can also be determined from the board which will have an entry of -1 in all of these positions. num_mines : jax array (int32) of shape () , indicates the number of mines to locate. step_count : jax array (int32) of shape () : specifies how many timesteps have elapsed since environment reset.","title":"Observation"},{"location":"environments/minesweeper/#action","text":"The action space is a MultiDiscreteArray of integer values representing coordinates of the square to explore, e.g. [3, 6] for the cell located on the third row and sixth column. If either a mined square or an already explored square is selected, the episode terminates (the latter are termed invalid actions ). Also, exploring a square will reveal only the contents of that square. This differs slightly from the usual implementation of the game, which automatically and recursively reveals neighbouring squares if there are no adjacent mines.","title":"Action"},{"location":"environments/minesweeper/#reward","text":"The reward is configurable, but default to +1 for exploring a new square that does not contain a mine, and 0 otherwise (which also terminates the episode). The episode also terminates if the board is solved.","title":"Reward"},{"location":"environments/minesweeper/#registered-versions","text":"Minesweeper-v0 , the classic game on a 10x10 grid with 10 mines to locate.","title":"Registered Versions \ud83d\udcd6"},{"location":"environments/mmst/","text":"MMST Environment # The multi minimum spanning tree (mmst) environment consists of a random connected graph with groups of nodes (same node types) that needs to be connected. The goal of the environment is to connect all nodes of the same type together without using the same utility nodes (nodes that do not belong to any group of nodes) in the shortest time possible. An episode ends when all group of nodes are connected or the maximum number of steps is reached. Note: This environment can be treated as a multi agent problem with each agent atempting to connect one group of node. In this implementation, we treat the problem as single agent that outputs multiple actions per nodes. Observation # At each step observation contains 4 items: a node_types, an adjacency matrix for the graph, an action mask for each group of nodes (agent) and current node positon of each agent. node_types : Array representing the types of nodes in the problem. For example, if we have 12 nodes, their indices are 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11. Let's consider we have 2 agents. Agent 0 wants to connect nodes (0, 1, 9), and agent 1 wants to connect nodes (3, 5, 8). The remaining nodes are considered utility nodes. Therefore, in the state view, the node_types are represented as [0, 0, -1, 1, -1, 1, 0, -1, 1, 0, -1, -1]. When generating the problem, each agent starts from one of its nodes. So, if agent 0 starts on node 1 and agent 1 on node 3, the connected_nodes array will have values [1, -1, ...] and [3, -1, ...] respectively. The agent's observation is represented using the following rules: - Each agent should see its connected nodes on the path as 0. - Nodes that the agent still needs to connect are represented as 1. - The next agent's nodes are represented by 2 and 3, the next by 4 and 5, and so on. - Utility unconnected nodes are represented by -1. For the 12 node example mentioned above, the expected observation view node_types will have the following values: node_types = jnp.array( [ [1, 0, -1, 2, -1, 3, 1, -1, 3, 1, -1, -1], [3, 2, -1, 0, -1, 1, 3, -1, 1, 3, -1, -1], ], dtype=jnp.int32, ) Note: to make the environment single agent, we use the first agent's observation. adj_matrix : Adjacency matrix representing the connections between nodes. positions : Current node positions of the agents. In our current problem, this will be represented as jnp.array([1, 3]). step_count : integer to keep track of the number of steps. action_mask : Binary mask indicating the validity of each action. Given the current node on which the agent is located, this mask determines if there is a valid edge to every other node. Action # The action space is a MultiDiscreteArray of shape (num_agents,) of integer values in the range of [0, num_nodes-1] . During every step, an agent picks the next node it wants to move to. An action is invalid if the agent picks a node it has no edge to or the node is a utility node already been used by another agent. Reward # At every step, an agent receives a reward of 10.0 if it gets a valid connection, a reward of -1.0 if it does not connect and an extra penalty of -1.0 if it chooses an invalid action. The total step reward is the sum of rewards per agent. Registered Versions \ud83d\udcd6 # MMST-v0 , 3 agents, 36 nodes, 72 edges, 4 nodes to connect per agent and step limit of 70.","title":"MMST"},{"location":"environments/mmst/#mmst-environment","text":"The multi minimum spanning tree (mmst) environment consists of a random connected graph with groups of nodes (same node types) that needs to be connected. The goal of the environment is to connect all nodes of the same type together without using the same utility nodes (nodes that do not belong to any group of nodes) in the shortest time possible. An episode ends when all group of nodes are connected or the maximum number of steps is reached. Note: This environment can be treated as a multi agent problem with each agent atempting to connect one group of node. In this implementation, we treat the problem as single agent that outputs multiple actions per nodes.","title":"MMST Environment"},{"location":"environments/mmst/#observation","text":"At each step observation contains 4 items: a node_types, an adjacency matrix for the graph, an action mask for each group of nodes (agent) and current node positon of each agent. node_types : Array representing the types of nodes in the problem. For example, if we have 12 nodes, their indices are 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11. Let's consider we have 2 agents. Agent 0 wants to connect nodes (0, 1, 9), and agent 1 wants to connect nodes (3, 5, 8). The remaining nodes are considered utility nodes. Therefore, in the state view, the node_types are represented as [0, 0, -1, 1, -1, 1, 0, -1, 1, 0, -1, -1]. When generating the problem, each agent starts from one of its nodes. So, if agent 0 starts on node 1 and agent 1 on node 3, the connected_nodes array will have values [1, -1, ...] and [3, -1, ...] respectively. The agent's observation is represented using the following rules: - Each agent should see its connected nodes on the path as 0. - Nodes that the agent still needs to connect are represented as 1. - The next agent's nodes are represented by 2 and 3, the next by 4 and 5, and so on. - Utility unconnected nodes are represented by -1. For the 12 node example mentioned above, the expected observation view node_types will have the following values: node_types = jnp.array( [ [1, 0, -1, 2, -1, 3, 1, -1, 3, 1, -1, -1], [3, 2, -1, 0, -1, 1, 3, -1, 1, 3, -1, -1], ], dtype=jnp.int32, ) Note: to make the environment single agent, we use the first agent's observation. adj_matrix : Adjacency matrix representing the connections between nodes. positions : Current node positions of the agents. In our current problem, this will be represented as jnp.array([1, 3]). step_count : integer to keep track of the number of steps. action_mask : Binary mask indicating the validity of each action. Given the current node on which the agent is located, this mask determines if there is a valid edge to every other node.","title":"Observation"},{"location":"environments/mmst/#action","text":"The action space is a MultiDiscreteArray of shape (num_agents,) of integer values in the range of [0, num_nodes-1] . During every step, an agent picks the next node it wants to move to. An action is invalid if the agent picks a node it has no edge to or the node is a utility node already been used by another agent.","title":"Action"},{"location":"environments/mmst/#reward","text":"At every step, an agent receives a reward of 10.0 if it gets a valid connection, a reward of -1.0 if it does not connect and an extra penalty of -1.0 if it chooses an invalid action. The total step reward is the sum of rewards per agent.","title":"Reward"},{"location":"environments/mmst/#registered-versions","text":"MMST-v0 , 3 agents, 36 nodes, 72 edges, 4 nodes to connect per agent and step limit of 70.","title":"Registered Versions \ud83d\udcd6"},{"location":"environments/multi_cvrp/","text":"Multi Agent Capacitated Vehicle Routing Problem - MultiCVRP Environment # We provide here a Jax JIT-able implementation of the multi-agent capacitated vehicle routing problem (MultiCVRP) which is specified in MVRPSTW . This environment introduces the problem of routing multiple agents in a coordinated manner, specifically in the context of collecting items from various locations. Each agent controls one vehicle. The problem, called the multi-agent capacitated vehicle routing problem (MultiCVRP), entails directing a group of agents to different locations on a map. They need to collectively go to each node and return items to the depot location. To make the problem a bit more realistic we consider the multi-vehicle routing problem with soft time windows (MVRPSTW). In this formulation, each location on the map also has a soft time window in which the items must be collected. If the items are collected outside this window a penalty is provided to the agents. A new problem instance is generated by resetting the environment. The problem instance contains coordinates for each node sampled from a uniform distribution inside the map boundries, and each node (except for depot) has a specific demand which is an integer value sampled from a uniform distribution between 1 and the maximum demand. The number of nodes with demand is a parameter of the environment. Observation # Each agent receives information on the coordinates, demands, time windows and penalty coefficients of all the customer nodes. Futhermore the agents receive positions, local times and vehicle capacity information on all vehicles. Lastly an action mask is also provided to each agent. node_coordinates : jax array (float32) of shape (num_vehicles, num_customers + 1, 2) , shows an array of the coordinates of each customer node and the depot node. node_demands : jax array (int16) of shape (num_vehicles, num_customers + 1,) , shows an array of the demands of each city node (and depot node where the demand is set to 0). node_time_windows : jax array (float32) of shape (num_vehicles, num_customers + 1, 2) , shows an array of the early and late time cutoffs for each customer. node_penalty_coefs : jax array (float32) of shape (num_vehicles, num_customers + 1, 2) , shows the early and late penalty coefficients for arriving early or late at a customer's location. other_vehicles_position : jax array (int16) of shape (num_vehicles, num_vehicles - 1) , shows the positions of all other vehicles. other_vehicles_local_times : jax array (float32) of shape (num_vehicles, num_vehicles - 1) , shows the local times of all other vehicles. other_vehicles_capacities : jax array (int16) of shape (num_vehicles, num_vehicles - 1) , shows the capacities of all other vehicles. vehicle_position : jax array (int16) of shape (num_vehicles) , shows the positions of the vehicles controlled by the agents. vehicle_local_time : jax array (float32) of shape (num_vehicles) , shows the local times of the vehicles controlled by the agents. vehicle_capacity : jax array (int16) of shape (num_vehicles) , shows the capacity of the vehicles controlled by the agents. action_mask : jax array (bool) of shape (num_vehicles, num_customers + 1,) , denoting which actions are possible (True) and which are not (False). Action # Each agent's action space is a BoundedArray of integer values in the range of [0, num_customers] . An action is the index of the next node to visit, and an action value of 0 corresponds to visiting the depot. Reward # Dense : The reward is equal to the sum of negative distances of the current location and next location of all the vehicles. Time penalities are added if the agents arrived early or late to specific customers. If the max step limit is reached, the episode ends with a large negative reward which is equal to the maximum negative distance reward that can be incurred. Sparse : The reward is 0 at every step but the last, where the reward is the negative of the length of the path chosen by all the agents combined. Time penalities are added if the agents arrived early or late to specific customers. If the max step limit is reached, the episode ends with a large negative reward which is equal to the maximum negative distance reward that can be incurred. Registered Versions \ud83d\udcd6 # MultiCVRP-v0 : MultiCVRP problem with 20 customers (randomly generated), maximum capacity of 20, and maximum demand of 10 with two vehicles.","title":"MultiCVRP"},{"location":"environments/multi_cvrp/#multi-agent-capacitated-vehicle-routing-problem-multicvrp-environment","text":"We provide here a Jax JIT-able implementation of the multi-agent capacitated vehicle routing problem (MultiCVRP) which is specified in MVRPSTW . This environment introduces the problem of routing multiple agents in a coordinated manner, specifically in the context of collecting items from various locations. Each agent controls one vehicle. The problem, called the multi-agent capacitated vehicle routing problem (MultiCVRP), entails directing a group of agents to different locations on a map. They need to collectively go to each node and return items to the depot location. To make the problem a bit more realistic we consider the multi-vehicle routing problem with soft time windows (MVRPSTW). In this formulation, each location on the map also has a soft time window in which the items must be collected. If the items are collected outside this window a penalty is provided to the agents. A new problem instance is generated by resetting the environment. The problem instance contains coordinates for each node sampled from a uniform distribution inside the map boundries, and each node (except for depot) has a specific demand which is an integer value sampled from a uniform distribution between 1 and the maximum demand. The number of nodes with demand is a parameter of the environment.","title":"Multi Agent Capacitated Vehicle Routing Problem - MultiCVRP Environment"},{"location":"environments/multi_cvrp/#observation","text":"Each agent receives information on the coordinates, demands, time windows and penalty coefficients of all the customer nodes. Futhermore the agents receive positions, local times and vehicle capacity information on all vehicles. Lastly an action mask is also provided to each agent. node_coordinates : jax array (float32) of shape (num_vehicles, num_customers + 1, 2) , shows an array of the coordinates of each customer node and the depot node. node_demands : jax array (int16) of shape (num_vehicles, num_customers + 1,) , shows an array of the demands of each city node (and depot node where the demand is set to 0). node_time_windows : jax array (float32) of shape (num_vehicles, num_customers + 1, 2) , shows an array of the early and late time cutoffs for each customer. node_penalty_coefs : jax array (float32) of shape (num_vehicles, num_customers + 1, 2) , shows the early and late penalty coefficients for arriving early or late at a customer's location. other_vehicles_position : jax array (int16) of shape (num_vehicles, num_vehicles - 1) , shows the positions of all other vehicles. other_vehicles_local_times : jax array (float32) of shape (num_vehicles, num_vehicles - 1) , shows the local times of all other vehicles. other_vehicles_capacities : jax array (int16) of shape (num_vehicles, num_vehicles - 1) , shows the capacities of all other vehicles. vehicle_position : jax array (int16) of shape (num_vehicles) , shows the positions of the vehicles controlled by the agents. vehicle_local_time : jax array (float32) of shape (num_vehicles) , shows the local times of the vehicles controlled by the agents. vehicle_capacity : jax array (int16) of shape (num_vehicles) , shows the capacity of the vehicles controlled by the agents. action_mask : jax array (bool) of shape (num_vehicles, num_customers + 1,) , denoting which actions are possible (True) and which are not (False).","title":"Observation"},{"location":"environments/multi_cvrp/#action","text":"Each agent's action space is a BoundedArray of integer values in the range of [0, num_customers] . An action is the index of the next node to visit, and an action value of 0 corresponds to visiting the depot.","title":"Action"},{"location":"environments/multi_cvrp/#reward","text":"Dense : The reward is equal to the sum of negative distances of the current location and next location of all the vehicles. Time penalities are added if the agents arrived early or late to specific customers. If the max step limit is reached, the episode ends with a large negative reward which is equal to the maximum negative distance reward that can be incurred. Sparse : The reward is 0 at every step but the last, where the reward is the negative of the length of the path chosen by all the agents combined. Time penalities are added if the agents arrived early or late to specific customers. If the max step limit is reached, the episode ends with a large negative reward which is equal to the maximum negative distance reward that can be incurred.","title":"Reward"},{"location":"environments/multi_cvrp/#registered-versions","text":"MultiCVRP-v0 : MultiCVRP problem with 20 customers (randomly generated), maximum capacity of 20, and maximum demand of 10 with two vehicles.","title":"Registered Versions \ud83d\udcd6"},{"location":"environments/robot_warehouse/","text":"RobotWarehouse Environment # We provide a JAX jit-able implementation of the Robotic Warehouse environment. The Robot Warehouse (RWARE) environment simulates a warehouse with robots moving and delivering requested goods. Real-world applications inspire the simulator, in which robots pick up shelves and deliver them to a workstation. Humans access the content of a shelf, and then robots can return them to empty shelf locations. The goal is to successfully deliver as many requested shelves in a given time budget. Once a shelf has been delivered, a new shelf is requested at random. Agents start each episode at random locations within the warehouse. Observation # The observation seen by the agent is a NamedTuple containing the following: agents_view : jax array (int32) of shape (num_agents, num_obs_features) , array representing the agent's view of other agents and shelves. action_mask : jax array (bool) of shape (num_agents, 5) , array specifying, for each agent, which action (noop, forward, left, right, toggle_load) is legal. step_count : jax array (int32) of shape () , number of steps elapsed in the current episode. Action # The action space is a MultiDiscreteArray containing an integer value in [0, 1, 2, 3, 4] for each agent. Each agent can take one of five actions: noop ( 0 ), forward ( 1 ), turn left ( 2 ), turn right ( 3 ), or toggle_load ( 4 ). The episode terminates under the following conditions: An invalid action is taken, or An agent collides with another agent. Reward # The reward is global and shared among the agents. It is equal to the number of shelves which were delivered successfully during the time step (i.e., +1 for each shelf). Registered Versions \ud83d\udcd6 # RobotWarehouse-v0 , a warehouse with 4 agents each with a sensor range of 1, a warehouse floor with 2 shelf rows, 3 shelf columns, a column height of 8, and a shelf request queue of 8.","title":"RobotWarehouse"},{"location":"environments/robot_warehouse/#robotwarehouse-environment","text":"We provide a JAX jit-able implementation of the Robotic Warehouse environment. The Robot Warehouse (RWARE) environment simulates a warehouse with robots moving and delivering requested goods. Real-world applications inspire the simulator, in which robots pick up shelves and deliver them to a workstation. Humans access the content of a shelf, and then robots can return them to empty shelf locations. The goal is to successfully deliver as many requested shelves in a given time budget. Once a shelf has been delivered, a new shelf is requested at random. Agents start each episode at random locations within the warehouse.","title":"RobotWarehouse Environment"},{"location":"environments/robot_warehouse/#observation","text":"The observation seen by the agent is a NamedTuple containing the following: agents_view : jax array (int32) of shape (num_agents, num_obs_features) , array representing the agent's view of other agents and shelves. action_mask : jax array (bool) of shape (num_agents, 5) , array specifying, for each agent, which action (noop, forward, left, right, toggle_load) is legal. step_count : jax array (int32) of shape () , number of steps elapsed in the current episode.","title":"Observation"},{"location":"environments/robot_warehouse/#action","text":"The action space is a MultiDiscreteArray containing an integer value in [0, 1, 2, 3, 4] for each agent. Each agent can take one of five actions: noop ( 0 ), forward ( 1 ), turn left ( 2 ), turn right ( 3 ), or toggle_load ( 4 ). The episode terminates under the following conditions: An invalid action is taken, or An agent collides with another agent.","title":"Action"},{"location":"environments/robot_warehouse/#reward","text":"The reward is global and shared among the agents. It is equal to the number of shelves which were delivered successfully during the time step (i.e., +1 for each shelf).","title":"Reward"},{"location":"environments/robot_warehouse/#registered-versions","text":"RobotWarehouse-v0 , a warehouse with 4 agents each with a sensor range of 1, a warehouse floor with 2 shelf rows, 3 shelf columns, a column height of 8, and a shelf request queue of 8.","title":"Registered Versions \ud83d\udcd6"},{"location":"environments/rubiks_cube/","text":"Rubik's Cube Environment # We provide here a Jax JIT-able implementation of the Rubik's cube . The environment contains an implementation of the classic 3x3x3 cube by default, and configurably other sizes. The goal of the agent is to match all stickers on each face to a single colour. On resetting the environment the cube will be randomly scrambled with a configurable number of turns (by default 100). Observation # The observation given to the agent gives a view of the current state of the cube, cube : jax array (int8) of shape (6, cube_size, cube_size) whose values are in [0, 1, 2, 3, 4, 5] (corresponding to the different sticker colors). The indices of the array specify the sticker position - first the face (in the order up , front , right , back , left , down ) and then the row and column. Note that the orientation of each face is as follows: UP: LEFT face on the left and BACK face pointing up FRONT: LEFT face on the left and UP face pointing up RIGHT: FRONT face on the left and UP face pointing up BACK: RIGHT face on the left and UP face pointing up LEFT: BACK face on the left and UP face pointing up DOWN: LEFT face on the left and FRONT face pointing up step_count : jax array (int32) of shape () , representing the number of steps in the episode thus far. Action # The action space is a MultiDiscreteArray , specifically a tuple of an index between 0 and 5 (since there are 6 faces), an index between 0 and cube_size//2 (the number of possible depths), and an index between 0 and 2 (3 possible directions). An action thus consists of three pieces of information: Face to turn, Depth of the turn (possible depths are between 0 representing the outer layer and cube_size//2 representing the layer closest to the middle), Direction of turn (possible directions are clockwise, anti-clockwise, or a half turn). Reward # The reward function is configurable, but by default is the fully sparse reward giving +1 for solving the cube and otherwise 0 . The episode terminates if either the cube is solved or a configurable horizon (by default 200 ) is reached. Registered Versions \ud83d\udcd6 # RubiksCube-v0 , the standard Rubik's Cube puzzle with faces of size 3x3. RubiksCube-partly-scrambled-v0 , an easier version of the standard Rubik's Cube puzzle with faces of size 3x3 yet only 7 scrambles at reset time, making it technically maximum 7 actions away from the solution.","title":"RubiksCube"},{"location":"environments/rubiks_cube/#rubiks-cube-environment","text":"We provide here a Jax JIT-able implementation of the Rubik's cube . The environment contains an implementation of the classic 3x3x3 cube by default, and configurably other sizes. The goal of the agent is to match all stickers on each face to a single colour. On resetting the environment the cube will be randomly scrambled with a configurable number of turns (by default 100).","title":"Rubik's Cube Environment"},{"location":"environments/rubiks_cube/#observation","text":"The observation given to the agent gives a view of the current state of the cube, cube : jax array (int8) of shape (6, cube_size, cube_size) whose values are in [0, 1, 2, 3, 4, 5] (corresponding to the different sticker colors). The indices of the array specify the sticker position - first the face (in the order up , front , right , back , left , down ) and then the row and column. Note that the orientation of each face is as follows: UP: LEFT face on the left and BACK face pointing up FRONT: LEFT face on the left and UP face pointing up RIGHT: FRONT face on the left and UP face pointing up BACK: RIGHT face on the left and UP face pointing up LEFT: BACK face on the left and UP face pointing up DOWN: LEFT face on the left and FRONT face pointing up step_count : jax array (int32) of shape () , representing the number of steps in the episode thus far.","title":"Observation"},{"location":"environments/rubiks_cube/#action","text":"The action space is a MultiDiscreteArray , specifically a tuple of an index between 0 and 5 (since there are 6 faces), an index between 0 and cube_size//2 (the number of possible depths), and an index between 0 and 2 (3 possible directions). An action thus consists of three pieces of information: Face to turn, Depth of the turn (possible depths are between 0 representing the outer layer and cube_size//2 representing the layer closest to the middle), Direction of turn (possible directions are clockwise, anti-clockwise, or a half turn).","title":"Action"},{"location":"environments/rubiks_cube/#reward","text":"The reward function is configurable, but by default is the fully sparse reward giving +1 for solving the cube and otherwise 0 . The episode terminates if either the cube is solved or a configurable horizon (by default 200 ) is reached.","title":"Reward"},{"location":"environments/rubiks_cube/#registered-versions","text":"RubiksCube-v0 , the standard Rubik's Cube puzzle with faces of size 3x3. RubiksCube-partly-scrambled-v0 , an easier version of the standard Rubik's Cube puzzle with faces of size 3x3 yet only 7 scrambles at reset time, making it technically maximum 7 actions away from the solution.","title":"Registered Versions \ud83d\udcd6"},{"location":"environments/snake/","text":"Snake Environment \ud83d\udc0d # We provide here an implementation of the Snake environment from (Bonnet et al., 2021) . The goal of the agent is to navigate in a grid world (by default of size 12x12) to collect as many fruits as possible without colliding with its own body (i.e. looping on itself). Observation # grid : jax array (float) of shape (num_rows, num_cols, 5) , feature maps (image) that include information about the fruit, the snake head, its body and tail. step_count : jax array (int32) of shape () , current number of steps in the episode. action_mask : jax array (bool) of shape (4,) , array specifying which directions the snake can move in from its current position. Action # The action space is a DiscreteArray of integer values: [0,1,2,3] -> [Up, Right, Down, Left] . Reward # The reward is +1 upon collection of a fruit and 0 otherwise. Registered Versions \ud83d\udcd6 # Snake-v1 : Snake game on a board of size 12x12 with a time limit of 4000 .","title":"Snake"},{"location":"environments/snake/#snake-environment","text":"We provide here an implementation of the Snake environment from (Bonnet et al., 2021) . The goal of the agent is to navigate in a grid world (by default of size 12x12) to collect as many fruits as possible without colliding with its own body (i.e. looping on itself).","title":"Snake Environment \ud83d\udc0d"},{"location":"environments/snake/#observation","text":"grid : jax array (float) of shape (num_rows, num_cols, 5) , feature maps (image) that include information about the fruit, the snake head, its body and tail. step_count : jax array (int32) of shape () , current number of steps in the episode. action_mask : jax array (bool) of shape (4,) , array specifying which directions the snake can move in from its current position.","title":"Observation"},{"location":"environments/snake/#action","text":"The action space is a DiscreteArray of integer values: [0,1,2,3] -> [Up, Right, Down, Left] .","title":"Action"},{"location":"environments/snake/#reward","text":"The reward is +1 upon collection of a fruit and 0 otherwise.","title":"Reward"},{"location":"environments/snake/#registered-versions","text":"Snake-v1 : Snake game on a board of size 12x12 with a time limit of 4000 .","title":"Registered Versions \ud83d\udcd6"},{"location":"environments/sudoku/","text":"Sudoku Environment # We provide here a Jax JIT-able implementation of the Sudoku puzzle game. Observation # The observation given to the agent consists of: board : jax array (int32) of shape (9,9): empty cells are represented by -1, and filled cells are represented by 0-8. action_mask : jax array (bool) of shape (9,9,9): indicates which actions are valid. Action # The action space is a MultiDiscreteArray of integer values representing coordinates of the square to explore and the digits to write in the cell, e.g. [3, 6, 8] for writing the digit 9 in the cell located on the fourth row and seventh column. Reward # The reward is 1 at the end of the episode if the board is correctly solved, and 0 in every other case. Termination # An episode terminates when there are no more legal actions available, this could happen if the board is solved or if the agent finds itself in a dead-end. Registered Versions \ud83d\udcd6 # Sudoku-v0 , the classic game on a 9x9 grid, 10000 random puzzles with mixed difficulty are included by default. Sudoku-very-easy-v0 , the classic game on a 9x9 grid, only 1000 very-easy random puzzles (>46 clues) included by default. Using custom puzzle instances # If one wants to include its own database of puzzles, the DatabaseGenerator can be initialized with any collection of puzzles using the argument custom_boards . Some references for databases of puzzle of various difficulties: - https://www.kaggle.com/datasets/rohanrao/sudoku - https://www.kaggle.com/datasets/informoney/4-million-sudoku-puzzles-easytohard Difficulty level as a function of number of clues # Adapted from An Algorithm for Generating only Desired Permutations for Solving Sudoku Puzzle .","title":"Sudoku"},{"location":"environments/sudoku/#sudoku-environment","text":"We provide here a Jax JIT-able implementation of the Sudoku puzzle game.","title":"Sudoku Environment"},{"location":"environments/sudoku/#observation","text":"The observation given to the agent consists of: board : jax array (int32) of shape (9,9): empty cells are represented by -1, and filled cells are represented by 0-8. action_mask : jax array (bool) of shape (9,9,9): indicates which actions are valid.","title":"Observation"},{"location":"environments/sudoku/#action","text":"The action space is a MultiDiscreteArray of integer values representing coordinates of the square to explore and the digits to write in the cell, e.g. [3, 6, 8] for writing the digit 9 in the cell located on the fourth row and seventh column.","title":"Action"},{"location":"environments/sudoku/#reward","text":"The reward is 1 at the end of the episode if the board is correctly solved, and 0 in every other case.","title":"Reward"},{"location":"environments/sudoku/#termination","text":"An episode terminates when there are no more legal actions available, this could happen if the board is solved or if the agent finds itself in a dead-end.","title":"Termination"},{"location":"environments/sudoku/#registered-versions","text":"Sudoku-v0 , the classic game on a 9x9 grid, 10000 random puzzles with mixed difficulty are included by default. Sudoku-very-easy-v0 , the classic game on a 9x9 grid, only 1000 very-easy random puzzles (>46 clues) included by default.","title":"Registered Versions \ud83d\udcd6"},{"location":"environments/sudoku/#using-custom-puzzle-instances","text":"If one wants to include its own database of puzzles, the DatabaseGenerator can be initialized with any collection of puzzles using the argument custom_boards . Some references for databases of puzzle of various difficulties: - https://www.kaggle.com/datasets/rohanrao/sudoku - https://www.kaggle.com/datasets/informoney/4-million-sudoku-puzzles-easytohard","title":"Using custom puzzle instances"},{"location":"environments/sudoku/#difficulty-level-as-a-function-of-number-of-clues","text":"Adapted from An Algorithm for Generating only Desired Permutations for Solving Sudoku Puzzle .","title":"Difficulty level as a function of number of clues"},{"location":"environments/tetris/","text":"Tetris Environment # We provide here a Jax JIT-able implementation of the game Tetris. Tetris is a popular single-player game that is played on a 2D grid by fitting falling blocks of various Tetrominoes together to create horizontal lines without any gaps. As each line is completed, it disappears, and the player earns points. If the stack of blocks reaches the top of the game grid, the game ends. The objective of Tetris is to score as many points as possible before the game ends, by clearing as many lines as possible. Tetris consists of 7 types of Tetrominoes, which are shapes that represent the letters \"I\", \"O\", \"S\", \"Z\", \"L\", \"J\", and \"T\" as shown in the image below. Observation # The observation in Tetris includes information about the grid, the Tetromino and the action mask. grid : jax array (int32) of shape (num_rows, num_cols) , representing the current grid state. The grid is filled with zeros for the empty cells and with ones for the filled cells. Here is an example of a random observation of the grid: 1 2 3 4 5 6 7 8 9 [ [0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 1, 1], [0, 0, 0, 0, 1, 1], [0, 1, 0, 0, 1, 1], [0, 1, 1, 1, 0, 1], [0, 1, 0, 1, 1, 1], [1, 1, 0, 1, 1, 1], ] tetromino : jax array (int32) of shape (4, 4) , where a value of 1 indicates a filled cell and a value of 0 indicates an empty cell. Here is an example of an I tetromino: 1 2 3 4 5 6 [ [1, 0, 0, 0], [1, 0, 0, 0], [1, 0, 0, 0], [1, 0, 0, 0], ] action_mask : jax array (bool) of shape (4, num_cols) , indicating which actions are valid in the current state of the environment. Each row in the action mask corresponds to a Tetromino for a certain rotation (example: the first row for 0 degrees rotation, the second row for 90 degrees rotation, and so on). Here is an example of an action mask that corresponds to the same grid and the tetromino examples: 1 2 3 4 5 6 [ [ True, False, True, True, False, False], [ True, True, False, False, False, False], [ True, False, True, True, False, False], [ True, True, False, False, False, False], ] - step_count : jax array (int32) of shape () , integer to keep track of the number of steps. Action # The action space in Tetris is represented as a MultiDiscreteArray of two integer values. The first integer value corresponds to the selected X-position where the Tetromino will be placed, and the second integer value represents the index for the rotation degree. The rotation degree index can take four possible values: 0 for \"0 degrees\", 1 for \"90 degrees\", 2 for \"180 degrees\", and 3 for \"270 degrees\". For example, an action of [7, 2] means placing the Tetromino in the seventh column with a rotation of 180 degrees. Reward # Dense: the reward is based on the number of lines cleared and the reward_list [0, 40, 100, 300, 1200] . If no lines are cleared, the reward is 0. As the number of cleared lines increases, so does the reward, with the maximum reward of 1200 being awarded for clearing four lines at once. Registered Versions \ud83d\udcd6 # Tetris-v0 , the default settings for tetris with a grid of size 10x10.","title":"Tetris"},{"location":"environments/tetris/#tetris-environment","text":"We provide here a Jax JIT-able implementation of the game Tetris. Tetris is a popular single-player game that is played on a 2D grid by fitting falling blocks of various Tetrominoes together to create horizontal lines without any gaps. As each line is completed, it disappears, and the player earns points. If the stack of blocks reaches the top of the game grid, the game ends. The objective of Tetris is to score as many points as possible before the game ends, by clearing as many lines as possible. Tetris consists of 7 types of Tetrominoes, which are shapes that represent the letters \"I\", \"O\", \"S\", \"Z\", \"L\", \"J\", and \"T\" as shown in the image below.","title":"Tetris Environment"},{"location":"environments/tetris/#observation","text":"The observation in Tetris includes information about the grid, the Tetromino and the action mask. grid : jax array (int32) of shape (num_rows, num_cols) , representing the current grid state. The grid is filled with zeros for the empty cells and with ones for the filled cells. Here is an example of a random observation of the grid: 1 2 3 4 5 6 7 8 9 [ [0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 1, 1], [0, 0, 0, 0, 1, 1], [0, 1, 0, 0, 1, 1], [0, 1, 1, 1, 0, 1], [0, 1, 0, 1, 1, 1], [1, 1, 0, 1, 1, 1], ] tetromino : jax array (int32) of shape (4, 4) , where a value of 1 indicates a filled cell and a value of 0 indicates an empty cell. Here is an example of an I tetromino: 1 2 3 4 5 6 [ [1, 0, 0, 0], [1, 0, 0, 0], [1, 0, 0, 0], [1, 0, 0, 0], ] action_mask : jax array (bool) of shape (4, num_cols) , indicating which actions are valid in the current state of the environment. Each row in the action mask corresponds to a Tetromino for a certain rotation (example: the first row for 0 degrees rotation, the second row for 90 degrees rotation, and so on). Here is an example of an action mask that corresponds to the same grid and the tetromino examples: 1 2 3 4 5 6 [ [ True, False, True, True, False, False], [ True, True, False, False, False, False], [ True, False, True, True, False, False], [ True, True, False, False, False, False], ] - step_count : jax array (int32) of shape () , integer to keep track of the number of steps.","title":"Observation"},{"location":"environments/tetris/#action","text":"The action space in Tetris is represented as a MultiDiscreteArray of two integer values. The first integer value corresponds to the selected X-position where the Tetromino will be placed, and the second integer value represents the index for the rotation degree. The rotation degree index can take four possible values: 0 for \"0 degrees\", 1 for \"90 degrees\", 2 for \"180 degrees\", and 3 for \"270 degrees\". For example, an action of [7, 2] means placing the Tetromino in the seventh column with a rotation of 180 degrees.","title":"Action"},{"location":"environments/tetris/#reward","text":"Dense: the reward is based on the number of lines cleared and the reward_list [0, 40, 100, 300, 1200] . If no lines are cleared, the reward is 0. As the number of cleared lines increases, so does the reward, with the maximum reward of 1200 being awarded for clearing four lines at once.","title":"Reward"},{"location":"environments/tetris/#registered-versions","text":"Tetris-v0 , the default settings for tetris with a grid of size 10x10.","title":"Registered Versions \ud83d\udcd6"},{"location":"environments/tsp/","text":"Traveling Salesman Problem (TSP) Environment # We provide here a Jax JIT-able implementation of the traveling salesman problem (TSP) . TSP is a well-known combinatorial optimization problem. Given a set of cities and the distances between them, the goal is to determine the shortest route that visits each city exactly once and finishes in the starting city. The problem is NP-complete, thus there is no known algorithm both correct and fast (i.e., that runs in polynomial time) for any instance of the problem. When the environment is reset, a new problem instance is generated by sampling coordinates (a pair for each city) from a uniform distribution between 0 and 1. The number of cities is a parameter of the environment. A trajectory terminates when no new cities can be visited or the last action was invalid (i.e., the agent attempted to revisit a city). Observation # The observation given to the agent provides information on the problem layout, the visited/unvisited cities and the current position (city) of the agent. coordinates : jax array (float) of shape (num_cities, 2) , array of coordinates of each city. position : jax array (int32) of shape () , identifier (index) of the last visited city. trajectory : jax array (int32) of shape (num_cities,) , city indices defining the route ( -1 --> not filled yet). action_mask : jax array (bool) of shape (num_cities,) , binary values denoting whether a city can be visited. Action # The action space is a DiscreteArray of integer values in the range of [0, num_cities-1] . An action is the index of the next city to visit. Reward # The reward could be either: Dense : the negative distance between the current city and the chosen next city to go to. It is 0 for the first chosen city, and for the last city, it also includes the distance to the initial city to complete the tour. Sparse : the negative tour length at the end of the episode. The tour length is defined as the sum of the distances between consecutive cities. It is computed by starting at the first city and ending there, after visiting all the cities. In both cases, the reward is a large negative penalty of -num_cities * sqrt(2) if the action is invalid, i.e. a previously selected city is selected again. Registered Versions \ud83d\udcd6 # TSP-v1 : TSP problem with 20 randomly generated cities and a dense reward function.","title":"TSP"},{"location":"environments/tsp/#traveling-salesman-problem-tsp-environment","text":"We provide here a Jax JIT-able implementation of the traveling salesman problem (TSP) . TSP is a well-known combinatorial optimization problem. Given a set of cities and the distances between them, the goal is to determine the shortest route that visits each city exactly once and finishes in the starting city. The problem is NP-complete, thus there is no known algorithm both correct and fast (i.e., that runs in polynomial time) for any instance of the problem. When the environment is reset, a new problem instance is generated by sampling coordinates (a pair for each city) from a uniform distribution between 0 and 1. The number of cities is a parameter of the environment. A trajectory terminates when no new cities can be visited or the last action was invalid (i.e., the agent attempted to revisit a city).","title":"Traveling Salesman Problem (TSP) Environment"},{"location":"environments/tsp/#observation","text":"The observation given to the agent provides information on the problem layout, the visited/unvisited cities and the current position (city) of the agent. coordinates : jax array (float) of shape (num_cities, 2) , array of coordinates of each city. position : jax array (int32) of shape () , identifier (index) of the last visited city. trajectory : jax array (int32) of shape (num_cities,) , city indices defining the route ( -1 --> not filled yet). action_mask : jax array (bool) of shape (num_cities,) , binary values denoting whether a city can be visited.","title":"Observation"},{"location":"environments/tsp/#action","text":"The action space is a DiscreteArray of integer values in the range of [0, num_cities-1] . An action is the index of the next city to visit.","title":"Action"},{"location":"environments/tsp/#reward","text":"The reward could be either: Dense : the negative distance between the current city and the chosen next city to go to. It is 0 for the first chosen city, and for the last city, it also includes the distance to the initial city to complete the tour. Sparse : the negative tour length at the end of the episode. The tour length is defined as the sum of the distances between consecutive cities. It is computed by starting at the first city and ending there, after visiting all the cities. In both cases, the reward is a large negative penalty of -num_cities * sqrt(2) if the action is invalid, i.e. a previously selected city is selected again.","title":"Reward"},{"location":"environments/tsp/#registered-versions","text":"TSP-v1 : TSP problem with 20 randomly generated cities and a dense reward function.","title":"Registered Versions \ud83d\udcd6"},{"location":"guides/advanced_usage/","text":"Advanced Usage \ud83e\uddd1\u200d\ud83d\udd2c # Being written in JAX, Jumanji's environments benefit from many of its features including automatic vectorization/parallelization ( jax.vmap , jax.pmap ) and JIT-compilation ( jax.jit ), which can be composed arbitrarily. We provide an example of this below, where we use jax.vmap and jax.lax.scan to generate a batch of rollouts in the Snake environment. 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 import jax import jumanji from jumanji.wrappers import AutoResetWrapper env = jumanji . make ( \"Snake-v1\" ) # Create a Snake environment env = AutoResetWrapper ( env ) # Automatically reset the environment when an episode terminates batch_size = 7 rollout_length = 5 num_actions = env . action_spec () . num_values random_key = jax . random . PRNGKey ( 0 ) key1 , key2 = jax . random . split ( random_key ) def step_fn ( state , key ): action = jax . random . randint ( key = key , minval = 0 , maxval = num_actions , shape = ()) new_state , timestep = env . step ( state , action ) return new_state , timestep def run_n_steps ( state , key , n ): random_keys = jax . random . split ( key , n ) state , rollout = jax . lax . scan ( step_fn , state , random_keys ) return rollout # Instantiate a batch of environment states keys = jax . random . split ( key1 , batch_size ) state , timestep = jax . vmap ( env . reset )( keys ) # Collect a batch of rollouts keys = jax . random . split ( key2 , batch_size ) rollout = jax . vmap ( run_n_steps , in_axes = ( 0 , 0 , None ))( state , keys , rollout_length ) # Shape and type of given rollout: # TimeStep(step_type=(7, 5), reward=(7, 5), discount=(7, 5), observation=(7, 5, 6, 6, 5), extras=None)","title":"Advanced Usage"},{"location":"guides/advanced_usage/#advanced-usage","text":"Being written in JAX, Jumanji's environments benefit from many of its features including automatic vectorization/parallelization ( jax.vmap , jax.pmap ) and JIT-compilation ( jax.jit ), which can be composed arbitrarily. We provide an example of this below, where we use jax.vmap and jax.lax.scan to generate a batch of rollouts in the Snake environment. 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 import jax import jumanji from jumanji.wrappers import AutoResetWrapper env = jumanji . make ( \"Snake-v1\" ) # Create a Snake environment env = AutoResetWrapper ( env ) # Automatically reset the environment when an episode terminates batch_size = 7 rollout_length = 5 num_actions = env . action_spec () . num_values random_key = jax . random . PRNGKey ( 0 ) key1 , key2 = jax . random . split ( random_key ) def step_fn ( state , key ): action = jax . random . randint ( key = key , minval = 0 , maxval = num_actions , shape = ()) new_state , timestep = env . step ( state , action ) return new_state , timestep def run_n_steps ( state , key , n ): random_keys = jax . random . split ( key , n ) state , rollout = jax . lax . scan ( step_fn , state , random_keys ) return rollout # Instantiate a batch of environment states keys = jax . random . split ( key1 , batch_size ) state , timestep = jax . vmap ( env . reset )( keys ) # Collect a batch of rollouts keys = jax . random . split ( key2 , batch_size ) rollout = jax . vmap ( run_n_steps , in_axes = ( 0 , 0 , None ))( state , keys , rollout_length ) # Shape and type of given rollout: # TimeStep(step_type=(7, 5), reward=(7, 5), discount=(7, 5), observation=(7, 5, 6, 6, 5), extras=None)","title":"Advanced Usage \ud83e\uddd1\u200d\ud83d\udd2c"},{"location":"guides/registration/","text":"Environment Registry # Jumanji adopts the convention defined in Gym of having an environment registry and a make function to instantiate environments. Create an environment # To instantiate a Jumanji registered environment, we provide the convenient function jumanji.make . It can be used as follows: 1 2 3 4 5 6 import jax import jumanji env = jumanji . make ( 'BinPack-v1' ) key = jax . random . PRNGKey ( 0 ) state , timestep = env . reset ( key ) The environment ID is composed of two parts, the environment name and its version. To get the full list of registered environments, you can use the registered_environments util. \u26a0\ufe0f Warning 1 2 3 4 Users can provide additional key-word arguments in the call to `jumanji.make(env_id, ...)`. These are then passed to the class constructor. Because they can be used to overwrite the intended configuration of the environment when registered, we discourage users to do so. However, we are mindful of particular use cases that might require this flexibility. Although the make function provides a unified way to instantiate environments, users can always instantiate them by importing the corresponding environment class. Register your environment # In addition to the environments available in Jumanji, users can register their custom environment and access them through the familiar jumanji.make function. Assuming you created an environment by subclassing Jumanji Environment base class, you can register it as follows: 1 2 3 4 5 6 7 from jumanji import register register ( id = \"CustomEnv-v0\" , # format: (env_name)-v(version) entry_point = \"path.to.your.package:CustomEnv\" , # class constructor kwargs = { ... }, # environment configuration ) To successfully register your environment, make sure to provide the right path to your class constructor. The kwargs argument is there to configurate the environment and allow you to register scenarios with a specific set of arguments. The environment ID must respect the format (EnvName)-v(version) , where the version number starts at v0 . For examples on how to register environments, please see our jumanji/__init__.py file. 1 Note that Jumanji doesn't allow users to overwrite the registration of an existing environment. To verify that your custom environment has been registered correctly, you can inspect the listing of registered environments using the registered_environments util.","title":"Registration"},{"location":"guides/registration/#environment-registry","text":"Jumanji adopts the convention defined in Gym of having an environment registry and a make function to instantiate environments.","title":"Environment Registry"},{"location":"guides/registration/#create-an-environment","text":"To instantiate a Jumanji registered environment, we provide the convenient function jumanji.make . It can be used as follows: 1 2 3 4 5 6 import jax import jumanji env = jumanji . make ( 'BinPack-v1' ) key = jax . random . PRNGKey ( 0 ) state , timestep = env . reset ( key ) The environment ID is composed of two parts, the environment name and its version. To get the full list of registered environments, you can use the registered_environments util. \u26a0\ufe0f Warning 1 2 3 4 Users can provide additional key-word arguments in the call to `jumanji.make(env_id, ...)`. These are then passed to the class constructor. Because they can be used to overwrite the intended configuration of the environment when registered, we discourage users to do so. However, we are mindful of particular use cases that might require this flexibility. Although the make function provides a unified way to instantiate environments, users can always instantiate them by importing the corresponding environment class.","title":"Create an environment"},{"location":"guides/registration/#register-your-environment","text":"In addition to the environments available in Jumanji, users can register their custom environment and access them through the familiar jumanji.make function. Assuming you created an environment by subclassing Jumanji Environment base class, you can register it as follows: 1 2 3 4 5 6 7 from jumanji import register register ( id = \"CustomEnv-v0\" , # format: (env_name)-v(version) entry_point = \"path.to.your.package:CustomEnv\" , # class constructor kwargs = { ... }, # environment configuration ) To successfully register your environment, make sure to provide the right path to your class constructor. The kwargs argument is there to configurate the environment and allow you to register scenarios with a specific set of arguments. The environment ID must respect the format (EnvName)-v(version) , where the version number starts at v0 . For examples on how to register environments, please see our jumanji/__init__.py file. 1 Note that Jumanji doesn't allow users to overwrite the registration of an existing environment. To verify that your custom environment has been registered correctly, you can inspect the listing of registered environments using the registered_environments util.","title":"Register your environment"},{"location":"guides/training/","text":"Training # Jumanji provides a training script train.py to train an online agent on a specified Jumanji environment given an environment-specific network. Agents # Jumanji provides two example agents in jumanji/training/agents/ to get you started with training on Jumanji environments: Random agent: uses the action mask to randomly sample valid actions. A2C agent: online advantage actor-critic agent that follows from [Mnih et al., 2016] . Configuration # In each environment-specific config YAML file, you will see a \"training\" section like below: 1 2 3 4 5 training : num_epochs : 1000 num_learner_steps_per_epoch : 50 n_steps : 20 total_batch_size : 64 Here, num_epochs corresponds to the number of data points in your plots. An epoch can be thought as an iteration. num_learner_steps_per_epoch is the number of learner steps that happen in each epoch. After every learner step, the A2C agent's policy is updated. n_steps is the sequence length (consecutive environment steps in a batch). total_batch_size is the number of environments that are run in parallel. So in the above example, 64 environments are running in parallel. Each of these 64 environments run 20 environment steps. After this, the agent's policy is updated via SGD. This constitutes a single learner step. 50 such learner steps are done for the epoch in question. After this, evaluation is done using the updated policy. The above procedure is done for 1000 epochs. Evaluation # Two types of evaluation are recorded: Stochastic evaluation (same policy used during training) Greedy evaluation (argmax over the action logits)","title":"Training"},{"location":"guides/training/#training","text":"Jumanji provides a training script train.py to train an online agent on a specified Jumanji environment given an environment-specific network.","title":"Training"},{"location":"guides/training/#agents","text":"Jumanji provides two example agents in jumanji/training/agents/ to get you started with training on Jumanji environments: Random agent: uses the action mask to randomly sample valid actions. A2C agent: online advantage actor-critic agent that follows from [Mnih et al., 2016] .","title":"Agents"},{"location":"guides/training/#configuration","text":"In each environment-specific config YAML file, you will see a \"training\" section like below: 1 2 3 4 5 training : num_epochs : 1000 num_learner_steps_per_epoch : 50 n_steps : 20 total_batch_size : 64 Here, num_epochs corresponds to the number of data points in your plots. An epoch can be thought as an iteration. num_learner_steps_per_epoch is the number of learner steps that happen in each epoch. After every learner step, the A2C agent's policy is updated. n_steps is the sequence length (consecutive environment steps in a batch). total_batch_size is the number of environments that are run in parallel. So in the above example, 64 environments are running in parallel. Each of these 64 environments run 20 environment steps. After this, the agent's policy is updated via SGD. This constitutes a single learner step. 50 such learner steps are done for the epoch in question. After this, evaluation is done using the updated policy. The above procedure is done for 1000 epochs.","title":"Configuration"},{"location":"guides/training/#evaluation","text":"Two types of evaluation are recorded: Stochastic evaluation (same policy used during training) Greedy evaluation (argmax over the action logits)","title":"Evaluation"},{"location":"guides/wrappers/","text":"Wrappers # The Wrapper interface is used for extending Jumanji Environment to add features like auto reset and vectorised environments. Jumanji provides wrappers to convert a Jumanji Environment to a DeepMind or Gym environment. Jumanji to DeepMind Environment # We can also convert our Jumanji environments to a DeepMind environment: 1 2 3 4 5 6 7 8 9 import jumanji.wrappers env = jumanji . make ( \"Snake-6x6-v0\" ) dm_env = jumanji . wrappers . JumanjiToDMEnvWrapper ( env ) timestep = dm_env . reset () action = dm_env . action_spec () . generate_value () next_timestep = dm_env . step ( action ) ... Jumanji To Gym # We can also convert our Jumanji environments to a Gym environment! Below is an example of how to convert a Jumanji environment into a Gym environment. 1 2 3 4 5 6 7 8 9 import jumanji.wrappers env = jumanji . make ( \"Snake-6x6-v0\" ) gym_env = jumanji . wrappers . JumanjiToGymWrapper ( env ) obs = gym_env . reset () action = gym_env . action_space . sample () observation , reward , done , extra = gym_env . step ( action ) ... Auto-reset an Environment # Below is an example of how to extend the functionality of the Snake environment to automatically reset whenever the environment reaches a terminal state. The Snake game terminates when the snake hits the wall, using the AutoResetWrapper the environment will be reset once a terminal state has been reached. 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 import jax.random import jumanji.wrappers env = jumanji . make ( \"Snake-6x6-v0\" ) env = jumanji . wrappers . AutoResetWrapper ( env ) key = jax . random . PRNGKey ( 0 ) state , timestep = env . reset ( key ) print ( \"New episode\" ) for i in range ( 100 ): action = env . action_spec () . generate_value () # Returns jnp.array(0) when using Snake. state , timestep = env . step ( state , action ) if timestep . first (): print ( \"New episode\" )","title":"Wrapper"},{"location":"guides/wrappers/#wrappers","text":"The Wrapper interface is used for extending Jumanji Environment to add features like auto reset and vectorised environments. Jumanji provides wrappers to convert a Jumanji Environment to a DeepMind or Gym environment.","title":"Wrappers"},{"location":"guides/wrappers/#jumanji-to-deepmind-environment","text":"We can also convert our Jumanji environments to a DeepMind environment: 1 2 3 4 5 6 7 8 9 import jumanji.wrappers env = jumanji . make ( \"Snake-6x6-v0\" ) dm_env = jumanji . wrappers . JumanjiToDMEnvWrapper ( env ) timestep = dm_env . reset () action = dm_env . action_spec () . generate_value () next_timestep = dm_env . step ( action ) ...","title":"Jumanji to DeepMind Environment"},{"location":"guides/wrappers/#jumanji-to-gym","text":"We can also convert our Jumanji environments to a Gym environment! Below is an example of how to convert a Jumanji environment into a Gym environment. 1 2 3 4 5 6 7 8 9 import jumanji.wrappers env = jumanji . make ( \"Snake-6x6-v0\" ) gym_env = jumanji . wrappers . JumanjiToGymWrapper ( env ) obs = gym_env . reset () action = gym_env . action_space . sample () observation , reward , done , extra = gym_env . step ( action ) ...","title":"Jumanji To Gym"},{"location":"guides/wrappers/#auto-reset-an-environment","text":"Below is an example of how to extend the functionality of the Snake environment to automatically reset whenever the environment reaches a terminal state. The Snake game terminates when the snake hits the wall, using the AutoResetWrapper the environment will be reset once a terminal state has been reached. 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 import jax.random import jumanji.wrappers env = jumanji . make ( \"Snake-6x6-v0\" ) env = jumanji . wrappers . AutoResetWrapper ( env ) key = jax . random . PRNGKey ( 0 ) state , timestep = env . reset ( key ) print ( \"New episode\" ) for i in range ( 100 ): action = env . action_spec () . generate_value () # Returns jnp.array(0) when using Snake. state , timestep = env . step ( state , action ) if timestep . first (): print ( \"New episode\" )","title":"Auto-reset an Environment"}]} \ No newline at end of file +{"config":{"indexing":"full","lang":["en"],"min_search_length":3,"prebuild_index":false,"separator":"[\\s\\-]+"},"docs":[{"location":"","text":"Environments | Installation | Quickstart | Training | Citation | Docs Welcome to the Jungle! \ud83c\udf34 # Jumanji is a diverse suite of scalable reinforcement learning environments written in JAX. Jumanji is helping pioneer a new wave of hardware-accelerated research and development in the field of RL. Jumanji's high-speed environments enable faster iteration and large-scale experimentation while simultaneously reducing complexity. Originating in the Research Team at InstaDeep , Jumanji is now developed jointly with the open-source community. To join us in these efforts, reach out, raise issues and read our contribution guidelines or just star \ud83c\udf1f to stay up to date with the latest developments! Goals \ud83d\ude80 # Provide a simple, well-tested API for JAX-based environments. Make research in RL more accessible. Facilitate the research on RL for problems in the industry and help close the gap between research and industrial applications. Provide environments whose difficulty can be scaled to be arbitrarily hard. Overview \ud83e\udd9c # \ud83e\udd51 Environment API : core abstractions for JAX-based environments. \ud83d\udd79\ufe0f Environment Suite : a collection of RL environments ranging from simple games to NP-hard combinatorial problems. \ud83c\udf6c Wrappers : easily connect to your favourite RL frameworks and libraries such as Acme , Stable Baselines3 , RLlib , OpenAI Gym and DeepMind-Env through our dm_env and gym wrappers. \ud83c\udf93 Examples : guides to facilitate Jumanji's adoption and highlight the added value of JAX-based environments. \ud83c\udfce\ufe0f Training: example agents that can be used as inspiration for the agents one may implement in their research. Environments \ud83c\udf0d Jumanji provides a diverse range of environments ranging from simple games to NP-hard combinatorial problems. Environment Category Registered Version(s) Source Description \ud83d\udd22 Game2048 Logic Game2048-v1 code doc \ud83c\udfa8 GraphColoring Logic GraphColoring-v0 code doc \ud83d\udca3 Minesweeper Logic Minesweeper-v0 code doc \ud83c\udfb2 RubiksCube Logic RubiksCube-v0 RubiksCube-partly-scrambled-v0 code doc \u270f\ufe0f Sudoku Logic Sudoku-v0 Sudoku-very-easy-v0 code doc \ud83d\udce6 BinPack (3D BinPacking Problem) Packing BinPack-v2 code doc \ud83c\udfed JobShop (Job Shop Scheduling Problem) Packing JobShop-v0 code doc \ud83c\udf92 Knapsack Packing Knapsack-v1 code doc \u2592 Tetris Packing Tetris-v0 code doc \ud83e\uddf9 Cleaner Routing Cleaner-v0 code doc Connector Routing Connector-v2 code doc \ud83d\ude9a CVRP (Capacitated Vehicle Routing Problem) Routing CVRP-v1 code doc \ud83d\ude9a MultiCVRP (Multi-Agent Capacitated Vehicle Routing Problem) Routing MultiCVRP-v0 code doc Maze Routing Maze-v0 code doc RobotWarehouse Routing RobotWarehouse-v0 code doc \ud83d\udc0d Snake Routing Snake-v1 code doc \ud83d\udcec TSP (Travelling Salesman Problem) Routing TSP-v1 code doc Multi Minimum Spanning Tree Problem Routing MMST-v0 code doc Installation \ud83c\udfac You can install the latest release of Jumanji from PyPI: 1 pip install jumanji Alternatively, you can install the latest development version directly from GitHub: 1 pip install git+https://github.com/instadeepai/jumanji.git Jumanji has been tested on Python 3.8 and 3.9. Note that because the installation of JAX differs depending on your hardware accelerator, we advise users to explicitly install the correct JAX version (see the official installation guide ). Rendering: Matplotlib is used for rendering all the environments. To visualize the environments you will need a GUI backend. For example, on Linux, you can install Tk via: apt-get install python3-tk , or using conda: conda install tk . Check out Matplotlib backends for a list of backends you can use. Quickstart \u26a1 RL practitioners will find Jumanji's interface familiar as it combines the widely adopted OpenAI Gym and DeepMind Environment interfaces. From OpenAI Gym, we adopted the idea of a registry and the render method, while our TimeStep structure is inspired by DeepMind Environment. Basic Usage \ud83e\uddd1\u200d\ud83d\udcbb # 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 import jax import jumanji # Instantiate a Jumanji environment using the registry env = jumanji . make ( 'Snake-v1' ) # Reset your (jit-able) environment key = jax . random . PRNGKey ( 0 ) state , timestep = jax . jit ( env . reset )( key ) # (Optional) Render the env state env . render ( state ) # Interact with the (jit-able) environment action = env . action_spec () . generate_value () # Action selection (dummy value here) state , timestep = jax . jit ( env . step )( state , action ) # Take a step and observe the next state and time step state represents the internal state of the environment: it contains all the information required to take a step when executing an action. This should not be confused with the observation contained in the timestep , which is the information perceived by the agent. timestep is a dataclass containing step_type , reward , discount , observation and extras . This structure is similar to dm_env.TimeStep except for the extras field that was added to allow users to log environments metrics that are neither part of the agent's observation nor part of the environment's internal state. Advanced Usage \ud83e\uddd1\u200d\ud83d\udd2c # Being written in JAX, Jumanji's environments benefit from many of its features including automatic vectorization/parallelization ( jax.vmap , jax.pmap ) and JIT-compilation ( jax.jit ), which can be composed arbitrarily. We provide an example of a more advanced usage in the advanced usage guide . Registry and Versioning \ud83d\udcd6 # Like OpenAI Gym, Jumanji keeps a strict versioning of its environments for reproducibility reasons. We maintain a registry of standard environments with their configuration. For each environment, a version suffix is appended, e.g. Snake-v1 . When changes are made to environments that might impact learning results, the version number is incremented by one to prevent potential confusion. For a full list of registered versions of each environment, check out the documentation . Training \ud83c\udfce\ufe0f To showcase how to train RL agents on Jumanji environments, we provide a random agent and a vanilla actor-critic (A2C) agent. These agents can be found in jumanji/training/ . Because the environment framework in Jumanji is so flexible, it allows pretty much any problem to be implemented as a Jumanji environment, giving rise to very diverse observations. For this reason, environment-specific networks are required to capture the symmetries of each environment. Alongside the A2C agent implementation, we provide examples of such environment-specific actor-critic networks in jumanji/training/networks . \u26a0\ufe0f The example agents in jumanji/training are only meant to serve as inspiration for how one can implement an agent. Jumanji is first and foremost a library of environments - as such, the agents and networks will not be maintained to a production standard. For more information on how to use the example agents, see the training guide . Contributing \ud83e\udd1d # Contributions are welcome! See our issue tracker for good first issues . Please read our contributing guidelines for details on how to submit pull requests, our Contributor License Agreement, and community guidelines. Citing Jumanji \u270f\ufe0f If you use Jumanji in your work, please cite the library using: 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 @misc{bonnet2023jumanji, title={Jumanji: a Diverse Suite of Scalable Reinforcement Learning Environments in JAX}, author={ Cl\u00e9ment Bonnet and Daniel Luo and Donal Byrne and Shikha Surana and Vincent Coyette and Paul Duckworth and Laurence I. Midgley and Tristan Kalloniatis and Sasha Abramowitz and Cemlyn N. Waters and Andries P. Smit and Nathan Grinsztajn and Ulrich A. Mbou Sob and Omayma Mahjoub and Elshadai Tegegn and Mohamed A. Mimouni and Raphael Boige and Ruan de Kock and Daniel Furelos-Blanco and Victor Le and Arnu Pretorius and Alexandre Laterre }, year={2023}, eprint={2306.09884}, url={https://arxiv.org/abs/2306.09884}, archivePrefix={arXiv}, primaryClass={cs.LG} } See Also \ud83d\udd0e # Other works have embraced the approach of writing RL environments in JAX. In particular, we suggest users check out the following sister repositories: \ud83e\udd16 Qdax is a library to accelerate Quality-Diversity and neuro-evolution algorithms through hardware accelerators and parallelization. \ud83c\udf33 Evojax provides tools to enable neuroevolution algorithms to work with neural networks running across multiple TPU/GPUs. \ud83e\uddbe Brax is a differentiable physics engine that simulates environments made up of rigid bodies, joints, and actuators. \ud83c\udfcb\ufe0f\u200d Gymnax implements classic environments including classic control, bsuite, MinAtar and a collection of meta RL tasks. \ud83c\udfb2 Pgx provides classic board game environments like Backgammon, Shogi, and Go. Acknowledgements \ud83d\ude4f # The development of this library was supported with Cloud TPUs from Google's TPU Research Cloud (TRC) \ud83c\udf24.","title":"Home"},{"location":"#welcome-to-the-jungle","text":"Jumanji is a diverse suite of scalable reinforcement learning environments written in JAX. Jumanji is helping pioneer a new wave of hardware-accelerated research and development in the field of RL. Jumanji's high-speed environments enable faster iteration and large-scale experimentation while simultaneously reducing complexity. Originating in the Research Team at InstaDeep , Jumanji is now developed jointly with the open-source community. To join us in these efforts, reach out, raise issues and read our contribution guidelines or just star \ud83c\udf1f to stay up to date with the latest developments!","title":"Welcome to the Jungle! \ud83c\udf34"},{"location":"#goals","text":"Provide a simple, well-tested API for JAX-based environments. Make research in RL more accessible. Facilitate the research on RL for problems in the industry and help close the gap between research and industrial applications. Provide environments whose difficulty can be scaled to be arbitrarily hard.","title":"Goals \ud83d\ude80"},{"location":"#overview","text":"\ud83e\udd51 Environment API : core abstractions for JAX-based environments. \ud83d\udd79\ufe0f Environment Suite : a collection of RL environments ranging from simple games to NP-hard combinatorial problems. \ud83c\udf6c Wrappers : easily connect to your favourite RL frameworks and libraries such as Acme , Stable Baselines3 , RLlib , OpenAI Gym and DeepMind-Env through our dm_env and gym wrappers. \ud83c\udf93 Examples : guides to facilitate Jumanji's adoption and highlight the added value of JAX-based environments. \ud83c\udfce\ufe0f Training: example agents that can be used as inspiration for the agents one may implement in their research.","title":"Overview \ud83e\udd9c"},{"location":"#basic-usage","text":"1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 import jax import jumanji # Instantiate a Jumanji environment using the registry env = jumanji . make ( 'Snake-v1' ) # Reset your (jit-able) environment key = jax . random . PRNGKey ( 0 ) state , timestep = jax . jit ( env . reset )( key ) # (Optional) Render the env state env . render ( state ) # Interact with the (jit-able) environment action = env . action_spec () . generate_value () # Action selection (dummy value here) state , timestep = jax . jit ( env . step )( state , action ) # Take a step and observe the next state and time step state represents the internal state of the environment: it contains all the information required to take a step when executing an action. This should not be confused with the observation contained in the timestep , which is the information perceived by the agent. timestep is a dataclass containing step_type , reward , discount , observation and extras . This structure is similar to dm_env.TimeStep except for the extras field that was added to allow users to log environments metrics that are neither part of the agent's observation nor part of the environment's internal state.","title":"Basic Usage \ud83e\uddd1\u200d\ud83d\udcbb"},{"location":"#advanced-usage","text":"Being written in JAX, Jumanji's environments benefit from many of its features including automatic vectorization/parallelization ( jax.vmap , jax.pmap ) and JIT-compilation ( jax.jit ), which can be composed arbitrarily. We provide an example of a more advanced usage in the advanced usage guide .","title":"Advanced Usage \ud83e\uddd1\u200d\ud83d\udd2c"},{"location":"#registry-and-versioning","text":"Like OpenAI Gym, Jumanji keeps a strict versioning of its environments for reproducibility reasons. We maintain a registry of standard environments with their configuration. For each environment, a version suffix is appended, e.g. Snake-v1 . When changes are made to environments that might impact learning results, the version number is incremented by one to prevent potential confusion. For a full list of registered versions of each environment, check out the documentation .","title":"Registry and Versioning \ud83d\udcd6"},{"location":"#contributing","text":"Contributions are welcome! See our issue tracker for good first issues . Please read our contributing guidelines for details on how to submit pull requests, our Contributor License Agreement, and community guidelines.","title":"Contributing \ud83e\udd1d"},{"location":"#see-also","text":"Other works have embraced the approach of writing RL environments in JAX. In particular, we suggest users check out the following sister repositories: \ud83e\udd16 Qdax is a library to accelerate Quality-Diversity and neuro-evolution algorithms through hardware accelerators and parallelization. \ud83c\udf33 Evojax provides tools to enable neuroevolution algorithms to work with neural networks running across multiple TPU/GPUs. \ud83e\uddbe Brax is a differentiable physics engine that simulates environments made up of rigid bodies, joints, and actuators. \ud83c\udfcb\ufe0f\u200d Gymnax implements classic environments including classic control, bsuite, MinAtar and a collection of meta RL tasks. \ud83c\udfb2 Pgx provides classic board game environments like Backgammon, Shogi, and Go.","title":"See Also \ud83d\udd0e"},{"location":"#acknowledgements","text":"The development of this library was supported with Cloud TPUs from Google's TPU Research Cloud (TRC) \ud83c\udf24.","title":"Acknowledgements \ud83d\ude4f"},{"location":"api/env/","text":"Environment ( ABC , Generic ) # Environment written in Jax that differs from the gym API to make the step and reset functions jittable. The state contains all the dynamics and data needed to step the environment, no computation stored in attributes of self. The API is inspired by brax . unwrapped : Environment property readonly # reset ( self , key : PRNGKeyArray ) -> Tuple [ ~ State , jumanji . types . TimeStep ] # Resets the environment to an initial state. Parameters: Name Type Description Default key PRNGKeyArray random key used to reset the environment. required Returns: Type Description state State object corresponding to the new state of the environment, timestep: TimeStep object corresponding the first timestep returned by the environment, step ( self , state : ~ State , action : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ]) -> Tuple [ ~ State , jumanji . types . TimeStep ] # Run one timestep of the environment's dynamics. Parameters: Name Type Description Default state ~State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] Array containing the action to take. required Returns: Type Description state State object corresponding to the next state of the environment, timestep: TimeStep object corresponding the timestep returned by the environment, observation_spec ( self ) -> Spec # Returns the observation spec. Returns: Type Description observation_spec a NestedSpec tree of spec. action_spec ( self ) -> Spec # Returns the action spec. Returns: Type Description action_spec a NestedSpec tree of spec. reward_spec ( self ) -> Array # Describes the reward returned by the environment. By default, this is assumed to be a single float. Returns: Type Description reward_spec a specs.Array spec. discount_spec ( self ) -> BoundedArray # Describes the discount returned by the environment. By default, this is assumed to be a single float between 0 and 1. Returns: Type Description discount_spec a specs.BoundedArray spec. render ( self , state : ~ State ) -> Any # Render frames of the environment for a given state. Parameters: Name Type Description Default state ~State State object containing the current dynamics of the environment. required close ( self ) -> None # Perform any necessary cleanup. __enter__ ( self ) -> Environment special # __exit__ ( self , * args : Any ) -> None special # Calls :meth: close() .","title":"Base"},{"location":"api/env/#jumanji.env.Environment","text":"Environment written in Jax that differs from the gym API to make the step and reset functions jittable. The state contains all the dynamics and data needed to step the environment, no computation stored in attributes of self. The API is inspired by brax .","title":"Environment"},{"location":"api/env/#jumanji.env.Environment.unwrapped","text":"","title":"unwrapped"},{"location":"api/env/#jumanji.env.Environment.reset","text":"Resets the environment to an initial state. Parameters: Name Type Description Default key PRNGKeyArray random key used to reset the environment. required Returns: Type Description state State object corresponding to the new state of the environment, timestep: TimeStep object corresponding the first timestep returned by the environment,","title":"reset()"},{"location":"api/env/#jumanji.env.Environment.step","text":"Run one timestep of the environment's dynamics. Parameters: Name Type Description Default state ~State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] Array containing the action to take. required Returns: Type Description state State object corresponding to the next state of the environment, timestep: TimeStep object corresponding the timestep returned by the environment,","title":"step()"},{"location":"api/env/#jumanji.env.Environment.observation_spec","text":"Returns the observation spec. Returns: Type Description observation_spec a NestedSpec tree of spec.","title":"observation_spec()"},{"location":"api/env/#jumanji.env.Environment.action_spec","text":"Returns the action spec. Returns: Type Description action_spec a NestedSpec tree of spec.","title":"action_spec()"},{"location":"api/env/#jumanji.env.Environment.reward_spec","text":"Describes the reward returned by the environment. By default, this is assumed to be a single float. Returns: Type Description reward_spec a specs.Array spec.","title":"reward_spec()"},{"location":"api/env/#jumanji.env.Environment.discount_spec","text":"Describes the discount returned by the environment. By default, this is assumed to be a single float between 0 and 1. Returns: Type Description discount_spec a specs.BoundedArray spec.","title":"discount_spec()"},{"location":"api/env/#jumanji.env.Environment.render","text":"Render frames of the environment for a given state. Parameters: Name Type Description Default state ~State State object containing the current dynamics of the environment. required","title":"render()"},{"location":"api/env/#jumanji.env.Environment.close","text":"Perform any necessary cleanup.","title":"close()"},{"location":"api/env/#jumanji.env.Environment.__enter__","text":"","title":"__enter__()"},{"location":"api/env/#jumanji.env.Environment.__exit__","text":"Calls :meth: close() .","title":"__exit__()"},{"location":"api/types/","text":"types # StepType ( int8 ) # Defines the status of a TimeStep within a sequence. First: 0 Mid: 1 Last: 2 TimeStep ( Generic , Mapping ) dataclass # Copied from dm_env.TimeStep with the goal of making it a Jax Type. The original dm_env.TimeStep is not a Jax type because inheriting a namedtuple is not treated as a valid Jax type (https://github.com/google/jax/issues/806). A TimeStep contains the data emitted by an environment at each step of interaction. A TimeStep holds a step_type , an observation (typically a NumPy array or a dict or list of arrays), and an associated reward and discount . The first TimeStep in a sequence will have StepType.FIRST . The final TimeStep will have StepType.LAST . All other TimeStep s in a sequence will have `StepType.MID. Attributes: Name Type Description step_type StepType A StepType enum value. reward Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] A scalar, NumPy array, nested dict, list or tuple of rewards; or None if step_type is StepType.FIRST , i.e. at the start of a sequence. discount Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] A scalar, NumPy array, nested dict, list or tuple of discount values in the range [0, 1] , or None if step_type is StepType.FIRST , i.e. at the start of a sequence. observation ~Observation A NumPy array, or a nested dict, list or tuple of arrays. Scalar values that can be cast to NumPy arrays (e.g. Python floats) are also valid in place of a scalar array. extras Optional[Dict] environment metric(s) or information returned by the environment but not observed by the agent (hence not in the observation). For example, it could be whether an invalid action was taken. In most environments, extras is None. step_type : StepType dataclass-field # reward : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ] dataclass-field # discount : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ] dataclass-field # observation : ~ Observation dataclass-field # extras : Optional [ Dict ] dataclass-field # __eq__ ( self , other ) special # __init__ ( self , step_type : StepType , reward : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ], discount : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ], observation : ~ Observation , extras : Optional [ Dict ] = None ) -> None special # __repr__ ( self ) special # __getitem__ ( self , x ) special # __len__ ( self ) special # __iter__ ( self ) special # first ( self ) -> Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ] # mid ( self ) -> Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ] # last ( self ) -> Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ] # from_tuple ( args ) # to_tuple ( self ) # replace ( self , ** kwargs ) # __getstate__ ( self ) special # __setstate__ ( self , state ) special # restart ( observation : ~ Observation , extras : Optional [ Dict ] = None , shape : Union [ int , Sequence [ int ]] = ()) -> TimeStep # Returns a TimeStep with step_type set to StepType.FIRST . Parameters: Name Type Description Default observation ~Observation array or tree of arrays. required extras Optional[Dict] environment metric(s) or information returned by the environment but not observed by the agent (hence not in the observation). For example, it could be whether an invalid action was taken. In most environments, extras is None. None shape Union[int, Sequence[int]] optional parameter to specify the shape of the rewards and discounts. Allows multi-agent environment compatibility. Defaults to () for scalar reward and discount. () Returns: Type Description TimeStep TimeStep identified as a reset. transition ( reward : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ], observation : ~ Observation , discount : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ] = None , extras : Optional [ Dict ] = None , shape : Union [ int , Sequence [ int ]] = ()) -> TimeStep # Returns a TimeStep with step_type set to StepType.MID . Parameters: Name Type Description Default reward Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] array. required observation ~Observation array or tree of arrays. required discount Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] array. None extras Optional[Dict] environment metric(s) or information returned by the environment but not observed by the agent (hence not in the observation). For example, it could be whether an invalid action was taken. In most environments, extras is None. None shape Union[int, Sequence[int]] optional parameter to specify the shape of the rewards and discounts. Allows multi-agent environment compatibility. Defaults to () for scalar reward and discount. () Returns: Type Description TimeStep TimeStep identified as a transition. termination ( reward : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ], observation : ~ Observation , extras : Optional [ Dict ] = None , shape : Union [ int , Sequence [ int ]] = ()) -> TimeStep # Returns a TimeStep with step_type set to StepType.LAST . Parameters: Name Type Description Default reward Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] array. required observation ~Observation array or tree of arrays. required extras Optional[Dict] environment metric(s) or information returned by the environment but not observed by the agent (hence not in the observation). For example, it could be whether an invalid action was taken. In most environments, extras is None. None shape Union[int, Sequence[int]] optional parameter to specify the shape of the rewards and discounts. Allows multi-agent environment compatibility. Defaults to () for scalar reward and discount. () Returns: Type Description TimeStep TimeStep identified as the termination of an episode. truncation ( reward : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ], observation : ~ Observation , discount : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ] = None , extras : Optional [ Dict ] = None , shape : Union [ int , Sequence [ int ]] = ()) -> TimeStep # Returns a TimeStep with step_type set to StepType.LAST . Parameters: Name Type Description Default reward Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] array. required observation ~Observation array or tree of arrays. required discount Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] array. None extras Optional[Dict] environment metric(s) or information returned by the environment but not observed by the agent (hence not in the observation). For example, it could be whether an invalid action was taken. In most environments, extras is None. None shape Union[int, Sequence[int]] optional parameter to specify the shape of the rewards and discounts. Allows multi-agent environment compatibility. Defaults to () for scalar reward and discount. () Returns: Type Description TimeStep TimeStep identified as the truncation of an episode. get_valid_dtype ( dtype : Union [ numpy . dtype , type ]) -> dtype # Cast a dtype taking into account the user type precision. E.g., if 64 bit is not enabled, jnp.dtype(jnp.float_) is still float64. By passing the given dtype through jnp.empty we get the supported dtype of float32. Parameters: Name Type Description Default dtype Union[numpy.dtype, type] jax numpy dtype or string specifying the array dtype. required Returns: Type Description dtype dtype converted to the correct type precision.","title":"Types"},{"location":"api/types/#jumanji.types","text":"","title":"types"},{"location":"api/types/#jumanji.types.StepType","text":"Defines the status of a TimeStep within a sequence. First: 0 Mid: 1 Last: 2","title":"StepType"},{"location":"api/types/#jumanji.types.TimeStep","text":"Copied from dm_env.TimeStep with the goal of making it a Jax Type. The original dm_env.TimeStep is not a Jax type because inheriting a namedtuple is not treated as a valid Jax type (https://github.com/google/jax/issues/806). A TimeStep contains the data emitted by an environment at each step of interaction. A TimeStep holds a step_type , an observation (typically a NumPy array or a dict or list of arrays), and an associated reward and discount . The first TimeStep in a sequence will have StepType.FIRST . The final TimeStep will have StepType.LAST . All other TimeStep s in a sequence will have `StepType.MID. Attributes: Name Type Description step_type StepType A StepType enum value. reward Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] A scalar, NumPy array, nested dict, list or tuple of rewards; or None if step_type is StepType.FIRST , i.e. at the start of a sequence. discount Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] A scalar, NumPy array, nested dict, list or tuple of discount values in the range [0, 1] , or None if step_type is StepType.FIRST , i.e. at the start of a sequence. observation ~Observation A NumPy array, or a nested dict, list or tuple of arrays. Scalar values that can be cast to NumPy arrays (e.g. Python floats) are also valid in place of a scalar array. extras Optional[Dict] environment metric(s) or information returned by the environment but not observed by the agent (hence not in the observation). For example, it could be whether an invalid action was taken. In most environments, extras is None.","title":"TimeStep"},{"location":"api/types/#jumanji.types.restart","text":"Returns a TimeStep with step_type set to StepType.FIRST . Parameters: Name Type Description Default observation ~Observation array or tree of arrays. required extras Optional[Dict] environment metric(s) or information returned by the environment but not observed by the agent (hence not in the observation). For example, it could be whether an invalid action was taken. In most environments, extras is None. None shape Union[int, Sequence[int]] optional parameter to specify the shape of the rewards and discounts. Allows multi-agent environment compatibility. Defaults to () for scalar reward and discount. () Returns: Type Description TimeStep TimeStep identified as a reset.","title":"restart()"},{"location":"api/types/#jumanji.types.transition","text":"Returns a TimeStep with step_type set to StepType.MID . Parameters: Name Type Description Default reward Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] array. required observation ~Observation array or tree of arrays. required discount Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] array. None extras Optional[Dict] environment metric(s) or information returned by the environment but not observed by the agent (hence not in the observation). For example, it could be whether an invalid action was taken. In most environments, extras is None. None shape Union[int, Sequence[int]] optional parameter to specify the shape of the rewards and discounts. Allows multi-agent environment compatibility. Defaults to () for scalar reward and discount. () Returns: Type Description TimeStep TimeStep identified as a transition.","title":"transition()"},{"location":"api/types/#jumanji.types.termination","text":"Returns a TimeStep with step_type set to StepType.LAST . Parameters: Name Type Description Default reward Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] array. required observation ~Observation array or tree of arrays. required extras Optional[Dict] environment metric(s) or information returned by the environment but not observed by the agent (hence not in the observation). For example, it could be whether an invalid action was taken. In most environments, extras is None. None shape Union[int, Sequence[int]] optional parameter to specify the shape of the rewards and discounts. Allows multi-agent environment compatibility. Defaults to () for scalar reward and discount. () Returns: Type Description TimeStep TimeStep identified as the termination of an episode.","title":"termination()"},{"location":"api/types/#jumanji.types.truncation","text":"Returns a TimeStep with step_type set to StepType.LAST . Parameters: Name Type Description Default reward Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] array. required observation ~Observation array or tree of arrays. required discount Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] array. None extras Optional[Dict] environment metric(s) or information returned by the environment but not observed by the agent (hence not in the observation). For example, it could be whether an invalid action was taken. In most environments, extras is None. None shape Union[int, Sequence[int]] optional parameter to specify the shape of the rewards and discounts. Allows multi-agent environment compatibility. Defaults to () for scalar reward and discount. () Returns: Type Description TimeStep TimeStep identified as the truncation of an episode.","title":"truncation()"},{"location":"api/types/#jumanji.types.get_valid_dtype","text":"Cast a dtype taking into account the user type precision. E.g., if 64 bit is not enabled, jnp.dtype(jnp.float_) is still float64. By passing the given dtype through jnp.empty we get the supported dtype of float32. Parameters: Name Type Description Default dtype Union[numpy.dtype, type] jax numpy dtype or string specifying the array dtype. required Returns: Type Description dtype dtype converted to the correct type precision.","title":"get_valid_dtype()"},{"location":"api/wrappers/","text":"wrappers # Wrapper ( Environment , Generic ) # Wraps the environment to allow modular transformations. Source: https://github.com/google/brax/blob/main/brax/envs/env.py#L72 unwrapped : Environment property readonly # Returns the wrapped env. __init__ ( self , env : Environment ) special # reset ( self , key : PRNGKeyArray ) -> Tuple [ ~ State , jumanji . types . TimeStep ] # Resets the environment to an initial state. Parameters: Name Type Description Default key PRNGKeyArray random key used to reset the environment. required Returns: Type Description state State object corresponding to the new state of the environment, timestep: TimeStep object corresponding the first timestep returned by the environment, step ( self , state : ~ State , action : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ]) -> Tuple [ ~ State , jumanji . types . TimeStep ] # Run one timestep of the environment's dynamics. Parameters: Name Type Description Default state ~State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] Array containing the action to take. required Returns: Type Description state State object corresponding to the next state of the environment, timestep: TimeStep object corresponding the timestep returned by the environment, observation_spec ( self ) -> Spec # Returns the observation spec. action_spec ( self ) -> Spec # Returns the action spec. render ( self , state : ~ State ) -> Any # Compute render frames during initialisation of the environment. Parameters: Name Type Description Default state ~State State object containing the dynamics of the environment. required close ( self ) -> None # Perform any necessary cleanup. Environments will automatically :meth: close() themselves when garbage collected or when the program exits. __enter__ ( self ) -> Wrapper special # __exit__ ( self , * args : Any ) -> None special # JumanjiToDMEnvWrapper ( Environment ) # A wrapper that converts Environment to dm_env.Environment. unwrapped : Environment property readonly # __init__ ( self , env : Environment , key : Optional [ jax . _src . prng . PRNGKeyArray ] = None ) special # Create the wrapped environment. Parameters: Name Type Description Default env Environment Environment to wrap to a dm_env.Environment . required key Optional[jax._src.prng.PRNGKeyArray] optional key to initialize the Environment with. None reset ( self ) -> TimeStep # Starts a new sequence and returns the first TimeStep of this sequence. Returns: Type Description A `TimeStep` namedtuple containing step_type: A StepType of FIRST . reward: None , indicating the reward is undefined. discount: None , indicating the discount is undefined. observation: A NumPy array, or a nested dict, list or tuple of arrays. Scalar values that can be cast to NumPy arrays (e.g. Python floats) are also valid in place of a scalar array. Must conform to the specification returned by observation_spec() . step ( self , action : ndarray ) -> TimeStep # Updates the environment according to the action and returns a TimeStep . If the environment returned a TimeStep with StepType.LAST at the previous step, this call to step will start a new sequence and action will be ignored. This method will also start a new sequence if called after the environment has been constructed and reset has not been called. Again, in this case action will be ignored. Parameters: Name Type Description Default action ndarray A NumPy array, or a nested dict, list or tuple of arrays corresponding to action_spec() . required Returns: Type Description A `TimeStep` namedtuple containing step_type: A StepType value. reward: Reward at this timestep, or None if step_type is StepType.FIRST . Must conform to the specification returned by reward_spec() . discount: A discount in the range [0, 1], or None if step_type is StepType.FIRST . Must conform to the specification returned by discount_spec() . observation: A NumPy array, or a nested dict, list or tuple of arrays. Scalar values that can be cast to NumPy arrays (e.g. Python floats) are also valid in place of a scalar array. Must conform to the specification returned by observation_spec() . observation_spec ( self ) -> Array # Returns the dm_env observation spec. action_spec ( self ) -> Array # Returns the dm_env action spec. MultiToSingleWrapper ( Wrapper ) # A wrapper that converts a multi-agent Environment to a single-agent Environment. __init__ ( self , env : Environment , reward_aggregator : Callable = < function sum at 0x7fe959e73700 > , discount_aggregator : Callable = < function amax at 0x7fe959e73ee0 > ) special # Create the wrapped environment. Parameters: Name Type Description Default env Environment Environment to wrap to a dm_env.Environment . required reward_aggregator Callable a function to aggregate all agents rewards into a single scalar value, e.g. sum. discount_aggregator Callable a function to aggregate all agents discounts into a single scalar value, e.g. max. reset ( self , key : PRNGKeyArray ) -> Tuple [ ~ State , jumanji . types . TimeStep [ ~ Observation ]] # Resets the environment to an initial state. Parameters: Name Type Description Default key PRNGKeyArray random key used to reset the environment. required Returns: Type Description state State object corresponding to the new state of the environment, timestep: TimeStep object corresponding the first timestep returned by the environment, step ( self , state : ~ State , action : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ]) -> Tuple [ ~ State , jumanji . types . TimeStep [ ~ Observation ]] # Run one timestep of the environment's dynamics. The rewards are aggregated into a single value based on the given reward aggregator. The discount value is set to the largest discount of all the agents. This essentially means that if any single agent is alive, the discount value won't be zero. Parameters: Name Type Description Default state ~State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] Array containing the action to take. required Returns: Type Description state State object corresponding to the next state of the environment, timestep: TimeStep object corresponding the timestep returned by the environment, VmapWrapper ( Wrapper ) # Vectorized Jax env. Please note that all methods that return arrays do not return a batch dimension because the batch size is not known to the VmapWrapper. Methods that omit the batch dimension include: - observation_spec - action_spec - reward_spec - discount_spec reset ( self , key : PRNGKeyArray ) -> Tuple [ ~ State , jumanji . types . TimeStep [ ~ Observation ]] # Resets the environment to an initial state. The first dimension of the key will dictate the number of concurrent environments. To obtain a key with the right first dimension, you may call jax.random.split on key with the parameter num representing the number of concurrent environments. Parameters: Name Type Description Default key PRNGKeyArray random keys used to reset the environments where the first dimension is the number of desired environments. required Returns: Type Description state State object corresponding to the new state of the environments, timestep: TimeStep object corresponding the first timesteps returned by the environments, step ( self , state : ~ State , action : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ]) -> Tuple [ ~ State , jumanji . types . TimeStep [ ~ Observation ]] # Run one timestep of the environment's dynamics. The first dimension of the state will dictate the number of concurrent environments. See VmapWrapper.reset for more details on how to get a state of concurrent environments. Parameters: Name Type Description Default state ~State State object containing the dynamics of the environments. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] Array containing the actions to take. required Returns: Type Description state State object corresponding to the next states of the environments, timestep: TimeStep object corresponding the timesteps returned by the environments, render ( self , state : ~ State ) -> Any # Render the first environment state of the given batch. The remaining elements of the batched state are ignored. Parameters: Name Type Description Default state ~State State object containing the current dynamics of the environment. required AutoResetWrapper ( Wrapper ) # Automatically resets environments that are done. Once the terminal state is reached, the state, observation, and step_type are reset. The observation and step_type of the terminal TimeStep is reset to the reset observation and StepType.LAST, respectively. The reward, discount, and extras retrieved from the transition to the terminal state. WARNING: do not jax.vmap the wrapped environment (e.g. do not use with the VmapWrapper ), which would lead to inefficient computation due to both the step and reset functions being processed each time step is called. Please use the VmapAutoResetWrapper instead. step ( self , state : ~ State , action : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ]) -> Tuple [ ~ State , jumanji . types . TimeStep [ ~ Observation ]] # Step the environment, with automatic resetting if the episode terminates. JumanjiToGymWrapper ( Env ) # A wrapper that converts a Jumanji Environment to one that follows the gym.Env API. unwrapped : Environment property readonly # Returns the base non-wrapped environment. Returns: Type Description Env The base non-wrapped gym.Env instance __init__ ( self , env : Environment , seed : int = 0 , backend : Optional [ str ] = None ) special # Create the Gym environment. Parameters: Name Type Description Default env Environment Environment to wrap to a gym.Env . required seed int the seed that is used to initialize the environment's PRNG. 0 backend Optional[str] the XLA backend. None reset ( self , * , seed : Optional [ int ] = None , return_info : bool = False , options : Optional [ dict ] = None ) -> Union [ Any , Tuple [ Any , Union [ Any ]]] # Resets the environment to an initial state by starting a new sequence and returns the first Observation of this sequence. Returns: Type Description obs an element of the environment's observation_space. info (optional): contains supplementary information such as metrics. step ( self , action : ndarray ) -> Tuple [ Any , float , bool , Optional [ Any ]] # Updates the environment according to the action and returns an Observation . Parameters: Name Type Description Default action ndarray A NumPy array representing the action provided by the agent. required Returns: Type Description observation an element of the environment's observation_space. reward: the amount of reward returned as a result of taking the action. terminated: whether a terminal state is reached. info: contains supplementary information such as metrics. seed ( self , seed : int = 0 ) -> None # Function which sets the seed for the environment's random number generator(s). Parameters: Name Type Description Default seed int the seed value for the random number generator(s). 0 render ( self , mode : str = 'human' ) -> Any # Renders the environment. Parameters: Name Type Description Default mode str currently not used since Jumanji does not currently support modes. 'human' close ( self ) -> None # Closes the environment, important for rendering where pygame is imported. jumanji_to_gym_obs ( observation : ~ Observation ) -> Any # Convert a Jumanji observation into a gym observation. Parameters: Name Type Description Default observation ~Observation JAX pytree with (possibly nested) containers that either have the __dict__ or _asdict methods implemented. required Returns: Type Description Any Numpy array or nested dictionary of numpy arrays.","title":"Wrappers"},{"location":"api/wrappers/#jumanji.wrappers","text":"","title":"wrappers"},{"location":"api/wrappers/#jumanji.wrappers.Wrapper","text":"Wraps the environment to allow modular transformations. Source: https://github.com/google/brax/blob/main/brax/envs/env.py#L72","title":"Wrapper"},{"location":"api/wrappers/#jumanji.wrappers.JumanjiToDMEnvWrapper","text":"A wrapper that converts Environment to dm_env.Environment.","title":"JumanjiToDMEnvWrapper"},{"location":"api/wrappers/#jumanji.wrappers.MultiToSingleWrapper","text":"A wrapper that converts a multi-agent Environment to a single-agent Environment.","title":"MultiToSingleWrapper"},{"location":"api/wrappers/#jumanji.wrappers.VmapWrapper","text":"Vectorized Jax env. Please note that all methods that return arrays do not return a batch dimension because the batch size is not known to the VmapWrapper. Methods that omit the batch dimension include: - observation_spec - action_spec - reward_spec - discount_spec","title":"VmapWrapper"},{"location":"api/wrappers/#jumanji.wrappers.AutoResetWrapper","text":"Automatically resets environments that are done. Once the terminal state is reached, the state, observation, and step_type are reset. The observation and step_type of the terminal TimeStep is reset to the reset observation and StepType.LAST, respectively. The reward, discount, and extras retrieved from the transition to the terminal state. WARNING: do not jax.vmap the wrapped environment (e.g. do not use with the VmapWrapper ), which would lead to inefficient computation due to both the step and reset functions being processed each time step is called. Please use the VmapAutoResetWrapper instead.","title":"AutoResetWrapper"},{"location":"api/wrappers/#jumanji.wrappers.JumanjiToGymWrapper","text":"A wrapper that converts a Jumanji Environment to one that follows the gym.Env API.","title":"JumanjiToGymWrapper"},{"location":"api/wrappers/#jumanji.wrappers.jumanji_to_gym_obs","text":"Convert a Jumanji observation into a gym observation. Parameters: Name Type Description Default observation ~Observation JAX pytree with (possibly nested) containers that either have the __dict__ or _asdict methods implemented. required Returns: Type Description Any Numpy array or nested dictionary of numpy arrays.","title":"jumanji_to_gym_obs()"},{"location":"api/environments/bin_pack/","text":"BinPack ( Environment ) # Problem of 3D bin packing, where a set of items have to be placed in a 3D container with the goal of maximizing its volume utilization. This environment only supports 1 bin, meaning it is equivalent to the 3D-knapsack problem. We use the Empty Maximal Space (EMS) formulation of this problem. An EMS is a 3D-rectangular space that lives inside the container and has the following Properties It does not intersect any items, and it is not fully included into any other EMSs. It is defined by 2 3D-points, hence 6 coordinates (x1, x2, y1, y2, z1, z2), the first point corresponding to its bottom-left location while the second defining its top-right corner. observation: Observation ems: EMS tree of jax arrays (float if normalize_dimensions else int32) each of shape (obs_num_ems,), coordinates of all EMSs at the current timestep. ems_mask: jax array (bool) of shape (obs_num_ems,) indicates the EMSs that are valid. items: Item tree of jax arrays (float if normalize_dimensions else int32) each of shape (max_num_items,), characteristics of all items for this instance. items_mask: jax array (bool) of shape (max_num_items,) indicates the items that are valid. items_placed: jax array (bool) of shape (max_num_items,) indicates the items that have been placed so far. action_mask: jax array (bool) of shape (obs_num_ems, max_num_items) mask of the joint action space: True if the action (ems_id, item_id) is valid. action: MultiDiscreteArray (int32) of shape (obs_num_ems, max_num_items). ems_id: int between 0 and obs_num_ems - 1 (included). item_id: int between 0 and max_num_items - 1 (included). reward: jax array (float) of shape (), could be either: dense: increase in volume utilization of the container due to packing the chosen item. sparse: volume utilization of the container at the end of the episode. episode termination: if no action can be performed, i.e. no items fit in any EMSs, or all items have been packed. if an invalid action is taken, i.e. an item that does not fit in an EMS or one that is already packed. state: State coordinates: jax array (float) of shape (num_nodes + 1, 2) the coordinates of each node and the depot. demands: jax array (int32) of shape (num_nodes + 1,) the associated cost of each node and the depot (0.0 for the depot). position: jax array (int32) the index of the last visited node. capacity: jax array (int32) the current capacity of the vehicle. visited_mask: jax array (bool) of shape (num_nodes + 1,) binary mask (False/True <--> not visited/visited). trajectory: jax array (int32) of shape (2 * num_nodes,) identifiers of the nodes that have been visited (set to DEPOT_IDX if not filled yet). num_visits: int32 number of actions that have been taken (i.e., unique visits). 1 2 3 4 5 6 7 8 from jumanji.environments import BinPack env = BinPack () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state ) __init__ ( self , generator : Optional [ jumanji . environments . packing . bin_pack . generator . Generator ] = None , obs_num_ems : int = 40 , reward_fn : Optional [ jumanji . environments . packing . bin_pack . reward . RewardFn ] = None , normalize_dimensions : bool = True , debug : bool = False , viewer : Optional [ jumanji . viewer . Viewer [ jumanji . environments . packing . bin_pack . types . State ]] = None ) special # Instantiates a BinPack environment. Parameters: Name Type Description Default generator Optional[jumanji.environments.packing.bin_pack.generator.Generator] Generator whose __call__ instantiates an environment instance. Implemented options are [ RandomGenerator , ToyGenerator , CSVGenerator ]. Defaults to RandomGenerator that generates up to 20 items maximum and that can handle 40 EMSs. None obs_num_ems int number of EMSs (possible spaces in which to place an item) to show to the agent. If obs_num_ems is smaller than generator.max_num_ems , the first obs_num_ems largest EMSs (in terms of volume) will be returned in the observation. The good number heavily depends on the number of items (given by the instance generator). Default to 40 EMSs observable. 40 reward_fn Optional[jumanji.environments.packing.bin_pack.reward.RewardFn] compute the reward based on the current state, the chosen action, the next state, whether the transition is valid and if it is terminal. Implemented options are [ DenseReward , SparseReward ]. In each case, the total return at the end of an episode is the volume utilization of the container. Defaults to DenseReward . None normalize_dimensions bool if True, the observation is normalized (float) along each dimension into a unit cubic container. If False, the observation is returned in millimeters, i.e. integers (for both items and EMSs). Default to True. True debug bool if True, will add to timestep.extras an invalid_ems_from_env field that checks if an invalid EMS was created by the environment, which should not happen. Computing this metric slows down the environment. Default to False. False viewer Optional[jumanji.viewer.Viewer[jumanji.environments.packing.bin_pack.types.State]] Viewer used for rendering. Defaults to BinPackViewer with \"human\" render mode. None observation_spec ( self ) -> jumanji . specs . Spec [ jumanji . environments . packing . bin_pack . types . Observation ] # Specifications of the observation of the BinPack environment. Returns: Type Description Spec for the `Observation` whose fields are ems: if normalize_dimensions: tree of BoundedArray (float) of shape (obs_num_ems,). else: tree of BoundedArray (int32) of shape (obs_num_ems,). ems_mask: BoundedArray (bool) of shape (obs_num_ems,). items: if normalize_dimensions: tree of BoundedArray (float) of shape (max_num_items,). else: tree of BoundedArray (int32) of shape (max_num_items,). items_mask: BoundedArray (bool) of shape (max_num_items,). items_placed: BoundedArray (bool) of shape (max_num_items,). action_mask: BoundedArray (bool) of shape (obs_num_ems, max_num_items). action_spec ( self ) -> MultiDiscreteArray # Specifications of the action expected by the BinPack environment. Returns: Type Description MultiDiscreteArray (int32) of shape (obs_num_ems, max_num_items). - ems_id int between 0 and obs_num_ems - 1 (included). - item_id: int between 0 and max_num_items - 1 (included). reset ( self , key : PRNGKeyArray ) -> Tuple [ jumanji . environments . packing . bin_pack . types . State , jumanji . types . TimeStep [ jumanji . environments . packing . bin_pack . types . Observation ]] # Resets the environment by calling the instance generator for a new instance. Parameters: Name Type Description Default key PRNGKeyArray random key used to reset the environment. required Returns: Type Description state State object corresponding to the new state of the environment after a reset. timestep: TimeStep object corresponding the first timestep returned by the environment after a reset. Also contains the following metrics in the extras field: - volume_utilization: utilization (in [0, 1]) of the container. - packed_items: number of items that are packed in the container. - ratio_packed_items: ratio (in [0, 1]) of items that are packed in the container. - active_ems: number of active EMSs in the current instance. - invalid_action: True if the action that was just taken was invalid. - invalid_ems_from_env (optional): True if the environment produced an EMS that was invalid. Only available in debug mode. step ( self , state : State , action : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ]) -> Tuple [ jumanji . environments . packing . bin_pack . types . State , jumanji . types . TimeStep [ jumanji . environments . packing . bin_pack . types . Observation ]] # Run one timestep of the environment's dynamics. If the action is invalid, the state is not updated, i.e. the action is not taken, and the episode terminates. Parameters: Name Type Description Default state State State object containing the data of the current instance. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] jax array (int32) of shape (2,): (ems_id, item_id). This means placing the given item at the location of the given EMS. If the action is not valid, the flag invalid_action will be set to True in timestep.extras and the episode terminates. required Returns: Type Description state State object corresponding to the next state of the environment. timestep: TimeStep object corresponding to the timestep returned by the environment. Also contains metrics in the extras field: - volume_utilization: utilization (in [0, 1]) of the container. - packed_items: number of items that are packed in the container. - ratio_packed_items: ratio (in [0, 1]) of items that are packed in the container. - active_ems: number of EMSs in the current instance. - invalid_action: True if the action that was just taken was invalid. - invalid_ems_from_env (optional): True if the environment produced an EMS that was invalid. Only available in debug mode. render ( self , state : State ) -> Optional [ numpy . ndarray [ Any , numpy . dtype [ + ScalarType ]]] # Render the given state of the environment. Parameters: Name Type Description Default state State State object containing the current dynamics of the environment. required close ( self ) -> None # Perform any necessary cleanup. Environments will automatically :meth: close() themselves when garbage collected or when the program exits.","title":"BinPack"},{"location":"api/environments/bin_pack/#jumanji.environments.packing.bin_pack.env.BinPack","text":"Problem of 3D bin packing, where a set of items have to be placed in a 3D container with the goal of maximizing its volume utilization. This environment only supports 1 bin, meaning it is equivalent to the 3D-knapsack problem. We use the Empty Maximal Space (EMS) formulation of this problem. An EMS is a 3D-rectangular space that lives inside the container and has the following Properties It does not intersect any items, and it is not fully included into any other EMSs. It is defined by 2 3D-points, hence 6 coordinates (x1, x2, y1, y2, z1, z2), the first point corresponding to its bottom-left location while the second defining its top-right corner. observation: Observation ems: EMS tree of jax arrays (float if normalize_dimensions else int32) each of shape (obs_num_ems,), coordinates of all EMSs at the current timestep. ems_mask: jax array (bool) of shape (obs_num_ems,) indicates the EMSs that are valid. items: Item tree of jax arrays (float if normalize_dimensions else int32) each of shape (max_num_items,), characteristics of all items for this instance. items_mask: jax array (bool) of shape (max_num_items,) indicates the items that are valid. items_placed: jax array (bool) of shape (max_num_items,) indicates the items that have been placed so far. action_mask: jax array (bool) of shape (obs_num_ems, max_num_items) mask of the joint action space: True if the action (ems_id, item_id) is valid. action: MultiDiscreteArray (int32) of shape (obs_num_ems, max_num_items). ems_id: int between 0 and obs_num_ems - 1 (included). item_id: int between 0 and max_num_items - 1 (included). reward: jax array (float) of shape (), could be either: dense: increase in volume utilization of the container due to packing the chosen item. sparse: volume utilization of the container at the end of the episode. episode termination: if no action can be performed, i.e. no items fit in any EMSs, or all items have been packed. if an invalid action is taken, i.e. an item that does not fit in an EMS or one that is already packed. state: State coordinates: jax array (float) of shape (num_nodes + 1, 2) the coordinates of each node and the depot. demands: jax array (int32) of shape (num_nodes + 1,) the associated cost of each node and the depot (0.0 for the depot). position: jax array (int32) the index of the last visited node. capacity: jax array (int32) the current capacity of the vehicle. visited_mask: jax array (bool) of shape (num_nodes + 1,) binary mask (False/True <--> not visited/visited). trajectory: jax array (int32) of shape (2 * num_nodes,) identifiers of the nodes that have been visited (set to DEPOT_IDX if not filled yet). num_visits: int32 number of actions that have been taken (i.e., unique visits). 1 2 3 4 5 6 7 8 from jumanji.environments import BinPack env = BinPack () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state )","title":"BinPack"},{"location":"api/environments/bin_pack/#jumanji.environments.packing.bin_pack.env.BinPack.__init__","text":"Instantiates a BinPack environment. Parameters: Name Type Description Default generator Optional[jumanji.environments.packing.bin_pack.generator.Generator] Generator whose __call__ instantiates an environment instance. Implemented options are [ RandomGenerator , ToyGenerator , CSVGenerator ]. Defaults to RandomGenerator that generates up to 20 items maximum and that can handle 40 EMSs. None obs_num_ems int number of EMSs (possible spaces in which to place an item) to show to the agent. If obs_num_ems is smaller than generator.max_num_ems , the first obs_num_ems largest EMSs (in terms of volume) will be returned in the observation. The good number heavily depends on the number of items (given by the instance generator). Default to 40 EMSs observable. 40 reward_fn Optional[jumanji.environments.packing.bin_pack.reward.RewardFn] compute the reward based on the current state, the chosen action, the next state, whether the transition is valid and if it is terminal. Implemented options are [ DenseReward , SparseReward ]. In each case, the total return at the end of an episode is the volume utilization of the container. Defaults to DenseReward . None normalize_dimensions bool if True, the observation is normalized (float) along each dimension into a unit cubic container. If False, the observation is returned in millimeters, i.e. integers (for both items and EMSs). Default to True. True debug bool if True, will add to timestep.extras an invalid_ems_from_env field that checks if an invalid EMS was created by the environment, which should not happen. Computing this metric slows down the environment. Default to False. False viewer Optional[jumanji.viewer.Viewer[jumanji.environments.packing.bin_pack.types.State]] Viewer used for rendering. Defaults to BinPackViewer with \"human\" render mode. None","title":"__init__()"},{"location":"api/environments/bin_pack/#jumanji.environments.packing.bin_pack.env.BinPack.observation_spec","text":"Specifications of the observation of the BinPack environment. Returns: Type Description Spec for the `Observation` whose fields are ems: if normalize_dimensions: tree of BoundedArray (float) of shape (obs_num_ems,). else: tree of BoundedArray (int32) of shape (obs_num_ems,). ems_mask: BoundedArray (bool) of shape (obs_num_ems,). items: if normalize_dimensions: tree of BoundedArray (float) of shape (max_num_items,). else: tree of BoundedArray (int32) of shape (max_num_items,). items_mask: BoundedArray (bool) of shape (max_num_items,). items_placed: BoundedArray (bool) of shape (max_num_items,). action_mask: BoundedArray (bool) of shape (obs_num_ems, max_num_items).","title":"observation_spec()"},{"location":"api/environments/bin_pack/#jumanji.environments.packing.bin_pack.env.BinPack.action_spec","text":"Specifications of the action expected by the BinPack environment. Returns: Type Description MultiDiscreteArray (int32) of shape (obs_num_ems, max_num_items). - ems_id int between 0 and obs_num_ems - 1 (included). - item_id: int between 0 and max_num_items - 1 (included).","title":"action_spec()"},{"location":"api/environments/bin_pack/#jumanji.environments.packing.bin_pack.env.BinPack.reset","text":"Resets the environment by calling the instance generator for a new instance. Parameters: Name Type Description Default key PRNGKeyArray random key used to reset the environment. required Returns: Type Description state State object corresponding to the new state of the environment after a reset. timestep: TimeStep object corresponding the first timestep returned by the environment after a reset. Also contains the following metrics in the extras field: - volume_utilization: utilization (in [0, 1]) of the container. - packed_items: number of items that are packed in the container. - ratio_packed_items: ratio (in [0, 1]) of items that are packed in the container. - active_ems: number of active EMSs in the current instance. - invalid_action: True if the action that was just taken was invalid. - invalid_ems_from_env (optional): True if the environment produced an EMS that was invalid. Only available in debug mode.","title":"reset()"},{"location":"api/environments/bin_pack/#jumanji.environments.packing.bin_pack.env.BinPack.step","text":"Run one timestep of the environment's dynamics. If the action is invalid, the state is not updated, i.e. the action is not taken, and the episode terminates. Parameters: Name Type Description Default state State State object containing the data of the current instance. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] jax array (int32) of shape (2,): (ems_id, item_id). This means placing the given item at the location of the given EMS. If the action is not valid, the flag invalid_action will be set to True in timestep.extras and the episode terminates. required Returns: Type Description state State object corresponding to the next state of the environment. timestep: TimeStep object corresponding to the timestep returned by the environment. Also contains metrics in the extras field: - volume_utilization: utilization (in [0, 1]) of the container. - packed_items: number of items that are packed in the container. - ratio_packed_items: ratio (in [0, 1]) of items that are packed in the container. - active_ems: number of EMSs in the current instance. - invalid_action: True if the action that was just taken was invalid. - invalid_ems_from_env (optional): True if the environment produced an EMS that was invalid. Only available in debug mode.","title":"step()"},{"location":"api/environments/bin_pack/#jumanji.environments.packing.bin_pack.env.BinPack.render","text":"Render the given state of the environment. Parameters: Name Type Description Default state State State object containing the current dynamics of the environment. required","title":"render()"},{"location":"api/environments/bin_pack/#jumanji.environments.packing.bin_pack.env.BinPack.close","text":"Perform any necessary cleanup. Environments will automatically :meth: close() themselves when garbage collected or when the program exits.","title":"close()"},{"location":"api/environments/cleaner/","text":"Cleaner ( Environment ) # A JAX implementation of the 'Cleaner' game where multiple agents have to clean all tiles of a maze. observation: Observation grid: jax array (int32) of shape (num_rows, num_cols) contains the state of the board: 0 for dirty tile, 1 for clean tile, 2 for wall. agents_locations: jax array (int32) of shape (num_agents, 2) contains the location of each agent on the board. action_mask: jax array (bool) of shape (num_agents, 4) indicates for each agent if each of the four actions (up, right, down, left) is allowed. step_count: (int32) the number of step since the beginning of the episode. action: jax array (int32) of shape (num_agents,) the action for each agent: (0: up, 1: right, 2: down, 3: left) reward: jax array (float) of shape () +1 every time a tile is cleaned and a configurable penalty (-0.5 by default) for each timestep. episode termination: All tiles are clean. The number of steps is greater than the limit. An invalid action is selected for any of the agents. state: State grid: jax array (int32) of shape (num_rows, num_cols) contains the current state of the board: 0 for dirty tile, 1 for clean tile, 2 for wall. agents_locations: jax array (int32) of shape (num_agents, 2) contains the location of each agent on the board. action_mask: jax array (bool) of shape (num_agents, 4) indicates for each agent if each of the four actions (up, right, down, left) is allowed. step_count: jax array (int32) of shape () the number of steps since the beginning of the episode. key: jax array (uint) of shape (2,) jax random generation key. Ignored since the environment is deterministic. 1 2 3 4 5 6 7 8 from jumanji.environments import Cleaner env = Cleaner () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state ) __init__ ( self , generator : Optional [ jumanji . environments . routing . cleaner . generator . Generator ] = None , time_limit : Optional [ int ] = None , penalty_per_timestep : float = 0.5 , viewer : Optional [ jumanji . viewer . Viewer [ jumanji . environments . routing . cleaner . types . State ]] = None ) -> None special # Instantiates a Cleaner environment. Parameters: Name Type Description Default num_agents number of agents. Defaults to 3. required time_limit Optional[int] max number of steps in an episode. Defaults to num_rows * num_cols . None generator Optional[jumanji.environments.routing.cleaner.generator.Generator] Generator whose __call__ instantiates an environment instance. Implemented options are [ RandomGenerator ]. Defaults to RandomGenerator with num_rows=10 , num_cols=10 and num_agents=3 . None viewer Optional[jumanji.viewer.Viewer[jumanji.environments.routing.cleaner.types.State]] Viewer used for rendering. Defaults to CleanerViewer with \"human\" render mode. None penalty_per_timestep float the penalty returned at each timestep in the reward. 0.5 __repr__ ( self ) -> str special # observation_spec ( self ) -> jumanji . specs . Spec [ jumanji . environments . routing . cleaner . types . Observation ] # Specification of the observation of the Cleaner environment. Returns: Type Description Spec for the `Observation`, consisting of the fields grid: BoundedArray (int32) of shape (num_rows, num_cols). Values are between 0 and 2 (inclusive). agent_locations_spec: BoundedArray (int32) of shape (num_agents, 2). Maximum value for the first column is num_rows, and maximum value for the second is num_cols. action_mask: BoundedArray (bool) of shape (num_agent, 4). step_count: BoundedArray (int32) of shape (). action_spec ( self ) -> MultiDiscreteArray # Specification of the action for the Cleaner environment. Returns: Type Description action_spec a specs.MultiDiscreteArray spec. reset ( self , key : PRNGKeyArray ) -> Tuple [ jumanji . environments . routing . cleaner . types . State , jumanji . types . TimeStep [ jumanji . environments . routing . cleaner . types . Observation ]] # Reset the environment to its initial state. All the tiles except upper left are dirty, and the agents start in the upper left corner of the grid. Parameters: Name Type Description Default key PRNGKeyArray random key used to reset the environment. required Returns: Type Description state State object corresponding to the new state of the environment after a reset. timestep: TimeStep object corresponding to the first timestep returned by the environment after a reset. step ( self , state : State , action : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ]) -> Tuple [ jumanji . environments . routing . cleaner . types . State , jumanji . types . TimeStep [ jumanji . environments . routing . cleaner . types . Observation ]] # Run one timestep of the environment's dynamics. If an action is invalid, the corresponding agent does not move and the episode terminates. Parameters: Name Type Description Default state State current environment state. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] Jax array of shape (num_agents,). Each agent moves one step in the specified direction (0: up, 1: right, 2: down, 3: left). required Returns: Type Description state State object corresponding to the next state of the environment. timestep: TimeStep object corresponding to the timestep returned by the environment. render ( self , state : State ) -> Optional [ numpy . ndarray [ Any , numpy . dtype [ + ScalarType ]]] # Render the given state of the environment. Parameters: Name Type Description Default state State State object containing the current environment state. required animate ( self , states : Sequence [ jumanji . environments . routing . cleaner . types . State ], interval : int = 200 , save_path : Optional [ str ] = None ) -> FuncAnimation # Creates an animated gif of the Cleaner environment based on the sequence of states. Parameters: Name Type Description Default states Sequence[jumanji.environments.routing.cleaner.types.State] sequence of environment states corresponding to consecutive timesteps. required interval int delay between frames in milliseconds, default to 200. 200 save_path Optional[str] the path where the animation file should be saved. If it is None, the plot will not be saved. None Returns: Type Description animation.FuncAnimation the animation object that was created. close ( self ) -> None # Perform any necessary cleanup. Environments will automatically :meth: close() themselves when garbage collected or when the program exits.","title":"Cleaner"},{"location":"api/environments/cleaner/#jumanji.environments.routing.cleaner.env.Cleaner","text":"A JAX implementation of the 'Cleaner' game where multiple agents have to clean all tiles of a maze. observation: Observation grid: jax array (int32) of shape (num_rows, num_cols) contains the state of the board: 0 for dirty tile, 1 for clean tile, 2 for wall. agents_locations: jax array (int32) of shape (num_agents, 2) contains the location of each agent on the board. action_mask: jax array (bool) of shape (num_agents, 4) indicates for each agent if each of the four actions (up, right, down, left) is allowed. step_count: (int32) the number of step since the beginning of the episode. action: jax array (int32) of shape (num_agents,) the action for each agent: (0: up, 1: right, 2: down, 3: left) reward: jax array (float) of shape () +1 every time a tile is cleaned and a configurable penalty (-0.5 by default) for each timestep. episode termination: All tiles are clean. The number of steps is greater than the limit. An invalid action is selected for any of the agents. state: State grid: jax array (int32) of shape (num_rows, num_cols) contains the current state of the board: 0 for dirty tile, 1 for clean tile, 2 for wall. agents_locations: jax array (int32) of shape (num_agents, 2) contains the location of each agent on the board. action_mask: jax array (bool) of shape (num_agents, 4) indicates for each agent if each of the four actions (up, right, down, left) is allowed. step_count: jax array (int32) of shape () the number of steps since the beginning of the episode. key: jax array (uint) of shape (2,) jax random generation key. Ignored since the environment is deterministic. 1 2 3 4 5 6 7 8 from jumanji.environments import Cleaner env = Cleaner () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state )","title":"Cleaner"},{"location":"api/environments/cleaner/#jumanji.environments.routing.cleaner.env.Cleaner.__init__","text":"Instantiates a Cleaner environment. Parameters: Name Type Description Default num_agents number of agents. Defaults to 3. required time_limit Optional[int] max number of steps in an episode. Defaults to num_rows * num_cols . None generator Optional[jumanji.environments.routing.cleaner.generator.Generator] Generator whose __call__ instantiates an environment instance. Implemented options are [ RandomGenerator ]. Defaults to RandomGenerator with num_rows=10 , num_cols=10 and num_agents=3 . None viewer Optional[jumanji.viewer.Viewer[jumanji.environments.routing.cleaner.types.State]] Viewer used for rendering. Defaults to CleanerViewer with \"human\" render mode. None penalty_per_timestep float the penalty returned at each timestep in the reward. 0.5","title":"__init__()"},{"location":"api/environments/cleaner/#jumanji.environments.routing.cleaner.env.Cleaner.__repr__","text":"","title":"__repr__()"},{"location":"api/environments/cleaner/#jumanji.environments.routing.cleaner.env.Cleaner.observation_spec","text":"Specification of the observation of the Cleaner environment. Returns: Type Description Spec for the `Observation`, consisting of the fields grid: BoundedArray (int32) of shape (num_rows, num_cols). Values are between 0 and 2 (inclusive). agent_locations_spec: BoundedArray (int32) of shape (num_agents, 2). Maximum value for the first column is num_rows, and maximum value for the second is num_cols. action_mask: BoundedArray (bool) of shape (num_agent, 4). step_count: BoundedArray (int32) of shape ().","title":"observation_spec()"},{"location":"api/environments/cleaner/#jumanji.environments.routing.cleaner.env.Cleaner.action_spec","text":"Specification of the action for the Cleaner environment. Returns: Type Description action_spec a specs.MultiDiscreteArray spec.","title":"action_spec()"},{"location":"api/environments/cleaner/#jumanji.environments.routing.cleaner.env.Cleaner.reset","text":"Reset the environment to its initial state. All the tiles except upper left are dirty, and the agents start in the upper left corner of the grid. Parameters: Name Type Description Default key PRNGKeyArray random key used to reset the environment. required Returns: Type Description state State object corresponding to the new state of the environment after a reset. timestep: TimeStep object corresponding to the first timestep returned by the environment after a reset.","title":"reset()"},{"location":"api/environments/cleaner/#jumanji.environments.routing.cleaner.env.Cleaner.step","text":"Run one timestep of the environment's dynamics. If an action is invalid, the corresponding agent does not move and the episode terminates. Parameters: Name Type Description Default state State current environment state. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] Jax array of shape (num_agents,). Each agent moves one step in the specified direction (0: up, 1: right, 2: down, 3: left). required Returns: Type Description state State object corresponding to the next state of the environment. timestep: TimeStep object corresponding to the timestep returned by the environment.","title":"step()"},{"location":"api/environments/cleaner/#jumanji.environments.routing.cleaner.env.Cleaner.render","text":"Render the given state of the environment. Parameters: Name Type Description Default state State State object containing the current environment state. required","title":"render()"},{"location":"api/environments/cleaner/#jumanji.environments.routing.cleaner.env.Cleaner.animate","text":"Creates an animated gif of the Cleaner environment based on the sequence of states. Parameters: Name Type Description Default states Sequence[jumanji.environments.routing.cleaner.types.State] sequence of environment states corresponding to consecutive timesteps. required interval int delay between frames in milliseconds, default to 200. 200 save_path Optional[str] the path where the animation file should be saved. If it is None, the plot will not be saved. None Returns: Type Description animation.FuncAnimation the animation object that was created.","title":"animate()"},{"location":"api/environments/cleaner/#jumanji.environments.routing.cleaner.env.Cleaner.close","text":"Perform any necessary cleanup. Environments will automatically :meth: close() themselves when garbage collected or when the program exits.","title":"close()"},{"location":"api/environments/connector/","text":"Connector ( Environment ) # The Connector environment is a gridworld problem where multiple pairs of points (sets) must be connected without overlapping the paths taken by any other set. This is achieved by allowing certain points to move to an adjacent cell at each step. However, each time a point moves it leaves an impassable trail behind it. The goal is to connect all sets. observation - Observation action mask: jax array (bool) of shape (num_agents, 5). step_count: jax array (int32) of shape () the current episode step. grid: jax array (int32) of shape (grid_size, grid_size) with 2 agents you might have a grid like this: 4 0 1 5 0 1 6 3 2 which means agent 1 has moved from the top right of the grid down and is currently in the bottom right corner and is aiming to get to the middle bottom cell. Agent 2 started in the top left and moved down once towards its target in the bottom left. action: jax array (int32) of shape (num_agents,): can take the values [0,1,2,3,4] which correspond to [No Op, Up, Right, Down, Left]. each value in the array corresponds to an agent's action. reward: jax array (float) of shape (): dense: reward is 1 for each successful connection on that step. Additionally, each pair of points that have not connected receives a penalty reward of -0.03. episode termination: all agents either can't move (no available actions) or have connected to their target. the time limit is reached. state: State: key: jax PRNG key used to randomly spawn agents and targets. grid: jax array (int32) of shape (grid_size, grid_size) giving the observation. step_count: jax array (int32) of shape () number of steps elapsed in the current episode. 1 2 3 4 5 6 7 8 from jumanji.environments import Connector env = Connector () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state ) reset ( self , key : PRNGKeyArray ) -> Tuple [ jumanji . environments . routing . connector . types . State , jumanji . types . TimeStep [ jumanji . environments . routing . connector . types . Observation ]] # Resets the environment. Parameters: Name Type Description Default key PRNGKeyArray used to randomly generate the connector grid. required Returns: Type Description state State object corresponding to the new state of the environment. timestep: TimeStep object corresponding to the initial environment timestep. step ( self , state : State , action : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ]) -> Tuple [ jumanji . environments . routing . connector . types . State , jumanji . types . TimeStep [ jumanji . environments . routing . connector . types . Observation ]] # Perform an environment step. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] Array containing the actions to take for each agent. - 0 no op - 1 move up - 2 move right - 3 move down - 4 move left required Returns: Type Description state State object corresponding to the next state of the environment. timestep: TimeStep object corresponding the timestep returned by the environment. render ( self , state : State ) -> Optional [ numpy . ndarray [ Any , numpy . dtype [ + ScalarType ]]] # Render the given state of the environment. Parameters: Name Type Description Default state State State object containing the current environment state. required observation_spec ( self ) -> jumanji . specs . Spec [ jumanji . environments . routing . connector . types . Observation ] # Specifications of the observation of the Connector environment. Returns: Type Description Spec for the `Observation` whose fields are grid: BoundedArray (int32) of shape (grid_size, grid_size). action_mask: BoundedArray (bool) of shape (num_agents, 5). step_count: BoundedArray (int32) of shape (). action_spec ( self ) -> MultiDiscreteArray # Returns the action spec for the Connector environment. 5 actions: [0,1,2,3,4] -> [No Op, Up, Right, Down, Left]. Since this is an environment with a multi-dimensional action space, it expects an array of actions of shape (num_agents,). Returns: Type Description observation_spec MultiDiscreteArray of shape (num_agents,).","title":"Connector"},{"location":"api/environments/connector/#jumanji.environments.routing.connector.env.Connector","text":"The Connector environment is a gridworld problem where multiple pairs of points (sets) must be connected without overlapping the paths taken by any other set. This is achieved by allowing certain points to move to an adjacent cell at each step. However, each time a point moves it leaves an impassable trail behind it. The goal is to connect all sets. observation - Observation action mask: jax array (bool) of shape (num_agents, 5). step_count: jax array (int32) of shape () the current episode step. grid: jax array (int32) of shape (grid_size, grid_size) with 2 agents you might have a grid like this: 4 0 1 5 0 1 6 3 2 which means agent 1 has moved from the top right of the grid down and is currently in the bottom right corner and is aiming to get to the middle bottom cell. Agent 2 started in the top left and moved down once towards its target in the bottom left. action: jax array (int32) of shape (num_agents,): can take the values [0,1,2,3,4] which correspond to [No Op, Up, Right, Down, Left]. each value in the array corresponds to an agent's action. reward: jax array (float) of shape (): dense: reward is 1 for each successful connection on that step. Additionally, each pair of points that have not connected receives a penalty reward of -0.03. episode termination: all agents either can't move (no available actions) or have connected to their target. the time limit is reached. state: State: key: jax PRNG key used to randomly spawn agents and targets. grid: jax array (int32) of shape (grid_size, grid_size) giving the observation. step_count: jax array (int32) of shape () number of steps elapsed in the current episode. 1 2 3 4 5 6 7 8 from jumanji.environments import Connector env = Connector () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state )","title":"Connector"},{"location":"api/environments/connector/#jumanji.environments.routing.connector.env.Connector.reset","text":"Resets the environment. Parameters: Name Type Description Default key PRNGKeyArray used to randomly generate the connector grid. required Returns: Type Description state State object corresponding to the new state of the environment. timestep: TimeStep object corresponding to the initial environment timestep.","title":"reset()"},{"location":"api/environments/connector/#jumanji.environments.routing.connector.env.Connector.step","text":"Perform an environment step. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] Array containing the actions to take for each agent. - 0 no op - 1 move up - 2 move right - 3 move down - 4 move left required Returns: Type Description state State object corresponding to the next state of the environment. timestep: TimeStep object corresponding the timestep returned by the environment.","title":"step()"},{"location":"api/environments/connector/#jumanji.environments.routing.connector.env.Connector.render","text":"Render the given state of the environment. Parameters: Name Type Description Default state State State object containing the current environment state. required","title":"render()"},{"location":"api/environments/connector/#jumanji.environments.routing.connector.env.Connector.observation_spec","text":"Specifications of the observation of the Connector environment. Returns: Type Description Spec for the `Observation` whose fields are grid: BoundedArray (int32) of shape (grid_size, grid_size). action_mask: BoundedArray (bool) of shape (num_agents, 5). step_count: BoundedArray (int32) of shape ().","title":"observation_spec()"},{"location":"api/environments/connector/#jumanji.environments.routing.connector.env.Connector.action_spec","text":"Returns the action spec for the Connector environment. 5 actions: [0,1,2,3,4] -> [No Op, Up, Right, Down, Left]. Since this is an environment with a multi-dimensional action space, it expects an array of actions of shape (num_agents,). Returns: Type Description observation_spec MultiDiscreteArray of shape (num_agents,).","title":"action_spec()"},{"location":"api/environments/cvrp/","text":"CVRP ( Environment ) # Capacitated Vehicle Routing Problem (CVRP) environment as described in [1]. observation: Observation coordinates: jax array (float) of shape (num_nodes + 1, 2) the coordinates of each node and the depot. demands: jax array (float) of shape (num_nodes + 1,) the associated cost of each node and the depot (0.0 for the depot). unvisited_nodes: jax array (bool) of shape (num_nodes + 1,) indicates nodes that remain to be visited. position: jax array (int32) of shape () the index of the last visited node. trajectory: jax array (int32) of shape (2 * num_nodes,) array of node indices defining the route (set to DEPOT_IDX if not filled yet). capacity: jax array (float) of shape () the current capacity of the vehicle. action_mask: jax array (bool) of shape (num_nodes + 1,) binary mask (False/True <--> invalid/valid action). action: jax array (int32) of shape () [0, ..., num_nodes] -> node to visit. 0 corresponds to visiting the depot. reward: jax array (float) of shape (), could be either: dense: the negative distance between the current node and the chosen next node to go to. For the last node, it also includes the distance to the depot to complete the tour. sparse: the negative tour length at the end of the episode. The tour length is defined as the sum of the distances between consecutive nodes. In both cases, the reward is a large negative penalty of -2 * num_nodes * sqrt(2) if the action is invalid, e.g. a previously selected node other than the depot is selected again. episode termination: if no action can be performed, i.e. all nodes have been visited. if an invalid action is taken, i.e. a previously visited city other than the depot is chosen. state: State coordinates: jax array (float) of shape (num_nodes + 1, 2) the coordinates of each node and the depot. demands: jax array (int32) of shape (num_nodes + 1,) the associated cost of each node and the depot (0.0 for the depot). position: jax array (int32) the index of the last visited node. capacity: jax array (int32) the current capacity of the vehicle. visited_mask: jax array (bool) of shape (num_nodes + 1,) binary mask (False/True <--> not visited/visited). trajectory: jax array (int32) of shape (2 * num_nodes,) identifiers of the nodes that have been visited (set to DEPOT_IDX if not filled yet). num_visits: int32 number of actions that have been taken (i.e., unique visits). [1] Toth P., Vigo D. (2014). \"Vehicle routing: problems, methods, and applications\". 1 2 3 4 5 6 7 8 from jumanji.environments import CVRP env = CVRP () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state ) __init__ ( self , generator : Optional [ jumanji . environments . routing . cvrp . generator . Generator ] = None , reward_fn : Optional [ jumanji . environments . routing . cvrp . reward . RewardFn ] = None , viewer : Optional [ jumanji . viewer . Viewer [ jumanji . environments . routing . cvrp . types . State ]] = None ) special # Instantiates a CVRP environment. Parameters: Name Type Description Default generator Optional[jumanji.environments.routing.cvrp.generator.Generator] Generator whose __call__ instantiates an environment instance. The default option is 'UniformGenerator' which randomly generates CVRP instances with 20 cities sampled from a uniform distribution, a maximum vehicle capacity of 30, and a maximum city demand of 10. None reward_fn Optional[jumanji.environments.routing.cvrp.reward.RewardFn] RewardFn whose __call__ method computes the reward of an environment transition. The function must compute the reward based on the current state, the chosen action, the next state and whether the action is valid. Implemented options are [ DenseReward , SparseReward ]. Defaults to DenseReward . None viewer Optional[jumanji.viewer.Viewer[jumanji.environments.routing.cvrp.types.State]] Viewer used for rendering. Defaults to CVRPViewer with \"human\" render mode. None reset ( self , key : PRNGKeyArray ) -> Tuple [ jumanji . environments . routing . cvrp . types . State , jumanji . types . TimeStep [ jumanji . environments . routing . cvrp . types . Observation ]] # Resets the environment. Parameters: Name Type Description Default key PRNGKeyArray used to randomly generate the coordinates. required Returns: Type Description state State object corresponding to the new state of the environment. timestep: TimeStep object corresponding to the first timestep returned by the environment. step ( self , state : State , action : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number , float , int ]) -> Tuple [ jumanji . environments . routing . cvrp . types . State , jumanji . types . TimeStep [ jumanji . environments . routing . cvrp . types . Observation ]] # Run one timestep of the environment's dynamics. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, float, int] jax array (int32) of shape () containing the index of the next node to visit. required Returns: Type Description state, timestep next state of the environment and timestep to be observed. observation_spec ( self ) -> jumanji . specs . Spec [ jumanji . environments . routing . cvrp . types . Observation ] # Returns the observation spec. Returns: Type Description Spec for the `Observation` whose fields are coordinates: BoundedArray (float) of shape (num_nodes + 1, 2). demands: BoundedArray (float) of shape (num_nodes + 1,). unvisited_nodes: BoundedArray (bool) of shape (num_nodes + 1,). position: DiscreteArray (num_values = num_nodes + 1) of shape (). trajectory: BoundedArray (int32) of shape (2 * num_nodes,). capacity: BoundedArray (float) of shape (). action_mask: BoundedArray (bool) of shape (num_nodes + 1,). action_spec ( self ) -> DiscreteArray # Returns the action spec. Returns: Type Description action_spec a specs.DiscreteArray spec.","title":"CVRP"},{"location":"api/environments/cvrp/#jumanji.environments.routing.cvrp.env.CVRP","text":"Capacitated Vehicle Routing Problem (CVRP) environment as described in [1]. observation: Observation coordinates: jax array (float) of shape (num_nodes + 1, 2) the coordinates of each node and the depot. demands: jax array (float) of shape (num_nodes + 1,) the associated cost of each node and the depot (0.0 for the depot). unvisited_nodes: jax array (bool) of shape (num_nodes + 1,) indicates nodes that remain to be visited. position: jax array (int32) of shape () the index of the last visited node. trajectory: jax array (int32) of shape (2 * num_nodes,) array of node indices defining the route (set to DEPOT_IDX if not filled yet). capacity: jax array (float) of shape () the current capacity of the vehicle. action_mask: jax array (bool) of shape (num_nodes + 1,) binary mask (False/True <--> invalid/valid action). action: jax array (int32) of shape () [0, ..., num_nodes] -> node to visit. 0 corresponds to visiting the depot. reward: jax array (float) of shape (), could be either: dense: the negative distance between the current node and the chosen next node to go to. For the last node, it also includes the distance to the depot to complete the tour. sparse: the negative tour length at the end of the episode. The tour length is defined as the sum of the distances between consecutive nodes. In both cases, the reward is a large negative penalty of -2 * num_nodes * sqrt(2) if the action is invalid, e.g. a previously selected node other than the depot is selected again. episode termination: if no action can be performed, i.e. all nodes have been visited. if an invalid action is taken, i.e. a previously visited city other than the depot is chosen. state: State coordinates: jax array (float) of shape (num_nodes + 1, 2) the coordinates of each node and the depot. demands: jax array (int32) of shape (num_nodes + 1,) the associated cost of each node and the depot (0.0 for the depot). position: jax array (int32) the index of the last visited node. capacity: jax array (int32) the current capacity of the vehicle. visited_mask: jax array (bool) of shape (num_nodes + 1,) binary mask (False/True <--> not visited/visited). trajectory: jax array (int32) of shape (2 * num_nodes,) identifiers of the nodes that have been visited (set to DEPOT_IDX if not filled yet). num_visits: int32 number of actions that have been taken (i.e., unique visits). [1] Toth P., Vigo D. (2014). \"Vehicle routing: problems, methods, and applications\". 1 2 3 4 5 6 7 8 from jumanji.environments import CVRP env = CVRP () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state )","title":"CVRP"},{"location":"api/environments/cvrp/#jumanji.environments.routing.cvrp.env.CVRP.__init__","text":"Instantiates a CVRP environment. Parameters: Name Type Description Default generator Optional[jumanji.environments.routing.cvrp.generator.Generator] Generator whose __call__ instantiates an environment instance. The default option is 'UniformGenerator' which randomly generates CVRP instances with 20 cities sampled from a uniform distribution, a maximum vehicle capacity of 30, and a maximum city demand of 10. None reward_fn Optional[jumanji.environments.routing.cvrp.reward.RewardFn] RewardFn whose __call__ method computes the reward of an environment transition. The function must compute the reward based on the current state, the chosen action, the next state and whether the action is valid. Implemented options are [ DenseReward , SparseReward ]. Defaults to DenseReward . None viewer Optional[jumanji.viewer.Viewer[jumanji.environments.routing.cvrp.types.State]] Viewer used for rendering. Defaults to CVRPViewer with \"human\" render mode. None","title":"__init__()"},{"location":"api/environments/cvrp/#jumanji.environments.routing.cvrp.env.CVRP.reset","text":"Resets the environment. Parameters: Name Type Description Default key PRNGKeyArray used to randomly generate the coordinates. required Returns: Type Description state State object corresponding to the new state of the environment. timestep: TimeStep object corresponding to the first timestep returned by the environment.","title":"reset()"},{"location":"api/environments/cvrp/#jumanji.environments.routing.cvrp.env.CVRP.step","text":"Run one timestep of the environment's dynamics. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, float, int] jax array (int32) of shape () containing the index of the next node to visit. required Returns: Type Description state, timestep next state of the environment and timestep to be observed.","title":"step()"},{"location":"api/environments/cvrp/#jumanji.environments.routing.cvrp.env.CVRP.observation_spec","text":"Returns the observation spec. Returns: Type Description Spec for the `Observation` whose fields are coordinates: BoundedArray (float) of shape (num_nodes + 1, 2). demands: BoundedArray (float) of shape (num_nodes + 1,). unvisited_nodes: BoundedArray (bool) of shape (num_nodes + 1,). position: DiscreteArray (num_values = num_nodes + 1) of shape (). trajectory: BoundedArray (int32) of shape (2 * num_nodes,). capacity: BoundedArray (float) of shape (). action_mask: BoundedArray (bool) of shape (num_nodes + 1,).","title":"observation_spec()"},{"location":"api/environments/cvrp/#jumanji.environments.routing.cvrp.env.CVRP.action_spec","text":"Returns the action spec. Returns: Type Description action_spec a specs.DiscreteArray spec.","title":"action_spec()"},{"location":"api/environments/game_2048/","text":"Game2048 ( Environment ) # Environment for the game 2048. The game consists of a board of size board_size x board_size (4x4 by default) in which the player can take actions to move the tiles on the board up, down, left, or right. The goal of the game is to combine tiles with the same number to create a tile with twice the value, until the player at least creates a tile with the value 2048 to consider it a win. observation: Observation board: jax array (int32) of shape (board_size, board_size) the current state of the board. An empty tile is represented by zero whereas a non-empty tile is an exponent of 2, e.g. 1, 2, 3, 4, ... (corresponding to 2, 4, 8, 16, ...). action_mask: jax array (bool) of shape (4,) indicates which actions are valid in the current state of the environment. action: jax array (int32) of shape (). Is in [0, 1, 2, 3] representing the actions up, right, down, and left, respectively. reward: jax array (float) of shape (). The reward is 0 except when the player combines tiles to create a new tile with twice the value. In this case, the reward is the value of the new tile. episode termination: if no more valid moves exist (this can happen when the board is full). state: State board: same as observation. step_count: jax array (int32) of shape (), the number of time steps in the episode so far. action_mask: same as observation. score: jax array (int32) of shape (), the sum of all tile values on the board. key: jax array (uint32) of shape (2,) random key used to generate random numbers at each step and for auto-reset. 1 2 3 4 5 6 7 8 from jumanji.environments import Game2048 env = Game2048 () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state ) __init__ ( self , board_size : int = 4 , viewer : Optional [ jumanji . viewer . Viewer [ jumanji . environments . logic . game_2048 . types . State ]] = None ) -> None special # Initialize the 2048 game. Parameters: Name Type Description Default board_size int size of the board. Defaults to 4. 4 viewer Optional[jumanji.viewer.Viewer[jumanji.environments.logic.game_2048.types.State]] Viewer used for rendering. Defaults to Game2048Viewer . None observation_spec ( self ) -> jumanji . specs . Spec [ jumanji . environments . logic . game_2048 . types . Observation ] # Specifications of the observation of the Game2048 environment. Returns: Type Description Spec containing all the specifications for all the `Observation` fields board: Array (jnp.int32) of shape (board_size, board_size). action_mask: BoundedArray (bool) of shape (4,). action_spec ( self ) -> DiscreteArray # Returns the action spec. 4 actions: [0, 1, 2, 3] -> [Up, Right, Down, Left]. Returns: Type Description action_spec DiscreteArray spec object. reset ( self , key : PRNGKeyArray ) -> Tuple [ jumanji . environments . logic . game_2048 . types . State , jumanji . types . TimeStep [ jumanji . environments . logic . game_2048 . types . Observation ]] # Resets the environment. Parameters: Name Type Description Default key PRNGKeyArray random number generator key. required Returns: Type Description state the new state of the environment. timestep: the first timestep returned by the environment. step ( self , state : State , action : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ]) -> Tuple [ jumanji . environments . logic . game_2048 . types . State , jumanji . types . TimeStep [ jumanji . environments . logic . game_2048 . types . Observation ]] # Updates the environment state after the agent takes an action. Parameters: Name Type Description Default state State the current state of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] the action taken by the agent. required Returns: Type Description state the new state of the environment. timestep: the next timestep.","title":"Game2048"},{"location":"api/environments/game_2048/#jumanji.environments.logic.game_2048.env.Game2048","text":"Environment for the game 2048. The game consists of a board of size board_size x board_size (4x4 by default) in which the player can take actions to move the tiles on the board up, down, left, or right. The goal of the game is to combine tiles with the same number to create a tile with twice the value, until the player at least creates a tile with the value 2048 to consider it a win. observation: Observation board: jax array (int32) of shape (board_size, board_size) the current state of the board. An empty tile is represented by zero whereas a non-empty tile is an exponent of 2, e.g. 1, 2, 3, 4, ... (corresponding to 2, 4, 8, 16, ...). action_mask: jax array (bool) of shape (4,) indicates which actions are valid in the current state of the environment. action: jax array (int32) of shape (). Is in [0, 1, 2, 3] representing the actions up, right, down, and left, respectively. reward: jax array (float) of shape (). The reward is 0 except when the player combines tiles to create a new tile with twice the value. In this case, the reward is the value of the new tile. episode termination: if no more valid moves exist (this can happen when the board is full). state: State board: same as observation. step_count: jax array (int32) of shape (), the number of time steps in the episode so far. action_mask: same as observation. score: jax array (int32) of shape (), the sum of all tile values on the board. key: jax array (uint32) of shape (2,) random key used to generate random numbers at each step and for auto-reset. 1 2 3 4 5 6 7 8 from jumanji.environments import Game2048 env = Game2048 () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state )","title":"Game2048"},{"location":"api/environments/game_2048/#jumanji.environments.logic.game_2048.env.Game2048.__init__","text":"Initialize the 2048 game. Parameters: Name Type Description Default board_size int size of the board. Defaults to 4. 4 viewer Optional[jumanji.viewer.Viewer[jumanji.environments.logic.game_2048.types.State]] Viewer used for rendering. Defaults to Game2048Viewer . None","title":"__init__()"},{"location":"api/environments/game_2048/#jumanji.environments.logic.game_2048.env.Game2048.observation_spec","text":"Specifications of the observation of the Game2048 environment. Returns: Type Description Spec containing all the specifications for all the `Observation` fields board: Array (jnp.int32) of shape (board_size, board_size). action_mask: BoundedArray (bool) of shape (4,).","title":"observation_spec()"},{"location":"api/environments/game_2048/#jumanji.environments.logic.game_2048.env.Game2048.action_spec","text":"Returns the action spec. 4 actions: [0, 1, 2, 3] -> [Up, Right, Down, Left]. Returns: Type Description action_spec DiscreteArray spec object.","title":"action_spec()"},{"location":"api/environments/game_2048/#jumanji.environments.logic.game_2048.env.Game2048.reset","text":"Resets the environment. Parameters: Name Type Description Default key PRNGKeyArray random number generator key. required Returns: Type Description state the new state of the environment. timestep: the first timestep returned by the environment.","title":"reset()"},{"location":"api/environments/game_2048/#jumanji.environments.logic.game_2048.env.Game2048.step","text":"Updates the environment state after the agent takes an action. Parameters: Name Type Description Default state State the current state of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] the action taken by the agent. required Returns: Type Description state the new state of the environment. timestep: the next timestep.","title":"step()"},{"location":"api/environments/graph_coloring/","text":"GraphColoring ( Environment ) # Environment for the GraphColoring problem. The problem is a combinatorial optimization task where the goal is to assign a color to each vertex of a graph in such a way that no two adjacent vertices share the same color. The problem is usually formulated as minimizing the number of colors used. observation: Observation adj_matrix: jax array (bool) of shape (num_nodes, num_nodes), representing the adjacency matrix of the graph. colors: jax array (int32) of shape (num_nodes,), representing the current color assignments for the vertices. action_mask: jax array (bool) of shape (num_colors,), indicating which actions are valid in the current state of the environment. current_node_index: integer representing the current node being colored. action: int, the color to be assigned to the current node (0 to num_nodes - 1) reward: float, a sparse reward is provided at the end of the episode. Equals the negative of the number of unique colors used to color all vertices in the graph. If an invalid action is taken, the reward is the negative of the total number of colors. episode termination: if all nodes have been assigned a color or if an invalid action is taken. state: State adj_matrix: jax array (bool) of shape (num_nodes, num_nodes), representing the adjacency matrix of the graph. colors: jax array (int32) of shape (num_nodes,), color assigned to each node, -1 if not assigned. current_node_index: jax array (int) with shape (), index of the current node. action_mask: jax array (bool) of shape (num_colors,), indicating which actions are valid in the current state of the environment. key: jax array (uint32) of shape (2,), random key used to generate random numbers at each step and for auto-reset. 1 2 3 4 5 6 7 8 from jumanji.environments import GraphColoring env = GraphColoring () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state ) __init__ ( self , generator : Optional [ jumanji . environments . logic . graph_coloring . generator . Generator ] = None , viewer : Optional [ jumanji . viewer . Viewer [ jumanji . environments . logic . graph_coloring . types . State ]] = None ) special # Instantiate a GraphColoring environment. Parameters: Name Type Description Default generator Optional[jumanji.environments.logic.graph_coloring.generator.Generator] callable to instantiate environment instances. Defaults to RandomGenerator which generates graphs with 20 num_nodes and edge_probability equal to 0.8. None viewer Optional[jumanji.viewer.Viewer[jumanji.environments.logic.graph_coloring.types.State]] environment viewer for rendering. Defaults to GraphColoringViewer . None reset ( self , key : PRNGKeyArray ) -> Tuple [ jumanji . environments . logic . graph_coloring . types . State , jumanji . types . TimeStep [ jumanji . environments . logic . graph_coloring . types . Observation ]] # Resets the environment to an initial state. Returns: Type Description Tuple[jumanji.environments.logic.graph_coloring.types.State, jumanji.types.TimeStep[jumanji.environments.logic.graph_coloring.types.Observation]] The initial state and timestep. step ( self , state : State , action : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ]) -> Tuple [ jumanji . environments . logic . graph_coloring . types . State , jumanji . types . TimeStep [ jumanji . environments . logic . graph_coloring . types . Observation ]] # Updates the environment state after the agent takes an action. Specifically, this function allows the agent to choose a color for the current node (based on the action taken) in a graph coloring problem. It then updates the state of the environment based on the color chosen and calculates the reward based on the validity of the action and the completion of the coloring task. Parameters: Name Type Description Default state State the current state of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] the action taken by the agent. required Returns: Type Description state the new state of the environment. timestep: the next timestep. observation_spec ( self ) -> jumanji . specs . Spec [ jumanji . environments . logic . graph_coloring . types . Observation ] # Returns the observation spec. Returns: Type Description Spec for the `Observation` whose fields are adj_matrix: BoundedArray (bool) of shape (num_nodes, num_nodes). Represents the adjacency matrix of the graph. action_mask: BoundedArray (bool) of shape (num_nodes,). Represents the valid actions in the current state. colors: BoundedArray (int32) of shape (num_nodes,). Represents the colors assigned to each node. current_node_index: BoundedArray (int32) of shape (). Represents the index of the current node. action_spec ( self ) -> DiscreteArray # Specification of the action for the GraphColoring environment. Returns: Type Description action_spec specs.DiscreteArray object","title":"GraphColoring"},{"location":"api/environments/graph_coloring/#jumanji.environments.logic.graph_coloring.env.GraphColoring","text":"Environment for the GraphColoring problem. The problem is a combinatorial optimization task where the goal is to assign a color to each vertex of a graph in such a way that no two adjacent vertices share the same color. The problem is usually formulated as minimizing the number of colors used. observation: Observation adj_matrix: jax array (bool) of shape (num_nodes, num_nodes), representing the adjacency matrix of the graph. colors: jax array (int32) of shape (num_nodes,), representing the current color assignments for the vertices. action_mask: jax array (bool) of shape (num_colors,), indicating which actions are valid in the current state of the environment. current_node_index: integer representing the current node being colored. action: int, the color to be assigned to the current node (0 to num_nodes - 1) reward: float, a sparse reward is provided at the end of the episode. Equals the negative of the number of unique colors used to color all vertices in the graph. If an invalid action is taken, the reward is the negative of the total number of colors. episode termination: if all nodes have been assigned a color or if an invalid action is taken. state: State adj_matrix: jax array (bool) of shape (num_nodes, num_nodes), representing the adjacency matrix of the graph. colors: jax array (int32) of shape (num_nodes,), color assigned to each node, -1 if not assigned. current_node_index: jax array (int) with shape (), index of the current node. action_mask: jax array (bool) of shape (num_colors,), indicating which actions are valid in the current state of the environment. key: jax array (uint32) of shape (2,), random key used to generate random numbers at each step and for auto-reset. 1 2 3 4 5 6 7 8 from jumanji.environments import GraphColoring env = GraphColoring () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state )","title":"GraphColoring"},{"location":"api/environments/graph_coloring/#jumanji.environments.logic.graph_coloring.env.GraphColoring.__init__","text":"Instantiate a GraphColoring environment. Parameters: Name Type Description Default generator Optional[jumanji.environments.logic.graph_coloring.generator.Generator] callable to instantiate environment instances. Defaults to RandomGenerator which generates graphs with 20 num_nodes and edge_probability equal to 0.8. None viewer Optional[jumanji.viewer.Viewer[jumanji.environments.logic.graph_coloring.types.State]] environment viewer for rendering. Defaults to GraphColoringViewer . None","title":"__init__()"},{"location":"api/environments/graph_coloring/#jumanji.environments.logic.graph_coloring.env.GraphColoring.reset","text":"Resets the environment to an initial state. Returns: Type Description Tuple[jumanji.environments.logic.graph_coloring.types.State, jumanji.types.TimeStep[jumanji.environments.logic.graph_coloring.types.Observation]] The initial state and timestep.","title":"reset()"},{"location":"api/environments/graph_coloring/#jumanji.environments.logic.graph_coloring.env.GraphColoring.step","text":"Updates the environment state after the agent takes an action. Specifically, this function allows the agent to choose a color for the current node (based on the action taken) in a graph coloring problem. It then updates the state of the environment based on the color chosen and calculates the reward based on the validity of the action and the completion of the coloring task. Parameters: Name Type Description Default state State the current state of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] the action taken by the agent. required Returns: Type Description state the new state of the environment. timestep: the next timestep.","title":"step()"},{"location":"api/environments/graph_coloring/#jumanji.environments.logic.graph_coloring.env.GraphColoring.observation_spec","text":"Returns the observation spec. Returns: Type Description Spec for the `Observation` whose fields are adj_matrix: BoundedArray (bool) of shape (num_nodes, num_nodes). Represents the adjacency matrix of the graph. action_mask: BoundedArray (bool) of shape (num_nodes,). Represents the valid actions in the current state. colors: BoundedArray (int32) of shape (num_nodes,). Represents the colors assigned to each node. current_node_index: BoundedArray (int32) of shape (). Represents the index of the current node.","title":"observation_spec()"},{"location":"api/environments/graph_coloring/#jumanji.environments.logic.graph_coloring.env.GraphColoring.action_spec","text":"Specification of the action for the GraphColoring environment. Returns: Type Description action_spec specs.DiscreteArray object","title":"action_spec()"},{"location":"api/environments/job_shop/","text":"JobShop ( Environment ) # The Job Shop Scheduling Problem, as described in [1], is one of the best known combinatorial optimization problems. We are given num_jobs jobs, each consisting of at most max_num_ops ops, which need to be processed on num_machines machines. Each operation (op) has a specific machine that it needs to be processed on and a duration (which must be less than or equal to max_duration_op ). The goal is to minimise the total length of the schedule, also known as the makespan. [1] https://developers.google.com/optimization/scheduling/job_shop. observation: Observation ops_machine_ids: jax array (int32) of (num_jobs, max_num_ops) id of the machine each operation must be processed on. ops_durations: jax array (int32) of (num_jobs, max_num_ops) processing time of each operation. ops_mask: jax array (bool) of (num_jobs, max_num_ops) indicating which operations have yet to be scheduled. machines_job_ids: jax array (int32) of shape (num_machines,) id of the job (or no-op) that each machine is processing. machines_remaining_times: jax array (int32) of shape (num_machines,) specifying, for each machine, the number of time steps until available. action_mask: jax array (bool) of shape (num_machines, num_jobs + 1) indicates which job(s) (or no-op) can legally be scheduled on each machine. action: jax array (int32) of shape (num_machines,). reward: jax array (float) of shape (). A reward of -1 is given each time step. If all machines are simultaneously idle or the agent selects an invalid action, the agent is given a large penalty of -num_jobs * max_num_ops * max_op_duration which is an upper bound on the makespan. episode termination: Finished schedule: all operations (and thus all jobs) every job have been processed. Illegal action: the agent ignores the action mask and takes an illegal action. Simultaneously idle: all machines are inactive at the same time. state: State ops_machine_ids: same as observation. ops_durations: same as observation. ops_mask: same as observation. machines_job_ids: same as observation. machines_remaining_times: same as observation. action_mask: same as observation. step_count: jax array (int32) of shape (), the number of time steps in the episode so far. scheduled_times: jax array (int32) of shape (num_jobs, max_num_ops), specifying the timestep at which every op (scheduled so far) was scheduled. 1 2 3 4 5 6 7 8 from jumanji.environments import JobShop env = JobShop () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state ) __init__ ( self , generator : Optional [ jumanji . environments . packing . job_shop . generator . Generator ] = None , viewer : Optional [ jumanji . viewer . Viewer [ jumanji . environments . packing . job_shop . types . State ]] = None ) special # Instantiate a JobShop environment. Parameters: Name Type Description Default generator Optional[jumanji.environments.packing.job_shop.generator.Generator] Generator whose __call__ instantiates an environment instance. Implemented options are ['ToyGenerator', 'RandomGenerator']. Defaults to RandomGenerator with 20 jobs, 10 machines, up to 8 ops for any given job, and a max operation duration of 6. None viewer Optional[jumanji.viewer.Viewer[jumanji.environments.packing.job_shop.types.State]] Viewer used for rendering. Defaults to JobShopViewer . None reset ( self , key : PRNGKeyArray ) -> Tuple [ jumanji . environments . packing . job_shop . types . State , jumanji . types . TimeStep [ jumanji . environments . packing . job_shop . types . Observation ]] # Resets the environment by creating a new problem instance and initialising the state and timestep. Parameters: Name Type Description Default key PRNGKeyArray random key used to reset the environment. required Returns: Type Description state the environment state after the reset. timestep: the first timestep returned by the environment after the reset. step ( self , state : State , action : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ]) -> Tuple [ jumanji . environments . packing . job_shop . types . State , jumanji . types . TimeStep [ jumanji . environments . packing . job_shop . types . Observation ]] # Updates the status of all machines, the status of the operations, and increments the time step. It updates the environment state and the timestep (which contains the new observation). It calculates the reward based on the three terminal conditions: - The action provided by the agent is invalid. - The schedule has finished. - All machines do a no-op that leads to all machines being simultaneously idle. Parameters: Name Type Description Default state State the environment state. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] the action to take. required Returns: Type Description state the updated environment state. timestep: the updated timestep. observation_spec ( self ) -> jumanji . specs . Spec [ jumanji . environments . packing . job_shop . types . Observation ] # Specifications of the observation of the JobShop environment. Returns: Type Description Spec containing the specifications for all the `Observation` fields ops_machine_ids: BoundedArray (int32) of shape (num_jobs, max_num_ops). ops_durations: BoundedArray (int32) of shape (num_jobs, max_num_ops). ops_mask: BoundedArray (bool) of shape (num_jobs, max_num_ops). machines_job_ids: BoundedArray (int32) of shape (num_machines,). machines_remaining_times: BoundedArray (int32) of shape (num_machines,). action_mask: BoundedArray (bool) of shape (num_machines, num_jobs + 1). action_spec ( self ) -> MultiDiscreteArray # Specifications of the action in the JobShop environment. The action gives each machine a job id ranging from 0, 1, ..., num_jobs where the last value corresponds to a no-op. Returns: Type Description action_spec a specs.MultiDiscreteArray spec.","title":"JobShop"},{"location":"api/environments/job_shop/#jumanji.environments.packing.job_shop.env.JobShop","text":"The Job Shop Scheduling Problem, as described in [1], is one of the best known combinatorial optimization problems. We are given num_jobs jobs, each consisting of at most max_num_ops ops, which need to be processed on num_machines machines. Each operation (op) has a specific machine that it needs to be processed on and a duration (which must be less than or equal to max_duration_op ). The goal is to minimise the total length of the schedule, also known as the makespan. [1] https://developers.google.com/optimization/scheduling/job_shop. observation: Observation ops_machine_ids: jax array (int32) of (num_jobs, max_num_ops) id of the machine each operation must be processed on. ops_durations: jax array (int32) of (num_jobs, max_num_ops) processing time of each operation. ops_mask: jax array (bool) of (num_jobs, max_num_ops) indicating which operations have yet to be scheduled. machines_job_ids: jax array (int32) of shape (num_machines,) id of the job (or no-op) that each machine is processing. machines_remaining_times: jax array (int32) of shape (num_machines,) specifying, for each machine, the number of time steps until available. action_mask: jax array (bool) of shape (num_machines, num_jobs + 1) indicates which job(s) (or no-op) can legally be scheduled on each machine. action: jax array (int32) of shape (num_machines,). reward: jax array (float) of shape (). A reward of -1 is given each time step. If all machines are simultaneously idle or the agent selects an invalid action, the agent is given a large penalty of -num_jobs * max_num_ops * max_op_duration which is an upper bound on the makespan. episode termination: Finished schedule: all operations (and thus all jobs) every job have been processed. Illegal action: the agent ignores the action mask and takes an illegal action. Simultaneously idle: all machines are inactive at the same time. state: State ops_machine_ids: same as observation. ops_durations: same as observation. ops_mask: same as observation. machines_job_ids: same as observation. machines_remaining_times: same as observation. action_mask: same as observation. step_count: jax array (int32) of shape (), the number of time steps in the episode so far. scheduled_times: jax array (int32) of shape (num_jobs, max_num_ops), specifying the timestep at which every op (scheduled so far) was scheduled. 1 2 3 4 5 6 7 8 from jumanji.environments import JobShop env = JobShop () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state )","title":"JobShop"},{"location":"api/environments/job_shop/#jumanji.environments.packing.job_shop.env.JobShop.__init__","text":"Instantiate a JobShop environment. Parameters: Name Type Description Default generator Optional[jumanji.environments.packing.job_shop.generator.Generator] Generator whose __call__ instantiates an environment instance. Implemented options are ['ToyGenerator', 'RandomGenerator']. Defaults to RandomGenerator with 20 jobs, 10 machines, up to 8 ops for any given job, and a max operation duration of 6. None viewer Optional[jumanji.viewer.Viewer[jumanji.environments.packing.job_shop.types.State]] Viewer used for rendering. Defaults to JobShopViewer . None","title":"__init__()"},{"location":"api/environments/job_shop/#jumanji.environments.packing.job_shop.env.JobShop.reset","text":"Resets the environment by creating a new problem instance and initialising the state and timestep. Parameters: Name Type Description Default key PRNGKeyArray random key used to reset the environment. required Returns: Type Description state the environment state after the reset. timestep: the first timestep returned by the environment after the reset.","title":"reset()"},{"location":"api/environments/job_shop/#jumanji.environments.packing.job_shop.env.JobShop.step","text":"Updates the status of all machines, the status of the operations, and increments the time step. It updates the environment state and the timestep (which contains the new observation). It calculates the reward based on the three terminal conditions: - The action provided by the agent is invalid. - The schedule has finished. - All machines do a no-op that leads to all machines being simultaneously idle. Parameters: Name Type Description Default state State the environment state. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] the action to take. required Returns: Type Description state the updated environment state. timestep: the updated timestep.","title":"step()"},{"location":"api/environments/job_shop/#jumanji.environments.packing.job_shop.env.JobShop.observation_spec","text":"Specifications of the observation of the JobShop environment. Returns: Type Description Spec containing the specifications for all the `Observation` fields ops_machine_ids: BoundedArray (int32) of shape (num_jobs, max_num_ops). ops_durations: BoundedArray (int32) of shape (num_jobs, max_num_ops). ops_mask: BoundedArray (bool) of shape (num_jobs, max_num_ops). machines_job_ids: BoundedArray (int32) of shape (num_machines,). machines_remaining_times: BoundedArray (int32) of shape (num_machines,). action_mask: BoundedArray (bool) of shape (num_machines, num_jobs + 1).","title":"observation_spec()"},{"location":"api/environments/job_shop/#jumanji.environments.packing.job_shop.env.JobShop.action_spec","text":"Specifications of the action in the JobShop environment. The action gives each machine a job id ranging from 0, 1, ..., num_jobs where the last value corresponds to a no-op. Returns: Type Description action_spec a specs.MultiDiscreteArray spec.","title":"action_spec()"},{"location":"api/environments/knapsack/","text":"Knapsack ( Environment ) # Knapsack environment as described in [1]. observation: Observation weights: jax array (float) of shape (num_items,) the weights of the items. values: jax array (float) of shape (num_items,) the values of the items. packed_items: jax array (bool) of shape (num_items,) binary mask denoting which items are already packed into the knapsack. action_mask: jax array (bool) of shape (num_items,) binary mask denoting which items can be packed into the knapsack. action: jax array (int32) of shape () [0, ..., num_items - 1] -> item to pack. reward: jax array (float) of shape (), could be either: dense: the value of the item to pack at the current timestep. sparse: the sum of the values of the items packed in the bag at the end of the episode. In both cases, the reward is 0 if the action is invalid, i.e. an item that was previously selected is selected again or has a weight larger than the bag capacity. episode termination: if no action can be performed, i.e. all items are packed or each remaining item's weight is larger than the bag capacity. if an invalid action is taken, i.e. the chosen item is already packed or has a weight larger than the bag capacity. state: State weights: jax array (float) of shape (num_items,) the weights of the items. values: jax array (float) of shape (num_items,) the values of the items. packed_items: jax array (bool) of shape (num_items,) binary mask denoting which items are already packed into the knapsack. remaining_budget: jax array (float) the budget currently remaining. [1] https://arxiv.org/abs/2010.16011 1 2 3 4 5 6 7 8 from jumanji.environments import Knapsack env = Knapsack () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state ) __init__ ( self , generator : Optional [ jumanji . environments . packing . knapsack . generator . Generator ] = None , reward_fn : Optional [ jumanji . environments . packing . knapsack . reward . RewardFn ] = None , viewer : Optional [ jumanji . viewer . Viewer [ jumanji . environments . packing . knapsack . types . State ]] = None ) special # Instantiates a Knapsack environment. Parameters: Name Type Description Default generator Optional[jumanji.environments.packing.knapsack.generator.Generator] Generator whose __call__ instantiates an environment instance. The default option is 'RandomGenerator' which samples Knapsack instances with 50 items and a total budget of 12.5. None reward_fn Optional[jumanji.environments.packing.knapsack.reward.RewardFn] RewardFn whose __call__ method computes the reward of an environment transition. The function must compute the reward based on the current state, the chosen action, the next state and whether the action is valid. Implemented options are [ DenseReward , SparseReward ]. Defaults to DenseReward . None viewer Optional[jumanji.viewer.Viewer[jumanji.environments.packing.knapsack.types.State]] Viewer used for rendering. Defaults to KnapsackViewer with \"human\" render mode. None reset ( self , key : PRNGKeyArray ) -> Tuple [ jumanji . environments . packing . knapsack . types . State , jumanji . types . TimeStep [ jumanji . environments . packing . knapsack . types . Observation ]] # Resets the environment. Parameters: Name Type Description Default key PRNGKeyArray used to randomly generate the weights and values of the items. required Returns: Type Description state the new state of the environment. timestep: the first timestep returned by the environment. step ( self , state : State , action : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number , float , int ]) -> Tuple [ jumanji . environments . packing . knapsack . types . State , jumanji . types . TimeStep [ jumanji . environments . packing . knapsack . types . Observation ]] # Run one timestep of the environment's dynamics. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, float, int] index of next item to take. required Returns: Type Description state next state of the environment. timestep: the timestep to be observed. observation_spec ( self ) -> jumanji . specs . Spec [ jumanji . environments . packing . knapsack . types . Observation ] # Returns the observation spec. Returns: Type Description Spec for each field in the Observation weights: BoundedArray (float) of shape (num_items,). values: BoundedArray (float) of shape (num_items,). packed_items: BoundedArray (bool) of shape (num_items,). action_mask: BoundedArray (bool) of shape (num_items,). action_spec ( self ) -> DiscreteArray # Returns the action spec. Returns: Type Description action_spec a specs.DiscreteArray spec.","title":"Knapsack"},{"location":"api/environments/knapsack/#jumanji.environments.packing.knapsack.env.Knapsack","text":"Knapsack environment as described in [1]. observation: Observation weights: jax array (float) of shape (num_items,) the weights of the items. values: jax array (float) of shape (num_items,) the values of the items. packed_items: jax array (bool) of shape (num_items,) binary mask denoting which items are already packed into the knapsack. action_mask: jax array (bool) of shape (num_items,) binary mask denoting which items can be packed into the knapsack. action: jax array (int32) of shape () [0, ..., num_items - 1] -> item to pack. reward: jax array (float) of shape (), could be either: dense: the value of the item to pack at the current timestep. sparse: the sum of the values of the items packed in the bag at the end of the episode. In both cases, the reward is 0 if the action is invalid, i.e. an item that was previously selected is selected again or has a weight larger than the bag capacity. episode termination: if no action can be performed, i.e. all items are packed or each remaining item's weight is larger than the bag capacity. if an invalid action is taken, i.e. the chosen item is already packed or has a weight larger than the bag capacity. state: State weights: jax array (float) of shape (num_items,) the weights of the items. values: jax array (float) of shape (num_items,) the values of the items. packed_items: jax array (bool) of shape (num_items,) binary mask denoting which items are already packed into the knapsack. remaining_budget: jax array (float) the budget currently remaining. [1] https://arxiv.org/abs/2010.16011 1 2 3 4 5 6 7 8 from jumanji.environments import Knapsack env = Knapsack () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state )","title":"Knapsack"},{"location":"api/environments/knapsack/#jumanji.environments.packing.knapsack.env.Knapsack.__init__","text":"Instantiates a Knapsack environment. Parameters: Name Type Description Default generator Optional[jumanji.environments.packing.knapsack.generator.Generator] Generator whose __call__ instantiates an environment instance. The default option is 'RandomGenerator' which samples Knapsack instances with 50 items and a total budget of 12.5. None reward_fn Optional[jumanji.environments.packing.knapsack.reward.RewardFn] RewardFn whose __call__ method computes the reward of an environment transition. The function must compute the reward based on the current state, the chosen action, the next state and whether the action is valid. Implemented options are [ DenseReward , SparseReward ]. Defaults to DenseReward . None viewer Optional[jumanji.viewer.Viewer[jumanji.environments.packing.knapsack.types.State]] Viewer used for rendering. Defaults to KnapsackViewer with \"human\" render mode. None","title":"__init__()"},{"location":"api/environments/knapsack/#jumanji.environments.packing.knapsack.env.Knapsack.reset","text":"Resets the environment. Parameters: Name Type Description Default key PRNGKeyArray used to randomly generate the weights and values of the items. required Returns: Type Description state the new state of the environment. timestep: the first timestep returned by the environment.","title":"reset()"},{"location":"api/environments/knapsack/#jumanji.environments.packing.knapsack.env.Knapsack.step","text":"Run one timestep of the environment's dynamics. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, float, int] index of next item to take. required Returns: Type Description state next state of the environment. timestep: the timestep to be observed.","title":"step()"},{"location":"api/environments/knapsack/#jumanji.environments.packing.knapsack.env.Knapsack.observation_spec","text":"Returns the observation spec. Returns: Type Description Spec for each field in the Observation weights: BoundedArray (float) of shape (num_items,). values: BoundedArray (float) of shape (num_items,). packed_items: BoundedArray (bool) of shape (num_items,). action_mask: BoundedArray (bool) of shape (num_items,).","title":"observation_spec()"},{"location":"api/environments/knapsack/#jumanji.environments.packing.knapsack.env.Knapsack.action_spec","text":"Returns the action spec. Returns: Type Description action_spec a specs.DiscreteArray spec.","title":"action_spec()"},{"location":"api/environments/macvrp/","text":"MultiCVRP ( Environment ) # Multi-Vehicle Routing Problems with Soft Time Windows (MVRPSTW) environment as described in [1]. We simplfy the naming to multi-agent capacitated vehicle routing problem (MultiCVRP). reward: jax array (float32) this global reward is provided to each agent. The reward is equal to the negative sum of the distances between consecutive nodes at the end of the episode over all agents. All time penalties are also added to the reward. observation and state: the observation and state variable types are defined in: jumanji/environments/routing/multi_cvrp/types.py [1] Zhang et al. (2020). \"Multi-Vehicle Routing Problems with Soft Time Windows: A Multi-Agent Reinforcement Learning Approach\". __init__ ( self , generator : Optional [ jumanji . environments . routing . multi_cvrp . generator . Generator ] = None , reward_fn : Optional [ jumanji . environments . routing . multi_cvrp . reward . RewardFn ] = None , viewer : Optional [ jumanji . viewer . Viewer ] = None ) special # Instantiates a MultiCVRP environment. Parameters: Name Type Description Default generator Optional[jumanji.environments.routing.multi_cvrp.generator.Generator] Generator whose __call__ instantiates an environment instance. Implemented options are [ UniformRandomGenerator ]. Defaults to UniformRandomGenerator with num_customers=20 and num_vehicles=2 . None reward_fn Optional[jumanji.environments.routing.multi_cvrp.reward.RewardFn] RewardFn whose __call__ method computes the reward of an environment transition. The function must compute the reward based on the current state and whether the environment is done. Implemented options are [ DenseReward , SparseReward ]. Defaults to DenseReward . None viewer Optional[jumanji.viewer.Viewer] Viewer used for rendering. Defaults to MultiCVRPViewer with \"human\" render mode. None reset ( self , key : PRNGKeyArray ) -> Tuple [ jumanji . environments . routing . multi_cvrp . types . State , jumanji . types . TimeStep [ jumanji . environments . routing . multi_cvrp . types . Observation ]] # Resets the environment. Parameters: Name Type Description Default key PRNGKeyArray used to randomly generate the problem and the start node. required Returns: Type Description state State object corresponding to the new state of the environment. timestep: TimeStep object corresponding to the first timestep returned by the environment. step ( self , state : State , action : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ]) -> Tuple [ jumanji . environments . routing . multi_cvrp . types . State , jumanji . types . TimeStep [ jumanji . environments . routing . multi_cvrp . types . Observation ]] # Run one timestep of the environment's dynamics. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] Array containing the index of the next nodes to visit. required Returns: Type Description state, timestep Tuple[State, TimeStep] containing the next state of the environment, as well as the timestep to be observed. observation_spec ( self ) -> jumanji . specs . Spec [ jumanji . environments . routing . multi_cvrp . types . Observation ] # Returns the observation spec. Returns: Type Description observation_spec a Tuple containing the spec for each of the constituent fields of an observation. action_spec ( self ) -> BoundedArray # Returns the action spec. Returns: Type Description action_spec a specs.BoundedArray spec.","title":"Macvrp"},{"location":"api/environments/macvrp/#jumanji.environments.routing.multi_cvrp.env.MultiCVRP","text":"Multi-Vehicle Routing Problems with Soft Time Windows (MVRPSTW) environment as described in [1]. We simplfy the naming to multi-agent capacitated vehicle routing problem (MultiCVRP). reward: jax array (float32) this global reward is provided to each agent. The reward is equal to the negative sum of the distances between consecutive nodes at the end of the episode over all agents. All time penalties are also added to the reward. observation and state: the observation and state variable types are defined in: jumanji/environments/routing/multi_cvrp/types.py [1] Zhang et al. (2020). \"Multi-Vehicle Routing Problems with Soft Time Windows: A Multi-Agent Reinforcement Learning Approach\".","title":"MultiCVRP"},{"location":"api/environments/macvrp/#jumanji.environments.routing.multi_cvrp.env.MultiCVRP.__init__","text":"Instantiates a MultiCVRP environment. Parameters: Name Type Description Default generator Optional[jumanji.environments.routing.multi_cvrp.generator.Generator] Generator whose __call__ instantiates an environment instance. Implemented options are [ UniformRandomGenerator ]. Defaults to UniformRandomGenerator with num_customers=20 and num_vehicles=2 . None reward_fn Optional[jumanji.environments.routing.multi_cvrp.reward.RewardFn] RewardFn whose __call__ method computes the reward of an environment transition. The function must compute the reward based on the current state and whether the environment is done. Implemented options are [ DenseReward , SparseReward ]. Defaults to DenseReward . None viewer Optional[jumanji.viewer.Viewer] Viewer used for rendering. Defaults to MultiCVRPViewer with \"human\" render mode. None","title":"__init__()"},{"location":"api/environments/macvrp/#jumanji.environments.routing.multi_cvrp.env.MultiCVRP.reset","text":"Resets the environment. Parameters: Name Type Description Default key PRNGKeyArray used to randomly generate the problem and the start node. required Returns: Type Description state State object corresponding to the new state of the environment. timestep: TimeStep object corresponding to the first timestep returned by the environment.","title":"reset()"},{"location":"api/environments/macvrp/#jumanji.environments.routing.multi_cvrp.env.MultiCVRP.step","text":"Run one timestep of the environment's dynamics. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] Array containing the index of the next nodes to visit. required Returns: Type Description state, timestep Tuple[State, TimeStep] containing the next state of the environment, as well as the timestep to be observed.","title":"step()"},{"location":"api/environments/macvrp/#jumanji.environments.routing.multi_cvrp.env.MultiCVRP.observation_spec","text":"Returns the observation spec. Returns: Type Description observation_spec a Tuple containing the spec for each of the constituent fields of an observation.","title":"observation_spec()"},{"location":"api/environments/macvrp/#jumanji.environments.routing.multi_cvrp.env.MultiCVRP.action_spec","text":"Returns the action spec. Returns: Type Description action_spec a specs.BoundedArray spec.","title":"action_spec()"},{"location":"api/environments/maze/","text":"Maze ( Environment ) # A JAX implementation of a 2D Maze. The goal is to navigate the maze to find the target position. observation: agent_position: current 2D Position of agent. target_position: 2D Position of target cell. walls: jax array (bool) of shape (num_rows, num_cols) whose values are True where walls are and False for empty cells. action_mask: array (bool) of shape (4,) defining the available actions in the current position. step_count: jax array (int32) of shape () step number of the episode. action: jax array (int32) of shape () specifying which action to take: [0,1,2,3] correspond to [Up, Right, Down, Left]. If an invalid action is taken, i.e. there is a wall blocking the action, then no action (no-op) is taken. reward: jax array (float32) of shape (): 1 if the target is reached, 0 otherwise. episode termination (if any): agent reaches the target position. the time_limit is reached. state: State: agent_position: current 2D Position of agent. target_position: 2D Position of target cell. walls: jax array (bool) of shape (num_rows, num_cols) whose values are True where walls are and False for empty cells. action_mask: array (bool) of shape (4,) defining the available actions in the current position. step_count: jax array (int32) of shape () step number of the episode. key: random key (uint) of shape (2,). 1 2 3 4 5 6 7 8 from jumanji.environments import Maze env = Maze () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state ) __init__ ( self , generator : Optional [ jumanji . environments . routing . maze . generator . Generator ] = None , time_limit : Optional [ int ] = None , viewer : Optional [ jumanji . viewer . Viewer [ jumanji . environments . routing . maze . types . State ]] = None ) -> None special # Instantiates a Maze environment. Parameters: Name Type Description Default generator Optional[jumanji.environments.routing.maze.generator.Generator] Generator whose __call__ instantiates an environment instance. Implemented options are [ ToyGenerator , RandomGenerator ]. Defaults to RandomGenerator with num_rows=10 and num_cols=10 . None time_limit Optional[int] the time_limit of an episode, i.e. the maximum number of environment steps before the episode terminates. By default, time_limit = num_rows * num_cols . None viewer Optional[jumanji.viewer.Viewer[jumanji.environments.routing.maze.types.State]] Viewer used for rendering. Defaults to MazeEnvViewer with \"human\" render mode. None observation_spec ( self ) -> jumanji . specs . Spec [ jumanji . environments . routing . maze . types . Observation ] # Specifications of the observation of the Maze environment. Returns: Type Description Spec for the `Observation` whose fields are agent_position: tree of BoundedArray (int32) of shape (). target_position: tree of BoundedArray (int32) of shape (). walls: BoundedArray (bool) of shape (num_rows, num_cols). step_count: Array (int32) of shape (). action_mask: BoundedArray (bool) of shape (4,). action_spec ( self ) -> DiscreteArray # Returns the action spec. 4 actions: [0,1,2,3] -> [Up, Right, Down, Left]. Returns: Type Description action_spec discrete action space with 4 values. reset ( self , key : PRNGKeyArray ) -> Tuple [ jumanji . environments . routing . maze . types . State , jumanji . types . TimeStep [ jumanji . environments . routing . maze . types . Observation ]] # Resets the environment by calling the instance generator for a new instance. Parameters: Name Type Description Default key PRNGKeyArray random key used to reset the environment since it is stochastic. required Returns: Type Description state State object corresponding to the new state of the environment after a reset. timestep: TimeStep object corresponding the first timestep returned by the environment after a reset. step ( self , state : State , action : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ]) -> Tuple [ jumanji . environments . routing . maze . types . State , jumanji . types . TimeStep [ jumanji . environments . routing . maze . types . Observation ]] # Run one timestep of the environment's dynamics. If an action is invalid, the agent does not move, i.e. the episode does not automatically terminate. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] (int32) specifying which action to take: [0,1,2,3] correspond to [Up, Right, Down, Left]. If an invalid action is taken, i.e. there is a wall blocking the action, then no action (no-op) is taken. required Returns: Type Description state the next state of the environment. timestep: the next timestep to be observed.","title":"Maze"},{"location":"api/environments/maze/#jumanji.environments.routing.maze.env.Maze","text":"A JAX implementation of a 2D Maze. The goal is to navigate the maze to find the target position. observation: agent_position: current 2D Position of agent. target_position: 2D Position of target cell. walls: jax array (bool) of shape (num_rows, num_cols) whose values are True where walls are and False for empty cells. action_mask: array (bool) of shape (4,) defining the available actions in the current position. step_count: jax array (int32) of shape () step number of the episode. action: jax array (int32) of shape () specifying which action to take: [0,1,2,3] correspond to [Up, Right, Down, Left]. If an invalid action is taken, i.e. there is a wall blocking the action, then no action (no-op) is taken. reward: jax array (float32) of shape (): 1 if the target is reached, 0 otherwise. episode termination (if any): agent reaches the target position. the time_limit is reached. state: State: agent_position: current 2D Position of agent. target_position: 2D Position of target cell. walls: jax array (bool) of shape (num_rows, num_cols) whose values are True where walls are and False for empty cells. action_mask: array (bool) of shape (4,) defining the available actions in the current position. step_count: jax array (int32) of shape () step number of the episode. key: random key (uint) of shape (2,). 1 2 3 4 5 6 7 8 from jumanji.environments import Maze env = Maze () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state )","title":"Maze"},{"location":"api/environments/maze/#jumanji.environments.routing.maze.env.Maze.__init__","text":"Instantiates a Maze environment. Parameters: Name Type Description Default generator Optional[jumanji.environments.routing.maze.generator.Generator] Generator whose __call__ instantiates an environment instance. Implemented options are [ ToyGenerator , RandomGenerator ]. Defaults to RandomGenerator with num_rows=10 and num_cols=10 . None time_limit Optional[int] the time_limit of an episode, i.e. the maximum number of environment steps before the episode terminates. By default, time_limit = num_rows * num_cols . None viewer Optional[jumanji.viewer.Viewer[jumanji.environments.routing.maze.types.State]] Viewer used for rendering. Defaults to MazeEnvViewer with \"human\" render mode. None","title":"__init__()"},{"location":"api/environments/maze/#jumanji.environments.routing.maze.env.Maze.observation_spec","text":"Specifications of the observation of the Maze environment. Returns: Type Description Spec for the `Observation` whose fields are agent_position: tree of BoundedArray (int32) of shape (). target_position: tree of BoundedArray (int32) of shape (). walls: BoundedArray (bool) of shape (num_rows, num_cols). step_count: Array (int32) of shape (). action_mask: BoundedArray (bool) of shape (4,).","title":"observation_spec()"},{"location":"api/environments/maze/#jumanji.environments.routing.maze.env.Maze.action_spec","text":"Returns the action spec. 4 actions: [0,1,2,3] -> [Up, Right, Down, Left]. Returns: Type Description action_spec discrete action space with 4 values.","title":"action_spec()"},{"location":"api/environments/maze/#jumanji.environments.routing.maze.env.Maze.reset","text":"Resets the environment by calling the instance generator for a new instance. Parameters: Name Type Description Default key PRNGKeyArray random key used to reset the environment since it is stochastic. required Returns: Type Description state State object corresponding to the new state of the environment after a reset. timestep: TimeStep object corresponding the first timestep returned by the environment after a reset.","title":"reset()"},{"location":"api/environments/maze/#jumanji.environments.routing.maze.env.Maze.step","text":"Run one timestep of the environment's dynamics. If an action is invalid, the agent does not move, i.e. the episode does not automatically terminate. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] (int32) specifying which action to take: [0,1,2,3] correspond to [Up, Right, Down, Left]. If an invalid action is taken, i.e. there is a wall blocking the action, then no action (no-op) is taken. required Returns: Type Description state the next state of the environment. timestep: the next timestep to be observed.","title":"step()"},{"location":"api/environments/minesweeper/","text":"Minesweeper ( Environment ) # A JAX implementation of the minesweeper game. observation: Observation board: jax array (int32) of shape (num_rows, num_cols): each cell contains -1 if not yet explored, or otherwise the number of mines in the 8 adjacent squares. action_mask: jax array (bool) of shape (num_rows, num_cols): indicates which actions are valid (not yet explored squares). num_mines: jax array (int32) of shape () , indicates the number of mines to locate. step_count: jax array (int32) of shape (): specifies how many timesteps have elapsed since environment reset. action: multi discrete array containing the square to explore (row and col). reward: jax array (float32): Configurable function of state and action. By default: 1 for every timestep where a valid action is chosen that doesn't reveal a mine, 0 for revealing a mine or selecting an already revealed square (and terminate the episode). episode termination: Configurable function of state, next_state, and action. By default: Stop the episode if a mine is explored, an invalid action is selected (exploring an already explored square), or the board is solved. state: State board: jax array (int32) of shape (num_rows, num_cols): each cell contains -1 if not yet explored, or otherwise the number of mines in the 8 adjacent squares. step_count: jax array (int32) of shape (): specifies how many timesteps have elapsed since environment reset. flat_mine_locations: jax array (int32) of shape (num_rows * num_cols,): indicates the (flat) locations of all the mines on the board. Will be of length num_mines. key: jax array (int32) of shape (2,) used for seeding the sampling of mine placement on reset. 1 2 3 4 5 6 7 8 from jumanji.environments import Minesweeper env = Minesweeper () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state ) __init__ ( self , generator : Optional [ jumanji . environments . logic . minesweeper . generator . Generator ] = None , reward_function : Optional [ jumanji . environments . logic . minesweeper . reward . RewardFn ] = None , done_function : Optional [ jumanji . environments . logic . minesweeper . done . DoneFn ] = None , viewer : Optional [ jumanji . viewer . Viewer [ jumanji . environments . logic . minesweeper . types . State ]] = None ) special # Instantiate a Minesweeper environment. Parameters: Name Type Description Default generator Optional[jumanji.environments.logic.minesweeper.generator.Generator] Generator to generate problem instances on environment reset. Implemented options are [ SamplingGenerator ]. Defaults to SamplingGenerator . The generator will have attributes: - num_rows: number of rows, i.e. height of the board. Defaults to 10. - num_cols: number of columns, i.e. width of the board. Defaults to 10. - num_mines: number of mines generated. Defaults to 10. None reward_function Optional[jumanji.environments.logic.minesweeper.reward.RewardFn] RewardFn whose __call__ method computes the reward of an environment transition based on the given current state and selected action. Implemented options are [ DefaultRewardFn ]. Defaults to DefaultRewardFn , giving a reward of 1.0 for revealing an empty square, 0.0 for revealing a mine, and 0.0 for an invalid action (selecting an already revealed square). None done_function Optional[jumanji.environments.logic.minesweeper.done.DoneFn] DoneFn whose __call__ method computes the done signal given the current state, action taken, and next state. Implemented options are [ DefaultDoneFn ]. Defaults to DefaultDoneFn , ending the episode on solving the board, revealing a mine, or picking an invalid action. None viewer Optional[jumanji.viewer.Viewer[jumanji.environments.logic.minesweeper.types.State]] Viewer to support rendering and animation methods. Implemented options are [ MinesweeperViewer ]. Defaults to MinesweeperViewer . None reset ( self , key : PRNGKeyArray ) -> Tuple [ jumanji . environments . logic . minesweeper . types . State , jumanji . types . TimeStep [ jumanji . environments . logic . minesweeper . types . Observation ]] # Resets the environment. Parameters: Name Type Description Default key PRNGKeyArray needed for placing mines. required Returns: Type Description state State corresponding to the new state of the environment, timestep: TimeStep corresponding to the first timestep returned by the environment. step ( self , state : State , action : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ]) -> Tuple [ jumanji . environments . logic . minesweeper . types . State , jumanji . types . TimeStep [ jumanji . environments . logic . minesweeper . types . Observation ]] # Run one timestep of the environment's dynamics. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] Array containing the row and column of the square to be explored. required Returns: Type Description next_state State corresponding to the next state of the environment, next_timestep: TimeStep corresponding to the timestep returned by the environment. observation_spec ( self ) -> jumanji . specs . Spec [ jumanji . environments . logic . minesweeper . types . Observation ] # Specifications of the observation of the Minesweeper environment. Returns: Type Description Spec for the `Observation` whose fields are board: BoundedArray (int32) of shape (num_rows, num_cols). action_mask: BoundedArray (bool) of shape (num_rows, num_cols). num_mines: BoundedArray (int32) of shape (). step_count: BoundedArray (int32) of shape (). action_spec ( self ) -> MultiDiscreteArray # Returns the action spec. An action consists of the height and width of the square to be explored. Returns: Type Description action_spec specs.MultiDiscreteArray object.","title":"Minesweeper"},{"location":"api/environments/minesweeper/#jumanji.environments.logic.minesweeper.env.Minesweeper","text":"A JAX implementation of the minesweeper game. observation: Observation board: jax array (int32) of shape (num_rows, num_cols): each cell contains -1 if not yet explored, or otherwise the number of mines in the 8 adjacent squares. action_mask: jax array (bool) of shape (num_rows, num_cols): indicates which actions are valid (not yet explored squares). num_mines: jax array (int32) of shape () , indicates the number of mines to locate. step_count: jax array (int32) of shape (): specifies how many timesteps have elapsed since environment reset. action: multi discrete array containing the square to explore (row and col). reward: jax array (float32): Configurable function of state and action. By default: 1 for every timestep where a valid action is chosen that doesn't reveal a mine, 0 for revealing a mine or selecting an already revealed square (and terminate the episode). episode termination: Configurable function of state, next_state, and action. By default: Stop the episode if a mine is explored, an invalid action is selected (exploring an already explored square), or the board is solved. state: State board: jax array (int32) of shape (num_rows, num_cols): each cell contains -1 if not yet explored, or otherwise the number of mines in the 8 adjacent squares. step_count: jax array (int32) of shape (): specifies how many timesteps have elapsed since environment reset. flat_mine_locations: jax array (int32) of shape (num_rows * num_cols,): indicates the (flat) locations of all the mines on the board. Will be of length num_mines. key: jax array (int32) of shape (2,) used for seeding the sampling of mine placement on reset. 1 2 3 4 5 6 7 8 from jumanji.environments import Minesweeper env = Minesweeper () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state )","title":"Minesweeper"},{"location":"api/environments/minesweeper/#jumanji.environments.logic.minesweeper.env.Minesweeper.__init__","text":"Instantiate a Minesweeper environment. Parameters: Name Type Description Default generator Optional[jumanji.environments.logic.minesweeper.generator.Generator] Generator to generate problem instances on environment reset. Implemented options are [ SamplingGenerator ]. Defaults to SamplingGenerator . The generator will have attributes: - num_rows: number of rows, i.e. height of the board. Defaults to 10. - num_cols: number of columns, i.e. width of the board. Defaults to 10. - num_mines: number of mines generated. Defaults to 10. None reward_function Optional[jumanji.environments.logic.minesweeper.reward.RewardFn] RewardFn whose __call__ method computes the reward of an environment transition based on the given current state and selected action. Implemented options are [ DefaultRewardFn ]. Defaults to DefaultRewardFn , giving a reward of 1.0 for revealing an empty square, 0.0 for revealing a mine, and 0.0 for an invalid action (selecting an already revealed square). None done_function Optional[jumanji.environments.logic.minesweeper.done.DoneFn] DoneFn whose __call__ method computes the done signal given the current state, action taken, and next state. Implemented options are [ DefaultDoneFn ]. Defaults to DefaultDoneFn , ending the episode on solving the board, revealing a mine, or picking an invalid action. None viewer Optional[jumanji.viewer.Viewer[jumanji.environments.logic.minesweeper.types.State]] Viewer to support rendering and animation methods. Implemented options are [ MinesweeperViewer ]. Defaults to MinesweeperViewer . None","title":"__init__()"},{"location":"api/environments/minesweeper/#jumanji.environments.logic.minesweeper.env.Minesweeper.reset","text":"Resets the environment. Parameters: Name Type Description Default key PRNGKeyArray needed for placing mines. required Returns: Type Description state State corresponding to the new state of the environment, timestep: TimeStep corresponding to the first timestep returned by the environment.","title":"reset()"},{"location":"api/environments/minesweeper/#jumanji.environments.logic.minesweeper.env.Minesweeper.step","text":"Run one timestep of the environment's dynamics. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] Array containing the row and column of the square to be explored. required Returns: Type Description next_state State corresponding to the next state of the environment, next_timestep: TimeStep corresponding to the timestep returned by the environment.","title":"step()"},{"location":"api/environments/minesweeper/#jumanji.environments.logic.minesweeper.env.Minesweeper.observation_spec","text":"Specifications of the observation of the Minesweeper environment. Returns: Type Description Spec for the `Observation` whose fields are board: BoundedArray (int32) of shape (num_rows, num_cols). action_mask: BoundedArray (bool) of shape (num_rows, num_cols). num_mines: BoundedArray (int32) of shape (). step_count: BoundedArray (int32) of shape ().","title":"observation_spec()"},{"location":"api/environments/minesweeper/#jumanji.environments.logic.minesweeper.env.Minesweeper.action_spec","text":"Returns the action spec. An action consists of the height and width of the square to be explored. Returns: Type Description action_spec specs.MultiDiscreteArray object.","title":"action_spec()"},{"location":"api/environments/mmst/","text":"MMST ( Environment ) # The MMST (Multi Minimum Spanning Tree) environment consists of a random connected graph with groups of nodes (same node types) that needs to be connected. The goal of the environment is to connect all nodes of the same type together without using the same utility nodes (nodes that do not belong to any group of nodes). Note: routing problems are randomly generated and may not be solvable! Requirements: The total number of nodes should be at least 20% more than the number of nodes we want to connect to guarantee we have enough remaining nodes to create a path with all the nodes we want to connect. An exception will be raised if the number of nodes is not greater than (0.8 x num_agents x num_nodes_per_agent). observation: Observation node_types: jax array (int) of shape (num_nodes): the component type of each node (-1 represents utility nodes). adj_matrix: jax array (bool) of shape (num_nodes, num_nodes): adjacency matrix of the graph. positions: jax array (int) of shape (num_agents,): the index of the last visited node. step_count: jax array (int) of shape (): integer to keep track of the number of steps. action_mask: jax array (bool) of shape (num_agent, num_nodes): binary mask (False/True <--> invalid/valid action). reward: float action: jax array (int) of shape (num_agents,): [0,1,..., num_nodes-1] Each agent selects the next node to which it wants to connect. state: State node_type: jax array (int) of shape (num_nodes,). the component type of each node (-1 represents utility nodes). adj_matrix: jax array (bool) of shape (num_nodes, num_nodes): adjacency matrix of the graph. connected_nodes: jax array (int) of shape (num_agents, time_limit). we only count each node visit once. connected_nodes_index: jax array (int) of shape (num_agents, num_nodes). position_index: jax array (int) of shape (num_agents,). node_edges: jax array (int) of shape (num_agents, num_nodes, num_nodes). positions: jax array (int) of shape (num_agents,). the index of the last visited node. action_mask: jax array (bool) of shape (num_agent, num_nodes). binary mask (False/True <--> invalid/valid action). finished_agents: jax array (bool) of shape (num_agent,). nodes_to_connect: jax array (int) of shape (num_agents, num_nodes_per_agent). step_count: step counter. time_limit: the number of steps allowed before an episode terminates. key: PRNG key for random sample. constants definitions: Nodes INVALID_NODE = -1: used to check if an agent selects an invalid node. A node may be invalid if its has no edge with the current node or if it is a utility node already selected by another agent. UTILITY_NODE = -1: utility node (belongs to no agent). EMPTY_NODE = -1: used for padding. state.connected_nodes stores the path (all the nodes) visited by an agent. Hence it has size equal to the step limit. We use this constant to initialise this array since 0 represents the first node. DUMMY_NODE = -10: used for tie-breaking if multiple agents select the same node. Edges EMPTY_EDGE = -1: used for masking edges array. state.node_edges is the graph's adjacency matrix, but we don't represent it using 0s and 1s, we use the node values instead, i.e A_ij = j or A_ij = -1 . Also edges are masked when utility nodes are selected by an agent to make it unaccessible by other agents. Actions encoding INVALID_CHOICE = -1 INVALID_TIE_BREAK = -2 INVALID_ALREADY_TRAVERSED = -3 __init__ ( self , generator : Optional [ jumanji . environments . routing . mmst . generator . Generator ] = None , reward_fn : Optional [ jumanji . environments . routing . mmst . reward . RewardFn ] = None , time_limit : int = 70 , viewer : Optional [ jumanji . viewer . Viewer [ jumanji . environments . routing . mmst . types . State ]] = None ) special # Create the MMST environment. Parameters: Name Type Description Default generator Optional[jumanji.environments.routing.mmst.generator.Generator] Generator whose __call__ instantiates an environment instance. Implemented options are [ SplitRandomGenerator ]. Defaults to SplitRandomGenerator(num_nodes=36, num_edges=72, max_degree=5, num_agents=3, num_nodes_per_agent=4, max_step=time_limit) . None reward_fn Optional[jumanji.environments.routing.mmst.reward.RewardFn] class of type RewardFn , whose __call__ is used as a reward function. Implemented options are [ DenseRewardFn ]. Defaults to DenseRewardFn(reward_values=(10.0, -1.0, -1.0)) . None time_limit int the number of steps allowed before an episode terminates. Defaults to 70. 70 viewer Optional[jumanji.viewer.Viewer[jumanji.environments.routing.mmst.types.State]] Viewer used for rendering. Defaults to MMSTViewer None reset ( self , key : PRNGKeyArray ) -> Tuple [ jumanji . environments . routing . mmst . types . State , jumanji . types . TimeStep [ jumanji . environments . routing . mmst . types . Observation ]] # Resets the environment. Parameters: Name Type Description Default key PRNGKeyArray used to randomly generate the problem and the different start nodes. required Returns: Type Description state State object corresponding to the new state of the environment. timestep: TimeStep object corresponding to the first timestep returned by the environment. step ( self , state : State , action : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ]) -> Tuple [ jumanji . environments . routing . mmst . types . State , jumanji . types . TimeStep [ jumanji . environments . routing . mmst . types . Observation ]] # Run one timestep of the environment's dynamics. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] Array containing the index of the next node to visit. required Returns: Type Description state, timestep Tuple[State, TimeStep] containing the next state of the environment, as well as the timestep to be observed. action_spec ( self ) -> MultiDiscreteArray # Returns the action spec. Returns: Type Description action_spec a specs.MultiDiscreteArray spec. observation_spec ( self ) -> jumanji . specs . Spec [ jumanji . environments . routing . mmst . types . Observation ] # Returns the observation spec. Returns: Type Description Spec for the `Observation` whose fields are node_types: BoundedArray (int32) of shape (num_nodes,). adj_matrix: BoundedArray (int) of shape (num_nodes, num_nodes). Represents the adjacency matrix of the graph. positions: BoundedArray (int32) of shape (num_agents). Current node position of agent. action_mask: BoundedArray (bool) of shape (num_agents, num_nodes,). Represents the valid actions in the current state. render ( self , state : State ) -> Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ] # Render the environment for a given state. Returns: Type Description Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] Array of rgb pixel values in the shape (width, height, rgb).","title":"MMST"},{"location":"api/environments/mmst/#jumanji.environments.routing.mmst.env.MMST","text":"The MMST (Multi Minimum Spanning Tree) environment consists of a random connected graph with groups of nodes (same node types) that needs to be connected. The goal of the environment is to connect all nodes of the same type together without using the same utility nodes (nodes that do not belong to any group of nodes). Note: routing problems are randomly generated and may not be solvable! Requirements: The total number of nodes should be at least 20% more than the number of nodes we want to connect to guarantee we have enough remaining nodes to create a path with all the nodes we want to connect. An exception will be raised if the number of nodes is not greater than (0.8 x num_agents x num_nodes_per_agent). observation: Observation node_types: jax array (int) of shape (num_nodes): the component type of each node (-1 represents utility nodes). adj_matrix: jax array (bool) of shape (num_nodes, num_nodes): adjacency matrix of the graph. positions: jax array (int) of shape (num_agents,): the index of the last visited node. step_count: jax array (int) of shape (): integer to keep track of the number of steps. action_mask: jax array (bool) of shape (num_agent, num_nodes): binary mask (False/True <--> invalid/valid action). reward: float action: jax array (int) of shape (num_agents,): [0,1,..., num_nodes-1] Each agent selects the next node to which it wants to connect. state: State node_type: jax array (int) of shape (num_nodes,). the component type of each node (-1 represents utility nodes). adj_matrix: jax array (bool) of shape (num_nodes, num_nodes): adjacency matrix of the graph. connected_nodes: jax array (int) of shape (num_agents, time_limit). we only count each node visit once. connected_nodes_index: jax array (int) of shape (num_agents, num_nodes). position_index: jax array (int) of shape (num_agents,). node_edges: jax array (int) of shape (num_agents, num_nodes, num_nodes). positions: jax array (int) of shape (num_agents,). the index of the last visited node. action_mask: jax array (bool) of shape (num_agent, num_nodes). binary mask (False/True <--> invalid/valid action). finished_agents: jax array (bool) of shape (num_agent,). nodes_to_connect: jax array (int) of shape (num_agents, num_nodes_per_agent). step_count: step counter. time_limit: the number of steps allowed before an episode terminates. key: PRNG key for random sample. constants definitions: Nodes INVALID_NODE = -1: used to check if an agent selects an invalid node. A node may be invalid if its has no edge with the current node or if it is a utility node already selected by another agent. UTILITY_NODE = -1: utility node (belongs to no agent). EMPTY_NODE = -1: used for padding. state.connected_nodes stores the path (all the nodes) visited by an agent. Hence it has size equal to the step limit. We use this constant to initialise this array since 0 represents the first node. DUMMY_NODE = -10: used for tie-breaking if multiple agents select the same node. Edges EMPTY_EDGE = -1: used for masking edges array. state.node_edges is the graph's adjacency matrix, but we don't represent it using 0s and 1s, we use the node values instead, i.e A_ij = j or A_ij = -1 . Also edges are masked when utility nodes are selected by an agent to make it unaccessible by other agents. Actions encoding INVALID_CHOICE = -1 INVALID_TIE_BREAK = -2 INVALID_ALREADY_TRAVERSED = -3","title":"MMST"},{"location":"api/environments/mmst/#jumanji.environments.routing.mmst.env.MMST.__init__","text":"Create the MMST environment. Parameters: Name Type Description Default generator Optional[jumanji.environments.routing.mmst.generator.Generator] Generator whose __call__ instantiates an environment instance. Implemented options are [ SplitRandomGenerator ]. Defaults to SplitRandomGenerator(num_nodes=36, num_edges=72, max_degree=5, num_agents=3, num_nodes_per_agent=4, max_step=time_limit) . None reward_fn Optional[jumanji.environments.routing.mmst.reward.RewardFn] class of type RewardFn , whose __call__ is used as a reward function. Implemented options are [ DenseRewardFn ]. Defaults to DenseRewardFn(reward_values=(10.0, -1.0, -1.0)) . None time_limit int the number of steps allowed before an episode terminates. Defaults to 70. 70 viewer Optional[jumanji.viewer.Viewer[jumanji.environments.routing.mmst.types.State]] Viewer used for rendering. Defaults to MMSTViewer None","title":"__init__()"},{"location":"api/environments/mmst/#jumanji.environments.routing.mmst.env.MMST.reset","text":"Resets the environment. Parameters: Name Type Description Default key PRNGKeyArray used to randomly generate the problem and the different start nodes. required Returns: Type Description state State object corresponding to the new state of the environment. timestep: TimeStep object corresponding to the first timestep returned by the environment.","title":"reset()"},{"location":"api/environments/mmst/#jumanji.environments.routing.mmst.env.MMST.step","text":"Run one timestep of the environment's dynamics. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] Array containing the index of the next node to visit. required Returns: Type Description state, timestep Tuple[State, TimeStep] containing the next state of the environment, as well as the timestep to be observed.","title":"step()"},{"location":"api/environments/mmst/#jumanji.environments.routing.mmst.env.MMST.action_spec","text":"Returns the action spec. Returns: Type Description action_spec a specs.MultiDiscreteArray spec.","title":"action_spec()"},{"location":"api/environments/mmst/#jumanji.environments.routing.mmst.env.MMST.observation_spec","text":"Returns the observation spec. Returns: Type Description Spec for the `Observation` whose fields are node_types: BoundedArray (int32) of shape (num_nodes,). adj_matrix: BoundedArray (int) of shape (num_nodes, num_nodes). Represents the adjacency matrix of the graph. positions: BoundedArray (int32) of shape (num_agents). Current node position of agent. action_mask: BoundedArray (bool) of shape (num_agents, num_nodes,). Represents the valid actions in the current state.","title":"observation_spec()"},{"location":"api/environments/mmst/#jumanji.environments.routing.mmst.env.MMST.render","text":"Render the environment for a given state. Returns: Type Description Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] Array of rgb pixel values in the shape (width, height, rgb).","title":"render()"},{"location":"api/environments/rubiks_cube/","text":"RubiksCube ( Environment ) # A JAX implementation of the Rubik's Cube with a configurable cube size (by default, 3) and number of scrambles at reset. observation: Observation cube: jax array (int8) of shape (6, cube_size, cube_size): each cell contains the index of the corresponding colour of the sticker in the scramble. step_count: jax array (int32) of shape (): specifies how many timesteps have elapsed since environment reset. action: multi discrete array containing the move to perform (face, depth, and direction). reward: jax array (float) of shape (): by default, 1.0 if cube is solved, otherwise 0.0. episode termination: if either the cube is solved or a time limit is reached. state: State cube: jax array (int8) of shape (6, cube_size, cube_size): each cell contains the index of the corresponding colour of the sticker in the scramble. step_count: jax array (int32) of shape (): specifies how many timesteps have elapsed since environment reset. key: jax array (uint) of shape (2,) used for seeding the sampling for scrambling on reset. 1 2 3 4 5 6 7 8 from jumanji.environments import RubiksCube env = RubiksCube () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state ) __init__ ( self , generator : Optional [ jumanji . environments . logic . rubiks_cube . generator . Generator ] = None , time_limit : int = 200 , reward_fn : Optional [ jumanji . environments . logic . rubiks_cube . reward . RewardFn ] = None , viewer : Optional [ jumanji . viewer . Viewer [ jumanji . environments . logic . rubiks_cube . types . State ]] = None ) special # Instantiate a RubiksCube environment. Parameters: Name Type Description Default generator Optional[jumanji.environments.logic.rubiks_cube.generator.Generator] Generator used to generate problem instances on environment reset. Implemented options are [ ScramblingGenerator ]. Defaults to ScramblingGenerator , with 100 scrambles on reset. The generator will contain an attribute cube_size , corresponding to the number of cubies to an edge, and defaulting to 3. None time_limit int the number of steps allowed before an episode terminates. Defaults to 200. 200 reward_fn Optional[jumanji.environments.logic.rubiks_cube.reward.RewardFn] RewardFn whose __call__ method computes the reward given the new state. Implemented options are [ SparseRewardFn ]. Defaults to SparseRewardFn , giving a reward of 1.0 if the cube is solved or otherwise 0.0. None viewer Optional[jumanji.viewer.Viewer[jumanji.environments.logic.rubiks_cube.types.State]] Viewer to support rendering and animation methods. Implemented options are [ RubiksCubeViewer ]. Defaults to RubiksCubeViewer . None reset ( self , key : PRNGKeyArray ) -> Tuple [ jumanji . environments . logic . rubiks_cube . types . State , jumanji . types . TimeStep [ jumanji . environments . logic . rubiks_cube . types . Observation ]] # Resets the environment. Parameters: Name Type Description Default key PRNGKeyArray needed for scramble. required Returns: Type Description state State corresponding to the new state of the environment. timestep: TimeStep corresponding to the first timestep returned by the environment. step ( self , state : State , action : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ]) -> Tuple [ jumanji . environments . logic . rubiks_cube . types . State , jumanji . types . TimeStep [ jumanji . environments . logic . rubiks_cube . types . Observation ]] # Run one timestep of the environment's dynamics. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] Array of shape (3,) indicating the face to move, depth of the move, and the amount to move by. required Returns: Type Description next_state State corresponding to the next state of the environment. next_timestep: TimeStep corresponding to the timestep returned by the environment. observation_spec ( self ) -> jumanji . specs . Spec [ jumanji . environments . logic . rubiks_cube . types . Observation ] # Specifications of the observation of the RubiksCube environment. Returns: Type Description Spec containing all the specifications for all the `Observation` fields cube: BoundedArray (jnp.int8) of shape (num_faces, cube_size, cube_size). step_count: BoundedArray (jnp.int32) of shape (). action_spec ( self ) -> MultiDiscreteArray # Returns the action spec. An action is composed of 3 elements that range in: 6 faces, each with cube_size//2 possible depths, and 3 possible directions. Returns: Type Description action_spec MultiDiscreteArray object.","title":"RubiksCube"},{"location":"api/environments/rubiks_cube/#jumanji.environments.logic.rubiks_cube.env.RubiksCube","text":"A JAX implementation of the Rubik's Cube with a configurable cube size (by default, 3) and number of scrambles at reset. observation: Observation cube: jax array (int8) of shape (6, cube_size, cube_size): each cell contains the index of the corresponding colour of the sticker in the scramble. step_count: jax array (int32) of shape (): specifies how many timesteps have elapsed since environment reset. action: multi discrete array containing the move to perform (face, depth, and direction). reward: jax array (float) of shape (): by default, 1.0 if cube is solved, otherwise 0.0. episode termination: if either the cube is solved or a time limit is reached. state: State cube: jax array (int8) of shape (6, cube_size, cube_size): each cell contains the index of the corresponding colour of the sticker in the scramble. step_count: jax array (int32) of shape (): specifies how many timesteps have elapsed since environment reset. key: jax array (uint) of shape (2,) used for seeding the sampling for scrambling on reset. 1 2 3 4 5 6 7 8 from jumanji.environments import RubiksCube env = RubiksCube () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state )","title":"RubiksCube"},{"location":"api/environments/rubiks_cube/#jumanji.environments.logic.rubiks_cube.env.RubiksCube.__init__","text":"Instantiate a RubiksCube environment. Parameters: Name Type Description Default generator Optional[jumanji.environments.logic.rubiks_cube.generator.Generator] Generator used to generate problem instances on environment reset. Implemented options are [ ScramblingGenerator ]. Defaults to ScramblingGenerator , with 100 scrambles on reset. The generator will contain an attribute cube_size , corresponding to the number of cubies to an edge, and defaulting to 3. None time_limit int the number of steps allowed before an episode terminates. Defaults to 200. 200 reward_fn Optional[jumanji.environments.logic.rubiks_cube.reward.RewardFn] RewardFn whose __call__ method computes the reward given the new state. Implemented options are [ SparseRewardFn ]. Defaults to SparseRewardFn , giving a reward of 1.0 if the cube is solved or otherwise 0.0. None viewer Optional[jumanji.viewer.Viewer[jumanji.environments.logic.rubiks_cube.types.State]] Viewer to support rendering and animation methods. Implemented options are [ RubiksCubeViewer ]. Defaults to RubiksCubeViewer . None","title":"__init__()"},{"location":"api/environments/rubiks_cube/#jumanji.environments.logic.rubiks_cube.env.RubiksCube.reset","text":"Resets the environment. Parameters: Name Type Description Default key PRNGKeyArray needed for scramble. required Returns: Type Description state State corresponding to the new state of the environment. timestep: TimeStep corresponding to the first timestep returned by the environment.","title":"reset()"},{"location":"api/environments/rubiks_cube/#jumanji.environments.logic.rubiks_cube.env.RubiksCube.step","text":"Run one timestep of the environment's dynamics. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] Array of shape (3,) indicating the face to move, depth of the move, and the amount to move by. required Returns: Type Description next_state State corresponding to the next state of the environment. next_timestep: TimeStep corresponding to the timestep returned by the environment.","title":"step()"},{"location":"api/environments/rubiks_cube/#jumanji.environments.logic.rubiks_cube.env.RubiksCube.observation_spec","text":"Specifications of the observation of the RubiksCube environment. Returns: Type Description Spec containing all the specifications for all the `Observation` fields cube: BoundedArray (jnp.int8) of shape (num_faces, cube_size, cube_size). step_count: BoundedArray (jnp.int32) of shape ().","title":"observation_spec()"},{"location":"api/environments/rubiks_cube/#jumanji.environments.logic.rubiks_cube.env.RubiksCube.action_spec","text":"Returns the action spec. An action is composed of 3 elements that range in: 6 faces, each with cube_size//2 possible depths, and 3 possible directions. Returns: Type Description action_spec MultiDiscreteArray object.","title":"action_spec()"},{"location":"api/environments/rware/","text":"RobotWarehouse ( Environment ) # A JAX implementation of the 'Robotic warehouse' environment: https://github.com/semitable/robotic-warehouse which is described in the paper [1]. Creates a grid world where multiple agents (robots) are supposed to collect shelves, bring them to a goal and then return them. Below is an example warehouse floor grid: the grid layout is instantiated using three arguments shelf_rows: number of vertical shelf clusters shelf_columns: odd number of horizontal shelf clusters column_height: height of each cluster A cluster is a set of grouped shelves (two cells wide) represented below as 1 XX Shelf cluster -> XX (this cluster is of height 3) XX Grid Layout: 1 2 3 4 shelf columns (here set to 3, i.e. v v v shelf_columns=3, must be an odd number) ---------- > -XX-XX-XX- ^ Shelf Row 1 -> -XX-XX-XX- Column Height (here set to 3, i.e. > -XX-XX-XX- v column_height=3) ---------- -XX----XX- < -XX----XX- <- Shelf Row 2 (here set to 2, i.e. -XX----XX- < shelf_rows=2) ---------- ----GG---- G: is the goal positions where agents are rewarded if they successfully deliver a requested shelf (i.e toggle the load action inside the goal position while carrying a requested shelf). The final grid size will be - height: (column_height + 1) * shelf_rows + 2 - width: (2 + 1) * shelf_columns + 1 The bottom-middle column is removed to allow for agents to queue in front of the goal positions action: jax array (int) of shape (num_agents,) containing the action for each agent. (0: noop, 1: forward, 2: left, 3: right, 4: toggle_load) reward: jax array (int) of shape (), global reward shared by all agents, +1 for every successful delivery of a requested shelf to the goal position. episode termination: The number of steps is greater than the limit. Any agent selects an action which causes two agents to collide. state: State grid: an array representing the warehouse floor as a 2D grid with two separate channels one for the agents, and one for the shelves agents: a pytree of Agent type with per agent leaves: [position, direction, is_carrying] shelves: a pytree of Shelf type with per shelf leaves: [position, is_requested] request_queue: the queue of requested shelves (by ID). step_count: an integer representing the current step of the episode. action_mask: an array of shape (num_agents, 5) containing the valid actions for each agent. key: a pseudorandom number generator key. [1] Papoudakis et al., Benchmarking Multi-Agent Deep Reinforcement Learning Algorithms in Cooperative Tasks (2021) 1 2 3 4 5 6 7 8 from jumanji.environments import RobotWarehouse env = RobotWarehouse () key = jax . random . PRNGKey ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state ) __init__ ( self , generator : Optional [ jumanji . environments . routing . robot_warehouse . generator . Generator ] = None , time_limit : int = 500 , viewer : Optional [ jumanji . viewer . Viewer [ jumanji . environments . routing . robot_warehouse . types . State ]] = None ) special # Instantiates an RobotWarehouse environment. Parameters: Name Type Description Default generator Optional[jumanji.environments.routing.robot_warehouse.generator.Generator] callable to instantiate environment instances. Defaults to RandomGenerator with parameters: shelf_rows = 2 , shelf_columns = 3 , column_height = 8 , num_agents = 4 , sensor_range = 1 , request_queue_size = 8 . None time_limit int the maximum step limit allowed within the environment. Defaults to 500. 500 viewer Optional[jumanji.viewer.Viewer[jumanji.environments.routing.robot_warehouse.types.State]] viewer to render the environment. Defaults to RobotWarehouseViewer . None reset ( self , key : PRNGKeyArray ) -> Tuple [ jumanji . environments . routing . robot_warehouse . types . State , jumanji . types . TimeStep [ jumanji . environments . routing . robot_warehouse . types . Observation ]] # Resets the environment. Parameters: Name Type Description Default key PRNGKeyArray random key used to reset the environment since it is stochastic. required Returns: Type Description state State object corresponding to the new state of the environment. timestep: TimeStep object corresponding the first timestep returned by the environment. step ( self , state : State , action : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ]) -> Tuple [ jumanji . environments . routing . robot_warehouse . types . State , jumanji . types . TimeStep [ jumanji . environments . routing . robot_warehouse . types . Observation ]] # Perform an environment step. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] Array containing the action to take. - 0 no op - 1 move forward - 2 turn left - 3 turn right - 4 toggle load required Returns: Type Description state State object corresponding to the next state of the environment. timestep: TimeStep object corresponding the timestep returned by the environment. observation_spec ( self ) -> jumanji . specs . Spec [ jumanji . environments . routing . robot_warehouse . types . Observation ] # Specification of the observation of the RobotWarehouse environment. Returns: Type Description Spec for the `Observation`, consisting of the fields agents_view: Array (int32) of shape (num_agents, num_obs_features). action_mask: BoundedArray (bool) of shape (num_agent, 5). step_count: BoundedArray (int32) of shape (). action_spec ( self ) -> MultiDiscreteArray # Returns the action spec. 5 actions: [0,1,2,3,4] -> [No Op, Forward, Left, Right, Toggle_load]. Since this is a multi-agent environment, the environment expects an array of actions. This array is of shape (num_agents,).","title":"Rware"},{"location":"api/environments/rware/#jumanji.environments.routing.robot_warehouse.env.RobotWarehouse","text":"A JAX implementation of the 'Robotic warehouse' environment: https://github.com/semitable/robotic-warehouse which is described in the paper [1]. Creates a grid world where multiple agents (robots) are supposed to collect shelves, bring them to a goal and then return them. Below is an example warehouse floor grid: the grid layout is instantiated using three arguments shelf_rows: number of vertical shelf clusters shelf_columns: odd number of horizontal shelf clusters column_height: height of each cluster A cluster is a set of grouped shelves (two cells wide) represented below as 1 XX Shelf cluster -> XX (this cluster is of height 3) XX Grid Layout: 1 2 3 4 shelf columns (here set to 3, i.e. v v v shelf_columns=3, must be an odd number) ---------- > -XX-XX-XX- ^ Shelf Row 1 -> -XX-XX-XX- Column Height (here set to 3, i.e. > -XX-XX-XX- v column_height=3) ---------- -XX----XX- < -XX----XX- <- Shelf Row 2 (here set to 2, i.e. -XX----XX- < shelf_rows=2) ---------- ----GG---- G: is the goal positions where agents are rewarded if they successfully deliver a requested shelf (i.e toggle the load action inside the goal position while carrying a requested shelf). The final grid size will be - height: (column_height + 1) * shelf_rows + 2 - width: (2 + 1) * shelf_columns + 1 The bottom-middle column is removed to allow for agents to queue in front of the goal positions action: jax array (int) of shape (num_agents,) containing the action for each agent. (0: noop, 1: forward, 2: left, 3: right, 4: toggle_load) reward: jax array (int) of shape (), global reward shared by all agents, +1 for every successful delivery of a requested shelf to the goal position. episode termination: The number of steps is greater than the limit. Any agent selects an action which causes two agents to collide. state: State grid: an array representing the warehouse floor as a 2D grid with two separate channels one for the agents, and one for the shelves agents: a pytree of Agent type with per agent leaves: [position, direction, is_carrying] shelves: a pytree of Shelf type with per shelf leaves: [position, is_requested] request_queue: the queue of requested shelves (by ID). step_count: an integer representing the current step of the episode. action_mask: an array of shape (num_agents, 5) containing the valid actions for each agent. key: a pseudorandom number generator key. [1] Papoudakis et al., Benchmarking Multi-Agent Deep Reinforcement Learning Algorithms in Cooperative Tasks (2021) 1 2 3 4 5 6 7 8 from jumanji.environments import RobotWarehouse env = RobotWarehouse () key = jax . random . PRNGKey ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state )","title":"RobotWarehouse"},{"location":"api/environments/rware/#jumanji.environments.routing.robot_warehouse.env.RobotWarehouse.__init__","text":"Instantiates an RobotWarehouse environment. Parameters: Name Type Description Default generator Optional[jumanji.environments.routing.robot_warehouse.generator.Generator] callable to instantiate environment instances. Defaults to RandomGenerator with parameters: shelf_rows = 2 , shelf_columns = 3 , column_height = 8 , num_agents = 4 , sensor_range = 1 , request_queue_size = 8 . None time_limit int the maximum step limit allowed within the environment. Defaults to 500. 500 viewer Optional[jumanji.viewer.Viewer[jumanji.environments.routing.robot_warehouse.types.State]] viewer to render the environment. Defaults to RobotWarehouseViewer . None","title":"__init__()"},{"location":"api/environments/rware/#jumanji.environments.routing.robot_warehouse.env.RobotWarehouse.reset","text":"Resets the environment. Parameters: Name Type Description Default key PRNGKeyArray random key used to reset the environment since it is stochastic. required Returns: Type Description state State object corresponding to the new state of the environment. timestep: TimeStep object corresponding the first timestep returned by the environment.","title":"reset()"},{"location":"api/environments/rware/#jumanji.environments.routing.robot_warehouse.env.RobotWarehouse.step","text":"Perform an environment step. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] Array containing the action to take. - 0 no op - 1 move forward - 2 turn left - 3 turn right - 4 toggle load required Returns: Type Description state State object corresponding to the next state of the environment. timestep: TimeStep object corresponding the timestep returned by the environment.","title":"step()"},{"location":"api/environments/rware/#jumanji.environments.routing.robot_warehouse.env.RobotWarehouse.observation_spec","text":"Specification of the observation of the RobotWarehouse environment. Returns: Type Description Spec for the `Observation`, consisting of the fields agents_view: Array (int32) of shape (num_agents, num_obs_features). action_mask: BoundedArray (bool) of shape (num_agent, 5). step_count: BoundedArray (int32) of shape ().","title":"observation_spec()"},{"location":"api/environments/rware/#jumanji.environments.routing.robot_warehouse.env.RobotWarehouse.action_spec","text":"Returns the action spec. 5 actions: [0,1,2,3,4] -> [No Op, Forward, Left, Right, Toggle_load]. Since this is a multi-agent environment, the environment expects an array of actions. This array is of shape (num_agents,).","title":"action_spec()"},{"location":"api/environments/snake/","text":"Snake ( Environment ) # A JAX implementation of the 'Snake' game. observation: Observation grid: jax array (float) of shape (num_rows, num_cols, 5) feature maps that include information about the fruit, the snake head, its body and tail. body: 2D map with 1. where a body cell is present, else 0. head: 2D map with 1. where the snake's head is located, else 0. tail: 2D map with 1. where the snake's tail is located, else 0. fruit: 2D map with 1. where the fruit is located, else 0. norm_body_state: 2D map with a float between 0. and 1. for each body cell in the decreasing order from head to tail. step_count: jax array (int32) of shape () current number of steps in the episode. action_mask: jax array (bool) of shape (4,) array specifying which directions the snake can move in from its current position. action: jax array (int32) of shape() [0,1,2,3] -> [Up, Right, Down, Left]. reward: jax array (float) of shape () 1.0 if a fruit is eaten, otherwise 0.0. episode termination: if no action can be performed, i.e. the snake is surrounded. if the time limit is reached. if an invalid action is taken, the snake exits the grid or bumps into itself. state: State body: jax array (bool) of shape (num_rows, num_cols) array indicating the snake's body cells. body_state: jax array (int32) of shape (num_rows, num_cols) array ordering the snake's body cells, in decreasing order from head to tail. head_position: Position (int32) of shape () position of the snake's head on the 2D grid. tail: jax array (bool) of shape (num_rows, num_cols) array indicating the snake's tail. fruit_position: Position (int32) of shape () position of the fruit on the 2D grid. length: jax array (int32) of shape () current length of the snake. step_count: jax array (int32) of shape () current number of steps in the episode. action_mask: jax array (bool) of shape (4,) array specifying which directions the snake can move in from its current position. key: jax array (uint32) of shape (2,) random key used to sample a new fruit when one is eaten and used for auto-reset. 1 2 3 4 5 6 7 8 from jumanji.environments import Snake env = Snake () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state ) __init__ ( self , num_rows : int = 12 , num_cols : int = 12 , time_limit : int = 4000 , viewer : Optional [ jumanji . viewer . Viewer [ jumanji . environments . routing . snake . types . State ]] = None ) special # Instantiates a Snake environment. Parameters: Name Type Description Default num_rows int number of rows of the 2D grid. Defaults to 12. 12 num_cols int number of columns of the 2D grid. Defaults to 12. 12 time_limit int time_limit of an episode, i.e. number of environment steps before the episode ends. Defaults to 4000. 4000 viewer Optional[jumanji.viewer.Viewer[jumanji.environments.routing.snake.types.State]] Viewer used for rendering. Defaults to SnakeViewer . None reset ( self , key : PRNGKeyArray ) -> Tuple [ jumanji . environments . routing . snake . types . State , jumanji . types . TimeStep [ jumanji . environments . routing . snake . types . Observation ]] # Resets the environment. Parameters: Name Type Description Default key PRNGKeyArray random key used to sample the snake and fruit positions. required Returns: Type Description state State object corresponding to the new state of the environment. timestep: TimeStep object corresponding to the first timestep returned by the environment. step ( self , state : State , action : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number , float , int ]) -> Tuple [ jumanji . environments . routing . snake . types . State , jumanji . types . TimeStep [ jumanji . environments . routing . snake . types . Observation ]] # Run one timestep of the environment's dynamics. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, float, int] Array containing the action to take: - 0: move up. - 1: move to the right. - 2: move down. - 3: move to the left. required Returns: Type Description state, timestep next state of the environment and timestep to be observed. observation_spec ( self ) -> jumanji . specs . Spec [ jumanji . environments . routing . snake . types . Observation ] # Returns the observation spec. Returns: Type Description Spec for the `Observation` whose fields are grid: BoundedArray (float) of shape (num_rows, num_cols, 5). step_count: DiscreteArray (num_values = time_limit) of shape (). action_mask: BoundedArray (bool) of shape (4,). action_spec ( self ) -> DiscreteArray # Returns the action spec. 4 actions: [0,1,2,3] -> [Up, Right, Down, Left]. Returns: Type Description action_spec a specs.DiscreteArray spec.","title":"Snake"},{"location":"api/environments/snake/#jumanji.environments.routing.snake.env.Snake","text":"A JAX implementation of the 'Snake' game. observation: Observation grid: jax array (float) of shape (num_rows, num_cols, 5) feature maps that include information about the fruit, the snake head, its body and tail. body: 2D map with 1. where a body cell is present, else 0. head: 2D map with 1. where the snake's head is located, else 0. tail: 2D map with 1. where the snake's tail is located, else 0. fruit: 2D map with 1. where the fruit is located, else 0. norm_body_state: 2D map with a float between 0. and 1. for each body cell in the decreasing order from head to tail. step_count: jax array (int32) of shape () current number of steps in the episode. action_mask: jax array (bool) of shape (4,) array specifying which directions the snake can move in from its current position. action: jax array (int32) of shape() [0,1,2,3] -> [Up, Right, Down, Left]. reward: jax array (float) of shape () 1.0 if a fruit is eaten, otherwise 0.0. episode termination: if no action can be performed, i.e. the snake is surrounded. if the time limit is reached. if an invalid action is taken, the snake exits the grid or bumps into itself. state: State body: jax array (bool) of shape (num_rows, num_cols) array indicating the snake's body cells. body_state: jax array (int32) of shape (num_rows, num_cols) array ordering the snake's body cells, in decreasing order from head to tail. head_position: Position (int32) of shape () position of the snake's head on the 2D grid. tail: jax array (bool) of shape (num_rows, num_cols) array indicating the snake's tail. fruit_position: Position (int32) of shape () position of the fruit on the 2D grid. length: jax array (int32) of shape () current length of the snake. step_count: jax array (int32) of shape () current number of steps in the episode. action_mask: jax array (bool) of shape (4,) array specifying which directions the snake can move in from its current position. key: jax array (uint32) of shape (2,) random key used to sample a new fruit when one is eaten and used for auto-reset. 1 2 3 4 5 6 7 8 from jumanji.environments import Snake env = Snake () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state )","title":"Snake"},{"location":"api/environments/snake/#jumanji.environments.routing.snake.env.Snake.__init__","text":"Instantiates a Snake environment. Parameters: Name Type Description Default num_rows int number of rows of the 2D grid. Defaults to 12. 12 num_cols int number of columns of the 2D grid. Defaults to 12. 12 time_limit int time_limit of an episode, i.e. number of environment steps before the episode ends. Defaults to 4000. 4000 viewer Optional[jumanji.viewer.Viewer[jumanji.environments.routing.snake.types.State]] Viewer used for rendering. Defaults to SnakeViewer . None","title":"__init__()"},{"location":"api/environments/snake/#jumanji.environments.routing.snake.env.Snake.reset","text":"Resets the environment. Parameters: Name Type Description Default key PRNGKeyArray random key used to sample the snake and fruit positions. required Returns: Type Description state State object corresponding to the new state of the environment. timestep: TimeStep object corresponding to the first timestep returned by the environment.","title":"reset()"},{"location":"api/environments/snake/#jumanji.environments.routing.snake.env.Snake.step","text":"Run one timestep of the environment's dynamics. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, float, int] Array containing the action to take: - 0: move up. - 1: move to the right. - 2: move down. - 3: move to the left. required Returns: Type Description state, timestep next state of the environment and timestep to be observed.","title":"step()"},{"location":"api/environments/snake/#jumanji.environments.routing.snake.env.Snake.observation_spec","text":"Returns the observation spec. Returns: Type Description Spec for the `Observation` whose fields are grid: BoundedArray (float) of shape (num_rows, num_cols, 5). step_count: DiscreteArray (num_values = time_limit) of shape (). action_mask: BoundedArray (bool) of shape (4,).","title":"observation_spec()"},{"location":"api/environments/snake/#jumanji.environments.routing.snake.env.Snake.action_spec","text":"Returns the action spec. 4 actions: [0,1,2,3] -> [Up, Right, Down, Left]. Returns: Type Description action_spec a specs.DiscreteArray spec.","title":"action_spec()"},{"location":"api/environments/sudoku/","text":"Sudoku ( Environment ) # A JAX implementation of the sudoku game. observation: Observation board: jax array (int32) of shape (9,9): empty cells are represented by -1, and filled cells are represented by 0-8. action_mask: jax array (bool) of shape (9,9,9): indicates which actions are valid. action: multi discrete array containing the square to write a digit, and the digits to input. reward: jax array (float32): 1 at the end of the episode if the board is valid 0 otherwise state: State board: jax array (int32) of shape (9,9): empty cells are represented by -1, and filled cells are represented by 0-8. action_mask: jax array (bool) of shape (9,9,9): indicates which actions are valid (empty cells and valid digits). key: jax array (int32) of shape (2,) used for seeding initial sudoku configuration. 1 2 3 4 5 6 7 8 from jumanji.environments import Sudoku env = Sudoku () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state ) __init__ ( self , generator : Optional [ jumanji . environments . logic . sudoku . generator . Generator ] = None , reward_fn : Optional [ jumanji . environments . logic . sudoku . reward . RewardFn ] = None , viewer : Optional [ jumanji . viewer . Viewer [ jumanji . environments . logic . sudoku . types . State ]] = None ) special # reset ( self , key : PRNGKeyArray ) -> Tuple [ jumanji . environments . logic . sudoku . types . State , jumanji . types . TimeStep [ jumanji . environments . logic . sudoku . types . Observation ]] # Resets the environment to an initial state. Parameters: Name Type Description Default key PRNGKeyArray random key used to reset the environment. required Returns: Type Description state State object corresponding to the new state of the environment, timestep: TimeStep object corresponding the first timestep returned by the environment, step ( self , state : State , action : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ]) -> Tuple [ jumanji . environments . logic . sudoku . types . State , jumanji . types . TimeStep [ jumanji . environments . logic . sudoku . types . Observation ]] # Run one timestep of the environment's dynamics. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] Array containing the action to take. required Returns: Type Description state State object corresponding to the next state of the environment, timestep: TimeStep object corresponding the timestep returned by the environment, observation_spec ( self ) -> jumanji . specs . Spec [ jumanji . environments . logic . sudoku . types . Observation ] # Returns the observation spec containing the board and action_mask arrays. Returns: Type Description Spec containing all the specifications for all the `Observation` fields board: BoundedArray (jnp.int8) of shape (9,9). action_mask: BoundedArray (bool) of shape (9,9,9). action_spec ( self ) -> MultiDiscreteArray # Returns the action spec. An action is composed of 3 integers: the row index, the column index and the value to be placed in the cell. Returns: Type Description action_spec MultiDiscreteArray object.","title":"Sudoku"},{"location":"api/environments/sudoku/#jumanji.environments.logic.sudoku.env.Sudoku","text":"A JAX implementation of the sudoku game. observation: Observation board: jax array (int32) of shape (9,9): empty cells are represented by -1, and filled cells are represented by 0-8. action_mask: jax array (bool) of shape (9,9,9): indicates which actions are valid. action: multi discrete array containing the square to write a digit, and the digits to input. reward: jax array (float32): 1 at the end of the episode if the board is valid 0 otherwise state: State board: jax array (int32) of shape (9,9): empty cells are represented by -1, and filled cells are represented by 0-8. action_mask: jax array (bool) of shape (9,9,9): indicates which actions are valid (empty cells and valid digits). key: jax array (int32) of shape (2,) used for seeding initial sudoku configuration. 1 2 3 4 5 6 7 8 from jumanji.environments import Sudoku env = Sudoku () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state )","title":"Sudoku"},{"location":"api/environments/sudoku/#jumanji.environments.logic.sudoku.env.Sudoku.__init__","text":"","title":"__init__()"},{"location":"api/environments/sudoku/#jumanji.environments.logic.sudoku.env.Sudoku.reset","text":"Resets the environment to an initial state. Parameters: Name Type Description Default key PRNGKeyArray random key used to reset the environment. required Returns: Type Description state State object corresponding to the new state of the environment, timestep: TimeStep object corresponding the first timestep returned by the environment,","title":"reset()"},{"location":"api/environments/sudoku/#jumanji.environments.logic.sudoku.env.Sudoku.step","text":"Run one timestep of the environment's dynamics. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] Array containing the action to take. required Returns: Type Description state State object corresponding to the next state of the environment, timestep: TimeStep object corresponding the timestep returned by the environment,","title":"step()"},{"location":"api/environments/sudoku/#jumanji.environments.logic.sudoku.env.Sudoku.observation_spec","text":"Returns the observation spec containing the board and action_mask arrays. Returns: Type Description Spec containing all the specifications for all the `Observation` fields board: BoundedArray (jnp.int8) of shape (9,9). action_mask: BoundedArray (bool) of shape (9,9,9).","title":"observation_spec()"},{"location":"api/environments/sudoku/#jumanji.environments.logic.sudoku.env.Sudoku.action_spec","text":"Returns the action spec. An action is composed of 3 integers: the row index, the column index and the value to be placed in the cell. Returns: Type Description action_spec MultiDiscreteArray object.","title":"action_spec()"},{"location":"api/environments/tetris/","text":"Tetris ( Environment ) # RL Environment for the game of Tetris. The environment has a grid where the player can place tetrominoes. The environment has the following characteristics: observation: Observation grid: jax array (int32) of shape (num_rows, num_cols) representing the current state of the grid. tetromino: jax array (int32) of shape (4, 4) representing the current tetromino sampled from the tetromino list. action_mask: jax array (bool) of shape (4, num_cols). For each tetromino there are 4 rotations, each one corresponds to a line in the action_mask. Mask of the joint action space: True if the action (x_position and rotation degree) is feasible for the current tetromino and grid state. action: multi discrete array of shape (2,) rotation_index: The degree index determines the rotation of the tetromino: 0 corresponds to 0 degrees, 1 corresponds to 90 degrees, 2 corresponds to 180 degrees, and 3 corresponds to 270 degrees. x_position: int between 0 and num_cols - 1 (included). reward: The reward is 0 if no lines was cleared by the action and a convex function of the number of cleared lines otherwise. episode termination: if the tetromino cannot be placed anymore (i.e., it hits the top of the grid). 1 2 3 4 5 6 7 8 from jumanji.environments import Tetris env = Tetris () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state ) __init__ ( self , num_rows : int = 10 , num_cols : int = 10 , time_limit : int = 400 , viewer : Optional [ jumanji . viewer . Viewer [ jumanji . environments . packing . tetris . types . State ]] = None ) -> None special # Instantiates a Tetris environment. Parameters: Name Type Description Default num_rows int number of rows of the 2D grid. Defaults to 10. 10 num_cols int number of columns of the 2D grid. Defaults to 10. 10 time_limit int time_limit of an episode, i.e. number of environment steps before the episode ends. Defaults to 400. 400 viewer Optional[jumanji.viewer.Viewer[jumanji.environments.packing.tetris.types.State]] Viewer used for rendering. Defaults to TetrisViewer . None reset ( self , key : PRNGKeyArray ) -> Tuple [ jumanji . environments . packing . tetris . types . State , jumanji . types . TimeStep [ jumanji . environments . packing . tetris . types . Observation ]] # Resets the environment. Parameters: Name Type Description Default key PRNGKeyArray needed for generating new tetrominoes. required Returns: Type Description state State corresponding to the new state of the environment, timestep: TimeStep corresponding to the first timestep returned by the environment. step ( self , state : State , action : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number ]) -> Tuple [ jumanji . environments . packing . tetris . types . State , jumanji . types . TimeStep [ jumanji . environments . packing . tetris . types . Observation ]] # Run one timestep of the environment's dynamics. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] chex.Array containing the rotation_index and x_position of the tetromino. required Returns: Type Description next_state State corresponding to the next state of the environment, next_timestep: TimeStep corresponding to the timestep returned by the environment. observation_spec ( self ) -> jumanji . specs . Spec [ jumanji . environments . packing . tetris . types . Observation ] # Specifications of the observation of the Tetris environment. Returns: Type Description Spec containing all the specifications for all the `Observation` fields grid: BoundedArray (jnp.int32) of shape (num_rows, num_cols). tetromino: BoundedArray (bool) of shape (4, 4). action_mask: BoundedArray (bool) of shape (NUM_ROTATIONS, num_cols). step_count: DiscreteArray (num_values = time_limit) of shape (). action_spec ( self ) -> MultiDiscreteArray # Returns the action spec. An action consists of two pieces of information: the amount of rotation (number of 90-degree rotations) and the x-position of the leftmost part of the tetromino. Returns: Type Description MultiDiscreteArray The action spec, which is a specs.MultiDiscreteArray object.","title":"Tetris"},{"location":"api/environments/tetris/#jumanji.environments.packing.tetris.env.Tetris","text":"RL Environment for the game of Tetris. The environment has a grid where the player can place tetrominoes. The environment has the following characteristics: observation: Observation grid: jax array (int32) of shape (num_rows, num_cols) representing the current state of the grid. tetromino: jax array (int32) of shape (4, 4) representing the current tetromino sampled from the tetromino list. action_mask: jax array (bool) of shape (4, num_cols). For each tetromino there are 4 rotations, each one corresponds to a line in the action_mask. Mask of the joint action space: True if the action (x_position and rotation degree) is feasible for the current tetromino and grid state. action: multi discrete array of shape (2,) rotation_index: The degree index determines the rotation of the tetromino: 0 corresponds to 0 degrees, 1 corresponds to 90 degrees, 2 corresponds to 180 degrees, and 3 corresponds to 270 degrees. x_position: int between 0 and num_cols - 1 (included). reward: The reward is 0 if no lines was cleared by the action and a convex function of the number of cleared lines otherwise. episode termination: if the tetromino cannot be placed anymore (i.e., it hits the top of the grid). 1 2 3 4 5 6 7 8 from jumanji.environments import Tetris env = Tetris () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state )","title":"Tetris"},{"location":"api/environments/tetris/#jumanji.environments.packing.tetris.env.Tetris.__init__","text":"Instantiates a Tetris environment. Parameters: Name Type Description Default num_rows int number of rows of the 2D grid. Defaults to 10. 10 num_cols int number of columns of the 2D grid. Defaults to 10. 10 time_limit int time_limit of an episode, i.e. number of environment steps before the episode ends. Defaults to 400. 400 viewer Optional[jumanji.viewer.Viewer[jumanji.environments.packing.tetris.types.State]] Viewer used for rendering. Defaults to TetrisViewer . None","title":"__init__()"},{"location":"api/environments/tetris/#jumanji.environments.packing.tetris.env.Tetris.reset","text":"Resets the environment. Parameters: Name Type Description Default key PRNGKeyArray needed for generating new tetrominoes. required Returns: Type Description state State corresponding to the new state of the environment, timestep: TimeStep corresponding to the first timestep returned by the environment.","title":"reset()"},{"location":"api/environments/tetris/#jumanji.environments.packing.tetris.env.Tetris.step","text":"Run one timestep of the environment's dynamics. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] chex.Array containing the rotation_index and x_position of the tetromino. required Returns: Type Description next_state State corresponding to the next state of the environment, next_timestep: TimeStep corresponding to the timestep returned by the environment.","title":"step()"},{"location":"api/environments/tetris/#jumanji.environments.packing.tetris.env.Tetris.observation_spec","text":"Specifications of the observation of the Tetris environment. Returns: Type Description Spec containing all the specifications for all the `Observation` fields grid: BoundedArray (jnp.int32) of shape (num_rows, num_cols). tetromino: BoundedArray (bool) of shape (4, 4). action_mask: BoundedArray (bool) of shape (NUM_ROTATIONS, num_cols). step_count: DiscreteArray (num_values = time_limit) of shape ().","title":"observation_spec()"},{"location":"api/environments/tetris/#jumanji.environments.packing.tetris.env.Tetris.action_spec","text":"Returns the action spec. An action consists of two pieces of information: the amount of rotation (number of 90-degree rotations) and the x-position of the leftmost part of the tetromino. Returns: Type Description MultiDiscreteArray The action spec, which is a specs.MultiDiscreteArray object.","title":"action_spec()"},{"location":"api/environments/tsp/","text":"TSP ( Environment ) # Traveling Salesman Problem (TSP) environment as described in [1]. observation: Observation coordinates: jax array (float) of shape (num_cities, 2) the coordinates of each city. position: jax array (int32) of shape () the index corresponding to the last visited city. trajectory: jax array (int32) of shape (num_cities,) array of city indices defining the route (-1 --> not filled yet). action_mask: jax array (bool) of shape (num_cities,) binary mask (False/True <--> illegal/legal <--> cannot be visited/can be visited). action: jax array (int32) of shape () [0, ..., num_cities - 1] -> city to visit. reward: jax array (float) of shape (), could be either: dense: the negative distance between the current city and the chosen next city to go to. It is 0 for the first chosen city, and for the last city, it also includes the distance to the initial city to complete the tour. sparse: the negative tour length at the end of the episode. The tour length is defined as the sum of the distances between consecutive cities. It is computed by starting at the first city and ending there, after visiting all the cities. In both cases, the reward is a large negative penalty of -num_cities * sqrt(2) if the action is invalid, i.e. a previously selected city is selected again. episode termination: if no action can be performed, i.e. all cities have been visited. if an invalid action is taken, i.e. an already visited city is chosen. state: State coordinates: jax array (float) of shape (num_cities, 2) the coordinates of each city. position: int32 the identifier (index) of the last visited city. visited_mask: jax array (bool) of shape (num_cities,) binary mask (False/True <--> not visited/visited). trajectory: jax array (int32) of shape (num_cities,) the identifiers of the cities that have been visited (-1 means that no city has been visited yet at that time in the sequence). num_visited: int32 number of cities that have been visited. [1] Kwon Y., Choo J., Kim B., Yoon I., Min S., Gwon Y. (2020). \"POMO: Policy Optimization with Multiple Optima for Reinforcement Learning\". 1 2 3 4 5 6 7 8 from jumanji.environments import TSP env = TSP () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state ) __init__ ( self , generator : Optional [ jumanji . environments . routing . tsp . generator . Generator ] = None , reward_fn : Optional [ jumanji . environments . routing . tsp . reward . RewardFn ] = None , viewer : Optional [ jumanji . viewer . Viewer [ jumanji . environments . routing . tsp . types . State ]] = None ) special # Instantiates a TSP environment. Parameters: Name Type Description Default generator Optional[jumanji.environments.routing.tsp.generator.Generator] Generator whose __call__ instantiates an environment instance. The default option is 'UniformGenerator' which randomly generates TSP instances with 20 cities sampled from a uniform distribution. None reward_fn Optional[jumanji.environments.routing.tsp.reward.RewardFn] RewardFn whose __call__ method computes the reward of an environment transition. The function must compute the reward based on the current state, the chosen action and the next state. Implemented options are [ DenseReward , SparseReward ]. Defaults to DenseReward . None viewer Optional[jumanji.viewer.Viewer[jumanji.environments.routing.tsp.types.State]] Viewer used for rendering. Defaults to TSPViewer with \"human\" render mode. None reset ( self , key : PRNGKeyArray ) -> Tuple [ jumanji . environments . routing . tsp . types . State , jumanji . types . TimeStep [ jumanji . environments . routing . tsp . types . Observation ]] # Resets the environment. Parameters: Name Type Description Default key PRNGKeyArray used to randomly generate the coordinates. required Returns: Type Description state State object corresponding to the new state of the environment. timestep: TimeStep object corresponding to the first timestep returned by the environment. step ( self , state : State , action : Union [ jax . Array , numpy . ndarray , numpy . bool_ , numpy . number , float , int ]) -> Tuple [ jumanji . environments . routing . tsp . types . State , jumanji . types . TimeStep [ jumanji . environments . routing . tsp . types . Observation ]] # Run one timestep of the environment's dynamics. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, float, int] Array containing the index of the next position to visit. required Returns: Type Description state the next state of the environment. timestep: the timestep to be observed. observation_spec ( self ) -> jumanji . specs . Spec [ jumanji . environments . routing . tsp . types . Observation ] # Returns the observation spec. Returns: Type Description Spec for the `Observation` whose fields are coordinates: BoundedArray (float) of shape (num_cities,). position: DiscreteArray (num_values = num_cities) of shape (). trajectory: BoundedArray (int32) of shape (num_cities,). action_mask: BoundedArray (bool) of shape (num_cities,). action_spec ( self ) -> DiscreteArray # Returns the action spec. Returns: Type Description action_spec a specs.DiscreteArray spec.","title":"TSP"},{"location":"api/environments/tsp/#jumanji.environments.routing.tsp.env.TSP","text":"Traveling Salesman Problem (TSP) environment as described in [1]. observation: Observation coordinates: jax array (float) of shape (num_cities, 2) the coordinates of each city. position: jax array (int32) of shape () the index corresponding to the last visited city. trajectory: jax array (int32) of shape (num_cities,) array of city indices defining the route (-1 --> not filled yet). action_mask: jax array (bool) of shape (num_cities,) binary mask (False/True <--> illegal/legal <--> cannot be visited/can be visited). action: jax array (int32) of shape () [0, ..., num_cities - 1] -> city to visit. reward: jax array (float) of shape (), could be either: dense: the negative distance between the current city and the chosen next city to go to. It is 0 for the first chosen city, and for the last city, it also includes the distance to the initial city to complete the tour. sparse: the negative tour length at the end of the episode. The tour length is defined as the sum of the distances between consecutive cities. It is computed by starting at the first city and ending there, after visiting all the cities. In both cases, the reward is a large negative penalty of -num_cities * sqrt(2) if the action is invalid, i.e. a previously selected city is selected again. episode termination: if no action can be performed, i.e. all cities have been visited. if an invalid action is taken, i.e. an already visited city is chosen. state: State coordinates: jax array (float) of shape (num_cities, 2) the coordinates of each city. position: int32 the identifier (index) of the last visited city. visited_mask: jax array (bool) of shape (num_cities,) binary mask (False/True <--> not visited/visited). trajectory: jax array (int32) of shape (num_cities,) the identifiers of the cities that have been visited (-1 means that no city has been visited yet at that time in the sequence). num_visited: int32 number of cities that have been visited. [1] Kwon Y., Choo J., Kim B., Yoon I., Min S., Gwon Y. (2020). \"POMO: Policy Optimization with Multiple Optima for Reinforcement Learning\". 1 2 3 4 5 6 7 8 from jumanji.environments import TSP env = TSP () key = jax . random . key ( 0 ) state , timestep = jax . jit ( env . reset )( key ) env . render ( state ) action = env . action_spec () . generate_value () state , timestep = jax . jit ( env . step )( state , action ) env . render ( state )","title":"TSP"},{"location":"api/environments/tsp/#jumanji.environments.routing.tsp.env.TSP.__init__","text":"Instantiates a TSP environment. Parameters: Name Type Description Default generator Optional[jumanji.environments.routing.tsp.generator.Generator] Generator whose __call__ instantiates an environment instance. The default option is 'UniformGenerator' which randomly generates TSP instances with 20 cities sampled from a uniform distribution. None reward_fn Optional[jumanji.environments.routing.tsp.reward.RewardFn] RewardFn whose __call__ method computes the reward of an environment transition. The function must compute the reward based on the current state, the chosen action and the next state. Implemented options are [ DenseReward , SparseReward ]. Defaults to DenseReward . None viewer Optional[jumanji.viewer.Viewer[jumanji.environments.routing.tsp.types.State]] Viewer used for rendering. Defaults to TSPViewer with \"human\" render mode. None","title":"__init__()"},{"location":"api/environments/tsp/#jumanji.environments.routing.tsp.env.TSP.reset","text":"Resets the environment. Parameters: Name Type Description Default key PRNGKeyArray used to randomly generate the coordinates. required Returns: Type Description state State object corresponding to the new state of the environment. timestep: TimeStep object corresponding to the first timestep returned by the environment.","title":"reset()"},{"location":"api/environments/tsp/#jumanji.environments.routing.tsp.env.TSP.step","text":"Run one timestep of the environment's dynamics. Parameters: Name Type Description Default state State State object containing the dynamics of the environment. required action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, float, int] Array containing the index of the next position to visit. required Returns: Type Description state the next state of the environment. timestep: the timestep to be observed.","title":"step()"},{"location":"api/environments/tsp/#jumanji.environments.routing.tsp.env.TSP.observation_spec","text":"Returns the observation spec. Returns: Type Description Spec for the `Observation` whose fields are coordinates: BoundedArray (float) of shape (num_cities,). position: DiscreteArray (num_values = num_cities) of shape (). trajectory: BoundedArray (int32) of shape (num_cities,). action_mask: BoundedArray (bool) of shape (num_cities,).","title":"observation_spec()"},{"location":"api/environments/tsp/#jumanji.environments.routing.tsp.env.TSP.action_spec","text":"Returns the action spec. Returns: Type Description action_spec a specs.DiscreteArray spec.","title":"action_spec()"},{"location":"environments/bin_pack/","text":"BinPack Environment # We provide here an implementation of the 3D bin packing problem . In this problem, the goal of the agent is to efficiently pack a set of boxes (items) of different sizes into a single container with as little empty space as possible. Since there is only 1 bin, this formulation is equivalent to the 3D-knapsack problem. Observation # The observation given to the agent provides information on the available empty space (called EMSs), the items that still need to be packed, and information on what actions are valid at this point. The full observation is as follows: ems : EMS tree of jax arrays (float if normalize_dimensions else int32) each of shape (obs_num_ems,) , coordinates of all EMSs at the current timestep. ems_mask : jax array (bool) of shape (obs_num_ems,) , indicates the EMSs that are valid. items : Item tree of jax arrays (float if normalize_dimensions else int32) each of shape (max_num_items,) , characteristics of all items for this instance. items_mask : jax array (bool) of shape (max_num_items,) , indicates the items that are valid. items_placed : jax array (bool) of shape (max_num_items,) , indicates the items that have been placed so far. action_mask : jax array (bool) of shape (obs_num_ems, max_num_items) , mask of the joint action space: True if the action [ems_id, item_id] is valid. Action # The action space is a MultiDiscreteArray of 2 integer values representing the ID of an EMS (space) and the ID of an item. For instance, [1, 5] will place item 5 in EMS 1. Reward # The reward could be either: Dense : normalized volume (relative to the container volume) of the item packed by taking the chosen action. The computed reward is equivalent to the increase in volume utilization of the container due to packing the chosen item. If the action is invalid, the reward is 0.0 instead. Sparse : computed only at the end of the episode (otherwise, returns 0.0). Returns the volume utilization of the container (between 0.0 and 1.0). If the action is invalid, the action is ignored and the reward is still returned as the current container utilization. Registered Versions \ud83d\udcd6 # BinPack-v2 , 3D bin-packing problem with a solvable random generator that generates up to 20 items maximum, that can handle 40 EMSs maximum that are given in the observation.","title":"BinPack"},{"location":"environments/bin_pack/#binpack-environment","text":"We provide here an implementation of the 3D bin packing problem . In this problem, the goal of the agent is to efficiently pack a set of boxes (items) of different sizes into a single container with as little empty space as possible. Since there is only 1 bin, this formulation is equivalent to the 3D-knapsack problem.","title":"BinPack Environment"},{"location":"environments/bin_pack/#observation","text":"The observation given to the agent provides information on the available empty space (called EMSs), the items that still need to be packed, and information on what actions are valid at this point. The full observation is as follows: ems : EMS tree of jax arrays (float if normalize_dimensions else int32) each of shape (obs_num_ems,) , coordinates of all EMSs at the current timestep. ems_mask : jax array (bool) of shape (obs_num_ems,) , indicates the EMSs that are valid. items : Item tree of jax arrays (float if normalize_dimensions else int32) each of shape (max_num_items,) , characteristics of all items for this instance. items_mask : jax array (bool) of shape (max_num_items,) , indicates the items that are valid. items_placed : jax array (bool) of shape (max_num_items,) , indicates the items that have been placed so far. action_mask : jax array (bool) of shape (obs_num_ems, max_num_items) , mask of the joint action space: True if the action [ems_id, item_id] is valid.","title":"Observation"},{"location":"environments/bin_pack/#action","text":"The action space is a MultiDiscreteArray of 2 integer values representing the ID of an EMS (space) and the ID of an item. For instance, [1, 5] will place item 5 in EMS 1.","title":"Action"},{"location":"environments/bin_pack/#reward","text":"The reward could be either: Dense : normalized volume (relative to the container volume) of the item packed by taking the chosen action. The computed reward is equivalent to the increase in volume utilization of the container due to packing the chosen item. If the action is invalid, the reward is 0.0 instead. Sparse : computed only at the end of the episode (otherwise, returns 0.0). Returns the volume utilization of the container (between 0.0 and 1.0). If the action is invalid, the action is ignored and the reward is still returned as the current container utilization.","title":"Reward"},{"location":"environments/bin_pack/#registered-versions","text":"BinPack-v2 , 3D bin-packing problem with a solvable random generator that generates up to 20 items maximum, that can handle 40 EMSs maximum that are given in the observation.","title":"Registered Versions \ud83d\udcd6"},{"location":"environments/cleaner/","text":"Cleaner Environment # We provide here a JAX jit-able implementation of the Multi-Agent Cleaning environment. In this environment, multiple agents must cooperatively clean the floor of a room with complex indoor barriers (black). At the beginning of an episode, the whole floor is dirty (green). Every time an agent (red) visits a dirty tile, it is cleaned (white). The goal is to clean as many tiles as possible in a given time budget. A new maze is randomly generated using a recursive division method for each new episode. Agents always start in the top left corner of the maze. Observation # The observation seen by the agent is a NamedTuple containing the following: grid : jax array (int) of shape (num_rows, num_cols) , array representing the grid, each tile is either dirty (0), clean (1), or a wall (2). agents_locations : jax array (int) of shape (num_agents, 2) , array specifying the x and y coordinates of every agent. action_mask : jax array (bool) of shape (num_agents, 4) , array specifying, for each agent, which action (up, right, down, left) is legal. step_count : jax array (int32) of shape () , number of steps elapsed in the current episode. Action # The action space is a MultiDiscreteArray containing an integer value in [0, 1, 2, 3] for each agent. Each agent can take one of four actions: up ( 0 ), right ( 1 ), down ( 2 ), or left ( 3 ). The episode terminates if any agent meets one of the following conditions: An invalid action is taken, or An action is blocked by a wall. In both cases, the agent's position remains unchanged. Reward # The reward is global and shared among the agents. It is equal to the number of tiles which were cleaned during the time step, minus a penalty (0.5 by default) to encourage agents to clean the maze faster. Registered Versions \ud83d\udcd6 # Cleaner-v0 , a room of size 10x10 with 3 agents.","title":"Cleaner"},{"location":"environments/cleaner/#cleaner-environment","text":"We provide here a JAX jit-able implementation of the Multi-Agent Cleaning environment. In this environment, multiple agents must cooperatively clean the floor of a room with complex indoor barriers (black). At the beginning of an episode, the whole floor is dirty (green). Every time an agent (red) visits a dirty tile, it is cleaned (white). The goal is to clean as many tiles as possible in a given time budget. A new maze is randomly generated using a recursive division method for each new episode. Agents always start in the top left corner of the maze.","title":"Cleaner Environment"},{"location":"environments/cleaner/#observation","text":"The observation seen by the agent is a NamedTuple containing the following: grid : jax array (int) of shape (num_rows, num_cols) , array representing the grid, each tile is either dirty (0), clean (1), or a wall (2). agents_locations : jax array (int) of shape (num_agents, 2) , array specifying the x and y coordinates of every agent. action_mask : jax array (bool) of shape (num_agents, 4) , array specifying, for each agent, which action (up, right, down, left) is legal. step_count : jax array (int32) of shape () , number of steps elapsed in the current episode.","title":"Observation"},{"location":"environments/cleaner/#action","text":"The action space is a MultiDiscreteArray containing an integer value in [0, 1, 2, 3] for each agent. Each agent can take one of four actions: up ( 0 ), right ( 1 ), down ( 2 ), or left ( 3 ). The episode terminates if any agent meets one of the following conditions: An invalid action is taken, or An action is blocked by a wall. In both cases, the agent's position remains unchanged.","title":"Action"},{"location":"environments/cleaner/#reward","text":"The reward is global and shared among the agents. It is equal to the number of tiles which were cleaned during the time step, minus a penalty (0.5 by default) to encourage agents to clean the maze faster.","title":"Reward"},{"location":"environments/cleaner/#registered-versions","text":"Cleaner-v0 , a room of size 10x10 with 3 agents.","title":"Registered Versions \ud83d\udcd6"},{"location":"environments/connector/","text":"Connector Environment # The Connector environment contains multiple agents spawned in a grid world with each agent representing a start and end position that need to be connected. The main goal of the environment is to connect each start and end position in as few steps as possible. However, when an agent moves it leaves behind a path, which is impassable by all agents. Thus, agents need to cooperate in order to allow each other to connect to their own targets without overlapping. An episode ends when all agents have connected to their targets or no agents can make any further moves due to being blocked. Observation # At each step observation contains 3 items: a grid, an action mask for each agent and the episode step count. grid : jax array (int32) of shape (grid_size, grid_size) , a 2D matrix that represents pairs of points that need to be connected. Each agent has three types of points: position , target and path which are represented by different numbers on the grid. The position of an agent has to connect to its target , leaving a path behind it as it moves across the grid forming its route. Each agent connects to only 1 target. action_mask : jax array (bool) of shape (num_agents, 5) , indicates which actions each agent can take. step_count : jax array (int32) of shape () , represents how many steps have been taken in the environment since the last reset. Encoding # Each agent has 3 components represented in the observation space: position , target , and path . Each agent in the environment will have an integer representing their components. Positions are encoded starting from 2 in multiples of 3: 2, 5, 8, \u2026 Targets are encoded starting from 3 in multiples of 3: 3, 6, 9, \u2026 Paths appear in the location of the head once it moves, starting from 1 in multiples of 3: 1, 4, 7, \u2026 Every group of 3 corresponds to 1 agent: (1,2,3), (4,5,6), \u2026 Example: 1 2 3 Agent1[path=1, position=2, target=3] Agent2[path=4, position=5, target=6] Agent3[path=7, position=8, target=9] For example, on a 6x6 grid, a possible observation is shown below. 1 2 3 4 5 6 [[ 2 0 3 0 0 0] [ 1 0 4 4 4 0] [ 1 0 5 9 0 0] [ 1 0 0 0 0 0] [ 0 0 0 8 0 0] [ 0 0 6 7 7 7]] Action # The action space is a MultiDiscreteArray of shape (num_agents,) of integer values in the range of [0, 4] . Each value corresponds to an agent moving in 1 of 4 cardinal directions or taking the no-op action. That is, [0, 1, 2, 3, 4] -> [No Op, Up, Right, Down, Left]. Reward # The reward is dense : +1.0 per agent that connects at that step and -0.03 per agent that has not connected yet. Rewards are provided in the shape (num_agents,) so that each agent can have a reward. Registered Versions \ud83d\udcd6 # Connector-v2 , grid size of 10 and 10 agents.","title":"Connector"},{"location":"environments/connector/#connector-environment","text":"The Connector environment contains multiple agents spawned in a grid world with each agent representing a start and end position that need to be connected. The main goal of the environment is to connect each start and end position in as few steps as possible. However, when an agent moves it leaves behind a path, which is impassable by all agents. Thus, agents need to cooperate in order to allow each other to connect to their own targets without overlapping. An episode ends when all agents have connected to their targets or no agents can make any further moves due to being blocked.","title":"Connector Environment"},{"location":"environments/connector/#observation","text":"At each step observation contains 3 items: a grid, an action mask for each agent and the episode step count. grid : jax array (int32) of shape (grid_size, grid_size) , a 2D matrix that represents pairs of points that need to be connected. Each agent has three types of points: position , target and path which are represented by different numbers on the grid. The position of an agent has to connect to its target , leaving a path behind it as it moves across the grid forming its route. Each agent connects to only 1 target. action_mask : jax array (bool) of shape (num_agents, 5) , indicates which actions each agent can take. step_count : jax array (int32) of shape () , represents how many steps have been taken in the environment since the last reset.","title":"Observation"},{"location":"environments/connector/#encoding","text":"Each agent has 3 components represented in the observation space: position , target , and path . Each agent in the environment will have an integer representing their components. Positions are encoded starting from 2 in multiples of 3: 2, 5, 8, \u2026 Targets are encoded starting from 3 in multiples of 3: 3, 6, 9, \u2026 Paths appear in the location of the head once it moves, starting from 1 in multiples of 3: 1, 4, 7, \u2026 Every group of 3 corresponds to 1 agent: (1,2,3), (4,5,6), \u2026 Example: 1 2 3 Agent1[path=1, position=2, target=3] Agent2[path=4, position=5, target=6] Agent3[path=7, position=8, target=9] For example, on a 6x6 grid, a possible observation is shown below. 1 2 3 4 5 6 [[ 2 0 3 0 0 0] [ 1 0 4 4 4 0] [ 1 0 5 9 0 0] [ 1 0 0 0 0 0] [ 0 0 0 8 0 0] [ 0 0 6 7 7 7]]","title":"Encoding"},{"location":"environments/connector/#action","text":"The action space is a MultiDiscreteArray of shape (num_agents,) of integer values in the range of [0, 4] . Each value corresponds to an agent moving in 1 of 4 cardinal directions or taking the no-op action. That is, [0, 1, 2, 3, 4] -> [No Op, Up, Right, Down, Left].","title":"Action"},{"location":"environments/connector/#reward","text":"The reward is dense : +1.0 per agent that connects at that step and -0.03 per agent that has not connected yet. Rewards are provided in the shape (num_agents,) so that each agent can have a reward.","title":"Reward"},{"location":"environments/connector/#registered-versions","text":"Connector-v2 , grid size of 10 and 10 agents.","title":"Registered Versions \ud83d\udcd6"},{"location":"environments/cvrp/","text":"Capacitated Vehicle Routing Problem (CVRP) Environment # We provide here a Jax JIT-able implementation of the capacitated vehicle routing problem (CVRP) which is a specific type of VRP . CVRP is a classic combinatorial optimization problem. Given a set of nodes with specific demands, a depot node, and a vehicle with limited capacity, the goal is to determine the shortest route between the nodes such that each node (excluding depot) is visited exactly once and has its demand covered. The problem is NP-complete, thus there is no known algorithm both correct and fast (i.e., that runs in polynomial time) for any instance of the problem. A new problem instance is generated by resetting the environment. The problem instance contains coordinates for each node sampled from a uniform distribution between 0 and 1, and each node (except for depot) has a specific demand which is an integer value sampled from a uniform distribution between 1 and the maximum demand (which is a parameter of the CVRP environment). The number of nodes with demand is a parameter of the environment. Observation # The observation given to the agent provides information on the problem layout, the visited/unvisited cities and the current position of the agent as well as the current capacity. coordinates : jax array (float) of shape (num_nodes + 1, 2) , array of coordinates of each city node and the depot node. demands : jax array (float) of shape (num_nodes + 1,) , array of the demands of each city node and the depot node whose demand is set to 0. unvisited_nodes : jax array (bool) of shape (num_nodes + 1,) , array denoting which nodes remain to be visited. position : jax array (int32) of shape () , identifier (index) of the current visited node (city or depot). trajectory : jax array (int32) of shape (2 * num_nodes,) , identifiers of the nodes that have been visited (set to DEPOT_IDX if not filled yet). capacity : jax array (float) of shape () , current capacity of the vehicle. action_mask : jax array (bool) of shape (num_nodes + 1,) , array denoting which actions are possible (True) and which are not (False). Action # The action space is a DiscreteArray of integer values in the range of [0, num_nodes] . An action is the index of the next node to visit, and an action value of 0 corresponds to visiting the depot. Reward # The reward could be either: Dense : the negative distance between the current node and the chosen next node to go to. For the last node, it also includes the distance to the depot to complete the tour. Sparse : the negative tour length at the end of the episode. The tour length is defined as the sum of the distances between consecutive nodes. In both cases, the reward is a large negative penalty of -2 * num_nodes * sqrt(2) if the action is invalid, e.g. a previously selected node other than the depot is selected again. Registered Versions \ud83d\udcd6 # CVRP-v1 : CVRP problem with 20 randomly generated nodes, a maximum capacity of 30, a maximum demand for each node of 10 and a dense reward function.","title":"CVRP"},{"location":"environments/cvrp/#capacitated-vehicle-routing-problem-cvrp-environment","text":"We provide here a Jax JIT-able implementation of the capacitated vehicle routing problem (CVRP) which is a specific type of VRP . CVRP is a classic combinatorial optimization problem. Given a set of nodes with specific demands, a depot node, and a vehicle with limited capacity, the goal is to determine the shortest route between the nodes such that each node (excluding depot) is visited exactly once and has its demand covered. The problem is NP-complete, thus there is no known algorithm both correct and fast (i.e., that runs in polynomial time) for any instance of the problem. A new problem instance is generated by resetting the environment. The problem instance contains coordinates for each node sampled from a uniform distribution between 0 and 1, and each node (except for depot) has a specific demand which is an integer value sampled from a uniform distribution between 1 and the maximum demand (which is a parameter of the CVRP environment). The number of nodes with demand is a parameter of the environment.","title":"Capacitated Vehicle Routing Problem (CVRP) Environment"},{"location":"environments/cvrp/#observation","text":"The observation given to the agent provides information on the problem layout, the visited/unvisited cities and the current position of the agent as well as the current capacity. coordinates : jax array (float) of shape (num_nodes + 1, 2) , array of coordinates of each city node and the depot node. demands : jax array (float) of shape (num_nodes + 1,) , array of the demands of each city node and the depot node whose demand is set to 0. unvisited_nodes : jax array (bool) of shape (num_nodes + 1,) , array denoting which nodes remain to be visited. position : jax array (int32) of shape () , identifier (index) of the current visited node (city or depot). trajectory : jax array (int32) of shape (2 * num_nodes,) , identifiers of the nodes that have been visited (set to DEPOT_IDX if not filled yet). capacity : jax array (float) of shape () , current capacity of the vehicle. action_mask : jax array (bool) of shape (num_nodes + 1,) , array denoting which actions are possible (True) and which are not (False).","title":"Observation"},{"location":"environments/cvrp/#action","text":"The action space is a DiscreteArray of integer values in the range of [0, num_nodes] . An action is the index of the next node to visit, and an action value of 0 corresponds to visiting the depot.","title":"Action"},{"location":"environments/cvrp/#reward","text":"The reward could be either: Dense : the negative distance between the current node and the chosen next node to go to. For the last node, it also includes the distance to the depot to complete the tour. Sparse : the negative tour length at the end of the episode. The tour length is defined as the sum of the distances between consecutive nodes. In both cases, the reward is a large negative penalty of -2 * num_nodes * sqrt(2) if the action is invalid, e.g. a previously selected node other than the depot is selected again.","title":"Reward"},{"location":"environments/cvrp/#registered-versions","text":"CVRP-v1 : CVRP problem with 20 randomly generated nodes, a maximum capacity of 30, a maximum demand for each node of 10 and a dense reward function.","title":"Registered Versions \ud83d\udcd6"},{"location":"environments/game_2048/","text":"2048 Environment # We provide here a Jax JIT-able implementation of the game 2048 . 2048 is a popular single-player puzzle game that is played on a 4x4 grid. The game board consists of cells, each containing a power of 2, and the objective is to reach a score of at least 2048 by merging cells together. The player can shift the entire grid in one of the four directions (up, down, right, left) to combine cells of the same value. When two adjacent cells have the same value, they merge into a single cell with a value equal to the sum of the two cells. The game ends when the player is no longer able to make any further moves. The ultimate goal is to achieve the highest-valued tile possible, with the hope of surpassing 2048. With each move, the player must carefully plan and strategize to reach the highest score possible. Observation # The observation in the game 2048 includes information about the board, the action mask, and the step count. board : jax array (int32) of shape (board_size, board_size) , representing the current game state. Each nonzero element in the array corresponds to a game tile and holds an exponent of 2. The actual value of the tile is obtained by raising 2 to the power of said exponent. Here is an example of a random observation of the game board: 1 2 3 4 [[ 2 0 1 4] [ 5 3 0 2] [ 0 2 3 2] [ 1 2 0 0]] This array can be converted into the actual game board: 1 2 3 4 [[ 4 0 2 16] [ 32 8 0 4] [ 0 4 8 4] [ 2 4 0 0]] action_mask : jax array (bool) of shape (4,) , indicating which actions are valid in the current state of the environment. The actions include moving the tiles up, right, down, or left. For example, an action mask [False, True, False, False] means that the only valid action is to move the tiles rightward. Action # The action space is a DiscreteArray of integer values in [0, 1, 2, 3] . Specifically, these four actions correspond to: up (0), right (1), down (2), or left (3). Reward # Taking an action in 2048 only returns a reward when two tiles of equal value are merged into a new tile containing their sum (i.e. twice each of their values). The cumulative reward in an episode is the sum of the values of all newly created tiles. For example, if a player merges two 512-value tiles to create a new 1024-value tile, and then merges two 256-value tiles to create a new 512-value tile, the total reward from these actions is 1536 (i.e., 1024 + 512). Registered Versions \ud83d\udcd6 # Game2048-v1 , the default settings for 2048 with a board of size 4x4.","title":"Game2048"},{"location":"environments/game_2048/#2048-environment","text":"We provide here a Jax JIT-able implementation of the game 2048 . 2048 is a popular single-player puzzle game that is played on a 4x4 grid. The game board consists of cells, each containing a power of 2, and the objective is to reach a score of at least 2048 by merging cells together. The player can shift the entire grid in one of the four directions (up, down, right, left) to combine cells of the same value. When two adjacent cells have the same value, they merge into a single cell with a value equal to the sum of the two cells. The game ends when the player is no longer able to make any further moves. The ultimate goal is to achieve the highest-valued tile possible, with the hope of surpassing 2048. With each move, the player must carefully plan and strategize to reach the highest score possible.","title":"2048 Environment"},{"location":"environments/game_2048/#observation","text":"The observation in the game 2048 includes information about the board, the action mask, and the step count. board : jax array (int32) of shape (board_size, board_size) , representing the current game state. Each nonzero element in the array corresponds to a game tile and holds an exponent of 2. The actual value of the tile is obtained by raising 2 to the power of said exponent. Here is an example of a random observation of the game board: 1 2 3 4 [[ 2 0 1 4] [ 5 3 0 2] [ 0 2 3 2] [ 1 2 0 0]] This array can be converted into the actual game board: 1 2 3 4 [[ 4 0 2 16] [ 32 8 0 4] [ 0 4 8 4] [ 2 4 0 0]] action_mask : jax array (bool) of shape (4,) , indicating which actions are valid in the current state of the environment. The actions include moving the tiles up, right, down, or left. For example, an action mask [False, True, False, False] means that the only valid action is to move the tiles rightward.","title":"Observation"},{"location":"environments/game_2048/#action","text":"The action space is a DiscreteArray of integer values in [0, 1, 2, 3] . Specifically, these four actions correspond to: up (0), right (1), down (2), or left (3).","title":"Action"},{"location":"environments/game_2048/#reward","text":"Taking an action in 2048 only returns a reward when two tiles of equal value are merged into a new tile containing their sum (i.e. twice each of their values). The cumulative reward in an episode is the sum of the values of all newly created tiles. For example, if a player merges two 512-value tiles to create a new 1024-value tile, and then merges two 256-value tiles to create a new 512-value tile, the total reward from these actions is 1536 (i.e., 1024 + 512).","title":"Reward"},{"location":"environments/game_2048/#registered-versions","text":"Game2048-v1 , the default settings for 2048 with a board of size 4x4.","title":"Registered Versions \ud83d\udcd6"},{"location":"environments/graph_coloring/","text":"Graph Coloring Environment # We provide here a Jax JIT-able implementation of the Graph Coloring environment. Graph coloring is a combinatorial optimization problem where the objective is to assign a color to each vertex of a graph in such a way that no two adjacent vertices share the same color. The problem is usually formulated as minimizing the number of colors used. The GraphColoring environment is an episodic, single-agent setting that allows for the exploration of graph coloring algorithms and reinforcement learning methods. Observation # The observation in the GraphColoring environment includes information about the graph, the colors assigned to the vertices, the action mask, and the current node index. graph : jax array (bool) of shape (num_nodes, num_nodes) , representing the adjacency matrix of the graph. For example, a random observation of the graph adjacency matrix: 1 2 3 4 ```[[False, True, False, True], [ True, False, True, False], [False, True, False, True], [ True, False, True, False]]``` colors : a JAX array (int32) of shape (num_nodes,) , representing the current color assignments for the vertices. Initially, all elements are set to -1, indicating that no colors have been assigned yet. For example, an initial color assignment: [-1, -1, -1, -1] action_mask : a JAX array of boolean values, shaped (num_colors,) , which indicates the valid actions in the current state of the environment. Each position in the array corresponds to a color. True at a position signifies that the corresponding color can be used to color a node, while False indicates the opposite. For example, for 4 number of colors available: [True, False, True, False] current_node_index : an integer representing the current node being colored. For example, an initial current_node_index might be 0. Action # The action space is a DiscreteArray of integer values in [0, 1, ..., num_colors - 1] . Each action corresponds to assigning a color to the current node. Reward # The reward in the GraphColoring environment is given as follows: sparse reward : a reward is provided at the end of the episode and equals the negative of the number of unique colors used to color all vertices in the graph. The agent's goal is to find a valid coloring using as few colors as possible while avoiding conflicts with adjacent nodes. Episode Termination # The goal of the agent is to find a valid coloring using as few colors as possible. An episode in the graph coloring environment can terminate under two conditions: All nodes have been assigned a color: the environment iteratively assigns colors to nodes. When all nodes have a color assigned (i.e., there are no nodes with a color value of -1), the episode ends. This is the natural termination condition and ideally the one we'd like the agent to achieve. Invalid action is taken: an action is considered invalid if it tries to assign a color to a node that is not within the allowed color set for that node at that time. The allowed color set for each node is updated after every action. If an invalid action is attempted, the episode immediately terminates and the agent receives a large negative reward. This encourages the agent to learn valid actions and discourages it from making invalid actions. Registered Versions \ud83d\udcd6 # GraphColoring-v0 : The default settings for the GraphColoring problem with a configurable number of nodes and edge_probability. The default number of nodes is 20, and the default edge probability is 0.8.","title":"GraphColoring"},{"location":"environments/graph_coloring/#graph-coloring-environment","text":"We provide here a Jax JIT-able implementation of the Graph Coloring environment. Graph coloring is a combinatorial optimization problem where the objective is to assign a color to each vertex of a graph in such a way that no two adjacent vertices share the same color. The problem is usually formulated as minimizing the number of colors used. The GraphColoring environment is an episodic, single-agent setting that allows for the exploration of graph coloring algorithms and reinforcement learning methods.","title":"Graph Coloring Environment"},{"location":"environments/graph_coloring/#observation","text":"The observation in the GraphColoring environment includes information about the graph, the colors assigned to the vertices, the action mask, and the current node index. graph : jax array (bool) of shape (num_nodes, num_nodes) , representing the adjacency matrix of the graph. For example, a random observation of the graph adjacency matrix: 1 2 3 4 ```[[False, True, False, True], [ True, False, True, False], [False, True, False, True], [ True, False, True, False]]``` colors : a JAX array (int32) of shape (num_nodes,) , representing the current color assignments for the vertices. Initially, all elements are set to -1, indicating that no colors have been assigned yet. For example, an initial color assignment: [-1, -1, -1, -1] action_mask : a JAX array of boolean values, shaped (num_colors,) , which indicates the valid actions in the current state of the environment. Each position in the array corresponds to a color. True at a position signifies that the corresponding color can be used to color a node, while False indicates the opposite. For example, for 4 number of colors available: [True, False, True, False] current_node_index : an integer representing the current node being colored. For example, an initial current_node_index might be 0.","title":"Observation"},{"location":"environments/graph_coloring/#action","text":"The action space is a DiscreteArray of integer values in [0, 1, ..., num_colors - 1] . Each action corresponds to assigning a color to the current node.","title":"Action"},{"location":"environments/graph_coloring/#reward","text":"The reward in the GraphColoring environment is given as follows: sparse reward : a reward is provided at the end of the episode and equals the negative of the number of unique colors used to color all vertices in the graph. The agent's goal is to find a valid coloring using as few colors as possible while avoiding conflicts with adjacent nodes.","title":"Reward"},{"location":"environments/graph_coloring/#episode-termination","text":"The goal of the agent is to find a valid coloring using as few colors as possible. An episode in the graph coloring environment can terminate under two conditions: All nodes have been assigned a color: the environment iteratively assigns colors to nodes. When all nodes have a color assigned (i.e., there are no nodes with a color value of -1), the episode ends. This is the natural termination condition and ideally the one we'd like the agent to achieve. Invalid action is taken: an action is considered invalid if it tries to assign a color to a node that is not within the allowed color set for that node at that time. The allowed color set for each node is updated after every action. If an invalid action is attempted, the episode immediately terminates and the agent receives a large negative reward. This encourages the agent to learn valid actions and discourages it from making invalid actions.","title":"Episode Termination"},{"location":"environments/graph_coloring/#registered-versions","text":"GraphColoring-v0 : The default settings for the GraphColoring problem with a configurable number of nodes and edge_probability. The default number of nodes is 20, and the default edge probability is 0.8.","title":"Registered Versions \ud83d\udcd6"},{"location":"environments/job_shop/","text":"JobShop Environment # We provide here a JAX jit-able implementation of the job shop scheduling problem . It is NP-hard and one of the most well-known combinatorial optimisation problems. The problem formulation is: N jobs , each consisting of a sequence of operations , need to be scheduled on M machines. For each job, its operations must be processed in order . This is called the precedence constraints . Only one operation in a job can be processed at any given time. A machine can only work on one operation at a time. Once started, an operation must run to completion. The goal of the agent is to determine the schedule that minimises the time needed to process all the jobs. The length of the schedule is also known as its makespan . Observation # The observation seen by the agent is a NamedTuple containing the following: ops_machine_ids : jax array (int32) of shape (num_jobs, max_num_ops) . For each job, it specifies the machine each op must be processed on. Note that a -1 corresponds to padded ops since not all jobs have the same number of ops. ops_durations : jax array (int32) of shape (num_jobs, max_num_ops) . For each job, it specifies the processing time of each operation. Note that a -1 corresponds to padded ops since not all jobs have the same number of ops. ops_mask : jax array (bool) of shape (num_jobs, max_num_ops) . For each job, indicates which operations remain to be scheduled. False if the op has been scheduled or if the op was added for padding, True otherwise. The first True in each row (i.e. each job) identifies the next operation for that job. machines_job_ids : jax array (int32) of shape (num_machines,) . For each machine, it specifies the job currently being processed. Note that -1 means no-op in which case the remaining time until available is always 0. machines_remaining_times : jax array (int32) of shape (num_machines,) . For each machine, it specifies the number of time steps until available. action_mask : jax array (bool) of (num_machines, num_jobs + 1) . For each machine, it indicates which jobs (or no-op) can legally be scheduled. The last column corresponds to no-op. Action # The action space is a MultiDiscreteArray containing an integer value in [0, 1, ..., num_jobs] for each machine. Thus, an action consists of the following: for each machine, decide which job (or no-op) to schedule at the current time step. The action is represented as a 1-dimensional array of length num_machines . For example, suppose we have M=5 machines and there are N=10 jobs. A legal action might be 1 action = [ 4 , 7 , 0 , 10 , 10 ] This action represents scheduling Job 4 on Machine 0, Job 7 on Machine 1, Job 0 on Machine 2, No-op on Machine 3, No-op on Machine 4. As such, the action is multidimensional and can be thought of as each machine (each agent) deciding which job (or no-op) to schedule. Importantly, the action space is a product of the marginal action space of each agent (machine). The rationale for having a no-op is the following: A machine might be busy processing an operation, in which case a no-op is the only allowed action for that machine. There might not be any jobs that can be scheduled on a machine. There may be scenarios where waiting to schedule a job via one or more no-op(s) ultimately minimizes the makespan. Reward # The reward setting is dense: a reward of -1 is given each time step if none of the termination criteria are met. An episode will terminate in any of the three scenarios below: Finished schedule : all operations (and thus all jobs) every job have been processed. Illegal action: the agent ignores the action mask and takes an illegal action. Simultaneously idle: all machines are inactive at the same time. If all machines are simultaneously idle or the agent selects an invalid action, this is reflected in a large penalty in the reward. This would be -num_jobs * max_num_ops * max_op_duration which is a upper bound on the makespan, corresponding to if every job had max_num_ops operations and every operation had a processing time of max_op_duration . Registered Versions \ud83d\udcd6 # JobShop-v0 : job-shop scheduling problem with 20 jobs, 10 machines, a maximum of 8 operations per job, and a max operation duration of 6 timesteps per operation.","title":"JobShop"},{"location":"environments/job_shop/#jobshop-environment","text":"We provide here a JAX jit-able implementation of the job shop scheduling problem . It is NP-hard and one of the most well-known combinatorial optimisation problems. The problem formulation is: N jobs , each consisting of a sequence of operations , need to be scheduled on M machines. For each job, its operations must be processed in order . This is called the precedence constraints . Only one operation in a job can be processed at any given time. A machine can only work on one operation at a time. Once started, an operation must run to completion. The goal of the agent is to determine the schedule that minimises the time needed to process all the jobs. The length of the schedule is also known as its makespan .","title":"JobShop Environment"},{"location":"environments/job_shop/#observation","text":"The observation seen by the agent is a NamedTuple containing the following: ops_machine_ids : jax array (int32) of shape (num_jobs, max_num_ops) . For each job, it specifies the machine each op must be processed on. Note that a -1 corresponds to padded ops since not all jobs have the same number of ops. ops_durations : jax array (int32) of shape (num_jobs, max_num_ops) . For each job, it specifies the processing time of each operation. Note that a -1 corresponds to padded ops since not all jobs have the same number of ops. ops_mask : jax array (bool) of shape (num_jobs, max_num_ops) . For each job, indicates which operations remain to be scheduled. False if the op has been scheduled or if the op was added for padding, True otherwise. The first True in each row (i.e. each job) identifies the next operation for that job. machines_job_ids : jax array (int32) of shape (num_machines,) . For each machine, it specifies the job currently being processed. Note that -1 means no-op in which case the remaining time until available is always 0. machines_remaining_times : jax array (int32) of shape (num_machines,) . For each machine, it specifies the number of time steps until available. action_mask : jax array (bool) of (num_machines, num_jobs + 1) . For each machine, it indicates which jobs (or no-op) can legally be scheduled. The last column corresponds to no-op.","title":"Observation"},{"location":"environments/job_shop/#action","text":"The action space is a MultiDiscreteArray containing an integer value in [0, 1, ..., num_jobs] for each machine. Thus, an action consists of the following: for each machine, decide which job (or no-op) to schedule at the current time step. The action is represented as a 1-dimensional array of length num_machines . For example, suppose we have M=5 machines and there are N=10 jobs. A legal action might be 1 action = [ 4 , 7 , 0 , 10 , 10 ] This action represents scheduling Job 4 on Machine 0, Job 7 on Machine 1, Job 0 on Machine 2, No-op on Machine 3, No-op on Machine 4. As such, the action is multidimensional and can be thought of as each machine (each agent) deciding which job (or no-op) to schedule. Importantly, the action space is a product of the marginal action space of each agent (machine). The rationale for having a no-op is the following: A machine might be busy processing an operation, in which case a no-op is the only allowed action for that machine. There might not be any jobs that can be scheduled on a machine. There may be scenarios where waiting to schedule a job via one or more no-op(s) ultimately minimizes the makespan.","title":"Action"},{"location":"environments/job_shop/#reward","text":"The reward setting is dense: a reward of -1 is given each time step if none of the termination criteria are met. An episode will terminate in any of the three scenarios below: Finished schedule : all operations (and thus all jobs) every job have been processed. Illegal action: the agent ignores the action mask and takes an illegal action. Simultaneously idle: all machines are inactive at the same time. If all machines are simultaneously idle or the agent selects an invalid action, this is reflected in a large penalty in the reward. This would be -num_jobs * max_num_ops * max_op_duration which is a upper bound on the makespan, corresponding to if every job had max_num_ops operations and every operation had a processing time of max_op_duration .","title":"Reward"},{"location":"environments/job_shop/#registered-versions","text":"JobShop-v0 : job-shop scheduling problem with 20 jobs, 10 machines, a maximum of 8 operations per job, and a max operation duration of 6 timesteps per operation.","title":"Registered Versions \ud83d\udcd6"},{"location":"environments/knapsack/","text":"Knapskack Environment # We provide here a Jax JIT-able implementation of the knapskack problem . The knapsack problem is a famous problem in combinatorial optimization. The goal is to determine, given a set of items, each with a weight and a value, which items to include in a collection so that the total weight is less than or equal to a given limit and the total value is as large as possible. The decision problem form of the knapsack problem is NP-complete, thus there is no known algorithm both correct and fast (polynomial-time) in all cases. When the environment is reset, a new problem instance is generated, by sampling weights and values from a uniform distribution between 0 and 1. The weight limit of the knapsack is a parameter of the environment. A trajectory terminates when no further item can be added to the knapsack or the chosen action is invalid. Observation # The observation given to the agent provides information regarding the weights and the values of all the items, as well as, which items have been packed into the knapsack. weights : jax array (float) of shape (num_items,) , array of weights of the items to be packed into the knapsack. values : jax array (float) of shape (num_items,) , array of values of the items to be packed into the knapsack. packed_items : jax array (bool) of shape (num_items,) , array of binary values denoting which items are already packed into the knapsack. action_mask : jax array (bool) of shape (num_items,) , array of binary values denoting which items can be packed into the knapsack. Action # The action space is a DiscreteArray of integer values in the range of [0, num_items-1] . An action is the index of the next item to pack. Reward # The reward can be either: Dense : the value of the item to pack at the current timestep. Sparse : the sum of the values of the items packed in the bag at the end of the episode. In both cases, the reward is 0 if the action is invalid, i.e. an item that was previously selected is selected again or has a weight larger than the bag capacity. Registered Versions \ud83d\udcd6 # Knapsack-v1 : Knapsack problem with 50 randomly generated items, a total budget of 12.5 and a dense reward function.","title":"Knapsack"},{"location":"environments/knapsack/#knapskack-environment","text":"We provide here a Jax JIT-able implementation of the knapskack problem . The knapsack problem is a famous problem in combinatorial optimization. The goal is to determine, given a set of items, each with a weight and a value, which items to include in a collection so that the total weight is less than or equal to a given limit and the total value is as large as possible. The decision problem form of the knapsack problem is NP-complete, thus there is no known algorithm both correct and fast (polynomial-time) in all cases. When the environment is reset, a new problem instance is generated, by sampling weights and values from a uniform distribution between 0 and 1. The weight limit of the knapsack is a parameter of the environment. A trajectory terminates when no further item can be added to the knapsack or the chosen action is invalid.","title":"Knapskack Environment"},{"location":"environments/knapsack/#observation","text":"The observation given to the agent provides information regarding the weights and the values of all the items, as well as, which items have been packed into the knapsack. weights : jax array (float) of shape (num_items,) , array of weights of the items to be packed into the knapsack. values : jax array (float) of shape (num_items,) , array of values of the items to be packed into the knapsack. packed_items : jax array (bool) of shape (num_items,) , array of binary values denoting which items are already packed into the knapsack. action_mask : jax array (bool) of shape (num_items,) , array of binary values denoting which items can be packed into the knapsack.","title":"Observation"},{"location":"environments/knapsack/#action","text":"The action space is a DiscreteArray of integer values in the range of [0, num_items-1] . An action is the index of the next item to pack.","title":"Action"},{"location":"environments/knapsack/#reward","text":"The reward can be either: Dense : the value of the item to pack at the current timestep. Sparse : the sum of the values of the items packed in the bag at the end of the episode. In both cases, the reward is 0 if the action is invalid, i.e. an item that was previously selected is selected again or has a weight larger than the bag capacity.","title":"Reward"},{"location":"environments/knapsack/#registered-versions","text":"Knapsack-v1 : Knapsack problem with 50 randomly generated items, a total budget of 12.5 and a dense reward function.","title":"Registered Versions \ud83d\udcd6"},{"location":"environments/maze/","text":"Maze Environment # We provide here a Jax JIT-able implementation of a 2D maze problem. The maze is a size-configurable 2D matrix where each cell represents either free space (white) or wall (black). The goal is for the agent (green) to reach the single target cell (red). It is a sparse reward problem, where the agent receives a reward of 0 at every step and a reward of 1 for reaching the target. The agent may choose to move one space up, right, down, or left: (\"N\", \u201cE\u201d, \"S\", \"W\"). If the way is blocked by a wall, it will remain at the same position. Each maze is randomly generated using a recursive division function. By default, a new maze, initial agent position and target position are generated each time the environment is reset. Observation # As an observation, the agent has access to the current maze configuration in the array named walls . It also has access to its current position agent_position , the target's target_position , the number of steps step_count elapsed in the current episode and the action mask action_mask . agent_position : Position(row, col) (int32) each of shape () , agent position in the maze. target_position : Position(row, col) (int32) each of shape () , target position in the maze. walls : jax array (bool) of shape (num_rows, num_cols) , indicates whether a grid cell is a wall. step_count : jax array (int32) of shape () , number of steps elapsed in the current episode. action_mask : jax array (bool) of shape (4,) , binary values denoting whether each action is possible. An example 5x5 observation walls array, is shown below. 1 represents a wall, and 0 represents free space. 1 2 3 4 5 [0, 1, 0, 0, 0], [0, 1, 0, 1, 1], [0, 1, 0, 0, 0], [0, 0, 0, 1, 1], [0, 0, 0, 0, 0] Action # The action space is a DiscreteArray of integer values in the range of [0, 3]. I.e. the agent can take one of four actions: up ( 0 ), right ( 1 ), down ( 2 ), or left ( 3 ). If an invalid action is taken, or an action is blocked by a wall, a no-op is performed and the agent's position remains unchanged. Reward # Maze is a sparse reward problem, where the agent receives a reward of 0 at every step and a reward of 1 for reaching the target position. An episode ends when the agent reaches the target position, or after a set number of steps (by default, this is twice the number of cells in the maze, i.e. step_limit=2*num_rows*num_cols ). Registered Versions \ud83d\udcd6 # Maze-v0 , maze with 10 rows and 10 cols.","title":"Maze"},{"location":"environments/maze/#maze-environment","text":"We provide here a Jax JIT-able implementation of a 2D maze problem. The maze is a size-configurable 2D matrix where each cell represents either free space (white) or wall (black). The goal is for the agent (green) to reach the single target cell (red). It is a sparse reward problem, where the agent receives a reward of 0 at every step and a reward of 1 for reaching the target. The agent may choose to move one space up, right, down, or left: (\"N\", \u201cE\u201d, \"S\", \"W\"). If the way is blocked by a wall, it will remain at the same position. Each maze is randomly generated using a recursive division function. By default, a new maze, initial agent position and target position are generated each time the environment is reset.","title":"Maze Environment"},{"location":"environments/maze/#observation","text":"As an observation, the agent has access to the current maze configuration in the array named walls . It also has access to its current position agent_position , the target's target_position , the number of steps step_count elapsed in the current episode and the action mask action_mask . agent_position : Position(row, col) (int32) each of shape () , agent position in the maze. target_position : Position(row, col) (int32) each of shape () , target position in the maze. walls : jax array (bool) of shape (num_rows, num_cols) , indicates whether a grid cell is a wall. step_count : jax array (int32) of shape () , number of steps elapsed in the current episode. action_mask : jax array (bool) of shape (4,) , binary values denoting whether each action is possible. An example 5x5 observation walls array, is shown below. 1 represents a wall, and 0 represents free space. 1 2 3 4 5 [0, 1, 0, 0, 0], [0, 1, 0, 1, 1], [0, 1, 0, 0, 0], [0, 0, 0, 1, 1], [0, 0, 0, 0, 0]","title":"Observation"},{"location":"environments/maze/#action","text":"The action space is a DiscreteArray of integer values in the range of [0, 3]. I.e. the agent can take one of four actions: up ( 0 ), right ( 1 ), down ( 2 ), or left ( 3 ). If an invalid action is taken, or an action is blocked by a wall, a no-op is performed and the agent's position remains unchanged.","title":"Action"},{"location":"environments/maze/#reward","text":"Maze is a sparse reward problem, where the agent receives a reward of 0 at every step and a reward of 1 for reaching the target position. An episode ends when the agent reaches the target position, or after a set number of steps (by default, this is twice the number of cells in the maze, i.e. step_limit=2*num_rows*num_cols ).","title":"Reward"},{"location":"environments/maze/#registered-versions","text":"Maze-v0 , maze with 10 rows and 10 cols.","title":"Registered Versions \ud83d\udcd6"},{"location":"environments/minesweeper/","text":"Minesweeper Environment # We provide here a Jax JIT-able implementation of the Minesweeper game. Observation # The observation given to the agent consists of: board : jax array (int32) of shape (num_rows, num_cols) : each cell contains -1 if not yet explored, or otherwise the number of mines in the 8 adjacent squares. action_mask : jax array (bool) of shape (num_rows, num_cols) : indicates which actions are valid (not yet explored squares). This can also be determined from the board which will have an entry of -1 in all of these positions. num_mines : jax array (int32) of shape () , indicates the number of mines to locate. step_count : jax array (int32) of shape () : specifies how many timesteps have elapsed since environment reset. Action # The action space is a MultiDiscreteArray of integer values representing coordinates of the square to explore, e.g. [3, 6] for the cell located on the third row and sixth column. If either a mined square or an already explored square is selected, the episode terminates (the latter are termed invalid actions ). Also, exploring a square will reveal only the contents of that square. This differs slightly from the usual implementation of the game, which automatically and recursively reveals neighbouring squares if there are no adjacent mines. Reward # The reward is configurable, but default to +1 for exploring a new square that does not contain a mine, and 0 otherwise (which also terminates the episode). The episode also terminates if the board is solved. Registered Versions \ud83d\udcd6 # Minesweeper-v0 , the classic game on a 10x10 grid with 10 mines to locate.","title":"Minesweeper"},{"location":"environments/minesweeper/#minesweeper-environment","text":"We provide here a Jax JIT-able implementation of the Minesweeper game.","title":"Minesweeper Environment"},{"location":"environments/minesweeper/#observation","text":"The observation given to the agent consists of: board : jax array (int32) of shape (num_rows, num_cols) : each cell contains -1 if not yet explored, or otherwise the number of mines in the 8 adjacent squares. action_mask : jax array (bool) of shape (num_rows, num_cols) : indicates which actions are valid (not yet explored squares). This can also be determined from the board which will have an entry of -1 in all of these positions. num_mines : jax array (int32) of shape () , indicates the number of mines to locate. step_count : jax array (int32) of shape () : specifies how many timesteps have elapsed since environment reset.","title":"Observation"},{"location":"environments/minesweeper/#action","text":"The action space is a MultiDiscreteArray of integer values representing coordinates of the square to explore, e.g. [3, 6] for the cell located on the third row and sixth column. If either a mined square or an already explored square is selected, the episode terminates (the latter are termed invalid actions ). Also, exploring a square will reveal only the contents of that square. This differs slightly from the usual implementation of the game, which automatically and recursively reveals neighbouring squares if there are no adjacent mines.","title":"Action"},{"location":"environments/minesweeper/#reward","text":"The reward is configurable, but default to +1 for exploring a new square that does not contain a mine, and 0 otherwise (which also terminates the episode). The episode also terminates if the board is solved.","title":"Reward"},{"location":"environments/minesweeper/#registered-versions","text":"Minesweeper-v0 , the classic game on a 10x10 grid with 10 mines to locate.","title":"Registered Versions \ud83d\udcd6"},{"location":"environments/mmst/","text":"MMST Environment # The multi minimum spanning tree (mmst) environment consists of a random connected graph with groups of nodes (same node types) that needs to be connected. The goal of the environment is to connect all nodes of the same type together without using the same utility nodes (nodes that do not belong to any group of nodes) in the shortest time possible. An episode ends when all group of nodes are connected or the maximum number of steps is reached. Note: This environment can be treated as a multi agent problem with each agent atempting to connect one group of node. In this implementation, we treat the problem as single agent that outputs multiple actions per nodes. Observation # At each step observation contains 4 items: a node_types, an adjacency matrix for the graph, an action mask for each group of nodes (agent) and current node positon of each agent. node_types : Array representing the types of nodes in the problem. For example, if we have 12 nodes, their indices are 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11. Let's consider we have 2 agents. Agent 0 wants to connect nodes (0, 1, 9), and agent 1 wants to connect nodes (3, 5, 8). The remaining nodes are considered utility nodes. Therefore, in the state view, the node_types are represented as [0, 0, -1, 1, -1, 1, 0, -1, 1, 0, -1, -1]. When generating the problem, each agent starts from one of its nodes. So, if agent 0 starts on node 1 and agent 1 on node 3, the connected_nodes array will have values [1, -1, ...] and [3, -1, ...] respectively. The agent's observation is represented using the following rules: - Each agent should see its connected nodes on the path as 0. - Nodes that the agent still needs to connect are represented as 1. - The next agent's nodes are represented by 2 and 3, the next by 4 and 5, and so on. - Utility unconnected nodes are represented by -1. For the 12 node example mentioned above, the expected observation view node_types will have the following values: node_types = jnp.array( [ [1, 0, -1, 2, -1, 3, 1, -1, 3, 1, -1, -1], [3, 2, -1, 0, -1, 1, 3, -1, 1, 3, -1, -1], ], dtype=jnp.int32, ) Note: to make the environment single agent, we use the first agent's observation. adj_matrix : Adjacency matrix representing the connections between nodes. positions : Current node positions of the agents. In our current problem, this will be represented as jnp.array([1, 3]). step_count : integer to keep track of the number of steps. action_mask : Binary mask indicating the validity of each action. Given the current node on which the agent is located, this mask determines if there is a valid edge to every other node. Action # The action space is a MultiDiscreteArray of shape (num_agents,) of integer values in the range of [0, num_nodes-1] . During every step, an agent picks the next node it wants to move to. An action is invalid if the agent picks a node it has no edge to or the node is a utility node already been used by another agent. Reward # At every step, an agent receives a reward of 10.0 if it gets a valid connection, a reward of -1.0 if it does not connect and an extra penalty of -1.0 if it chooses an invalid action. The total step reward is the sum of rewards per agent. Registered Versions \ud83d\udcd6 # MMST-v0 , 3 agents, 36 nodes, 72 edges, 4 nodes to connect per agent and step limit of 70.","title":"MMST"},{"location":"environments/mmst/#mmst-environment","text":"The multi minimum spanning tree (mmst) environment consists of a random connected graph with groups of nodes (same node types) that needs to be connected. The goal of the environment is to connect all nodes of the same type together without using the same utility nodes (nodes that do not belong to any group of nodes) in the shortest time possible. An episode ends when all group of nodes are connected or the maximum number of steps is reached. Note: This environment can be treated as a multi agent problem with each agent atempting to connect one group of node. In this implementation, we treat the problem as single agent that outputs multiple actions per nodes.","title":"MMST Environment"},{"location":"environments/mmst/#observation","text":"At each step observation contains 4 items: a node_types, an adjacency matrix for the graph, an action mask for each group of nodes (agent) and current node positon of each agent. node_types : Array representing the types of nodes in the problem. For example, if we have 12 nodes, their indices are 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11. Let's consider we have 2 agents. Agent 0 wants to connect nodes (0, 1, 9), and agent 1 wants to connect nodes (3, 5, 8). The remaining nodes are considered utility nodes. Therefore, in the state view, the node_types are represented as [0, 0, -1, 1, -1, 1, 0, -1, 1, 0, -1, -1]. When generating the problem, each agent starts from one of its nodes. So, if agent 0 starts on node 1 and agent 1 on node 3, the connected_nodes array will have values [1, -1, ...] and [3, -1, ...] respectively. The agent's observation is represented using the following rules: - Each agent should see its connected nodes on the path as 0. - Nodes that the agent still needs to connect are represented as 1. - The next agent's nodes are represented by 2 and 3, the next by 4 and 5, and so on. - Utility unconnected nodes are represented by -1. For the 12 node example mentioned above, the expected observation view node_types will have the following values: node_types = jnp.array( [ [1, 0, -1, 2, -1, 3, 1, -1, 3, 1, -1, -1], [3, 2, -1, 0, -1, 1, 3, -1, 1, 3, -1, -1], ], dtype=jnp.int32, ) Note: to make the environment single agent, we use the first agent's observation. adj_matrix : Adjacency matrix representing the connections between nodes. positions : Current node positions of the agents. In our current problem, this will be represented as jnp.array([1, 3]). step_count : integer to keep track of the number of steps. action_mask : Binary mask indicating the validity of each action. Given the current node on which the agent is located, this mask determines if there is a valid edge to every other node.","title":"Observation"},{"location":"environments/mmst/#action","text":"The action space is a MultiDiscreteArray of shape (num_agents,) of integer values in the range of [0, num_nodes-1] . During every step, an agent picks the next node it wants to move to. An action is invalid if the agent picks a node it has no edge to or the node is a utility node already been used by another agent.","title":"Action"},{"location":"environments/mmst/#reward","text":"At every step, an agent receives a reward of 10.0 if it gets a valid connection, a reward of -1.0 if it does not connect and an extra penalty of -1.0 if it chooses an invalid action. The total step reward is the sum of rewards per agent.","title":"Reward"},{"location":"environments/mmst/#registered-versions","text":"MMST-v0 , 3 agents, 36 nodes, 72 edges, 4 nodes to connect per agent and step limit of 70.","title":"Registered Versions \ud83d\udcd6"},{"location":"environments/multi_cvrp/","text":"Multi Agent Capacitated Vehicle Routing Problem - MultiCVRP Environment # We provide here a Jax JIT-able implementation of the multi-agent capacitated vehicle routing problem (MultiCVRP) which is specified in MVRPSTW . This environment introduces the problem of routing multiple agents in a coordinated manner, specifically in the context of collecting items from various locations. Each agent controls one vehicle. The problem, called the multi-agent capacitated vehicle routing problem (MultiCVRP), entails directing a group of agents to different locations on a map. They need to collectively go to each node and return items to the depot location. To make the problem a bit more realistic we consider the multi-vehicle routing problem with soft time windows (MVRPSTW). In this formulation, each location on the map also has a soft time window in which the items must be collected. If the items are collected outside this window a penalty is provided to the agents. A new problem instance is generated by resetting the environment. The problem instance contains coordinates for each node sampled from a uniform distribution inside the map boundries, and each node (except for depot) has a specific demand which is an integer value sampled from a uniform distribution between 1 and the maximum demand. The number of nodes with demand is a parameter of the environment. Observation # Each agent receives information on the coordinates, demands, time windows and penalty coefficients of all the customer nodes. Futhermore the agents receive positions, local times and vehicle capacity information on all vehicles. Lastly an action mask is also provided to each agent. node_coordinates : jax array (float32) of shape (num_vehicles, num_customers + 1, 2) , shows an array of the coordinates of each customer node and the depot node. node_demands : jax array (int16) of shape (num_vehicles, num_customers + 1,) , shows an array of the demands of each city node (and depot node where the demand is set to 0). node_time_windows : jax array (float32) of shape (num_vehicles, num_customers + 1, 2) , shows an array of the early and late time cutoffs for each customer. node_penalty_coefs : jax array (float32) of shape (num_vehicles, num_customers + 1, 2) , shows the early and late penalty coefficients for arriving early or late at a customer's location. other_vehicles_position : jax array (int16) of shape (num_vehicles, num_vehicles - 1) , shows the positions of all other vehicles. other_vehicles_local_times : jax array (float32) of shape (num_vehicles, num_vehicles - 1) , shows the local times of all other vehicles. other_vehicles_capacities : jax array (int16) of shape (num_vehicles, num_vehicles - 1) , shows the capacities of all other vehicles. vehicle_position : jax array (int16) of shape (num_vehicles) , shows the positions of the vehicles controlled by the agents. vehicle_local_time : jax array (float32) of shape (num_vehicles) , shows the local times of the vehicles controlled by the agents. vehicle_capacity : jax array (int16) of shape (num_vehicles) , shows the capacity of the vehicles controlled by the agents. action_mask : jax array (bool) of shape (num_vehicles, num_customers + 1,) , denoting which actions are possible (True) and which are not (False). Action # Each agent's action space is a BoundedArray of integer values in the range of [0, num_customers] . An action is the index of the next node to visit, and an action value of 0 corresponds to visiting the depot. Reward # Dense : The reward is equal to the sum of negative distances of the current location and next location of all the vehicles. Time penalities are added if the agents arrived early or late to specific customers. If the max step limit is reached, the episode ends with a large negative reward which is equal to the maximum negative distance reward that can be incurred. Sparse : The reward is 0 at every step but the last, where the reward is the negative of the length of the path chosen by all the agents combined. Time penalities are added if the agents arrived early or late to specific customers. If the max step limit is reached, the episode ends with a large negative reward which is equal to the maximum negative distance reward that can be incurred. Registered Versions \ud83d\udcd6 # MultiCVRP-v0 : MultiCVRP problem with 20 customers (randomly generated), maximum capacity of 20, and maximum demand of 10 with two vehicles.","title":"MultiCVRP"},{"location":"environments/multi_cvrp/#multi-agent-capacitated-vehicle-routing-problem-multicvrp-environment","text":"We provide here a Jax JIT-able implementation of the multi-agent capacitated vehicle routing problem (MultiCVRP) which is specified in MVRPSTW . This environment introduces the problem of routing multiple agents in a coordinated manner, specifically in the context of collecting items from various locations. Each agent controls one vehicle. The problem, called the multi-agent capacitated vehicle routing problem (MultiCVRP), entails directing a group of agents to different locations on a map. They need to collectively go to each node and return items to the depot location. To make the problem a bit more realistic we consider the multi-vehicle routing problem with soft time windows (MVRPSTW). In this formulation, each location on the map also has a soft time window in which the items must be collected. If the items are collected outside this window a penalty is provided to the agents. A new problem instance is generated by resetting the environment. The problem instance contains coordinates for each node sampled from a uniform distribution inside the map boundries, and each node (except for depot) has a specific demand which is an integer value sampled from a uniform distribution between 1 and the maximum demand. The number of nodes with demand is a parameter of the environment.","title":"Multi Agent Capacitated Vehicle Routing Problem - MultiCVRP Environment"},{"location":"environments/multi_cvrp/#observation","text":"Each agent receives information on the coordinates, demands, time windows and penalty coefficients of all the customer nodes. Futhermore the agents receive positions, local times and vehicle capacity information on all vehicles. Lastly an action mask is also provided to each agent. node_coordinates : jax array (float32) of shape (num_vehicles, num_customers + 1, 2) , shows an array of the coordinates of each customer node and the depot node. node_demands : jax array (int16) of shape (num_vehicles, num_customers + 1,) , shows an array of the demands of each city node (and depot node where the demand is set to 0). node_time_windows : jax array (float32) of shape (num_vehicles, num_customers + 1, 2) , shows an array of the early and late time cutoffs for each customer. node_penalty_coefs : jax array (float32) of shape (num_vehicles, num_customers + 1, 2) , shows the early and late penalty coefficients for arriving early or late at a customer's location. other_vehicles_position : jax array (int16) of shape (num_vehicles, num_vehicles - 1) , shows the positions of all other vehicles. other_vehicles_local_times : jax array (float32) of shape (num_vehicles, num_vehicles - 1) , shows the local times of all other vehicles. other_vehicles_capacities : jax array (int16) of shape (num_vehicles, num_vehicles - 1) , shows the capacities of all other vehicles. vehicle_position : jax array (int16) of shape (num_vehicles) , shows the positions of the vehicles controlled by the agents. vehicle_local_time : jax array (float32) of shape (num_vehicles) , shows the local times of the vehicles controlled by the agents. vehicle_capacity : jax array (int16) of shape (num_vehicles) , shows the capacity of the vehicles controlled by the agents. action_mask : jax array (bool) of shape (num_vehicles, num_customers + 1,) , denoting which actions are possible (True) and which are not (False).","title":"Observation"},{"location":"environments/multi_cvrp/#action","text":"Each agent's action space is a BoundedArray of integer values in the range of [0, num_customers] . An action is the index of the next node to visit, and an action value of 0 corresponds to visiting the depot.","title":"Action"},{"location":"environments/multi_cvrp/#reward","text":"Dense : The reward is equal to the sum of negative distances of the current location and next location of all the vehicles. Time penalities are added if the agents arrived early or late to specific customers. If the max step limit is reached, the episode ends with a large negative reward which is equal to the maximum negative distance reward that can be incurred. Sparse : The reward is 0 at every step but the last, where the reward is the negative of the length of the path chosen by all the agents combined. Time penalities are added if the agents arrived early or late to specific customers. If the max step limit is reached, the episode ends with a large negative reward which is equal to the maximum negative distance reward that can be incurred.","title":"Reward"},{"location":"environments/multi_cvrp/#registered-versions","text":"MultiCVRP-v0 : MultiCVRP problem with 20 customers (randomly generated), maximum capacity of 20, and maximum demand of 10 with two vehicles.","title":"Registered Versions \ud83d\udcd6"},{"location":"environments/robot_warehouse/","text":"RobotWarehouse Environment # We provide a JAX jit-able implementation of the Robotic Warehouse environment. The Robot Warehouse (RWARE) environment simulates a warehouse with robots moving and delivering requested goods. Real-world applications inspire the simulator, in which robots pick up shelves and deliver them to a workstation. Humans access the content of a shelf, and then robots can return them to empty shelf locations. The goal is to successfully deliver as many requested shelves in a given time budget. Once a shelf has been delivered, a new shelf is requested at random. Agents start each episode at random locations within the warehouse. Observation # The observation seen by the agent is a NamedTuple containing the following: agents_view : jax array (int32) of shape (num_agents, num_obs_features) , array representing the agent's view of other agents and shelves. action_mask : jax array (bool) of shape (num_agents, 5) , array specifying, for each agent, which action (noop, forward, left, right, toggle_load) is legal. step_count : jax array (int32) of shape () , number of steps elapsed in the current episode. Action # The action space is a MultiDiscreteArray containing an integer value in [0, 1, 2, 3, 4] for each agent. Each agent can take one of five actions: noop ( 0 ), forward ( 1 ), turn left ( 2 ), turn right ( 3 ), or toggle_load ( 4 ). The episode terminates under the following conditions: An invalid action is taken, or An agent collides with another agent. Reward # The reward is global and shared among the agents. It is equal to the number of shelves which were delivered successfully during the time step (i.e., +1 for each shelf). Registered Versions \ud83d\udcd6 # RobotWarehouse-v0 , a warehouse with 4 agents each with a sensor range of 1, a warehouse floor with 2 shelf rows, 3 shelf columns, a column height of 8, and a shelf request queue of 8.","title":"RobotWarehouse"},{"location":"environments/robot_warehouse/#robotwarehouse-environment","text":"We provide a JAX jit-able implementation of the Robotic Warehouse environment. The Robot Warehouse (RWARE) environment simulates a warehouse with robots moving and delivering requested goods. Real-world applications inspire the simulator, in which robots pick up shelves and deliver them to a workstation. Humans access the content of a shelf, and then robots can return them to empty shelf locations. The goal is to successfully deliver as many requested shelves in a given time budget. Once a shelf has been delivered, a new shelf is requested at random. Agents start each episode at random locations within the warehouse.","title":"RobotWarehouse Environment"},{"location":"environments/robot_warehouse/#observation","text":"The observation seen by the agent is a NamedTuple containing the following: agents_view : jax array (int32) of shape (num_agents, num_obs_features) , array representing the agent's view of other agents and shelves. action_mask : jax array (bool) of shape (num_agents, 5) , array specifying, for each agent, which action (noop, forward, left, right, toggle_load) is legal. step_count : jax array (int32) of shape () , number of steps elapsed in the current episode.","title":"Observation"},{"location":"environments/robot_warehouse/#action","text":"The action space is a MultiDiscreteArray containing an integer value in [0, 1, 2, 3, 4] for each agent. Each agent can take one of five actions: noop ( 0 ), forward ( 1 ), turn left ( 2 ), turn right ( 3 ), or toggle_load ( 4 ). The episode terminates under the following conditions: An invalid action is taken, or An agent collides with another agent.","title":"Action"},{"location":"environments/robot_warehouse/#reward","text":"The reward is global and shared among the agents. It is equal to the number of shelves which were delivered successfully during the time step (i.e., +1 for each shelf).","title":"Reward"},{"location":"environments/robot_warehouse/#registered-versions","text":"RobotWarehouse-v0 , a warehouse with 4 agents each with a sensor range of 1, a warehouse floor with 2 shelf rows, 3 shelf columns, a column height of 8, and a shelf request queue of 8.","title":"Registered Versions \ud83d\udcd6"},{"location":"environments/rubiks_cube/","text":"Rubik's Cube Environment # We provide here a Jax JIT-able implementation of the Rubik's cube . The environment contains an implementation of the classic 3x3x3 cube by default, and configurably other sizes. The goal of the agent is to match all stickers on each face to a single colour. On resetting the environment the cube will be randomly scrambled with a configurable number of turns (by default 100). Observation # The observation given to the agent gives a view of the current state of the cube, cube : jax array (int8) of shape (6, cube_size, cube_size) whose values are in [0, 1, 2, 3, 4, 5] (corresponding to the different sticker colors). The indices of the array specify the sticker position - first the face (in the order up , front , right , back , left , down ) and then the row and column. Note that the orientation of each face is as follows: UP: LEFT face on the left and BACK face pointing up FRONT: LEFT face on the left and UP face pointing up RIGHT: FRONT face on the left and UP face pointing up BACK: RIGHT face on the left and UP face pointing up LEFT: BACK face on the left and UP face pointing up DOWN: LEFT face on the left and FRONT face pointing up step_count : jax array (int32) of shape () , representing the number of steps in the episode thus far. Action # The action space is a MultiDiscreteArray , specifically a tuple of an index between 0 and 5 (since there are 6 faces), an index between 0 and cube_size//2 (the number of possible depths), and an index between 0 and 2 (3 possible directions). An action thus consists of three pieces of information: Face to turn, Depth of the turn (possible depths are between 0 representing the outer layer and cube_size//2 representing the layer closest to the middle), Direction of turn (possible directions are clockwise, anti-clockwise, or a half turn). Reward # The reward function is configurable, but by default is the fully sparse reward giving +1 for solving the cube and otherwise 0 . The episode terminates if either the cube is solved or a configurable horizon (by default 200 ) is reached. Registered Versions \ud83d\udcd6 # RubiksCube-v0 , the standard Rubik's Cube puzzle with faces of size 3x3. RubiksCube-partly-scrambled-v0 , an easier version of the standard Rubik's Cube puzzle with faces of size 3x3 yet only 7 scrambles at reset time, making it technically maximum 7 actions away from the solution.","title":"RubiksCube"},{"location":"environments/rubiks_cube/#rubiks-cube-environment","text":"We provide here a Jax JIT-able implementation of the Rubik's cube . The environment contains an implementation of the classic 3x3x3 cube by default, and configurably other sizes. The goal of the agent is to match all stickers on each face to a single colour. On resetting the environment the cube will be randomly scrambled with a configurable number of turns (by default 100).","title":"Rubik's Cube Environment"},{"location":"environments/rubiks_cube/#observation","text":"The observation given to the agent gives a view of the current state of the cube, cube : jax array (int8) of shape (6, cube_size, cube_size) whose values are in [0, 1, 2, 3, 4, 5] (corresponding to the different sticker colors). The indices of the array specify the sticker position - first the face (in the order up , front , right , back , left , down ) and then the row and column. Note that the orientation of each face is as follows: UP: LEFT face on the left and BACK face pointing up FRONT: LEFT face on the left and UP face pointing up RIGHT: FRONT face on the left and UP face pointing up BACK: RIGHT face on the left and UP face pointing up LEFT: BACK face on the left and UP face pointing up DOWN: LEFT face on the left and FRONT face pointing up step_count : jax array (int32) of shape () , representing the number of steps in the episode thus far.","title":"Observation"},{"location":"environments/rubiks_cube/#action","text":"The action space is a MultiDiscreteArray , specifically a tuple of an index between 0 and 5 (since there are 6 faces), an index between 0 and cube_size//2 (the number of possible depths), and an index between 0 and 2 (3 possible directions). An action thus consists of three pieces of information: Face to turn, Depth of the turn (possible depths are between 0 representing the outer layer and cube_size//2 representing the layer closest to the middle), Direction of turn (possible directions are clockwise, anti-clockwise, or a half turn).","title":"Action"},{"location":"environments/rubiks_cube/#reward","text":"The reward function is configurable, but by default is the fully sparse reward giving +1 for solving the cube and otherwise 0 . The episode terminates if either the cube is solved or a configurable horizon (by default 200 ) is reached.","title":"Reward"},{"location":"environments/rubiks_cube/#registered-versions","text":"RubiksCube-v0 , the standard Rubik's Cube puzzle with faces of size 3x3. RubiksCube-partly-scrambled-v0 , an easier version of the standard Rubik's Cube puzzle with faces of size 3x3 yet only 7 scrambles at reset time, making it technically maximum 7 actions away from the solution.","title":"Registered Versions \ud83d\udcd6"},{"location":"environments/snake/","text":"Snake Environment \ud83d\udc0d # We provide here an implementation of the Snake environment from (Bonnet et al., 2021) . The goal of the agent is to navigate in a grid world (by default of size 12x12) to collect as many fruits as possible without colliding with its own body (i.e. looping on itself). Observation # grid : jax array (float) of shape (num_rows, num_cols, 5) , feature maps (image) that include information about the fruit, the snake head, its body and tail. step_count : jax array (int32) of shape () , current number of steps in the episode. action_mask : jax array (bool) of shape (4,) , array specifying which directions the snake can move in from its current position. Action # The action space is a DiscreteArray of integer values: [0,1,2,3] -> [Up, Right, Down, Left] . Reward # The reward is +1 upon collection of a fruit and 0 otherwise. Registered Versions \ud83d\udcd6 # Snake-v1 : Snake game on a board of size 12x12 with a time limit of 4000 .","title":"Snake"},{"location":"environments/snake/#snake-environment","text":"We provide here an implementation of the Snake environment from (Bonnet et al., 2021) . The goal of the agent is to navigate in a grid world (by default of size 12x12) to collect as many fruits as possible without colliding with its own body (i.e. looping on itself).","title":"Snake Environment \ud83d\udc0d"},{"location":"environments/snake/#observation","text":"grid : jax array (float) of shape (num_rows, num_cols, 5) , feature maps (image) that include information about the fruit, the snake head, its body and tail. step_count : jax array (int32) of shape () , current number of steps in the episode. action_mask : jax array (bool) of shape (4,) , array specifying which directions the snake can move in from its current position.","title":"Observation"},{"location":"environments/snake/#action","text":"The action space is a DiscreteArray of integer values: [0,1,2,3] -> [Up, Right, Down, Left] .","title":"Action"},{"location":"environments/snake/#reward","text":"The reward is +1 upon collection of a fruit and 0 otherwise.","title":"Reward"},{"location":"environments/snake/#registered-versions","text":"Snake-v1 : Snake game on a board of size 12x12 with a time limit of 4000 .","title":"Registered Versions \ud83d\udcd6"},{"location":"environments/sudoku/","text":"Sudoku Environment # We provide here a Jax JIT-able implementation of the Sudoku puzzle game. Observation # The observation given to the agent consists of: board : jax array (int32) of shape (9,9): empty cells are represented by -1, and filled cells are represented by 0-8. action_mask : jax array (bool) of shape (9,9,9): indicates which actions are valid. Action # The action space is a MultiDiscreteArray of integer values representing coordinates of the square to explore and the digits to write in the cell, e.g. [3, 6, 8] for writing the digit 9 in the cell located on the fourth row and seventh column. Reward # The reward is 1 at the end of the episode if the board is correctly solved, and 0 in every other case. Termination # An episode terminates when there are no more legal actions available, this could happen if the board is solved or if the agent finds itself in a dead-end. Registered Versions \ud83d\udcd6 # Sudoku-v0 , the classic game on a 9x9 grid, 10000 random puzzles with mixed difficulty are included by default. Sudoku-very-easy-v0 , the classic game on a 9x9 grid, only 1000 very-easy random puzzles (>46 clues) included by default. Using custom puzzle instances # If one wants to include its own database of puzzles, the DatabaseGenerator can be initialized with any collection of puzzles using the argument custom_boards . Some references for databases of puzzle of various difficulties: - https://www.kaggle.com/datasets/rohanrao/sudoku - https://www.kaggle.com/datasets/informoney/4-million-sudoku-puzzles-easytohard Difficulty level as a function of number of clues # Adapted from An Algorithm for Generating only Desired Permutations for Solving Sudoku Puzzle .","title":"Sudoku"},{"location":"environments/sudoku/#sudoku-environment","text":"We provide here a Jax JIT-able implementation of the Sudoku puzzle game.","title":"Sudoku Environment"},{"location":"environments/sudoku/#observation","text":"The observation given to the agent consists of: board : jax array (int32) of shape (9,9): empty cells are represented by -1, and filled cells are represented by 0-8. action_mask : jax array (bool) of shape (9,9,9): indicates which actions are valid.","title":"Observation"},{"location":"environments/sudoku/#action","text":"The action space is a MultiDiscreteArray of integer values representing coordinates of the square to explore and the digits to write in the cell, e.g. [3, 6, 8] for writing the digit 9 in the cell located on the fourth row and seventh column.","title":"Action"},{"location":"environments/sudoku/#reward","text":"The reward is 1 at the end of the episode if the board is correctly solved, and 0 in every other case.","title":"Reward"},{"location":"environments/sudoku/#termination","text":"An episode terminates when there are no more legal actions available, this could happen if the board is solved or if the agent finds itself in a dead-end.","title":"Termination"},{"location":"environments/sudoku/#registered-versions","text":"Sudoku-v0 , the classic game on a 9x9 grid, 10000 random puzzles with mixed difficulty are included by default. Sudoku-very-easy-v0 , the classic game on a 9x9 grid, only 1000 very-easy random puzzles (>46 clues) included by default.","title":"Registered Versions \ud83d\udcd6"},{"location":"environments/sudoku/#using-custom-puzzle-instances","text":"If one wants to include its own database of puzzles, the DatabaseGenerator can be initialized with any collection of puzzles using the argument custom_boards . Some references for databases of puzzle of various difficulties: - https://www.kaggle.com/datasets/rohanrao/sudoku - https://www.kaggle.com/datasets/informoney/4-million-sudoku-puzzles-easytohard","title":"Using custom puzzle instances"},{"location":"environments/sudoku/#difficulty-level-as-a-function-of-number-of-clues","text":"Adapted from An Algorithm for Generating only Desired Permutations for Solving Sudoku Puzzle .","title":"Difficulty level as a function of number of clues"},{"location":"environments/tetris/","text":"Tetris Environment # We provide here a Jax JIT-able implementation of the game Tetris. Tetris is a popular single-player game that is played on a 2D grid by fitting falling blocks of various Tetrominoes together to create horizontal lines without any gaps. As each line is completed, it disappears, and the player earns points. If the stack of blocks reaches the top of the game grid, the game ends. The objective of Tetris is to score as many points as possible before the game ends, by clearing as many lines as possible. Tetris consists of 7 types of Tetrominoes, which are shapes that represent the letters \"I\", \"O\", \"S\", \"Z\", \"L\", \"J\", and \"T\" as shown in the image below. Observation # The observation in Tetris includes information about the grid, the Tetromino and the action mask. grid : jax array (int32) of shape (num_rows, num_cols) , representing the current grid state. The grid is filled with zeros for the empty cells and with ones for the filled cells. Here is an example of a random observation of the grid: 1 2 3 4 5 6 7 8 9 [ [0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 1, 1], [0, 0, 0, 0, 1, 1], [0, 1, 0, 0, 1, 1], [0, 1, 1, 1, 0, 1], [0, 1, 0, 1, 1, 1], [1, 1, 0, 1, 1, 1], ] tetromino : jax array (int32) of shape (4, 4) , where a value of 1 indicates a filled cell and a value of 0 indicates an empty cell. Here is an example of an I tetromino: 1 2 3 4 5 6 [ [1, 0, 0, 0], [1, 0, 0, 0], [1, 0, 0, 0], [1, 0, 0, 0], ] action_mask : jax array (bool) of shape (4, num_cols) , indicating which actions are valid in the current state of the environment. Each row in the action mask corresponds to a Tetromino for a certain rotation (example: the first row for 0 degrees rotation, the second row for 90 degrees rotation, and so on). Here is an example of an action mask that corresponds to the same grid and the tetromino examples: 1 2 3 4 5 6 [ [ True, False, True, True, False, False], [ True, True, False, False, False, False], [ True, False, True, True, False, False], [ True, True, False, False, False, False], ] - step_count : jax array (int32) of shape () , integer to keep track of the number of steps. Action # The action space in Tetris is represented as a MultiDiscreteArray of two integer values. The first integer value corresponds to the selected X-position where the Tetromino will be placed, and the second integer value represents the index for the rotation degree. The rotation degree index can take four possible values: 0 for \"0 degrees\", 1 for \"90 degrees\", 2 for \"180 degrees\", and 3 for \"270 degrees\". For example, an action of [7, 2] means placing the Tetromino in the seventh column with a rotation of 180 degrees. Reward # Dense: the reward is based on the number of lines cleared and the reward_list [0, 40, 100, 300, 1200] . If no lines are cleared, the reward is 0. As the number of cleared lines increases, so does the reward, with the maximum reward of 1200 being awarded for clearing four lines at once. Registered Versions \ud83d\udcd6 # Tetris-v0 , the default settings for tetris with a grid of size 10x10.","title":"Tetris"},{"location":"environments/tetris/#tetris-environment","text":"We provide here a Jax JIT-able implementation of the game Tetris. Tetris is a popular single-player game that is played on a 2D grid by fitting falling blocks of various Tetrominoes together to create horizontal lines without any gaps. As each line is completed, it disappears, and the player earns points. If the stack of blocks reaches the top of the game grid, the game ends. The objective of Tetris is to score as many points as possible before the game ends, by clearing as many lines as possible. Tetris consists of 7 types of Tetrominoes, which are shapes that represent the letters \"I\", \"O\", \"S\", \"Z\", \"L\", \"J\", and \"T\" as shown in the image below.","title":"Tetris Environment"},{"location":"environments/tetris/#observation","text":"The observation in Tetris includes information about the grid, the Tetromino and the action mask. grid : jax array (int32) of shape (num_rows, num_cols) , representing the current grid state. The grid is filled with zeros for the empty cells and with ones for the filled cells. Here is an example of a random observation of the grid: 1 2 3 4 5 6 7 8 9 [ [0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 1, 1], [0, 0, 0, 0, 1, 1], [0, 1, 0, 0, 1, 1], [0, 1, 1, 1, 0, 1], [0, 1, 0, 1, 1, 1], [1, 1, 0, 1, 1, 1], ] tetromino : jax array (int32) of shape (4, 4) , where a value of 1 indicates a filled cell and a value of 0 indicates an empty cell. Here is an example of an I tetromino: 1 2 3 4 5 6 [ [1, 0, 0, 0], [1, 0, 0, 0], [1, 0, 0, 0], [1, 0, 0, 0], ] action_mask : jax array (bool) of shape (4, num_cols) , indicating which actions are valid in the current state of the environment. Each row in the action mask corresponds to a Tetromino for a certain rotation (example: the first row for 0 degrees rotation, the second row for 90 degrees rotation, and so on). Here is an example of an action mask that corresponds to the same grid and the tetromino examples: 1 2 3 4 5 6 [ [ True, False, True, True, False, False], [ True, True, False, False, False, False], [ True, False, True, True, False, False], [ True, True, False, False, False, False], ] - step_count : jax array (int32) of shape () , integer to keep track of the number of steps.","title":"Observation"},{"location":"environments/tetris/#action","text":"The action space in Tetris is represented as a MultiDiscreteArray of two integer values. The first integer value corresponds to the selected X-position where the Tetromino will be placed, and the second integer value represents the index for the rotation degree. The rotation degree index can take four possible values: 0 for \"0 degrees\", 1 for \"90 degrees\", 2 for \"180 degrees\", and 3 for \"270 degrees\". For example, an action of [7, 2] means placing the Tetromino in the seventh column with a rotation of 180 degrees.","title":"Action"},{"location":"environments/tetris/#reward","text":"Dense: the reward is based on the number of lines cleared and the reward_list [0, 40, 100, 300, 1200] . If no lines are cleared, the reward is 0. As the number of cleared lines increases, so does the reward, with the maximum reward of 1200 being awarded for clearing four lines at once.","title":"Reward"},{"location":"environments/tetris/#registered-versions","text":"Tetris-v0 , the default settings for tetris with a grid of size 10x10.","title":"Registered Versions \ud83d\udcd6"},{"location":"environments/tsp/","text":"Traveling Salesman Problem (TSP) Environment # We provide here a Jax JIT-able implementation of the traveling salesman problem (TSP) . TSP is a well-known combinatorial optimization problem. Given a set of cities and the distances between them, the goal is to determine the shortest route that visits each city exactly once and finishes in the starting city. The problem is NP-complete, thus there is no known algorithm both correct and fast (i.e., that runs in polynomial time) for any instance of the problem. When the environment is reset, a new problem instance is generated by sampling coordinates (a pair for each city) from a uniform distribution between 0 and 1. The number of cities is a parameter of the environment. A trajectory terminates when no new cities can be visited or the last action was invalid (i.e., the agent attempted to revisit a city). Observation # The observation given to the agent provides information on the problem layout, the visited/unvisited cities and the current position (city) of the agent. coordinates : jax array (float) of shape (num_cities, 2) , array of coordinates of each city. position : jax array (int32) of shape () , identifier (index) of the last visited city. trajectory : jax array (int32) of shape (num_cities,) , city indices defining the route ( -1 --> not filled yet). action_mask : jax array (bool) of shape (num_cities,) , binary values denoting whether a city can be visited. Action # The action space is a DiscreteArray of integer values in the range of [0, num_cities-1] . An action is the index of the next city to visit. Reward # The reward could be either: Dense : the negative distance between the current city and the chosen next city to go to. It is 0 for the first chosen city, and for the last city, it also includes the distance to the initial city to complete the tour. Sparse : the negative tour length at the end of the episode. The tour length is defined as the sum of the distances between consecutive cities. It is computed by starting at the first city and ending there, after visiting all the cities. In both cases, the reward is a large negative penalty of -num_cities * sqrt(2) if the action is invalid, i.e. a previously selected city is selected again. Registered Versions \ud83d\udcd6 # TSP-v1 : TSP problem with 20 randomly generated cities and a dense reward function.","title":"TSP"},{"location":"environments/tsp/#traveling-salesman-problem-tsp-environment","text":"We provide here a Jax JIT-able implementation of the traveling salesman problem (TSP) . TSP is a well-known combinatorial optimization problem. Given a set of cities and the distances between them, the goal is to determine the shortest route that visits each city exactly once and finishes in the starting city. The problem is NP-complete, thus there is no known algorithm both correct and fast (i.e., that runs in polynomial time) for any instance of the problem. When the environment is reset, a new problem instance is generated by sampling coordinates (a pair for each city) from a uniform distribution between 0 and 1. The number of cities is a parameter of the environment. A trajectory terminates when no new cities can be visited or the last action was invalid (i.e., the agent attempted to revisit a city).","title":"Traveling Salesman Problem (TSP) Environment"},{"location":"environments/tsp/#observation","text":"The observation given to the agent provides information on the problem layout, the visited/unvisited cities and the current position (city) of the agent. coordinates : jax array (float) of shape (num_cities, 2) , array of coordinates of each city. position : jax array (int32) of shape () , identifier (index) of the last visited city. trajectory : jax array (int32) of shape (num_cities,) , city indices defining the route ( -1 --> not filled yet). action_mask : jax array (bool) of shape (num_cities,) , binary values denoting whether a city can be visited.","title":"Observation"},{"location":"environments/tsp/#action","text":"The action space is a DiscreteArray of integer values in the range of [0, num_cities-1] . An action is the index of the next city to visit.","title":"Action"},{"location":"environments/tsp/#reward","text":"The reward could be either: Dense : the negative distance between the current city and the chosen next city to go to. It is 0 for the first chosen city, and for the last city, it also includes the distance to the initial city to complete the tour. Sparse : the negative tour length at the end of the episode. The tour length is defined as the sum of the distances between consecutive cities. It is computed by starting at the first city and ending there, after visiting all the cities. In both cases, the reward is a large negative penalty of -num_cities * sqrt(2) if the action is invalid, i.e. a previously selected city is selected again.","title":"Reward"},{"location":"environments/tsp/#registered-versions","text":"TSP-v1 : TSP problem with 20 randomly generated cities and a dense reward function.","title":"Registered Versions \ud83d\udcd6"},{"location":"guides/advanced_usage/","text":"Advanced Usage \ud83e\uddd1\u200d\ud83d\udd2c # Being written in JAX, Jumanji's environments benefit from many of its features including automatic vectorization/parallelization ( jax.vmap , jax.pmap ) and JIT-compilation ( jax.jit ), which can be composed arbitrarily. We provide an example of this below, where we use jax.vmap and jax.lax.scan to generate a batch of rollouts in the Snake environment. 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 import jax import jumanji from jumanji.wrappers import AutoResetWrapper env = jumanji . make ( \"Snake-v1\" ) # Create a Snake environment env = AutoResetWrapper ( env ) # Automatically reset the environment when an episode terminates batch_size = 7 rollout_length = 5 num_actions = env . action_spec () . num_values random_key = jax . random . PRNGKey ( 0 ) key1 , key2 = jax . random . split ( random_key ) def step_fn ( state , key ): action = jax . random . randint ( key = key , minval = 0 , maxval = num_actions , shape = ()) new_state , timestep = env . step ( state , action ) return new_state , timestep def run_n_steps ( state , key , n ): random_keys = jax . random . split ( key , n ) state , rollout = jax . lax . scan ( step_fn , state , random_keys ) return rollout # Instantiate a batch of environment states keys = jax . random . split ( key1 , batch_size ) state , timestep = jax . vmap ( env . reset )( keys ) # Collect a batch of rollouts keys = jax . random . split ( key2 , batch_size ) rollout = jax . vmap ( run_n_steps , in_axes = ( 0 , 0 , None ))( state , keys , rollout_length ) # Shape and type of given rollout: # TimeStep(step_type=(7, 5), reward=(7, 5), discount=(7, 5), observation=(7, 5, 6, 6, 5), extras=None)","title":"Advanced Usage"},{"location":"guides/advanced_usage/#advanced-usage","text":"Being written in JAX, Jumanji's environments benefit from many of its features including automatic vectorization/parallelization ( jax.vmap , jax.pmap ) and JIT-compilation ( jax.jit ), which can be composed arbitrarily. We provide an example of this below, where we use jax.vmap and jax.lax.scan to generate a batch of rollouts in the Snake environment. 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 import jax import jumanji from jumanji.wrappers import AutoResetWrapper env = jumanji . make ( \"Snake-v1\" ) # Create a Snake environment env = AutoResetWrapper ( env ) # Automatically reset the environment when an episode terminates batch_size = 7 rollout_length = 5 num_actions = env . action_spec () . num_values random_key = jax . random . PRNGKey ( 0 ) key1 , key2 = jax . random . split ( random_key ) def step_fn ( state , key ): action = jax . random . randint ( key = key , minval = 0 , maxval = num_actions , shape = ()) new_state , timestep = env . step ( state , action ) return new_state , timestep def run_n_steps ( state , key , n ): random_keys = jax . random . split ( key , n ) state , rollout = jax . lax . scan ( step_fn , state , random_keys ) return rollout # Instantiate a batch of environment states keys = jax . random . split ( key1 , batch_size ) state , timestep = jax . vmap ( env . reset )( keys ) # Collect a batch of rollouts keys = jax . random . split ( key2 , batch_size ) rollout = jax . vmap ( run_n_steps , in_axes = ( 0 , 0 , None ))( state , keys , rollout_length ) # Shape and type of given rollout: # TimeStep(step_type=(7, 5), reward=(7, 5), discount=(7, 5), observation=(7, 5, 6, 6, 5), extras=None)","title":"Advanced Usage \ud83e\uddd1\u200d\ud83d\udd2c"},{"location":"guides/registration/","text":"Environment Registry # Jumanji adopts the convention defined in Gym of having an environment registry and a make function to instantiate environments. Create an environment # To instantiate a Jumanji registered environment, we provide the convenient function jumanji.make . It can be used as follows: 1 2 3 4 5 6 import jax import jumanji env = jumanji . make ( 'BinPack-v1' ) key = jax . random . PRNGKey ( 0 ) state , timestep = env . reset ( key ) The environment ID is composed of two parts, the environment name and its version. To get the full list of registered environments, you can use the registered_environments util. \u26a0\ufe0f Warning 1 2 3 4 Users can provide additional key-word arguments in the call to `jumanji.make(env_id, ...)`. These are then passed to the class constructor. Because they can be used to overwrite the intended configuration of the environment when registered, we discourage users to do so. However, we are mindful of particular use cases that might require this flexibility. Although the make function provides a unified way to instantiate environments, users can always instantiate them by importing the corresponding environment class. Register your environment # In addition to the environments available in Jumanji, users can register their custom environment and access them through the familiar jumanji.make function. Assuming you created an environment by subclassing Jumanji Environment base class, you can register it as follows: 1 2 3 4 5 6 7 from jumanji import register register ( id = \"CustomEnv-v0\" , # format: (env_name)-v(version) entry_point = \"path.to.your.package:CustomEnv\" , # class constructor kwargs = { ... }, # environment configuration ) To successfully register your environment, make sure to provide the right path to your class constructor. The kwargs argument is there to configurate the environment and allow you to register scenarios with a specific set of arguments. The environment ID must respect the format (EnvName)-v(version) , where the version number starts at v0 . For examples on how to register environments, please see our jumanji/__init__.py file. 1 Note that Jumanji doesn't allow users to overwrite the registration of an existing environment. To verify that your custom environment has been registered correctly, you can inspect the listing of registered environments using the registered_environments util.","title":"Registration"},{"location":"guides/registration/#environment-registry","text":"Jumanji adopts the convention defined in Gym of having an environment registry and a make function to instantiate environments.","title":"Environment Registry"},{"location":"guides/registration/#create-an-environment","text":"To instantiate a Jumanji registered environment, we provide the convenient function jumanji.make . It can be used as follows: 1 2 3 4 5 6 import jax import jumanji env = jumanji . make ( 'BinPack-v1' ) key = jax . random . PRNGKey ( 0 ) state , timestep = env . reset ( key ) The environment ID is composed of two parts, the environment name and its version. To get the full list of registered environments, you can use the registered_environments util. \u26a0\ufe0f Warning 1 2 3 4 Users can provide additional key-word arguments in the call to `jumanji.make(env_id, ...)`. These are then passed to the class constructor. Because they can be used to overwrite the intended configuration of the environment when registered, we discourage users to do so. However, we are mindful of particular use cases that might require this flexibility. Although the make function provides a unified way to instantiate environments, users can always instantiate them by importing the corresponding environment class.","title":"Create an environment"},{"location":"guides/registration/#register-your-environment","text":"In addition to the environments available in Jumanji, users can register their custom environment and access them through the familiar jumanji.make function. Assuming you created an environment by subclassing Jumanji Environment base class, you can register it as follows: 1 2 3 4 5 6 7 from jumanji import register register ( id = \"CustomEnv-v0\" , # format: (env_name)-v(version) entry_point = \"path.to.your.package:CustomEnv\" , # class constructor kwargs = { ... }, # environment configuration ) To successfully register your environment, make sure to provide the right path to your class constructor. The kwargs argument is there to configurate the environment and allow you to register scenarios with a specific set of arguments. The environment ID must respect the format (EnvName)-v(version) , where the version number starts at v0 . For examples on how to register environments, please see our jumanji/__init__.py file. 1 Note that Jumanji doesn't allow users to overwrite the registration of an existing environment. To verify that your custom environment has been registered correctly, you can inspect the listing of registered environments using the registered_environments util.","title":"Register your environment"},{"location":"guides/training/","text":"Training # Jumanji provides a training script train.py to train an online agent on a specified Jumanji environment given an environment-specific network. Agents # Jumanji provides two example agents in jumanji/training/agents/ to get you started with training on Jumanji environments: Random agent: uses the action mask to randomly sample valid actions. A2C agent: online advantage actor-critic agent that follows from [Mnih et al., 2016] . Configuration # In each environment-specific config YAML file, you will see a \"training\" section like below: 1 2 3 4 5 training : num_epochs : 1000 num_learner_steps_per_epoch : 50 n_steps : 20 total_batch_size : 64 Here, num_epochs corresponds to the number of data points in your plots. An epoch can be thought as an iteration. num_learner_steps_per_epoch is the number of learner steps that happen in each epoch. After every learner step, the A2C agent's policy is updated. n_steps is the sequence length (consecutive environment steps in a batch). total_batch_size is the number of environments that are run in parallel. So in the above example, 64 environments are running in parallel. Each of these 64 environments run 20 environment steps. After this, the agent's policy is updated via SGD. This constitutes a single learner step. 50 such learner steps are done for the epoch in question. After this, evaluation is done using the updated policy. The above procedure is done for 1000 epochs. Evaluation # Two types of evaluation are recorded: Stochastic evaluation (same policy used during training) Greedy evaluation (argmax over the action logits)","title":"Training"},{"location":"guides/training/#training","text":"Jumanji provides a training script train.py to train an online agent on a specified Jumanji environment given an environment-specific network.","title":"Training"},{"location":"guides/training/#agents","text":"Jumanji provides two example agents in jumanji/training/agents/ to get you started with training on Jumanji environments: Random agent: uses the action mask to randomly sample valid actions. A2C agent: online advantage actor-critic agent that follows from [Mnih et al., 2016] .","title":"Agents"},{"location":"guides/training/#configuration","text":"In each environment-specific config YAML file, you will see a \"training\" section like below: 1 2 3 4 5 training : num_epochs : 1000 num_learner_steps_per_epoch : 50 n_steps : 20 total_batch_size : 64 Here, num_epochs corresponds to the number of data points in your plots. An epoch can be thought as an iteration. num_learner_steps_per_epoch is the number of learner steps that happen in each epoch. After every learner step, the A2C agent's policy is updated. n_steps is the sequence length (consecutive environment steps in a batch). total_batch_size is the number of environments that are run in parallel. So in the above example, 64 environments are running in parallel. Each of these 64 environments run 20 environment steps. After this, the agent's policy is updated via SGD. This constitutes a single learner step. 50 such learner steps are done for the epoch in question. After this, evaluation is done using the updated policy. The above procedure is done for 1000 epochs.","title":"Configuration"},{"location":"guides/training/#evaluation","text":"Two types of evaluation are recorded: Stochastic evaluation (same policy used during training) Greedy evaluation (argmax over the action logits)","title":"Evaluation"},{"location":"guides/wrappers/","text":"Wrappers # The Wrapper interface is used for extending Jumanji Environment to add features like auto reset and vectorised environments. Jumanji provides wrappers to convert a Jumanji Environment to a DeepMind or Gym environment. Jumanji to DeepMind Environment # We can also convert our Jumanji environments to a DeepMind environment: 1 2 3 4 5 6 7 8 9 import jumanji.wrappers env = jumanji . make ( \"Snake-6x6-v0\" ) dm_env = jumanji . wrappers . JumanjiToDMEnvWrapper ( env ) timestep = dm_env . reset () action = dm_env . action_spec () . generate_value () next_timestep = dm_env . step ( action ) ... Jumanji To Gym # We can also convert our Jumanji environments to a Gym environment! Below is an example of how to convert a Jumanji environment into a Gym environment. 1 2 3 4 5 6 7 8 9 import jumanji.wrappers env = jumanji . make ( \"Snake-6x6-v0\" ) gym_env = jumanji . wrappers . JumanjiToGymWrapper ( env ) obs = gym_env . reset () action = gym_env . action_space . sample () observation , reward , done , extra = gym_env . step ( action ) ... Auto-reset an Environment # Below is an example of how to extend the functionality of the Snake environment to automatically reset whenever the environment reaches a terminal state. The Snake game terminates when the snake hits the wall, using the AutoResetWrapper the environment will be reset once a terminal state has been reached. 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 import jax.random import jumanji.wrappers env = jumanji . make ( \"Snake-6x6-v0\" ) env = jumanji . wrappers . AutoResetWrapper ( env ) key = jax . random . PRNGKey ( 0 ) state , timestep = env . reset ( key ) print ( \"New episode\" ) for i in range ( 100 ): action = env . action_spec () . generate_value () # Returns jnp.array(0) when using Snake. state , timestep = env . step ( state , action ) if timestep . first (): print ( \"New episode\" )","title":"Wrapper"},{"location":"guides/wrappers/#wrappers","text":"The Wrapper interface is used for extending Jumanji Environment to add features like auto reset and vectorised environments. Jumanji provides wrappers to convert a Jumanji Environment to a DeepMind or Gym environment.","title":"Wrappers"},{"location":"guides/wrappers/#jumanji-to-deepmind-environment","text":"We can also convert our Jumanji environments to a DeepMind environment: 1 2 3 4 5 6 7 8 9 import jumanji.wrappers env = jumanji . make ( \"Snake-6x6-v0\" ) dm_env = jumanji . wrappers . JumanjiToDMEnvWrapper ( env ) timestep = dm_env . reset () action = dm_env . action_spec () . generate_value () next_timestep = dm_env . step ( action ) ...","title":"Jumanji to DeepMind Environment"},{"location":"guides/wrappers/#jumanji-to-gym","text":"We can also convert our Jumanji environments to a Gym environment! Below is an example of how to convert a Jumanji environment into a Gym environment. 1 2 3 4 5 6 7 8 9 import jumanji.wrappers env = jumanji . make ( \"Snake-6x6-v0\" ) gym_env = jumanji . wrappers . JumanjiToGymWrapper ( env ) obs = gym_env . reset () action = gym_env . action_space . sample () observation , reward , done , extra = gym_env . step ( action ) ...","title":"Jumanji To Gym"},{"location":"guides/wrappers/#auto-reset-an-environment","text":"Below is an example of how to extend the functionality of the Snake environment to automatically reset whenever the environment reaches a terminal state. The Snake game terminates when the snake hits the wall, using the AutoResetWrapper the environment will be reset once a terminal state has been reached. 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 import jax.random import jumanji.wrappers env = jumanji . make ( \"Snake-6x6-v0\" ) env = jumanji . wrappers . AutoResetWrapper ( env ) key = jax . random . PRNGKey ( 0 ) state , timestep = env . reset ( key ) print ( \"New episode\" ) for i in range ( 100 ): action = env . action_spec () . generate_value () # Returns jnp.array(0) when using Snake. state , timestep = env . step ( state , action ) if timestep . first (): print ( \"New episode\" )","title":"Auto-reset an Environment"}]} \ No newline at end of file diff --git a/sitemap.xml.gz b/sitemap.xml.gz index 03a2b75ec..8aed7fe6b 100644 Binary files a/sitemap.xml.gz and b/sitemap.xml.gz differ