Skip to content

Commit

Permalink
Merge pull request #30 from ReactiveBayes/minor-refactor-29
Browse files Browse the repository at this point in the history
Refactor MLE & ControlVariate strategies
  • Loading branch information
bvdmitri authored Aug 5, 2024
2 parents 85d7077 + d6f1447 commit d2c93ae
Show file tree
Hide file tree
Showing 12 changed files with 570 additions and 469 deletions.
12 changes: 9 additions & 3 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ In order to project a log probability density function onto a member of the expo
```@docs
ExponentialFamilyProjection.ProjectionParameters
ExponentialFamilyProjection.DefaultProjectionParameters
ExponentialFamilyProjection.getinitialpoint
```

Read more about different optimization strategies [here](@ref opt-strategies).
Expand All @@ -33,9 +34,6 @@ The projection is performed by calling the `project_to` function with the specif
ExponentialFamilyProjection.project_to
```

!!! note
Different strategies are compatible with different types of arguments. Read [Optimization strategies](@ref opt-strategies) section for more information.

## [Optimization strategies](@id opt-strategies)

The optimization procedure requires computing the expectation of the gradient to perform gradient descent in the natural parameters space. Currently, the library provides the following strategies for computing these expectations:
Expand All @@ -45,6 +43,10 @@ ExponentialFamilyProjection.DefaultStrategy
ExponentialFamilyProjection.ControlVariateStrategy
ExponentialFamilyProjection.MLEStrategy
ExponentialFamilyProjection.preprocess_strategy_argument
ExponentialFamilyProjection.create_state!
ExponentialFamilyProjection.prepare_state!
ExponentialFamilyProjection.compute_cost
ExponentialFamilyProjection.compute_gradient!
```

