From 26e3403d95a125f4e7120c7142b9a8ab6189573c Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Wed, 23 Oct 2024 14:47:42 +0200 Subject: [PATCH 1/2] test(fix): Manopt API update --- test/shifted_negative_numbers_tests.jl | 8 ++++---- test/shifted_positive_numbers_tests.jl | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/test/shifted_negative_numbers_tests.jl b/test/shifted_negative_numbers_tests.jl index d48494f..a86be95 100644 --- a/test/shifted_negative_numbers_tests.jl +++ b/test/shifted_negative_numbers_tests.jl @@ -159,7 +159,7 @@ end b in (10.0, 5.0), c in (1.0, 10.0, -1.0), eps in (1e-4, 1e-5, 1e-8, 1e-10), - stepsize in (ConstantStepsize(0.1), ConstantStepsize(0.01), ConstantStepsize(0.001)) + stepsize in (ConstantLength(0.1), ConstantLength(0.01), ConstantLength(0.001)) expected_q = -b / 2a expected_minimum = c - b^2 / (4a) @@ -232,11 +232,11 @@ end obj = ManifoldGradientObjective(missing, grad_f!; evaluation=InplaceEvaluation()) dmp = DefaultManoptProblem(M, obj) s = GradientDescentState( - M, - q; + M; + p = q, stopping_criterion=StopWhenGradientNormLessNonAllocating(1e-8), stepsize=ConstantStepsizeNonAllocating(0.1), - direction=IdentityUpdateRule(), + direction=Manopt.IdentityUpdateRule(), retraction_method=default_retraction_method(M, typeof(q)), X=zero_vector(M, q), ) diff --git a/test/shifted_positive_numbers_tests.jl b/test/shifted_positive_numbers_tests.jl index b65183e..b6dae7e 100644 --- a/test/shifted_positive_numbers_tests.jl +++ b/test/shifted_positive_numbers_tests.jl @@ -159,7 +159,7 @@ end b in (-10.0, -5.0), c in (1.0, 10.0, -1.0), eps in (1e-4, 1e-5, 1e-8, 1e-10), - stepsize in (ConstantStepsize(0.1), ConstantStepsize(0.01), ConstantStepsize(0.001)) + stepsize in (ConstantLength(0.1), ConstantLength(0.01), ConstantLength(0.001)) expected_q = -b / 2a expected_minimum = c - b^2 / (4a) @@ -231,11 +231,11 @@ end obj = ManifoldGradientObjective(missing, grad_f!; evaluation=InplaceEvaluation()) dmp = DefaultManoptProblem(M, obj) s = GradientDescentState( - M, - q; + M; + p = q, stopping_criterion=StopWhenGradientNormLessNonAllocating(1e-8), stepsize=ConstantStepsizeNonAllocating(0.1), - direction=IdentityUpdateRule(), + direction=Manopt.IdentityUpdateRule(), retraction_method=default_retraction_method(M, typeof(q)), X=zero_vector(M, q), ) From b5f3b7049f6d65e5b35c974f7ea68f5f8b66d84f Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Wed, 23 Oct 2024 15:00:44 +0200 Subject: [PATCH 2/2] test(fix): missed ConstantLenght --- test/manopt_setuptests.jl | 2 +- test/natural_manifolds/normal_tests.jl | 2 +- test/single_point_manifold_tests.jl | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/manopt_setuptests.jl b/test/manopt_setuptests.jl index 4d42663..43af8ae 100644 --- a/test/manopt_setuptests.jl +++ b/test/manopt_setuptests.jl @@ -16,7 +16,7 @@ function (sc::StopWhenGradientNormLessNonAllocating)(mp, s, i) return false end -# Non allocating version of the same `ConstantStepsize` from `Manopt.jl` +# Non allocating version of the same `ConstantLength` from `Manopt.jl` struct ConstantStepsizeNonAllocating{T} <: Stepsize stepsize::T end diff --git a/test/natural_manifolds/normal_tests.jl b/test/natural_manifolds/normal_tests.jl index 1f9c693..e17fc72 100644 --- a/test/natural_manifolds/normal_tests.jl +++ b/test/natural_manifolds/normal_tests.jl @@ -25,6 +25,6 @@ end k = rand(rng, 1:10) m = randn(rng, k) γ = rand(rng)^2 + 1 - return MvNormalMeanScalePrecision(m, C) + return MvNormalMeanScalePrecision(m, γ) end end \ No newline at end of file diff --git a/test/single_point_manifold_tests.jl b/test/single_point_manifold_tests.jl index 0f18a17..807951a 100644 --- a/test/single_point_manifold_tests.jl +++ b/test/single_point_manifold_tests.jl @@ -82,7 +82,7 @@ end b in (10.0, 5.0), c in (1.0, 10.0, -1.0), eps in (1e-4, 1e-5, 1e-8, 1e-10), - stepsize in (ConstantStepsize(0.1), ConstantStepsize(0.01), ConstantStepsize(0.001)) + stepsize in (ConstantLength(0.1), ConstantLength(0.01), ConstantLength(0.001)) f(M, x) = (a .* x .^ 2 .+ b .* x .+ c)[1] grad_f(M, x) = 2 .* a .* x .+ b