Skip to content

Commit

Permalink
Merge pull request #349 from SciML/mkl
Browse files Browse the repository at this point in the history
Setup MKL direct factorizations
  • Loading branch information
ChrisRackauckas authored Aug 1, 2023
2 parents 98a2292 + aaf64d3 commit 1ee467b
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 2 deletions.
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2"
Expand All @@ -37,6 +38,7 @@ Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2"
LinearSolveCUDAExt = "CUDA"
LinearSolveHYPREExt = "HYPRE"
LinearSolveIterativeSolversExt = "IterativeSolvers"
LinearSolveMKLExt = "MKL_jll"
LinearSolveKrylovKitExt = "KrylovKit"
LinearSolvePardisoExt = "Pardiso"

Expand Down Expand Up @@ -70,6 +72,7 @@ InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
MultiFloats = "bdf0d083-296b-4888-a5b6-7498122e68a5"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Expand All @@ -78,4 +81,4 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI"]
test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "MKL_jll"]
50 changes: 50 additions & 0 deletions ext/LinearSolveMKLExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
module LinearSolveMKLExt

using MKL_jll
using LinearAlgebra: BlasInt, LU
using LinearAlgebra.LAPACK: require_one_based_indexing, chkfinite, chkstride1,
@blasfunc, chkargsok
using LinearAlgebra
const usemkl = MKL_jll.is_available()

using LinearSolve
using LinearSolve: ArrayInterface, MKLLUFactorization, @get_cacheval, LinearCache, SciMLBase

function getrf!(A::AbstractMatrix{<:Float64}; ipiv = similar(A, BlasInt, min(size(A,1),size(A,2))), info = Ref{BlasInt}(), check = false)
require_one_based_indexing(A)
check && chkfinite(A)
chkstride1(A)
m, n = size(A)
lda = max(1,stride(A, 2))
ccall((@blasfunc(dgetrf_), MKL_jll.libmkl_rt), Cvoid,
(Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64},
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
m, n, A, lda, ipiv, info)
chkargsok(info[])
A, ipiv, info[] #Error code is stored in LU factorization type
end

default_alias_A(::MKLLUFactorization, ::Any, ::Any) = false
default_alias_b(::MKLLUFactorization, ::Any, ::Any) = false

function LinearSolve.init_cacheval(alg::MKLLUFactorization, A, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
ArrayInterface.lu_instance(convert(AbstractMatrix, A))
end

function SciMLBase.solve!(cache::LinearCache, alg::MKLLUFactorization;
kwargs...)
A = cache.A
A = convert(AbstractMatrix, A)
if cache.isfresh
cacheval = @get_cacheval(cache, :MKLLUFactorization)
fact = LU(getrf!(A)...)
cache.cacheval = fact
cache.isfresh = false
end
y = ldiv!(cache.u, @get_cacheval(cache, :MKLLUFactorization), cache.b)
SciMLBase.build_linear_solution(alg, y, nothing, cache)
end

end
4 changes: 4 additions & 0 deletions src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ end
@require KrylovKit="0b1a1467-8014-51b9-945f-bf0ae24f4b77" begin
include("../ext/LinearSolveKrylovKitExt.jl")
end
@require MKL_jll="856f044c-d86e-5d09-b602-aeab76dc8ba7" begin
include("../ext/LinearSolveMKLExt.jl")
end
end
end

Expand Down Expand Up @@ -181,6 +184,7 @@ export HYPREAlgorithm
export CudaOffloadFactorization
export MKLPardisoFactorize, MKLPardisoIterate
export PardisoJL
export MKLLUFactorization

export OperatorAssumptions, OperatorCondition

Expand Down
10 changes: 10 additions & 0 deletions src/extension_algs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -337,3 +337,13 @@ A wrapper over the IterativeSolvers.jl MINRES.
"""
function IterativeSolversJL_MINRES end

"""
```julia
MKLLUFactorization()
```
A wrapper over Intel's Math Kernel Library (MKL). Direct calls to MKL in a way that pre-allocates workspace
to avoid allocations and does not require libblastrampoline.
"""
struct MKLLUFactorization <: AbstractFactorization end
3 changes: 2 additions & 1 deletion test/basictests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using LinearSolve, LinearAlgebra, SparseArrays, MultiFloats, ForwardDiff
using SciMLOperators
using IterativeSolvers, KrylovKit
using IterativeSolvers, KrylovKit, MKL_jll
using Test
import Random

Expand Down Expand Up @@ -207,6 +207,7 @@ end
QRFactorization(),
SVDFactorization(),
RFLUFactorization(),
MKLLUFactorization(),
LinearSolve.defaultalg(prob1.A, prob1.b))
@testset "$alg" begin
test_interface(alg, prob1, prob2)
Expand Down

0 comments on commit 1ee467b

Please sign in to comment.