Skip to content

Commit

Permalink
Add actual file
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored and ChrisRackauckas committed Sep 22, 2023
1 parent e53f1b0 commit 004eb9d
Showing 1 changed file with 38 additions and 0 deletions.
38 changes: 38 additions & 0 deletions ext/LinearSolveEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
module LinearSolveEnzymeExt

using LinearSolve
isdefined(Base, :get_extension) ? (import Enzyme) : (import ..Enzyme)


using Enzyme

using EnzymeCore

# y=inv(A) B
# dA −= z y^T
# dB += z, where z = inv(A^T) dy
function EnzymeCore.EnzymeRules.augmented_primal(config::EnzymeCore.EnzymeRules.ConfigWidth{1}, func::Const{typeof(LinearSolve.solve)}, ::Type{Duplicated{RT}}, prob::Duplicated{LP}, alg::Const; kwargs...) where {RT, LP <: LinearProblem}
res = func.val(prob.val, alg.val; kwargs...)
dres = deepcopy(res)
dres.u .= 0
cache = (copy(prob.val.A), res, dres.u)
return EnzymeCore.EnzymeRules.AugmentedReturn{RT, RT, typeof(cache)}(res, dres, cache)
end

function EnzymeCore.EnzymeRules.reverse(config::EnzymeCore.EnzymeRules.ConfigWidth{1}, func::Const{typeof(LinearSolve.solve)}, ::Type{Duplicated{RT}}, cache, prob::Duplicated{LP}, alg::Const; kwargs...) where {RT, LP <: LinearProblem}
A, y, dy = cache

dA = prob.dval.A
db = prob.dval.b

invprob = LinearProblem(transpose(A), dy)

z = func.val(invprob, alg; kwargs...)

dA .-= z * transpose(y)
db .+= z
dy .= 0
return (nothing, nothing)
end

end

0 comments on commit 004eb9d

Please sign in to comment.