From aa52e195059af3572604d16a481a88e93c70c5d8 Mon Sep 17 00:00:00 2001 From: "Feras A. Saad" Date: Tue, 2 May 2023 10:49:20 -0400 Subject: [PATCH] Fix bug in JAX array initialization: Qs.set(0).set(P0) -> Qs.at[0].set(P0). MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The current code fails with the following AttributeError: --> 240     Qs = Qs.set(0).set(P0)  # first element requires different initialisation AttributeError: DynamicJaxprTracer has no attribute set --- bayesnewton/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bayesnewton/ops.py b/bayesnewton/ops.py index 9523086..747dadd 100644 --- a/bayesnewton/ops.py +++ b/bayesnewton/ops.py @@ -237,7 +237,7 @@ def parallel_filtering_operator(elem1, elem2): def make_associative_filtering_elements(As, Qs, H, ys, noise_covs, m0, P0): - Qs = Qs.set(0).set(P0) # first element requires different initialisation + Qs = Qs.at[0].set(P0) # first element requires different initialisation AA, b, C, J, eta = parallel_filtering_element(As, Qs, H, noise_covs, ys) # modify initial b to account for m0 (not needed if m0=zeros) S = H @ Qs[0] @ H.T + noise_covs[0]