diff --git a/src/quaxed/numpy/_creation_functions.py b/src/quaxed/numpy/_creation_functions.py index debd91c..1a8c762 100644 --- a/src/quaxed/numpy/_creation_functions.py +++ b/src/quaxed/numpy/_creation_functions.py @@ -21,19 +21,18 @@ import jax import jax.numpy as jnp from jaxtyping import ArrayLike +from plum import dispatch from quax import Value from quaxed._types import DType from quaxed._utils import quaxify -from ._dispatch import dispatcher - T = TypeVar("T") # ============================================================================= -@dispatcher +@dispatch def arange( start: ArrayLike, stop: ArrayLike | None, @@ -45,7 +44,7 @@ def arange( return jnp.arange(start, stop, step, dtype=dtype) -@dispatcher # type: ignore[no-redef] +@dispatch # type: ignore[no-redef] def arange( start: ArrayLike, stop: ArrayLike | None, @@ -58,7 +57,7 @@ def arange( return arange(start, stop, step, dtype=dtype) -@dispatcher # type: ignore[no-redef] +@dispatch # type: ignore[no-redef] def arange( start: ArrayLike, /, @@ -71,7 +70,7 @@ def arange( return arange(start, stop, step, dtype=dtype) -@dispatcher # type: ignore[no-redef] +@dispatch # type: ignore[no-redef] def arange( *, start: ArrayLike, @@ -101,7 +100,7 @@ def asarray( # ============================================================================= -@dispatcher # type: ignore[misc] +@dispatch # type: ignore[misc] def empty_like( prototype: ArrayLike, /, @@ -115,7 +114,7 @@ def empty_like( # ============================================================================= -@dispatcher +@dispatch def full( shape: tuple[int, ...] | int, fill_value: ArrayLike, @@ -125,7 +124,7 @@ def full( return jnp.full(shape, fill_value, dtype=dtype) -@dispatcher # type: ignore[no-redef] +@dispatch # type: ignore[no-redef] def full( shape: tuple[int, ...] | int, *, @@ -138,7 +137,7 @@ def full( # ============================================================================= -@dispatcher +@dispatch def full_like( x: ArrayLike, /, @@ -150,7 +149,7 @@ def full_like( return jnp.full_like(x, fill_value, dtype=dtype, shape=shape) -@dispatcher # type: ignore[no-redef] +@dispatch # type: ignore[no-redef] def full_like( x: ArrayLike, *, @@ -167,7 +166,7 @@ def full_like( # ============================================================================= -@dispatcher +@dispatch def linspace( # noqa: PLR0913 start: ArrayLike, stop: ArrayLike, @@ -184,7 +183,7 @@ def linspace( # noqa: PLR0913 ) -@dispatcher # type: ignore[no-redef] +@dispatch # type: ignore[no-redef] def linspace( # noqa: PLR0913 start: ArrayLike, stop: ArrayLike, @@ -215,7 +214,7 @@ def meshgrid( # ============================================================================= -@dispatcher # type: ignore[misc] +@dispatch # type: ignore[misc] def ones_like( x: ArrayLike, /, *, dtype: DType | None = None, shape: tuple[int, ...] | None = None ) -> ArrayLike: @@ -244,7 +243,7 @@ def triu(x: ArrayLike, /, *, k: int = 0) -> ArrayLike: # @partial(jax.jit, static_argnames=("dtype", "device")) # @quaxify -@dispatcher # type: ignore[misc] +@dispatch # type: ignore[misc] def zeros_like( x: ArrayLike, /,