Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Continuous Transition Node for Latent State Transformation #309

Merged
merged 47 commits into from
Dec 18, 2023

Conversation

albertpod
Copy link
Member

@albertpod albertpod commented Apr 7, 2023

This draft pull request (PR) addresses an issue related to variational message passing (VMP) in the context of updating a full-rank rectangular matrix within a "latent state transformation" model.

In this model, an m-dimensional latent state (x) is transformed into an n-dimensional observation (y) (or vice versa) through a linear transformation (y = K(a) x), where K(.) is a function that transforms a vector of arbitrary length into a matrix A (i.e., K(a) = A). The objective is to infer the variational posterior distribution of a, x, y, and process noise precision W.

To achieve this, we introduced a new node called ContinuousTransition. The update rules are implemented using structured variational inference. The complete derivation stack is based on unpublished material, which is available for reviewers.

The use of the ContinuousTransition node is primarily in two regimes:

  1. When no structure on A is specified:
transformation = a -> reshape(a, 2, 2)
...
a ~ MvNormalMeanCovariance(zeros(2), Diagonal(ones(2)))
y ~ ContinuousTransition(x, a, W) where {meta = CTMeta(transformation, a)}
...
  1. When structure is known:
transformation = a -> [cos(a[1]) -sin(a[1]); sin(a[1]) cos(a[1])]
...
a ~ MvNormalMeanCovariance(zeros(1), Diagonal(ones(1)))
y ~ ContinuousTransition(x, a, W) where {meta = CTMeta(transformation, a)}
...

Milestones:

  • Finalize derivations for the ContinuousTransition node.
  • Add documentation for the ContinuousTransition node.
  • Write unit tests for the implemented update rules.
  • [ ] Provide examples and benchmarks showcasing the ContinuousTransition node's capabilities.
  • Refine the name and API of the Transfominator node based on feedback.

@albertpod albertpod marked this pull request as draft April 7, 2023 10:24
@albertpod albertpod changed the title [WIP] Add Transfominator Node for Latent State Transformation and Dimensionality Reduction [WIP] Add Continuous Transition Node for Latent State Transformation and Dimensionality Reduction May 5, 2023
@ThijsvdLaar
Copy link

Very nice work, I think this node may prove very useful. Two comments from my side:

  1. In the CTMeta it's not clear to me what the \hat{a} represents. Why does it need to be passed explicitly? And why in your example of the post above is a a Gaussian while the CTMeta constructor requires a Vector{<:Real}? A more explicit naming could already help here.
  2. It may be good to mention the two use cases of the post above in the docs as well.

Copy link
Member

@bartvanerp bartvanerp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool new feature! I think we can start a lot of nice projects with this addition. I left some comments, but mostly found a lot of places where we can optimize the code. I did not annotate all of them, but optimizing the code is definitely required. I leave it in the middle whether it should be part of this PR.

src/nodes/continuous_transition.jl Outdated Show resolved Hide resolved
src/nodes/continuous_transition.jl Outdated Show resolved Hide resolved
src/nodes/continuous_transition.jl Outdated Show resolved Hide resolved
src/nodes/continuous_transition.jl Outdated Show resolved Hide resolved
src/nodes/continuous_transition.jl Outdated Show resolved Hide resolved
src/nodes/continuous_transition.jl Show resolved Hide resolved
src/rules/continuous_transition/a.jl Outdated Show resolved Hide resolved
src/rules/continuous_transition/marginals.jl Outdated Show resolved Hide resolved
src/rules/continuous_transition/marginals.jl Outdated Show resolved Hide resolved
src/rules/continuous_transition/marginals.jl Outdated Show resolved Hide resolved
albertpod and others added 5 commits December 11, 2023 16:32
Co-authored-by: Bart van Erp <44952318+bartvanerp@users.noreply.github.com>
Co-authored-by: Bart van Erp <44952318+bartvanerp@users.noreply.github.com>
src/nodes/continuous_transition.jl Outdated Show resolved Hide resolved
src/nodes/continuous_transition.jl Outdated Show resolved Hide resolved
src/rules/continuous_transition/a.jl Outdated Show resolved Hide resolved
@albertpod
Copy link
Member Author

albertpod commented Dec 12, 2023

@ThijsvdLaar, thanks for the suggestion. Both points are addressed in the latest commit.

@albertpod
Copy link
Member Author

albertpod commented Dec 12, 2023

@bartvanerp, thanks for checking. I've optimized where evident. Perhaps further optimization can be addressed with new issues.

@albertpod albertpod requested review from bartvanerp and removed request for bvdmitri December 12, 2023 20:54
@albertpod
Copy link
Member Author

@bartvanerp or @bvdmitri only remaining issue with this PR now is that in case of linear transformation (e.g. reshape) we compute Jacobians, which is unnecessary). Should I improve on it within this PR or we a corresponding issue? In case of a former, what do you think is best way of doing it?

My idea is to make meta smth like:

struct ContinuousTransitionMeta
    f::Union{Vector{<:AbstractMatrix}, <:Function} 

    function ContinuousTransitionMeta(transformation::F) where {F}
        return new{F}(transformation)
    end
   
   function ContinuousTransitionMeta(transformation::F, a0::Vector{<:Real}) where {F}
        dy = size(transformation(a0), 1)
        Js = [ForwardDiff.jacobian(a -> transformation(a)[i, :], 1:length(a0)) for i in 1:dy]
        return new(Js)
    end

end

@bartvanerp
Copy link
Member

I am fine with creating an issue for this and assigning someone. For later reference: an example solution could be to create an anonymous function and pre-specify its jacobian, such as

foo(x) = reshape(x, d1, d2)
jacobian_sliced(f::foo, i) = .... # this should be specified by the user in their experiment
jacobian_sliced(f::Function, i) = ForwardDiff.jacobian(a -> transformation(a)[i, :] # generic case (fallback in case specialization not specified)

I would say that this kind of optimization has to be performed by the user himself as an advanced feature and only needs to be documented.

Also note that are a couple of unresolved comments.

@albertpod
Copy link
Member Author

@bartvanerp I've resolved the rest of the comments.
There's an issue that must be opened regarding AE and Jacobian computations.
I will open it after the merge.

@bvdmitri
Copy link
Member

What do you mean unnecessary? In either case you can have a boolean flag, which controls whether to compute the jacobians or not (true by default for example)

@albertpod
Copy link
Member Author

@bvdmitri with unnecessary computations, I meant that in the case of a linear transformation, we don't need to compute Jacobians each time when ctcompanion_matrix is called. As for now, it happens regardless of transformation type.

@bvdmitri
Copy link
Member

But your code cannot assume anything about the type of the transformation, how can you decide if it is necessary or not?

@albertpod
Copy link
Member Author

@bvdmitri, you are correct. That's why I disregarded this option of precomputed Jacobians and kept it generic.
However, the semi-experienced user can specify a flag that precomutes the Jacobians. It's a burden of the user that should be appropriately documented.

@albertpod albertpod merged commit f4835c2 into main Dec 18, 2023
10 checks passed
@albertpod albertpod deleted the transfominator branch December 18, 2023 14:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants