Skip to content

Commit

Permalink
Merge pull request #539 from team-ocean/dependabot/pip/jax-0.4.17
Browse files Browse the repository at this point in the history
Bump jax from 0.4.14 to 0.4.17
  • Loading branch information
dionhaefner authored Oct 10, 2023
2 parents 7dee403 + 9060774 commit 00b4cc0
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 67 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/test-all.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ jobs:

matrix:
os: [ubuntu-20.04]
python-version: ["3.7", "3.10"]
python-version: ["3.8", "3.11"]
backend: [numpy, jax]

env:
PYOM2_DIR: /home/runner/pyom2
PETSC_VERSION: 3.15
PETSC_VERSION: "3.20.0"
PETSC_DIR: /home/runner/petsc
PETSC_ARCH: arch-linux-c-opt
OMPI_MCA_rmaps_base_oversubscribe: "1"
Expand All @@ -48,7 +48,7 @@ jobs:
path: ~/.cache/pip
key: ${{ matrix.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('**/requirements*.txt') }}
restore-keys: |
${{ matrix.os }}-pip-
${{ matrix.os }}-pip-${{ matrix.python-version }}-
- name: Restore PyOM2 build cache
uses: actions/cache@v2
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/test-install.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ jobs:

matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: ["3.7", "3.10"]
python-version: ["3.8", "3.11"]
nocc: [false]

include:
# also test whether installation without C compiler works
- os: ubuntu-latest
python-version: "3.7"
python-version: "3.8"
nocc: true

runs-on: ${{ matrix.os }}
Expand Down
29 changes: 22 additions & 7 deletions doc/tutorial/erda.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,20 @@ Data Analysis Gateway (DAG)

In order to install Veros on a DAG instance do the following after launching the **Terminal**:

1. Clone the Veros repository
1. Clone the Veros repository:

.. exec::

from veros import __version__ as veros_version
if "0+untagged" in veros_version:
veros_version = "main"
else:
veros_version = f"v{veros_version}"
if "+" in veros_version:
veros_version, _ = veros_version.split("+")
veros_version, _ = veros_version.split("+")
print(".. code-block::\n")
print(f" $ git clone https://github.com/team-ocean/veros.git -b v{veros_version}")
print(" $ cd ~/modi_mount")
print(f" $ git clone https://github.com/team-ocean/veros.git -b {veros_version}")

(or `any other version of Veros <https://github.com/team-ocean/veros/releases>`__).

Expand Down Expand Up @@ -139,14 +144,24 @@ MPI Oriented Development and Investigation (MODI)

In order to install Veros with the `veros-bgc biogeochemistry plugin <https://veros-bgc.readthedocs.io/en/latest/>`__ start an **Ocean HPC Notebook** from the **Jupyter service** home page following :ref:`the instructions above <erda-jupyter>`.

1. Launch the **Terminal**, change your current directory to ~/modi_mount and clone the Veros repository::
1. Launch the **Terminal**, change your current directory to ~/modi_mount and clone the Veros repository:

$ cd ~/modi_mount
$ git clone https://github.com/team-ocean/veros.git -b v0.2.3
.. exec::

from veros import __version__ as veros_version
if "0+untagged" in veros_version:
veros_version = "main"
else:
veros_version = f"v{veros_version}"
if "+" in veros_version:
veros_version, _ = veros_version.split("+")
print(".. code-block::\n")
print(" $ cd ~/modi_mount")
print(f" $ git clone https://github.com/team-ocean/veros.git -b {veros_version}")

2. Create a new conda environment for Veros::

$ conda create --prefix ~/modi_mount/conda-env-veros -y python=3.7
$ conda create --prefix ~/modi_mount/conda-env-veros -y python=3.11

3. To use the new environment, activate it via::

Expand Down
2 changes: 1 addition & 1 deletion requirements_jax.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
jax==0.4.14
jax==0.4.17
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@
Intended Audience :: Science/Research
License :: OSI Approved :: MIT License
Programming Language :: Python :: 3
Programming Language :: Python :: 3.7
Programming Language :: Python :: 3.8
Programming Language :: Python :: 3.9
Programming Language :: Python :: 3.10
Programming Language :: Python :: 3.11
Programming Language :: Python :: 3.12
Programming Language :: Python :: Implementation :: CPython
Topic :: Scientific/Engineering
Operating System :: Microsoft :: Windows
Expand Down Expand Up @@ -158,7 +159,7 @@ def _env_to_bool(envvar):
long_description=long_description,
long_description_content_type="text/markdown",
url="https://veros.readthedocs.io",
python_requires=">=3.7",
python_requires=">=3.8",
version=versioneer.get_version(),
cmdclass=cmdclass,
packages=find_packages(),
Expand Down
24 changes: 18 additions & 6 deletions veros/core/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,18 +104,30 @@ def solve_tridiagonal_jax(a, b, c, d, water_mask, edge_mask, use_ext=None):
import jax.lax
import jax.numpy as jnp

