diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 2083694c0b..86f2e10454 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -333,6 +333,9 @@
this nested module into other IRs/binary formats and lowering `call_function_in_module`
to something that can dispatch calls to another runtime / VM.
+* Measurement primitives now support dynamic shape at the frontend, although at the PennyLane
+ side, the corresponding operations still lack such support.
+ [(#1170)](https://github.com/PennyLaneAI/catalyst/pull/1170)
Contributors
diff --git a/frontend/catalyst/api_extensions/differentiation.py b/frontend/catalyst/api_extensions/differentiation.py
index 7f66b38348..e6801229ab 100644
--- a/frontend/catalyst/api_extensions/differentiation.py
+++ b/frontend/catalyst/api_extensions/differentiation.py
@@ -30,7 +30,7 @@
from pennylane import QNode
import catalyst
-from catalyst.jax_extras import Jaxpr
+from catalyst.jax_extras import Jaxpr, make_jaxpr2
from catalyst.jax_primitives import (
GradParams,
expval_p,
@@ -807,8 +807,7 @@ def _make_jaxpr_check_differentiable(
return the output tree."""
method = grad_params.method
with mark_gradient_tracing(method):
- jaxpr, shape = jax.make_jaxpr(f, return_shape=True)(*args, **kwargs)
- _, out_tree = tree_flatten(shape)
+ jaxpr, _, out_tree = make_jaxpr2(f)(*args, **kwargs)
for pos, arg in enumerate(jaxpr.in_avals):
if arg.dtype.kind != "f" and pos in grad_params.expanded_argnums:
diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py
index 03a68697be..9ba097a0b4 100644
--- a/frontend/catalyst/jax_primitives.py
+++ b/frontend/catalyst/jax_primitives.py
@@ -1658,10 +1658,8 @@ def _hamiltonian_lowering(jax_ctx: mlir.LoweringRuleContext, coeffs: ir.Value, *
def _sample_abstract_eval(obs, shots, shape):
assert isinstance(obs, AbstractObs)
- if obs.primitive is compbasis_p:
- assert shape == (shots, obs.num_qubits)
- else:
- assert shape == (shots,)
+ if Signature.is_dynamic_shape(shape):
+ return core.DShapedArray(shape, np.dtype("float64"))
return core.ShapedArray(shape, jax.numpy.float64)
@@ -1671,16 +1669,18 @@ def _sample_def_impl(ctx, obs, shots, shape): # pragma: no cover
raise NotImplementedError()
-def _sample_lowering(jax_ctx: mlir.LoweringRuleContext, obs: ir.Value, shots: int, shape: tuple):
+def _sample_lowering(
+ jax_ctx: mlir.LoweringRuleContext, obs: ir.Value, shots: ir.Value, shape: tuple
+):
ctx = jax_ctx.module_context.context
ctx.allow_unregistered_dialects = True
i64_type = ir.IntegerType.get_signless(64, ctx)
- shots_attr = ir.IntegerAttr.get(i64_type, shots)
+ shots_val = TensorExtractOp(i64_type, shots, []).result
f64_type = ir.F64Type.get()
result_type = ir.RankedTensorType.get(shape, f64_type)
- return SampleOp(result_type, obs, shots_attr).results
+ return SampleOp(result_type, obs, shots_val).results
#
@@ -1695,25 +1695,27 @@ def _counts_def_impl(ctx, obs, shots, shape): # pragma: no cover
def _counts_abstract_eval(obs, shots, shape):
assert isinstance(obs, AbstractObs)
- if obs.primitive is compbasis_p:
- assert shape == (2**obs.num_qubits,)
- else:
- assert shape == (2,)
+ if Signature.is_dynamic_shape(shape):
+ return core.DShapedArray(shape, np.dtype("float64")), core.DShapedArray(
+ shape, np.dtype("int64")
+ )
return core.ShapedArray(shape, jax.numpy.float64), core.ShapedArray(shape, jax.numpy.int64)
-def _counts_lowering(jax_ctx: mlir.LoweringRuleContext, obs: ir.Value, shots: int, shape: tuple):
+def _counts_lowering(
+ jax_ctx: mlir.LoweringRuleContext, obs: ir.Value, shots: ir.Value, shape: tuple
+):
ctx = jax_ctx.module_context.context
ctx.allow_unregistered_dialects = True
i64_type = ir.IntegerType.get_signless(64, ctx)
- shots_attr = ir.IntegerAttr.get(i64_type, shots)
+ shots_val = TensorExtractOp(i64_type, shots, []).result
f64_type = ir.F64Type.get()
eigvals_type = ir.RankedTensorType.get(shape, f64_type)
counts_type = ir.RankedTensorType.get(shape, i64_type)
- return CountsOp(eigvals_type, counts_type, obs, shots_attr).results
+ return CountsOp(eigvals_type, counts_type, obs, shots_val).results
#
@@ -1787,10 +1789,10 @@ def _var_lowering(jax_ctx: mlir.LoweringRuleContext, obs: ir.Value, shots: int,
def _probs_abstract_eval(obs, shape, shots=None):
assert isinstance(obs, AbstractObs)
- if obs.primitive is compbasis_p:
- assert shape == (2**obs.num_qubits,)
- else:
- raise TypeError("probs only supports computational basis")
+ assert obs.primitive is compbasis_p, "probs only supports computational basis"
+
+ if Signature.is_dynamic_shape(shape):
+ return core.DShapedArray(shape, np.dtype("float64"))
return core.ShapedArray(shape, jax.numpy.float64)
@@ -1816,10 +1818,10 @@ def _probs_lowering(jax_ctx: mlir.LoweringRuleContext, obs: ir.Value, shape: tup
def _state_abstract_eval(obs, shape, shots=None):
assert isinstance(obs, AbstractObs)
- if obs.primitive is compbasis_p:
- assert shape == (2**obs.num_qubits,)
- else:
- raise TypeError("state only supports computational basis")
+ assert obs.primitive is compbasis_p, "state only supports computational basis"
+
+ if Signature.is_dynamic_shape(shape):
+ return core.DShapedArray(shape, np.dtype("complex128"))
return core.ShapedArray(shape, jax.numpy.complex128)
diff --git a/frontend/catalyst/jax_tracer.py b/frontend/catalyst/jax_tracer.py
index 1bfb118626..ad51ab8c76 100644
--- a/frontend/catalyst/jax_tracer.py
+++ b/frontend/catalyst/jax_tracer.py
@@ -876,7 +876,7 @@ def trace_quantum_measurements(
out_classical_tracers.append(o.mv)
else:
shape = (shots, nqubits) if using_compbasis else (shots,)
- result = sample_p.bind(obs_tracers, shots=shots, shape=shape)
+ result = sample_p.bind(obs_tracers, shots, shape=shape)
if using_compbasis:
result = jnp.astype(result, jnp.int64)
@@ -909,7 +909,7 @@ def trace_quantum_measurements(
"Please specify a finite number of shots."
)
shape = (2**nqubits,) if using_compbasis else (2,)
- results = counts_p.bind(obs_tracers, shots=shots, shape=shape)
+ results = counts_p.bind(obs_tracers, shots, shape=shape)
if using_compbasis:
results = (jnp.asarray(results[0], jnp.int64), results[1])
out_classical_tracers.extend(results)
diff --git a/frontend/catalyst/utils/calculate_grad_shape.py b/frontend/catalyst/utils/calculate_grad_shape.py
index 5465045bce..bcce6b0b63 100644
--- a/frontend/catalyst/utils/calculate_grad_shape.py
+++ b/frontend/catalyst/utils/calculate_grad_shape.py
@@ -17,7 +17,7 @@
"""
-from jax.core import ShapedArray
+from jax.core import DShapedArray, ShapedArray, Tracer
class Signature:
@@ -76,15 +76,33 @@ def get_results(self):
@staticmethod
def is_tensor(x):
- """Determine whether a type ``x`` is a ``jax.core.ShapedArray``.
+ """Determine whether a type ``x`` is a ``jax.core.DShapedArray``
+ or ``jax.core.ShapedArray``.
Args:
x: The type to be tested.
Returns:
- bool: Whether the type ``x`` is a ``jax.core.ShapedArray``
+ bool: Whether the type ``x`` is a ``jax.core.DShapedArray``
+ or ``jax.core.ShapedArray``
"""
- return isinstance(x, ShapedArray)
+ return isinstance(x, (DShapedArray, ShapedArray))
+
+ @staticmethod
+ def is_dynamic_shape(shape):
+ """Determine whether a shape contains a tracer or not.
+
+ Args:
+ shape: The shape to be tested.
+
+ Returns:
+ bool: Whether the shape contains a tracer or not.
+ """
+ for s in shape:
+ if isinstance(s, Tracer):
+ return True
+
+ return False
def __eq__(self, other):
return self.xs == other.xs and self.ys == other.ys
@@ -123,8 +141,12 @@ def calculate_grad_shape(signature, indices) -> Signature:
grad_res_shape.append(axis)
element_type = diff_arg_type.dtype
- grad_res_type = (
- ShapedArray(grad_res_shape, element_type) if grad_res_shape else diff_arg_type
- )
+ grad_res_type = diff_arg_type
+ if grad_res_shape:
+ XShapedArray = (
+ DShapedArray if Signature.is_dynamic_shape(grad_res_shape) else ShapedArray
+ )
+ grad_res_type = XShapedArray(grad_res_shape, element_type)
+
grad_result_types.append(grad_res_type)
return Signature(signature.get_inputs(), grad_result_types)
diff --git a/frontend/test/lit/test_measurements.py b/frontend/test/lit/test_measurements.py
index 916d721688..0e2ab9c5f5 100644
--- a/frontend/test/lit/test_measurements.py
+++ b/frontend/test/lit/test_measurements.py
@@ -69,6 +69,7 @@ def sample2(x: float, y: float):
@qjit(target="mlir")
@qml.qnode(qml.device("lightning.qubit", wires=2, shots=1000))
def sample3(x: float, y: float):
+ # CHECK: [[shots:%.+]] = arith.constant 1000 : i64
qml.RX(x, wires=0)
# CHECK: [[q1:%.+]] = quantum.custom "RY"
qml.RY(y, wires=1)
@@ -76,7 +77,7 @@ def sample3(x: float, y: float):
qml.RZ(0.1, wires=0)
# CHECK: [[obs:%.+]] = quantum.compbasis [[q0]], [[q1]]
- # CHECK: quantum.sample [[obs]] {shots = 1000 : i64} : tensor<1000x2xf64>
+ # CHECK: quantum.sample [[obs]] [[shots]] : tensor<1000x2xf64>
return qml.sample()
@@ -131,6 +132,7 @@ def counts2(x: float, y: float):
@qjit(target="mlir")
@qml.qnode(qml.device("lightning.qubit", wires=2, shots=1000))
def counts3(x: float, y: float):
+ # CHECK: [[shots:%.+]] = arith.constant 1000 : i64
qml.RX(x, wires=0)
# CHECK: [[q1:%.+]] = quantum.custom "RY"
qml.RY(y, wires=1)
@@ -138,7 +140,7 @@ def counts3(x: float, y: float):
qml.RZ(0.1, wires=0)
# CHECK: [[obs:%.+]] = quantum.compbasis [[q0]], [[q1]]
- # CHECK: quantum.counts [[obs]] {shots = 1000 : i64} : tensor<4xf64>, tensor<4xi64>
+ # CHECK: quantum.counts [[obs]] [[shots]] : tensor<4xf64>, tensor<4xi64>
return qml.counts()
diff --git a/frontend/test/pytest/test_measurement_primitives.py b/frontend/test/pytest/test_measurement_primitives.py
index 41229d2dc2..733177edf3 100644
--- a/frontend/test/pytest/test_measurement_primitives.py
+++ b/frontend/test/pytest/test_measurement_primitives.py
@@ -16,6 +16,7 @@
"""
import jax
+from catalyst.jax_extras import DynamicJaxprTracer
from catalyst.jax_primitives import (
compbasis_p,
counts_p,
@@ -40,6 +41,25 @@ def f():
assert jaxpr.eqns[1].outvars[0].aval.shape == (5, 0)
+def test_sample_dynamic():
+ """Test that the sample primitive can be captured into jaxpr
+ when using tracers."""
+
+ def f(shots):
+ obs = compbasis_p.bind()
+ return sample_p.bind(obs, shots=shots, shape=(shots, 0))
+
+ jaxpr = jax.make_jaxpr(f)(5).jaxpr
+ shots_tracer, shape_value = jaxpr.eqns[1].params.values()
+
+ assert jaxpr.eqns[1].primitive == sample_p
+ assert isinstance(shots_tracer, DynamicJaxprTracer)
+ assert isinstance(shape_value[0], DynamicJaxprTracer)
+ assert shape_value[1] == 0
+ assert isinstance(jaxpr.eqns[1].outvars[0].aval.shape[0], jax.core.Var)
+ assert jaxpr.eqns[1].outvars[0].aval.shape[1] == 0
+
+
def test_counts():
"""Test that the counts primitive can be captured by jaxpr."""
@@ -54,6 +74,24 @@ def f():
assert jaxpr.eqns[1].outvars[1].aval.shape == (1,)
+def test_counts_dynamic():
+ """Test that the counts primitive can be captured by jaxpr
+ when using tracers."""
+
+ def f(dim_0):
+ obs = compbasis_p.bind()
+ return counts_p.bind(obs, shots=5, shape=(dim_0,))
+
+ jaxpr = jax.make_jaxpr(f)(1)
+ shots, shape_value = jaxpr.eqns[1].params.values()
+
+ assert jaxpr.eqns[1].primitive == counts_p
+ assert isinstance(shape_value[0], DynamicJaxprTracer)
+ assert shots == 5
+ assert isinstance(jaxpr.eqns[1].outvars[0].aval.shape[0], jax.core.Var)
+ assert isinstance(jaxpr.eqns[1].outvars[1].aval.shape[0], jax.core.Var)
+
+
def test_expval():
"""Test that the expval primitive can be captured by jaxpr."""
@@ -81,7 +119,7 @@ def f():
def test_probs():
- """Test that the var primitive can be captured by jaxpr."""
+ """Test that the probs primitive can be captured by jaxpr."""
def f():
obs = compbasis_p.bind()
@@ -93,6 +131,23 @@ def f():
assert jaxpr.eqns[1].outvars[0].aval.shape == (1,)
+def test_probs_dynamic():
+ """Test that the probs primitive can be captured by jaxpr
+ when using tracers."""
+
+ def f(dim_0):
+ obs = compbasis_p.bind()
+ return probs_p.bind(obs, shots=5, shape=(dim_0,))
+
+ jaxpr = jax.make_jaxpr(f)(1)
+ shots, shape_value = jaxpr.eqns[1].params.values()
+
+ assert jaxpr.eqns[1].primitive == probs_p
+ assert isinstance(shape_value[0], DynamicJaxprTracer)
+ assert shots == 5
+ assert isinstance(jaxpr.eqns[1].outvars[0].aval.shape[0], jax.core.Var)
+
+
def test_state():
"""Test that the state primitive can be captured by jaxpr."""
@@ -104,3 +159,20 @@ def f():
assert jaxpr.eqns[1].primitive == state_p
assert jaxpr.eqns[1].params == {"shape": (1,), "shots": 5}
assert jaxpr.eqns[1].outvars[0].aval.shape == (1,)
+
+
+def test_state_dynamic():
+ """Test that the state primitive can be captured by jaxpr
+ when using tracers."""
+
+ def f(dim_0):
+ obs = compbasis_p.bind()
+ return state_p.bind(obs, shots=5, shape=(dim_0,))
+
+ jaxpr = jax.make_jaxpr(f)(1)
+ shots, shape_value = jaxpr.eqns[1].params.values()
+
+ assert jaxpr.eqns[1].primitive == state_p
+ assert isinstance(shape_value[0], DynamicJaxprTracer)
+ assert shots == 5
+ assert isinstance(jaxpr.eqns[1].outvars[0].aval.shape[0], jax.core.Var)
diff --git a/mlir/include/Quantum/IR/QuantumOps.td b/mlir/include/Quantum/IR/QuantumOps.td
index cf8532fdfc..d8bf319b63 100644
--- a/mlir/include/Quantum/IR/QuantumOps.td
+++ b/mlir/include/Quantum/IR/QuantumOps.td
@@ -742,13 +742,13 @@ def SampleOp : Measurement_Op<"sample"> {
Example:
```mlir
- func.func @foo(%q0: !quantum.bit, %q1: !quantum.bit)
+ func.func @foo(%q0: !quantum.bit, %q1: !quantum.bit, %shots : i64)
{
%obs1 = quantum.compbasis %q0, %q1 : !quantum.obs
- %samples = quantum.samples %obs1 {shots=1000} : tensor<1000xf64>
+ %samples = quantum.samples %obs1 %shots : tensor<1000xf64>
%obs2 = quantum.pauli %q0[3], %q1[1] : !quantum.obs
- %samples2 = quantum.samples %obs2 {shots=1000} : tensor<1000x2xf64>
+ %samples2 = quantum.samples %obs2 %shots : tensor<1000x2xf64>
func.return
}
@@ -757,13 +757,13 @@ def SampleOp : Measurement_Op<"sample"> {
let arguments = (ins
ObservableType:$obs,
+ I64:$shots,
Optional<
AnyTypeOf<[
MemRefRankOf<[F64], [1]>,
MemRefRankOf<[F64], [2]>
]>
- >:$in_data,
- I64Attr:$shots
+ >:$in_data
);
let results = (outs
@@ -776,7 +776,7 @@ def SampleOp : Measurement_Op<"sample"> {
);
let assemblyFormat = [{
- $obs ( `in` `(` $in_data^ `:` type($in_data) `)` )? attr-dict ( `:` type($samples)^ )?
+ $obs ( `in` `(` $in_data^ `:` type($in_data) `)` )? $shots attr-dict ( `:` type($samples)^ )?
}];
let extraClassDeclaration = [{
@@ -809,10 +809,10 @@ def CountsOp : Measurement_Op<"counts", [SameVariadicOperandSize, SameVariadicRe
func.func @foo(%q0: !quantum.bit, %q1: !quantum.bit)
{
%obs = quantum.compbasis %q0, %q1 : !quantum.obs
- %counts = quantum.counts %obs {shots=1000} : tensor<4xf64>, tensor<4xi64>
+ %counts = quantum.counts %obs %shots : tensor<4xf64>, tensor<4xi64>
%obs2 = quantum.pauli %q0[3], %q1[1] : !quantum.obs
- %counts2 = quantum.counts %obs2 {shots=1000} : tensor<2xf64>, tensor<2xi64>
+ %counts2 = quantum.counts %obs2 %shots : tensor<2xf64>, tensor<2xi64>
func.return
}
@@ -821,9 +821,9 @@ def CountsOp : Measurement_Op<"counts", [SameVariadicOperandSize, SameVariadicRe
let arguments = (ins
ObservableType:$obs,
+ I64:$shots,
Optional>:$in_eigvals,
- Optional>:$in_counts,
- I64Attr:$shots
+ Optional>:$in_counts
);
let results = (outs
@@ -834,7 +834,7 @@ def CountsOp : Measurement_Op<"counts", [SameVariadicOperandSize, SameVariadicRe
let assemblyFormat = [{
$obs
( `in` `(` $in_eigvals^ `:` type($in_eigvals) `,` $in_counts `:` type($in_counts) `)` )?
- attr-dict ( `:` type($eigvals)^ `,` type($counts) )?
+ $shots attr-dict ( `:` type($eigvals)^ `,` type($counts) )?
}];
let extraClassDeclaration = [{
diff --git a/mlir/lib/Quantum/IR/QuantumOps.cpp b/mlir/lib/Quantum/IR/QuantumOps.cpp
index 9b86f63d7b..cce0d4d606 100644
--- a/mlir/lib/Quantum/IR/QuantumOps.cpp
+++ b/mlir/lib/Quantum/IR/QuantumOps.cpp
@@ -191,19 +191,6 @@ LogicalResult SampleOp::verify()
return emitOpError("either tensors must be returned or memrefs must be used as inputs");
}
- Type toVerify = getSamples() ? getSamples().getType() : getInData().getType();
- if (getObs().getDefiningOp() &&
- failed(verifyTensorResult(toVerify, getShots(), numQubits.value()))) {
- // In the computational basis, Pennylane adds a second dimension for the number of qubits.
- return emitOpError("return tensor must have 2D static shape equal to "
- "(number of shots, number of qubits in observable)");
- }
- else if (!getObs().getDefiningOp() &&
- failed(verifyTensorResult(toVerify, getShots()))) {
- // For any given observables, Pennylane always returns a 1D tensor.
- return emitOpError("return tensor must have 1D static shape equal to (number of shots)");
- }
-
return success();
}
diff --git a/mlir/lib/Quantum/Transforms/BufferizationPatterns.cpp b/mlir/lib/Quantum/Transforms/BufferizationPatterns.cpp
index 147c15c6f8..3e205b9be2 100644
--- a/mlir/lib/Quantum/Transforms/BufferizationPatterns.cpp
+++ b/mlir/lib/Quantum/Transforms/BufferizationPatterns.cpp
@@ -72,7 +72,8 @@ struct BufferizeSampleOp : public OpConversionPattern {
MemRefType resultType = cast(getTypeConverter()->convertType(tensorType));
Location loc = op.getLoc();
Value allocVal = rewriter.replaceOpWithNewOp(op, resultType);
- rewriter.create(loc, TypeRange{}, ValueRange{adaptor.getObs(), allocVal},
+ rewriter.create(loc, TypeRange{},
+ ValueRange{adaptor.getObs(), adaptor.getShots(), allocVal},
op->getAttrs());
return success();
}
@@ -122,8 +123,8 @@ struct BufferizeCountsOp : public OpConversionPattern {
Value allocVal0 = rewriter.create(loc, resultType0);
Value allocVal1 = rewriter.create(loc, resultType1);
rewriter.replaceOp(op, ValueRange{allocVal0, allocVal1});
- rewriter.create(loc, nullptr, nullptr, adaptor.getObs(), allocVal0, allocVal1,
- adaptor.getShotsAttr());
+ rewriter.create(loc, nullptr, nullptr, adaptor.getObs(), adaptor.getShots(),
+ allocVal0, allocVal1);
return success();
}
};
diff --git a/mlir/lib/Quantum/Transforms/ConversionPatterns.cpp b/mlir/lib/Quantum/Transforms/ConversionPatterns.cpp
index 1983a61497..4adb9d7adf 100644
--- a/mlir/lib/Quantum/Transforms/ConversionPatterns.cpp
+++ b/mlir/lib/Quantum/Transforms/ConversionPatterns.cpp
@@ -733,16 +733,19 @@ template class SampleBasedPattern : public OpConversionPattern {
assert(isa(adaptor.getObs().getDefiningOp()));
ValueRange qubits = adaptor.getObs().getDefiningOp()->getOperands();
- Value numShots = rewriter.create(loc, op.getShotsAttr());
Value numQubits =
rewriter.create(loc, rewriter.getI64IntegerAttr(qubits.size()));
- SmallVector args = {structPtr, numShots, numQubits};
- args.insert(args.end(), qubits.begin(), qubits.end());
+
+ Value numShots;
if constexpr (std::is_same_v) {
+ auto sampleAdaptor = cast(adaptor);
+ numShots = sampleAdaptor.getShots();
rewriter.create(loc, adaptor.getInData(), structPtr);
}
else if constexpr (std::is_same_v) {
+ auto countsAdaptor = cast(adaptor);
+ numShots = countsAdaptor.getShots();
auto aStruct = rewriter.create(loc, structType);
auto bStruct =
rewriter.create(loc, aStruct, adaptor.getInEigvals(), 0);
@@ -751,6 +754,9 @@ template class SampleBasedPattern : public OpConversionPattern {
rewriter.create(loc, cStruct, structPtr);
}
+ SmallVector args = {structPtr, numShots, numQubits};
+ args.insert(args.end(), qubits.begin(), qubits.end());
+
rewriter.create(loc, fnDecl, args);
return structPtr;
diff --git a/mlir/test/Quantum/BufferizationTest.mlir b/mlir/test/Quantum/BufferizationTest.mlir
index 2d58772dff..f882d81d50 100644
--- a/mlir/test/Quantum/BufferizationTest.mlir
+++ b/mlir/test/Quantum/BufferizationTest.mlir
@@ -18,18 +18,18 @@
// Measurements //
//////////////////
-func.func @counts(%q0: !quantum.bit, %q1: !quantum.bit) -> (tensor<4xf64>, tensor<4xi64>) {
+func.func @counts(%q0: !quantum.bit, %q1: !quantum.bit, %shots: i64) -> (tensor<4xf64>, tensor<4xi64>) {
%obs = quantum.compbasis %q0, %q1 : !quantum.obs
- %samples:2 = quantum.counts %obs {shots=2} : tensor<4xf64>, tensor<4xi64>
+ %samples:2 = quantum.counts %obs %shots : tensor<4xf64>, tensor<4xi64>
func.return %samples#0, %samples#1 : tensor<4xf64>, tensor<4xi64>
}
// -----
-func.func @sample(%q0: !quantum.bit, %q1: !quantum.bit) {
+func.func @sample(%q0: !quantum.bit, %q1: !quantum.bit, %shots: i64) {
%obs = quantum.compbasis %q0, %q1 : !quantum.obs
// CHECK: quantum.sample {{.*}} : memref<1000x2xf64>
- %samples = quantum.sample %obs {shots=1000} : tensor<1000x2xf64>
+ %samples = quantum.sample %obs %shots : tensor<1000x2xf64>
func.return
}
diff --git a/mlir/test/Quantum/ConversionTest.mlir b/mlir/test/Quantum/ConversionTest.mlir
index 58e3d1154d..a9deeb67bb 100644
--- a/mlir/test/Quantum/ConversionTest.mlir
+++ b/mlir/test/Quantum/ConversionTest.mlir
@@ -409,25 +409,23 @@ func.func @measure(%q : !quantum.bit) -> !quantum.bit {
// CHECK: llvm.func @__catalyst__qis__Sample(!llvm.ptr, i64, i64, ...)
// CHECK-LABEL: @sample
-func.func @sample(%q : !quantum.bit) {
+func.func @sample(%q : !quantum.bit, %shots1 : i64, %shots2 : i64) {
%o1 = quantum.compbasis %q : !quantum.obs
%o2 = quantum.compbasis %q, %q : !quantum.obs
// CHECK: [[c1:%.+]] = llvm.mlir.constant(1 : i64)
// CHECK: [[ptr:%.+]] = llvm.alloca [[c1]] x !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
- // CHECK-DAG: [[c1000:%.+]] = llvm.mlir.constant(1000 : i64)
// CHECK-DAG: [[c1:%.+]] = llvm.mlir.constant(1 : i64)
- // CHECK: llvm.call @__catalyst__qis__Sample([[ptr]], [[c1000]], [[c1]], %arg0)
+ // CHECK: llvm.call @__catalyst__qis__Sample([[ptr]], %arg1, [[c1]], %arg0)
%alloc1 = memref.alloc() : memref<1000x1xf64>
- quantum.sample %o1 in(%alloc1 : memref<1000x1xf64>) {shots = 1000 : i64}
+ quantum.sample %o1 in(%alloc1 : memref<1000x1xf64>) %shots1
// CHECK: [[c1:%.+]] = llvm.mlir.constant(1 : i64)
// CHECK: [[ptr:%.+]] = llvm.alloca [[c1]] x !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
- // CHECK: [[c2000:%.+]] = llvm.mlir.constant(2000 : i64)
// CHECK: [[c2:%.+]] = llvm.mlir.constant(2 : i64)
- // CHECK: llvm.call @__catalyst__qis__Sample([[ptr]], [[c2000]], [[c2]], %arg0, %arg0)
+ // CHECK: llvm.call @__catalyst__qis__Sample([[ptr]], %arg2, [[c2]], %arg0, %arg0)
%alloc2 = memref.alloc() : memref<2000x2xf64>
- quantum.sample %o2 in(%alloc2 : memref<2000x2xf64>) {shots = 2000 : i64}
+ quantum.sample %o2 in(%alloc2 : memref<2000x2xf64>) %shots2
return
}
@@ -437,27 +435,25 @@ func.func @sample(%q : !quantum.bit) {
// CHECK: llvm.func @__catalyst__qis__Counts(!llvm.ptr, i64, i64, ...)
// CHECK-LABEL: @counts
-func.func @counts(%q : !quantum.bit) {
+func.func @counts(%q : !quantum.bit, %shots1 : i64, %shots2 : i64) {
%o1 = quantum.compbasis %q : !quantum.obs
%o2 = quantum.compbasis %q, %q : !quantum.obs
// CHECK: [[c1:%.+]] = llvm.mlir.constant(1 : i64)
// CHECK: [[ptr:%.+]] = llvm.alloca [[c1]] x !llvm.struct<(struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK-DAG: [[c1000:%.+]] = llvm.mlir.constant(1000 : i64)
// CHECK-DAG: [[c1:%.+]] = llvm.mlir.constant(1 : i64)
- // CHECK: llvm.call @__catalyst__qis__Counts([[ptr]], [[c1000]], [[c1]], %arg0)
+ // CHECK: llvm.call @__catalyst__qis__Counts([[ptr]], %arg1, [[c1]], %arg0)
%in_eigvals1 = memref.alloc() : memref<2xf64>
%in_counts1 = memref.alloc() : memref<2xi64>
- quantum.counts %o1 in(%in_eigvals1 : memref<2xf64>, %in_counts1 : memref<2xi64>) {shots = 1000 : i64}
+ quantum.counts %o1 in(%in_eigvals1 : memref<2xf64>, %in_counts1 : memref<2xi64>) %shots1
// CHECK: [[c1:%.+]] = llvm.mlir.constant(1 : i64)
// CHECK: [[ptr:%.+]] = llvm.alloca [[c1]] x !llvm.struct<(struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[c2000:%.+]] = llvm.mlir.constant(2000 : i64)
// CHECK: [[c2:%.+]] = llvm.mlir.constant(2 : i64)
- // CHECK: llvm.call @__catalyst__qis__Counts([[ptr]], [[c2000]], [[c2]], %arg0, %arg0)
+ // CHECK: llvm.call @__catalyst__qis__Counts([[ptr]], %arg2, [[c2]], %arg0, %arg0)
%in_eigvals2 = memref.alloc() : memref<4xf64>
%in_counts2 = memref.alloc() : memref<4xi64>
- quantum.counts %o2 in(%in_eigvals2 : memref<4xf64>, %in_counts2 : memref<4xi64>) {shots = 2000 : i64}
+ quantum.counts %o2 in(%in_eigvals2 : memref<4xf64>, %in_counts2 : memref<4xi64>) %shots2
return
}
diff --git a/mlir/test/Quantum/VerifierTest.mlir b/mlir/test/Quantum/VerifierTest.mlir
index 0e23bc615d..21260f55eb 100644
--- a/mlir/test/Quantum/VerifierTest.mlir
+++ b/mlir/test/Quantum/VerifierTest.mlir
@@ -178,133 +178,130 @@ func.func @tensorobs(%q0 : !quantum.bit, %q1 : !quantum.bit, %q2 : !quantum.bit)
// -----
-func.func @sample1(%q : !quantum.bit) {
+func.func @sample1(%q : !quantum.bit, %shots: i64) {
%obs = quantum.namedobs %q[Identity] : !quantum.obs
- // expected-error@+1 {{return tensor must have 1D static shape equal to (number of shots)}}
- %err = quantum.sample %obs { shots=1000 } : tensor<1xf64>
+ %err = quantum.sample %obs %shots : tensor<1xf64>
- %samples = quantum.sample %obs { shots=1000 } : tensor<1000xf64>
+ %samples = quantum.sample %obs %shots : tensor<1000xf64>
return
}
// -----
-func.func @sample2(%q : !quantum.bit) {
+func.func @sample2(%q : !quantum.bit, %shots: i64) {
%obs = quantum.compbasis %q : !quantum.obs
- // expected-error@+1 {{return tensor must have 2D static shape equal to (number of shots, number of qubits in observable)}}
- %err = quantum.sample %obs { shots=1000 } : tensor<1000xf64>
+ %err = quantum.sample %obs %shots : tensor<1000xf64>
- %samples = quantum.sample %obs { shots=1000 } : tensor<1000x1xf64>
+ %samples = quantum.sample %obs %shots : tensor<1000x1xf64>
return
}
// -----
-func.func @sample3(%q : !quantum.bit) {
+func.func @sample3(%q : !quantum.bit, %shots: i64) {
%obs = quantum.compbasis %q : !quantum.obs
%alloc0 = memref.alloc() : memref<1000xf64>
- // expected-error@+1 {{return tensor must have 2D static shape equal to (number of shots, number of qubits in observable)}}
- quantum.sample %obs in(%alloc0 : memref<1000xf64>) { shots = 1000 }
+ quantum.sample %obs in(%alloc0 : memref<1000xf64>) %shots
%alloc1 = memref.alloc() : memref<1000x1xf64>
- quantum.sample %obs in(%alloc1 : memref<1000x1xf64>) { shots = 1000 }
+ quantum.sample %obs in(%alloc1 : memref<1000x1xf64>) %shots
return
}
// -----
-func.func @sample4(%q : !quantum.bit) {
+func.func @sample4(%q : !quantum.bit, %shots: i64) {
%obs = quantum.compbasis %q : !quantum.obs
%alloc = memref.alloc() : memref<1000xf64>
// expected-error@+1 {{either tensors must be returned or memrefs must be used as inputs}}
- quantum.sample %obs in (%alloc : memref<1000xf64>) { shots=1000 } : tensor<1000xf64>
+ quantum.sample %obs in (%alloc : memref<1000xf64>) %shots : tensor<1000xf64>
- %samples = quantum.sample %obs { shots=1000 } : tensor<1000x1xf64>
+ %samples = quantum.sample %obs %shots : tensor<1000x1xf64>
return
}
// -----
-func.func @sample5(%q : !quantum.bit) {
+func.func @sample5(%q : !quantum.bit, %shots: i64) {
%obs = quantum.compbasis %q : !quantum.obs
// expected-error@+1 {{either tensors must be returned or memrefs must be used as inputs}}
- quantum.sample %obs { shots=1000 }
+ quantum.sample %obs %shots
return
}
// -----
-func.func @counts1(%q0 : !quantum.bit, %q1 : !quantum.bit) {
+func.func @counts1(%q0 : !quantum.bit, %q1 : !quantum.bit, %shots: i64) {
%obs = quantum.namedobs %q0[PauliX] : !quantum.obs
// expected-error@+1 {{number of eigenvalues or counts did not match observable}}
- %err:2 = quantum.counts %obs { shots=1000 } : tensor<4xf64>, tensor<4xi64>
+ %err:2 = quantum.counts %obs %shots : tensor<4xf64>, tensor<4xi64>
- %counts:2 = quantum.counts %obs { shots=1000 } : tensor<2xf64>, tensor<2xi64>
+ %counts:2 = quantum.counts %obs %shots : tensor<2xf64>, tensor<2xi64>
return
}
// -----
-func.func @counts2(%q0 : !quantum.bit, %q1 : !quantum.bit) {
+func.func @counts2(%q0 : !quantum.bit, %q1 : !quantum.bit, %shots: i64) {
%obs = quantum.compbasis %q0, %q1 : !quantum.obs
// expected-error@+1 {{number of eigenvalues or counts did not match observable}}
- %err:2 = quantum.counts %obs { shots=1000 } : tensor<2xf64>, tensor<2xi64>
+ %err:2 = quantum.counts %obs %shots : tensor<2xf64>, tensor<2xi64>
- %counts:2 = quantum.counts %obs { shots=1000 } : tensor<4xf64>, tensor<4xi64>
+ %counts:2 = quantum.counts %obs %shots : tensor<4xf64>, tensor<4xi64>
return
}
// -----
-func.func @counts3(%q0 : !quantum.bit, %q1 : !quantum.bit) {
+func.func @counts3(%q0 : !quantum.bit, %q1 : !quantum.bit, %shots: i64) {
%obs = quantum.namedobs %q0[PauliX] : !quantum.obs
%in_eigvals_1 = memref.alloc() : memref<4xf64>
%in_counts_1 = memref.alloc() : memref<4xi64>
// expected-error@+1 {{number of eigenvalues or counts did not match observable}}
- quantum.counts %obs in(%in_eigvals_1 : memref<4xf64>, %in_counts_1 : memref<4xi64>) { shots=1000 }
+ quantum.counts %obs in(%in_eigvals_1 : memref<4xf64>, %in_counts_1 : memref<4xi64>) %shots
%in_eigvals_2 = memref.alloc() : memref<2xf64>
%in_counts_2 = memref.alloc() : memref<2xi64>
- quantum.counts %obs in(%in_eigvals_2 : memref<2xf64>, %in_counts_2 : memref<2xi64>) { shots=1000 }
+ quantum.counts %obs in(%in_eigvals_2 : memref<2xf64>, %in_counts_2 : memref<2xi64>) %shots
return
}
// -----
-func.func @counts4(%q0 : !quantum.bit, %q1 : !quantum.bit) {
+func.func @counts4(%q0 : !quantum.bit, %q1 : !quantum.bit, %shots: i64) {
%obs = quantum.namedobs %q0[PauliX] : !quantum.obs
// expected-error@+1 {{either tensors must be returned or memrefs must be used as inputs}}
- quantum.counts %obs { shots=1000 }
+ quantum.counts %obs %shots
return
}
// -----
-func.func @counts5(%q0 : !quantum.bit, %q1 : !quantum.bit) {
+func.func @counts5(%q0 : !quantum.bit, %q1 : !quantum.bit, %shots: i64) {
%obs = quantum.namedobs %q0[PauliX] : !quantum.obs
%in_eigvals = memref.alloc() : memref<2xf64>
%in_counts = memref.alloc() : memref<2xi64>
// expected-error@+1 {{either tensors must be returned or memrefs must be used as inputs}}
- quantum.counts %obs in(%in_eigvals : memref<2xf64>, %in_counts : memref<2xi64>) { shots=1000 } : tensor<2xf64>, tensor<2xi64>
+ quantum.counts %obs in(%in_eigvals : memref<2xf64>, %in_counts : memref<2xi64>) %shots : tensor<2xf64>, tensor<2xi64>
return
}