Skip to content

Commit

Permalink
Merge pull request #195 from jmarabotto/sandbox
Browse files Browse the repository at this point in the history
ENH: Outsource ``apply()`` from transform objects
  • Loading branch information
oesteban authored May 17, 2024
2 parents f28cd14 + 5b1736b commit 8a18581
Show file tree
Hide file tree
Showing 9 changed files with 280 additions and 366 deletions.
107 changes: 9 additions & 98 deletions nitransforms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from nibabel import funcs as _nbfuncs
from nibabel.nifti1 import intent_codes as INTENT_CODES
from nibabel.cifti2 import Cifti2Image
from scipy import ndimage as ndi

EQUALITY_TOL = 1e-5

Expand Down Expand Up @@ -178,7 +177,10 @@ def __ne__(self, other):
class TransformBase:
"""Abstract image class to represent transforms."""

__slots__ = ("_reference", "_ndim",)
__slots__ = (
"_reference",
"_ndim",
)

def __init__(self, reference=None):
"""Instantiate a transform."""
Expand Down Expand Up @@ -222,101 +224,6 @@ def ndim(self):
"""Access the dimensions of the reference space."""
raise TypeError("TransformBase has no dimensions")

def apply(
self,
spatialimage,
reference=None,
order=3,
mode="constant",
cval=0.0,
prefilter=True,
output_dtype=None,
):
"""
Apply a transformation to an image, resampling on the reference spatial object.
Parameters
----------
spatialimage : `spatialimage`
The image object containing the data to be resampled in reference
space
reference : spatial object, optional
The image, surface, or combination thereof containing the coordinates
of samples that will be sampled.
order : int, optional
The order of the spline interpolation, default is 3.
The order has to be in the range 0-5.
mode : {'constant', 'reflect', 'nearest', 'mirror', 'wrap'}, optional
Determines how the input image is extended when the resamplings overflows
a border. Default is 'constant'.
cval : float, optional
Constant value for ``mode='constant'``. Default is 0.0.
prefilter: bool, optional
Determines if the image's data array is prefiltered with
a spline filter before interpolation. The default is ``True``,
which will create a temporary *float64* array of filtered values
if *order > 1*. If setting this to ``False``, the output will be
slightly blurred if *order > 1*, unless the input is prefiltered,
i.e. it is the result of calling the spline filter on the original
input.
output_dtype: dtype specifier, optional
The dtype of the returned array or image, if specified.
If ``None``, the default behavior is to use the effective dtype of
the input image. If slope and/or intercept are defined, the effective
dtype is float64, otherwise it is equivalent to the input image's
``get_data_dtype()`` (on-disk type).
If ``reference`` is defined, then the return value is an image, with
a data array of the effective dtype but with the on-disk dtype set to
the input image's on-disk dtype.
Returns
-------
resampled : `spatialimage` or ndarray
The data imaged after resampling to reference space.
"""
if reference is not None and isinstance(reference, (str, Path)):
reference = _nbload(str(reference))

_ref = (
self.reference if reference is None else SpatialReference.factory(reference)
)

if _ref is None:
raise TransformError("Cannot apply transform without reference")

if isinstance(spatialimage, (str, Path)):
spatialimage = _nbload(str(spatialimage))

data = np.asanyarray(spatialimage.dataobj)
targets = ImageGrid(spatialimage).index( # data should be an image
_as_homogeneous(self.map(_ref.ndcoords.T), dim=_ref.ndim)
)

resampled = ndi.map_coordinates(
data,
targets.T,
output=output_dtype,
order=order,
mode=mode,
cval=cval,
prefilter=prefilter,
)

if isinstance(_ref, ImageGrid): # If reference is grid, reshape
hdr = None
if _ref.header is not None:
hdr = _ref.header.copy()
hdr.set_data_dtype(output_dtype or spatialimage.get_data_dtype())
moved = spatialimage.__class__(
resampled.reshape(_ref.shape),
_ref.affine,
hdr,
)
return moved

return resampled

def map(self, x, inverse=False):
r"""
Apply :math:`y = f(x)`.
Expand Down Expand Up @@ -382,4 +289,8 @@ def _as_homogeneous(xyz, dtype="float32", dim=3):

def _apply_affine(x, affine, dim):
"""Get the image array's indexes corresponding to coordinates."""
return affine.dot(_as_homogeneous(x, dim=dim).T)[:dim, ...].T
return np.tensordot(
affine,
_as_homogeneous(x, dim=dim).T,
axes=1,
)[:dim, ...]
4 changes: 3 additions & 1 deletion nitransforms/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from .linear import load as linload
from .nonlinear import load as nlinload
from .resampling import apply


def cli_apply(pargs):
Expand Down Expand Up @@ -38,7 +39,8 @@ def cli_apply(pargs):
# ensure a reference is set
xfm.reference = pargs.ref or pargs.moving

