Skip to content

Commit

Permalink
Merge pull request #31 from ReactiveBayes/direction_rule_as_parameter
Browse files Browse the repository at this point in the history
Direction rule as argument to projection parameters
  • Loading branch information
Nimrais authored Aug 5, 2024
2 parents 5856bae + 83f6b96 commit 85d7077
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 3 deletions.
7 changes: 5 additions & 2 deletions src/projected_to.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,15 @@ The following parameters are available:
* `niterations = 100`: The number of iterations for the optimization procedure.
* `tolerance = 1e-6`: The tolerance for the norm of the gradient.
* `stepsize = ConstantStepsize(0.1)`: The stepsize for the optimization procedure. Accepts stepsizes from `Manopt.jl`.
* `direction = BoundedNormUpdateRule(static(1.0)`: Direction update rule. Accepts `Manopt.DirectionUpdateRule` from `Manopt.jl`.
* `usebuffer = Val(true)`: Whether to use a buffer for the projection. Must be either `Val(true)` or `Val(false)`. Disabling buffer can be useful for debugging purposes.
"""
Base.@kwdef struct ProjectionParameters{S,I,T,P,B}
Base.@kwdef struct ProjectionParameters{S,I,T,P,D,B}
strategy::S = DefaultStrategy()
niterations::I = 100
tolerance::T = 1e-6
stepsize::P = ConstantStepsize(0.1)
direction::D = BoundedNormUpdateRule(static(1.0))
usebuffer::B = Val(true)
end

Expand All @@ -124,6 +126,7 @@ getstrategy(parameters::ProjectionParameters) = parameters.strategy
getniterations(parameters::ProjectionParameters) = parameters.niterations
gettolerance(parameters::ProjectionParameters) = parameters.tolerance
getstepsize(parameters::ProjectionParameters) = parameters.stepsize
getdirection(parameters::ProjectionParameters) = parameters.direction

with_buffer(f::F, parameters::ProjectionParameters) where {F} =
with_buffer(f, parameters.usebuffer, parameters)
Expand Down Expand Up @@ -292,7 +295,7 @@ function _kernel_project_to(
p;
stopping_criterion = get_stopping_criterion(parameters),
stepsize = getstepsize(parameters),
direction = BoundedNormUpdateRule(static(1)),
direction = getdirection(parameters),
kwargs...,
)

Expand Down
85 changes: 84 additions & 1 deletion test/projection/projected_to_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -496,4 +496,87 @@ end
@test_logs match_mode = :all project_to(prj, targetfn)
@test_logs match_mode = :all project_to(prj, targetfn, debug = [])

end
end

@testitem "Direction rule can improve for MLE" begin
using BayesBase, ExponentialFamily, Distributions
using ExponentialFamilyProjection, StableRNGs

dists = (Beta(1, 1), Gamma(10, 20), Bernoulli(0.8), NormalMeanVariance(-10, 0.1), Poisson(4.8))

for dist in dists
rng = StableRNG(42)
data = rand(rng, dist, 4000)

norm_bounds = [0.01, 0.1, 10.0]

divergences = map(norm_bounds) do norm
parameters = ProjectionParameters(
direction = ExponentialFamilyProjection.BoundedNormUpdateRule(norm)
)
projection = ProjectedTo(ExponentialFamily.exponential_family_typetag(dist), ()..., parameters = parameters)
approximated = project_to(projection, data)
kldivergence(approximated, dist)
end

@testset "true dist $(dist)" begin
@test issorted(divergences, rev=true)
@test (divergences[1] - divergences[end]) / divergences[1] > 0.05
end

end
end

@testitem "MomentumGradient direction update rule on logpdf" begin
using BayesBase, ExponentialFamily, Distributions
using ExponentialFamilyProjection, ExponentialFamilyManifolds, Manopt, StableRNGs


true_dist = MvNormal([1.0, 2.0], [1.0 0.7; 0.7 2.0])
logp = (x) -> logpdf(true_dist, x)

manifold = ExponentialFamilyManifolds.get_natural_manifold(MvNormalMeanCovariance, (2,), nothing)
initialpoint = rand(manifold)
direction = MomentumGradient(manifold, initialpoint)

momentum_parameters = ProjectionParameters(
direction = direction,
niterations = 1000,
tolerance = 1e-8
)

projection = ProjectedTo(MvNormalMeanCovariance, 2, parameters=momentum_parameters)

approximated = project_to(projection, logp, initialpoint = initialpoint)

@test approximated isa MvNormalMeanCovariance
@test kldivergence(approximated, true_dist) < 0.01
@test projection.parameters.direction isa MomentumGradient
end

@testitem "MomentumGradient direction update rule on samples" begin
using BayesBase, ExponentialFamily, Distributions
using ExponentialFamilyProjection, ExponentialFamilyManifolds, Manopt, StableRNGs

true_dist = MvNormal([1.0, 2.0], [1.0 0.7; 0.7 2.0])
rng = StableRNG(42)
samples = rand(rng, true_dist, 1000)

manifold = ExponentialFamilyManifolds.get_natural_manifold(MvNormalMeanCovariance, (2,), nothing)

initialpoint = rand(rng, manifold)
direction = MomentumGradient(manifold, initialpoint)

momentum_parameters = ProjectionParameters(
direction = direction,
niterations = 1000,
tolerance = 1e-8
)

projection = ProjectedTo(MvNormalMeanCovariance, 2, parameters=momentum_parameters)
approximated = project_to(projection, samples, initialpoint = initialpoint)

@test approximated isa MvNormalMeanCovariance
@test kldivergence(approximated, true_dist) < 0.01 # Ensure good approximation
@test projection.parameters.direction isa MomentumGradient
end

0 comments on commit 85d7077

Please sign in to comment.