diff --git a/python/parallelproj/operators.py b/python/parallelproj/operators.py index 6f323e0b..761d85de 100644 --- a/python/parallelproj/operators.py +++ b/python/parallelproj/operators.py @@ -314,18 +314,22 @@ def iscomplex(self) -> bool: class GaussianFilterOperator(LinearOperator): """Gaussian filter operator""" - def __init__(self, in_shape: tuple[int,...], **kwargs): + def __init__(self, in_shape: tuple[int, ...], sigma: float | npt.NDArray, + **kwargs): """init method Parameters ---------- in_shape : tuple[int, ...] shape of the input array + sigma: float | array + standard deviation of the gaussian filter **kwargs : sometype passed to the ndimage gaussian_filter function """ super().__init__() self._in_shape = in_shape + self._sigma = sigma self._kwargs = kwargs @property @@ -341,14 +345,26 @@ def _apply(self, x: npt.ArrayLike) -> npt.ArrayLike: xp = array_api_compat.get_namespace(x) if parallelproj.is_cuda_array(x): - import cupy as cp + import array_api_compat.cupy as cp import cupyx.scipy.ndimage as ndimagex + if array_api_compat.is_array_api_obj(self._sigma): + sigma = cp.asarray(self._sigma) + else: + sigma = self._sigma + return xp.asarray(ndimagex.gaussian_filter(cp.asarray(x), + sigma=sigma, **self._kwargs), device=device(x)) else: import scipy.ndimage as ndimage + if array_api_compat.is_array_api_obj(self._sigma): + sigma = np.asarray(self._sigma) + else: + sigma = self._sigma + return xp.asarray(ndimage.gaussian_filter(np.asarray(x), + sigma=sigma, **self._kwargs), device=device(x)) diff --git a/test/parallelproj/test_operators.py b/test/parallelproj/test_operators.py index 565281ce..95b6805b 100644 --- a/test/parallelproj/test_operators.py +++ b/test/parallelproj/test_operators.py @@ -49,9 +49,13 @@ def elemenwise_test(xp: ModuleType, dev: str): def gaussian_test(xp: ModuleType, dev: str): np.random.seed(0) in_shape = (32, 32) - sigma = 2.3 + sigma1 = 2.3 - op = parallelproj.GaussianFilterOperator(in_shape, sigma=sigma) + op = parallelproj.GaussianFilterOperator(in_shape, sigma=sigma1) + op.adjointness_test(xp, dev) + + sigma2 = xp.asarray([2.3, 1.2], device=dev) + op = parallelproj.GaussianFilterOperator(in_shape, sigma=sigma2) op.adjointness_test(xp, dev)