Skip to content

Commit

Permalink
Add preconditioning to to_boundary
Browse files Browse the repository at this point in the history
  • Loading branch information
mpf committed May 21, 2024
1 parent b5bb89a commit 6e97d0a
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 8 deletions.
10 changes: 8 additions & 2 deletions src/cg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,15 @@ kwargs_cg = (:M, :ldiv, :radius, :linesearch, :atol, :rtol, :itmax, :timemax, :v
(zero_curvature || solved) && continue

α = γ / pAp

# Compute step size to boundary if applicable.
σ = radius > 0 ? maximum(to_boundary(n, x, p, radius, dNorm2=pNorm²)) : α
if radius == 0
σ = α
elseif MisI
σ = maximum(to_boundary(n, x, p, z, radius, dNorm2=pNorm²))
else
σ = maximum(to_boundary(n, x, p, z, radius, M=M, ldiv=!ldiv))
end

kdisplay(iter, verbose) && @printf(iostream, " %8.1e %8.1e %8.1e %.2fs\n", pAp, α, σ, ktimer(start_time))

Expand Down
23 changes: 17 additions & 6 deletions src/krylov_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -405,21 +405,32 @@ If `flip` is set to `true`, `σ1` and `σ2` are computed such that
‖x - σi d‖ = radius, i = 1, 2.
"""
function to_boundary(n :: Int, x :: AbstractVector{FC}, d :: AbstractVector{FC}, radius :: T; flip :: Bool=false, xNorm2 :: T=zero(T), dNorm2 :: T=zero(T)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
function to_boundary(n :: Int, x :: AbstractVector{FC}, d :: AbstractVector{FC}, z::AbstractVector{FC}, radius :: T; flip :: Bool=false, xNorm2 :: T=zero(T), dNorm2 :: T=zero(T), M=I, ldiv :: Bool = false) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
radius > 0 || error("radius must be positive")

# ‖d‖² σ² + (xᴴd + dᴴx) σ + (‖x‖² - Δ²).
rxd = @kdotr(n, x, d)
flip && (rxd = -rxd)
dNorm2 == zero(T) && (dNorm2 = @kdotr(n, d, d))
if M === I
# ‖d‖² σ² + (xᴴd + dᴴx) σ + (‖x‖² - Δ²).
rxd = @kdotr(n, x, d)
dNorm2 == zero(T) && (dNorm2 = @kdotr(n, d, d))
xNorm2 == zero(T) && (xNorm2 = @kdotr(n, x, x))
else
# (dᴴMd) σ² + (xᴴMd + dᴴMx) σ + (xᴴMx - Δ²).
mulorldiv!(z, M, x, ldiv)
rxd = dot(z, d)
xNorm2 = dot(z, x)
mulorldiv!(z, M, d, ldiv)
dNorm2 = dot(z, d)
end
dNorm2 == zero(T) && error("zero direction")
xNorm2 == zero(T) && (xNorm2 = @kdotr(n, x, x))
flip && (rxd = -rxd)

radius2 = radius * radius
(xNorm2 radius2) || error(@sprintf("outside of the trust region: ‖x‖²=%7.1e, Δ²=%7.1e", xNorm2, radius2))

# q₂ = ‖d‖², q₁ = xᴴd + dᴴx, q₀ = ‖x‖² - Δ²
# ‖x‖² ≤ Δ² ⟹ (q₁)² - 4 * q₂ * q₀ ≥ 0
roots = roots_quadratic(dNorm2, 2 * rxd, xNorm2 - radius2)

return roots # `σ1` and `σ2`
end

Expand Down

0 comments on commit 6e97d0a

Please sign in to comment.