Source code for nitransforms.manip

# emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
#
#   See COPYING file distributed along with the NiBabel package for the
#   copyright and license terms.
#
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
"""Common interface for transforms."""

import os
from collections.abc import Iterable
import numpy as np

import h5py
from nitransforms.base import (
    TransformBase,
    TransformError,
)
from nitransforms.io import itk, x5 as x5io
from nitransforms.io.x5 import from_filename as load_x5
from nitransforms.linear import (  # noqa: F401
    Affine,
    from_x5 as linear_from_x5,
)
from nitransforms.nonlinear import (  # noqa: F401
    DenseFieldTransform,
    from_x5 as nonlinear_from_x5,
)


[docs] class TransformChain(TransformBase): """Implements the concatenation of transforms.""" __slots__ = ("_transforms",) def __init__(self, transforms=None): """Initialize a chain of transforms.""" super().__init__() self._transforms = None if transforms is not None: self.transforms = transforms def __add__(self, b): """ Compose this and other transforms. Example ------- >>> T1 = TransformBase() >>> added = T1 + TransformBase() + TransformBase() >>> isinstance(added, TransformChain) True >>> len(added.transforms) 3 """ self.append(b) return self def __getitem__(self, i): """ Enable indexed access of transform chains. Example ------- >>> T1 = TransformBase() >>> chain = T1 + TransformBase() >>> chain[0] is T1 True """ return self.transforms[i] def __len__(self): """Enable using len().""" return len(self.transforms) @property def ndim(self): """Get the number of dimensions.""" return max(x.ndim for x in self._transforms) @property def transforms(self): """Get the internal list of transforms.""" return self._transforms @transforms.setter def transforms(self, value): self._transforms = _as_chain(value) if self.transforms[0].reference: self.reference = self.transforms[0].reference
[docs] def append(self, x): """ Concatenate one element to the chain. Example ------- >>> chain = TransformChain(transforms=TransformBase()) >>> chain.append((TransformBase(), TransformBase())) >>> len(chain) 3 """ self.transforms += _as_chain(x)
[docs] def insert(self, i, x): """ Insert an item at a given position. Example ------- >>> chain = TransformChain(transforms=[TransformBase(), TransformBase()]) >>> chain.insert(1, TransformBase()) >>> len(chain) 3 >>> chain.insert(1, TransformChain(chain)) >>> len(chain) 6 """ self.transforms = self.transforms[:i] + _as_chain(x) + self.transforms[i:]
[docs] def map(self, x, inverse=False): """ Apply a succession of transforms, e.g., :math:`y = f_3(f_2(f_1(f_0(x))))`. Example ------- >>> chain = TransformChain(transforms=[TransformBase(), TransformBase()]) >>> chain([(0., 0., 0.), (1., 1., 1.), (-1., -1., -1.)]) [(0.0, 0.0, 0.0), (1.0, 1.0, 1.0), (-1.0, -1.0, -1.0)] >>> chain([(0., 0., 0.), (1., 1., 1.), (-1., -1., -1.)], inverse=True) [(0.0, 0.0, 0.0), (1.0, 1.0, 1.0), (-1.0, -1.0, -1.0)] >>> TransformChain()((0., 0., 0.)) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): TransformError: """ if not self.transforms: raise TransformError("Cannot apply an empty transforms chain.") transforms = self.transforms if inverse: transforms = list(reversed(self.transforms)) for xfm in transforms: x = xfm.map(x, inverse=inverse) return x
[docs] def asaffine(self, indices=None): """ Combine a succession of linear transforms into one. Example ------ >>> chain = TransformChain(transforms=[ ... Affine.from_matvec(vec=(2, -10, 3)), ... Affine.from_matvec(vec=(-2, 10, -3)), ... ]) >>> chain.asaffine() array([[1., 0., 0., 0.], [0., 1., 0., 0.], [0., 0., 1., 0.], [0., 0., 0., 1.]]) >>> chain = TransformChain(transforms=[ ... Affine.from_matvec(vec=(1, 2, 3)), ... Affine.from_matvec(mat=[[0, 1, 0], [0, 0, 1], [1, 0, 0]]), ... ]) >>> chain.asaffine() array([[0., 1., 0., 2.], [0., 0., 1., 3.], [1., 0., 0., 1.], [0., 0., 0., 1.]]) >>> np.allclose( ... chain.map((4, -2, 1)), ... chain.asaffine().map((4, -2, 1)), ... ) True Parameters ---------- indices : :obj:`numpy.array_like` The indices of the values to extract. """ affines = ( self.transforms if indices is None else np.take(self.transforms, indices) ) retval = affines[0] for xfm in affines[1:]: retval = xfm @ retval return retval
[docs] @classmethod def from_filename(cls, filename, fmt="X5", reference=None, moving=None, x5_chain=0): """Load a transform file.""" retval = [] if fmt and fmt.upper() == "X5": # Get list of X5 nodes and generate transforms xfm_list = [ globals()[f"{node.type}_from_x5"]([node]) for node in load_x5(filename) ] if not xfm_list: raise TransformError("Empty transform group") if x5_chain is None: return xfm_list with h5py.File(str(filename), "r") as f: chain_grp = f.get("TransformChain") if chain_grp is None: raise TransformError("X5 file contains no TransformChain") chain_path = chain_grp[str(x5_chain)][()] chain_path = ( chain_path.decode() if isinstance(chain_path, bytes) else chain_path ) return TransformChain([xfm_list[int(idx)] for idx in chain_path.split("/")]) if str(filename).endswith(".h5"): reference = None xforms = itk.ITKCompositeH5.from_filename(filename) for xfmobj in xforms: if isinstance(xfmobj, itk.ITKLinearTransform): retval.insert(0, Affine(xfmobj.to_ras(), reference=reference)) else: retval.insert(0, DenseFieldTransform(xfmobj)) return TransformChain(retval) raise NotImplementedError
[docs] def to_filename(self, filename, fmt="X5"): """Store the transform chain in X5 format.""" if fmt.upper() != "X5": raise NotImplementedError("Only X5 format is supported for chains") existing = ( self.from_filename(filename, x5_chain=None) if os.path.exists(filename) else [] ) xfm_chain = [] new_xfms = [] next_xfm_index = len(existing) for xfm in self.transforms: for eidx, existing_xfm in enumerate(existing): if xfm == existing_xfm: xfm_chain.append(eidx) break else: xfm_chain.append(next_xfm_index) new_xfms.append((next_xfm_index, xfm)) existing.append(xfm) next_xfm_index += 1 mode = "r+" if os.path.exists(filename) else "w" with h5py.File(str(filename), mode) as f: if "Format" not in f.attrs: f.attrs["Format"] = "X5" f.attrs["Version"] = np.uint16(1) tg = f.require_group("TransformGroup") for idx, node in new_xfms: g = tg.create_group(str(idx)) x5io._write_x5_group(g, node.to_x5()) cg = f.require_group("TransformChain") cg.create_dataset(str(len(cg)), data="/".join(str(i) for i in xfm_chain)) return filename
def _as_chain(x): """Convert a value into a transform chain.""" if isinstance(x, TransformChain): return x.transforms if isinstance(x, TransformBase): return [x] if isinstance(x, Iterable): return list(x) return [x] load = TransformChain.from_filename