Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Make more libraries into an extension? #530

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
FastLapackInterface = "29a986be-02c6-4525-aec4-84b980013641"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
KLU = "ef3ab10e-7fda-4108-b977-705223b18434"
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
Expand All @@ -26,7 +25,6 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Sparspak = "e56a9233-b9d6-4f03-8d0f-1825330902ac"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
Expand All @@ -40,11 +38,13 @@ EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
FastAlmostBandedMatrices = "9d29842c-ecb8-4973-b1e9-a27b1157504e"
HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
KLU = "ef3ab10e-7fda-4108-b977-705223b18434"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[extensions]
LinearSolveBandedMatricesExt = "BandedMatrices"
Expand All @@ -55,11 +55,13 @@ LinearSolveEnzymeExt = "EnzymeCore"
LinearSolveFastAlmostBandedMatricesExt = "FastAlmostBandedMatrices"
LinearSolveHYPREExt = "HYPRE"
LinearSolveIterativeSolversExt = "IterativeSolvers"
LinearSolveKLUExt = "KLU"
LinearSolveKernelAbstractionsExt = "KernelAbstractions"
LinearSolveKrylovKitExt = "KrylovKit"
LinearSolveMetalExt = "Metal"
LinearSolvePardisoExt = "Pardiso"
LinearSolveRecursiveArrayToolsExt = "RecursiveArrayTools"
LinearSolveSparseArraysExt = "SparseArrays"

[compat]
AllocCheck = "0.1"
Expand Down
66 changes: 66 additions & 0 deletions ext/LinearSolveKLUExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
module LinearSolveKLUExt

using LinearSolve, LinearSolve.LinearAlgebra
using KLU, KLU.SparseArrays

const PREALLOCATED_KLU = KLU.KLUFactorization(SparseMatrixCSC(0, 0, [1], Int[],
Float64[]))

function init_cacheval(alg::KLUFactorization,
A, b, u, Pl, Pr,
maxiters::Int, abstol, reltol,
verbose::Bool, assumptions::OperatorAssumptions)
nothing
end

function init_cacheval(alg::KLUFactorization, A::SparseMatrixCSC{Float64, Int}, b, u, Pl,
Pr,
maxiters::Int, abstol, reltol,
verbose::Bool, assumptions::OperatorAssumptions)
PREALLOCATED_KLU
end

function init_cacheval(alg::KLUFactorization, A::AbstractSparseArray, b, u, Pl, Pr,
maxiters::Int, abstol,
reltol,
verbose::Bool, assumptions::OperatorAssumptions)
A = convert(AbstractMatrix, A)
return KLU.KLUFactorization(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
nonzeros(A)))
end

# TODO: guard this against errors
function SciMLBase.solve!(cache::LinearCache, alg::KLUFactorization; kwargs...)
A = cache.A
A = convert(AbstractMatrix, A)
if cache.isfresh
cacheval = @get_cacheval(cache, :KLUFactorization)
if alg.reuse_symbolic
if alg.check_pattern && pattern_changed(cacheval, A)
fact = KLU.klu(
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
nonzeros(A)),
check = false)
else
fact = KLU.klu!(cacheval, nonzeros(A), check = false)
end
else
# New fact each time since the sparsity pattern can change
# and thus it needs to reallocate
fact = KLU.klu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
nonzeros(A)))
end
cache.cacheval = fact
cache.isfresh = false
end
F = @get_cacheval(cache, :KLUFactorization)
if F.common.status == KLU.KLU_OK
y = ldiv!(cache.u, F, cache.b)
SciMLBase.build_linear_solution(alg, y, nothing, cache)
else
SciMLBase.build_linear_solution(
alg, cache.u, nothing, cache; retcode = ReturnCode.Infeasible)
end
end

end
178 changes: 178 additions & 0 deletions ext/LinearSolveSparseArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
module LinearSolveSparseArraysExt

using LinearSolve
import LinearSolve: SciMLBase, LinearAlgebra, PrecompileTools, init_cacheval
using LinearSolve: DefaultLinearSolver, DefaultAlgorithmChoice
using SparseArrays
using SparseArrays: AbstractSparseMatrixCSC, nonzeros, rowvals, getcolptr

# Specialize QR for the non-square case
# Missing ldiv! definitions: https://github.com/JuliaSparse/SparseArrays.jl/issues/242
function LinearSolve._ldiv!(x::Vector,
A::Union{SparseArrays.QR, LinearAlgebra.QRCompactWY,
SparseArrays.SPQR.QRSparse,
SparseArrays.CHOLMOD.Factor}, b::Vector)
x .= A \ b
end

function LinearSolve._ldiv!(x::AbstractVector,
A::Union{SparseArrays.QR, LinearAlgebra.QRCompactWY,
SparseArrays.SPQR.QRSparse,
SparseArrays.CHOLMOD.Factor}, b::AbstractVector)
x .= A \ b
end

# Ambiguity removal
function LinearSolve._ldiv!(::SVector,
A::Union{SparseArrays.CHOLMOD.Factor, LinearAlgebra.QR,
LinearAlgebra.QRCompactWY, SparseArrays.SPQR.QRSparse},
b::AbstractVector)
(A \ b)
end
function LinearSolve._ldiv!(::SVector,
A::Union{SparseArrays.CHOLMOD.Factor, LinearAlgebra.QR,
LinearAlgebra.QRCompactWY, SparseArrays.SPQR.QRSparse},
b::SVector)
(A \ b)
end

