Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix extrapolation in ZNE function #1213

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open

Fix extrapolation in ZNE function #1213

wants to merge 2 commits into from

Conversation

dime10
Copy link
Collaborator

@dime10 dime10 commented Oct 17, 2024

Extrapolation was done against the folding numbers instead of the scale factors. Since the folding numbers start at 0, extrapolation would always yield a result very close to the first data point.

Need to extrapolate against scale factors rather than folding numbers.
@dime10 dime10 added bug Something isn't working frontend Pull requests that update the frontend labels Oct 17, 2024
@dime10 dime10 requested a review from rmoyard October 17, 2024 16:54
@@ -273,6 +273,10 @@

<h3>Bug fixes</h3>

* Fix a bug in the extrapolation part of the `catalyst.mitigate_with_zne` function that would lead
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: the sentence reads a bit redundant. A bug in extrapolation code leads to..incorrect extrapolated result"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* Fix a bug in the extrapolation part of the `catalyst.mitigate_with_zne` function that would lead
* Fix a bug in `catalyst.mitigate_with_zne` that would lead

@@ -164,14 +162,14 @@ class ZNECallable(CatalystCallable):
def __init__(
self,
fn: Callable,
num_folds: jnp.ndarray,
scale_factors: Sequence,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can type them more strictly as [int]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
scale_factors: Sequence,
scale_factors: Sequence[int],

@@ -209,16 +207,18 @@ def __call__(self, *args, **kwargs):
callable_fn
), "expected callable set as param on the first operation in zne target"

results = zne_p.bind(
*args_data, self.num_folds, folding=folding, jaxpr=jaxpr, fn=callable_fn
fold_numbers = (jnp.asarray(self.scale_factors, dtype=int) - 1) // 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice that they became element-wise ops 🤙🏻

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working frontend Pull requests that update the frontend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants