diff --git a/src/distributions.jl b/src/distributions.jl index e8c2f787f..a3d79ebc7 100644 --- a/src/distributions.jl +++ b/src/distributions.jl @@ -7,6 +7,7 @@ export naturalparams, as_naturalparams, lognormalizer, NaturalParameters import Distributions: mean, median, mode, shape, scale, rate, var, std, cov, invcov, entropy, pdf, logpdf, logdetcov import Distributions: VariateForm, ValueSupport, Distribution +import LinearAlgebra: UniformScaling import Base: prod, convert @@ -140,6 +141,7 @@ automatic_convert_paramfloattype(::Type{D}, params) where {D} = error("Cannot au Converts (if possible) the elements of the `container` to be of type `T`. """ convert_paramfloattype(::Type{T}, container::AbstractArray) where {T} = convert(AbstractArray{T}, container) +convert_paramfloattype(::Type{T}, container::UniformScaling) where {T} = convert(UniformScaling{T}, container) convert_paramfloattype(::Type{T}, number::Number) where {T} = convert(T, number) convert_paramfloattype(::Type, ::Nothing) = nothing diff --git a/src/helpers/helpers.jl b/src/helpers/helpers.jl index 82fca667e..e0251e2c0 100644 --- a/src/helpers/helpers.jl +++ b/src/helpers/helpers.jl @@ -9,6 +9,8 @@ import Base: IteratorEltype, HasEltype import Base: eltype, length, size, sum import Base: IndexStyle, IndexLinear, getindex +import LinearAlgebra: UniformScaling + import Rocket: similar_typeof """ @@ -151,9 +153,10 @@ Float64 """ function deep_eltype end -deep_eltype(::Type{T}) where {T} = T -deep_eltype(::Type{T}) where {T <: AbstractArray} = deep_eltype(eltype(T)) -deep_eltype(any) = deep_eltype(typeof(any)) +deep_eltype(::Type{T}) where {T} = T +deep_eltype(::Type{T}) where {T <: AbstractArray} = deep_eltype(eltype(T)) +deep_eltype(::Type{T}) where {T <: UniformScaling} = deep_eltype(eltype(T)) +deep_eltype(any) = deep_eltype(typeof(any)) ## diff --git a/src/rules/multiplication/in.jl b/src/rules/multiplication/in.jl index 1ebc4db42..55d45ec9b 100644 --- a/src/rules/multiplication/in.jl +++ b/src/rules/multiplication/in.jl @@ -66,3 +66,45 @@ end @rule typeof(*)(:in, Marginalisation) (m_out::UnivariateDistribution, m_A::UnivariateDistribution, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin return @call_rule typeof(*)(:A, Marginalisation) (m_out = m_out, m_in = m_A, meta = meta) end + +#------------------------ +# Real * NormalDistributions +#------------------------ +@rule typeof(*)(:in, Marginalisation) (m_A::PointMass{<:Real}, m_out::UnivariateNormalDistributionsFamily, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin + @logscale 0 + a = mean(m_A) + μ_out, v_out = mean_var(m_out) + return NormalMeanVariance(μ_out / a, v_out / a^2) +end + +@rule typeof(*)(:in, Marginalisation) (m_A::PointMass{<:Real}, m_out::MvNormalMeanCovariance, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin + @logscale 0 + a = mean(m_A) + μ_out, v_out = mean_cov(m_out) + return MvNormalMeanCovariance(μ_out / a, v_out / a^2) +end + +@rule typeof(*)(:in, Marginalisation) (m_A::PointMass{<:Real}, m_out::MvNormalMeanPrecision, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin + @logscale 0 + a = mean(m_A) + μ_out, w_out = mean_precision(m_out) + return MvNormalMeanPrecision(μ_out / a, a^2 * w_out) +end + +@rule typeof(*)(:in, Marginalisation) (m_A::PointMass{<:Real}, m_out::MvNormalWeightedMeanPrecision, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin + @logscale 0 + a = mean(m_A) + ξ_out, w_out = weightedmean_precision(m_out) + return MvNormalWeightedMeanPrecision(a * ξ_out, a^2 * w_out) +end + +@rule typeof(*)(:in, Marginalisation) (m_A::NormalDistributionsFamily, m_out::PointMass{<:Real}, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin + return @call_rule typeof(*)(:in, Marginalisation) (m_A = m_out, m_out = m_A, meta = meta, addons = getaddons()) # symmetric rule +end + +#------------------------ +# UniformScaling * NormalDistributions +#------------------------ +@rule typeof(*)(:in, Marginalisation) (m_A::PointMass{<:UniformScaling}, m_out::NormalDistributionsFamily, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin + return @call_rule typeof(*)(:in, Marginalisation) (m_A = PointMass(mean(m_A).λ), m_out = m_out, meta = meta, addons = getaddons()) # dispatch to real * normal +end diff --git a/src/rules/multiplication/out.jl b/src/rules/multiplication/out.jl index 1fc462ad4..9e2d01cfa 100644 --- a/src/rules/multiplication/out.jl +++ b/src/rules/multiplication/out.jl @@ -54,7 +54,7 @@ end end #------------------------ -# Real * UnivariateNormalDistributions +# Real * NormalDistributions #------------------------ @rule typeof(*)(:out, Marginalisation) (m_A::PointMass{<:Real}, m_in::UnivariateNormalDistributionsFamily, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin @logscale 0 @@ -63,10 +63,38 @@ end return NormalMeanVariance(a * μ_in, a^2 * v_in) end -@rule typeof(*)(:out, Marginalisation) (m_A::UnivariateNormalDistributionsFamily, m_in::PointMass{<:Real}, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin +@rule typeof(*)(:out, Marginalisation) (m_A::PointMass{<:Real}, m_in::MvNormalMeanCovariance, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin + @logscale 0 + a = mean(m_A) + μ_in, v_in = mean_cov(m_in) + return MvNormalMeanCovariance(a * μ_in, a^2 * v_in) +end + +@rule typeof(*)(:out, Marginalisation) (m_A::PointMass{<:Real}, m_in::MvNormalMeanPrecision, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin + @logscale 0 + a = mean(m_A) + μ_in, w_in = mean_precision(m_in) + return MvNormalMeanPrecision(a * μ_in, w_in / a^2) +end + +@rule typeof(*)(:out, Marginalisation) (m_A::PointMass{<:Real}, m_in::MvNormalWeightedMeanPrecision, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin + @logscale 0 + a = mean(m_A) + ξ_in, w_in = weightedmean_precision(m_in) + return MvNormalWeightedMeanPrecision(ξ_in / a, w_in / a^2) +end + +@rule typeof(*)(:out, Marginalisation) (m_A::NormalDistributionsFamily, m_in::PointMass{<:Real}, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin return @call_rule typeof(*)(:out, Marginalisation) (m_A = m_in, m_in = m_A, meta = meta, addons = getaddons()) # symmetric rule end +#------------------------ +# UniformScaling * NormalDistributions +#------------------------ +@rule typeof(*)(:out, Marginalisation) (m_A::PointMass{<:UniformScaling}, m_in::NormalDistributionsFamily, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin + return @call_rule typeof(*)(:out, Marginalisation) (m_A = PointMass(mean(m_A).λ), m_in = m_in, meta = meta, addons = getaddons()) # dispatch to real * normal +end + #----------------------- # Univariate Normal * Univariate Normal #---------------------- diff --git a/test/rules/multiplication/test_in.jl b/test/rules/multiplication/test_in.jl index 88c65d23e..cc669d280 100644 --- a/test/rules/multiplication/test_in.jl +++ b/test/rules/multiplication/test_in.jl @@ -2,10 +2,24 @@ module RulesMultiplicationInTest using Test using ReactiveMP -using Random, Distributions, StableRNGs +using Random, Distributions, StableRNGs, LinearAlgebra import ReactiveMP: @test_rules, make_inversedist_message @testset "rule:typeof(*):in" begin + @testset "Belief Propagation: (m_A::PointMass{<:Real}, m_out::MultivariateNormalDistributionsFamily)" begin + @test_rules [check_type_promotion = true] (*)(:in, Marginalisation) [ + (input = (m_A = PointMass(2), m_out = MvNormalMeanCovariance([2, 4], [12 8; 8 24])), output = MvNormalMeanCovariance([1, 2], [3 2; 2 6])), + (input = (m_A = PointMass(0.5), m_out = MvNormalMeanPrecision([1, 2], [12 8; 8 24])), output = MvNormalMeanPrecision([2, 4], [3 2; 2 6])), + (input = (m_A = PointMass(0.5), m_out = MvNormalWeightedMeanPrecision([2, 4], [12 8; 8 24])), output = MvNormalWeightedMeanPrecision([1, 2], [3 2; 2 6])) + ] + end + @testset "Belief Propagation: (m_A::PointMass{<:UniformScaling}, m_out::MultivariateNormalDistributionsFamily)" begin + @test_rules [check_type_promotion = true] (*)(:in, Marginalisation) [ + (input = (m_A = PointMass(2I), m_out = MvNormalMeanCovariance([2, 4], [12 8; 8 24])), output = MvNormalMeanCovariance([1, 2], [3 2; 2 6])), + (input = (m_A = PointMass(0.5I), m_out = MvNormalMeanPrecision([1, 2], [12 8; 8 24])), output = MvNormalMeanPrecision([2, 4], [3 2; 2 6])), + (input = (m_A = PointMass(0.5I), m_out = MvNormalWeightedMeanPrecision([2, 4], [12 8; 8 24])), output = MvNormalWeightedMeanPrecision([1, 2], [3 2; 2 6])) + ] + end @testset "Univariate Gaussian messages" begin rng = StableRNG(42) d1 = NormalMeanVariance(0.0, 1.0) diff --git a/test/rules/multiplication/test_out.jl b/test/rules/multiplication/test_out.jl index 935769741..5163abbe1 100644 --- a/test/rules/multiplication/test_out.jl +++ b/test/rules/multiplication/test_out.jl @@ -2,10 +2,24 @@ module RulesMultiplicationOutTest using Test using ReactiveMP -using Random, Distributions, StableRNGs +using Random, Distributions, StableRNGs, LinearAlgebra import ReactiveMP: @test_rules, besselmod, make_productdist_message @testset "rule:typeof(*):out" begin + @testset "Belief Propagation: (m_A::PointMass{<:Real}, m_in::MultivariateNormalDistributionsFamily)" begin + @test_rules [check_type_promotion = true] (*)(:out, Marginalisation) [ + (input = (m_A = PointMass(2), m_in = MvNormalMeanCovariance([1, 2], [3 2; 2 6])), output = MvNormalMeanCovariance([2, 4], [12 8; 8 24])), + (input = (m_A = PointMass(0.5), m_in = MvNormalMeanPrecision([2, 4], [3 2; 2 6])), output = MvNormalMeanPrecision([1, 2], [12 8; 8 24])), + (input = (m_A = PointMass(0.5), m_in = MvNormalWeightedMeanPrecision([1, 2], [3 2; 2 6])), output = MvNormalWeightedMeanPrecision([2, 4], [12 8; 8 24])) + ] + end + @testset "Belief Propagation: (m_A::PointMass{<:UniformScaling}, m_in::MultivariateNormalDistributionsFamily)" begin + @test_rules [check_type_promotion = true] (*)(:out, Marginalisation) [ + (input = (m_A = PointMass(2I), m_in = MvNormalMeanCovariance([1, 2], [3 2; 2 6])), output = MvNormalMeanCovariance([2, 4], [12 8; 8 24])), + (input = (m_A = PointMass(0.5I), m_in = MvNormalMeanPrecision([2, 4], [3 2; 2 6])), output = MvNormalMeanPrecision([1, 2], [12 8; 8 24])), + (input = (m_A = PointMass(0.5I), m_in = MvNormalWeightedMeanPrecision([1, 2], [3 2; 2 6])), output = MvNormalWeightedMeanPrecision([2, 4], [12 8; 8 24])) + ] + end @testset "Univariate Gaussian messages" begin rng = StableRNG(42) d1 = NormalMeanVariance(0.0, 1.0)