Skip to content

Commit

Permalink
Add needs_square_A trait
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 24, 2023
1 parent a880003 commit 86beacf
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LinearSolve"
uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
authors = ["SciML"]
version = "2.11.1"
version = "2.12.0"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
40 changes: 40 additions & 0 deletions src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,46 @@ end
include("factorization_sparse.jl")
end

# Solver Specific Traits
## Needs Square Matrix
"""
needs_square_A(alg)
Returns `true` if the algorithm requires a square matrix.
Note that this checks if the implementation of the algorithm needs a square matrix by
trying to solve an underdetermined system. It is recommended to add a dispatch to this
function for custom algorithms!
"""
needs_square_A(::Nothing) = false # Linear Solve automatically will use a correct alg!
function needs_square_A(alg::SciMLLinearSolveAlgorithm)
try
A = [1.0 2.0;

Check warning on line 160 in src/LinearSolve.jl

View check run for this annotation

Codecov / codecov/patch

src/LinearSolve.jl#L157-L160

Added lines #L157 - L160 were not covered by tests
3.0 4.0;
5.0 6.0]
b = ones(Float64, 3)
solve(LinearProblem(A, b), alg)
return false

Check warning on line 165 in src/LinearSolve.jl

View check run for this annotation

Codecov / codecov/patch

src/LinearSolve.jl#L163-L165

Added lines #L163 - L165 were not covered by tests
catch err
return true

Check warning on line 167 in src/LinearSolve.jl

View check run for this annotation

Codecov / codecov/patch

src/LinearSolve.jl#L167

Added line #L167 was not covered by tests
end
end
for alg in (:QRFactorization, :FastQRFactorization, :NormalCholeskyFactorization,
:NormalBunchKaufmanFactorization)
@eval needs_square_A(::$(alg)) = false
end
for kralg in (Krylov.lsmr!, Krylov.craigmr!)
@eval needs_square_A(::KrylovJL{$(typeof(kralg))}) = false
end
for alg in (:LUFactorization, :FastLUFactorization, :SVDFactorization,
:GenericFactorization, :GenericLUFactorization, :SimpleLUFactorization,
:RFLUFactorization, :UMFPACKFactorization, :KLUFactorization, :SparspakFactorization,
:DiagonalFactorization, :CholeskyFactorization, :BunchKaufmanFactorization,
:CHOLMODFactorization, :LDLtFactorization, :AppleAccelerateLUFactorization,
:MKLLUFactorization, :MetalLUFactorization)
@eval needs_square_A(::$(alg)) = true

Check warning on line 183 in src/LinearSolve.jl

View check run for this annotation

Codecov / codecov/patch

src/LinearSolve.jl#L183

Added line #L183 was not covered by tests
end

const IS_OPENBLAS = Ref(true)
isopenblas() = IS_OPENBLAS[]

Expand Down
7 changes: 7 additions & 0 deletions test/nonsquare.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@ b = rand(m)
prob = LinearProblem(A, b)
res = A \ b
@test solve(prob).u res
@test !LinearSolve.needs_square_A(QRFactorization())
@test solve(prob, QRFactorization()) res
@test !LinearSolve.needs_square_A(FastQRFactorization())
@test solve(prob, FastQRFactorization()) res
@test !LinearSolve.needs_square_A(KrylovJL_LSMR())
@test solve(prob, KrylovJL_LSMR()) res

A = sprand(m, n, 0.5)
Expand All @@ -23,6 +27,7 @@ A = sprand(n, m, 0.5)
b = rand(n)
prob = LinearProblem(A, b)
res = Matrix(A) \ b
@test !LinearSolve.needs_square_A(KrylovJL_CRAIGMR())
@test solve(prob, KrylovJL_CRAIGMR()) res

A = sprandn(1000, 100, 0.1)
Expand All @@ -35,7 +40,9 @@ A = randn(1000, 100)
b = randn(1000)
@test isapprox(solve(LinearProblem(A, b)).u, Symmetric(A' * A) \ (A' * b))
solve(LinearProblem(A, b)).u;
@test !LinearSolve.needs_square_A(NormalCholeskyFactorization())
solve(LinearProblem(A, b), (LinearSolve.NormalCholeskyFactorization())).u;
@test !LinearSolve.needs_square_A(NormalBunchKaufmanFactorization())
solve(LinearProblem(A, b), (LinearSolve.NormalBunchKaufmanFactorization())).u;
solve(LinearProblem(A, b),
assumptions = (OperatorAssumptions(false;
Expand Down

0 comments on commit 86beacf

Please sign in to comment.