diff --git a/src/distributions/bernoulli.jl b/src/distributions/bernoulli.jl index 8ba29f91..96eb944a 100644 --- a/src/distributions/bernoulli.jl +++ b/src/distributions/bernoulli.jl @@ -89,6 +89,10 @@ isbasemeasureconstant(::Type{Bernoulli}) = ConstantBaseMeasure() getbasemeasure(::Type{Bernoulli}) = (x) -> oneunit(x) getsufficientstatistics(::Type{Bernoulli}) = (identity,) +getexpectationlogbasemeasure(::NaturalParametersSpace, ::Type{Bernoulli}) = (_) -> begin + return log(1) +end + getlogpartition(::NaturalParametersSpace, ::Type{Bernoulli}) = (η) -> begin (η₁,) = unpack_parameters(Bernoulli, η) return -log(logistic(-η₁)) diff --git a/src/distributions/beta.jl b/src/distributions/beta.jl index 7e363a9e..eb0a845a 100644 --- a/src/distributions/beta.jl +++ b/src/distributions/beta.jl @@ -57,6 +57,10 @@ isbasemeasureconstant(::Type{Beta}) = ConstantBaseMeasure() getbasemeasure(::Type{Beta}) = (x) -> oneunit(x) getsufficientstatistics(::Type{Beta}) = (log, mirrorlog) +getexpectationlogbasemeasure(::NaturalParametersSpace, ::Type{Beta}) = (_) -> begin + return log(1) +end + getlogpartition(::NaturalParametersSpace, ::Type{Beta}) = (η) -> begin (η₁, η₂) = unpack_parameters(Beta, η) return logbeta(η₁ + one(η₁), η₂ + one(η₂)) diff --git a/src/distributions/binomial.jl b/src/distributions/binomial.jl index 79a00ad4..bfdf8735 100644 --- a/src/distributions/binomial.jl +++ b/src/distributions/binomial.jl @@ -87,6 +87,10 @@ isbasemeasureconstant(::Type{Binomial}) = NonConstantBaseMeasure() getbasemeasure(::Type{Binomial}, ntrials) = Base.Fix1(binomial, ntrials) getsufficientstatistics(::Type{Binomial}, _) = (identity,) +getexpectationlogbasemeasure(::NaturalParametersSpace, ::Type{Binomial}) = (_) -> begin + error("Expectation of log base measure is not defined for Binomial distribution") +end + getlogpartition(::NaturalParametersSpace, ::Type{Binomial}, ntrials) = (η) -> begin (η₁,) = unpack_parameters(Binomial, η) return ntrials * log1pexp(η₁) diff --git a/src/distributions/categorical.jl b/src/distributions/categorical.jl index bd256d06..445f9958 100644 --- a/src/distributions/categorical.jl +++ b/src/distributions/categorical.jl @@ -67,6 +67,10 @@ isbasemeasureconstant(::Type{Categorical}) = ConstantBaseMeasure() getbasemeasure(::Type{Categorical}, _) = (x) -> oneunit(x) getsufficientstatistics(::Type{Categorical}, conditioner) = ((x) -> OneElement(x, conditioner),) +getexpectationlogbasemeasure(::NaturalParametersSpace, ::Type{Categorical}) = (_) -> begin + log(1) +end + getlogpartition(::NaturalParametersSpace, ::Type{Categorical}, conditioner) = (η) -> begin if (conditioner !== length(η)) diff --git a/src/distributions/chi_squared.jl b/src/distributions/chi_squared.jl index 209267cc..d268dae0 100644 --- a/src/distributions/chi_squared.jl +++ b/src/distributions/chi_squared.jl @@ -64,6 +64,10 @@ getlogpartition(::NaturalParametersSpace, ::Type{Chisq}) = (η) -> begin return loggamma(η1 + o) + (η1 + o) * logtwo end +getexpectationlogbasemeasure(::NaturalParametersSpace, ::Type{Chisq}) = (η) -> begin + η/2 +end + getgradlogpartition(::NaturalParametersSpace, ::Type{Chisq}) = (η) -> begin (η1,) = unpack_parameters(Chisq, η) return SA[digamma(η1 + one(η1)) + logtwo] diff --git a/src/distributions/dirichlet.jl b/src/distributions/dirichlet.jl index 3dc19268..a34c827b 100644 --- a/src/distributions/dirichlet.jl +++ b/src/distributions/dirichlet.jl @@ -55,6 +55,10 @@ isbasemeasureconstant(::Type{Dirichlet}) = ConstantBaseMeasure() getbasemeasure(::Type{Dirichlet}) = (x) -> one(Float64) getsufficientstatistics(::Type{Dirichlet}) = (x -> vmap(log, x),) +getexpectationlogbasemeasure(::NaturalParametersSpace, ::Type{Dirichlet}) = (_) -> begin + log(1) +end + getlogpartition(::NaturalParametersSpace, ::Type{Dirichlet}) = (η) -> begin (η1,) = unpack_parameters(Dirichlet, η) firstterm = mapreduce(x -> loggamma(x + 1), +, η1) diff --git a/src/distributions/erlang.jl b/src/distributions/erlang.jl index e65e344a..e4f82752 100644 --- a/src/distributions/erlang.jl +++ b/src/distributions/erlang.jl @@ -51,6 +51,10 @@ isbasemeasureconstant(::Type{Erlang}) = ConstantBaseMeasure() getbasemeasure(::Type{Erlang}) = (x) -> one(x) getsufficientstatistics(::Type{Erlang}) = (log, identity) +getexpectationlogbasemeasure(::NaturalParametersSpace, ::Type{Erlang}) = (_) -> begin + log(1) +end + getlogpartition(::NaturalParametersSpace, ::Type{Erlang}) = (η) -> begin (η1, η2) = unpack_parameters(Erlang, η) diff --git a/src/distributions/exponential.jl b/src/distributions/exponential.jl index 1278ee24..84858ab2 100644 --- a/src/distributions/exponential.jl +++ b/src/distributions/exponential.jl @@ -39,6 +39,10 @@ isbasemeasureconstant(::Type{Exponential}) = ConstantBaseMeasure() getbasemeasure(::Type{Exponential}) = (x) -> oneunit(x) getsufficientstatistics(::Type{Exponential}) = (identity,) +getexpectationlogbasemeasure(::NaturalParametersSpace, ::Type{Exponential}) = (_) -> begin + return log(1) +end + getlogpartition(::NaturalParametersSpace, ::Type{Exponential}) = (η) -> begin (η₁,) = unpack_parameters(Exponential, η) return -log(-η₁) diff --git a/src/distributions/gamma_family/gamma_family.jl b/src/distributions/gamma_family/gamma_family.jl index a08976b0..4fa0958f 100644 --- a/src/distributions/gamma_family/gamma_family.jl +++ b/src/distributions/gamma_family/gamma_family.jl @@ -90,6 +90,10 @@ isbasemeasureconstant(::Type{Gamma}) = ConstantBaseMeasure() getbasemeasure(::Type{Gamma}) = (x) -> oneunit(x) getsufficientstatistics(::Type{Gamma}) = (log, identity) +getexpectationlogbasemeasure(::NaturalParametersSpace, ::Type{Gamma}) = (_) -> begin + return log(1) +end + getlogpartition(::NaturalParametersSpace, ::Type{Gamma}) = (η) -> begin (η₁, η₂) = unpack_parameters(Gamma, η) return loggamma(η₁ + one(η₁)) - (η₁ + one(η₁)) * log(-η₂) diff --git a/src/distributions/gamma_inverse.jl b/src/distributions/gamma_inverse.jl index e521ab60..b3af39e4 100644 --- a/src/distributions/gamma_inverse.jl +++ b/src/distributions/gamma_inverse.jl @@ -58,6 +58,10 @@ isbasemeasureconstant(::Type{GammaInverse}) = ConstantBaseMeasure() getbasemeasure(::Type{GammaInverse}) = (x) -> oneunit(x) getsufficientstatistics(::Type{GammaInverse}) = (log, inv) +getexpectationlogbasemeasure(::NaturalParametersSpace, ::Type{GammaInverse}) = (_) -> begin + log(1) +end + getlogpartition(::NaturalParametersSpace, ::Type{GammaInverse}) = (η) -> begin (η₁, η₂) = unpack_parameters(GammaInverse, η) return loggamma(-η₁ - one(η₁)) - (-η₁ - one(η₁)) * log(-η₂) diff --git a/src/distributions/geometric.jl b/src/distributions/geometric.jl index 81e00cdf..5a78f4ea 100644 --- a/src/distributions/geometric.jl +++ b/src/distributions/geometric.jl @@ -39,6 +39,10 @@ isbasemeasureconstant(::Type{Geometric}) = ConstantBaseMeasure() getbasemeasure(::Type{Geometric}) = (x) -> one(x) getsufficientstatistics(::Type{Geometric}) = (identity,) +getexpectationlogbasemeasure(::NaturalParametersSpace, ::Type{Geometric}) = (_) -> begin + log(1) +end + getlogpartition(::NaturalParametersSpace, ::Type{Geometric}) = (η) -> begin (η1,) = unpack_parameters(Geometric, η) return -log(one(η1) - exp(η1)) diff --git a/src/distributions/laplace.jl b/src/distributions/laplace.jl index 4d60b1c5..af21bd78 100644 --- a/src/distributions/laplace.jl +++ b/src/distributions/laplace.jl @@ -163,6 +163,10 @@ getsufficientstatistics(::Type{Laplace}, conditioner) = ( (x) -> abs(x - conditioner), ) +getexpectationlogbasemeasure(::Type{Laplace}, conditioner) = (_) -> begin + log(1) +end + getlogpartition(::NaturalParametersSpace, ::Type{Laplace}, _) = (η) -> begin (η₁,) = unpack_parameters(Laplace, η) return log(-2 / η₁) diff --git a/src/distributions/lognormal.jl b/src/distributions/lognormal.jl index 9c1cf971..187a5a84 100644 --- a/src/distributions/lognormal.jl +++ b/src/distributions/lognormal.jl @@ -48,6 +48,10 @@ isbasemeasureconstant(::Type{LogNormal}) = ConstantBaseMeasure() getbasemeasure(::Type{LogNormal}) = (x) -> invsqrt2π getsufficientstatistics(::Type{LogNormal}) = (log, x -> abs2(log(x))) +getexpectationlogbasemeasure(::NaturalParametersSpace, ::Type{LogNormal}) = (_) -> begin + 0.5*log2π +end + getlogpartition(::NaturalParametersSpace, ::Type{LogNormal}) = (η) -> begin (η₁, η₂) = unpack_parameters(LogNormal, η) return -(η₁ + 1)^2 / (4η₂) - log(-2η₂) / 2 diff --git a/src/distributions/matrix_dirichlet.jl b/src/distributions/matrix_dirichlet.jl index 06010965..3dd978f1 100644 --- a/src/distributions/matrix_dirichlet.jl +++ b/src/distributions/matrix_dirichlet.jl @@ -143,6 +143,10 @@ isbasemeasureconstant(::Type{MatrixDirichlet}) = ConstantBaseMeasure() getbasemeasure(::Type{MatrixDirichlet}) = (x) -> one(Float64) getsufficientstatistics(::Type{MatrixDirichlet}) = (x -> vmap(log, x),) +getexpectationlogbasemeasure(::NaturalParametersSpace, ::Type{MatrixDirichlet}) = (_) -> begin + log(1) +end + getlogpartition(::NaturalParametersSpace, ::Type{MatrixDirichlet}) = (η) -> begin (η1,) = unpack_parameters(MatrixDirichlet, η) diff --git a/src/distributions/mv_normal_wishart.jl b/src/distributions/mv_normal_wishart.jl index 6c634235..54d04da7 100644 --- a/src/distributions/mv_normal_wishart.jl +++ b/src/distributions/mv_normal_wishart.jl @@ -185,6 +185,10 @@ function getsufficientstatistics(::Type{MvNormalWishart}) ) end +getexpectationlogbasemeasure(::NaturalParametersSpace, ::Type{MvNormalWishart}) = (_) -> begin + log(1) +end + getlogpartition(::NaturalParametersSpace, ::Type{MvNormalWishart}) = (η) -> begin η1, η2, η3, η4 = unpack_parameters(MvNormalWishart, η) d = length(η1) diff --git a/src/distributions/negative_binomial.jl b/src/distributions/negative_binomial.jl index c0d3ddf6..4071978d 100644 --- a/src/distributions/negative_binomial.jl +++ b/src/distributions/negative_binomial.jl @@ -101,6 +101,10 @@ isbasemeasureconstant(::Type{NegativeBinomial}) = NonConstantBaseMeasure() getbasemeasure(::Type{NegativeBinomial}, conditioner) = (x) -> binomial(Int(x + conditioner - 1), x) getsufficientstatistics(::Type{NegativeBinomial}, conditioner) = (identity,) +getexpectationlogbasemeasure(::NaturalParametersSpace, ::Type{NegativeBinomial}, conditioner) = (η) -> begin + error("Expectation of log base measure is not implemented for NegativeBinomial distribution") +end + getlogpartition(::NaturalParametersSpace, ::Type{NegativeBinomial}, conditioner) = (η) -> begin (η1,) = unpack_parameters(NegativeBinomial, η) return -conditioner * log(one(η1) - exp(η1)) diff --git a/src/distributions/normal_family/normal_family.jl b/src/distributions/normal_family/normal_family.jl index 80e70658..4299e69c 100644 --- a/src/distributions/normal_family/normal_family.jl +++ b/src/distributions/normal_family/normal_family.jl @@ -578,6 +578,10 @@ isbasemeasureconstant(::Type{NormalMeanVariance}) = ConstantBaseMeasure() getbasemeasure(::Type{NormalMeanVariance}) = (x) -> convert(typeof(x), invsqrt2π) getsufficientstatistics(::Type{NormalMeanVariance}) = (identity, abs2) +getexpectationlogbasemeasure(::NaturalParametersSpace, ::Type{NormalMeanVariance}) = (η) -> begin + 0.5*log2π +end + getlogpartition(::NaturalParametersSpace, ::Type{NormalMeanVariance}) = (η) -> begin (η₁, η₂) = unpack_parameters(NormalMeanVariance, η) return -abs2(η₁) / 4η₂ - log(-2η₂) / 2 @@ -691,6 +695,11 @@ isbasemeasureconstant(::Type{MvNormalMeanCovariance}) = ConstantBaseMeasure() getbasemeasure(::Type{MvNormalMeanCovariance}) = (x) -> (2π)^(length(x) / -2) getsufficientstatistics(::Type{MvNormalMeanCovariance}) = (identity, (x) -> x * x') +getexpectationlogbasemeasure(::NaturalParametersSpace, ::Type{MvNormalMeanCovariance}) = (η) -> begin + k = div(-1 + isqrt(1 + 4 * length(η)), 2) + return k * log2π / 2 +end + getlogpartition(::NaturalParametersSpace, ::Type{MvNormalMeanCovariance}) = (η) -> begin (η₁, η₂) = unpack_parameters(MvNormalMeanCovariance, η) k = length(η₁) diff --git a/src/distributions/normal_gamma.jl b/src/distributions/normal_gamma.jl index a557d50f..bd0f1c9c 100644 --- a/src/distributions/normal_gamma.jl +++ b/src/distributions/normal_gamma.jl @@ -137,6 +137,10 @@ getbasemeasure(::Type{NormalGamma}) = (x) -> invsqrt2π # x is a 2d vector where first dimension is mean and the second dimension is precision component getsufficientstatistics(::Type{NormalGamma}) = (x -> x[1] * x[2], x -> x[2] * x[1]^2, x -> log(x[2]), x -> x[2]) +getexpectationlogbasemeasure(::NaturalParametersSpace, ::Type{NormalGamma}) = (_) -> begin + 0.5*log2π +end + getlogpartition(::NaturalParametersSpace, ::Type{NormalGamma}) = (η) -> begin (η1, η2, η3, η4) = unpack_parameters(NormalGamma, η) η3half = η3 + (1 / 2) diff --git a/src/distributions/pareto.jl b/src/distributions/pareto.jl index da09c42a..a2e0d2ea 100644 --- a/src/distributions/pareto.jl +++ b/src/distributions/pareto.jl @@ -143,7 +143,11 @@ isbasemeasureconstant(::Type{Pareto}) = ConstantBaseMeasure() getbasemeasure(::Type{Pareto}, _) = (x) -> oneunit(x) getsufficientstatistics(::Type{Pareto}, conditioner) = (log,) -getlogpartition(::NaturalParametersSpace, ::Type{Pareto}, conditioner) = (η) -> begin +getexpectationlogbasemeasure(::Type{Pareto}, _) = (_) -> begin + log(1) +end + +getlogpartition(::Type{Pareto}, conditioner) = (η) -> begin (η1,) = unpack_parameters(Pareto, η) return log(conditioner^(one(η1) + η1) / (-one(η1) - η1)) end diff --git a/src/distributions/poisson.jl b/src/distributions/poisson.jl index abdabd6f..1fc76f2a 100644 --- a/src/distributions/poisson.jl +++ b/src/distributions/poisson.jl @@ -55,6 +55,10 @@ isbasemeasureconstant(::Type{Poisson}) = NonConstantBaseMeasure() getbasemeasure(::Type{Poisson}) = (x) -> one(x) / factorial(x) getsufficientstatistics(::Type{Poisson}) = (identity,) +getexpectationlogbasemeasure(::NaturalParametersSpace, ::Type{Poisson}) = (_) -> begin + error("Expectation of log base measure is not implemented for Poisson distribution") +end + getlogpartition(::NaturalParametersSpace, ::Type{Poisson}) = (η) -> begin (η1,) = unpack_parameters(Poisson, η) return exp(η1) diff --git a/src/distributions/rayleigh.jl b/src/distributions/rayleigh.jl index 7f1cf973..e8b3c9e1 100644 --- a/src/distributions/rayleigh.jl +++ b/src/distributions/rayleigh.jl @@ -51,6 +51,10 @@ isbasemeasureconstant(::Type{Rayleigh}) = NonConstantBaseMeasure() getbasemeasure(::Type{Rayleigh}) = identity getsufficientstatistics(::Type{Rayleigh}) = (x -> x^2,) +getexpectationlogbasemeasure(::NaturalParametersSpace, ::Type{Rayleigh}) = (η) -> begin + error("Expectation of log base measure is not implemented for Rayleigh distribution") +end + getlogpartition(::NaturalParametersSpace, ::Type{Rayleigh}) = (η) -> begin (η1,) = unpack_parameters(Rayleigh, η) return -log(-2 * η1) diff --git a/src/distributions/von_mises.jl b/src/distributions/von_mises.jl index ec4a5bd4..c38adc2a 100644 --- a/src/distributions/von_mises.jl +++ b/src/distributions/von_mises.jl @@ -82,11 +82,17 @@ isbasemeasureconstant(::Type{VonMises}) = ConstantBaseMeasure() getbasemeasure(::Type{VonMises}, _) = (x) -> inv(twoπ) getsufficientstatistics(::Type{VonMises}, _) = (cos, sin) + +getexpectationlogbasemeasure(::NaturalParametersSpace, ::Type{VonMises}, _) = (_) -> begin + return log(twoπ) +end + getgradlogpartition(::NaturalParametersSpace, ::Type{VonMises}, _) = (η) -> begin u = sqrt(dot(η, η)) same_part = besseli(1, u) / (u * besseli(0, u)) return SA[η[1] * same_part, η[2] * same_part] end + getlogpartition(::NaturalParametersSpace, ::Type{VonMises}, _) = (η) -> begin return log(besseli(0, sqrt(dot(η, η)))) end diff --git a/src/distributions/von_mises_fisher.jl b/src/distributions/von_mises_fisher.jl index c2f05445..75d6bcd0 100644 --- a/src/distributions/von_mises_fisher.jl +++ b/src/distributions/von_mises_fisher.jl @@ -74,6 +74,10 @@ isbasemeasureconstant(::Type{VonMisesFisher}) = ConstantBaseMeasure() getbasemeasure(::Type{VonMisesFisher}) = (x) -> (inv2π)^(length(x) / 2) getsufficientstatistics(::Type{VonMisesFisher}) = (identity,) +getexpectationlogbasemeasure(::NaturalParametersSpace, ::Type{VonMisesFisher}) = (η) -> begin + length(η)/2 * log(inv2π) +end + getlogpartition(::NaturalParametersSpace, ::Type{VonMisesFisher}) = (η) -> begin κ = sqrt(η' * η) p = length(η) diff --git a/src/distributions/weibull.jl b/src/distributions/weibull.jl index a25dd0a0..a5f78229 100644 --- a/src/distributions/weibull.jl +++ b/src/distributions/weibull.jl @@ -94,6 +94,10 @@ isbasemeasureconstant(::Type{Weibull}) = NonConstantBaseMeasure() getbasemeasure(::Type{Weibull}, conditioner) = x -> x^(conditioner - 1) getsufficientstatistics(::Type{Weibull}, conditioner) = (x -> x^conditioner,) +getexpectationlogbasemeasure(::NaturalParametersSpace, ::Type{Weibull}, conditioner) = (η) -> begin + error("Expectation of log base measure is not implemented for Weibull distribution") +end + getlogpartition(::NaturalParametersSpace, ::Type{Weibull}, conditioner) = (η) -> begin (η1,) = unpack_parameters(Weibull, η) return -log(-η1) - log(conditioner) diff --git a/src/distributions/wishart.jl b/src/distributions/wishart.jl index 4af9dcbe..681149e0 100644 --- a/src/distributions/wishart.jl +++ b/src/distributions/wishart.jl @@ -253,6 +253,10 @@ isbasemeasureconstant(::Type{WishartFast}) = ConstantBaseMeasure() getbasemeasure(::Type{WishartFast}) = (x) -> one(Float64) getsufficientstatistics(::Type{WishartFast}) = (chollogdet, identity) +getexpectationlogbasemeasure(::NaturalParametersSpace, ::Type{WishartFast}) = (η) -> begin + return log(1) +end + getlogpartition(::NaturalParametersSpace, ::Type{WishartFast}) = (η) -> begin η1, η2 = unpack_parameters(WishartFast, η) p = first(size(η2)) diff --git a/src/distributions/wishart_inverse.jl b/src/distributions/wishart_inverse.jl index 8d109bb5..c0114dbd 100644 --- a/src/distributions/wishart_inverse.jl +++ b/src/distributions/wishart_inverse.jl @@ -275,6 +275,10 @@ isbasemeasureconstant(::Type{InverseWishartFast}) = ConstantBaseMeasure() getbasemeasure(::Type{InverseWishartFast}) = (x) -> one(Float64) getsufficientstatistics(::Type{InverseWishartFast}) = (chollogdet, cholinv) +getexpectationlogbasemeasure(::NaturalParametersSpace, ::Type{InverseWishartFast}) = (η) -> begin + return log(1) +end + getlogpartition(::NaturalParametersSpace, ::Type{InverseWishartFast}) = (η) -> begin η1, η2 = unpack_parameters(InverseWishartFast, η) p = first(size(η2)) diff --git a/src/exponential_family.jl b/src/exponential_family.jl index 4552bbf0..6722a59d 100644 --- a/src/exponential_family.jl +++ b/src/exponential_family.jl @@ -325,11 +325,27 @@ function fisherinformation(ef::ExponentialFamilyDistribution, η = getnaturalpar return getfisherinformation(ef)(η) end +""" + expectationlogbasemeasure(distribution, η) + +Return the computed value of the expectation of the log base measure of the exponential family distribution at the point `η`. +By default `η = getnaturalparameters(ef)`. +""" +function expectationlogbasemeasure(ef::ExponentialFamilyDistribution, η = getnaturalparameters(ef)) + return getexpectationlogbasemeasure(ef)(η) +end + getbasemeasure(ef::ExponentialFamilyDistribution) = getbasemeasure(ef.attributes, ef) getbasemeasure(::Nothing, ef::ExponentialFamilyDistribution{T}) where {T} = getbasemeasure(T, getconditioner(ef)) getbasemeasure(attributes::ExponentialFamilyDistributionAttributes, ::ExponentialFamilyDistribution) = getbasemeasure(attributes) +getexpectationlogbasemeasure(ef::ExponentialFamilyDistribution) = getexpectationlogbasemeasure(ef.attributes, ef) +getexpectationlogbasemeasure(::Nothing, ef::ExponentialFamilyDistribution{T}) where {T} = + getexpectationlogbasemeasure(T, getconditioner(ef)) +getexpectationlogbasemeasure(attributes::ExponentialFamilyDistributionAttributes, ::ExponentialFamilyDistribution) = + getexpectationlogbasemeasure(attributes) + getsufficientstatistics(ef::ExponentialFamilyDistribution) = getsufficientstatistics(ef.attributes, ef) getsufficientstatistics(::Nothing, ef::ExponentialFamilyDistribution{T}) where {T} = getsufficientstatistics(T, getconditioner(ef)) @@ -416,6 +432,15 @@ For conditional exponential family distributions requires an extra `conditioner` """ getbasemeasure(::Type{T}, ::Nothing) where {T <: Distribution} = getbasemeasure(T) +""" +getexpectationlogbasemeasure(::Type{<:Distribution}, [ conditioner ]) + +A specific verion of `getexpectationlogbasemeasure` defined particularly for distribution types from `Distributions.jl` package. +Does not require an instance of the `ExponentialFamilyDistribution` and can be called directly with a specific distribution type instead. +For conditional exponential family distributions requires an extra `conditioner` argument. +""" +getexpectationlogbasemeasure(::Type{T}, ::Nothing) where {T <: Distribution} = getexpectationlogbasemeasure(NaturalParametersSpace(), T) + """ getsufficientstatistics(::Type{<:Distribution}, [ conditioner ]) @@ -675,6 +700,21 @@ Evaluates and returns the cumulative distribution function of the exponential fa """ BayesBase.cdf(ef::ExponentialFamilyDistribution{D}, x) where {D <: Distribution} = cdf(Base.convert(Distribution, ef), x) + +function _entropy(η, _logpartition, _grad_logpartition, _expectionlogbasemeasure) + return _logpartition - dot(η, _grad_logpartition) + _expectionlogbasemeasure +end + +""" + entropy(ef::ExponentialFamilyDistribution) + +Evaluates and returns the entropy of the exponential family distribution. +""" +function BayesBase.entropy(ef::ExponentialFamilyDistribution) + return _entropy(getnaturalparameters(ef), logpartition(ef), gradlogpartition(ef), expectationlogbasemeasure(ef)) +end + + BayesBase.variate_form(::Type{<:ExponentialFamilyDistribution{D}}) where {D <: Distribution} = variate_form(D) BayesBase.variate_form(::Type{<:ExponentialFamilyDistribution{V}}) where {V <: VariateForm} = V diff --git a/test/distributions/distributions_setuptests.jl b/test/distributions/distributions_setuptests.jl index 4256dcef..85938a6a 100644 --- a/test/distributions/distributions_setuptests.jl +++ b/test/distributions/distributions_setuptests.jl @@ -341,6 +341,8 @@ function run_test_gradlogpartition_properties(distribution; nsamples = 6000, tes if test_against_forwardiff @test gradient ≈ ForwardDiff.gradient((η) -> getlogpartition(ef)(η), getnaturalparameters(ef)) end + + @test entropy(ef) ≈ entropy(distribution) end function run_test_fisherinformation_against_hessian(distribution; assume_ours_faster = true, assume_no_allocations = true)