Source code for nitransforms.nonlinear
# emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
#
# See COPYING file distributed along with the NiBabel package for the
# copyright and license terms.
#
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
"""Nonlinear transforms."""
import warnings
from functools import partial
from collections import namedtuple
import numpy as np
import nibabel as nb
from nitransforms import io
from nitransforms.io.base import _ensure_image
from nitransforms.io.x5 import from_filename as load_x5
from nitransforms.interp.bspline import grid_bspline_weights, _cubic_bspline
from nitransforms.base import (
TransformBase,
TransformError,
ImageGrid,
_as_homogeneous,
)
from scipy.ndimage import map_coordinates
# Avoids circular imports
try:
from nitransforms._version import __version__
except ModuleNotFoundError: # pragma: no cover
__version__ = "0+unknown"
[docs]
class DenseFieldTransform(TransformBase):
"""Represents dense field (voxel-wise) transforms."""
__slots__ = ("_field", "_deltas", "_is_deltas")
def __init__(self, field=None, is_deltas=True, reference=None):
"""
Create a dense field transform.
Converting to a field of deformations is straightforward by just adding the corresponding
displacement to the :math:`(x, y, z)` coordinates of each voxel.
Numerically, deformation fields are less susceptible to rounding errors
than displacements fields.
SPM generally prefers deformations for that reason.
Parameters
----------
field : :obj:`numpy.array_like` or :obj:`nibabel.SpatialImage`
The field of deformations or displacements (*deltas*). If given as a data array,
then the reference **must** be given.
is_deltas : :obj:`bool`
Whether this is a displacements (deltas) field (default), or deformations.
reference : :obj:`ImageGrid`
Defines the domain of the transform. If not provided, the domain is defined from
the ``field`` input.
Example
-------
>>> DenseFieldTransform(test_dir / "someones_displacement_field.nii.gz")
<DenseFieldTransform[3D] (57, 67, 56)>
"""
if field is None and reference is None:
raise TransformError("cannot initialize field")
super().__init__()
if field is not None:
field = _ensure_image(field)
# Extract data if nibabel object otherwise assume numpy array
_data = np.squeeze(
np.asanyarray(field.dataobj)
if hasattr(field, "dataobj")
else field.copy()
)
try:
self.reference = ImageGrid(reference if reference is not None else field)
except AttributeError:
raise TransformError(
"field must be a spatial image if reference is not provided"
if reference is None
else "reference is not a spatial image"
)
fieldshape = (*self.reference.shape, self.reference.ndim)
if field is None:
_data = np.zeros(fieldshape)
elif fieldshape != _data.shape:
raise TransformError(
f"Shape of the field ({'x'.join(str(i) for i in _data.shape)}) "
f"doesn't match that of the reference({'x'.join(str(i) for i in fieldshape)})"
)
self._is_deltas = is_deltas
self._field = self.reference.ndcoords.reshape(fieldshape)
if self.is_deltas:
self._deltas = _data.copy()
self._field += self._deltas
else:
self._field = _data.copy()
def __repr__(self):
"""Beautify the python representation."""
return f"<{self.__class__.__name__}[{self._field.shape[-1]}D] {self._field.shape[:3]}>"
@property
def is_deltas(self):
"""Check whether this is a displacements (``True``) or a deformation (``False``) field."""
return self._is_deltas
@property
def ndim(self):
"""Get the dimensions of the transform."""
return self._field.ndim - 1
[docs]
def map(self, x, inverse=False):
r"""
Apply the transformation to a list of physical coordinate points.
.. math::
\mathbf{y} = \mathbf{x} + \Delta(\mathbf{x}),
\label{eq:2}\tag{2}
where :math:`\Delta(\mathbf{x})` is the value of the discrete field of displacements
:math:`\Delta` interpolated at the location :math:`\mathbf{x}`.
Parameters
----------
x : N x D :obj:`numpy.array_like`
Input RAS+ coordinates (i.e., physical coordinates).
inverse : :obj:`bool`
If ``True``, apply the inverse transform :math:`x = f^{-1}(y)`.
Returns
-------
y : N x D :obj:`numpy.array_like`
Transformed (mapped) RAS+ coordinates (i.e., physical coordinates).
Examples
--------
>>> xfm = DenseFieldTransform(
... test_dir / "someones_displacement_field.nii.gz",
... is_deltas=False,
... )
>>> xfm.map([[-6.5, -36., -19.5]]).tolist()
[[0.0, -0.47516798973083496, 0.0]]
>>> xfm.map([[-6.5, -36., -19.5], [-1., -41.5, -11.25]]).tolist()
[[0.0, -0.47516798973083496, 0.0], [0.0, -0.538356602191925, 0.0]]
>>> np.array_str(
... xfm.map([[-6.7, -36.3, -19.2], [-1., -41.5, -11.25]]),
... precision=3,
... suppress_small=True,
... )
'[[ 0. -0.482 0. ]\n [ 0. -0.538 0. ]]'
>>> xfm = DenseFieldTransform(
... test_dir / "someones_displacement_field.nii.gz",
... is_deltas=True,
... )
>>> xfm.map([[-6.5, -36., -19.5], [-1., -41.5, -11.25]]).tolist() # doctest: +ELLIPSIS
[[-6.5, -36.475..., -19.5], [-1.0, -42.038..., -11.25]]
>>> np.array_str(
... xfm.map([[-6.7, -36.3, -19.2], [-1., -41.5, -11.25]]),
... precision=3,
... suppress_small=True,
... )
'[[ -6.7 -36.782 -19.2 ]\n [ -1. -42.038 -11.25 ]]'
"""
if inverse is True:
raise NotImplementedError
ijk = self.reference.index(np.array(x, dtype="float32"))
indexes = np.round(ijk).astype("int")
ongrid = np.where(np.linalg.norm(ijk - indexes, axis=1) < 1e-3)[0]
if ongrid.size == np.shape(x)[0]:
# return self._field[*indexes.T, :] # From Python 3.11
return self._field[tuple(indexes.T) + (np.s_[:],)]
mapped_coords = np.vstack(
tuple(
map_coordinates(
self._field[..., i],
ijk.T,
order=3,
mode="constant",
cval=np.nan,
prefilter=True,
)
for i in range(self.reference.ndim)
)
).T
# Set NaN values back to the original coordinates value = no displacement
mapped_coords[np.isnan(mapped_coords)] = np.array(x)[np.isnan(mapped_coords)]
return mapped_coords
def __matmul__(self, b):
"""
Compose with a transform on the right.
Examples
--------
>>> deff = DenseFieldTransform(
... test_dir / "someones_displacement_field.nii.gz",
... is_deltas=False,
... )
>>> deff2 = deff @ TransformBase()
>>> deff == deff2
True
>>> disp = DenseFieldTransform(test_dir / "someones_displacement_field.nii.gz")
>>> disp2 = disp @ TransformBase()
>>> disp == disp2
True
"""
retval = b.map(self._field.reshape((-1, self._field.shape[-1]))).reshape(
self._field.shape
)
return DenseFieldTransform(retval, is_deltas=False, reference=self.reference)
def __eq__(self, other):
"""
Overload equals operator.
Examples
--------
>>> xfm1 = DenseFieldTransform(test_dir / "someones_displacement_field.nii.gz")
>>> xfm2 = DenseFieldTransform(test_dir / "someones_displacement_field.nii.gz")
>>> xfm1 == xfm2
True
>>> xfm1 == TransformBase()
False
>>> xfm1 == BSplineFieldTransform(test_dir / "someones_bspline_coefficients.nii.gz")
False
"""
if not hasattr(other, "_field") or self._field.shape != other._field.shape:
return False
_eq = np.allclose(self._field, other._field)
if _eq and self._reference != other._reference:
warnings.warn("Fields are equal, but references do not match.")
return _eq
[docs]
def to_filename(self, filename, fmt="X5", moving=None, x5_inverse=False):
"""Store the transform in the designated format."""
if fmt.upper() == "X5":
raise TypeError("Please use .to_x5()")
field = nb.Nifti1Image(
self._deltas if self.is_deltas else self._field,
self.reference.affine,
None,
)
if fmt.lower() == "afni":
from nitransforms.io.afni import AFNIDisplacementsField as FieldIOType
elif fmt.lower() in ("itk", "ants", "elastix"):
from nitransforms.io.itk import ITKDisplacementsField as FieldIOType
elif fmt.lower() == "fsl":
from nitransforms.io.fsl import FSLDisplacementsField as FieldIOType
else:
raise NotImplementedError(
f"Dense field of type '{fmt}' cannot be converted."
)
FieldIOType.to_image(field).to_filename(filename)
[docs]
def to_x5(self, metadata=None):
"""Return an :class:`~nitransforms.io.x5.X5Transform` representation."""
metadata = {"WrittenBy": f"NiTransforms {__version__}"} | (metadata or {})
domain = None
if (reference := self.reference) is not None:
domain = io.x5.X5Domain(
grid=True,
size=getattr(reference, "shape", (0, 0, 0)),
mapping=reference.affine,
coordinates="cartesian",
)
kinds = tuple("space" for _ in range(self.ndim)) + ("vector",)
return io.x5.X5Transform(
type="nonlinear",
subtype="densefield",
representation="displacements" if self.is_deltas else "deformations",
metadata=metadata,
transform=self._deltas if self.is_deltas else self._field,
dimension_kinds=kinds,
domain=domain,
)
@classmethod
def from_filename(cls, filename, fmt="X5", x5_position=0):
_factory = {
"afni": io.afni.AFNIDisplacementsField,
"itk": io.itk.ITKDisplacementsField,
"fsl": io.fsl.FSLDisplacementsField,
"X5": None,
}
fmt = fmt.upper()
if fmt not in {k.upper() for k in _factory}:
raise NotImplementedError(f"Unsupported format <{fmt}>")
if fmt == "X5":
return from_x5(load_x5(filename), x5_position=x5_position)
return cls(_factory[fmt.lower()].from_filename(filename))
load = DenseFieldTransform.from_filename
[docs]
class BSplineFieldTransform(TransformBase):
"""Represent a nonlinear transform parameterized by BSpline basis."""
__slots__ = ["_coeffs", "_knots", "_weights", "_order", "_moving"]
def __init__(self, coefficients, reference=None, order=3):
"""Create a smooth deformation field using B-Spline basis."""
super().__init__()
self._order = order
coefficients = _ensure_image(coefficients)
self._coeffs = np.asanyarray(coefficients.dataobj)
self._knots = ImageGrid(coefficients)
self._weights = None
if reference is not None:
self.reference = reference
if coefficients.shape[-1] != self.reference.ndim:
raise TransformError(
"Number of components of the coefficients does "
"not match the number of dimensions"
)
def __eq__(self, other):
"""
Overload equals operator.
Examples
--------
>>> xfm1 = BSplineFieldTransform(test_dir / "someones_bspline_coefficients.nii.gz")
>>> xfm2 = BSplineFieldTransform(test_dir / "someones_bspline_coefficients.nii.gz")
>>> xfm1 == xfm2
True
>>> xfm2._coeffs[:, :, :] = 0 # Let's zero all coefficients
>>> xfm1 == xfm2
False
>>> xfm2 = BSplineFieldTransform(
... test_dir / "someones_bspline_coefficients.nii.gz",
... order=4,
... )
>>> xfm1 == xfm2
False
>>> xfm1 == TransformBase()
False
>>> xfm1 == DenseFieldTransform(test_dir / "someones_displacement_field.nii.gz")
False
"""
if not hasattr(other, "_coeffs") or self._coeffs.shape != other._coeffs.shape:
return False
_eq = self._order == other._order
_eq = _eq and np.allclose(self._coeffs, other._coeffs)
if _eq and self._reference != other._reference:
warnings.warn("Coefficients are equal, but references do not match.")
return _eq
@property
def ndim(self):
"""Get the dimensions of the transform."""
return self._coeffs.ndim - 1
@classmethod
def from_filename(cls, filename, fmt="X5", x5_position=0):
_factory = {
"X5": None,
}
fmt = fmt.upper()
if fmt not in {k.upper() for k in _factory}:
raise NotImplementedError(f"Unsupported format <{fmt}>")
return from_x5(load_x5(filename), x5_position=x5_position)
# return cls(_factory[fmt.lower()].from_filename(filename))
[docs]
def to_field(self, reference=None, dtype="float32"):
"""Generate a displacements deformation field from this B-Spline field."""
_ref = (
self.reference if reference is None else ImageGrid(_ensure_image(reference))
)
if _ref is None:
raise TransformError("A reference must be defined")
if self._weights is None:
self._weights = grid_bspline_weights(_ref, self._knots)
field = np.zeros((_ref.npoints, self.ndim))
for d in range(self.ndim):
# 1 x Nvox : (1 x K) @ (K x Nvox)
field[:, d] = self._coeffs[..., d].reshape(-1) @ self._weights
return DenseFieldTransform(
field.astype(dtype).reshape(*_ref.shape, -1), reference=_ref
)
[docs]
def to_x5(self, metadata=None):
"""Return an :class:`~nitransforms.io.x5.X5Transform` representation."""
metadata = {"WrittenBy": f"NiTransforms {__version__}"} | (metadata or {})
domain = None
if (reference := self.reference) is not None:
domain = io.x5.X5Domain(
grid=True,
size=getattr(reference, "shape", (0, 0, 0)),
mapping=reference.affine,
coordinates="cartesian",
)
kinds = tuple("space" for _ in range(self.ndim)) + ("vector",)
return io.x5.X5Transform(
type="nonlinear",
subtype="bspline",
representation="coefficients",
metadata=metadata,
transform=self._coeffs,
dimension_kinds=kinds,
domain=domain,
additional_parameters=self._knots.affine,
)
[docs]
def map(self, x, inverse=False):
r"""
Apply the transformation to a list of physical coordinate points.
.. math::
\mathbf{y} = \mathbf{x} + \Psi^3(\mathbf{k}, \mathbf{x}),
\label{eq:1}\tag{1}
Parameters
----------
x : N x D numpy.ndarray
Input RAS+ coordinates (i.e., physical coordinates).
inverse : bool
If ``True``, apply the inverse transform :math:`x = f^{-1}(y)`.
Returns
-------
y : N x D numpy.ndarray
Transformed (mapped) RAS+ coordinates (i.e., physical coordinates).
Examples
--------
>>> xfm = BSplineFieldTransform(test_dir / "someones_bspline_coefficients.nii.gz")
>>> xfm.reference = test_dir / "someones_anatomy.nii.gz"
>>> xfm.map([-6.5, -36., -19.5]).tolist() # doctest: +ELLIPSIS
[[-6.5, -36.475114..., -19.5]]
>>> xfm.map([[-6.5, -36., -19.5], [-1., -41.5, -11.25]]).tolist() # doctest: +ELLIPSIS
[[-6.5, -36.475114..., -19.5], [-1.0, -42.03878957..., -11.25]]
"""
vfunc = partial(
_map_xyz,
reference=self.reference,
knots=self._knots,
coeffs=self._coeffs,
)
return np.array([vfunc(_x).tolist() for _x in np.atleast_2d(x)])
[docs]
def from_x5(x5_list, x5_position=0):
"""Create a transform from a list of :class:`~nitransforms.io.x5.X5Transform` objects."""
x5_xfm = x5_list[x5_position]
Transform = (
BSplineFieldTransform if x5_xfm.subtype == "bspline" else DenseFieldTransform
)
Domain = namedtuple("Domain", "affine shape")
reference = Domain(x5_xfm.domain.mapping, x5_xfm.domain.size)
xfm_params = (
nb.Nifti1Image(x5_xfm.transform, x5_xfm.additional_parameters)
if x5_xfm.subtype == "bspline"
else x5_xfm.transform
)
xfm_kwargs = (
{}
if x5_xfm.subtype == "bspline"
else {"is_deltas": x5_xfm.representation == "displacements"}
)
return Transform(xfm_params, reference=reference, **xfm_kwargs)
def _map_xyz(x, reference, knots, coeffs):
"""Apply the transformation to just one coordinate."""
ndim = len(x)
# Calculate the index coordinates of the point in the B-Spline grid
ijk = (knots.inverse @ _as_homogeneous(x).squeeze())[:ndim]
# Determine the window within distance 2.0 (where the B-Spline is nonzero).
# Probably this will change if the order of the B-Spline is different
w_start, w_end = np.ceil(ijk - 2).astype(int), np.floor(ijk + 2).astype(int)
# Generate a grid of indexes corresponding to the window, clipped to the
# coefficient grid boundaries
nonzero_knots = []
for start, end, size in zip(w_start, w_end, knots.shape):
start = max(start, 0)
end = min(end, size - 1)
nonzero_knots.append(np.arange(start, end + 1))
nonzero_knots = tuple(np.meshgrid(*nonzero_knots, indexing="ij"))
window = np.array(nonzero_knots).reshape((ndim, -1))
# Calculate the absolute distance of the location w.r.t. all voxels in
# the window. Distances are expressed in knot-grid voxel units
distance = np.abs(window.T - ijk)
# Since this is a grid, distance only takes a few float values
unique_d, indices = np.unique(distance.reshape(-1), return_inverse=True)
# Calculate the B-Spline weight corresponding to the distance.
# Then multiply the three weights of each knot (tensor-product B-Spline)
tensor_bspline = _cubic_bspline(unique_d)[indices].reshape(distance.shape).prod(1)
# Extract the values of the coefficients in the window
coeffs = coeffs[nonzero_knots].reshape((-1, ndim))
# Inference: the displacement is the product of coefficients x tensor-product B-Splines
return x + coeffs.T @ tensor_bspline