Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The CVI re-use optimiser at each prod call, and this behaviour can be not suitable for some optimisers #303

Closed
Nimrais opened this issue Mar 31, 2023 · 9 comments · Fixed by #320
Labels
bug Something isn't working documentation Improvements or additions to documentation enhancement New feature or request

Comments

@Nimrais
Copy link
Member

Nimrais commented Mar 31, 2023

Imagine you have an optimizer that has an inner state and somehow changes its parameters with each iteration (e.g., Adam).
If you will try to use it in several optimization tasks with a shared state: the convergence of each following task is not granted.

So if CVI is used in several consequent products with such an optimizer, it can start brake at some point.

The following code starts to find a bad approximation from some point:

using ReactiveMP
using StableRNGs
using Distributions
using Flux

n1 = NormalMeanVariance(10 * randn(rng), 10 * rand(rng))
n2 = NormalMeanVariance(10 * randn(rng), 10 * rand(rng))
n_analytical = prod(ProdAnalytical(), n1, n2)
cvi_result = prod(cvi, ContinuousUnivariateLogPdf((x) -> logpdf(n1, x)), n2)
@info n_analytical
@info prod(cvi, ContinuousUnivariateLogPdf((x) -> logpdf(n1, x)), n2)
@info isapprox(cvi_result, n_analytical, atol = 0.1)

for i in 1:10000
    cvi_result = prod(cvi, ContinuousUnivariateLogPdf((x) -> logpdf(n1, x)), n2)
    if !isapprox(cvi_result, n_analytical, atol = 0.1)
        @info "broken at $(i) iteration"
    end
end

This code works fine:

n1 = NormalMeanVariance(10 * randn(rng), 10 * rand(rng))
n2 = NormalMeanVariance(10 * randn(rng), 10 * rand(rng))
n_analytical = prod(ProdAnalytical(), n1, n2)
rng = StableRNG(42)
cvi = CVI(rng, 1, 1000, Flux.Adam(0.007), ForwardDiffGrad(), 10, Val(true), true)
cvi_result = prod(cvi, ContinuousUnivariateLogPdf((x) -> logpdf(n1, x)), n2)
@info n_analytical
@info prod(cvi, ContinuousUnivariateLogPdf((x) -> logpdf(n1, x)), n2)
@info isapprox(cvi_result, n_analytical, atol = 0.1)

for i in 1:10000
    cvi = CVI(rng, 1, 1000, Flux.Adam(0.007), ForwardDiffGrad(), 10, Val(true), true)
    cvi_result = prod(cvi, ContinuousUnivariateLogPdf((x) -> logpdf(n1, x)), n2)
    if !isapprox(cvi_result, n_analytical, atol = 0.1)
        @info "broken at $(i) iteration"
    end
end

Should we give a user an ability to reset an optimizer between iterations or prod calls?

@Nimrais Nimrais added bug Something isn't working documentation Improvements or additions to documentation enhancement New feature or request labels Mar 31, 2023
@bvdmitri
Copy link
Member

Its a good point, is there an API to reset Flux optimizers?
We should at least document the current behaviour.

As a workaround for now you can use a custom callback for your optimization procedure:

function use_adam_callback(λ, ∇)
    opt = Flux.Adam(0.007)
    return ReactiveMP.cvi_update!(opt, λ, ∇)
end
cvi = CVI(rng, 1, 1000, use_adam_callback, ForwardDiffGrad(), 10, Val(true), true)

You can generalize this pattern in a structure and do smth like

cvi = CVI(rng, 1, 1000, ResetOptimizer(() -> Flux.Adam(0.007)), ForwardDiffGrad(), 10, Val(true), true)

where

struct ResetOptimizer{C}
    callback::C
end

function cvi_update!(opt::ResetOptimizer, λ::NaturalParameters, ∇::NaturalParameters)
    return cvi_update!(opt.callback(), λ, ∇)
end

@Nimrais
Copy link
Member Author

Nimrais commented Mar 31, 2023

Thanks!

Ideally, I want to reset the optimizer not between iterations but between different prod calls.

@Nimrais
Copy link
Member Author

Nimrais commented Mar 31, 2023

I will look. I do not know Flux.jl has a reset API or not.

But If it would be possible to have an O parameter, not an optimizer but as a fabric of optimizers it should fix the issue anyway.

@bvdmitri
Copy link
Member

We discussed with @albertpod and we think this issue is quite severe and the current constructor for CVI is error-prone. This behaviour either must be documented explicitly or maybe instead we should disallow creating the CVI constructor with an actual optimizer and use the factory pattern indeed. We still can change things given that we don't have a lot of users depending on the current implementation. Better to change it now than later. I like the factory pattern approach but maybe we can simply call reset (? not sure it even exists) or similar (? not sure Flux implements it) on the opt in the current implementation.
@Nimrais are you planning to work on it?

@bvdmitri
Copy link
Member

bvdmitri commented May 7, 2023

@Nimrais Have you looked into the Flux API?

@Nimrais
Copy link
Member Author

Nimrais commented May 7, 2023

@bvdmitri Yes, I did. The new version of Flux's training code has been written as an independent package called Optimisers.jl. It includes a functionality that can solve our problem, called adjust!. With this, you can modify the current learning rate, momentum, and so on:
Optimisers.adjust!(opt, 0.03) # change η for the whole model...

@Nimrais
Copy link
Member Author

Nimrais commented May 7, 2023

I have taken some time to think, and in my opinion, there are two important points at which users can decide to reinitialise or to adjust the training hyperparameters:

  1. At the start of a new VMP iteration
  2. At the start of a new optimisation procedure.

So, ideally, as I see it, we need a structure in which we can implement two callbacks: on_vmp_iteration and on_optimization_procedure. This would allow for adjustments or any other user-specified actions to occur at these points.

@bvdmitri
Copy link
Member

I've added the deepcopy call to the optimisers from the Optimisers.jl. We need to check if the issue has actually been resolved.

@bvdmitri
Copy link
Member

This has been fixed a while ago.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working documentation Improvements or additions to documentation enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants