From e56227aeb4bcef3102dc4e843ff41ae804fb0c6b Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sat, 28 Oct 2023 11:21:25 +0100 Subject: [PATCH] Fix enzyme batch mode Fixes https://github.com/EnzymeAD/Enzyme.jl/issues/1075 --- ext/LinearSolveEnzymeExt.jl | 14 +++++------ test/enzyme.jl | 48 +++++++++++++++++++++++++++---------- 2 files changed, 42 insertions(+), 20 deletions(-) diff --git a/ext/LinearSolveEnzymeExt.jl b/ext/LinearSolveEnzymeExt.jl index fea332dcd..a25fe14ba 100644 --- a/ext/LinearSolveEnzymeExt.jl +++ b/ext/LinearSolveEnzymeExt.jl @@ -58,14 +58,14 @@ function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.i d_b .= 0 end else - for i in 1:EnzymeRules.width(config) - if d_A !== prob_d_A[i] - prob_d_A[i] .+= d_A[i] - d_A[i] .= 0 + for (_prob_d_A,_d_A,_prob_d_b, _d_b) in zip(prob_d_A, d_A, prob_d_b, d_b) + if _d_A !== _prob_d_A + _prob_d_A .+= _d_A + _d_A .= 0 end - if d_b !== prob_d_b[i] - prob_d_b[i] .+= d_b[i] - d_b[i] .= 0 + if _d_b !== _prob_d_b + _prob_d_b .+= _d_b + _d_b .= 0 end end end diff --git a/test/enzyme.jl b/test/enzyme.jl index 62904c055..a194450d0 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -33,15 +33,25 @@ b1 = rand(n); db1 = zeros(n); db12 = zeros(n); -#= -# Batch test fails -# Captured in MWE: https://github.com/EnzymeAD/Enzyme.jl/issues/1075 +# Batch test +n = 4 +A = rand(n, n); +dA = zeros(n, n); +dA2 = zeros(n, n); +b1 = rand(n); +db1 = zeros(n); +db12 = zeros(n); -function fbatch(y, A, b1; alg = LUFactorization()) +function f(A, b1; alg = LUFactorization()) prob = LinearProblem(A, b1) - sol1 = solve(prob, alg) + s1 = sol1.u + norm(s1) +end +function fbatch(y, A, b1; alg = LUFactorization()) + prob = LinearProblem(A, b1) + sol1 = solve(prob, alg) s1 = sol1.u y[1] = norm(s1) nothing @@ -50,16 +60,28 @@ end y = [0.0] dy1 = [1.0] dy2 = [1.0] +Enzyme.autodiff(Reverse, fbatch, Duplicated(y, dy1), Duplicated(copy(A), dA), Duplicated(copy(b1), db1)) + +@test y[1] ≈ f(copy(A),b1) +dA_2 = ForwardDiff.gradient(x->f(x,eltype(x).(b1)), copy(A)) +db1_2 = ForwardDiff.gradient(x->f(eltype(x).(A),x), copy(b1)) + +@test dA ≈ dA_2 +@test db1 ≈ db1_2 + +y .= 0 +dy1 .= 1 +dy2 .= 1 +dA .= 0 +dA2 .= 0 +db1 .= 0 +db12 .= 0 Enzyme.autodiff(Reverse, fbatch, BatchDuplicated(y, (dy1, dy2)), BatchDuplicated(copy(A), (dA, dA2)), BatchDuplicated(copy(b1), (db1, db12))) -dA2 = ForwardDiff.gradient(x->f(x,eltype(x).(b1)), copy(A)) -db12 = ForwardDiff.gradient(x->f(eltype(x).(A),x), copy(b1)) - -@test_broken dA ≈ dA_2 -@test_broken dA2 ≈ dA_2 -@test_broken db1 ≈ db1_2 -@test_broken db12 ≈ db1_2 -=# +@test dA ≈ dA_2 +@test db1 ≈ db1_2 +@test dA2 ≈ dA_2 +@test db12 ≈ db1_2 function f(A, b1, b2; alg = LUFactorization()) prob = LinearProblem(A, b1)