Skip to content

Commit

Permalink
Use Pivoted QR for Underdetermined Systems
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 24, 2023
1 parent 86beacf commit d814042
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 5 deletions.
1 change: 1 addition & 0 deletions src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ EnumX.@enumx DefaultAlgorithmChoice begin
NormalCholeskyFactorization
AppleAccelerateLUFactorization
MKLLUFactorization
QRFactorizationPivoted
end

struct DefaultLinearSolver <: SciMLLinearSolveAlgorithm
Expand Down
24 changes: 19 additions & 5 deletions src/default.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
needs_concrete_A(alg::DefaultLinearSolver) = true
mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12,
T13, T14, T15, T16, T17, T18}
T13, T14, T15, T16, T17, T18, T19}
LUFactorization::T1
QRFactorization::T2
DiagonalFactorization::T3
Expand All @@ -19,6 +19,7 @@ mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10,
NormalCholeskyFactorization::T16
AppleAccelerateLUFactorization::T17
MKLLUFactorization::T18
QRFactorizationPivoted::T19
end

# Legacy fallback
Expand Down Expand Up @@ -168,8 +169,8 @@ function defaultalg(A, b, assump::OperatorAssumptions)
(A === nothing ? eltype(b) <: Union{Float32, Float64} :
eltype(A) <: Union{Float32, Float64})
DefaultAlgorithmChoice.RFLUFactorization
#elseif A === nothing || A isa Matrix
# alg = FastLUFactorization()
#elseif A === nothing || A isa Matrix
# alg = FastLUFactorization()
elseif usemkl && (A === nothing ? eltype(b) <: Union{Float32, Float64} :
eltype(A) <: Union{Float32, Float64})
DefaultAlgorithmChoice.MKLLUFactorization
Expand Down Expand Up @@ -199,9 +200,19 @@ function defaultalg(A, b, assump::OperatorAssumptions)
elseif assump.condition === OperatorCondition.WellConditioned
DefaultAlgorithmChoice.NormalCholeskyFactorization
elseif assump.condition === OperatorCondition.IllConditioned
DefaultAlgorithmChoice.QRFactorization
if size(A, 1) < size(A, 2)

Check warning on line 203 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L203

Added line #L203 was not covered by tests
# Underdetermined
DefaultAlgorithmChoice.QRFactorizationPivoted

Check warning on line 205 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L205

Added line #L205 was not covered by tests
else
DefaultAlgorithmChoice.QRFactorization

Check warning on line 207 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L207

Added line #L207 was not covered by tests
end
elseif assump.condition === OperatorCondition.VeryIllConditioned
DefaultAlgorithmChoice.QRFactorization
if size(A, 1) < size(A, 2)

Check warning on line 210 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L210

Added line #L210 was not covered by tests
# Underdetermined
DefaultAlgorithmChoice.QRFactorizationPivoted

Check warning on line 212 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L212

Added line #L212 was not covered by tests
else
DefaultAlgorithmChoice.QRFactorization

Check warning on line 214 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L214

Added line #L214 was not covered by tests
end
elseif assump.condition === OperatorCondition.SuperIllConditioned
DefaultAlgorithmChoice.SVDFactorization
else
Expand Down Expand Up @@ -247,6 +258,8 @@ function algchoice_to_alg(alg::Symbol)
NormalCholeskyFactorization()
elseif alg === :AppleAccelerateLUFactorization
AppleAccelerateLUFactorization()
elseif alg === :QRFactorizationPivoted
QRFactorization(ColumnNorm())
else
error("Algorithm choice symbol $alg not allowed in the default")
end
Expand Down Expand Up @@ -310,6 +323,7 @@ function defaultalg_symbol(::Type{T}) where {T}
Symbol(split(string(SciMLBase.parameterless_type(T)), ".")[end])
end
defaultalg_symbol(::Type{<:GenericFactorization{typeof(ldlt!)}}) = :LDLtFactorization
defaultalg_symbol(::Type{<:QRFactorization{ColumnNorm}}) = :QRFactorizationPivoted

"""
if alg.alg === DefaultAlgorithmChoice.LUFactorization
Expand Down
6 changes: 6 additions & 0 deletions src/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,12 @@ function QRFactorization(inplace = true)
QRFactorization(pivot, 16, inplace)
end

@static if VERSION v"1.7beta"
function QRFactorization(pivot::LinearAlgebra.PivotingStrategy, inplace::Bool = true)
QRFactorization(pivot, 16, inplace)
end
end

function do_factorization(alg::QRFactorization, A, b, u)
A = convert(AbstractMatrix, A)
if alg.inplace && !(A isa SparseMatrixCSC) && !(A isa GPUArraysCore.AbstractGPUArray)
Expand Down
9 changes: 9 additions & 0 deletions test/nonsquare.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,12 @@ solve(LinearProblem(A, b), (LinearSolve.NormalCholeskyFactorization())).u;
solve(LinearProblem(A, b),
assumptions = (OperatorAssumptions(false;
condition = OperatorCondition.WellConditioned))).u;

# Underdetermined
m, n = 2, 3

A = rand(m, n)
b = rand(m)
prob = LinearProblem(A, b)
res = A \ b
@test solve(prob).u res

0 comments on commit d814042

Please sign in to comment.