Skip to content

Commit

Permalink
Use metaprogramming to generate out-of-place methods (#887)
Browse files Browse the repository at this point in the history
* Use metaprogramming to generate out-of-place methods

* Add a function results
  • Loading branch information
amontoison authored Oct 13, 2024
1 parent 466997f commit f5fbae9
Show file tree
Hide file tree
Showing 27 changed files with 118 additions and 468 deletions.
21 changes: 0 additions & 21 deletions src/bicgstab.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,27 +113,6 @@ optargs_bicgstab = (:x0,)
kwargs_bicgstab = (:c, :M, :N, :ldiv, :atol, :rtol, :itmax, :timemax, :verbose, :history, :callback, :iostream)

@eval begin
function bicgstab($(def_args_bicgstab...), $(def_optargs_bicgstab...); $(def_kwargs_bicgstab...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = BicgstabSolver(A, b)
warm_start!(solver, $(optargs_bicgstab...))
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
bicgstab!(solver, $(args_bicgstab...); $(kwargs_bicgstab...))
solver.stats.timer += elapsed_time
return (solver.x, solver.stats)
end

function bicgstab($(def_args_bicgstab...); $(def_kwargs_bicgstab...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = BicgstabSolver(A, b)
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
bicgstab!(solver, $(args_bicgstab...); $(kwargs_bicgstab...))
solver.stats.timer += elapsed_time
return (solver.x, solver.stats)
end

function bicgstab!(solver :: BicgstabSolver{T,FC,S}, $(def_args_bicgstab...); $(def_kwargs_bicgstab...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, S <: AbstractVector{FC}}

# Timer
Expand Down
21 changes: 0 additions & 21 deletions src/bilq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,27 +106,6 @@ optargs_bilq = (:x0,)
kwargs_bilq = (:c, :transfer_to_bicg, :M, :N, :ldiv, :atol, :rtol, :itmax, :timemax, :verbose, :history, :callback, :iostream)

@eval begin
function bilq($(def_args_bilq...), $(def_optargs_bilq...); $(def_kwargs_bilq...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = BilqSolver(A, b)
warm_start!(solver, $(optargs_bilq...))
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
bilq!(solver, $(args_bilq...); $(kwargs_bilq...))
solver.stats.timer += elapsed_time
return (solver.x, solver.stats)
end

function bilq($(def_args_bilq...); $(def_kwargs_bilq...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = BilqSolver(A, b)
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
bilq!(solver, $(args_bilq...); $(kwargs_bilq...))
solver.stats.timer += elapsed_time
return (solver.x, solver.stats)
end

function bilq!(solver :: BilqSolver{T,FC,S}, $(def_args_bilq...); $(def_kwargs_bilq...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, S <: AbstractVector{FC}}

# Timer
Expand Down
21 changes: 0 additions & 21 deletions src/bilqr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,27 +104,6 @@ optargs_bilqr = (:x0, :y0)
kwargs_bilqr = (:transfer_to_bicg, :atol, :rtol, :itmax, :timemax, :verbose, :history, :callback, :iostream)

@eval begin
function bilqr($(def_args_bilqr...), $(def_optargs_bilqr...); $(def_kwargs_bilqr...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = BilqrSolver(A, b)
warm_start!(solver, $(optargs_bilqr...))
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
bilqr!(solver, $(args_bilqr...); $(kwargs_bilqr...))
solver.stats.timer += elapsed_time
return (solver.x, solver.y, solver.stats)
end

function bilqr($(def_args_bilqr...); $(def_kwargs_bilqr...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = BilqrSolver(A, b)
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
bilqr!(solver, $(args_bilqr...); $(kwargs_bilqr...))
solver.stats.timer += elapsed_time
return (solver.x, solver.y, solver.stats)
end

function bilqr!(solver :: BilqrSolver{T,FC,S}, $(def_args_bilqr...); $(def_kwargs_bilqr...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, S <: AbstractVector{FC}}

# Timer
Expand Down
21 changes: 0 additions & 21 deletions src/block_gmres.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,27 +91,6 @@ optargs_block_gmres = (:X0,)
kwargs_block_gmres = (:M, :N, :ldiv, :restart, :reorthogonalization, :atol, :rtol, :itmax, :timemax, :verbose, :history, :callback, :iostream)

@eval begin
function block_gmres($(def_args_block_gmres...), $(def_optargs_block_gmres...); memory :: Int=20, $(def_kwargs_block_gmres...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = BlockGmresSolver(A, B; memory)
warm_start!(solver, $(optargs_block_gmres...))
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
block_gmres!(solver, $(args_block_gmres...); $(kwargs_block_gmres...))
solver.stats.timer += elapsed_time
return solver.X, solver.stats
end

function block_gmres($(def_args_block_gmres...); memory :: Int=20, $(def_kwargs_block_gmres...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = BlockGmresSolver(A, B; memory)
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
block_gmres!(solver, $(args_block_gmres...); $(kwargs_block_gmres...))
solver.stats.timer += elapsed_time
return solver.X, solver.stats
end

function block_gmres!(solver :: BlockGmresSolver{T,FC,SV,SM}, $(def_args_block_gmres...); $(def_kwargs_block_gmres...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, SV <: AbstractVector{FC}, SM <: AbstractMatrix{FC}}

# Timer
Expand Down
5 changes: 3 additions & 2 deletions src/block_krylov_solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Type for storing the vectors required by the in-place version of BLOCK-GMRES.
The outer constructors
solver = BlockGmresSolver(m, n, p, memory, SV, SM)
solver = BlockGmresSolver(A, B; memory=5)
solver = BlockGmresSolver(A, B, memory = 5)
may be used in order to create these vectors.
`memory` is set to `div(n,p)` if the value given is larger than `div(n,p)`.
Expand Down Expand Up @@ -59,7 +59,7 @@ function BlockGmresSolver(m, n, p, memory, SV, SM)
return solver
end

function BlockGmresSolver(A, B; memory::Int=5)
function BlockGmresSolver(A, B, memory = 5)
m, n = size(A)
s, p = size(B)
SM = typeof(B)
Expand All @@ -81,6 +81,7 @@ for (KS, fun, nsol, nA, nAt, warm_start) in [
if $nsol == 1
solution(solver :: $KS) = solver.X
solution(solver :: $KS, p :: Integer) = (p == 1) ? solution(solver) : error("solution(solver) has only one output.")
results(solver :: $KS) = (solver.X, solver.stats)
end
issolved(solver :: $KS) = solver.stats.solved
if $warm_start
Expand Down
21 changes: 0 additions & 21 deletions src/car.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,27 +95,6 @@ optargs_car = (:x0,)
kwargs_car = (:M, :ldiv, :atol, :rtol, :itmax, :timemax, :verbose, :history, :callback, :iostream)

@eval begin
function car($(def_args_car...), $(def_optargs_car...); $(def_kwargs_car...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = CarSolver(A, b)
warm_start!(solver, $(optargs_car...))
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
car!(solver, $(args_car...); $(kwargs_car...))
solver.stats.timer += elapsed_time
return (solver.x, solver.stats)
end

function car($(def_args_car...); $(def_kwargs_car...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = CarSolver(A, b)
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
car!(solver, $(args_car...); $(kwargs_car...))
solver.stats.timer += elapsed_time
return (solver.x, solver.stats)
end

function car!(solver :: CarSolver{T,FC,S}, $(def_args_car...); $(def_kwargs_car...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, S <: AbstractVector{FC}}

# Timer
Expand Down
21 changes: 0 additions & 21 deletions src/cg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,27 +105,6 @@ optargs_cg = (:x0,)
kwargs_cg = (:M, :ldiv, :radius, :linesearch, :atol, :rtol, :itmax, :timemax, :verbose, :history, :callback, :iostream)

@eval begin
function cg($(def_args_cg...), $(def_optargs_cg...); $(def_kwargs_cg...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = CgSolver(A, b)
warm_start!(solver, $(optargs_cg...))
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
cg!(solver, $(args_cg...); $(kwargs_cg...))
solver.stats.timer += elapsed_time
return (solver.x, solver.stats)
end

function cg($(def_args_cg...); $(def_kwargs_cg...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = CgSolver(A, b)
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
cg!(solver, $(args_cg...); $(kwargs_cg...))
solver.stats.timer += elapsed_time
return (solver.x, solver.stats)
end

function cg!(solver :: CgSolver{T,FC,S}, $(def_args_cg...); $(def_kwargs_cg...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, S <: AbstractVector{FC}}

# Timer
Expand Down
21 changes: 0 additions & 21 deletions src/cg_lanczos.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,27 +101,6 @@ optargs_cg_lanczos = (:x0,)
kwargs_cg_lanczos = (:M, :ldiv, :check_curvature, :atol, :rtol, :itmax, :timemax, :verbose, :history, :callback, :iostream)

@eval begin
function cg_lanczos($(def_args_cg_lanczos...), $(def_optargs_cg_lanczos...); $(def_kwargs_cg_lanczos...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = CgLanczosSolver(A, b)
warm_start!(solver, $(optargs_cg_lanczos...))
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
cg_lanczos!(solver, $(args_cg_lanczos...); $(kwargs_cg_lanczos...))
solver.stats.timer += elapsed_time
return (solver.x, solver.stats)
end

function cg_lanczos($(def_args_cg_lanczos...); $(def_kwargs_cg_lanczos...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = CgLanczosSolver(A, b)
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
cg_lanczos!(solver, $(args_cg_lanczos...); $(kwargs_cg_lanczos...))
solver.stats.timer += elapsed_time
return (solver.x, solver.stats)
end

function cg_lanczos!(solver :: CgLanczosSolver{T,FC,S}, $(def_args_cg_lanczos...); $(def_kwargs_cg_lanczos...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, S <: AbstractVector{FC}}

# Timer
Expand Down
10 changes: 0 additions & 10 deletions src/cgls.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,16 +117,6 @@ args_cgls = (:A, :b)
kwargs_cgls = (:M, :ldiv, :radius, , :atol, :rtol, :itmax, :timemax, :verbose, :history, :callback, :iostream)

@eval begin
function cgls($(def_args_cgls...); $(def_kwargs_cgls...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = CglsSolver(A, b)
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
cgls!(solver, $(args_cgls...); $(kwargs_cgls...))
solver.stats.timer += elapsed_time
return (solver.x, solver.stats)
end

function cgls!(solver :: CglsSolver{T,FC,S}, $(def_args_cgls...); $(def_kwargs_cgls...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, S <: AbstractVector{FC}}

# Timer
Expand Down
10 changes: 0 additions & 10 deletions src/cgne.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,16 +122,6 @@ args_cgne = (:A, :b)
kwargs_cgne = (:N, :ldiv, , :atol, :rtol, :itmax, :timemax, :verbose, :history, :callback, :iostream)

@eval begin
function cgne($(def_args_cgne...); $(def_kwargs_cgne...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = CgneSolver(A, b)
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
cgne!(solver, $(args_cgne...); $(kwargs_cgne...))
solver.stats.timer += elapsed_time
return (solver.x, solver.stats)
end

function cgne!(solver :: CgneSolver{T,FC,S}, $(def_args_cgne...); $(def_kwargs_cgne...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, S <: AbstractVector{FC}}

# Timer
Expand Down
21 changes: 0 additions & 21 deletions src/cgs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,27 +114,6 @@ optargs_cgs = (:x0,)
kwargs_cgs = (:c, :M, :N, :ldiv, :atol, :rtol, :itmax, :timemax, :verbose, :history, :callback, :iostream)

@eval begin
function cgs($(def_args_cgs...), $(def_optargs_cgs...); $(def_kwargs_cgs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = CgsSolver(A, b)
warm_start!(solver, $(optargs_cgs...))
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
cgs!(solver, $(args_cgs...); $(kwargs_cgs...))
solver.stats.timer += elapsed_time
return (solver.x, solver.stats)
end

function cgs($(def_args_cgs...); $(def_kwargs_cgs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = CgsSolver(A, b)
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
cgs!(solver, $(args_cgs...); $(kwargs_cgs...))
solver.stats.timer += elapsed_time
return (solver.x, solver.stats)
end

function cgs!(solver :: CgsSolver{T,FC,S}, $(def_args_cgs...); $(def_kwargs_cgs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, S <: AbstractVector{FC}}

# Timer
Expand Down
21 changes: 0 additions & 21 deletions src/cr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,27 +112,6 @@ optargs_cr = (:x0,)
kwargs_cr = (:M, :ldiv, :radius, :linesearch, , :atol, :rtol, :itmax, :timemax, :verbose, :history, :callback, :iostream)

@eval begin
function cr($(def_args_cr...), $(def_optargs_cr...); $(def_kwargs_cr...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = CrSolver(A, b)
warm_start!(solver, $(optargs_cr...))
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
cr!(solver, $(args_cr...); $(kwargs_cr...))
solver.stats.timer += elapsed_time
return (solver.x, solver.stats)
end

function cr($(def_args_cr...); $(def_kwargs_cr...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = CrSolver(A, b)
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
cr!(solver, $(args_cr...); $(kwargs_cr...))
solver.stats.timer += elapsed_time
return (solver.x, solver.stats)
end

function cr!(solver :: CrSolver{T,FC,S}, $(def_args_cr...); $(def_kwargs_cr...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, S <: AbstractVector{FC}}

# Timer
Expand Down
10 changes: 0 additions & 10 deletions src/craig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,16 +162,6 @@ args_craig = (:A, :b)
kwargs_craig = (:M, :N, :ldiv, :transfer_to_lsqr, :sqd, , :btol, :conlim, :atol, :rtol, :itmax, :timemax, :verbose, :history, :callback, :iostream)

@eval begin
function craig($(def_args_craig...); $(def_kwargs_craig...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = CraigSolver(A, b)
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
craig!(solver, $(args_craig...); $(kwargs_craig...))
solver.stats.timer += elapsed_time
return (solver.x, solver.y, solver.stats)
end

function craig!(solver :: CraigSolver{T,FC,S}, $(def_args_craig...); $(def_kwargs_craig...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, S <: AbstractVector{FC}}

# Timer
Expand Down
10 changes: 0 additions & 10 deletions src/craigmr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,16 +149,6 @@ args_craigmr = (:A, :b)
kwargs_craigmr = (:M, :N, :ldiv, :sqd, , :atol, :rtol, :itmax, :timemax, :verbose, :history, :callback, :iostream)

@eval begin
function craigmr($(def_args_craigmr...); $(def_kwargs_craigmr...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = CraigmrSolver(A, b)
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
craigmr!(solver, $(args_craigmr...); $(kwargs_craigmr...))
solver.stats.timer += elapsed_time
return (solver.x, solver.y, solver.stats)
end

function craigmr!(solver :: CraigmrSolver{T,FC,S}, $(def_args_craigmr...); $(def_kwargs_craigmr...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, S <: AbstractVector{FC}}

# Timer
Expand Down
10 changes: 0 additions & 10 deletions src/crls.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,16 +108,6 @@ args_crls = (:A, :b)
kwargs_crls = (:M, :ldiv, :radius, , :atol, :rtol, :itmax, :timemax, :verbose, :history, :callback, :iostream)

@eval begin
function crls($(def_args_crls...); $(def_kwargs_crls...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = CrlsSolver(A, b)
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
crls!(solver, $(args_crls...); $(kwargs_crls...))
solver.stats.timer += elapsed_time
return (solver.x, solver.stats)
end

function crls!(solver :: CrlsSolver{T,FC,S}, $(def_args_crls...); $(def_kwargs_crls...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, S <: AbstractVector{FC}}

# Timer
Expand Down
10 changes: 0 additions & 10 deletions src/crmr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,16 +120,6 @@ args_crmr = (:A, :b)
kwargs_crmr = (:N, :ldiv, , :atol, :rtol, :itmax, :timemax, :verbose, :history, :callback, :iostream)

@eval begin
function crmr($(def_args_crmr...); $(def_kwargs_crmr...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = CrmrSolver(A, b)
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
crmr!(solver, $(args_crmr...); $(kwargs_crmr...))
solver.stats.timer += elapsed_time
return (solver.x, solver.stats)
end

function crmr!(solver :: CrmrSolver{T,FC,S}, $(def_args_crmr...); $(def_kwargs_crmr...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, S <: AbstractVector{FC}}

# Timer
Expand Down
Loading

0 comments on commit f5fbae9

Please sign in to comment.