diff --git a/src/ReactiveMP.jl b/src/ReactiveMP.jl index 169d2fa0c..d5c860c4c 100644 --- a/src/ReactiveMP.jl +++ b/src/ReactiveMP.jl @@ -123,6 +123,7 @@ include("nodes/bifm.jl") include("nodes/bifm_helper.jl") include("nodes/probit.jl") include("nodes/poisson.jl") +include("nodes/continuous_transition.jl") include("nodes/half_normal.jl") include("nodes/flow/flow.jl") diff --git a/src/nodes/continuous_transition.jl b/src/nodes/continuous_transition.jl new file mode 100644 index 000000000..8a707d969 --- /dev/null +++ b/src/nodes/continuous_transition.jl @@ -0,0 +1,123 @@ +export CTransition, ContinuousTransition, CTMeta, ContinuousTransitionMeta + +import LazyArrays +import StatsFuns: log2π + +@doc raw""" +The ContinuousTransition node transforms an m-dimensional (dx) vector x into an n-dimensional (dy) vector y via a linear (or nonlinear) transformation with a `n×m`-dimensional matrix `A` that is constructed from a vector `a`. +ContinuousTransition node is primarily used in two regimes: + +# When no structure on A is specified: +```julia +transformation = a -> reshape(a, 2, 2) +... +a ~ MvNormalMeanCovariance(zeros(2), Diagonal(ones(2))) +y ~ ContinuousTransition(x, a, W) where {meta = CTMeta(transformation)} +... +``` +# When some structure if A is known: +```julia +transformation = a -> [cos(a[1]) -sin(a[1]); sin(a[1]) cos(a[1])] +... +a ~ MvNormalMeanCovariance(zeros(1), Diagonal(ones(1))) +y ~ ContinuousTransition(x, a, W) where {meta = CTMeta(transformation)} +... +``` +To construct the matrix `A`, the elements of `a` are reshaped into `A` according to the transformation function provided in the meta. If you intend to use univariate Gaussian distributions, use it as a vector of length `1``, e.g. `a ~ MvNormalMeanCovariance([0.0], [1.;])`. + +Check ContinuousTransitionMeta for more details on how to specify the transformation function that **must** return a matrix. + +```julia +y ~ ContinuousTransition(x, a, W) where {meta = ContinuousTransitionMeta(transformation)} +``` +Interfaces: +1. y - n-dimensional output of the ContinuousTransition node. +2. x - m-dimensional input of the ContinuousTransition node. +3. a - any-dimensional vector that casts into the matrix `A`. +4. W - `n×n`-dimensional precision matrix used to soften the transition and perform variational message passing. + +Note that you can set W to a fixed value or put a prior on it to control the amount of jitter. +""" +struct ContinuousTransition end + +const CTransition = ContinuousTransition + +@node ContinuousTransition Stochastic [y, x, a, W] + +@doc raw""" +`ContinuousTransitionMeta` is used as a metadata flag in `ContinuousTransition` to define the transformation function for constructing the matrix `A` from vector `a`. + +`ContinuousTransitionMeta` requires a transformation function and the length of vector `a`, which acts as an expansion point for approximating the transformation linearly. If transformation appears to be linear, then no approximation is performed. + +Constructors: +- `ContinuousTransitionMeta(transformation::Function, â::Vector{<:Real})`: Constructs a `ContinuousTransitionMeta` struct with the transformation function and allocated basis vectors. + +Fields: +- `f`: Represents the transformation function that transforms vector `a` into matrix `A` + +The `ContinuousTransitionMeta` struct plays a pivotal role in defining how the vector `a` is transformed into the matrix `A`, thus influencing the behavior of the `ContinuousTransition` node. +""" +struct ContinuousTransitionMeta{F <: Function} + f::F # transformation function + + function ContinuousTransitionMeta(transformation::F) where {F} + return new{F}(transformation) + end +end + +const CTMeta = ContinuousTransitionMeta + +gettransformation(meta::CTMeta) = meta.f +# getctoutputdim(meta::CTMeta, J) = div(size(J, 2), size(J, 1)) # returns dy where J ∈ ℝ^{dx × dydx} + +getjacobians(ctmeta::CTMeta, a) = process_Fs(gettransformation(ctmeta), a) +process_Fs(f::Function, a) = [ForwardDiff.jacobian(a -> f(a)[i, :], a) for i in 1:size(f(a), 1)] + +default_meta(::Type{CTMeta}) = error("ContinuousTransition node requires meta flag explicitly specified") + +default_functional_dependencies_pipeline(::Type{<:ContinuousTransition}) = RequireMarginalFunctionalDependencies((3,), (nothing,)) + +""" + `ctcompanion_matrix` casts a vector `a` into a matrix `A` by means of linearization of the transformation function `f` around the expansion point `a0`. +""" +function ctcompanion_matrix(a, epsilon, meta::CTMeta) + a0 = a + epsilon # expansion point + Js = getjacobians(meta, a0) + f = gettransformation(meta) + dy = length(Js) + # we approximate each row of A by a linear function and create a matrix A composed of the approximated rows + A = f(a0) + for i in 1:dy + A[i, :] .+= Js[i] * (a - a0) + end + return A +end + +@average_energy ContinuousTransition (q_y_x::Any, q_a::Any, q_W::Any, meta::CTMeta) = begin + ma, Va = mean_cov(q_a) + myx, Vyx = mean_cov(q_y_x) + mW = mean(q_W) + + Fs = getjacobians(meta, ma) # dx × dydx + dy = length(Fs) + + n = div(ndims(q_y_x), 2) + mA = ctcompanion_matrix(ma, sqrt.(var(q_a)), meta) + + mx, Vx = @views myx[(dy + 1):end], Vyx[(dy + 1):end, (dy + 1):end] + my, Vy = @views myx[1:dy], Vyx[1:dy, 1:dy] + Vyx = @view Vyx[1:dy, (dy + 1):end] + + g1 = -mA * Vyx' + g2 = g1' + trWSU, trkronxxWSU = zero(eltype(ma)), zero(eltype(ma)) + xxt = mx * mx' + for (i, j) in Iterators.product(1:dy, 1:dy) + FjVaFi = Fs[j] * Va * Fs[i]' + trWSU += mW[j, i] * tr(FjVaFi) + trkronxxWSU += mW[j, i] * tr(xxt * FjVaFi) + end + AE = n / 2 * log2π - mean(logdet, q_W) + (tr(mW * (mA * Vx * mA' + g1 + g2 + Vy + (mA * mx - my) * (mA * mx - my)')) + trWSU + trkronxxWSU) / 2 + + return AE +end diff --git a/src/rules/continuous_transition/W.jl b/src/rules/continuous_transition/W.jl new file mode 100644 index 000000000..340ae5577 --- /dev/null +++ b/src/rules/continuous_transition/W.jl @@ -0,0 +1,35 @@ +function compute_delta(my, Vy, mx, Vx, Vyx, mA, Va, ma, Fs) + dy = length(my) + G₁ = (my * my' + Vy) + G₂ = ((my * mx' + Vyx) * mA') + G₃ = transpose(G₂) + Ex_xx = rank1update(Vx, mx) + G₅ = zeros(eltype(ma), dy, dy) + G₆ = zeros(eltype(ma), dy, dy) + mamat = ma * ma' + for (i, j) in Iterators.product(1:dy, 1:dy) + tmp = Fs[i]' * Ex_xx * Fs[j] + G₅[i, j] = tr(tmp * mamat) + G₆[i, j] = tr(tmp * Va) + end + return G₁ - G₂ - G₃ + G₅ + G₆ +end + +@rule ContinuousTransition(:W, Marginalisation) (q_y_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, meta::CTMeta) = begin + ma, Va = mean_cov(q_a) + Fs = getjacobians(meta, ma) + dy = length(Fs) + + epsilon = sqrt.(var(q_a)) + mA = ctcompanion_matrix(ma, epsilon, meta) + + myx, Vyx = mean_cov(q_y_x) + + mx, Vx = @views myx[(dy + 1):end], Vyx[(dy + 1):end, (dy + 1):end] + my, Vy = @views myx[1:dy], Vyx[1:dy, 1:dy] + Vyx = @views Vyx[1:dy, (dy + 1):end] + + Δ = compute_delta(my, Vy, mx, Vx, Vyx, mA, Va, ma, Fs) + + return WishartFast(dy + 2, Δ) +end diff --git a/src/rules/continuous_transition/a.jl b/src/rules/continuous_transition/a.jl new file mode 100644 index 000000000..ccc768598 --- /dev/null +++ b/src/rules/continuous_transition/a.jl @@ -0,0 +1,25 @@ +@rule ContinuousTransition(:a, Marginalisation) (q_y_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_W::Any, meta::CTMeta) = begin + ma = mean(q_a) + mW = mean(q_W) + myx, Vyx = mean_cov(q_y_x) + + Fs = getjacobians(meta, ma) + dy = length(Fs) + + mx, Vx = @views myx[(dy + 1):end], Vyx[(dy + 1):end, (dy + 1):end] + my, Vy = @views myx[1:dy], Vyx[1:dy, 1:dy] + Vyx = @view Vyx[1:dy, (dy + 1):end] + + xi, W = zeros(eltype(ma), length(ma)), zeros(eltype(ma), length(ma), length(ma)) + + Vxymxy = rank1update(Vyx', mx, my) + Vxmx = rank1update(Vx, mx) + for i in 1:dy + xi += Fs[i]' * Vxymxy * mW[:, i] + for j in 1:dy + W += mW[j, i] * Fs[i]' * Vxmx * Fs[j] + end + end + + return MvNormalWeightedMeanPrecision(xi, W) +end diff --git a/src/rules/continuous_transition/marginals.jl b/src/rules/continuous_transition/marginals.jl new file mode 100644 index 000000000..abd21539b --- /dev/null +++ b/src/rules/continuous_transition/marginals.jl @@ -0,0 +1,38 @@ + +@marginalrule ContinuousTransition(:y_x) (m_y::MultivariateNormalDistributionsFamily, m_x::MultivariateNormalDistributionsFamily, q_a::Any, q_W::Any, meta::CTMeta) = begin + return continuous_tranition_marginal(m_y, m_x, q_a, q_W, meta) +end + +function continuous_tranition_marginal(m_y::MultivariateNormalDistributionsFamily, m_x::MultivariateNormalDistributionsFamily, q_a::Any, q_W::Any, meta::CTMeta) + ma, Va = mean_cov(q_a) + + Fs = getjacobians(meta, ma) + dy = length(Fs) + + mW = mean(q_W) + + epsilon = sqrt.(var(q_a)) + mA = ctcompanion_matrix(ma, epsilon, meta) + + xiy, Wy = weightedmean_precision(m_y) + xix, Wx = weightedmean_precision(m_x) + + W_11 = Wy + mW + + # + W_12 = negate_inplace!(mW * mA) + + W_21 = negate_inplace!(mA' * mW) + + Ξ = Wx + for (i, j) in Iterators.product(1:dy, 1:dy) + Ξ += mW[j, i] * Fs[j] * Va * Fs[i]' + end + + W_22 = Ξ + mA' * mW * mA + + W = [W_11 W_12; W_21 W_22] + ξ = [xiy; xix] + + return MvNormalWeightedMeanPrecision(ξ, W) +end diff --git a/src/rules/continuous_transition/x.jl b/src/rules/continuous_transition/x.jl new file mode 100644 index 000000000..56169db82 --- /dev/null +++ b/src/rules/continuous_transition/x.jl @@ -0,0 +1,24 @@ +@rule ContinuousTransition(:x, Marginalisation) (m_y::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_W::Any, meta::CTMeta) = begin + ma, Va = mean_cov(q_a) + my, Wy = mean_precision(m_y) + mW = mean(q_W) + + Fs = getjacobians(meta, ma) + dy = length(Fs) + + epsilon = sqrt.(var(q_a)) + mA = ctcompanion_matrix(ma, epsilon, meta) + + # Woodbury identity + # inv(inv(Wy) + inv(mW)) = Wy - Wy * inv(Wy + mW) * Wy + WymW = Wy - Wy * cholinv(Wy + mW) * Wy + Ξ = mA' * WymW * mA + + for (i, j) in Iterators.product(1:dy, 1:dy) + Ξ += mW[j, i] * Fs[j] * Va * Fs[i]' + end + + z = mA' * WymW * my + + return MvNormalWeightedMeanPrecision(z, Ξ) +end diff --git a/src/rules/continuous_transition/y.jl b/src/rules/continuous_transition/y.jl new file mode 100644 index 000000000..864ce4237 --- /dev/null +++ b/src/rules/continuous_transition/y.jl @@ -0,0 +1,14 @@ +@rule ContinuousTransition(:y, Marginalisation) (m_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_W::Any, meta::CTMeta) = begin + ma = mean(q_a) + mx, Vx = mean_cov(m_x) + + mW = mean(q_W) + + epsilon = sqrt.(var(q_a)) + mA = ctcompanion_matrix(ma, epsilon, meta) + + Vy = mA * Vx * mA' + cholinv(mW) + my = mA * mx + + return MvNormalMeanCovariance(my, Vy) +end diff --git a/src/rules/prototypes.jl b/src/rules/prototypes.jl index d4d20ddd4..c152401ec 100644 --- a/src/rules/prototypes.jl +++ b/src/rules/prototypes.jl @@ -110,6 +110,12 @@ include("transition/out.jl") include("transition/in.jl") include("transition/a.jl") +include("continuous_transition/y.jl") +include("continuous_transition/x.jl") +include("continuous_transition/a.jl") +include("continuous_transition/W.jl") +include("continuous_transition/marginals.jl") + include("autoregressive/y.jl") include("autoregressive/x.jl") include("autoregressive/theta.jl") diff --git a/test/nodes/test_continuous_transition.jl b/test/nodes/test_continuous_transition.jl new file mode 100644 index 000000000..34950603f --- /dev/null +++ b/test/nodes/test_continuous_transition.jl @@ -0,0 +1,39 @@ +module ContinuousTransitionNodeTest + +using Test, ReactiveMP, Random, Distributions, BayesBase, ExponentialFamily + +import ReactiveMP: getdimensionality, getjacobians, gettransformation, getunits, ctcompanion_matrix + +@testset "ContinuousTransitionNode" begin + dy, dx = 2, 3 + meta = CTMeta(a -> reshape(a, dy, dx)) + @testset "Creation" begin + node = make_node(ContinuousTransition, FactorNodeCreationOptions(nothing, meta, nothing)) + + @test functionalform(node) === ContinuousTransition + @test sdtype(node) === Stochastic() + @test name.(interfaces(node)) === (:y, :x, :a, :W) + @test factorisation(node) === ((1, 2, 3, 4),) + end + + @testset "AverageEnergy" begin + q_y_x = MvNormalMeanCovariance(zeros(5), diageye(5)) + q_a = MvNormalMeanCovariance(zeros(6), diageye(6)) + q_W = Wishart(3, diageye(2)) + + marginals = (Marginal(q_y_x, false, false, nothing), Marginal(q_a, false, false, nothing), Marginal(q_W, false, false, nothing)) + + @test score(AverageEnergy(), ContinuousTransition, Val{(:y_x, :a, :W)}(), marginals, meta) ≈ 13.0 atol = 1e-2 + @show getjacobians(meta, mean(q_a)) + end + + @testset "ContinuousTransition Functionality" begin + m_a = randn(6) + A = ctcompanion_matrix(m_a, zeros(length(m_a)), meta) + + @test size(A) == (dy, dx) + @test A == gettransformation(meta)(m_a) + end +end + +end diff --git a/test/nodes/test_transfominator.jl b/test/nodes/test_transfominator.jl new file mode 100644 index 000000000..1c4c77350 --- /dev/null +++ b/test/nodes/test_transfominator.jl @@ -0,0 +1,22 @@ +score( + AverageEnergy(), + ContinuousTransition, + Val{(:y_x, :h, :Λ)}(), + ( + Marginal(MvNormalMeanPrecision(zeros(4), diageye(4)), false, false, nothing), + Marginal(MvNormalMeanPrecision(zeros(4), diageye(4)), false, false, nothing), + Marginal(Wishart(2, diageye(2)), false, false, nothing) + ), + CTMeta((2, 2)) +) +score( + AverageEnergy(), + ContinuousTransition, + Val{(:y_x, :h, :Λ)}(), + ( + Marginal(MvNormalMeanPrecision(zeros(5), diageye(5)), false, false, nothing), + Marginal(MvNormalMeanPrecision(zeros(6), diageye(6)), false, false, nothing), + Marginal(Wishart(2, diageye(2)), false, false, nothing) + ), + CTMeta((2, 3)) +) diff --git a/test/rules/continuous_transition/test_W.jl b/test/rules/continuous_transition/test_W.jl new file mode 100644 index 000000000..dd55b3807 --- /dev/null +++ b/test/rules/continuous_transition/test_W.jl @@ -0,0 +1,68 @@ +module RulesContinuousTransitionTest + +using Test, ReactiveMP, BayesBase, Random, ExponentialFamily, Distributions, LinearAlgebra + +import ReactiveMP: @test_rules, ctcompanion_matrix, getjacobians, getunits, WishartFast + +@testset "rules:ContinuousTransition:W" begin + rng = MersenneTwister(42) + + @testset "Linear transformation" begin + # the following rule is used for testing purposes only + # It is derived separately by Thijs van de Laar + function benchmark_rule(q_y_x, mA, ΣA, UA) + myx, Vyx = mean_cov(q_y_x) + + dy = size(mA, 1) + Vx = Vyx[(dy + 1):end, (dy + 1):end] + Vy = Vyx[1:dy, 1:dy] + mx = myx[(dy + 1):end] + my = myx[1:dy] + Vyx = Vyx[1:dy, (dy + 1):end] + + G = tr(Vx * UA) * ΣA + mA * Vx * mA' - mA * Vyx' - Vyx * mA' + Vy + ΣA * mx' * UA * mx + (mA * mx - my) * (mA * mx - my)' + + return WishartFast(dy + 2, G) + end + + @testset "Structured: (q_y_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, meta::CTMeta)" begin + for (dy, dx) in [(1, 3), (2, 3), (3, 2), (2, 2)] + dydx = dy * dx + transformation = (a) -> reshape(a, dy, dx) + mA, ΣA, UA = rand(rng, dy, dx), diageye(dy), diageye(dx) + + metal = CTMeta(transformation) + Lx, Ly = rand(rng, dx, dx), rand(rng, dy, dy) + μx, Σx = rand(rng, dx), Lx * Lx' + μy, Σy = rand(rng, dy), Ly * Ly' + + qyx = MvNormalMeanCovariance([μy; μx], [Σy zeros(dy, dx); zeros(dx, dy) Σx]) + qa = MvNormalMeanCovariance(vec(mA), kron(UA, ΣA)) + + @test_rules [check_type_promotion = true, atol = 1e-5] ContinuousTransition(:W, Marginalisation) [( + input = (q_y_x = qyx, q_a = qa, meta = metal), output = benchmark_rule(qyx, mA, ΣA, UA) + )] + end + end + end + + @testset "Nonlinear transformation" begin + @testset "Structured: (q_y_x::MultivariateNormalDistributionsFamily, q_a::Any, q_W::Any, meta::CTMeta)" begin + dy, dx = 2, 2 + dydx = dy * dy + transformation = (a) -> [cos(a[1]) -sin(a[1]); sin(a[1]) cos(a[1])] + + metanl = CTMeta(transformation) + μx, Σx = zeros(dx), diageye(dx) + μy, Σy = zeros(dy), diageye(dy) + + qyx = MvNormalMeanCovariance([μy; μx], [Σy zeros(dy, dx); zeros(dx, dy) Σx]) + qa = MvNormalMeanCovariance(zeros(1), diageye(1)) + @test_rules [check_type_promotion = true] ContinuousTransition(:W, Marginalisation) [( + input = (q_y_x = qyx, q_a = qa, meta = metanl), output = WishartFast(dy + 2, dy * diageye(dy)) + )] + end + end +end + +end diff --git a/test/rules/continuous_transition/test_a.jl b/test/rules/continuous_transition/test_a.jl new file mode 100644 index 000000000..533a26397 --- /dev/null +++ b/test/rules/continuous_transition/test_a.jl @@ -0,0 +1,66 @@ +module RulesContinuousTransitionTestA + +using Test, ReactiveMP, BayesBase, Random, ExponentialFamily, Distributions + +import ReactiveMP: @test_rules, getjacobians, getunits + +@testset "rules:ContinuousTransition:a" begin + rng = MersenneTwister(42) + + @testset "Linear transformation" begin + + # the following rule is used for testing purposes only + # It is derived separately by Thijs van de Laar + function benchmark_rule(q_y_x, q_W) + myx, Vyx = mean_cov(q_y_x) + dy = size(q_W.S, 1) + Vx = Vyx[(dy + 1):end, (dy + 1):end] + mx = myx[(dy + 1):end] + my = myx[1:dy] + Vyx = Vyx[1:dy, (dy + 1):end] + mW = mean(q_W) + Λ = kron(Vx + mx * mx', mW) + return MvNormalWeightedMeanPrecision(Λ * vec((Vyx + my * mx') * inv((Vx + mx * mx'))), Λ) + end + + @testset "Structured: (q_y_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_W::Any, meta::CTMeta)" begin + for (dy, dx) in [(1, 3), (2, 3), (3, 2), (2, 2)] + dydx = dy * dx + transformation = (a) -> reshape(a, dy, dx) + a0 = rand(Float32, dydx) + metal = CTMeta(transformation) + Lx, Ly = rand(rng, dx, dx), rand(rng, dy, dy) + μx, Σx = rand(rng, dx), Lx * Lx' + μy, Σy = rand(rng, dy), Ly * Ly' + + qyx = MvNormalMeanCovariance([μy; μx], [Σy zeros(dy, dx); zeros(dx, dy) Σx]) + qa = MvNormalMeanCovariance(a0, diageye(dydx)) + qW = Wishart(dy + 1, diageye(dy)) + @test_rules [check_type_promotion = false] ContinuousTransition(:a, Marginalisation) [( + input = (q_y_x = qyx, q_a = qa, q_W = qW, meta = metal), output = benchmark_rule(qyx, qW) + )] + end + end + end + + @testset "Nonlinear transformation" begin + @testset "Structured: (q_y_x::MultivariateNormalDistributionsFamily, q_a::Any, q_W::Any, meta::CTMeta)" begin + dy, dx = 2, 2 + dydx = dy * dy + transformation = (a) -> [cos(a[1]) -sin(a[1]); sin(a[1]) cos(a[1])] + a0 = zeros(Int, 1) + metanl = CTMeta(transformation) + μx, Σx = ones(dx), diageye(dx) + μy, Σy = ones(dy), diageye(dy) + + qyx = MvNormalMeanCovariance([μy; μx], [Σy zeros(dy, dx); zeros(dx, dy) Σx]) + qa = MvNormalMeanCovariance(a0, diageye(1)) + qW = Wishart(dy, diageye(dy)) + @test_rules [check_type_promotion = true] ContinuousTransition(:a, Marginalisation) [( + input = (q_y_x = qyx, q_a = qa, q_W = qW, meta = metanl), output = MvNormalWeightedMeanPrecision(zeros(1), (qW.df * dy * dx) * diageye(1)) + )] + end + end +end + +end diff --git a/test/rules/continuous_transition/test_marginals.jl b/test/rules/continuous_transition/test_marginals.jl new file mode 100644 index 000000000..d8f9550ab --- /dev/null +++ b/test/rules/continuous_transition/test_marginals.jl @@ -0,0 +1,69 @@ +module RulesContinuousTransitionTest + +using Test, ReactiveMP, BayesBase, Random, ExponentialFamily, Distributions, LinearAlgebra +import ReactiveMP: @test_marginalrules + +@testset "marginalrules:ContinuousTransition" begin + rng = MersenneTwister(42) + @testset "Linear transformation" begin + # the following rule is used for testing purposes only + # It is derived separately by Thijs van de Laar + function benchmark_rule(m_x, m_y, q_W, q_A) + mx, Wx = mean_invcov(m_x) + my, Wy = mean_invcov(m_y) + mW = mean(q_W) + mA, ΣA, UA = q_A.M, q_A.U, q_A.V + + U = Wx + tr(mW * ΣA) * UA + + Wq = [Wy+mW -mW*mA; -mA'*mW U+mA' * mW * mA] + return MvNormalWeightedMeanPrecision([Wy * my; Wx * mx], Wq) + end + @testset "y_x: (m_y::NormalDistributionsFamily, m_x::NormalDistributionsFamily, q_a::NormalDistributionsFamily, q_W::Any)" begin + for (dy, dx) in [(1, 3), (2, 3), (3, 2), (2, 2)] + transformation = (a) -> reshape(a, dy, dx) + + mA, ΣA, UA = rand(rng, dy, dx), diageye(dy), diageye(dx) + qA = MatrixNormal(mA, ΣA, UA) + + Lx, Ly = rand(rng, dx, dx), rand(rng, dy, dy) + μx, Σx = rand(rng, dx), Lx * Lx' + μy, Σy = rand(rng, dy), Ly * Ly' + + my = MvNormalMeanCovariance(μy, Σy) + mx = MvNormalMeanCovariance(μx, Σx) + qa = MvNormalMeanCovariance(vec(mA), kron(UA, ΣA)) + qW = Wishart(dy + 1, diageye(dy)) + + metal = CTMeta(transformation) + + @test_marginalrules [check_type_promotion = true, atol = 1e-3] ContinuousTransition(:y_x) [( + input = (m_y = my, m_x = mx, q_a = qa, q_W = qW, meta = metal), output = benchmark_rule(mx, my, qW, qA) + )] + end + end + end + + @testset "Nonlinear transformation" begin + @testset "y_x: (m_y::NormalDistributionsFamily, m_x::NormalDistributionsFamily, q_a::NormalDistributionsFamily, q_W::Any)" begin + dy, dx = 2, 2 + transformation = (a) -> [cos(a[1]) -sin(a[1]); sin(a[1]) cos(a[1])] + + μx, Σx = zeros(dx), diageye(dx) + μy, Σy = zeros(dy), diageye(dy) + + my = MvNormalMeanCovariance(μy, Σy) + mx = MvNormalMeanCovariance(μx, Σx) + qa = MvNormalMeanCovariance(zeros(1), tiny * diageye(1)) + qW = Wishart(dy, diageye(dy)) + + metanl = CTMeta(transformation) + + @test_marginalrules [check_type_promotion = true, atol = 1e-5] ContinuousTransition(:y_x) [( + input = (m_y = my, m_x = mx, q_a = qa, q_W = qW, meta = metanl), + output = MvNormalWeightedMeanPrecision(zeros(4), [(dy + qW.df - 1)*diageye(dy) -(qW.df)diageye(dx); -(qW.df)diageye(dx) (dy + qW.df - 1)diageye(dy)]) + )] + end + end +end +end diff --git a/test/rules/continuous_transition/test_x.jl b/test/rules/continuous_transition/test_x.jl new file mode 100644 index 000000000..e22d8c7d4 --- /dev/null +++ b/test/rules/continuous_transition/test_x.jl @@ -0,0 +1,65 @@ +module RulesContinuousTransitionTest + +using Test, ReactiveMP, BayesBase, Random, ExponentialFamily, Distributions, LinearAlgebra + +import ReactiveMP: @test_rules, ctcompanion_matrix, getjacobians, getunits + +@testset "rules:ContinuousTransition:x" begin + rng = MersenneTwister(42) + + @testset "Linear transformation" begin + # the following rule is used for testing purposes only + # It is derived separately by Thijs van de Laar + function benchmark_rule(q_y, q_W, mA, ΣA, UA) + my, Vy = mean_cov(q_y) + + mW = mean(q_W) + + Λ = tr(mW * ΣA) * UA + mA' * inv(Vy + inv(mW)) * mA + ξ = mA' * inv(Vy + inv(mW)) * my + return MvNormalWeightedMeanPrecision(ξ, Λ) + end + + @testset "Structured: (m_y::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_W::Any, meta::CTMeta)" begin + for (dy, dx) in [(1, 3), (2, 3), (3, 2), (2, 2)] + dydx = dy * dx + transformation = (a) -> reshape(a, dy, dx) + + mA, ΣA, UA = rand(rng, dy, dx), diageye(dy), diageye(dx) + + metal = CTMeta(transformation) + Lx, Ly = rand(rng, dx, dx), rand(rng, dy, dy) + μy, Σy = rand(rng, dy), Ly * Ly' + + qy = MvNormalMeanCovariance(μy, Σy) + qa = MvNormalMeanCovariance(vec(mA), diageye(dydx)) + qW = Wishart(dy + 1, diageye(dy)) + + @test_rules [check_type_promotion = true, atol = 1e-4] ContinuousTransition(:x, Marginalisation) [( + input = (m_y = qy, q_a = qa, q_W = qW, meta = metal), output = benchmark_rule(qy, qW, mA, ΣA, UA) + )] + end + end + end + + @testset "Nonlinear transformation" begin + @testset "Structured: (m_y::MultivariateNormalDistributionsFamily, q_a::Any, q_W::Any, meta::CTMeta)" begin + dy, dx = 2, 2 + dydx = dy * dy + transformation = (a) -> [cos(a[1]) -sin(a[1]); sin(a[1]) cos(a[1])] + + metanl = CTMeta(transformation) + μy, Σy = zeros(dy), diageye(dy) + + qy = MvNormalMeanCovariance(μy, Σy) + qa = MvNormalMeanCovariance(zeros(1), tiny * diageye(1)) + qW = Wishart(dy + 1, diageye(dy)) + + @test_rules [check_type_promotion = true] ContinuousTransition(:x, Marginalisation) [( + input = (m_y = qy, q_a = qa, q_W = qW, meta = metanl), output = MvGaussianWeightedMeanPrecision(zeros(dx), 3 / 4 * diageye(dx)) + )] + end + end +end + +end diff --git a/test/rules/continuous_transition/test_y.jl b/test/rules/continuous_transition/test_y.jl new file mode 100644 index 000000000..9cb1b2a5a --- /dev/null +++ b/test/rules/continuous_transition/test_y.jl @@ -0,0 +1,61 @@ +module RulesContinuousTransitionTest + +using Test, ReactiveMP, BayesBase, Random, ExponentialFamily, Distributions + +import ReactiveMP: @test_rules, ctcompanion_matrix, getjacobians, getunits + +@testset "rules:ContinuousTransition:y" begin + rng = MersenneTwister(42) + + @testset "Linear transformation" begin + # the following rule is used for testing purposes only + # It is derived separately by Thijs van de Laar + function benchmark_rule(q_x, q_W, mA) + mx, Vx = mean_cov(q_x) + mW = mean(q_W) + return MvNormalMeanCovariance(mA * mx, mA * Vx * mA' + inv(mW)) + end + + @testset "Structured: (m_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_W::Any, meta::CTMeta)" begin + for (dy, dx) in [(1, 3), (2, 3), (3, 2), (2, 2)] + dydx = dy * dx + transformation = (a) -> reshape(a, dy, dx) + + mA = rand(rng, dy, dx) + + metal = CTMeta(transformation) + Lx = rand(rng, dx, dx) + μx, Σx = rand(rng, dx), Lx * Lx' + + qx = MvNormalMeanCovariance(μx, Σx) + qa = MvNormalMeanCovariance(vec(mA), diageye(dydx)) + qW = Wishart(dy + 1, diageye(dy)) + + @test_rules [check_type_promotion = true, atol = 1e-5] ContinuousTransition(:y, Marginalisation) [( + input = (m_x = qx, q_a = qa, q_W = qW, meta = metal), output = benchmark_rule(qx, qW, mA) + )] + end + end + end + + @testset "Nonlinear transformation" begin + @testset "Structured: (m_x::MultivariateNormalDistributionsFamily, q_a::Any, q_W::Any, meta::CTMeta)" begin + dy, dx = 2, 2 + dydx = dy * dy + transformation = (a) -> [cos(a[1]) -sin(a[1]); sin(a[1]) cos(a[1])] + + metanl = CTMeta(transformation) + μx, Σx = zeros(dx), diageye(dx) + + qx = MvNormalMeanCovariance(μx, Σx) + qa = MvNormalMeanCovariance(zeros(1), tiny * diageye(1)) + qW = Wishart(dy + 1, diageye(dy)) + + @test_rules [check_type_promotion = true] ContinuousTransition(:y, Marginalisation) [( + input = (m_x = qx, q_a = qa, q_W = qW, meta = metanl), output = MvGaussianMeanCovariance(zeros(dy), 4 / 3 * diageye(dy)) + )] + end + end +end + +end diff --git a/test/runtests.jl b/test/runtests.jl index 87e64fc2e..633948f09 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -252,6 +252,7 @@ end addtests(testrunner, "nodes/test_uniform.jl") addtests(testrunner, "nodes/test_normal_mixture.jl") addtests(testrunner, "nodes/test_softdot.jl") + addtests(testrunner, "nodes/test_continuous_transition.jl") addtests(testrunner, "rules/uniform/test_out.jl") @@ -388,5 +389,11 @@ end addtests(testrunner, "rules/autoregressive/test_theta.jl") addtests(testrunner, "rules/autoregressive/test_marginals.jl") + addtests(testrunner, "rules/continuous_transition/test_a.jl") + addtests(testrunner, "rules/continuous_transition/test_W.jl") + addtests(testrunner, "rules/continuous_transition/test_x.jl") + addtests(testrunner, "rules/continuous_transition/test_y.jl") + addtests(testrunner, "rules/continuous_transition/test_marginals.jl") + run(testrunner) end