Skip to content

Commit

Permalink
infra: Update tests for AwsDevice implementation (#59)
Browse files Browse the repository at this point in the history
Also updates python-package workflow to run tests with the main branch
of dependencies instead of latest releases.
  • Loading branch information
speller26 authored Mar 2, 2021
1 parent 21ec6af commit ad3649c
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 45 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ jobs:
flake8
- name: Run unit tests
run: |
tox -e unit-tests
coverage run -m pytest
coverage combine
coverage report
coverage xml
- name: Upload coverage report to Codecov
uses: codecov/codecov-action@v1
82 changes: 39 additions & 43 deletions test/unit_tests/test_braket_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@

def test_reset():
"""Tests that the members of the device are cleared on reset."""
dev = _device(wires=2)
dev = _aws_device(wires=2)
dev._circuit = CIRCUIT
dev._task = TASK

Expand All @@ -195,7 +195,7 @@ def test_reset():
@pytest.mark.parametrize("pl_op, braket_gate, qubits, params", testdata)
def test_apply(pl_op, braket_gate, qubits, params):
"""Tests that the correct Braket gate is applied for each PennyLane operation."""
dev = _device(wires=len(qubits))
dev = _aws_device(wires=len(qubits))
circuit = dev.apply([pl_op(*params, wires=qubits)])
assert circuit == Circuit().add_instruction(Instruction(braket_gate(*params), qubits))

Expand All @@ -206,14 +206,14 @@ def test_apply_inverse_gates(pl_op, braket_gate, params, inv_params, qubits):
Tests that the correct Braket gate is applied for the inverse of each PennyLane operations
where the inverse is defined.
"""
dev = _device(wires=len(qubits))
dev = _aws_device(wires=len(qubits))
circuit = dev.apply([pl_op(*params, wires=qubits).inv()])
assert circuit == Circuit().add_instruction(Instruction(braket_gate(*inv_params), qubits))


def test_apply_unused_qubits():
"""Tests that the correct circuit is created when not all qires in the dievce are used."""
dev = _device(wires=4)
dev = _aws_device(wires=4)
operations = [qml.Hadamard(wires=1), qml.CNOT(wires=[1, 2]), qml.RX(np.pi / 2, wires=2)]
rotations = [qml.RY(np.pi, wires=1)]
circuit = dev.apply(operations, rotations)
Expand All @@ -224,7 +224,7 @@ def test_apply_unused_qubits():
@pytest.mark.xfail(raises=NotImplementedError)
def test_apply_unsupported():
"""Tests that apply() throws NotImplementedError when it encounters an unknown gate."""
dev = _device(wires=2)
dev = _aws_device(wires=2)
mock_op = Mock()
mock_op.name = "foo"
mock_op.parameters = []
Expand All @@ -236,7 +236,7 @@ def test_apply_unsupported():
def test_apply_unwrap_tensor():
"""Test that apply() unwraps tensors from the PennyLane version of NumPy into standard NumPy
arrays (or floats)"""
dev = _device(wires=1)
dev = _aws_device(wires=1)

a = anp.array(0.6) # array
b = np.array(0.5, requires_grad=True) # tensor
Expand All @@ -249,10 +249,10 @@ def test_apply_unwrap_tensor():
assert not any([isinstance(angle, np.tensor) for angle in angles])


@patch.object(AwsQuantumTask, "create")
def test_execute(mock_create):
mock_create.return_value = TASK
dev = _device(wires=4, foo="bar")
@patch.object(AwsDevice, "run")
def test_execute(mock_run):
mock_run.return_value = TASK
dev = _aws_device(wires=4, foo="bar")

with QuantumTape() as circuit:
qml.Hadamard(wires=0)
Expand Down Expand Up @@ -283,12 +283,10 @@ def test_execute(mock_create):
)
assert dev.task == TASK

mock_create.assert_called_with(
mock.ANY,
DEVICE_ARN,
mock_run.assert_called_with(
CIRCUIT,
("foo", "bar"),
SHOTS,
s3_destination_folder=("foo", "bar"),
shots=SHOTS,
poll_timeout_seconds=AwsQuantumTask.DEFAULT_RESULTS_POLL_TIMEOUT,
poll_interval_seconds=AwsQuantumTask.DEFAULT_RESULTS_POLL_INTERVAL,
foo="bar",
Expand All @@ -297,7 +295,7 @@ def test_execute(mock_create):

def test_pl_to_braket_circuit():
"""Tests that a PennyLane circuit is correctly converted into a Braket circuit"""
dev = _device(wires=2, foo="bar")
dev = _aws_device(wires=2, foo="bar")

with QuantumTape() as tape:
qml.RX(0.2, wires=0)
Expand All @@ -320,7 +318,7 @@ def test_pl_to_braket_circuit():

def test_bad_statistics():
"""Test if a QuantumFunctionError is raised for an invalid return type"""
dev = _device(wires=1, foo="bar")
dev = _aws_device(wires=1, foo="bar")
observable = qml.Identity(wires=0, do_queue=False)
observable.return_type = None

Expand All @@ -330,7 +328,7 @@ def test_bad_statistics():

def test_batch_execute_non_parallel(monkeypatch):
"""Test if the batch_execute() method simply calls the inherited method if parallel=False"""
dev = _device(wires=2, foo="bar", parallel=False)
dev = _aws_device(wires=2, foo="bar", parallel=False)
assert dev.parallel is False

with monkeypatch.context() as m:
Expand All @@ -343,7 +341,7 @@ def test_batch_execute_non_parallel(monkeypatch):
def test_batch_execute_parallel(mock_run_batch):
"""Test batch_execute(parallel=True) correctly calls batch execution methods in Braket SDK"""
mock_run_batch.return_value = TASK_BATCH
dev = _device(wires=4, foo="bar", parallel=True)
dev = _aws_device(wires=4, foo="bar", parallel=True)
assert dev.parallel is True

with QuantumTape() as circuit:
Expand Down Expand Up @@ -391,8 +389,8 @@ def test_batch_execute_parallel(mock_run_batch):
)


@patch.object(AwsQuantumTask, "create")
def test_execute_all_samples(mock_create):
@patch.object(AwsDevice, "run")
def test_execute_all_samples(mock_run):
result = GateModelQuantumTaskResult.from_string(
json.dumps(
{
Expand Down Expand Up @@ -436,8 +434,8 @@ def test_execute_all_samples(mock_create):
)
task = Mock()
task.result.return_value = result
mock_create.return_value = task
dev = _device(wires=3)
mock_run.return_value = task
dev = _aws_device(wires=3)

with QuantumTape() as circuit:
qml.Hadamard(wires=0)
Expand All @@ -449,21 +447,22 @@ def test_execute_all_samples(mock_create):


@pytest.mark.xfail(raises=ValueError)
def test_non_jaqcd_device():
@patch.object(AwsDevice, "name", new_callable=mock.PropertyMock)
def test_non_jaqcd_device(name_mock):
"""Tests that BraketDevice cannot be instantiated with a non-JAQCD AwsDevice"""
_bad_device(wires=2)
_bad_aws_device(wires=2)


def test_simulator_default_shots():
"""Tests that simulator devices are analytic if ``shots`` is not supplied"""
dev = _device(wires=2, device_type=AwsDeviceType.SIMULATOR, shots=None)
dev = _aws_device(wires=2, device_type=AwsDeviceType.SIMULATOR, shots=None)
assert dev.shots == 1
assert dev.analytic


def test_simulator_0_shots():
"""Tests that simulator devices are analytic if ``shots`` is not supplied"""
dev = _device(wires=2, device_type=AwsDeviceType.SIMULATOR, shots=0)
dev = _aws_device(wires=2, device_type=AwsDeviceType.SIMULATOR, shots=0)
assert dev.shots == 1
assert dev.analytic

Expand All @@ -484,28 +483,28 @@ def test_local_0_shots():

def test_qpu_default_shots():
"""Tests that QPU devices have the right default value for ``shots``"""
dev = _device(wires=2, shots=None)
dev = _aws_device(wires=2, shots=None)
assert dev.shots == AwsDevice.DEFAULT_SHOTS_QPU
assert not dev.analytic


@pytest.mark.xfail(raises=ValueError)
def test_qpu_0_shots():
"""Tests that QPUs can not be instantiated with 0 shots"""
_device(wires=2, shots=0)
_aws_device(wires=2, shots=0)


@pytest.mark.xfail(raises=ValueError)
def test_invalid_device_type():
"""Tests that BraketDevice cannot be instantiated with an unknown device type"""
_device(wires=2, device_type="foo", shots=None)
_aws_device(wires=2, device_type="foo", shots=None)


def test_wires():
"""Test if the apply method supports custom wire labels"""

wires = ["A", 0, "B", -1]
dev = _device(wires=wires, device_type=AwsDeviceType.SIMULATOR, shots=0)
dev = _aws_device(wires=wires, device_type=AwsDeviceType.SIMULATOR, shots=0)

ops = [qml.RX(0.1, wires="A"), qml.CNOT(wires=[0, "B"]), qml.RY(0.3, wires=-1)]
target_wires = [[0], [1, 2], [3]]
Expand All @@ -517,17 +516,15 @@ def test_wires():
assert w == t


def _noop(*args, **kwargs):
return None


@patch.object(AwsDevice, "__init__", _noop)
@patch.object(AwsDevice, "type", new_callable=mock.PropertyMock)
@patch.object(AwsDevice, "properties")
@patch.object(AwsDevice, "refresh_metadata", return_value=None)
def _device(
refresh_metadata_mock,
properties_mock,
type_mock,
wires,
device_type=AwsDeviceType.QPU,
shots=SHOTS,
**kwargs
def _aws_device(
properties_mock, type_mock, wires, device_type=AwsDeviceType.QPU, shots=SHOTS, **kwargs
):
properties_mock.action = {DeviceActionType.JAQCD: "foo"}
type_mock.return_value = device_type
Expand All @@ -541,10 +538,9 @@ def _device(
)


@patch.object(AwsDevice, "type")
@patch.object(AwsDevice, "__init__", _noop)
@patch.object(AwsDevice, "properties")
@patch.object(AwsDevice, "refresh_metadata", return_value=None)
def _bad_device(refresh_metadata_mock, properties_mock, type_mock, wires, **kwargs):
def _bad_aws_device(properties_mock, wires, **kwargs):
properties_mock.action = {DeviceActionType.ANNEALING: "foo"}
properties_mock.type = AwsDeviceType.QPU
return BraketAwsQubitDevice(
Expand Down
1 change: 0 additions & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ commands =
coverage combine
coverage report
coverage html
coverage xml
extras = test

[testenv:integ-tests]
Expand Down

0 comments on commit ad3649c

Please sign in to comment.