Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Apple Accelerate and improve MKL integration #355

Merged
merged 11 commits into from
Aug 8, 2023
5 changes: 4 additions & 1 deletion ext/LinearSolveMKLExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
chkstride1(A)
m, n = size(A)
lda = max(1,stride(A, 2))
if isempty(ipiv)
ipiv = similar(A, BlasInt, min(size(A,1),size(A,2)))

Check warning on line 20 in ext/LinearSolveMKLExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveMKLExt.jl#L19-L20

Added lines #L19 - L20 were not covered by tests
end
ccall((@blasfunc(dgetrf_), MKL_jll.libmkl_rt), Cvoid,
(Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64},
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
Expand All @@ -39,7 +42,7 @@
A = convert(AbstractMatrix, A)
if cache.isfresh
cacheval = @get_cacheval(cache, :MKLLUFactorization)
fact = LU(getrf!(A)...)
fact = LU(getrf!(A; ipiv = cacheval.ipiv)...)

Check warning on line 45 in ext/LinearSolveMKLExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveMKLExt.jl#L45

Added line #L45 was not covered by tests
cache.cacheval = fact
cache.isfresh = false
end
Expand Down
6 changes: 6 additions & 0 deletions src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ using EnumX
using Requires
import InteractiveUtils

using LinearAlgebra: BlasInt, LU
using LinearAlgebra.LAPACK: require_one_based_indexing, chkfinite, chkstride1,
@blasfunc, chkargsok

import GPUArraysCore
import Preferences

Expand Down Expand Up @@ -87,6 +91,7 @@ include("solve_function.jl")
include("default.jl")
include("init.jl")
include("extension_algs.jl")
include("appleaccelerate.jl")
include("deprecated.jl")

@generated function SciMLBase.solve!(cache::LinearCache, alg::AbstractFactorization;
Expand Down Expand Up @@ -185,6 +190,7 @@ export CudaOffloadFactorization
export MKLPardisoFactorize, MKLPardisoIterate
export PardisoJL
export MKLLUFactorization
export AppleAccelerateLUFactorization

export OperatorAssumptions, OperatorCondition

Expand Down
65 changes: 65 additions & 0 deletions src/appleaccelerate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# For now, only use BLAS from Accelerate (that is to say, vecLib)
global const libacc = "/System/Library/Frameworks/Accelerate.framework/Accelerate"

"""
```julia
AppleAccelerateLUFactorization()
```

A wrapper over Apple's Accelerate Library. Direct calls to Acceelrate in a way that pre-allocates workspace
to avoid allocations and does not require libblastrampoline.
"""
struct AppleAccelerateLUFactorization <: AbstractFactorization end

function appleaccelerate_isavailable()
libacc_hdl = dlopen_e(libacc)
if libacc_hdl == C_NULL
return false

Check warning on line 17 in src/appleaccelerate.jl

View check run for this annotation

Codecov / codecov/patch

src/appleaccelerate.jl#L14-L17

Added lines #L14 - L17 were not covered by tests
end

if dlsym_e(libacc_hdl, "dgemm\$NEWLAPACK\$ILP64") == C_NULL
return false

Check warning on line 21 in src/appleaccelerate.jl

View check run for this annotation

Codecov / codecov/patch

src/appleaccelerate.jl#L20-L21

Added lines #L20 - L21 were not covered by tests
end
return true

Check warning on line 23 in src/appleaccelerate.jl

View check run for this annotation

Codecov / codecov/patch

src/appleaccelerate.jl#L23

Added line #L23 was not covered by tests
end

function aa_getrf!(A::AbstractMatrix{<:Float64}; ipiv = similar(A, BlasInt, min(size(A,1),size(A,2))), info = Ref{BlasInt}(), check = false)
require_one_based_indexing(A)
check && chkfinite(A)
chkstride1(A)
m, n = size(A)
lda = max(1,stride(A, 2))
if isempty(ipiv)
ipiv = similar(A, BlasInt, min(size(A,1),size(A,2)))

Check warning on line 33 in src/appleaccelerate.jl

View check run for this annotation

Codecov / codecov/patch

src/appleaccelerate.jl#L26-L33

Added lines #L26 - L33 were not covered by tests
end

ccall(("dgetrf\$NEWLAPACK\$ILP64", libacc), Cvoid,

Check warning on line 36 in src/appleaccelerate.jl

View check run for this annotation

Codecov / codecov/patch

src/appleaccelerate.jl#L36

Added line #L36 was not covered by tests
(Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64},
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
m, n, A, lda, ipiv, info)
ViralBShah marked this conversation as resolved.
Show resolved Hide resolved
chkargsok(info[])
A, ipiv, info[] #Error code is stored in LU factorization type

Check warning on line 41 in src/appleaccelerate.jl

View check run for this annotation

Codecov / codecov/patch

src/appleaccelerate.jl#L40-L41

Added lines #L40 - L41 were not covered by tests
end

default_alias_A(::AppleAccelerateLUFactorization, ::Any, ::Any) = false
default_alias_b(::AppleAccelerateLUFactorization, ::Any, ::Any) = false

Check warning on line 45 in src/appleaccelerate.jl

View check run for this annotation

Codecov / codecov/patch

src/appleaccelerate.jl#L44-L45

Added lines #L44 - L45 were not covered by tests

function LinearSolve.init_cacheval(alg::AppleAccelerateLUFactorization, A, b, u, Pl, Pr,

Check warning on line 47 in src/appleaccelerate.jl

View check run for this annotation

Codecov / codecov/patch

src/appleaccelerate.jl#L47

Added line #L47 was not covered by tests
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
ArrayInterface.lu_instance(convert(AbstractMatrix, A))

Check warning on line 50 in src/appleaccelerate.jl

View check run for this annotation

Codecov / codecov/patch

src/appleaccelerate.jl#L50

Added line #L50 was not covered by tests
end

function SciMLBase.solve!(cache::LinearCache, alg::AppleAccelerateLUFactorization;

Check warning on line 53 in src/appleaccelerate.jl

View check run for this annotation

Codecov / codecov/patch

src/appleaccelerate.jl#L53

Added line #L53 was not covered by tests
kwargs...)
A = cache.A
A = convert(AbstractMatrix, A)
if cache.isfresh
cacheval = @get_cacheval(cache, :AppleAccelerateLUFactorization)
fact = LU(aa_getrf!(A; ipiv = cacheval.ipiv)...)
cache.cacheval = fact
cache.isfresh = false

Check warning on line 61 in src/appleaccelerate.jl

View check run for this annotation

Codecov / codecov/patch

src/appleaccelerate.jl#L55-L61

Added lines #L55 - L61 were not covered by tests
end
y = ldiv!(cache.u, @get_cacheval(cache, :AppleAccelerateLUFactorization), cache.b)
SciMLBase.build_linear_solution(alg, y, nothing, cache)

Check warning on line 64 in src/appleaccelerate.jl

View check run for this annotation

Codecov / codecov/patch

src/appleaccelerate.jl#L63-L64

Added lines #L63 - L64 were not covered by tests
end
3 changes: 3 additions & 0 deletions test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,9 @@ end
test_interface(alg, prob1, prob2)
end
end
if LinearSolve.appleaccelerate_isavailable()
test_interface(AppleAccelerateLUFactorization(), prob1, prob2)
end
end

@testset "Generic Factorizations" begin
Expand Down
Loading