Skip to content

Commit

Permalink
getting very close
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Sep 24, 2023
1 parent 9630121 commit b0d228d
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 15 deletions.
9 changes: 4 additions & 5 deletions ext/LinearSolveEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,12 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(Line
(dr.u for dr in dres)

Check warning on line 83 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L83

Added line #L83 was not covered by tests
end

cache = (res, resvals)
cache = (res, resvals, deepcopy(linsolve.val))
return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, cache)

Check warning on line 87 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L86-L87

Added lines #L86 - L87 were not covered by tests
end

function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, cache, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache}
y, dys = cache
_linsolve = linsolve.val
y, dys, _linsolve = cache

Check warning on line 91 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L90-L91

Added lines #L90 - L91 were not covered by tests

@assert !(typeof(linsolve) <: Const)
@assert !(typeof(linsolve) <: Active)

Check warning on line 94 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L93-L94

Added lines #L93 - L94 were not covered by tests
Expand All @@ -113,9 +112,9 @@ function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.s
for (dA, db, dy) in zip(dAs, dbs, dys)
z = if _linsolve.cacheval isa Factorization
_linsolve.cacheval' \ dy
elseif linsolve.cacheval isa Tuple && linsolve.cacheval[1] isa Factorization
elseif _linsolve.cacheval isa Tuple && _linsolve.cacheval[1] isa Factorization
_linsolve.cacheval[1]' \ dy
elseif linsolve.alg isa AbstractKrylovSubspaceMethod
elseif _linsolve.alg isa AbstractKrylovSubspaceMethod

Check warning on line 117 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L112-L117

Added lines #L112 - L117 were not covered by tests
# Doesn't modify `A`, so it's safe to just reuse it
invprob = LinearSolve.LinearProblem(transpose(_linsolve.A), dy)
solve(invprob;

Check warning on line 120 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L119-L120

Added lines #L119 - L120 were not covered by tests
Expand Down
54 changes: 44 additions & 10 deletions test/enzyme.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Enzyme, FiniteDiff
using Enzyme, ForwardDiff
using LinearSolve, LinearAlgebra, Test

n = 4
Expand All @@ -20,8 +20,8 @@ f(A, b1) # Uses BLAS

Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), Duplicated(copy(b1), db1))

dA2 = FiniteDiff.finite_difference_gradient(x->f(x,b1), copy(A))
db12 = FiniteDiff.finite_difference_gradient(x->f(A,x), copy(b1))
dA2 = ForwardDiff.gradient(x->f(x,eltype(x).(b1)), copy(A))
db12 = ForwardDiff.gradient(x->f(eltype(x).(A),x), copy(b1))

@test dA dA2
@test db1 db12
Expand All @@ -35,8 +35,8 @@ db12 = zeros(n);

@test_broken Enzyme.autodiff(Reverse, f, BatchDuplicated(copy(A), (dA, dA2)), BatchDuplicated(copy(b1), (db1, db12)))

dA_2 = FiniteDiff.finite_difference_gradient(x->f(x,b1), copy(A))
db1_2 = FiniteDiff.finite_difference_gradient(x->f(A,x), copy(b1))
dA2 = ForwardDiff.gradient(x->f(x,eltype(x).(b1)), copy(A))
db12 = ForwardDiff.gradient(x->f(eltype(x).(A),x), copy(b1))

@test_broken dA dA_2
@test_broken dA2 dA_2
Expand All @@ -45,9 +45,8 @@ db1_2 = FiniteDiff.finite_difference_gradient(x->f(A,x), copy(b1))

function f(A, b1, b2; alg = LUFactorization())
prob = LinearProblem(A, b1)

cache = init(prob, alg)
s1 = solve!(cache).u
s1 = copy(solve!(cache).u)
cache.b = b2
s2 = solve!(cache).u
norm(s1 + s2)
Expand All @@ -60,11 +59,46 @@ db1 = zeros(n);
b2 = rand(n);
db2 = zeros(n);

f(A, b1, b2)
Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2))

dA2 = FiniteDiff.finite_difference_gradient(x->f(x,b1,b2), copy(A))
db12 = FiniteDiff.finite_difference_gradient(x->f(A,x,b2), copy(b1))
db22 = FiniteDiff.finite_difference_gradient(x->f(A,b1,x), copy(b2))
dA2 = ForwardDiff.gradient(x->f(x,eltype(x).(b1),eltype(x).(b2)), copy(A))
db12 = ForwardDiff.gradient(x->f(eltype(x).(A),x,eltype(x).(b2)), copy(b1))
db22 = ForwardDiff.gradient(x->f(eltype(x).(A),eltype(x).(b1),x), copy(b2))

@test dA dA2
@test db1 db12
@test db2 db22

function f2(A, b1, b2; alg = RFLUFactorization())
prob = LinearProblem(A, b1)
cache = init(prob, alg)
s1 = copy(solve!(cache).u)
cache.b = b2
s2 = solve!(cache).u
norm(s1 + s2)
end

f2(A, b1, b2)
dA = zeros(n, n);
db1 = zeros(n);
db2 = zeros(n);
Enzyme.autodiff(Reverse, f2, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2))

@test dA dA2
@test db1 db12
@test db2 db22

function f3(A, b1, b2; alg = KrylovJL_GMRES())
prob = LinearProblem(A, b1)
cache = init(prob, alg)
s1 = solve!(cache).u
cache.b = b2
s2 = solve!(cache).u
norm(s1 + s2)
end

Enzyme.autodiff(Reverse, f3, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2))

@test dA dA2 atol=5e-5
@test db1 db12
Expand Down

0 comments on commit b0d228d

Please sign in to comment.