Source code for wsireg.reg_images.tifffile_reg_image

import warnings
from typing import List, Tuple

import dask.array as da
import numpy as np
import SimpleITK as sitk
from ome_types import from_xml
from tifffile import TiffFile

from wsireg.reg_images.reg_image import RegImage
from wsireg.utils.im_utils import (
    get_tifffile_info,
    guess_rgb,
    preprocess_dask_array,
    tifffile_to_dask,
)


[docs] class TiffFileRegImage(RegImage): def __init__( self, image_fp, image_res, mask=None, pre_reg_transforms=None, preprocessing=None, channel_names=None, channel_colors=None, ): super(TiffFileRegImage, self).__init__(preprocessing) self._path = image_fp self._image_res = image_res self.tf = TiffFile(self._path) self.reader = "tifffile" ( self._shape, self._im_dtype, self.largest_series, ) = self._get_image_info() self._get_dim_info() self._dask_image = self._get_dask_image() if mask: self._mask = self.read_mask(mask) self.pre_reg_transforms = pre_reg_transforms self._channel_names = channel_names self._channel_colors = channel_colors self.original_size_transform = None def _get_image_info(self) -> Tuple[Tuple[int, int, int], np.dtype, int]: if len(self.tf.series) > 1: warnings.warn( "The tiff contains multiple series, " "the largest series will be read by default" ) im_dims, im_dtype, largest_series = get_tifffile_info(self._path) im_dims = (int(im_dims[0]), int(im_dims[1]), int(im_dims[2])) return im_dims, im_dtype, largest_series def _get_dim_info(self) -> None: if self._shape: if self.tf.ome_metadata: self.ome_metadata = from_xml(self.tf.ome_metadata) spp = ( self.ome_metadata.images[self.largest_series] .pixels.channels[0] .samples_per_pixel ) interleaved = self.ome_metadata.images[ self.largest_series ].pixels.interleaved if spp and spp > 1: self._is_rgb = True else: self._is_rgb = False if guess_rgb(self._shape) is False: self._channel_axis = 0 self._is_interleaved = False elif interleaved and guess_rgb(self._shape): self._is_interleaved = True self._channel_axis = len(self._shape) - 1 else: self._is_rgb = guess_rgb(self._shape) self._is_interleaved = self._is_rgb if self._is_rgb: self._channel_axis = len(self._shape) - 1 else: self._channel_axis = 0 self._n_ch = self._shape[self._channel_axis] def _get_dask_image(self) -> List[da.Array]: dask_image = tifffile_to_dask(self._path, self.largest_series, level=0) dask_image = ( dask_image.reshape(1, *dask_image.shape) if len(dask_image.shape) == 2 else dask_image ) if self._is_rgb and not self._is_interleaved: dask_image = da.rollaxis(dask_image, 0, 3) return dask_image
[docs] def read_reg_image(self): """ Read and preprocess the image for registration. """ reg_image = self._dask_image reg_image = preprocess_dask_array(reg_image, self.preprocessing) if ( self.preprocessing is not None and self.preprocessing.as_uint8 is True and reg_image.GetPixelID() != sitk.sitkUInt8 ): reg_image = sitk.RescaleIntensity(reg_image) reg_image = sitk.Cast(reg_image, sitk.sitkUInt8) self.preprocess_image(reg_image)
[docs] def read_single_channel(self, channel_idx: int): """ Read in a single channel for transformation by plane. Parameters ---------- channel_idx: int Index of the channel to be read Returns ------- image: np.ndarray Numpy array of the selected channel to be read """ if channel_idx > (self.n_ch - 1): warnings.warn( "channel_idx exceeds number of channels, reading channel at channel_idx == 0" ) channel_idx = 0 if self._is_rgb: image = self._dask_image[:, :, channel_idx].compute() else: image = self._dask_image[channel_idx, :, :].compute() return image