Source code for wsireg.reg_transforms.reg_transform

from warnings import warn
from typing import Optional
import numpy as np
import SimpleITK as sitk

from wsireg.utils.tform_conversion import convert_to_itk


[docs] class RegTransform: """Container for elastix transform that manages inversion and other metadata. Converts elastix transformation dict to it's SimpleITK representation Attributes ---------- elastix_transform: dict elastix transform stored in a python dict itk_transform: sitk.Transform elastix transform in SimpleITK container output_spacing: list of float Spacing of the targeted image during registration output_size: list of int Size of the targeted image during registration output_direction: list of float Direction of the targeted image during registration (not relevant for 2D applications) output_origin: list of float Origin of the targeted image during registration resampler_interpolator: str elastix interpolator setting for resampling the image is_linear: bool Whether the given transform is linear or non-linear (non-rigid) inverse_transform: sitk.Transform or None Inverse of the itk transform used for transforming from moving to fixed space Only calculated for non-rigid transforms when called by `compute_inverse_nonlinear` as the process is quite memory and computationally intensive """ def __init__(self, elastix_transform): """ Parameters ---------- elastix_transform: dict elastix transform stored in a python dict """ self.elastix_transform: dict = elastix_transform self.itk_transform: sitk.Transform = convert_to_itk( self.elastix_transform ) self.output_spacing = [ float(p) for p in self.elastix_transform.get("Spacing") ] self.output_size = [int(p) for p in self.elastix_transform.get("Size")] self.output_origin = [ float(p) for p in self.elastix_transform.get("Origin") ] self.output_direction = [ float(p) for p in self.elastix_transform.get("Direction") ] self.resample_interpolator = self.elastix_transform.get( "ResampleInterpolator" )[0] self.is_linear = self.itk_transform.IsLinear() if self.is_linear is True: self.inverse_transform = self.itk_transform.GetInverse() transform_name = self.itk_transform.GetName() if transform_name == "Euler2DTransform": self.inverse_transform = sitk.Euler2DTransform( self.inverse_transform ) elif transform_name == "AffineTransform": self.inverse_transform = sitk.AffineTransform( self.inverse_transform ) elif transform_name == "Similarity2DTransform": self.inverse_transform = sitk.Similarity2DTransform( self.inverse_transform ) else: self.inverse_transform = None
[docs] def compute_inverse_nonlinear(self) -> None: """Compute the inverse of a BSpline transform using ITK""" tform_to_dfield = sitk.TransformToDisplacementFieldFilter() tform_to_dfield.SetOutputSpacing(self.output_spacing) tform_to_dfield.SetOutputOrigin(self.output_origin) tform_to_dfield.SetOutputDirection(self.output_direction) tform_to_dfield.SetSize(self.output_size) displacement_field = tform_to_dfield.Execute(self.itk_transform) displacement_field = sitk.InvertDisplacementField(displacement_field) displacement_field = sitk.DisplacementFieldTransform( displacement_field ) self.inverse_transform = displacement_field
[docs] def as_np_matrix( self, use_np_ordering: bool = False, n_dim: int = 3, use_inverse: bool = False, to_px_idx: bool = False, ) -> Optional[np.ndarray]: """ Creates a affine transform matrix as np.ndarray whether the center of rotation is 0,0. Optionally in physical or pixel coordinates. Parameters ---------- use_np_ordering: bool Use numpy ordering of yx (napari-compatible) n_dim: int Number of dimensions in the affine matrix, using 3 creates a 3x3 array use_inverse: bool return the inverse affine transformation to_px_idx: bool return the transformation matrix specified in pixels or physical (microns) Returns ------- full_matrix: np.ndarray Affine transformation matrix """ if self.is_linear: if use_np_ordering is True: order = slice(None, None, -1) else: order = slice(None, None, 1) if use_inverse is True: transform = self.inverse_transform else: transform = self.itk_transform # pull transform values tmatrix = np.array(transform.GetMatrix()[order]).reshape(2, 2) center = np.array(transform.GetCenter()[order]) translation = np.array(transform.GetTranslation()[order]) if to_px_idx is True: phys_to_index = 1 / np.asarray(self.output_spacing).astype( np.float64 ) center *= phys_to_index translation *= phys_to_index # construct matrix full_matrix = np.eye(n_dim) full_matrix[0:2, 0:2] = tmatrix full_matrix[0:2, n_dim - 1] = ( -np.dot(tmatrix, center) + translation + center ) return full_matrix else: warn( "Non-linear transformations can not be represented converted" "to homogenous matrix" ) return None