Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return DShapeArray objects in measure primitives when possible #1170

Draft
wants to merge 23 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
79b8dc5
Replace ShapedArrays with DShapeArrays in measure primitives
rauletorresc Oct 1, 2024
4c4b75f
Use make_jaxpr2 instead of jax.make_jaxpr
rauletorresc Oct 2, 2024
beecfd4
Merge branch 'main' into raultorres/measures_and_dshapearrays
rauletorresc Oct 2, 2024
a01831a
Add test
rauletorresc Oct 3, 2024
d99681c
Merge branch 'main' into raultorres/measures_and_dshapearrays
rauletorresc Oct 3, 2024
ecfffa6
Check if the shape contains a tracer
rauletorresc Oct 3, 2024
0b6ef61
Return a DShapedArray or a ShapedArray according to the context
rauletorresc Oct 3, 2024
4305870
Fix documentation
rauletorresc Oct 3, 2024
8c79bd6
Fix test
rauletorresc Oct 4, 2024
f51fb90
Use public methods
rauletorresc Oct 4, 2024
8b71383
Add tests for probs, state and counts primitives
rauletorresc Oct 8, 2024
b4e5e35
Merge branch 'main' into raultorres/measures_and_dshapearrays
rauletorresc Oct 8, 2024
cdafde6
Update changelog
rauletorresc Oct 8, 2024
de568ab
[NFC] Fix doc strings
rauletorresc Oct 8, 2024
e1f0449
Change error for assertions
rauletorresc Oct 8, 2024
04f02fd
Merge branch 'main' into raultorres/measures_and_dshapearrays
rauletorresc Oct 8, 2024
f28b3e4
Modify MLIR lowering of sample primitive
rauletorresc Oct 9, 2024
cdd7e9f
Merge branch 'main' into raultorres/measures_and_dshapearrays
rauletorresc Oct 10, 2024
1b0a193
Modify bufferization and conversion patterns
rauletorresc Oct 10, 2024
e5b1315
Merge branch 'main' into raultorres/measures_and_dshapearrays
rauletorresc Oct 10, 2024
b72158e
Merge branch 'main' into raultorres/measures_and_dshapearrays
rauletorresc Oct 15, 2024
7f36d1a
Modify CountsOp at the MLIR level
rauletorresc Oct 16, 2024
2307a8e
Trace shots and extract from tensor
rauletorresc Oct 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,9 @@
this nested module into other IRs/binary formats and lowering `call_function_in_module`
to something that can dispatch calls to another runtime / VM.