moved = xfm.apply(
moved = apply(
xfm,
pargs.moving,
order=pargs.order,
mode=pargs.mode,
Expand Down
125 changes: 4 additions & 121 deletions nitransforms/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,12 @@
import warnings
import numpy as np
from pathlib import Path
from scipy import ndimage as ndi

from nibabel.loadsave import load as _nbload
from nibabel.affines import from_matvec
from nibabel.arrayproxy import get_obj_dtype

from nitransforms.base import (
ImageGrid,
TransformBase,
SpatialReference,
_as_homogeneous,
EQUALITY_TOL,
)
Expand Down Expand Up @@ -113,6 +109,10 @@ def __invert__(self):
"""
return self.__class__(self._inverse)

def __len__(self):
"""Enable using len()."""
return 1 if self._matrix.ndim == 2 else len(self._matrix)

def __matmul__(self, b):
"""
Compose two Affines.
Expand Down Expand Up @@ -330,10 +330,6 @@ def __getitem__(self, i):
"""Enable indexed access to the series of matrices."""
return Affine(self.matrix[i, ...], reference=self._reference)

def __len__(self):
"""Enable using len()."""
return len(self._matrix)

def map(self, x, inverse=False):
r"""
Apply :math:`y = f(x)`.
Expand Down Expand Up @@ -402,119 +398,6 @@ def to_filename(self, filename, fmt="X5", moving=None):
).to_filename(filename)
return filename

def apply(
self,
spatialimage,
reference=None,
order=3,
mode="constant",
cval=0.0,
prefilter=True,
output_dtype=None,
):
"""
Apply a transformation to an image, resampling on the reference spatial object.
Parameters
----------
spatialimage : `spatialimage`
The image object containing the data to be resampled in reference
space
reference : spatial object, optional
The image, surface, or combination thereof containing the coordinates
of samples that will be sampled.
order : int, optional
The order of the spline interpolation, default is 3.
The order has to be in the range 0-5.
mode : {"constant", "reflect", "nearest", "mirror", "wrap"}, optional
Determines how the input image is extended when the resamplings overflows
a border. Default is "constant".
cval : float, optional
Constant value for ``mode="constant"``. Default is 0.0.
prefilter: bool, optional
Determines if the image's data array is prefiltered with
a spline filter before interpolation. The default is ``True``,
which will create a temporary *float64* array of filtered values
if *order > 1*. If setting this to ``False``, the output will be
slightly blurred if *order > 1*, unless the input is prefiltered,
i.e. it is the result of calling the spline filter on the original
input.
Returns
-------
resampled : `spatialimage` or ndarray
The data imaged after resampling to reference space.
"""

if reference is not None and isinstance(reference, (str, Path)):
reference = _nbload(str(reference))

_ref = (
self.reference if reference is None else SpatialReference.factory(reference)
)

if isinstance(spatialimage, (str, Path)):
spatialimage = _nbload(str(spatialimage))

# Avoid opening the data array just yet
input_dtype = get_obj_dtype(spatialimage.dataobj)
output_dtype = output_dtype or input_dtype

# Prepare physical coordinates of input (grid, points)
xcoords = _ref.ndcoords.astype("f4").T

# Invert target's (moving) affine once
ras2vox = ~Affine(spatialimage.affine)

if spatialimage.ndim == 4 and (len(self) != spatialimage.shape[-1]):
raise ValueError(
"Attempting to apply %d transforms on a file with "
"%d timepoints" % (len(self), spatialimage.shape[-1])
)

# Order F ensures individual volumes are contiguous in memory
# Also matches NIfTI, making final save more efficient
resampled = np.zeros(
(xcoords.shape[0], len(self)), dtype=output_dtype, order="F"
)

dataobj = (
np.asanyarray(spatialimage.dataobj, dtype=input_dtype)
if spatialimage.ndim in (2, 3)
else None
)

for t, xfm_t in enumerate(self):
# Map the input coordinates on to timepoint t of the target (moving)
ycoords = xfm_t.map(xcoords)[..., : _ref.ndim]

# Calculate corresponding voxel coordinates
yvoxels = ras2vox.map(ycoords)[..., : _ref.ndim]

# Interpolate
resampled[..., t] = ndi.map_coordinates(
(
dataobj
if dataobj is not None
else spatialimage.dataobj[..., t].astype(input_dtype, copy=False)
),
yvoxels.T,
output=output_dtype,
order=order,
mode=mode,
cval=cval,
prefilter=prefilter,
)

if isinstance(_ref, ImageGrid): # If reference is grid, reshape
newdata = resampled.reshape(_ref.shape + (len(self),))
moved = spatialimage.__class__(newdata, _ref.affine, spatialimage.header)
moved.header.set_data_dtype(output_dtype)
return moved

return resampled


def load(filename, fmt=None, reference=None, moving=None):
"""
Expand Down
Loading

0 comments on commit 8a18581

Please sign in to comment.