diff --git a/src/block_krylov_solvers.jl b/src/block_krylov_solvers.jl index 070070ac5..075b86783 100644 --- a/src/block_krylov_solvers.jl +++ b/src/block_krylov_solvers.jl @@ -32,6 +32,10 @@ mutable struct BlockMinresSolver{T,FC,SV,SM} <: BlockKrylovSolver{T,FC,SV,SM} D :: SM τ :: SV tmp :: SM + Vₖ₋₁ :: SM + Vₖ :: SM + wₖ₋₂ :: SM + wₖ₋₁ :: SM warm_start :: Bool stats :: SimpleStats{T} end @@ -48,8 +52,12 @@ function BlockMinresSolver(m, n, p, SV, SM) D = SM(undef, 2p, p) τ = SV(undef, p) tmp = C isa Matrix ? SM(undef, 0, 0) : SM(undef, p, p) + Vₖ₋₁ = SM(undef, n, p) + Vₖ = SM(undef, n, p) + wₖ₋₂ = SM(undef, n, p) + wₖ₋₁ = SM(undef, n, p) stats = SimpleStats(0, false, false, T[], T[], T[], 0.0, "unknown") - solver = BlockMinresSolver{T,FC,SV,SM}(m, n, p, ΔX, X, W, P, Q, C, D, τ, tmp, false, stats) + solver = BlockMinresSolver{T,FC,SV,SM}(m, n, p, ΔX, X, W, P, Q, C, D, τ, tmp, Vₖ₋₁, Vₖ, wₖ₋₂, wₖ₋₁, false, stats) return solver end diff --git a/src/block_minres.jl b/src/block_minres.jl index ad143515c..19c36e34c 100644 --- a/src/block_minres.jl +++ b/src/block_minres.jl @@ -104,12 +104,14 @@ kwargs_block_minres = (:M, :ldiv, :atol, :rtol, :itmax, :timemax, :verbose, :his ktypeof(B) <: SM || error("ktypeof(B) is not a subtype of $SM") # Set up workspace. - ΔX, X, W, V, Z = solver.ΔX, solver.X, solver.W, solver.V, solver.Z + Vₖ₋₁, Vₖ = solver.Vₖ₋₁, solver.Vₖ + ΔX, X, W, Z = solver.ΔX, solver.X, solver.W, solver.Z C, D, R, H, τ, stats = solver.C, solver.D, solver.R, solver.H, solver.τ, solver.stats + wₖ₋₂, wₖ₋₁ = solver.wₖ₋₂, solver.wₖ₋₁ warm_start = solver.warm_start RNorms = stats.residuals reset!(stats) - R₀ = warm_start ? Q : B + R₀ = warm_start ? W : B # Define the blocks D1 and D2 D1 = view(D, 1:p, :) @@ -126,11 +128,10 @@ kwargs_block_minres = (:M, :ldiv, :atol, :rtol, :itmax, :timemax, :verbose, :his # Initial residual R₀. if warm_start - mul!(Q, A, ΔX) - Q .= B .- Q + mul!(W, A, ΔX) + W .= B .- W end - MisI || mulorldiv!(R₀, M, W, ldiv) # R₀ = M(B - AX₀) - RNorm = norm(R₀) # ‖R₀‖_F + RNorm = norm(R₀) # ‖R₀‖_F history && push!(RNorms, RNorm) iter = 0 @@ -162,9 +163,9 @@ kwargs_block_minres = (:M, :ldiv, :atol, :rtol, :itmax, :timemax, :verbose, :his # Continue the block-Lanczos process. mul!(W, A, V) # Q ← AVₖ for i = 1 : inner_iter - mul!(Ω, V', W) # Ωₖ = Vₖᴴ * Q - (iter ≥ 2) && mul!(Q, Vold, Ψ') # Q ← Q - Vₖ₋₁ * Ψₖᴴ - mul!(Q, V, R, α, β) # Q = Q - Vₖ * Ωₖ + mul!(Ω, Vₖ', W) # Ωₖ = Vₖᴴ * Q + (iter ≥ 2) && mul!(Q, Vₖ₋₁, Ψ') # Q ← Q - Vₖ₋₁ * Ψₖᴴ + mul!(Q, Vₖ, Ω, α, β) # Q = Q - Vₖ * Ωₖ end # Vₖ₊₁ and Ψₖ₊₁ are stored in Q and C. @@ -174,7 +175,22 @@ kwargs_block_minres = (:M, :ldiv, :atol, :rtol, :itmax, :timemax, :verbose, :his householder!(Q, C, τ, solver.tmp) end - # Update the QR factorization of Tₖ₊₁.ₖ. + # Update the QR factorization of Tₖ₊₁.ₖ = Qₖ [ Rₖ ]. + # [ Oᵀ ] + # + # [ Ω₁ Ψ₂ᴴ 0 • • • 0 ] [ Λ₁ Γ₁ Π₁ 0 • • 0 ] + # [ Ψ₂ Ω₂ • • • ] [ 0 Λ₂ Γ₂ • • • ] + # [ 0 • • • • • ] [ • • Λ₃ • • • • ] + # [ • • • • • • • ] = Qₖ [ • • • • • 0 ] + # [ • • • • • 0 ] [ • • • • Πₖ₋₂] + # [ • • • • Ψₖᴴ ] [ • • • Γₖ₋₁] + # [ • • Ψₖ Ωₖ ] [ 0 • • • • 0 Λₖ ] + # [ 0 • • • • 0 Ψₖ₊₁] [ 0 • • • • • 0 ] + # + # If k = 1, we don't have any previous reflection. + # If k = 2, we apply the last reflection. + # If k ≥ 3, we only apply the two previous reflections. + # Apply previous Householder reflections Θₖ₋₂. if k ≥ 3 D1 .= Rₖ₋₂.ₖ @@ -202,45 +218,54 @@ kwargs_block_minres = (:M, :ldiv, :atol, :rtol, :itmax, :timemax, :verbose, :his householder!(H[i], Rₖ.ₖ, τ[i], solver.tmp, compact=true) end - # Update Zₖ = (Qₖ)ᴴΓE₁ = (Λ₁, ..., Λₖ, Λbarₖ₊₁) - D1 .= Λbarₖ + # Update Zₖ = (Qₖ)ᴴΨ₁E₁ = (Φ₁, ..., Φₖ, Φbarₖ₊₁) + D1 .= Φbarₖ D2 .= zero(FC) kormqr!('L', trans, H[i], τ[i], D) - Λₖ .= D1 + Φₖ = D1 # Compute the directions Wₖ, the last columns of Wₖ = Vₖ(Rₖ)⁻¹ ⟷ (Rₖ)ᵀ(Wₖ)ᵀ = (Vₖ)ᵀ # R₁₁w₁ = v₁ if iter == 1 - # wₖ = wₖ₋₁ - # kaxpy!(n, one(FC), uₖ, wₖ) - # wₖ .= wₖ ./ δₖ + wₖ = wₖ₋₁ + wₖ .+= vₖ + ldiv!(LowerTriangular(R₁₁), wₖ) end # R₂₂w₂ = (v₂ - R₂₁w₁) if iter == 2 - # wₖ = wₖ₋₂ - # kaxpy!(n, -λₖ₋₁, wₖ₋₁, wₖ) - # kaxpy!(n, one(FC), uₖ, wₖ) - # wₖ .= wₖ ./ δₖ + wₖ = wₖ₋₂ + wₖ .-= R₂₁ * wₖ₋₁ + wₖ .+= vₖ + ldiv!(LowerTriangular(R₂₁), wₖ) end # Rₖₖwₖ = (vₖ - Rₖₖ₋₁wₖ₋₁ - Rₖₖ₋₂wₖ₋₂) if iter ≥ 3 - # kscal!(n, -ϵₖ₋₂, wₖ₋₂) - # wₖ = wₖ₋₂ - # kaxpy!(n, -λₖ₋₁, wₖ₋₁, wₖ) - # kaxpy!(n, one(FC), uₖ, wₖ) - # wₖ .= wₖ ./ δₖ + lmul!(UpperTriangular(Rₖₖ₋₂), wₖ₋₂) + wₖ = wₖ₋₂ + wₖ .-= Rₖₖ₋₁ * wₖ₋₁ + wₖ .+= vₖ + ldiv!(LowerTriangular(Rₖₖ), wₖ) end # Update Xₖ = VₖYₖ = WₖZₖ - # Xₖ = Xₖ₋₁ + Λₖ * wₖ - mul!(X, Λₖ, W[i], γ, β) + # Xₖ = Xₖ₋₁ + Φₖ * wₖ + mul!(X, Φₖ, wₖ, γ, β) # Update residual norm estimate. - # ‖ M(B - AXₖ) ‖_F = ‖Λbarₖ₊₁‖_F + # ‖ M(B - AXₖ) ‖_F = ‖Φbarₖ₊₁‖_F C .= D2 RNorm = norm(C) history && push!(RNorms, RNorm) + # Compute vₖ and vₖ₊₁ + copyto!(Vₖ₋₁, Vₖ) # vₖ₋₁ ← vₖ + copyto!(Vₖ, Q) # vₖ ← vₖ₊₁ + + # Update directions for X + if iter ≥ 2 + @kswap!(wₖ₋₂, wₖ₋₁) + end + # Update stopping criterion. user_requested_exit = callback(solver) :: Bool solved = RNorm ≤ ε