Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add entropy for ef distribution #188

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/distributions/bernoulli.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ isbasemeasureconstant(::Type{Bernoulli}) = ConstantBaseMeasure()
getbasemeasure(::Type{Bernoulli}) = (x) -> oneunit(x)
getsufficientstatistics(::Type{Bernoulli}) = (identity,)

getexpectationlogbasemeasure(::NaturalParametersSpace, ::Type{Bernoulli}) = (_) -> begin
return log(1)
end

getlogpartition(::NaturalParametersSpace, ::Type{Bernoulli}) = (η) -> begin
(η₁,) = unpack_parameters(Bernoulli, η)
return -log(logistic(-η₁))
Expand Down
4 changes: 4 additions & 0 deletions src/distributions/beta.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ isbasemeasureconstant(::Type{Beta}) = ConstantBaseMeasure()
getbasemeasure(::Type{Beta}) = (x) -> oneunit(x)
getsufficientstatistics(::Type{Beta}) = (log, mirrorlog)

getexpectationlogbasemeasure(::NaturalParametersSpace, ::Type{Beta}) = (_) -> begin
return log(1)
end

getlogpartition(::NaturalParametersSpace, ::Type{Beta}) = (η) -> begin
(η₁, η₂) = unpack_parameters(Beta, η)
return logbeta(η₁ + one(η₁), η₂ + one(η₂))
Expand Down
4 changes: 4 additions & 0 deletions src/distributions/binomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ isbasemeasureconstant(::Type{Binomial}) = NonConstantBaseMeasure()
getbasemeasure(::Type{Binomial}, ntrials) = Base.Fix1(binomial, ntrials)
getsufficientstatistics(::Type{Binomial}, _) = (identity,)

getexpectationlogbasemeasure(::NaturalParametersSpace, ::Type{Binomial}) = (_) -> begin
error("Expectation of log base measure is not defined for Binomial distribution")
end

getlogpartition(::NaturalParametersSpace, ::Type{Binomial}, ntrials) = (η) -> begin
(η₁,) = unpack_parameters(Binomial, η)
return ntrials * log1pexp(η₁)
Expand Down
4 changes: 4 additions & 0 deletions src/distributions/categorical.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ isbasemeasureconstant(::Type{Categorical}) = ConstantBaseMeasure()
getbasemeasure(::Type{Categorical}, _) = (x) -> oneunit(x)
getsufficientstatistics(::Type{Categorical}, conditioner) = ((x) -> OneElement(x, conditioner),)

getexpectationlogbasemeasure(::NaturalParametersSpace, ::Type{Categorical}) = (_) -> begin
log(1)
end

getlogpartition(::NaturalParametersSpace, ::Type{Categorical}, conditioner) =
(η) -> begin
if (conditioner !== length(η))
Expand Down
4 changes: 4 additions & 0 deletions src/distributions/chi_squared.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ getlogpartition(::NaturalParametersSpace, ::Type{Chisq}) = (η) -> begin
return loggamma(η1 + o) + (η1 + o) * logtwo
end

getexpectationlogbasemeasure(::NaturalParametersSpace, ::Type{Chisq}) = (η) -> begin
η/2
end

getgradlogpartition(::NaturalParametersSpace, ::Type{Chisq}) = (η) -> begin
(η1,) = unpack_parameters(Chisq, η)
return SA[digamma(η1 + one(η1)) + logtwo]
Expand Down
4 changes: 4 additions & 0 deletions src/distributions/dirichlet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ isbasemeasureconstant(::Type{Dirichlet}) = ConstantBaseMeasure()
getbasemeasure(::Type{Dirichlet}) = (x) -> one(Float64)
getsufficientstatistics(::Type{Dirichlet}) = (x -> vmap(log, x),)

getexpectationlogbasemeasure(::NaturalParametersSpace, ::Type{Dirichlet}) = (_) -> begin
log(1)
end

