Source code for wsireg.reg_transforms.reg_transform_seq
import json
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import SimpleITK as sitk
from wsireg.reg_transforms.reg_transform import RegTransform
from wsireg.utils.tform_utils import ELX_TO_ITK_INTERPOLATORS
[docs]
class RegTransformSeq:
"""Class to concatenate and compose sequences of transformations"""
reg_transforms: List[RegTransform] = []
resampler: Optional[sitk.ResampleImageFilter] = None
composed_linear_mats: Optional[Dict[str, np.ndarray]] = None
reg_transforms_itk_order: List[RegTransform] = []
def __init__(
self,
reg_transforms: Optional[
Union[str, Path, Dict[str, List[str]]]
] = None,
transform_seq_idx: Optional[List[int]] = None,
) -> None:
"""
Parameters
----------
reg_transforms: List or single RegTransform or None
RegTransforms to be composed
transform_seq_idx: list of int
Order in sequence of the transform. If a pre-reg transform, it will not be reversed like a sequence
of elastix transforms would to make the composite ITK transform
"""
self._transform_seq_idx = []
if reg_transforms:
self.add_transforms(
reg_transforms, transform_seq_idx=transform_seq_idx
)
else:
self._composite_transform = None
self._n_transforms = 0
[docs]
def add_transforms(
self,
transforms: Union[str, Path, dict, List[RegTransform], RegTransform],
transform_seq_idx: Optional[List[int]] = None,
) -> None:
"""
Add transforms to sequence.
Parameters
----------
transforms: path to wsireg transforms .json, elastix transform dict,RegTransform ot List of RegTransform
transform_seq_idx: list of int
Order in sequence of the transform. If a pre-reg transform, it will not be reversed like a sequence
of elastix transforms would to make the composite ITK transform
"""
if isinstance(transforms, (str, Path, dict)):
tform_list, tform_idx = _read_wsireg_transform(transforms)
self.transform_seq_idx = tform_idx
reg_transforms = [RegTransform(t) for t in tform_list]
self.reg_transforms = self.reg_transforms + reg_transforms
elif isinstance(transforms, (list, RegTransform)):
if isinstance(transforms, RegTransform):
transforms = [transforms]
self.reg_transforms = self.reg_transforms + transforms
self.transform_seq_idx = transform_seq_idx
self._update_transform_properties()
@property
def composite_transform(self) -> sitk.CompositeTransform:
"""Composite ITK transform from transformation sequence"""
return self._composite_transform
@composite_transform.setter
def composite_transform(self, transforms):
self._composite_transform = transforms
@property
def transform_seq_idx(self) -> List[int]:
"""Transformation sequence for all combined transformations."""
return self._transform_seq_idx
@transform_seq_idx.setter
def transform_seq_idx(self, transform_seq):
if len(self._transform_seq_idx) > 0:
reindex_val = np.max(self._transform_seq_idx) + 1
else:
reindex_val = 0
transform_seq = [x + reindex_val for x in transform_seq]
self._transform_seq_idx = self._transform_seq_idx + transform_seq
@property
def n_transforms(self) -> int:
"""Number of transformations in sequence."""
return self._n_transforms
@n_transforms.setter
def n_transforms(self) -> None:
self._n_transforms = len(self.reg_transforms)
@property
def output_size(self) -> Tuple[int, int]:
"""Output size of image resampled by transform, initially determined from the last
transformation in the chain"""
return self._output_size
@output_size.setter
def output_size(self, new_size: Tuple[int, int]) -> None:
self._output_size = new_size
@property
def output_spacing(self) -> Union[Tuple[float, float], Tuple[int, int]]:
"""Output spacing of image resampled by transform, initially determined from the last
transformation in the chain"""
return self._output_spacing
@output_size.setter
def output_size(
self, new_spacing: Union[Tuple[float, float], Tuple[int, int]]
) -> None:
self._output_spacing = new_spacing
[docs]
def set_output_spacing(
self, spacing: Union[Tuple[float, float], Tuple[int, int]]
) -> None:
"""
Method that allows setting the output spacing of the resampler
to resampled to any pixel spacing desired. This will also change the output_size
to match.
Parameters
----------
spacing: tuple of float
Spacing to set the new image. Will also change the output size to match.
"""
output_size_scaling = np.asarray(self._output_spacing) / np.asarray(
spacing
)
new_output_size = np.ceil(
np.multiply(self._output_size, output_size_scaling)
)
new_output_size = tuple([int(i) for i in new_output_size])
self._output_spacing = spacing
self._output_size = new_output_size
self._build_resampler()
def _update_transform_properties(self) -> None:
self._output_size = self.reg_transforms[-1].output_size
self._output_spacing = self.reg_transforms[-1].output_spacing
self._build_transform_data()
def _build_transform_data(self) -> None:
self._build_composite_transform(
self.reg_transforms, self.transform_seq_idx
)
self._build_resampler()
def _build_composite_transform(
self, reg_transforms, reg_transform_seq_idx
) -> None:
composite_index = []
for unique_idx in np.unique(reg_transform_seq_idx):
in_seq_tform_idx = np.where(reg_transform_seq_idx == unique_idx)[0]
if len(in_seq_tform_idx) > 1:
composite_index = composite_index + list(
in_seq_tform_idx[::-1]
)
else:
composite_index = composite_index + list(in_seq_tform_idx)
composite_transform = sitk.CompositeTransform(2)
for tform_idx in composite_index:
composite_transform.AddTransform(
reg_transforms[tform_idx].itk_transform
)
self._composite_transform = composite_transform
self.reg_transforms_itk_order = [
self.reg_transforms[i] for i in composite_index
]
def _build_resampler(self) -> None:
resampler = sitk.ResampleImageFilter()
resampler.SetOutputOrigin(self.reg_transforms[-1].output_origin)
resampler.SetOutputDirection(self.reg_transforms[-1].output_direction)
resampler.SetSize(self.output_size)
resampler.SetOutputSpacing(self.output_spacing)
interpolator = ELX_TO_ITK_INTERPOLATORS.get(
self.reg_transforms[-1].resample_interpolator
)
resampler.SetInterpolator(interpolator)
resampler.SetTransform(self.composite_transform)
self.resampler = resampler
[docs]
def transform_points(
self, pt_data: np.ndarray, px_idx=True, source_res=1, output_idx=True
) -> np.ndarray:
"""
Transform point sets using the transformation chain
Parameters
----------
pt_data: np.ndarray
Point data in xy order
px_idx: bool
Whether point data is in pixel or physical coordinate sapce
source_res: float
spacing of the pixels associated with pt_data if they are not in physical coordinate space
output_idx: bool
return transformed points to pixel indices in the output_spacing's reference space
Returns
-------
tformed_pts: np.ndarray
Transformed points
"""
tformed_pts = []
for pt in pt_data:
if px_idx is True:
pt = pt * source_res
for idx, t in enumerate(self.reg_transforms):
if idx == 0:
t_pt = t.inverse_transform.TransformPoint(pt)
else:
t_pt = t.inverse_transform.TransformPoint(t_pt)
t_pt = np.array(t_pt)
if output_idx is True:
t_pt *= 1 / self._output_spacing[0]
tformed_pts.append(t_pt)
return np.stack(tformed_pts)
[docs]
def append(self, other) -> None:
"""
Concatenate transformation sequences.
Parameters
----------
other: RegTransformSeq
Append a RegTransformSeq to another
"""
self.add_transforms(other.reg_transforms, other.transform_seq_idx)
def _write_transforms(self, output_path: Union[str, Path]):
return
def _read_wsireg_transform(
parameter_data: Union[str, Path, Dict[Any, Any]]
) -> Tuple[List[Dict[str, List[str]]], List[int]]:
"""Convert wsireg transform dict or from file to List of RegTransforms"""
if isinstance(parameter_data, (str, Path)):
parameter_data_in = json.load(open(parameter_data, "r"))
else:
parameter_data_in = parameter_data
transform_list = []
transform_list_seq_id = []
seq_idx = 0
for k, v in parameter_data_in.items():
if k == "initial":
if isinstance(v, dict):
transform_list.append(v)
transform_list_seq_id.append(seq_idx)
seq_idx += 1
elif isinstance(v, list):
for init_tform in v:
transform_list.append(init_tform)
transform_list_seq_id.append(seq_idx)
seq_idx += 1
else:
if isinstance(v, dict):
transform_list.append(v)
transform_list_seq_id.append(seq_idx)
seq_idx += 1
elif isinstance(v, list):
for tform in v:
transform_list.append(tform)
transform_list_seq_id.append(seq_idx)
seq_idx += 1
return transform_list, transform_list_seq_id