Skip to content

Commit

Permalink
Using and testing all backends
Browse files Browse the repository at this point in the history
  • Loading branch information
FlorianPfaff committed Oct 16, 2023
1 parent 2ff69c8 commit f331275
Show file tree
Hide file tree
Showing 159 changed files with 2,306 additions and 1,450 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install poetry
poetry install --extras healpy_support
poetry install --extras healpy_support --extras pytorch_support
- name: List files and check Python and package versions
run: |
Expand All @@ -40,12 +40,17 @@ jobs:
- name: Run tests
run: |
poetry env use python
poetry run python -m pytest --rootdir . -v --strict-config --junitxml=junit_test_results.xml ./pyrecest
export PYRECEST_BACKEND=numpy
poetry run python -m pytest --rootdir . -v --strict-config --junitxml=junit_test_results_numpy.xml ./pyrecest
export PYRECEST_BACKEND=pytorch
poetry run python -m pytest --rootdir . -v --strict-config --junitxml=junit_test_results_pytorch.xml ./pyrecest
env:
PYTHONPATH: ${{ github.workspace }}

- name: Publish test results
if: always()
uses: EnricoMi/publish-unit-test-result-action@v2
with:
files: junit_test_results.xml
files: |
junit_test_results_numpy.xml
junit_test_results_pytorch.xml
53 changes: 0 additions & 53 deletions .github/workflows/tests_pytorch.yml

This file was deleted.

2 changes: 1 addition & 1 deletion .github/workflows/update-requirements.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
run: python -m poetry update

- name: Update requirements.txt
run: python -m poetry export --format requirements.txt --output requirements.txt --extras healpy_support --without-hashes
run: python -m poetry export --format requirements.txt --output requirements.txt --extras healpy_support --extras pytorch_support --without-hashes

- name: Update requirements-dev.txt
run: python -m poetry export --with dev --format requirements.txt --output requirements-dev.txt --without-hashes
Expand Down
6 changes: 6 additions & 0 deletions .jscpd.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"ignore": [
"pyrecest/_backend/**"
]
}

222 changes: 216 additions & 6 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ shapely = "*"

[tool.poetry.extras]
healpy_support = ["healpy"]
pytorch_support = ["torch"]

[tool.poetry.group.dev.dependencies]
healpy = "*"
torch = "*"
autopep8 = "^2.0.2"
pytest = "*"
parameterized = "*"
Expand Down
2 changes: 2 additions & 0 deletions pyrecest/_backend/.pylintrc
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[MESSAGES CONTROL]
disable=all
2 changes: 1 addition & 1 deletion pyrecest/_backend/_shared_numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def assignment_by_sum(x, values, indices, axis=0):


def ndim(x):
return x.ndim
return _np.ndim(x)


