Skip to content

Commit

Permalink
support transpilation of switch_case (#1962)
Browse files Browse the repository at this point in the history
  • Loading branch information
hhorii authored Oct 20, 2023
1 parent e7afbfb commit a13ff4e
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 29 deletions.
4 changes: 4 additions & 0 deletions qiskit_aer/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@
"pauli",
"mcx_gray",
"ecr",
"switch_case",
]
),
"density_matrix": sorted(
Expand Down Expand Up @@ -149,6 +150,7 @@
"delay",
"pauli",
"ecr",
"switch_case",
]
),
"matrix_product_state": sorted(
Expand Down Expand Up @@ -191,6 +193,7 @@
"cswap",
"diagonal",
"initialize",
"switch_case",
]
),
"stabilizer": sorted(
Expand All @@ -214,6 +217,7 @@
"rx",
"ry",
"rz",
"switch_case",
]
),
"extended_stabilizer": sorted(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
issues:
- |
Though Aer supports ``switch`` for several methods, transpilation of circuits with ``switch`` has been failed.
This commit enables such transpilation by adding ``switch_case`` operations into basis gates.
87 changes: 58 additions & 29 deletions test/terra/backends/aer_simulator/test_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def add_jump(self, circ, jump_to, clbit=None, value=0):
instr.c_if(clbit, value)
return circ.append(instr, qubits)

@data("statevector", "density_matrix", "matrix_product_state")
@data("statevector", "density_matrix", "matrix_product_state", "stabilizer")
def test_jump_always(self, method):
backend = self.backend(method=method)

Expand All @@ -76,7 +76,7 @@ def test_jump_always(self, method):
self.assertEqual(len(counts), 1)
self.assertIn("0000", counts)

@data("statevector", "density_matrix", "matrix_product_state")
@data("statevector", "density_matrix", "matrix_product_state", "stabilizer")
def test_jump_conditional(self, method):
backend = self.backend(method=method)

Expand All @@ -98,7 +98,7 @@ def test_jump_conditional(self, method):
self.assertEqual(len(counts), 1)
self.assertIn("0000 0", counts)

@data("statevector", "density_matrix", "matrix_product_state")
@data("statevector", "density_matrix", "matrix_product_state", "stabilizer")
def test_no_jump_conditional(self, method):
backend = self.backend(method=method)

Expand All @@ -119,7 +119,7 @@ def test_no_jump_conditional(self, method):
counts = result.get_counts()
self.assertNotEqual(len(counts), 1)

@data("statevector", "density_matrix", "matrix_product_state")
@data("statevector", "density_matrix", "matrix_product_state", "stabilizer")
def test_invalid_jump(self, method):
logging.disable(level=logging.WARN)

Expand All @@ -142,7 +142,7 @@ def test_invalid_jump(self, method):

logging.disable(level=logging.NOTSET)

@data("statevector", "density_matrix", "matrix_product_state")
@data("statevector", "density_matrix", "matrix_product_state", "stabilizer")
def test_duplicated_mark(self, method):
logging.disable(level=logging.WARN)

Expand All @@ -165,7 +165,7 @@ def test_duplicated_mark(self, method):

logging.disable(level=logging.NOTSET)

@data("statevector", "density_matrix", "matrix_product_state")
@data("statevector", "density_matrix", "matrix_product_state", "stabilizer")
def test_if_true_body_builder(self, method):
backend = self.backend(method=method)

Expand All @@ -189,7 +189,7 @@ def test_if_true_body_builder(self, method):
self.assertEqual(len(counts), 1)
self.assertIn("0001 1", counts)

@data("statevector", "density_matrix", "matrix_product_state")
@data("statevector", "density_matrix", "matrix_product_state", "stabilizer")
def test_if_else_body_builder(self, method):
backend = self.backend(method=method)

Expand All @@ -214,7 +214,7 @@ def test_if_else_body_builder(self, method):
self.assertEqual(len(counts), 1)
self.assertIn("0000 0", counts)

@data("statevector", "density_matrix", "matrix_product_state")
@data("statevector", "density_matrix", "matrix_product_state", "stabilizer")
def test_for_loop_builder(self, method):
backend = self.backend(method=method)

Expand All @@ -240,7 +240,7 @@ def test_for_loop_builder(self, method):
self.assertEqual(len(counts), 1)
self.assertIn("01100", counts)

@data("statevector", "density_matrix", "matrix_product_state")
@data("statevector", "density_matrix", "matrix_product_state", "stabilizer")
def test_for_loop_builder_no_loop_variable(self, method):
backend = self.backend(method=method)

Expand All @@ -266,7 +266,7 @@ def test_for_loop_builder_no_loop_variable(self, method):
self.assertEqual(len(counts), 1)
self.assertIn("01010", counts)

@data("statevector", "density_matrix", "matrix_product_state")
@data("statevector", "density_matrix", "matrix_product_state", "stabilizer")
def test_for_loop_break_builder(self, method):
backend = self.backend(method=method)

Expand Down Expand Up @@ -309,7 +309,7 @@ def test_for_loop_break_builder(self, method):
self.assertEqual(len(counts), 1)
self.assertIn("11100 1", counts)

@data("statevector", "density_matrix", "matrix_product_state")
@data("statevector", "density_matrix", "matrix_product_state", "stabilizer")
def test_for_loop_continue_builder(self, method):
backend = self.backend(method=method)

Expand Down Expand Up @@ -371,7 +371,7 @@ def test_for_loop_continue_builder(self, method):
self.assertEqual(len(counts), 1)
self.assertIn("11110 0 1 0 0 0", counts)

@data("statevector", "density_matrix", "matrix_product_state")
@data("statevector", "density_matrix", "matrix_product_state", "stabilizer")
def test_while_loop_no_iteration(self, method):
backend = self.backend(method=method)

Expand All @@ -390,7 +390,7 @@ def test_while_loop_no_iteration(self, method):
self.assertEqual(len(counts), 1)
self.assertIn("0 0", counts)

@data("statevector", "density_matrix", "matrix_product_state")
@data("statevector", "density_matrix", "matrix_product_state", "stabilizer")
def test_while_loop_single_iteration(self, method):
backend = self.backend(method=method)

Expand Down Expand Up @@ -421,7 +421,7 @@ def test_while_loop_single_iteration(self, method):
self.assertEqual(len(counts), 1)
self.assertIn("10 0", counts)

@data("statevector", "density_matrix", "matrix_product_state")
@data("statevector", "density_matrix", "matrix_product_state", "stabilizer")
def test_while_loop_double_iterations(self, method):
backend = self.backend(method=method)

Expand Down Expand Up @@ -452,7 +452,7 @@ def test_while_loop_double_iterations(self, method):
self.assertEqual(len(counts), 1)
self.assertIn("01 0", counts)

@data("statevector", "density_matrix", "matrix_product_state")
@data("statevector", "density_matrix", "matrix_product_state", "stabilizer")
def test_while_loop_continue(self, method):
backend = self.backend(method=method)

Expand Down Expand Up @@ -486,7 +486,7 @@ def test_while_loop_continue(self, method):
self.assertEqual(len(counts), 1)
self.assertIn("0 0", counts)

@data("statevector", "density_matrix", "matrix_product_state")
@data("statevector", "density_matrix", "matrix_product_state", "stabilizer")
def test_nested_loop(self, method):
backend = self.backend(method=method)

Expand All @@ -513,7 +513,7 @@ def test_nested_loop(self, method):
self.assertEqual(len(counts), 1)
self.assertIn("011", counts)

@data("statevector", "density_matrix", "matrix_product_state")
@data("statevector", "density_matrix", "matrix_product_state", "stabilizer")
def test_while_loop_last(self, method):
backend = self.backend(method=method)

Expand All @@ -527,7 +527,7 @@ def test_while_loop_last(self, method):
result = backend.run(circ, method=method).result()
self.assertSuccess(result)

@data("statevector", "density_matrix", "matrix_product_state")
@data("statevector", "density_matrix", "matrix_product_state", "stabilizer")
def test_no_invalid_nested_reordering(self, method):
"""Test that the jump/mark system doesn't allow nested conditional marks to jump incorrectly
relative to their outer marks. Regression test of gh-1665."""
Expand All @@ -549,7 +549,7 @@ def test_no_invalid_nested_reordering(self, method):
self.assertSuccess(result)
self.assertEqual(result.get_counts(), {"110": 100})

@data("statevector", "density_matrix", "matrix_product_state")
@data("statevector", "density_matrix", "matrix_product_state", "stabilizer")
def test_no_invalid_reordering_if(self, method):
"""Test that the jump/mark system doesn't allow an unrelated operation to jump inside a
conditional statement."""
Expand All @@ -575,7 +575,7 @@ def test_no_invalid_reordering_if(self, method):
self.assertSuccess(result)
self.assertEqual(result.get_counts(), {"010": 100})

@data("statevector", "density_matrix", "matrix_product_state")
@data("statevector", "density_matrix", "matrix_product_state", "stabilizer")
def test_no_invalid_reordering_while(self, method):
"""Test that the jump/mark system doesn't allow an unrelated operation to jump inside a
conditional statement."""
Expand All @@ -601,7 +601,7 @@ def test_no_invalid_reordering_while(self, method):
self.assertSuccess(result)
self.assertEqual(result.get_counts(), {"010": 100})

@data("statevector", "density_matrix", "matrix_product_state")
@data("statevector", "density_matrix", "matrix_product_state", "stabilizer")
def test_transpile_break_and_continue_loop(self, method):
"""Test that transpiler can transpile break_loop and continue_loop with AerSimulator"""

Expand Down Expand Up @@ -632,7 +632,7 @@ def test_transpile_break_and_continue_loop(self, method):
result = backend.run(transpiled, method=method, shots=100).result()
self.assertEqual(result.get_counts(), {"1": 100})

@data("statevector", "density_matrix", "matrix_product_state")
@data("statevector", "density_matrix", "matrix_product_state", "stabilizer")
def test_switch_clbit(self, method):
"""Test that a switch statement can be constructed with a bit as a condition."""

Expand Down Expand Up @@ -681,7 +681,7 @@ def test_switch_clbit(self, method):
self.assertSuccess(ret1)
self.assertEqual(ret1.get_counts(), ret1_expected.get_counts())

@data("statevector", "density_matrix", "matrix_product_state")
@data("statevector", "density_matrix", "matrix_product_state", "stabilizer")
def test_switch_register(self, method):
"""Test that a switch statement can be constructed with a register as a condition."""

Expand Down Expand Up @@ -742,7 +742,7 @@ def test_switch_register(self, method):
self.assertSuccess(ret3)
self.assertEqual(ret3.get_counts(), {"011 11": 100})

@data("statevector", "density_matrix", "matrix_product_state")
@data("statevector", "density_matrix", "matrix_product_state", "stabilizer")
def test_switch_with_default(self, method):
"""Test that a switch statement can be constructed with a default case at the end."""

Expand Down Expand Up @@ -803,7 +803,7 @@ def test_switch_with_default(self, method):
self.assertSuccess(ret3)
self.assertEqual(ret3.get_counts(), {"111 11": 100})

@data("statevector", "density_matrix", "matrix_product_state")
@data("statevector", "density_matrix", "matrix_product_state", "stabilizer")
def test_switch_multiple_cases_to_same_block(self, method):
"""Test that it is possible to add multiple cases that apply to the same block, if they are
given as a compound value. This is an allowed special case of block fall-through."""
Expand Down Expand Up @@ -865,7 +865,36 @@ def test_switch_multiple_cases_to_same_block(self, method):
self.assertSuccess(ret3)
self.assertEqual(ret3.get_counts(), {"011 11": 100})

@data("statevector", "density_matrix", "matrix_product_state")
@data("statevector", "density_matrix", "matrix_product_state", "stabilizer")
def test_switch_transpilation(self, method):
"""Test swtich test cases can be transpiled"""

backend = self.backend(method=method, seed_simulator=1)

qubit0 = Qubit()
qubit1 = Qubit()
qubit2 = Qubit()

creg = ClassicalRegister(2)
qc = QuantumCircuit([qubit0, qubit1, qubit2], creg)

with qc.switch(creg) as case:
with case(0):
qc.x(0)
with case(1):
qc.x(1)
with case(case.DEFAULT):
qc.x(2)

qc.measure_all()

transpiled = transpile(qc, backend)

ret0 = backend.run(transpiled, shots=100).result()
self.assertSuccess(ret0)
self.assertEqual(ret0.get_counts(), {"001 00": 100})

@data("statevector", "density_matrix", "matrix_product_state", "stabilizer")
def test_switch_register_with_classical_expression(self, method):
"""Test that a switch statement can be constructed with a register as a condition."""

Expand Down Expand Up @@ -926,7 +955,7 @@ def test_switch_register_with_classical_expression(self, method):
self.assertSuccess(ret3)
self.assertEqual(ret3.get_counts(), {"011 11": 100})

@data("statevector", "density_matrix", "matrix_product_state")
@data("statevector", "density_matrix", "matrix_product_state", "stabilizer")
def test_if_expr_true_body_builder(self, method):
"""test expression with branch operation"""
backend = self.backend(method=method)
Expand Down Expand Up @@ -974,7 +1003,7 @@ def test_if_expr_true_body_builder(self, method):
self.assertEqual(len(counts), 1)
self.assertIn("0001 011", counts)

@data("statevector", "density_matrix", "matrix_product_state")
@data("statevector", "density_matrix", "matrix_product_state", "stabilizer")
def test_if_expr_false_body_builder(self, method):
"""test expression with branch operation"""
backend = self.backend(method=method)
Expand Down Expand Up @@ -1026,7 +1055,7 @@ def test_if_expr_false_body_builder(self, method):
self.assertEqual(len(counts), 1)
self.assertIn("0001 011", counts)

@data("statevector", "density_matrix", "matrix_product_state")
@data("statevector", "density_matrix", "matrix_product_state", "stabilizer")
def test_while_expr_loop_break(self, method):
backend = self.backend(method=method)

Expand Down

0 comments on commit a13ff4e

Please sign in to comment.