Skip to content

Commit

Permalink
Merge pull request #142 from FlorianPfaff/watson_to_bingham
Browse files Browse the repository at this point in the history
Added to_bingham for WatsonDistribution
  • Loading branch information
FlorianPfaff authored Aug 27, 2023
2 parents dbc48c1 + 8ff9122 commit 56b37b5
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 23 deletions.
15 changes: 3 additions & 12 deletions pyrecest/distributions/hypersphere_subset/watson_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,27 +50,18 @@ def pdf(self, xs):
return p

def to_bingham(self) -> BinghamDistribution:
"""
Converts the Watson distribution to a Bingham distribution.
Returns:
BinghamDistribution: The converted distribution.
Raises:
NotImplementedError: If kappa is less than 0.
"""
if self.kappa < 0:
raise NotImplementedError(
"Conversion to Bingham is not implemented for kappa<0"
)

M = np.tile(self.mu, (1, self.dim + 1))
E = np.eye(self.dim + 1)
M = np.tile(self.mu.reshape(-1, 1), (1, self.input_dim))
E = np.eye(self.input_dim)
E[0, 0] = 0
M = M + E
Q, _ = qr(M)
M = np.hstack([Q[:, 1:], Q[:, 0].reshape(-1, 1)])
Z = np.vstack([np.full((self.dim, 1), -self.kappa), 0])
Z = np.hstack([np.full((self.dim), -self.kappa), 0])
return BinghamDistribution(Z, M)

def sample(self, n):
Expand Down
31 changes: 20 additions & 11 deletions pyrecest/tests/distributions/test_watson_distribution.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
import unittest

import numpy as np
from pyrecest.distributions import WatsonDistribution
from pyrecest.distributions import BinghamDistribution, WatsonDistribution


class TestWatsonDistribution(unittest.TestCase):
def setUp(self):
self.xs = np.array(
[[1, 0, 0], [1, 2, 2], [0, 1, 0], [0, 0, 1], [1, 1, 1], [-1, -1, -1]],
dtype=float,
)
self.xs = self.xs / np.linalg.norm(self.xs, axis=1, keepdims=True)

def test_constructor(self):
mu = np.array([1, 2, 3])
mu = mu / np.linalg.norm(mu)
Expand All @@ -14,20 +21,14 @@ def test_constructor(self):
self.assertIsInstance(w, WatsonDistribution)
np.testing.assert_array_equal(w.mu, mu)
self.assertEqual(w.kappa, kappa)
self.assertEqual(w.dim, len(mu) - 1)
self.assertEqual(w.input_dim, np.size(mu))

def test_pdf(self):
mu = np.array([1, 2, 3])
mu = mu / np.linalg.norm(mu)
kappa = 2
w = WatsonDistribution(mu, kappa)

xs = np.array(
[[1, 0, 0], [1, 2, 2], [0, 1, 0], [0, 0, 1], [1, 1, 1], [-1, -1, -1]],
dtype=float,
)
xs = xs / np.linalg.norm(xs, axis=1, keepdims=True)

expected_pdf_values = np.array(
[
0.0388240901641662,
Expand All @@ -39,18 +40,26 @@ def test_pdf(self):
]
)

pdf_values = w.pdf(xs)
pdf_values = w.pdf(self.xs)
np.testing.assert_almost_equal(pdf_values, expected_pdf_values, decimal=5)

def test_integrate(self):
mu = np.array([1, 2, 3])
mu = mu / np.linalg.norm(mu)
kappa = 2
w = WatsonDistribution(mu, kappa)

# Test integral
self.assertAlmostEqual(w.integrate(), 1, delta=1e-5)

def test_to_bingham(self):
mu = np.array([1.0, 0.0, 0.0])
kappa = 2.0
watson_dist = WatsonDistribution(mu, kappa)
bingham_dist = watson_dist.to_bingham()
self.assertIsInstance(bingham_dist, BinghamDistribution)
np.testing.assert_almost_equal(
watson_dist.pdf(self.xs), bingham_dist.pdf(self.xs), decimal=5
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 56b37b5

Please sign in to comment.