From 76c51d59663b09ab3691a0849959e470a1acbdaa Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Fri, 31 Mar 2023 18:39:06 +0200 Subject: [PATCH 01/38] Add transfominator node --- Project.toml | 1 + src/ReactiveMP.jl | 1 + src/nodes/transfominator.jl | 43 +++++++++++++++++++++++++++++++++++++ 3 files changed, 45 insertions(+) create mode 100644 src/nodes/transfominator.jl diff --git a/Project.toml b/Project.toml index 3bedb0267..cacbbd071 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["Dmitry Bagaev ", "Albert Podusenko Date: Fri, 31 Mar 2023 18:50:02 +0200 Subject: [PATCH 02/38] Fix TMeta --- src/nodes/transfominator.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/nodes/transfominator.jl b/src/nodes/transfominator.jl index ef01350ad..439f8e8c9 100644 --- a/src/nodes/transfominator.jl +++ b/src/nodes/transfominator.jl @@ -12,10 +12,10 @@ struct TMeta Fs :: Vector{<:AbstractMatrix} # masks es :: Vector{<:AbstractVector} # unit vectors - function TMeta(ds::Tuple) + function TMeta(ds::Tuple{T, T}) where {T} dim1, dim2 = ds Fs = [tmask(dim1, dim2, i) for i in 1:dim2] - es = [StandardBasisVector(dim1, i, one(T)) for i in 1:dim2] + es = [StandardBasisVector(dim2, i, one(T)) for i in 1:dim2] return new(ds, Fs, es) end end From 07f965972767609f0f26ffa510c2dda9fc440af3 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Fri, 7 Apr 2023 12:10:21 +0200 Subject: [PATCH 03/38] Add transfominator node --- src/nodes/transfominator.jl | 49 +++++++++++++++++---- src/rules/prototypes.jl | 6 +++ src/rules/transfominator/h.jl | 17 +++++++ src/rules/transfominator/lambda.jl | 26 +++++++++++ src/rules/transfominator/marginals.jl | 40 +++++++++++++++++ src/rules/transfominator/x.jl | 20 +++++++++ src/rules/transfominator/y.jl | 21 +++++++++ test/nodes/test_transfominator.jl | 22 +++++++++ test/rules/transfominator/test_h.jl | 1 + test/rules/transfominator/test_lambda.jl | 1 + test/rules/transfominator/test_marginals.jl | 28 ++++++++++++ test/rules/transfominator/test_x.jl | 3 ++ test/rules/transfominator/test_y.jl | 3 ++ 13 files changed, 228 insertions(+), 9 deletions(-) create mode 100644 src/rules/transfominator/h.jl create mode 100644 src/rules/transfominator/lambda.jl create mode 100644 src/rules/transfominator/marginals.jl create mode 100644 src/rules/transfominator/x.jl create mode 100644 src/rules/transfominator/y.jl create mode 100644 test/nodes/test_transfominator.jl create mode 100644 test/rules/transfominator/test_h.jl create mode 100644 test/rules/transfominator/test_lambda.jl create mode 100644 test/rules/transfominator/test_marginals.jl create mode 100644 test/rules/transfominator/test_x.jl create mode 100644 test/rules/transfominator/test_y.jl diff --git a/src/nodes/transfominator.jl b/src/nodes/transfominator.jl index 439f8e8c9..3c5155588 100644 --- a/src/nodes/transfominator.jl +++ b/src/nodes/transfominator.jl @@ -7,17 +7,47 @@ struct Transfominator end const transfominator = Transfominator +@node Transfominator Stochastic [y, x, h, Λ] + struct TMeta - ds :: Tuple # dimensionality of Transfominator - Fs :: Vector{<:AbstractMatrix} # masks - es :: Vector{<:AbstractVector} # unit vectors - - function TMeta(ds::Tuple{T, T}) where {T} - dim1, dim2 = ds - Fs = [tmask(dim1, dim2, i) for i in 1:dim2] - es = [StandardBasisVector(dim2, i, one(T)) for i in 1:dim2] + ds::Tuple # dimensionality of Transfominator (dy, dx) + Fs::Vector{<:AbstractMatrix} # masks + es::Vector{<:AbstractVector} # unit vectors + + function TMeta(ds::Tuple{T, T}) where {T <: Integer} + dy, dx = ds + Fs = [tmask(dx, dy, i) for i in 1:dy] + es = [StandardBasisVector(dy, i, one(T)) for i in 1:dy] return new(ds, Fs, es) end + + function TMeta(dy::T, dx::T) where {T <: Integer} + Fs = [tmask(dx, dy, i) for i in 1:dy] + es = [StandardBasisVector(dy, i, one(T)) for i in 1:dy] + return new((dy, dx), Fs, es) + end +end + +@average_energy Transfominator (q_y_x::MultivariateNormalDistributionsFamily, q_h::MultivariateNormalDistributionsFamily, q_Λ::Wishart, meta::TMeta) = begin + mh, Vh = mean_cov(q_h) + myx, Vyx = mean_cov(q_y_x) + mΛ = mean(q_Λ) + + dy, dx = getdimensionality(meta) + Fs, es = getmasks(meta), getunits(meta) + n = div(ndims(q_y_x), 2) + mH = tcompanion_matrix(mh, meta) + mx, Vx = myx[(dy + 1):end], Vyx[(dy + 1):end, (dy + 1):end] + my, Vy = myx[1:dy], Vyx[1:dy, 1:dy] + Vyx = Vyx[1:dy, (dy + 1):end] + g₁ = my' * mΛ * my + tr(Vy * mΛ) + g₂ = mx' * mH' * mΛ * my + tr(Vyx * mH' * mΛ) + g₃ = g₂ + G = sum(sum(es[i]' * mΛ * es[j] * Fs[i] * (mh * mh' + Vh) * Fs[j]' for i in 1:length(Fs)) for j in 1:length(Fs)) + g₄ = mx' * G * mx + tr(Vx * G) + AE = n / 2 * log2π - 0.5 * mean(logdet, q_Λ) + 0.5 * (g₁ - g₂ - g₃ + g₄) + + return AE end getdimensionality(meta::TMeta) = meta.ds @@ -38,6 +68,7 @@ end function tcompanion_matrix(w, meta::TMeta) Fs, es = getmasks(meta), getunits(meta) - L = sum(es[i] * w' * Fs[i]' for i in 1:dim2) + dy, dx = getdimensionality(meta) + L = sum(es[i] * w' * Fs[i]' for i in 1:dy) return L end diff --git a/src/rules/prototypes.jl b/src/rules/prototypes.jl index 5c56ab658..25a523800 100644 --- a/src/rules/prototypes.jl +++ b/src/rules/prototypes.jl @@ -105,6 +105,12 @@ include("transition/out.jl") include("transition/in.jl") include("transition/a.jl") +include("transfominator/y.jl") +include("transfominator/x.jl") +include("transfominator/h.jl") +include("transfominator/lambda.jl") +include("transfominator/marginals.jl") + include("autoregressive/y.jl") include("autoregressive/x.jl") include("autoregressive/theta.jl") diff --git a/src/rules/transfominator/h.jl b/src/rules/transfominator/h.jl new file mode 100644 index 000000000..5e4ff6333 --- /dev/null +++ b/src/rules/transfominator/h.jl @@ -0,0 +1,17 @@ +@rule Transfominator(:h, Marginalisation) (q_y_x::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::TMeta) = begin + dy, dx = getdimensionality(meta) + Fs, es = getmasks(meta), getunits(meta) + + myx, Vyx = mean_cov(q_y_x) + + mx, Vx = myx[(dy + 1):end], Vyx[(dy + 1):end, (dy + 1):end] + my, Vy = myx[1:dy], Vyx[1:dy, 1:dy] + Vyx = Vyx[1:dy, (dy + 1):end] + + mΛ = mean(q_Λ) + + D = sum(sum(es[i]' * mΛ * es[j] * Fs[i]' * (mx * mx' + Vx) * Fs[j] for i in 1:length(Fs)) for j in 1:length(Fs)) + z = sum(Fs[i]' * (mx * my' + Vyx') * mΛ * es[i] for i in 1:length(Fs)) + + return MvNormalWeightedMeanPrecision(z, D) +end diff --git a/src/rules/transfominator/lambda.jl b/src/rules/transfominator/lambda.jl new file mode 100644 index 000000000..02fecbd63 --- /dev/null +++ b/src/rules/transfominator/lambda.jl @@ -0,0 +1,26 @@ +function compute_delta(my, Vy, mx, Vx, Vyx, mH, Vh, mh, Fs, es) + G₁ = (my * my' + Vy) + G₂ = ((my * mx' + Vyx) * mH') + G₃ = transpose(G₂) + Ex_xx = mx * mx' + Vx + G₅ = sum(sum(es[i] * mh' * Fs[i]'Ex_xx * Fs[j] * mh * es[j]' for i in 1:length(Fs)) for j in 1:length(Fs)) + G₆ = sum(sum(es[i] * tr(Fs[i]' * Ex_xx * Fs[j] * Vh) * es[j]' for i in 1:length(Fs)) for j in 1:length(Fs)) + Δ = G₁ - G₂ - G₃ + G₅ + G₆ +end + +@rule Transfominator(:Λ, Marginalisation) (q_y_x::MultivariateNormalDistributionsFamily, q_h::MultivariateNormalDistributionsFamily, meta::TMeta) = begin + dy, dx = getdimensionality(meta) + Fs, es = getmasks(meta), getunits(meta) + + mh, Vh = mean_cov(q_h) + mH = tcompanion_matrix(mh, meta) + myx, Vyx = mean_cov(q_y_x) + + mx, Vx = myx[(dy + 1):end], Vyx[(dy + 1):end, (dy + 1):end] + my, Vy = myx[1:dy], Vyx[1:dy, 1:dy] + Vyx = Vyx[1:dy, (dy + 1):end] + + Δ = compute_delta(my, Vy, mx, Vx, Vyx, mH, Vh, mh, Fs, es) + + return WishartMessage(length(Fs) + 2, Δ) +end diff --git a/src/rules/transfominator/marginals.jl b/src/rules/transfominator/marginals.jl new file mode 100644 index 000000000..37a6718fd --- /dev/null +++ b/src/rules/transfominator/marginals.jl @@ -0,0 +1,40 @@ + +@marginalrule Transfominator(:y_x) ( + m_y::MultivariateNormalDistributionsFamily, m_x::MultivariateNormalDistributionsFamily, q_h::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::TMeta +) = begin + return transfominator_marginal(m_y, m_x, q_h, q_Λ, meta) +end + +function transfominator_marginal( + m_y::MultivariateNormalDistributionsFamily, m_x::MultivariateNormalDistributionsFamily, q_h::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::TMeta +) + dx, dy = getdimensionality(meta) + Fs, es = getmasks(meta), getunits(meta) + + mh, Vh = mean_cov(q_h) + mΛ = mean(q_Λ) + + mH = tcompanion_matrix(mh, meta) + + b_my, b_Vy = mean_cov(m_y) + f_mx, f_Vx = mean_cov(m_x) + + inv_b_Vy = cholinv(b_Vy) + inv_f_Vx = cholinv(f_Vx) + + Ξ = inv_f_Vx + sum(sum(es[j]' * mΛ * es[i] * Fs[j] * Vh * Fs[i]' for i in 1:length(Fs)) for j in 1:length(Fs)) + + W_11 = inv_b_Vy + mΛ + + # negate_inplace!(mW * mH) + W_12 = -(mΛ * mH) + + W_21 = -(mH' * mΛ) + + W_22 = Ξ + mH' * mΛ * mH + + W = [W_11 W_12; W_21 W_22] + ξ = [inv_b_Vy * b_my; inv_f_Vx * f_mx] + + return MvNormalWeightedMeanPrecision(ξ, W) +end diff --git a/src/rules/transfominator/x.jl b/src/rules/transfominator/x.jl new file mode 100644 index 000000000..02f5bf1d1 --- /dev/null +++ b/src/rules/transfominator/x.jl @@ -0,0 +1,20 @@ +@rule Transfominator(:x, Marginalisation) (m_y::MultivariateNormalDistributionsFamily, q_h::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::TMeta) = begin + mh, Vh = mean_cov(q_h) + my, Vy = mean_cov(m_y) + + mΛ = mean(q_Λ) + + dy, dx = getdimensionality(meta) + Fs, es = getmasks(meta), getunits(meta) + + mH = tcompanion_matrix(mh, meta) + + Λ = sum(sum(es[j]' * mΛ * es[i] * Fs[j] * Vh * Fs[i]' for i in 1:length(Fs)) for j in 1:length(Fs)) + + Σ₁ = Hermitian(pinv(mH) * (Vy) * pinv(mH') + pinv(mH' * mΛ * mH)) + + Ξ = (pinv(Σ₁) + Λ) + z = pinv(Σ₁) * pinv(mH) * my + + return MvNormalWeightedMeanPrecision(z, Ξ) +end diff --git a/src/rules/transfominator/y.jl b/src/rules/transfominator/y.jl new file mode 100644 index 000000000..df088d19c --- /dev/null +++ b/src/rules/transfominator/y.jl @@ -0,0 +1,21 @@ +@rule Transfominator(:y, Marginalisation) (m_x::MultivariateNormalDistributionsFamily, q_h::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::TMeta) = begin + mh, Vh = mean_cov(q_h) + mx, Wx = mean_invcov(m_x) + + mΛ = mean(q_Λ) + + dy, dx = getdimensionality(meta) + Fs, es = getmasks(meta), getunits(meta) + + mH = tcompanion_matrix(mh, meta) + + Λ = sum(sum(es[j]' * mΛ * es[i] * Fs[j] * Vh * Fs[i]' for i in 1:length(Fs)) for j in 1:length(Fs)) + + Ξ = Λ + Wx + z = Wx * mx + + Vy = mH * inv(Ξ) * mH' + inv(mΛ) + my = mH * inv(Ξ) * z + + return MvNormalMeanCovariance(my, Vy) +end diff --git a/test/nodes/test_transfominator.jl b/test/nodes/test_transfominator.jl new file mode 100644 index 000000000..ee2f3a858 --- /dev/null +++ b/test/nodes/test_transfominator.jl @@ -0,0 +1,22 @@ +score( + AverageEnergy(), + Transfominator, + Val{(:y_x, :h, :Λ)}(), + ( + Marginal(MvNormalMeanPrecision(zeros(4), diageye(4)), false, false, nothing), + Marginal(MvNormalMeanPrecision(zeros(4), diageye(4)), false, false, nothing), + Marginal(Wishart(2, diageye(2)), false, false, nothing) + ), + TMeta((2, 2)) +) +score( + AverageEnergy(), + Transfominator, + Val{(:y_x, :h, :Λ)}(), + ( + Marginal(MvNormalMeanPrecision(zeros(5), diageye(5)), false, false, nothing), + Marginal(MvNormalMeanPrecision(zeros(6), diageye(6)), false, false, nothing), + Marginal(Wishart(2, diageye(2)), false, false, nothing) + ), + TMeta((2, 3)) +) diff --git a/test/rules/transfominator/test_h.jl b/test/rules/transfominator/test_h.jl new file mode 100644 index 000000000..c796c4ae7 --- /dev/null +++ b/test/rules/transfominator/test_h.jl @@ -0,0 +1 @@ +@call_rule Transfominator(:h, Marginalisation) (q_y_x = MvNormalMeanCovariance(randn(5), diageye(5)), q_Λ = Wishart(2, diageye(2)), meta = TMeta(2, 3)) diff --git a/test/rules/transfominator/test_lambda.jl b/test/rules/transfominator/test_lambda.jl new file mode 100644 index 000000000..e5e16004f --- /dev/null +++ b/test/rules/transfominator/test_lambda.jl @@ -0,0 +1 @@ +@call_rule Transfominator(:Λ, Marginalisation) (q_y_x = MvNormalMeanCovariance(randn(5), diageye(5)), q_h = MvNormalMeanCovariance(randn(6), diageye(6)), meta = TMeta(2, 3)) diff --git a/test/rules/transfominator/test_marginals.jl b/test/rules/transfominator/test_marginals.jl new file mode 100644 index 000000000..3b17c9b19 --- /dev/null +++ b/test/rules/transfominator/test_marginals.jl @@ -0,0 +1,28 @@ +module RulesTransfominatorTest + +using Test +using ReactiveMP +using Random +using LinearAlgebra +using Distributions + +import ReactiveMP: @test_marginalrules + +# @call_marginalrule Transfominator(:y_x) (m_y = MvNormalMeanPrecision(ones(2), diageye(2)), m_x = MvNormalMeanPrecision(ones(2), diageye(2)), q_h = MvNormalMeanPrecision(ones(4), diageye(4)), q_Λ = Wishart(2, diageye(2)), meta = TMeta(2, 2)) +# @call_marginalrule Transfominator(:y_x) (m_y = MvNormalMeanPrecision(ones(2), diageye(2)), m_x = MvNormalMeanPrecision(ones(3), diageye(3)), q_h = MvNormalMeanPrecision(ones(6), diageye(6)), q_Λ = Wishart(2, diageye(2)), meta = TMeta(2, 3)) + +@testset "marginalrules:Transfominator" begin + @testset "y_x: (m_y::NormalDistributionsFamily, m_x::NormalDistributionsFamily, q_θ::NormalDistributionsFamily, q_γ::Any)" begin + @test_marginalrules [with_float_conversions = true] Transfominator(:y_x) [( + input = ( + m_y = MvNormalMeanPrecision(ones(2), diageye(2)), + m_x = MvNormalMeanPrecision(ones(2), diageye(2)), + q_h = MvNormalMeanPrecision(ones(4), diageye(4)), + q_Λ = Wishart(2, diageye(2)), + meta = TMeta(2, 2) + ), + output = MvNormalWeightedMeanPrecision(zeros(2), [2.0 -1.0; -1.0 3.0]) + )] + end +end +end diff --git a/test/rules/transfominator/test_x.jl b/test/rules/transfominator/test_x.jl new file mode 100644 index 000000000..c6d6c9db1 --- /dev/null +++ b/test/rules/transfominator/test_x.jl @@ -0,0 +1,3 @@ +@call_rule Transfominator(:x, Marginalisation) ( + m_y = MvNormalMeanPrecision(randn(2), diageye(2)), q_h = MvNormalMeanPrecision(randn(6), diageye(6)), q_Λ = Wishart(2, diageye(2)), meta = TMeta(2, 3) +) diff --git a/test/rules/transfominator/test_y.jl b/test/rules/transfominator/test_y.jl new file mode 100644 index 000000000..023f5a48a --- /dev/null +++ b/test/rules/transfominator/test_y.jl @@ -0,0 +1,3 @@ +@call_rule Transfominator(:y, Marginalisation) ( + m_x = MvNormalMeanPrecision(randn(3), diageye(3)), q_h = MvNormalMeanPrecision(randn(6), diageye(6)), q_Λ = Wishart(2, diageye(2)), meta = TMeta(2, 3) +) From b1fbb0b242ed6be8cddc10c19c7a2585b80fbde8 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Fri, 5 May 2023 13:37:12 +0200 Subject: [PATCH 04/38] Update node --- src/nodes/transfominator.jl | 16 +++++++++++++++- src/rules/transfominator/lambda.jl | 4 ++-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/src/nodes/transfominator.jl b/src/nodes/transfominator.jl index 3c5155588..605c2e517 100644 --- a/src/nodes/transfominator.jl +++ b/src/nodes/transfominator.jl @@ -3,6 +3,20 @@ export transfominator, Transfominator, TMeta import LazyArrays, BlockArrays import StatsFuns: log2π +@doc raw""" +The Transfominator node is a node that transforms a n-dimensional vector x into m-dimensional vector y. +The transformation is achieved by casting n*m-dimensional vector h into a m×n H matrix. + +```julia +y ~ Transfominator(x, h, Λ) +``` + +Interfaces: +1. y - latent output of the Transfominator node +2. x - latent input of the Transfominator node +3. h - latent vector that casts into the matrix H +4. Λ - latent precision matrix (could be fixed) +""" struct Transfominator end const transfominator = Transfominator @@ -54,7 +68,7 @@ getdimensionality(meta::TMeta) = meta.ds getmasks(meta::TMeta) = meta.Fs getunits(meta::TMeta) = meta.es -@node Transfominator Stochastic [y, x, w, Λ] +@node Transfominator Stochastic [y, x, h, Λ] default_meta(::Type{TMeta}) = error("Transfominator node requires meta flag explicitly specified") diff --git a/src/rules/transfominator/lambda.jl b/src/rules/transfominator/lambda.jl index 02fecbd63..07e3d392f 100644 --- a/src/rules/transfominator/lambda.jl +++ b/src/rules/transfominator/lambda.jl @@ -5,7 +5,7 @@ function compute_delta(my, Vy, mx, Vx, Vyx, mH, Vh, mh, Fs, es) Ex_xx = mx * mx' + Vx G₅ = sum(sum(es[i] * mh' * Fs[i]'Ex_xx * Fs[j] * mh * es[j]' for i in 1:length(Fs)) for j in 1:length(Fs)) G₆ = sum(sum(es[i] * tr(Fs[i]' * Ex_xx * Fs[j] * Vh) * es[j]' for i in 1:length(Fs)) for j in 1:length(Fs)) - Δ = G₁ - G₂ - G₃ + G₅ + G₆ + return G₁ - G₂ - G₃ + G₅ + G₆ end @rule Transfominator(:Λ, Marginalisation) (q_y_x::MultivariateNormalDistributionsFamily, q_h::MultivariateNormalDistributionsFamily, meta::TMeta) = begin @@ -22,5 +22,5 @@ end Δ = compute_delta(my, Vy, mx, Vx, Vyx, mH, Vh, mh, Fs, es) - return WishartMessage(length(Fs) + 2, Δ) + return WishartMessage(dy+2, Δ) end From 35bb086b9368576cc094a1f6240d5eca2d4a7a17 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Fri, 5 May 2023 18:37:12 +0200 Subject: [PATCH 05/38] Update node --- src/ReactiveMP.jl | 2 +- ...sfominator.jl => continuous_transition.jl} | 49 ++++++++++--------- .../h.jl | 2 +- .../lambda.jl | 5 +- .../marginals.jl | 13 +++-- .../x.jl | 4 +- .../y.jl | 4 +- src/rules/prototypes.jl | 10 ++-- test/nodes/test_transfominator.jl | 8 +-- test/rules/continuous_transition/test_h.jl | 1 + .../continuous_transition/test_lambda.jl | 1 + .../continuous_transition/test_marginals.jl | 28 +++++++++++ .../test_x.jl | 4 +- .../test_y.jl | 4 +- test/rules/transfominator/test_h.jl | 1 - test/rules/transfominator/test_lambda.jl | 1 - test/rules/transfominator/test_marginals.jl | 28 ----------- 17 files changed, 83 insertions(+), 82 deletions(-) rename src/nodes/{transfominator.jl => continuous_transition.jl} (54%) rename src/rules/{transfominator => continuous_transition}/h.jl (80%) rename src/rules/{transfominator => continuous_transition}/lambda.jl (76%) rename src/rules/{transfominator => continuous_transition}/marginals.jl (76%) rename src/rules/{transfominator => continuous_transition}/x.jl (68%) rename src/rules/{transfominator => continuous_transition}/y.jl (66%) create mode 100644 test/rules/continuous_transition/test_h.jl create mode 100644 test/rules/continuous_transition/test_lambda.jl create mode 100644 test/rules/continuous_transition/test_marginals.jl rename test/rules/{transfominator => continuous_transition}/test_x.jl (60%) rename test/rules/{transfominator => continuous_transition}/test_y.jl (60%) delete mode 100644 test/rules/transfominator/test_h.jl delete mode 100644 test/rules/transfominator/test_lambda.jl delete mode 100644 test/rules/transfominator/test_marginals.jl diff --git a/src/ReactiveMP.jl b/src/ReactiveMP.jl index de707486d..9ed0686bd 100644 --- a/src/ReactiveMP.jl +++ b/src/ReactiveMP.jl @@ -157,7 +157,7 @@ include("nodes/bifm.jl") include("nodes/bifm_helper.jl") include("nodes/probit.jl") include("nodes/poisson.jl") -include("nodes/transfominator.jl") +include("nodes/continuous_transition.jl") include("nodes/flow/flow.jl") include("nodes/delta/delta.jl") diff --git a/src/nodes/transfominator.jl b/src/nodes/continuous_transition.jl similarity index 54% rename from src/nodes/transfominator.jl rename to src/nodes/continuous_transition.jl index 605c2e517..bd31ccce7 100644 --- a/src/nodes/transfominator.jl +++ b/src/nodes/continuous_transition.jl @@ -1,48 +1,49 @@ -export transfominator, Transfominator, TMeta +export transfominator, CTransition, ContinuousTransition, CTMeta import LazyArrays, BlockArrays import StatsFuns: log2π @doc raw""" -The Transfominator node is a node that transforms a n-dimensional vector x into m-dimensional vector y. +The ContinuousTransition node is a node that transforms a n-dimensional vector x into m-dimensional vector y. The transformation is achieved by casting n*m-dimensional vector h into a m×n H matrix. ```julia -y ~ Transfominator(x, h, Λ) +y ~ ContinuousTransition(x, h, Λ) ``` Interfaces: -1. y - latent output of the Transfominator node -2. x - latent input of the Transfominator node +1. y - latent output of the ContinuousTransition node +2. x - latent input of the ContinuousTransition node 3. h - latent vector that casts into the matrix H 4. Λ - latent precision matrix (could be fixed) """ -struct Transfominator end +struct ContinuousTransition end -const transfominator = Transfominator +const transfominator = ContinuousTransition +const CTransition = ContinuousTransition -@node Transfominator Stochastic [y, x, h, Λ] +@node ContinuousTransition Stochastic [y, x, h, Λ] -struct TMeta - ds::Tuple # dimensionality of Transfominator (dy, dx) +struct CTMeta + ds::Tuple # dimensionality of ContinuousTransition (dy, dx) Fs::Vector{<:AbstractMatrix} # masks es::Vector{<:AbstractVector} # unit vectors - function TMeta(ds::Tuple{T, T}) where {T <: Integer} + function CTMeta(ds::Tuple{T, T}) where {T <: Integer} dy, dx = ds - Fs = [tmask(dx, dy, i) for i in 1:dy] + Fs = [ctmask(dx, dy, i) for i in 1:dy] es = [StandardBasisVector(dy, i, one(T)) for i in 1:dy] return new(ds, Fs, es) end - function TMeta(dy::T, dx::T) where {T <: Integer} - Fs = [tmask(dx, dy, i) for i in 1:dy] + function CTMeta(dy::T, dx::T) where {T <: Integer} + Fs = [ctmask(dx, dy, i) for i in 1:dy] es = [StandardBasisVector(dy, i, one(T)) for i in 1:dy] return new((dy, dx), Fs, es) end end -@average_energy Transfominator (q_y_x::MultivariateNormalDistributionsFamily, q_h::MultivariateNormalDistributionsFamily, q_Λ::Wishart, meta::TMeta) = begin +@average_energy ContinuousTransition (q_y_x::MultivariateNormalDistributionsFamily, q_h::MultivariateNormalDistributionsFamily, q_Λ::Wishart, meta::CTMeta) = begin mh, Vh = mean_cov(q_h) myx, Vyx = mean_cov(q_y_x) mΛ = mean(q_Λ) @@ -50,7 +51,7 @@ end dy, dx = getdimensionality(meta) Fs, es = getmasks(meta), getunits(meta) n = div(ndims(q_y_x), 2) - mH = tcompanion_matrix(mh, meta) + mH = ctcompanion_matrix(mh, meta) mx, Vx = myx[(dy + 1):end], Vyx[(dy + 1):end, (dy + 1):end] my, Vy = myx[1:dy], Vyx[1:dy, 1:dy] Vyx = Vyx[1:dy, (dy + 1):end] @@ -64,15 +65,15 @@ end return AE end -getdimensionality(meta::TMeta) = meta.ds -getmasks(meta::TMeta) = meta.Fs -getunits(meta::TMeta) = meta.es +getdimensionality(meta::CTMeta) = meta.ds +getmasks(meta::CTMeta) = meta.Fs +getunits(meta::CTMeta) = meta.es -@node Transfominator Stochastic [y, x, h, Λ] +@node ContinuousTransition Stochastic [y, x, h, Λ] -default_meta(::Type{TMeta}) = error("Transfominator node requires meta flag explicitly specified") +default_meta(::Type{CTMeta}) = error("ContinuousTransition node requires meta flag explicitly specified") -function tmask(dim1, dim2, index) +function ctmask(dim1, dim2, index) F = zeros(dim1, dim1 * dim2) start_col = (index - 1) * dim1 + 1 end_col = start_col + dim1 - 1 @@ -80,9 +81,9 @@ function tmask(dim1, dim2, index) return F end -function tcompanion_matrix(w, meta::TMeta) +function ctcompanion_matrix(w, meta::CTMeta) Fs, es = getmasks(meta), getunits(meta) - dy, dx = getdimensionality(meta) + dy, _ = getdimensionality(meta) L = sum(es[i] * w' * Fs[i]' for i in 1:dy) return L end diff --git a/src/rules/transfominator/h.jl b/src/rules/continuous_transition/h.jl similarity index 80% rename from src/rules/transfominator/h.jl rename to src/rules/continuous_transition/h.jl index 5e4ff6333..d26b3a7fc 100644 --- a/src/rules/transfominator/h.jl +++ b/src/rules/continuous_transition/h.jl @@ -1,4 +1,4 @@ -@rule Transfominator(:h, Marginalisation) (q_y_x::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::TMeta) = begin +@rule ContinuousTransition(:h, Marginalisation) (q_y_x::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::CTMeta) = begin dy, dx = getdimensionality(meta) Fs, es = getmasks(meta), getunits(meta) diff --git a/src/rules/transfominator/lambda.jl b/src/rules/continuous_transition/lambda.jl similarity index 76% rename from src/rules/transfominator/lambda.jl rename to src/rules/continuous_transition/lambda.jl index 07e3d392f..c1f1545ab 100644 --- a/src/rules/transfominator/lambda.jl +++ b/src/rules/continuous_transition/lambda.jl @@ -8,12 +8,12 @@ function compute_delta(my, Vy, mx, Vx, Vyx, mH, Vh, mh, Fs, es) return G₁ - G₂ - G₃ + G₅ + G₆ end -@rule Transfominator(:Λ, Marginalisation) (q_y_x::MultivariateNormalDistributionsFamily, q_h::MultivariateNormalDistributionsFamily, meta::TMeta) = begin +@rule ContinuousTransition(:Λ, Marginalisation) (q_y_x::MultivariateNormalDistributionsFamily, q_h::MultivariateNormalDistributionsFamily, meta::CTMeta) = begin dy, dx = getdimensionality(meta) Fs, es = getmasks(meta), getunits(meta) mh, Vh = mean_cov(q_h) - mH = tcompanion_matrix(mh, meta) + mH = ctcompanion_matrix(mh, meta) myx, Vyx = mean_cov(q_y_x) mx, Vx = myx[(dy + 1):end], Vyx[(dy + 1):end, (dy + 1):end] @@ -22,5 +22,6 @@ end Δ = compute_delta(my, Vy, mx, Vx, Vyx, mH, Vh, mh, Fs, es) + # NOTE: WishartMessage stores inverse of scale matrix return WishartMessage(dy+2, Δ) end diff --git a/src/rules/transfominator/marginals.jl b/src/rules/continuous_transition/marginals.jl similarity index 76% rename from src/rules/transfominator/marginals.jl rename to src/rules/continuous_transition/marginals.jl index 37a6718fd..7c5633727 100644 --- a/src/rules/transfominator/marginals.jl +++ b/src/rules/continuous_transition/marginals.jl @@ -1,20 +1,19 @@ -@marginalrule Transfominator(:y_x) ( - m_y::MultivariateNormalDistributionsFamily, m_x::MultivariateNormalDistributionsFamily, q_h::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::TMeta +@marginalrule ContinuousTransition(:y_x) ( + m_y::MultivariateNormalDistributionsFamily, m_x::MultivariateNormalDistributionsFamily, q_h::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::CTMeta ) = begin - return transfominator_marginal(m_y, m_x, q_h, q_Λ, meta) + return continuous_tranition_marginal(m_y, m_x, q_h, q_Λ, meta) end -function transfominator_marginal( - m_y::MultivariateNormalDistributionsFamily, m_x::MultivariateNormalDistributionsFamily, q_h::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::TMeta +function continuous_tranition_marginal( + m_y::MultivariateNormalDistributionsFamily, m_x::MultivariateNormalDistributionsFamily, q_h::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::CTMeta ) - dx, dy = getdimensionality(meta) Fs, es = getmasks(meta), getunits(meta) mh, Vh = mean_cov(q_h) mΛ = mean(q_Λ) - mH = tcompanion_matrix(mh, meta) + mH = ctcompanion_matrix(mh, meta) b_my, b_Vy = mean_cov(m_y) f_mx, f_Vx = mean_cov(m_x) diff --git a/src/rules/transfominator/x.jl b/src/rules/continuous_transition/x.jl similarity index 68% rename from src/rules/transfominator/x.jl rename to src/rules/continuous_transition/x.jl index 02f5bf1d1..f96843919 100644 --- a/src/rules/transfominator/x.jl +++ b/src/rules/continuous_transition/x.jl @@ -1,4 +1,4 @@ -@rule Transfominator(:x, Marginalisation) (m_y::MultivariateNormalDistributionsFamily, q_h::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::TMeta) = begin +@rule ContinuousTransition(:x, Marginalisation) (m_y::MultivariateNormalDistributionsFamily, q_h::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::CTMeta) = begin mh, Vh = mean_cov(q_h) my, Vy = mean_cov(m_y) @@ -7,7 +7,7 @@ dy, dx = getdimensionality(meta) Fs, es = getmasks(meta), getunits(meta) - mH = tcompanion_matrix(mh, meta) + mH = ctcompanion_matrix(mh, meta) Λ = sum(sum(es[j]' * mΛ * es[i] * Fs[j] * Vh * Fs[i]' for i in 1:length(Fs)) for j in 1:length(Fs)) diff --git a/src/rules/transfominator/y.jl b/src/rules/continuous_transition/y.jl similarity index 66% rename from src/rules/transfominator/y.jl rename to src/rules/continuous_transition/y.jl index df088d19c..501dcae75 100644 --- a/src/rules/transfominator/y.jl +++ b/src/rules/continuous_transition/y.jl @@ -1,4 +1,4 @@ -@rule Transfominator(:y, Marginalisation) (m_x::MultivariateNormalDistributionsFamily, q_h::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::TMeta) = begin +@rule ContinuousTransition(:y, Marginalisation) (m_x::MultivariateNormalDistributionsFamily, q_h::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::CTMeta) = begin mh, Vh = mean_cov(q_h) mx, Wx = mean_invcov(m_x) @@ -7,7 +7,7 @@ dy, dx = getdimensionality(meta) Fs, es = getmasks(meta), getunits(meta) - mH = tcompanion_matrix(mh, meta) + mH = ctcompanion_matrix(mh, meta) Λ = sum(sum(es[j]' * mΛ * es[i] * Fs[j] * Vh * Fs[i]' for i in 1:length(Fs)) for j in 1:length(Fs)) diff --git a/src/rules/prototypes.jl b/src/rules/prototypes.jl index 25a523800..a0d8db46e 100644 --- a/src/rules/prototypes.jl +++ b/src/rules/prototypes.jl @@ -105,11 +105,11 @@ include("transition/out.jl") include("transition/in.jl") include("transition/a.jl") -include("transfominator/y.jl") -include("transfominator/x.jl") -include("transfominator/h.jl") -include("transfominator/lambda.jl") -include("transfominator/marginals.jl") +include("continuous_transition/y.jl") +include("continuous_transition/x.jl") +include("continuous_transition/h.jl") +include("continuous_transition/lambda.jl") +include("continuous_transition/marginals.jl") include("autoregressive/y.jl") include("autoregressive/x.jl") diff --git a/test/nodes/test_transfominator.jl b/test/nodes/test_transfominator.jl index ee2f3a858..1c4c77350 100644 --- a/test/nodes/test_transfominator.jl +++ b/test/nodes/test_transfominator.jl @@ -1,22 +1,22 @@ score( AverageEnergy(), - Transfominator, + ContinuousTransition, Val{(:y_x, :h, :Λ)}(), ( Marginal(MvNormalMeanPrecision(zeros(4), diageye(4)), false, false, nothing), Marginal(MvNormalMeanPrecision(zeros(4), diageye(4)), false, false, nothing), Marginal(Wishart(2, diageye(2)), false, false, nothing) ), - TMeta((2, 2)) + CTMeta((2, 2)) ) score( AverageEnergy(), - Transfominator, + ContinuousTransition, Val{(:y_x, :h, :Λ)}(), ( Marginal(MvNormalMeanPrecision(zeros(5), diageye(5)), false, false, nothing), Marginal(MvNormalMeanPrecision(zeros(6), diageye(6)), false, false, nothing), Marginal(Wishart(2, diageye(2)), false, false, nothing) ), - TMeta((2, 3)) + CTMeta((2, 3)) ) diff --git a/test/rules/continuous_transition/test_h.jl b/test/rules/continuous_transition/test_h.jl new file mode 100644 index 000000000..bdda2fa3d --- /dev/null +++ b/test/rules/continuous_transition/test_h.jl @@ -0,0 +1 @@ +@call_rule ContinuousTransition(:h, Marginalisation) (q_y_x = MvNormalMeanCovariance(randn(5), diageye(5)), q_Λ = Wishart(2, diageye(2)), meta = CTMeta(2, 3)) diff --git a/test/rules/continuous_transition/test_lambda.jl b/test/rules/continuous_transition/test_lambda.jl new file mode 100644 index 000000000..3b588591f --- /dev/null +++ b/test/rules/continuous_transition/test_lambda.jl @@ -0,0 +1 @@ +@call_rule ContinuousTransition(:Λ, Marginalisation) (q_y_x = MvNormalMeanCovariance(randn(5), diageye(5)), q_h = MvNormalMeanCovariance(randn(6), diageye(6)), meta = CTMeta(2, 3)) diff --git a/test/rules/continuous_transition/test_marginals.jl b/test/rules/continuous_transition/test_marginals.jl new file mode 100644 index 000000000..b00ac22ff --- /dev/null +++ b/test/rules/continuous_transition/test_marginals.jl @@ -0,0 +1,28 @@ +module RulesContinuousTransitionTest + +using Test +using ReactiveMP +using Random +using LinearAlgebra +using Distributions + +import ReactiveMP: @test_marginalrules + +# @call_marginalrule ContinuousTransition(:y_x) (m_y = MvNormalMeanPrecision(ones(2), diageye(2)), m_x = MvNormalMeanPrecision(ones(2), diageye(2)), q_h = MvNormalMeanPrecision(ones(4), diageye(4)), q_Λ = Wishart(2, diageye(2)), meta = CTMeta(2, 2)) +# @call_marginalrule ContinuousTransition(:y_x) (m_y = MvNormalMeanPrecision(ones(2), diageye(2)), m_x = MvNormalMeanPrecision(ones(3), diageye(3)), q_h = MvNormalMeanPrecision(ones(6), diageye(6)), q_Λ = Wishart(2, diageye(2)), meta = CTMeta(2, 3)) + +@testset "marginalrules:ContinuousTransition" begin + @testset "y_x: (m_y::NormalDistributionsFamily, m_x::NormalDistributionsFamily, q_θ::NormalDistributionsFamily, q_γ::Any)" begin + @test_marginalrules [with_float_conversions = true] ContinuousTransition(:y_x) [( + input = ( + m_y = MvNormalMeanPrecision(ones(2), diageye(2)), + m_x = MvNormalMeanPrecision(ones(2), diageye(2)), + q_h = MvNormalMeanPrecision(ones(4), diageye(4)), + q_Λ = Wishart(2, diageye(2)), + meta = CTMeta(2, 2) + ), + output = MvNormalWeightedMeanPrecision(zeros(2), [2.0 -1.0; -1.0 3.0]) + )] + end +end +end diff --git a/test/rules/transfominator/test_x.jl b/test/rules/continuous_transition/test_x.jl similarity index 60% rename from test/rules/transfominator/test_x.jl rename to test/rules/continuous_transition/test_x.jl index c6d6c9db1..e67d8f4d3 100644 --- a/test/rules/transfominator/test_x.jl +++ b/test/rules/continuous_transition/test_x.jl @@ -1,3 +1,3 @@ -@call_rule Transfominator(:x, Marginalisation) ( - m_y = MvNormalMeanPrecision(randn(2), diageye(2)), q_h = MvNormalMeanPrecision(randn(6), diageye(6)), q_Λ = Wishart(2, diageye(2)), meta = TMeta(2, 3) +@call_rule ContinuousTransition(:x, Marginalisation) ( + m_y = MvNormalMeanPrecision(randn(2), diageye(2)), q_h = MvNormalMeanPrecision(randn(6), diageye(6)), q_Λ = Wishart(2, diageye(2)), meta = CTMeta(2, 3) ) diff --git a/test/rules/transfominator/test_y.jl b/test/rules/continuous_transition/test_y.jl similarity index 60% rename from test/rules/transfominator/test_y.jl rename to test/rules/continuous_transition/test_y.jl index 023f5a48a..8f146ad4f 100644 --- a/test/rules/transfominator/test_y.jl +++ b/test/rules/continuous_transition/test_y.jl @@ -1,3 +1,3 @@ -@call_rule Transfominator(:y, Marginalisation) ( - m_x = MvNormalMeanPrecision(randn(3), diageye(3)), q_h = MvNormalMeanPrecision(randn(6), diageye(6)), q_Λ = Wishart(2, diageye(2)), meta = TMeta(2, 3) +@call_rule ContinuousTransition(:y, Marginalisation) ( + m_x = MvNormalMeanPrecision(randn(3), diageye(3)), q_h = MvNormalMeanPrecision(randn(6), diageye(6)), q_Λ = Wishart(2, diageye(2)), meta = CTMeta(2, 3) ) diff --git a/test/rules/transfominator/test_h.jl b/test/rules/transfominator/test_h.jl deleted file mode 100644 index c796c4ae7..000000000 --- a/test/rules/transfominator/test_h.jl +++ /dev/null @@ -1 +0,0 @@ -@call_rule Transfominator(:h, Marginalisation) (q_y_x = MvNormalMeanCovariance(randn(5), diageye(5)), q_Λ = Wishart(2, diageye(2)), meta = TMeta(2, 3)) diff --git a/test/rules/transfominator/test_lambda.jl b/test/rules/transfominator/test_lambda.jl deleted file mode 100644 index e5e16004f..000000000 --- a/test/rules/transfominator/test_lambda.jl +++ /dev/null @@ -1 +0,0 @@ -@call_rule Transfominator(:Λ, Marginalisation) (q_y_x = MvNormalMeanCovariance(randn(5), diageye(5)), q_h = MvNormalMeanCovariance(randn(6), diageye(6)), meta = TMeta(2, 3)) diff --git a/test/rules/transfominator/test_marginals.jl b/test/rules/transfominator/test_marginals.jl deleted file mode 100644 index 3b17c9b19..000000000 --- a/test/rules/transfominator/test_marginals.jl +++ /dev/null @@ -1,28 +0,0 @@ -module RulesTransfominatorTest - -using Test -using ReactiveMP -using Random -using LinearAlgebra -using Distributions - -import ReactiveMP: @test_marginalrules - -# @call_marginalrule Transfominator(:y_x) (m_y = MvNormalMeanPrecision(ones(2), diageye(2)), m_x = MvNormalMeanPrecision(ones(2), diageye(2)), q_h = MvNormalMeanPrecision(ones(4), diageye(4)), q_Λ = Wishart(2, diageye(2)), meta = TMeta(2, 2)) -# @call_marginalrule Transfominator(:y_x) (m_y = MvNormalMeanPrecision(ones(2), diageye(2)), m_x = MvNormalMeanPrecision(ones(3), diageye(3)), q_h = MvNormalMeanPrecision(ones(6), diageye(6)), q_Λ = Wishart(2, diageye(2)), meta = TMeta(2, 3)) - -@testset "marginalrules:Transfominator" begin - @testset "y_x: (m_y::NormalDistributionsFamily, m_x::NormalDistributionsFamily, q_θ::NormalDistributionsFamily, q_γ::Any)" begin - @test_marginalrules [with_float_conversions = true] Transfominator(:y_x) [( - input = ( - m_y = MvNormalMeanPrecision(ones(2), diageye(2)), - m_x = MvNormalMeanPrecision(ones(2), diageye(2)), - q_h = MvNormalMeanPrecision(ones(4), diageye(4)), - q_Λ = Wishart(2, diageye(2)), - meta = TMeta(2, 2) - ), - output = MvNormalWeightedMeanPrecision(zeros(2), [2.0 -1.0; -1.0 3.0]) - )] - end -end -end From 997dea4c0d03ff14036818cb47634dfb804b94ca Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Fri, 5 May 2023 18:57:18 +0200 Subject: [PATCH 06/38] Make format --- src/nodes/continuous_transition.jl | 18 +++++++++++------- src/rules/continuous_transition/lambda.jl | 2 +- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/nodes/continuous_transition.jl b/src/nodes/continuous_transition.jl index bd31ccce7..e420f8f53 100644 --- a/src/nodes/continuous_transition.jl +++ b/src/nodes/continuous_transition.jl @@ -4,18 +4,22 @@ import LazyArrays, BlockArrays import StatsFuns: log2π @doc raw""" -The ContinuousTransition node is a node that transforms a n-dimensional vector x into m-dimensional vector y. -The transformation is achieved by casting n*m-dimensional vector h into a m×n H matrix. +The ContinuousTransition node transforms an m-dimensional (dx) vector x into an n-dimensional (dy) vector y via a linear transformation with a n×m-dimensional matrix H that is constructed from a n*m-dimensional vector h. + +To construct the matrix H, the elements of h are filled into H starting with the first row, one element at a time. + +The transformation is performed with the following syntax: ```julia y ~ ContinuousTransition(x, h, Λ) ``` - Interfaces: -1. y - latent output of the ContinuousTransition node -2. x - latent input of the ContinuousTransition node -3. h - latent vector that casts into the matrix H -4. Λ - latent precision matrix (could be fixed) +1. y - n-dimensional output of the ContinuousTransition node. +2. x - m-dimensional input of the ContinuousTransition node. +3. h - nm-dimensional vector that casts into the matrix H. +4. Λ - n×n-dimensional precision matrix used to soften the transition and perform variational message passing, as belief-propagation is not feasible for y = Hx. + +Note that you can set Λ to a fixed value or put a prior on it to control the amount of jitter. """ struct ContinuousTransition end diff --git a/src/rules/continuous_transition/lambda.jl b/src/rules/continuous_transition/lambda.jl index c1f1545ab..da581af8a 100644 --- a/src/rules/continuous_transition/lambda.jl +++ b/src/rules/continuous_transition/lambda.jl @@ -23,5 +23,5 @@ end Δ = compute_delta(my, Vy, mx, Vx, Vyx, mH, Vh, mh, Fs, es) # NOTE: WishartMessage stores inverse of scale matrix - return WishartMessage(dy+2, Δ) + return WishartMessage(dy + 2, Δ) end From 6eb557cf29021ffd5b624acd71e9b87fc6af4ee7 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Tue, 23 May 2023 13:56:50 +0300 Subject: [PATCH 07/38] Add matrix normal node --- src/ReactiveMP.jl | 1 + src/nodes/matrix_normal.jl | 7 +++++++ src/rules/matrix_normal/out.jl | 3 +++ src/rules/prototypes.jl | 2 ++ 4 files changed, 13 insertions(+) create mode 100644 src/nodes/matrix_normal.jl create mode 100644 src/rules/matrix_normal/out.jl diff --git a/src/ReactiveMP.jl b/src/ReactiveMP.jl index 9ed0686bd..64af21b48 100644 --- a/src/ReactiveMP.jl +++ b/src/ReactiveMP.jl @@ -133,6 +133,7 @@ include("nodes/uninformative.jl") include("nodes/uniform.jl") include("nodes/normal_mean_variance.jl") include("nodes/normal_mean_precision.jl") +include("nodes/matrix_normal.jl") include("nodes/mv_normal_mean_covariance.jl") include("nodes/mv_normal_mean_precision.jl") include("nodes/mv_normal_mean_scale_precision.jl") diff --git a/src/nodes/matrix_normal.jl b/src/nodes/matrix_normal.jl new file mode 100644 index 000000000..83517dddc --- /dev/null +++ b/src/nodes/matrix_normal.jl @@ -0,0 +1,7 @@ +@node MatrixNormal Stochastic [out, M, U, V] + +# default method for mean-field assumption +@average_energy MatrixNormal (q_out::Any, q_M::Any, q_U::Any, q_V::Any) = begin + q_Σ = PointMass(kron(mean(q_U), mean(q_V))) + -score(AverageEnergy(), MvNormalMeanCovariance, Val{(:out, :μ, :Σ)}(), map((q) -> Marginal(q, false, false, nothing), (q_out, q_M, q_Σ)), nothing) +end \ No newline at end of file diff --git a/src/rules/matrix_normal/out.jl b/src/rules/matrix_normal/out.jl new file mode 100644 index 000000000..3539b7917 --- /dev/null +++ b/src/rules/matrix_normal/out.jl @@ -0,0 +1,3 @@ +@rule MatrixNormal(:out, Marginalisation) (q_M::PointMass, q_U::PointMass, q_V::PointMass, ) = begin + MvNormalMeanCovariance(vec(mean(q_M)), kron(mean(q_U), mean(q_V))) +end diff --git a/src/rules/prototypes.jl b/src/rules/prototypes.jl index a0d8db46e..3152e2fab 100644 --- a/src/rules/prototypes.jl +++ b/src/rules/prototypes.jl @@ -62,6 +62,8 @@ include("normal_mean_variance/mean.jl") include("normal_mean_variance/var.jl") include("normal_mean_variance/marginals.jl") +include("matrix_normal/out.jl") + include("mv_normal_mean_precision/out.jl") include("mv_normal_mean_precision/mean.jl") include("mv_normal_mean_precision/precision.jl") From 030acdadca290aa6eb1a5082f2f75f09b3c263e7 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Fri, 3 Nov 2023 12:17:05 +0100 Subject: [PATCH 08/38] Modify MatNormal --- src/nodes/matrix_normal.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/nodes/matrix_normal.jl b/src/nodes/matrix_normal.jl index 83517dddc..6ab777434 100644 --- a/src/nodes/matrix_normal.jl +++ b/src/nodes/matrix_normal.jl @@ -1,7 +1,8 @@ @node MatrixNormal Stochastic [out, M, U, V] -# default method for mean-field assumption +# we use equivalence of `` @average_energy MatrixNormal (q_out::Any, q_M::Any, q_U::Any, q_V::Any) = begin q_Σ = PointMass(kron(mean(q_U), mean(q_V))) - -score(AverageEnergy(), MvNormalMeanCovariance, Val{(:out, :μ, :Σ)}(), map((q) -> Marginal(q, false, false, nothing), (q_out, q_M, q_Σ)), nothing) + q_m = PointMass(vec(mean(q_M))) + -score(AverageEnergy(), MvNormalMeanCovariance, Val{(:out, :μ, :Σ)}(), map((q) -> Marginal(q, false, false, nothing), (q_out, q_m, q_Σ)), nothing) end \ No newline at end of file From 1c591b782a9e97acb0da09401a917132a353f53e Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Fri, 3 Nov 2023 19:48:20 +0100 Subject: [PATCH 09/38] Update CTransition --- src/nodes/continuous_transition.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nodes/continuous_transition.jl b/src/nodes/continuous_transition.jl index e420f8f53..7b0ff0f69 100644 --- a/src/nodes/continuous_transition.jl +++ b/src/nodes/continuous_transition.jl @@ -1,6 +1,6 @@ export transfominator, CTransition, ContinuousTransition, CTMeta -import LazyArrays, BlockArrays +import LazyArrays import StatsFuns: log2π @doc raw""" From 6de02d70fe6e3a09772040a0e9421a275bb6283d Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Fri, 10 Nov 2023 15:58:31 +0100 Subject: [PATCH 10/38] Update ctransition --- src/nodes/continuous_transition.jl | 121 ++++++++++-------- src/rules/continuous_transition/W.jl | 28 ++++ .../continuous_transition/{h.jl => a.jl} | 8 +- src/rules/continuous_transition/lambda.jl | 27 ---- src/rules/continuous_transition/marginals.jl | 26 ++-- src/rules/continuous_transition/x.jl | 18 +-- src/rules/continuous_transition/y.jl | 18 +-- src/rules/prototypes.jl | 4 +- 8 files changed, 136 insertions(+), 114 deletions(-) create mode 100644 src/rules/continuous_transition/W.jl rename src/rules/continuous_transition/{h.jl => a.jl} (65%) delete mode 100644 src/rules/continuous_transition/lambda.jl diff --git a/src/nodes/continuous_transition.jl b/src/nodes/continuous_transition.jl index 7b0ff0f69..be91503ab 100644 --- a/src/nodes/continuous_transition.jl +++ b/src/nodes/continuous_transition.jl @@ -4,90 +4,107 @@ import LazyArrays import StatsFuns: log2π @doc raw""" -The ContinuousTransition node transforms an m-dimensional (dx) vector x into an n-dimensional (dy) vector y via a linear transformation with a n×m-dimensional matrix H that is constructed from a n*m-dimensional vector h. +The ContinuousTransition node transforms an m-dimensional (dx) vector x into an n-dimensional (dy) vector y via a linear (or nonlinear) transformation with a `n×m`-dimensional matrix `A` that is constructed from a `nm`-dimensional vector `a`. -To construct the matrix H, the elements of h are filled into H starting with the first row, one element at a time. +To construct the matrix A, the elements of `a` are filled into A according to the transformation function provided with meta. -The transformation is performed with the following syntax: +Check CTMeta for more details on how to specify the transformation function that **must** return a matrix. ```julia -y ~ ContinuousTransition(x, h, Λ) +y ~ ContinuousTransition(x, a, W) ``` Interfaces: 1. y - n-dimensional output of the ContinuousTransition node. 2. x - m-dimensional input of the ContinuousTransition node. -3. h - nm-dimensional vector that casts into the matrix H. -4. Λ - n×n-dimensional precision matrix used to soften the transition and perform variational message passing, as belief-propagation is not feasible for y = Hx. +3. a - `nm`-dimensional vector that casts into the matrix `A`. +4. W - `n×n`-dimensional precision matrix used to soften the transition and perform variational message passing, as belief-propagation is not feasible for `y = Ax`. -Note that you can set Λ to a fixed value or put a prior on it to control the amount of jitter. +Note that you can set W to a fixed value or put a prior on it to control the amount of jitter. """ struct ContinuousTransition end -const transfominator = ContinuousTransition const CTransition = ContinuousTransition -@node ContinuousTransition Stochastic [y, x, h, Λ] +@node ContinuousTransition Stochastic [y, x, a, W] +@doc raw""" +`CTMeta` is used as a metadata flag in `ContinuousTransition` to define the transformation function for constructing the matrix `A` from vector `a`. + +There are two scenarios for specifying the transformation: +1. **Linear Transformation**: In this case, `CTMeta` requires a transformation function and the length of vector `a`. +2. **Nonlinear Transformation**: For nonlinear transformations, `CTMeta` expects a transformation function and a vector `â`, which acts as an expansion point for approximating the transformation linearly. + +Constructors: +- `CTMeta(transformation::Function, len::Integer)`: Used for linear transformations. +- `CTMeta(transformation::Function, â::Vector{<:Real})`: Used for nonlinear transformations. + +Fields: +- `ds`: A tuple indicating the dimensionality of the ContinuousTransition (dy, dx). +- `Fs`: Represents the masks, which can be either a Vector of AbstractMatrices or a Function, depending on the transformation type. +- `es`: A Vector of unit vectors used in the transformation process. + +The `CTMeta` struct plays a pivotal role in defining how the vector `a` is transformed into the matrix `A`, thus influencing the behavior of the `ContinuousTransition` node. +""" struct CTMeta ds::Tuple # dimensionality of ContinuousTransition (dy, dx) - Fs::Vector{<:AbstractMatrix} # masks + Fs::Union{Vector{<:AbstractMatrix}, <:Function} # masks es::Vector{<:AbstractVector} # unit vectors - function CTMeta(ds::Tuple{T, T}) where {T <: Integer} - dy, dx = ds - Fs = [ctmask(dx, dy, i) for i in 1:dy] - es = [StandardBasisVector(dy, i, one(T)) for i in 1:dy] - return new(ds, Fs, es) + # meta for linear transformation of a vector to a matrix + function CTMeta(transformation::Function, len::Integer) + dy, dx = size(transformation(zeros(len))) + Fs = [ForwardDiff.jacobian(a -> transformation(a)[i, :], 1:len) for i in 1:dy] + es = [StandardBasisVector(dy, i, 1.0) for i in 1:dy] + return new((dy, dx), Fs, es) end - function CTMeta(dy::T, dx::T) where {T <: Integer} - Fs = [ctmask(dx, dy, i) for i in 1:dy] - es = [StandardBasisVector(dy, i, one(T)) for i in 1:dy] - return new((dy, dx), Fs, es) + # meta for nonlinear transformation of a vector to a matrix + function CTMeta(transformation::Function, â::Vector{<:Real}) + dy, dx = size(transformation(â)) + es = [StandardBasisVector(dy, i, 1.0) for i in 1:dy] + return new((dy, dx), transformation, es) end + +end + +getunits(meta::CTMeta) = meta.es +getdimensionality(meta::CTMeta) = meta.ds + +getmasks(ctmeta::CTMeta, a) = process_Fs(ctmeta.Fs, a) +process_Fs(Fs::Vector{<:AbstractMatrix}, a) = Fs +process_Fs(Fs::Function, a) = [ForwardDiff.jacobian(a -> Fs(a)[i, :], a) for i in 1:size(Fs(a), 1)] + +@node ContinuousTransition Stochastic [y, x, a, W] + +default_meta(::Type{CTMeta}) = error("ContinuousTransition node requires meta flag explicitly specified") + +default_functional_dependencies_pipeline(::Type{<:ContinuousTransition}) = RequireMarginalFunctionalDependencies((3,), (nothing, )) + +function ctcompanion_matrix(a, meta::CTMeta) + Fs, es = getmasks(meta, a), getunits(meta) + dy, _ = getdimensionality(meta) + A = sum(es[i] * a' * Fs[i]' for i in 1:dy) + return A end -@average_energy ContinuousTransition (q_y_x::MultivariateNormalDistributionsFamily, q_h::MultivariateNormalDistributionsFamily, q_Λ::Wishart, meta::CTMeta) = begin - mh, Vh = mean_cov(q_h) +@average_energy ContinuousTransition (q_y_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_W::Wishart, meta::CTMeta) = begin + ma, Va = mean_cov(q_a) myx, Vyx = mean_cov(q_y_x) - mΛ = mean(q_Λ) + mW = mean(q_W) dy, dx = getdimensionality(meta) - Fs, es = getmasks(meta), getunits(meta) + Fs, es = getmasks(meta, ma), getunits(meta) n = div(ndims(q_y_x), 2) - mH = ctcompanion_matrix(mh, meta) + mA = ctcompanion_matrix(ma, meta) mx, Vx = myx[(dy + 1):end], Vyx[(dy + 1):end, (dy + 1):end] my, Vy = myx[1:dy], Vyx[1:dy, 1:dy] Vyx = Vyx[1:dy, (dy + 1):end] - g₁ = my' * mΛ * my + tr(Vy * mΛ) - g₂ = mx' * mH' * mΛ * my + tr(Vyx * mH' * mΛ) + g₁ = my' * mW * my + tr(Vy * mW) + g₂ = mx' * mA' * mW * my + tr(Vyx * mA' * mW) g₃ = g₂ - G = sum(sum(es[i]' * mΛ * es[j] * Fs[i] * (mh * mh' + Vh) * Fs[j]' for i in 1:length(Fs)) for j in 1:length(Fs)) + G = sum(sum(es[i]' * mW * es[j] * Fs[i] * (ma * ma' + Va) * Fs[j]' for i in 1:length(Fs)) for j in 1:length(Fs)) g₄ = mx' * G * mx + tr(Vx * G) - AE = n / 2 * log2π - 0.5 * mean(logdet, q_Λ) + 0.5 * (g₁ - g₂ - g₃ + g₄) + AE = n / 2 * log2π - 0.5 * mean(logdet, q_W) + 0.5 * (g₁ - g₂ - g₃ + g₄) return AE -end - -getdimensionality(meta::CTMeta) = meta.ds -getmasks(meta::CTMeta) = meta.Fs -getunits(meta::CTMeta) = meta.es - -@node ContinuousTransition Stochastic [y, x, h, Λ] - -default_meta(::Type{CTMeta}) = error("ContinuousTransition node requires meta flag explicitly specified") - -function ctmask(dim1, dim2, index) - F = zeros(dim1, dim1 * dim2) - start_col = (index - 1) * dim1 + 1 - end_col = start_col + dim1 - 1 - @inbounds F[1:dim1, start_col:end_col] = I(dim1) - return F -end - -function ctcompanion_matrix(w, meta::CTMeta) - Fs, es = getmasks(meta), getunits(meta) - dy, _ = getdimensionality(meta) - L = sum(es[i] * w' * Fs[i]' for i in 1:dy) - return L -end +end \ No newline at end of file diff --git a/src/rules/continuous_transition/W.jl b/src/rules/continuous_transition/W.jl new file mode 100644 index 000000000..842822ebf --- /dev/null +++ b/src/rules/continuous_transition/W.jl @@ -0,0 +1,28 @@ +function compute_delta(my, Vy, mx, Vx, Vyx, mA, Va, ma, Fs, es) + G₁ = (my * my' + Vy) + G₂ = ((my * mx' + Vyx) * mA') + G₃ = transpose(G₂) + Ex_xx = mx * mx' + Vx + G₅ = sum(sum(es[i] * ma' * Fs[i]'Ex_xx * Fs[j] * ma * es[j]' for i in 1:length(Fs)) for j in 1:length(Fs)) + G₆ = sum(sum(es[i] * tr(Fs[i]' * Ex_xx * Fs[j] * Va) * es[j]' for i in 1:length(Fs)) for j in 1:length(Fs)) + return G₁ - G₂ - G₃ + G₅ + G₆ +end + +@rule ContinuousTransition(:W, Marginalisation) (q_y_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, meta::CTMeta) = begin + dy, dx = getdimensionality(meta) + + ma, Va = mean_cov(q_a) + Fs, es = getmasks(meta, ma), getunits(meta) + + mA = ctcompanion_matrix(ma, meta) + myx, Vyx = mean_cov(q_y_x) + + mx, Vx = myx[(dy + 1):end], Vyx[(dy + 1):end, (dy + 1):end] + my, Vy = myx[1:dy], Vyx[1:dy, 1:dy] + Vyx = Vyx[1:dy, (dy + 1):end] + + Δ = compute_delta(my, Vy, mx, Vx, Vyx, mA, Va, ma, Fs, es) + + # NOTE: WishartFast stores inverse of scale matrix + return WishartFast(dy + 2, Δ) +end diff --git a/src/rules/continuous_transition/h.jl b/src/rules/continuous_transition/a.jl similarity index 65% rename from src/rules/continuous_transition/h.jl rename to src/rules/continuous_transition/a.jl index d26b3a7fc..a3b6e7395 100644 --- a/src/rules/continuous_transition/h.jl +++ b/src/rules/continuous_transition/a.jl @@ -1,6 +1,8 @@ -@rule ContinuousTransition(:h, Marginalisation) (q_y_x::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::CTMeta) = begin +@rule ContinuousTransition(:a, Marginalisation) (q_y_x::MultivariateNormalDistributionsFamily, q_a::NormalDistributionsFamily, q_W::Any, meta::CTMeta) = begin dy, dx = getdimensionality(meta) - Fs, es = getmasks(meta), getunits(meta) + + ma = mean(q_a) + Fs, es = getmasks(meta, ma), getunits(meta) myx, Vyx = mean_cov(q_y_x) @@ -8,7 +10,7 @@ my, Vy = myx[1:dy], Vyx[1:dy, 1:dy] Vyx = Vyx[1:dy, (dy + 1):end] - mΛ = mean(q_Λ) + mΛ = mean(q_W) D = sum(sum(es[i]' * mΛ * es[j] * Fs[i]' * (mx * mx' + Vx) * Fs[j] for i in 1:length(Fs)) for j in 1:length(Fs)) z = sum(Fs[i]' * (mx * my' + Vyx') * mΛ * es[i] for i in 1:length(Fs)) diff --git a/src/rules/continuous_transition/lambda.jl b/src/rules/continuous_transition/lambda.jl deleted file mode 100644 index da581af8a..000000000 --- a/src/rules/continuous_transition/lambda.jl +++ /dev/null @@ -1,27 +0,0 @@ -function compute_delta(my, Vy, mx, Vx, Vyx, mH, Vh, mh, Fs, es) - G₁ = (my * my' + Vy) - G₂ = ((my * mx' + Vyx) * mH') - G₃ = transpose(G₂) - Ex_xx = mx * mx' + Vx - G₅ = sum(sum(es[i] * mh' * Fs[i]'Ex_xx * Fs[j] * mh * es[j]' for i in 1:length(Fs)) for j in 1:length(Fs)) - G₆ = sum(sum(es[i] * tr(Fs[i]' * Ex_xx * Fs[j] * Vh) * es[j]' for i in 1:length(Fs)) for j in 1:length(Fs)) - return G₁ - G₂ - G₃ + G₅ + G₆ -end - -@rule ContinuousTransition(:Λ, Marginalisation) (q_y_x::MultivariateNormalDistributionsFamily, q_h::MultivariateNormalDistributionsFamily, meta::CTMeta) = begin - dy, dx = getdimensionality(meta) - Fs, es = getmasks(meta), getunits(meta) - - mh, Vh = mean_cov(q_h) - mH = ctcompanion_matrix(mh, meta) - myx, Vyx = mean_cov(q_y_x) - - mx, Vx = myx[(dy + 1):end], Vyx[(dy + 1):end, (dy + 1):end] - my, Vy = myx[1:dy], Vyx[1:dy, 1:dy] - Vyx = Vyx[1:dy, (dy + 1):end] - - Δ = compute_delta(my, Vy, mx, Vx, Vyx, mH, Vh, mh, Fs, es) - - # NOTE: WishartMessage stores inverse of scale matrix - return WishartMessage(dy + 2, Δ) -end diff --git a/src/rules/continuous_transition/marginals.jl b/src/rules/continuous_transition/marginals.jl index 7c5633727..2bffa94ef 100644 --- a/src/rules/continuous_transition/marginals.jl +++ b/src/rules/continuous_transition/marginals.jl @@ -1,19 +1,21 @@ @marginalrule ContinuousTransition(:y_x) ( - m_y::MultivariateNormalDistributionsFamily, m_x::MultivariateNormalDistributionsFamily, q_h::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::CTMeta + m_y::MultivariateNormalDistributionsFamily, m_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_W::Any, meta::CTMeta ) = begin - return continuous_tranition_marginal(m_y, m_x, q_h, q_Λ, meta) + return continuous_tranition_marginal(m_y, m_x, q_a, q_W, meta) end function continuous_tranition_marginal( - m_y::MultivariateNormalDistributionsFamily, m_x::MultivariateNormalDistributionsFamily, q_h::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::CTMeta + m_y::MultivariateNormalDistributionsFamily, m_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_W::Any, meta::CTMeta ) - Fs, es = getmasks(meta), getunits(meta) - mh, Vh = mean_cov(q_h) - mΛ = mean(q_Λ) + ma, Va = mean_cov(q_a) - mH = ctcompanion_matrix(mh, meta) + Fs, es = getmasks(meta, ma), getunits(meta) + + mW = mean(q_W) + + mA = ctcompanion_matrix(ma, meta) b_my, b_Vy = mean_cov(m_y) f_mx, f_Vx = mean_cov(m_x) @@ -21,16 +23,16 @@ function continuous_tranition_marginal( inv_b_Vy = cholinv(b_Vy) inv_f_Vx = cholinv(f_Vx) - Ξ = inv_f_Vx + sum(sum(es[j]' * mΛ * es[i] * Fs[j] * Vh * Fs[i]' for i in 1:length(Fs)) for j in 1:length(Fs)) + Ξ = inv_f_Vx + sum(sum(es[j]' * mW * es[i] * Fs[j] * Va * Fs[i]' for i in 1:length(Fs)) for j in 1:length(Fs)) - W_11 = inv_b_Vy + mΛ + W_11 = inv_b_Vy + mW # negate_inplace!(mW * mH) - W_12 = -(mΛ * mH) + W_12 = -(mW * mA) - W_21 = -(mH' * mΛ) + W_21 = -(mA' * mW) - W_22 = Ξ + mH' * mΛ * mH + W_22 = Ξ + mA' * mW * mA W = [W_11 W_12; W_21 W_22] ξ = [inv_b_Vy * b_my; inv_f_Vx * f_mx] diff --git a/src/rules/continuous_transition/x.jl b/src/rules/continuous_transition/x.jl index f96843919..4add0bf65 100644 --- a/src/rules/continuous_transition/x.jl +++ b/src/rules/continuous_transition/x.jl @@ -1,20 +1,20 @@ -@rule ContinuousTransition(:x, Marginalisation) (m_y::MultivariateNormalDistributionsFamily, q_h::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::CTMeta) = begin - mh, Vh = mean_cov(q_h) +@rule ContinuousTransition(:x, Marginalisation) (m_y::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_W::Any, meta::CTMeta) = begin + ma, Va = mean_cov(q_a) my, Vy = mean_cov(m_y) - mΛ = mean(q_Λ) + mW = mean(q_W) dy, dx = getdimensionality(meta) - Fs, es = getmasks(meta), getunits(meta) + Fs, es = getmasks(meta, ma), getunits(meta) - mH = ctcompanion_matrix(mh, meta) + mA = ctcompanion_matrix(ma, meta) - Λ = sum(sum(es[j]' * mΛ * es[i] * Fs[j] * Vh * Fs[i]' for i in 1:length(Fs)) for j in 1:length(Fs)) + W = sum(sum(es[j]' * mW * es[i] * Fs[j] * Va * Fs[i]' for i in 1:length(Fs)) for j in 1:length(Fs)) - Σ₁ = Hermitian(pinv(mH) * (Vy) * pinv(mH') + pinv(mH' * mΛ * mH)) + Σ₁ = Hermitian(pinv(mA) * (Vy) * pinv(mA') + pinv(mA' * mW * mA)) - Ξ = (pinv(Σ₁) + Λ) - z = pinv(Σ₁) * pinv(mH) * my + Ξ = (pinv(Σ₁) + W) + z = pinv(Σ₁) * pinv(mA) * my return MvNormalWeightedMeanPrecision(z, Ξ) end diff --git a/src/rules/continuous_transition/y.jl b/src/rules/continuous_transition/y.jl index 501dcae75..ee3b8fdaf 100644 --- a/src/rules/continuous_transition/y.jl +++ b/src/rules/continuous_transition/y.jl @@ -1,21 +1,21 @@ -@rule ContinuousTransition(:y, Marginalisation) (m_x::MultivariateNormalDistributionsFamily, q_h::MultivariateNormalDistributionsFamily, q_Λ::Any, meta::CTMeta) = begin - mh, Vh = mean_cov(q_h) +@rule ContinuousTransition(:y, Marginalisation) (m_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_W::Any, meta::CTMeta) = begin + ma, Va = mean_cov(q_a) mx, Wx = mean_invcov(m_x) - mΛ = mean(q_Λ) + mW = mean(q_W) dy, dx = getdimensionality(meta) - Fs, es = getmasks(meta), getunits(meta) + Fs, es = getmasks(meta, ma), getunits(meta) - mH = ctcompanion_matrix(mh, meta) + mA = ctcompanion_matrix(ma, meta) - Λ = sum(sum(es[j]' * mΛ * es[i] * Fs[j] * Vh * Fs[i]' for i in 1:length(Fs)) for j in 1:length(Fs)) + W = sum(sum(es[j]' * mW * es[i] * Fs[j] * Va * Fs[i]' for i in 1:length(Fs)) for j in 1:length(Fs)) - Ξ = Λ + Wx + Ξ = W + Wx z = Wx * mx - Vy = mH * inv(Ξ) * mH' + inv(mΛ) - my = mH * inv(Ξ) * z + Vy = mA * inv(Ξ) * mA' + inv(mW) + my = mA * inv(Ξ) * z return MvNormalMeanCovariance(my, Vy) end diff --git a/src/rules/prototypes.jl b/src/rules/prototypes.jl index 8e3608c0a..934149a25 100644 --- a/src/rules/prototypes.jl +++ b/src/rules/prototypes.jl @@ -114,8 +114,8 @@ include("transition/a.jl") include("continuous_transition/y.jl") include("continuous_transition/x.jl") -include("continuous_transition/h.jl") -include("continuous_transition/lambda.jl") +include("continuous_transition/a.jl") +include("continuous_transition/W.jl") include("continuous_transition/marginals.jl") include("autoregressive/y.jl") From 841a57e9910069f2dfd1328af650cdce31722163 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Sun, 12 Nov 2023 14:51:00 +0100 Subject: [PATCH 11/38] Update ctransition --- src/nodes/continuous_transition.jl | 6 +++++- src/rules/continuous_transition/a.jl | 6 +++--- .../continuous_transition/{test_lambda.jl => test_W.jl} | 0 test/rules/continuous_transition/{test_h.jl => test_a.jl} | 0 4 files changed, 8 insertions(+), 4 deletions(-) rename test/rules/continuous_transition/{test_lambda.jl => test_W.jl} (100%) rename test/rules/continuous_transition/{test_h.jl => test_a.jl} (100%) diff --git a/src/nodes/continuous_transition.jl b/src/nodes/continuous_transition.jl index be91503ab..1c18369a5 100644 --- a/src/nodes/continuous_transition.jl +++ b/src/nodes/continuous_transition.jl @@ -50,6 +50,8 @@ struct CTMeta Fs::Union{Vector{<:AbstractMatrix}, <:Function} # masks es::Vector{<:AbstractVector} # unit vectors + # NOTE: this meta is not user-friendly, I don't like supplying the length of a vector + # perhaps making mutable struct with empty meta first will be better from user perspective # meta for linear transformation of a vector to a matrix function CTMeta(transformation::Function, len::Integer) dy, dx = size(transformation(zeros(len))) @@ -72,7 +74,9 @@ getdimensionality(meta::CTMeta) = meta.ds getmasks(ctmeta::CTMeta, a) = process_Fs(ctmeta.Fs, a) process_Fs(Fs::Vector{<:AbstractMatrix}, a) = Fs -process_Fs(Fs::Function, a) = [ForwardDiff.jacobian(a -> Fs(a)[i, :], a) for i in 1:size(Fs(a), 1)] + +# NOTE: this doesn't seem to be the right way of working with nonlinar approximation +process_Fs(Fs::Function, a) = [ForwardDiff.jacobian(a -> Fs(a)[i, :], a) for i in 1:size(Fs(a), 1)] @node ContinuousTransition Stochastic [y, x, a, W] diff --git a/src/rules/continuous_transition/a.jl b/src/rules/continuous_transition/a.jl index a3b6e7395..e021a1d4d 100644 --- a/src/rules/continuous_transition/a.jl +++ b/src/rules/continuous_transition/a.jl @@ -10,10 +10,10 @@ my, Vy = myx[1:dy], Vyx[1:dy, 1:dy] Vyx = Vyx[1:dy, (dy + 1):end] - mΛ = mean(q_W) + mW = mean(q_W) - D = sum(sum(es[i]' * mΛ * es[j] * Fs[i]' * (mx * mx' + Vx) * Fs[j] for i in 1:length(Fs)) for j in 1:length(Fs)) - z = sum(Fs[i]' * (mx * my' + Vyx') * mΛ * es[i] for i in 1:length(Fs)) + D = sum(sum(es[i]' * mW * es[j] * Fs[i]' * (mx * mx' + Vx) * Fs[j] for i in 1:length(Fs)) for j in 1:length(Fs)) + z = sum(Fs[i]' * (mx * my' + Vyx') * mW * es[i] for i in 1:length(Fs)) return MvNormalWeightedMeanPrecision(z, D) end diff --git a/test/rules/continuous_transition/test_lambda.jl b/test/rules/continuous_transition/test_W.jl similarity index 100% rename from test/rules/continuous_transition/test_lambda.jl rename to test/rules/continuous_transition/test_W.jl diff --git a/test/rules/continuous_transition/test_h.jl b/test/rules/continuous_transition/test_a.jl similarity index 100% rename from test/rules/continuous_transition/test_h.jl rename to test/rules/continuous_transition/test_a.jl From bfbb15e3b2f87f798f9f9c13a00075164c6f37f4 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Sun, 12 Nov 2023 14:56:39 +0100 Subject: [PATCH 12/38] Make format --- src/nodes/continuous_transition.jl | 7 +++---- src/nodes/matrix_normal.jl | 2 +- src/rules/continuous_transition/marginals.jl | 1 - src/rules/matrix_normal/out.jl | 2 +- 4 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/nodes/continuous_transition.jl b/src/nodes/continuous_transition.jl index 1c18369a5..b7a58d654 100644 --- a/src/nodes/continuous_transition.jl +++ b/src/nodes/continuous_transition.jl @@ -66,13 +66,12 @@ struct CTMeta es = [StandardBasisVector(dy, i, 1.0) for i in 1:dy] return new((dy, dx), transformation, es) end - end getunits(meta::CTMeta) = meta.es getdimensionality(meta::CTMeta) = meta.ds -getmasks(ctmeta::CTMeta, a) = process_Fs(ctmeta.Fs, a) +getmasks(ctmeta::CTMeta, a) = process_Fs(ctmeta.Fs, a) process_Fs(Fs::Vector{<:AbstractMatrix}, a) = Fs # NOTE: this doesn't seem to be the right way of working with nonlinar approximation @@ -82,7 +81,7 @@ process_Fs(Fs::Function, a) = [ForwardDiff.jacobian(a -> Fs(a)[i, :], a) for i i default_meta(::Type{CTMeta}) = error("ContinuousTransition node requires meta flag explicitly specified") -default_functional_dependencies_pipeline(::Type{<:ContinuousTransition}) = RequireMarginalFunctionalDependencies((3,), (nothing, )) +default_functional_dependencies_pipeline(::Type{<:ContinuousTransition}) = RequireMarginalFunctionalDependencies((3,), (nothing,)) function ctcompanion_matrix(a, meta::CTMeta) Fs, es = getmasks(meta, a), getunits(meta) @@ -111,4 +110,4 @@ end AE = n / 2 * log2π - 0.5 * mean(logdet, q_W) + 0.5 * (g₁ - g₂ - g₃ + g₄) return AE -end \ No newline at end of file +end diff --git a/src/nodes/matrix_normal.jl b/src/nodes/matrix_normal.jl index 6ab777434..3ec66c2fe 100644 --- a/src/nodes/matrix_normal.jl +++ b/src/nodes/matrix_normal.jl @@ -5,4 +5,4 @@ q_Σ = PointMass(kron(mean(q_U), mean(q_V))) q_m = PointMass(vec(mean(q_M))) -score(AverageEnergy(), MvNormalMeanCovariance, Val{(:out, :μ, :Σ)}(), map((q) -> Marginal(q, false, false, nothing), (q_out, q_m, q_Σ)), nothing) -end \ No newline at end of file +end diff --git a/src/rules/continuous_transition/marginals.jl b/src/rules/continuous_transition/marginals.jl index 2bffa94ef..bc5d7a09a 100644 --- a/src/rules/continuous_transition/marginals.jl +++ b/src/rules/continuous_transition/marginals.jl @@ -8,7 +8,6 @@ end function continuous_tranition_marginal( m_y::MultivariateNormalDistributionsFamily, m_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_W::Any, meta::CTMeta ) - ma, Va = mean_cov(q_a) Fs, es = getmasks(meta, ma), getunits(meta) diff --git a/src/rules/matrix_normal/out.jl b/src/rules/matrix_normal/out.jl index 3539b7917..837e1d7be 100644 --- a/src/rules/matrix_normal/out.jl +++ b/src/rules/matrix_normal/out.jl @@ -1,3 +1,3 @@ -@rule MatrixNormal(:out, Marginalisation) (q_M::PointMass, q_U::PointMass, q_V::PointMass, ) = begin +@rule MatrixNormal(:out, Marginalisation) (q_M::PointMass, q_U::PointMass, q_V::PointMass) = begin MvNormalMeanCovariance(vec(mean(q_M)), kron(mean(q_U), mean(q_V))) end From 6b48cf8448fbdd60ecbc62094a5f7e0337151987 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Sun, 12 Nov 2023 16:16:01 +0100 Subject: [PATCH 13/38] WIP: Add tests --- test/nodes/test_continuous_transition.jl | 49 +++++++++++++++++++ test/rules/continuous_transition/test_W.jl | 28 ++++++++++- test/rules/continuous_transition/test_a.jl | 27 +++++++++- .../continuous_transition/test_marginals.jl | 26 ++++------ test/rules/continuous_transition/test_x.jl | 28 +++++++++-- test/rules/continuous_transition/test_y.jl | 26 ++++++++-- test/runtests.jl | 1 + 7 files changed, 161 insertions(+), 24 deletions(-) create mode 100644 test/nodes/test_continuous_transition.jl diff --git a/test/nodes/test_continuous_transition.jl b/test/nodes/test_continuous_transition.jl new file mode 100644 index 000000000..87b4ce14d --- /dev/null +++ b/test/nodes/test_continuous_transition.jl @@ -0,0 +1,49 @@ +module ContinuousTransitionNodeTest + +using Test, ReactiveMP, Random, Distributions, BayesBase, ExponentialFamily + +import ReactiveMP: getdimensionality, getmasks, ctcompanion_matrix, getunits + +@testset "ContinuousTransitionNode" begin + @testset "Creation" begin + meta = CTMeta((a) -> reshape(a, 2, 3), 6) # Example transformation function and vector length + node = make_node(ContinuousTransition, FactorNodeCreationOptions(nothing, meta, nothing)) + + @test functionalform(node) === ContinuousTransition + @test sdtype(node) === Stochastic() + @test name.(interfaces(node)) === (:y, :x, :a, :W) + @test factorisation(node) === ((1, 2, 3, 4),) + @test getdimensionality(metadata(node)) == (2, 3) # Based on the transformation function dimensions + end + + @testset "AverageEnergy" begin + # This is an example setup, you'll need to adjust the distributions and marginals according to your needs + q_y_x = MvNormalMeanCovariance(zeros(5), diageye(5)) + q_a = MvNormalMeanCovariance(zeros(6), diageye(6)) # Adjust the dimension according to `a` + q_W = Wishart(3, diageye(2)) # Adjust the degrees of freedom and scale matrix as needed + + marginals = (Marginal(q_y_x, false, false, nothing), Marginal(q_a, false, false, nothing), Marginal(q_W, false, false, nothing)) + meta = CTMeta((a) -> reshape(a, 2, 3), 6) + + @test score(AverageEnergy(), ContinuousTransition, Val{(:y_x, :a, :W)}(), marginals, meta) ≈ 13.415092731310878 #ExpectedValue + end + + @testset "CTransitionFunctionality" begin + a = rand(6) # Example vector `a` + meta = CTMeta((a) -> reshape(a, 2, 3), 6) + A = ctcompanion_matrix(a, meta) + + @test size(A) == (2, 3) + @test A == reshape(a, 2, 3) # This is based on the transformation function provided in meta + end + + @testset "MetadataFunctionality" begin + meta = CTMeta((a) -> reshape(a, 2, 3), 6) + + @test getdimensionality(meta) == (2, 3) + @test length(getmasks(meta, rand(6))) == 2 # Based on `dy` + @test length(getunits(meta)) == 2 # Based on `dy` + end +end + +end diff --git a/test/rules/continuous_transition/test_W.jl b/test/rules/continuous_transition/test_W.jl index 3b588591f..68df9c9a9 100644 --- a/test/rules/continuous_transition/test_W.jl +++ b/test/rules/continuous_transition/test_W.jl @@ -1 +1,27 @@ -@call_rule ContinuousTransition(:Λ, Marginalisation) (q_y_x = MvNormalMeanCovariance(randn(5), diageye(5)), q_h = MvNormalMeanCovariance(randn(6), diageye(6)), meta = CTMeta(2, 3)) +module RulesContinuousTransitionTest + +using Test, ReactiveMP, BayesBase, Random, ExponentialFamily, Distributions + +import ReactiveMP: @test_rules, ctcompanion_matrix, getmasks, getunits, WishartFast + +@testset "rules:ContinuousTransition:W" begin + @testset "Structured: (q_y_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, meta::CTMeta)" begin + # Example transformation function and vector length for CTMeta + meta = CTMeta((a) -> reshape(a, 2, 3), 6) + + @test_rules [check_type_promotion = true] ContinuousTransition(:W, Marginalisation) [ + ( + input = ( + q_y_x = MvNormalMeanCovariance(zeros(5), diageye(5)), # Adjust dimensions as needed + q_a = MvNormalMeanCovariance(zeros(6), diageye(6)), + meta = meta + ), + output = WishartFast(4, 4 * diageye(2)) + ) + # Additional test cases with different distributions and metadata settings + ] + end + # Additional tests for edge cases, errors, or specific behaviors of the rule can be added here +end + +end diff --git a/test/rules/continuous_transition/test_a.jl b/test/rules/continuous_transition/test_a.jl index bdda2fa3d..98a9c6f6b 100644 --- a/test/rules/continuous_transition/test_a.jl +++ b/test/rules/continuous_transition/test_a.jl @@ -1 +1,26 @@ -@call_rule ContinuousTransition(:h, Marginalisation) (q_y_x = MvNormalMeanCovariance(randn(5), diageye(5)), q_Λ = Wishart(2, diageye(2)), meta = CTMeta(2, 3)) +module RulesContinuousTransitionTestA + +using Test, ReactiveMP, BayesBase, Random, ExponentialFamily, Distributions + +import ReactiveMP: @test_rules, getmasks, getunits + +@testset "rules:ContinuousTransition:a" begin + @testset "Structured: (q_y_x::MultivariateNormalDistributionsFamily, q_a::NormalDistributionsFamily, q_W::Any, meta::CTMeta)" begin + # Example transformation function and vector length for CTMeta + meta = CTMeta((a) -> reshape(a, 2, 3), 6) + + @test_rules [check_type_promotion = true] ContinuousTransition(:a, Marginalisation) [ + ( + input = ( + q_y_x = MvNormalMeanCovariance([zeros(2); zeros(3)], [diageye(2) zeros(2, 3); zeros(3, 2) diageye(3)]), + q_a = MvNormalMeanCovariance(zeros(6), diageye(6)), + q_W = Wishart(3, diageye(2)), + meta = meta + ), + output = MvNormalWeightedMeanPrecision(zeros(6), 3 * diageye(6)) + ) + ] + end +end + +end diff --git a/test/rules/continuous_transition/test_marginals.jl b/test/rules/continuous_transition/test_marginals.jl index b00ac22ff..d8a023432 100644 --- a/test/rules/continuous_transition/test_marginals.jl +++ b/test/rules/continuous_transition/test_marginals.jl @@ -1,27 +1,21 @@ module RulesContinuousTransitionTest -using Test -using ReactiveMP -using Random -using LinearAlgebra -using Distributions - +using Test, ReactiveMP, BayesBase, Random, ExponentialFamily, Distributions, LinearAlgebra import ReactiveMP: @test_marginalrules -# @call_marginalrule ContinuousTransition(:y_x) (m_y = MvNormalMeanPrecision(ones(2), diageye(2)), m_x = MvNormalMeanPrecision(ones(2), diageye(2)), q_h = MvNormalMeanPrecision(ones(4), diageye(4)), q_Λ = Wishart(2, diageye(2)), meta = CTMeta(2, 2)) -# @call_marginalrule ContinuousTransition(:y_x) (m_y = MvNormalMeanPrecision(ones(2), diageye(2)), m_x = MvNormalMeanPrecision(ones(3), diageye(3)), q_h = MvNormalMeanPrecision(ones(6), diageye(6)), q_Λ = Wishart(2, diageye(2)), meta = CTMeta(2, 3)) - @testset "marginalrules:ContinuousTransition" begin - @testset "y_x: (m_y::NormalDistributionsFamily, m_x::NormalDistributionsFamily, q_θ::NormalDistributionsFamily, q_γ::Any)" begin - @test_marginalrules [with_float_conversions = true] ContinuousTransition(:y_x) [( + @testset "y_x: (m_y::NormalDistributionsFamily, m_x::NormalDistributionsFamily, q_a::NormalDistributionsFamily, q_W::Any)" begin + meta = CTMeta((a) -> reshape(a, 2, 3), 6) + + @test_marginalrules [check_type_promotion = true] ContinuousTransition(:y_x) [( input = ( m_y = MvNormalMeanPrecision(ones(2), diageye(2)), - m_x = MvNormalMeanPrecision(ones(2), diageye(2)), - q_h = MvNormalMeanPrecision(ones(4), diageye(4)), - q_Λ = Wishart(2, diageye(2)), - meta = CTMeta(2, 2) + m_x = MvNormalMeanPrecision(ones(3), diageye(3)), + q_a = MvNormalMeanPrecision(ones(6), diageye(6)), + q_W = Wishart(2, diageye(2)), + meta = meta ), - output = MvNormalWeightedMeanPrecision(zeros(2), [2.0 -1.0; -1.0 3.0]) + output = MvNormalWeightedMeanPrecision(zeros(5), diageye(5)) )] end end diff --git a/test/rules/continuous_transition/test_x.jl b/test/rules/continuous_transition/test_x.jl index e67d8f4d3..ccd4ac03f 100644 --- a/test/rules/continuous_transition/test_x.jl +++ b/test/rules/continuous_transition/test_x.jl @@ -1,3 +1,25 @@ -@call_rule ContinuousTransition(:x, Marginalisation) ( - m_y = MvNormalMeanPrecision(randn(2), diageye(2)), q_h = MvNormalMeanPrecision(randn(6), diageye(6)), q_Λ = Wishart(2, diageye(2)), meta = CTMeta(2, 3) -) +module RulesContinuousTransitionTest + +using Test, ReactiveMP, BayesBase, Random, ExponentialFamily, Distributions + +import ReactiveMP: @test_rules, ctcompanion_matrix, getmasks, getunits + +@testset "rules:ContinuousTransition:x" begin + @testset "Structured: (m_y::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_W::Any, meta::CTMeta)" begin + # Example transformation function and vector length for CTMeta + meta = CTMeta((a) -> reshape(a, 2, 3), 6) + + @test_rules [check_type_promotion = true] ContinuousTransition(:x, Marginalisation) [ + ( + input = (m_y = MvNormalMeanCovariance(zeros(2), diageye(2)), q_a = MvNormalMeanCovariance(zeros(6), diageye(6)), q_W = Wishart(3, diageye(2)), meta = meta), + output = MvNormalMeanCovariance(zeros(2), 1 / 3 * diageye(2)) + ) + # Additional test cases with different distributions and metadata settings + # Each case should represent a realistic scenario for your application + ] + end + + # Additional tests for edge cases, errors, or specific behaviors of the rule can be added here +end + +end diff --git a/test/rules/continuous_transition/test_y.jl b/test/rules/continuous_transition/test_y.jl index 8f146ad4f..81b4cdd48 100644 --- a/test/rules/continuous_transition/test_y.jl +++ b/test/rules/continuous_transition/test_y.jl @@ -1,3 +1,23 @@ -@call_rule ContinuousTransition(:y, Marginalisation) ( - m_x = MvNormalMeanPrecision(randn(3), diageye(3)), q_h = MvNormalMeanPrecision(randn(6), diageye(6)), q_Λ = Wishart(2, diageye(2)), meta = CTMeta(2, 3) -) +module RulesContinuousTransitionTest + +using Test, ReactiveMP, BayesBase, Random, ExponentialFamily, Distributions + +import ReactiveMP: @test_rules, ctcompanion_matrix, getmasks, getunits + +@testset "rules:ContinuousTransition:y" begin + @testset "Structured: (m_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_W::Any, meta::CTMeta)" begin + # Example transformation function and vector length for CTMeta + meta = CTMeta((a) -> reshape(a, 2, 3), 6) + + @test_rules [check_type_promotion = true] ContinuousTransition(:y, Marginalisation) [ + ( + input = (m_x = MvNormalMeanCovariance(zeros(3), diageye(3)), q_a = MvNormalMeanCovariance(zeros(6), diageye(6)), q_W = Wishart(3, diageye(2)), meta = meta), + output = MvNormalMeanCovariance(zeros(2), 1 / 3 * diageye(2)) + ) + ] + end + + # Additional tests for edge cases, errors, or specific behaviors of the rule can be added here +end + +end diff --git a/test/runtests.jl b/test/runtests.jl index 60dedd7d0..b5191a582 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -251,6 +251,7 @@ end addtests(testrunner, "nodes/test_uniform.jl") addtests(testrunner, "nodes/test_normal_mixture.jl") addtests(testrunner, "nodes/test_softdot.jl") + addtests(testrunner, "nodes/test_continuous_transition.jl") addtests(testrunner, "rules/uniform/test_out.jl") From 0194e69114e1ef1890edeeb7a469285eea7fcc38 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Sun, 12 Nov 2023 16:57:56 +0100 Subject: [PATCH 14/38] WIP: Format tests --- test/rules/continuous_transition/test_W.jl | 20 +++++++++----------- test/rules/continuous_transition/test_a.jl | 20 +++++++++----------- test/rules/continuous_transition/test_x.jl | 15 +++++++-------- test/rules/continuous_transition/test_y.jl | 10 ++++------ 4 files changed, 29 insertions(+), 36 deletions(-) diff --git a/test/rules/continuous_transition/test_W.jl b/test/rules/continuous_transition/test_W.jl index 68df9c9a9..b3f2380a3 100644 --- a/test/rules/continuous_transition/test_W.jl +++ b/test/rules/continuous_transition/test_W.jl @@ -9,17 +9,15 @@ import ReactiveMP: @test_rules, ctcompanion_matrix, getmasks, getunits, WishartF # Example transformation function and vector length for CTMeta meta = CTMeta((a) -> reshape(a, 2, 3), 6) - @test_rules [check_type_promotion = true] ContinuousTransition(:W, Marginalisation) [ - ( - input = ( - q_y_x = MvNormalMeanCovariance(zeros(5), diageye(5)), # Adjust dimensions as needed - q_a = MvNormalMeanCovariance(zeros(6), diageye(6)), - meta = meta - ), - output = WishartFast(4, 4 * diageye(2)) - ) - # Additional test cases with different distributions and metadata settings - ] + @test_rules [check_type_promotion = true] ContinuousTransition(:W, Marginalisation) [( + input = ( + q_y_x = MvNormalMeanCovariance(zeros(5), diageye(5)), # Adjust dimensions as needed + q_a = MvNormalMeanCovariance(zeros(6), diageye(6)), + meta = meta + ), output = WishartFast(4, 4 * diageye(2)) + ) + # Additional test cases with different distributions and metadata settings +] end # Additional tests for edge cases, errors, or specific behaviors of the rule can be added here end diff --git a/test/rules/continuous_transition/test_a.jl b/test/rules/continuous_transition/test_a.jl index 98a9c6f6b..43118dd7e 100644 --- a/test/rules/continuous_transition/test_a.jl +++ b/test/rules/continuous_transition/test_a.jl @@ -9,17 +9,15 @@ import ReactiveMP: @test_rules, getmasks, getunits # Example transformation function and vector length for CTMeta meta = CTMeta((a) -> reshape(a, 2, 3), 6) - @test_rules [check_type_promotion = true] ContinuousTransition(:a, Marginalisation) [ - ( - input = ( - q_y_x = MvNormalMeanCovariance([zeros(2); zeros(3)], [diageye(2) zeros(2, 3); zeros(3, 2) diageye(3)]), - q_a = MvNormalMeanCovariance(zeros(6), diageye(6)), - q_W = Wishart(3, diageye(2)), - meta = meta - ), - output = MvNormalWeightedMeanPrecision(zeros(6), 3 * diageye(6)) - ) - ] + @test_rules [check_type_promotion = true] ContinuousTransition(:a, Marginalisation) [( + input = ( + q_y_x = MvNormalMeanCovariance([zeros(2); zeros(3)], [diageye(2) zeros(2, 3); zeros(3, 2) diageye(3)]), + q_a = MvNormalMeanCovariance(zeros(6), diageye(6)), + q_W = Wishart(3, diageye(2)), + meta = meta + ), + output = MvNormalWeightedMeanPrecision(zeros(6), 3 * diageye(6)) + )] end end diff --git a/test/rules/continuous_transition/test_x.jl b/test/rules/continuous_transition/test_x.jl index ccd4ac03f..f06434804 100644 --- a/test/rules/continuous_transition/test_x.jl +++ b/test/rules/continuous_transition/test_x.jl @@ -9,14 +9,13 @@ import ReactiveMP: @test_rules, ctcompanion_matrix, getmasks, getunits # Example transformation function and vector length for CTMeta meta = CTMeta((a) -> reshape(a, 2, 3), 6) - @test_rules [check_type_promotion = true] ContinuousTransition(:x, Marginalisation) [ - ( - input = (m_y = MvNormalMeanCovariance(zeros(2), diageye(2)), q_a = MvNormalMeanCovariance(zeros(6), diageye(6)), q_W = Wishart(3, diageye(2)), meta = meta), - output = MvNormalMeanCovariance(zeros(2), 1 / 3 * diageye(2)) - ) - # Additional test cases with different distributions and metadata settings - # Each case should represent a realistic scenario for your application - ] + @test_rules [check_type_promotion = true] ContinuousTransition(:x, Marginalisation) [( + input = (m_y = MvNormalMeanCovariance(zeros(2), diageye(2)), q_a = MvNormalMeanCovariance(zeros(6), diageye(6)), q_W = Wishart(3, diageye(2)), meta = meta), + output = MvNormalMeanCovariance(zeros(2), 1 / 3 * diageye(2)) + ) + # Additional test cases with different distributions and metadata settings + # Each case should represent a realistic scenario for your application +] end # Additional tests for edge cases, errors, or specific behaviors of the rule can be added here diff --git a/test/rules/continuous_transition/test_y.jl b/test/rules/continuous_transition/test_y.jl index 81b4cdd48..6933812d4 100644 --- a/test/rules/continuous_transition/test_y.jl +++ b/test/rules/continuous_transition/test_y.jl @@ -9,12 +9,10 @@ import ReactiveMP: @test_rules, ctcompanion_matrix, getmasks, getunits # Example transformation function and vector length for CTMeta meta = CTMeta((a) -> reshape(a, 2, 3), 6) - @test_rules [check_type_promotion = true] ContinuousTransition(:y, Marginalisation) [ - ( - input = (m_x = MvNormalMeanCovariance(zeros(3), diageye(3)), q_a = MvNormalMeanCovariance(zeros(6), diageye(6)), q_W = Wishart(3, diageye(2)), meta = meta), - output = MvNormalMeanCovariance(zeros(2), 1 / 3 * diageye(2)) - ) - ] + @test_rules [check_type_promotion = true] ContinuousTransition(:y, Marginalisation) [( + input = (m_x = MvNormalMeanCovariance(zeros(3), diageye(3)), q_a = MvNormalMeanCovariance(zeros(6), diageye(6)), q_W = Wishart(3, diageye(2)), meta = meta), + output = MvNormalMeanCovariance(zeros(2), 1 / 3 * diageye(2)) + )] end # Additional tests for edge cases, errors, or specific behaviors of the rule can be added here From df158cb37a862df91661e848bc52d19251f58c63 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Fri, 8 Dec 2023 17:24:25 +0100 Subject: [PATCH 15/38] Add tests for a and W --- src/nodes/continuous_transition.jl | 61 +++++++------ src/rules/continuous_transition/W.jl | 11 ++- src/rules/continuous_transition/a.jl | 21 ++--- src/rules/continuous_transition/marginals.jl | 12 +-- src/rules/continuous_transition/x.jl | 4 +- src/rules/continuous_transition/y.jl | 4 +- test/nodes/test_continuous_transition.jl | 35 ++++---- test/rules/continuous_transition/test_W.jl | 90 ++++++++++++++++---- test/rules/continuous_transition/test_a.jl | 86 ++++++++++++++++--- test/rules/continuous_transition/test_x.jl | 2 +- test/rules/continuous_transition/test_y.jl | 2 +- 11 files changed, 218 insertions(+), 110 deletions(-) diff --git a/src/nodes/continuous_transition.jl b/src/nodes/continuous_transition.jl index b7a58d654..3a8499c13 100644 --- a/src/nodes/continuous_transition.jl +++ b/src/nodes/continuous_transition.jl @@ -4,20 +4,20 @@ import LazyArrays import StatsFuns: log2π @doc raw""" -The ContinuousTransition node transforms an m-dimensional (dx) vector x into an n-dimensional (dy) vector y via a linear (or nonlinear) transformation with a `n×m`-dimensional matrix `A` that is constructed from a `nm`-dimensional vector `a`. +The ContinuousTransition node transforms an m-dimensional (dx) vector x into an n-dimensional (dy) vector y via a linear (or nonlinear) transformation with a `n×m`-dimensional matrix `A` that is constructed from a vector `a`. -To construct the matrix A, the elements of `a` are filled into A according to the transformation function provided with meta. +To construct the matrix A, the elements of `a` are filled into A according to the transformation function provided with meta. `a` must be of MultivariateNormalDistributionsFamily type. If you intend to use univariate Gaussian, use it as a vector of length `1``, e.g. `a ~ MvNormalMeanCovariance([0.0], [1.;])`. Check CTMeta for more details on how to specify the transformation function that **must** return a matrix. ```julia -y ~ ContinuousTransition(x, a, W) +y ~ ContinuousTransition(x, a, W) where {meta = CTMeta(transformation, â)} ``` Interfaces: 1. y - n-dimensional output of the ContinuousTransition node. 2. x - m-dimensional input of the ContinuousTransition node. -3. a - `nm`-dimensional vector that casts into the matrix `A`. -4. W - `n×n`-dimensional precision matrix used to soften the transition and perform variational message passing, as belief-propagation is not feasible for `y = Ax`. +3. a - any-dimensional vector that casts into the matrix `A`. +4. W - `n×n`-dimensional precision matrix used to soften the transition and perform variational message passing. Note that you can set W to a fixed value or put a prior on it to control the amount of jitter. """ @@ -47,35 +47,25 @@ The `CTMeta` struct plays a pivotal role in defining how the vector `a` is trans """ struct CTMeta ds::Tuple # dimensionality of ContinuousTransition (dy, dx) - Fs::Union{Vector{<:AbstractMatrix}, <:Function} # masks + f::Function # transformation function es::Vector{<:AbstractVector} # unit vectors - # NOTE: this meta is not user-friendly, I don't like supplying the length of a vector + # NOTE: this meta is not user-friendly, I don't like a vector # perhaps making mutable struct with empty meta first will be better from user perspective - # meta for linear transformation of a vector to a matrix - function CTMeta(transformation::Function, len::Integer) - dy, dx = size(transformation(zeros(len))) - Fs = [ForwardDiff.jacobian(a -> transformation(a)[i, :], 1:len) for i in 1:dy] - es = [StandardBasisVector(dy, i, 1.0) for i in 1:dy] - return new((dy, dx), Fs, es) - end - - # meta for nonlinear transformation of a vector to a matrix + # meta for transformation of a vector to a matrix function CTMeta(transformation::Function, â::Vector{<:Real}) dy, dx = size(transformation(â)) - es = [StandardBasisVector(dy, i, 1.0) for i in 1:dy] + es = [StandardBasisVector(dy, i, one(eltype(first(â)))) for i in 1:dy] return new((dy, dx), transformation, es) end end getunits(meta::CTMeta) = meta.es getdimensionality(meta::CTMeta) = meta.ds +gettransformation(meta::CTMeta) = meta.f -getmasks(ctmeta::CTMeta, a) = process_Fs(ctmeta.Fs, a) -process_Fs(Fs::Vector{<:AbstractMatrix}, a) = Fs - -# NOTE: this doesn't seem to be the right way of working with nonlinar approximation -process_Fs(Fs::Function, a) = [ForwardDiff.jacobian(a -> Fs(a)[i, :], a) for i in 1:size(Fs(a), 1)] +getjacobians(ctmeta::CTMeta, a) = process_Fs(gettransformation(ctmeta), a) +process_Fs(f::Function, a) = [ForwardDiff.jacobian(a -> f(a)[i, :], a) for i in 1:size(f(a), 1)] @node ContinuousTransition Stochastic [y, x, a, W] @@ -83,31 +73,38 @@ default_meta(::Type{CTMeta}) = error("ContinuousTransition node requires meta fl default_functional_dependencies_pipeline(::Type{<:ContinuousTransition}) = RequireMarginalFunctionalDependencies((3,), (nothing,)) -function ctcompanion_matrix(a, meta::CTMeta) - Fs, es = getmasks(meta, a), getunits(meta) +""" + `ctcompanion_matrix` casts a vector `a` into a matrix `A` by means of linearization of the transformation function `f` around the expansion point `a0`. +""" +function ctcompanion_matrix(a, epsilon, meta::CTMeta) + a0 = a .+ epsilon # expansion point + Js, es = getjacobians(meta, a0), getunits(meta) + f = gettransformation(meta) dy, _ = getdimensionality(meta) - A = sum(es[i] * a' * Fs[i]' for i in 1:dy) + # we approximate each row of A by a linear function and create a matrix A composed of the approximated rows + A = sum(es[i] * (f(a0)[i, :] + Js[i] * (a - a0))' for i in 1:dy) return A end -@average_energy ContinuousTransition (q_y_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_W::Wishart, meta::CTMeta) = begin +@average_energy ContinuousTransition (q_y_x::Any, q_a::Any, q_W::Any, meta::CTMeta) = begin ma, Va = mean_cov(q_a) myx, Vyx = mean_cov(q_y_x) mW = mean(q_W) dy, dx = getdimensionality(meta) - Fs, es = getmasks(meta, ma), getunits(meta) + Fs, es = getjacobians(meta, ma), getunits(meta) n = div(ndims(q_y_x), 2) - mA = ctcompanion_matrix(ma, meta) - mx, Vx = myx[(dy + 1):end], Vyx[(dy + 1):end, (dy + 1):end] - my, Vy = myx[1:dy], Vyx[1:dy, 1:dy] - Vyx = Vyx[1:dy, (dy + 1):end] + mA = ctcompanion_matrix(ma, sqrt.(var(q_a)), meta) + + mx, Vx = @views myx[(dy + 1):end], Vyx[(dy + 1):end, (dy + 1):end] + my, Vy = @views myx[1:dy], Vyx[1:dy, 1:dy] + Vyx = @view Vyx[1:dy, (dy + 1):end] g₁ = my' * mW * my + tr(Vy * mW) g₂ = mx' * mA' * mW * my + tr(Vyx * mA' * mW) g₃ = g₂ G = sum(sum(es[i]' * mW * es[j] * Fs[i] * (ma * ma' + Va) * Fs[j]' for i in 1:length(Fs)) for j in 1:length(Fs)) g₄ = mx' * G * mx + tr(Vx * G) - AE = n / 2 * log2π - 0.5 * mean(logdet, q_W) + 0.5 * (g₁ - g₂ - g₃ + g₄) + AE = n / 2 * log2π - (mean(logdet, q_W) - (g₁ - g₂ - g₃ + g₄)) / 2 return AE end diff --git a/src/rules/continuous_transition/W.jl b/src/rules/continuous_transition/W.jl index 842822ebf..785ae41ee 100644 --- a/src/rules/continuous_transition/W.jl +++ b/src/rules/continuous_transition/W.jl @@ -12,17 +12,16 @@ end dy, dx = getdimensionality(meta) ma, Va = mean_cov(q_a) - Fs, es = getmasks(meta, ma), getunits(meta) + Fs, es = getjacobians(meta, ma), getunits(meta) - mA = ctcompanion_matrix(ma, meta) + mA = ctcompanion_matrix(ma, sqrt.(var(q_a)), meta) myx, Vyx = mean_cov(q_y_x) - mx, Vx = myx[(dy + 1):end], Vyx[(dy + 1):end, (dy + 1):end] - my, Vy = myx[1:dy], Vyx[1:dy, 1:dy] - Vyx = Vyx[1:dy, (dy + 1):end] + mx, Vx = @views myx[(dy + 1):end], Vyx[(dy + 1):end, (dy + 1):end] + my, Vy = @views myx[1:dy], Vyx[1:dy, 1:dy] + Vyx = @views Vyx[1:dy, (dy + 1):end] Δ = compute_delta(my, Vy, mx, Vx, Vyx, mA, Va, ma, Fs, es) - # NOTE: WishartFast stores inverse of scale matrix return WishartFast(dy + 2, Δ) end diff --git a/src/rules/continuous_transition/a.jl b/src/rules/continuous_transition/a.jl index e021a1d4d..f2055ea9c 100644 --- a/src/rules/continuous_transition/a.jl +++ b/src/rules/continuous_transition/a.jl @@ -1,19 +1,20 @@ -@rule ContinuousTransition(:a, Marginalisation) (q_y_x::MultivariateNormalDistributionsFamily, q_a::NormalDistributionsFamily, q_W::Any, meta::CTMeta) = begin - dy, dx = getdimensionality(meta) +@rule ContinuousTransition(:a, Marginalisation) (q_y_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_W::Any, meta::CTMeta) = begin ma = mean(q_a) - Fs, es = getmasks(meta, ma), getunits(meta) - + mW = mean(q_W) myx, Vyx = mean_cov(q_y_x) - mx, Vx = myx[(dy + 1):end], Vyx[(dy + 1):end, (dy + 1):end] - my, Vy = myx[1:dy], Vyx[1:dy, 1:dy] - Vyx = Vyx[1:dy, (dy + 1):end] + dy, dx = getdimensionality(meta) - mW = mean(q_W) + mx, Vx = @views myx[(dy + 1):end], Vyx[(dy + 1):end, (dy + 1):end] + my, Vy = @views myx[1:dy], Vyx[1:dy, 1:dy] + Vyx = @view Vyx[1:dy, (dy + 1):end] + + Fs, es = getjacobians(meta, ma), getunits(meta) - D = sum(sum(es[i]' * mW * es[j] * Fs[i]' * (mx * mx' + Vx) * Fs[j] for i in 1:length(Fs)) for j in 1:length(Fs)) - z = sum(Fs[i]' * (mx * my' + Vyx') * mW * es[i] for i in 1:length(Fs)) + # rank1update(Vyx, mx, my) equivalent to ξ = (Vyx + mx * my') + D = sum(sum(es[j]' * mW * es[i] * Fs[i]' * rank1update(Vx, mx) * Fs[j] for i in 1:length(Fs)) for j in 1:length(Fs)) + z = sum(Fs[i]' * rank1update(Vyx', mx, my) * mW * es[i] for i in 1:length(Fs)) return MvNormalWeightedMeanPrecision(z, D) end diff --git a/src/rules/continuous_transition/marginals.jl b/src/rules/continuous_transition/marginals.jl index bc5d7a09a..94ddea21e 100644 --- a/src/rules/continuous_transition/marginals.jl +++ b/src/rules/continuous_transition/marginals.jl @@ -1,20 +1,16 @@ -@marginalrule ContinuousTransition(:y_x) ( - m_y::MultivariateNormalDistributionsFamily, m_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_W::Any, meta::CTMeta -) = begin +@marginalrule ContinuousTransition(:y_x) (m_y::MultivariateNormalDistributionsFamily, m_x::MultivariateNormalDistributionsFamily, q_a::Any, q_W::Any, meta::CTMeta) = begin return continuous_tranition_marginal(m_y, m_x, q_a, q_W, meta) end -function continuous_tranition_marginal( - m_y::MultivariateNormalDistributionsFamily, m_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_W::Any, meta::CTMeta -) +function continuous_tranition_marginal(m_y::MultivariateNormalDistributionsFamily, m_x::MultivariateNormalDistributionsFamily, q_a::Any, q_W::Any, meta::CTMeta) ma, Va = mean_cov(q_a) - Fs, es = getmasks(meta, ma), getunits(meta) + Fs, es = getjacobians(meta, ma), getunits(meta) mW = mean(q_W) - mA = ctcompanion_matrix(ma, meta) + mA = ctcompanion_matrix(ma, sqrt.(var(q_a)), meta) b_my, b_Vy = mean_cov(m_y) f_mx, f_Vx = mean_cov(m_x) diff --git a/src/rules/continuous_transition/x.jl b/src/rules/continuous_transition/x.jl index 4add0bf65..edabd060f 100644 --- a/src/rules/continuous_transition/x.jl +++ b/src/rules/continuous_transition/x.jl @@ -5,9 +5,9 @@ mW = mean(q_W) dy, dx = getdimensionality(meta) - Fs, es = getmasks(meta, ma), getunits(meta) + Fs, es = getjacobians(meta, ma), getunits(meta) - mA = ctcompanion_matrix(ma, meta) + mA = ctcompanion_matrix(ma, sqrt.(var(q_a)), meta) W = sum(sum(es[j]' * mW * es[i] * Fs[j] * Va * Fs[i]' for i in 1:length(Fs)) for j in 1:length(Fs)) diff --git a/src/rules/continuous_transition/y.jl b/src/rules/continuous_transition/y.jl index ee3b8fdaf..f34c8f6ce 100644 --- a/src/rules/continuous_transition/y.jl +++ b/src/rules/continuous_transition/y.jl @@ -5,9 +5,9 @@ mW = mean(q_W) dy, dx = getdimensionality(meta) - Fs, es = getmasks(meta, ma), getunits(meta) + Fs, es = getjacobians(meta, ma), getunits(meta) - mA = ctcompanion_matrix(ma, meta) + mA = ctcompanion_matrix(ma, sqrt.(var(q_a)), meta) W = sum(sum(es[j]' * mW * es[i] * Fs[j] * Va * Fs[i]' for i in 1:length(Fs)) for j in 1:length(Fs)) diff --git a/test/nodes/test_continuous_transition.jl b/test/nodes/test_continuous_transition.jl index 87b4ce14d..943582f72 100644 --- a/test/nodes/test_continuous_transition.jl +++ b/test/nodes/test_continuous_transition.jl @@ -2,47 +2,46 @@ module ContinuousTransitionNodeTest using Test, ReactiveMP, Random, Distributions, BayesBase, ExponentialFamily -import ReactiveMP: getdimensionality, getmasks, ctcompanion_matrix, getunits +import ReactiveMP: getdimensionality, getjacobians, gettransformation, getunits, ctcompanion_matrix @testset "ContinuousTransitionNode" begin + + dy, dx = 2, 3 + a0 = rand(dx*dy) # Example vector `a0` + meta = CTMeta(a -> reshape(a, dy, dx), a0) @testset "Creation" begin - meta = CTMeta((a) -> reshape(a, 2, 3), 6) # Example transformation function and vector length node = make_node(ContinuousTransition, FactorNodeCreationOptions(nothing, meta, nothing)) @test functionalform(node) === ContinuousTransition @test sdtype(node) === Stochastic() @test name.(interfaces(node)) === (:y, :x, :a, :W) @test factorisation(node) === ((1, 2, 3, 4),) - @test getdimensionality(metadata(node)) == (2, 3) # Based on the transformation function dimensions + @test getdimensionality(metadata(node)) == (dy, dx) # Based on the transformation function dimensions end @testset "AverageEnergy" begin # This is an example setup, you'll need to adjust the distributions and marginals according to your needs q_y_x = MvNormalMeanCovariance(zeros(5), diageye(5)) q_a = MvNormalMeanCovariance(zeros(6), diageye(6)) # Adjust the dimension according to `a` - q_W = Wishart(3, diageye(2)) # Adjust the degrees of freedom and scale matrix as needed + q_W = Wishart(3, diageye(2)) marginals = (Marginal(q_y_x, false, false, nothing), Marginal(q_a, false, false, nothing), Marginal(q_W, false, false, nothing)) - meta = CTMeta((a) -> reshape(a, 2, 3), 6) - @test score(AverageEnergy(), ContinuousTransition, Val{(:y_x, :a, :W)}(), marginals, meta) ≈ 13.415092731310878 #ExpectedValue + @test score(AverageEnergy(), ContinuousTransition, Val{(:y_x, :a, :W)}(), marginals, meta) ≈ 13.415092731310878 + @show getjacobians(meta, a0) end - @testset "CTransitionFunctionality" begin - a = rand(6) # Example vector `a` - meta = CTMeta((a) -> reshape(a, 2, 3), 6) - A = ctcompanion_matrix(a, meta) + @testset "CTransition Functionality" begin + A = ctcompanion_matrix(a0, zeros(length(a0)), meta) - @test size(A) == (2, 3) - @test A == reshape(a, 2, 3) # This is based on the transformation function provided in meta + @test size(A) == (dy, dx) + @test A == gettransformation(meta)(a0) end - @testset "MetadataFunctionality" begin - meta = CTMeta((a) -> reshape(a, 2, 3), 6) - - @test getdimensionality(meta) == (2, 3) - @test length(getmasks(meta, rand(6))) == 2 # Based on `dy` - @test length(getunits(meta)) == 2 # Based on `dy` + @testset "Metadata Functionality" begin + @test getdimensionality(meta) == (dy, dx) + @test length(getjacobians(meta, a0)) == dy # Based on `dy` + @test length(getunits(meta)) == dy # Based on `dy` end end diff --git a/test/rules/continuous_transition/test_W.jl b/test/rules/continuous_transition/test_W.jl index b3f2380a3..1d2f0383a 100644 --- a/test/rules/continuous_transition/test_W.jl +++ b/test/rules/continuous_transition/test_W.jl @@ -1,25 +1,83 @@ module RulesContinuousTransitionTest -using Test, ReactiveMP, BayesBase, Random, ExponentialFamily, Distributions +using Test, ReactiveMP, BayesBase, Random, ExponentialFamily, Distributions, LinearAlgebra -import ReactiveMP: @test_rules, ctcompanion_matrix, getmasks, getunits, WishartFast +import ReactiveMP: @test_rules, ctcompanion_matrix, getjacobians, getunits, WishartFast @testset "rules:ContinuousTransition:W" begin - @testset "Structured: (q_y_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, meta::CTMeta)" begin - # Example transformation function and vector length for CTMeta - meta = CTMeta((a) -> reshape(a, 2, 3), 6) - - @test_rules [check_type_promotion = true] ContinuousTransition(:W, Marginalisation) [( - input = ( - q_y_x = MvNormalMeanCovariance(zeros(5), diageye(5)), # Adjust dimensions as needed - q_a = MvNormalMeanCovariance(zeros(6), diageye(6)), - meta = meta - ), output = WishartFast(4, 4 * diageye(2)) - ) - # Additional test cases with different distributions and metadata settings -] + rng = MersenneTwister(42) + + @testset "Linear transformation" begin + # the following rule is used for testing purposes only + # It is derived separately by Thijs van de Laar + function benchmark_rule(q_y_x, mA, ΣA, UA) + myx, Vyx = mean_cov(q_y_x) + + dy = size(mA, 1) + Vx = Vyx[(dy + 1):end, (dy + 1):end] + Vy = Vyx[1:dy, 1:dy] + mx = myx[(dy + 1):end] + my = myx[1:dy] + Vyx = Vyx[1:dy, (dy + 1):end] + + G = tr(Vx*UA)*ΣA + mA*Vx*mA' - mA*Vyx' - Vyx*mA' + Vy + ΣA*mx'*UA*mx + (mA*mx-my)*(mA*mx-my)' + + return WishartFast(dy + 2, G) + end + + + @testset "Structured: (q_y_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, meta::CTMeta)" begin + + for (dy, dx) in [(1, 3), (2, 3), (3, 2), (2, 2)] + dydx = dy * dx + transformation = (a) -> reshape(a, dy, dx) + a0 = rand(dydx) + metal = CTMeta(transformation, a0) + Lx, Ly = rand(rng, dx, dx), rand(rng, dy, dy) + μx, Σx = rand(rng, dx), Lx * Lx' + μy, Σy = rand(rng, dy), Ly * Ly' + + mA, ΣA, UA = rand(rng, dy, dx), diageye(dy), diageye(dx) + + qyx = MvNormalMeanCovariance([μy; μx], [Σy zeros(dy, dx); zeros(dx, dy) Σx]) + qa = MvNormalMeanCovariance(vec(mA), kron(UA, ΣA)) + + @test_rules [check_type_promotion = true, atol = 1e-5] ContinuousTransition(:W, Marginalisation) [( + input = ( + q_y_x = qyx, + q_a = qa, + meta = metal + ), output = benchmark_rule(qyx, mA, ΣA, UA) + ) + ] + end + end end - # Additional tests for edge cases, errors, or specific behaviors of the rule can be added here + + @testset "Nonlinear transformation" begin + @testset "Structured: (q_y_x::MultivariateNormalDistributionsFamily, q_a::Any, q_W::Any, meta::CTMeta)" begin + + dy, dx = 2, 2 + dydx = dy * dy + transformation = (a) -> [cos(a[1]) -sin(a[1]); sin(a[1]) cos(a[1])] + a0 = zeros(1) + metanl = CTMeta(transformation, a0) + μx, Σx = zeros(dx), diageye(dx) + μy, Σy = zeros(dy), diageye(dy) + + qyx = MvNormalMeanCovariance([μy; μx], [Σy zeros(dy, dx); zeros(dx, dy) Σx]) + qa = MvNormalMeanCovariance(a0, diageye(1)) + @test_rules [check_type_promotion = true] ContinuousTransition(:W, Marginalisation) [( + input = ( + q_y_x = qyx, + q_a = qa, + meta = metanl + ), + output = WishartFast(dy+2, dy*diageye(dy)) + )] + end + end + end end diff --git a/test/rules/continuous_transition/test_a.jl b/test/rules/continuous_transition/test_a.jl index 43118dd7e..34039c1d7 100644 --- a/test/rules/continuous_transition/test_a.jl +++ b/test/rules/continuous_transition/test_a.jl @@ -2,23 +2,81 @@ module RulesContinuousTransitionTestA using Test, ReactiveMP, BayesBase, Random, ExponentialFamily, Distributions -import ReactiveMP: @test_rules, getmasks, getunits +import ReactiveMP: @test_rules, getjacobians, getunits @testset "rules:ContinuousTransition:a" begin - @testset "Structured: (q_y_x::MultivariateNormalDistributionsFamily, q_a::NormalDistributionsFamily, q_W::Any, meta::CTMeta)" begin - # Example transformation function and vector length for CTMeta - meta = CTMeta((a) -> reshape(a, 2, 3), 6) - - @test_rules [check_type_promotion = true] ContinuousTransition(:a, Marginalisation) [( - input = ( - q_y_x = MvNormalMeanCovariance([zeros(2); zeros(3)], [diageye(2) zeros(2, 3); zeros(3, 2) diageye(3)]), - q_a = MvNormalMeanCovariance(zeros(6), diageye(6)), - q_W = Wishart(3, diageye(2)), - meta = meta - ), - output = MvNormalWeightedMeanPrecision(zeros(6), 3 * diageye(6)) - )] + rng = MersenneTwister(42) + + @testset "Linear transformation" begin + + # the following rule is used for testing purposes only + # It is derived separately by Thijs van de Laar + function benchmark_rule(q_y_x, q_W) + myx, Vyx = mean_cov(q_y_x) + dy = size(q_W.S, 1) + Vx = Vyx[(dy + 1):end, (dy + 1):end] + mx = myx[(dy + 1):end] + my = myx[1:dy] + Vyx = Vyx[1:dy, (dy + 1):end] + mW = mean(q_W) + Λ = kron(Vx + mx*mx', mW) + return MvNormalWeightedMeanPrecision(Λ*vec((Vyx + my*mx')*inv((Vx + mx*mx'))), Λ) + end + + @testset "Structured: (q_y_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_W::Any, meta::CTMeta)" begin + + for (dy, dx) in [(1, 3), (2, 3), (3, 2), (2, 2)] + dydx = dy * dx + transformation = (a) -> reshape(a, dy, dx) + a0 = rand(dydx) + metal = CTMeta(transformation, a0) + Lx, Ly = rand(rng, dx, dx), rand(rng, dy, dy) + μx, Σx = rand(rng, dx), Lx * Lx' + μy, Σy = rand(rng, dy), Ly * Ly' + + qyx = MvNormalMeanCovariance([μy; μx], [Σy zeros(dy, dx); zeros(dx, dy) Σx]) + qa = MvNormalMeanCovariance(a0, diageye(dydx)) + qW = Wishart(dy+1, diageye(dy)) + @test_rules [check_type_promotion = false] ContinuousTransition(:a, Marginalisation) [( + input = ( + q_y_x = qyx, + q_a = qa, + q_W = qW, + meta = metal + ), + output = benchmark_rule(qyx, qW) + )] + end + end + end + + @testset "Nonlinear transformation" begin + + @testset "Structured: (q_y_x::MultivariateNormalDistributionsFamily, q_a::Any, q_W::Any, meta::CTMeta)" begin + + dy, dx = 2, 2 + dydx = dy * dy + transformation = (a) -> [cos(a[1]) -sin(a[1]); sin(a[1]) cos(a[1])] + a0 = zeros(1) + metanl = CTMeta(transformation, a0) + μx, Σx = ones(dx), diageye(dx) + μy, Σy = ones(dy), diageye(dy) + + qyx = MvNormalMeanCovariance([μy; μx], [Σy zeros(dy, dx); zeros(dx, dy) Σx]) + qa = MvNormalMeanCovariance(a0, diageye(1)) + qW = Wishart(dy, diageye(dy)) + @test_rules [check_type_promotion = true] ContinuousTransition(:a, Marginalisation) [( + input = ( + q_y_x = qyx, + q_a = qa, + q_W = qW, + meta = metanl + ), + output = MvNormalWeightedMeanPrecision(zeros(1), (qW.df*dy*dx)*diageye(1)) + )] + end end + end end diff --git a/test/rules/continuous_transition/test_x.jl b/test/rules/continuous_transition/test_x.jl index f06434804..f058aaf17 100644 --- a/test/rules/continuous_transition/test_x.jl +++ b/test/rules/continuous_transition/test_x.jl @@ -2,7 +2,7 @@ module RulesContinuousTransitionTest using Test, ReactiveMP, BayesBase, Random, ExponentialFamily, Distributions -import ReactiveMP: @test_rules, ctcompanion_matrix, getmasks, getunits +import ReactiveMP: @test_rules, ctcompanion_matrix, getjacobians, getunits @testset "rules:ContinuousTransition:x" begin @testset "Structured: (m_y::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_W::Any, meta::CTMeta)" begin diff --git a/test/rules/continuous_transition/test_y.jl b/test/rules/continuous_transition/test_y.jl index 6933812d4..f3d79edc2 100644 --- a/test/rules/continuous_transition/test_y.jl +++ b/test/rules/continuous_transition/test_y.jl @@ -2,7 +2,7 @@ module RulesContinuousTransitionTest using Test, ReactiveMP, BayesBase, Random, ExponentialFamily, Distributions -import ReactiveMP: @test_rules, ctcompanion_matrix, getmasks, getunits +import ReactiveMP: @test_rules, ctcompanion_matrix, getjacobians, getunits @testset "rules:ContinuousTransition:y" begin @testset "Structured: (m_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_W::Any, meta::CTMeta)" begin From ce18c07f42a7ff7563fc7a2cf1aa6e5f91d9608d Mon Sep 17 00:00:00 2001 From: bvdmitri Date: Fri, 8 Dec 2023 18:14:47 +0100 Subject: [PATCH 16/38] fix tests --- test/rules/continuous_transition/test_W.jl | 4 ++-- test/rules/continuous_transition/test_a.jl | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/rules/continuous_transition/test_W.jl b/test/rules/continuous_transition/test_W.jl index 1d2f0383a..3bd02fd12 100644 --- a/test/rules/continuous_transition/test_W.jl +++ b/test/rules/continuous_transition/test_W.jl @@ -31,7 +31,7 @@ import ReactiveMP: @test_rules, ctcompanion_matrix, getjacobians, getunits, Wish for (dy, dx) in [(1, 3), (2, 3), (3, 2), (2, 2)] dydx = dy * dx transformation = (a) -> reshape(a, dy, dx) - a0 = rand(dydx) + a0 = rand(Float32, dydx) metal = CTMeta(transformation, a0) Lx, Ly = rand(rng, dx, dx), rand(rng, dy, dy) μx, Σx = rand(rng, dx), Lx * Lx' @@ -60,7 +60,7 @@ import ReactiveMP: @test_rules, ctcompanion_matrix, getjacobians, getunits, Wish dy, dx = 2, 2 dydx = dy * dy transformation = (a) -> [cos(a[1]) -sin(a[1]); sin(a[1]) cos(a[1])] - a0 = zeros(1) + a0 = zeros(Int, 1) metanl = CTMeta(transformation, a0) μx, Σx = zeros(dx), diageye(dx) μy, Σy = zeros(dy), diageye(dy) diff --git a/test/rules/continuous_transition/test_a.jl b/test/rules/continuous_transition/test_a.jl index 34039c1d7..864595160 100644 --- a/test/rules/continuous_transition/test_a.jl +++ b/test/rules/continuous_transition/test_a.jl @@ -57,7 +57,7 @@ import ReactiveMP: @test_rules, getjacobians, getunits dy, dx = 2, 2 dydx = dy * dy transformation = (a) -> [cos(a[1]) -sin(a[1]); sin(a[1]) cos(a[1])] - a0 = zeros(1) + a0 = zeros(Int, 1) metanl = CTMeta(transformation, a0) μx, Σx = ones(dx), diageye(dx) μy, Σy = ones(dy), diageye(dy) From 6b4ffae0b5f110229d9f623247f196f15c39c442 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Fri, 8 Dec 2023 18:27:49 +0100 Subject: [PATCH 17/38] Update test W --- test/rules/continuous_transition/test_W.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/rules/continuous_transition/test_W.jl b/test/rules/continuous_transition/test_W.jl index 1d2f0383a..401482d27 100644 --- a/test/rules/continuous_transition/test_W.jl +++ b/test/rules/continuous_transition/test_W.jl @@ -31,13 +31,15 @@ import ReactiveMP: @test_rules, ctcompanion_matrix, getjacobians, getunits, Wish for (dy, dx) in [(1, 3), (2, 3), (3, 2), (2, 2)] dydx = dy * dx transformation = (a) -> reshape(a, dy, dx) - a0 = rand(dydx) + mA, ΣA, UA = rand(rng, dy, dx), diageye(dy), diageye(dx) + + a0 = vec(mA) + metal = CTMeta(transformation, a0) Lx, Ly = rand(rng, dx, dx), rand(rng, dy, dy) μx, Σx = rand(rng, dx), Lx * Lx' μy, Σy = rand(rng, dy), Ly * Ly' - mA, ΣA, UA = rand(rng, dy, dx), diageye(dy), diageye(dx) qyx = MvNormalMeanCovariance([μy; μx], [Σy zeros(dy, dx); zeros(dx, dy) Σx]) qa = MvNormalMeanCovariance(vec(mA), kron(UA, ΣA)) From 467b2db69a4a7b410709f709a9ba395444adb1ae Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Fri, 8 Dec 2023 18:50:06 +0100 Subject: [PATCH 18/38] Update rules --- src/nodes/continuous_transition.jl | 2 - src/rules/continuous_transition/x.jl | 6 +- test/rules/continuous_transition/test_W.jl | 2 +- test/rules/continuous_transition/test_a.jl | 2 +- test/rules/continuous_transition/test_x.jl | 77 ++++++++++++++++++---- 5 files changed, 68 insertions(+), 21 deletions(-) diff --git a/src/nodes/continuous_transition.jl b/src/nodes/continuous_transition.jl index 3a8499c13..d5134e684 100644 --- a/src/nodes/continuous_transition.jl +++ b/src/nodes/continuous_transition.jl @@ -67,8 +67,6 @@ gettransformation(meta::CTMeta) = meta.f getjacobians(ctmeta::CTMeta, a) = process_Fs(gettransformation(ctmeta), a) process_Fs(f::Function, a) = [ForwardDiff.jacobian(a -> f(a)[i, :], a) for i in 1:size(f(a), 1)] -@node ContinuousTransition Stochastic [y, x, a, W] - default_meta(::Type{CTMeta}) = error("ContinuousTransition node requires meta flag explicitly specified") default_functional_dependencies_pipeline(::Type{<:ContinuousTransition}) = RequireMarginalFunctionalDependencies((3,), (nothing,)) diff --git a/src/rules/continuous_transition/x.jl b/src/rules/continuous_transition/x.jl index edabd060f..3e81ee70c 100644 --- a/src/rules/continuous_transition/x.jl +++ b/src/rules/continuous_transition/x.jl @@ -11,10 +11,8 @@ W = sum(sum(es[j]' * mW * es[i] * Fs[j] * Va * Fs[i]' for i in 1:length(Fs)) for j in 1:length(Fs)) - Σ₁ = Hermitian(pinv(mA) * (Vy) * pinv(mA') + pinv(mA' * mW * mA)) - - Ξ = (pinv(Σ₁) + W) - z = pinv(Σ₁) * pinv(mA) * my + z = mA'* inv(Vy + inv(mW)) * my + Ξ = mA'* inv(Vy + inv(mW)) * mA + W return MvNormalWeightedMeanPrecision(z, Ξ) end diff --git a/test/rules/continuous_transition/test_W.jl b/test/rules/continuous_transition/test_W.jl index 38f897e76..cf89b36ec 100644 --- a/test/rules/continuous_transition/test_W.jl +++ b/test/rules/continuous_transition/test_W.jl @@ -33,7 +33,7 @@ import ReactiveMP: @test_rules, ctcompanion_matrix, getjacobians, getunits, Wish transformation = (a) -> reshape(a, dy, dx) mA, ΣA, UA = rand(rng, dy, dx), diageye(dy), diageye(dx) - a0 = vec(mA) + a0 = Float32.(vec(mA)) metal = CTMeta(transformation, a0) Lx, Ly = rand(rng, dx, dx), rand(rng, dy, dy) diff --git a/test/rules/continuous_transition/test_a.jl b/test/rules/continuous_transition/test_a.jl index 864595160..2204c5c26 100644 --- a/test/rules/continuous_transition/test_a.jl +++ b/test/rules/continuous_transition/test_a.jl @@ -28,7 +28,7 @@ import ReactiveMP: @test_rules, getjacobians, getunits for (dy, dx) in [(1, 3), (2, 3), (3, 2), (2, 2)] dydx = dy * dx transformation = (a) -> reshape(a, dy, dx) - a0 = rand(dydx) + a0 = rand(Float32, dydx) metal = CTMeta(transformation, a0) Lx, Ly = rand(rng, dx, dx), rand(rng, dy, dy) μx, Σx = rand(rng, dx), Lx * Lx' diff --git a/test/rules/continuous_transition/test_x.jl b/test/rules/continuous_transition/test_x.jl index f058aaf17..428ae8fdb 100644 --- a/test/rules/continuous_transition/test_x.jl +++ b/test/rules/continuous_transition/test_x.jl @@ -1,24 +1,75 @@ module RulesContinuousTransitionTest -using Test, ReactiveMP, BayesBase, Random, ExponentialFamily, Distributions +using Test, ReactiveMP, BayesBase, Random, ExponentialFamily, Distributions, LinearAlgebra import ReactiveMP: @test_rules, ctcompanion_matrix, getjacobians, getunits @testset "rules:ContinuousTransition:x" begin - @testset "Structured: (m_y::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_W::Any, meta::CTMeta)" begin - # Example transformation function and vector length for CTMeta - meta = CTMeta((a) -> reshape(a, 2, 3), 6) - - @test_rules [check_type_promotion = true] ContinuousTransition(:x, Marginalisation) [( - input = (m_y = MvNormalMeanCovariance(zeros(2), diageye(2)), q_a = MvNormalMeanCovariance(zeros(6), diageye(6)), q_W = Wishart(3, diageye(2)), meta = meta), - output = MvNormalMeanCovariance(zeros(2), 1 / 3 * diageye(2)) - ) - # Additional test cases with different distributions and metadata settings - # Each case should represent a realistic scenario for your application -] + + rng = MersenneTwister(42) + + @testset "Linear transformation" begin + # the following rule is used for testing purposes only + # It is derived separately by Thijs van de Laar + function benchmark_rule(q_y, q_W, mA, ΣA, UA) + my, Vy = mean_cov(q_y) + + mW = mean(q_W) + + Λ = tr(mW*ΣA)*UA + mA'*inv(Vy + inv(mW))*mA + ξ = mA'*inv(Vy + inv(mW))*my + return MvNormalWeightedMeanPrecision(ξ, Λ) + end + + @testset "Structured: (m_y::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_W::Any, meta::CTMeta)" begin + for (dy, dx) in [(1, 3), (2, 3), (3, 2), (2, 2)] + dydx = dy * dx + transformation = (a) -> reshape(a, dy, dx) + + mA, ΣA, UA = rand(rng, dy, dx), diageye(dy), diageye(dx) + + a0 = Float32.(vec(mA)) + metal = CTMeta(transformation, a0) + Lx, Ly = rand(rng, dx, dx), rand(rng, dy, dy) + μy, Σy = rand(rng, dy), Ly * Ly' + + qy = MvNormalMeanCovariance(μy, Σy) + qa = MvNormalMeanCovariance(a0, diageye(dydx)) + qW = Wishart(dy+1, diageye(dy)) + + @test_rules [check_type_promotion = false] ContinuousTransition(:x, Marginalisation) [( + input = (m_y = qy, q_a = qa, q_W = qW, meta = metal), + output = benchmark_rule(qy, qW, mA, ΣA, UA) + ) + # Additional test cases with different distributions and metadata settings + # Each case should represent a realistic scenario for your application + ] + end + end + end + + @testset "Nonlinear transformation" begin + @testset "Structured: (q_y_x::MultivariateNormalDistributionsFamily, q_a::Any, q_W::Any, meta::CTMeta)" begin + + dy, dx = 2, 2 + dydx = dy * dy + transformation = (a) -> [cos(a[1]) -sin(a[1]); sin(a[1]) cos(a[1])] + a0 = zeros(Float32, 1) + metanl = CTMeta(transformation, a0) + μy, Σy = zeros(dy), diageye(dy) + + qy = MvNormalMeanCovariance(μy, Σy) + qa = MvNormalMeanCovariance(a0, tiny*diageye(1)) + qW = Wishart(dy+1, diageye(dy)) + + @test_rules [check_type_promotion = false] ContinuousTransition(:x, Marginalisation) [( + input = (m_y = qy, q_a = qa, q_W = qW, meta = metanl), + output = MvGaussianWeightedMeanPrecision(zeros(dx), 3/4*diageye(dx)) + ) + ] + end end - # Additional tests for edge cases, errors, or specific behaviors of the rule can be added here end end From 382ab2d2a95f8c7b1062fdd26167c6b9920097f2 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Fri, 8 Dec 2023 19:14:11 +0100 Subject: [PATCH 19/38] Add tests for x and y --- src/rules/continuous_transition/a.jl | 1 - src/rules/continuous_transition/x.jl | 4 +- src/rules/continuous_transition/y.jl | 11 ++-- test/nodes/test_continuous_transition.jl | 9 ++-- test/rules/continuous_transition/test_W.jl | 27 +++------- test/rules/continuous_transition/test_a.jl | 40 +++++---------- test/rules/continuous_transition/test_x.jl | 32 +++++------- test/rules/continuous_transition/test_y.jl | 58 +++++++++++++++++++--- 8 files changed, 90 insertions(+), 92 deletions(-) diff --git a/src/rules/continuous_transition/a.jl b/src/rules/continuous_transition/a.jl index f2055ea9c..5992bf0e7 100644 --- a/src/rules/continuous_transition/a.jl +++ b/src/rules/continuous_transition/a.jl @@ -1,5 +1,4 @@ @rule ContinuousTransition(:a, Marginalisation) (q_y_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_W::Any, meta::CTMeta) = begin - ma = mean(q_a) mW = mean(q_W) myx, Vyx = mean_cov(q_y_x) diff --git a/src/rules/continuous_transition/x.jl b/src/rules/continuous_transition/x.jl index 3e81ee70c..31f39ae1c 100644 --- a/src/rules/continuous_transition/x.jl +++ b/src/rules/continuous_transition/x.jl @@ -11,8 +11,8 @@ W = sum(sum(es[j]' * mW * es[i] * Fs[j] * Va * Fs[i]' for i in 1:length(Fs)) for j in 1:length(Fs)) - z = mA'* inv(Vy + inv(mW)) * my - Ξ = mA'* inv(Vy + inv(mW)) * mA + W + z = mA' * inv(Vy + inv(mW)) * my + Ξ = mA' * inv(Vy + inv(mW)) * mA + W return MvNormalWeightedMeanPrecision(z, Ξ) end diff --git a/src/rules/continuous_transition/y.jl b/src/rules/continuous_transition/y.jl index f34c8f6ce..b812c50fd 100644 --- a/src/rules/continuous_transition/y.jl +++ b/src/rules/continuous_transition/y.jl @@ -1,6 +1,6 @@ @rule ContinuousTransition(:y, Marginalisation) (m_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_W::Any, meta::CTMeta) = begin ma, Va = mean_cov(q_a) - mx, Wx = mean_invcov(m_x) + mx, Vx = mean_cov(m_x) mW = mean(q_W) @@ -9,13 +9,8 @@ mA = ctcompanion_matrix(ma, sqrt.(var(q_a)), meta) - W = sum(sum(es[j]' * mW * es[i] * Fs[j] * Va * Fs[i]' for i in 1:length(Fs)) for j in 1:length(Fs)) - - Ξ = W + Wx - z = Wx * mx - - Vy = mA * inv(Ξ) * mA' + inv(mW) - my = mA * inv(Ξ) * z + Vy = mA * Vx * mA' + inv(mW) + my = mA * mx return MvNormalMeanCovariance(my, Vy) end diff --git a/test/nodes/test_continuous_transition.jl b/test/nodes/test_continuous_transition.jl index 943582f72..3dfceb31c 100644 --- a/test/nodes/test_continuous_transition.jl +++ b/test/nodes/test_continuous_transition.jl @@ -5,9 +5,8 @@ using Test, ReactiveMP, Random, Distributions, BayesBase, ExponentialFamily import ReactiveMP: getdimensionality, getjacobians, gettransformation, getunits, ctcompanion_matrix @testset "ContinuousTransitionNode" begin - dy, dx = 2, 3 - a0 = rand(dx*dy) # Example vector `a0` + a0 = rand(dx * dy) # Example vector `a0` meta = CTMeta(a -> reshape(a, dy, dx), a0) @testset "Creation" begin node = make_node(ContinuousTransition, FactorNodeCreationOptions(nothing, meta, nothing)) @@ -23,11 +22,11 @@ import ReactiveMP: getdimensionality, getjacobians, gettransformation, getunits, # This is an example setup, you'll need to adjust the distributions and marginals according to your needs q_y_x = MvNormalMeanCovariance(zeros(5), diageye(5)) q_a = MvNormalMeanCovariance(zeros(6), diageye(6)) # Adjust the dimension according to `a` - q_W = Wishart(3, diageye(2)) + q_W = Wishart(3, diageye(2)) marginals = (Marginal(q_y_x, false, false, nothing), Marginal(q_a, false, false, nothing), Marginal(q_W, false, false, nothing)) - @test score(AverageEnergy(), ContinuousTransition, Val{(:y_x, :a, :W)}(), marginals, meta) ≈ 13.415092731310878 + @test score(AverageEnergy(), ContinuousTransition, Val{(:y_x, :a, :W)}(), marginals, meta) ≈ 13.415092731310878 @show getjacobians(meta, a0) end @@ -35,7 +34,7 @@ import ReactiveMP: getdimensionality, getjacobians, gettransformation, getunits, A = ctcompanion_matrix(a0, zeros(length(a0)), meta) @test size(A) == (dy, dx) - @test A == gettransformation(meta)(a0) + @test A == gettransformation(meta)(a0) end @testset "Metadata Functionality" begin diff --git a/test/rules/continuous_transition/test_W.jl b/test/rules/continuous_transition/test_W.jl index cf89b36ec..71d0ff77c 100644 --- a/test/rules/continuous_transition/test_W.jl +++ b/test/rules/continuous_transition/test_W.jl @@ -15,19 +15,17 @@ import ReactiveMP: @test_rules, ctcompanion_matrix, getjacobians, getunits, Wish dy = size(mA, 1) Vx = Vyx[(dy + 1):end, (dy + 1):end] - Vy = Vyx[1:dy, 1:dy] + Vy = Vyx[1:dy, 1:dy] mx = myx[(dy + 1):end] my = myx[1:dy] Vyx = Vyx[1:dy, (dy + 1):end] - G = tr(Vx*UA)*ΣA + mA*Vx*mA' - mA*Vyx' - Vyx*mA' + Vy + ΣA*mx'*UA*mx + (mA*mx-my)*(mA*mx-my)' + G = tr(Vx * UA) * ΣA + mA * Vx * mA' - mA * Vyx' - Vyx * mA' + Vy + ΣA * mx' * UA * mx + (mA * mx - my) * (mA * mx - my)' return WishartFast(dy + 2, G) end - @testset "Structured: (q_y_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, meta::CTMeta)" begin - for (dy, dx) in [(1, 3), (2, 3), (3, 2), (2, 2)] dydx = dy * dx transformation = (a) -> reshape(a, dy, dx) @@ -39,47 +37,34 @@ import ReactiveMP: @test_rules, ctcompanion_matrix, getjacobians, getunits, Wish Lx, Ly = rand(rng, dx, dx), rand(rng, dy, dy) μx, Σx = rand(rng, dx), Lx * Lx' μy, Σy = rand(rng, dy), Ly * Ly' - qyx = MvNormalMeanCovariance([μy; μx], [Σy zeros(dy, dx); zeros(dx, dy) Σx]) qa = MvNormalMeanCovariance(vec(mA), kron(UA, ΣA)) @test_rules [check_type_promotion = true, atol = 1e-5] ContinuousTransition(:W, Marginalisation) [( - input = ( - q_y_x = qyx, - q_a = qa, - meta = metal - ), output = benchmark_rule(qyx, mA, ΣA, UA) - ) - ] + input = (q_y_x = qyx, q_a = qa, meta = metal), output = benchmark_rule(qyx, mA, ΣA, UA) + )] end end end @testset "Nonlinear transformation" begin @testset "Structured: (q_y_x::MultivariateNormalDistributionsFamily, q_a::Any, q_W::Any, meta::CTMeta)" begin - dy, dx = 2, 2 dydx = dy * dy transformation = (a) -> [cos(a[1]) -sin(a[1]); sin(a[1]) cos(a[1])] a0 = zeros(Int, 1) metanl = CTMeta(transformation, a0) - μx, Σx = zeros(dx), diageye(dx) + μx, Σx = zeros(dx), diageye(dx) μy, Σy = zeros(dy), diageye(dy) qyx = MvNormalMeanCovariance([μy; μx], [Σy zeros(dy, dx); zeros(dx, dy) Σx]) qa = MvNormalMeanCovariance(a0, diageye(1)) @test_rules [check_type_promotion = true] ContinuousTransition(:W, Marginalisation) [( - input = ( - q_y_x = qyx, - q_a = qa, - meta = metanl - ), - output = WishartFast(dy+2, dy*diageye(dy)) + input = (q_y_x = qyx, q_a = qa, meta = metanl), output = WishartFast(dy + 2, dy * diageye(dy)) )] end end - end end diff --git a/test/rules/continuous_transition/test_a.jl b/test/rules/continuous_transition/test_a.jl index 2204c5c26..acd9910ce 100644 --- a/test/rules/continuous_transition/test_a.jl +++ b/test/rules/continuous_transition/test_a.jl @@ -13,18 +13,17 @@ import ReactiveMP: @test_rules, getjacobians, getunits # It is derived separately by Thijs van de Laar function benchmark_rule(q_y_x, q_W) myx, Vyx = mean_cov(q_y_x) - dy = size(q_W.S, 1) - Vx = Vyx[(dy + 1):end, (dy + 1):end] - mx = myx[(dy + 1):end] - my = myx[1:dy] + dy = size(q_W.S, 1) + Vx = Vyx[(dy + 1):end, (dy + 1):end] + mx = myx[(dy + 1):end] + my = myx[1:dy] Vyx = Vyx[1:dy, (dy + 1):end] - mW = mean(q_W) - Λ = kron(Vx + mx*mx', mW) - return MvNormalWeightedMeanPrecision(Λ*vec((Vyx + my*mx')*inv((Vx + mx*mx'))), Λ) + mW = mean(q_W) + Λ = kron(Vx + mx * mx', mW) + return MvNormalWeightedMeanPrecision(Λ * vec((Vyx + my * mx') * inv((Vx + mx * mx'))), Λ) end @testset "Structured: (q_y_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_W::Any, meta::CTMeta)" begin - for (dy, dx) in [(1, 3), (2, 3), (3, 2), (2, 2)] dydx = dy * dx transformation = (a) -> reshape(a, dy, dx) @@ -33,50 +32,35 @@ import ReactiveMP: @test_rules, getjacobians, getunits Lx, Ly = rand(rng, dx, dx), rand(rng, dy, dy) μx, Σx = rand(rng, dx), Lx * Lx' μy, Σy = rand(rng, dy), Ly * Ly' - + qyx = MvNormalMeanCovariance([μy; μx], [Σy zeros(dy, dx); zeros(dx, dy) Σx]) qa = MvNormalMeanCovariance(a0, diageye(dydx)) - qW = Wishart(dy+1, diageye(dy)) + qW = Wishart(dy + 1, diageye(dy)) @test_rules [check_type_promotion = false] ContinuousTransition(:a, Marginalisation) [( - input = ( - q_y_x = qyx, - q_a = qa, - q_W = qW, - meta = metal - ), - output = benchmark_rule(qyx, qW) + input = (q_y_x = qyx, q_a = qa, q_W = qW, meta = metal), output = benchmark_rule(qyx, qW) )] end end end @testset "Nonlinear transformation" begin - @testset "Structured: (q_y_x::MultivariateNormalDistributionsFamily, q_a::Any, q_W::Any, meta::CTMeta)" begin - dy, dx = 2, 2 dydx = dy * dy transformation = (a) -> [cos(a[1]) -sin(a[1]); sin(a[1]) cos(a[1])] a0 = zeros(Int, 1) metanl = CTMeta(transformation, a0) - μx, Σx = ones(dx), diageye(dx) + μx, Σx = ones(dx), diageye(dx) μy, Σy = ones(dy), diageye(dy) qyx = MvNormalMeanCovariance([μy; μx], [Σy zeros(dy, dx); zeros(dx, dy) Σx]) qa = MvNormalMeanCovariance(a0, diageye(1)) qW = Wishart(dy, diageye(dy)) @test_rules [check_type_promotion = true] ContinuousTransition(:a, Marginalisation) [( - input = ( - q_y_x = qyx, - q_a = qa, - q_W = qW, - meta = metanl - ), - output = MvNormalWeightedMeanPrecision(zeros(1), (qW.df*dy*dx)*diageye(1)) + input = (q_y_x = qyx, q_a = qa, q_W = qW, meta = metanl), output = MvNormalWeightedMeanPrecision(zeros(1), (qW.df * dy * dx) * diageye(1)) )] end end - end end diff --git a/test/rules/continuous_transition/test_x.jl b/test/rules/continuous_transition/test_x.jl index 428ae8fdb..4746a67d3 100644 --- a/test/rules/continuous_transition/test_x.jl +++ b/test/rules/continuous_transition/test_x.jl @@ -5,7 +5,6 @@ using Test, ReactiveMP, BayesBase, Random, ExponentialFamily, Distributions, Lin import ReactiveMP: @test_rules, ctcompanion_matrix, getjacobians, getunits @testset "rules:ContinuousTransition:x" begin - rng = MersenneTwister(42) @testset "Linear transformation" begin @@ -16,8 +15,8 @@ import ReactiveMP: @test_rules, ctcompanion_matrix, getjacobians, getunits mW = mean(q_W) - Λ = tr(mW*ΣA)*UA + mA'*inv(Vy + inv(mW))*mA - ξ = mA'*inv(Vy + inv(mW))*my + Λ = tr(mW * ΣA) * UA + mA' * inv(Vy + inv(mW)) * mA + ξ = mA' * inv(Vy + inv(mW)) * my return MvNormalWeightedMeanPrecision(ξ, Λ) end @@ -32,25 +31,23 @@ import ReactiveMP: @test_rules, ctcompanion_matrix, getjacobians, getunits metal = CTMeta(transformation, a0) Lx, Ly = rand(rng, dx, dx), rand(rng, dy, dy) μy, Σy = rand(rng, dy), Ly * Ly' - + qy = MvNormalMeanCovariance(μy, Σy) qa = MvNormalMeanCovariance(a0, diageye(dydx)) - qW = Wishart(dy+1, diageye(dy)) - + qW = Wishart(dy + 1, diageye(dy)) + @test_rules [check_type_promotion = false] ContinuousTransition(:x, Marginalisation) [( - input = (m_y = qy, q_a = qa, q_W = qW, meta = metal), - output = benchmark_rule(qy, qW, mA, ΣA, UA) + input = (m_y = qy, q_a = qa, q_W = qW, meta = metal), output = benchmark_rule(qy, qW, mA, ΣA, UA) ) # Additional test cases with different distributions and metadata settings # Each case should represent a realistic scenario for your application - ] +] end end end @testset "Nonlinear transformation" begin - @testset "Structured: (q_y_x::MultivariateNormalDistributionsFamily, q_a::Any, q_W::Any, meta::CTMeta)" begin - + @testset "Structured: (m_y::MultivariateNormalDistributionsFamily, q_a::Any, q_W::Any, meta::CTMeta)" begin dy, dx = 2, 2 dydx = dy * dy transformation = (a) -> [cos(a[1]) -sin(a[1]); sin(a[1]) cos(a[1])] @@ -59,17 +56,14 @@ import ReactiveMP: @test_rules, ctcompanion_matrix, getjacobians, getunits μy, Σy = zeros(dy), diageye(dy) qy = MvNormalMeanCovariance(μy, Σy) - qa = MvNormalMeanCovariance(a0, tiny*diageye(1)) - qW = Wishart(dy+1, diageye(dy)) + qa = MvNormalMeanCovariance(a0, tiny * diageye(1)) + qW = Wishart(dy + 1, diageye(dy)) - @test_rules [check_type_promotion = false] ContinuousTransition(:x, Marginalisation) [( - input = (m_y = qy, q_a = qa, q_W = qW, meta = metanl), - output = MvGaussianWeightedMeanPrecision(zeros(dx), 3/4*diageye(dx)) - ) - ] + @test_rules [check_type_promotion = true] ContinuousTransition(:x, Marginalisation) [( + input = (m_y = qy, q_a = qa, q_W = qW, meta = metanl), output = MvGaussianWeightedMeanPrecision(zeros(dx), 3 / 4 * diageye(dx)) + )] end end - end end diff --git a/test/rules/continuous_transition/test_y.jl b/test/rules/continuous_transition/test_y.jl index f3d79edc2..06c71170a 100644 --- a/test/rules/continuous_transition/test_y.jl +++ b/test/rules/continuous_transition/test_y.jl @@ -5,14 +5,56 @@ using Test, ReactiveMP, BayesBase, Random, ExponentialFamily, Distributions import ReactiveMP: @test_rules, ctcompanion_matrix, getjacobians, getunits @testset "rules:ContinuousTransition:y" begin - @testset "Structured: (m_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_W::Any, meta::CTMeta)" begin - # Example transformation function and vector length for CTMeta - meta = CTMeta((a) -> reshape(a, 2, 3), 6) - - @test_rules [check_type_promotion = true] ContinuousTransition(:y, Marginalisation) [( - input = (m_x = MvNormalMeanCovariance(zeros(3), diageye(3)), q_a = MvNormalMeanCovariance(zeros(6), diageye(6)), q_W = Wishart(3, diageye(2)), meta = meta), - output = MvNormalMeanCovariance(zeros(2), 1 / 3 * diageye(2)) - )] + rng = MersenneTwister(42) + + @testset "Linear transformation" begin + # the following rule is used for testing purposes only + # It is derived separately by Thijs van de Laar + function benchmark_rule(q_x, q_W, mA) + mx, Vx = mean_cov(q_x) + mW = mean(q_W) + return MvNormalMeanCovariance(mA * mx, mA * Vx * mA' + inv(mW)) + end + + @testset "Structured: (m_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_W::Any, meta::CTMeta)" begin + for (dy, dx) in [(1, 3), (2, 3), (3, 2), (2, 2)] + dydx = dy * dx + transformation = (a) -> reshape(a, dy, dx) + + mA = rand(rng, dy, dx) + a0 = Float32.(vec(mA)) + metal = CTMeta(transformation, a0) + Lx = rand(rng, dx, dx) + μx, Σx = rand(rng, dx), Lx * Lx' + + qx = MvNormalMeanCovariance(μx, Σx) + qa = MvNormalMeanCovariance(a0, diageye(dydx)) + qW = Wishart(dy + 1, diageye(dy)) + + @test_rules [check_type_promotion = true, atol = 1e-5] ContinuousTransition(:y, Marginalisation) [( + input = (m_x = qx, q_a = qa, q_W = qW, meta = metal), output = benchmark_rule(qx, qW, mA) + )] + end + end + end + + @testset "Nonlinear transformation" begin + @testset "Structured: (m_x::MultivariateNormalDistributionsFamily, q_a::Any, q_W::Any, meta::CTMeta)" begin + dy, dx = 2, 2 + dydx = dy * dy + transformation = (a) -> [cos(a[1]) -sin(a[1]); sin(a[1]) cos(a[1])] + a0 = zeros(Float32, 1) + metanl = CTMeta(transformation, a0) + μx, Σx = zeros(dx), diageye(dx) + + qx = MvNormalMeanCovariance(μx, Σx) + qa = MvNormalMeanCovariance(a0, tiny * diageye(1)) + qW = Wishart(dy + 1, diageye(dy)) + + @test_rules [check_type_promotion = true] ContinuousTransition(:y, Marginalisation) [( + input = (m_x = qx, q_a = qa, q_W = qW, meta = metanl), output = MvGaussianMeanCovariance(zeros(dy), 4 / 3 * diageye(dy)) + )] + end end # Additional tests for edge cases, errors, or specific behaviors of the rule can be added here From 5640921fd0cf25a08320a03c9b9e63ae47983e54 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Fri, 8 Dec 2023 20:11:09 +0100 Subject: [PATCH 20/38] Update test marginals --- .../continuous_transition/test_marginals.jl | 76 +++++++++++++++---- 1 file changed, 63 insertions(+), 13 deletions(-) diff --git a/test/rules/continuous_transition/test_marginals.jl b/test/rules/continuous_transition/test_marginals.jl index d8a023432..10009fc93 100644 --- a/test/rules/continuous_transition/test_marginals.jl +++ b/test/rules/continuous_transition/test_marginals.jl @@ -4,19 +4,69 @@ using Test, ReactiveMP, BayesBase, Random, ExponentialFamily, Distributions, Lin import ReactiveMP: @test_marginalrules @testset "marginalrules:ContinuousTransition" begin - @testset "y_x: (m_y::NormalDistributionsFamily, m_x::NormalDistributionsFamily, q_a::NormalDistributionsFamily, q_W::Any)" begin - meta = CTMeta((a) -> reshape(a, 2, 3), 6) - - @test_marginalrules [check_type_promotion = true] ContinuousTransition(:y_x) [( - input = ( - m_y = MvNormalMeanPrecision(ones(2), diageye(2)), - m_x = MvNormalMeanPrecision(ones(3), diageye(3)), - q_a = MvNormalMeanPrecision(ones(6), diageye(6)), - q_W = Wishart(2, diageye(2)), - meta = meta - ), - output = MvNormalWeightedMeanPrecision(zeros(5), diageye(5)) - )] + rng = MersenneTwister(42) + @testset "Linear transformation" begin + # the following rule is used for testing purposes only + # It is derived separately by Thijs van de Laar + function benchmark_rule(m_x, m_y, q_W, q_A) + mx, Wx = mean_invcov(m_x) + my, Wy = mean_invcov(m_y) + mW = mean(q_W) + mA, ΣA, UA = q_A.M, q_A.U, q_A.V + + U = Wx + tr(mW * ΣA) * UA + + Wq = [Wy+mW -mW*mA; -mA'*mW U+mA' * mW * mA] + return MvNormalWeightedMeanPrecision([Wy * my; Wx * mx], Wq) + end + @testset "y_x: (m_y::NormalDistributionsFamily, m_x::NormalDistributionsFamily, q_a::NormalDistributionsFamily, q_W::Any)" begin + for (dy, dx) in [(1, 3), (2, 3), (3, 2), (2, 2)] + transformation = (a) -> reshape(a, dy, dx) + + mA, ΣA, UA = rand(rng, dy, dx), diageye(dy), diageye(dx) + qA = MatrixNormal(mA, ΣA, UA) + a0 = Float32.(vec(mA)) + + Lx, Ly = rand(rng, dx, dx), rand(rng, dy, dy) + μx, Σx = rand(rng, dx), Lx * Lx' + μy, Σy = rand(rng, dy), Ly * Ly' + + my = MvNormalMeanCovariance(μy, Σy) + mx = MvNormalMeanCovariance(μx, Σx) + qa = MvNormalMeanCovariance(vec(mA), kron(UA, ΣA)) + qW = Wishart(dy + 1, diageye(dy)) + + metal = CTMeta(transformation, a0) + + @test_marginalrules [check_type_promotion = false, atol = 1e-5] ContinuousTransition(:y_x) [( + input = (m_y = my, m_x = mx, q_a = qa, q_W = qW, meta = metal), output = benchmark_rule(mx, my, qW, qA) + )] + end + end + end + + @testset "Nonlinear transformation" begin + @testset "y_x: (m_y::NormalDistributionsFamily, m_x::NormalDistributionsFamily, q_a::NormalDistributionsFamily, q_W::Any)" begin + dy, dx = 2, 2 + transformation = (a) -> [cos(a[1]) -sin(a[1]); sin(a[1]) cos(a[1])] + + a0 = zeros(Int, 1) + + μx, Σx = zeros(dx), diageye(dx) + μy, Σy = zeros(dy), diageye(dy) + + my = MvNormalMeanCovariance(μy, Σy) + mx = MvNormalMeanCovariance(μx, Σx) + qa = MvNormalMeanCovariance(a0, tiny * diageye(1)) + qW = Wishart(dy, diageye(dy)) + + metanl = CTMeta(transformation, a0) + + @test_marginalrules [check_type_promotion = true, atol = 1e-5] ContinuousTransition(:y_x) [( + input = (m_y = my, m_x = mx, q_a = qa, q_W = qW, meta = metanl), + output = MvNormalWeightedMeanPrecision(zeros(4), [(dy + qW.df - 1)*diageye(dy) -(qW.df)diageye(dx); -(qW.df)diageye(dx) (dy + qW.df - 1)diageye(dy)]) + )] + end end end end From 51c2813db101ff9d4ce78ebeebf4a23803459f3d Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Fri, 8 Dec 2023 20:17:10 +0100 Subject: [PATCH 21/38] Update promote type --- test/rules/continuous_transition/test_marginals.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/rules/continuous_transition/test_marginals.jl b/test/rules/continuous_transition/test_marginals.jl index 10009fc93..c8ee5a417 100644 --- a/test/rules/continuous_transition/test_marginals.jl +++ b/test/rules/continuous_transition/test_marginals.jl @@ -38,7 +38,7 @@ import ReactiveMP: @test_marginalrules metal = CTMeta(transformation, a0) - @test_marginalrules [check_type_promotion = false, atol = 1e-5] ContinuousTransition(:y_x) [( + @test_marginalrules [check_type_promotion = true, atol = 1e-3] ContinuousTransition(:y_x) [( input = (m_y = my, m_x = mx, q_a = qa, q_W = qW, meta = metal), output = benchmark_rule(mx, my, qW, qA) )] end From 68719f59c287f8ef092ea8d6459ca22d1f3a4ad0 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Fri, 8 Dec 2023 20:37:54 +0100 Subject: [PATCH 22/38] Update docs --- src/nodes/continuous_transition.jl | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/nodes/continuous_transition.jl b/src/nodes/continuous_transition.jl index d5134e684..4052f5ea9 100644 --- a/src/nodes/continuous_transition.jl +++ b/src/nodes/continuous_transition.jl @@ -30,17 +30,14 @@ const CTransition = ContinuousTransition @doc raw""" `CTMeta` is used as a metadata flag in `ContinuousTransition` to define the transformation function for constructing the matrix `A` from vector `a`. -There are two scenarios for specifying the transformation: -1. **Linear Transformation**: In this case, `CTMeta` requires a transformation function and the length of vector `a`. -2. **Nonlinear Transformation**: For nonlinear transformations, `CTMeta` expects a transformation function and a vector `â`, which acts as an expansion point for approximating the transformation linearly. +`CTMeta` requires a transformation function and the length of vector `a`, which acts as an expansion point for approximating the transformation linearly. If transformation appears to be linear, then no approximation is performed. Constructors: -- `CTMeta(transformation::Function, len::Integer)`: Used for linear transformations. -- `CTMeta(transformation::Function, â::Vector{<:Real})`: Used for nonlinear transformations. +- `CTMeta(transformation::Function, â::Vector{<:Real})`: Constructs a `CTMeta` struct with the transformation function and allocated basis vectors. Fields: - `ds`: A tuple indicating the dimensionality of the ContinuousTransition (dy, dx). -- `Fs`: Represents the masks, which can be either a Vector of AbstractMatrices or a Function, depending on the transformation type. +- `f`: Represents the transformation function that transforms vector `a` into matrix `A` - `es`: A Vector of unit vectors used in the transformation process. The `CTMeta` struct plays a pivotal role in defining how the vector `a` is transformed into the matrix `A`, thus influencing the behavior of the `ContinuousTransition` node. From 5f267dc890a5e630fcda7381b8f3144e8a3ee4c4 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Fri, 8 Dec 2023 20:44:30 +0100 Subject: [PATCH 23/38] Update CTransition node --- src/nodes/continuous_transition.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nodes/continuous_transition.jl b/src/nodes/continuous_transition.jl index 4052f5ea9..e73e5234a 100644 --- a/src/nodes/continuous_transition.jl +++ b/src/nodes/continuous_transition.jl @@ -1,4 +1,4 @@ -export transfominator, CTransition, ContinuousTransition, CTMeta +export CTransition, ContinuousTransition, CTMeta import LazyArrays import StatsFuns: log2π From c2648622b2728cff6311ddd7355d58b54873935c Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Fri, 8 Dec 2023 21:24:09 +0100 Subject: [PATCH 24/38] Remove matrix normal --- src/ReactiveMP.jl | 1 - src/nodes/matrix_normal.jl | 8 -------- src/rules/matrix_normal/out.jl | 3 --- src/rules/prototypes.jl | 1 - 4 files changed, 13 deletions(-) delete mode 100644 src/nodes/matrix_normal.jl delete mode 100644 src/rules/matrix_normal/out.jl diff --git a/src/ReactiveMP.jl b/src/ReactiveMP.jl index b0da1e521..d5c860c4c 100644 --- a/src/ReactiveMP.jl +++ b/src/ReactiveMP.jl @@ -98,7 +98,6 @@ include("nodes/uninformative.jl") include("nodes/uniform.jl") include("nodes/normal_mean_variance.jl") include("nodes/normal_mean_precision.jl") -include("nodes/matrix_normal.jl") include("nodes/mv_normal_mean_covariance.jl") include("nodes/mv_normal_mean_precision.jl") include("nodes/mv_normal_mean_scale_precision.jl") diff --git a/src/nodes/matrix_normal.jl b/src/nodes/matrix_normal.jl deleted file mode 100644 index 3ec66c2fe..000000000 --- a/src/nodes/matrix_normal.jl +++ /dev/null @@ -1,8 +0,0 @@ -@node MatrixNormal Stochastic [out, M, U, V] - -# we use equivalence of `` -@average_energy MatrixNormal (q_out::Any, q_M::Any, q_U::Any, q_V::Any) = begin - q_Σ = PointMass(kron(mean(q_U), mean(q_V))) - q_m = PointMass(vec(mean(q_M))) - -score(AverageEnergy(), MvNormalMeanCovariance, Val{(:out, :μ, :Σ)}(), map((q) -> Marginal(q, false, false, nothing), (q_out, q_m, q_Σ)), nothing) -end diff --git a/src/rules/matrix_normal/out.jl b/src/rules/matrix_normal/out.jl deleted file mode 100644 index 837e1d7be..000000000 --- a/src/rules/matrix_normal/out.jl +++ /dev/null @@ -1,3 +0,0 @@ -@rule MatrixNormal(:out, Marginalisation) (q_M::PointMass, q_U::PointMass, q_V::PointMass) = begin - MvNormalMeanCovariance(vec(mean(q_M)), kron(mean(q_U), mean(q_V))) -end diff --git a/src/rules/prototypes.jl b/src/rules/prototypes.jl index 934149a25..af1aefeb7 100644 --- a/src/rules/prototypes.jl +++ b/src/rules/prototypes.jl @@ -62,7 +62,6 @@ include("normal_mean_variance/mean.jl") include("normal_mean_variance/var.jl") include("normal_mean_variance/marginals.jl") -include("matrix_normal/out.jl") include("mv_normal_mean_precision/out.jl") include("mv_normal_mean_precision/mean.jl") From 1972f23c91af91f37754dd783514fcf3962d6380 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Sat, 9 Dec 2023 14:32:56 +0100 Subject: [PATCH 25/38] Make format --- src/rules/prototypes.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/rules/prototypes.jl b/src/rules/prototypes.jl index af1aefeb7..c152401ec 100644 --- a/src/rules/prototypes.jl +++ b/src/rules/prototypes.jl @@ -62,7 +62,6 @@ include("normal_mean_variance/mean.jl") include("normal_mean_variance/var.jl") include("normal_mean_variance/marginals.jl") - include("mv_normal_mean_precision/out.jl") include("mv_normal_mean_precision/mean.jl") include("mv_normal_mean_precision/precision.jl") From 240c44929d283c3d6fbf3c02256edd0b3ddf8728 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Sat, 9 Dec 2023 15:18:42 +0100 Subject: [PATCH 26/38] Update tests --- test/runtests.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index f8d5d25b3..633948f09 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -389,5 +389,11 @@ end addtests(testrunner, "rules/autoregressive/test_theta.jl") addtests(testrunner, "rules/autoregressive/test_marginals.jl") + addtests(testrunner, "rules/continuous_transition/test_a.jl") + addtests(testrunner, "rules/continuous_transition/test_W.jl") + addtests(testrunner, "rules/continuous_transition/test_x.jl") + addtests(testrunner, "rules/continuous_transition/test_y.jl") + addtests(testrunner, "rules/continuous_transition/test_marginals.jl") + run(testrunner) end From ad7f7fd57eab98590da83bbb411cd86f662608ba Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Sat, 9 Dec 2023 16:38:43 +0100 Subject: [PATCH 27/38] Speed up --- src/rules/continuous_transition/x.jl | 10 ++++++---- src/rules/continuous_transition/y.jl | 2 +- test/rules/continuous_transition/test_x.jl | 7 ++----- test/rules/continuous_transition/test_y.jl | 2 -- 4 files changed, 9 insertions(+), 12 deletions(-) diff --git a/src/rules/continuous_transition/x.jl b/src/rules/continuous_transition/x.jl index 31f39ae1c..f40b0f222 100644 --- a/src/rules/continuous_transition/x.jl +++ b/src/rules/continuous_transition/x.jl @@ -1,6 +1,6 @@ @rule ContinuousTransition(:x, Marginalisation) (m_y::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_W::Any, meta::CTMeta) = begin ma, Va = mean_cov(q_a) - my, Vy = mean_cov(m_y) + my, Wy = mean_precision(m_y) mW = mean(q_W) @@ -10,9 +10,11 @@ mA = ctcompanion_matrix(ma, sqrt.(var(q_a)), meta) W = sum(sum(es[j]' * mW * es[i] * Fs[j] * Va * Fs[i]' for i in 1:length(Fs)) for j in 1:length(Fs)) - - z = mA' * inv(Vy + inv(mW)) * my - Ξ = mA' * inv(Vy + inv(mW)) * mA + W + # Woodbury identity + # inv(inv(Wy) + inv(mW)) = Wy - Wy * inv(Wy + mW) * Wy + WymW = Wy - Wy * inv(Wy + mW) * Wy + z = mA' * WymW * my + Ξ = mA' * WymW * mA + W return MvNormalWeightedMeanPrecision(z, Ξ) end diff --git a/src/rules/continuous_transition/y.jl b/src/rules/continuous_transition/y.jl index b812c50fd..a74c0bd6f 100644 --- a/src/rules/continuous_transition/y.jl +++ b/src/rules/continuous_transition/y.jl @@ -1,5 +1,5 @@ @rule ContinuousTransition(:y, Marginalisation) (m_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_W::Any, meta::CTMeta) = begin - ma, Va = mean_cov(q_a) + ma = mean(q_a) mx, Vx = mean_cov(m_x) mW = mean(q_W) diff --git a/test/rules/continuous_transition/test_x.jl b/test/rules/continuous_transition/test_x.jl index 4746a67d3..84da782ec 100644 --- a/test/rules/continuous_transition/test_x.jl +++ b/test/rules/continuous_transition/test_x.jl @@ -36,12 +36,9 @@ import ReactiveMP: @test_rules, ctcompanion_matrix, getjacobians, getunits qa = MvNormalMeanCovariance(a0, diageye(dydx)) qW = Wishart(dy + 1, diageye(dy)) - @test_rules [check_type_promotion = false] ContinuousTransition(:x, Marginalisation) [( + @test_rules [check_type_promotion = true, atol = 1e-4] ContinuousTransition(:x, Marginalisation) [( input = (m_y = qy, q_a = qa, q_W = qW, meta = metal), output = benchmark_rule(qy, qW, mA, ΣA, UA) - ) - # Additional test cases with different distributions and metadata settings - # Each case should represent a realistic scenario for your application -] + )] end end end diff --git a/test/rules/continuous_transition/test_y.jl b/test/rules/continuous_transition/test_y.jl index 06c71170a..f3a3b2abb 100644 --- a/test/rules/continuous_transition/test_y.jl +++ b/test/rules/continuous_transition/test_y.jl @@ -56,8 +56,6 @@ import ReactiveMP: @test_rules, ctcompanion_matrix, getjacobians, getunits )] end end - - # Additional tests for edge cases, errors, or specific behaviors of the rule can be added here end end From ddd30a870954d407e4ee805dd9264df771fff6f2 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Mon, 11 Dec 2023 12:20:24 +0100 Subject: [PATCH 28/38] Change inv to cholinv --- src/rules/continuous_transition/x.jl | 3 +-- src/rules/continuous_transition/y.jl | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/rules/continuous_transition/x.jl b/src/rules/continuous_transition/x.jl index f40b0f222..ff4977fbb 100644 --- a/src/rules/continuous_transition/x.jl +++ b/src/rules/continuous_transition/x.jl @@ -1,7 +1,6 @@ @rule ContinuousTransition(:x, Marginalisation) (m_y::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_W::Any, meta::CTMeta) = begin ma, Va = mean_cov(q_a) my, Wy = mean_precision(m_y) - mW = mean(q_W) dy, dx = getdimensionality(meta) @@ -12,7 +11,7 @@ W = sum(sum(es[j]' * mW * es[i] * Fs[j] * Va * Fs[i]' for i in 1:length(Fs)) for j in 1:length(Fs)) # Woodbury identity # inv(inv(Wy) + inv(mW)) = Wy - Wy * inv(Wy + mW) * Wy - WymW = Wy - Wy * inv(Wy + mW) * Wy + WymW = Wy - Wy * cholinv(Wy + mW) * Wy z = mA' * WymW * my Ξ = mA' * WymW * mA + W diff --git a/src/rules/continuous_transition/y.jl b/src/rules/continuous_transition/y.jl index a74c0bd6f..c9c61d858 100644 --- a/src/rules/continuous_transition/y.jl +++ b/src/rules/continuous_transition/y.jl @@ -9,7 +9,7 @@ mA = ctcompanion_matrix(ma, sqrt.(var(q_a)), meta) - Vy = mA * Vx * mA' + inv(mW) + Vy = mA * Vx * mA' + cholinv(mW) my = mA * mx return MvNormalMeanCovariance(my, Vy) From 77cb69bda71913796b27bc290fa2cc3a9304cb0d Mon Sep 17 00:00:00 2001 From: Albert Date: Mon, 11 Dec 2023 16:32:21 +0100 Subject: [PATCH 29/38] Update src/nodes/continuous_transition.jl Co-authored-by: Bart van Erp <44952318+bartvanerp@users.noreply.github.com> --- src/nodes/continuous_transition.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/nodes/continuous_transition.jl b/src/nodes/continuous_transition.jl index e73e5234a..44bafd6d1 100644 --- a/src/nodes/continuous_transition.jl +++ b/src/nodes/continuous_transition.jl @@ -42,10 +42,10 @@ Fields: The `CTMeta` struct plays a pivotal role in defining how the vector `a` is transformed into the matrix `A`, thus influencing the behavior of the `ContinuousTransition` node. """ -struct CTMeta - ds::Tuple # dimensionality of ContinuousTransition (dy, dx) - f::Function # transformation function - es::Vector{<:AbstractVector} # unit vectors +struct CTMeta{T <: Tuple, F <: Function, V <: Vector{<:AbstractVector}} + ds::T # dimensionality of ContinuousTransition (dy, dx) + f::F# transformation function + es::V # unit vectors # NOTE: this meta is not user-friendly, I don't like a vector # perhaps making mutable struct with empty meta first will be better from user perspective From 8de19999ebae64d3c705b4056a1bc124a374d951 Mon Sep 17 00:00:00 2001 From: Albert Date: Mon, 11 Dec 2023 16:38:50 +0100 Subject: [PATCH 30/38] Update src/nodes/continuous_transition.jl Co-authored-by: Bart van Erp <44952318+bartvanerp@users.noreply.github.com> --- src/nodes/continuous_transition.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nodes/continuous_transition.jl b/src/nodes/continuous_transition.jl index 44bafd6d1..534d9cfdb 100644 --- a/src/nodes/continuous_transition.jl +++ b/src/nodes/continuous_transition.jl @@ -6,7 +6,7 @@ import StatsFuns: log2π @doc raw""" The ContinuousTransition node transforms an m-dimensional (dx) vector x into an n-dimensional (dy) vector y via a linear (or nonlinear) transformation with a `n×m`-dimensional matrix `A` that is constructed from a vector `a`. -To construct the matrix A, the elements of `a` are filled into A according to the transformation function provided with meta. `a` must be of MultivariateNormalDistributionsFamily type. If you intend to use univariate Gaussian, use it as a vector of length `1``, e.g. `a ~ MvNormalMeanCovariance([0.0], [1.;])`. +To construct the matrix `A`, the elements of `a` are reshaped into `A` according to the transformation function provided in the meta. `a` must be of `MultivariateNormalDistributionsFamily` type. If you intend to use univariate Gaussian distributions, use it as a vector of length `1``, e.g. `a ~ MvNormalMeanCovariance([0.0], [1.;])`. Check CTMeta for more details on how to specify the transformation function that **must** return a matrix. From ba4b6f61dc79adee703de4f6e8ec54925aacb04e Mon Sep 17 00:00:00 2001 From: Albert Date: Mon, 11 Dec 2023 20:27:51 +0100 Subject: [PATCH 31/38] Refactor CTransition --- src/nodes/continuous_transition.jl | 73 +++++++++---------- src/rules/continuous_transition/W.jl | 18 +++-- src/rules/continuous_transition/a.jl | 9 +-- src/rules/continuous_transition/marginals.jl | 19 +++-- src/rules/continuous_transition/x.jl | 9 ++- src/rules/continuous_transition/y.jl | 6 +- test/nodes/test_continuous_transition.jl | 22 ++---- test/rules/continuous_transition/test_W.jl | 8 +- test/rules/continuous_transition/test_a.jl | 4 +- .../continuous_transition/test_marginals.jl | 9 +-- test/rules/continuous_transition/test_x.jl | 11 ++- test/rules/continuous_transition/test_y.jl | 12 +-- 12 files changed, 90 insertions(+), 110 deletions(-) diff --git a/src/nodes/continuous_transition.jl b/src/nodes/continuous_transition.jl index 534d9cfdb..c50e6592a 100644 --- a/src/nodes/continuous_transition.jl +++ b/src/nodes/continuous_transition.jl @@ -1,4 +1,4 @@ -export CTransition, ContinuousTransition, CTMeta +export CTransition, ContinuousTransition, CTMeta, ContinuousTransitionMeta import LazyArrays import StatsFuns: log2π @@ -6,12 +6,12 @@ import StatsFuns: log2π @doc raw""" The ContinuousTransition node transforms an m-dimensional (dx) vector x into an n-dimensional (dy) vector y via a linear (or nonlinear) transformation with a `n×m`-dimensional matrix `A` that is constructed from a vector `a`. -To construct the matrix `A`, the elements of `a` are reshaped into `A` according to the transformation function provided in the meta. `a` must be of `MultivariateNormalDistributionsFamily` type. If you intend to use univariate Gaussian distributions, use it as a vector of length `1``, e.g. `a ~ MvNormalMeanCovariance([0.0], [1.;])`. +To construct the matrix `A`, the elements of `a` are reshaped into `A` according to the transformation function provided in the meta. If you intend to use univariate Gaussian distributions, use it as a vector of length `1``, e.g. `a ~ MvNormalMeanCovariance([0.0], [1.;])`. -Check CTMeta for more details on how to specify the transformation function that **must** return a matrix. +Check ContinuousTransitionMeta for more details on how to specify the transformation function that **must** return a matrix. ```julia -y ~ ContinuousTransition(x, a, W) where {meta = CTMeta(transformation, â)} +y ~ ContinuousTransition(x, a, W) where {meta = ContinuousTransitionMeta(transformation)} ``` Interfaces: 1. y - n-dimensional output of the ContinuousTransition node. @@ -28,38 +28,30 @@ const CTransition = ContinuousTransition @node ContinuousTransition Stochastic [y, x, a, W] @doc raw""" -`CTMeta` is used as a metadata flag in `ContinuousTransition` to define the transformation function for constructing the matrix `A` from vector `a`. +`ContinuousTransitionMeta` is used as a metadata flag in `ContinuousTransition` to define the transformation function for constructing the matrix `A` from vector `a`. -`CTMeta` requires a transformation function and the length of vector `a`, which acts as an expansion point for approximating the transformation linearly. If transformation appears to be linear, then no approximation is performed. +`ContinuousTransitionMeta` requires a transformation function and the length of vector `a`, which acts as an expansion point for approximating the transformation linearly. If transformation appears to be linear, then no approximation is performed. Constructors: -- `CTMeta(transformation::Function, â::Vector{<:Real})`: Constructs a `CTMeta` struct with the transformation function and allocated basis vectors. +- `ContinuousTransitionMeta(transformation::Function, â::Vector{<:Real})`: Constructs a `ContinuousTransitionMeta` struct with the transformation function and allocated basis vectors. Fields: -- `ds`: A tuple indicating the dimensionality of the ContinuousTransition (dy, dx). - `f`: Represents the transformation function that transforms vector `a` into matrix `A` -- `es`: A Vector of unit vectors used in the transformation process. -The `CTMeta` struct plays a pivotal role in defining how the vector `a` is transformed into the matrix `A`, thus influencing the behavior of the `ContinuousTransition` node. +The `ContinuousTransitionMeta` struct plays a pivotal role in defining how the vector `a` is transformed into the matrix `A`, thus influencing the behavior of the `ContinuousTransition` node. """ -struct CTMeta{T <: Tuple, F <: Function, V <: Vector{<:AbstractVector}} - ds::T # dimensionality of ContinuousTransition (dy, dx) - f::F# transformation function - es::V # unit vectors - - # NOTE: this meta is not user-friendly, I don't like a vector - # perhaps making mutable struct with empty meta first will be better from user perspective - # meta for transformation of a vector to a matrix - function CTMeta(transformation::Function, â::Vector{<:Real}) - dy, dx = size(transformation(â)) - es = [StandardBasisVector(dy, i, one(eltype(first(â)))) for i in 1:dy] - return new((dy, dx), transformation, es) +struct ContinuousTransitionMeta{F <: Function} + f::F # transformation function + + function ContinuousTransitionMeta(transformation::F) where {F} + return new{F}(transformation) end end -getunits(meta::CTMeta) = meta.es -getdimensionality(meta::CTMeta) = meta.ds +const CTMeta = ContinuousTransitionMeta + gettransformation(meta::CTMeta) = meta.f +# getctoutputdim(meta::CTMeta, J) = div(size(J, 2), size(J, 1)) # returns dy where J ∈ ℝ^{dx × dydx} getjacobians(ctmeta::CTMeta, a) = process_Fs(gettransformation(ctmeta), a) process_Fs(f::Function, a) = [ForwardDiff.jacobian(a -> f(a)[i, :], a) for i in 1:size(f(a), 1)] @@ -72,12 +64,12 @@ default_functional_dependencies_pipeline(::Type{<:ContinuousTransition}) = Requi `ctcompanion_matrix` casts a vector `a` into a matrix `A` by means of linearization of the transformation function `f` around the expansion point `a0`. """ function ctcompanion_matrix(a, epsilon, meta::CTMeta) - a0 = a .+ epsilon # expansion point - Js, es = getjacobians(meta, a0), getunits(meta) - f = gettransformation(meta) - dy, _ = getdimensionality(meta) + a0 = a + epsilon # expansion point + Js = getjacobians(meta, a0) + f = gettransformation(meta) + dy = length(Js) # we approximate each row of A by a linear function and create a matrix A composed of the approximated rows - A = sum(es[i] * (f(a0)[i, :] + Js[i] * (a - a0))' for i in 1:dy) + A = sum(StandardBasisVector(dy, i) * (f(a0)[i, :] + Js[i] * (a - a0))' for i in 1:dy) return A end @@ -86,20 +78,23 @@ end myx, Vyx = mean_cov(q_y_x) mW = mean(q_W) - dy, dx = getdimensionality(meta) - Fs, es = getjacobians(meta, ma), getunits(meta) - n = div(ndims(q_y_x), 2) - mA = ctcompanion_matrix(ma, sqrt.(var(q_a)), meta) + Fs = getjacobians(meta, ma) # dx × dydx + dy = length(Fs) + + n = div(ndims(q_y_x), 2) + mA = ctcompanion_matrix(ma, sqrt.(var(q_a)), meta) mx, Vx = @views myx[(dy + 1):end], Vyx[(dy + 1):end, (dy + 1):end] my, Vy = @views myx[1:dy], Vyx[1:dy, 1:dy] Vyx = @view Vyx[1:dy, (dy + 1):end] - g₁ = my' * mW * my + tr(Vy * mW) - g₂ = mx' * mA' * mW * my + tr(Vyx * mA' * mW) - g₃ = g₂ - G = sum(sum(es[i]' * mW * es[j] * Fs[i] * (ma * ma' + Va) * Fs[j]' for i in 1:length(Fs)) for j in 1:length(Fs)) - g₄ = mx' * G * mx + tr(Vx * G) - AE = n / 2 * log2π - (mean(logdet, q_W) - (g₁ - g₂ - g₃ + g₄)) / 2 + # we proved (when Va = kron(U, S)): + # tr(W(Sx'Ux)) = tr(kron(xx', W)kron(U, S)) = tr(kron(xx', W)Va) + # sum(es[i]' * mW * es[j] * Fs[i] * Va * Fs[j]') = tr(WS)U + g1 = -mA * Vyx' + g2 = g1' + trWSU = sum(sum(StandardBasisVector(dy, i)' * mW * StandardBasisVector(dy, j) * Fs[i] * Va * Fs[j]' for i in 1:dy) for j in 1:dy) + + AE = n / 2 * log2π - mean(logdet, q_W) + (tr(mW * (mA * Vx * mA' + g1 + g2 + Vy + (mA * mx - my) * (mA * mx - my)')) + tr(trWSU) + tr(kron(mx * mx', mW) * Va)) / 2 return AE end diff --git a/src/rules/continuous_transition/W.jl b/src/rules/continuous_transition/W.jl index 785ae41ee..f28f6779c 100644 --- a/src/rules/continuous_transition/W.jl +++ b/src/rules/continuous_transition/W.jl @@ -1,27 +1,29 @@ -function compute_delta(my, Vy, mx, Vx, Vyx, mA, Va, ma, Fs, es) +function compute_delta(my, Vy, mx, Vx, Vyx, mA, Va, ma, Fs) + dy = length(my) G₁ = (my * my' + Vy) G₂ = ((my * mx' + Vyx) * mA') G₃ = transpose(G₂) Ex_xx = mx * mx' + Vx - G₅ = sum(sum(es[i] * ma' * Fs[i]'Ex_xx * Fs[j] * ma * es[j]' for i in 1:length(Fs)) for j in 1:length(Fs)) - G₆ = sum(sum(es[i] * tr(Fs[i]' * Ex_xx * Fs[j] * Va) * es[j]' for i in 1:length(Fs)) for j in 1:length(Fs)) + G₅ = sum(sum(StandardBasisVector(dy, i) * ma' * Fs[i]'Ex_xx * Fs[j] * ma * StandardBasisVector(dy, j)' for i in 1:dy) for j in 1:dy) + G₆ = sum(sum(StandardBasisVector(dy, i) * tr(Fs[i]' * Ex_xx * Fs[j] * Va) * StandardBasisVector(dy, j)' for i in 1:dy) for j in 1:dy) return G₁ - G₂ - G₃ + G₅ + G₆ end @rule ContinuousTransition(:W, Marginalisation) (q_y_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, meta::CTMeta) = begin - dy, dx = getdimensionality(meta) - ma, Va = mean_cov(q_a) - Fs, es = getjacobians(meta, ma), getunits(meta) + Fs = getjacobians(meta, ma) + dy = length(Fs) + + epsilon = sqrt.(var(q_a)) + mA = ctcompanion_matrix(ma, epsilon, meta) - mA = ctcompanion_matrix(ma, sqrt.(var(q_a)), meta) myx, Vyx = mean_cov(q_y_x) mx, Vx = @views myx[(dy + 1):end], Vyx[(dy + 1):end, (dy + 1):end] my, Vy = @views myx[1:dy], Vyx[1:dy, 1:dy] Vyx = @views Vyx[1:dy, (dy + 1):end] - Δ = compute_delta(my, Vy, mx, Vx, Vyx, mA, Va, ma, Fs, es) + Δ = compute_delta(my, Vy, mx, Vx, Vyx, mA, Va, ma, Fs) return WishartFast(dy + 2, Δ) end diff --git a/src/rules/continuous_transition/a.jl b/src/rules/continuous_transition/a.jl index 5992bf0e7..e95059f25 100644 --- a/src/rules/continuous_transition/a.jl +++ b/src/rules/continuous_transition/a.jl @@ -3,17 +3,16 @@ mW = mean(q_W) myx, Vyx = mean_cov(q_y_x) - dy, dx = getdimensionality(meta) + Fs = getjacobians(meta, ma) + dy = length(Fs) mx, Vx = @views myx[(dy + 1):end], Vyx[(dy + 1):end, (dy + 1):end] my, Vy = @views myx[1:dy], Vyx[1:dy, 1:dy] Vyx = @view Vyx[1:dy, (dy + 1):end] - Fs, es = getjacobians(meta, ma), getunits(meta) - # rank1update(Vyx, mx, my) equivalent to ξ = (Vyx + mx * my') - D = sum(sum(es[j]' * mW * es[i] * Fs[i]' * rank1update(Vx, mx) * Fs[j] for i in 1:length(Fs)) for j in 1:length(Fs)) - z = sum(Fs[i]' * rank1update(Vyx', mx, my) * mW * es[i] for i in 1:length(Fs)) + D = sum(sum(StandardBasisVector(dy, j)' * mW * StandardBasisVector(dy, i) * Fs[i]' * rank1update(Vx, mx) * Fs[j] for i in 1:dy) for j in 1:dy) + z = sum(Fs[i]' * rank1update(Vyx', mx, my) * mW * StandardBasisVector(dy, i) for i in 1:dy) return MvNormalWeightedMeanPrecision(z, D) end diff --git a/src/rules/continuous_transition/marginals.jl b/src/rules/continuous_transition/marginals.jl index 94ddea21e..5699c2536 100644 --- a/src/rules/continuous_transition/marginals.jl +++ b/src/rules/continuous_transition/marginals.jl @@ -6,21 +6,20 @@ end function continuous_tranition_marginal(m_y::MultivariateNormalDistributionsFamily, m_x::MultivariateNormalDistributionsFamily, q_a::Any, q_W::Any, meta::CTMeta) ma, Va = mean_cov(q_a) - Fs, es = getjacobians(meta, ma), getunits(meta) + Fs = getjacobians(meta, ma) + dy = length(Fs) mW = mean(q_W) - mA = ctcompanion_matrix(ma, sqrt.(var(q_a)), meta) + epsilon = sqrt.(var(q_a)) + mA = ctcompanion_matrix(ma, epsilon, meta) - b_my, b_Vy = mean_cov(m_y) - f_mx, f_Vx = mean_cov(m_x) + xiy, Wy = weightedmean_precision(m_y) + xix, Wx = weightedmean_precision(m_x) - inv_b_Vy = cholinv(b_Vy) - inv_f_Vx = cholinv(f_Vx) + Ξ = Wx + sum(sum(StandardBasisVector(dy, j)' * mW * StandardBasisVector(dy, i) * Fs[j] * Va * Fs[i]' for i in 1:dy) for j in 1:dy) - Ξ = inv_f_Vx + sum(sum(es[j]' * mW * es[i] * Fs[j] * Va * Fs[i]' for i in 1:length(Fs)) for j in 1:length(Fs)) - - W_11 = inv_b_Vy + mW + W_11 = Wy + mW # negate_inplace!(mW * mH) W_12 = -(mW * mA) @@ -30,7 +29,7 @@ function continuous_tranition_marginal(m_y::MultivariateNormalDistributionsFamil W_22 = Ξ + mA' * mW * mA W = [W_11 W_12; W_21 W_22] - ξ = [inv_b_Vy * b_my; inv_f_Vx * f_mx] + ξ = [xiy; xix] return MvNormalWeightedMeanPrecision(ξ, W) end diff --git a/src/rules/continuous_transition/x.jl b/src/rules/continuous_transition/x.jl index ff4977fbb..79e582e57 100644 --- a/src/rules/continuous_transition/x.jl +++ b/src/rules/continuous_transition/x.jl @@ -3,12 +3,13 @@ my, Wy = mean_precision(m_y) mW = mean(q_W) - dy, dx = getdimensionality(meta) - Fs, es = getjacobians(meta, ma), getunits(meta) + Fs = getjacobians(meta, ma) + dy = length(Fs) - mA = ctcompanion_matrix(ma, sqrt.(var(q_a)), meta) + epsilon = sqrt.(var(q_a)) + mA = ctcompanion_matrix(ma, epsilon, meta) - W = sum(sum(es[j]' * mW * es[i] * Fs[j] * Va * Fs[i]' for i in 1:length(Fs)) for j in 1:length(Fs)) + W = sum(sum(StandardBasisVector(dy, j)' * mW * StandardBasisVector(dy, i) * Fs[j] * Va * Fs[i]' for i in 1:dy) for j in 1:dy) # Woodbury identity # inv(inv(Wy) + inv(mW)) = Wy - Wy * inv(Wy + mW) * Wy WymW = Wy - Wy * cholinv(Wy + mW) * Wy diff --git a/src/rules/continuous_transition/y.jl b/src/rules/continuous_transition/y.jl index c9c61d858..864ce4237 100644 --- a/src/rules/continuous_transition/y.jl +++ b/src/rules/continuous_transition/y.jl @@ -4,10 +4,8 @@ mW = mean(q_W) - dy, dx = getdimensionality(meta) - Fs, es = getjacobians(meta, ma), getunits(meta) - - mA = ctcompanion_matrix(ma, sqrt.(var(q_a)), meta) + epsilon = sqrt.(var(q_a)) + mA = ctcompanion_matrix(ma, epsilon, meta) Vy = mA * Vx * mA' + cholinv(mW) my = mA * mx diff --git a/test/nodes/test_continuous_transition.jl b/test/nodes/test_continuous_transition.jl index 3dfceb31c..e2ced23f1 100644 --- a/test/nodes/test_continuous_transition.jl +++ b/test/nodes/test_continuous_transition.jl @@ -6,8 +6,7 @@ import ReactiveMP: getdimensionality, getjacobians, gettransformation, getunits, @testset "ContinuousTransitionNode" begin dy, dx = 2, 3 - a0 = rand(dx * dy) # Example vector `a0` - meta = CTMeta(a -> reshape(a, dy, dx), a0) + meta = CTMeta(a -> reshape(a, dy, dx)) @testset "Creation" begin node = make_node(ContinuousTransition, FactorNodeCreationOptions(nothing, meta, nothing)) @@ -15,32 +14,25 @@ import ReactiveMP: getdimensionality, getjacobians, gettransformation, getunits, @test sdtype(node) === Stochastic() @test name.(interfaces(node)) === (:y, :x, :a, :W) @test factorisation(node) === ((1, 2, 3, 4),) - @test getdimensionality(metadata(node)) == (dy, dx) # Based on the transformation function dimensions end @testset "AverageEnergy" begin - # This is an example setup, you'll need to adjust the distributions and marginals according to your needs q_y_x = MvNormalMeanCovariance(zeros(5), diageye(5)) - q_a = MvNormalMeanCovariance(zeros(6), diageye(6)) # Adjust the dimension according to `a` + q_a = MvNormalMeanCovariance(zeros(6), diageye(6)) q_W = Wishart(3, diageye(2)) marginals = (Marginal(q_y_x, false, false, nothing), Marginal(q_a, false, false, nothing), Marginal(q_W, false, false, nothing)) @test score(AverageEnergy(), ContinuousTransition, Val{(:y_x, :a, :W)}(), marginals, meta) ≈ 13.415092731310878 - @show getjacobians(meta, a0) + @show getjacobians(meta, mean(q_a)) end - @testset "CTransition Functionality" begin - A = ctcompanion_matrix(a0, zeros(length(a0)), meta) + @testset "ContinuousTransition Functionality" begin + m_a = randn(6) + A = ctcompanion_matrix(m_a, zeros(length(m_a)), meta) @test size(A) == (dy, dx) - @test A == gettransformation(meta)(a0) - end - - @testset "Metadata Functionality" begin - @test getdimensionality(meta) == (dy, dx) - @test length(getjacobians(meta, a0)) == dy # Based on `dy` - @test length(getunits(meta)) == dy # Based on `dy` + @test A == gettransformation(meta)(m_a) end end diff --git a/test/rules/continuous_transition/test_W.jl b/test/rules/continuous_transition/test_W.jl index 71d0ff77c..d40222ee1 100644 --- a/test/rules/continuous_transition/test_W.jl +++ b/test/rules/continuous_transition/test_W.jl @@ -31,9 +31,7 @@ import ReactiveMP: @test_rules, ctcompanion_matrix, getjacobians, getunits, Wish transformation = (a) -> reshape(a, dy, dx) mA, ΣA, UA = rand(rng, dy, dx), diageye(dy), diageye(dx) - a0 = Float32.(vec(mA)) - - metal = CTMeta(transformation, a0) + metal = CTMeta(transformation) Lx, Ly = rand(rng, dx, dx), rand(rng, dy, dy) μx, Σx = rand(rng, dx), Lx * Lx' μy, Σy = rand(rng, dy), Ly * Ly' @@ -53,8 +51,8 @@ import ReactiveMP: @test_rules, ctcompanion_matrix, getjacobians, getunits, Wish dy, dx = 2, 2 dydx = dy * dy transformation = (a) -> [cos(a[1]) -sin(a[1]); sin(a[1]) cos(a[1])] - a0 = zeros(Int, 1) - metanl = CTMeta(transformation, a0) + + metanl = CTMeta(transformation) μx, Σx = zeros(dx), diageye(dx) μy, Σy = zeros(dy), diageye(dy) diff --git a/test/rules/continuous_transition/test_a.jl b/test/rules/continuous_transition/test_a.jl index acd9910ce..533a26397 100644 --- a/test/rules/continuous_transition/test_a.jl +++ b/test/rules/continuous_transition/test_a.jl @@ -28,7 +28,7 @@ import ReactiveMP: @test_rules, getjacobians, getunits dydx = dy * dx transformation = (a) -> reshape(a, dy, dx) a0 = rand(Float32, dydx) - metal = CTMeta(transformation, a0) + metal = CTMeta(transformation) Lx, Ly = rand(rng, dx, dx), rand(rng, dy, dy) μx, Σx = rand(rng, dx), Lx * Lx' μy, Σy = rand(rng, dy), Ly * Ly' @@ -49,7 +49,7 @@ import ReactiveMP: @test_rules, getjacobians, getunits dydx = dy * dy transformation = (a) -> [cos(a[1]) -sin(a[1]); sin(a[1]) cos(a[1])] a0 = zeros(Int, 1) - metanl = CTMeta(transformation, a0) + metanl = CTMeta(transformation) μx, Σx = ones(dx), diageye(dx) μy, Σy = ones(dy), diageye(dy) diff --git a/test/rules/continuous_transition/test_marginals.jl b/test/rules/continuous_transition/test_marginals.jl index c8ee5a417..d8f9550ab 100644 --- a/test/rules/continuous_transition/test_marginals.jl +++ b/test/rules/continuous_transition/test_marginals.jl @@ -25,7 +25,6 @@ import ReactiveMP: @test_marginalrules mA, ΣA, UA = rand(rng, dy, dx), diageye(dy), diageye(dx) qA = MatrixNormal(mA, ΣA, UA) - a0 = Float32.(vec(mA)) Lx, Ly = rand(rng, dx, dx), rand(rng, dy, dy) μx, Σx = rand(rng, dx), Lx * Lx' @@ -36,7 +35,7 @@ import ReactiveMP: @test_marginalrules qa = MvNormalMeanCovariance(vec(mA), kron(UA, ΣA)) qW = Wishart(dy + 1, diageye(dy)) - metal = CTMeta(transformation, a0) + metal = CTMeta(transformation) @test_marginalrules [check_type_promotion = true, atol = 1e-3] ContinuousTransition(:y_x) [( input = (m_y = my, m_x = mx, q_a = qa, q_W = qW, meta = metal), output = benchmark_rule(mx, my, qW, qA) @@ -50,17 +49,15 @@ import ReactiveMP: @test_marginalrules dy, dx = 2, 2 transformation = (a) -> [cos(a[1]) -sin(a[1]); sin(a[1]) cos(a[1])] - a0 = zeros(Int, 1) - μx, Σx = zeros(dx), diageye(dx) μy, Σy = zeros(dy), diageye(dy) my = MvNormalMeanCovariance(μy, Σy) mx = MvNormalMeanCovariance(μx, Σx) - qa = MvNormalMeanCovariance(a0, tiny * diageye(1)) + qa = MvNormalMeanCovariance(zeros(1), tiny * diageye(1)) qW = Wishart(dy, diageye(dy)) - metanl = CTMeta(transformation, a0) + metanl = CTMeta(transformation) @test_marginalrules [check_type_promotion = true, atol = 1e-5] ContinuousTransition(:y_x) [( input = (m_y = my, m_x = mx, q_a = qa, q_W = qW, meta = metanl), diff --git a/test/rules/continuous_transition/test_x.jl b/test/rules/continuous_transition/test_x.jl index 84da782ec..e22d8c7d4 100644 --- a/test/rules/continuous_transition/test_x.jl +++ b/test/rules/continuous_transition/test_x.jl @@ -27,13 +27,12 @@ import ReactiveMP: @test_rules, ctcompanion_matrix, getjacobians, getunits mA, ΣA, UA = rand(rng, dy, dx), diageye(dy), diageye(dx) - a0 = Float32.(vec(mA)) - metal = CTMeta(transformation, a0) + metal = CTMeta(transformation) Lx, Ly = rand(rng, dx, dx), rand(rng, dy, dy) μy, Σy = rand(rng, dy), Ly * Ly' qy = MvNormalMeanCovariance(μy, Σy) - qa = MvNormalMeanCovariance(a0, diageye(dydx)) + qa = MvNormalMeanCovariance(vec(mA), diageye(dydx)) qW = Wishart(dy + 1, diageye(dy)) @test_rules [check_type_promotion = true, atol = 1e-4] ContinuousTransition(:x, Marginalisation) [( @@ -48,12 +47,12 @@ import ReactiveMP: @test_rules, ctcompanion_matrix, getjacobians, getunits dy, dx = 2, 2 dydx = dy * dy transformation = (a) -> [cos(a[1]) -sin(a[1]); sin(a[1]) cos(a[1])] - a0 = zeros(Float32, 1) - metanl = CTMeta(transformation, a0) + + metanl = CTMeta(transformation) μy, Σy = zeros(dy), diageye(dy) qy = MvNormalMeanCovariance(μy, Σy) - qa = MvNormalMeanCovariance(a0, tiny * diageye(1)) + qa = MvNormalMeanCovariance(zeros(1), tiny * diageye(1)) qW = Wishart(dy + 1, diageye(dy)) @test_rules [check_type_promotion = true] ContinuousTransition(:x, Marginalisation) [( diff --git a/test/rules/continuous_transition/test_y.jl b/test/rules/continuous_transition/test_y.jl index f3a3b2abb..9cb1b2a5a 100644 --- a/test/rules/continuous_transition/test_y.jl +++ b/test/rules/continuous_transition/test_y.jl @@ -22,13 +22,13 @@ import ReactiveMP: @test_rules, ctcompanion_matrix, getjacobians, getunits transformation = (a) -> reshape(a, dy, dx) mA = rand(rng, dy, dx) - a0 = Float32.(vec(mA)) - metal = CTMeta(transformation, a0) + + metal = CTMeta(transformation) Lx = rand(rng, dx, dx) μx, Σx = rand(rng, dx), Lx * Lx' qx = MvNormalMeanCovariance(μx, Σx) - qa = MvNormalMeanCovariance(a0, diageye(dydx)) + qa = MvNormalMeanCovariance(vec(mA), diageye(dydx)) qW = Wishart(dy + 1, diageye(dy)) @test_rules [check_type_promotion = true, atol = 1e-5] ContinuousTransition(:y, Marginalisation) [( @@ -43,12 +43,12 @@ import ReactiveMP: @test_rules, ctcompanion_matrix, getjacobians, getunits dy, dx = 2, 2 dydx = dy * dy transformation = (a) -> [cos(a[1]) -sin(a[1]); sin(a[1]) cos(a[1])] - a0 = zeros(Float32, 1) - metanl = CTMeta(transformation, a0) + + metanl = CTMeta(transformation) μx, Σx = zeros(dx), diageye(dx) qx = MvNormalMeanCovariance(μx, Σx) - qa = MvNormalMeanCovariance(a0, tiny * diageye(1)) + qa = MvNormalMeanCovariance(zeros(1), tiny * diageye(1)) qW = Wishart(dy + 1, diageye(dy)) @test_rules [check_type_promotion = true] ContinuousTransition(:y, Marginalisation) [( From ba47042f3316cc2cac84b05ffc6bb2acd3577392 Mon Sep 17 00:00:00 2001 From: Albert Date: Mon, 11 Dec 2023 21:40:21 +0100 Subject: [PATCH 32/38] Improve speed --- src/nodes/continuous_transition.jl | 8 ++++---- src/rules/continuous_transition/a.jl | 2 +- src/rules/continuous_transition/marginals.jl | 6 +++--- test/nodes/test_continuous_transition.jl | 2 +- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/nodes/continuous_transition.jl b/src/nodes/continuous_transition.jl index c50e6592a..4348878b5 100644 --- a/src/nodes/continuous_transition.jl +++ b/src/nodes/continuous_transition.jl @@ -69,7 +69,7 @@ function ctcompanion_matrix(a, epsilon, meta::CTMeta) f = gettransformation(meta) dy = length(Js) # we approximate each row of A by a linear function and create a matrix A composed of the approximated rows - A = sum(StandardBasisVector(dy, i) * (f(a0)[i, :] + Js[i] * (a - a0))' for i in 1:dy) + A = f(a0) + mapreduce(i -> StandardBasisVector(dy, i) * (Js[i] * (a - a0))', +, 1:dy) return A end @@ -88,13 +88,13 @@ end my, Vy = @views myx[1:dy], Vyx[1:dy, 1:dy] Vyx = @view Vyx[1:dy, (dy + 1):end] # we proved (when Va = kron(U, S)): - # tr(W(Sx'Ux)) = tr(kron(xx', W)kron(U, S)) = tr(kron(xx', W)Va) + # sum(es[i]' * mW * es[j] * mx * Fs[i] * Va * Fs[j]' * mx') = tr(kron(xx', W)kron(U, S)) # sum(es[i]' * mW * es[j] * Fs[i] * Va * Fs[j]') = tr(WS)U g1 = -mA * Vyx' g2 = g1' trWSU = sum(sum(StandardBasisVector(dy, i)' * mW * StandardBasisVector(dy, j) * Fs[i] * Va * Fs[j]' for i in 1:dy) for j in 1:dy) - - AE = n / 2 * log2π - mean(logdet, q_W) + (tr(mW * (mA * Vx * mA' + g1 + g2 + Vy + (mA * mx - my) * (mA * mx - my)')) + tr(trWSU) + tr(kron(mx * mx', mW) * Va)) / 2 + kronxxWSU = sum(sum(StandardBasisVector(dy, i)' * mW * StandardBasisVector(dy, j) * mx' * Fs[i] * Va * Fs[j]' * mx for i in 1:dy) for j in 1:dy) + AE = n / 2 * log2π - mean(logdet, q_W) + (tr(mW * (mA * Vx * mA' + g1 + g2 + Vy + (mA * mx - my) * (mA * mx - my)')) + tr(trWSU) + tr(kronxxWSU)) / 2 return AE end diff --git a/src/rules/continuous_transition/a.jl b/src/rules/continuous_transition/a.jl index e95059f25..74f6b6a6d 100644 --- a/src/rules/continuous_transition/a.jl +++ b/src/rules/continuous_transition/a.jl @@ -12,7 +12,7 @@ # rank1update(Vyx, mx, my) equivalent to ξ = (Vyx + mx * my') D = sum(sum(StandardBasisVector(dy, j)' * mW * StandardBasisVector(dy, i) * Fs[i]' * rank1update(Vx, mx) * Fs[j] for i in 1:dy) for j in 1:dy) - z = sum(Fs[i]' * rank1update(Vyx', mx, my) * mW * StandardBasisVector(dy, i) for i in 1:dy) + z = mapreduce(i -> Fs[i]' * rank1update(Vyx', mx, my) * mW * StandardBasisVector(dy, i), +, 1:dy) return MvNormalWeightedMeanPrecision(z, D) end diff --git a/src/rules/continuous_transition/marginals.jl b/src/rules/continuous_transition/marginals.jl index 5699c2536..1db743e60 100644 --- a/src/rules/continuous_transition/marginals.jl +++ b/src/rules/continuous_transition/marginals.jl @@ -21,10 +21,10 @@ function continuous_tranition_marginal(m_y::MultivariateNormalDistributionsFamil W_11 = Wy + mW - # negate_inplace!(mW * mH) - W_12 = -(mW * mA) + # + W_12 = negate_inplace!(mW * mA) - W_21 = -(mA' * mW) + W_21 = negate_inplace!(mA' * mW) W_22 = Ξ + mA' * mW * mA diff --git a/test/nodes/test_continuous_transition.jl b/test/nodes/test_continuous_transition.jl index e2ced23f1..34950603f 100644 --- a/test/nodes/test_continuous_transition.jl +++ b/test/nodes/test_continuous_transition.jl @@ -23,7 +23,7 @@ import ReactiveMP: getdimensionality, getjacobians, gettransformation, getunits, marginals = (Marginal(q_y_x, false, false, nothing), Marginal(q_a, false, false, nothing), Marginal(q_W, false, false, nothing)) - @test score(AverageEnergy(), ContinuousTransition, Val{(:y_x, :a, :W)}(), marginals, meta) ≈ 13.415092731310878 + @test score(AverageEnergy(), ContinuousTransition, Val{(:y_x, :a, :W)}(), marginals, meta) ≈ 13.0 atol = 1e-2 @show getjacobians(meta, mean(q_a)) end From 9d38d6190643f6b08dd16a389b39c9d395c91e0e Mon Sep 17 00:00:00 2001 From: Albert Date: Mon, 11 Dec 2023 22:15:50 +0100 Subject: [PATCH 33/38] Fix test --- test/rules/continuous_transition/test_W.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/rules/continuous_transition/test_W.jl b/test/rules/continuous_transition/test_W.jl index d40222ee1..dd55b3807 100644 --- a/test/rules/continuous_transition/test_W.jl +++ b/test/rules/continuous_transition/test_W.jl @@ -57,7 +57,7 @@ import ReactiveMP: @test_rules, ctcompanion_matrix, getjacobians, getunits, Wish μy, Σy = zeros(dy), diageye(dy) qyx = MvNormalMeanCovariance([μy; μx], [Σy zeros(dy, dx); zeros(dx, dy) Σx]) - qa = MvNormalMeanCovariance(a0, diageye(1)) + qa = MvNormalMeanCovariance(zeros(1), diageye(1)) @test_rules [check_type_promotion = true] ContinuousTransition(:W, Marginalisation) [( input = (q_y_x = qyx, q_a = qa, meta = metanl), output = WishartFast(dy + 2, dy * diageye(dy)) )] From 6a1af1523e310da4afb415c188fc458d4d46027a Mon Sep 17 00:00:00 2001 From: Albert Date: Tue, 12 Dec 2023 12:58:10 +0100 Subject: [PATCH 34/38] Update src/nodes/continuous_transition.jl Co-authored-by: Bart van Erp <44952318+bartvanerp@users.noreply.github.com> --- src/nodes/continuous_transition.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/nodes/continuous_transition.jl b/src/nodes/continuous_transition.jl index 4348878b5..e09ad32ea 100644 --- a/src/nodes/continuous_transition.jl +++ b/src/nodes/continuous_transition.jl @@ -69,7 +69,10 @@ function ctcompanion_matrix(a, epsilon, meta::CTMeta) f = gettransformation(meta) dy = length(Js) # we approximate each row of A by a linear function and create a matrix A composed of the approximated rows - A = f(a0) + mapreduce(i -> StandardBasisVector(dy, i) * (Js[i] * (a - a0))', +, 1:dy) + A = f(a0) + for i in 1:dy + A[i,:] .+= Js[i] * (a-a0) + end return A end From 741430afe32076898cdcf811404d4d22cd3ed422 Mon Sep 17 00:00:00 2001 From: Albert Date: Tue, 12 Dec 2023 14:34:31 +0100 Subject: [PATCH 35/38] Update rule for CTransition a --- src/rules/continuous_transition/a.jl | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/rules/continuous_transition/a.jl b/src/rules/continuous_transition/a.jl index 74f6b6a6d..52fed127e 100644 --- a/src/rules/continuous_transition/a.jl +++ b/src/rules/continuous_transition/a.jl @@ -10,9 +10,16 @@ my, Vy = @views myx[1:dy], Vyx[1:dy, 1:dy] Vyx = @view Vyx[1:dy, (dy + 1):end] - # rank1update(Vyx, mx, my) equivalent to ξ = (Vyx + mx * my') - D = sum(sum(StandardBasisVector(dy, j)' * mW * StandardBasisVector(dy, i) * Fs[i]' * rank1update(Vx, mx) * Fs[j] for i in 1:dy) for j in 1:dy) - z = mapreduce(i -> Fs[i]' * rank1update(Vyx', mx, my) * mW * StandardBasisVector(dy, i), +, 1:dy) + xi, W = zeros(eltype(ma), length(ma)), zeros(eltype(ma), length(ma), length(ma)) - return MvNormalWeightedMeanPrecision(z, D) + Vxymxy = rank1update(Vyx', mx, my) + Vxmx = rank1update(Vx, mx) + for i in 1:dy + xi += Fs[i]' * Vxymxy * mW[:,i] + for j in 1:dy + W += mW[j,i] * Fs[i]'*Vxmx*Fs[j] + end + end + + return MvNormalWeightedMeanPrecision(xi, W) end From fdfc607d19015da49c687aaf6c0f384aa0ba43e2 Mon Sep 17 00:00:00 2001 From: Albert Date: Tue, 12 Dec 2023 15:16:25 +0100 Subject: [PATCH 36/38] Optimize rules --- src/nodes/continuous_transition.jl | 15 +++++++++------ src/rules/continuous_transition/W.jl | 12 +++++++++--- src/rules/continuous_transition/marginals.jl | 7 +++++-- src/rules/continuous_transition/x.jl | 8 ++++++-- 4 files changed, 29 insertions(+), 13 deletions(-) diff --git a/src/nodes/continuous_transition.jl b/src/nodes/continuous_transition.jl index e09ad32ea..379ec0637 100644 --- a/src/nodes/continuous_transition.jl +++ b/src/nodes/continuous_transition.jl @@ -90,14 +90,17 @@ end mx, Vx = @views myx[(dy + 1):end], Vyx[(dy + 1):end, (dy + 1):end] my, Vy = @views myx[1:dy], Vyx[1:dy, 1:dy] Vyx = @view Vyx[1:dy, (dy + 1):end] - # we proved (when Va = kron(U, S)): - # sum(es[i]' * mW * es[j] * mx * Fs[i] * Va * Fs[j]' * mx') = tr(kron(xx', W)kron(U, S)) - # sum(es[i]' * mW * es[j] * Fs[i] * Va * Fs[j]') = tr(WS)U + g1 = -mA * Vyx' g2 = g1' - trWSU = sum(sum(StandardBasisVector(dy, i)' * mW * StandardBasisVector(dy, j) * Fs[i] * Va * Fs[j]' for i in 1:dy) for j in 1:dy) - kronxxWSU = sum(sum(StandardBasisVector(dy, i)' * mW * StandardBasisVector(dy, j) * mx' * Fs[i] * Va * Fs[j]' * mx for i in 1:dy) for j in 1:dy) - AE = n / 2 * log2π - mean(logdet, q_W) + (tr(mW * (mA * Vx * mA' + g1 + g2 + Vy + (mA * mx - my) * (mA * mx - my)')) + tr(trWSU) + tr(kronxxWSU)) / 2 + trWSU, trkronxxWSU = zero(eltype(ma)), zero(eltype(ma)) + xxt = mx * mx' + for (i, j) in Iterators.product(1:dy, 1:dy) + FjVaFi = Fs[j] * Va * Fs[i]' + trWSU += mW[j, i] * tr(FjVaFi) + trkronxxWSU += mW[j, i] * tr(xxt * FjVaFi) + end + AE = n / 2 * log2π - mean(logdet, q_W) + (tr(mW * (mA * Vx * mA' + g1 + g2 + Vy + (mA * mx - my) * (mA * mx - my)')) + trWSU + trkronxxWSU) / 2 return AE end diff --git a/src/rules/continuous_transition/W.jl b/src/rules/continuous_transition/W.jl index f28f6779c..340ae5577 100644 --- a/src/rules/continuous_transition/W.jl +++ b/src/rules/continuous_transition/W.jl @@ -3,9 +3,15 @@ function compute_delta(my, Vy, mx, Vx, Vyx, mA, Va, ma, Fs) G₁ = (my * my' + Vy) G₂ = ((my * mx' + Vyx) * mA') G₃ = transpose(G₂) - Ex_xx = mx * mx' + Vx - G₅ = sum(sum(StandardBasisVector(dy, i) * ma' * Fs[i]'Ex_xx * Fs[j] * ma * StandardBasisVector(dy, j)' for i in 1:dy) for j in 1:dy) - G₆ = sum(sum(StandardBasisVector(dy, i) * tr(Fs[i]' * Ex_xx * Fs[j] * Va) * StandardBasisVector(dy, j)' for i in 1:dy) for j in 1:dy) + Ex_xx = rank1update(Vx, mx) + G₅ = zeros(eltype(ma), dy, dy) + G₆ = zeros(eltype(ma), dy, dy) + mamat = ma * ma' + for (i, j) in Iterators.product(1:dy, 1:dy) + tmp = Fs[i]' * Ex_xx * Fs[j] + G₅[i, j] = tr(tmp * mamat) + G₆[i, j] = tr(tmp * Va) + end return G₁ - G₂ - G₃ + G₅ + G₆ end diff --git a/src/rules/continuous_transition/marginals.jl b/src/rules/continuous_transition/marginals.jl index 1db743e60..81027376e 100644 --- a/src/rules/continuous_transition/marginals.jl +++ b/src/rules/continuous_transition/marginals.jl @@ -17,8 +17,6 @@ function continuous_tranition_marginal(m_y::MultivariateNormalDistributionsFamil xiy, Wy = weightedmean_precision(m_y) xix, Wx = weightedmean_precision(m_x) - Ξ = Wx + sum(sum(StandardBasisVector(dy, j)' * mW * StandardBasisVector(dy, i) * Fs[j] * Va * Fs[i]' for i in 1:dy) for j in 1:dy) - W_11 = Wy + mW # @@ -26,6 +24,11 @@ function continuous_tranition_marginal(m_y::MultivariateNormalDistributionsFamil W_21 = negate_inplace!(mA' * mW) + Ξ = Wx + for (i, j) in Iterators.product(1:dy, 1:dy) + Ξ += mW[j, i] * Fs[j] * Va * Fs[i]' + end + W_22 = Ξ + mA' * mW * mA W = [W_11 W_12; W_21 W_22] diff --git a/src/rules/continuous_transition/x.jl b/src/rules/continuous_transition/x.jl index 79e582e57..56169db82 100644 --- a/src/rules/continuous_transition/x.jl +++ b/src/rules/continuous_transition/x.jl @@ -9,12 +9,16 @@ epsilon = sqrt.(var(q_a)) mA = ctcompanion_matrix(ma, epsilon, meta) - W = sum(sum(StandardBasisVector(dy, j)' * mW * StandardBasisVector(dy, i) * Fs[j] * Va * Fs[i]' for i in 1:dy) for j in 1:dy) # Woodbury identity # inv(inv(Wy) + inv(mW)) = Wy - Wy * inv(Wy + mW) * Wy WymW = Wy - Wy * cholinv(Wy + mW) * Wy + Ξ = mA' * WymW * mA + + for (i, j) in Iterators.product(1:dy, 1:dy) + Ξ += mW[j, i] * Fs[j] * Va * Fs[i]' + end + z = mA' * WymW * my - Ξ = mA' * WymW * mA + W return MvNormalWeightedMeanPrecision(z, Ξ) end From 88fbf430dd697de8b9b418cbf32bfc04819142b4 Mon Sep 17 00:00:00 2001 From: Albert Date: Tue, 12 Dec 2023 15:23:20 +0100 Subject: [PATCH 37/38] Make format --- src/nodes/continuous_transition.jl | 4 ++-- src/rules/continuous_transition/a.jl | 8 ++++---- src/rules/continuous_transition/marginals.jl | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/nodes/continuous_transition.jl b/src/nodes/continuous_transition.jl index 379ec0637..325b77cb0 100644 --- a/src/nodes/continuous_transition.jl +++ b/src/nodes/continuous_transition.jl @@ -71,7 +71,7 @@ function ctcompanion_matrix(a, epsilon, meta::CTMeta) # we approximate each row of A by a linear function and create a matrix A composed of the approximated rows A = f(a0) for i in 1:dy - A[i,:] .+= Js[i] * (a-a0) + A[i, :] .+= Js[i] * (a - a0) end return A end @@ -90,7 +90,7 @@ end mx, Vx = @views myx[(dy + 1):end], Vyx[(dy + 1):end, (dy + 1):end] my, Vy = @views myx[1:dy], Vyx[1:dy, 1:dy] Vyx = @view Vyx[1:dy, (dy + 1):end] - + g1 = -mA * Vyx' g2 = g1' trWSU, trkronxxWSU = zero(eltype(ma)), zero(eltype(ma)) diff --git a/src/rules/continuous_transition/a.jl b/src/rules/continuous_transition/a.jl index 52fed127e..ccc768598 100644 --- a/src/rules/continuous_transition/a.jl +++ b/src/rules/continuous_transition/a.jl @@ -10,16 +10,16 @@ my, Vy = @views myx[1:dy], Vyx[1:dy, 1:dy] Vyx = @view Vyx[1:dy, (dy + 1):end] - xi, W = zeros(eltype(ma), length(ma)), zeros(eltype(ma), length(ma), length(ma)) + xi, W = zeros(eltype(ma), length(ma)), zeros(eltype(ma), length(ma), length(ma)) Vxymxy = rank1update(Vyx', mx, my) Vxmx = rank1update(Vx, mx) for i in 1:dy - xi += Fs[i]' * Vxymxy * mW[:,i] + xi += Fs[i]' * Vxymxy * mW[:, i] for j in 1:dy - W += mW[j,i] * Fs[i]'*Vxmx*Fs[j] + W += mW[j, i] * Fs[i]' * Vxmx * Fs[j] end end - + return MvNormalWeightedMeanPrecision(xi, W) end diff --git a/src/rules/continuous_transition/marginals.jl b/src/rules/continuous_transition/marginals.jl index 81027376e..abd21539b 100644 --- a/src/rules/continuous_transition/marginals.jl +++ b/src/rules/continuous_transition/marginals.jl @@ -24,7 +24,7 @@ function continuous_tranition_marginal(m_y::MultivariateNormalDistributionsFamil W_21 = negate_inplace!(mA' * mW) - Ξ = Wx + Ξ = Wx for (i, j) in Iterators.product(1:dy, 1:dy) Ξ += mW[j, i] * Fs[j] * Va * Fs[i]' end From d5e101c9c9665f9e7366e165109be51a80e6d991 Mon Sep 17 00:00:00 2001 From: Albert Date: Tue, 12 Dec 2023 15:27:04 +0100 Subject: [PATCH 38/38] Update docs --- src/nodes/continuous_transition.jl | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/nodes/continuous_transition.jl b/src/nodes/continuous_transition.jl index 325b77cb0..8a707d969 100644 --- a/src/nodes/continuous_transition.jl +++ b/src/nodes/continuous_transition.jl @@ -5,7 +5,24 @@ import StatsFuns: log2π @doc raw""" The ContinuousTransition node transforms an m-dimensional (dx) vector x into an n-dimensional (dy) vector y via a linear (or nonlinear) transformation with a `n×m`-dimensional matrix `A` that is constructed from a vector `a`. +ContinuousTransition node is primarily used in two regimes: +# When no structure on A is specified: +```julia +transformation = a -> reshape(a, 2, 2) +... +a ~ MvNormalMeanCovariance(zeros(2), Diagonal(ones(2))) +y ~ ContinuousTransition(x, a, W) where {meta = CTMeta(transformation)} +... +``` +# When some structure if A is known: +```julia +transformation = a -> [cos(a[1]) -sin(a[1]); sin(a[1]) cos(a[1])] +... +a ~ MvNormalMeanCovariance(zeros(1), Diagonal(ones(1))) +y ~ ContinuousTransition(x, a, W) where {meta = CTMeta(transformation)} +... +``` To construct the matrix `A`, the elements of `a` are reshaped into `A` according to the transformation function provided in the meta. If you intend to use univariate Gaussian distributions, use it as a vector of length `1``, e.g. `a ~ MvNormalMeanCovariance([0.0], [1.;])`. Check ContinuousTransitionMeta for more details on how to specify the transformation function that **must** return a matrix.