diff --git a/docs/src/index.md b/docs/src/index.md index 52508d6..9108725 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -13,6 +13,7 @@ In order to project a log probability density function onto a member of the expo ```@docs ExponentialFamilyProjection.ProjectionParameters ExponentialFamilyProjection.DefaultProjectionParameters +ExponentialFamilyProjection.getinitialpoint ``` Read more about different optimization strategies [here](@ref opt-strategies). @@ -33,9 +34,6 @@ The projection is performed by calling the `project_to` function with the specif ExponentialFamilyProjection.project_to ``` -!!! note - Different strategies are compatible with different types of arguments. Read [Optimization strategies](@ref opt-strategies) section for more information. - ## [Optimization strategies](@id opt-strategies) The optimization procedure requires computing the expectation of the gradient to perform gradient descent in the natural parameters space. Currently, the library provides the following strategies for computing these expectations: @@ -45,6 +43,10 @@ ExponentialFamilyProjection.DefaultStrategy ExponentialFamilyProjection.ControlVariateStrategy ExponentialFamilyProjection.MLEStrategy ExponentialFamilyProjection.preprocess_strategy_argument +ExponentialFamilyProjection.create_state! +ExponentialFamilyProjection.prepare_state! +ExponentialFamilyProjection.compute_cost +ExponentialFamilyProjection.compute_gradient! ``` For high-dimensional distributions, adjusting the default number of samples might be necessary to achieve better performance. @@ -154,6 +156,10 @@ plot!(0.0:0.01:1.0, x -> pdf(result, x), label="estimated projection", fill = 0, ## Manopt extensions +```@docs +ExponentialFamilyProjection.ProjectionCostGradientObjective +``` + ### Bounded direction update rule The `ExponentialFamilyProjection.jl` package implements a specialized gradient direction rule that limits the norm (manifold-specific) of the gradient to a pre-specified value. diff --git a/src/ExponentialFamilyProjection.jl b/src/ExponentialFamilyProjection.jl index d1f14ef..58da1e4 100644 --- a/src/ExponentialFamilyProjection.jl +++ b/src/ExponentialFamilyProjection.jl @@ -22,19 +22,105 @@ __projection_fast_pack_parameters(t::NTuple{N,<:Number}) where {N} = t __projection_fast_pack_parameters(t) = ExponentialFamily.pack_parameters(t) include("manopt/bounded_norm_update_rule.jl") -include("cvi.jl") +include("manopt/projection_objective.jl") +include("projected_to.jl") """ preprocess_strategy_argument(strategy, argument) -Checks the compatibility of `strategy` with `argument` and returns a modified strategy if needed. +Checks the compatibility of `strategy` with `argument` and returns a modified strategy and argument if needed. """ function preprocess_strategy_argument end +""" + create_state!( + strategy, + M::AbstractManifold, + parameters::ProjectionParameters, + projection_argument, + initial_ef, + supplementary_η, + ) + +Creates, initializes and returns a state for the `strategy` with the given parameters. +""" +function create_state! end + +""" + prepare_state!( + strategy, + state, + M::AbstractManifold, + parameters::ProjectionParameters, + projection_argument, + distribution, + supplementary_η, + ) + +Prepares an existing `state` of the `strategy` for the new optimization iteration for use by setting or updating its internal parameters. +""" +function prepare_state! end + +""" + compute_cost( + M::AbstractManifold, + strategy, + state, + η, + logpartition, + gradlogpartition, + inv_fisher, + ) + +Compute the cost using the provided `strategy`. + +# Arguments +- `M::AbstractManifold`: The manifold on which the computations are performed. +- `strategy`: The strategy used for computation of the cost value. +- `state`: The current state for the `strategy`. +- `η`: Parameter vector. +- `logpartition`: The log partition of the current point (η). +- `gradlogpartition`: The gradient of the log partition of the current point (η). +- `inv_fisher`: The inverse Fisher information matrix of the current point (η). + +# Returns +- `cost`: The computed cost value. +""" +function compute_cost end + +""" + compute_gradient!( + M::AbstractManifold, + strategy, + state, + X, + η, + logpartition, + gradlogpartition, + inv_fisher, + ) + +Updates the gradient `X` in-place using the provided `strategy`. + +# Arguments +- `M::AbstractManifold`: The manifold on which the computations are performed. +- `strategy`: The strategy used for computation of the gradient value. +- `state`: The current state of the control variate strategy. +- `X`: The storage for the gradient. +- `η`: Parameter vector. +- `logpartition`: The log partition of the current point (η). +- `gradlogpartition`: The gradient of the log partition of the current point (η). +- `inv_fisher`: The inverse Fisher information matrix of the current point (η). + +# Returns +- `X`: The computed gradient (updated in-place) +""" +function compute_gradient! end + include("strategies/control_variate.jl") include("strategies/mle.jl") include("strategies/default.jl") -include("projected_to.jl") + end diff --git a/src/cvi.jl b/src/cvi.jl deleted file mode 100644 index 2778976..0000000 --- a/src/cvi.jl +++ /dev/null @@ -1,64 +0,0 @@ - -struct CVICostGradientObjective{F,P,S,B} - projection_argument::F - supplementary_η::P - strategy::S - buffer::B -end - -get_cvi_projection_argument(obj::CVICostGradientObjective) = obj.projection_argument -get_cvi_supplementary_η(obj::CVICostGradientObjective) = obj.supplementary_η -get_cvi_strategy(obj::CVICostGradientObjective) = obj.strategy -get_cvi_buffer(obj::CVICostGradientObjective) = obj.buffer - -function (objective::CVICostGradientObjective)(M::AbstractManifold, X, p) - ef = convert(ExponentialFamilyDistribution, M, p) - - strategy = get_cvi_strategy(objective) - state = prepare_state!( - M, - strategy, - objective.projection_argument, - ef, - objective.supplementary_η, - ) - - logpartition = ExponentialFamily.logpartition(ef) - gradlogpartition = ExponentialFamily.gradlogpartition(ef) - inv_fisher = cholinv(ExponentialFamily.fisherinformation(ef)) - η = copy(ExponentialFamily.getnaturalparameters(ef)) - - # If we have some supplementary natural parameters in the objective - # we must subtract them from the natural parameters of the current η - supplementary = get_cvi_supplementary_η(objective) - foreach(supplementary) do s_η - vmap!(-, η, η, s_η) - end - - c = compute_cost( - M, - objective, - strategy, - state, - η, - logpartition, - gradlogpartition, - inv_fisher, - ) - X = compute_gradient!( - M, - objective, - strategy, - state, - X, - η, - logpartition, - gradlogpartition, - inv_fisher, - ) - X = project!(M, X, p, X) - - return c, X -end - - diff --git a/src/manopt/projection_objective.jl b/src/manopt/projection_objective.jl new file mode 100644 index 0000000..26b3222 --- /dev/null +++ b/src/manopt/projection_objective.jl @@ -0,0 +1,90 @@ + +""" + ProjectionCostGradientObjective + +This structure provides an interface for `Manopt` to compute the cost and gradients required for the optimization procedure based on manifold projection. The actual computation of costs and gradients is defined by the `strategy` argument. + +# Arguments + +- `projection_parameters`: The parameters for projection, must be of type `ProjectionParameters` +- `projection_argument`: The second argument of the `project_to` function. +- `current_η`: Current optimization point. +- `supplementary_η`: A tuple of additional natural parameters subtracted from the current point in each optimization iteration. +- `strategy`: Specifies the method for computing costs and gradients, which may support different `projection_argument` values. +- `strategy_state`: The state for the `strategy`, usually created with `create_state!` + +!!! note + This structure is internal and is subject to change. +""" +struct ProjectionCostGradientObjective{J,F,C,P,S,T} + projection_parameters::J + projection_argument::F + current_η::C + supplementary_η::P + strategy::S + strategy_state::T +end + +get_projection_parameters(obj::ProjectionCostGradientObjective) = obj.projection_parameters +get_projection_argument(obj::ProjectionCostGradientObjective) = obj.projection_argument +get_current_η(obj::ProjectionCostGradientObjective) = obj.current_η +get_supplementary_η(obj::ProjectionCostGradientObjective) = obj.supplementary_η +get_strategy(obj::ProjectionCostGradientObjective) = obj.strategy +get_strategy_state(obj::ProjectionCostGradientObjective) = obj.strategy_state + +function (objective::ProjectionCostGradientObjective)(M::AbstractManifold, X, p) + current_η = copyto!(get_current_η(objective), p) + current_ef = convert(ExponentialFamilyDistribution, M, current_η) + + strategy = get_strategy(objective) + state = get_strategy_state(objective) + projection_parameters = get_projection_parameters(objective) + projection_argument = get_projection_argument(objective) + supplementary_η = get_supplementary_η(objective) + + state = prepare_state!( + strategy, + state, + M, + projection_parameters, + projection_argument, + current_ef, + supplementary_η, + ) + + logpartition = ExponentialFamily.logpartition(current_ef) + gradlogpartition = ExponentialFamily.gradlogpartition(current_ef) + inv_fisher = cholinv(ExponentialFamily.fisherinformation(current_ef)) + η = copy(ExponentialFamily.getnaturalparameters(current_ef)) + + # If we have some supplementary natural parameters in the objective + # we must subtract them from the natural parameters of the current η + foreach(supplementary_η) do s_η + vmap!(-, current_η, current_η, s_η) + end + + c = compute_cost( + M, + strategy, + state, + current_η, + logpartition, + gradlogpartition, + inv_fisher, + ) + X = compute_gradient!( + M, + strategy, + state, + X, + current_η, + logpartition, + gradlogpartition, + inv_fisher, + ) + X = project!(M, X, p, X) + + return c, X +end + + diff --git a/src/projected_to.jl b/src/projected_to.jl index 2d9398f..36548e7 100644 --- a/src/projected_to.jl +++ b/src/projected_to.jl @@ -103,16 +103,18 @@ The following parameters are available: * `niterations = 100`: The number of iterations for the optimization procedure. * `tolerance = 1e-6`: The tolerance for the norm of the gradient. * `stepsize = ConstantStepsize(0.1)`: The stepsize for the optimization procedure. Accepts stepsizes from `Manopt.jl`. +* `seed`: Optional; Seed for the `rng` +* `rng`: Optional; Random number generator * `direction = BoundedNormUpdateRule(static(1.0)`: Direction update rule. Accepts `Manopt.DirectionUpdateRule` from `Manopt.jl`. -* `usebuffer = Val(true)`: Whether to use a buffer for the projection. Must be either `Val(true)` or `Val(false)`. Disabling buffer can be useful for debugging purposes. """ -Base.@kwdef struct ProjectionParameters{S,I,T,P,D,B} +Base.@kwdef struct ProjectionParameters{S,I,T,P,D,N,U} strategy::S = DefaultStrategy() niterations::I = 100 tolerance::T = 1e-6 stepsize::P = ConstantStepsize(0.1) - direction::D = BoundedNormUpdateRule(static(1.0)) - usebuffer::B = Val(true) + seed::D = 42 + rng::N = StableRNG(seed) + direction::U = BoundedNormUpdateRule(static(1.0)) end """ @@ -126,23 +128,19 @@ getstrategy(parameters::ProjectionParameters) = parameters.strategy getniterations(parameters::ProjectionParameters) = parameters.niterations gettolerance(parameters::ProjectionParameters) = parameters.tolerance getstepsize(parameters::ProjectionParameters) = parameters.stepsize +getseed(parameters::ProjectionParameters) = parameters.seed +getrng(parameters::ProjectionParameters) = parameters.rng getdirection(parameters::ProjectionParameters) = parameters.direction -with_buffer(f::F, parameters::ProjectionParameters) where {F} = - with_buffer(f, parameters.usebuffer, parameters) - -with_buffer(f::F, buffer, ::ProjectionParameters) where {F} = f(buffer) -with_buffer(f::F, ::Val{false}, ::ProjectionParameters) where {F} = f(nothing) -with_buffer(f::F, ::Val{true}, ::ProjectionParameters) where {F} = - let buffer = MallocSlabBuffer() - try - f(buffer) - catch exception - rethrow(exception) - finally - free(buffer) - end - end +""" + getinitialpoint(strategy, M::AbstractManifold, parameters::ProjectionParameters) + +Returns an initial point to start optimization from. By default returns a `rand` point from `M`, +but different strategies may implement their own methods. +""" +function getinitialpoint(::Any, M::AbstractManifold, parameters::ProjectionParameters) + return rand(getrng(parameters), M) +end function Manopt.get_stopping_criterion(parameters::ProjectionParameters) return Manopt.get_stopping_criterion( @@ -179,24 +177,24 @@ end using Manopt, StaticTools """ - project_to(to::ProjectedTo, logf::F, supplementary..., initialpoint, kwargs...) + project_to(to::ProjectedTo, argument::F, supplementary..., initialpoint, kwargs...) -Finds the closest projection of `logf` onto the exponential family distribution specified by `to`. +Finds the closest projection of `argument` onto the exponential family distribution specified by `to`. # Arguments - `to::ProjectedTo`: Configuration for the projection. Refer to `ProjectedTo` for detailed information. -- `logf::F`: An (un-normalized) function representing the log-PDF of an arbitrary distribution. -- `supplementary...`: Additional distributions to project the product of `logf` and these distributions (optional). +- `argument::F`: An (un-normalized) function representing the log-PDF of an arbitrary distribution _or_ a list of samples. +- `supplementary...`: Additional distributions to project the product of `argument` and these distributions (optional). - `initialpoint`: Starting point for the optimization process (optional). - `kwargs...`: Additional arguments passed to `Manopt.gradient_descent!` (optional). For details on `gradient_descent!` parameters, see the [Manopt.jl documentation](https://manoptjl.org/stable/solvers/gradient_descent/#Manopt.gradient_descent). # Supplementary The `supplementary` distributions must match the type and conditioner of the target distribution specified in `to`. -Including supplementary distributions is equivalent to modified `logf` function as follows: +Including supplementary distributions is equivalent to modified `argument` function as follows: ```julia -f_modified = (x) -> logf(x) + logpdf(supplementary[1], x) + logpdf(supplementary[2], x) + ... +f_modified = (x) -> argument(x) + logpdf(supplementary[1], x) + logpdf(supplementary[2], x) + ... ``` ```jldoctest @@ -210,6 +208,21 @@ ProjectedTo(Beta) julia> project_to(prj, f) isa ExponentialFamily.Beta true ``` + +```jldoctest +julia> using ExponentialFamily, BayesBase, StableRNGs + +julia> samples = rand(StableRNG(42), Beta(30.14, 2.71), 1_000); + +julia> prj = ProjectedTo(Beta; parameters = ProjectionParameters(tolerance = 1e-2)) +ProjectedTo(Beta) + +julia> project_to(prj, samples) isa ExponentialFamily.Beta +true +``` + +!!! note + Different strategies are compatible with different types of arguments. Read optimization strategies section in the documentation for more information. """ function project_to( prj::ProjectedTo, @@ -219,7 +232,7 @@ function project_to( kwargs..., ) where {F} M = get_projected_to_manifold(prj) - parameters = get_projected_to_parameters(prj) + projection_parameters = get_projected_to_parameters(prj) # "Supplementary" natural parameters are parameters that are simply being subtracted # from the natural parameters of the current estiamted distribution. This might be useful @@ -239,36 +252,37 @@ function project_to( return copy(getnaturalparameters(supplementary_ef)) end - _strategy = preprocess_strategy_argument(getstrategy(parameters), projection_argument) - p = preprocess_initialpoint(initialpoint, M, _strategy) + strategy, projection_argument = preprocess_strategy_argument( + getstrategy(projection_parameters), + projection_argument, + ) + current_η = preprocess_initialpoint(initialpoint, strategy, M, projection_parameters) + current_ef = convert(ExponentialFamilyDistribution, M, current_η) - _state = prepare_state!( + state = create_state!( + strategy, M, - _strategy, + projection_parameters, projection_argument, - convert(ExponentialFamilyDistribution, M, p), + current_ef, supplementary_η, ) - strategy = with_state(_strategy, _state) - # We disable the default `debug` statements, which are set in `Manopt` # in order to improve the performance a little bit kwargs = !haskey(kwargs, :debug) ? (; kwargs..., debug = missing) : kwargs - return with_buffer(parameters) do buffer - return _kernel_project_to( - get_projected_to_type(prj), - M, - projection_argument, - supplementary_η, - strategy, - buffer, - parameters, - p, - kwargs, - ) - end + return _kernel_project_to( + get_projected_to_type(prj), + M, + projection_parameters, + projection_argument, + supplementary_η, + strategy, + state, + current_η, + kwargs, + ) end # see https://docs.julialang.org/en/v1/manual/performance-tips/#kernel-functions @@ -276,26 +290,34 @@ end function _kernel_project_to( ::Type{T}, M, + projection_parameters, projection_argument, supplementary_η, strategy, - buffer, - parameters, - p, + state, + current_η, kwargs, ) where {T} - g_grad_g! = - CVICostGradientObjective(projection_argument, supplementary_η, strategy, buffer) + g_grad_g! = ProjectionCostGradientObjective( + projection_parameters, + projection_argument, + copy(current_η), + supplementary_η, + strategy, + state, + ) objective = ManifoldCostGradientObjective(g_grad_g!; evaluation = InplaceEvaluation()) - q = p # `gradient_descent!` overrides `q` + # `gradient_descent!` is a type-unstable call, so better not to use `q = gradient_descent!` + # `gradient_descent!` will override `q` instead + q = current_η gradient_descent!( M, objective, - p; - stopping_criterion = get_stopping_criterion(parameters), - stepsize = getstepsize(parameters), - direction = getdirection(parameters), + current_η; + stopping_criterion = get_stopping_criterion(projection_parameters), + stepsize = getstepsize(projection_parameters), + direction = getdirection(projection_parameters), kwargs..., ) @@ -304,35 +326,44 @@ end # This function preprocess the initial point for the projection # If the initial point is not provided, it generates a new one with the `getinitialpoint` function -function preprocess_initialpoint(initialpoint::Nothing, M, strategy) - return getinitialpoint(strategy, M) +function preprocess_initialpoint(initialpoint::Nothing, strategy, M, parameters) + return getinitialpoint(strategy, M, parameters) end -function preprocess_initialpoint(initialpoint::Any, M, strategy) +function preprocess_initialpoint(initialpoint::Any, strategy, M, parameters) return preprocess_initialpoint( ExponentialFamily.exponential_family_typetag(M), initialpoint, - M, strategy, + M, + parameters, ) end # If the initial point is provided as the distribution type which we project on to, # we generate a new initial point using the `naturalparameters` of the distribution -function preprocess_initialpoint(::Type{T}, initialpoint::T, M, strategy) where {T} +function preprocess_initialpoint( + ::Type{T}, + initialpoint::T, + strategy, + M, + parameters, +) where {T} return preprocess_initialpoint( T, convert(ExponentialFamilyDistribution, initialpoint), - M, strategy, + M, + parameters, ) end function preprocess_initialpoint( ::Type{T}, initialpoint::ExponentialFamilyDistribution{T}, - M, strategy, + M, + parameters, ) where {T} return ExponentialFamilyManifolds.partition_point( M, @@ -341,6 +372,6 @@ function preprocess_initialpoint( end # Otherwise we just copy the initial point, since we use it for the optimization in place -function preprocess_initialpoint(_, initialpoint::AbstractArray, M, strategy) +function preprocess_initialpoint(_, initialpoint::AbstractArray, strategy, M, parameters) return copy(initialpoint) end \ No newline at end of file diff --git a/src/strategies/control_variate.jl b/src/strategies/control_variate.jl index 10fe88c..39c2c68 100644 --- a/src/strategies/control_variate.jl +++ b/src/strategies/control_variate.jl @@ -1,5 +1,6 @@ -using StableRNGs, LoopVectorization, Bumper, FillArrays +using StableRNGs, LoopVectorization, Bumper, FillArrays, StaticTools +import Random: AbstractRNG import BayesBase: InplaceLogpdf """ @@ -9,66 +10,29 @@ A strategy for gradient descent optimization and gradients computations that res The following parameters are available: * `nsamples = 2000`: The number of samples to use for estimates -* `seed = 42`: The seed for the random number generator -* `rng = StableRNG(seed)`: The random number generator +* `buffer = StaticTools.MallocSlabBuffer()`: Advanced option; A buffer for temporary computations !!! note This strategy requires a function as an argument for `project_to` and cannot project a collection of samples. Use `MLEStrategy` to project a collection of samples. """ -Base.@kwdef struct ControlVariateStrategy{S,D,N,T} +Base.@kwdef struct ControlVariateStrategy{S,B} nsamples::S = 2000 - seed::D = 42 - rng::N = StableRNG(seed) - state::T = nothing + buffer::B = StaticTools.MallocSlabBuffer() end -getnsamples(strategy::ControlVariateStrategy) = strategy.nsamples -getseed(strategy::ControlVariateStrategy) = strategy.seed -getrng(strategy::ControlVariateStrategy) = strategy.rng -getstate(strategy::ControlVariateStrategy) = strategy.state +get_nsamples(strategy::ControlVariateStrategy) = strategy.nsamples +get_buffer(strategy::ControlVariateStrategy) = strategy.buffer function Base.:(==)(a::ControlVariateStrategy, b::ControlVariateStrategy)::Bool - return getnsamples(a) == getnsamples(b) && - getseed(a) == getseed(b) && - getrng(a) == getrng(b) && - getstate(a) == getstate(b) + return get_nsamples(a) == get_nsamples(b) && get_buffer(a) == get_buffer(b) end -function getinitialpoint(strategy::ControlVariateStrategy, M::AbstractManifold) - return rand(getrng(strategy), M) -end - -function with_state(strategy::ControlVariateStrategy, state) - return ControlVariateStrategy( - nsamples = getnsamples(strategy), - seed = getseed(strategy), - rng = getrng(strategy), - state = state, - ) -end - -preprocess_strategy_argument(strategy::ControlVariateStrategy, argument::Any) = strategy +preprocess_strategy_argument(strategy::ControlVariateStrategy, argument::Any) = + (strategy, convert(InplaceLogpdf, argument)) preprocess_strategy_argument(::ControlVariateStrategy, argument::AbstractArray) = error( lazy"The `ControlVariateStrategy` requires the projection argument to be a callable object (e.g. `Function`). Got `$(typeof(argument))` instead.", ) -function prepare_state!( - M::AbstractManifold, - strategy::ControlVariateStrategy, - projection_argument::F, - distribution, - supplementary_η, -) where {F} - return prepare_state!( - M, - getstate(strategy), - strategy, - convert(InplaceLogpdf, projection_argument), - distribution, - supplementary_η, - ) -end - Base.@kwdef struct ControlVariateStrategyState{M,L,LB,F,G} samples::M logpdfs::L @@ -85,33 +49,32 @@ function Base.:(==)(a::ControlVariateStrategyState, b::ControlVariateStrategySta a.gradsamples == b.gradsamples end -getsamples(state::ControlVariateStrategyState) = state.samples -getlogpdfs(state::ControlVariateStrategyState) = state.logpdfs -getlogbasemeasures(state::ControlVariateStrategyState) = state.logbasemeasures -getsufficientstatistics(state::ControlVariateStrategyState) = state.sufficientstatistics -getgradsamples(state::ControlVariateStrategyState) = state.gradsamples +get_samples(state::ControlVariateStrategyState) = state.samples +get_logpdfs(state::ControlVariateStrategyState) = state.logpdfs +get_logbasemeasures(state::ControlVariateStrategyState) = state.logbasemeasures +get_sufficientstatistics(state::ControlVariateStrategyState) = state.sufficientstatistics +get_gradsamples(state::ControlVariateStrategyState) = state.gradsamples -function prepare_state!( - M::AbstractManifold, - ::Nothing, +function create_state!( strategy::ControlVariateStrategy, - projection_argument::InplaceLogpdf, - distribution, + M::AbstractManifold, + parameters::ProjectionParameters, + projection_argument, + initial_ef, supplementary_η, ) # If the `state` saved in `ControlVariateStrategy` is `nothing` # we simply create new containers for the samples, logpdfs, etc. - nsamples = getnsamples(strategy) - rng = getrng(strategy) - samples = prepare_samples_container(rng, distribution, nsamples, supplementary_η) - logpdfs = prepare_logpdfs_container(rng, distribution, nsamples, supplementary_η) + nsamples = get_nsamples(strategy) + rng = getrng(parameters) + samples = prepare_samples_container(rng, initial_ef, nsamples, supplementary_η) + logpdfs = prepare_logpdfs_container(rng, initial_ef, nsamples, supplementary_η) logbasemeasures = - prepare_logbasemeasures_container(rng, distribution, nsamples, supplementary_η) + prepare_logbasemeasures_container(rng, initial_ef, nsamples, supplementary_η) sufficientstatistics = - prepare_sufficientstatistics_container(rng, distribution, nsamples, supplementary_η) - gradsamples = - prepare_gradsamples_container(rng, distribution, nsamples, supplementary_η) + prepare_sufficientstatistics_container(rng, initial_ef, nsamples, supplementary_η) + gradsamples = prepare_gradsamples_container(rng, initial_ef, nsamples, supplementary_η) state = ControlVariateStrategyState( samples = samples, @@ -122,11 +85,12 @@ function prepare_state!( ) return prepare_state!( - M, - state, strategy, + state, + M, + parameters, projection_argument, - distribution, + initial_ef, supplementary_η, ) end @@ -177,31 +141,33 @@ prepare_logbasemeasures_container( ) = zeros(paramfloattype(distribution), nsamples) function prepare_state!( - M::AbstractManifold, - state::ControlVariateStrategyState, strategy::ControlVariateStrategy, - projection_argument::InplaceLogpdf, - distribution, + state::ControlVariateStrategyState, + M::AbstractManifold, + parameters::ProjectionParameters, + projection_argument, + current_ef, supplementary_η, ) # We need to reset the RNG state every time we prepare the state # This is important not only for reproducibility, but also to ensure # that the gradient computation is stable - Random.seed!(getrng(strategy), getseed(strategy)) - Random.rand!(getrng(strategy), distribution, state.samples) + Random.seed!(getrng(parameters), getseed(parameters)) + Random.rand!(getrng(parameters), current_ef, get_samples(state)) - _, sample_container = ExponentialFamily.check_logpdf(distribution, state.samples) + _, sample_container = ExponentialFamily.check_logpdf(current_ef, get_samples(state)) - glogpartion = ExponentialFamily.gradlogpartition(distribution) - J = size(state.gradsamples, 1) + glogpartion = ExponentialFamily.gradlogpartition(current_ef) + J = size(get_gradsamples(state), 1) - projection_argument(state.logpdfs, sample_container) + inplace_projection_argument = convert(BayesBase.InplaceLogpdf, projection_argument) + inplace_projection_argument(get_logpdfs(state), sample_container) one_minus_n_of_supplementary = 1 - length(supplementary_η) nonconstantbasemeasure = - ExponentialFamily.isbasemeasureconstant(distribution) === NonConstantBaseMeasure() + ExponentialFamily.isbasemeasureconstant(current_ef) === NonConstantBaseMeasure() foreach(enumerate(sample_container)) do (i, sample) # if `basemeasure` is constant we assume that @@ -209,11 +175,11 @@ function prepare_state!( if nonconstantbasemeasure @inbounds state.logbasemeasures[i] = one_minus_n_of_supplementary * - ExponentialFamily.logbasemeasure(distribution, sample) + ExponentialFamily.logbasemeasure(current_ef, sample) end sufficientstatistics = __projection_fast_pack_parameters( - ExponentialFamily.sufficientstatistics(distribution, sample), + ExponentialFamily.sufficientstatistics(current_ef, sample), ) @inbounds logpdf = state.logpdfs[i] @@ -230,7 +196,6 @@ end function compute_cost( M::AbstractManifold, - obj::CVICostGradientObjective, strategy::ControlVariateStrategy, state::ControlVariateStrategyState, η, @@ -244,7 +209,6 @@ end function compute_gradient!( M::AbstractManifold, - obj::CVICostGradientObjective, strategy::ControlVariateStrategy, state::ControlVariateStrategyState, X, @@ -253,10 +217,9 @@ function compute_gradient!( gradlogpartition, inv_fisher, ) - buffer = get_cvi_buffer(obj) + buffer = get_buffer(strategy) if isnothing(buffer) return control_variate_compute_gradient!( - obj, strategy, state, X, @@ -267,7 +230,6 @@ function compute_gradient!( ) else return control_variate_compute_gradient_buffered!( - obj, strategy, state, X, @@ -280,7 +242,6 @@ function compute_gradient!( end function control_variate_compute_gradient!( - obj::CVICostGradientObjective, strategy::ControlVariateStrategy, state::ControlVariateStrategyState, X, @@ -303,7 +264,6 @@ function control_variate_compute_gradient!( end function control_variate_compute_gradient_buffered!( - obj::CVICostGradientObjective, strategy::ControlVariateStrategy, state::ControlVariateStrategyState, X, @@ -315,8 +275,8 @@ function control_variate_compute_gradient_buffered!( # This code is a bit involved, more comments are added # The `@no_escape` macro simplifies writing non-allocating code, it allows # to create intermediate buffers which will be freed immediatelly upon exiting the block - # uses the buffer from `get_cvi_buffer(obj)` so buffer must be relatively big - buffer = get_cvi_buffer(obj) + # uses the buffer from `get_buffer(obj)` so buffer must be relatively big + buffer = get_buffer(strategy) @no_escape buffer begin # First we compute the `cov` between `state.sufficientstatistics'` and `state.gradsamples'` diff --git a/src/strategies/default.jl b/src/strategies/default.jl index a78ea18..099562f 100644 --- a/src/strategies/default.jl +++ b/src/strategies/default.jl @@ -12,5 +12,7 @@ Rules: """ struct DefaultStrategy end -preprocess_strategy_argument(::DefaultStrategy, argument::AbstractArray) = MLEStrategy() -preprocess_strategy_argument(::DefaultStrategy, argument::Any) = ControlVariateStrategy() \ No newline at end of file +preprocess_strategy_argument(::DefaultStrategy, argument::AbstractArray) = + preprocess_strategy_argument(MLEStrategy(), argument) +preprocess_strategy_argument(::DefaultStrategy, argument::Any) = + preprocess_strategy_argument(ControlVariateStrategy(), argument) \ No newline at end of file diff --git a/src/strategies/mle.jl b/src/strategies/mle.jl index c8af7d0..5fb28e5 100644 --- a/src/strategies/mle.jl +++ b/src/strategies/mle.jl @@ -1,61 +1,21 @@ using ForwardDiff, LoopVectorization """ - MLEStrategy(; kwargs...) + MLEStrategy() A strategy for gradient descent optimization and gradients computations that resembles MLE estimation. -The following parameters are available: -* `seed = 42`: The seed for the random number generator -* `rng = StableRNG(seed)`: The random number generator - !!! note This strategy requires a collection of samples as an argument for `project_to` and cannot project a function. Use `ControlVariateStrategy` to project a function. """ -Base.@kwdef struct MLEStrategy{D,N,T} - seed::D = 42 - rng::N = StableRNG(seed) - state::T = nothing -end - -getseed(strategy::MLEStrategy) = strategy.seed -getrng(strategy::MLEStrategy) = strategy.rng -getstate(strategy::MLEStrategy) = strategy.state - -function Base.:(==)(a::MLEStrategy, b::MLEStrategy)::Bool - return getseed(a) == getseed(b) && getrng(a) == getrng(b) && getstate(a) == getstate(b) -end - -function getinitialpoint(strategy::MLEStrategy, M::AbstractManifold) - return rand(getrng(strategy), M) -end - -function with_state(strategy::MLEStrategy, state) - return MLEStrategy(seed = getseed(strategy), rng = getrng(strategy), state = state) -end +struct MLEStrategy end -preprocess_strategy_argument(strategy::MLEStrategy, argument::AbstractArray) = strategy +preprocess_strategy_argument(strategy::MLEStrategy, argument::AbstractArray) = + (strategy, argument) preprocess_strategy_argument(::MLEStrategy, argument::Any) = error( lazy"`MLEStrategy` requires the projection argument to be an array of samples. Got `$(typeof(argument))` instead.", ) -function prepare_state!( - M::AbstractManifold, - strategy::MLEStrategy, - projection_argument::S, - distribution, - supplementary_η, -) where {S} - return prepare_state!( - M, - getstate(strategy), - strategy, - projection_argument, - distribution, - supplementary_η, - ) -end - Base.@kwdef struct MLEStrategyState{F,C,G} targetfn::F config::C @@ -70,20 +30,20 @@ gettargetfn(state::MLEStrategyState) = state.targetfn getconfig(state::MLEStrategyState) = state.config gettmpgrad(state::MLEStrategyState) = state.tmpgrad -function prepare_state!( - M::AbstractManifold, - ::Nothing, +function create_state!( strategy::MLEStrategy, - samples, - distribution, + M::AbstractManifold, + parameters::ProjectionParameters, + samples::AbstractArray, + initial_ef, supplementary_η, ) - _, sample_container = ExponentialFamily.check_logpdf(distribution, samples) + _, sample_container = ExponentialFamily.check_logpdf(initial_ef, samples) # Our samples are fixed, thus we can precompute all the `sufficientstatistics` once sufficientstatistics = zeros( - paramfloattype(distribution), - length(getnaturalparameters(distribution)), + paramfloattype(initial_ef), + length(getnaturalparameters(initial_ef)), length(samples), ) @@ -91,7 +51,7 @@ function prepare_state!( foreach(enumerate(sample_container)) do (i, sample) sample_sufficientstatistics = __projection_fast_pack_parameters( - ExponentialFamily.sufficientstatistics(distribution, sample), + ExponentialFamily.sufficientstatistics(initial_ef, sample), ) @turbo warn_check_args = false for j = 1:J @inbounds sufficientstatistics[j, i] = sample_sufficientstatistics[j] @@ -99,17 +59,18 @@ function prepare_state!( end targetfn = MLETargetFn(M, samples, sufficientstatistics) - config = ForwardDiff.GradientConfig(targetfn, getnaturalparameters(distribution)) - tmpgrad = ForwardDiff.gradient(targetfn, getnaturalparameters(distribution), config) + config = ForwardDiff.GradientConfig(targetfn, getnaturalparameters(initial_ef)) + tmpgrad = ForwardDiff.gradient(targetfn, getnaturalparameters(initial_ef), config) return MLEStrategyState(targetfn, config, tmpgrad) end function prepare_state!( - M::AbstractManifold, - state::MLEStrategyState, strategy::MLEStrategy, - samples, - distribution, + state::MLEStrategyState, + M::AbstractManifold, + parameters::ProjectionParameters, + samples::AbstractArray, + current_ef, supplementary_η, ) return state @@ -140,7 +101,6 @@ end function compute_cost( M::AbstractManifold, - obj::CVICostGradientObjective, strategy::MLEStrategy, state::MLEStrategyState, η, @@ -153,7 +113,6 @@ end function compute_gradient!( M::AbstractManifold, - obj::CVICostGradientObjective, strategy::MLEStrategy, state::MLEStrategyState, X, diff --git a/test/projection/projected_to_setuptests.jl b/test/projection/projected_to_setuptests.jl index 62e45ed..9fb71f4 100644 --- a/test/projection/projected_to_setuptests.jl +++ b/test/projection/projected_to_setuptests.jl @@ -102,11 +102,11 @@ function test_convergence_nsamples( parameters = ProjectionParameters( strategy = ExponentialFamilyProjection.ControlVariateStrategy( nsamples = nsamples, - seed = rand(nsamples_rng, UInt), ), niterations = nsamples_niterations, tolerance = nsamples_tolerance, stepsize = nsamples_stepsize, + seed = rand(nsamples_rng, UInt), ) projection = ProjectedTo(T, dims..., parameters = parameters, conditioner = conditioner) @@ -177,11 +177,11 @@ function test_convergence_niterations( parameters = ProjectionParameters( strategy = ExponentialFamilyProjection.ControlVariateStrategy( nsamples = niterations_nsamples, - seed = rand(niterations_rng, UInt), ), niterations = niterations, tolerance = niterations_tolerance, stepsize = niterations_stepsize, + seed = rand(niterations_rng, UInt), ) projection = ProjectedTo(T, dims..., parameters = parameters, conditioner = conditioner) @@ -226,12 +226,11 @@ function test_convergence_niterations_mle( experiment = map(niterations_range) do niterations data = rand(niterations_rng, distribution, niterations_nsamples) parameters = ProjectionParameters( - strategy = ExponentialFamilyProjection.MLEStrategy( - seed = rand(niterations_rng, UInt), - ), + strategy = ExponentialFamilyProjection.MLEStrategy(), niterations = niterations, tolerance = niterations_tolerance, stepsize = niterations_stepsize, + seed = rand(niterations_rng, UInt), ) projection = ProjectedTo(T, dims..., parameters = parameters, conditioner = conditioner) @@ -276,12 +275,11 @@ function test_convergence_nsamples_mle( experiment = map(nsamples_range) do nsamples data = rand(nsamples_rng, distribution, nsamples) parameters = ProjectionParameters( - strategy = ExponentialFamilyProjection.MLEStrategy( - seed = rand(nsamples_rng, UInt), - ), + strategy = ExponentialFamilyProjection.MLEStrategy(), niterations = nsamples_niterations, tolerance = nsamples_tolerance, stepsize = nsamples_stepsize, + seed = rand(nsamples_rng, UInt), ) projection = diff --git a/test/projection/projected_to_tests.jl b/test/projection/projected_to_tests.jl index 7904b80..96a85ef 100644 --- a/test/projection/projected_to_tests.jl +++ b/test/projection/projected_to_tests.jl @@ -121,71 +121,6 @@ end end end -@testitem "ProjectionParameters usebuffer" begin - using Bumper - import ExponentialFamilyProjection: getstepsize, with_buffer - - parameters = ProjectionParameters(usebuffer = Val(true)) - - result = with_buffer(parameters) do buffer - @test buffer !== nothing - @no_escape buffer begin - container = @alloc(Float64, 10) - @test length(container) === 10 - end - return "asd" - end - @test result == "asd" - - parameters = ProjectionParameters(usebuffer = Val(false)) - - result = with_buffer(parameters) do buffer - @test buffer === nothing - return "dsa" - end - @test result == "dsa" - -end - -@testitem "Projection result should not depend on the usage of buffer" begin - using ExponentialFamily, BayesBase - distributions = [ - Beta(10, 10), - Gamma(10, 10), - Exponential(1), - LogNormal(0, 1), - Dirichlet([1, 1]), - NormalMeanVariance(0.0, 1.0), - MvNormalMeanCovariance([0.0, 0.0], [1.0 0.0; 0.0 1.0]), - Chisq(30.0) - ] - - for distribution in distributions - parameters_with_buffer = ProjectionParameters(usebuffer = Val(true)) - parameters_without_buffer = ProjectionParameters(usebuffer = Val(false)) - - dims = size(rand(distribution)) - - prj_with_buffer = ProjectedTo( - ExponentialFamily.exponential_family_typetag(distribution), - dims...; - parameters = parameters_with_buffer, - ) - prj_without_buffer = ProjectedTo( - ExponentialFamily.exponential_family_typetag(distribution), - dims...; - parameters = parameters_without_buffer, - ) - - targetfn = (x) -> logpdf(distribution, x) - result_with_buffer = project_to(prj_with_buffer, targetfn) - result_without_buffer = project_to(prj_without_buffer, targetfn) - - # Small differences are allowed due to different LinearAlgebra routines - @test result_with_buffer ≈ result_without_buffer - end -end - @testitem "Projection should support supplementary natural parameters" begin using ExponentialFamily, StableRNGs, BayesBase, Distributions @@ -193,9 +128,10 @@ end prj = ProjectedTo( Beta, parameters = ProjectionParameters( + strategy = ExponentialFamilyProjection.ControlVariateStrategy(), tolerance = 1e-4, niterations = 300, - strategy = ExponentialFamilyProjection.ControlVariateStrategy(rng = rng), + rng = rng, ), ) diff --git a/test/strategies/control_variate_tests.jl b/test/strategies/control_variate_tests.jl index 16c9fed..cb99afe 100644 --- a/test/strategies/control_variate_tests.jl +++ b/test/strategies/control_variate_tests.jl @@ -1,86 +1,74 @@ @testitem "ControlVariateStrategy generic properties" begin using Random, - Bumper, LinearAlgebra, Distributions, ExponentialFamily, ExponentialFamilyManifolds + BayesBase, + Bumper, + LinearAlgebra, + Distributions, + ExponentialFamily, + ExponentialFamilyManifolds import ExponentialFamilyProjection: ControlVariateStrategy, - getnsamples, - getseed, - getrng, - getbuffer, - getstate, + ProjectionParameters, + get_nsamples, + get_buffer, + create_state!, prepare_state! - @test ControlVariateStrategy() == ControlVariateStrategy() - @test ControlVariateStrategy(nsamples = 100) == ControlVariateStrategy(nsamples = 100) - @test ControlVariateStrategy(seed = 42) == ControlVariateStrategy(seed = 42) - @test ControlVariateStrategy(rng = MersenneTwister(42)) == - ControlVariateStrategy(rng = MersenneTwister(42)) - - @test ControlVariateStrategy(nsamples = 50) !== ControlVariateStrategy(nsamples = 100) - @test ControlVariateStrategy(seed = 41) !== ControlVariateStrategy(seed = 42) - @test ControlVariateStrategy(rng = MersenneTwister(41)) !== - ControlVariateStrategy(rng = MersenneTwister(42)) + @test ControlVariateStrategy() !== ControlVariateStrategy() # buffers are different + @test ControlVariateStrategy(nsamples = 100, buffer = nothing) == + ControlVariateStrategy(nsamples = 100, buffer = nothing) + buffer = Bumper.default_buffer() + @test ControlVariateStrategy(nsamples = 100, buffer = buffer) == + ControlVariateStrategy(nsamples = 100, buffer = buffer) + @test ControlVariateStrategy(nsamples = 50, buffer = nothing) !== + ControlVariateStrategy(nsamples = 100, buffer = nothing) + @test ControlVariateStrategy(nsamples = 50, buffer = buffer) !== + ControlVariateStrategy(nsamples = 100, buffer = buffer) @testset "nsamples" begin strategy = ControlVariateStrategy(nsamples = 100) - @test getnsamples(strategy) === 100 + @test get_nsamples(strategy) === 100 strategy = ControlVariateStrategy(nsamples = 200) - @test getnsamples(strategy) === 200 + @test get_nsamples(strategy) === 200 end - @testset "seed" begin - strategy = ControlVariateStrategy(seed = 42) + @testset "buffer" begin + strategy = ControlVariateStrategy(buffer = Bumper.default_buffer()) - @test getseed(strategy) === 42 + @test get_buffer(strategy) === Bumper.default_buffer() - strategy = ControlVariateStrategy(seed = 24) + strategy = ControlVariateStrategy(buffer = nothing) - @test getseed(strategy) === 24 + @test get_buffer(strategy) === nothing end - @testset "rng" begin - rng1 = MersenneTwister(42) - rng2 = MersenneTwister(24) - strategy = ControlVariateStrategy(rng = rng1) - - @test getrng(strategy) === rng1 - @test getrng(strategy) !== rng2 - - strategy = ControlVariateStrategy(rng = rng2) - - @test getrng(strategy) !== rng1 - @test getrng(strategy) === rng2 - end - - @testset "state" begin + @testset "create_state!" begin distributions = [Beta(5, 5), Chisq(10)] + parameters = ProjectionParameters() for dist in distributions ef = convert(ExponentialFamilyDistribution, Beta(5, 5)) T = ExponentialFamily.exponential_family_typetag(ef) d = size(mean(ef)) c = getconditioner(ef) M = ExponentialFamilyManifolds.get_natural_manifold(T, d, c) - state1 = prepare_state!(M, ControlVariateStrategy(), (x) -> 1, ef, ()) - state2 = prepare_state!(M, ControlVariateStrategy(), (x) -> 1, ef, ()) + arg = (x) -> 1 + state1 = create_state!(ControlVariateStrategy(), M, parameters, arg, ef, ()) + state2 = create_state!(ControlVariateStrategy(), M, parameters, arg, ef, ()) @test state1 == state2 - @test ControlVariateStrategy(state = state1) == - ControlVariateStrategy(state = state2) - state1 = prepare_state!(M, ControlVariateStrategy(), (x) -> 1, ef, (ef,)) - state2 = prepare_state!(M, ControlVariateStrategy(), (x) -> 1, ef, (ef,)) + state1 = create_state!(ControlVariateStrategy(), M, parameters, arg, ef, (ef,)) + state2 = create_state!(ControlVariateStrategy(), M, parameters, arg, ef, (ef,)) @test state1 == state2 - state1 = prepare_state!(M, ControlVariateStrategy(), (x) -> 1, ef, (ef, ef)) - state2 = prepare_state!(M, ControlVariateStrategy(), (x) -> 1, ef, (ef, ef)) + state1 = + create_state!(ControlVariateStrategy(), M, parameters, arg, ef, (ef, ef)) + state2 = + create_state!(ControlVariateStrategy(), M, parameters, arg, ef, (ef, ef)) @test state1 == state2 - - state1 = prepare_state!(M, ControlVariateStrategy(), (x) -> 1, ef, ()) - state2 = prepare_state!(M, ControlVariateStrategy(), (x) -> 2, ef, ()) - @test state1 != state2 end end end @@ -96,12 +84,13 @@ end import ExponentialFamilyProjection: ControlVariateStrategy, ControlVariateStrategyState, + create_state!, prepare_state!, - getsamples, - getlogpdfs, - getlogbasemeasures, - getsufficientstatistics, - getgradsamples + get_samples, + get_logpdfs, + get_logbasemeasures, + get_sufficientstatistics, + get_gradsamples dists = [ NormalMeanVariance(0, 1), @@ -112,6 +101,7 @@ end ] for dist in dists + targetfn1 = let dist = dist (x) -> logpdf(dist, x) end @@ -137,12 +127,13 @@ end @testset "Empty state should create new state every time (no supplementary_η)" begin rng = StableRNG(42) - strategy = - ControlVariateStrategy(rng = rng, nsamples = nsamples, state = nothing) + parameters = ProjectionParameters(rng = rng) + strategy = ControlVariateStrategy(nsamples = nsamples) - @test_opt ignored_modules = (Base, LinearAlgebra, Distributions) prepare_state!( - M, + @test_opt ignored_modules = (Base, LinearAlgebra, Distributions) create_state!( strategy, + M, + parameters, targetfn, ef, supplementary_η, @@ -179,25 +170,28 @@ end supplementary_η, ) - state1 = prepare_state!(M, strategy, targetfn, ef, supplementary_η) - state2 = prepare_state!(M, strategy, targetfn, ef, supplementary_η) + state1 = + create_state!(strategy, M, parameters, targetfn, ef, supplementary_η) + state2 = + create_state!(strategy, M, parameters, targetfn, ef, supplementary_η) + @test state1 == state2 @test state1 !== state2 # `==` check that the content of the arrays are similar # `!==` checks that the arrays are different in memory - @test getsamples(state1) == getsamples(state2) - @test getsamples(state1) !== getsamples(state2) - @test getlogpdfs(state1) == getlogpdfs(state2) - @test getlogpdfs(state1) !== getlogpdfs(state2) - @test getsufficientstatistics(state1) == getsufficientstatistics(state2) - @test getsufficientstatistics(state1) !== getsufficientstatistics(state2) - @test getgradsamples(state1) == getgradsamples(state2) - @test getgradsamples(state1) !== getgradsamples(state2) + @test get_samples(state1) == get_samples(state2) + @test get_samples(state1) !== get_samples(state2) + @test get_logpdfs(state1) == get_logpdfs(state2) + @test get_logpdfs(state1) !== get_logpdfs(state2) + @test get_sufficientstatistics(state1) == get_sufficientstatistics(state2) + @test get_sufficientstatistics(state1) !== get_sufficientstatistics(state2) + @test get_gradsamples(state1) == get_gradsamples(state2) + @test get_gradsamples(state1) !== get_gradsamples(state2) if isbasemeasureconstant(ef) === ConstantBaseMeasure() - @test getlogbasemeasures(state1) === getlogbasemeasures(state2) + @test get_logbasemeasures(state1) === get_logbasemeasures(state2) else - @test getlogbasemeasures(state1) !== getlogbasemeasures(state2) + @test get_logbasemeasures(state1) !== get_logbasemeasures(state2) end samples = rand(ef, nsamples) @@ -218,24 +212,30 @@ end gradsamples = gradsamples, ) - strategy_with_state = - ControlVariateStrategy(nsamples = nsamples, state = state3) - state3_prepared = - prepare_state!(M, strategy_with_state, targetfn, ef, supplementary_η) + strategy = ControlVariateStrategy(nsamples = nsamples) + state3_prepared = prepare_state!( + strategy, + state3, + M, + parameters, + targetfn, + ef, + supplementary_η, + ) @test state3 === state3_prepared - @test getsamples(state3) === getsamples(state3_prepared) - @test getlogpdfs(state3) === getlogpdfs(state3_prepared) - @test getsufficientstatistics(state3) === - getsufficientstatistics(state3_prepared) - @test getlogbasemeasures(state3) === getlogbasemeasures(state3_prepared) - @test getgradsamples(state3) === getgradsamples(state3_prepared) - - @test getsamples(state1) == getsamples(state3) - @test getlogpdfs(state1) == getlogpdfs(state3) - @test getlogbasemeasures(state1) == getlogbasemeasures(state3) - @test getsufficientstatistics(state1) == getsufficientstatistics(state3) - @test getgradsamples(state1) == getgradsamples(state3) + @test get_samples(state3) === get_samples(state3_prepared) + @test get_logpdfs(state3) === get_logpdfs(state3_prepared) + @test get_sufficientstatistics(state3) === + get_sufficientstatistics(state3_prepared) + @test get_logbasemeasures(state3) === get_logbasemeasures(state3_prepared) + @test get_gradsamples(state3) === get_gradsamples(state3_prepared) + + @test get_samples(state1) == get_samples(state3) + @test get_logpdfs(state1) == get_logpdfs(state3) + @test get_logbasemeasures(state1) == get_logbasemeasures(state3) + @test get_sufficientstatistics(state1) == get_sufficientstatistics(state3) + @test get_gradsamples(state1) == get_gradsamples(state3) end end end @@ -243,7 +243,8 @@ end end @testitem "Gradient shouldn't depend on the scale of the `logpdf` when nsamples goes to infinity" begin - import ExponentialFamilyProjection: CVICostGradientObjective, ControlVariateStrategy + import ExponentialFamilyProjection: + ProjectionCostGradientObjective, ControlVariateStrategy, create_state! import ExponentialFamilyManifolds: get_natural_manifold using StableRNGs, ExponentialFamily, Manifolds, BayesBase @@ -252,6 +253,7 @@ end targetfn2 = (x) -> logpdf(dist, x) - 1000 strategy = ControlVariateStrategy(nsamples = 10^6) + parameters = ProjectionParameters() M = get_natural_manifold(Beta, ()) rng = StableRNG(42) @@ -259,8 +261,39 @@ end X1 = zero_vector(M, p) X2 = zero_vector(M, p) - objective1 = CVICostGradientObjective(targetfn1, (), strategy, nothing) - objective2 = CVICostGradientObjective(targetfn2, (), strategy, nothing) + state1 = create_state!( + strategy, + M, + parameters, + targetfn1, + convert(ExponentialFamilyDistribution, M, p), + (), + ) + state2 = create_state!( + strategy, + M, + parameters, + targetfn2, + convert(ExponentialFamilyDistribution, M, p), + (), + ) + + objective1 = ProjectionCostGradientObjective( + parameters, + targetfn1, + copy(p), + (), + strategy, + state1, + ) + objective2 = ProjectionCostGradientObjective( + parameters, + targetfn2, + copy(p), + (), + strategy, + state2, + ) c1, X1 = objective1(M, X1, p) c2, X2 = objective2(M, X2, p) @@ -299,6 +332,7 @@ end manifold = ExponentialFamilyManifolds.get_natural_manifold(typetag, dims, nothing) + targetfn_part = (x) -> logpdf(left, x) targetfn_full = (x) -> logpdf(ProductOf(left, right), x) ef = convert(ExponentialFamilyDistribution, right) @@ -309,24 +343,42 @@ end point = rand(StableRNG(42), manifold) costs = map(seeds) do seed - obj_part = ExponentialFamilyProjection.CVICostGradientObjective( + parameters = ProjectionParameters(seed = seed) + strategy = ExponentialFamilyProjection.ControlVariateStrategy( + nsamples = nsamples, + ) + state_part = ExponentialFamilyProjection.create_state!( + strategy, + manifold, + parameters, targetfn_part, + convert(ExponentialFamilyDistribution, manifold, point), supplementary_ef, - ExponentialFamilyProjection.ControlVariateStrategy( - nsamples = nsamples, - seed = seed, - ), - nothing, + ) + obj_part = ExponentialFamilyProjection.ProjectionCostGradientObjective( + parameters, + targetfn_part, + copy(point), + supplementary_ef, + strategy, + state_part, ) - obj_full = ExponentialFamilyProjection.CVICostGradientObjective( + state_full = ExponentialFamilyProjection.create_state!( + strategy, + manifold, + parameters, + targetfn_full, + convert(ExponentialFamilyDistribution, manifold, point), + [], + ) + obj_full = ExponentialFamilyProjection.ProjectionCostGradientObjective( + parameters, targetfn_full, + copy(point), [], - ExponentialFamilyProjection.ControlVariateStrategy( - nsamples = nsamples, - seed = seed, - ), - nothing, + strategy, + state_full, ) X2 = Manopt.zero_vector(manifold, point) @@ -361,4 +413,49 @@ end [0.5], ) +end + +@testitem "Projection result should not depend on the usage of buffer" begin + using ExponentialFamily, BayesBase, Bumper, StaticTools + distributions = [ + Beta(10, 10), + Gamma(10, 10), + Exponential(1), + LogNormal(0, 1), + Dirichlet([1, 1]), + NormalMeanVariance(0.0, 1.0), + MvNormalMeanCovariance([0.0, 0.0], [1.0 0.0; 0.0 1.0]), + Chisq(30.0), + ] + + for distribution in distributions + parameters_with_buffer = ProjectionParameters( + strategy = ExponentialFamilyProjection.ControlVariateStrategy( + buffer = StaticTools.MallocSlabBuffer(), + ), + ) + parameters_without_buffer = ProjectionParameters( + strategy = ExponentialFamilyProjection.ControlVariateStrategy(buffer = nothing), + ) + + dims = size(rand(distribution)) + + prj_with_buffer = ProjectedTo( + ExponentialFamily.exponential_family_typetag(distribution), + dims...; + parameters = parameters_with_buffer, + ) + prj_without_buffer = ProjectedTo( + ExponentialFamily.exponential_family_typetag(distribution), + dims...; + parameters = parameters_without_buffer, + ) + + targetfn = (x) -> logpdf(distribution, x) + result_with_buffer = project_to(prj_with_buffer, targetfn) + result_without_buffer = project_to(prj_without_buffer, targetfn) + + # Small differences are allowed due to different LinearAlgebra routines + @test result_with_buffer ≈ result_without_buffer + end end \ No newline at end of file diff --git a/test/strategies/mle_tests.jl b/test/strategies/mle_tests.jl index 552d0df..df8a48c 100644 --- a/test/strategies/mle_tests.jl +++ b/test/strategies/mle_tests.jl @@ -43,24 +43,25 @@ end c = getconditioner(ef) d = size(rand(rng, ef)) M = ExponentialFamilyManifolds.get_natural_manifold(T, d, c) + p = ProjectionParameters() + η = getnaturalparameters(ef) strategy = ExponentialFamilyProjection.MLEStrategy() - state = ExponentialFamilyProjection.prepare_state!(M, strategy, samples, ef, ()) - strategy = ExponentialFamilyProjection.with_state(strategy, state) - obj = ExponentialFamilyProjection.CVICostGradientObjective( + state = ExponentialFamilyProjection.create_state!(strategy, M, p, samples, ef, ()) + obj = ExponentialFamilyProjection.ProjectionCostGradientObjective( + p, samples, + copy(η), (), strategy, - nothing, + state, ) - η = getnaturalparameters(ef) _logpartition = logpartition(ef) _gradlogpartition = gradlogpartition(ef) _inv_fisher = inv(fisherinformation(ef)) cost = ExponentialFamilyProjection.compute_cost( M, - obj, strategy, state, η, @@ -73,7 +74,6 @@ end ExponentialFamilyProjection.compute_gradient!( M, - obj, strategy, state, gradient,