Skip to content

Commit

Permalink
Merge pull request #360 from biaslab/Fix_PointMass_init
Browse files Browse the repository at this point in the history
Upate for PointMass initmarginals and initmasseges
  • Loading branch information
bvdmitri authored Nov 1, 2023
2 parents 43a1fb0 + e858592 commit 6cff9c7
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 1 deletion.
2 changes: 2 additions & 0 deletions src/variables/variable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ getmarginals(variables::AbstractArray{<:AbstractVariable}, skip_strategy::Margin

setmarginal!(variable::AbstractVariable, marginal) = setmarginal!(getmarginal(variable, IncludeAll()), marginal)

setmarginals!(variables::AbstractArray{<:AbstractVariable}, marginal::PointMass) = _setmarginals!(Base.HasLength(), variables, Iterators.repeated(marginal, length(variables)))
setmarginals!(variables::AbstractArray{<:AbstractVariable}, marginal::Distribution) = _setmarginals!(Base.HasLength(), variables, Iterators.repeated(marginal, length(variables)))
setmarginals!(variables::AbstractArray{<:AbstractVariable}, marginals) = _setmarginals!(Base.IteratorSize(marginals), variables, marginals)

Expand All @@ -113,6 +114,7 @@ end
setmessage!(variable::AbstractVariable, index::Int, message) = setmessage!(messageout(variable, index), message)
setmessage!(variable::AbstractVariable, message) = foreach(i -> setmessage!(variable, i, message), 1:degree(variable))

setmessages!(variables::AbstractArray{<:AbstractVariable}, message::PointMass) = _setmessages!(Base.HasLength(), variables, Iterators.repeated(message, length(variables)))
setmessages!(variables::AbstractArray{<:AbstractVariable}, message::Distribution) = _setmessages!(Base.HasLength(), variables, Iterators.repeated(message, length(variables)))
setmessages!(variables::AbstractArray{<:AbstractVariable}, messages) = _setmessages!(Base.IteratorSize(messages), variables, messages)

Expand Down
2 changes: 1 addition & 1 deletion test/variables/test_variable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ using Test, ReactiveMP, Rocket, BayesBase, Distributions, ExponentialFamily
Base.broadcastable(::TestOptions) = Ref(TestOptions()) # for broadcasting

@testset "setmarginal! tests for randomvar" begin
for dist in (NormalMeanVariance(-2.0, 3.0), NormalMeanPrecision(-2.0, 3.0))
for dist in (NormalMeanVariance(-2.0, 3.0), NormalMeanPrecision(-2.0, 3.0), PointMass(2.0))
T = typeof(dist)
variable = randomvar(:r)
flag = false
Expand Down

0 comments on commit 6cff9c7

Please sign in to comment.