Skip to content

Commit

Permalink
add sigma as explicit argument in GaussianFilterOperator and convert …
Browse files Browse the repository at this point in the history
…correctly to numpy/cupy arrays
  • Loading branch information
gschramm committed Oct 18, 2023
1 parent 627fe97 commit bce7076
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
20 changes: 18 additions & 2 deletions python/parallelproj/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,18 +314,22 @@ def iscomplex(self) -> bool:
class GaussianFilterOperator(LinearOperator):
"""Gaussian filter operator"""

def __init__(self, in_shape: tuple[int,...], **kwargs):
def __init__(self, in_shape: tuple[int, ...], sigma: float | npt.NDArray,
**kwargs):
"""init method
Parameters
----------
in_shape : tuple[int, ...]
shape of the input array
sigma: float | array
standard deviation of the gaussian filter
**kwargs : sometype
passed to the ndimage gaussian_filter function
"""
super().__init__()
self._in_shape = in_shape
self._sigma = sigma
self._kwargs = kwargs

@property
Expand All @@ -341,14 +345,26 @@ def _apply(self, x: npt.ArrayLike) -> npt.ArrayLike:
xp = array_api_compat.get_namespace(x)

if parallelproj.is_cuda_array(x):
import cupy as cp
import array_api_compat.cupy as cp
import cupyx.scipy.ndimage as ndimagex
if array_api_compat.is_array_api_obj(self._sigma):
sigma = cp.asarray(self._sigma)
else:
sigma = self._sigma

return xp.asarray(ndimagex.gaussian_filter(cp.asarray(x),
sigma=sigma,
**self._kwargs),
device=device(x))
else:
import scipy.ndimage as ndimage
if array_api_compat.is_array_api_obj(self._sigma):
sigma = np.asarray(self._sigma)
else:
sigma = self._sigma

return xp.asarray(ndimage.gaussian_filter(np.asarray(x),
sigma=sigma,
**self._kwargs),
device=device(x))

Expand Down
8 changes: 6 additions & 2 deletions test/parallelproj/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,13 @@ def elemenwise_test(xp: ModuleType, dev: str):
def gaussian_test(xp: ModuleType, dev: str):
np.random.seed(0)
in_shape = (32, 32)
sigma = 2.3
sigma1 = 2.3

op = parallelproj.GaussianFilterOperator(in_shape, sigma=sigma)
op = parallelproj.GaussianFilterOperator(in_shape, sigma=sigma1)
op.adjointness_test(xp, dev)

sigma2 = xp.asarray([2.3, 1.2], device=dev)
op = parallelproj.GaussianFilterOperator(in_shape, sigma=sigma2)
op.adjointness_test(xp, dev)


Expand Down

0 comments on commit bce7076

Please sign in to comment.