Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Ensure the order of set_basis_state_p and set_state_p are preserved. #1174

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

erick-xanadu
Copy link
Contributor

@erick-xanadu erick-xanadu commented Oct 2, 2024

Context: When a tape contains two state preparation operations, it will decompose one. However, the order of the decomposed operations is no longer guaranteed to come after the first state preparation. This is due to how the tracing interacts with transforms and decomposition. The only order that JAXPR cares about is the use-def chain. However, if two state preparation operations occur on a different subset of wires, then the state preparation can be placed after the one that was decomposed.

Description of the Change: We add set_basis_state_p and set_state_p to the FORCED_ORDER_PRIMITIVES set to ensure that the order is preserved as they are traced.

Benefits: No logic errors.

Possible Drawbacks: More dependency on the topological sorting. I believe some time ago it was found out that topological sorting takes a long time and it takes longer the more FORCED_ORDER_PRIMITIVES there are.

Related GitHub Issues:

TODO:

  • MLIR Pass to analyze qnodes to make this a compile time guarantee.

@paul0403
Copy link
Contributor

paul0403 commented Oct 3, 2024

Should we add some verification in MLIR? i.e. verify that there must be no other quantum.custom operations before a quantum.set_state or quantum.set_basis_state

The reason I suggest is because I think logical errors (aka silently give wrong results) are much more unfavorable as opposed to some sort of hard program termination.

@dime10
Copy link
Collaborator

dime10 commented Oct 3, 2024

Is this an issue with the implementation of set_state? In that, should it only affect the qubits it receives instead of the entire statevector?

@erick-xanadu
Copy link
Contributor Author

Should we add some verification in MLIR?

Yes.

@erick-xanadu
Copy link
Contributor Author

erick-xanadu commented Oct 3, 2024

Is this an issue with the implementation of set_state? In that, should it only affect the qubits it receives instead of the entire statevector?

Well, that is the reason why the error occurs. set_basis_state_p will go into PL/lightning's implementation which will zero out the entire state vector and set the index computed at runtime. Since it is a condition that calling SetBasisState must happen as the first operation I would say that this is a code generation error. But alternatively we can change the semantics of the runtime implementation to not do zero-out the whole state vector and only a subset of it. That would be a discussion to have with the lightning team.

@erick-xanadu
Copy link
Contributor Author

I'll wait until there's consensus on the solution's implementation. I don't mind either way, but I prefer not relying on the topological sorting.

@dime10
Copy link
Collaborator

dime10 commented Oct 4, 2024

I did talk with Lee about it and he was open to not zeroing out the entire statevector. Someone would need to implement it though. There is also some discussion about these issues in slack here (and the post right below it):
https://xanaduhq.slack.com/archives/CA89H1BAN/p1727985781572109

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants