diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 2083694c0b..86f2e10454 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -333,6 +333,9 @@ this nested module into other IRs/binary formats and lowering `call_function_in_module` to something that can dispatch calls to another runtime / VM. +* Measurement primitives now support dynamic shape at the frontend, although at the PennyLane + side, the corresponding operations still lack such support. + [(#1170)](https://github.com/PennyLaneAI/catalyst/pull/1170)

Contributors

diff --git a/frontend/catalyst/api_extensions/differentiation.py b/frontend/catalyst/api_extensions/differentiation.py index 7f66b38348..e6801229ab 100644 --- a/frontend/catalyst/api_extensions/differentiation.py +++ b/frontend/catalyst/api_extensions/differentiation.py @@ -30,7 +30,7 @@ from pennylane import QNode import catalyst -from catalyst.jax_extras import Jaxpr +from catalyst.jax_extras import Jaxpr, make_jaxpr2 from catalyst.jax_primitives import ( GradParams, expval_p, @@ -807,8 +807,7 @@ def _make_jaxpr_check_differentiable( return the output tree.""" method = grad_params.method with mark_gradient_tracing(method): - jaxpr, shape = jax.make_jaxpr(f, return_shape=True)(*args, **kwargs) - _, out_tree = tree_flatten(shape) + jaxpr, _, out_tree = make_jaxpr2(f)(*args, **kwargs) for pos, arg in enumerate(jaxpr.in_avals): if arg.dtype.kind != "f" and pos in grad_params.expanded_argnums: diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index 03a68697be..9ba097a0b4 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -1658,10 +1658,8 @@ def _hamiltonian_lowering(jax_ctx: mlir.LoweringRuleContext, coeffs: ir.Value, * def _sample_abstract_eval(obs, shots, shape): assert isinstance(obs, AbstractObs) - if obs.primitive is compbasis_p: - assert shape == (shots, obs.num_qubits) - else: - assert shape == (shots,) + if Signature.is_dynamic_shape(shape): + return core.DShapedArray(shape, np.dtype("float64")) return core.ShapedArray(shape, jax.numpy.float64) @@ -1671,16 +1669,18 @@ def _sample_def_impl(ctx, obs, shots, shape): # pragma: no cover raise NotImplementedError() -def _sample_lowering(jax_ctx: mlir.LoweringRuleContext, obs: ir.Value, shots: int, shape: tuple): +def _sample_lowering( + jax_ctx: mlir.LoweringRuleContext, obs: ir.Value, shots: ir.Value, shape: tuple +): ctx = jax_ctx.module_context.context ctx.allow_unregistered_dialects = True i64_type = ir.IntegerType.get_signless(64, ctx) - shots_attr = ir.IntegerAttr.get(i64_type, shots) + shots_val = TensorExtractOp(i64_type, shots, []).result f64_type = ir.F64Type.get() result_type = ir.RankedTensorType.get(shape, f64_type) - return SampleOp(result_type, obs, shots_attr).results + return SampleOp(result_type, obs, shots_val).results # @@ -1695,25 +1695,27 @@ def _counts_def_impl(ctx, obs, shots, shape): # pragma: no cover def _counts_abstract_eval(obs, shots, shape): assert isinstance(obs, AbstractObs) - if obs.primitive is compbasis_p: - assert shape == (2**obs.num_qubits,) - else: - assert shape == (2,) + if Signature.is_dynamic_shape(shape): + return core.DShapedArray(shape, np.dtype("float64")), core.DShapedArray( + shape, np.dtype("int64") + ) return core.ShapedArray(shape, jax.numpy.float64), core.ShapedArray(shape, jax.numpy.int64) -def _counts_lowering(jax_ctx: mlir.LoweringRuleContext, obs: ir.Value, shots: int, shape: tuple): +def _counts_lowering( + jax_ctx: mlir.LoweringRuleContext, obs: ir.Value, shots: ir.Value, shape: tuple +): ctx = jax_ctx.module_context.context ctx.allow_unregistered_dialects = True i64_type = ir.IntegerType.get_signless(64, ctx) - shots_attr = ir.IntegerAttr.get(i64_type, shots) + shots_val = TensorExtractOp(i64_type, shots, []).result f64_type = ir.F64Type.get() eigvals_type = ir.RankedTensorType.get(shape, f64_type) counts_type = ir.RankedTensorType.get(shape, i64_type) - return CountsOp(eigvals_type, counts_type, obs, shots_attr).results + return CountsOp(eigvals_type, counts_type, obs, shots_val).results # @@ -1787,10 +1789,10 @@ def _var_lowering(jax_ctx: mlir.LoweringRuleContext, obs: ir.Value, shots: int, def _probs_abstract_eval(obs, shape, shots=None): assert isinstance(obs, AbstractObs) - if obs.primitive is compbasis_p: - assert shape == (2**obs.num_qubits,) - else: - raise TypeError("probs only supports computational basis") + assert obs.primitive is compbasis_p, "probs only supports computational basis" + + if Signature.is_dynamic_shape(shape): + return core.DShapedArray(shape, np.dtype("float64")) return core.ShapedArray(shape, jax.numpy.float64) @@ -1816,10 +1818,10 @@ def _probs_lowering(jax_ctx: mlir.LoweringRuleContext, obs: ir.Value, shape: tup def _state_abstract_eval(obs, shape, shots=None): assert isinstance(obs, AbstractObs) - if obs.primitive is compbasis_p: - assert shape == (2**obs.num_qubits,) - else: - raise TypeError("state only supports computational basis") + assert obs.primitive is compbasis_p, "state only supports computational basis" + + if Signature.is_dynamic_shape(shape): + return core.DShapedArray(shape, np.dtype("complex128")) return core.ShapedArray(shape, jax.numpy.complex128) diff --git a/frontend/catalyst/jax_tracer.py b/frontend/catalyst/jax_tracer.py index 1bfb118626..ad51ab8c76 100644 --- a/frontend/catalyst/jax_tracer.py +++ b/frontend/catalyst/jax_tracer.py @@ -876,7 +876,7 @@ def trace_quantum_measurements( out_classical_tracers.append(o.mv) else: shape = (shots, nqubits) if using_compbasis else (shots,) - result = sample_p.bind(obs_tracers, shots=shots, shape=shape) + result = sample_p.bind(obs_tracers, shots, shape=shape) if using_compbasis: result = jnp.astype(result, jnp.int64) @@ -909,7 +909,7 @@ def trace_quantum_measurements( "Please specify a finite number of shots." ) shape = (2**nqubits,) if using_compbasis else (2,) - results = counts_p.bind(obs_tracers, shots=shots, shape=shape) + results = counts_p.bind(obs_tracers, shots, shape=shape) if using_compbasis: results = (jnp.asarray(results[0], jnp.int64), results[1]) out_classical_tracers.extend(results) diff --git a/frontend/catalyst/utils/calculate_grad_shape.py b/frontend/catalyst/utils/calculate_grad_shape.py index 5465045bce..bcce6b0b63 100644 --- a/frontend/catalyst/utils/calculate_grad_shape.py +++ b/frontend/catalyst/utils/calculate_grad_shape.py @@ -17,7 +17,7 @@ """ -from jax.core import ShapedArray +from jax.core import DShapedArray, ShapedArray, Tracer class Signature: @@ -76,15 +76,33 @@ def get_results(self): @staticmethod def is_tensor(x): - """Determine whether a type ``x`` is a ``jax.core.ShapedArray``. + """Determine whether a type ``x`` is a ``jax.core.DShapedArray`` + or ``jax.core.ShapedArray``. Args: x: The type to be tested. Returns: - bool: Whether the type ``x`` is a ``jax.core.ShapedArray`` + bool: Whether the type ``x`` is a ``jax.core.DShapedArray`` + or ``jax.core.ShapedArray`` """ - return isinstance(x, ShapedArray) + return isinstance(x, (DShapedArray, ShapedArray)) + + @staticmethod + def is_dynamic_shape(shape): + """Determine whether a shape contains a tracer or not. + + Args: + shape: The shape to be tested. + + Returns: + bool: Whether the shape contains a tracer or not. + """ + for s in shape: + if isinstance(s, Tracer): + return True + + return False def __eq__(self, other): return self.xs == other.xs and self.ys == other.ys @@ -123,8 +141,12 @@ def calculate_grad_shape(signature, indices) -> Signature: grad_res_shape.append(axis) element_type = diff_arg_type.dtype - grad_res_type = ( - ShapedArray(grad_res_shape, element_type) if grad_res_shape else diff_arg_type - ) + grad_res_type = diff_arg_type + if grad_res_shape: + XShapedArray = ( + DShapedArray if Signature.is_dynamic_shape(grad_res_shape) else ShapedArray + ) + grad_res_type = XShapedArray(grad_res_shape, element_type) + grad_result_types.append(grad_res_type) return Signature(signature.get_inputs(), grad_result_types) diff --git a/frontend/test/lit/test_measurements.py b/frontend/test/lit/test_measurements.py index 916d721688..0e2ab9c5f5 100644 --- a/frontend/test/lit/test_measurements.py +++ b/frontend/test/lit/test_measurements.py @@ -69,6 +69,7 @@ def sample2(x: float, y: float): @qjit(target="mlir") @qml.qnode(qml.device("lightning.qubit", wires=2, shots=1000)) def sample3(x: float, y: float): + # CHECK: [[shots:%.+]] = arith.constant 1000 : i64 qml.RX(x, wires=0) # CHECK: [[q1:%.+]] = quantum.custom "RY" qml.RY(y, wires=1) @@ -76,7 +77,7 @@ def sample3(x: float, y: float): qml.RZ(0.1, wires=0) # CHECK: [[obs:%.+]] = quantum.compbasis [[q0]], [[q1]] - # CHECK: quantum.sample [[obs]] {shots = 1000 : i64} : tensor<1000x2xf64> + # CHECK: quantum.sample [[obs]] [[shots]] : tensor<1000x2xf64> return qml.sample() @@ -131,6 +132,7 @@ def counts2(x: float, y: float): @qjit(target="mlir") @qml.qnode(qml.device("lightning.qubit", wires=2, shots=1000)) def counts3(x: float, y: float): + # CHECK: [[shots:%.+]] = arith.constant 1000 : i64 qml.RX(x, wires=0) # CHECK: [[q1:%.+]] = quantum.custom "RY" qml.RY(y, wires=1) @@ -138,7 +140,7 @@ def counts3(x: float, y: float): qml.RZ(0.1, wires=0) # CHECK: [[obs:%.+]] = quantum.compbasis [[q0]], [[q1]] - # CHECK: quantum.counts [[obs]] {shots = 1000 : i64} : tensor<4xf64>, tensor<4xi64> + # CHECK: quantum.counts [[obs]] [[shots]] : tensor<4xf64>, tensor<4xi64> return qml.counts() diff --git a/frontend/test/pytest/test_measurement_primitives.py b/frontend/test/pytest/test_measurement_primitives.py index 41229d2dc2..733177edf3 100644 --- a/frontend/test/pytest/test_measurement_primitives.py +++ b/frontend/test/pytest/test_measurement_primitives.py @@ -16,6 +16,7 @@ """ import jax +from catalyst.jax_extras import DynamicJaxprTracer from catalyst.jax_primitives import ( compbasis_p, counts_p, @@ -40,6 +41,25 @@ def f(): assert jaxpr.eqns[1].outvars[0].aval.shape == (5, 0) +def test_sample_dynamic(): + """Test that the sample primitive can be captured into jaxpr + when using tracers.""" + + def f(shots): + obs = compbasis_p.bind() + return sample_p.bind(obs, shots=shots, shape=(shots, 0)) + + jaxpr = jax.make_jaxpr(f)(5).jaxpr + shots_tracer, shape_value = jaxpr.eqns[1].params.values() + + assert jaxpr.eqns[1].primitive == sample_p + assert isinstance(shots_tracer, DynamicJaxprTracer) + assert isinstance(shape_value[0], DynamicJaxprTracer) + assert shape_value[1] == 0 + assert isinstance(jaxpr.eqns[1].outvars[0].aval.shape[0], jax.core.Var) + assert jaxpr.eqns[1].outvars[0].aval.shape[1] == 0 + + def test_counts(): """Test that the counts primitive can be captured by jaxpr.""" @@ -54,6 +74,24 @@ def f(): assert jaxpr.eqns[1].outvars[1].aval.shape == (1,) +def test_counts_dynamic(): + """Test that the counts primitive can be captured by jaxpr + when using tracers.""" + + def f(dim_0): + obs = compbasis_p.bind() + return counts_p.bind(obs, shots=5, shape=(dim_0,)) + + jaxpr = jax.make_jaxpr(f)(1) + shots, shape_value = jaxpr.eqns[1].params.values() + + assert jaxpr.eqns[1].primitive == counts_p + assert isinstance(shape_value[0], DynamicJaxprTracer) + assert shots == 5 + assert isinstance(jaxpr.eqns[1].outvars[0].aval.shape[0], jax.core.Var) + assert isinstance(jaxpr.eqns[1].outvars[1].aval.shape[0], jax.core.Var) + + def test_expval(): """Test that the expval primitive can be captured by jaxpr.""" @@ -81,7 +119,7 @@ def f(): def test_probs(): - """Test that the var primitive can be captured by jaxpr.""" + """Test that the probs primitive can be captured by jaxpr.""" def f(): obs = compbasis_p.bind() @@ -93,6 +131,23 @@ def f(): assert jaxpr.eqns[1].outvars[0].aval.shape == (1,) +def test_probs_dynamic(): + """Test that the probs primitive can be captured by jaxpr + when using tracers.""" + + def f(dim_0): + obs = compbasis_p.bind() + return probs_p.bind(obs, shots=5, shape=(dim_0,)) + + jaxpr = jax.make_jaxpr(f)(1) + shots, shape_value = jaxpr.eqns[1].params.values() + + assert jaxpr.eqns[1].primitive == probs_p + assert isinstance(shape_value[0], DynamicJaxprTracer) + assert shots == 5 + assert isinstance(jaxpr.eqns[1].outvars[0].aval.shape[0], jax.core.Var) + + def test_state(): """Test that the state primitive can be captured by jaxpr.""" @@ -104,3 +159,20 @@ def f(): assert jaxpr.eqns[1].primitive == state_p assert jaxpr.eqns[1].params == {"shape": (1,), "shots": 5} assert jaxpr.eqns[1].outvars[0].aval.shape == (1,) + + +def test_state_dynamic(): + """Test that the state primitive can be captured by jaxpr + when using tracers.""" + + def f(dim_0): + obs = compbasis_p.bind() + return state_p.bind(obs, shots=5, shape=(dim_0,)) + + jaxpr = jax.make_jaxpr(f)(1) + shots, shape_value = jaxpr.eqns[1].params.values() + + assert jaxpr.eqns[1].primitive == state_p + assert isinstance(shape_value[0], DynamicJaxprTracer) + assert shots == 5 + assert isinstance(jaxpr.eqns[1].outvars[0].aval.shape[0], jax.core.Var) diff --git a/mlir/include/Quantum/IR/QuantumOps.td b/mlir/include/Quantum/IR/QuantumOps.td index cf8532fdfc..d8bf319b63 100644 --- a/mlir/include/Quantum/IR/QuantumOps.td +++ b/mlir/include/Quantum/IR/QuantumOps.td @@ -742,13 +742,13 @@ def SampleOp : Measurement_Op<"sample"> { Example: ```mlir - func.func @foo(%q0: !quantum.bit, %q1: !quantum.bit) + func.func @foo(%q0: !quantum.bit, %q1: !quantum.bit, %shots : i64) { %obs1 = quantum.compbasis %q0, %q1 : !quantum.obs - %samples = quantum.samples %obs1 {shots=1000} : tensor<1000xf64> + %samples = quantum.samples %obs1 %shots : tensor<1000xf64> %obs2 = quantum.pauli %q0[3], %q1[1] : !quantum.obs - %samples2 = quantum.samples %obs2 {shots=1000} : tensor<1000x2xf64> + %samples2 = quantum.samples %obs2 %shots : tensor<1000x2xf64> func.return } @@ -757,13 +757,13 @@ def SampleOp : Measurement_Op<"sample"> { let arguments = (ins ObservableType:$obs, + I64:$shots, Optional< AnyTypeOf<[ MemRefRankOf<[F64], [1]>, MemRefRankOf<[F64], [2]> ]> - >:$in_data, - I64Attr:$shots + >:$in_data ); let results = (outs @@ -776,7 +776,7 @@ def SampleOp : Measurement_Op<"sample"> { ); let assemblyFormat = [{ - $obs ( `in` `(` $in_data^ `:` type($in_data) `)` )? attr-dict ( `:` type($samples)^ )? + $obs ( `in` `(` $in_data^ `:` type($in_data) `)` )? $shots attr-dict ( `:` type($samples)^ )? }]; let extraClassDeclaration = [{ @@ -809,10 +809,10 @@ def CountsOp : Measurement_Op<"counts", [SameVariadicOperandSize, SameVariadicRe func.func @foo(%q0: !quantum.bit, %q1: !quantum.bit) { %obs = quantum.compbasis %q0, %q1 : !quantum.obs - %counts = quantum.counts %obs {shots=1000} : tensor<4xf64>, tensor<4xi64> + %counts = quantum.counts %obs %shots : tensor<4xf64>, tensor<4xi64> %obs2 = quantum.pauli %q0[3], %q1[1] : !quantum.obs - %counts2 = quantum.counts %obs2 {shots=1000} : tensor<2xf64>, tensor<2xi64> + %counts2 = quantum.counts %obs2 %shots : tensor<2xf64>, tensor<2xi64> func.return } @@ -821,9 +821,9 @@ def CountsOp : Measurement_Op<"counts", [SameVariadicOperandSize, SameVariadicRe let arguments = (ins ObservableType:$obs, + I64:$shots, Optional>:$in_eigvals, - Optional>:$in_counts, - I64Attr:$shots + Optional>:$in_counts ); let results = (outs @@ -834,7 +834,7 @@ def CountsOp : Measurement_Op<"counts", [SameVariadicOperandSize, SameVariadicRe let assemblyFormat = [{ $obs ( `in` `(` $in_eigvals^ `:` type($in_eigvals) `,` $in_counts `:` type($in_counts) `)` )? - attr-dict ( `:` type($eigvals)^ `,` type($counts) )? + $shots attr-dict ( `:` type($eigvals)^ `,` type($counts) )? }]; let extraClassDeclaration = [{ diff --git a/mlir/lib/Quantum/IR/QuantumOps.cpp b/mlir/lib/Quantum/IR/QuantumOps.cpp index 9b86f63d7b..cce0d4d606 100644 --- a/mlir/lib/Quantum/IR/QuantumOps.cpp +++ b/mlir/lib/Quantum/IR/QuantumOps.cpp @@ -191,19 +191,6 @@ LogicalResult SampleOp::verify() return emitOpError("either tensors must be returned or memrefs must be used as inputs"); } - Type toVerify = getSamples() ? getSamples().getType() : getInData().getType(); - if (getObs().getDefiningOp() && - failed(verifyTensorResult(toVerify, getShots(), numQubits.value()))) { - // In the computational basis, Pennylane adds a second dimension for the number of qubits. - return emitOpError("return tensor must have 2D static shape equal to " - "(number of shots, number of qubits in observable)"); - } - else if (!getObs().getDefiningOp() && - failed(verifyTensorResult(toVerify, getShots()))) { - // For any given observables, Pennylane always returns a 1D tensor. - return emitOpError("return tensor must have 1D static shape equal to (number of shots)"); - } - return success(); } diff --git a/mlir/lib/Quantum/Transforms/BufferizationPatterns.cpp b/mlir/lib/Quantum/Transforms/BufferizationPatterns.cpp index 147c15c6f8..3e205b9be2 100644 --- a/mlir/lib/Quantum/Transforms/BufferizationPatterns.cpp +++ b/mlir/lib/Quantum/Transforms/BufferizationPatterns.cpp @@ -72,7 +72,8 @@ struct BufferizeSampleOp : public OpConversionPattern { MemRefType resultType = cast(getTypeConverter()->convertType(tensorType)); Location loc = op.getLoc(); Value allocVal = rewriter.replaceOpWithNewOp(op, resultType); - rewriter.create(loc, TypeRange{}, ValueRange{adaptor.getObs(), allocVal}, + rewriter.create(loc, TypeRange{}, + ValueRange{adaptor.getObs(), adaptor.getShots(), allocVal}, op->getAttrs()); return success(); } @@ -122,8 +123,8 @@ struct BufferizeCountsOp : public OpConversionPattern { Value allocVal0 = rewriter.create(loc, resultType0); Value allocVal1 = rewriter.create(loc, resultType1); rewriter.replaceOp(op, ValueRange{allocVal0, allocVal1}); - rewriter.create(loc, nullptr, nullptr, adaptor.getObs(), allocVal0, allocVal1, - adaptor.getShotsAttr()); + rewriter.create(loc, nullptr, nullptr, adaptor.getObs(), adaptor.getShots(), + allocVal0, allocVal1); return success(); } }; diff --git a/mlir/lib/Quantum/Transforms/ConversionPatterns.cpp b/mlir/lib/Quantum/Transforms/ConversionPatterns.cpp index 1983a61497..4adb9d7adf 100644 --- a/mlir/lib/Quantum/Transforms/ConversionPatterns.cpp +++ b/mlir/lib/Quantum/Transforms/ConversionPatterns.cpp @@ -733,16 +733,19 @@ template class SampleBasedPattern : public OpConversionPattern { assert(isa(adaptor.getObs().getDefiningOp())); ValueRange qubits = adaptor.getObs().getDefiningOp()->getOperands(); - Value numShots = rewriter.create(loc, op.getShotsAttr()); Value numQubits = rewriter.create(loc, rewriter.getI64IntegerAttr(qubits.size())); - SmallVector args = {structPtr, numShots, numQubits}; - args.insert(args.end(), qubits.begin(), qubits.end()); + + Value numShots; if constexpr (std::is_same_v) { + auto sampleAdaptor = cast(adaptor); + numShots = sampleAdaptor.getShots(); rewriter.create(loc, adaptor.getInData(), structPtr); } else if constexpr (std::is_same_v) { + auto countsAdaptor = cast(adaptor); + numShots = countsAdaptor.getShots(); auto aStruct = rewriter.create(loc, structType); auto bStruct = rewriter.create(loc, aStruct, adaptor.getInEigvals(), 0); @@ -751,6 +754,9 @@ template class SampleBasedPattern : public OpConversionPattern { rewriter.create(loc, cStruct, structPtr); } + SmallVector args = {structPtr, numShots, numQubits}; + args.insert(args.end(), qubits.begin(), qubits.end()); + rewriter.create(loc, fnDecl, args); return structPtr; diff --git a/mlir/test/Quantum/BufferizationTest.mlir b/mlir/test/Quantum/BufferizationTest.mlir index 2d58772dff..f882d81d50 100644 --- a/mlir/test/Quantum/BufferizationTest.mlir +++ b/mlir/test/Quantum/BufferizationTest.mlir @@ -18,18 +18,18 @@ // Measurements // ////////////////// -func.func @counts(%q0: !quantum.bit, %q1: !quantum.bit) -> (tensor<4xf64>, tensor<4xi64>) { +func.func @counts(%q0: !quantum.bit, %q1: !quantum.bit, %shots: i64) -> (tensor<4xf64>, tensor<4xi64>) { %obs = quantum.compbasis %q0, %q1 : !quantum.obs - %samples:2 = quantum.counts %obs {shots=2} : tensor<4xf64>, tensor<4xi64> + %samples:2 = quantum.counts %obs %shots : tensor<4xf64>, tensor<4xi64> func.return %samples#0, %samples#1 : tensor<4xf64>, tensor<4xi64> } // ----- -func.func @sample(%q0: !quantum.bit, %q1: !quantum.bit) { +func.func @sample(%q0: !quantum.bit, %q1: !quantum.bit, %shots: i64) { %obs = quantum.compbasis %q0, %q1 : !quantum.obs // CHECK: quantum.sample {{.*}} : memref<1000x2xf64> - %samples = quantum.sample %obs {shots=1000} : tensor<1000x2xf64> + %samples = quantum.sample %obs %shots : tensor<1000x2xf64> func.return } diff --git a/mlir/test/Quantum/ConversionTest.mlir b/mlir/test/Quantum/ConversionTest.mlir index 58e3d1154d..a9deeb67bb 100644 --- a/mlir/test/Quantum/ConversionTest.mlir +++ b/mlir/test/Quantum/ConversionTest.mlir @@ -409,25 +409,23 @@ func.func @measure(%q : !quantum.bit) -> !quantum.bit { // CHECK: llvm.func @__catalyst__qis__Sample(!llvm.ptr, i64, i64, ...) // CHECK-LABEL: @sample -func.func @sample(%q : !quantum.bit) { +func.func @sample(%q : !quantum.bit, %shots1 : i64, %shots2 : i64) { %o1 = quantum.compbasis %q : !quantum.obs %o2 = quantum.compbasis %q, %q : !quantum.obs // CHECK: [[c1:%.+]] = llvm.mlir.constant(1 : i64) // CHECK: [[ptr:%.+]] = llvm.alloca [[c1]] x !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // CHECK-DAG: [[c1000:%.+]] = llvm.mlir.constant(1000 : i64) // CHECK-DAG: [[c1:%.+]] = llvm.mlir.constant(1 : i64) - // CHECK: llvm.call @__catalyst__qis__Sample([[ptr]], [[c1000]], [[c1]], %arg0) + // CHECK: llvm.call @__catalyst__qis__Sample([[ptr]], %arg1, [[c1]], %arg0) %alloc1 = memref.alloc() : memref<1000x1xf64> - quantum.sample %o1 in(%alloc1 : memref<1000x1xf64>) {shots = 1000 : i64} + quantum.sample %o1 in(%alloc1 : memref<1000x1xf64>) %shots1 // CHECK: [[c1:%.+]] = llvm.mlir.constant(1 : i64) // CHECK: [[ptr:%.+]] = llvm.alloca [[c1]] x !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // CHECK: [[c2000:%.+]] = llvm.mlir.constant(2000 : i64) // CHECK: [[c2:%.+]] = llvm.mlir.constant(2 : i64) - // CHECK: llvm.call @__catalyst__qis__Sample([[ptr]], [[c2000]], [[c2]], %arg0, %arg0) + // CHECK: llvm.call @__catalyst__qis__Sample([[ptr]], %arg2, [[c2]], %arg0, %arg0) %alloc2 = memref.alloc() : memref<2000x2xf64> - quantum.sample %o2 in(%alloc2 : memref<2000x2xf64>) {shots = 2000 : i64} + quantum.sample %o2 in(%alloc2 : memref<2000x2xf64>) %shots2 return } @@ -437,27 +435,25 @@ func.func @sample(%q : !quantum.bit) { // CHECK: llvm.func @__catalyst__qis__Counts(!llvm.ptr, i64, i64, ...) // CHECK-LABEL: @counts -func.func @counts(%q : !quantum.bit) { +func.func @counts(%q : !quantum.bit, %shots1 : i64, %shots2 : i64) { %o1 = quantum.compbasis %q : !quantum.obs %o2 = quantum.compbasis %q, %q : !quantum.obs // CHECK: [[c1:%.+]] = llvm.mlir.constant(1 : i64) // CHECK: [[ptr:%.+]] = llvm.alloca [[c1]] x !llvm.struct<(struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK-DAG: [[c1000:%.+]] = llvm.mlir.constant(1000 : i64) // CHECK-DAG: [[c1:%.+]] = llvm.mlir.constant(1 : i64) - // CHECK: llvm.call @__catalyst__qis__Counts([[ptr]], [[c1000]], [[c1]], %arg0) + // CHECK: llvm.call @__catalyst__qis__Counts([[ptr]], %arg1, [[c1]], %arg0) %in_eigvals1 = memref.alloc() : memref<2xf64> %in_counts1 = memref.alloc() : memref<2xi64> - quantum.counts %o1 in(%in_eigvals1 : memref<2xf64>, %in_counts1 : memref<2xi64>) {shots = 1000 : i64} + quantum.counts %o1 in(%in_eigvals1 : memref<2xf64>, %in_counts1 : memref<2xi64>) %shots1 // CHECK: [[c1:%.+]] = llvm.mlir.constant(1 : i64) // CHECK: [[ptr:%.+]] = llvm.alloca [[c1]] x !llvm.struct<(struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[c2000:%.+]] = llvm.mlir.constant(2000 : i64) // CHECK: [[c2:%.+]] = llvm.mlir.constant(2 : i64) - // CHECK: llvm.call @__catalyst__qis__Counts([[ptr]], [[c2000]], [[c2]], %arg0, %arg0) + // CHECK: llvm.call @__catalyst__qis__Counts([[ptr]], %arg2, [[c2]], %arg0, %arg0) %in_eigvals2 = memref.alloc() : memref<4xf64> %in_counts2 = memref.alloc() : memref<4xi64> - quantum.counts %o2 in(%in_eigvals2 : memref<4xf64>, %in_counts2 : memref<4xi64>) {shots = 2000 : i64} + quantum.counts %o2 in(%in_eigvals2 : memref<4xf64>, %in_counts2 : memref<4xi64>) %shots2 return } diff --git a/mlir/test/Quantum/VerifierTest.mlir b/mlir/test/Quantum/VerifierTest.mlir index 0e23bc615d..21260f55eb 100644 --- a/mlir/test/Quantum/VerifierTest.mlir +++ b/mlir/test/Quantum/VerifierTest.mlir @@ -178,133 +178,130 @@ func.func @tensorobs(%q0 : !quantum.bit, %q1 : !quantum.bit, %q2 : !quantum.bit) // ----- -func.func @sample1(%q : !quantum.bit) { +func.func @sample1(%q : !quantum.bit, %shots: i64) { %obs = quantum.namedobs %q[Identity] : !quantum.obs - // expected-error@+1 {{return tensor must have 1D static shape equal to (number of shots)}} - %err = quantum.sample %obs { shots=1000 } : tensor<1xf64> + %err = quantum.sample %obs %shots : tensor<1xf64> - %samples = quantum.sample %obs { shots=1000 } : tensor<1000xf64> + %samples = quantum.sample %obs %shots : tensor<1000xf64> return } // ----- -func.func @sample2(%q : !quantum.bit) { +func.func @sample2(%q : !quantum.bit, %shots: i64) { %obs = quantum.compbasis %q : !quantum.obs - // expected-error@+1 {{return tensor must have 2D static shape equal to (number of shots, number of qubits in observable)}} - %err = quantum.sample %obs { shots=1000 } : tensor<1000xf64> + %err = quantum.sample %obs %shots : tensor<1000xf64> - %samples = quantum.sample %obs { shots=1000 } : tensor<1000x1xf64> + %samples = quantum.sample %obs %shots : tensor<1000x1xf64> return } // ----- -func.func @sample3(%q : !quantum.bit) { +func.func @sample3(%q : !quantum.bit, %shots: i64) { %obs = quantum.compbasis %q : !quantum.obs %alloc0 = memref.alloc() : memref<1000xf64> - // expected-error@+1 {{return tensor must have 2D static shape equal to (number of shots, number of qubits in observable)}} - quantum.sample %obs in(%alloc0 : memref<1000xf64>) { shots = 1000 } + quantum.sample %obs in(%alloc0 : memref<1000xf64>) %shots %alloc1 = memref.alloc() : memref<1000x1xf64> - quantum.sample %obs in(%alloc1 : memref<1000x1xf64>) { shots = 1000 } + quantum.sample %obs in(%alloc1 : memref<1000x1xf64>) %shots return } // ----- -func.func @sample4(%q : !quantum.bit) { +func.func @sample4(%q : !quantum.bit, %shots: i64) { %obs = quantum.compbasis %q : !quantum.obs %alloc = memref.alloc() : memref<1000xf64> // expected-error@+1 {{either tensors must be returned or memrefs must be used as inputs}} - quantum.sample %obs in (%alloc : memref<1000xf64>) { shots=1000 } : tensor<1000xf64> + quantum.sample %obs in (%alloc : memref<1000xf64>) %shots : tensor<1000xf64> - %samples = quantum.sample %obs { shots=1000 } : tensor<1000x1xf64> + %samples = quantum.sample %obs %shots : tensor<1000x1xf64> return } // ----- -func.func @sample5(%q : !quantum.bit) { +func.func @sample5(%q : !quantum.bit, %shots: i64) { %obs = quantum.compbasis %q : !quantum.obs // expected-error@+1 {{either tensors must be returned or memrefs must be used as inputs}} - quantum.sample %obs { shots=1000 } + quantum.sample %obs %shots return } // ----- -func.func @counts1(%q0 : !quantum.bit, %q1 : !quantum.bit) { +func.func @counts1(%q0 : !quantum.bit, %q1 : !quantum.bit, %shots: i64) { %obs = quantum.namedobs %q0[PauliX] : !quantum.obs // expected-error@+1 {{number of eigenvalues or counts did not match observable}} - %err:2 = quantum.counts %obs { shots=1000 } : tensor<4xf64>, tensor<4xi64> + %err:2 = quantum.counts %obs %shots : tensor<4xf64>, tensor<4xi64> - %counts:2 = quantum.counts %obs { shots=1000 } : tensor<2xf64>, tensor<2xi64> + %counts:2 = quantum.counts %obs %shots : tensor<2xf64>, tensor<2xi64> return } // ----- -func.func @counts2(%q0 : !quantum.bit, %q1 : !quantum.bit) { +func.func @counts2(%q0 : !quantum.bit, %q1 : !quantum.bit, %shots: i64) { %obs = quantum.compbasis %q0, %q1 : !quantum.obs // expected-error@+1 {{number of eigenvalues or counts did not match observable}} - %err:2 = quantum.counts %obs { shots=1000 } : tensor<2xf64>, tensor<2xi64> + %err:2 = quantum.counts %obs %shots : tensor<2xf64>, tensor<2xi64> - %counts:2 = quantum.counts %obs { shots=1000 } : tensor<4xf64>, tensor<4xi64> + %counts:2 = quantum.counts %obs %shots : tensor<4xf64>, tensor<4xi64> return } // ----- -func.func @counts3(%q0 : !quantum.bit, %q1 : !quantum.bit) { +func.func @counts3(%q0 : !quantum.bit, %q1 : !quantum.bit, %shots: i64) { %obs = quantum.namedobs %q0[PauliX] : !quantum.obs %in_eigvals_1 = memref.alloc() : memref<4xf64> %in_counts_1 = memref.alloc() : memref<4xi64> // expected-error@+1 {{number of eigenvalues or counts did not match observable}} - quantum.counts %obs in(%in_eigvals_1 : memref<4xf64>, %in_counts_1 : memref<4xi64>) { shots=1000 } + quantum.counts %obs in(%in_eigvals_1 : memref<4xf64>, %in_counts_1 : memref<4xi64>) %shots %in_eigvals_2 = memref.alloc() : memref<2xf64> %in_counts_2 = memref.alloc() : memref<2xi64> - quantum.counts %obs in(%in_eigvals_2 : memref<2xf64>, %in_counts_2 : memref<2xi64>) { shots=1000 } + quantum.counts %obs in(%in_eigvals_2 : memref<2xf64>, %in_counts_2 : memref<2xi64>) %shots return } // ----- -func.func @counts4(%q0 : !quantum.bit, %q1 : !quantum.bit) { +func.func @counts4(%q0 : !quantum.bit, %q1 : !quantum.bit, %shots: i64) { %obs = quantum.namedobs %q0[PauliX] : !quantum.obs // expected-error@+1 {{either tensors must be returned or memrefs must be used as inputs}} - quantum.counts %obs { shots=1000 } + quantum.counts %obs %shots return } // ----- -func.func @counts5(%q0 : !quantum.bit, %q1 : !quantum.bit) { +func.func @counts5(%q0 : !quantum.bit, %q1 : !quantum.bit, %shots: i64) { %obs = quantum.namedobs %q0[PauliX] : !quantum.obs %in_eigvals = memref.alloc() : memref<2xf64> %in_counts = memref.alloc() : memref<2xi64> // expected-error@+1 {{either tensors must be returned or memrefs must be used as inputs}} - quantum.counts %obs in(%in_eigvals : memref<2xf64>, %in_counts : memref<2xi64>) { shots=1000 } : tensor<2xf64>, tensor<2xi64> + quantum.counts %obs in(%in_eigvals : memref<2xf64>, %in_counts : memref<2xi64>) %shots : tensor<2xf64>, tensor<2xi64> return }