Skip to content

Commit

Permalink
fix mutated db
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Sep 24, 2023
1 parent 54f0722 commit e4f0785
Showing 1 changed file with 15 additions and 14 deletions.
29 changes: 15 additions & 14 deletions ext/LinearSolveEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(Line
else
(dval.b for dval in dres)

Check warning on line 30 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L30

Added line #L30 was not covered by tests
end

return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, (d_A, d_b))

Check warning on line 33 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L33

Added line #L33 was not covered by tests
end

Expand Down Expand Up @@ -89,20 +90,6 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(Line
(dr.u for dr in dres)

Check warning on line 90 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L90

Added line #L90 was not covered by tests
end

cache = (res, resvals, deepcopy(linsolve.val))
return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, cache)
end

function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, cache, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache}
y, dys, _linsolve = cache

@assert !(typeof(linsolve) <: Const)
@assert !(typeof(linsolve) <: Active)

if EnzymeRules.width(config) == 1
dys = (dys,)
end

dAs = if EnzymeRules.width(config) == 1
(linsolve.dval.A,)

Check warning on line 94 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L93-L94

Added lines #L93 - L94 were not covered by tests
else
Expand All @@ -115,6 +102,20 @@ function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.s
(dval.b for dval in linsolve.dval)

Check warning on line 102 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L102

Added line #L102 was not covered by tests
end

cache = (res, resvals, deepcopy(linsolve.val), dAs, dbs)
return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, cache)

Check warning on line 106 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L105-L106

Added lines #L105 - L106 were not covered by tests
end

function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, cache, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache}
y, dys, _linsolve, dAs, dbs = cache

Check warning on line 110 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L109-L110

Added lines #L109 - L110 were not covered by tests

@assert !(typeof(linsolve) <: Const)
@assert !(typeof(linsolve) <: Active)

Check warning on line 113 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L112-L113

Added lines #L112 - L113 were not covered by tests

if EnzymeRules.width(config) == 1
dys = (dys,)

Check warning on line 116 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L115-L116

Added lines #L115 - L116 were not covered by tests
end

for (dA, db, dy) in zip(dAs, dbs, dys)
z = if _linsolve.cacheval isa Factorization
_linsolve.cacheval' \ dy
Expand Down

0 comments on commit e4f0785

Please sign in to comment.