diff --git a/ext/LinearSolveEnzymeExt.jl b/ext/LinearSolveEnzymeExt.jl index 45e9231d1..c3e474c71 100644 --- a/ext/LinearSolveEnzymeExt.jl +++ b/ext/LinearSolveEnzymeExt.jl @@ -83,12 +83,13 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(Line (dr.u for dr in dres) end - cache = (res, resvals, linsolve.val) + cache = (res, resvals) return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, cache) end function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, cache, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache} - y, dys, _linsolve = cache + y, dys = cache + _linsolve = linsolve.val @assert !(typeof(linsolve) <: Const) @assert !(typeof(linsolve) <: Active) @@ -110,8 +111,8 @@ function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.s end for (dA, db, dy) in zip(dAs, dbs, dys) - z = if linsolve.cacheval isa Factorization - linsolve.cacheval' \ dy + z = if _linsolve.cacheval isa Factorization + _linsolve.cacheval' \ dy elseif linsolve.cacheval isa Tuple && linsolve.cacheval[1] isa Factorization linsolve.cacheval[1]' \ dy elseif linsolve.alg isa AbstractKrylovSubspaceMethod