def get_slice(x, indices):
Expand Down
28 changes: 18 additions & 10 deletions pyrecest/distributions/abstract_dirac_distribution.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
from pyrecest.backend import sum
from pyrecest.backend import ones
from pyrecest.backend import log
from pyrecest.backend import isclose
from pyrecest.backend import argmax
from pyrecest.backend import all
from pyrecest.backend import int64
from pyrecest.backend import int32
import copy
import warnings
from collections.abc import Callable
Expand All @@ -22,7 +30,7 @@ def __init__(self, d: np.ndarray, w: np.ndarray | None = None):
:param w: Weights of Dirac locations as a numpy array. If not provided, defaults to uniform weights.
"""
if w is None:
w = np.ones(d.shape[0]) / d.shape[0]
w = ones(d.shape[0]) / d.shape[0]

assert d.shape[0] == np.size(w), "Number of Diracs and weights must match."
self.d = copy.copy(d)
Expand All @@ -34,9 +42,9 @@ def normalize_in_place(self):
"""
Normalize the weights in-place to ensure they sum to 1.
"""
if not np.isclose(np.sum(self.w), 1, atol=1e-10):
if not isclose(sum(self.w), 1, atol=1e-10):
warnings.warn("Weights are not normalized.", RuntimeWarning)
self.w = self.w / np.sum(self.w)
self.w = self.w / sum(self.w)

@beartype
def normalize(self) -> "AbstractDiracDistribution":
Expand Down Expand Up @@ -67,28 +75,28 @@ def reweigh(self, f: Callable) -> "AbstractDiracDistribution":
wNew = f(dist.d)

assert wNew.shape == dist.w.shape, "Function returned wrong output dimensions."
assert np.all(wNew >= 0), "All weights should be greater than or equal to 0."
assert np.sum(wNew) > 0, "The sum of all weights should be greater than 0."
assert all(wNew >= 0), "All weights should be greater than or equal to 0."
assert sum(wNew) > 0, "The sum of all weights should be greater than 0."

dist.w = wNew * dist.w
dist.w = dist.w / np.sum(dist.w)
dist.w = dist.w / sum(dist.w)

return dist

@beartype
def sample(self, n: int | np.int32 | np.int64) -> np.ndarray:
def sample(self, n: int | int32 | int64) -> np.ndarray:
ids = np.random.choice(np.size(self.w), size=n, p=self.w)
return self.d[ids]

def entropy(self) -> float:
warnings.warn("Entropy is not defined in a continuous sense")
return -np.sum(self.w * np.log(self.w))
return -sum(self.w * log(self.w))

def integrate(self, left=None, right=None) -> np.ndarray:
assert (
left is None and right is None
), "Must overwrite in child class to use integral limits"
return np.sum(self.w)
return sum(self.w)

def log_likelihood(self, *args):
raise NotImplementedError("PDF:UNDEFINED, not supported")
Expand All @@ -112,7 +120,7 @@ def kld_numerical(self, *args):
raise NotImplementedError("PDF:UNDEFINED, not supported")

def mode(self, rel_tol=0.001):
highest_val, ind = np.max(self.w), np.argmax(self.w)
highest_val, ind = np.max(self.w), argmax(self.w)
if (highest_val / self.w.size) < (1 + rel_tol):
warnings.warn(
"The samples may be equally weighted, .mode is likely to return a bad result."
Expand Down
4 changes: 3 additions & 1 deletion pyrecest/distributions/abstract_disk_distribution.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from pyrecest.backend import eye
from pyrecest.backend import array
import numpy as np

from .abstract_ellipsoidal_ball_distribution import AbstractEllipsoidalBallDistribution
Expand All @@ -10,7 +12,7 @@ class AbstractDiskDistribution(AbstractEllipsoidalBallDistribution):

# We index it using 2-D Euclidean vectors (is zero everywhere else)
def __init__(self):
super().__init__(np.array([0, 0]), np.eye(2))
super().__init__(array([0, 0]), eye(2))

def mean(self):
raise TypeError("Mean not defined for distributions on the disk.")
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pyrecest.backend import sqrt
import numbers

import numpy as np
Expand Down Expand Up @@ -50,4 +51,4 @@ def get_manifold_size(self) -> np.number | numbers.Real:
else:
c = (np.pi ** (self.dim / 2)) / gamma((self.dim / 2) + 1)

return c * np.sqrt(np.linalg.det(self.shape_matrix))
return c * sqrt(np.linalg.det(self.shape_matrix))
16 changes: 10 additions & 6 deletions pyrecest/distributions/abstract_manifold_specific_distribution.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from pyrecest.backend import squeeze
from pyrecest.backend import int64
from pyrecest.backend import int32
from pyrecest.backend import empty
import numbers
from abc import ABC, abstractmethod
from collections.abc import Callable
Expand Down Expand Up @@ -60,17 +64,17 @@ def set_mode(self, _):
raise NotImplementedError("set_mode is not implemented for this distribution")

@beartype
def sample(self, n: int | np.int32 | np.int64) -> np.ndarray:
def sample(self, n: int | int32 | int64) -> np.ndarray:
"""Obtain n samples from the distribution."""
return self.sample_metropolis_hastings(n)

# jscpd:ignore-start
@beartype
def sample_metropolis_hastings(
self,
n: int | np.int32 | np.int64,
burn_in: int | np.int32 | np.int64 = 10,
skipping: int | np.int32 | np.int64 = 5,
n: int | int32 | int64,
burn_in: int | int32 | int64 = 10,
skipping: int | int32 | int64 = 5,
proposal: Callable | None = None,
start_point: np.number | numbers.Real | np.ndarray | None = None,
) -> np.ndarray:
Expand All @@ -83,7 +87,7 @@ def sample_metropolis_hastings(
)

total_samples = burn_in + n * skipping
s = np.empty(
s = empty(
(
total_samples,
self.input_dim,
Expand All @@ -105,4 +109,4 @@ def sample_metropolis_hastings(
i += 1

relevant_samples = s[burn_in::skipping, :]
return np.squeeze(relevant_samples)
return squeeze(relevant_samples)
18 changes: 12 additions & 6 deletions pyrecest/distributions/abstract_mixture.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
from pyrecest.backend import sum
from pyrecest.backend import ones
from pyrecest.backend import int64
from pyrecest.backend import int32
from pyrecest.backend import empty
from pyrecest.backend import zeros
import collections
import copy
import warnings
Expand Down Expand Up @@ -27,7 +33,7 @@ def __init__(
num_distributions = len(dists)

if weights is None:
weights = np.ones(num_distributions) / num_distributions
weights = ones(num_distributions) / num_distributions
else:
weights = np.asarray(weights)

Expand All @@ -48,9 +54,9 @@ def __init__(

self.dists = dists

if abs(np.sum(weights) - 1) > 1e-10:
if abs(sum(weights) - 1) > 1e-10:
warnings.warn("Weights of mixture do not sum to one.")
self.w = weights / np.sum(weights)
self.w = weights / sum(weights)
else:
self.w = weights

Expand All @@ -59,12 +65,12 @@ def input_dim(self) -> int:
return self.dists[0].input_dim

@beartype
def sample(self, n: int | np.int32 | np.int64) -> np.ndarray:
def sample(self, n: int | int32 | int64) -> np.ndarray:
d = np.random.choice(len(self.w), size=n, p=self.w)

occurrences = np.bincount(d, minlength=len(self.dists))
count = 0
s = np.empty((n, self.input_dim))
s = empty((n, self.input_dim))
for i, occ in enumerate(occurrences):
if occ != 0:
s[count : count + occ, :] = self.dists[i].sample(occ) # noqa: E203
Expand All @@ -79,7 +85,7 @@ def sample(self, n: int | np.int32 | np.int64) -> np.ndarray:
def pdf(self, xs: np.ndarray) -> np.ndarray:
assert xs.shape[-1] == self.input_dim, "Dimension mismatch"

p = np.zeros(1) if xs.ndim == 1 else np.zeros(xs.shape[0])
p = zeros(1) if xs.ndim == 1 else zeros(xs.shape[0])

for i, dist in enumerate(self.dists):
p += self.w[i] * dist.pdf(xs)
Expand Down
10 changes: 7 additions & 3 deletions pyrecest/distributions/abstract_orthogonal_basis_distribution.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from pyrecest.backend import real
from pyrecest.backend import imag
from pyrecest.backend import exp
from pyrecest.backend import all
import copy
import warnings
from abc import abstractmethod
Expand Down Expand Up @@ -57,14 +61,14 @@ def pdf(self, xs: np.ndarray | np.number) -> np.ndarray | np.number:
"""
val = self.value(xs)
if self.transformation == "sqrt":
assert np.all(np.imag(val) < 0.0001)
return np.real(val) ** 2
assert all(imag(val) < 0.0001)
return real(val) ** 2

if self.transformation == "identity":
return val

if self.transformation == "log":
warnings.warn("Density may not be normalized")
return np.exp(val)
return exp(val)

raise ValueError("Transformation not recognized or unsupported")
4 changes: 3 additions & 1 deletion pyrecest/distributions/abstract_periodic_distribution.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from pyrecest.backend import int64
from pyrecest.backend import int32
from abc import abstractmethod

import numpy as np
Expand All @@ -10,7 +12,7 @@ class AbstractPeriodicDistribution(AbstractBoundedDomainDistribution):
"""Abstract class for a distributions on periodic manifolds."""

@beartype
def __init__(self, dim: int | np.int32 | np.int64):
def __init__(self, dim: int | int32 | int64):
super().__init__(dim=dim)

@beartype
Expand Down
Loading

0 comments on commit f331275

Please sign in to comment.