Skip to content

Commit

Permalink
🐛
Browse files Browse the repository at this point in the history
  • Loading branch information
dionhaefner committed Oct 10, 2023
1 parent b5a7de3 commit 9060774
Showing 1 changed file with 20 additions and 4 deletions.
24 changes: 20 additions & 4 deletions veros/core/special/tdma_.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# defensive imports since extensions are optional
try:
from veros.core.special import tdma_cython_
except ImportError:
Expand All @@ -13,16 +14,31 @@
HAS_GPU_EXT = True

import numpy as np
import jax.numpy as jnp

import jax
import jax.numpy as jnp
from jax.core import Primitive, ShapedArray

from jax.lib import xla_client
from jax.interpreters import xla, mlir
import jaxlib.mlir.ir as ir
from jaxlib.mlir.dialects import mhlo

try:
from jax.interpreters.mlir import custom_call # noqa: F401
except ImportError:
# TODO: remove once we require jax > 0.4.16
from jaxlib.hlo_helpers import custom_call as _custom_call

# Recent versions return a structure with a field 'results'. We mock it on
# older versions
from collections import namedtuple

MockResult = namedtuple("MockResult", ["results"])

def custom_call(*args, result_types, **kwargs):
results = _custom_call(*args, out_types=result_types, **kwargs)
return MockResult(results)


if HAS_CPU_EXT:
for kernel_name in (b"tdma_cython_double", b"tdma_cython_float"):
Expand Down Expand Up @@ -108,7 +124,7 @@ def tdma_xla_encode_cpu(ctx, a, b, c, d, system_depths):
else:
raise RuntimeError("got unrecognized dtype")

out = mlir.custom_call(
out = custom_call(
kernel,
operands=(
a,
Expand Down Expand Up @@ -168,7 +184,7 @@ def tdma_xla_encode_gpu(ctx, a, b, c, d, system_depths):
out_types = [ir.RankedTensorType.get(dims, dtype), ir.RankedTensorType.get(dims, dtype)]
out_layouts = (arr_layout, arr_layout)

out = mlir.custom_call(
out = custom_call(
kernel,
operands=(a, b, c, d),
result_tyes=out_types,
Expand Down

0 comments on commit 9060774

Please sign in to comment.