getlogpartition(::NaturalParametersSpace, ::Type{Dirichlet}) = (η) -> begin
(η1,) = unpack_parameters(Dirichlet, η)
firstterm = mapreduce(x -> loggamma(x + 1), +, η1)
Expand Down
4 changes: 4 additions & 0 deletions src/distributions/erlang.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ isbasemeasureconstant(::Type{Erlang}) = ConstantBaseMeasure()
getbasemeasure(::Type{Erlang}) = (x) -> one(x)
getsufficientstatistics(::Type{Erlang}) = (log, identity)

getexpectationlogbasemeasure(::NaturalParametersSpace, ::Type{Erlang}) = (_) -> begin
log(1)
end

getlogpartition(::NaturalParametersSpace, ::Type{Erlang}) = (η) -> begin
(η1, η2) = unpack_parameters(Erlang, η)

Expand Down
4 changes: 4 additions & 0 deletions src/distributions/exponential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ isbasemeasureconstant(::Type{Exponential}) = ConstantBaseMeasure()
getbasemeasure(::Type{Exponential}) = (x) -> oneunit(x)
getsufficientstatistics(::Type{Exponential}) = (identity,)

getexpectationlogbasemeasure(::NaturalParametersSpace, ::Type{Exponential}) = (_) -> begin
return log(1)
end

getlogpartition(::NaturalParametersSpace, ::Type{Exponential}) = (η) -> begin
(η₁,) = unpack_parameters(Exponential, η)
return -log(-η₁)
Expand Down
4 changes: 4 additions & 0 deletions src/distributions/gamma_family/gamma_family.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ isbasemeasureconstant(::Type{Gamma}) = ConstantBaseMeasure()
getbasemeasure(::Type{Gamma}) = (x) -> oneunit(x)
getsufficientstatistics(::Type{Gamma}) = (log, identity)

getexpectationlogbasemeasure(::NaturalParametersSpace, ::Type{Gamma}) = (_) -> begin
return log(1)
end

getlogpartition(::NaturalParametersSpace, ::Type{Gamma}) = (η) -> begin
(η₁, η₂) = unpack_parameters(Gamma, η)
return loggamma(η₁ + one(η₁)) - (η₁ + one(η₁)) * log(-η₂)
Expand Down
4 changes: 4 additions & 0 deletions src/distributions/gamma_inverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ isbasemeasureconstant(::Type{GammaInverse}) = ConstantBaseMeasure()
getbasemeasure(::Type{GammaInverse}) = (x) -> oneunit(x)
getsufficientstatistics(::Type{GammaInverse}) = (log, inv)

getexpectationlogbasemeasure(::NaturalParametersSpace, ::Type{GammaInverse}) = (_) -> begin
log(1)
end

getlogpartition(::NaturalParametersSpace, ::Type{GammaInverse}) = (η) -> begin
(η₁, η₂) = unpack_parameters(GammaInverse, η)
return loggamma(-η₁ - one(η₁)) - (-η₁ - one(η₁)) * log(-η₂)
Expand Down
4 changes: 4 additions & 0 deletions src/distributions/geometric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ isbasemeasureconstant(::Type{Geometric}) = ConstantBaseMeasure()
getbasemeasure(::Type{Geometric}) = (x) -> one(x)
getsufficientstatistics(::Type{Geometric}) = (identity,)

getexpectationlogbasemeasure(::NaturalParametersSpace, ::Type{Geometric}) = (_) -> begin
log(1)
end

getlogpartition(::NaturalParametersSpace, ::Type{Geometric}) = (η) -> begin
(η1,) = unpack_parameters(Geometric, η)
return -log(one(η1) - exp(η1))
Expand Down
4 changes: 4 additions & 0 deletions src/distributions/laplace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,10 @@ getsufficientstatistics(::Type{Laplace}, conditioner) = (
(x) -> abs(x - conditioner),
)

getexpectationlogbasemeasure(::Type{Laplace}, conditioner) = (_) -> begin
log(1)
end

