diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index a53d926d8..d1f3458d1 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -28,6 +28,8 @@ jobs: group: 'LinearSolveHYPRE' - version: '1' group: 'LinearSolvePardiso' + - version: '1' + group: 'LinearSolveBandedMatrices' steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v1 diff --git a/Project.toml b/Project.toml index 508dab4f4..3e02be667 100644 --- a/Project.toml +++ b/Project.toml @@ -107,4 +107,4 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff"] +test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices"] diff --git a/ext/LinearSolveBandedMatricesExt.jl b/ext/LinearSolveBandedMatricesExt.jl index f402e06fa..1eaf0b712 100644 --- a/ext/LinearSolveBandedMatricesExt.jl +++ b/ext/LinearSolveBandedMatricesExt.jl @@ -5,8 +5,14 @@ import LinearSolve: defaultalg, do_factorization, init_cacheval, DefaultLinearSolver, DefaultAlgorithmChoice # Defaults for BandedMatrices -function defaultalg(A::BandedMatrix, b, ::OperatorAssumptions) - return DefaultLinearSolver(DefaultAlgorithmChoice.DirectLdiv!) +function defaultalg(A::BandedMatrix, b, oa::OperatorAssumptions) + if oa.issq + return DefaultLinearSolver(DefaultAlgorithmChoice.DirectLdiv!) + elseif LinearSolve.is_underdetermined(A) + error("No solver for underdetermined `A::BandedMatrix` is currently implemented!") + else + return DefaultLinearSolver(DefaultAlgorithmChoice.QRFactorization) + end end function defaultalg(A::Symmetric{<:Number, <:BandedMatrix}, b, ::OperatorAssumptions) diff --git a/test/banded.jl b/test/banded.jl new file mode 100644 index 000000000..0bbf5e65b --- /dev/null +++ b/test/banded.jl @@ -0,0 +1,38 @@ +using BandedMatrices, LinearAlgebra, LinearSolve, Test + +# Square Case +n = 8 +A = BandedMatrix(Matrix(I, n, n), (2, 2)) +b = ones(n) +A1 = A / 1; +b1 = rand(n); +x1 = zero(b); +A2 = A / 2; +b2 = rand(n); +x2 = zero(b); + +sol1 = solve(LinearProblem(A1, b1; u0 = x1)) +@test sol1.u ≈ A1 \ b1 +sol2 = solve(LinearProblem(A2, b2; u0 = x2)) +@test sol2.u ≈ A2 \ b2 + +# Square Symmetric +A1s = Symmetric(A1) +A2s = Symmetric(A2) + +sol1s = solve(LinearProblem(A1s, b1; u0 = x1)) +@test sol1s.u ≈ A1s \ b1 +sol2s = solve(LinearProblem(A2s, b2; u0 = x2)) +@test sol2s.u ≈ A2s \ b2 + +# Underdetermined +A = BandedMatrix(rand(8, 10), (2, 2)) +b = rand(8) + +@test_throws ErrorException solve(LinearProblem(A, b)).u + +# Overdetermined +A = BandedMatrix(ones(10, 8), (2, 0)) +b = rand(10) + +@test_nowarn solve(LinearProblem(A, b)) diff --git a/test/runtests.jl b/test/runtests.jl index 4f2e78feb..bbf1cd0d8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,6 +15,7 @@ if GROUP == "All" || GROUP == "Core" @time @safetestset "Default Alg Tests" include("default_algs.jl") VERSION >= v"1.9" && @time @safetestset "Enzyme Derivative Rules" include("enzyme.jl") @time @safetestset "Traits" include("traits.jl") + VERSION >= v"1.9" && @time @safetestset "BandedMatrices" include("banded.jl") end if GROUP == "LinearSolveCUDA"