-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Continuous Transition Node for Latent State Transformation #309
Conversation
…P.jl into transfominator
Very nice work, I think this node may prove very useful. Two comments from my side:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool new feature! I think we can start a lot of nice projects with this addition. I left some comments, but mostly found a lot of places where we can optimize the code. I did not annotate all of them, but optimizing the code is definitely required. I leave it in the middle whether it should be part of this PR.
Co-authored-by: Bart van Erp <44952318+bartvanerp@users.noreply.github.com>
Co-authored-by: Bart van Erp <44952318+bartvanerp@users.noreply.github.com>
Co-authored-by: Bart van Erp <44952318+bartvanerp@users.noreply.github.com>
@ThijsvdLaar, thanks for the suggestion. Both points are addressed in the latest commit. |
@bartvanerp, thanks for checking. I've optimized where evident. Perhaps further optimization can be addressed with new issues. |
@bartvanerp or @bvdmitri only remaining issue with this PR now is that in case of linear transformation (e.g. reshape) we compute Jacobians, which is unnecessary). Should I improve on it within this PR or we a corresponding issue? In case of a former, what do you think is best way of doing it? My idea is to make meta smth like: struct ContinuousTransitionMeta
f::Union{Vector{<:AbstractMatrix}, <:Function}
function ContinuousTransitionMeta(transformation::F) where {F}
return new{F}(transformation)
end
function ContinuousTransitionMeta(transformation::F, a0::Vector{<:Real}) where {F}
dy = size(transformation(a0), 1)
Js = [ForwardDiff.jacobian(a -> transformation(a)[i, :], 1:length(a0)) for i in 1:dy]
return new(Js)
end
end |
I am fine with creating an issue for this and assigning someone. For later reference: an example solution could be to create an anonymous function and pre-specify its jacobian, such as foo(x) = reshape(x, d1, d2)
jacobian_sliced(f::foo, i) = .... # this should be specified by the user in their experiment
jacobian_sliced(f::Function, i) = ForwardDiff.jacobian(a -> transformation(a)[i, :] # generic case (fallback in case specialization not specified) I would say that this kind of optimization has to be performed by the user himself as an advanced feature and only needs to be documented. Also note that are a couple of unresolved comments. |
@bartvanerp I've resolved the rest of the comments. |
What do you mean unnecessary? In either case you can have a boolean flag, which controls whether to compute the jacobians or not (true by default for example) |
@bvdmitri with unnecessary computations, I meant that in the case of a linear transformation, we don't need to compute Jacobians each time when |
But your code cannot assume anything about the type of the transformation, how can you decide if it is necessary or not? |
@bvdmitri, you are correct. That's why I disregarded this option of precomputed Jacobians and kept it generic. |
This draft pull request (PR) addresses an issue related to variational message passing (VMP) in the context of updating a full-rank rectangular matrix within a "latent state transformation" model.
In this model, an m-dimensional latent state (
x
) is transformed into an n-dimensional observation (y
) (or vice versa) through a linear transformation (y = K(a) x
), whereK(.)
is a function that transforms a vector of arbitrary length into a matrixA
(i.e.,K(a) = A
). The objective is to infer the variational posterior distribution ofa
,x
,y
, and process noise precisionW
.To achieve this, we introduced a new node called
ContinuousTransition
. The update rules are implemented using structured variational inference. The complete derivation stack is based on unpublished material, which is available for reviewers.The use of the
ContinuousTransition
node is primarily in two regimes:A
is specified:Milestones:
ContinuousTransition
node.ContinuousTransition
node.[ ] Provide examples and benchmarks showcasing theContinuousTransition
node's capabilities.Transfominator
node based on feedback.