getlogpartition(::NaturalParametersSpace, ::Type{Laplace}, _) = (η) -> begin
(η₁,) = unpack_parameters(Laplace, η)
return log(-2 / η₁)
Expand Down
4 changes: 4 additions & 0 deletions src/distributions/lognormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ isbasemeasureconstant(::Type{LogNormal}) = ConstantBaseMeasure()
getbasemeasure(::Type{LogNormal}) = (x) -> invsqrt2π
getsufficientstatistics(::Type{LogNormal}) = (log, x -> abs2(log(x)))

getexpectationlogbasemeasure(::NaturalParametersSpace, ::Type{LogNormal}) = (_) -> begin
0.5*log2π
end

getlogpartition(::NaturalParametersSpace, ::Type{LogNormal}) = (η) -> begin
(η₁, η₂) = unpack_parameters(LogNormal, η)
return -(η₁ + 1)^2 / (4η₂) - log(-2η₂) / 2
Expand Down
4 changes: 4 additions & 0 deletions src/distributions/matrix_dirichlet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,10 @@ isbasemeasureconstant(::Type{MatrixDirichlet}) = ConstantBaseMeasure()
getbasemeasure(::Type{MatrixDirichlet}) = (x) -> one(Float64)
getsufficientstatistics(::Type{MatrixDirichlet}) = (x -> vmap(log, x),)

getexpectationlogbasemeasure(::NaturalParametersSpace, ::Type{MatrixDirichlet}) = (_) -> begin
log(1)
end

getlogpartition(::NaturalParametersSpace, ::Type{MatrixDirichlet}) =
(η) -> begin
(η1,) = unpack_parameters(MatrixDirichlet, η)
Expand Down
4 changes: 4 additions & 0 deletions src/distributions/mv_normal_wishart.jl
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,10 @@ function getsufficientstatistics(::Type{MvNormalWishart})
)
end

getexpectationlogbasemeasure(::NaturalParametersSpace, ::Type{MvNormalWishart}) = (_) -> begin
log(1)
end

getlogpartition(::NaturalParametersSpace, ::Type{MvNormalWishart}) = (η) -> begin
η1, η2, η3, η4 = unpack_parameters(MvNormalWishart, η)
d = length(η1)
Expand Down
4 changes: 4 additions & 0 deletions src/distributions/negative_binomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ isbasemeasureconstant(::Type{NegativeBinomial}) = NonConstantBaseMeasure()
getbasemeasure(::Type{NegativeBinomial}, conditioner) = (x) -> binomial(Int(x + conditioner - 1), x)
getsufficientstatistics(::Type{NegativeBinomial}, conditioner) = (identity,)

getexpectationlogbasemeasure(::NaturalParametersSpace, ::Type{NegativeBinomial}, conditioner) = (η) -> begin
error("Expectation of log base measure is not implemented for NegativeBinomial distribution")
end

getlogpartition(::NaturalParametersSpace, ::Type{NegativeBinomial}, conditioner) = (η) -> begin
(η1,) = unpack_parameters(NegativeBinomial, η)
return -conditioner * log(one(η1) - exp(η1))
Expand Down
9 changes: 9 additions & 0 deletions src/distributions/normal_family/normal_family.jl
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,10 @@ isbasemeasureconstant(::Type{NormalMeanVariance}) = ConstantBaseMeasure()
getbasemeasure(::Type{NormalMeanVariance}) = (x) -> convert(typeof(x), invsqrt2π)
getsufficientstatistics(::Type{NormalMeanVariance}) = (identity, abs2)

getexpectationlogbasemeasure(::NaturalParametersSpace, ::Type{NormalMeanVariance}) = (η) -> begin
0.5*log2π
end

