diff --git a/src/ImplicitDifferentiation.jl b/src/ImplicitDifferentiation.jl index efe2f34..4f90924 100644 --- a/src/ImplicitDifferentiation.jl +++ b/src/ImplicitDifferentiation.jl @@ -23,6 +23,6 @@ using LinearAlgebra: axpby!, factorize, lu include("implicit_function.jl") include("operators.jl") -export ImplicitFunction +export ImplicitFunction, KrylovLinearSolver end diff --git a/src/implicit_function.jl b/src/implicit_function.jl index 6201e2a..c4e234e 100644 --- a/src/implicit_function.jl +++ b/src/implicit_function.jl @@ -1,34 +1,43 @@ """ KrylovLinearSolver -Callable object that can solve linear systems `As = b` and `AS = b` in the same way as the built-in `\\`. +Callable object that can solve linear systems `Ax = b` and `AX = B` in the same way as the built-in `\\`. Uses an iterative solver from [Krylov.jl](https://github.com/JuliaSmoothOptimizers/Krylov.jl) under the hood. -# Note +# Constructor -This name is not exported, and thus not part of the public API, but it is used in the [`ImplicitFunction`](@ref) constructors. -""" -struct KrylovLinearSolver end + KrylovLinearSolver(; verbose=true) + +If `verbose` is `true`, the solver logs a warning in case of failure. +Otherwise it will fail silently, and may return solutions that do not exactly satisfy the linear system. + +# Callable behavior -""" (::KylovLinearSolver)(A, b::AbstractVector) Solve a linear system with a single right-hand side. -""" -function (::KrylovLinearSolver)(A, b::AbstractVector) - x, stats = gmres(A, b) - return x -end -""" (::KrylovLinearSolver)(A, B::AbstractMatrix) Solve a linear system with multiple right-hand sides. """ -function (::KrylovLinearSolver)(A, B::AbstractMatrix) +Base.@kwdef struct KrylovLinearSolver + verbose::Bool = true +end + +function (solver::KrylovLinearSolver)(A, b::AbstractVector) + x, stats = gmres(A, b) + if !stats.solved || stats.inconsistent + solver.verbose && + @warn "Failed to solve the linear system in the implicit function theorem with `Krylov.gmres`" stats + end + return x +end + +function (solver::KrylovLinearSolver)(A, B::AbstractMatrix) # X, stats = block_gmres(A, B) # https://github.com/JuliaSmoothOptimizers/Krylov.jl/issues/854 X = mapreduce(hcat, eachcol(B)) do b - first(gmres(A, b)) + solver(A, b) end return X end @@ -80,6 +89,14 @@ Picks the `lazy` parameter automatically based on the `linear_solver`, using the Picks the `linear_solver` automatically based on the `lazy` parameter. +# Callable behavior + + (implicit::ImplicitFunction)(x::AbstractVector, args...; kwargs...) + +Return `implicit.forward(x, args...; kwargs...)`, which can be either an `AbstractVector` `y` or a tuple `(y, z)`. + +This call makes `y` differentiable with respect to `x`. + # Function signatures There are two possible signatures for `forward` and `conditions`, which must be consistent with one another: @@ -122,9 +139,6 @@ struct ImplicitFunction{ conditions_y_backend::B2 end -""" - -""" function ImplicitFunction{lazy}( forward::F, conditions::C; @@ -163,13 +177,6 @@ function Base.show(io::IO, implicit::ImplicitFunction{lazy}) where {lazy} ) end -""" - (implicit::ImplicitFunction)(x::AbstractVector, args...; kwargs...) - -Return `implicit.forward(x, args...; kwargs...)`, which can be either an `AbstractVector` `y` or a tuple `(y, z)`. - -This call makes `y` differentiable with respect to `x`. -""" function (implicit::ImplicitFunction)(x::AbstractVector, args...; kwargs...) y_or_yz = implicit.forward(x, args...; kwargs...) return y_or_yz