Source code for wsireg.reg_images.sitk_reg_image

import warnings

import SimpleITK as sitk

from wsireg.reg_images.reg_image import RegImage
from wsireg.utils.im_utils import (
    ensure_dask_array,
    get_sitk_image_info,
    guess_rgb,
)


[docs] class SitkRegImage(RegImage): def __init__( self, image, image_res, mask=None, pre_reg_transforms=None, preprocessing=None, channel_names=None, channel_colors=None, ): super(SitkRegImage, self).__init__(preprocessing) self._path = image self._image_res = image_res self.reader = "sitk" ( self._shape, self._im_dtype, ) = self._get_image_info() self._is_rgb = guess_rgb(self._shape) self._n_ch = self._shape[2] if self.is_rgb else self._shape[0] if mask: self._mask = self.read_mask(mask) self.pre_reg_transforms = pre_reg_transforms self._channel_names = channel_names self._channel_colors = channel_colors self.original_size_transform = None def _get_image_info(self): im_dims, im_dtype = get_sitk_image_info(self._path) im_dims = (int(im_dims[0]), int(im_dims[1]), int(im_dims[2])) return im_dims, im_dtype
[docs] def read_reg_image(self): """ Read and preprocess the image for registration. """ reg_image = sitk.ReadImage(self._path) if self.preprocessing.as_uint8 is True and reg_image.GetPixelID() != 1: reg_image = sitk.Cast( sitk.RescaleIntensity(reg_image), sitk.sitkUInt8 ) self.preprocess_image(reg_image)
def _read_full_image(self): self._dask_image = ensure_dask_array( sitk.GetArrayFromImage(sitk.ReadImage(self._path)) ) rechunk_size = ( (2048, 2048, self.n_ch) if self.is_rgb else (self.n_ch, 2048, 2048) ) self._dask_image = self._dask_image.rechunk(rechunk_size)
[docs] def read_single_channel(self, channel_idx: int): """ Read in a single channel for transformation by plane. Parameters ---------- channel_idx: int Index of the channel to be read Returns ------- image: np.ndarray Numpy array of the selected channel to be read """ if channel_idx > (self.n_ch - 1): warnings.warn( "channel_idx exceeds number of channels, reading channel at channel_idx == 0" ) channel_idx = 0 if self._is_rgb: image = self._dask_image[:, :, channel_idx].compute() else: image = self._dask_image[channel_idx, :, :].compute() return image