diff --git a/src/ExponentialFamilyManifolds.jl b/src/ExponentialFamilyManifolds.jl index 2a82cad..5a38fa0 100644 --- a/src/ExponentialFamilyManifolds.jl +++ b/src/ExponentialFamilyManifolds.jl @@ -16,6 +16,7 @@ include("natural_manifolds/categorical.jl") include("natural_manifolds/dirichlet.jl") include("natural_manifolds/exponential.jl") include("natural_manifolds/gamma.jl") +include("natural_manifolds/inverse_gamma.jl") include("natural_manifolds/geometric.jl") include("natural_manifolds/laplace.jl") include("natural_manifolds/lognormal.jl") diff --git a/src/natural_manifolds/inverse_gamma.jl b/src/natural_manifolds/inverse_gamma.jl new file mode 100644 index 0000000..67969c0 --- /dev/null +++ b/src/natural_manifolds/inverse_gamma.jl @@ -0,0 +1,19 @@ +""" + get_natural_manifold_base(::Type{ExponentialFamily.GammaInverse}, ::Tuple{}, conditioner=nothing) + +Get the natural manifold base for the `ExponentialFamily.GammaInverse` distribution. +""" +function get_natural_manifold_base(::Type{ExponentialFamily.GammaInverse}, ::Tuple{}, conditioner=nothing) + return ProductManifold( + ShiftedNegativeNumbers(static(-1)), ShiftedNegativeNumbers(static(0)) + ) +end + +""" + partition_point(::Type{ExponentialFamily.GammaInverse}, ::Tuple{}, p, conditioner=nothing) + +Converts the `point` to a compatible representation for the natural manifold of type `ExponentialFamily.GammaInverse`. +""" +function partition_point(::Type{ExponentialFamily.GammaInverse}, ::Tuple{}, p, conditioner=nothing) + return ArrayPartition(view(p, 1:1), view(p, 2:2)) +end \ No newline at end of file diff --git a/test/natural_manifolds/gamma_inverse_tests.jl b/test/natural_manifolds/gamma_inverse_tests.jl new file mode 100644 index 0000000..37c86b1 --- /dev/null +++ b/test/natural_manifolds/gamma_inverse_tests.jl @@ -0,0 +1,7 @@ +@testitem "Check `GammaInverse` natural manifold" begin + include("natural_manifolds_setuptests.jl") + + test_natural_manifold() do rng + return ExponentialFamily.InverseGamma(10rand(rng), 10rand(rng)) + end +end \ No newline at end of file