From 5494ba3dc8e5e058cd9755018afe48fa49a286e8 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 4 Oct 2023 23:23:57 +0000 Subject: [PATCH 01/12] Bump jax from 0.4.14 to 0.4.17 Bumps [jax](https://github.com/google/jax) from 0.4.14 to 0.4.17. - [Release notes](https://github.com/google/jax/releases) - [Changelog](https://github.com/google/jax/blob/main/CHANGELOG.md) - [Commits](https://github.com/google/jax/compare/jax-v0.4.14...jax-v0.4.17) --- updated-dependencies: - dependency-name: jax dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- requirements_jax.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements_jax.txt b/requirements_jax.txt index 76f9015e..48feeed9 100644 --- a/requirements_jax.txt +++ b/requirements_jax.txt @@ -1 +1 @@ -jax==0.4.14 +jax==0.4.17 From 00690d608d6bf373bcc92fc93b5411ceeb0ad371 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dion=20H=C3=A4fner?= Date: Thu, 5 Oct 2023 10:17:05 +0200 Subject: [PATCH 02/12] graceful fallback if custom calls are broken --- veros/core/operators.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/veros/core/operators.py b/veros/core/operators.py index 216501ff..5e626aea 100644 --- a/veros/core/operators.py +++ b/veros/core/operators.py @@ -104,7 +104,12 @@ def solve_tridiagonal_jax(a, b, c, d, water_mask, edge_mask, use_ext=None): import jax.lax import jax.numpy as jnp - from veros.core.special.tdma_ import tdma, HAS_CPU_EXT, HAS_GPU_EXT + try: + from veros.core.special.tdma_ import tdma, HAS_CPU_EXT, HAS_GPU_EXT + except ImportError: + # graceful fallback if TDMA extension is broken + HAS_CPU_EXT = False + HAS_GPU_EXT = False if use_ext is None: use_ext = (HAS_CPU_EXT and runtime_settings.device == "cpu") or ( From 4abe3351341a3c74dee8378d49e31399a281194d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dion=20H=C3=A4fner?= Date: Tue, 10 Oct 2023 10:34:18 +0200 Subject: [PATCH 03/12] fix jax custom calls --- veros/core/operators.py | 23 ++++++--- veros/core/special/tdma_.py | 98 ++++++++++++++++++++----------------- 2 files changed, 68 insertions(+), 53 deletions(-) diff --git a/veros/core/operators.py b/veros/core/operators.py index 5e626aea..9dd826c5 100644 --- a/veros/core/operators.py +++ b/veros/core/operators.py @@ -107,20 +107,27 @@ def solve_tridiagonal_jax(a, b, c, d, water_mask, edge_mask, use_ext=None): try: from veros.core.special.tdma_ import tdma, HAS_CPU_EXT, HAS_GPU_EXT except ImportError: - # graceful fallback if TDMA extension is broken - HAS_CPU_EXT = False - HAS_GPU_EXT = False - - if use_ext is None: - use_ext = (HAS_CPU_EXT and runtime_settings.device == "cpu") or ( + if use_ext: + raise + has_ext = False + else: + has_ext = (HAS_CPU_EXT and runtime_settings.device == "cpu") or ( HAS_GPU_EXT and runtime_settings.device == "gpu" ) + if use_ext is None: + if not has_ext: + warnings.warn("Could not use custom TDMA implementation, falling back to pure JAX") + use_ext = False + else: + use_ext = True + + if use_ext and not has_ext: + raise RuntimeError("Could not use custom TDMA implementation") + if use_ext: return tdma(a, b, c, d, water_mask, edge_mask) - warnings.warn("Could not use custom TDMA implementation, falling back to pure JAX") - a = water_mask * a * jnp.logical_not(edge_mask) b = jnp.where(water_mask, b, 1.0) c = water_mask * c diff --git a/veros/core/special/tdma_.py b/veros/core/special/tdma_.py index c5898314..c8f413ca 100644 --- a/veros/core/special/tdma_.py +++ b/veros/core/special/tdma_.py @@ -16,12 +16,13 @@ import jax.numpy as jnp import jax -from jax import abstract_arrays -from jax.core import Primitive +from jax.core import Primitive, ShapedArray + from jax.lib import xla_client -from jax.interpreters import xla +from jax.interpreters import xla, mlir +import jaxlib.mlir.ir as ir +from jaxlib.mlir.dialects import mhlo -_ops = xla_client.ops if HAS_CPU_EXT: for kernel_name in (b"tdma_cython_double", b"tdma_cython_float"): @@ -31,11 +32,16 @@ if HAS_GPU_EXT: for kernel_name in (b"tdma_cuda_double", b"tdma_cuda_float"): fn = tdma_cuda_.gpu_custom_call_targets[kernel_name] - xla_client.register_custom_call_target(kernel_name, fn, platform="gpu") + xla_client.register_custom_call_target(kernel_name, fn, platform="CUDA") + +def as_mhlo_constant(val, dtype): + if isinstance(val, mhlo.ConstantOp): + return val -def _constant_s64_scalar(c, x): - return _ops.Constant(c, np.int64(x)) + return mhlo.ConstantOp( + ir.DenseElementsAttr.get(np.array([val], dtype=dtype), type=mlir.dtype_to_ir_type(np.dtype(dtype))) + ).result def tdma(a, b, c, d, interior_mask, edge_mask, device=None): @@ -64,20 +70,23 @@ def tdma_impl(*args, **kwargs): return xla.apply_primitive(tdma_p, *args, **kwargs) -def tdma_xla_encode_cpu(builder, a, b, c, d, system_depths): +def tdma_xla_encode_cpu(ctx, a, b, c, d, system_depths): # try import again to trigger exception on ImportError from veros.core.special import tdma_cython_ # noqa: F401 - x_shape = builder.GetShape(a) - dtype = x_shape.element_type() - dims = x_shape.dimensions() + x_aval, *_ = ctx.avals_in + np_dtype = x_aval.dtype + + x_type = ir.RankedTensorType(a.type) + dtype = x_type.element_type + dims = x_type.shape supported_dtypes = ( np.dtype(np.float32), np.dtype(np.float64), ) - if dtype not in supported_dtypes: + if np_dtype not in supported_dtypes: raise TypeError(f"TDMA only supports {supported_dtypes} arrays, got: {dtype}") # compute number of elements to vectorize over @@ -87,24 +96,19 @@ def tdma_xla_encode_cpu(builder, a, b, c, d, system_depths): stride = dims[-1] - sys_depth_shape = builder.get_shape(system_depths) - sys_depth_dtype = sys_depth_shape.element_type() - sys_depth_dims = sys_depth_shape.dimensions() - assert sys_depth_dtype is np.dtype(np.int32) - assert tuple(sys_depth_dims) == tuple(dims[:-1]) - - arr_shape = xla_client.Shape.array_shape(dtype, dims) - out_shape = xla_client.Shape.tuple_shape([arr_shape, xla_client.Shape.array_shape(dtype, (stride,))]) + out_types = [ + ir.RankedTensorType.get(dims, dtype), + ir.RankedTensorType.get((stride,), dtype), + ] - if dtype is np.dtype(np.float32): + if np_dtype is np.dtype(np.float32): kernel = b"tdma_cython_float" - elif dtype is np.dtype(np.float64): + elif np_dtype is np.dtype(np.float64): kernel = b"tdma_cython_double" else: raise RuntimeError("got unrecognized dtype") - out = _ops.CustomCall( - builder, + out = mlir.custom_call( kernel, operands=( a, @@ -112,31 +116,34 @@ def tdma_xla_encode_cpu(builder, a, b, c, d, system_depths): c, d, system_depths, - _constant_s64_scalar(builder, num_systems), - _constant_s64_scalar(builder, stride), + as_mhlo_constant(num_systems, np.int64), + as_mhlo_constant(stride, np.int64), ), - shape=out_shape, + result_types=out_types, ) - return _ops.GetTupleElement(out, 0) + return out.results[:-1] -def tdma_xla_encode_gpu(builder, a, b, c, d, system_depths): +def tdma_xla_encode_gpu(ctx, a, b, c, d, system_depths): # try import again to trigger exception on ImportError from veros.core.special import tdma_cuda_ # noqa: F401 if system_depths is not None: raise ValueError("TDMA does not support system_depths argument on GPU") - a_shape = builder.get_shape(a) - dtype = a_shape.element_type() - dims = a_shape.dimensions() + x_aval, *_ = ctx.avals_in + x_nptype = x_aval.dtype + + x_type = ir.RankedTensorType(a.type) + dtype = x_type.element_type + dims = x_type.shape supported_dtypes = ( np.dtype(np.float32), np.dtype(np.float64), ) - if dtype not in supported_dtypes: + if x_nptype not in supported_dtypes: raise TypeError(f"TDMA only supports {supported_dtypes} arrays, got: {dtype}") # compute number of elements to vectorize over @@ -153,31 +160,32 @@ def tdma_xla_encode_gpu(builder, a, b, c, d, system_depths): else: raise RuntimeError("got unrecognized dtype") - opaque = tdma_cuda_.build_tridiag_descriptor(num_systems, system_depth) + descriptor = tdma_cuda_.build_tridiag_descriptor(num_systems, system_depth) ndims = len(dims) arr_layout = tuple(range(ndims - 2, -1, -1)) + (ndims - 1,) - arr_shape = xla_client.Shape.array_shape(dtype, dims, arr_layout) - out_shape = xla_client.Shape.tuple_shape([arr_shape, arr_shape]) - out = _ops.CustomCallWithLayout( - builder, + out_types = [ir.RankedTensorType.get(dims, dtype), ir.RankedTensorType.get(dims, dtype)] + out_layouts = (arr_layout, arr_layout) + + out = mlir.custom_call( kernel, operands=(a, b, c, d), - shape_with_layout=out_shape, - operand_shapes_with_layout=(arr_shape,) * 4, - opaque=opaque, + result_tyes=out_types, + result_layouts=out_layouts, + operand_layouts=(arr_layout,) * 4, + backend_config=descriptor, ) - return _ops.GetTupleElement(out, 0) + return out.results[:-1] def tdma_abstract_eval(a, b, c, d, system_depths): - return abstract_arrays.ShapedArray(a.shape, a.dtype) + return ShapedArray(a.shape, a.dtype) tdma_p = Primitive("tdma") tdma_p.def_impl(tdma_impl) tdma_p.def_abstract_eval(tdma_abstract_eval) -xla.backend_specific_translations["cpu"][tdma_p] = tdma_xla_encode_cpu -xla.backend_specific_translations["gpu"][tdma_p] = tdma_xla_encode_gpu +mlir.register_lowering(tdma_p, tdma_xla_encode_cpu, platform="cpu") +mlir.register_lowering(tdma_p, tdma_xla_encode_gpu, platform="cuda") From f6b5ce8464982adb9ff8914a52c0b9d9bf1f0036 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dion=20H=C3=A4fner?= Date: Tue, 10 Oct 2023 11:08:20 +0200 Subject: [PATCH 04/12] drop python 3.7 --- .github/workflows/test-all.yml | 2 +- .github/workflows/test-install.yml | 4 ++-- doc/tutorial/erda.rst | 16 +++++++++++++--- setup.py | 5 +++-- 4 files changed, 19 insertions(+), 8 deletions(-) diff --git a/.github/workflows/test-all.yml b/.github/workflows/test-all.yml index 8a50a7ae..a789a089 100644 --- a/.github/workflows/test-all.yml +++ b/.github/workflows/test-all.yml @@ -18,7 +18,7 @@ jobs: matrix: os: [ubuntu-20.04] - python-version: ["3.7", "3.10"] + python-version: ["3.8", "3.11"] backend: [numpy, jax] env: diff --git a/.github/workflows/test-install.yml b/.github/workflows/test-install.yml index a5799e2d..993ca562 100644 --- a/.github/workflows/test-install.yml +++ b/.github/workflows/test-install.yml @@ -20,13 +20,13 @@ jobs: matrix: os: [ubuntu-latest, macos-latest, windows-latest] - python-version: ["3.7", "3.10"] + python-version: ["3.8", "3.11"] nocc: [false] include: # also test whether installation without C compiler works - os: ubuntu-latest - python-version: "3.7" + python-version: "3.8" nocc: true runs-on: ${{ matrix.os }} diff --git a/doc/tutorial/erda.rst b/doc/tutorial/erda.rst index 39635529..4562f975 100644 --- a/doc/tutorial/erda.rst +++ b/doc/tutorial/erda.rst @@ -141,12 +141,22 @@ In order to install Veros with the `veros-bgc biogeochemistry plugin Date: Tue, 10 Oct 2023 11:26:06 +0200 Subject: [PATCH 05/12] :bug: --- .github/workflows/test-all.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-all.yml b/.github/workflows/test-all.yml index a789a089..0aa815c9 100644 --- a/.github/workflows/test-all.yml +++ b/.github/workflows/test-all.yml @@ -48,7 +48,7 @@ jobs: path: ~/.cache/pip key: ${{ matrix.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('**/requirements*.txt') }} restore-keys: | - ${{ matrix.os }}-pip- + ${{ matrix.os }}-pip-${{ matrix.python-version }}- - name: Restore PyOM2 build cache uses: actions/cache@v2 From e24e1c562efb62e1dd63b39238a8bed982c96137 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dion=20H=C3=A4fner?= Date: Tue, 10 Oct 2023 11:31:39 +0200 Subject: [PATCH 06/12] :bug: --- .github/workflows/test-all.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test-all.yml b/.github/workflows/test-all.yml index 0aa815c9..2db7b6d6 100644 --- a/.github/workflows/test-all.yml +++ b/.github/workflows/test-all.yml @@ -17,13 +17,13 @@ jobs: fail-fast: false matrix: - os: [ubuntu-20.04] + os: [ubuntu-22.04] python-version: ["3.8", "3.11"] backend: [numpy, jax] env: PYOM2_DIR: /home/runner/pyom2 - PETSC_VERSION: 3.15 + PETSC_VERSION: 3.20 PETSC_DIR: /home/runner/petsc PETSC_ARCH: arch-linux-c-opt OMPI_MCA_rmaps_base_oversubscribe: "1" From cba67614fc6199c1a242b40de2395e70b1066cb3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dion=20H=C3=A4fner?= Date: Tue, 10 Oct 2023 11:34:16 +0200 Subject: [PATCH 07/12] :bug: --- .github/workflows/test-all.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-all.yml b/.github/workflows/test-all.yml index 2db7b6d6..87476d63 100644 --- a/.github/workflows/test-all.yml +++ b/.github/workflows/test-all.yml @@ -80,7 +80,7 @@ jobs: run: | git clone -b v$PETSC_VERSION --depth 1 https://gitlab.com/petsc/petsc.git $PETSC_DIR pushd $PETSC_DIR - ./configure --with-debugging=0 -with-shared-libraries --with-precision=double + python2 ./configure --with-debugging=0 -with-shared-libraries --with-precision=double make all popd From 08d78f6429214e6e4967f3f901090d0c3a78e5b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dion=20H=C3=A4fner?= Date: Tue, 10 Oct 2023 11:37:22 +0200 Subject: [PATCH 08/12] :bug: --- .github/workflows/test-all.yml | 18 +----------------- 1 file changed, 1 insertion(+), 17 deletions(-) diff --git a/.github/workflows/test-all.yml b/.github/workflows/test-all.yml index 87476d63..84b0c1bb 100644 --- a/.github/workflows/test-all.yml +++ b/.github/workflows/test-all.yml @@ -57,13 +57,6 @@ jobs: path: ${{ env.PYOM2_DIR }}/py_src/*.so key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('vendor/pyom2/patches/*.patch') }}-${{ hashFiles('requirements.txt') }} - - name: Restore PETSc build cache - uses: actions/cache@v2 - id: petsc-cache - with: - path: ${{ env.PETSC_DIR }} - key: ${{ matrix.os }}-${{ env.PETSC_VERSION }} - - name: Setup Python environment run: | python -m pip install --upgrade pip @@ -75,15 +68,6 @@ jobs: sudo apt-get update sudo apt-get install libopenmpi-dev - - name: Build PETSc - if: steps.petsc-cache.outputs.cache-hit != 'true' - run: | - git clone -b v$PETSC_VERSION --depth 1 https://gitlab.com/petsc/petsc.git $PETSC_DIR - pushd $PETSC_DIR - python2 ./configure --with-debugging=0 -with-shared-libraries --with-precision=double - make all - popd - - name: Install Veros run: | pip install mpi4py @@ -94,7 +78,7 @@ jobs: else pip install -e .[test] fi - pip install petsc4py==$PETSC_VERSION --no-deps + pip install petsc==$PETSC_VERSION petsc4py==$PETSC_VERSION --no-deps # Build PyOM2 after Veros to make sure we have compatible versions of NumPy / f2py - name: Build PyOM2 From 91176b5fcceab567f3ea038914b708eed4e495e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dion=20H=C3=A4fner?= Date: Tue, 10 Oct 2023 11:41:58 +0200 Subject: [PATCH 09/12] :bug: --- .github/workflows/test-all.yml | 20 ++++++++++++++++++-- doc/tutorial/erda.rst | 15 ++++++++++----- 2 files changed, 28 insertions(+), 7 deletions(-) diff --git a/.github/workflows/test-all.yml b/.github/workflows/test-all.yml index 84b0c1bb..c084d6e1 100644 --- a/.github/workflows/test-all.yml +++ b/.github/workflows/test-all.yml @@ -23,7 +23,7 @@ jobs: env: PYOM2_DIR: /home/runner/pyom2 - PETSC_VERSION: 3.20 + PETSC_VERSION: "3.20" PETSC_DIR: /home/runner/petsc PETSC_ARCH: arch-linux-c-opt OMPI_MCA_rmaps_base_oversubscribe: "1" @@ -57,6 +57,13 @@ jobs: path: ${{ env.PYOM2_DIR }}/py_src/*.so key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('vendor/pyom2/patches/*.patch') }}-${{ hashFiles('requirements.txt') }} + - name: Restore PETSc build cache + uses: actions/cache@v2 + id: petsc-cache + with: + path: ${{ env.PETSC_DIR }} + key: ${{ matrix.os }}-${{ env.PETSC_VERSION }} + - name: Setup Python environment run: | python -m pip install --upgrade pip @@ -68,6 +75,15 @@ jobs: sudo apt-get update sudo apt-get install libopenmpi-dev + - name: Build PETSc + if: steps.petsc-cache.outputs.cache-hit != 'true' + run: | + git clone -b v$PETSC_VERSION --depth 1 https://gitlab.com/petsc/petsc.git $PETSC_DIR + pushd $PETSC_DIR + ./configure --with-debugging=0 -with-shared-libraries --with-precision=double + make all + popd + - name: Install Veros run: | pip install mpi4py @@ -78,7 +94,7 @@ jobs: else pip install -e .[test] fi - pip install petsc==$PETSC_VERSION petsc4py==$PETSC_VERSION --no-deps + pip install petsc4py==$PETSC_VERSION --no-deps # Build PyOM2 after Veros to make sure we have compatible versions of NumPy / f2py - name: Build PyOM2 diff --git a/doc/tutorial/erda.rst b/doc/tutorial/erda.rst index 4562f975..2133a3ac 100644 --- a/doc/tutorial/erda.rst +++ b/doc/tutorial/erda.rst @@ -71,15 +71,20 @@ Data Analysis Gateway (DAG) In order to install Veros on a DAG instance do the following after launching the **Terminal**: -1. Clone the Veros repository +1. Clone the Veros repository: .. exec:: from veros import __version__ as veros_version + if "0+untagged" in veros_version: + veros_version = "main" + else: + veros_version = f"v{veros_version}" if "+" in veros_version: - veros_version, _ = veros_version.split("+") + veros_version, _ = veros_version.split("+") print(".. code-block::\n") - print(f" $ git clone https://github.com/team-ocean/veros.git -b v{veros_version}") + print(" $ cd ~/modi_mount") + print(f" $ git clone https://github.com/team-ocean/veros.git -b {veros_version}") (or `any other version of Veros `__). @@ -139,9 +144,9 @@ MPI Oriented Development and Investigation (MODI) In order to install Veros with the `veros-bgc biogeochemistry plugin `__ start an **Ocean HPC Notebook** from the **Jupyter service** home page following :ref:`the instructions above `. -1. Launch the **Terminal**, change your current directory to ~/modi_mount and clone the Veros repository:: +1. Launch the **Terminal**, change your current directory to ~/modi_mount and clone the Veros repository: -.. exec:: + .. exec:: from veros import __version__ as veros_version if "0+untagged" in veros_version: From 88038c694d4ff3ac814a75185e7c673f57a8e1f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dion=20H=C3=A4fner?= Date: Tue, 10 Oct 2023 11:43:56 +0200 Subject: [PATCH 10/12] :bug: --- .github/workflows/test-all.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-all.yml b/.github/workflows/test-all.yml index c084d6e1..8231e1dd 100644 --- a/.github/workflows/test-all.yml +++ b/.github/workflows/test-all.yml @@ -23,7 +23,7 @@ jobs: env: PYOM2_DIR: /home/runner/pyom2 - PETSC_VERSION: "3.20" + PETSC_VERSION: "3.20.0" PETSC_DIR: /home/runner/petsc PETSC_ARCH: arch-linux-c-opt OMPI_MCA_rmaps_base_oversubscribe: "1" From b5a7de3ddab49b591eee906f2eaca855d330e745 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dion=20H=C3=A4fner?= Date: Tue, 10 Oct 2023 12:06:38 +0200 Subject: [PATCH 11/12] :bug: --- .github/workflows/test-all.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-all.yml b/.github/workflows/test-all.yml index 8231e1dd..2f9d93a9 100644 --- a/.github/workflows/test-all.yml +++ b/.github/workflows/test-all.yml @@ -17,7 +17,7 @@ jobs: fail-fast: false matrix: - os: [ubuntu-22.04] + os: [ubuntu-20.04] python-version: ["3.8", "3.11"] backend: [numpy, jax] From 906077431f02f636ce34b9d62c2495dd7373b432 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dion=20H=C3=A4fner?= Date: Tue, 10 Oct 2023 13:24:03 +0200 Subject: [PATCH 12/12] :bug: --- veros/core/special/tdma_.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/veros/core/special/tdma_.py b/veros/core/special/tdma_.py index c8f413ca..3b6f12ab 100644 --- a/veros/core/special/tdma_.py +++ b/veros/core/special/tdma_.py @@ -1,3 +1,4 @@ +# defensive imports since extensions are optional try: from veros.core.special import tdma_cython_ except ImportError: @@ -13,16 +14,31 @@ HAS_GPU_EXT = True import numpy as np -import jax.numpy as jnp import jax +import jax.numpy as jnp from jax.core import Primitive, ShapedArray - from jax.lib import xla_client from jax.interpreters import xla, mlir import jaxlib.mlir.ir as ir from jaxlib.mlir.dialects import mhlo +try: + from jax.interpreters.mlir import custom_call # noqa: F401 +except ImportError: + # TODO: remove once we require jax > 0.4.16 + from jaxlib.hlo_helpers import custom_call as _custom_call + + # Recent versions return a structure with a field 'results'. We mock it on + # older versions + from collections import namedtuple + + MockResult = namedtuple("MockResult", ["results"]) + + def custom_call(*args, result_types, **kwargs): + results = _custom_call(*args, out_types=result_types, **kwargs) + return MockResult(results) + if HAS_CPU_EXT: for kernel_name in (b"tdma_cython_double", b"tdma_cython_float"): @@ -108,7 +124,7 @@ def tdma_xla_encode_cpu(ctx, a, b, c, d, system_depths): else: raise RuntimeError("got unrecognized dtype") - out = mlir.custom_call( + out = custom_call( kernel, operands=( a, @@ -168,7 +184,7 @@ def tdma_xla_encode_gpu(ctx, a, b, c, d, system_depths): out_types = [ir.RankedTensorType.get(dims, dtype), ir.RankedTensorType.get(dims, dtype)] out_layouts = (arr_layout, arr_layout) - out = mlir.custom_call( + out = custom_call( kernel, operands=(a, b, c, d), result_tyes=out_types,