diff --git a/nitransforms/resampling.py b/nitransforms/resampling.py index 430abf1a..56a02dc7 100644 --- a/nitransforms/resampling.py +++ b/nitransforms/resampling.py @@ -8,10 +8,11 @@ ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## """Resampling utilities.""" +import asyncio from os import cpu_count -from concurrent.futures import ProcessPoolExecutor, as_completed +from functools import partial from pathlib import Path -from typing import Tuple +from typing import Callable, TypeVar import numpy as np from nibabel.loadsave import load as _nbload @@ -27,65 +28,19 @@ _as_homogeneous, ) +R = TypeVar("R") + SERIALIZE_VOLUME_WINDOW_WIDTH: int = 8 """Minimum number of volumes to automatically serialize 4D transforms.""" -def _apply_volume( - index: int, - data: np.ndarray, - targets: np.ndarray, - order: int = 3, - mode: str = "constant", - cval: float = 0.0, - prefilter: bool = True, -) -> Tuple[int, np.ndarray]: - """ - Decorate :obj:`~scipy.ndimage.map_coordinates` to return an order index for parallelization. +async def worker(job: Callable[[], R], semaphore) -> R: + async with semaphore: + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, job) - Parameters - ---------- - index : :obj:`int` - The index of the volume to apply the interpolation to. - data : :obj:`~numpy.ndarray` - The input data array. - targets : :obj:`~numpy.ndarray` - The target coordinates for mapping. - order : :obj:`int`, optional - The order of the spline interpolation, default is 3. - The order has to be in the range 0-5. - mode : :obj:`str`, optional - Determines how the input image is extended when the resamplings overflows - a border. One of ``'constant'``, ``'reflect'``, ``'nearest'``, ``'mirror'``, - or ``'wrap'``. Default is ``'constant'``. - cval : :obj:`float`, optional - Constant value for ``mode='constant'``. Default is 0.0. - prefilter: :obj:`bool`, optional - Determines if the image's data array is prefiltered with - a spline filter before interpolation. The default is ``True``, - which will create a temporary *float64* array of filtered values - if *order > 1*. If setting this to ``False``, the output will be - slightly blurred if *order > 1*, unless the input is prefiltered, - i.e. it is the result of calling the spline filter on the original - input. - - Returns - ------- - (:obj:`int`, :obj:`~numpy.ndarray`) - The index and the array resulting from the interpolation. - - """ - return index, ndi.map_coordinates( - data, - targets, - order=order, - mode=mode, - cval=cval, - prefilter=prefilter, - ) - -def apply( +async def apply( transform: TransformBase, spatialimage: str | Path | SpatialImage, reference: str | Path | SpatialImage = None, @@ -94,9 +49,9 @@ def apply( cval: float = 0.0, prefilter: bool = True, output_dtype: np.dtype = None, - serialize_nvols: int = SERIALIZE_VOLUME_WINDOW_WIDTH, - njobs: int = None, dtype_width: int = 8, + serialize_nvols: int = SERIALIZE_VOLUME_WINDOW_WIDTH, + max_concurrent: int = min(cpu_count(), 12), ) -> SpatialImage | np.ndarray: """ Apply a transformation to an image, resampling on the reference spatial object. @@ -118,7 +73,7 @@ def apply( or ``'wrap'``. Default is ``'constant'``. cval : :obj:`float`, optional Constant value for ``mode='constant'``. Default is 0.0. - prefilter: :obj:`bool`, optional + prefilter : :obj:`bool`, optional Determines if the image's data array is prefiltered with a spline filter before interpolation. The default is ``True``, which will create a temporary *float64* array of filtered values @@ -126,7 +81,7 @@ def apply( slightly blurred if *order > 1*, unless the input is prefiltered, i.e. it is the result of calling the spline filter on the original input. - output_dtype: :obj:`~numpy.dtype`, optional + output_dtype : :obj:`~numpy.dtype`, optional The dtype of the returned array or image, if specified. If ``None``, the default behavior is to use the effective dtype of the input image. If slope and/or intercept are defined, the effective @@ -135,10 +90,17 @@ def apply( If ``reference`` is defined, then the return value is an image, with a data array of the effective dtype but with the on-disk dtype set to the input image's on-disk dtype. - dtype_width: :obj:`int` + dtype_width : :obj:`int` Cap the width of the input data type to the given number of bytes. This argument is intended to work as a way to implement lower memory requirements in resampling. + serialize_nvols : :obj:`int` + Minimum number of volumes in a 3D+t (that is, a series of 3D transformations + independent in time) to resample on a one-by-one basis. + Serialized resampling can be executed concurrently (parallelized) with + the argument ``max_concurrent``. + max_concurrent : :obj:`int` + Maximum number of 3D resamplings to be executed concurrently. Returns ------- @@ -201,46 +163,47 @@ def apply( else None ) - njobs = cpu_count() if njobs is None or njobs < 1 else njobs + # Order F ensures individual volumes are contiguous in memory + # Also matches NIfTI, making final save more efficient + resampled = np.zeros( + (len(ref_ndcoords), len(transform)), dtype=input_dtype, order="F" + ) - with ProcessPoolExecutor(max_workers=min(njobs, n_resamplings)) as executor: - results = [] - for t in range(n_resamplings): - xfm_t = transform if n_resamplings == 1 else transform[t] + semaphore = asyncio.Semaphore(max_concurrent) - if targets is None: - targets = ImageGrid(spatialimage).index( # data should be an image - _as_homogeneous(xfm_t.map(ref_ndcoords), dim=_ref.ndim) - ) + tasks = [] + for t in range(n_resamplings): + xfm_t = transform if n_resamplings == 1 else transform[t] - data_t = ( - data - if data is not None - else spatialimage.dataobj[..., t].astype(input_dtype, copy=False) + if targets is None: + targets = ImageGrid(spatialimage).index( # data should be an image + _as_homogeneous(xfm_t.map(ref_ndcoords), dim=_ref.ndim) ) - results.append( - executor.submit( - _apply_volume, - t, - data_t, - targets, - order=order, - mode=mode, - cval=cval, - prefilter=prefilter, + data_t = ( + data + if data is not None + else spatialimage.dataobj[..., t].astype(input_dtype, copy=False) + ) + + tasks.append( + asyncio.create_task( + worker( + partial( + ndi.map_coordinates, + data_t, + targets, + output=resampled[..., t], + order=order, + mode=mode, + cval=cval, + prefilter=prefilter, + ), + semaphore, ) ) - - # Order F ensures individual volumes are contiguous in memory - # Also matches NIfTI, making final save more efficient - resampled = np.zeros( - (len(ref_ndcoords), len(transform)), dtype=input_dtype, order="F" ) - - for future in as_completed(results): - t, resampled_t = future.result() - resampled[..., t] = resampled_t + await asyncio.gather(*tasks) else: data = np.asanyarray(spatialimage.dataobj, dtype=input_dtype) diff --git a/nitransforms/tests/test_resampling.py b/nitransforms/tests/test_resampling.py index f944b225..2384ad97 100644 --- a/nitransforms/tests/test_resampling.py +++ b/nitransforms/tests/test_resampling.py @@ -15,7 +15,7 @@ from nitransforms import nonlinear as nitnl from nitransforms import manip as nitm from nitransforms import io -from nitransforms.resampling import apply, _apply_volume +from nitransforms.resampling import apply RMSE_TOL_LINEAR = 0.09 RMSE_TOL_NONLINEAR = 0.05 @@ -363,16 +363,3 @@ def test_LinearTransformsMapping_apply( reference=testdata_path / "sbref.nii.gz", serialize_nvols=2 if serialize_4d else np.inf, ) - - -@pytest.mark.parametrize("t", list(range(4))) -def test_apply_helper(monkeypatch, t): - """Ensure the apply helper function correctly just decorates with index.""" - from nitransforms.resampling import ndi - - def _retval(*args, **kwargs): - return 1 - - monkeypatch.setattr(ndi, "map_coordinates", _retval) - - assert _apply_volume(t, None, None) == (t, 1)