Skip to content

Commit

Permalink
Move to new qadence (#22)
Browse files Browse the repository at this point in the history
* not saving the quantum model

* move to new qadence+refactor

* fix docstring

* added functional form of the optimizers torch style

* Extended docs

* linting + docs

* fix import

* no TM

* bump version

* relax test and mark it as flaky
  • Loading branch information
inafergra authored Aug 28, 2024
1 parent 4f93b77 commit 3bdf515
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 105 deletions.
16 changes: 12 additions & 4 deletions docs/qinfo_tools/qng.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,17 +113,24 @@ for i in range(n_epochs_adam):
```

### QNG
The way to initialize the `QuantumNaturalGradient` optimizer in `qadence-libs` is slightly different from other usual Torch optimizers. Normally, one needs to pass a `params` argument to the optimizer to specify which parameters of the model should be optimized. In the `QuantumNaturalGradient`, it is assumed that all *circuit* parameters are to be optimized, whereas the *non-circuit* parameters will not be optimized. By circuit parameters, we mean parameters that somehow affect the quantum gates of the circuit and therefore influence the final quantum state. Any parameters affecting the observable (such as ouput scaling or shifting) are not considered circuit parameters, as those parameters will not be included in the QFI matrix as they don't affect the final state of the circuit.

The `QuantumNaturalGradient` constructor takes a qadence's `QuantumModel` as the 'model', and it will automatically identify its circuit and non-circuit parameters. The `approximation` argument defaults to the SPSA method, however the exact version of the QNG is also implemented and can be used for small circuits (beware of using the exact version for large circuits, as it scales badly). $\beta$ is a small constant added to the QFI matrix before inversion to ensure numerical stability,

$$(F_{ij} + \beta \mathbb{I})^{-1}$$

where $\mathbb{I}$ is the identify matrix. It is always a good idea to try out different values of $\beta$ if the training is not converging, which might be due to a too small $\beta$.

```python exec="on" source="material-block" html="1" session="main"
# Train with QNG
n_epochs_qng = 20
lr_qng = 0.1

model.reset_vparams(initial_params)
optimizer = QuantumNaturalGradient(
model.parameters(),
model=model,
lr=lr_qng,
approximation=FisherApproximation.EXACT,
model=model,
beta=0.1,
)

Expand All @@ -137,17 +144,18 @@ for i in range(n_epochs_qng):
```

### QNG-SPSA
The QNG-SPSA optimizer can be constructed similarly to the exact QNG, where now a new argument $\epsilon$ is used to control the shift used in the finite differences derivatives of the SPSA algorithm.

```python exec="on" source="material-block" html="1" session="main"
# Train with QNG-SPSA
n_epochs_qng_spsa = 20
lr_qng_spsa = 0.01

model.reset_vparams(initial_params)
optimizer = QuantumNaturalGradient(
model.parameters(),
model=model,
lr=lr_qng_spsa,
approximation=FisherApproximation.SPSA,
model=model,
beta=0.1,
epsilon=0.01,
)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ authors = [
requires-python = ">=3.9,<3.12"
license = {text = "Apache 2.0"}
keywords = ["quantum"]
version = "0.1.2"
version = "0.1.3"
classifiers=[
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python",
Expand Down
251 changes: 157 additions & 94 deletions qadence_libs/qinfo_tools/qng.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from __future__ import annotations

from typing import Callable, Sequence
import re
from typing import Callable

import torch
from qadence import QuantumCircuit, QuantumModel
from qadence import QNN, Parameter, QuantumCircuit, QuantumModel
from qadence.logger import get_logger
from qadence.ml_tools.models import TransformedModule
from torch.optim.optimizer import Optimizer, required

from qadence_libs.qinfo_tools.qfi import get_quantum_fisher, get_quantum_fisher_spsa
Expand All @@ -14,12 +14,50 @@
logger = get_logger(__name__)


def _identify_circuit_vparams(
model: QuantumModel | QNN, circuit: QuantumCircuit
) -> dict[str, Parameter]:
"""Returns the parameters of the model that are circuit parameters.
Args:
model (QuantumModel|QNN): The model
circuit (QuantumCircuit): The quantum circuit
Returns:
dict[str, Parameter]:
Dictionary containing the circuit parameters
"""
non_circuit_vparams = []
circ_vparams = {}
pattern = r"_params\."
for n, p in model.named_parameters():
n = re.sub(pattern, "", n)
if p.requires_grad:
if n in circuit.parameters():
circ_vparams[n] = p
else:
non_circuit_vparams.append(n)

if len(non_circuit_vparams) > 0:
msg = f"""Parameters {non_circuit_vparams} are trainable parameters of the model
which are not part of the quantum circuit. Since the QNG optimizer can
only optimize circuit parameters, these parameter will not be optimized.
Please use another optimizer for the non-circuit parameters."""
logger.warning(msg)

return circ_vparams


class QuantumNaturalGradient(Optimizer):
"""Implements the Quantum Natural Gradient Algorithm.
There are currently two variants of the algorithm implemented: exact QNG and
the SPSA approximation.
Unlike other torch optimizers, QuantumNaturalGradient does not take a `Sequence`
of parameters as an argument, but rather the QuantumModel whose parameters are to be
optimized. All circuit parameters in the QuantumModel will be optimized.
WARNING: The exact QNG optimizer is very inefficient both in time and memory as
it calculates the exact Quantum Fisher Information of the circuit at every
iteration. Therefore, it is not meant to be run with medium to large circuits.
Expand All @@ -29,8 +67,7 @@ class QuantumNaturalGradient(Optimizer):

def __init__(
self,
params: Sequence,
model: QuantumModel = required,
model: QuantumModel | QNN = required,
lr: float = required,
approximation: FisherApproximation | str = FisherApproximation.SPSA,
beta: float = 10e-3,
Expand All @@ -39,18 +76,16 @@ def __init__(
"""
Args:
params (tuple | torch.Tensor): Variational parameters to be updated
model (QuantumModel):
Model to be optimized. The optimizers needs to access its quantum circuit
to compute the QFI matrix.
Model whose (circuit) parameters are to be optimized
lr (float): Learning rate.
approximation (FisherApproximation):
Approximation used to compute the QFI matrix. Defaults to FisherApproximation.SPSA
beta (float):
Shift applied to the QFI matrix before inversion to ensure numerical stability.
Defaults to 10e-3.
epsilon (float):
Finite difference applied when computing the SPSA derivatives. Defaults to 10e-2.
Finite difference used when computing the SPSA derivatives. Defaults to 10e-2.
"""

if 0.0 > lr:
Expand All @@ -60,39 +95,37 @@ def __init__(
if 0.0 > epsilon:
raise ValueError(f"Invalid epsilon value: {epsilon}")

if isinstance(model, TransformedModule):
logger.warning(
"The model is of type '<class TransformedModule>. "
"Keep in mind that the QNG optimizer can only optimize circuit "
"parameters. Input and output shifting/scaling parameters will not be optimized."
)
# Retrieve the quantum model from the TransformedModule
model = model.model
if not isinstance(model, QuantumModel):
raise TypeError(
"The model should be an instance of '<class QuantumModel>' "
f"or '<class TransformedModule>'. Got {type(model)}."
f"The model should be an instance of '<class QuantumModel>'. Got {type(model)}."
)

self.param_dict = model.vparams
self.model = model
self.circuit = model._circuit.abstract
if not isinstance(self.circuit, QuantumCircuit):
raise TypeError(
"The circuit should be an instance of '<class QuantumCircuit>'."
"Got {type(self.circuit)}"
f"""The circuit should be an instance of '<class QuantumCircuit>'.
Got {type(self.circuit)}"""
)

circ_vparams = _identify_circuit_vparams(model, self.circuit)
self.vparams_keys = list(circ_vparams.keys())
vparams_values = list(circ_vparams.values())

defaults = dict(
model=model,
lr=lr,
approximation=approximation,
beta=beta,
epsilon=epsilon,
)
super().__init__(params, defaults)

super().__init__(vparams_values, defaults)

if len(self.param_groups) != 1:
raise ValueError("QNG doesn't support per-parameter options (parameter groups)")

if approximation == FisherApproximation.SPSA:
state = self.state["state"]
state = self.state
state.setdefault("iter", 0)
state.setdefault("qfi_estimator", None)

Expand All @@ -107,74 +140,104 @@ def step(self, closure: Callable | None = None) -> torch.Tensor:
if closure is not None:
loss = closure()

for group in self.param_groups:
vparams_values = [p for p in group["params"] if p.requires_grad]

# Build the parameter dictionary
# We rely on the `vparam()` method in `QuantumModel` and the
# `parameters()` in `nn.Module` to give the same param ordering.
# We test for this in `test_qng.py`.
vparams_dict = dict(zip(self.param_dict.keys(), vparams_values))

approximation = group["approximation"]
grad_vec = torch.tensor([v.grad.data for v in vparams_values])
if approximation == FisherApproximation.EXACT:
# Calculate the EXACT metric tensor
metric_tensor = 0.25 * get_quantum_fisher(
self.circuit,
vparams_dict=vparams_dict,
)

with torch.no_grad():
# Apply a finite shift to the metric tensor to avoid numerical
# stability issues when solving the least squares problem
metric_tensor = metric_tensor + group["beta"] * torch.eye(len(grad_vec))

# Get transformed gradient vector solving the least squares problem
transf_grad = torch.linalg.lstsq(
metric_tensor,
grad_vec,
driver="gelsd",
).solution

# Update parameters
for i, p in enumerate(vparams_values):
p.data.add_(transf_grad[i], alpha=-group["lr"])

elif approximation == FisherApproximation.SPSA:
state = self.state["state"]
with torch.no_grad():
# Get estimation of the QFI matrix
qfi_estimator, qfi_mat_positive_sd = get_quantum_fisher_spsa(
circuit=self.circuit,
iteration=state["iter"],
vparams_dict=vparams_dict,
previous_qfi_estimator=state["qfi_estimator"],
epsilon=group["epsilon"],
beta=group["beta"],
)

# Get transformed gradient vector solving the least squares problem
transf_grad = torch.linalg.lstsq(
0.25 * qfi_mat_positive_sd,
grad_vec,
driver="gelsd",
).solution

# Update parameters
for i, p in enumerate(vparams_values):
if p.grad is None:
continue
p.data.add_(transf_grad[i], alpha=-group["lr"])

state["iter"] += 1
state["qfi_estimator"] = qfi_estimator

else:
raise NotImplementedError(
f"Approximation {approximation} of the QNG optimizer "
"is not implemented. Choose an item from the "
f"FisherApproximation enum: {FisherApproximation.list()}."
)
assert len(self.param_groups) == 1
group = self.param_groups[0]

approximation = group["approximation"]
beta = group["beta"]
epsilon = group["epsilon"]
lr = group["lr"]
circuit = self.circuit
vparams_keys = self.vparams_keys
vparams_values = group["params"]
grad_vec = torch.tensor([v.grad.data for v in vparams_values])

if approximation == FisherApproximation.EXACT:
qng_exact(vparams_values, vparams_keys, grad_vec, lr, circuit, beta)
elif approximation == FisherApproximation.SPSA:
qng_spsa(vparams_values, vparams_keys, grad_vec, lr, circuit, self.state, epsilon, beta)
else:
raise NotImplementedError(
f"""Approximation {approximation} of the QNG optimizer
is not implemented. Choose an item from the
FisherApproximation enum: {FisherApproximation.list()}."""
)

return loss


def qng_exact(
vparams_values: list,
vparams_keys: list,
grad_vec: torch.Tensor,
lr: float,
circuit: QuantumCircuit,
beta: float,
) -> None:
"""Functional API that performs exact QNG algorithm computation.
See :class:`~qadence_libs.qinfo_tools.QuantumNaturalGradient` for details.
"""

# EXACT metric tensor
vparams_dict = dict(zip(vparams_keys, vparams_values))
metric_tensor = 0.25 * get_quantum_fisher(
circuit,
vparams_dict=vparams_dict,
)
with torch.no_grad():
# Apply a finite shift to the metric tensor to avoid numerical
# stability issues when solving the least squares problem
metric_tensor = metric_tensor + beta * torch.eye(len(grad_vec))

# Get transformed gradient vector solving the least squares problem
transf_grad = torch.linalg.lstsq(
metric_tensor,
grad_vec,
driver="gelsd",
).solution

for i, p in enumerate(vparams_values):
p.data.add_(transf_grad[i], alpha=-lr)


def qng_spsa(
vparams_values: list,
vparams_keys: list,
grad_vec: torch.Tensor,
lr: float,
circuit: QuantumCircuit,
state: dict,
epsilon: float,
beta: float,
) -> None:
"""Functional API that performs the QNG-SPSA algorithm computation.
See :class:`~qadence_libs.qinfo_tools.QuantumNaturalGradient` for details.
"""

# Get estimation of the QFI matrix
vparams_dict = dict(zip(vparams_keys, vparams_values))
qfi_estimator, qfi_mat_positive_sd = get_quantum_fisher_spsa(
circuit=circuit,
iteration=state["iter"],
vparams_dict=vparams_dict,
previous_qfi_estimator=state["qfi_estimator"],
epsilon=epsilon,
beta=beta,
)

# Get transformed gradient vector solving the least squares problem
transf_grad = torch.linalg.lstsq(
0.25 * qfi_mat_positive_sd,
grad_vec,
driver="gelsd",
).solution

for i, p in enumerate(vparams_values):
if p.grad is None:
continue
p.data.add_(transf_grad[i], alpha=-lr)

state["iter"] += 1
state["qfi_estimator"] = qfi_estimator
2 changes: 1 addition & 1 deletion tests/constructors/test_rydberg_hea.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from qadence.blocks.analog import ConstantAnalogRotation
from qadence.circuit import QuantumCircuit
from qadence.constructors import hamiltonian_factory, total_magnetization
from qadence.models import QuantumModel
from qadence.model import QuantumModel
from qadence.operations import AnalogRY, X
from qadence.parameters import VariationalParameter
from qadence.register import Register
Expand Down
Loading

0 comments on commit 3bdf515

Please sign in to comment.