From b822a69e1b89509eb3dfebad33e22c01ca41d9b7 Mon Sep 17 00:00:00 2001 From: Chengfeng-Jia Date: Fri, 27 Oct 2023 15:27:17 +0200 Subject: [PATCH 1/2] Upate for PointMass initmarginals and initmasseges --- src/variables/variable.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/variables/variable.jl b/src/variables/variable.jl index cd774848b..a3174c664 100644 --- a/src/variables/variable.jl +++ b/src/variables/variable.jl @@ -92,6 +92,7 @@ getmarginals(variables::AbstractArray{<:AbstractVariable}, skip_strategy::Margin setmarginal!(variable::AbstractVariable, marginal) = setmarginal!(getmarginal(variable, IncludeAll()), marginal) +setmarginals!(variables::AbstractArray{<:AbstractVariable}, marginal::PointMass) = _setmarginals!(Base.HasLength(), variables, Iterators.repeated(marginal, length(variables))) setmarginals!(variables::AbstractArray{<:AbstractVariable}, marginal::Distribution) = _setmarginals!(Base.HasLength(), variables, Iterators.repeated(marginal, length(variables))) setmarginals!(variables::AbstractArray{<:AbstractVariable}, marginals) = _setmarginals!(Base.IteratorSize(marginals), variables, marginals) @@ -110,6 +111,7 @@ end setmessage!(variable::AbstractVariable, index::Int, message) = setmessage!(messageout(variable, index), message) setmessage!(variable::AbstractVariable, message) = foreach(i -> setmessage!(variable, i, message), 1:degree(variable)) +setmessages!(variables::AbstractArray{<:AbstractVariable}, message::PointMass) = _setmessages!(Base.HasLength(), variables, Iterators.repeated(message, length(variables))) setmessages!(variables::AbstractArray{<:AbstractVariable}, message::Distribution) = _setmessages!(Base.HasLength(), variables, Iterators.repeated(message, length(variables))) setmessages!(variables::AbstractArray{<:AbstractVariable}, messages) = _setmessages!(Base.IteratorSize(messages), variables, messages) From e858592f8c6f927de907c5479024559a199d3c19 Mon Sep 17 00:00:00 2001 From: Chengfeng-Jia Date: Tue, 31 Oct 2023 09:48:10 +0100 Subject: [PATCH 2/2] Add a test for PointMass --- test/variables/test_variable.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/variables/test_variable.jl b/test/variables/test_variable.jl index 05d5dbb9b..58dc857f0 100644 --- a/test/variables/test_variable.jl +++ b/test/variables/test_variable.jl @@ -14,7 +14,7 @@ using Rocket Base.broadcastable(::TestOptions) = Ref(TestOptions()) # for broadcasting @testset "setmarginal! tests for randomvar" begin - for dist in (NormalMeanVariance(-2.0, 3.0), NormalMeanPrecision(-2.0, 3.0)) + for dist in (NormalMeanVariance(-2.0, 3.0), NormalMeanPrecision(-2.0, 3.0), PointMass(2.0)) T = typeof(dist) variable = randomvar(:r) flag = false