diff --git a/src/variables/variable.jl b/src/variables/variable.jl index 079189064..0d15cce5d 100644 --- a/src/variables/variable.jl +++ b/src/variables/variable.jl @@ -100,6 +100,7 @@ setmarginals!(variables::AbstractArray{<:AbstractVariable}, marginal::Distributi setmarginals!(variables::AbstractArray{<:AbstractVariable}, marginals) = _setmarginals!(Base.IteratorSize(marginals), variables, marginals) function _setmarginals!(::Base.IteratorSize, variables::AbstractArray{<:AbstractVariable}, marginals) + @assert length(variables) == length(marginals) "Variables $(variables) and marginals $(marginals) should have the same length" foreach(zip(variables, marginals)) do (variable, marginal) setmarginal!(variable, marginal) end @@ -119,6 +120,7 @@ setmessages!(variables::AbstractArray{<:AbstractVariable}, message::Distribution setmessages!(variables::AbstractArray{<:AbstractVariable}, messages) = _setmessages!(Base.IteratorSize(messages), variables, messages) function _setmessages!(::Base.IteratorSize, variables::AbstractArray{<:AbstractVariable}, messages) + @assert length(variables) == length(messages) "Variables $(variables) and messages $(messages) should have the same length" foreach(zip(variables, messages)) do (variable, message) setmessage!(variable, message) end diff --git a/test/variables/test_variable.jl b/test/variables/test_variable.jl index 758542689..8319565ae 100644 --- a/test/variables/test_variable.jl +++ b/test/variables/test_variable.jl @@ -2,80 +2,146 @@ module ReactiveMPVariableTest using Test, ReactiveMP, Rocket, BayesBase, Distributions, ExponentialFamily -@testset "Variable" begin - import ReactiveMP: activate! - import Rocket: getscheduler +struct CustomDeterministicNode end - struct TestOptions end +@node CustomDeterministicNode Deterministic [out, x] - Rocket.getscheduler(::TestOptions) = AsapScheduler() - Base.broadcastable(::TestOptions) = Ref(TestOptions()) # for broadcasting +function test_variable_set_method(variable, dist::T, k) where {T} + activate!(variable, TestOptions()) - @testset "setmarginal! tests for randomvar" begin - for dist in (NormalMeanVariance(-2.0, 3.0), NormalMeanPrecision(-2.0, 3.0), PointMass(2.0)) - T = typeof(dist) - variable = randomvar(:r) - flag = false + test_out_var = randomvar(:out) - activate!(variable, TestOptions()) + # messages could be initialized only when the node is created + for _ in 1:k + make_node(identity, ReactiveMP.FactorNodeCreationOptions(ReactiveMP.DeltaFn, TestNodeMetaData(), nothing), test_out_var, variable) + end - setmarginal!(variable, dist) + @test degree(variable) === k - subscription = subscribe!(getmarginal(variable, IncludeAll()), (marginal) -> begin - @test typeof(marginal) <: Marginal{T} - @test mean(marginal) === mean(dist) - @test var(marginal) === var(dist) - flag = true - end) + # Check that before calling the `setmarginals!` all marginals are `nothing` + @test isnothing(Rocket.getrecent(getmarginal(variable, IncludeAll()))) - # Test that subscription happenend - @test flag === true + setmarginal!(variable, dist) - unsubscribe!(subscription) + marginal_subscription_flag = false + # After calling the `setmarginals!` the marginal should be equal to `dist` + subscription = subscribe!(getmarginal(variable, IncludeAll()), (marginal) -> begin + @test typeof(marginal) <: Marginal{T} + @test mean(marginal) === mean(dist) + @test var(marginal) === var(dist) + marginal_subscription_flag = true + end) + @test marginal_subscription_flag === true + unsubscribe!(subscription) - variablesmv = randomvar(:r, 2) - flagmv = false + # Check that before calling the `setmessages!` all messages are `nothing` + for node_index in 1:k + @test isnothing(Rocket.getrecent(ReactiveMP.messageout(variable, node_index))) + end - activate!.(variablesmv, TestOptions()) + for node_index in 1:k + setmessage!(variable, node_index, dist) + end - setmarginals!(variablesmv, dist) + for node_index in 1:k + message_subscription_flag = false + subscription = subscribe!(ReactiveMP.messageout(variable, node_index), (message) -> begin + @test typeof(message) <: Message{T} + @test mean(message) === mean(dist) + @test var(message) === var(dist) + message_subscription_flag = true + end) + @test message_subscription_flag === true + unsubscribe!(subscription) + end +end + +struct TestNodeMetaData end + +ReactiveMP.collect_meta(::Type{D}, options::FactorNodeCreationOptions{F, T}) where {D <: DeltaFn, F, T <: TestNodeMetaData} = TestNodeMetaData() +ReactiveMP.getinverse(::TestNodeMetaData) = nothing +ReactiveMP.getinverse(::TestNodeMetaData, k::Int) = nothing - subscriptionmv = subscribe!(getmarginals(variablesmv, IncludeAll()), (marginals) -> begin - @test length(marginals) === 2 - foreach(marginals) do marginal - @test typeof(marginal) <: Marginal{T} - @test mean(marginal) === mean(dist) - @test var(marginal) === var(dist) - end - flagmv = true - end) +function test_variables_set_methods(variables, dist::T, k::Int) where {T} + marginal_subscription_flag = false - # Test that subscription happenend - @test flagmv === true + activate!.(variables, TestOptions()) + + @test_throws AssertionError setmarginals!(variables, Iterators.repeated(dist, length(variables) - 1)) + + test_out_var = randomvar(:out) + + for _ in 1:k + make_node(identity, ReactiveMP.FactorNodeCreationOptions(ReactiveMP.DeltaFn, TestNodeMetaData(), nothing), test_out_var, variables...) + end - unsubscribe!(subscriptionmv) + @test all(degree.(variables) .== k) - variablesmx = randomvar(:r, 2, 2) - flagmx = false + @test_throws AssertionError setmessages!(variables, Iterators.repeated(dist, length(variables) - 1)) + @test_throws AssertionError setmessages!(variables, Iterators.repeated(dist, length(variables) - 1)) - activate!.(variablesmx, TestOptions()) + # Test `setmarginals!` - setmarginals!(variablesmx, dist) + # Check that before calling the `setmarginals!` all marginals are `nothing` + @test all(isnothing, Rocket.getrecent.(getmarginal.(variables, IncludeAll()))) - subscriptionmx = subscribe!(getmarginals(variablesmx, IncludeAll()), (marginals) -> begin - @test length(marginals) === 4 - foreach(marginals) do marginal - @test typeof(marginal) <: Marginal{T} - @test mean(marginal) === mean(dist) - @test var(marginal) === var(dist) - end - flagmx = true - end) + setmarginals!(variables, dist) - # Test that subscription happenend - @test flagmx === true + # After calling the `setmarginals!` all marginals should be equal to `dist` + subscription = subscribe!(getmarginals(variables, IncludeAll()), (marginals) -> begin + @test length(marginals) === length(variables) + foreach(marginals) do marginal + @test typeof(marginal) <: Marginal{T} + @test mean(marginal) === mean(dist) + @test var(marginal) === var(dist) + end + marginal_subscription_flag = true + end) + + # Test that subscription happenend + @test marginal_subscription_flag === true + unsubscribe!(subscription) + + # Check that before calling the `setmessages!` all messages are `nothing` + for node_index in 1:k + @test all(isnothing, Rocket.getrecent.(ReactiveMP.messageout.(variables, node_index))) + end + + # After calling the `setmessages!` all marginals should be equal to `dist` + setmessages!(variables, dist) + # For each outbound index + for node_index in 1:k + messages_subscription_flag = false + subscription = subscribe!(collectLatest(ReactiveMP.messageout.(variables, node_index)), (messages) -> begin + @test length(messages) === length(variables) + foreach(messages) do message + @test typeof(message) <: Message{T} + @test mean(message) === mean(dist) + @test var(message) === var(dist) + end + messages_subscription_flag = true + end) + @test messages_subscription_flag === true + unsubscribe!(subscription) + end +end + +@testset "Variable" begin + import ReactiveMP: activate! + import Rocket: getscheduler + + struct TestOptions end + + Rocket.getscheduler(::TestOptions) = AsapScheduler() + Base.broadcastable(::TestOptions) = Ref(TestOptions()) # for broadcasting - unsubscribe!(subscriptionmx) + @testset "setmarginal! and setmessages! tests for randomvar" begin + dists = (NormalMeanVariance(-2.0, 3.0), NormalMeanPrecision(-2.0, 3.0), PointMass(2.0)) + number_of_nodes = 1:4 + for (dist, k) in Iterators.product(dists, number_of_nodes) + test_variable_set_method(randomvar(:r), dist, k) + test_variables_set_methods(randomvar(:r, 2), dist, k) + test_variables_set_methods(randomvar(:r, 2, 2), dist, k) end end end