Skip to content

Commit

Permalink
Merge pull request #388 from SciML/gpu_default
Browse files Browse the repository at this point in the history
Fix GPU tests
  • Loading branch information
ChrisRackauckas authored Oct 5, 2023
2 parents a53f644 + b8ef3e4 commit 296b142
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 24 deletions.
27 changes: 11 additions & 16 deletions ext/LinearSolveCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -1,31 +1,26 @@
module LinearSolveCUDAExt

using CUDA, LinearAlgebra, LinearSolve, SciMLBase
using CUDA
using LinearSolve
using LinearSolve.LinearAlgebra, LinearSolve.SciMLBase, LinearSolve.ArrayInterface
using SciMLBase: AbstractSciMLOperator

function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadFactorization;
kwargs...)
if cache.isfresh
fact = LinearSolve.do_factorization(alg, CUDA.CuArray(cache.A), cache.b, cache.u)
cache = LinearSolve.set_cacheval(cache, fact)
fact = qr(CUDA.CuArray(cache.A))
cache.cacheval = fact
cache.isfresh = false
end

copyto!(cache.u, cache.b)
y = Array(ldiv!(cache.cacheval, CUDA.CuArray(cache.u)))
y = Array(ldiv!(CUDA.CuArray(cache.u), cache.cacheval, CUDA.CuArray(cache.b)))
cache.u .= y
SciMLBase.build_linear_solution(alg, y, nothing, cache)
end

function LinearSolve.do_factorization(alg::CudaOffloadFactorization, A, b, u)
A isa Union{AbstractMatrix, AbstractSciMLOperator} ||
error("LU is not defined for $(typeof(A))")

if A isa Union{MatrixOperator, DiffEqArrayOperator}
A = A.A
end

fact = qr(CUDA.CuArray(A))
return fact
function LinearSolve.init_cacheval(alg::CudaOffloadFactorization, A, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
qr(CUDA.CuArray(A))
end

end
16 changes: 8 additions & 8 deletions test/gpu/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,16 @@ function test_interface(alg, prob1, prob2)
@test A1 * y b1

cache = SciMLBase.init(prob1, alg; cache_kwargs...) # initialize cache
y = solve(cache)
@test A1 * y b1
solve!(cache)
@test A1 * cache.u b1

cache = LinearSolve.set_A(cache, copy(A2))
y = solve(cache)
@test A2 * y b1
cache.A = copy(A2)
solve!(cache)
@test A2 * cache.u b1

cache = LinearSolve.set_b(cache, b2)
y = solve(cache)
@test A2 * y b2
cache.b = copy(b2)
solve!(cache)
@test A2 * cache.u b2

return
end
Expand Down

0 comments on commit 296b142

Please sign in to comment.