diff --git a/src/cg.jl b/src/cg.jl index 1345a6232..50517e427 100644 --- a/src/cg.jl +++ b/src/cg.jl @@ -210,9 +210,15 @@ kwargs_cg = (:M, :ldiv, :radius, :linesearch, :atol, :rtol, :itmax, :timemax, :v (zero_curvature || solved) && continue α = γ / pAp - + # Compute step size to boundary if applicable. - σ = radius > 0 ? maximum(to_boundary(n, x, p, radius, dNorm2=pNorm²)) : α + if radius == 0 + σ = α + elseif MisI + σ = maximum(to_boundary(n, x, p, z, radius, dNorm2=pNorm²)) + else + σ = maximum(to_boundary(n, x, p, z, radius, M=M, ldiv=!ldiv)) + end kdisplay(iter, verbose) && @printf(iostream, " %8.1e %8.1e %8.1e %.2fs\n", pAp, α, σ, ktimer(start_time)) diff --git a/src/krylov_utils.jl b/src/krylov_utils.jl index ca73bc121..239e56b75 100644 --- a/src/krylov_utils.jl +++ b/src/krylov_utils.jl @@ -405,21 +405,32 @@ If `flip` is set to `true`, `σ1` and `σ2` are computed such that ‖x - σi d‖ = radius, i = 1, 2. """ -function to_boundary(n :: Int, x :: AbstractVector{FC}, d :: AbstractVector{FC}, radius :: T; flip :: Bool=false, xNorm2 :: T=zero(T), dNorm2 :: T=zero(T)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}} +function to_boundary(n :: Int, x :: AbstractVector{FC}, d :: AbstractVector{FC}, z::AbstractVector{FC}, radius :: T; flip :: Bool=false, xNorm2 :: T=zero(T), dNorm2 :: T=zero(T), M=I, ldiv :: Bool = false) where {T <: AbstractFloat, FC <: FloatOrComplex{T}} radius > 0 || error("radius must be positive") - # ‖d‖² σ² + (xᴴd + dᴴx) σ + (‖x‖² - Δ²). - rxd = @kdotr(n, x, d) - flip && (rxd = -rxd) - dNorm2 == zero(T) && (dNorm2 = @kdotr(n, d, d)) + if M === I + # ‖d‖² σ² + (xᴴd + dᴴx) σ + (‖x‖² - Δ²). + rxd = @kdotr(n, x, d) + dNorm2 == zero(T) && (dNorm2 = @kdotr(n, d, d)) + xNorm2 == zero(T) && (xNorm2 = @kdotr(n, x, x)) + else + # (dᴴMd) σ² + (xᴴMd + dᴴMx) σ + (xᴴMx - Δ²). + mulorldiv!(z, M, x, ldiv) + rxd = dot(z, d) + xNorm2 = dot(z, x) + mulorldiv!(z, M, d, ldiv) + dNorm2 = dot(z, d) + end dNorm2 == zero(T) && error("zero direction") - xNorm2 == zero(T) && (xNorm2 = @kdotr(n, x, x)) + flip && (rxd = -rxd) + radius2 = radius * radius (xNorm2 ≤ radius2) || error(@sprintf("outside of the trust region: ‖x‖²=%7.1e, Δ²=%7.1e", xNorm2, radius2)) # q₂ = ‖d‖², q₁ = xᴴd + dᴴx, q₀ = ‖x‖² - Δ² # ‖x‖² ≤ Δ² ⟹ (q₁)² - 4 * q₂ * q₀ ≥ 0 roots = roots_quadratic(dNorm2, 2 * rxd, xNorm2 - radius2) + return roots # `σ1` and `σ2` end