Skip to content

Commit

Permalink
fix: resolve some failing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
oesteban committed Jul 30, 2024
1 parent e0bde09 commit fbb0451
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 13 deletions.
4 changes: 4 additions & 0 deletions nitransforms/nonlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ def __repr__(self):
"""Beautify the python representation."""
return f"<{self.__class__.__name__}[{self._field.shape[-1]}D] {self._field.shape[:3]}>"

def __len__(self):
"""Enable len() -- for compatibility, only len == 1 is supported."""
return 1

@property
def ndim(self):
"""Get the dimensions of the transform."""
Expand Down
21 changes: 10 additions & 11 deletions nitransforms/resampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def apply(

# Avoid opening the data array just yet
input_dtype = get_obj_dtype(spatialimage.dataobj)
output_dtype = output_dtype or input_dtype

# Number of transformations
data_nvols = 1 if spatialimage.ndim < 4 else spatialimage.shape[-1]
Expand All @@ -115,16 +114,17 @@ def apply(
serialize_4d = n_resamplings >= serialize_nvols

targets = None
ref_ndcoords = _ref.ndcoords.T
if hasattr(transform, "to_field") and callable(transform.to_field):
targets = ImageGrid(spatialimage).index(
_as_homogeneous(
transform.to_field(reference=reference).map(_ref.ndcoords.T),
transform.to_field(reference=reference).map(ref_ndcoords),
dim=_ref.ndim,
)
)
elif xfm_nvols == 1:
targets = ImageGrid(spatialimage).index( # data should be an image
_as_homogeneous(transform.map(_ref.ndcoords.T), dim=_ref.ndim)
_as_homogeneous(transform.map(ref_ndcoords), dim=_ref.ndim)
)

if serialize_4d:
Expand All @@ -137,15 +137,15 @@ def apply(
# Order F ensures individual volumes are contiguous in memory
# Also matches NIfTI, making final save more efficient
resampled = np.zeros(
(spatialimage.size, len(transform)), dtype=output_dtype, order="F"
(len(ref_ndcoords), len(transform)), dtype=input_dtype, order="F"
)

for t in range(n_resamplings):
xfm_t = transform if n_resamplings == 1 else transform[t]

if targets is None:
targets = ImageGrid(spatialimage).index( # data should be an image
_as_homogeneous(xfm_t.map(_ref.ndcoords.T), dim=_ref.ndim)
_as_homogeneous(xfm_t.map(ref_ndcoords), dim=_ref.ndim)
)

# Interpolate
Expand All @@ -156,7 +156,6 @@ def apply(
else spatialimage.dataobj[..., t].astype(input_dtype, copy=False)
),
targets,
output=output_dtype,
order=order,
mode=mode,
cval=cval,
Expand All @@ -168,7 +167,7 @@ def apply(

if targets is None:
targets = ImageGrid(spatialimage).index( # data should be an image
_as_homogeneous(transform.map(_ref.ndcoords.T), dim=_ref.ndim)
_as_homogeneous(transform.map(ref_ndcoords), dim=_ref.ndim)
)

# Cast 3D data into 4D if 4D nonsequential transform
Expand All @@ -181,7 +180,6 @@ def apply(
resampled = ndi.map_coordinates(
data,
targets,
output=output_dtype,
order=order,
mode=mode,
cval=cval,
Expand All @@ -190,13 +188,14 @@ def apply(

if isinstance(_ref, ImageGrid): # If reference is grid, reshape
hdr = _ref.header.copy() if _ref.header is not None else spatialimage.header.__class__()
hdr.set_data_dtype(output_dtype)
hdr.set_data_dtype(output_dtype or spatialimage.header.get_data_dtype())

moved = spatialimage.__class__(
resampled.reshape(_ref.shape if data.ndim < 4 else _ref.shape + (-1,)),
resampled.reshape(_ref.shape if n_resamplings == 1 else _ref.shape + (-1,)),
_ref.affine,
hdr,
)
return moved

return resampled
output_dtype = output_dtype or input_dtype
return resampled.astype(output_dtype)
7 changes: 5 additions & 2 deletions nitransforms/tests/test_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Tests of the base module."""
import numpy as np
import nibabel as nb
from nibabel.arrayproxy import get_obj_dtype

import pytest
import h5py

Expand Down Expand Up @@ -97,7 +99,7 @@ def _to_hdf5(klass, x5_root):
fname = testdata_path / "someones_anatomy.nii.gz"

img = nb.load(fname)
imgdata = np.asanyarray(img.dataobj, dtype=img.get_data_dtype())
imgdata = np.asanyarray(img.dataobj, dtype=get_obj_dtype(img.dataobj))

# Test identity transform - setting reference
xfm = TransformBase()
Expand All @@ -111,7 +113,8 @@ def _to_hdf5(klass, x5_root):
xfm = nitl.Affine()
xfm.reference = fname
moved = apply(xfm, fname, order=0)
assert np.all(imgdata == np.asanyarray(moved.dataobj, dtype=moved.get_data_dtype()))

assert np.all(imgdata == np.asanyarray(moved.dataobj, dtype=get_obj_dtype(moved.dataobj)))

# Test ndim returned by affine
assert nitl.Affine().ndim == 3
Expand Down

0 comments on commit fbb0451

Please sign in to comment.