From e1bfcda2006090aa63347f162883ad58622b9b1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C4=B0smail=20=C5=9Een=C3=B6z?= Date: Wed, 10 Jul 2024 13:37:10 +0300 Subject: [PATCH 01/14] add manifold for Categorical distribution --- src/ExponentialFamilyManifolds.jl | 1 + src/natural_manifolds/categorical.jl | 19 +++++++++++++++++++ test/natural_manifolds/categorical_tests.jl | 9 +++++++++ 3 files changed, 29 insertions(+) create mode 100644 src/natural_manifolds/categorical.jl create mode 100644 test/natural_manifolds/categorical_tests.jl diff --git a/src/ExponentialFamilyManifolds.jl b/src/ExponentialFamilyManifolds.jl index 390e9ff..0079dce 100644 --- a/src/ExponentialFamilyManifolds.jl +++ b/src/ExponentialFamilyManifolds.jl @@ -11,6 +11,7 @@ include("natural_manifolds/bernoulli.jl") include("natural_manifolds/beta.jl") include("natural_manifolds/binomial.jl") include("natural_manifolds/chisq.jl") +include("natural_manifolds/categorical.jl") include("natural_manifolds/dirichlet.jl") include("natural_manifolds/exponential.jl") include("natural_manifolds/gamma.jl") diff --git a/src/natural_manifolds/categorical.jl b/src/natural_manifolds/categorical.jl new file mode 100644 index 0000000..583f323 --- /dev/null +++ b/src/natural_manifolds/categorical.jl @@ -0,0 +1,19 @@ + +""" + get_natural_manifold_base(::Type{Categorical}, dims::Tuple{Int}, conditioner=nothing) + +Get the natural manifold base for the `Categorical` distribution. +""" +function get_natural_manifold_base(::Type{Categorical}, ::Tuple{}, conditioner=nothing) + return Euclidean(conditioner) +end + +""" + partition_point(::Type{Categorical}, dims::Tuple{Int}, p, conditioner=nothing) + +Converts the `point` to a compatible representation for the natural manifold of type `Categorical`. +""" +function partition_point(::Type{Categorical}, ::Tuple{}, p, conditioner=nothing) + # See comment in `get_natural_manifold_base` for `Categorical` + return ArrayPartition(p) +end \ No newline at end of file diff --git a/test/natural_manifolds/categorical_tests.jl b/test/natural_manifolds/categorical_tests.jl new file mode 100644 index 0000000..989859b --- /dev/null +++ b/test/natural_manifolds/categorical_tests.jl @@ -0,0 +1,9 @@ +@testitem "Check `Categorical` natural manifold" begin + include("natural_manifolds_setuptests.jl") + + test_natural_manifold() do rng + p = rand(rng, 10) + normalize!(p, 1) + return Categorical(p) + end +end \ No newline at end of file From ddd97e1d22390e7aab8a688d6ae14ea11c5c1465 Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Mon, 22 Jul 2024 15:42:43 +0200 Subject: [PATCH 02/14] feat: add SinglePointManifold --- src/ExponentialFamilyManifolds.jl | 1 + src/SinglePointManifold.jl | 68 +++++++++++++++++++++++++++++ test/single_point_manifold_tests.jl | 67 ++++++++++++++++++++++++++++ 3 files changed, 136 insertions(+) create mode 100644 src/SinglePointManifold.jl create mode 100644 test/single_point_manifold_tests.jl diff --git a/src/ExponentialFamilyManifolds.jl b/src/ExponentialFamilyManifolds.jl index 390e9ff..cf1fbad 100644 --- a/src/ExponentialFamilyManifolds.jl +++ b/src/ExponentialFamilyManifolds.jl @@ -5,6 +5,7 @@ using BayesBase, ExponentialFamily, ManifoldsBase, Manifolds, Random, LinearAlge include("symmetric_negative_definite.jl") include("shifted_negative_numbers.jl") include("shifted_positive_numbers.jl") +include("SinglePointManifold.jl") include("natural_manifolds.jl") include("natural_manifolds/bernoulli.jl") diff --git a/src/SinglePointManifold.jl b/src/SinglePointManifold.jl new file mode 100644 index 0000000..9386b8e --- /dev/null +++ b/src/SinglePointManifold.jl @@ -0,0 +1,68 @@ +using ManifoldsBase +using Random + +""" + SymmetricNegativeDefinite(point) + +This manifold represents a set from one point. +""" +struct SinglePointManifold{T, R} <: AbstractManifold{ℝ} + point::T + representation_size::R +end + +function SinglePointManifold(point::T) where {T} + return SinglePointManifold(point, size(point)) +end + +function Base.show(io::IO, M::SinglePointManifold) + print(io, "SinglePointManifold(", M.point, ")") +end + +ManifoldsBase.manifold_dimension(::SinglePointManifold) = 0 +ManifoldsBase.representation_size(M::SinglePointManifold) = M.representation_size +ManifoldsBase.injectivity_radius(M::SinglePointManifold) = zero(eltype(M.point)) + +ManifoldsBase.default_retraction_method(::SinglePointManifold) = ExponentialRetraction() + +function ManifoldsBase.check_point(M::SinglePointManifold, p; kwargs...) + if p != M.point + return DomainError(p, "The point $(p) does not lie on $(M), which contains only $(M.point).") + end + return nothing +end + +function ManifoldsBase.check_vector(M::SinglePointManifold, p, X; kwargs...) + if !iszero(X) && size(M.point) == size(X) + return DomainError(X, "The tangent space of $(M) contains only the zero vector.") + end + return nothing +end + +ManifoldsBase.is_flat(::SinglePointManifold) = true + +ManifoldsBase.embed(::SinglePointManifold, p) = p +ManifoldsBase.embed(::SinglePointManifold, p, X) = X + +function ManifoldsBase.inner(::SinglePointManifold, p, X, Y) + return zero(eltype(X)) +end + +function ManifoldsBase.exp!(M::SinglePointManifold, q, p, X, t::Number=1) + q .= M.point + return q +end + +function ManifoldsBase.log!(::SinglePointManifold, X, p, q) + X .= zero(eltype(X)) + return X +end + +function ManifoldsBase.project!(::SinglePointManifold, Y, p, X) + fill!(Y, zero(eltype(Y))) + return Y +end + +function ManifoldsBase.zero_vector!(::SinglePointManifold, X, p) + return fill!(X, zero(eltype(X))) +end \ No newline at end of file diff --git a/test/single_point_manifold_tests.jl b/test/single_point_manifold_tests.jl new file mode 100644 index 0000000..ed37275 --- /dev/null +++ b/test/single_point_manifold_tests.jl @@ -0,0 +1,67 @@ +using Test +using ManifoldsBase +using Random +using StaticArrays + +@testitem "Generic properties of SinglePointManifold" begin + import ManifoldsBase: check_point, check_vector, representation_size, injectivity_radius, get_embedding, is_flat, inner, manifold_dimension + import ExponentialFamilyManifolds: SinglePointManifold + using ManifoldsBase, Static, StaticArrays, JET, Manifolds + + points = [ + 0, + 0.0, + 0.0f0, + 1, + 1.0, + 1.0f0, + -1, + 2, + π, + rand(), + randn() + ] + + for p in points + M = SinglePointManifold(p) + + @test repr(M) == "SinglePointManifold($p)" + + @test @inferred(representation_size(M)) === () + @test @inferred(manifold_dimension(M)) === 0 + @test @inferred(is_flat(M)) === true + @test injectivity_radius(M) ≈ 0 + + @test_throws MethodError get_embedding(M) + + @test check_point(M, p) === nothing + @test check_point(M, p + 1) isa DomainError + @test check_point(M, p - 1) isa DomainError + + @test check_vector(M, p, 0) === nothing + @test check_vector(M, p, 1) isa DomainError + @test check_vector(M, p, -1) isa DomainError + + @test @eval(@allocated(representation_size($M))) === 0 + @test @eval(@allocated(manifold_dimension($M))) === 0 + @test @eval(@allocated(is_flat($M))) === 0 + + X = [1] + Y = [1] + + @test_opt inner(M, p, X, Y) + @test_opt inner(M, p, 0, 0) + end + + vector_points = [[1], [1, 2], [1, 2, 3]] + + for p in vector_points + M = SinglePointManifold(p) + q = similar(p) + X = zero_vector(M, p) + @test ManifoldsBase.exp!(M, q, p, X) == p + @test ManifoldsBase.log!(M, X, p, p) == zero_vector(M, p) + @test ManifoldsBase.log(M, p, p) == zero_vector(M, p) + @test ManifoldsBase.project!(M, similar(X), p, similar(X)) == zero_vector(M, p) + end +end \ No newline at end of file From 6ceae1f52e83731ceb1ba49f8cc1cc4dcf200c60 Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Mon, 22 Jul 2024 15:48:10 +0200 Subject: [PATCH 03/14] fix: implement get_natural_manifold_base for Categorical --- src/natural_manifolds/categorical.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/natural_manifolds/categorical.jl b/src/natural_manifolds/categorical.jl index 583f323..282af7c 100644 --- a/src/natural_manifolds/categorical.jl +++ b/src/natural_manifolds/categorical.jl @@ -5,7 +5,9 @@ Get the natural manifold base for the `Categorical` distribution. """ function get_natural_manifold_base(::Type{Categorical}, ::Tuple{}, conditioner=nothing) - return Euclidean(conditioner) + return ProductManifold( + Euclidean(conditioner-1), SinglePointManifold(0) + ) end """ @@ -15,5 +17,5 @@ Converts the `point` to a compatible representation for the natural manifold of """ function partition_point(::Type{Categorical}, ::Tuple{}, p, conditioner=nothing) # See comment in `get_natural_manifold_base` for `Categorical` - return ArrayPartition(p) + return ArrayPartition(p[1:end-1], p[end]) end \ No newline at end of file From bed5de4c8a6dfb48c4e16d6fc21820f383b03316 Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Mon, 22 Jul 2024 15:49:42 +0200 Subject: [PATCH 04/14] test: clean test --- test/single_point_manifold_tests.jl | 5 ----- 1 file changed, 5 deletions(-) diff --git a/test/single_point_manifold_tests.jl b/test/single_point_manifold_tests.jl index ed37275..f6647c3 100644 --- a/test/single_point_manifold_tests.jl +++ b/test/single_point_manifold_tests.jl @@ -1,8 +1,3 @@ -using Test -using ManifoldsBase -using Random -using StaticArrays - @testitem "Generic properties of SinglePointManifold" begin import ManifoldsBase: check_point, check_vector, representation_size, injectivity_radius, get_embedding, is_flat, inner, manifold_dimension import ExponentialFamilyManifolds: SinglePointManifold From 77a67d2e9c0054abd718087fd0bcbf2734840a98 Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Mon, 22 Jul 2024 16:11:41 +0200 Subject: [PATCH 05/14] fix: implement rand for SinglePointManifold --- src/SinglePointManifold.jl | 10 ++++++- src/natural_manifolds/categorical.jl | 2 +- test/natural_manifolds/categorical_tests.jl | 30 +++++++++++++++++++++ test/single_point_manifold_tests.jl | 5 ++++ 4 files changed, 45 insertions(+), 2 deletions(-) diff --git a/src/SinglePointManifold.jl b/src/SinglePointManifold.jl index 9386b8e..812bfec 100644 --- a/src/SinglePointManifold.jl +++ b/src/SinglePointManifold.jl @@ -65,4 +65,12 @@ end function ManifoldsBase.zero_vector!(::SinglePointManifold, X, p) return fill!(X, zero(eltype(X))) -end \ No newline at end of file +end + +function Random.rand(M::SinglePointManifold; kwargs...) + return rand(Random.default_rng(), M; kwargs...) +end + +function Random.rand(rng::AbstractRNG, M::SinglePointManifold; kwargs...) + return M.point +end diff --git a/src/natural_manifolds/categorical.jl b/src/natural_manifolds/categorical.jl index 282af7c..49ee771 100644 --- a/src/natural_manifolds/categorical.jl +++ b/src/natural_manifolds/categorical.jl @@ -6,7 +6,7 @@ Get the natural manifold base for the `Categorical` distribution. """ function get_natural_manifold_base(::Type{Categorical}, ::Tuple{}, conditioner=nothing) return ProductManifold( - Euclidean(conditioner-1), SinglePointManifold(0) + Euclidean(conditioner-1), SinglePointManifold([0]) ) end diff --git a/test/natural_manifolds/categorical_tests.jl b/test/natural_manifolds/categorical_tests.jl index 989859b..94ee120 100644 --- a/test/natural_manifolds/categorical_tests.jl +++ b/test/natural_manifolds/categorical_tests.jl @@ -6,4 +6,34 @@ normalize!(p, 1) return Categorical(p) end +end + +@testitem "Check that optimization work on Categorical" begin + include("natural_manifolds_setuptests.jl") + + using Manopt, ForwardDiff + using BayesBase + + rng = StableRNG(42) + p = rand(StableRNG(42), 10) + normalize!(p, 1) + distribution = Categorical(p) + sample = rand(rng, distribution) + dims = size(sample) + ef = convert(ExponentialFamilyDistribution, distribution) + T = ExponentialFamily.exponential_family_typetag(ef) + M = get_natural_manifold(T, dims, getconditioner(ef)) + + function f(M, p) + return (mean(p) - 0.5)^2 + end + + function g(M, p) + X = ForwardDiff.gradient((p) -> f(M, p), p) + return project!(M, X, p, X) + end + + q = gradient_descent(M, f, g, rand(rng, M)) + @show is_point(M, q) + @test mean(q) ≈ 0.5 atol = 1e-1 end \ No newline at end of file diff --git a/test/single_point_manifold_tests.jl b/test/single_point_manifold_tests.jl index f6647c3..5f7d53a 100644 --- a/test/single_point_manifold_tests.jl +++ b/test/single_point_manifold_tests.jl @@ -2,6 +2,10 @@ import ManifoldsBase: check_point, check_vector, representation_size, injectivity_radius, get_embedding, is_flat, inner, manifold_dimension import ExponentialFamilyManifolds: SinglePointManifold using ManifoldsBase, Static, StaticArrays, JET, Manifolds + using StableRNGs + + rng = StableRNG(42) + points = [ 0, @@ -58,5 +62,6 @@ @test ManifoldsBase.log!(M, X, p, p) == zero_vector(M, p) @test ManifoldsBase.log(M, p, p) == zero_vector(M, p) @test ManifoldsBase.project!(M, similar(X), p, similar(X)) == zero_vector(M, p) + @test rand(rng, M) ∈ M end end \ No newline at end of file From 2f0cbb6e4009bdc7fe834da1e04038a8d654e944 Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Mon, 22 Jul 2024 16:15:07 +0200 Subject: [PATCH 06/14] test: show -> test --- test/natural_manifolds/categorical_tests.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/natural_manifolds/categorical_tests.jl b/test/natural_manifolds/categorical_tests.jl index 94ee120..38b84f3 100644 --- a/test/natural_manifolds/categorical_tests.jl +++ b/test/natural_manifolds/categorical_tests.jl @@ -34,6 +34,6 @@ end end q = gradient_descent(M, f, g, rand(rng, M)) - @show is_point(M, q) + @test q ∈ M @test mean(q) ≈ 0.5 atol = 1e-1 -end \ No newline at end of file +end From 3637e4f44c6f20b25ba3957123d4f188ab5896d7 Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Mon, 22 Jul 2024 16:18:43 +0200 Subject: [PATCH 07/14] test(cov): improve codecaverage --- test/single_point_manifold_tests.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/test/single_point_manifold_tests.jl b/test/single_point_manifold_tests.jl index 5f7d53a..b50ca6d 100644 --- a/test/single_point_manifold_tests.jl +++ b/test/single_point_manifold_tests.jl @@ -1,5 +1,5 @@ @testitem "Generic properties of SinglePointManifold" begin - import ManifoldsBase: check_point, check_vector, representation_size, injectivity_radius, get_embedding, is_flat, inner, manifold_dimension + import ManifoldsBase: check_point, check_vector, embed, representation_size, injectivity_radius, get_embedding, is_flat, inner, manifold_dimension import ExponentialFamilyManifolds: SinglePointManifold using ManifoldsBase, Static, StaticArrays, JET, Manifolds using StableRNGs @@ -50,6 +50,10 @@ @test_opt inner(M, p, X, Y) @test_opt inner(M, p, 0, 0) + + @test embed(M, p) == p + @test embed(M, p, 0) == 0 + @test inner(M, p, 0, 0) == 0 end vector_points = [[1], [1, 2], [1, 2, 3]] @@ -64,4 +68,4 @@ @test ManifoldsBase.project!(M, similar(X), p, similar(X)) == zero_vector(M, p) @test rand(rng, M) ∈ M end -end \ No newline at end of file +end From 0204834a6e07b9bfc7d937f77cad6549cb8e50c7 Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Mon, 22 Jul 2024 16:34:59 +0200 Subject: [PATCH 08/14] fix: appropriate array partion --- src/natural_manifolds/categorical.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/natural_manifolds/categorical.jl b/src/natural_manifolds/categorical.jl index 49ee771..2b3e509 100644 --- a/src/natural_manifolds/categorical.jl +++ b/src/natural_manifolds/categorical.jl @@ -17,5 +17,5 @@ Converts the `point` to a compatible representation for the natural manifold of """ function partition_point(::Type{Categorical}, ::Tuple{}, p, conditioner=nothing) # See comment in `get_natural_manifold_base` for `Categorical` - return ArrayPartition(p[1:end-1], p[end]) + return ArrayPartition(p[1:end-1], p[end:end]) end \ No newline at end of file From eb22c17e8652194c7f8bda11fbbfc5de7806100f Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Mon, 22 Jul 2024 16:53:33 +0200 Subject: [PATCH 09/14] docs: add SinglePointManifold into docs --- docs/src/index.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/src/index.md b/docs/src/index.md index b58e96b..1002d99 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -81,6 +81,7 @@ ExponentialFamilyManifolds.partition_point ExponentialFamilyManifolds.ShiftedPositiveNumbers ExponentialFamilyManifolds.ShiftedNegativeNumbers ExponentialFamilyManifolds.SymmetricNegativeDefinite +ExponentialFamilyManifolds.SinglePointManifold ``` ## Optimization example From 4f1ca5edc821d648899af359771c9d68ddd6fd56 Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Mon, 22 Jul 2024 16:55:23 +0200 Subject: [PATCH 10/14] test: add test for def retraction and rand --- test/single_point_manifold_tests.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/single_point_manifold_tests.jl b/test/single_point_manifold_tests.jl index b50ca6d..51621c6 100644 --- a/test/single_point_manifold_tests.jl +++ b/test/single_point_manifold_tests.jl @@ -3,6 +3,7 @@ import ExponentialFamilyManifolds: SinglePointManifold using ManifoldsBase, Static, StaticArrays, JET, Manifolds using StableRNGs + using Random rng = StableRNG(42) @@ -30,6 +31,7 @@ @test @inferred(manifold_dimension(M)) === 0 @test @inferred(is_flat(M)) === true @test injectivity_radius(M) ≈ 0 + @test default_retraction_method(M) == ExponentialRetraction() @test_throws MethodError get_embedding(M) @@ -67,5 +69,6 @@ @test ManifoldsBase.log(M, p, p) == zero_vector(M, p) @test ManifoldsBase.project!(M, similar(X), p, similar(X)) == zero_vector(M, p) @test rand(rng, M) ∈ M + @test rand(M) ∈ M end end From 80f9586167e586b475900a1ef3d5b108f2fe0da3 Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Mon, 22 Jul 2024 23:59:04 +0200 Subject: [PATCH 11/14] test: more representative test --- src/natural_manifolds/categorical.jl | 3 +-- test/natural_manifolds/categorical_tests.jl | 7 ++++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/natural_manifolds/categorical.jl b/src/natural_manifolds/categorical.jl index 2b3e509..33c080c 100644 --- a/src/natural_manifolds/categorical.jl +++ b/src/natural_manifolds/categorical.jl @@ -16,6 +16,5 @@ end Converts the `point` to a compatible representation for the natural manifold of type `Categorical`. """ function partition_point(::Type{Categorical}, ::Tuple{}, p, conditioner=nothing) - # See comment in `get_natural_manifold_base` for `Categorical` - return ArrayPartition(p[1:end-1], p[end:end]) + return ArrayPartition(view(p, 1:conditioner-1), view(p, conditioner:conditioner)) end \ No newline at end of file diff --git a/test/natural_manifolds/categorical_tests.jl b/test/natural_manifolds/categorical_tests.jl index 38b84f3..48dd550 100644 --- a/test/natural_manifolds/categorical_tests.jl +++ b/test/natural_manifolds/categorical_tests.jl @@ -25,12 +25,13 @@ end M = get_natural_manifold(T, dims, getconditioner(ef)) function f(M, p) - return (mean(p) - 0.5)^2 + ef = convert(ExponentialFamilyDistribution, M, p) + return (mean(ef) - 0.5)^2 end function g(M, p) - X = ForwardDiff.gradient((p) -> f(M, p), p) - return project!(M, X, p, X) + ef = convert(ExponentialFamilyDistribution, M, p) + return project(M, p, 2 * (mean(ef) - 0.5) * p ./ 10) end q = gradient_descent(M, f, g, rand(rng, M)) From 6eaa19d6519b622d3d03c281d8779079f5a99594 Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Tue, 23 Jul 2024 12:14:22 +0200 Subject: [PATCH 12/14] test: add gradient test for SinglePointManifold --- src/ExponentialFamilyManifolds.jl | 2 +- src/SinglePointManifold.jl | 76 ----------------------------- test/single_point_manifold_tests.jl | 34 +++++++++++++ 3 files changed, 35 insertions(+), 77 deletions(-) delete mode 100644 src/SinglePointManifold.jl diff --git a/src/ExponentialFamilyManifolds.jl b/src/ExponentialFamilyManifolds.jl index 1d4f181..2a82cad 100644 --- a/src/ExponentialFamilyManifolds.jl +++ b/src/ExponentialFamilyManifolds.jl @@ -5,7 +5,7 @@ using BayesBase, ExponentialFamily, ManifoldsBase, Manifolds, Random, LinearAlge include("symmetric_negative_definite.jl") include("shifted_negative_numbers.jl") include("shifted_positive_numbers.jl") -include("SinglePointManifold.jl") +include("single_point_manifold.jl") include("natural_manifolds.jl") include("natural_manifolds/bernoulli.jl") diff --git a/src/SinglePointManifold.jl b/src/SinglePointManifold.jl deleted file mode 100644 index 812bfec..0000000 --- a/src/SinglePointManifold.jl +++ /dev/null @@ -1,76 +0,0 @@ -using ManifoldsBase -using Random - -""" - SymmetricNegativeDefinite(point) - -This manifold represents a set from one point. -""" -struct SinglePointManifold{T, R} <: AbstractManifold{ℝ} - point::T - representation_size::R -end - -function SinglePointManifold(point::T) where {T} - return SinglePointManifold(point, size(point)) -end - -function Base.show(io::IO, M::SinglePointManifold) - print(io, "SinglePointManifold(", M.point, ")") -end - -ManifoldsBase.manifold_dimension(::SinglePointManifold) = 0 -ManifoldsBase.representation_size(M::SinglePointManifold) = M.representation_size -ManifoldsBase.injectivity_radius(M::SinglePointManifold) = zero(eltype(M.point)) - -ManifoldsBase.default_retraction_method(::SinglePointManifold) = ExponentialRetraction() - -function ManifoldsBase.check_point(M::SinglePointManifold, p; kwargs...) - if p != M.point - return DomainError(p, "The point $(p) does not lie on $(M), which contains only $(M.point).") - end - return nothing -end - -function ManifoldsBase.check_vector(M::SinglePointManifold, p, X; kwargs...) - if !iszero(X) && size(M.point) == size(X) - return DomainError(X, "The tangent space of $(M) contains only the zero vector.") - end - return nothing -end - -ManifoldsBase.is_flat(::SinglePointManifold) = true - -ManifoldsBase.embed(::SinglePointManifold, p) = p -ManifoldsBase.embed(::SinglePointManifold, p, X) = X - -function ManifoldsBase.inner(::SinglePointManifold, p, X, Y) - return zero(eltype(X)) -end - -function ManifoldsBase.exp!(M::SinglePointManifold, q, p, X, t::Number=1) - q .= M.point - return q -end - -function ManifoldsBase.log!(::SinglePointManifold, X, p, q) - X .= zero(eltype(X)) - return X -end - -function ManifoldsBase.project!(::SinglePointManifold, Y, p, X) - fill!(Y, zero(eltype(Y))) - return Y -end - -function ManifoldsBase.zero_vector!(::SinglePointManifold, X, p) - return fill!(X, zero(eltype(X))) -end - -function Random.rand(M::SinglePointManifold; kwargs...) - return rand(Random.default_rng(), M; kwargs...) -end - -function Random.rand(rng::AbstractRNG, M::SinglePointManifold; kwargs...) - return M.point -end diff --git a/test/single_point_manifold_tests.jl b/test/single_point_manifold_tests.jl index 51621c6..0f18a17 100644 --- a/test/single_point_manifold_tests.jl +++ b/test/single_point_manifold_tests.jl @@ -72,3 +72,37 @@ @test rand(M) ∈ M end end + +@testitem "Simple manifold optimization problem #1" begin + using Manopt, ForwardDiff, Static, StableRNGs, LinearAlgebra + + import ExponentialFamilyManifolds: SinglePointManifold + + for a in (2.0, 3.0), + b in (10.0, 5.0), + c in (1.0, 10.0, -1.0), + eps in (1e-4, 1e-5, 1e-8, 1e-10), + stepsize in (ConstantStepsize(0.1), ConstantStepsize(0.01), ConstantStepsize(0.001)) + + f(M, x) = (a .* x .^ 2 .+ b .* x .+ c)[1] + grad_f(M, x) = 2 .* a .* x .+ b + + rng = StableRNG(42) + + for s in [0, 0.0, 10] + M = SinglePointManifold(s) + p0 = rand(rng, M) + + q1 = gradient_descent( + M, + f, + grad_f, + p0; + stepsize=stepsize, + stopping_criterion=StopAfterIteration(1) + ) + + @test q1 ≈ s + end + end +end \ No newline at end of file From 68e2f0d0d75601042edb149279bfe5d9f5d8bd0c Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Tue, 23 Jul 2024 12:14:41 +0200 Subject: [PATCH 13/14] refactor: rename file --- src/single_point_manifold.jl | 76 ++++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 src/single_point_manifold.jl diff --git a/src/single_point_manifold.jl b/src/single_point_manifold.jl new file mode 100644 index 0000000..8dcc4b2 --- /dev/null +++ b/src/single_point_manifold.jl @@ -0,0 +1,76 @@ +using ManifoldsBase +using Random + +""" + SinglePointManifold(point) + +This manifold represents a set from one point. +""" +struct SinglePointManifold{T, R} <: AbstractManifold{ℝ} + point::T + representation_size::R +end + +function SinglePointManifold(point::T) where {T} + return SinglePointManifold(point, size(point)) +end + +function Base.show(io::IO, M::SinglePointManifold) + print(io, "SinglePointManifold(", M.point, ")") +end + +ManifoldsBase.manifold_dimension(::SinglePointManifold) = 0 +ManifoldsBase.representation_size(M::SinglePointManifold) = M.representation_size +ManifoldsBase.injectivity_radius(M::SinglePointManifold) = zero(eltype(M.point)) + +ManifoldsBase.default_retraction_method(::SinglePointManifold) = ExponentialRetraction() + +function ManifoldsBase.check_point(M::SinglePointManifold, p; kwargs...) + if p[1] != M.point + return DomainError(p, "The point $(p) does not lie on $(M), which contains only $(M.point).") + end + return nothing +end + +function ManifoldsBase.check_vector(M::SinglePointManifold, p, X; kwargs...) + if !iszero(X) && size(M.point) == size(X) + return DomainError(X, "The tangent space of $(M) contains only the zero vector.") + end + return nothing +end + +ManifoldsBase.is_flat(::SinglePointManifold) = true + +ManifoldsBase.embed(::SinglePointManifold, p) = p +ManifoldsBase.embed(::SinglePointManifold, p, X) = X + +function ManifoldsBase.inner(::SinglePointManifold, p, X, Y) + return zero(eltype(X)) +end + +function ManifoldsBase.exp!(M::SinglePointManifold, q, p, X, t::Number=1) + q .= M.point + return q +end + +function ManifoldsBase.log!(::SinglePointManifold, X, p, q) + X .= zero(eltype(X)) + return X +end + +function ManifoldsBase.project!(::SinglePointManifold, Y, p, X) + fill!(Y, zero(eltype(Y))) + return Y +end + +function ManifoldsBase.zero_vector!(::SinglePointManifold, X, p) + return fill!(X, zero(eltype(X))) +end + +function Random.rand(M::SinglePointManifold; kwargs...) + return rand(Random.default_rng(), M; kwargs...) +end + +function Random.rand(rng::AbstractRNG, M::SinglePointManifold; kwargs...) + return M.point +end From 99b234b40699f62078c705a550785abe3eb4320d Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Thu, 25 Jul 2024 14:35:15 +0200 Subject: [PATCH 14/14] fix: update ExponetialFamily.jl --- Project.toml | 2 +- src/single_point_manifold.jl | 2 +- test/natural_manifolds/categorical_tests.jl | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index 9c1cd5b..d42fc33 100644 --- a/Project.toml +++ b/Project.toml @@ -15,7 +15,7 @@ Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" [compat] BayesBase = "1.3" -ExponentialFamily = "1.4.3" +ExponentialFamily = "1.5.1" LinearAlgebra = "1.10" Manifolds = "0.9" ManifoldsBase = "0.15" diff --git a/src/single_point_manifold.jl b/src/single_point_manifold.jl index 8dcc4b2..2c36d81 100644 --- a/src/single_point_manifold.jl +++ b/src/single_point_manifold.jl @@ -26,7 +26,7 @@ ManifoldsBase.injectivity_radius(M::SinglePointManifold) = zero(eltype(M.point)) ManifoldsBase.default_retraction_method(::SinglePointManifold) = ExponentialRetraction() function ManifoldsBase.check_point(M::SinglePointManifold, p; kwargs...) - if p[1] != M.point + if !(p ≈ M.point) return DomainError(p, "The point $(p) does not lie on $(M), which contains only $(M.point).") end return nothing diff --git a/test/natural_manifolds/categorical_tests.jl b/test/natural_manifolds/categorical_tests.jl index 48dd550..c613a20 100644 --- a/test/natural_manifolds/categorical_tests.jl +++ b/test/natural_manifolds/categorical_tests.jl @@ -26,12 +26,12 @@ end function f(M, p) ef = convert(ExponentialFamilyDistribution, M, p) - return (mean(ef) - 0.5)^2 + η = getnaturalparameters(ef) + return (mean(η) - 0.5)^2 end function g(M, p) - ef = convert(ExponentialFamilyDistribution, M, p) - return project(M, p, 2 * (mean(ef) - 0.5) * p ./ 10) + return project(M, p, 2 * p ./ 10) end q = gradient_descent(M, f, g, rand(rng, M))