Skip to content

Commit

Permalink
More caching
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Sep 24, 2023
1 parent e4f0785 commit 3d25e6f
Showing 1 changed file with 39 additions and 14 deletions.
53 changes: 39 additions & 14 deletions ext/LinearSolveEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,34 +30,48 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(Line
(dval.b for dval in dres)
end

return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, (d_A, d_b))

prob_d_A = if EnzymeRules.width(config) == 1
prob.dval.A
else
(dval.A for dval in prob.dval)
end
prob_d_b = if EnzymeRules.width(config) == 1
prob.dval.b
else
(dval.b for dval in prob.dval)
end

return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, (d_A, d_b, prob_d_A, prob_d_b))
end

function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, cache, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
d_A, d_b = cache
d_A, d_b, prob_d_A, prob_d_b = cache

if EnzymeRules.width(config) == 1
if d_A !== prob.dval.A
prob.dval.A .+= d_A
if d_A !== prob_d_A
prob_d_A .+= d_A
d_A .= 0
end
if d_b !== prob.dval.b
prob.dval.b .+= d_b
if d_b !== prob_d_b
prob_d_b .+= d_b
d_b .= 0
end
else
for i in 1:EnzymeRules.width(config)
if d_A !== prob.dval.A
prob.dval.A[i] .+= d_A[i]
if d_A !== prob_d_A[i]
prob_d_A[i] .+= d_A[i]
d_A[i] .= 0
end
if d_b !== prob.dval.b
prob.dval.b[i] .+= d_b[i]
if d_b !== prob_d_b[i]
prob_d_b[i] .+= d_b[i]
d_b[i] .= 0
end
end
end

@show "dA init rev", d_A, Base.pointer_from_objref(d_A), prob_d_A, Base.pointer_from_objref(prob_d_A)

return (nothing, nothing)
end

Expand Down Expand Up @@ -87,22 +101,33 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(Line
resvals = if EnzymeRules.width(config) == 1
dres.u
else
(dr.u for dr in dres)
ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
dres[i].u
end
end

dAs = if EnzymeRules.width(config) == 1
(linsolve.dval.A,)
else
(dval.A for dval in linsolve.dval)
ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
linsolve.dval[i].A
end
end

dbs = if EnzymeRules.width(config) == 1
(linsolve.dval.b,)
else
(dval.b for dval in linsolve.dval)
ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
linsolve.dval[i].b
end
end

cache = (res, resvals, deepcopy(linsolve.val), dAs, dbs)
cachesolve = deepcopy(linsolve.val)

cache = (copy(res.u), resvals, cachesolve, dAs, dbs)
return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, cache)
end

Expand Down

0 comments on commit 3d25e6f

Please sign in to comment.