* Measurement primitives now support dynamic shape at the frontend, although at the PennyLane
side, the corresponding operations still lack such support.
[(#1170)](https://github.com/PennyLaneAI/catalyst/pull/1170)

<h3>Contributors</h3>

Expand Down
5 changes: 2 additions & 3 deletions frontend/catalyst/api_extensions/differentiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from pennylane import QNode

import catalyst
from catalyst.jax_extras import Jaxpr
from catalyst.jax_extras import Jaxpr, make_jaxpr2
from catalyst.jax_primitives import (
GradParams,
expval_p,
Expand Down Expand Up @@ -800,8 +800,7 @@ def _make_jaxpr_check_differentiable(
return the output tree."""
method = grad_params.method
with mark_gradient_tracing(method):
jaxpr, shape = jax.make_jaxpr(f, return_shape=True)(*args, **kwargs)
_, out_tree = tree_flatten(shape)
jaxpr, _, out_tree = make_jaxpr2(f)(*args, **kwargs)

for pos, arg in enumerate(jaxpr.in_avals):
if arg.dtype.kind != "f" and pos in grad_params.expanded_argnums:
Expand Down
30 changes: 14 additions & 16 deletions frontend/catalyst/jax_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -1633,10 +1633,8 @@ def _hamiltonian_lowering(jax_ctx: mlir.LoweringRuleContext, coeffs: ir.Value, *
def _sample_abstract_eval(obs, shots, shape):
assert isinstance(obs, AbstractObs)

if obs.primitive is compbasis_p:
assert shape == (shots, obs.num_qubits)
else:
assert shape == (shots,)
if Signature.is_dynamic_shape(shape):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice 👍 We could also keep the assertions on the static case.

return core.DShapedArray(shape, np.dtype("float64"))

return core.ShapedArray(shape, jax.numpy.float64)

Expand Down Expand Up @@ -1670,10 +1668,10 @@ def _counts_def_impl(ctx, obs, shots, shape): # pragma: no cover
def _counts_abstract_eval(obs, shots, shape):
assert isinstance(obs, AbstractObs)

if obs.primitive is compbasis_p:
assert shape == (2**obs.num_qubits,)
else:
assert shape == (2,)
if Signature.is_dynamic_shape(shape):
return core.DShapedArray(shape, np.dtype("float64")), core.DShapedArray(
shape, np.dtype("int64")
)

return core.ShapedArray(shape, jax.numpy.float64), core.ShapedArray(shape, jax.numpy.int64)

Expand Down Expand Up @@ -1762,10 +1760,10 @@ def _var_lowering(jax_ctx: mlir.LoweringRuleContext, obs: ir.Value, shots: int,
def _probs_abstract_eval(obs, shape, shots=None):
assert isinstance(obs, AbstractObs)

if obs.primitive is compbasis_p:
assert shape == (2**obs.num_qubits,)
else:
raise TypeError("probs only supports computational basis")
assert obs.primitive is compbasis_p, "probs only supports computational basis"

if Signature.is_dynamic_shape(shape):
return core.DShapedArray(shape, np.dtype("float64"))

return core.ShapedArray(shape, jax.numpy.float64)

Expand All @@ -1791,10 +1789,10 @@ def _probs_lowering(jax_ctx: mlir.LoweringRuleContext, obs: ir.Value, shape: tup
def _state_abstract_eval(obs, shape, shots=None):
assert isinstance(obs, AbstractObs)

if obs.primitive is compbasis_p:
assert shape == (2**obs.num_qubits,)
else:
raise TypeError("state only supports computational basis")
assert obs.primitive is compbasis_p, "state only supports computational basis"

if Signature.is_dynamic_shape(shape):
return core.DShapedArray(shape, np.dtype("complex128"))

return core.ShapedArray(shape, jax.numpy.complex128)

Expand Down
36 changes: 29 additions & 7 deletions frontend/catalyst/utils/calculate_grad_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""


from jax.core import ShapedArray
from jax.core import DShapedArray, ShapedArray, Tracer


class Signature:
Expand Down Expand Up @@ -76,15 +76,33 @@ def get_results(self):

@staticmethod
def is_tensor(x):
"""Determine whether a type ``x`` is a ``jax.core.ShapedArray``.
"""Determine whether a type ``x`` is a ``jax.core.DShapedArray``
or ``jax.core.ShapedArray``.

Args:
x: The type to be tested.

Returns:
bool: Whether the type ``x`` is a ``jax.core.ShapedArray``
bool: Whether the type ``x`` is a ``jax.core.DShapedArray``
or ``jax.core.ShapedArray``
"""
return isinstance(x, ShapedArray)
return isinstance(x, (DShapedArray, ShapedArray))

@staticmethod
def is_dynamic_shape(shape):
"""Determine whether a shape contains a tracer or not.

Args:
shape: The shape to be tested.

Returns:
bool: Whether the shape contains a tracer or not.
"""
for s in shape:
if isinstance(s, Tracer):
return True

return False

def __eq__(self, other):
return self.xs == other.xs and self.ys == other.ys
Expand Down Expand Up @@ -123,8 +141,12 @@ def calculate_grad_shape(signature, indices) -> Signature:
grad_res_shape.append(axis)
element_type = diff_arg_type.dtype

grad_res_type = (
ShapedArray(grad_res_shape, element_type) if grad_res_shape else diff_arg_type
)
grad_res_type = diff_arg_type
if grad_res_shape:
XShapedArray = (
DShapedArray if Signature.is_dynamic_shape(grad_res_shape) else ShapedArray
)
grad_res_type = XShapedArray(grad_res_shape, element_type)

grad_result_types.append(grad_res_type)
return Signature(signature.get_inputs(), grad_result_types)
74 changes: 73 additions & 1 deletion frontend/test/pytest/test_measurement_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""
import jax

from catalyst.jax_extras import DynamicJaxprTracer
from catalyst.jax_primitives import (
compbasis_p,
counts_p,
Expand All @@ -40,6 +41,25 @@ def f():
assert jaxpr.eqns[1].outvars[0].aval.shape == (5, 0)


def test_sample_dynamic():
"""Test that the sample primitive can be captured into jaxpr
when using tracers."""

def f(shots):
obs = compbasis_p.bind()
return sample_p.bind(obs, shots=shots, shape=(shots, 0))

jaxpr = jax.make_jaxpr(f)(5).jaxpr
shots_tracer, shape_value = jaxpr.eqns[1].params.values()

assert jaxpr.eqns[1].primitive == sample_p
assert isinstance(shots_tracer, DynamicJaxprTracer)
assert isinstance(shape_value[0], DynamicJaxprTracer)
assert shape_value[1] == 0
assert isinstance(jaxpr.eqns[1].outvars[0].aval.shape[0], jax.core.Var)
assert jaxpr.eqns[1].outvars[0].aval.shape[1] == 0


def test_counts():
"""Test that the counts primitive can be captured by jaxpr."""

Expand All @@ -54,6 +74,24 @@ def f():
assert jaxpr.eqns[1].outvars[1].aval.shape == (1,)


def test_counts_dynamic():
"""Test that the counts primitive can be captured by jaxpr
when using tracers."""

def f(dim_0):
obs = compbasis_p.bind()
return counts_p.bind(obs, shots=5, shape=(dim_0,))

jaxpr = jax.make_jaxpr(f)(1)
shots, shape_value = jaxpr.eqns[1].params.values()

assert jaxpr.eqns[1].primitive == counts_p
assert isinstance(shape_value[0], DynamicJaxprTracer)
assert shots == 5
assert isinstance(jaxpr.eqns[1].outvars[0].aval.shape[0], jax.core.Var)
assert isinstance(jaxpr.eqns[1].outvars[1].aval.shape[0], jax.core.Var)


def test_expval():
"""Test that the expval primitive can be captured by jaxpr."""

Expand Down Expand Up @@ -81,7 +119,7 @@ def f():


def test_probs():
"""Test that the var primitive can be captured by jaxpr."""
"""Test that the probs primitive can be captured by jaxpr."""

def f():
obs = compbasis_p.bind()
Expand All @@ -93,6 +131,23 @@ def f():
assert jaxpr.eqns[1].outvars[0].aval.shape == (1,)


def test_probs_dynamic():
"""Test that the probs primitive can be captured by jaxpr
when using tracers."""

def f(dim_0):
obs = compbasis_p.bind()
return probs_p.bind(obs, shots=5, shape=(dim_0,))

jaxpr = jax.make_jaxpr(f)(1)
shots, shape_value = jaxpr.eqns[1].params.values()

assert jaxpr.eqns[1].primitive == probs_p
assert isinstance(shape_value[0], DynamicJaxprTracer)
assert shots == 5
Comment on lines +143 to +147
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know what changed that means there are no escaped tracer issues compared to your previous test?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note the jaxpr for the example here is different than what JAX would generate, and I'm wondering if that isn't going to lead to issues (like escaped tracers).

Our jaxpr, contains tracer in the jaxpr equation:

{ lambda ; a:i64[]. let
    b:AbstractObs(num_qubits=0,primitive=compbasis) = compbasis 
    c:f64[a] d:i64[a] = counts[
      shape=(Traced<ShapedArray(int64[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>,)
      shots=5
    ] b
  in (c, d) }

JAX's jaxpr for similar example (jnp.ones((dim_0, 5))), does not insert the tracer as constant data into the jaxpr equation:

{ lambda ; a:i64[]. let
    b:f64[a,5] = broadcast_in_dim[broadcast_dimensions=() shape=(None, 5)] 1.0 a
  in (b,) }

Instead we find None in the place where the tracer would be.

assert isinstance(jaxpr.eqns[1].outvars[0].aval.shape[0], jax.core.Var)


def test_state():
"""Test that the state primitive can be captured by jaxpr."""

Expand All @@ -104,3 +159,20 @@ def f():
assert jaxpr.eqns[1].primitive == state_p
assert jaxpr.eqns[1].params == {"shape": (1,), "shots": 5}
assert jaxpr.eqns[1].outvars[0].aval.shape == (1,)


def test_state_dynamic():
"""Test that the state primitive can be captured by jaxpr
when using tracers."""

def f(dim_0):
obs = compbasis_p.bind()
return state_p.bind(obs, shots=5, shape=(dim_0,))

jaxpr = jax.make_jaxpr(f)(1)
shots, shape_value = jaxpr.eqns[1].params.values()
Comment on lines +168 to +173
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we still need to cover more of the pipeline in our tests (ignoring the MLIR core for now, which can be addressed in a different PR).
From what I can see we test the bind function and abstract evaluation of jaxpr primitives, but that leaves out most of the code that constitutes the frontend. The first step might be to identify which functions could be affected by this change, from the start of tracing until jax spits out the mlir, and then ensure those functions are tested for the dynamic case. Or alternatively develop tests that cover roughly that entire span.


assert jaxpr.eqns[1].primitive == state_p
assert isinstance(shape_value[0], DynamicJaxprTracer)
assert shots == 5
assert isinstance(jaxpr.eqns[1].outvars[0].aval.shape[0], jax.core.Var)
Loading