Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return DShapeArray objects in measure primitives when possible #1170

Draft
wants to merge 23 commits into
base: main
Choose a base branch
from

Conversation

rauletorresc
Copy link
Contributor

@rauletorresc rauletorresc commented Oct 1, 2024

Context: Measure primitives only return ShapedArray objects, which forces variables like the number of shots or number of wires to be known at compilation time. With the support for DShapedArray objects, this constraint can be removed.

Description of the Change: Return DShapedArray objects when possible.

Benefits: Eliminating the need to know the number of shots/wires statically is a step towards allowing us to compile programs that don't have a fixed number of qubits.

[sc-74736]

@rauletorresc rauletorresc added the frontend Pull requests that update the frontend label Oct 1, 2024
@rauletorresc rauletorresc requested a review from a team October 1, 2024 23:37
@rauletorresc rauletorresc self-assigned this Oct 1, 2024
@rauletorresc rauletorresc marked this pull request as ready for review October 1, 2024 23:48
frontend/catalyst/jax_primitives.py Outdated Show resolved Hide resolved
frontend/catalyst/jax_primitives.py Outdated Show resolved Hide resolved
Copy link

codecov bot commented Oct 4, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 97.96%. Comparing base (2a49e34) to head (04f02fd).
Report is 1 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1170      +/-   ##
==========================================
+ Coverage   97.89%   97.96%   +0.07%     
==========================================
  Files          76       76              
  Lines       10879    10885       +6     
  Branches     1289     1292       +3     
==========================================
+ Hits        10650    10664      +14     
+ Misses        178      174       -4     
+ Partials       51       47       -4     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@rauletorresc rauletorresc requested a review from a team October 8, 2024 02:53
@rauletorresc rauletorresc changed the title Replace ShapedArray with DShapeArray objects in measure primitives Return DShapeArray objects in measure primitives when possible Oct 8, 2024
Copy link
Collaborator

@dime10 dime10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Raul!

assert shape == (shots, obs.num_qubits)
else:
assert shape == (shots,)
if Signature.is_dynamic_shape(shape):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice 👍 We could also keep the assertions on the static case.

Comment on lines +143 to +147
shots, shape_value = jaxpr.eqns[1].params.values()

assert jaxpr.eqns[1].primitive == probs_p
assert isinstance(shape_value[0], DynamicJaxprTracer)
assert shots == 5
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know what changed that means there are no escaped tracer issues compared to your previous test?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note the jaxpr for the example here is different than what JAX would generate, and I'm wondering if that isn't going to lead to issues (like escaped tracers).

Our jaxpr, contains tracer in the jaxpr equation:

{ lambda ; a:i64[]. let
    b:AbstractObs(num_qubits=0,primitive=compbasis) = compbasis 
    c:f64[a] d:i64[a] = counts[
      shape=(Traced<ShapedArray(int64[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>,)
      shots=5
    ] b
  in (c, d) }

JAX's jaxpr for similar example (jnp.ones((dim_0, 5))), does not insert the tracer as constant data into the jaxpr equation:

{ lambda ; a:i64[]. let
    b:f64[a,5] = broadcast_in_dim[broadcast_dimensions=() shape=(None, 5)] 1.0 a
  in (b,) }

Instead we find None in the place where the tracer would be.

Comment on lines +168 to +173
def f(dim_0):
obs = compbasis_p.bind()
return state_p.bind(obs, shots=5, shape=(dim_0,))

jaxpr = jax.make_jaxpr(f)(1)
shots, shape_value = jaxpr.eqns[1].params.values()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we still need to cover more of the pipeline in our tests (ignoring the MLIR core for now, which can be addressed in a different PR).
From what I can see we test the bind function and abstract evaluation of jaxpr primitives, but that leaves out most of the code that constitutes the frontend. The first step might be to identify which functions could be affected by this change, from the start of tracing until jax spits out the mlir, and then ensure those functions are tested for the dynamic case. Or alternatively develop tests that cover roughly that entire span.

@rauletorresc rauletorresc marked this pull request as draft October 9, 2024 13:50
@rauletorresc
Copy link
Contributor Author

I decided to tackle the MLIR changes in the same PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
frontend Pull requests that update the frontend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants