From 440fdf6e8b86f7b77b957f76f7e7ad9f9b147b41 Mon Sep 17 00:00:00 2001 From: Nathaniel Starkman Date: Sun, 31 Dec 2023 12:10:38 -0800 Subject: [PATCH] use `jax.experimental.array_api` (#9) * use jax.experimental.array_api * disable beartype checking (until we can get stuff to pass) Signed-off-by: nstarman Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 3 +- pyproject.toml | 22 +- src/array_api_jax_compat/__init__.py | 64 +- src/array_api_jax_compat/_constants.py | 2 +- .../_creation_functions.py | 175 +- .../_data_type_functions.py | 32 +- .../_elementwise_functions.py | 144 +- .../_indexing_functions.py | 7 +- .../_linear_algebra_functions.py | 11 +- .../_manipulation_functions.py | 54 +- .../_searching_functions.py | 10 +- src/array_api_jax_compat/_set_functions.py | 10 +- .../_sorting_functions.py | 12 +- .../_statistical_functions.py | 30 +- src/array_api_jax_compat/_types.py | 2 + .../_utility_functions.py | 6 +- src/array_api_jax_compat/fft.py | 33 +- src/array_api_jax_compat/linalg.py | 67 +- tests/myarray.py | 1472 +++++++++++++++++ tests/test_jax.py | 57 + tests/test_myarray.py | 1422 ++++++++++++++++ 21 files changed, 3299 insertions(+), 336 deletions(-) create mode 100644 tests/myarray.py create mode 100644 tests/test_jax.py create mode 100644 tests/test_myarray.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ea02c3d..4b50bce 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,6 +28,7 @@ repos: - id: mixed-line-ending - id: name-tests-test args: ["--pytest-test-first"] + exclude: tests/myarray.py - id: requirements-txt-fixer - id: trailing-whitespace @@ -55,7 +56,7 @@ repos: rev: "v1.8.0" hooks: - id: mypy - files: src|tests + files: src args: [] additional_dependencies: - numpy diff --git a/pyproject.toml b/pyproject.toml index fb3da3b..ba8a376 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,6 +74,8 @@ addopts = ["-ra", "--showlocals", "--strict-markers", "--strict-config"] xfail_strict = true filterwarnings = [ "error", + # jaxtyping + "ignore:ast\\.Str is deprecated and will be removed in Python 3.14:DeprecationWarning", ] log_cli_level = "INFO" testpaths = [ @@ -90,7 +92,7 @@ port.exclude_lines = [ ] [tool.mypy] -files = ["src", "tests"] +files = ["src"] python_version = "3.10" warn_unused_configs = true strict = true @@ -111,6 +113,7 @@ plugins = [ [[tool.mypy.overrides]] module = [ "jax.*", + "jaxtyping.*", "plum.*", "quax.*", ] @@ -124,23 +127,24 @@ src = ["src"] [tool.ruff.lint] extend-select = ["ALL"] ignore = [ - "A001", # Variable is shadowing a Python builtin - "A002", # Argument is shadowing a Python builtin + "A001", # Variable is shadowing a Python builtin + "A002", # Argument is shadowing a Python builtin "ANN101", # Missing type annotation for self in method "ANN401", # Dynamically typed expressions (typing.Any) are disallowed # TODO "ARG001", # Unused function argument - "D103", # Missing docstring in public function # TODO - "D203", # one-blank-line-before-class - "D213", # Multi-line docstring summary should start at the second line + "D103", # Missing docstring in public function # TODO + "D203", # one-blank-line-before-class + "D213", # Multi-line docstring summary should start at the second line "ERA001", # Found commented-out code "FIX002", # Line contains TODO, consider resolving the issue + "PD011", # Pandas "PYI041", # Use `float` instead of `int | float` - "TD002", # Missing author in TODO; try: `# TODO(): . - "TD003", # Missing issue link on the line following this TODO + "TD002", # Missing author in TODO; try: `# TODO(): . + "TD003", # Missing issue link on the line following this TODO ] [tool.ruff.lint.per-file-ignores] -"tests/**" = ["INP001", "S101", "T20"] +"tests/**" = ["ANN", "INP001", "PLR0913", "S101", "T20"] "__init__.py" = ["F403"] "noxfile.py" = ["T20"] "docs/conf.py" = ["INP001"] diff --git a/src/array_api_jax_compat/__init__.py b/src/array_api_jax_compat/__init__.py index ede8412..9e24447 100644 --- a/src/array_api_jax_compat/__init__.py +++ b/src/array_api_jax_compat/__init__.py @@ -10,37 +10,41 @@ from typing import Any -from . import ( - _constants, - _creation_functions, - _data_type_functions, - _elementwise_functions, - _indexing_functions, - _linear_algebra_functions, - _manipulation_functions, - _searching_functions, - _set_functions, - _sorting_functions, - _statistical_functions, - _utility_functions, - fft, - linalg, -) -from ._constants import * -from ._creation_functions import * -from ._data_type_functions import * -from ._elementwise_functions import * -from ._indexing_functions import * -from ._linear_algebra_functions import * -from ._manipulation_functions import * -from ._searching_functions import * -from ._set_functions import * -from ._sorting_functions import * -from ._statistical_functions import * -from ._utility_functions import * -from ._version import version as __version__ +from jax.experimental.array_api import __array_api_version__ +from jaxtyping import install_import_hook -__all__ = ["__version__", "fft", "linalg"] +with install_import_hook("array_api_jax_compat", None): + from . import ( + _constants, + _creation_functions, + _data_type_functions, + _elementwise_functions, + _indexing_functions, + _linear_algebra_functions, + _manipulation_functions, + _searching_functions, + _set_functions, + _sorting_functions, + _statistical_functions, + _utility_functions, + fft, + linalg, + ) + from ._constants import * + from ._creation_functions import * + from ._data_type_functions import * + from ._elementwise_functions import * + from ._indexing_functions import * + from ._linear_algebra_functions import * + from ._manipulation_functions import * + from ._searching_functions import * + from ._set_functions import * + from ._sorting_functions import * + from ._statistical_functions import * + from ._utility_functions import * + from ._version import version as __version__ + +__all__ = ["__version__", "__array_api_version__", "fft", "linalg"] __all__ += _constants.__all__ __all__ += _creation_functions.__all__ __all__ += _data_type_functions.__all__ diff --git a/src/array_api_jax_compat/_constants.py b/src/array_api_jax_compat/_constants.py index 6fa113a..fbfe7a3 100644 --- a/src/array_api_jax_compat/_constants.py +++ b/src/array_api_jax_compat/_constants.py @@ -2,4 +2,4 @@ __all__ = ["e", "inf", "nan", "newaxis", "pi"] -from jax.numpy import e, inf, nan, newaxis, pi +from jax.experimental.array_api import e, inf, nan, newaxis, pi diff --git a/src/array_api_jax_compat/_creation_functions.py b/src/array_api_jax_compat/_creation_functions.py index 0904754..e239068 100644 --- a/src/array_api_jax_compat/_creation_functions.py +++ b/src/array_api_jax_compat/_creation_functions.py @@ -1,35 +1,36 @@ """Array API creation functions.""" __all__ = [ - # "arange", + "arange", "asarray", - # "empty", + "empty", "empty_like", - # "eye", - # "from_dlpack", - # "full", + "eye", + "from_dlpack", + "full", "full_like", - # "linspace", + "linspace", "meshgrid", - # "ones", + "ones", "ones_like", "tril", "triu", - # "zeros", + "zeros", "zeros_like", ] from functools import partial -from typing import Any, TypeVar +from typing import TypeVar import jax import jax.numpy as jnp from jax import Device +from jax.experimental import array_api from quax import Value from ._dispatch import dispatcher -from ._types import DType, NestedSequence, SupportsBufferProtocol +from ._types import DType from ._utils import quaxify T = TypeVar("T") @@ -37,26 +38,45 @@ # ============================================================================= +@dispatcher # type: ignore[misc] +def arange( + start: jax.Array | jax.core.Tracer, + /, + stop: jax.Array | None = None, + step: jax.Array | int = 1, + *, + dtype: DType | None = None, + device: Device | None = None, +) -> jax.Array | jax.core.Tracer: + return array_api.arange(start, stop, step, dtype=dtype, device=device) + + +# ============================================================================= + + @partial(jax.jit, static_argnames=("dtype", "device", "copy")) @quaxify def asarray( - obj: Value - | bool - | int - | float - | complex - | NestedSequence[Any] - | SupportsBufferProtocol, + obj: Value, /, *, dtype: DType | None = None, device: Device | None = None, - copy: bool | None = None, # TODO: support # pylint: disable=unused-argument + copy: bool | None = None, ) -> Value: - out = jnp.asarray(obj, dtype=dtype) - return jax.device_put(out, device=device) - # TODO: jax.lax.cond is not yet supported by Quax. - # out = jax.lax.cond(bool(copy), lambda x: jax.lax.copy_p.bind(x), lambda x: x, out) + return array_api.asarray(obj, dtype=dtype, device=device, copy=copy) + + +# ============================================================================= + + +def empty( + shape: tuple[int, ...], + *, + dtype: DType | None = None, + device: Device | None = None, +) -> jax.Array: + return array_api.empty(shape, dtype=dtype, device=device) # ============================================================================= @@ -72,8 +92,43 @@ def empty_like( dtype: DType | None = None, device: Device | None = None, ) -> jax.Array | jax.core.Tracer | Value: - out = jnp.empty_like(x, dtype=dtype) - return jax.device_put(out, device=device) + return array_api.empty_like(x, dtype=dtype, device=device) + + +# ============================================================================= + + +def eye( + n_rows: int, + n_cols: int | None = None, + /, + *, + k: int = 0, + dtype: DType | None = None, + device: Device | None = None, +) -> jax.Array: + return array_api.eye(n_rows, n_cols, k=k, dtype=dtype, device=device) + + +# ============================================================================= + + +def from_dlpack(x: object, /) -> jax.Array: + return array_api.from_dlpack(x) + + +# ============================================================================= + + +@dispatcher # type: ignore[misc] +def full( + shape: tuple[int, ...], + fill_value: int | float | complex | bool | jax.Array, + *, + dtype: DType | None = None, + device: Device | None = None, +) -> jax.Array | Value: + return array_api.full(shape, fill_value, dtype=dtype, device=device) # ============================================================================= @@ -85,13 +140,36 @@ def empty_like( def full_like( x: jax.Array | jax.core.Tracer | Value, /, - fill_value: bool | int | float | complex, + fill_value: bool | int | float | complex | jax.Array | Value, + *, + dtype: DType | None = None, + device: Device | None = None, +) -> jax.Array | jax.core.Tracer | Value: + return array_api.full_like(x, fill_value, dtype=dtype, device=device) + + +# ============================================================================= + + +@dispatcher # type: ignore[misc] +def linspace( # noqa: PLR0913 + start: int | float | complex | jax.Array, + stop: int | float | complex | jax.Array, + /, + num: int, *, dtype: DType | None = None, device: Device | None = None, + endpoint: bool = True, ) -> jax.Array | jax.core.Tracer | Value: - out = jnp.full_like(x, fill_value, dtype=dtype) - return jax.device_put(out, device=device) + return array_api.linspace( + start, + stop, + num, + dtype=dtype, + device=device, + endpoint=endpoint, + ) # ============================================================================= @@ -105,6 +183,18 @@ def meshgrid(*arrays: Value, indexing: str = "xy") -> list[Value]: # ============================================================================= +def ones( + shape: tuple[int, ...], + *, + dtype: DType | None = None, + device: Device | None = None, +) -> jax.Array: + return array_api.ones(shape, dtype=dtype, device=device) + + +# ============================================================================= + + # @partial(jax.jit, static_argnames=("dtype", "device")) # @quaxify # TODO: quaxify won't work here because of how the function is defined. @dispatcher # type: ignore[misc] @@ -115,8 +205,7 @@ def ones_like( dtype: DType | None = None, device: Device | None = None, ) -> jax.Array | jax.core.Tracer | Value: - out = jnp.ones_like(x, dtype=dtype) - return jax.device_put(out, device=device) + return array_api.ones_like(x, dtype=dtype, device=device) # ============================================================================= @@ -125,7 +214,7 @@ def ones_like( # @partial(jax.jit, static_argnames=("k",)) @quaxify def tril(x: Value, /, *, k: int = 0) -> Value: - return jnp.tril(x, k=k) + return array_api.tril(x, k=k) # ============================================================================= @@ -134,7 +223,19 @@ def tril(x: Value, /, *, k: int = 0) -> Value: # @partial(jax.jit, static_argnames=("k",)) @quaxify def triu(x: Value, /, *, k: int = 0) -> Value: - return jnp.triu(x, k=k) + return array_api.triu(x, k=k) + + +# ============================================================================= + + +def zeros( + shape: tuple[int, ...], + *, + dtype: DType | None = None, + device: Device | None = None, +) -> jax.Array: + return array_api.zeros(shape, dtype=dtype, device=device) # ============================================================================= @@ -150,14 +251,4 @@ def zeros_like( dtype: DType | None = None, device: Device | None = None, ) -> Value | jax.core.Tracer | jax.Array: - out = jnp.zeros_like(x, dtype=dtype) - return jax.device_put(out, device=device) - - -# @dispatcher -# def zeros_like( -# x: quax.zero.Zero, /, *, dtype: DType | None = None, device: Device | None = None -# ) -> jnp.ndarray: -# out = jnp.zeros_like(x, dtype=dtype) -# out = jax.device_put(out, device=device) -# return out + return array_api.zeros_like(x, dtype=dtype, device=device) diff --git a/src/array_api_jax_compat/_data_type_functions.py b/src/array_api_jax_compat/_data_type_functions.py index 1080f76..6b9fa4f 100644 --- a/src/array_api_jax_compat/_data_type_functions.py +++ b/src/array_api_jax_compat/_data_type_functions.py @@ -1,8 +1,8 @@ __all__ = ["astype", "can_cast", "finfo", "iinfo", "isdtype", "result_type"] -import jax -from jax import Device +from jax.experimental import array_api +from jax.experimental.array_api._data_type_functions import FInfo, IInfo from quax import Value from ._types import DType @@ -10,38 +10,28 @@ @quaxify -def astype( - x: Value, - dtype: DType, - /, - *, - copy: bool = True, # TODO: support # pylint: disable=unused-argument - device: Device | None = None, -) -> Value: - out = jax.lax.convert_element_type(x, dtype) - return jax.device_put(out, device=device) +def astype(x: Value, dtype: DType, /, *, copy: bool = True) -> Value: + return array_api.astype(x, dtype, copy=copy) @quaxify def can_cast(from_: DType | Value, to: DType, /) -> bool: - return jax.numpy.can_cast(from_, to) + return array_api.can_cast(from_, to) @quaxify -def finfo(type: DType | Value, /) -> jax.numpy.finfo: - return jax.numpy.finfo(type) +def finfo(type: DType | Value, /) -> FInfo: + return array_api.finfo(type) @quaxify -def iinfo(type: DType | Value, /) -> jax.numpy.iinfo: - return jax.numpy.iinfo(type) +def iinfo(type: DType | Value, /) -> IInfo: + return array_api.iinfo(type) -@quaxify -def isdtype(dtype: DType, kind: DType | str | tuple[DType | str, ...]) -> bool: - raise NotImplementedError +isdtype = quaxify(array_api.isdtype) @quaxify def result_type(*arrays_and_dtypes: Value | DType) -> DType: - return jax.numpy.result_type(*arrays_and_dtypes) + return array_api.result_type(*arrays_and_dtypes) diff --git a/src/array_api_jax_compat/_elementwise_functions.py b/src/array_api_jax_compat/_elementwise_functions.py index d7e2e04..68c68a1 100644 --- a/src/array_api_jax_compat/_elementwise_functions.py +++ b/src/array_api_jax_compat/_elementwise_functions.py @@ -16,7 +16,6 @@ "bitwise_xor", "ceil", "conj", - "copysign", "cos", "cosh", "divide", @@ -42,8 +41,6 @@ "logical_not", "logical_or", "logical_xor", - "maximum", - "minimum", "multiply", "negative", "not_equal", @@ -53,7 +50,6 @@ "remainder", "round", "sign", - "signbit", "sin", "sinh", "square", @@ -65,7 +61,7 @@ ] -import jax.numpy as jnp +from jax.experimental import array_api from quax import Value from ._utils import quaxify @@ -73,314 +69,294 @@ @quaxify def abs(x: Value, /) -> Value: - return jnp.abs(x) + return array_api.abs(x) @quaxify def acos(x: Value, /) -> Value: - return jnp.arccos(x) + return array_api.acos(x) @quaxify def acosh(x: Value, /) -> Value: - return jnp.arccosh(x) + return array_api.acosh(x) @quaxify def add(x1: Value, x2: Value, /) -> Value: - return jnp.add(x1, x2) + return array_api.add(x1, x2) @quaxify def asin(x: Value, /) -> Value: - return jnp.arcsin(x) + return array_api.asin(x) @quaxify def asinh(x: Value, /) -> Value: - return jnp.arcsinh(x) + return array_api.asinh(x) @quaxify def atan(x: Value, /) -> Value: - return jnp.arctan(x) + return array_api.atan(x) @quaxify def atan2(x1: Value, x2: Value, /) -> Value: - return jnp.arctan2(x1, x2) + return array_api.atan2(x1, x2) @quaxify def atanh(x: Value, /) -> Value: - return jnp.arctanh(x) + return array_api.atanh(x) @quaxify def bitwise_and(x1: Value, x2: Value, /) -> Value: - return jnp.bitwise_and(x1, x2) + return array_api.bitwise_and(x1, x2) @quaxify def bitwise_left_shift(x1: Value, x2: Value, /) -> Value: - return jnp.left_shift(x1, x2) + return array_api.bitwise_left_shift(x1, x2) @quaxify def bitwise_invert(x: Value, /) -> Value: - return jnp.bitwise_not(x) + return array_api.bitwise_invert(x) @quaxify def bitwise_or(x1: Value, x2: Value, /) -> Value: - return jnp.bitwise_or(x1, x2) + return array_api.bitwise_or(x1, x2) @quaxify def bitwise_right_shift(x1: Value, x2: Value, /) -> Value: - return jnp.right_shift(x1, x2) + return array_api.bitwise_right_shift(x1, x2) @quaxify def bitwise_xor(x1: Value, x2: Value, /) -> Value: - return jnp.bitwise_xor(x1, x2) + return array_api.bitwise_xor(x1, x2) @quaxify def ceil(x: Value, /) -> Value: - return jnp.ceil(x) + return array_api.ceil(x) @quaxify def conj(x: Value, /) -> Value: - return jnp.conj(x) - - -@quaxify -def copysign(x1: Value, x2: Value, /) -> Value: - return jnp.copysign(x1, x2) + return array_api.conj(x) @quaxify def cos(x: Value, /) -> Value: - return jnp.cos(x) + return array_api.cos(x) @quaxify def cosh(x: Value, /) -> Value: - return jnp.cosh(x) + return array_api.cosh(x) @quaxify def divide(x1: Value, x2: Value, /) -> Value: - return jnp.divide(x1, x2) + return array_api.divide(x1, x2) @quaxify def equal(x1: Value, x2: Value, /) -> Value: - return jnp.equal(x1, x2) + return array_api.equal(x1, x2) @quaxify def exp(x: Value, /) -> Value: - return jnp.exp(x) + return array_api.exp(x) @quaxify def expm1(x: Value, /) -> Value: - return jnp.expm1(x) + return array_api.expm1(x) @quaxify def floor(x: Value, /) -> Value: - return jnp.floor(x) + return array_api.floor(x) @quaxify def floor_divide(x1: Value, x2: Value, /) -> Value: - return jnp.floor_divide(x1, x2) + return array_api.floor_divide(x1, x2) @quaxify def greater(x1: Value, x2: Value, /) -> Value: - return jnp.greater(x1, x2) + return array_api.greater(x1, x2) @quaxify def greater_equal(x1: Value, x2: Value, /) -> Value: - return jnp.greater_equal(x1, x2) + return array_api.greater_equal(x1, x2) @quaxify def imag(x: Value, /) -> Value: - return jnp.imag(x) + return array_api.imag(x) @quaxify def isfinite(x: Value, /) -> Value: - return jnp.isfinite(x) + return array_api.isfinite(x) @quaxify def isinf(x: Value, /) -> Value: - return jnp.isinf(x) + return array_api.isinf(x) @quaxify def isnan(x: Value, /) -> Value: - return jnp.isnan(x) + return array_api.isnan(x) @quaxify def less(x1: Value, x2: Value, /) -> Value: - return jnp.less(x1, x2) + return array_api.less(x1, x2) @quaxify def less_equal(x1: Value, x2: Value, /) -> Value: - return jnp.less_equal(x1, x2) + return array_api.less_equal(x1, x2) @quaxify def log(x: Value, /) -> Value: - return jnp.log(x) + return array_api.log(x) @quaxify def log1p(x: Value, /) -> Value: - return jnp.log1p(x) + return array_api.log1p(x) @quaxify def log2(x: Value, /) -> Value: - return jnp.log2(x) + return array_api.log2(x) @quaxify def log10(x: Value, /) -> Value: - return jnp.log10(x) + return array_api.log10(x) @quaxify def logaddexp(x1: Value, x2: Value, /) -> Value: - return jnp.logaddexp(x1, x2) + return array_api.logaddexp(x1, x2) @quaxify def logical_and(x1: Value, x2: Value, /) -> Value: - return jnp.logical_and(x1, x2) + return array_api.logical_and(x1, x2) @quaxify def logical_not(x: Value, /) -> Value: - return jnp.logical_not(x) + return array_api.logical_not(x) @quaxify def logical_or(x1: Value, x2: Value, /) -> Value: - return jnp.logical_or(x1, x2) + return array_api.logical_or(x1, x2) @quaxify def logical_xor(x1: Value, x2: Value, /) -> Value: - return jnp.logical_xor(x1, x2) - - -@quaxify -def maximum(x1: Value, x2: Value, /) -> Value: - return jnp.maximum(x1, x2) - - -@quaxify -def minimum(x1: Value, x2: Value, /) -> Value: - return jnp.minimum(x1, x2) + return array_api.logical_xor(x1, x2) @quaxify def multiply(x1: Value, x2: Value, /) -> Value: - return jnp.multiply(x1, x2) + return array_api.multiply(x1, x2) @quaxify def negative(x: Value, /) -> Value: - return jnp.negative(x) + return array_api.negative(x) @quaxify def not_equal(x1: Value, x2: Value, /) -> Value: - return jnp.not_equal(x1, x2) + return array_api.not_equal(x1, x2) @quaxify def positive(x: Value, /) -> Value: - return jnp.positive(x) + return array_api.positive(x) @quaxify def pow(x1: Value, x2: Value, /) -> Value: - return jnp.power(x1, x2) + return array_api.pow(x1, x2) @quaxify def real(x: Value, /) -> Value: - return jnp.real(x) + return array_api.real(x) @quaxify def remainder(x1: Value, x2: Value, /) -> Value: - return jnp.remainder(x1, x2) + return array_api.remainder(x1, x2) @quaxify def round(x: Value, /) -> Value: - return jnp.round(x) + return array_api.round(x) @quaxify def sign(x: Value, /) -> Value: - return jnp.sign(x) - - -@quaxify -def signbit(x: Value, /) -> Value: - return jnp.signbit(x) + return array_api.sign(x) @quaxify def sin(x: Value, /) -> Value: - return jnp.sin(x) + return array_api.sin(x) @quaxify def sinh(x: Value, /) -> Value: - return jnp.sinh(x) + return array_api.sinh(x) @quaxify def square(x: Value, /) -> Value: - return jnp.square(x) + return array_api.square(x) @quaxify def sqrt(x: Value, /) -> Value: - return jnp.sqrt(x) + return array_api.sqrt(x) @quaxify def subtract(x1: Value, x2: Value, /) -> Value: - return jnp.subtract(x1, x2) + return array_api.subtract(x1, x2) @quaxify def tan(x: Value, /) -> Value: - return jnp.tan(x) + return array_api.tan(x) @quaxify def tanh(x: Value, /) -> Value: - return jnp.tanh(x) + return array_api.tanh(x) @quaxify def trunc(x: Value, /) -> Value: - return jnp.trunc(x) + return array_api.trunc(x) diff --git a/src/array_api_jax_compat/_indexing_functions.py b/src/array_api_jax_compat/_indexing_functions.py index f05b37d..4c0354d 100644 --- a/src/array_api_jax_compat/_indexing_functions.py +++ b/src/array_api_jax_compat/_indexing_functions.py @@ -1,8 +1,11 @@ __all__ = ["take"] -import jax.numpy as jnp +from jax.experimental import array_api from quax import Value +from ._utils import quaxify + +@quaxify def take(x: Value, indices: Value, /, *, axis: int | None = None) -> Value: - return jnp.take(x, indices, axis=axis) + return array_api.take(x, indices, axis=axis) diff --git a/src/array_api_jax_compat/_linear_algebra_functions.py b/src/array_api_jax_compat/_linear_algebra_functions.py index 9c7366f..b2bc159 100644 --- a/src/array_api_jax_compat/_linear_algebra_functions.py +++ b/src/array_api_jax_compat/_linear_algebra_functions.py @@ -3,7 +3,7 @@ from collections.abc import Sequence -import jax.numpy as jnp +from jax.experimental import array_api from quax import Value from ._utils import quaxify @@ -11,12 +11,12 @@ @quaxify def matmul(x1: Value, x2: Value, /) -> Value: - return jnp.matmul(x1, x2) + return array_api.matmul(x1, x2) @quaxify def matrix_transpose(x: Value, /) -> Value: - return jnp.transpose(x) + return array_api.matrix_transpose(x) @quaxify @@ -27,10 +27,9 @@ def tensordot( *, axes: int | tuple[Sequence[int], Sequence[int]] = 2, ) -> Value: - return jnp.tensordot(x1, x2, axes=axes) + return array_api.tensordot(x1, x2, axes=axes) @quaxify def vecdot(x1: Value, x2: Value, /, *, axis: int = -1) -> Value: - del axis # TODO: support - return jnp.dot(x1, x2) + return array_api.vecdot(x1, x2, axis=axis) diff --git a/src/array_api_jax_compat/_manipulation_functions.py b/src/array_api_jax_compat/_manipulation_functions.py index 87dba5c..b9708ef 100644 --- a/src/array_api_jax_compat/_manipulation_functions.py +++ b/src/array_api_jax_compat/_manipulation_functions.py @@ -4,17 +4,14 @@ "concat", "expand_dims", "flip", - "moveaxis", "permute_dims", "reshape", "roll", "squeeze", "stack", - "tile", - "unstack", ] -import jax.numpy as jnp +from jax.experimental import array_api from quax import Value from ._utils import quaxify @@ -22,12 +19,12 @@ @quaxify def broadcast_arrays(*arrays: Value) -> list[Value]: - return jnp.broadcast_arrays(*arrays) + return array_api.broadcast_arrays(*arrays) @quaxify def broadcast_to(x: Value, /, shape: tuple[int, ...]) -> Value: - return jnp.broadcast_to(x, shape) + return array_api.broadcast_to(x, shape) @quaxify @@ -37,72 +34,45 @@ def concat( *, axis: int | None = 0, ) -> Value: - return jnp.concatenate(arrays, axis=axis) + return array_api.concat(arrays, axis=axis) @quaxify def expand_dims(x: Value, /, *, axis: int = 0) -> Value: - return jnp.expand_dims(x, axis=axis) + return array_api.expand_dims(x, axis=axis) @quaxify def flip(x: Value, /, *, axis: int | tuple[int, ...] | None = None) -> Value: - return jnp.flip(x, axis=axis) - - -@quaxify -def moveaxis( - x: Value, - source: int | tuple[int, ...], - destination: int | tuple[int, ...], - /, -) -> Value: - return jnp.moveaxis(x, source, destination) + return array_api.flip(x, axis=axis) @quaxify def permute_dims(x: Value, /, axes: tuple[int, ...]) -> Value: - return jnp.transpose(x, axes) + return array_api.permute_dims(x, axes=axes) @quaxify def reshape(x: Value, /, shape: tuple[int, ...], *, copy: bool | None = None) -> Value: - return jnp.reshape(x, shape, order="C" if copy else "K") + return array_api.reshape(x, shape, copy=copy) @quaxify def roll( x: Value, /, - shift: int | tuple[int, ...], + shift: int | tuple[int], *, axis: int | tuple[int, ...] | None = None, ) -> Value: - return jnp.roll(x, shift, axis=axis) + return array_api.roll(x, shift=shift, axis=axis) @quaxify def squeeze(x: Value, /, axis: int | tuple[int, ...]) -> Value: - return jnp.squeeze(x, axis=axis) + return array_api.squeeze(x, axis=axis) @quaxify def stack(arrays: tuple[Value, ...] | list[Value], /, *, axis: int = 0) -> Value: - return jnp.stack(arrays, axis=axis) - - -@quaxify -def tile(x: Value, repetitions: tuple[int, ...], /) -> Value: - return jnp.tile(x, repetitions) - - -@quaxify -def unstack( - x: Value, # TODO: support # pylint: disable=unused-argument - /, - *, - axis: int = 0, # TODO: support # pylint: disable=unused-argument -) -> tuple[Value, ...]: - msg = "not yet supported." - raise NotImplementedError(msg) - # return jnp.split(x, axis=axis) + return array_api.stack(arrays, axis=axis) diff --git a/src/array_api_jax_compat/_searching_functions.py b/src/array_api_jax_compat/_searching_functions.py index ff26aea..96362d5 100644 --- a/src/array_api_jax_compat/_searching_functions.py +++ b/src/array_api_jax_compat/_searching_functions.py @@ -1,7 +1,7 @@ __all__ = ["argmax", "argmin", "nonzero", "where"] -import jax.numpy as jnp +from jax.experimental import array_api from quax import Value from ._utils import quaxify @@ -9,19 +9,19 @@ @quaxify def argmax(x: Value, /, *, axis: int | None = None, keepdims: bool = False) -> Value: - return jnp.argmax(x, axis=axis, keepdims=keepdims) + return array_api.argmax(x, axis=axis, keepdims=keepdims) @quaxify def argmin(x: Value, /, *, axis: int | None = None, keepdims: bool = False) -> Value: - return jnp.argmin(x, axis=axis, keepdims=keepdims) + return array_api.argmin(x, axis=axis, keepdims=keepdims) @quaxify def nonzero(x: Value, /) -> tuple[Value, ...]: - return jnp.nonzero(x) + return array_api.nonzero(x) @quaxify def where(condition: Value, x1: Value, x2: Value, /) -> Value: - return jnp.where(condition, x1, x2) + return array_api.where(condition, x1, x2) diff --git a/src/array_api_jax_compat/_set_functions.py b/src/array_api_jax_compat/_set_functions.py index 8e46b08..22c159b 100644 --- a/src/array_api_jax_compat/_set_functions.py +++ b/src/array_api_jax_compat/_set_functions.py @@ -1,7 +1,7 @@ __all__ = ["unique_all", "unique_counts", "unique_inverse", "unique_values"] -import jax.numpy as jnp +from jax.experimental import array_api from quax import Value from ._utils import quaxify @@ -9,19 +9,19 @@ @quaxify def unique_all(x: Value, /) -> tuple[Value, Value, Value, Value]: - return jnp.unique(x, return_counts=True, return_index=True, return_inverse=True) + return array_api.unique_all(x) @quaxify def unique_counts(x: Value, /) -> tuple[Value, Value]: - return jnp.unique(x, return_counts=True) + return array_api.unique_counts(x) @quaxify def unique_inverse(x: Value, /) -> tuple[Value, Value]: - return jnp.unique(x, return_inverse=True) + return array_api.unique_inverse(x) @quaxify def unique_values(x: Value, /) -> Value: - return jnp.unique(x) + return array_api.unique_values(x) diff --git a/src/array_api_jax_compat/_sorting_functions.py b/src/array_api_jax_compat/_sorting_functions.py index ec9637b..96fb65b 100644 --- a/src/array_api_jax_compat/_sorting_functions.py +++ b/src/array_api_jax_compat/_sorting_functions.py @@ -1,7 +1,9 @@ +"""Sorting functions.""" + __all__ = ["argsort", "sort"] -import jax.numpy as jnp +from jax.experimental import array_api from quax import Value from ._utils import quaxify @@ -13,10 +15,10 @@ def argsort( /, *, axis: int = -1, - descending: bool = False, # TODO: support # pylint: disable=unused-argument + descending: bool = False, stable: bool = True, ) -> Value: - return jnp.argsort(x, axis=axis, kind="stable" if stable else "quicksort") + return array_api.argsort(x, axis=axis, descending=descending, stable=stable) @quaxify @@ -25,7 +27,7 @@ def sort( /, *, axis: int = -1, - descending: bool = False, # TODO: support # pylint: disable=unused-argument + descending: bool = False, stable: bool = True, ) -> Value: - return jnp.sort(x, axis=axis, kind="stable" if stable else "quicksort") + return array_api.sort(x, axis=axis, descending=descending, stable=stable) diff --git a/src/array_api_jax_compat/_statistical_functions.py b/src/array_api_jax_compat/_statistical_functions.py index a9c8163..6806d1e 100644 --- a/src/array_api_jax_compat/_statistical_functions.py +++ b/src/array_api_jax_compat/_statistical_functions.py @@ -1,25 +1,13 @@ -__all__ = ["cumulative_sum", "max", "mean", "min", "prod", "std", "sum", "var"] +__all__ = ["max", "mean", "min", "prod", "std", "sum", "var"] -import jax.numpy as jnp +from jax.experimental import array_api from quax import Value from ._types import DType from ._utils import quaxify -@quaxify -def cumulative_sum( - x: Value, - /, - *, - axis: int | None = None, - dtype: DType | None = None, - include_initial: bool = False, # TODO: support # pylint: disable=unused-argument -) -> Value: - return jnp.cumsum(x, axis=axis, dtype=dtype) - - @quaxify def max( # pylint: disable=redefined-builtin x: Value, @@ -28,7 +16,7 @@ def max( # pylint: disable=redefined-builtin axis: int | tuple[int, ...] | None = None, keepdims: bool = False, ) -> Value: - return jnp.max(x, axis=axis, keepdims=keepdims) + return array_api.max(x, axis=axis, keepdims=keepdims) @quaxify @@ -39,7 +27,7 @@ def mean( axis: int | tuple[int, ...] | None = None, keepdims: bool = False, ) -> Value: - return jnp.mean(x, axis=axis, keepdims=keepdims) + return array_api.mean(x, axis=axis, keepdims=keepdims) @quaxify @@ -50,7 +38,7 @@ def min( # pylint: disable=redefined-builtin axis: int | tuple[int, ...] | None = None, keepdims: bool = False, ) -> Value: - return jnp.min(x, axis=axis, keepdims=keepdims) + return array_api.min(x, axis=axis, keepdims=keepdims) @quaxify @@ -62,7 +50,7 @@ def prod( dtype: DType | None = None, keepdims: bool = False, ) -> Value: - return jnp.prod(x, axis=axis, dtype=dtype, keepdims=keepdims) + return array_api.prod(x, axis=axis, dtype=dtype, keepdims=keepdims) @quaxify @@ -74,7 +62,7 @@ def std( correction: int | float = 0.0, keepdims: bool = False, ) -> Value: - return jnp.std(x, axis=axis, ddof=correction, keepdims=keepdims) + return array_api.std(x, axis=axis, correction=correction, keepdims=keepdims) @quaxify @@ -86,7 +74,7 @@ def sum( # pylint: disable=redefined-builtin dtype: DType | None = None, keepdims: bool = False, ) -> Value: - return jnp.sum(x, axis=axis, dtype=dtype, keepdims=keepdims) + return array_api.sum(x, axis=axis, dtype=dtype, keepdims=keepdims) @quaxify @@ -98,4 +86,4 @@ def var( correction: int | float = 0.0, keepdims: bool = False, ) -> Value: - return jnp.var(x, axis=axis, ddof=correction, keepdims=keepdims) + return array_api.var(x, axis=axis, correction=correction, keepdims=keepdims) diff --git a/src/array_api_jax_compat/_types.py b/src/array_api_jax_compat/_types.py index 96dfe2e..5117515 100644 --- a/src/array_api_jax_compat/_types.py +++ b/src/array_api_jax_compat/_types.py @@ -18,6 +18,7 @@ class DType(Protocol): dtype: np.dtype[Any] +@runtime_checkable # TODO: need actual implementation class SupportsBufferProtocol(Protocol): """Supports the buffer protocol.""" @@ -25,6 +26,7 @@ class SupportsBufferProtocol(Protocol): _T_co = TypeVar("_T_co", covariant=True) +@runtime_checkable class NestedSequence(Protocol[_T_co]): """A nested sequence.""" diff --git a/src/array_api_jax_compat/_utility_functions.py b/src/array_api_jax_compat/_utility_functions.py index 8c70fbf..e7cf4f4 100644 --- a/src/array_api_jax_compat/_utility_functions.py +++ b/src/array_api_jax_compat/_utility_functions.py @@ -2,7 +2,7 @@ __all__ = ["all", "any"] -import jax.numpy as jnp +from jax.experimental import array_api from quax import Value from ._utils import quaxify @@ -16,7 +16,7 @@ def all( axis: int | tuple[int, ...] | None = None, keepdims: bool = False, ) -> Value: - return jnp.all(x, axis=axis, keepdims=keepdims) + return array_api.all(x, axis=axis, keepdims=keepdims) @quaxify @@ -27,4 +27,4 @@ def any( axis: int | tuple[int, ...] | None = None, keepdims: bool = False, ) -> Value: - return jnp.any(x, axis=axis, keepdims=keepdims) + return array_api.any(x, axis=axis, keepdims=keepdims) diff --git a/src/array_api_jax_compat/fft.py b/src/array_api_jax_compat/fft.py index 2ab496e..da923b6 100644 --- a/src/array_api_jax_compat/fft.py +++ b/src/array_api_jax_compat/fft.py @@ -20,9 +20,8 @@ from collections.abc import Sequence from typing import Literal -import jax -import jax.numpy as jnp from jax import Device +from jax.experimental.array_api import fft as _jax_fft from quax import Value from ._utils import quaxify @@ -37,7 +36,7 @@ def fft( axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", ) -> Value: - return jnp.fft.fft(x, n=n, axis=axis, norm=norm) + return _jax_fft.fft(x, n=n, axis=axis, norm=norm) @quaxify @@ -49,7 +48,7 @@ def ifft( axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", ) -> Value: - return jnp.fft.ifft(x, n=n, axis=axis, norm=norm) + return _jax_fft.ifft(x, n=n, axis=axis, norm=norm) @quaxify @@ -61,7 +60,7 @@ def fftn( axes: Sequence[int] | None = None, norm: Literal["backward", "ortho", "forward"] = "backward", ) -> Value: - return jnp.fft.fftn(x, s=s, axes=axes, norm=norm) + return _jax_fft.fftn(x, s=s, axes=axes, norm=norm) @quaxify @@ -73,7 +72,7 @@ def ifftn( axes: Sequence[int] | None = None, norm: Literal["backward", "ortho", "forward"] = "backward", ) -> Value: - return jnp.fft.ifftn(x, s=s, axes=axes, norm=norm) + return _jax_fft.ifftn(x, s=s, axes=axes, norm=norm) @quaxify @@ -85,7 +84,7 @@ def rfft( axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", ) -> Value: - return jnp.fft.rfft(x, n=n, axis=axis, norm=norm) + return _jax_fft.rfft(x, n=n, axis=axis, norm=norm) @quaxify @@ -97,7 +96,7 @@ def irfft( axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", ) -> Value: - return jnp.fft.irfft(x, n=n, axis=axis, norm=norm) + return _jax_fft.irfft(x, n=n, axis=axis, norm=norm) @quaxify @@ -109,7 +108,7 @@ def rfftn( axes: Sequence[int] | None = None, norm: Literal["backward", "ortho", "forward"] = "backward", ) -> Value: - return jnp.fft.rfftn(x, s=s, axes=axes, norm=norm) + return _jax_fft.rfftn(x, s=s, axes=axes, norm=norm) @quaxify @@ -121,7 +120,7 @@ def irfftn( axes: Sequence[int] | None = None, norm: Literal["backward", "ortho", "forward"] = "backward", ) -> Value: - return jnp.fft.irfftn(x, s=s, axes=axes, norm=norm) + return _jax_fft.irfftn(x, s=s, axes=axes, norm=norm) @quaxify @@ -133,7 +132,7 @@ def hfft( axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", ) -> Value: - return jnp.fft.hfft(x, n=n, axis=axis, norm=norm) + return _jax_fft.hfft(x, n=n, axis=axis, norm=norm) @quaxify @@ -145,26 +144,24 @@ def ihfft( axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", ) -> Value: - return jnp.fft.ihfft(x, n=n, axis=axis, norm=norm) + return _jax_fft.ihfft(x, n=n, axis=axis, norm=norm) @quaxify def fftfreq(n: int, /, *, d: float = 1.0, device: Device | None = None) -> Value: - out = jnp.fft.fftfreq(n, d=d) - return jax.device_put(out, device=device) + return _jax_fft.fftfreq(n, d=d, device=device) @quaxify def rfftfreq(n: int, /, *, d: float = 1.0, device: Device | None = None) -> Value: - out = jnp.fft.rfftfreq(n, d=d) - return jax.device_put(out, device=device) + return _jax_fft.rfftfreq(n, d=d, device=device) @quaxify def fftshift(x: Value, /, *, axes: int | Sequence[int] | None = None) -> Value: - return jnp.fft.fftshift(x, axes=axes) + return _jax_fft.fftshift(x, axes=axes) @quaxify def ifftshift(x: Value, /, *, axes: int | Sequence[int] | None = None) -> Value: - return jnp.fft.ifftshift(x, axes=axes) + return _jax_fft.ifftshift(x, axes=axes) diff --git a/src/array_api_jax_compat/linalg.py b/src/array_api_jax_compat/linalg.py index bbea605..86fc8b1 100644 --- a/src/array_api_jax_compat/linalg.py +++ b/src/array_api_jax_compat/linalg.py @@ -31,6 +31,7 @@ from typing import Literal import jax.numpy as jnp +from jax.experimental import array_api from quax import Value from ._types import DType @@ -38,48 +39,43 @@ @quaxify -def cholesky( - x: Value, - /, - *, - upper: bool = False, # TODO: support # pylint: disable=unused-argument -) -> Value: - return jnp.linalg.cholesky(x) +def cholesky(x: Value, /, *, upper: bool = False) -> Value: + return array_api.linalg.cholesky(x, upper=upper) @quaxify def cross(x1: Value, x2: Value, /, *, axis: int = -1) -> Value: - return jnp.cross(x1, x2, axis=axis) + return array_api.linalg.cross(x1, x2, axis=axis) @quaxify def det(x: Value, /) -> Value: - return jnp.linalg.det(x) + return array_api.linalg.det(x) @quaxify def diagonal(x: Value, /, *, offset: int = 0) -> Value: - return jnp.diagonal(x, offset=offset) + return array_api.linalg.diagonal(x, offset=offset) @quaxify def eigh(x: Value, /) -> tuple[Value]: - return jnp.linalg.eigh(x) + return array_api.linalg.eigh(x) @quaxify def eigvalsh(x: Value, /) -> Value: - return jnp.linalg.eigvalsh(x) + return array_api.linalg.eigvalsh(x) @quaxify def inv(x: Value, /) -> Value: - return jnp.linalg.inv(x) + return array_api.linalg.inv(x) @quaxify def matmul(x1: Value, x2: Value, /) -> Value: - return jnp.matmul(x1, x2) + return array_api.matmul(x1, x2) @quaxify @@ -90,37 +86,32 @@ def matrix_norm( keepdims: bool = False, ord: int | float | Literal["fro", "nuc"] | None = "fro", ) -> Value: - return jnp.linalg.norm(x, keepdims=keepdims, ord=ord) + return array_api.linalg.matrix_norm(x, keepdims=keepdims, ord=ord) @quaxify def matrix_power(x: Value, n: int, /) -> Value: - return jnp.linalg.matrix_power(x, n) + return array_api.linalg.matrix_power(x, n) @quaxify def matrix_rank(x: Value, /, *, rtol: float | Value | None = None) -> Value: - return jnp.linalg.matrix_rank(x, tol=rtol) + return array_api.linalg.matrix_rank(x, rtol=rtol) @quaxify def matrix_transpose(x: Value, /) -> Value: - return jnp.transpose(x) + return array_api.linalg.matrix_transpose(x) @quaxify def outer(x1: Value, x2: Value, /) -> Value: - return jnp.outer(x1, x2) + return array_api.linalg.outer(x1, x2) @quaxify -def pinv( - x: Value, - /, - *, - rtol: float | Value | None = None, # pylint: disable=unused-argument -) -> Value: - return jnp.linalg.pinv(x, rcond=rtol) +def pinv(x: Value, /, *, rtol: float | Value | None = None) -> Value: + return array_api.linalg.pinv(x, rtol=rtol) @quaxify @@ -130,27 +121,27 @@ def qr( *, mode: Literal["reduced", "complete"] = "reduced", ) -> tuple[Value, Value]: - return jnp.linalg.qr(x, mode=mode) + return array_api.linalg.qr(x, mode=mode) @quaxify def slogdet(x: Value, /) -> tuple[Value, Value]: - return jnp.linalg.slogdet(x) + return array_api.linalg.slogdet(x) @quaxify def solve(x1: Value, x2: Value, /) -> Value: - return jnp.linalg.solve(x1, x2) + return array_api.linalg.solve(x1, x2) @quaxify def svd(x: Value, /, *, full_matrices: bool = True) -> tuple[Value, Value, Value]: - return jnp.linalg.svd(x, full_matrices=full_matrices) + return array_api.linalg.svd(x, full_matrices=full_matrices) @quaxify def svdvals(x: Value, /) -> Value: - return jnp.linalg.svd(x, compute_uv=False) + return array_api.linalg.svdvals(x) @quaxify @@ -161,7 +152,7 @@ def tensordot( *, axes: int | tuple[Sequence[int], Sequence[int]] = 2, ) -> Value: - return jnp.tensordot(x1, x2, axes=axes) + return array_api.tensordot(x1, x2, axes=axes) @quaxify @@ -170,14 +161,8 @@ def trace(x: Value, /, *, offset: int = 0, dtype: DType | None = None) -> Value: @quaxify -def vecdot( - x1: Value, - x2: Value, - /, - *, - axis: int | None = None, # TODO: support # pylint: disable=unused-argument -) -> Value: - return jnp.dot(x1, x2) +def vecdot(x1: Value, x2: Value, /, *, axis: int | None = None) -> Value: + return array_api.vecdot(x1, x2, axis=axis) @quaxify @@ -189,4 +174,4 @@ def vector_norm( keepdims: bool = False, ord: int | float = 2, # pylint: disable=redefined-builtin ) -> Value: - return jnp.linalg.norm(x, axis=axis, keepdims=keepdims, ord=ord) + return array_api.linalg.vector_norm(x, axis=axis, keepdims=keepdims, ord=ord) diff --git a/tests/myarray.py b/tests/myarray.py new file mode 100644 index 0000000..3741b77 --- /dev/null +++ b/tests/myarray.py @@ -0,0 +1,1472 @@ +"""Test with :class:`quax.DenseArrayValue` inputs.""" + +from collections.abc import Sequence +from dataclasses import replace +from typing import Any + +import equinox as eqx +import jax +import jax.experimental.array_api as jax_xp +from jax import Device, lax +from jax._src.lax.lax import DotDimensionNumbers, PrecisionLike +from jax._src.lax.slicing import GatherDimensionNumbers, GatherScatterMode +from jax._src.typing import DTypeLike, Shape +from quax import ArrayValue, DenseArrayValue, register +from quax.zero import Zero + +from array_api_jax_compat._dispatch import dispatcher +from array_api_jax_compat._types import DType + + +class MyArray(ArrayValue): + """A :class:`quax.ArrayValue` that is dense. + + This is different from :class:`quax.MyArray` only in that + `quax` will not attempt to convert it to a JAX array. + """ + + array: jax.Array = eqx.field(converter=jax_xp.asarray) + + def materialise(self) -> jax.Array: + """Convert to a JAX array.""" + raise NotImplementedError + + def aval(self) -> jax.core.ShapedArray: + """Return the ShapedArray.""" + return jax.core.get_aval(self.array) + + +# ============================================================================== + + +@register(lax.abs_p) +def _abs_p(x: MyArray) -> MyArray: + return replace(x, array=lax.abs(x.array)) + + +# ============================================================================== + + +@register(lax.acos_p) +def _acos_p(x: MyArray) -> MyArray: + return replace(x, array=lax.acos(x.array)) + + +# ============================================================================== + + +@register(lax.acosh_p) +def _acosh_p(x: MyArray) -> MyArray: + return replace(x, array=lax.acosh(x.array)) + + +# ============================================================================== + + +@register(lax.add_p) +def _add_p_qq(x: MyArray, y: DenseArrayValue | MyArray) -> MyArray: + return MyArray(lax.add(x.array, y.array)) + + +# ============================================================================== + + +@register(lax.after_all_p) +def _after_all_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.all_gather_p) +def _all_gather_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.all_to_all_p) +def _all_to_all_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.and_p) +def _and_p(x1: MyArray, x2: MyArray, /) -> MyArray: + return MyArray(x1.array & x2.array) + + +# ============================================================================== + + +@register(lax.approx_top_k_p) +def _approx_top_k_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.argmax_p) +def _argmax_p(operand: MyArray, *, axes: Any, index_dtype: Any) -> MyArray: + return replace(operand, array=lax.argmax(operand.array, axes[0], index_dtype)) + + +# ============================================================================== + + +@register(lax.argmin_p) +def _argmin_p(operand: MyArray, *, axes: Any, index_dtype: Any) -> MyArray: + return replace(operand, array=lax.argmin(operand.array, axes[0], index_dtype)) + + +# ============================================================================== + + +@register(lax.asin_p) +def _asin_p(x: MyArray) -> MyArray: + return replace(x, array=lax.asin(x.array)) + + +# ============================================================================== + + +@register(lax.asinh_p) +def _asinh_p(x: MyArray) -> MyArray: + return replace(x, array=lax.asinh(x.array)) + + +# ============================================================================== + + +@register(lax.atan2_p) +def _atan2_p(x: MyArray, y: MyArray) -> MyArray: + return MyArray(lax.atan2(x.array, y.array)) + + +# ============================================================================== + + +@register(lax.atan_p) +def _atan_p(x: MyArray) -> MyArray: + return MyArray(lax.atan(x.array)) + + +# ============================================================================== + + +@register(lax.atanh_p) +def _atanh_p(x: MyArray) -> MyArray: + return MyArray(lax.atanh(x.array)) + + +# ============================================================================== + + +@register(lax.axis_index_p) +def _axis_index_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.bessel_i0e_p) +def _bessel_i0e_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.bessel_i1e_p) +def _bessel_i1e_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.bitcast_convert_type_p) +def _bitcast_convert_type_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.broadcast_in_dim_p) +def _broadcast_in_dim_p( + operand: MyArray, + *, + shape: Any, + broadcast_dimensions: Any, +) -> MyArray: + return replace( + operand, + array=lax.broadcast_in_dim(operand.array, shape, broadcast_dimensions), + ) + + +# ============================================================================== + + +@register(lax.cbrt_p) +def _cbrt_p(x: MyArray) -> MyArray: + return MyArray(lax.cbrt(x.array)) + + +# ============================================================================== + + +@register(lax.ceil_p) +def _ceil_p(x: MyArray) -> MyArray: + return replace(x, array=lax.ceil(x.array)) + + +# ============================================================================== + + +@register(lax.clamp_p) +def _clamp_p(min: MyArray, x: MyArray, max: MyArray) -> MyArray: + return replace(x, array=lax.clamp(min.array, x.array, max.array)) + + +# ============================================================================== + + +@register(lax.clz_p) +def _clz_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.complex_p) +def _complex_p(x: MyArray, y: MyArray) -> MyArray: + return MyArray(lax.complex(x.array, y.array)) + + +# ============================================================================== + + +@register(lax.concatenate_p) +def _concatenate_p( + operand0: MyArray, + *operands: MyArray | DenseArrayValue, + dimension: Any, +) -> MyArray: + return MyArray( + lax.concatenate( + [operand0.array] + [op.array for op in operands], + dimension=dimension, + ), + ) + + +# ============================================================================== + + +@register(lax.cond_p) # TODO: implement +def _cond_p(index, consts) -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.conj_p) +def _conj_p(x: MyArray, *, input_dtype: Any) -> MyArray: + del input_dtype # TODO: use this? + return replace(x, array=lax.conj(x.array)) + + +# ============================================================================== + + +@register(lax.conv_general_dilated_p) +def _conv_general_dilated_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.convert_element_type_p) +def _convert_element_type_p( + operand: MyArray, + *, + new_dtype: Any, + weak_type: Any, +) -> MyArray: + del weak_type + return replace(operand, array=lax.convert_element_type(operand.array, new_dtype)) + + +# ============================================================================== + + +@register(lax.copy_p) +def _copy_p(x: MyArray) -> MyArray: + return replace(x, array=lax.copy_p.bind(x.array)) + + +# ============================================================================== + + +@register(lax.cos_p) +def _cos_p(x: MyArray) -> MyArray: + return replace(x, array=lax.cos(x.array)) + + +# ============================================================================== + + +@register(lax.cosh_p) +def _cosh_p(x: MyArray) -> MyArray: + return replace(x, array=lax.cosh(x.array)) + + +# ============================================================================== + + +@register(lax.create_token_p) +def _create_token_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.cumlogsumexp_p) +def _cumlogsumexp_p(operand: MyArray, *, axis: Any, reverse: Any) -> MyArray: + # TODO: double check units make sense here. + return replace( + operand, + array=lax.cumlogsumexp(operand.array, axis=axis, reverse=reverse), + ) + + +# ============================================================================== + + +@register(lax.cummax_p) +def _cummax_p(operand: MyArray, *, axis: Any, reverse: Any) -> MyArray: + return replace(operand, array=lax.cummax(operand.array, axis=axis, reverse=reverse)) + + +# ============================================================================== + + +@register(lax.cummin_p) +def _cummin_p(operand: MyArray, *, axis: Any, reverse: Any) -> MyArray: + return replace(operand, array=lax.cummin(operand.array, axis=axis, reverse=reverse)) + + +# ============================================================================== + + +@register(lax.cumprod_p) +def _cumprod_p(operand: MyArray, *, axis: Any, reverse: Any) -> MyArray: + return replace( + operand, + array=lax.cumprod(operand.array, axis=axis, reverse=reverse), + ) + + +# ============================================================================== + + +@register(lax.cumsum_p) +def _cumsum_p(operand: MyArray, *, axis: Any, reverse: Any) -> MyArray: + return replace(operand, array=lax.cumsum(operand.array, axis=axis, reverse=reverse)) + + +# ============================================================================== + + +@register(lax.device_put_p) +def _device_put_p(x: MyArray, *, device: Any, src: Any) -> MyArray: + return replace(x, array=jax.device_put(x.array, device=device, src=src)) + + +# ============================================================================== + + +@register(lax.digamma_p) +def _digamma_p(x: MyArray) -> MyArray: + return replace(x, array=lax.digamma(x.array)) + + +# ============================================================================== + + +@register(lax.div_p) +def _div_p(x: MyArray, y: DenseArrayValue | MyArray) -> MyArray: + return MyArray(lax.div(x.array, y.array)) + + +# ============================================================================== + + +@register(lax.dot_general_p) # TODO: implement +def _dot_general_p( + lhs: MyArray, + rhs: MyArray, + *, + dimension_numbers: DotDimensionNumbers, + precision: PrecisionLike = None, + preferred_element_type: DTypeLike | None = None, +) -> MyArray: + return MyArray( + lax.dot_general_p.bind( + lhs.array, + rhs.array, + dimension_numbers=dimension_numbers, + precision=precision, + preferred_element_type=preferred_element_type, + ), + ) + + +# ============================================================================== + + +@register(lax.dynamic_slice_p) +def _dynamic_slice_p( + operand: MyArray, + start_indices: DenseArrayValue, + dynamic_sizes: DenseArrayValue, + *, + slice_sizes: Any, +) -> MyArray: + raise NotImplementedError # TODO: implement + + +# ============================================================================== + + +@register(lax.dynamic_update_slice_p) +def _dynamic_update_slice_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.eq_p) +def _eq_p(x: MyArray, y: DenseArrayValue | MyArray) -> MyArray: + return MyArray(lax.eq(x.array, y.array)) + + +# ============================================================================== + + +@register(lax.eq_to_p) +def _eq_to_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.erf_inv_p) +def _erf_inv_p(x: MyArray) -> MyArray: + return replace(x, array=lax.erf_inv(x.array)) + + +# ============================================================================== + + +@register(lax.erf_p) +def _erf_p(x: MyArray) -> MyArray: + return replace(x, array=lax.erf(x.array)) + + +# ============================================================================== + + +@register(lax.erfc_p) +def _erfc_p(x: MyArray) -> MyArray: + return replace(x, array=lax.erfc(x.array)) + + +# ============================================================================== + + +@register(lax.exp2_p) +def _exp2_p(x: MyArray) -> MyArray: + return replace(x, array=lax.exp2(x.array)) + + +# ============================================================================== + + +@register(lax.exp_p) +def _exp_p(x: MyArray) -> MyArray: + return replace(x, array=lax.exp(x.array)) + + +# ============================================================================== + + +@register(lax.expm1_p) +def _expm1_p(x: MyArray) -> MyArray: + return replace(x, array=lax.expm1(x.array)) + + +# ============================================================================== + + +@register(lax.fft_p) +def _fft_p(x: MyArray, *, fft_type: Any, fft_lengths: Any) -> MyArray: + return replace(x, array=lax.fft(x.array, fft_type, fft_lengths)) + + +# ============================================================================== + + +@register(lax.floor_p) +def _floor_p(x: MyArray) -> MyArray: + return replace(x, array=lax.floor(x.array)) + + +# ============================================================================== + + +@register(lax.gather_p) +def _gather_p( + operand: MyArray, + start_indices: DenseArrayValue | MyArray, + *, + dimension_numbers: GatherDimensionNumbers, + slice_sizes: Shape, + unique_indices: bool, + indices_are_sorted: bool, + mode: str | GatherScatterMode | None = None, + fill_value: Any = None, +) -> MyArray: + return MyArray( + lax.gather( + operand.array, + start_indices.array, + dimension_numbers=dimension_numbers, + slice_sizes=slice_sizes, + unique_indices=unique_indices, + indices_are_sorted=indices_are_sorted, + mode=mode, + fill_value=fill_value, + ), + ) + + +# ============================================================================== + + +@register(lax.ge_p) +def _ge_p(x: MyArray, y: MyArray) -> MyArray: + return MyArray(lax.ge(x.array, y.array)) + + +# ============================================================================== + + +@register(lax.gt_p) +def _gt_p(x: MyArray, y: MyArray) -> MyArray: + return MyArray(lax.gt(x.array, y.array)) + + +# ============================================================================== + + +@register(lax.igamma_grad_a_p) +def _igamma_grad_a_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.igamma_p) +def _igamma_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.igammac_p) +def _igammac_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.imag_p) +def _imag_p(x: MyArray) -> MyArray: + return replace(x, array=lax.imag(x.array)) + + +# ============================================================================== + + +@register(lax.infeed_p) +def _infeed_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.integer_pow_p) +def _integer_pow_p(x: MyArray, *, y: Any) -> MyArray: + return replace(x, array=lax.integer_pow(x.array, y)) + + +# ============================================================================== + + +# @register(lax.iota_p) +# def _iota_p(dtype: MyArray) -> MyArray: +# raise NotImplementedError + + +# ============================================================================== + + +@register(lax.is_finite_p) +def _is_finite_p(x: MyArray) -> MyArray: + return replace(x, array=lax.is_finite(x.array)) + + +# ============================================================================== + + +@register(lax.le_p) +def _le_p(x: MyArray, y: MyArray) -> MyArray: + return MyArray(lax.le(x.array, y.array)) + + +# ============================================================================== + + +@register(lax.le_to_p) +def _le_to_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.lgamma_p) +def _lgamma_p(x: MyArray) -> MyArray: + return replace(x, array=lax.lgamma(x.array)) + + +# ============================================================================== + + +@register(lax.linear_solve_p) +def _linear_solve_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.log1p_p) +def _log1p_p(x: MyArray) -> MyArray: + return replace(x, array=lax.log1p(x.array)) + + +# ============================================================================== + + +@register(lax.log_p) +def _log_p(x: MyArray) -> MyArray: + return replace(x, array=lax.log(x.array)) + + +# ============================================================================== + + +@register(lax.logistic_p) +def _logistic_p(x: MyArray) -> MyArray: + return replace(x, array=lax.logistic(x.array)) + + +# ============================================================================== + + +@register(lax.lt_p) +def _lt_p(x: MyArray, y: DenseArrayValue | MyArray) -> MyArray: + return MyArray(lax.lt(x.array, y.array)) + + +# ============================================================================== + + +@register(lax.lt_to_p) +def _lt_to_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.max_p) +def _max_p(x: MyArray, y: MyArray) -> MyArray: + return MyArray(lax.max(x.array, y.array)) + + +@register(lax.max_p) +def _max_p_d1(x: DenseArrayValue, y: MyArray) -> MyArray: + return MyArray(lax.max(x.array, y.array)) + + +@register(lax.max_p) +def _max_p_d2(x: MyArray, y: DenseArrayValue) -> MyArray: + return MyArray(lax.max(x.array, y.array)) + + +# ============================================================================== + + +@register(lax.min_p) +def _min_p(x: MyArray, y: MyArray) -> MyArray: + return MyArray(lax.min(x.array, y.array)) + + +# ============================================================================== +# Multiplication + + +@register(lax.mul_p) +def _mul_p(x: MyArray, y: MyArray) -> MyArray: + return MyArray(lax.mul(x.array, y.array)) + + +# ============================================================================== + + +@register(lax.ne_p) +def _ne_p(x: MyArray, y: DenseArrayValue) -> MyArray: + return MyArray(lax.ne(x.array, y.materialise())) + + +@register(lax.ne_p) +def _ne_p(x: MyArray, y: MyArray) -> MyArray: + return MyArray(lax.ne(x.array, y.array)) + + +# ============================================================================== + + +@register(lax.neg_p) +def _neg_p(x: MyArray) -> MyArray: + return replace(x, array=lax.neg(x.array)) + + +# ============================================================================== + + +@register(lax.nextafter_p) +def _nextafter_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.not_p) +def _not_p(x: MyArray) -> MyArray: + return replace(x, array=lax.bitwise_not(x.array)) + + +# ============================================================================== + + +@register(lax.or_p) +def _or_p(x: MyArray, y: MyArray) -> MyArray: + return replace(x, array=lax.bitwise_or(x.array, y.array)) + + +# ============================================================================== + + +@register(lax.outfeed_p) +def _outfeed_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.pad_p) +def _pad_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.pmax_p) +def _pmax_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.pmin_p) +def _pmin_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.polygamma_p) +def _polygamma_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.population_count_p) +def _population_count_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.pow_p) +def _pow_p_qq(x: MyArray, y: MyArray) -> MyArray: + return MyArray(array=lax.pow(x.array, y.array)) + + +# ============================================================================== + + +@register(lax.ppermute_p) +def _ppermute_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.psum_p) +def _psum_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.random_gamma_grad_p) +def _random_gamma_grad_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.real_p) +def _real_p(x: MyArray) -> MyArray: + return replace(x, array=lax.real(x.array)) + + +# ============================================================================== + + +@register(lax.reduce_and_p) +def _reduce_and_p( + operand: MyArray, + *, + axes: Sequence[int], +) -> Any: + return lax.reduce_and_p.bind(operand.array, axes=tuple(axes)) + + +# ============================================================================== + + +@register(lax.reduce_max_p) +def _reduce_max_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.reduce_min_p) +def _reduce_min_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.reduce_or_p) +def _reduce_or_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.reduce_p) +def _reduce_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.reduce_precision_p) +def _reduce_precision_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.reduce_prod_p) +def _reduce_prod_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.reduce_sum_p) +def _reduce_sum_p(operand: MyArray, *, axes: tuple[int, ...]) -> MyArray: + return MyArray(lax.reduce_sum_p.bind(operand.array, axes=axes)) + + +# ============================================================================== + + +@register(lax.reduce_window_max_p) +def _reduce_window_max_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.reduce_window_min_p) +def _reduce_window_min_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.reduce_window_p) +def _reduce_window_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.reduce_window_sum_p) +def _reduce_window_sum_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.reduce_xor_p) +def _reduce_xor_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.regularized_incomplete_beta_p) +def _regularized_incomplete_beta_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.rem_p) +def _rem_p(x: MyArray, y: MyArray) -> MyArray: + return MyArray(lax.rem(x.array, y.array)) + + +@register(lax.rem_p) +def _rem_p_d1(x: DenseArrayValue, y: MyArray) -> MyArray: + return MyArray(lax.rem(x.array, y.array)) + + +@register(lax.rem_p) +def _rem_p_d1(x: MyArray, y: DenseArrayValue) -> MyArray: + return MyArray(lax.rem(x.array, y.array)) + + +# ============================================================================== + + +@register(lax.reshape_p) +def _reshape_p(operand: MyArray, *, new_sizes: Any, dimensions: Any) -> MyArray: + return replace(operand, array=lax.reshape(operand.array, new_sizes, dimensions)) + + +# ============================================================================== + + +@register(lax.rev_p) +def _rev_p(operand: MyArray, *, dimensions: Any) -> MyArray: + return replace(operand, array=lax.rev(operand.array, dimensions)) + + +# ============================================================================== + + +@register(lax.rng_bit_generator_p) +def _rng_bit_generator_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.rng_uniform_p) +def _rng_uniform_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.round_p) +def _round_p(x: MyArray, *, rounding_method: Any) -> MyArray: + return replace(x, array=lax.round(x.array, rounding_method)) + + +# ============================================================================== + + +@register(lax.rsqrt_p) +def _rsqrt_p(x: MyArray) -> MyArray: + return replace(x, array=lax.rsqrt(x.array) ** (-1 / 2)) + + +# ============================================================================== + + +@register(lax.scan_p) +def _scan_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.scatter_add_p) +def _scatter_add_p( + operand: MyArray, + scatter_indices: MyArray | DenseArrayValue, + updates: MyArray | DenseArrayValue, + *, + update_jaxpr: Any, + update_consts: Any, + dimension_numbers: Any, + indices_are_sorted: bool, + unique_indices: bool, + mode: str | GatherScatterMode | None = None, +) -> MyArray: + return MyArray( + lax.scatter_add_p.bind( + operand.array, + scatter_indices.array, + updates.array, + update_jaxpr=update_jaxpr, + update_consts=update_consts, + dimension_numbers=dimension_numbers, + indices_are_sorted=indices_are_sorted, + unique_indices=unique_indices, + mode=mode, + ), + ) + + +@register(lax.scatter_add_p) +def _scatter_add_p( + operand: Zero, + scatter_indices: MyArray | DenseArrayValue, + updates: MyArray | DenseArrayValue, + *, + update_jaxpr: Any, + update_consts: Any, + dimension_numbers: Any, + indices_are_sorted: bool, + unique_indices: bool, + mode: str | GatherScatterMode | None = None, +) -> MyArray: + return MyArray( + lax.scatter_add_p.bind( + jax_xp.zeros_like(operand), + scatter_indices.array, + updates.array, + update_jaxpr=update_jaxpr, + update_consts=update_consts, + dimension_numbers=dimension_numbers, + indices_are_sorted=indices_are_sorted, + unique_indices=unique_indices, + mode=mode, + ), + ) + + +# ============================================================================== + + +@register(lax.scatter_max_p) +def _scatter_max_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.scatter_min_p) +def _scatter_min_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.scatter_mul_p) +def _scatter_mul_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.scatter_p) +def _scatter_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.select_and_gather_add_p) +def _select_and_gather_add_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.select_and_scatter_add_p) +def _select_and_scatter_add_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.select_and_scatter_p) +def _select_and_scatter_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.select_n_p) +def _select_n_p(which: DenseArrayValue | MyArray, *cases: Zero | MyArray) -> MyArray: + if not any(isinstance(case, MyArray) for case in cases): + msg = "At least one case must be a MyArray." + raise ValueError(msg) + + # Process the cases, replacing Zero and MyArray with a materialised array. + cases_ = ( + case.array if isinstance(case, MyArray) else case.materialise() + for case in cases + ) + return MyArray(lax.select_n(which.array, *cases_)) + + +# ============================================================================== + + +@register(lax.sharding_constraint_p) +def _sharding_constraint_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.shift_left_p) +def _shift_left_p(x: MyArray, y: MyArray) -> MyArray: + return MyArray(lax.shift_left(x.array, y.array)) + + +# ============================================================================== + + +@register(lax.shift_right_arithmetic_p) +def _shift_right_arithmetic_p(x: MyArray, y: MyArray) -> MyArray: + return MyArray(lax.shift_right_arithmetic(x.array, y.array)) + + +# ============================================================================== + + +@register(lax.shift_right_logical_p) +def _shift_right_logical_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.sign_p) +def _sign_p(x: MyArray) -> MyArray: + return replace(x, array=lax.sign(x.array)) + + +# ============================================================================== + + +@register(lax.sin_p) +def _sin_p(x: MyArray) -> MyArray: + return replace(x, array=lax.sin(x.array)) + + +# ============================================================================== + + +@register(lax.sinh_p) +def _sinh_p(x: MyArray) -> MyArray: + return replace(x, array=lax.sinh(x.array)) + + +# ============================================================================== + + +@register(lax.slice_p) +def _slice_p( + operand: MyArray, + *, + start_indices: Any, + limit_indices: Any, + strides: Any, +) -> MyArray: + return replace( + operand, + array=lax.slice_p.bind( + operand.array, + start_indices=start_indices, + limit_indices=limit_indices, + strides=strides, + ), + ) + + +# ============================================================================== + + +@register(lax.sort_p) +def _sort_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.sqrt_p) +def _sqrt_p(x: MyArray) -> MyArray: + return replace(x, array=lax.sqrt(x.array)) + + +# ============================================================================== + + +@register(lax.squeeze_p) +def _squeeze_p(x: MyArray, *, dimensions: Any) -> MyArray: + return replace(x, array=lax.squeeze(x.array, dimensions)) + + +# ============================================================================== + + +@register(lax.stop_gradient_p) +def _stop_gradient_p(x: MyArray) -> MyArray: + return replace(x, array=lax.stop_gradient(x.array)) + + +# ============================================================================== +# Subtraction + + +@register(lax.sub_p) +def _sub_p(x: MyArray, y: DenseArrayValue | MyArray) -> MyArray: + return MyArray(lax.sub(x.array, y.array)) + + +# ============================================================================== + + +@register(lax.tan_p) +def _tan_p(x: MyArray) -> MyArray: + return replace(x, array=lax.tan(x.array)) + + +# ============================================================================== + + +@register(lax.tanh_p) +def _tanh_p(x: MyArray) -> MyArray: + return replace(x, array=lax.tanh(x.array)) + + +# ============================================================================== + + +@register(lax.top_k_p) +def _top_k_p(operand: MyArray, k: int) -> MyArray: + raise replace(operand, array=lax.top_k(operand.array, k)) + + +# ============================================================================== + + +@register(lax.transpose_p) +def _transpose_p(operand: MyArray, *, permutation: Any) -> MyArray: + return replace(operand, array=lax.transpose(operand.array, permutation)) + + +# ============================================================================== + + +@register(lax.while_p) +def _while_p() -> MyArray: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.xor_p) +def _xor_p(x: MyArray, y: DenseArrayValue | MyArray) -> MyArray: + return MyArray(lax.bitwise_xor(x.array, y.array)) + + +# ============================================================================== + + +@register(lax.zeta_p) +def _zeta_p() -> MyArray: + raise NotImplementedError + + +############################################################################### + + +@dispatcher +def arange( + start: MyArray, + stop: MyArray | None = None, + step: MyArray | None = None, + *, + dtype: Any = None, + device: Any = None, +) -> MyArray: + return MyArray( + jax_xp.arange( + start.array, + stop=stop.array if stop is not None else None, + step=step.array if step is not None else None, + dtype=dtype, + device=device, + ), + ) + + +@dispatcher # type: ignore[misc] +def empty_like( + x: MyArray, + /, + *, + dtype: DType | None = None, + device: Device | None = None, +) -> MyArray: + return MyArray(jax_xp.empty_like(x.array, dtype=dtype, device=device)) + + +@dispatcher +def full_like( + x: MyArray, + /, + fill_value: bool | int | float | complex | MyArray, + *, + dtype: DType | None = None, + device: Device | None = None, +) -> MyArray: + return MyArray( + jax_xp.full_like(x.array, fill_value, dtype=dtype, device=device), + ) + + +@dispatcher +def linspace( + start: MyArray, + stop: MyArray, + num: int, + *, + dtype: DType | None = None, + device: Device | None = None, + endpoint: bool = True, +) -> MyArray: + return MyArray( + jax_xp.linspace( + start.array, + stop.array, + num=num, + dtype=dtype, + device=device, + endpoint=endpoint, + ), + ) + + +@dispatcher +def ones_like( + x: MyArray, + /, + dtype: DType | None = None, + device: Device | None = None, +) -> MyArray: + return MyArray(jax_xp.ones_like(x.array, dtype=dtype, device=device)) + + +@dispatcher +def zeros_like( + x: MyArray, + /, + dtype: DType | None = None, + device: Device | None = None, +) -> MyArray: + return MyArray(jax_xp.zeros_like(x.array, dtype=dtype, device=device)) diff --git a/tests/test_jax.py b/tests/test_jax.py new file mode 100644 index 0000000..7acee06 --- /dev/null +++ b/tests/test_jax.py @@ -0,0 +1,57 @@ +"""Test with JAX inputs.""" + + +from jax.experimental import array_api as jax_xp + +import array_api_jax_compat as xp + +# ============================================================================= +# Constants + + +def test_e(): + """Test `e`.""" + assert xp.e is jax_xp.e + + +def test_inf(): + """Test `inf`.""" + assert xp.inf is jax_xp.inf + + +def test_nan(): + """Test `nan`.""" + assert xp.nan is jax_xp.nan + + +def test_newaxis(): + """Test `newaxis`.""" + assert xp.newaxis is jax_xp.newaxis + + +def test_pi(): + """Test `pi`.""" + assert xp.pi is jax_xp.pi + + +# ============================================================================= +# Creation functions + + +# def test_arange(): +# """Test `arange`.""" +# # TODO: test the start, stop, step, dtype, device arguments +# got = xp.arange(3) +# expected = jax_xp.arange(3) + +# assert isinstance(got, jnp.ndarray) +# assert jnp.array_equal(got, expected) + + +# def test_asarray(): +# """Test `asarray`.""" +# got = xp.asarray([1, 2, 3]) +# expected = jax_xp.asarray([1, 2, 3]) + +# assert isinstance(got, jnp.ndarray) +# assert jnp.array_equal(got, expected) diff --git a/tests/test_myarray.py b/tests/test_myarray.py new file mode 100644 index 0000000..590f3da --- /dev/null +++ b/tests/test_myarray.py @@ -0,0 +1,1422 @@ +"""Test with :class:`quax.DenseArrayValue` inputs.""" + +import jax.experimental.array_api as jax_xp +import jax.numpy as jnp +import pytest +from jax import Array +from jax.experimental.array_api._data_type_functions import FInfo, IInfo +from jax.experimental.array_api._set_functions import ( + UniqueAllResult, + UniqueCountsResult, + UniqueInverseResult, +) +from myarray import MyArray + +import array_api_jax_compat as xp + +############################################################################### + +# ============================================================================= +# Constants + + +def test_e(): + """Test `e`.""" + assert not isinstance(xp.e, MyArray) + + +def test_inf(): + """Test `inf`.""" + assert not isinstance(xp.inf, MyArray) + + +def test_nan(): + """Test `nan`.""" + assert not isinstance(xp.nan, MyArray) + + +def test_newaxis(): + """Test `newaxis`.""" + assert not isinstance(xp.newaxis, MyArray) + + +def test_pi(): + """Test `pi`.""" + assert not isinstance(xp.pi, MyArray) + + +# ============================================================================= +# Creation functions + + +def test_arange(): + """Test `arange`.""" + # TODO: test the start, stop, step, dtype, device arguments + got = xp.arange(MyArray(3)) + expected = MyArray(jax_xp.arange(3)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_asarray(): + """Test `asarray`.""" + # TODO: test the dtype, device, copy arguments + got = xp.asarray(MyArray([1, 2, 3])) + expected = MyArray(jax_xp.asarray([1, 2, 3])) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +@pytest.mark.xfail(reason="returns a jax.Array") +def test_empty(): + """Test `empty`.""" + # TODO: test the dtype, device arguments + got = xp.empty((2, 3)) + assert isinstance(got, MyArray) + + +def test_empty_like(): + """Test `empty_like`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + got = xp.empty_like(x) + expected = MyArray(jax_xp.empty_like(x.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +@pytest.mark.xfail(reason="returns a jax.Array") +def test_eye(): + """Test `eye`.""" + got = xp.eye(3) + + assert isinstance(got, MyArray) + + +@pytest.mark.skip("TODO") +def test_from_dlpack(): + """Test `from_dlpack`.""" + + +@pytest.mark.xfail(reason="returns a jax.Array") +def test_full(): + """Test `full`.""" + got = xp.full((2, 3), 1.0) + + assert isinstance(got, MyArray) + + +def test_full_like(): + """Test `full_like`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + got = xp.full_like(x, 1.0) + expected = MyArray(jax_xp.full_like(x.array, 1.0)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_linspace(): + """Test `linspace`.""" + # TODO: test the dtype, device, endpoint arguments + got = xp.linspace(MyArray(0.0), MyArray(10.0), 11) + expected = MyArray(jax_xp.linspace(0.0, 10.0, 11)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_meshgrid(): + """Test `meshgrid`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + y = MyArray(xp.asarray([4, 5, 6], dtype=float)) + + got1, got2 = xp.meshgrid(x, y) + exp1, exp2 = jax_xp.meshgrid(x.array, y.array) + + assert isinstance(got1, MyArray) + assert jnp.array_equal(got1.array, exp1) + + assert isinstance(got2, MyArray) + assert jnp.array_equal(got2.array, exp2) + + +@pytest.mark.xfail(reason="returns a jax.Array") +def test_ones(): + """Test `ones`.""" + assert isinstance(xp.ones((2, 3)), MyArray) + + +def test_ones_like(): + """Test `ones_like`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + got = xp.ones_like(x) + expected = MyArray(jax_xp.ones_like(x.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_tril(): + """Test `tril`.""" + x = MyArray([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + got = xp.tril(x) + expected = MyArray(jax_xp.tril(x.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_triu(): + """Test `triu`.""" + x = MyArray([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + got = xp.triu(x) + expected = MyArray(jax_xp.triu(x.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +@pytest.mark.xfail(reason="returns a jax.Array") +def test_zeros(): + """Test `zeros`.""" + assert isinstance(xp.zeros((2, 3)), MyArray) + + +def test_zeros_like(): + """Test `zeros_like`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + got = xp.zeros_like(x) + expected = MyArray(jax_xp.zeros_like(x.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +# ============================================================================= +# Data-type functions + + +def test_astype(): + """Test `astype`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + got = xp.astype(x, jnp.float32) + expected = MyArray(jax_xp.asarray(x.array, dtype=jnp.float32)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +@pytest.mark.skip("TODO") +def test_can_cast(): + """Test `can_cast`.""" + # x = jax_xp.asarray([1, 2, 3], dtype=float) + # mx = MyArray(x) + + # assert xp.can_cast(x, float) + # assert xp.can_cast(mx, float) + + +def test_finfo(): + """Test `finfo`.""" + got = xp.finfo(jnp.float32) + expected = jax_xp.finfo(jnp.float32) + + assert isinstance(got, FInfo) + for attr in FInfo.__slots__: + assert getattr(got, attr) == getattr(expected, attr) + + +def test_iinfo(): + """Test `iinfo`.""" + got = xp.iinfo(jnp.int32) + expected = jax_xp.iinfo(jnp.int32) + + assert isinstance(got, IInfo) + for attr in IInfo.__slots__: + assert getattr(got, attr) == getattr(expected, attr) + + +def test_isdtype(): + """Test `isdtype`.""" + # True by definition + + +def test_result_type(): + """Test `result_type`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + y = MyArray(xp.asarray([4, 5, 6], dtype=float)) + got = xp.result_type(x, y) + expected = jax_xp.result_type(x.array, y.array) + + assert isinstance(got, jnp.dtype) + assert got == expected + + +# ============================================================================= +# Elementwise functions + + +def test_abs(): + """Test `abs`.""" + x = MyArray([-1, 2, -3]) + got = xp.abs(x) + expected = MyArray(jax_xp.abs(x.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_acos(): + """Test `acos`.""" + x = MyArray(xp.asarray([-1, 0, 1], dtype=float)) + got = xp.acos(x) + expected = MyArray(jax_xp.acos(x.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_acosh(): + """Test `acosh`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + got = xp.acosh(x) + expected = MyArray(jax_xp.acosh(x.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_add(): + """Test `add`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + y = MyArray(xp.asarray([4, 5, 6], dtype=float)) + got = xp.add(x, y) + expected = MyArray(jax_xp.add(x.array, y.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_asin(): + """Test `asin`.""" + x = MyArray(xp.asarray([-1, 0, 1], dtype=float)) + got = xp.asin(x) + expected = MyArray(jax_xp.asin(x.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_asinh(): + """Test `asinh`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + got = xp.asinh(x) + expected = MyArray(jax_xp.asinh(x.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_atan(): + """Test `atan`.""" + x = MyArray(xp.asarray([-1, 0, 1], dtype=float)) + got = xp.atan(x) + expected = MyArray(jax_xp.atan(x.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_atan2(): + """Test `atan2`.""" + x = MyArray(xp.asarray([-1, 0, 1], dtype=float)) + y = MyArray(xp.asarray([4, 5, 6], dtype=float)) + got = xp.atan2(x, y) + expected = MyArray(jax_xp.atan2(x.array, y.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_atanh(): + """Test `atanh`.""" + x = MyArray(xp.asarray([-1, 0, 1], dtype=float)) + got = xp.atanh(x) + expected = MyArray(jax_xp.atanh(x.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_bitwise_and(): + """Test `bitwise_and`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=int)) + y = MyArray(xp.asarray([4, 5, 6], dtype=int)) + got = xp.bitwise_and(x, y) + expected = MyArray(jax_xp.bitwise_and(x.array, y.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_bitwise_left_shift(): + """Test `bitwise_left_shift`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=int)) + y = MyArray(xp.asarray([4, 5, 6], dtype=int)) + got = xp.bitwise_left_shift(x, y) + expected = MyArray(jax_xp.bitwise_left_shift(x.array, y.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_bitwise_invert(): + """Test `bitwise_invert`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=int)) + got = xp.bitwise_invert(x) + expected = MyArray(jax_xp.bitwise_invert(x.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_bitwise_or(): + """Test `bitwise_or`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=int)) + y = MyArray(xp.asarray([4, 5, 6], dtype=int)) + got = xp.bitwise_or(x, y) + expected = MyArray(jax_xp.bitwise_or(x.array, y.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_bitwise_right_shift(): + """Test `bitwise_right_shift`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=int)) + y = MyArray(xp.asarray([4, 5, 6], dtype=int)) + got = xp.bitwise_right_shift(x, y) + expected = MyArray(jax_xp.bitwise_right_shift(x.array, y.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_bitwise_xor(): + """Test `bitwise_xor`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=int)) + y = MyArray(xp.asarray([4, 5, 6], dtype=int)) + got = xp.bitwise_xor(x, y) + expected = MyArray(jax_xp.bitwise_xor(x.array, y.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_ceil(): + """Test `ceil`.""" + x = MyArray([1.1, 2.2, 3.3]) + got = xp.ceil(x) + expected = MyArray(jax_xp.ceil(x.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_conj(): + """Test `conj`.""" + x = MyArray([1 + 2j, 3 + 4j]) + got = xp.conj(x) + expected = MyArray(jax_xp.conj(x.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_cos(): + """Test `cos`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + got = xp.cos(x) + expected = MyArray(jax_xp.cos(x.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_cosh(): + """Test `cosh`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + got = xp.cosh(x) + expected = MyArray(jax_xp.cosh(x.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_divide(): + """Test `divide`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + y = MyArray(xp.asarray([4, 5, 6], dtype=float)) + got = xp.divide(x, y) + expected = MyArray(jax_xp.divide(x.array, y.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_equal(): + """Test `equal`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + y = MyArray(xp.asarray([4, 5, 6], dtype=float)) + got = xp.equal(x, y) + expected = MyArray(jax_xp.equal(x.array, y.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_exp(): + """Test `exp`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + got = xp.exp(x) + expected = MyArray(jax_xp.exp(x.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_expm1(): + """Test `expm1`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + got = xp.expm1(x) + expected = MyArray(jax_xp.expm1(x.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_floor(): + """Test `floor`.""" + x = MyArray([1.1, 2.2, 3.3]) + got = xp.floor(x) + expected = MyArray(jax_xp.floor(x.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_floor_divide(): + """Test `floor_divide`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + y = MyArray(xp.asarray([4, 5, 6], dtype=float)) + got = xp.floor_divide(x, y) + expected = MyArray(jax_xp.floor_divide(x.array, y.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_greater(): + """Test `greater`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + y = MyArray(xp.asarray([4, 5, 6], dtype=float)) + got = xp.greater(x, y) + expected = MyArray(jax_xp.greater(x.array, y.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_greater_equal(): + """Test `greater_equal`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + y = MyArray(xp.asarray([4, 5, 6], dtype=float)) + got = xp.greater_equal(x, y) + expected = MyArray(jax_xp.greater_equal(x.array, y.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_imag(): + """Test `imag`.""" + x = MyArray([1 + 2j, 3 + 4j]) + got = xp.imag(x) + expected = MyArray(jax_xp.imag(x.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_isfinite(): + """Test `isfinite`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + got = xp.isfinite(x) + expected = MyArray(jax_xp.isfinite(x.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_isinf(): + """Test `isinf`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + got = xp.isinf(x) + expected = MyArray(jax_xp.isinf(x.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_isnan(): + """Test `isnan`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + got = xp.isnan(x) + expected = MyArray(jax_xp.isnan(x.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_less(): + """Test `less`.""" + x = MyArray([1, 5, 3]) + y = MyArray([4, 2, 6]) + got = xp.less(x, y) + expected = MyArray(jax_xp.less(x.array, y.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_less_equal(): + """Test `less_equal`.""" + x = MyArray([1, 5, 3]) + y = MyArray([4, 2, 6]) + got = xp.less_equal(x, y) + expected = MyArray(jax_xp.less_equal(x.array, y.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_log(): + """Test `log`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + got = xp.log(x) + expected = MyArray(jax_xp.log(x.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_log1p(): + """Test `log1p`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + got = xp.log1p(x) + expected = MyArray(jax_xp.log1p(x.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_log2(): + """Test `log2`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + got = xp.log2(x) + expected = MyArray(jax_xp.log2(x.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_log10(): + """Test `log10`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + got = xp.log10(x) + expected = MyArray(jax_xp.log10(x.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_logaddexp(): + """Test `logaddexp`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + y = MyArray(xp.asarray([4, 5, 6], dtype=float)) + got = xp.logaddexp(x, y) + expected = MyArray(jax_xp.logaddexp(x.array, y.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_logical_and(): + """Test `logical_and`.""" + x = MyArray([True, False, True]) + y = MyArray([False, True, False]) + got = xp.logical_and(x, y) + expected = MyArray(jax_xp.logical_and(x.array, y.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_logical_not(): + """Test `logical_not`.""" + x = MyArray([True, False, True]) + got = xp.logical_not(x) + expected = MyArray(jax_xp.logical_not(x.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_logical_or(): + """Test `logical_or`.""" + x = MyArray([True, False, True]) + y = MyArray([False, True, False]) + got = xp.logical_or(x, y) + expected = MyArray(jax_xp.logical_or(x.array, y.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_logical_xor(): + """Test `logical_xor`.""" + x = MyArray([True, False, True]) + y = MyArray([False, True, False]) + got = xp.logical_xor(x, y) + expected = MyArray(jax_xp.logical_xor(x.array, y.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_multiply(): + """Test `multiply`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + y = MyArray(xp.asarray([4, 5, 6], dtype=float)) + got = xp.multiply(x, y) + expected = MyArray(jax_xp.multiply(x.array, y.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_negative(): + """Test `negative`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + got = xp.negative(x) + expected = MyArray(jax_xp.negative(x.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_not_equal(): + """Test `not_equal`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + y = MyArray(xp.asarray([4, 2, 6], dtype=float)) + got = xp.not_equal(x, y) + expected = MyArray(jax_xp.not_equal(x.array, y.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_positive(): + """Test `positive`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + got = xp.positive(x) + expected = MyArray(jax_xp.positive(x.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_pow(): + """Test `pow`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + y = MyArray(xp.asarray([4, 2, 6], dtype=float)) + got = xp.pow(x, y) + expected = MyArray(jax_xp.pow(x.array, y.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_real(): + """Test `real`.""" + x = MyArray([1 + 2j, 3 + 4j]) + got = xp.real(x) + expected = MyArray(jax_xp.real(x.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_remainder(): + """Test `remainder`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + y = MyArray(xp.asarray([4, 5, 6], dtype=float)) + got = xp.remainder(x, y) + expected = MyArray(jax_xp.remainder(x.array, y.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_round(): + """Test `round`.""" + x = MyArray([1.1, 2.2, 3.3]) + got = xp.round(x) + expected = MyArray(jax_xp.round(x.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_sign(): + """Test `sign`.""" + x = MyArray([-1, 2, -3]) + got = xp.sign(x) + expected = MyArray(jax_xp.sign(x.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_sin(): + """Test `sin`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + got = xp.sin(x) + expected = MyArray(jax_xp.sin(x.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_sinh(): + """Test `sinh`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + got = xp.sinh(x) + expected = MyArray(jax_xp.sinh(x.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_square(): + """Test `square`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + got = xp.square(x) + expected = MyArray(jax_xp.square(x.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_sqrt(): + """Test `sqrt`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + got = xp.sqrt(x) + expected = MyArray(jax_xp.sqrt(x.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_subtract(): + """Test `subtract`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + y = MyArray(xp.asarray([4, 5, 6], dtype=float)) + got = xp.subtract(x, y) + expected = MyArray(jax_xp.subtract(x.array, y.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_tan(): + """Test `tan`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + got = xp.tan(x) + expected = MyArray(jax_xp.tan(x.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_tanh(): + """Test `tanh`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + got = xp.tanh(x) + expected = MyArray(jax_xp.tanh(x.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_trunc(): + """Test `trunc`.""" + x = MyArray([1.1, 2.2, 3.3]) + got = xp.trunc(x) + expected = MyArray(jax_xp.trunc(x.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +# ============================================================================= +# Indexing functions + + +def test_take(): + """Test `take`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + indices = MyArray(xp.asarray([0, 1, 2], dtype=int)) + got = xp.take(x, indices) + expected = MyArray(jax_xp.take(x.array, indices.array, axis=None)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +# ============================================================================= +# Linear algebra functions + + +def test_matmul(): + """Test `matmul`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + y = MyArray(xp.asarray([4, 5, 6], dtype=float)) + got = xp.matmul(x, y) + expected = MyArray(jax_xp.matmul(x.array, y.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_matrix_transpose(): + """Test `matrix_transpose`.""" + x = MyArray(xp.asarray([[1, 2, 3], [4, 5, 6]], dtype=float)) + got = xp.matrix_transpose(x) + expected = MyArray(jax_xp.matrix_transpose(x.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_tensordot(): + """Test `tensordot`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + y = MyArray(xp.asarray([4, 5, 6], dtype=float)) + axes = 1 + got = xp.tensordot(x, y, axes=axes) + expected = MyArray(jax_xp.tensordot(x.array, y.array, axes=axes)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_vecdot(): + """Test `vecdot`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + y = MyArray(xp.asarray([4, 5, 6], dtype=float)) + got = xp.vecdot(x, y) + expected = MyArray(jax_xp.vecdot(x.array, y.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +# ============================================================================= +# Manipulation functions + + +def test_broadcast_arrays(): + """Test `broadcast_arrays`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + y = MyArray(xp.asarray([4], dtype=float)) + got = xp.broadcast_arrays(x, y) + expected = jax_xp.broadcast_arrays(x.array, y.array) + + assert isinstance(got, tuple | list) + assert len(got) == len(expected) + for got_, expected_ in zip(got, expected, strict=True): + assert isinstance(got_, MyArray) + assert jnp.array_equal(got_.array, expected_) + + +def test_broadcast_to(): + """Test `broadcast_to`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + shape = (2, 3) + got = xp.broadcast_to(x, shape) + expected = MyArray(jax_xp.broadcast_to(x.array, shape)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_concat(): + """Test `concat`.""" + # TODO: test the axis argument + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + y = MyArray(xp.asarray([4], dtype=float)) + got = xp.concat((x, y)) + expected = MyArray(jax_xp.concat((x.array, y.array))) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_expand_dims(): + """Test `expand_dims`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + got = xp.expand_dims(x, axis=0) + expected = MyArray(jax_xp.expand_dims(x.array, axis=0)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_flip(): + """Test `flip`.""" + x = MyArray(xp.asarray([[1, 2, 3], [4, 5, 6]], dtype=float)) + got = xp.flip(x) + expected = MyArray(jax_xp.flip(x.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_permute_dims(): + """Test `permute_dims`.""" + x = MyArray(xp.asarray([[1, 2, 3], [4, 5, 6]], dtype=float)) + got = xp.permute_dims(x, (1, 0)) + expected = MyArray(jax_xp.permute_dims(x.array, (1, 0))) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_reshape(): + """Test `reshape`.""" + x = MyArray(xp.asarray([[1, 2, 3], [4, 5, 6]], dtype=float)) + got = xp.reshape(x, (3, 2)) + expected = MyArray(jax_xp.reshape(x.array, (3, 2))) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_roll(): + """Test `roll`.""" + x = MyArray(xp.asarray([[1, 2, 3], [4, 5, 6]], dtype=float)) + got = xp.roll(x, shift=1, axis=0) + expected = MyArray(jax_xp.roll(x.array, shift=1, axis=0)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_squeeze(): + """Test `squeeze`.""" + x = MyArray(xp.asarray([[[0], [1], [2]]], dtype=float)) + got = xp.squeeze(x, axis=(0, 2)) + expected = MyArray(jax_xp.squeeze(x.array, axis=(0, 2))) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_stack(): + """Test `stack`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + y = MyArray(xp.asarray([4, 5, 6], dtype=float)) + got = xp.stack((x, y)) + expected = MyArray(jax_xp.stack((x.array, y.array))) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +# ============================================================================= +# Searching functions + + +def test_argmax(): + """Test `argmax`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + got = xp.argmax(x) + expected = MyArray(jax_xp.argmax(x.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_argmin(): + """Test `argmin`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + got = xp.argmin(x) + expected = MyArray(jax_xp.argmin(x.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +def test_nonzero(): + """Test `nonzero`.""" + x = MyArray(xp.asarray([1, 2, 3], dtype=float)) + (got,) = xp.nonzero(x) + (expected,) = jax_xp.nonzero(x.array) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected) + + +def test_where(): + """Test `where`.""" + condition = MyArray(xp.asarray([True, False, True])) + y = MyArray(xp.asarray([1, 2, 3], dtype=float)) + z = MyArray(xp.asarray([4, 5, 6], dtype=float)) + got = xp.where(condition, y, z) + expected = MyArray(jax_xp.where(condition.array, y.array, z.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +# ============================================================================= +# Set functions + + +@pytest.mark.xfail(reason="value is not a MyArray") +def test_unique_all(): + """Test `unique_all`.""" + x = MyArray(xp.asarray([1, 2, 2, 3, 3, 3], dtype=float)) + got = xp.unique_all(x) + expected = jax_xp.unique_all(x.array) + + assert isinstance(got, UniqueAllResult) + + assert isinstance(got.values, MyArray) + assert jnp.array_equal(got.values, expected.values) + + assert isinstance(got.inverse, MyArray) + assert jnp.array_equal(got.inverse, expected.inverse) + + assert isinstance(got.inverse_indices, Array) + assert jnp.array_equal(got.inverse_indices, expected.inverse_indices) + + assert isinstance(got.counts, Array) + assert jnp.array_equal(got.counts, expected.counts) + + +@pytest.mark.xfail(reason="value is not a MyArray") +def test_unique_counts(): + """Test `unique_counts`.""" + x = MyArray(xp.asarray([1, 2, 2, 3, 3, 3], dtype=float)) + got = xp.unique_counts(x) + expected = jax_xp.unique_counts(x.array) + + assert isinstance(got, UniqueCountsResult) + + assert isinstance(got.values, MyArray) + assert jnp.array_equal(got.values.array, expected.values) + + assert isinstance(got.counts, Array) + assert jnp.array_equal(got.counts, expected.counts) + + +@pytest.mark.xfail(reason="value is not a MyArray") +def test_unique_inverse(): + """Test `unique_inverse`.""" + x = MyArray(xp.asarray([1, 2, 2, 3, 3, 3], dtype=float)) + got = xp.unique_inverse(x) + expected = jax_xp.unique_inverse(x.array) + + assert isinstance(got, UniqueInverseResult) + + assert isinstance(got.values, MyArray) + assert jnp.array_equal(got.values.array, expected.values) + + assert isinstance(got.inverse, MyArray) + assert jnp.array_equal(got.inverse.array, expected.inverse) + + +@pytest.mark.xfail(reason="value is not a MyArray") +def test_unique_values(): + """Test `unique_values`.""" + x = MyArray(xp.asarray([1, 2, 2, 3, 3, 3], dtype=float)) + got = xp.unique_values(x) + expected = MyArray(jax_xp.unique_values(x.array)) + + assert isinstance(got, MyArray) + assert jnp.array_equal(got.array, expected.array) + + +# ============================================================================= +# Sorting functions + + +@pytest.mark.skip("TODO") +def test_argsort(): + """Test `argsort`.""" + + +@pytest.mark.skip("TODO") +def test_sort(): + """Test `sort`.""" + + +# ============================================================================= +# Statistical functions + + +@pytest.mark.skip("TODO") +def test_max(): + """Test `max`.""" + + +@pytest.mark.skip("TODO") +def test_mean(): + """Test `mean`.""" + + +@pytest.mark.skip("TODO") +def test_min(): + """Test `min`.""" + + +@pytest.mark.skip("TODO") +def test_prod(): + """Test `prod`.""" + + +@pytest.mark.skip("TODO") +def test_std(): + """Test `std`.""" + + +@pytest.mark.skip("TODO") +def test_sum(): + """Test `sum`.""" + + +@pytest.mark.skip("TODO") +def test_var(): + """Test `var`.""" + + +# ============================================================================= +# Utility functions + + +@pytest.mark.skip("TODO") +def test_all(): + """Test `all`.""" + + +@pytest.mark.skip("TODO") +def test_any(): + """Test `any`.""" + + +# ============================================================================= +# FFT + + +@pytest.mark.skip("TODO") +def test_fft_fft(): + """Test `fft.fft`.""" + + +@pytest.mark.skip("TODO") +def test_fft_ifft(): + """Test `fft.ifft`.""" + + +@pytest.mark.skip("TODO") +def test_fft_fftn(): + """Test `fft.fftn`.""" + + +@pytest.mark.skip("TODO") +def test_fft_ifftn(): + """Test `fft.ifftn`.""" + + +@pytest.mark.skip("TODO") +def test_fft_rfft(): + """Test `fft.rfft`.""" + + +@pytest.mark.skip("TODO") +def test_fft_irfft(): + """Test `fft.irfft`.""" + + +@pytest.mark.skip("TODO") +def test_fft_rfftn(): + """Test `fft.rfftn`.""" + + +@pytest.mark.skip("TODO") +def test_fft_irfftn(): + """Test `fft.irfftn`.""" + + +@pytest.mark.skip("TODO") +def test_fft_hfft(): + """Test `fft.hfft`.""" + + +@pytest.mark.skip("TODO") +def test_fft_ihfft(): + """Test `fft.ihfft`.""" + + +@pytest.mark.skip("TODO") +def test_fft_fftfreq(): + """Test `fft.fftfreq`.""" + + +@pytest.mark.skip("TODO") +def test_fft_rfftfreq(): + """Test `fft.rfftfreq`.""" + + +@pytest.mark.skip("TODO") +def test_fft_fftshift(): + """Test `fft.fftshift`.""" + + +@pytest.mark.skip("TODO") +def test_fft_ifftshift(): + """Test `fft.ifftshift`.""" + + +# ============================================================================= +# Linalg + + +@pytest.mark.skip("TODO") +def test_linalg_cholesky(): + """Test `linalg.cholesky`.""" + + +@pytest.mark.skip("TODO") +def test_linalg_cross(): + """Test `linalg.cross`.""" + + +@pytest.mark.skip("TODO") +def test_linalg_det(): + """Test `linalg.det`.""" + + +@pytest.mark.skip("TODO") +def test_linalg_diagonal(): + """Test `linalg.diagonal`.""" + + +@pytest.mark.skip("TODO") +def test_linalg_eigh(): + """Test `linalg.eigh`.""" + + +@pytest.mark.skip("TODO") +def test_linalg_eigvalsh(): + """Test `linalg.eigvalsh`.""" + + +@pytest.mark.skip("TODO") +def test_linalg_inv(): + """Test `linalg.inv`.""" + + +@pytest.mark.skip("TODO") +def test_linalg_matmul(): + """Test `linalg.matmul`.""" + + +@pytest.mark.skip("TODO") +def test_linalg_matrix_norm(): + """Test `linalg.matrix_norm`.""" + + +@pytest.mark.skip("TODO") +def test_linalg_matrix_power(): + """Test `linalg.matrix_power`.""" + + +@pytest.mark.skip("TODO") +def test_linalg_matrix_rank(): + """Test `linalg.matrix_rank`.""" + + +@pytest.mark.skip("TODO") +def test_linalg_matrix_transpose(): + """Test `linalg.matrix_transpose`.""" + + +@pytest.mark.skip("TODO") +def test_linalg_outer(): + """Test `linalg.outer`.""" + + +@pytest.mark.skip("TODO") +def test_linalg_pinv(): + """Test `linalg.pinv`.""" + + +@pytest.mark.skip("TODO") +def test_linalg_qr(): + """Test `linalg.qr`.""" + + +@pytest.mark.skip("TODO") +def test_linalg_slogdet(): + """Test `linalg.slogdet`.""" + + +@pytest.mark.skip("TODO") +def test_linalg_solve(): + """Test `linalg.solve`.""" + + +@pytest.mark.skip("TODO") +def test_linalg_svd(): + """Test `linalg.svd`.""" + + +@pytest.mark.skip("TODO") +def test_linalg_svdvals(): + """Test `linalg.svdvals`.""" + + +@pytest.mark.skip("TODO") +def test_linalg_tensordot(): + """Test `linalg.tensordot`.""" + + +@pytest.mark.skip("TODO") +def test_linalg_trace(): + """Test `linalg.trace`.""" + + +@pytest.mark.skip("TODO") +def test_linalg_vecdot(): + """Test `linalg.vecdot`.""" + + +@pytest.mark.skip("TODO") +def test_linalg_vector_norm(): + """Test `linalg.vector_norm`."""