from veros.core.special.tdma_ import tdma, HAS_CPU_EXT, HAS_GPU_EXT

if use_ext is None:
use_ext = (HAS_CPU_EXT and runtime_settings.device == "cpu") or (
try:
from veros.core.special.tdma_ import tdma, HAS_CPU_EXT, HAS_GPU_EXT
except ImportError:
if use_ext:
raise
has_ext = False
else:
has_ext = (HAS_CPU_EXT and runtime_settings.device == "cpu") or (
HAS_GPU_EXT and runtime_settings.device == "gpu"
)

if use_ext is None:
if not has_ext:
warnings.warn("Could not use custom TDMA implementation, falling back to pure JAX")
use_ext = False
else:
use_ext = True

if use_ext and not has_ext:
raise RuntimeError("Could not use custom TDMA implementation")

if use_ext:
return tdma(a, b, c, d, water_mask, edge_mask)

warnings.warn("Could not use custom TDMA implementation, falling back to pure JAX")

a = water_mask * a * jnp.logical_not(edge_mask)
b = jnp.where(water_mask, b, 1.0)
c = water_mask * c
Expand Down
116 changes: 70 additions & 46 deletions veros/core/special/tdma_.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# defensive imports since extensions are optional
try:
from veros.core.special import tdma_cython_
except ImportError:
Expand All @@ -13,15 +14,31 @@
HAS_GPU_EXT = True

import numpy as np
import jax.numpy as jnp

import jax
from jax import abstract_arrays
from jax.core import Primitive
import jax.numpy as jnp
from jax.core import Primitive, ShapedArray
from jax.lib import xla_client
from jax.interpreters import xla
from jax.interpreters import xla, mlir
import jaxlib.mlir.ir as ir
from jaxlib.mlir.dialects import mhlo

try:
from jax.interpreters.mlir import custom_call # noqa: F401
except ImportError:
# TODO: remove once we require jax > 0.4.16
from jaxlib.hlo_helpers import custom_call as _custom_call

# Recent versions return a structure with a field 'results'. We mock it on
# older versions
from collections import namedtuple

MockResult = namedtuple("MockResult", ["results"])

def custom_call(*args, result_types, **kwargs):
results = _custom_call(*args, out_types=result_types, **kwargs)
return MockResult(results)

_ops = xla_client.ops

if HAS_CPU_EXT:
for kernel_name in (b"tdma_cython_double", b"tdma_cython_float"):
Expand All @@ -31,11 +48,16 @@
if HAS_GPU_EXT:
for kernel_name in (b"tdma_cuda_double", b"tdma_cuda_float"):
fn = tdma_cuda_.gpu_custom_call_targets[kernel_name]
xla_client.register_custom_call_target(kernel_name, fn, platform="gpu")
xla_client.register_custom_call_target(kernel_name, fn, platform="CUDA")


def _constant_s64_scalar(c, x):
return _ops.Constant(c, np.int64(x))
def as_mhlo_constant(val, dtype):
if isinstance(val, mhlo.ConstantOp):
return val

return mhlo.ConstantOp(
ir.DenseElementsAttr.get(np.array([val], dtype=dtype), type=mlir.dtype_to_ir_type(np.dtype(dtype)))
).result


def tdma(a, b, c, d, interior_mask, edge_mask, device=None):
Expand Down Expand Up @@ -64,20 +86,23 @@ def tdma_impl(*args, **kwargs):
return xla.apply_primitive(tdma_p, *args, **kwargs)


def tdma_xla_encode_cpu(builder, a, b, c, d, system_depths):
def tdma_xla_encode_cpu(ctx, a, b, c, d, system_depths):
# try import again to trigger exception on ImportError
from veros.core.special import tdma_cython_ # noqa: F401

x_shape = builder.GetShape(a)
dtype = x_shape.element_type()
dims = x_shape.dimensions()
x_aval, *_ = ctx.avals_in
np_dtype = x_aval.dtype

x_type = ir.RankedTensorType(a.type)
dtype = x_type.element_type
dims = x_type.shape

supported_dtypes = (
np.dtype(np.float32),
np.dtype(np.float64),
)

if dtype not in supported_dtypes:
if np_dtype not in supported_dtypes:
raise TypeError(f"TDMA only supports {supported_dtypes} arrays, got: {dtype}")

# compute number of elements to vectorize over
Expand All @@ -87,56 +112,54 @@ def tdma_xla_encode_cpu(builder, a, b, c, d, system_depths):

stride = dims[-1]

sys_depth_shape = builder.get_shape(system_depths)
sys_depth_dtype = sys_depth_shape.element_type()
sys_depth_dims = sys_depth_shape.dimensions()
assert sys_depth_dtype is np.dtype(np.int32)
assert tuple(sys_depth_dims) == tuple(dims[:-1])
out_types = [
ir.RankedTensorType.get(dims, dtype),
ir.RankedTensorType.get((stride,), dtype),
]

arr_shape = xla_client.Shape.array_shape(dtype, dims)
out_shape = xla_client.Shape.tuple_shape([arr_shape, xla_client.Shape.array_shape(dtype, (stride,))])

if dtype is np.dtype(np.float32):
if np_dtype is np.dtype(np.float32):
kernel = b"tdma_cython_float"
elif dtype is np.dtype(np.float64):
elif np_dtype is np.dtype(np.float64):
kernel = b"tdma_cython_double"
else:
raise RuntimeError("got unrecognized dtype")

out = _ops.CustomCall(
builder,
out = custom_call(
kernel,
operands=(
a,
b,
c,
d,
system_depths,
_constant_s64_scalar(builder, num_systems),
_constant_s64_scalar(builder, stride),
as_mhlo_constant(num_systems, np.int64),
as_mhlo_constant(stride, np.int64),
),
shape=out_shape,
result_types=out_types,
)
return _ops.GetTupleElement(out, 0)
return out.results[:-1]


def tdma_xla_encode_gpu(builder, a, b, c, d, system_depths):
def tdma_xla_encode_gpu(ctx, a, b, c, d, system_depths):
# try import again to trigger exception on ImportError
from veros.core.special import tdma_cuda_ # noqa: F401

if system_depths is not None:
raise ValueError("TDMA does not support system_depths argument on GPU")

a_shape = builder.get_shape(a)
dtype = a_shape.element_type()
dims = a_shape.dimensions()
x_aval, *_ = ctx.avals_in
x_nptype = x_aval.dtype

x_type = ir.RankedTensorType(a.type)
dtype = x_type.element_type
dims = x_type.shape

supported_dtypes = (
np.dtype(np.float32),
np.dtype(np.float64),
)

if dtype not in supported_dtypes:
if x_nptype not in supported_dtypes:
raise TypeError(f"TDMA only supports {supported_dtypes} arrays, got: {dtype}")

# compute number of elements to vectorize over
Expand All @@ -153,31 +176,32 @@ def tdma_xla_encode_gpu(builder, a, b, c, d, system_depths):
else:
raise RuntimeError("got unrecognized dtype")

opaque = tdma_cuda_.build_tridiag_descriptor(num_systems, system_depth)
descriptor = tdma_cuda_.build_tridiag_descriptor(num_systems, system_depth)

ndims = len(dims)
arr_layout = tuple(range(ndims - 2, -1, -1)) + (ndims - 1,)
arr_shape = xla_client.Shape.array_shape(dtype, dims, arr_layout)
out_shape = xla_client.Shape.tuple_shape([arr_shape, arr_shape])

out = _ops.CustomCallWithLayout(
builder,
out_types = [ir.RankedTensorType.get(dims, dtype), ir.RankedTensorType.get(dims, dtype)]
out_layouts = (arr_layout, arr_layout)

out = custom_call(
kernel,
operands=(a, b, c, d),
shape_with_layout=out_shape,
operand_shapes_with_layout=(arr_shape,) * 4,
opaque=opaque,
result_tyes=out_types,
result_layouts=out_layouts,
operand_layouts=(arr_layout,) * 4,
backend_config=descriptor,
)
return _ops.GetTupleElement(out, 0)
return out.results[:-1]


def tdma_abstract_eval(a, b, c, d, system_depths):
return abstract_arrays.ShapedArray(a.shape, a.dtype)
return ShapedArray(a.shape, a.dtype)


tdma_p = Primitive("tdma")
tdma_p.def_impl(tdma_impl)
tdma_p.def_abstract_eval(tdma_abstract_eval)

xla.backend_specific_translations["cpu"][tdma_p] = tdma_xla_encode_cpu
xla.backend_specific_translations["gpu"][tdma_p] = tdma_xla_encode_gpu
mlir.register_lowering(tdma_p, tdma_xla_encode_cpu, platform="cpu")
mlir.register_lowering(tdma_p, tdma_xla_encode_gpu, platform="cuda")

0 comments on commit 00b4cc0

Please sign in to comment.