From 3a3c03e14ef15b74a4132910a5802a4270b794b6 Mon Sep 17 00:00:00 2001 From: Alexis Montoison Date: Fri, 11 Oct 2024 22:37:23 -0500 Subject: [PATCH] Add an implementation of BLOCK-MINRES --- docs/src/api.md | 10 +- docs/src/block_krylov.md | 16 ++- docs/src/block_processes.md | 2 + docs/src/gpu.md | 2 +- docs/src/preconditioners.md | 2 +- src/Krylov.jl | 1 + src/block_gmres.jl | 2 +- src/block_krylov_processes.jl | 2 +- src/block_krylov_solvers.jl | 59 +++++++- src/block_minres.jl | 244 ++++++++++++++++++++++++++++++++++ src/krylov_processes.jl | 2 +- src/krylov_solve.jl | 135 +++++++++++-------- src/usymlq.jl | 2 +- src/usymqr.jl | 2 +- 14 files changed, 411 insertions(+), 70 deletions(-) create mode 100644 src/block_minres.jl diff --git a/docs/src/api.md b/docs/src/api.md index b3050c16e..01453099d 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -12,11 +12,10 @@ Krylov.LSLQStats Krylov.LsmrStats ``` -## Solver Types +## Workspace of Krylov methods ```@docs KrylovSolver -BlockKrylovSolver MinresSolver MinaresSolver CgSolver @@ -53,6 +52,13 @@ CraigSolver CraigmrSolver GpmrSolver FgmresSolver +``` + +## Workspace of block-Krylov methods + +```@docs +BlockKrylovSolver +BlockMinresSolver BlockGmresSolver ``` diff --git a/docs/src/block_krylov.md b/docs/src/block_krylov.md index 428b1ca35..160b3fddb 100644 --- a/docs/src/block_krylov.md +++ b/docs/src/block_krylov.md @@ -1,10 +1,7 @@ -## Block-GMRES - !!! note - `block_gmres` works on GPUs - with Julia 1.11. + `block_minres` and `block_gmres` work on GPUs with Julia 1.11. -If you want to use `block_gmres` on previous Julia versions, you can overload the function `Krylov.copy_triangle` with the following code: +If you want to use `block_minres` and `block_gmres` on previous Julia versions, you can overload the function `Krylov.copy_triangle` with the following code: ```julia using KernelAbstractions, Krylov @@ -23,6 +20,15 @@ function Krylov.copy_triangle(Q::AbstractMatrix{FC}, R::AbstractMatrix{FC}, k::I end ``` +## Block-MINRES + +```@docs +block_minres +block_minres! +``` + +## Block-GMRES + ```@docs block_gmres block_gmres! diff --git a/docs/src/block_processes.md b/docs/src/block_processes.md index e9fc17811..3c1d114dc 100644 --- a/docs/src/block_processes.md +++ b/docs/src/block_processes.md @@ -27,6 +27,8 @@ T_{k+1,k} = The function [`hermitian_lanczos`](@ref hermitian_lanczos(::Any, ::AbstractMatrix{FC}, ::Int) where FC <: (Union{Complex{T}, T} where T <: AbstractFloat)) returns $V_{k+1}$, $\Psi_1$ and $T_{k+1,k}$. +Related method: [`BLOCK-MINRES`](@ref block_minres). + ```@docs hermitian_lanczos(::Any, ::AbstractMatrix{FC}, ::Int) where FC <: (Union{Complex{T}, T} where T <: AbstractFloat) ``` diff --git a/docs/src/gpu.md b/docs/src/gpu.md index 2db45df1b..f783a3995 100644 --- a/docs/src/gpu.md +++ b/docs/src/gpu.md @@ -107,7 +107,7 @@ if CUDA.functional() symmetric = hermitian = true opM = LinearOperator(T, n, n, symmetric, hermitian, (y, x) -> ldiv_ic0!(P, x, y, z)) - # Solve an Hermitian positive definite system with an IC(0) preconditioner on GPU + # Solve a Hermitian positive definite system with an IC(0) preconditioner on GPU x, stats = cg(A_gpu, b_gpu, M=opM) end ``` diff --git a/docs/src/preconditioners.md b/docs/src/preconditioners.md index 983bc51c7..c05225a9f 100644 --- a/docs/src/preconditioners.md +++ b/docs/src/preconditioners.md @@ -46,7 +46,7 @@ A Krylov method dedicated to non-Hermitian linear systems allows the three varia ### Hermitian linear systems -Methods concerned: [`SYMMLQ`](@ref symmlq), [`CG`](@ref cg), [`CG-LANCZOS`](@ref cg_lanczos), [`CG-LANCZOS-SHIFT`](@ref cg_lanczos_shift), [`CR`](@ref cr), [`CAR`](@ref car), [`MINRES`](@ref minres), [`MINRES-QLP`](@ref minres_qlp) and [`MINARES`](@ref minares). +Methods concerned: [`SYMMLQ`](@ref symmlq), [`CG`](@ref cg), [`CG-LANCZOS`](@ref cg_lanczos), [`CG-LANCZOS-SHIFT`](@ref cg_lanczos_shift), [`CR`](@ref cr), [`CAR`](@ref car), [`MINRES`](@ref minres), [`BLOCK-MINRES`](@ref block_minres), [`MINRES-QLP`](@ref minres_qlp) and [`MINARES`](@ref minares). When $A$ is Hermitian, we can only use centered preconditioning $L^{-1}AL^{-H}y = L^{-1}b$ with $x = L^{-H}y$. Centered preconditioning is a special case of two-sided preconditioning with $P_{\ell} = L = P_r^H$ that maintains hermicity. diff --git a/src/Krylov.jl b/src/Krylov.jl index 2578d308b..ac44f8052 100644 --- a/src/Krylov.jl +++ b/src/Krylov.jl @@ -12,6 +12,7 @@ include("block_krylov_utils.jl") include("block_krylov_processes.jl") include("block_krylov_solvers.jl") +include("block_minres.jl") include("block_gmres.jl") include("cg.jl") diff --git a/src/block_gmres.jl b/src/block_gmres.jl index 997311954..83547b938 100644 --- a/src/block_gmres.jl +++ b/src/block_gmres.jl @@ -1,6 +1,6 @@ # An implementation of block-GMRES for the solution of the square linear system AX = B. # -# Alexis Montoison, +# Alexis Montoison, -- # Argonne National Laboratory -- Chicago, October 2023. export block_gmres, block_gmres! diff --git a/src/block_krylov_processes.jl b/src/block_krylov_processes.jl index 3eb03d0b4..04d1061ae 100644 --- a/src/block_krylov_processes.jl +++ b/src/block_krylov_processes.jl @@ -3,7 +3,7 @@ #### Input arguments -* `A`: a linear operator that models an Hermitian matrix of dimension `n`; +* `A`: a linear operator that models a Hermitian matrix of dimension `n`; * `B`: a matrix of size `n × p`; * `k`: the number of iterations of the block Hermitian Lanczos process. diff --git a/src/block_krylov_solvers.jl b/src/block_krylov_solvers.jl index a1e8db13a..a20c967e5 100644 --- a/src/block_krylov_solvers.jl +++ b/src/block_krylov_solvers.jl @@ -1,12 +1,64 @@ export BlockKrylovSolver -export BlockGmresSolver +export BlockMinresSolver, BlockGmresSolver -const BLOCK_KRYLOV_SOLVERS = Dict(:block_gmres => :BlockGmresSolver) +const BLOCK_KRYLOV_SOLVERS = Dict(:block_minres => :BlockMinresSolver, + :block_gmres => :BlockGmresSolver ) "Abstract type for using block Krylov solvers in-place" abstract type BlockKrylovSolver{T,FC,SV,SM} end +""" +Type for storing the vectors required by the in-place version of BLOCK-MINRES. + +The outer constructors + + solver = BlockMinresSolver(m, n, p, SV, SM) + solver = BlockMinresSolver(A, B) + +may be used in order to create these vectors. +`memory` is set to `div(n,p)` if the value given is larger than `div(n,p)`. +""" +mutable struct BlockMinresSolver{T,FC,SV,SM} <: BlockKrylovSolver{T,FC,SV,SM} + m :: Int + n :: Int + p :: Int + ΔX :: SM + X :: SM + W :: SM + P :: SM + Q :: SM + C :: SM + D :: SM + τ :: SV + warm_start :: Bool + stats :: SimpleStats{T} +end + +function BlockMinresSolver(m, n, p, SV, SM) + FC = eltype(SV) + T = real(FC) + ΔX = SM(undef, 0, 0) + X = SM(undef, n, p) + W = SM(undef, n, p) + P = SM(undef, 0, 0) + Q = SM(undef, 0, 0) + C = SM(undef, p, p) + D = SM(undef, 2p, p) + τ = SV(undef, p) + stats = SimpleStats(0, false, false, T[], T[], T[], 0.0, "unknown") + solver = BlockMinresSolver{T,FC,SV,SM}(m, n, p, ΔX, X, W, P, Q, C, D, τ, false, stats) + return solver +end + +function BlockMinresSolver(A, B) + m, n = size(A) + s, p = size(B) + SM = typeof(B) + SV = matrix_to_vector(SM) + BlockMinresSolver(m, n, p, SV, SM) +end + """ Type for storing the vectors required by the in-place version of BLOCK-GMRES. @@ -71,7 +123,8 @@ function BlockGmresSolver(A, B, memory = 5) end for (KS, fun, nsol, nA, nAt, warm_start) in [ - (:BlockGmresSolver, :block_gmres!, 1, 1, 0, true) + (:BlockMinresSolver, :block_minres!, 1, 1, 0, true) + (:BlockGmresSolver , :block_gmres! , 1, 1, 0, true) ] @eval begin size(solver :: $KS) = solver.m, solver.n diff --git a/src/block_minres.jl b/src/block_minres.jl new file mode 100644 index 000000000..f9f222186 --- /dev/null +++ b/src/block_minres.jl @@ -0,0 +1,244 @@ +# An implementation of block-MINRES for the solution of the square linear system AX = B. +# +# Alexis Montoison, -- +# Argonne National Laboratory -- Chicago, October 2024. + +export block_minres, block_minres! + +""" + (X, stats) = block_minres(A, b::AbstractMatrix{FC}; + M=I, ldiv::Bool=false, + atol::T=√eps(T), rtol::T=√eps(T), itmax::Int=0, + timemax::Float64=Inf, verbose::Int=0, history::Bool=false, + callback=solver->false, iostream::IO=kstdout) + +`T` is an `AbstractFloat` such as `Float32`, `Float64` or `BigFloat`. +`FC` is `T` or `Complex{T}`. + + (X, stats) = block_minres(A, B, X0::AbstractMatrix; kwargs...) + +Block-MINRES can be warm-started from an initial guess `X0` where `kwargs` are the same keyword arguments as above. + +Solve the Hermitian linear system AX = B of size n with p right-hand sides using block-MINRES. + +#### Input arguments + +* `A`: a linear operator that models a Hermitian matrix of dimension n; +* `B`: a matrix of size n × p. + +#### Optional argument + +* `X0`: a matrix of size n × p that represents an initial guess of the solution X. + +#### Keyword arguments + +* `M`: linear operator that models a Hermitian positive-definite matrix of size `n` used for centered preconditioning; +* `ldiv`: define whether the preconditioners use `ldiv!` or `mul!`; +* `atol`: absolute stopping tolerance based on the residual norm; +* `rtol`: relative stopping tolerance based on the residual norm; +* `itmax`: the maximum number of iterations. If `itmax=0`, the default number of iterations is set to `2 * div(n,p)`; +* `timemax`: the time limit in seconds; +* `verbose`: additional details can be displayed if verbose mode is enabled (verbose > 0). Information will be displayed every `verbose` iterations; +* `history`: collect additional statistics on the run such as residual norms; +* `callback`: function or functor called as `callback(solver)` that returns `true` if the block-Krylov method should terminate, and `false` otherwise; +* `iostream`: stream to which output is logged. + +#### Output arguments + +* `X`: a dense matrix of size n × p; +* `stats`: statistics collected on the run in a [`SimpleStats`](@ref) structure. +""" +function block_minres end + +""" + solver = block_minres!(solver::BlockMinresSolver, B; kwargs...) + solver = block_minres!(solver::BlockMinresSolver, B, X0; kwargs...) + +where `kwargs` are keyword arguments of [`block_minres`](@ref). + +See [`BlockMinresSolver`](@ref) for more details about the `solver`. +""" +function block_minres! end + +def_args_block_minres = (:(A ), + :(B::AbstractMatrix{FC})) + +def_optargs_block_minres = (:(X0::AbstractMatrix),) + +def_kwargs_block_minres = (:(; M = I ), + :(; ldiv::Bool = false ), + :(; atol::T = √eps(T) ), + :(; rtol::T = √eps(T) ), + :(; itmax::Int = 0 ), + :(; timemax::Float64 = Inf ), + :(; verbose::Int = 0 ), + :(; history::Bool = false ), + :(; callback = solver -> false ), + :(; iostream::IO = kstdout )) + +def_kwargs_block_minres = mapreduce(extract_parameters, vcat, def_kwargs_block_minres) + +args_block_minres = (:A, :B) +optargs_block_minres = (:X0,) +kwargs_block_minres = (:M, :ldiv, :atol, :rtol, :itmax, :timemax, :verbose, :history, :callback, :iostream) + +@eval begin + function block_minres($(def_args_block_minres...), $(def_optargs_block_minres...); $(def_kwargs_block_minres...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}} + start_time = time_ns() + solver = BlockMinresSolver(A, B) + warm_start!(solver, $(optargs_block_minres...)) + elapsed_time = ktimer(start_time) + timemax -= elapsed_time + block_minres!(solver, $(args_block_minres...); $(kwargs_block_minres...)) + solver.stats.timer += elapsed_time + return solver.X, solver.stats + end + + function block_minres($(def_args_block_minres...); $(def_kwargs_block_minres...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}} + start_time = time_ns() + solver = BlockMinresSolver(A, B) + elapsed_time = ktimer(start_time) + timemax -= elapsed_time + block_minres!(solver, $(args_block_minres...); $(kwargs_block_minres...)) + solver.stats.timer += elapsed_time + return solver.X, solver.stats + end + + function block_minres!(solver :: BlockMinresSolver{T,FC,SV,SM}, $(def_args_block_minres...); $(def_kwargs_block_minres...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, SV <: AbstractVector{FC}, SM <: AbstractMatrix{FC}} + + # Timer + start_time = time_ns() + timemax_ns = 1e9 * timemax + + n, m = size(A) + s, p = size(B) + m == n || error("System must be square") + n == s || error("Inconsistent problem size") + (verbose > 0) && @printf(iostream, "BLOCK-MINRES: system of size %d with %d right-hand sides\n", n, p) + + # Check M = Iₙ + MisI = (M === I) + MisI || error("Block-MINRES doesn't support preconditioning yet.") + + # Check type consistency + eltype(A) == FC || @warn "eltype(A) ≠ $FC. This could lead to errors or additional allocations in operator-matrix products." + ktypeof(B) <: SM || error("ktypeof(B) is not a subtype of $SM") + + # Set up workspace. + ΔX, X, W, V, Z = solver.ΔX, solver.X, solver.W, solver.V, solver.Z + C, D, R, H, τ, stats = solver.C, solver.D, solver.R, solver.H, solver.τ, solver.stats + warm_start = solver.warm_start + RNorms = stats.residuals + reset!(stats) + R₀ = warm_start ? Q : B + + # Define the blocks D1 and D2 + D1 = view(D, 1:p, :) + D2 = view(D, p+1:2p, :) + trans = FC <: AbstractFloat ? 'T' : 'C' + + # Coefficients for mul! + α = -one(FC) + β = one(FC) + γ = one(FC) + + # Initial solution X₀. + fill!(X, zero(FC)) + + # Initial residual R₀. + if warm_start + mul!(Q, A, ΔX) + Q .= B .- Q + end + MisI || mulorldiv!(R₀, M, W, ldiv) # R₀ = M(B - AX₀) + RNorm = norm(R₀) # ‖R₀‖_F + history && push!(RNorms, RNorm) + + iter = 0 + itmax == 0 && (itmax = 2*div(n,p)) + + ε = atol + rtol * RNorm + (verbose > 0) && @printf(iostream, "%5s %7s %5s\n", "k", "‖Rₖ‖", "timer") + kdisplay(iter, verbose) && @printf(iostream, "%5d %7.1e %.2fs\n", iter, RNorm, ktimer(start_time)) + + # Stopping criterion + status = "unknown" + solved = RNorm ≤ ε + tired = iter ≥ itmax + user_requested_exit = false + overtimed = false + + while !(solved || tired || user_requested_exit || overtimed) + # Update iteration index. + iter = iter + 1 + + # Initial Ψ₁ and V₁ + copyto!(V, R₀) + householder!(V, Z, τ) + + # Continue the block-Lanczos process. + mul!(W, A, V) # Q ← AVₖ + for i = 1 : inner_iter + mul!(Ω, V', W) # Ωₖ = Vₖᴴ * Q + (iter ≥ 2) && mul!(Q, ...) # Q ← Q - βₖ * Vₖ₋₁ * Ψₖᴴ + mul!(Q, V, R, α, β) # Q = Q - Vₖ * Ωₖ + end + + # Vₖ₊₁ and Ψₖ₊₁ are stored in Q and C. + householder!(Q, C, τ) + + # Update the QR factorization of Tₖ₊₁.ₖ. + # Apply previous Householder reflections Ωᵢ. + for i = 1 : inner_iter-1 + D1 .= R[nr+i] + D2 .= R[nr+i+1] + @kormqr!('L', trans, H[i], τ[i], D) + R[nr+i] .= D1 + R[nr+i+1] .= D2 + end + + # Compute and apply current Householder reflection Ωₖ. + H[inner_iter][1:p,:] .= R[nr+inner_iter] + H[inner_iter][p+1:2p,:] .= C + householder!(H[inner_iter], R[nr+inner_iter], τ[inner_iter], compact=true) + + # Update Zₖ = (Qₖ)ᴴΓE₁ = (Λ₁, ..., Λₖ, Λbarₖ₊₁) + D1 .= Z[inner_iter] + D2 .= zero(FC) + @kormqr!('L', trans, H[inner_iter], τ[inner_iter], D) + Z[inner_iter] .= D1 + + # Update residual norm estimate. + # ‖ M(B - AXₖ) ‖_F = ‖Λbarₖ₊₁‖_F + C .= D2 + RNorm = norm(C) + history && push!(RNorms, RNorm) + + # Update stopping criterion. + user_requested_exit = callback(solver) :: Bool + solved = RNorm ≤ ε + tired = iter ≥ itmax + timer = time_ns() - start_time + overtimed = timer > timemax_ns + kdisplay(iter, verbose) && @printf(iostream, "%5d %7.1e %.2fs\n", iter, RNorm, ktimer(start_time)) + end + (verbose > 0) && @printf(iostream, "\n") + + # Termination status + tired && (status = "maximum number of iterations exceeded") + solved && (status = "solution good enough given atol and rtol") + overtimed && (status = "time limit exceeded") + user_requested_exit && (status = "user-requested exit") + + # Update Xₖ + warm_start && (X .+= ΔX) + solver.warm_start = false + + # Update stats + stats.niter = iter + stats.solved = solved + stats.timer = ktimer(start_time) + stats.status = status + return solver + end +end diff --git a/src/krylov_processes.jl b/src/krylov_processes.jl index b3fcc6efe..c5605a229 100644 --- a/src/krylov_processes.jl +++ b/src/krylov_processes.jl @@ -5,7 +5,7 @@ export hermitian_lanczos, nonhermitian_lanczos, arnoldi, golub_kahan, saunders_s #### Input arguments -* `A`: a linear operator that models an Hermitian matrix of dimension `n`; +* `A`: a linear operator that models a Hermitian matrix of dimension `n`; * `b`: a vector of length `n`; * `k`: the number of iterations of the Hermitian Lanczos process. diff --git a/src/krylov_solve.jl b/src/krylov_solve.jl index f5b1ffb4a..5cf8b274a 100644 --- a/src/krylov_solve.jl +++ b/src/krylov_solve.jl @@ -23,42 +23,41 @@ function solve! end # Krylov methods for (workspace, krylov, args, def_args, optargs, def_optargs, kwargs, def_kwargs) in [ - (:LsmrSolver , :lsmr , args_lsmr , def_args_lsmr , () , () , kwargs_lsmr , def_kwargs_lsmr ) - (:CgsSolver , :cgs , args_cgs , def_args_cgs , optargs_cgs , def_optargs_cgs , kwargs_cgs , def_kwargs_cgs ) - (:UsymlqSolver , :usymlq , args_usymlq , def_args_usymlq , optargs_usymlq , def_optargs_usymlq , kwargs_usymlq , def_kwargs_usymlq ) - (:LnlqSolver , :lnlq , args_lnlq , def_args_lnlq , () , () , kwargs_lnlq , def_kwargs_lnlq ) - (:BicgstabSolver , :bicgstab , args_bicgstab , def_args_bicgstab , optargs_bicgstab , def_optargs_bicgstab , kwargs_bicgstab , def_kwargs_bicgstab ) - (:CrlsSolver , :crls , args_crls , def_args_crls , () , () , kwargs_crls , def_kwargs_crls ) - (:LsqrSolver , :lsqr , args_lsqr , def_args_lsqr , () , () , kwargs_lsqr , def_kwargs_lsqr ) - (:MinresSolver , :minres , args_minres , def_args_minres , optargs_minres , def_optargs_minres , kwargs_minres , def_kwargs_minres ) - (:MinaresSolver , :minares , args_minares , def_args_minares , optargs_minares , def_optargs_minares , kwargs_minares , def_kwargs_minares ) - (:CgneSolver , :cgne , args_cgne , def_args_cgne , () , () , kwargs_cgne , def_kwargs_cgne ) - (:DqgmresSolver , :dqgmres , args_dqgmres , def_args_dqgmres , optargs_dqgmres , def_optargs_dqgmres , kwargs_dqgmres , def_kwargs_dqgmres ) - (:SymmlqSolver , :symmlq , args_symmlq , def_args_symmlq , optargs_symmlq , def_optargs_symmlq , kwargs_symmlq , def_kwargs_symmlq ) - (:TrimrSolver , :trimr , args_trimr , def_args_trimr , optargs_trimr , def_optargs_trimr , kwargs_trimr , def_kwargs_trimr ) - (:UsymqrSolver , :usymqr , args_usymqr , def_args_usymqr , optargs_usymqr , def_optargs_usymqr , kwargs_usymqr , def_kwargs_usymqr ) - (:BilqrSolver , :bilqr , args_bilqr , def_args_bilqr , optargs_bilqr , def_optargs_bilqr , kwargs_bilqr , def_kwargs_bilqr ) - (:CrSolver , :cr , args_cr , def_args_cr , optargs_cr , def_optargs_cr , kwargs_cr , def_kwargs_cr ) - (:CarSolver , :car , args_car , def_args_car , optargs_car , def_optargs_car , kwargs_car , def_kwargs_car ) - (:CraigmrSolver , :craigmr , args_craigmr , def_args_craigmr , () , () , kwargs_craigmr , def_kwargs_craigmr ) - (:TricgSolver , :tricg , args_tricg , def_args_tricg , optargs_tricg , def_optargs_tricg , kwargs_tricg , def_kwargs_tricg ) - (:CraigSolver , :craig , args_craig , def_args_craig , () , () , kwargs_craig , def_kwargs_craig ) - (:DiomSolver , :diom , args_diom , def_args_diom , optargs_diom , def_optargs_diom , kwargs_diom , def_kwargs_diom ) - (:LslqSolver , :lslq , args_lslq , def_args_lslq , () , () , kwargs_lslq , def_kwargs_lslq ) - (:TrilqrSolver , :trilqr , args_trilqr , def_args_trilqr , optargs_trilqr , def_optargs_trilqr , kwargs_trilqr , def_kwargs_trilqr ) - (:CrmrSolver , :crmr , args_crmr , def_args_crmr , () , () , kwargs_crmr , def_kwargs_crmr ) - (:CgSolver , :cg , args_cg , def_args_cg , optargs_cg , def_optargs_cg , kwargs_cg , def_kwargs_cg ) - (:CglsSolver , :cgls , args_cgls , def_args_cgls , () , () , kwargs_cgls , def_kwargs_cgls ) - (:CgLanczosSolver, :cg_lanczos , args_cg_lanczos , def_args_cg_lanczos , optargs_cg_lanczos, def_optargs_cg_lanczos, kwargs_cg_lanczos , def_kwargs_cg_lanczos) - (:BilqSolver , :bilq , args_bilq , def_args_bilq , optargs_bilq , def_optargs_bilq , kwargs_bilq , def_kwargs_bilq ) - (:MinresQlpSolver, :minres_qlp , args_minres_qlp , def_args_minres_qlp , optargs_minres_qlp, def_optargs_minres_qlp, kwargs_minres_qlp , def_kwargs_minres_qlp) - (:QmrSolver , :qmr , args_qmr , def_args_qmr , optargs_qmr , def_optargs_qmr , kwargs_qmr , def_kwargs_qmr ) - (:GmresSolver , :gmres , args_gmres , def_args_gmres , optargs_gmres , def_optargs_gmres , kwargs_gmres , def_kwargs_gmres ) - (:FgmresSolver , :fgmres , args_fgmres , def_args_fgmres , optargs_fgmres , def_optargs_fgmres , kwargs_fgmres , def_kwargs_fgmres ) - (:FomSolver , :fom , args_fom , def_args_fom , optargs_fom , def_optargs_fom , kwargs_fom , def_kwargs_fom ) - (:GpmrSolver , :gpmr , args_gpmr , def_args_gpmr , optargs_gpmr , def_optargs_gpmr , kwargs_gpmr , def_kwargs_gpmr ) - (:CgLanczosShiftSolver , :cg_lanczos_shift , args_cg_lanczos_shift , def_args_cg_lanczos_shift , (), (), kwargs_cg_lanczos_shift , def_kwargs_cg_lanczos_shift ) - (:CglsLanczosShiftSolver, :cgls_lanczos_shift, args_cgls_lanczos_shift, def_args_cgls_lanczos_shift, (), (), kwargs_cgls_lanczos_shift, def_kwargs_cgls_lanczos_shift) + (:LsmrSolver , :lsmr , args_lsmr , def_args_lsmr , () , () , kwargs_lsmr , def_kwargs_lsmr ) + (:CgsSolver , :cgs , args_cgs , def_args_cgs , optargs_cgs , def_optargs_cgs , kwargs_cgs , def_kwargs_cgs ) + (:UsymlqSolver , :usymlq , args_usymlq , def_args_usymlq , optargs_usymlq , def_optargs_usymlq , kwargs_usymlq , def_kwargs_usymlq ) + (:LnlqSolver , :lnlq , args_lnlq , def_args_lnlq , () , () , kwargs_lnlq , def_kwargs_lnlq ) + (:BicgstabSolver , :bicgstab , args_bicgstab , def_args_bicgstab , optargs_bicgstab , def_optargs_bicgstab , kwargs_bicgstab , def_kwargs_bicgstab ) + (:CrlsSolver , :crls , args_crls , def_args_crls , () , () , kwargs_crls , def_kwargs_crls ) + (:LsqrSolver , :lsqr , args_lsqr , def_args_lsqr , () , () , kwargs_lsqr , def_kwargs_lsqr ) + (:MinresSolver , :minres , args_minres , def_args_minres , optargs_minres , def_optargs_minres , kwargs_minres , def_kwargs_minres ) + (:MinaresSolver , :minares , args_minares , def_args_minares , optargs_minares , def_optargs_minares , kwargs_minares , def_kwargs_minares ) + (:CgneSolver , :cgne , args_cgne , def_args_cgne , () , () , kwargs_cgne , def_kwargs_cgne ) + (:DqgmresSolver , :dqgmres , args_dqgmres , def_args_dqgmres , optargs_dqgmres , def_optargs_dqgmres , kwargs_dqgmres , def_kwargs_dqgmres ) + (:SymmlqSolver , :symmlq , args_symmlq , def_args_symmlq , optargs_symmlq , def_optargs_symmlq , kwargs_symmlq , def_kwargs_symmlq ) + (:TrimrSolver , :trimr , args_trimr , def_args_trimr , optargs_trimr , def_optargs_trimr , kwargs_trimr , def_kwargs_trimr ) + (:UsymqrSolver , :usymqr , args_usymqr , def_args_usymqr , optargs_usymqr , def_optargs_usymqr , kwargs_usymqr , def_kwargs_usymqr ) + (:BilqrSolver , :bilqr , args_bilqr , def_args_bilqr , optargs_bilqr , def_optargs_bilqr , kwargs_bilqr , def_kwargs_bilqr ) + (:CrSolver , :cr , args_cr , def_args_cr , optargs_cr , def_optargs_cr , kwargs_cr , def_kwargs_cr ) + (:CarSolver , :car , args_car , def_args_car , optargs_car , def_optargs_car , kwargs_car , def_kwargs_car ) + (:CraigmrSolver , :craigmr , args_craigmr , def_args_craigmr , () , () , kwargs_craigmr , def_kwargs_craigmr ) + (:TricgSolver , :tricg , args_tricg , def_args_tricg , optargs_tricg , def_optargs_tricg , kwargs_tricg , def_kwargs_tricg ) + (:CraigSolver , :craig , args_craig , def_args_craig , () , () , kwargs_craig , def_kwargs_craig ) + (:DiomSolver , :diom , args_diom , def_args_diom , optargs_diom , def_optargs_diom , kwargs_diom , def_kwargs_diom ) + (:LslqSolver , :lslq , args_lslq , def_args_lslq , () , () , kwargs_lslq , def_kwargs_lslq ) + (:TrilqrSolver , :trilqr , args_trilqr , def_args_trilqr , optargs_trilqr , def_optargs_trilqr , kwargs_trilqr , def_kwargs_trilqr ) + (:CrmrSolver , :crmr , args_crmr , def_args_crmr , () , () , kwargs_crmr , def_kwargs_crmr ) + (:CgSolver , :cg , args_cg , def_args_cg , optargs_cg , def_optargs_cg , kwargs_cg , def_kwargs_cg ) + (:CgLanczosShiftSolver, :cg_lanczos_shift, args_cg_lanczos_shift, def_args_cg_lanczos_shift, () , () , kwargs_cg_lanczos_shift, def_kwargs_cg_lanczos_shift) + (:CglsSolver , :cgls , args_cgls , def_args_cgls , () , () , kwargs_cgls , def_kwargs_cgls ) + (:CgLanczosSolver , :cg_lanczos , args_cg_lanczos , def_args_cg_lanczos , optargs_cg_lanczos, def_optargs_cg_lanczos, kwargs_cg_lanczos , def_kwargs_cg_lanczos ) + (:BilqSolver , :bilq , args_bilq , def_args_bilq , optargs_bilq , def_optargs_bilq , kwargs_bilq , def_kwargs_bilq ) + (:MinresQlpSolver , :minres_qlp , args_minres_qlp , def_args_minres_qlp , optargs_minres_qlp, def_optargs_minres_qlp, kwargs_minres_qlp , def_kwargs_minres_qlp ) + (:QmrSolver , :qmr , args_qmr , def_args_qmr , optargs_qmr , def_optargs_qmr , kwargs_qmr , def_kwargs_qmr ) + (:GmresSolver , :gmres , args_gmres , def_args_gmres , optargs_gmres , def_optargs_gmres , kwargs_gmres , def_kwargs_gmres ) + (:FgmresSolver , :fgmres , args_fgmres , def_args_fgmres , optargs_fgmres , def_optargs_fgmres , kwargs_fgmres , def_kwargs_fgmres ) + (:FomSolver , :fom , args_fom , def_args_fom , optargs_fom , def_optargs_fom , kwargs_fom , def_kwargs_fom ) + (:GpmrSolver , :gpmr , args_gpmr , def_args_gpmr , optargs_gpmr , def_optargs_gpmr , kwargs_gpmr , def_kwargs_gpmr ) ] # Create the symbol for the in-place method krylov! = Symbol(krylov, :!) @@ -176,37 +175,67 @@ end # Block-Krylov methods for (workspace, krylov, args, def_args, optargs, def_optargs, kwargs, def_kwargs) in [ - (:BlockGmresSolver, :block_gmres, args_block_gmres, def_args_block_gmres, optargs_block_gmres, def_optargs_block_gmres, kwargs_block_gmres, def_kwargs_block_gmres), + (:BlockMinresSolver, :block_minres, args_block_minres, def_args_block_minres, optargs_block_minres, def_optargs_block_minres, kwargs_block_minres, def_kwargs_block_minres) + (:BlockGmresSolver , :block_gmres , args_block_gmres , def_args_block_gmres , optargs_block_gmres , def_optargs_block_gmres , kwargs_block_gmres , def_kwargs_block_gmres ) ] # Create the symbol for the in-place method krylov! = Symbol(krylov, :!) - @eval begin - ## Out-of-place - function $(krylov)($(def_args...); memory :: Int=20, $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}} - start_time = time_ns() - solver = $workspace(A, B, memory) - elapsed_time = ktimer(start_time) - timemax -= elapsed_time - $(krylov!)(solver, $(args...); $(kwargs...)) - solver.stats.timer += elapsed_time - return results(solver) - end - - if !isempty($optargs) - function $(krylov)($(def_args...), $(def_optargs...); memory :: Int=20, $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}} + ## Out-of-place + if krylov == :block_gmres + @eval begin + function $(krylov)($(def_args...); memory :: Int=20, $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}} start_time = time_ns() solver = $workspace(A, B, memory) - warm_start!(solver, $(optargs...)) elapsed_time = ktimer(start_time) timemax -= elapsed_time $(krylov!)(solver, $(args...); $(kwargs...)) solver.stats.timer += elapsed_time return results(solver) end + + if !isempty($optargs) + function $(krylov)($(def_args...), $(def_optargs...); memory :: Int=20, $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}} + start_time = time_ns() + solver = $workspace(A, B, memory) + warm_start!(solver, $(optargs...)) + elapsed_time = ktimer(start_time) + timemax -= elapsed_time + $(krylov!)(solver, $(args...); $(kwargs...)) + solver.stats.timer += elapsed_time + return results(solver) + end + end + end + else + @eval begin + function $(krylov)($(def_args...); $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}} + start_time = time_ns() + solver = $workspace(A, b) + elapsed_time = ktimer(start_time) + timemax -= elapsed_time + $(krylov!)(solver, $(args...); $(kwargs...)) + solver.stats.timer += elapsed_time + return results(solver) + end + + if !isempty($optargs) + function $(krylov)($(def_args...), $(def_optargs...); $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}} + start_time = time_ns() + solver = $workspace(A, b) + warm_start!(solver, $(optargs...)) + elapsed_time = ktimer(start_time) + timemax -= elapsed_time + $(krylov!)(solver, $(args...); $(kwargs...)) + solver.stats.timer += elapsed_time + return results(solver) + end + end end + end - ## In-place + ## In-place + @eval begin solve!(solver :: $workspace{T,FC,SV,SM}, $(def_args...); $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, SV <: AbstractVector{FC}, SM <: AbstractMatrix{FC}} = $(krylov!)(solver, $(args...); $(kwargs...)) if !isempty($optargs) diff --git a/src/usymlq.jl b/src/usymlq.jl index 18b859fa1..586b13af7 100644 --- a/src/usymlq.jl +++ b/src/usymlq.jl @@ -281,7 +281,7 @@ kwargs_usymlq = (:transfer_to_usymcg, :atol, :rtol, :itmax, :timemax, :verbose, kaxpby!(n, -cₖ, uₖ, conj(sₖ), d̅) end - # Compute uₖ₊₁ and uₖ₊₁. + # Compute vₖ₊₁ and uₖ₊₁. kcopy!(m, vₖ₋₁, vₖ) # vₖ₋₁ ← vₖ kcopy!(n, uₖ₋₁, uₖ) # uₖ₋₁ ← uₖ diff --git a/src/usymqr.jl b/src/usymqr.jl index e385ef43f..bd7403e3d 100644 --- a/src/usymqr.jl +++ b/src/usymqr.jl @@ -293,7 +293,7 @@ kwargs_usymqr = (:atol, :rtol, :itmax, :timemax, :verbose, :history, :callback, AᴴrNorm = abs(ζbarₖ) * √(abs2(δbarₖ) + abs2(cₖ₋₁ * γₖ₊₁)) history && push!(AᴴrNorms, AᴴrNorm) - # Compute uₖ₊₁ and uₖ₊₁. + # Compute vₖ₊₁ and uₖ₊₁. kcopy!(m, vₖ₋₁, vₖ) # vₖ₋₁ ← vₖ kcopy!(n, uₖ₋₁, uₖ) # uₖ₋₁ ← uₖ