diff --git a/pyrecest/distributions/hypersphere_subset/abstract_hypersphere_subset_distribution.py b/pyrecest/distributions/hypersphere_subset/abstract_hypersphere_subset_distribution.py index a69b7ef9..51a8e5b1 100644 --- a/pyrecest/distributions/hypersphere_subset/abstract_hypersphere_subset_distribution.py +++ b/pyrecest/distributions/hypersphere_subset/abstract_hypersphere_subset_distribution.py @@ -339,7 +339,7 @@ def hypersph_to_cart(hypersph_coords, mode: str = "colatitude"): ) elif mode in ("elevation", "inclination"): from .abstract_sphere_subset_distribution import AbstractSphereSubsetDistribution - assert hypersph_coords.shape[1] == 2, "Mode only supports 2 dimensions" + assert hypersph_coords.shape[1] == 2, "Mode only S2 dimensions" x, y, z = AbstractSphereSubsetDistribution.sph_to_cart(hypersph_coords[:, 0], hypersph_coords[:, 1], mode=mode) cart_coords = column_stack((x, y, z)) else: @@ -347,6 +347,25 @@ def hypersph_to_cart(hypersph_coords, mode: str = "colatitude"): return cart_coords.squeeze() + @staticmethod + def cart_to_hypersph(cart_coords, mode: str = "colatitude"): + cart_coords = atleast_2d(cart_coords) + if mode == "colatitude": + cart_coords = AbstractHypersphereSubsetDistribution._cart_to_hyersph_colatitude(cart_coords) + elif mode in ("elevation", "inclination"): + from .abstract_sphere_subset_distribution import AbstractSphereSubsetDistribution + assert cart_coords.shape[1] == 3, "Mode only supports S2" + theta, phi = AbstractSphereSubsetDistribution.cart_to_sph(cart_coords[:, 0], cart_coords[:, 1], cart_coords[:, 2], mode=mode) + cart_coords = column_stack((theta, phi)) + else: + raise ValueError("Mode must be 'colatitude', 'elevation' or 'inclination'") + + return cart_coords.squeeze() + + @staticmethod + def _cart_to_hyersph_colatitude(cart_coords): + raise NotImplementedError('Not yet implemented') + @staticmethod def _get_integrand_hypersph_fun(fun: Callable): """ diff --git a/pyrecest/tests/distributions/test_abstract_hypersphere_subset_distribution.py b/pyrecest/tests/distributions/test_abstract_hypersphere_subset_distribution.py index 9c672a14..1d0ec732 100644 --- a/pyrecest/tests/distributions/test_abstract_hypersphere_subset_distribution.py +++ b/pyrecest/tests/distributions/test_abstract_hypersphere_subset_distribution.py @@ -212,7 +212,7 @@ def test_hyperspherical_to_cartesian_specific(self, dimensions, specific_functio npt.assert_allclose(cartesian_specific, cartesian_given) - @parameterized.expand( + @parameterized.expand( [ ("colatitude",), ("elevation",), @@ -226,17 +226,13 @@ def test_cart_to_sph_to_cart(self, mode): z = array([0.0, 0.0, 1.0]) # Convert to spherical coordinates and back - azimuth, theta = AbstractHypersphereSubsetDistribution.cart_to_hypersph( - x, y, z, mode=mode - ) - x_new, y_new, z_new = AbstractHypersphereSubsetDistribution.hypersph_to_cart( - azimuth, theta, mode=mode - ) + angles = AbstractHypersphereSubsetDistribution.cart_to_hypersph(column_stack((x, y, z)), mode=mode) + cart_res = AbstractHypersphereSubsetDistribution.hypersph_to_cart(angles, mode=mode) # The new Cartesian coordinates should be close to the original ones - npt.assert_allclose(x_new, x, atol=1e-7) - npt.assert_allclose(y_new, y, atol=1e-7) - npt.assert_allclose(z_new, z, atol=1e-7) + npt.assert_allclose(cart_res[:, 0], x, atol=1e-7) + npt.assert_allclose(cart_res[:, 1], y, atol=1e-7) + npt.assert_allclose(cart_res[:, 2], z, atol=1e-7) def test_pdf_hyperspherical_coords_1d(self): mu_ = array([0.5, 1.0]) / linalg.norm(array([0.5, 1.0])) diff --git a/pyrecest/tests/distributions/test_watson_distribution.py b/pyrecest/tests/distributions/test_watson_distribution.py index f1772a93..bba0c557 100644 --- a/pyrecest/tests/distributions/test_watson_distribution.py +++ b/pyrecest/tests/distributions/test_watson_distribution.py @@ -58,7 +58,7 @@ def test_integrate(self): mu = mu / linalg.norm(mu) kappa = 2.0 w = WatsonDistribution(mu, kappa) - self.assertAlmostEqual(w.integrate(), 1, delta=1e-5) + self.assertAlmostEqual(w.integrate(), 1.0, delta=1e-5) def test_to_bingham(self): mu = array([1.0, 0.0, 0.0]) diff --git a/test_plot_0.png b/test_plot_0.png index b0b32cca..024ef8f8 100644 Binary files a/test_plot_0.png and b/test_plot_0.png differ diff --git a/test_plot_1.png b/test_plot_1.png index 85d6471f..6c293ab5 100644 Binary files a/test_plot_1.png and b/test_plot_1.png differ diff --git a/test_plot_2.png b/test_plot_2.png index 33eb522e..8dcf8c1f 100644 Binary files a/test_plot_2.png and b/test_plot_2.png differ