From 13fff0e1f16afce8b2438e9e0a970277bc5311bd Mon Sep 17 00:00:00 2001 From: Erick Ochoa Lopez Date: Wed, 2 Oct 2024 15:33:40 -0400 Subject: [PATCH] Add lit test to ensure that set_basis_state is the first operation on the tape --- frontend/catalyst/jax_tracer.py | 2 +- frontend/test/lit/test_skip_initial_state_prep.py | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/frontend/catalyst/jax_tracer.py b/frontend/catalyst/jax_tracer.py index 0847a31f7f..3affaa86bb 100644 --- a/frontend/catalyst/jax_tracer.py +++ b/frontend/catalyst/jax_tracer.py @@ -183,7 +183,7 @@ def _eval_jaxpr(*args, **kwargs): # Take care when adding primitives to this set in order to avoid introducing a quadratic number of # edges to the jaxpr equation graph in ``sort_eqns()``. Each equation with a primitive in this set # is constrained to occur before all subsequent equations in the quantum operations trace. -FORCED_ORDER_PRIMITIVES = {qdevice_p, gphase_p} +FORCED_ORDER_PRIMITIVES = {qdevice_p, gphase_p, set_basis_state_p, set_state_p} PAULI_NAMED_MAP = { "I": "Identity", diff --git a/frontend/test/lit/test_skip_initial_state_prep.py b/frontend/test/lit/test_skip_initial_state_prep.py index e2b0eda70a..f74b40c82f 100644 --- a/frontend/test/lit/test_skip_initial_state_prep.py +++ b/frontend/test/lit/test_skip_initial_state_prep.py @@ -71,3 +71,17 @@ def state_prep_example_double(): # CHECK: quantum.set_state # CHECK-NOT: quantum.set_state print(state_prep_example_double.mlir) + +@qml.qjit(target="mlir") +@qml.qnode(qml.device("lightning.qubit", wires=5, shots=100)) +def multiple_basis_embedding(): + """Ok, so we only have one, but we also need to guarantee that it is the first one""" + # CHECK-LABEL: func.func private @multiple_basis_embedding + # CHECK-NOT: quantum.custom + # CHECK: quantum.set_basis_state + qml.BasisEmbedding(2, wires=[0, 1, 2]) + qml.BasisEmbedding(3, wires=[3, 4]) + qml.ctrl(qml.X(0), control=[1, 2]) + return qml.counts(wires=[3, 4]) + +print(multiple_basis_embedding.mlir)