function LinearSolve.pattern_changed(fact, A::SparseArrays.SparseMatrixCSC)
!(SparseArrays.decrement(SparseArrays.getcolptr(A)) ==
fact.colptr && SparseArrays.decrement(SparseArrays.getrowval(A)) ==
fact.rowval)
end

const PREALLOCATED_UMFPACK = SparseArrays.UMFPACK.UmfpackLU(SparseMatrixCSC(0, 0, [1],
Int[], Float64[]))

function init_cacheval(alg::UMFPACKFactorization,
A, b, u, Pl, Pr,
maxiters::Int, abstol, reltol,
verbose::Bool, assumptions::OperatorAssumptions)
nothing
end

function init_cacheval(alg::UMFPACKFactorization, A::SparseMatrixCSC{Float64, Int}, b, u,
Pl, Pr,
maxiters::Int, abstol, reltol,
verbose::Bool, assumptions::OperatorAssumptions)
PREALLOCATED_UMFPACK
end

function init_cacheval(alg::UMFPACKFactorization, A::AbstractSparseArray, b, u, Pl, Pr,
maxiters::Int, abstol,
reltol,
verbose::Bool, assumptions::OperatorAssumptions)
A = convert(AbstractMatrix, A)
zerobased = SparseArrays.getcolptr(A)[1] == 0
return SparseArrays.UMFPACK.UmfpackLU(SparseMatrixCSC(size(A)..., getcolptr(A),
rowvals(A), nonzeros(A)))
end

function SciMLBase.solve!(cache::LinearCache, alg::UMFPACKFactorization; kwargs...)
A = cache.A
A = convert(AbstractMatrix, A)
if cache.isfresh
cacheval = @get_cacheval(cache, :UMFPACKFactorization)
if alg.reuse_symbolic
# Caches the symbolic factorization: https://github.com/JuliaLang/julia/pull/33738
if alg.check_pattern && pattern_changed(cacheval, A)
fact = lu(
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
nonzeros(A)),
check = false)
else
fact = lu!(cacheval,
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
nonzeros(A)), check = false)
end
else
fact = lu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)),
check = false)
end
cache.cacheval = fact
cache.isfresh = false
end

F = @get_cacheval(cache, :UMFPACKFactorization)
if F.status == SparseArrays.UMFPACK.UMFPACK_OK
y = ldiv!(cache.u, F, cache.b)
SciMLBase.build_linear_solution(alg, y, nothing, cache)
else
SciMLBase.build_linear_solution(
alg, cache.u, nothing, cache; retcode = ReturnCode.Infeasible)
end
end

const PREALLOCATED_CHOLMOD = cholesky(SparseMatrixCSC(0, 0, [1], Int[], Float64[]))

function init_cacheval(alg::CHOLMODFactorization,
A, b, u, Pl, Pr,
maxiters::Int, abstol, reltol,
verbose::Bool, assumptions::OperatorAssumptions)
nothing
end

function init_cacheval(alg::CHOLMODFactorization,
A::Union{SparseMatrixCSC{T, Int}, Symmetric{T, SparseMatrixCSC{T, Int}}}, b, u,
Pl, Pr,
maxiters::Int, abstol, reltol,
verbose::Bool, assumptions::OperatorAssumptions) where {T <:
Union{Float32, Float64}}
PREALLOCATED_CHOLMOD
end

function SciMLBase.solve!(cache::LinearCache, alg::CHOLMODFactorization; kwargs...)
A = cache.A
A = convert(AbstractMatrix, A)

if cache.isfresh
cacheval = @get_cacheval(cache, :CHOLMODFactorization)
fact = cholesky(A; check = false)
if !LinearAlgebra.issuccess(fact)
ldlt!(fact, A; check = false)
end
cache.cacheval = fact
cache.isfresh = false
end

cache.u .= @get_cacheval(cache, :CHOLMODFactorization) \ cache.b
SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
end

function LinearSolve.defaultalg(
A::Symmetric{<:Number, <:SparseMatrixCSC}, b, ::OperatorAssumptions{Bool})
DefaultLinearSolver(DefaultAlgorithmChoice.CHOLMODFactorization)
end

function LinearSolve.defaultalg(A::AbstractSparseMatrixCSC{Tv, Ti}, b,
assump::OperatorAssumptions{Bool}) where {Tv, Ti}
if assump.issq
DefaultLinearSolver(DefaultAlgorithmChoice.SparspakFactorization)
else
error("Generic number sparse factorization for non-square is not currently handled")
end
end

function LinearSolve.defaultalg(A::AbstractSparseMatrixCSC{<:Union{Float64, ComplexF64}, Ti}, b,
assump::OperatorAssumptions{Bool}) where {Ti}
if assump.issq
if length(b) <= 10_000 && length(nonzeros(A)) / length(A) < 2e-4
DefaultLinearSolver(DefaultAlgorithmChoice.KLUFactorization)
else
DefaultLinearSolver(DefaultAlgorithmChoice.UMFPACKFactorization)
end
else
DefaultLinearSolver(DefaultAlgorithmChoice.QRFactorization)
end
end

PrecompileTools.@compile_workload begin
A = sprand(4, 4, 0.3) + I
b = rand(4)
prob = LinearProblem(A, b)
sol = solve(prob, KLUFactorization())
sol = solve(prob, UMFPACKFactorization())
end

end
Loading
Loading