For high-dimensional distributions, adjusting the default number of samples might be necessary to achieve better performance.
Expand Down Expand Up @@ -154,6 +156,10 @@ plot!(0.0:0.01:1.0, x -> pdf(result, x), label="estimated projection", fill = 0,

## Manopt extensions

```@docs
ExponentialFamilyProjection.ProjectionCostGradientObjective
```

### Bounded direction update rule

The `ExponentialFamilyProjection.jl` package implements a specialized gradient direction rule that limits the norm (manifold-specific) of the gradient to a pre-specified value.
Expand Down
92 changes: 89 additions & 3 deletions src/ExponentialFamilyProjection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,105 @@ __projection_fast_pack_parameters(t::NTuple{N,<:Number}) where {N} = t
__projection_fast_pack_parameters(t) = ExponentialFamily.pack_parameters(t)

include("manopt/bounded_norm_update_rule.jl")
include("cvi.jl")
include("manopt/projection_objective.jl")
include("projected_to.jl")

"""
preprocess_strategy_argument(strategy, argument)
Checks the compatibility of `strategy` with `argument` and returns a modified strategy if needed.
Checks the compatibility of `strategy` with `argument` and returns a modified strategy and argument if needed.
"""
function preprocess_strategy_argument end

"""
create_state!(
strategy,
M::AbstractManifold,
parameters::ProjectionParameters,
projection_argument,
initial_ef,
supplementary_η,
)
Creates, initializes and returns a state for the `strategy` with the given parameters.
"""
function create_state! end

"""
prepare_state!(
strategy,
state,
M::AbstractManifold,
parameters::ProjectionParameters,
projection_argument,
distribution,
supplementary_η,
)
Prepares an existing `state` of the `strategy` for the new optimization iteration for use by setting or updating its internal parameters.
"""
function prepare_state! end

"""
compute_cost(
M::AbstractManifold,
strategy,
state,
η,
logpartition,
gradlogpartition,
inv_fisher,
)
Compute the cost using the provided `strategy`.
# Arguments
- `M::AbstractManifold`: The manifold on which the computations are performed.
- `strategy`: The strategy used for computation of the cost value.
- `state`: The current state for the `strategy`.
- `η`: Parameter vector.
- `logpartition`: The log partition of the current point (η).
- `gradlogpartition`: The gradient of the log partition of the current point (η).
- `inv_fisher`: The inverse Fisher information matrix of the current point (η).
# Returns
- `cost`: The computed cost value.
"""
function compute_cost end

"""
compute_gradient!(
M::AbstractManifold,
strategy,
state,
X,
η,
logpartition,
gradlogpartition,
inv_fisher,
)
Updates the gradient `X` in-place using the provided `strategy`.
# Arguments
- `M::AbstractManifold`: The manifold on which the computations are performed.
- `strategy`: The strategy used for computation of the gradient value.
- `state`: The current state of the control variate strategy.
- `X`: The storage for the gradient.
- `η`: Parameter vector.
- `logpartition`: The log partition of the current point (η).
- `gradlogpartition`: The gradient of the log partition of the current point (η).
- `inv_fisher`: The inverse Fisher information matrix of the current point (η).
# Returns
- `X`: The computed gradient (updated in-place)
"""
function compute_gradient! end

include("strategies/control_variate.jl")
include("strategies/mle.jl")
include("strategies/default.jl")

include("projected_to.jl")


end
64 changes: 0 additions & 64 deletions src/cvi.jl

This file was deleted.

90 changes: 90 additions & 0 deletions src/manopt/projection_objective.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@

"""
ProjectionCostGradientObjective
This structure provides an interface for `Manopt` to compute the cost and gradients required for the optimization procedure based on manifold projection. The actual computation of costs and gradients is defined by the `strategy` argument.
# Arguments
- `projection_parameters`: The parameters for projection, must be of type `ProjectionParameters`
- `projection_argument`: The second argument of the `project_to` function.
- `current_η`: Current optimization point.
- `supplementary_η`: A tuple of additional natural parameters subtracted from the current point in each optimization iteration.
- `strategy`: Specifies the method for computing costs and gradients, which may support different `projection_argument` values.
- `strategy_state`: The state for the `strategy`, usually created with `create_state!`
!!! note
This structure is internal and is subject to change.
"""
struct ProjectionCostGradientObjective{J,F,C,P,S,T}
projection_parameters::J
projection_argument::F
current_η::C
supplementary_η::P
strategy::S
strategy_state::T
end

get_projection_parameters(obj::ProjectionCostGradientObjective) = obj.projection_parameters
get_projection_argument(obj::ProjectionCostGradientObjective) = obj.projection_argument
get_current_η(obj::ProjectionCostGradientObjective) = obj.current_η
get_supplementary_η(obj::ProjectionCostGradientObjective) = obj.supplementary_η
get_strategy(obj::ProjectionCostGradientObjective) = obj.strategy
get_strategy_state(obj::ProjectionCostGradientObjective) = obj.strategy_state

function (objective::ProjectionCostGradientObjective)(M::AbstractManifold, X, p)
current_η = copyto!(get_current_η(objective), p)
current_ef = convert(ExponentialFamilyDistribution, M, current_η)

strategy = get_strategy(objective)
state = get_strategy_state(objective)
projection_parameters = get_projection_parameters(objective)
projection_argument = get_projection_argument(objective)
supplementary_η = get_supplementary_η(objective)

state = prepare_state!(
strategy,
state,
M,
projection_parameters,
projection_argument,
current_ef,
supplementary_η,
)

logpartition = ExponentialFamily.logpartition(current_ef)
gradlogpartition = ExponentialFamily.gradlogpartition(current_ef)
inv_fisher = cholinv(ExponentialFamily.fisherinformation(current_ef))
η = copy(ExponentialFamily.getnaturalparameters(current_ef))

# If we have some supplementary natural parameters in the objective
# we must subtract them from the natural parameters of the current η
foreach(supplementary_η) do s_η
vmap!(-, current_η, current_η, s_η)
end

c = compute_cost(
M,
strategy,
state,
current_η,
logpartition,
gradlogpartition,
inv_fisher,
)
X = compute_gradient!(
M,
strategy,
state,
X,
current_η,
logpartition,
gradlogpartition,
inv_fisher,
)
X = project!(M, X, p, X)

return c, X
end


Loading

5 comments on commit d2c93ae

@Nimrais
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error while trying to register: Register Failed
@Nimrais, it looks like you are not a publicly listed member/owner in the parent organization (ReactiveBayes).
If you are a member/owner, you will need to change your membership to public. See GitHub Help

@Nimrais
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bvdmitri @albertpod Can you do it?

@wouterwln
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/113117

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.1.2 -m "<description of version>" d2c93ae5362d370287de6a3b5493263d493e904b
git push origin v1.1.2

Please sign in to comment.