Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add extras #33

Merged
merged 3 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions src/projected_to.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The following arguments are optional:

* `conditioner = nothing`: a conditioner to use for the projection, not all exponential family members require a conditioner, but some do, e.g. `Laplace`
* `parameters = DefaultProjectionParameters`: parameters for the projection procedure
* `kwargs = nothing`: Additional arguments passed to `Manopt.gradient_descent!` (optional). For details on `gradient_descent!` parameters, see the [Manopt.jl documentation](https://manoptjl.org/stable/solvers/gradient_descent/#Manopt.gradient_descent). Note, that `kwargs` passed to `project_to` take precedence over `kwargs` specified in the parameters.

```jldoctest
julia> using ExponentialFamily
Expand All @@ -33,28 +34,32 @@ julia> projected_to = ProjectedTo(Laplace, conditioner = 2.0)
ProjectedTo(Laplace, conditioner = 2.0)
```
"""
struct ProjectedTo{T,D,C,P}
struct ProjectedTo{T,D,C,P,E}
dims::D
conditioner::C
parameters::P
kwargs::E
end

ProjectedTo(
dims::Vararg{Int};
conditioner = nothing,
parameters = DefaultProjectionParameters(),
kwargs = nothing,
) = ProjectedTo(
ExponentialFamilyDistribution,
dims...,
conditioner = conditioner,
parameters = parameters,
kwargs = kwargs,
)
function ProjectedTo(
::Type{T},
dims...;
conditioner::C = nothing,
parameters::P = DefaultProjectionParameters(),
) where {T,C,P}
kwargs::E = nothing,
) where {T,C,P,E}
# Check that `dims` are all integers
if !all(d -> typeof(d) <: Int, dims)
# If not, throw an error, also suggesting to use keyword arguments
Expand All @@ -65,13 +70,14 @@ function ProjectedTo(
end
error(msg)
end
return ProjectedTo{T,typeof(dims),C,P}(dims, conditioner, parameters)
return ProjectedTo{T,typeof(dims),C,P,E}(dims, conditioner, parameters, kwargs)
end

get_projected_to_type(::ProjectedTo{T}) where {T} = T
get_projected_to_dims(prj::ProjectedTo) = prj.dims
get_projected_to_conditioner(prj::ProjectedTo) = prj.conditioner
get_projected_to_parameters(prj::ProjectedTo) = prj.parameters
get_projected_to_kwargs(prj::ProjectedTo) = prj.kwargs
get_projected_to_manifold(prj::ProjectedTo) =
ExponentialFamilyManifolds.get_natural_manifold(
get_projected_to_type(prj),
Expand Down Expand Up @@ -268,9 +274,19 @@ function project_to(
supplementary_η,
)

# First we query the `kwargs` defined in the `ProjectionParameters`
prj_kwargs = get_projected_to_kwargs(prj)
prj_kwargs = isnothing(prj_kwargs) ? (;) : prj_kwargs
# And attach the `kwargs` passed to `project_to`, those may override
# some settings in the `ProjectionParameters`
if !isnothing(kwargs)
prj_kwargs = (; prj_kwargs..., kwargs...)
end
# We disable the default `debug` statements, which are set in `Manopt`
# in order to improve the performance a little bit
kwargs = !haskey(kwargs, :debug) ? (; kwargs..., debug = missing) : kwargs
if !haskey(prj_kwargs, :debug)
prj_kwargs = (; prj_kwargs..., debug = missing)
end

return _kernel_project_to(
get_projected_to_type(prj),
Expand All @@ -281,7 +297,7 @@ function project_to(
strategy,
state,
current_η,
kwargs,
prj_kwargs,
)
end

Expand Down
28 changes: 22 additions & 6 deletions test/projection/helpers/debug.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,27 @@

include("../projected_to_setuptests.jl")

function test_projection_with_debug(n_iterations, do_debug)
function test_projection_with_debug(n_iterations, do_debug, pass_to_prj = false)
distribution = Bernoulli(0.5)
buf = IOBuffer()
if do_debug
debug = [DebugCost(io=buf), DebugDivider("\n";io=buf)]
debug = [DebugCost(io = buf), DebugDivider("\n"; io = buf)]
else
debug = missing
end

projection_parameters = ProjectionParameters(niterations=n_iterations)
project_to(ProjectedTo(Bernoulli, parameters = projection_parameters), (x) -> logpdf(distribution, x); debug=debug)
projection_parameters = ProjectionParameters(niterations = n_iterations)
if pass_to_prj
prj = ProjectedTo(
Bernoulli,
parameters = projection_parameters,
kwargs = (debug = debug,),
)
project_to(prj, (x) -> logpdf(distribution, x))
else
prj = ProjectedTo(Bernoulli, parameters = projection_parameters)
project_to(prj, (x) -> logpdf(distribution, x); debug = debug)
end
debug_string = String(take!(buf))
if do_debug
lines = split(debug_string, '\n')
Expand All @@ -26,14 +36,20 @@
end

@testset "projections with debug" begin
for n in 1:10
for n = 1:10
test_projection_with_debug(n, true)
end
end

@testset "projections without debug" begin
for n in 1:10
for n = 1:10
test_projection_with_debug(n, false)
end
end

@testset "projections with debug passed through ProjectedTo" begin
for n = 1:10
test_projection_with_debug(n, true, true)
end
end
end
105 changes: 74 additions & 31 deletions test/projection/projected_to_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -431,35 +431,76 @@ end
# Do not produce debug output by default
@test_logs match_mode = :all project_to(prj, targetfn)
@test_logs match_mode = :all project_to(prj, targetfn, debug = [])


end

@testitem "kwargs in the `project_to` should take precedence over kwargs in `ProjectionParameters`" begin
using ExponentialFamilyProjection, StableRNGs, ExponentialFamily, Manopt, JET

@testset begin
rng = StableRNG(42)
prj = ProjectedTo(Beta; kwargs = (debug = missing,))
targetfn = (x) -> rand(rng) > 0.5 ? 1 : -1

@test_logs match_mode = :all project_to(prj, targetfn)
@test_logs (:warn, r"The cost increased.*") match_mode = :any project_to(
prj,
targetfn,
debug = [Manopt.DebugWarnIfCostIncreases()],
)
end

@testset begin
rng = StableRNG(42)
prj = ProjectedTo(Beta; kwargs = (debug = [Manopt.DebugWarnIfCostIncreases()],))
targetfn = (x) -> rand(rng) > 0.5 ? 1 : -1

@test_logs (:warn, r"The cost increased.*") match_mode = :any project_to(
prj,
targetfn,
)
@test_logs match_mode = :all project_to(prj, targetfn, debug = missing)
@test_logs match_mode = :all project_to(prj, targetfn, debug = [])
end

end

@testitem "Direction rule can improve for MLE" begin
@testitem "Direction rule can improve for MLE" begin
using BayesBase, ExponentialFamily, Distributions
using ExponentialFamilyProjection, StableRNGs

dists = (Beta(1, 1), Gamma(10, 20), Bernoulli(0.8), NormalMeanVariance(-10, 0.1), Poisson(4.8))

dists = (
Beta(1, 1),
Gamma(10, 20),
Bernoulli(0.8),
NormalMeanVariance(-10, 0.1),
Poisson(4.8),
)

for dist in dists
rng = StableRNG(42)
data = rand(rng, dist, 4000)

norm_bounds = [0.01, 0.1, 10.0]

divergences = map(norm_bounds) do norm
parameters = ProjectionParameters(
direction = ExponentialFamilyProjection.BoundedNormUpdateRule(norm)
direction = ExponentialFamilyProjection.BoundedNormUpdateRule(norm),
)
projection = ProjectedTo(
ExponentialFamily.exponential_family_typetag(dist),
()...,
parameters = parameters,
)
projection = ProjectedTo(ExponentialFamily.exponential_family_typetag(dist), ()..., parameters = parameters)
approximated = project_to(projection, data)
kldivergence(approximated, dist)
end

@testset "true dist $(dist)" begin
@test issorted(divergences, rev=true)
@test issorted(divergences, rev = true)
@test (divergences[1] - divergences[end]) / divergences[1] > 0.05
end

end
end

Expand All @@ -471,20 +512,21 @@ end
true_dist = MvNormal([1.0, 2.0], [1.0 0.7; 0.7 2.0])
logp = (x) -> logpdf(true_dist, x)

manifold = ExponentialFamilyManifolds.get_natural_manifold(MvNormalMeanCovariance, (2,), nothing)
manifold = ExponentialFamilyManifolds.get_natural_manifold(
MvNormalMeanCovariance,
(2,),
nothing,
)
initialpoint = rand(manifold)
direction = MomentumGradient(manifold, initialpoint)

momentum_parameters = ProjectionParameters(
direction = direction,
niterations = 1000,
tolerance = 1e-8
)
momentum_parameters =
ProjectionParameters(direction = direction, niterations = 1000, tolerance = 1e-8)

projection = ProjectedTo(MvNormalMeanCovariance, 2, parameters = momentum_parameters)

projection = ProjectedTo(MvNormalMeanCovariance, 2, parameters=momentum_parameters)

approximated = project_to(projection, logp, initialpoint = initialpoint)

@test approximated isa MvNormalMeanCovariance
@test kldivergence(approximated, true_dist) < 0.01
@test projection.parameters.direction isa MomentumGradient
Expand All @@ -497,21 +539,22 @@ end
true_dist = MvNormal([1.0, 2.0], [1.0 0.7; 0.7 2.0])
rng = StableRNG(42)
samples = rand(rng, true_dist, 1000)

manifold = ExponentialFamilyManifolds.get_natural_manifold(MvNormalMeanCovariance, (2,), nothing)


manifold = ExponentialFamilyManifolds.get_natural_manifold(
MvNormalMeanCovariance,
(2,),
nothing,
)

initialpoint = rand(rng, manifold)
direction = MomentumGradient(manifold, initialpoint)

momentum_parameters = ProjectionParameters(
direction = direction,
niterations = 1000,
tolerance = 1e-8
)

projection = ProjectedTo(MvNormalMeanCovariance, 2, parameters=momentum_parameters)

momentum_parameters =
ProjectionParameters(direction = direction, niterations = 1000, tolerance = 1e-8)

projection = ProjectedTo(MvNormalMeanCovariance, 2, parameters = momentum_parameters)
approximated = project_to(projection, samples, initialpoint = initialpoint)

@test approximated isa MvNormalMeanCovariance
@test kldivergence(approximated, true_dist) < 0.01 # Ensure good approximation
@test projection.parameters.direction isa MomentumGradient
Expand Down
Loading