-
Suppose I start with an MPS simulation ( Consider for example, # %% Import Quimb + defs
from quimb.tensor.circuit import MPS_computational_state
from quimb.tensor.tensor_builder import MPO_rand_herm
from quimb.gates import CNOT, H
import quimb as qu
import quimb.tensor as qtn
import jax
import numpy
qtn.jax_register_pytree()
# %%
n = 10
A = MPO_rand_herm(n, bond_dim=2, tags=["HAM"])
def two_qubit_layer(circ, gate_round, contract=False):
"""Apply a layer of constant entangling gates."""
n_wires = circ.L
offset = (gate_round % (n_wires - 1)) + 1
for i in range(circ.L):
circ.gate_(CNOT, (i, (i + offset) % n_wires), contract=contract)
def circuit(n_layers, params, contract=False):
psi0 = MPS_computational_state("0" * n)
for j in range(n_layers):
for i in range(n):
psi0.gate_(H, i, contract=contract)
Rz = qu.phase_gate(params[j])
for i in range(n):
psi0.gate_(Rz, i, contract=contract)
two_qubit_layer(psi0, j, contract=contract)
return psi0
@jax.jit
@jax.value_and_grad
def cost(circ):
p = circ
pH = circ.H
p.align(A, pH)
return (pH & A & p).contract(all, output_inds=()).real
params = numpy.random.random((1))
circ = circuit(1, params, contract="swap+split")
expv_qub, grad_qub = cost(circ)
#grad_qub
# MatrixProductState(tensors=10, indices=19, L=10, max_bond=4)
# Tensor(shape=(2, 2), inds=[_8b5bd2AAKLP, k0], tags={I0}),
# Tensor(shape=(2, 3, 2), inds=[_8b5bd2AAKLP, _8b5bd2AAKLQ, k1], tags={I1}),
# Tensor(shape=(3, 4, 2), inds=[_8b5bd2AAKLQ, _8b5bd2AAKLR, k2], tags={I2}),
# Tensor(shape=(4, 4, 2), inds=[_8b5bd2AAKLR, _8b5bd2AAKLS, k3], tags={I3}),
# Tensor(shape=(4, 4, 2), inds=[_8b5bd2AAKLS, _8b5bd2AAKLT, k4], tags={I4}),
# Tensor(shape=(4, 4, 2), inds=[_8b5bd2AAKLT, _8b5bd2AAKLU, k5], tags={I5}),
# Tensor(shape=(4, 4, 2), inds=[_8b5bd2AAKLU, _8b5bd2AAKLV, k6], tags={I6}),
# Tensor(shape=(4, 4, 2), inds=[_8b5bd2AAKLV, _8b5bd2AAKLW, k7], tags={I7}),
# Tensor(shape=(4, 2, 2), inds=[_8b5bd2AAKLW, _8b5bd2AAKLX, k8], tags={I8}),
# Tensor(shape=(2, 2), inds=[_8b5bd2AAKLX, k9], tags={I9}), I would like to make sure what is |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
I think your snippet is missing some imports to run directly, but to answer your questions:
|
Beta Was this translation helpful? Give feedback.
-
Could one brute force fix for this be to keep matrix bond dimensions fixed at the maximum bond dimension? I.e. all tensors of the MPS have fixed shape |
Beta Was this translation helpful? Give feedback.
I think your snippet is missing some imports to run directly, but to answer your questions:
circ
). Forjax
and its functional style, grad(f(theta)), where theta is any 'pytree', computes df/dtheta. So you need to write a function whereparams
(orPTensor
instances) are the inputs to get the gradient with respect to them.jax
and other libraries with a 'static' computational graph, since the shapes are not predetermined. But t…