Source code for nitransforms.base

# emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
#
#   See COPYING file distributed along with the NiBabel package for the
#   copyright and license terms.
#
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
"""Common interface for transforms."""
from pathlib import Path
import numpy as np
import h5py
import warnings
from nibabel.loadsave import load as _nbload
from nibabel import funcs as _nbfuncs
from nibabel.nifti1 import intent_codes as INTENT_CODES
from nibabel.cifti2 import Cifti2Image
from scipy import ndimage as ndi

EQUALITY_TOL = 1e-5


[docs]class TransformError(TypeError): """A custom exception for transforms."""
[docs]class SpatialReference: """Factory to create spatial references."""
[docs] @staticmethod def factory(dataset): """Create a reference for spatial transforms.""" try: return SampledSpatialData(dataset) except ValueError: return ImageGrid(dataset)
[docs]class SampledSpatialData: """Represent sampled spatial data: regularly gridded (images) and surfaces.""" __slots__ = ["_ndim", "_coords", "_npoints", "_shape"] def __init__(self, dataset): """Create a sampling reference.""" self._shape = None if isinstance(dataset, SampledSpatialData): self._coords = dataset.ndcoords.copy() self._npoints, self._ndim = self._coords.shape return if isinstance(dataset, (str, Path)): dataset = _nbload(str(dataset)) if hasattr(dataset, "numDA"): # Looks like a Gifti file _das = dataset.get_arrays_from_intent(INTENT_CODES["pointset"]) if not _das: raise TypeError( "Input Gifti file does not contain reference coordinates." ) self._coords = np.vstack([da.data for da in _das]) self._npoints, self._ndim = self._coords.shape return if isinstance(dataset, Cifti2Image): raise NotImplementedError raise ValueError("Dataset could not be interpreted as an irregular sample.") @property def npoints(self): """Access the total number of voxels.""" return self._npoints @property def ndim(self): """Access the number of dimensions.""" return self._ndim @property def ndcoords(self): """List the physical coordinates of this sample.""" return self._coords @property def shape(self): """Access the space's size of each dimension.""" return self._shape
[docs]class ImageGrid(SampledSpatialData): """Class to represent spaces of gridded data (images).""" __slots__ = ["_affine", "_inverse", "_ndindex", "_header"] def __init__(self, image): """Create a gridded sampling reference.""" if isinstance(image, (str, Path)): image = _nbfuncs.squeeze_image(_nbload(str(image))) self._affine = image.affine self._shape = image.shape self._header = getattr(image, "header", None) self._ndim = getattr(image, "ndim", len(image.shape)) if self._ndim >= 4: self._shape = image.shape[:3] self._ndim = 3 self._npoints = getattr(image, "npoints", np.prod(self._shape)) self._ndindex = None self._coords = None self._inverse = getattr(image, "inverse", np.linalg.inv(image.affine)) @property def affine(self): """Access the indexes-to-RAS affine.""" return self._affine @property def header(self): """Access the original reference's header.""" return self._header @property def inverse(self): """Access the RAS-to-indexes affine.""" return self._inverse @property def ndindex(self): """List the indexes corresponding to the space grid.""" if self._ndindex is None: indexes = tuple([np.arange(s) for s in self._shape]) self._ndindex = np.array(np.meshgrid(*indexes, indexing="ij")).reshape( self._ndim, self._npoints ) return self._ndindex @property def ndcoords(self): """List the physical coordinates of this gridded space samples.""" if self._coords is None: self._coords = np.tensordot( self._affine, np.vstack((self.ndindex, np.ones((1, self._npoints)))), axes=1, )[:3, ...] return self._coords
[docs] def ras(self, ijk): """Get RAS+ coordinates from input indexes.""" return _apply_affine(ijk, self._affine, self._ndim)
[docs] def index(self, x): """Get the image array's indexes corresponding to coordinates.""" return _apply_affine(x, self._inverse, self._ndim)
def _to_hdf5(self, group): group.attrs["Type"] = "image" group.attrs["ndim"] = self.ndim group.create_dataset("affine", data=self.affine) group.create_dataset("shape", data=self.shape) def __eq__(self, other): """Overload equals operator.""" return ( np.allclose(self.affine, other.affine, rtol=EQUALITY_TOL) and self.shape == other.shape ) def __ne__(self, other): """Overload not equal operator.""" return not self == other
[docs]class TransformBase: """Abstract image class to represent transforms.""" __slots__ = ("_reference", "_ndim",) def __init__(self, reference=None): """Instantiate a transform.""" self._reference = None if reference: self.reference = reference def __call__(self, x, inverse=False): """Apply y = f(x).""" return self.map(x, inverse=inverse) def __add__(self, b): """ Compose this and other transforms. Example ------- >>> T1 = TransformBase() >>> added = T1 + TransformBase() >>> len(added.transforms) 2 """ from .manip import TransformChain return TransformChain(transforms=[self, b]) @property def reference(self): """Access a reference space where data will be resampled onto.""" if self._reference is None: warnings.warn("Reference space not set") return self._reference @reference.setter def reference(self, image): self._reference = ImageGrid(image) @property def ndim(self): """Access the dimensions of the reference space.""" raise TypeError("TransformBase has no dimensions")
[docs] def apply( self, spatialimage, reference=None, order=3, mode="constant", cval=0.0, prefilter=True, output_dtype=None, ): """ Apply a transformation to an image, resampling on the reference spatial object. Parameters ---------- spatialimage : `spatialimage` The image object containing the data to be resampled in reference space reference : spatial object, optional The image, surface, or combination thereof containing the coordinates of samples that will be sampled. order : int, optional The order of the spline interpolation, default is 3. The order has to be in the range 0-5. mode : {'constant', 'reflect', 'nearest', 'mirror', 'wrap'}, optional Determines how the input image is extended when the resamplings overflows a border. Default is 'constant'. cval : float, optional Constant value for ``mode='constant'``. Default is 0.0. prefilter: bool, optional Determines if the image's data array is prefiltered with a spline filter before interpolation. The default is ``True``, which will create a temporary *float64* array of filtered values if *order > 1*. If setting this to ``False``, the output will be slightly blurred if *order > 1*, unless the input is prefiltered, i.e. it is the result of calling the spline filter on the original input. output_dtype: dtype specifier, optional The dtype of the returned array or image, if specified. If ``None``, the default behavior is to use the effective dtype of the input image. If slope and/or intercept are defined, the effective dtype is float64, otherwise it is equivalent to the input image's ``get_data_dtype()`` (on-disk type). If ``reference`` is defined, then the return value is an image, with a data array of the effective dtype but with the on-disk dtype set to the input image's on-disk dtype. Returns ------- resampled : `spatialimage` or ndarray The data imaged after resampling to reference space. """ if reference is not None and isinstance(reference, (str, Path)): reference = _nbload(str(reference)) _ref = ( self.reference if reference is None else SpatialReference.factory(reference) ) if _ref is None: raise TransformError("Cannot apply transform without reference") if isinstance(spatialimage, (str, Path)): spatialimage = _nbload(str(spatialimage)) data = np.asanyarray(spatialimage.dataobj) targets = ImageGrid(spatialimage).index( # data should be an image _as_homogeneous(self.map(_ref.ndcoords.T), dim=_ref.ndim) ) resampled = ndi.map_coordinates( data, targets.T, output=output_dtype, order=order, mode=mode, cval=cval, prefilter=prefilter, ) if isinstance(_ref, ImageGrid): # If reference is grid, reshape hdr = None if _ref.header is not None: hdr = _ref.header.copy() hdr.set_data_dtype(output_dtype or spatialimage.get_data_dtype()) moved = spatialimage.__class__( resampled.reshape(_ref.shape), _ref.affine, hdr, ) return moved return resampled
[docs] def map(self, x, inverse=False): r""" Apply :math:`y = f(x)`. TransformBase implements the identity transform. Parameters ---------- x : N x D numpy.ndarray Input RAS+ coordinates (i.e., physical coordinates). inverse : bool If ``True``, apply the inverse transform :math:`x = f^{-1}(y)`. Returns ------- y : N x D numpy.ndarray Transformed (mapped) RAS+ coordinates (i.e., physical coordinates). """ return x
[docs] def to_filename(self, filename, fmt="X5"): """Store the transform in BIDS-Transforms HDF5 file format (.x5).""" with h5py.File(filename, "w") as out_file: out_file.attrs["Format"] = "X5" out_file.attrs["Version"] = np.uint16(1) root = out_file.create_group("/0") self._to_hdf5(root) return filename
def _to_hdf5(self, x5_root): """Serialize this object into the x5 file format.""" raise NotImplementedError
def _as_homogeneous(xyz, dtype="float32", dim=3): """ Convert 2D and 3D coordinates into homogeneous coordinates. Examples -------- >>> _as_homogeneous((4, 5), dtype='int8', dim=2).tolist() [[4, 5, 1]] >>> _as_homogeneous((4, 5, 6),dtype='int8').tolist() [[4, 5, 6, 1]] >>> _as_homogeneous((4, 5, 6, 1),dtype='int8').tolist() [[4, 5, 6, 1]] >>> _as_homogeneous([(1, 2, 3), (4, 5, 6)]).tolist() [[1.0, 2.0, 3.0, 1.0], [4.0, 5.0, 6.0, 1.0]] """ xyz = np.atleast_2d(np.array(xyz, dtype=dtype)) if np.shape(xyz)[-1] == dim + 1: return xyz return np.hstack((xyz, np.ones((xyz.shape[0], 1), dtype=dtype))) def _apply_affine(x, affine, dim): """Get the image array's indexes corresponding to coordinates.""" return affine.dot(_as_homogeneous(x, dim=dim).T)[:dim, ...].T