Source code for wsireg.reg_images.czi_reg_image

import warnings
from typing import Tuple

import dask.array as da
import numpy as np
import SimpleITK as sitk

from wsireg.reg_images.reg_image import RegImage
from wsireg.utils.im_utils import CziRegImageReader, guess_rgb


[docs] class CziRegImage(RegImage): def __init__( self, image, image_res, mask=None, pre_reg_transforms=None, preprocessing=None, channel_names=None, channel_colors=None, ): super(CziRegImage, self).__init__(preprocessing) self._path = image self._image_res = image_res self.czi = CziRegImageReader(self._path) self.reader = "czi" scene_idx = self.czi.axes.index('S') if self.czi.shape[scene_idx] > 1: raise ValueError('multi scene czis not allowed at this time') ( self._shape, self._im_dtype, ) = self._get_image_info() self._is_rgb = guess_rgb(self._shape) self._n_ch = self._shape[2] if self.is_rgb else self._shape[0] self._dask_image = self._prepare_dask_image() if mask: self._mask = self.read_mask(mask) self.pre_reg_transforms = pre_reg_transforms self._channel_names = channel_names self._channel_colors = channel_colors self.original_size_transform = None def _get_image_info(self): # if RGB need to get 0 if self.czi.shape[-1] > 1: ch_dim_idx = self.czi.axes.index('0') else: ch_dim_idx = self.czi.axes.index('C') y_dim_idx = self.czi.axes.index('Y') x_dim_idx = self.czi.axes.index('X') if self.czi.shape[-1] > 1: im_dims = np.array(self.czi.shape)[ [y_dim_idx, x_dim_idx, ch_dim_idx] ] else: im_dims = np.array(self.czi.shape)[ [ch_dim_idx, y_dim_idx, x_dim_idx] ] im_dtype = self.czi.dtype im_dims = (int(im_dims[0]), int(im_dims[1]), int(im_dims[2])) return im_dims, im_dtype def _prepare_dask_image(self) -> da.Array: ch_dim = self._shape[1:] if not self._is_rgb else self._shape[:2] chunks = ((1,) * self._n_ch, (ch_dim[0],), (ch_dim[1],)) dask_image = da.map_blocks( self._czi_read_single_channel, chunks=chunks, dtype=self.im_dtype, meta=np.array((), dtype=self._im_dtype), ) return dask_image def _czi_read_single_channel(self, block_id: Tuple[int, ...]): channel_idx = block_id[0] if self.is_rgb is False: image = self.czi.sub_asarray( channel_idx=[channel_idx], ) else: image = self.czi.sub_asarray_rgb( channel_idx=[channel_idx], greyscale=False ) return np.expand_dims(np.squeeze(image), axis=0)
[docs] def read_reg_image(self): """ Read and preprocess the image for registration. For the Zeiss CZI reader, this involves grayscaling RGB on read or reading only a subset of the channel images. """ if self.is_rgb: reg_image = self.czi.sub_asarray_rgb(greyscale=True) else: reg_image = self.czi.sub_asarray( channel_idx=self.preprocessing.ch_indices, as_uint8=self.preprocessing.as_uint8, ) reg_image = np.squeeze(reg_image) reg_image = sitk.GetImageFromArray(reg_image) self.preprocess_image(reg_image)
[docs] def read_single_channel(self, channel_idx: int): """ Read in a single channel for transformation by plane. Parameters ---------- channel_idx: int Index of the channel to be read Returns ------- image: np.ndarray Numpy array of the selected channel to be read """ if channel_idx > (self.n_ch - 1): warnings.warn( "channel_idx exceeds number of channels, reading channel at channel_idx == 0" ) channel_idx = 0 image = self._dask_image[channel_idx, :, :].compute() return image