getlogpartition(::NaturalParametersSpace, ::Type{NormalMeanVariance}) = (η) -> begin
(η₁, η₂) = unpack_parameters(NormalMeanVariance, η)
return -abs2(η₁) / 4η₂ - log(-2η₂) / 2
Expand Down Expand Up @@ -691,6 +695,11 @@ isbasemeasureconstant(::Type{MvNormalMeanCovariance}) = ConstantBaseMeasure()
getbasemeasure(::Type{MvNormalMeanCovariance}) = (x) -> (2π)^(length(x) / -2)
getsufficientstatistics(::Type{MvNormalMeanCovariance}) = (identity, (x) -> x * x')

getexpectationlogbasemeasure(::NaturalParametersSpace, ::Type{MvNormalMeanCovariance}) = (η) -> begin
k = div(-1 + isqrt(1 + 4 * length(η)), 2)
return k * log2π / 2
end

getlogpartition(::NaturalParametersSpace, ::Type{MvNormalMeanCovariance}) = (η) -> begin
(η₁, η₂) = unpack_parameters(MvNormalMeanCovariance, η)
k = length(η₁)
Expand Down
4 changes: 4 additions & 0 deletions src/distributions/normal_gamma.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,10 @@ getbasemeasure(::Type{NormalGamma}) = (x) -> invsqrt2π
# x is a 2d vector where first dimension is mean and the second dimension is precision component
getsufficientstatistics(::Type{NormalGamma}) = (x -> x[1] * x[2], x -> x[2] * x[1]^2, x -> log(x[2]), x -> x[2])

getexpectationlogbasemeasure(::NaturalParametersSpace, ::Type{NormalGamma}) = (_) -> begin
0.5*log2π
end

getlogpartition(::NaturalParametersSpace, ::Type{NormalGamma}) = (η) -> begin
(η1, η2, η3, η4) = unpack_parameters(NormalGamma, η)
η3half = η3 + (1 / 2)
Expand Down
6 changes: 5 additions & 1 deletion src/distributions/pareto.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,11 @@ isbasemeasureconstant(::Type{Pareto}) = ConstantBaseMeasure()
getbasemeasure(::Type{Pareto}, _) = (x) -> oneunit(x)
getsufficientstatistics(::Type{Pareto}, conditioner) = (log,)

getlogpartition(::NaturalParametersSpace, ::Type{Pareto}, conditioner) = (η) -> begin
getexpectationlogbasemeasure(::Type{Pareto}, _) = (_) -> begin
log(1)
end

getlogpartition(::Type{Pareto}, conditioner) = (η) -> begin
(η1,) = unpack_parameters(Pareto, η)
return log(conditioner^(one(η1) + η1) / (-one(η1) - η1))
end
Expand Down
4 changes: 4 additions & 0 deletions src/distributions/poisson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ isbasemeasureconstant(::Type{Poisson}) = NonConstantBaseMeasure()
getbasemeasure(::Type{Poisson}) = (x) -> one(x) / factorial(x)
getsufficientstatistics(::Type{Poisson}) = (identity,)

getexpectationlogbasemeasure(::NaturalParametersSpace, ::Type{Poisson}) = (_) -> begin
error("Expectation of log base measure is not implemented for Poisson distribution")
end

getlogpartition(::NaturalParametersSpace, ::Type{Poisson}) = (η) -> begin
(η1,) = unpack_parameters(Poisson, η)
return exp(η1)
Expand Down
4 changes: 4 additions & 0 deletions src/distributions/rayleigh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ isbasemeasureconstant(::Type{Rayleigh}) = NonConstantBaseMeasure()
getbasemeasure(::Type{Rayleigh}) = identity
getsufficientstatistics(::Type{Rayleigh}) = (x -> x^2,)

getexpectationlogbasemeasure(::NaturalParametersSpace, ::Type{Rayleigh}) = (η) -> begin
error("Expectation of log base measure is not implemented for Rayleigh distribution")
end

getlogpartition(::NaturalParametersSpace, ::Type{Rayleigh}) = (η) -> begin
(η1,) = unpack_parameters(Rayleigh, η)
return -log(-2 * η1)
Expand Down
6 changes: 6 additions & 0 deletions src/distributions/von_mises.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,17 @@ isbasemeasureconstant(::Type{VonMises}) = ConstantBaseMeasure()

getbasemeasure(::Type{VonMises}, _) = (x) -> inv(twoπ)
getsufficientstatistics(::Type{VonMises}, _) = (cos, sin)

getexpectationlogbasemeasure(::NaturalParametersSpace, ::Type{VonMises}, _) = (_) -> begin
return log(twoπ)
end

getgradlogpartition(::NaturalParametersSpace, ::Type{VonMises}, _) = (η) -> begin
u = sqrt(dot(η, η))
same_part = besseli(1, u) / (u * besseli(0, u))
return SA[η[1] * same_part, η[2] * same_part]
end

getlogpartition(::NaturalParametersSpace, ::Type{VonMises}, _) = (η) -> begin
return log(besseli(0, sqrt(dot(η, η))))
end
Expand Down
4 changes: 4 additions & 0 deletions src/distributions/von_mises_fisher.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ isbasemeasureconstant(::Type{VonMisesFisher}) = ConstantBaseMeasure()
getbasemeasure(::Type{VonMisesFisher}) = (x) -> (inv2π)^(length(x) / 2)
getsufficientstatistics(::Type{VonMisesFisher}) = (identity,)

getexpectationlogbasemeasure(::NaturalParametersSpace, ::Type{VonMisesFisher}) = (η) -> begin
length(η)/2 * log(inv2π)
end

getlogpartition(::NaturalParametersSpace, ::Type{VonMisesFisher}) = (η) -> begin
κ = sqrt(η' * η)
p = length(η)
Expand Down
4 changes: 4 additions & 0 deletions src/distributions/weibull.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ isbasemeasureconstant(::Type{Weibull}) = NonConstantBaseMeasure()
getbasemeasure(::Type{Weibull}, conditioner) = x -> x^(conditioner - 1)
getsufficientstatistics(::Type{Weibull}, conditioner) = (x -> x^conditioner,)

getexpectationlogbasemeasure(::NaturalParametersSpace, ::Type{Weibull}, conditioner) = (η) -> begin
error("Expectation of log base measure is not implemented for Weibull distribution")
end

getlogpartition(::NaturalParametersSpace, ::Type{Weibull}, conditioner) = (η) -> begin
(η1,) = unpack_parameters(Weibull, η)
return -log(-η1) - log(conditioner)
Expand Down
4 changes: 4 additions & 0 deletions src/distributions/wishart.jl
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,10 @@ isbasemeasureconstant(::Type{WishartFast}) = ConstantBaseMeasure()
getbasemeasure(::Type{WishartFast}) = (x) -> one(Float64)
getsufficientstatistics(::Type{WishartFast}) = (chollogdet, identity)

getexpectationlogbasemeasure(::NaturalParametersSpace, ::Type{WishartFast}) = (η) -> begin
return log(1)
end

getlogpartition(::NaturalParametersSpace, ::Type{WishartFast}) = (η) -> begin
η1, η2 = unpack_parameters(WishartFast, η)
p = first(size(η2))
Expand Down
4 changes: 4 additions & 0 deletions src/distributions/wishart_inverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,10 @@ isbasemeasureconstant(::Type{InverseWishartFast}) = ConstantBaseMeasure()
getbasemeasure(::Type{InverseWishartFast}) = (x) -> one(Float64)
getsufficientstatistics(::Type{InverseWishartFast}) = (chollogdet, cholinv)

getexpectationlogbasemeasure(::NaturalParametersSpace, ::Type{InverseWishartFast}) = (η) -> begin
return log(1)
end

getlogpartition(::NaturalParametersSpace, ::Type{InverseWishartFast}) = (η) -> begin
η1, η2 = unpack_parameters(InverseWishartFast, η)
p = first(size(η2))
Expand Down
40 changes: 40 additions & 0 deletions src/exponential_family.jl
Original file line number Diff line number Diff line change
Expand Up @@ -325,11 +325,27 @@ function fisherinformation(ef::ExponentialFamilyDistribution, η = getnaturalpar
return getfisherinformation(ef)(η)
end

"""
expectationlogbasemeasure(distribution, η)

Return the computed value of the expectation of the log base measure of the exponential family distribution at the point `η`.
By default `η = getnaturalparameters(ef)`.
"""
function expectationlogbasemeasure(ef::ExponentialFamilyDistribution, η = getnaturalparameters(ef))
return getexpectationlogbasemeasure(ef)(η)
end

getbasemeasure(ef::ExponentialFamilyDistribution) = getbasemeasure(ef.attributes, ef)
getbasemeasure(::Nothing, ef::ExponentialFamilyDistribution{T}) where {T} = getbasemeasure(T, getconditioner(ef))
getbasemeasure(attributes::ExponentialFamilyDistributionAttributes, ::ExponentialFamilyDistribution) =
getbasemeasure(attributes)

getexpectationlogbasemeasure(ef::ExponentialFamilyDistribution) = getexpectationlogbasemeasure(ef.attributes, ef)
getexpectationlogbasemeasure(::Nothing, ef::ExponentialFamilyDistribution{T}) where {T} =
getexpectationlogbasemeasure(T, getconditioner(ef))
getexpectationlogbasemeasure(attributes::ExponentialFamilyDistributionAttributes, ::ExponentialFamilyDistribution) =
getexpectationlogbasemeasure(attributes)

getsufficientstatistics(ef::ExponentialFamilyDistribution) = getsufficientstatistics(ef.attributes, ef)
getsufficientstatistics(::Nothing, ef::ExponentialFamilyDistribution{T}) where {T} =
getsufficientstatistics(T, getconditioner(ef))
Expand Down Expand Up @@ -416,6 +432,15 @@ For conditional exponential family distributions requires an extra `conditioner`
"""
getbasemeasure(::Type{T}, ::Nothing) where {T <: Distribution} = getbasemeasure(T)

"""
getexpectationlogbasemeasure(::Type{<:Distribution}, [ conditioner ])

A specific verion of `getexpectationlogbasemeasure` defined particularly for distribution types from `Distributions.jl` package.
Does not require an instance of the `ExponentialFamilyDistribution` and can be called directly with a specific distribution type instead.
For conditional exponential family distributions requires an extra `conditioner` argument.
"""
getexpectationlogbasemeasure(::Type{T}, ::Nothing) where {T <: Distribution} = getexpectationlogbasemeasure(NaturalParametersSpace(), T)

"""
getsufficientstatistics(::Type{<:Distribution}, [ conditioner ])

Expand Down Expand Up @@ -675,6 +700,21 @@ Evaluates and returns the cumulative distribution function of the exponential fa
"""
BayesBase.cdf(ef::ExponentialFamilyDistribution{D}, x) where {D <: Distribution} = cdf(Base.convert(Distribution, ef), x)


function _entropy(η, _logpartition, _grad_logpartition, _expectionlogbasemeasure)
return _logpartition - dot(η, _grad_logpartition) + _expectionlogbasemeasure
end

"""
entropy(ef::ExponentialFamilyDistribution)

Evaluates and returns the entropy of the exponential family distribution.
"""
function BayesBase.entropy(ef::ExponentialFamilyDistribution)
return _entropy(getnaturalparameters(ef), logpartition(ef), gradlogpartition(ef), expectationlogbasemeasure(ef))
end


BayesBase.variate_form(::Type{<:ExponentialFamilyDistribution{D}}) where {D <: Distribution} = variate_form(D)
BayesBase.variate_form(::Type{<:ExponentialFamilyDistribution{V}}) where {V <: VariateForm} = V

Expand Down
2 changes: 2 additions & 0 deletions test/distributions/distributions_setuptests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,8 @@ function run_test_gradlogpartition_properties(distribution; nsamples = 6000, tes
if test_against_forwardiff
@test gradient ≈ ForwardDiff.gradient((η) -> getlogpartition(ef)(η), getnaturalparameters(ef))
end

@test entropy(ef) ≈ entropy(distribution)
end

function run_test_fisherinformation_against_hessian(distribution; assume_ours_faster = true, assume_no_allocations = true)
Expand Down
Loading