diff --git a/nitransforms/nonlinear.py b/nitransforms/nonlinear.py index 9c29c53c..ced348a2 100644 --- a/nitransforms/nonlinear.py +++ b/nitransforms/nonlinear.py @@ -94,6 +94,10 @@ def __repr__(self): """Beautify the python representation.""" return f"<{self.__class__.__name__}[{self._field.shape[-1]}D] {self._field.shape[:3]}>" + def __len__(self): + """Enable len() -- for compatibility, only len == 1 is supported.""" + return 1 + @property def ndim(self): """Get the dimensions of the transform.""" diff --git a/nitransforms/resampling.py b/nitransforms/resampling.py index 45474008..eb3f9ad0 100644 --- a/nitransforms/resampling.py +++ b/nitransforms/resampling.py @@ -97,7 +97,6 @@ def apply( # Avoid opening the data array just yet input_dtype = get_obj_dtype(spatialimage.dataobj) - output_dtype = output_dtype or input_dtype # Number of transformations data_nvols = 1 if spatialimage.ndim < 4 else spatialimage.shape[-1] @@ -115,16 +114,17 @@ def apply( serialize_4d = n_resamplings >= serialize_nvols targets = None + ref_ndcoords = _ref.ndcoords.T if hasattr(transform, "to_field") and callable(transform.to_field): targets = ImageGrid(spatialimage).index( _as_homogeneous( - transform.to_field(reference=reference).map(_ref.ndcoords.T), + transform.to_field(reference=reference).map(ref_ndcoords), dim=_ref.ndim, ) ) elif xfm_nvols == 1: targets = ImageGrid(spatialimage).index( # data should be an image - _as_homogeneous(transform.map(_ref.ndcoords.T), dim=_ref.ndim) + _as_homogeneous(transform.map(ref_ndcoords), dim=_ref.ndim) ) if serialize_4d: @@ -137,7 +137,7 @@ def apply( # Order F ensures individual volumes are contiguous in memory # Also matches NIfTI, making final save more efficient resampled = np.zeros( - (spatialimage.size, len(transform)), dtype=output_dtype, order="F" + (len(ref_ndcoords), len(transform)), dtype=input_dtype, order="F" ) for t in range(n_resamplings): @@ -145,7 +145,7 @@ def apply( if targets is None: targets = ImageGrid(spatialimage).index( # data should be an image - _as_homogeneous(xfm_t.map(_ref.ndcoords.T), dim=_ref.ndim) + _as_homogeneous(xfm_t.map(ref_ndcoords), dim=_ref.ndim) ) # Interpolate @@ -156,7 +156,6 @@ def apply( else spatialimage.dataobj[..., t].astype(input_dtype, copy=False) ), targets, - output=output_dtype, order=order, mode=mode, cval=cval, @@ -168,7 +167,7 @@ def apply( if targets is None: targets = ImageGrid(spatialimage).index( # data should be an image - _as_homogeneous(transform.map(_ref.ndcoords.T), dim=_ref.ndim) + _as_homogeneous(transform.map(ref_ndcoords), dim=_ref.ndim) ) # Cast 3D data into 4D if 4D nonsequential transform @@ -181,7 +180,6 @@ def apply( resampled = ndi.map_coordinates( data, targets, - output=output_dtype, order=order, mode=mode, cval=cval, @@ -190,13 +188,14 @@ def apply( if isinstance(_ref, ImageGrid): # If reference is grid, reshape hdr = _ref.header.copy() if _ref.header is not None else spatialimage.header.__class__() - hdr.set_data_dtype(output_dtype) + hdr.set_data_dtype(output_dtype or spatialimage.header.get_data_dtype()) moved = spatialimage.__class__( - resampled.reshape(_ref.shape if data.ndim < 4 else _ref.shape + (-1,)), + resampled.reshape(_ref.shape if n_resamplings == 1 else _ref.shape + (-1,)), _ref.affine, hdr, ) return moved - return resampled + output_dtype = output_dtype or input_dtype + return resampled.astype(output_dtype) diff --git a/nitransforms/tests/test_base.py b/nitransforms/tests/test_base.py index fb4be8d8..c85ac2e2 100644 --- a/nitransforms/tests/test_base.py +++ b/nitransforms/tests/test_base.py @@ -1,6 +1,8 @@ """Tests of the base module.""" import numpy as np import nibabel as nb +from nibabel.arrayproxy import get_obj_dtype + import pytest import h5py @@ -97,7 +99,7 @@ def _to_hdf5(klass, x5_root): fname = testdata_path / "someones_anatomy.nii.gz" img = nb.load(fname) - imgdata = np.asanyarray(img.dataobj, dtype=img.get_data_dtype()) + imgdata = np.asanyarray(img.dataobj, dtype=get_obj_dtype(img.dataobj)) # Test identity transform - setting reference xfm = TransformBase() @@ -111,7 +113,8 @@ def _to_hdf5(klass, x5_root): xfm = nitl.Affine() xfm.reference = fname moved = apply(xfm, fname, order=0) - assert np.all(imgdata == np.asanyarray(moved.dataobj, dtype=moved.get_data_dtype())) + + assert np.all(imgdata == np.asanyarray(moved.dataobj, dtype=get_obj_dtype(moved.dataobj))) # Test ndim returned by affine assert nitl.Affine().ndim == 3