From f5763f55e13c15ffc4d4ff72ec9d09576e168727 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C4=B0smail=20=C5=9Een=C3=B6z?= Date: Fri, 2 Aug 2024 16:31:38 +0200 Subject: [PATCH 1/4] direction rule as argument to projection parameters --- src/projected_to.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/projected_to.jl b/src/projected_to.jl index 2874c81..2d9398f 100644 --- a/src/projected_to.jl +++ b/src/projected_to.jl @@ -103,13 +103,15 @@ The following parameters are available: * `niterations = 100`: The number of iterations for the optimization procedure. * `tolerance = 1e-6`: The tolerance for the norm of the gradient. * `stepsize = ConstantStepsize(0.1)`: The stepsize for the optimization procedure. Accepts stepsizes from `Manopt.jl`. +* `direction = BoundedNormUpdateRule(static(1.0)`: Direction update rule. Accepts `Manopt.DirectionUpdateRule` from `Manopt.jl`. * `usebuffer = Val(true)`: Whether to use a buffer for the projection. Must be either `Val(true)` or `Val(false)`. Disabling buffer can be useful for debugging purposes. """ -Base.@kwdef struct ProjectionParameters{S,I,T,P,B} +Base.@kwdef struct ProjectionParameters{S,I,T,P,D,B} strategy::S = DefaultStrategy() niterations::I = 100 tolerance::T = 1e-6 stepsize::P = ConstantStepsize(0.1) + direction::D = BoundedNormUpdateRule(static(1.0)) usebuffer::B = Val(true) end @@ -124,6 +126,7 @@ getstrategy(parameters::ProjectionParameters) = parameters.strategy getniterations(parameters::ProjectionParameters) = parameters.niterations gettolerance(parameters::ProjectionParameters) = parameters.tolerance getstepsize(parameters::ProjectionParameters) = parameters.stepsize +getdirection(parameters::ProjectionParameters) = parameters.direction with_buffer(f::F, parameters::ProjectionParameters) where {F} = with_buffer(f, parameters.usebuffer, parameters) @@ -292,7 +295,7 @@ function _kernel_project_to( p; stopping_criterion = get_stopping_criterion(parameters), stepsize = getstepsize(parameters), - direction = BoundedNormUpdateRule(static(1)), + direction = getdirection(parameters), kwargs..., ) From 1b0df2b362372e3894761d52c5078e0367389a8a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C4=B0smail=20=C5=9Een=C3=B6z?= Date: Sat, 3 Aug 2024 13:28:53 +0200 Subject: [PATCH 2/4] add tests for direction rule effect comparison --- test/projection/projected_to_tests.jl | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/test/projection/projected_to_tests.jl b/test/projection/projected_to_tests.jl index e4d650a..062fe57 100644 --- a/test/projection/projected_to_tests.jl +++ b/test/projection/projected_to_tests.jl @@ -496,4 +496,29 @@ end @test_logs match_mode = :all project_to(prj, targetfn) @test_logs match_mode = :all project_to(prj, targetfn, debug = []) +end + +@testitem "Direction rule effect comparison for MLE" begin + using BayesBase, ExponentialFamily, Distributions, JET + using ExponentialFamilyProjection, StableRNGs + + true_dist = (Beta(1, 1), Gamma(10, 20), Bernoulli(0.8), NormalMeanVariance(-10, 0.1), Poisson(4.8)) + for dist in true_dist + data = rand(StableRNG(42), dist, 100) + divergences = [] + for norm in (0.0:0.01:1.0) + parameters = ProjectionParameters( + direction = ExponentialFamilyProjection.BoundedNormUpdateRule(norm) + ) + projection = ProjectedTo(ExponentialFamily.exponential_family_typetag.(dist), ()..., parameters = parameters) + approximated = project_to(projection, data) + push!(divergences, kldivergence(approximated, dist)) + end + + Δdivergences = divergences[1:end-1] - divergences[2:end] + for i = 1 : length(Δdivergences) - 1 + @test Δdivergences[i] ≠ Δdivergences[i + 1] + end + end + end \ No newline at end of file From 03a3bc731884d48068c4b98f452388c621038672 Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Mon, 5 Aug 2024 11:31:37 +0200 Subject: [PATCH 3/4] test: usefull tests for Direction rule --- test/projection/projected_to_tests.jl | 84 ++++++++++++++++++++++----- 1 file changed, 71 insertions(+), 13 deletions(-) diff --git a/test/projection/projected_to_tests.jl b/test/projection/projected_to_tests.jl index 062fe57..5e71548 100644 --- a/test/projection/projected_to_tests.jl +++ b/test/projection/projected_to_tests.jl @@ -498,27 +498,85 @@ end end -@testitem "Direction rule effect comparison for MLE" begin - using BayesBase, ExponentialFamily, Distributions, JET +@testitem "Direction rule can improve for MLE" begin + using BayesBase, ExponentialFamily, Distributions using ExponentialFamilyProjection, StableRNGs - true_dist = (Beta(1, 1), Gamma(10, 20), Bernoulli(0.8), NormalMeanVariance(-10, 0.1), Poisson(4.8)) - for dist in true_dist - data = rand(StableRNG(42), dist, 100) - divergences = [] - for norm in (0.0:0.01:1.0) + dists = (Beta(1, 1), Gamma(10, 20), Bernoulli(0.8), NormalMeanVariance(-10, 0.1), Poisson(4.8)) + + for dist in dists + rng = StableRNG(42) + data = rand(rng, dist, 4000) + + norm_bounds = [0.01, 0.1, 10.0] + + divergences = map(norm_bounds) do norm parameters = ProjectionParameters( direction = ExponentialFamilyProjection.BoundedNormUpdateRule(norm) ) - projection = ProjectedTo(ExponentialFamily.exponential_family_typetag.(dist), ()..., parameters = parameters) + projection = ProjectedTo(ExponentialFamily.exponential_family_typetag(dist), ()..., parameters = parameters) approximated = project_to(projection, data) - push!(divergences, kldivergence(approximated, dist)) + kldivergence(approximated, dist) end - - Δdivergences = divergences[1:end-1] - divergences[2:end] - for i = 1 : length(Δdivergences) - 1 - @test Δdivergences[i] ≠ Δdivergences[i + 1] + + @testset "true dist $(dist)" begin + @test issorted(divergences, rev=true) + @test (divergences[1] - divergences[end]) / divergences[1] > 0.05 end + end +end + +@testitem "MomentumGradient direction update rule on logpdf" begin + using BayesBase, ExponentialFamily, Distributions + using ExponentialFamilyProjection, ExponentialFamilyManifolds, Manopt, StableRNGs + + + true_dist = MvNormal([1.0, 2.0], [1.0 0.7; 0.7 2.0]) + logp = (x) -> logpdf(true_dist, x) + + manifold = ExponentialFamilyManifolds.get_natural_manifold(MvNormalMeanCovariance, (2,), nothing) + initialpoint = rand(manifold) + direction = MomentumGradient(manifold, initialpoint) + + momentum_parameters = ProjectionParameters( + direction = direction, + niterations = 1000, + tolerance = 1e-8 + ) + + projection = ProjectedTo(MvNormalMeanCovariance, 2, parameters=momentum_parameters) + + approximated = project_to(projection, logp, initialpoint = initialpoint) + + @test approximated isa MvNormalMeanCovariance + @test kldivergence(approximated, true_dist) < 0.01 + @test projection.parameters.direction isa MomentumGradient +end + +@testitem "MomentumGradient direction update rule on samples" begin + using BayesBase, ExponentialFamily, Distributions + using ExponentialFamilyProjection, ExponentialFamilyManifolds, Manopt, StableRNGs + true_dist = MvNormal([1.0, 2.0], [1.0 0.7; 0.7 2.0]) + rng = StableRNG(42) + samples = rand(rng, true_dist, 1000) + + manifold = ExponentialFamilyManifolds.get_natural_manifold(MvNormalMeanCovariance, (2,), nothing) + + initialpoint = rand(rng, manifold) + direction = MomentumGradient(manifold, initialpoint) + + momentum_parameters = ProjectionParameters( + direction = direction, + niterations = 1000, + tolerance = 1e-8 + ) + + projection = ProjectedTo(MvNormalMeanCovariance, 2, parameters=momentum_parameters) + approximated = project_to(projection, samples, initialpoint = initialpoint) + + @test approximated isa MvNormalMeanCovariance + @test kldivergence(approximated, true_dist) < 0.01 # Ensure good approximation + @test projection.parameters.direction isa MomentumGradient end \ No newline at end of file From 83f6b96a332278f77e609267c3c1f753af8b7894 Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Mon, 5 Aug 2024 11:32:41 +0200 Subject: [PATCH 4/4] style: :art: --- test/projection/projected_to_tests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/projection/projected_to_tests.jl b/test/projection/projected_to_tests.jl index 5e71548..7904b80 100644 --- a/test/projection/projected_to_tests.jl +++ b/test/projection/projected_to_tests.jl @@ -579,4 +579,4 @@ end @test approximated isa MvNormalMeanCovariance @test kldivergence(approximated, true_dist) < 0.01 # Ensure good approximation @test projection.parameters.direction isa MomentumGradient -end \ No newline at end of file +end