-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Return DShapeArray objects in measure primitives when possible #1170
base: main
Are you sure you want to change the base?
Changes from 16 commits
79b8dc5
4c4b75f
beecfd4
a01831a
d99681c
ecfffa6
0b6ef61
4305870
8c79bd6
f51fb90
8b71383
b4e5e35
cdafde6
de568ab
e1f0449
04f02fd
f28b3e4
cdd7e9f
1b0a193
e5b1315
b72158e
7f36d1a
2307a8e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
""" | ||
import jax | ||
|
||
from catalyst.jax_extras import DynamicJaxprTracer | ||
from catalyst.jax_primitives import ( | ||
compbasis_p, | ||
counts_p, | ||
|
@@ -40,6 +41,25 @@ def f(): | |
assert jaxpr.eqns[1].outvars[0].aval.shape == (5, 0) | ||
|
||
|
||
def test_sample_dynamic(): | ||
"""Test that the sample primitive can be captured into jaxpr | ||
when using tracers.""" | ||
|
||
def f(shots): | ||
obs = compbasis_p.bind() | ||
return sample_p.bind(obs, shots=shots, shape=(shots, 0)) | ||
|
||
jaxpr = jax.make_jaxpr(f)(5).jaxpr | ||
shots_tracer, shape_value = jaxpr.eqns[1].params.values() | ||
|
||
assert jaxpr.eqns[1].primitive == sample_p | ||
assert isinstance(shots_tracer, DynamicJaxprTracer) | ||
assert isinstance(shape_value[0], DynamicJaxprTracer) | ||
assert shape_value[1] == 0 | ||
assert isinstance(jaxpr.eqns[1].outvars[0].aval.shape[0], jax.core.Var) | ||
assert jaxpr.eqns[1].outvars[0].aval.shape[1] == 0 | ||
|
||
|
||
def test_counts(): | ||
"""Test that the counts primitive can be captured by jaxpr.""" | ||
|
||
|
@@ -54,6 +74,24 @@ def f(): | |
assert jaxpr.eqns[1].outvars[1].aval.shape == (1,) | ||
|
||
|
||
def test_counts_dynamic(): | ||
"""Test that the counts primitive can be captured by jaxpr | ||
when using tracers.""" | ||
|
||
def f(dim_0): | ||
obs = compbasis_p.bind() | ||
return counts_p.bind(obs, shots=5, shape=(dim_0,)) | ||
|
||
jaxpr = jax.make_jaxpr(f)(1) | ||
shots, shape_value = jaxpr.eqns[1].params.values() | ||
|
||
assert jaxpr.eqns[1].primitive == counts_p | ||
assert isinstance(shape_value[0], DynamicJaxprTracer) | ||
assert shots == 5 | ||
assert isinstance(jaxpr.eqns[1].outvars[0].aval.shape[0], jax.core.Var) | ||
assert isinstance(jaxpr.eqns[1].outvars[1].aval.shape[0], jax.core.Var) | ||
|
||
|
||
def test_expval(): | ||
"""Test that the expval primitive can be captured by jaxpr.""" | ||
|
||
|
@@ -81,7 +119,7 @@ def f(): | |
|
||
|
||
def test_probs(): | ||
"""Test that the var primitive can be captured by jaxpr.""" | ||
"""Test that the probs primitive can be captured by jaxpr.""" | ||
|
||
def f(): | ||
obs = compbasis_p.bind() | ||
|
@@ -93,6 +131,23 @@ def f(): | |
assert jaxpr.eqns[1].outvars[0].aval.shape == (1,) | ||
|
||
|
||
def test_probs_dynamic(): | ||
"""Test that the probs primitive can be captured by jaxpr | ||
when using tracers.""" | ||
|
||
def f(dim_0): | ||
obs = compbasis_p.bind() | ||
return probs_p.bind(obs, shots=5, shape=(dim_0,)) | ||
|
||
jaxpr = jax.make_jaxpr(f)(1) | ||
shots, shape_value = jaxpr.eqns[1].params.values() | ||
|
||
assert jaxpr.eqns[1].primitive == probs_p | ||
assert isinstance(shape_value[0], DynamicJaxprTracer) | ||
assert shots == 5 | ||
Comment on lines
+143
to
+147
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you know what changed that means there are no escaped tracer issues compared to your previous test? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note the jaxpr for the example here is different than what JAX would generate, and I'm wondering if that isn't going to lead to issues (like escaped tracers). Our jaxpr, contains tracer in the jaxpr equation: { lambda ; a:i64[]. let
b:AbstractObs(num_qubits=0,primitive=compbasis) = compbasis
c:f64[a] d:i64[a] = counts[
shape=(Traced<ShapedArray(int64[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>,)
shots=5
] b
in (c, d) } JAX's jaxpr for similar example ( { lambda ; a:i64[]. let
b:f64[a,5] = broadcast_in_dim[broadcast_dimensions=() shape=(None, 5)] 1.0 a
in (b,) } Instead we find |
||
assert isinstance(jaxpr.eqns[1].outvars[0].aval.shape[0], jax.core.Var) | ||
|
||
|
||
def test_state(): | ||
"""Test that the state primitive can be captured by jaxpr.""" | ||
|
||
|
@@ -104,3 +159,20 @@ def f(): | |
assert jaxpr.eqns[1].primitive == state_p | ||
assert jaxpr.eqns[1].params == {"shape": (1,), "shots": 5} | ||
assert jaxpr.eqns[1].outvars[0].aval.shape == (1,) | ||
|
||
|
||
def test_state_dynamic(): | ||
"""Test that the state primitive can be captured by jaxpr | ||
when using tracers.""" | ||
|
||
def f(dim_0): | ||
obs = compbasis_p.bind() | ||
return state_p.bind(obs, shots=5, shape=(dim_0,)) | ||
|
||
jaxpr = jax.make_jaxpr(f)(1) | ||
shots, shape_value = jaxpr.eqns[1].params.values() | ||
Comment on lines
+168
to
+173
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we still need to cover more of the pipeline in our tests (ignoring the MLIR core for now, which can be addressed in a different PR). |
||
|
||
assert jaxpr.eqns[1].primitive == state_p | ||
assert isinstance(shape_value[0], DynamicJaxprTracer) | ||
assert shots == 5 | ||
assert isinstance(jaxpr.eqns[1].outvars[0].aval.shape[0], jax.core.Var) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice 👍 We could also keep the assertions on the static case.