Skip to content

Commit

Permalink
fix multiple solve handling
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Sep 24, 2023
1 parent 3b39753 commit cbb5f1d
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions ext/LinearSolveEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,13 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(Line
(dr.u for dr in dres)

Check warning on line 83 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L83

Added line #L83 was not covered by tests
end

cache = (res, resvals, linsolve.val)
cache = (res, resvals)
return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, cache)

Check warning on line 87 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L86-L87

Added lines #L86 - L87 were not covered by tests
end

function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, cache, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache}
y, dys, _linsolve = cache
y, dys = cache
_linsolve = linsolve.val

Check warning on line 92 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L90-L92

Added lines #L90 - L92 were not covered by tests

@assert !(typeof(linsolve) <: Const)
@assert !(typeof(linsolve) <: Active)

Check warning on line 95 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L94-L95

Added lines #L94 - L95 were not covered by tests
Expand All @@ -110,8 +111,8 @@ function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.s
end

for (dA, db, dy) in zip(dAs, dbs, dys)
z = if linsolve.cacheval isa Factorization
linsolve.cacheval' \ dy
z = if _linsolve.cacheval isa Factorization
_linsolve.cacheval' \ dy
elseif linsolve.cacheval isa Tuple && linsolve.cacheval[1] isa Factorization
linsolve.cacheval[1]' \ dy
elseif linsolve.alg isa AbstractKrylovSubspaceMethod

Check warning on line 118 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L113-L118

Added lines #L113 - L118 were not covered by tests
Expand Down

0 comments on commit cbb5f1d

Please sign in to comment.