Skip to content

Commit

Permalink
Merge pull request #353 from biaslab/dev-addmultiplicationrules
Browse files Browse the repository at this point in the history
Add new rules multiplication node
  • Loading branch information
bvdmitri authored Oct 9, 2023
2 parents fcfc9c1 + 026f8cb commit f275dd4
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 7 deletions.
2 changes: 2 additions & 0 deletions src/distributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ export naturalparams, as_naturalparams, lognormalizer, NaturalParameters

import Distributions: mean, median, mode, shape, scale, rate, var, std, cov, invcov, entropy, pdf, logpdf, logdetcov
import Distributions: VariateForm, ValueSupport, Distribution
import LinearAlgebra: UniformScaling

import Base: prod, convert

Expand Down Expand Up @@ -140,6 +141,7 @@ automatic_convert_paramfloattype(::Type{D}, params) where {D} = error("Cannot au
Converts (if possible) the elements of the `container` to be of type `T`.
"""
convert_paramfloattype(::Type{T}, container::AbstractArray) where {T} = convert(AbstractArray{T}, container)
convert_paramfloattype(::Type{T}, container::UniformScaling) where {T} = convert(UniformScaling{T}, container)
convert_paramfloattype(::Type{T}, number::Number) where {T} = convert(T, number)
convert_paramfloattype(::Type, ::Nothing) = nothing

Expand Down
9 changes: 6 additions & 3 deletions src/helpers/helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import Base: IteratorEltype, HasEltype
import Base: eltype, length, size, sum
import Base: IndexStyle, IndexLinear, getindex

import LinearAlgebra: UniformScaling

import Rocket: similar_typeof

"""
Expand Down Expand Up @@ -151,9 +153,10 @@ Float64
"""
function deep_eltype end

deep_eltype(::Type{T}) where {T} = T
deep_eltype(::Type{T}) where {T <: AbstractArray} = deep_eltype(eltype(T))
deep_eltype(any) = deep_eltype(typeof(any))
deep_eltype(::Type{T}) where {T} = T
deep_eltype(::Type{T}) where {T <: AbstractArray} = deep_eltype(eltype(T))
deep_eltype(::Type{T}) where {T <: UniformScaling} = deep_eltype(eltype(T))
deep_eltype(any) = deep_eltype(typeof(any))

##

Expand Down
42 changes: 42 additions & 0 deletions src/rules/multiplication/in.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,45 @@ end
@rule typeof(*)(:in, Marginalisation) (m_out::UnivariateDistribution, m_A::UnivariateDistribution, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin
return @call_rule typeof(*)(:A, Marginalisation) (m_out = m_out, m_in = m_A, meta = meta)
end

#------------------------
# Real * NormalDistributions
#------------------------
@rule typeof(*)(:in, Marginalisation) (m_A::PointMass{<:Real}, m_out::UnivariateNormalDistributionsFamily, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin
@logscale 0
a = mean(m_A)
μ_out, v_out = mean_var(m_out)
return NormalMeanVariance(μ_out / a, v_out / a^2)
end

@rule typeof(*)(:in, Marginalisation) (m_A::PointMass{<:Real}, m_out::MvNormalMeanCovariance, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin
@logscale 0
a = mean(m_A)
μ_out, v_out = mean_cov(m_out)
return MvNormalMeanCovariance(μ_out / a, v_out / a^2)
end

@rule typeof(*)(:in, Marginalisation) (m_A::PointMass{<:Real}, m_out::MvNormalMeanPrecision, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin
@logscale 0
a = mean(m_A)
μ_out, w_out = mean_precision(m_out)
return MvNormalMeanPrecision(μ_out / a, a^2 * w_out)
end

@rule typeof(*)(:in, Marginalisation) (m_A::PointMass{<:Real}, m_out::MvNormalWeightedMeanPrecision, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin
@logscale 0
a = mean(m_A)
ξ_out, w_out = weightedmean_precision(m_out)
return MvNormalWeightedMeanPrecision(a * ξ_out, a^2 * w_out)
end

@rule typeof(*)(:in, Marginalisation) (m_A::NormalDistributionsFamily, m_out::PointMass{<:Real}, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin
return @call_rule typeof(*)(:in, Marginalisation) (m_A = m_out, m_out = m_A, meta = meta, addons = getaddons()) # symmetric rule
end

#------------------------
# UniformScaling * NormalDistributions
#------------------------
@rule typeof(*)(:in, Marginalisation) (m_A::PointMass{<:UniformScaling}, m_out::NormalDistributionsFamily, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin
return @call_rule typeof(*)(:in, Marginalisation) (m_A = PointMass(mean(m_A).λ), m_out = m_out, meta = meta, addons = getaddons()) # dispatch to real * normal
end
32 changes: 30 additions & 2 deletions src/rules/multiplication/out.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ end
end

#------------------------
# Real * UnivariateNormalDistributions
# Real * NormalDistributions
#------------------------
@rule typeof(*)(:out, Marginalisation) (m_A::PointMass{<:Real}, m_in::UnivariateNormalDistributionsFamily, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin
@logscale 0
Expand All @@ -63,10 +63,38 @@ end
return NormalMeanVariance(a * μ_in, a^2 * v_in)
end

@rule typeof(*)(:out, Marginalisation) (m_A::UnivariateNormalDistributionsFamily, m_in::PointMass{<:Real}, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin
@rule typeof(*)(:out, Marginalisation) (m_A::PointMass{<:Real}, m_in::MvNormalMeanCovariance, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin
@logscale 0
a = mean(m_A)
μ_in, v_in = mean_cov(m_in)
return MvNormalMeanCovariance(a * μ_in, a^2 * v_in)
end

@rule typeof(*)(:out, Marginalisation) (m_A::PointMass{<:Real}, m_in::MvNormalMeanPrecision, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin
@logscale 0
a = mean(m_A)
μ_in, w_in = mean_precision(m_in)
return MvNormalMeanPrecision(a * μ_in, w_in / a^2)
end

@rule typeof(*)(:out, Marginalisation) (m_A::PointMass{<:Real}, m_in::MvNormalWeightedMeanPrecision, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin
@logscale 0
a = mean(m_A)
ξ_in, w_in = weightedmean_precision(m_in)
return MvNormalWeightedMeanPrecision(ξ_in / a, w_in / a^2)
end

@rule typeof(*)(:out, Marginalisation) (m_A::NormalDistributionsFamily, m_in::PointMass{<:Real}, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin
return @call_rule typeof(*)(:out, Marginalisation) (m_A = m_in, m_in = m_A, meta = meta, addons = getaddons()) # symmetric rule
end

#------------------------
# UniformScaling * NormalDistributions
#------------------------
@rule typeof(*)(:out, Marginalisation) (m_A::PointMass{<:UniformScaling}, m_in::NormalDistributionsFamily, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin
return @call_rule typeof(*)(:out, Marginalisation) (m_A = PointMass(mean(m_A).λ), m_in = m_in, meta = meta, addons = getaddons()) # dispatch to real * normal
end

#-----------------------
# Univariate Normal * Univariate Normal
#----------------------
Expand Down
16 changes: 15 additions & 1 deletion test/rules/multiplication/test_in.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,24 @@ module RulesMultiplicationInTest

using Test
using ReactiveMP
using Random, Distributions, StableRNGs
using Random, Distributions, StableRNGs, LinearAlgebra
import ReactiveMP: @test_rules, make_inversedist_message

@testset "rule:typeof(*):in" begin
@testset "Belief Propagation: (m_A::PointMass{<:Real}, m_out::MultivariateNormalDistributionsFamily)" begin
@test_rules [check_type_promotion = true] (*)(:in, Marginalisation) [
(input = (m_A = PointMass(2), m_out = MvNormalMeanCovariance([2, 4], [12 8; 8 24])), output = MvNormalMeanCovariance([1, 2], [3 2; 2 6])),
(input = (m_A = PointMass(0.5), m_out = MvNormalMeanPrecision([1, 2], [12 8; 8 24])), output = MvNormalMeanPrecision([2, 4], [3 2; 2 6])),
(input = (m_A = PointMass(0.5), m_out = MvNormalWeightedMeanPrecision([2, 4], [12 8; 8 24])), output = MvNormalWeightedMeanPrecision([1, 2], [3 2; 2 6]))
]
end
@testset "Belief Propagation: (m_A::PointMass{<:UniformScaling}, m_out::MultivariateNormalDistributionsFamily)" begin
@test_rules [check_type_promotion = true] (*)(:in, Marginalisation) [
(input = (m_A = PointMass(2I), m_out = MvNormalMeanCovariance([2, 4], [12 8; 8 24])), output = MvNormalMeanCovariance([1, 2], [3 2; 2 6])),
(input = (m_A = PointMass(0.5I), m_out = MvNormalMeanPrecision([1, 2], [12 8; 8 24])), output = MvNormalMeanPrecision([2, 4], [3 2; 2 6])),
(input = (m_A = PointMass(0.5I), m_out = MvNormalWeightedMeanPrecision([2, 4], [12 8; 8 24])), output = MvNormalWeightedMeanPrecision([1, 2], [3 2; 2 6]))
]
end
@testset "Univariate Gaussian messages" begin
rng = StableRNG(42)
d1 = NormalMeanVariance(0.0, 1.0)
Expand Down
16 changes: 15 additions & 1 deletion test/rules/multiplication/test_out.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,24 @@ module RulesMultiplicationOutTest

using Test
using ReactiveMP
using Random, Distributions, StableRNGs
using Random, Distributions, StableRNGs, LinearAlgebra
import ReactiveMP: @test_rules, besselmod, make_productdist_message

@testset "rule:typeof(*):out" begin
@testset "Belief Propagation: (m_A::PointMass{<:Real}, m_in::MultivariateNormalDistributionsFamily)" begin
@test_rules [check_type_promotion = true] (*)(:out, Marginalisation) [
(input = (m_A = PointMass(2), m_in = MvNormalMeanCovariance([1, 2], [3 2; 2 6])), output = MvNormalMeanCovariance([2, 4], [12 8; 8 24])),
(input = (m_A = PointMass(0.5), m_in = MvNormalMeanPrecision([2, 4], [3 2; 2 6])), output = MvNormalMeanPrecision([1, 2], [12 8; 8 24])),
(input = (m_A = PointMass(0.5), m_in = MvNormalWeightedMeanPrecision([1, 2], [3 2; 2 6])), output = MvNormalWeightedMeanPrecision([2, 4], [12 8; 8 24]))
]
end
@testset "Belief Propagation: (m_A::PointMass{<:UniformScaling}, m_in::MultivariateNormalDistributionsFamily)" begin
@test_rules [check_type_promotion = true] (*)(:out, Marginalisation) [
(input = (m_A = PointMass(2I), m_in = MvNormalMeanCovariance([1, 2], [3 2; 2 6])), output = MvNormalMeanCovariance([2, 4], [12 8; 8 24])),
(input = (m_A = PointMass(0.5I), m_in = MvNormalMeanPrecision([2, 4], [3 2; 2 6])), output = MvNormalMeanPrecision([1, 2], [12 8; 8 24])),
(input = (m_A = PointMass(0.5I), m_in = MvNormalWeightedMeanPrecision([1, 2], [3 2; 2 6])), output = MvNormalWeightedMeanPrecision([2, 4], [12 8; 8 24]))
]
end
@testset "Univariate Gaussian messages" begin
rng = StableRNG(42)
d1 = NormalMeanVariance(0.0, 1.0)
Expand Down

0 comments on commit f275dd4

Please sign in to comment.