Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add manifold for Categorical distribution #15

Merged
merged 17 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"

[compat]
BayesBase = "1.3"
ExponentialFamily = "1.4.3"
ExponentialFamily = "1.5.1"
LinearAlgebra = "1.10"
Manifolds = "0.9"
ManifoldsBase = "0.15"
Expand Down
1 change: 1 addition & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ ExponentialFamilyManifolds.partition_point
ExponentialFamilyManifolds.ShiftedPositiveNumbers
ExponentialFamilyManifolds.ShiftedNegativeNumbers
ExponentialFamilyManifolds.SymmetricNegativeDefinite
ExponentialFamilyManifolds.SinglePointManifold
```

## Optimization example
Expand Down
2 changes: 2 additions & 0 deletions src/ExponentialFamilyManifolds.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ using BayesBase, ExponentialFamily, ManifoldsBase, Manifolds, Random, LinearAlge
include("symmetric_negative_definite.jl")
include("shifted_negative_numbers.jl")
include("shifted_positive_numbers.jl")
include("single_point_manifold.jl")
include("natural_manifolds.jl")

include("natural_manifolds/bernoulli.jl")
include("natural_manifolds/beta.jl")
include("natural_manifolds/binomial.jl")
include("natural_manifolds/chisq.jl")
include("natural_manifolds/categorical.jl")
include("natural_manifolds/dirichlet.jl")
include("natural_manifolds/exponential.jl")
include("natural_manifolds/gamma.jl")
Expand Down
20 changes: 20 additions & 0 deletions src/natural_manifolds/categorical.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@

"""
get_natural_manifold_base(::Type{Categorical}, dims::Tuple{Int}, conditioner=nothing)

Get the natural manifold base for the `Categorical` distribution.
"""
function get_natural_manifold_base(::Type{Categorical}, ::Tuple{}, conditioner=nothing)
return ProductManifold(
Euclidean(conditioner-1), SinglePointManifold([0])
)
end

"""
partition_point(::Type{Categorical}, dims::Tuple{Int}, p, conditioner=nothing)

Converts the `point` to a compatible representation for the natural manifold of type `Categorical`.
"""
function partition_point(::Type{Categorical}, ::Tuple{}, p, conditioner=nothing)
return ArrayPartition(view(p, 1:conditioner-1), view(p, conditioner:conditioner))
end
76 changes: 76 additions & 0 deletions src/single_point_manifold.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
using ManifoldsBase
using Random

"""
SinglePointManifold(point)

This manifold represents a set from one point.
"""
struct SinglePointManifold{T, R} <: AbstractManifold{ℝ}
point::T
representation_size::R
end

function SinglePointManifold(point::T) where {T}
return SinglePointManifold(point, size(point))
end

function Base.show(io::IO, M::SinglePointManifold)
print(io, "SinglePointManifold(", M.point, ")")
end

ManifoldsBase.manifold_dimension(::SinglePointManifold) = 0
ManifoldsBase.representation_size(M::SinglePointManifold) = M.representation_size
ManifoldsBase.injectivity_radius(M::SinglePointManifold) = zero(eltype(M.point))

ManifoldsBase.default_retraction_method(::SinglePointManifold) = ExponentialRetraction()

function ManifoldsBase.check_point(M::SinglePointManifold, p; kwargs...)
if !(p ≈ M.point)
return DomainError(p, "The point $(p) does not lie on $(M), which contains only $(M.point).")
end
return nothing
end

function ManifoldsBase.check_vector(M::SinglePointManifold, p, X; kwargs...)
if !iszero(X) && size(M.point) == size(X)
return DomainError(X, "The tangent space of $(M) contains only the zero vector.")
end
return nothing
end

ManifoldsBase.is_flat(::SinglePointManifold) = true

ManifoldsBase.embed(::SinglePointManifold, p) = p
ManifoldsBase.embed(::SinglePointManifold, p, X) = X

function ManifoldsBase.inner(::SinglePointManifold, p, X, Y)
return zero(eltype(X))
end

function ManifoldsBase.exp!(M::SinglePointManifold, q, p, X, t::Number=1)
q .= M.point
return q
end

function ManifoldsBase.log!(::SinglePointManifold, X, p, q)
X .= zero(eltype(X))
return X
end

function ManifoldsBase.project!(::SinglePointManifold, Y, p, X)
fill!(Y, zero(eltype(Y)))
return Y
end

function ManifoldsBase.zero_vector!(::SinglePointManifold, X, p)
return fill!(X, zero(eltype(X)))
end

function Random.rand(M::SinglePointManifold; kwargs...)
return rand(Random.default_rng(), M; kwargs...)
end

function Random.rand(rng::AbstractRNG, M::SinglePointManifold; kwargs...)
return M.point
end
40 changes: 40 additions & 0 deletions test/natural_manifolds/categorical_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
@testitem "Check `Categorical` natural manifold" begin
include("natural_manifolds_setuptests.jl")

test_natural_manifold() do rng
p = rand(rng, 10)
normalize!(p, 1)
return Categorical(p)
end
end

@testitem "Check that optimization work on Categorical" begin
include("natural_manifolds_setuptests.jl")

using Manopt, ForwardDiff
using BayesBase

rng = StableRNG(42)
p = rand(StableRNG(42), 10)
normalize!(p, 1)
distribution = Categorical(p)
sample = rand(rng, distribution)
dims = size(sample)
ef = convert(ExponentialFamilyDistribution, distribution)
T = ExponentialFamily.exponential_family_typetag(ef)
M = get_natural_manifold(T, dims, getconditioner(ef))

function f(M, p)
ef = convert(ExponentialFamilyDistribution, M, p)
η = getnaturalparameters(ef)
return (mean(η) - 0.5)^2
end

function g(M, p)
return project(M, p, 2 * p ./ 10)
end

q = gradient_descent(M, f, g, rand(rng, M))
@test q ∈ M
@test mean(q) ≈ 0.5 atol = 1e-1
end
108 changes: 108 additions & 0 deletions test/single_point_manifold_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
@testitem "Generic properties of SinglePointManifold" begin
import ManifoldsBase: check_point, check_vector, embed, representation_size, injectivity_radius, get_embedding, is_flat, inner, manifold_dimension
import ExponentialFamilyManifolds: SinglePointManifold
using ManifoldsBase, Static, StaticArrays, JET, Manifolds
using StableRNGs
using Random

rng = StableRNG(42)


points = [
0,
0.0,
0.0f0,
1,
1.0,
1.0f0,
-1,
2,
π,
rand(),
randn()
]

for p in points
M = SinglePointManifold(p)

@test repr(M) == "SinglePointManifold($p)"

@test @inferred(representation_size(M)) === ()
@test @inferred(manifold_dimension(M)) === 0
@test @inferred(is_flat(M)) === true
@test injectivity_radius(M) ≈ 0
@test default_retraction_method(M) == ExponentialRetraction()

@test_throws MethodError get_embedding(M)

@test check_point(M, p) === nothing
@test check_point(M, p + 1) isa DomainError
@test check_point(M, p - 1) isa DomainError

@test check_vector(M, p, 0) === nothing
@test check_vector(M, p, 1) isa DomainError
@test check_vector(M, p, -1) isa DomainError

@test @eval(@allocated(representation_size($M))) === 0
@test @eval(@allocated(manifold_dimension($M))) === 0
@test @eval(@allocated(is_flat($M))) === 0

X = [1]
Y = [1]

@test_opt inner(M, p, X, Y)
@test_opt inner(M, p, 0, 0)

@test embed(M, p) == p
@test embed(M, p, 0) == 0
@test inner(M, p, 0, 0) == 0
end

vector_points = [[1], [1, 2], [1, 2, 3]]

for p in vector_points
M = SinglePointManifold(p)
q = similar(p)
X = zero_vector(M, p)
@test ManifoldsBase.exp!(M, q, p, X) == p
@test ManifoldsBase.log!(M, X, p, p) == zero_vector(M, p)
@test ManifoldsBase.log(M, p, p) == zero_vector(M, p)
@test ManifoldsBase.project!(M, similar(X), p, similar(X)) == zero_vector(M, p)
@test rand(rng, M) ∈ M
@test rand(M) ∈ M
end
end

@testitem "Simple manifold optimization problem #1" begin
using Manopt, ForwardDiff, Static, StableRNGs, LinearAlgebra

import ExponentialFamilyManifolds: SinglePointManifold

for a in (2.0, 3.0),
b in (10.0, 5.0),
c in (1.0, 10.0, -1.0),
eps in (1e-4, 1e-5, 1e-8, 1e-10),
stepsize in (ConstantStepsize(0.1), ConstantStepsize(0.01), ConstantStepsize(0.001))

f(M, x) = (a .* x .^ 2 .+ b .* x .+ c)[1]
grad_f(M, x) = 2 .* a .* x .+ b

rng = StableRNG(42)

for s in [0, 0.0, 10]
M = SinglePointManifold(s)
p0 = rand(rng, M)

q1 = gradient_descent(
M,
f,
grad_f,
p0;
stepsize=stepsize,
stopping_criterion=StopAfterIteration(1)
)

@test q1 ≈ s
end
end
end
Loading