Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Making arithmetic templates QJIT compatible #6307

Merged
merged 20 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
* The `Hermitian` operator now has a `compute_sparse_matrix` implementation.
[(#6225)](https://github.com/PennyLaneAI/pennylane/pull/6225)

* The quantum arithmetic templates are now QJIT compatible.
[(#6307)](https://github.com/PennyLaneAI/pennylane/pull/6307)

<h4>Capturing and representing hybrid programs</h4>

* Differentiation of hybrid programs via `qml.grad` and `qml.jacobian` can now be captured
Expand Down
6 changes: 5 additions & 1 deletion pennylane/templates/subroutines/multiplier.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,5 +228,9 @@ def compute_decomposition(k, x_wires, mod, work_wires): # pylint: disable=argum
for x_wire, aux_wire in zip(x_wires, wires_aux_swap):
op_list.append(qml.SWAP(wires=[x_wire, aux_wire]))
inv_k = pow(k, -1, mod)
op_list.extend(qml.adjoint(_mul_out_k_mod)(inv_k, x_wires, mod, work_wire_aux, wires_aux))

# Adjoint is not iterable in QJIT
willjmax marked this conversation as resolved.
Show resolved Hide resolved
for op in reversed(_mul_out_k_mod(inv_k, x_wires, mod, work_wire_aux, wires_aux)):
op_list.append(qml.adjoint(op))
willjmax marked this conversation as resolved.
Show resolved Hide resolved

return op_list
45 changes: 29 additions & 16 deletions pennylane/templates/subroutines/phase_adder.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,20 +131,25 @@ def __init__(

num_work_wires = 0 if work_wire is None else len(work_wire)

if mod is None:
mod = 2 ** len(x_wires)
elif mod != 2 ** len(x_wires) and num_work_wires != 1:
raise ValueError(f"If mod is not 2^{len(x_wires)}, one work wire should be provided.")
if not isinstance(k, int) or not isinstance(mod, int):
raise ValueError("Both k and mod must be integers")
if mod > 2 ** len(x_wires):
raise ValueError(
"PhaseAdder must have enough x_wires to represent mod. The maximum mod "
f"with len(x_wires)={len(x_wires)} is {2 ** len(x_wires)}, but received {mod}."
)
if work_wire is not None:
if any(wire in work_wire for wire in x_wires):
raise ValueError("None of the wires in work_wire should be included in x_wires.")
if not qml.math.is_abstract(mod):
willjmax marked this conversation as resolved.
Show resolved Hide resolved
if mod is None:
mod = 2 ** len(x_wires)
elif mod != 2 ** len(x_wires) and num_work_wires != 1:
raise ValueError(
f"If mod is not 2^{len(x_wires)}, one work wire should be provided."
)
if not isinstance(k, int) or not isinstance(mod, int):
raise ValueError("Both k and mod must be integers")
if mod > 2 ** len(x_wires):
raise ValueError(
"PhaseAdder must have enough x_wires to represent mod. The maximum mod "
f"with len(x_wires)={len(x_wires)} is {2 ** len(x_wires)}, but received {mod}."
)
if work_wire is not None:
if any(wire in work_wire for wire in x_wires):
raise ValueError(
"None of the wires in work_wire should be included in x_wires."
)

self.hyperparameters["k"] = k % mod
self.hyperparameters["mod"] = mod
Expand Down Expand Up @@ -216,12 +221,20 @@ def compute_decomposition(k, x_wires, mod, work_wire): # pylint: disable=argume
else:
aux_k = x_wires[0]
op_list.extend(_add_k_fourier(k, x_wires))
op_list.extend(qml.adjoint(_add_k_fourier)(mod, x_wires))

# Adjoint is not iterable in QJIT
for op in reversed(_add_k_fourier(mod, x_wires)):
op_list.append(qml.adjoint(op))

op_list.append(qml.adjoint(qml.QFT)(wires=x_wires))
op_list.append(qml.ctrl(qml.PauliX(work_wire), control=aux_k, control_values=1))
op_list.append(qml.QFT(wires=x_wires))
op_list.extend(qml.ctrl(op, control=work_wire) for op in _add_k_fourier(mod, x_wires))
op_list.extend(qml.adjoint(_add_k_fourier)(k, x_wires))

# Adjoint is not iterable in QJIT
obliviateandsurrender marked this conversation as resolved.
Show resolved Hide resolved
for op in reversed(_add_k_fourier(k, x_wires)):
op_list.append(qml.adjoint(op))

op_list.append(qml.adjoint(qml.QFT)(wires=x_wires))
op_list.append(qml.ctrl(qml.PauliX(work_wire), control=aux_k, control_values=0))
op_list.append(qml.QFT(wires=x_wires))
Expand Down
Loading