Source code for wsireg.utils.tform_utils
import json
from pathlib import Path
from typing import Tuple, Union
import itk
import numpy as np
import SimpleITK as sitk
from wsireg.parameter_maps.transformations import (
BASE_AFF_TFORM,
BASE_RIG_TFORM,
)
from wsireg.reg_transforms.reg_transform import RegTransform
from wsireg.utils.itk_im_conversions import (
itk_image_to_sitk_image,
sitk_image_to_itk_image,
)
from wsireg.utils.reg_utils import json_to_pmap_dict
NUMERIC_ELX_PARAMETERS = {
"CenterOfRotationPoint": np.float64,
"DefaultPixelValue": np.float64,
"Direction": np.float64,
"FixedImageDimension": np.int64,
"Index": np.int64,
"MovingImageDimension": np.int64,
"NumberOfParameters": np.int64,
"Origin": np.float64,
"Size": np.int64,
"Spacing": np.float64,
"TransformParameters": np.float64,
}
ELX_LINEAR_TRANSFORMS = [
"AffineTransform",
"EulerTransform",
"SimilarityTransform",
]
ELX_TO_ITK_INTERPOLATORS = {
"FinalNearestNeighborInterpolator": sitk.sitkNearestNeighbor,
"FinalLinearInterpolator": sitk.sitkLinear,
"FinalBSplineInterpolator": sitk.sitkBSpline,
}
[docs]
def prepare_tform_dict(tform_dict, shape_tform=False):
tform_dict_out = {}
for k, v in tform_dict.items():
if k == "initial":
tform_dict_out["initial"] = v
else:
tforms = []
for tform in v:
if "invert" in list(tform.keys()):
if shape_tform is False:
tforms.append(tform["image"])
else:
tforms.append(tform["invert"])
else:
tforms.append(tform)
tform_dict_out[k] = tforms
return tform_dict_out
[docs]
def transform_2d_image_itkelx(
image, transformation_maps, writer="sitk", **zarr_kwargs
):
"""
Transform 2D images with multiple models and return the transformed image
or write the transformed image to disk as a .tif file.
Multichannel or multicomponent images (RGB) have to be transformed a single channel at a time
This function takes care of performing those transformations and reconstructing the image in the same
data type as the input
Parameters
----------
image : SimpleITK.Image
Image to be transformed
transformation_maps : list
list of SimpleElastix ParameterMaps to used for transformation
Returns
-------
Transformed SimpleITK.Image
"""
if transformation_maps is not None:
tfx = itk.TransformixFilter.New()
# TODO: add mask cropping here later
# print("mask cropping")
# tmap = sitk.ReadParameterFile(transformation_maps[0])
# x_min = int(float(tmap["MinimumX"][0]))
# x_max = int(float(tmap["MaximumX"][0]))
# y_min = int(float(tmap["MinimumY"][0]))
# y_max = int(float(tmap["MaximumY"][0]))
# image = image[x_min:x_max, y_min:y_max]
# origin = np.repeat(0, len(image.GetSize()))
# image.SetOrigin(tuple([int(i) for i in origin]))
# else:
transform_pobj = itk.ParameterObject.New()
for idx, tmap in enumerate(transformation_maps):
if isinstance(tmap, str):
tmap = sitk.ReadParameterFile(tmap)
if idx == 0:
tmap["InitialTransformParametersFileName"] = (
"NoInitialTransform",
)
transform_pobj.AddParameterMap(tmap)
else:
tmap["InitialTransformParametersFileName"] = (
"NoInitialTransform",
)
transform_pobj.AddParameterMap(tmap)
tfx.SetTransformParameterObject(transform_pobj)
tfx.LogToConsoleOn()
tfx.LogToFileOff()
else:
tfx = None
# if tfx is None:
# xy_final_size = np.array(image.GetSize(), dtype=np.uint32)
# else:
# xy_final_size = np.array(
# transformation_maps[-1]["Size"], dtype=np.uint32
# )
if writer == "sitk" or writer is None:
return transform_image_itkelx_to_sitk(image, tfx)
elif writer == "zarr":
return
else:
raise ValueError("writer type {} not recognized".format(writer))
[docs]
def transform_image_to_sitk(image, tfx):
# manage transformation/casting if data is multichannel or RGB
# data is always returned in the same PixelIDType as it is entered
pixel_id = image.GetPixelID()
if tfx is not None:
if pixel_id in list(range(1, 13)) and image.GetDepth() == 0:
tfx.SetMovingImage(image)
image = tfx.Execute()
image = sitk.Cast(image, pixel_id)
elif pixel_id in list(range(1, 13)) and image.GetDepth() > 0:
images = []
for chan in range(image.GetDepth()):
tfx.SetMovingImage(image[:, :, chan])
images.append(sitk.Cast(tfx.Execute(), pixel_id))
image = sitk.JoinSeries(images)
image = sitk.Cast(image, pixel_id)
elif pixel_id > 12:
images = []
for idx in range(image.GetNumberOfComponentsPerPixel()):
im = sitk.VectorIndexSelectionCast(image, idx)
pixel_id_nonvec = im.GetPixelID()
tfx.SetMovingImage(im)
images.append(sitk.Cast(tfx.Execute(), pixel_id_nonvec))
del im
image = sitk.Compose(images)
image = sitk.Cast(image, pixel_id)
return image
[docs]
def transform_image_itkelx_to_sitk(image, tfx):
# manage transformation/casting if data is multichannel or RGB
# data is always returned in the same PixelIDType as it is entered
pixel_id = image.GetPixelID()
if tfx is not None:
if pixel_id in list(range(1, 13)) and image.GetDepth() == 0:
image = sitk_image_to_itk_image(image, cast_to_float32=True)
tfx.SetMovingImage(image)
tfx.UpdateLargestPossibleRegion()
image = tfx.GetOutput()
image = itk_image_to_sitk_image(image)
image = sitk.Cast(image, pixel_id)
elif pixel_id in list(range(1, 13)) and image.GetDepth() > 0:
images = []
for chan in range(image.GetDepth()):
image = sitk_image_to_itk_image(
image[:, :, chan], cast_to_float32=True
)
tfx.SetMovingImage(image)
tfx.UpdateLargestPossibleRegion()
image = tfx.GetOutput()
image = itk_image_to_sitk_image(image)
image = sitk.Cast(image, pixel_id)
images.append(image)
image = sitk.JoinSeries(images)
image = sitk.Cast(image, pixel_id)
elif pixel_id > 12:
images = []
for idx in range(image.GetNumberOfComponentsPerPixel()):
im = sitk.VectorIndexSelectionCast(image, idx)
pixel_id_nonvec = im.GetPixelID()
im = sitk_image_to_itk_image(im, cast_to_float32=True)
tfx.SetMovingImage(im)
tfx.UpdateLargestPossibleRegion()
im = tfx.GetOutput()
im = itk_image_to_sitk_image(im)
im = sitk.Cast(im, pixel_id_nonvec)
images.append(im)
del im
image = sitk.Compose(images)
image = sitk.Cast(image, pixel_id)
return image
[docs]
def apply_transform_dict_itkelx(
image_fp,
image_res,
tform_dict_in,
prepro_dict=None,
is_shape_mask=False,
writer="sitk",
**im_tform_kwargs,
):
"""
Apply a complex series of transformations in a python dictionary to an image
Parameters
----------
image_fp : str
file path to the image to be transformed, it will be read in it's entirety
image_res : float
pixel resolution of image to be transformed
tform_dict : dict of lists
dict of SimpleElastix transformations stored in lists, may contain an "initial" transforms (preprocessing transforms)
these will be applied first, then the key order of the dict will determine the rest of the transformations
is_shape_mask : bool
whether the image being transformed is a shape mask (determines import)
Returns
-------
image: itk.Image
image that has been transformed
"""
if is_shape_mask is False:
if isinstance(image_fp, sitk.Image):
image = image_fp
# else:
# image = RegImage(
# image_fp, image_res, prepro_dict=prepro_dict
# ).image
else:
image = sitk.GetImageFromArray(image_fp)
del image_fp
image.SetSpacing((image_res, image_res))
if tform_dict_in is None:
if writer == "zarr":
image = transform_2d_image_itkelx(
image,
None,
writer="zarr",
zarr_store_dir=im_tform_kwargs["zarr_store_dir"],
channel_names=im_tform_kwargs["channel_names"],
channel_colors=im_tform_kwargs["channel_colors"],
)
else:
image = transform_2d_image_itkelx(image, None)
else:
tform_dict = tform_dict_in.copy()
if tform_dict.get("registered") is None and tform_dict.get(0) is None:
tform_dict["registered"] = tform_dict["initial"]
tform_dict.pop("initial", None)
if isinstance(tform_dict.get("registered"), list) is False:
tform_dict["registered"] = [tform_dict["registered"]]
for idx in range(len(tform_dict["registered"])):
tform_dict[idx] = [tform_dict["registered"][idx]]
tform_dict.pop("registered", None)
else:
tform_dict = prepare_tform_dict(tform_dict, shape_tform=False)
if "initial" in tform_dict:
for initial_tform in tform_dict["initial"]:
if isinstance(initial_tform, list) is False:
initial_tform = [initial_tform]
for tform in initial_tform:
image = transform_2d_image_itkelx(image, [tform])
tform_dict.pop("initial", None)
for k, v in tform_dict.items():
if writer == "zarr" and k == list(tform_dict.keys())[-1]:
image = transform_2d_image_itkelx(
image,
v,
writer="zarr",
zarr_store_dir=im_tform_kwargs["zarr_store_dir"],
channel_names=im_tform_kwargs["channel_names"],
channel_colors=im_tform_kwargs["channel_colors"],
)
else:
image = transform_2d_image_itkelx(image, v)
return image
[docs]
def compute_rot_bound(image, angle=30):
"""
compute the bounds of an image after by an angle
Parameters
----------
image : sitk.Image
SimpleITK image that will be rotated angle
angle : float
angle of rotation in degrees, rotates counter-clockwise if positive
Returns
-------
tuple of the rotated image's size in x and y
"""
w, h = image.GetSize()[0], image.GetSize()[1]
theta = np.radians(angle)
c, s = np.abs(np.cos(theta)), np.abs(np.sin(theta))
bound_w = (h * s) + (w * c)
bound_h = (h * c) + (w * s)
return bound_w, bound_h
[docs]
def gen_rigid_tform_rot(image, spacing, angle):
"""
generate a SimpleElastix transformation parameter Map to rotate image by angle
Parameters
----------
image : sitk.Image
SimpleITK image that will be rotated
spacing : float
Physical spacing of the SimpleITK image
angle : float
angle of rotation in degrees, rotates counter-clockwise if positive
Returns
-------
SimpleITK.ParameterMap of rotation transformation (EulerTransform)
"""
tform = BASE_RIG_TFORM.copy()
image.SetSpacing((spacing, spacing))
bound_w, bound_h = compute_rot_bound(image, angle=angle)
rot_cent_pt = image.TransformContinuousIndexToPhysicalPoint(
((bound_w - 1) / 2, (bound_h - 1) / 2)
)
c_x, c_y = (image.GetSize()[0] - 1) / 2, (image.GetSize()[1] - 1) / 2
c_x_phy, c_y_phy = image.TransformContinuousIndexToPhysicalPoint(
(c_x, c_y)
)
t_x = rot_cent_pt[0] - c_x_phy
t_y = rot_cent_pt[1] - c_y_phy
tform["Spacing"] = [str(spacing), str(spacing)]
tform["Size"] = [str(int(np.ceil(bound_w))), str(int(np.ceil(bound_h)))]
tform["CenterOfRotationPoint"] = [str(rot_cent_pt[0]), str(rot_cent_pt[1])]
tform["TransformParameters"] = [
str(np.radians(angle)),
str(-1 * t_x),
str(-1 * t_y),
]
return tform
[docs]
def gen_rigid_translation(
image, spacing, translation_x, translation_y, size_x, size_y
):
"""
generate a SimpleElastix transformation parameter Map to rotate image by angle
Parameters
----------
image : sitk.Image
SimpleITK image that will be rotated
spacing : float
Physical spacing of the SimpleITK image
Returns
-------
SimpleITK.ParameterMap of rotation transformation (EulerTransform)
"""
tform = BASE_RIG_TFORM.copy()
image.SetSpacing((spacing, spacing))
bound_w, bound_h = compute_rot_bound(image, angle=0)
rot_cent_pt = image.TransformContinuousIndexToPhysicalPoint(
((bound_w - 1) / 2, (bound_h - 1) / 2)
)
(
translation_x,
translation_y,
) = image.TransformContinuousIndexToPhysicalPoint(
(float(translation_x), float(translation_y))
)
# c_x, c_y = (image.GetSize()[0] - 1) / 2, (image.GetSize()[1] - 1) / 2
tform["Spacing"] = [str(spacing), str(spacing)]
tform["Size"] = [str(size_x), str(size_y)]
tform["CenterOfRotationPoint"] = [str(rot_cent_pt[0]), str(rot_cent_pt[1])]
tform["TransformParameters"] = [
str(0),
str(translation_x),
str(translation_y),
]
return tform
[docs]
def gen_rig_to_original(original_size, crop_transform):
crop_transform["Size"] = [str(original_size[0]), str(original_size[1])]
tform_params = [float(t) for t in crop_transform["TransformParameters"]]
crop_transform["TransformParameters"] = [
str(0),
str(tform_params[1] * -1),
str(tform_params[2] * -1),
]
return crop_transform
[docs]
def gen_aff_tform_flip(image, spacing, flip="h"):
"""
generate a SimpleElastix transformation parameter Map to horizontally or vertically flip image
Parameters
----------
image : sitk.Image
SimpleITK image that will be rotated
spacing : float
Physical spacing of the SimpleITK image
flip : str
"h" or "v" for horizontal or vertical flipping, respectively
Returns
-------
SimpleITK.ParameterMap of flipping transformation (AffineTransform)
"""
tform = BASE_AFF_TFORM.copy()
image.SetSpacing((spacing, spacing))
bound_w, bound_h = compute_rot_bound(image, angle=0)
rot_cent_pt = image.TransformContinuousIndexToPhysicalPoint(
((bound_w - 1) / 2, (bound_h - 1) / 2)
)
tform["Spacing"] = [str(spacing), str(spacing)]
tform["Size"] = [str(int(bound_w)), str(int(bound_h))]
tform["CenterOfRotationPoint"] = [str(rot_cent_pt[0]), str(rot_cent_pt[1])]
if flip == "h":
tform_params = ["-1", "0", "0", "1", "0", "0"]
elif flip == "v":
tform_params = ["1", "0", "0", "-1", "0", "0"]
tform["TransformParameters"] = tform_params
return tform
[docs]
def make_composite_itk(itk_tforms):
itk_composite = sitk.CompositeTransform(2)
for t in itk_tforms:
itk_composite.AddTransform(t.itk_transform)
return itk_composite
[docs]
def get_final_tform(parameter_data):
if (
isinstance(parameter_data, str)
and Path(parameter_data).suffix == ".json"
):
parameter_data = json.load(open(parameter_data, "r"))
final_key = list(parameter_data.keys())[-1]
final_tform = parameter_data[final_key][-1]
return final_tform
[docs]
def collate_wsireg_transforms(parameter_data):
if (
isinstance(parameter_data, str)
and Path(parameter_data).suffix == ".json"
):
parameter_data = json.load(open(parameter_data, "r"))
parameter_data_list = []
for k, v in parameter_data.items():
if k == "initial":
if isinstance(v, dict):
parameter_data_list.append([v])
elif isinstance(v, list):
for init_tform in v:
parameter_data_list.append([init_tform])
else:
sub_tform = []
if isinstance(v, dict):
sub_tform.append(v)
elif isinstance(v, list):
sub_tform += v
sub_tform = sub_tform[::-1]
parameter_data_list.append(sub_tform)
flat_pmap_list = [
item for sublist in parameter_data_list for item in sublist
]
if all([isinstance(t, dict) for t in flat_pmap_list]):
flat_pmap_list = [RegTransform(t) for t in flat_pmap_list]
return flat_pmap_list
[docs]
def wsireg_transforms_to_itk_composite(parameter_data):
reg_transforms = collate_wsireg_transforms(parameter_data)
composite_tform = make_composite_itk(reg_transforms)
return composite_tform, reg_transforms
[docs]
def prepare_wsireg_transform_data(transform_data):
if isinstance(transform_data, str) is True:
transform_data = json_to_pmap_dict(transform_data)
if transform_data is not None:
(
composite_transform,
itk_transforms,
) = wsireg_transforms_to_itk_composite(transform_data)
final_tform = itk_transforms[-1]
return composite_transform, itk_transforms, final_tform
[docs]
def wsireg_transforms_to_resampler(final_tform):
resampler = sitk.ResampleImageFilter()
resampler.SetOutputOrigin(final_tform.output_origin)
resampler.SetSize(final_tform.output_size)
resampler.SetOutputDirection(final_tform.output_direction)
resampler.SetOutputSpacing(final_tform.output_spacing)
interpolator = ELX_TO_ITK_INTERPOLATORS.get(
final_tform.resample_interpolator
)
resampler.SetInterpolator(interpolator)
return resampler
[docs]
def sitk_transform_image(image, final_tform, composite_transform):
resampler = wsireg_transforms_to_resampler(final_tform)
resampler.SetTransform(composite_transform)
image = resampler.Execute(image)
return image
[docs]
def identity_elx_transform(
image_size: Tuple[int, int],
image_spacing: Union[Tuple[int, int], Tuple[float, float]],
):
identity = BASE_RIG_TFORM
identity.update({"Size": [str(i) for i in image_size]})
identity.update({"Spacing": [str(i) for i in image_spacing]})
return identity