Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Apple Accelerate and improve MKL integration #355

Merged
merged 11 commits into from
Aug 8, 2023
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
KLU = "ef3ab10e-7fda-4108-b977-705223b18434"
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Expand Down
12 changes: 8 additions & 4 deletions ext/LinearSolveMKLExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@ function getrf!(A::AbstractMatrix{<:Float64}; ipiv = similar(A, BlasInt, min(siz
chkstride1(A)
m, n = size(A)
lda = max(1,stride(A, 2))
if isempty(ipiv)
ipiv = similar(A, BlasInt, min(size(A,1),size(A,2)))
end
ccall((@blasfunc(dgetrf_), MKL_jll.libmkl_rt), Cvoid,
(Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64},
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
m, n, A, lda, ipiv, info)
chkargsok(info[])
A, ipiv, info[] #Error code is stored in LU factorization type
A, ipiv, info[], info #Error code is stored in LU factorization type
end

default_alias_A(::MKLLUFactorization, ::Any, ::Any) = false
Expand All @@ -30,7 +33,7 @@ default_alias_b(::MKLLUFactorization, ::Any, ::Any) = false
function LinearSolve.init_cacheval(alg::MKLLUFactorization, A, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
ArrayInterface.lu_instance(convert(AbstractMatrix, A))
ArrayInterface.lu_instance(convert(AbstractMatrix, A)), Ref{BlasInt}()
end

function SciMLBase.solve!(cache::LinearCache, alg::MKLLUFactorization;
Expand All @@ -39,11 +42,12 @@ function SciMLBase.solve!(cache::LinearCache, alg::MKLLUFactorization;
A = convert(AbstractMatrix, A)
if cache.isfresh
cacheval = @get_cacheval(cache, :MKLLUFactorization)
fact = LU(getrf!(A)...)
res = getrf!(A; ipiv = cacheval[1].ipiv, info = cacheval[2])
fact = LU(res[1:3]...), res[4]
cache.cacheval = fact
cache.isfresh = false
end
y = ldiv!(cache.u, @get_cacheval(cache, :MKLLUFactorization), cache.b)
y = ldiv!(cache.u, @get_cacheval(cache, :MKLLUFactorization)[1], cache.b)
SciMLBase.build_linear_solution(alg, y, nothing, cache)
end

Expand Down
6 changes: 6 additions & 0 deletions src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ using EnumX
using Requires
import InteractiveUtils

using LinearAlgebra: BlasInt, LU
using LinearAlgebra.LAPACK: require_one_based_indexing, chkfinite, chkstride1,
@blasfunc, chkargsok

import GPUArraysCore
import Preferences

Expand Down Expand Up @@ -87,6 +91,7 @@ include("solve_function.jl")
include("default.jl")
include("init.jl")
include("extension_algs.jl")
include("appleaccelerate.jl")
include("deprecated.jl")

@generated function SciMLBase.solve!(cache::LinearCache, alg::AbstractFactorization;
Expand Down Expand Up @@ -185,6 +190,7 @@ export CudaOffloadFactorization
export MKLPardisoFactorize, MKLPardisoIterate
export PardisoJL
export MKLLUFactorization
export AppleAccelerateLUFactorization

export OperatorAssumptions, OperatorCondition

Expand Down
102 changes: 102 additions & 0 deletions src/appleaccelerate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
using LinearAlgebra
using Libdl

# For now, only use BLAS from Accelerate (that is to say, vecLib)
global const libacc = "/System/Library/Frameworks/Accelerate.framework/Accelerate"

"""
```julia
AppleAccelerateLUFactorization()
```

A wrapper over Apple's Accelerate Library. Direct calls to Acceelrate in a way that pre-allocates workspace
to avoid allocations and does not require libblastrampoline.
"""
struct AppleAccelerateLUFactorization <: AbstractFactorization end

function appleaccelerate_isavailable()
libacc_hdl = Libdl.dlopen_e(libacc)
if libacc_hdl == C_NULL
return false
end

if dlsym_e(libacc_hdl, "dgetrf_") == C_NULL
return false

Check warning on line 24 in src/appleaccelerate.jl

View check run for this annotation

Codecov / codecov/patch

src/appleaccelerate.jl#L23-L24

Added lines #L23 - L24 were not covered by tests
end
return true

Check warning on line 26 in src/appleaccelerate.jl

View check run for this annotation

Codecov / codecov/patch

src/appleaccelerate.jl#L26

Added line #L26 was not covered by tests
end

function aa_getrf!(A::AbstractMatrix{<:Float64}; ipiv = similar(A, Cint, min(size(A,1),size(A,2))), info = Ref{Cint}(), check = false)
require_one_based_indexing(A)
check && chkfinite(A)
chkstride1(A)
m, n = size(A)
lda = max(1,stride(A, 2))
if isempty(ipiv)
ipiv = similar(A, Cint, min(size(A,1),size(A,2)))

Check warning on line 36 in src/appleaccelerate.jl

View check run for this annotation

Codecov / codecov/patch

src/appleaccelerate.jl#L29-L36

Added lines #L29 - L36 were not covered by tests
end

ccall(("dgetrf_", libacc), Cvoid,

Check warning on line 39 in src/appleaccelerate.jl

View check run for this annotation

Codecov / codecov/patch

src/appleaccelerate.jl#L39

Added line #L39 was not covered by tests
(Ref{Cint}, Ref{Cint}, Ptr{Float64},
Ref{Cint}, Ptr{Cint}, Ptr{Cint}),
m, n, A, lda, ipiv, info)
ViralBShah marked this conversation as resolved.
Show resolved Hide resolved
info[] < 0 && throw(ArgumentError("Invalid arguments sent to LAPACK dgetrf_"))
A, ipiv, BlasInt(info[]), info #Error code is stored in LU factorization type

Check warning on line 44 in src/appleaccelerate.jl

View check run for this annotation

Codecov / codecov/patch

src/appleaccelerate.jl#L43-L44

Added lines #L43 - L44 were not covered by tests
end

function aa_getrs!(trans::AbstractChar, A::AbstractMatrix{<:Float64}, ipiv::AbstractVector{Cint}, B::AbstractVecOrMat{<:Float64}; info = Ref{Cint}())
require_one_based_indexing(A, ipiv, B)
LinearAlgebra.LAPACK.chktrans(trans)
chkstride1(A, B, ipiv)
n = LinearAlgebra.checksquare(A)
if n != size(B, 1)
throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n"))

Check warning on line 53 in src/appleaccelerate.jl

View check run for this annotation

Codecov / codecov/patch

src/appleaccelerate.jl#L47-L53

Added lines #L47 - L53 were not covered by tests
end
if n != length(ipiv)
throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n"))

Check warning on line 56 in src/appleaccelerate.jl

View check run for this annotation

Codecov / codecov/patch

src/appleaccelerate.jl#L55-L56

Added lines #L55 - L56 were not covered by tests
end
nrhs = size(B, 2)
ccall(("dgetrs_", libacc), Cvoid,

Check warning on line 59 in src/appleaccelerate.jl

View check run for this annotation

Codecov / codecov/patch

src/appleaccelerate.jl#L58-L59

Added lines #L58 - L59 were not covered by tests
(Ref{UInt8}, Ref{Cint}, Ref{Cint}, Ptr{Float64}, Ref{Cint},
Ptr{Cint}, Ptr{Float64}, Ref{Cint}, Ptr{Cint}, Clong),
trans, n, size(B,2), A, max(1,stride(A,2)), ipiv, B, max(1,stride(B,2)), info, 1)
LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[]))
B

Check warning on line 64 in src/appleaccelerate.jl

View check run for this annotation

Codecov / codecov/patch

src/appleaccelerate.jl#L63-L64

Added lines #L63 - L64 were not covered by tests
end

default_alias_A(::AppleAccelerateLUFactorization, ::Any, ::Any) = false
default_alias_b(::AppleAccelerateLUFactorization, ::Any, ::Any) = false

Check warning on line 68 in src/appleaccelerate.jl

View check run for this annotation

Codecov / codecov/patch

src/appleaccelerate.jl#L67-L68

Added lines #L67 - L68 were not covered by tests

function LinearSolve.init_cacheval(alg::AppleAccelerateLUFactorization, A, b, u, Pl, Pr,

Check warning on line 70 in src/appleaccelerate.jl

View check run for this annotation

Codecov / codecov/patch

src/appleaccelerate.jl#L70

Added line #L70 was not covered by tests
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
luinst = ArrayInterface.lu_instance(convert(AbstractMatrix, A))
LU(luinst.factors,similar(A, Cint, 0), luinst.info), Ref{Cint}()

Check warning on line 74 in src/appleaccelerate.jl

View check run for this annotation

Codecov / codecov/patch

src/appleaccelerate.jl#L73-L74

Added lines #L73 - L74 were not covered by tests
end

function SciMLBase.solve!(cache::LinearCache, alg::AppleAccelerateLUFactorization;

Check warning on line 77 in src/appleaccelerate.jl

View check run for this annotation

Codecov / codecov/patch

src/appleaccelerate.jl#L77

Added line #L77 was not covered by tests
kwargs...)
A = cache.A
A = convert(AbstractMatrix, A)
if cache.isfresh
cacheval = @get_cacheval(cache, :AppleAccelerateLUFactorization)
res = aa_getrf!(A; ipiv = cacheval[1].ipiv, info = cacheval[2])
fact = LU(res[1:3]...), res[4]
cache.cacheval = fact
cache.isfresh = false

Check warning on line 86 in src/appleaccelerate.jl

View check run for this annotation

Codecov / codecov/patch

src/appleaccelerate.jl#L79-L86

Added lines #L79 - L86 were not covered by tests
end

A, info = @get_cacheval(cache, :AppleAccelerateLUFactorization)
LinearAlgebra.require_one_based_indexing(cache.u, cache.b)
m, n = size(A, 1), size(A, 2)
if m > n
Bc = copy(cache.b)
aa_getrs!('N', A.factors, A.ipiv, Bc; info)
return copyto!(cache.u, 1, Bc, 1, n)

Check warning on line 95 in src/appleaccelerate.jl

View check run for this annotation

Codecov / codecov/patch

src/appleaccelerate.jl#L89-L95

Added lines #L89 - L95 were not covered by tests
else
copyto!(cache.u, cache.b)
aa_getrs!('N', A.factors, A.ipiv, cache.u; info)

Check warning on line 98 in src/appleaccelerate.jl

View check run for this annotation

Codecov / codecov/patch

src/appleaccelerate.jl#L97-L98

Added lines #L97 - L98 were not covered by tests
end

SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)

Check warning on line 101 in src/appleaccelerate.jl

View check run for this annotation

Codecov / codecov/patch

src/appleaccelerate.jl#L101

Added line #L101 was not covered by tests
end
3 changes: 3 additions & 0 deletions test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,9 @@ end
test_interface(alg, prob1, prob2)
end
end
if LinearSolve.appleaccelerate_isavailable()
test_interface(AppleAccelerateLUFactorization(), prob1, prob2)
end
end

@testset "Generic Factorizations" begin
Expand Down
4 changes: 3 additions & 1 deletion test/resolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ using LinearSolve, LinearAlgebra, SparseArrays, InteractiveUtils, Test

for alg in subtypes(LinearSolve.AbstractFactorization)
@show alg
if !(alg in [DiagonalFactorization, CudaOffloadFactorization])
if !(alg in [DiagonalFactorization, CudaOffloadFactorization, AppleAccelerateLUFactorization]) &&
(!(alg == AppleAccelerateLUFactorization) || LinearSolve.appleaccelerate_isavailable())

A = [1.0 2.0; 3.0 4.0]
alg in [KLUFactorization, UMFPACKFactorization, SparspakFactorization] &&
(A = sparse(A))
Expand Down
Loading