Source code for wsireg.writers.merge_ome_tiff_writer

from pathlib import Path
from typing import List, Optional, Tuple, Union

import cv2
import numpy as np
import SimpleITK as sitk
from tifffile import TiffWriter

from wsireg.reg_images.reg_image import RegImage
from wsireg.reg_images.merge_reg_image import MergeRegImage
from wsireg.reg_transforms.reg_transform_seq import RegTransformSeq
from wsireg.utils.im_utils import (
    SITK_TO_NP_DTYPE,
    format_channel_names,
    get_pyramid_info,
    prepare_ome_xml_str,
)


[docs] class MergeOmeTiffWriter: x_size: Optional[int] = None y_size: Optional[int] = None y_spacing: Optional[Union[int, float]] = None x_spacing: Optional[Union[int, float]] = None tile_size: int = 512 pyr_levels: Optional[List[Tuple[int, int]]] = None n_pyr_levels: Optional[int] = None PhysicalSizeY: Optional[Union[int, float]] = None PhysicalSizeX: Optional[Union[int, float]] = None subifds: Optional[int] = None compression: str = "deflate" def __init__( self, reg_images: MergeRegImage, reg_transform_seqs: Optional[List[RegTransformSeq]] = None, ): """ Class for writing multiple images wiuth and without transforms to a singel OME-TIFF. Parameters ---------- reg_image: MergeRegImage MergeRegImage to be transformed reg_transform_seqs: List of RegTransformSeq or None Registration transformation sequences for each wsireg image to be merged Attibutes --------- x_size: int Size of the merged image after transformation in x y_size: int Size of the merged image after transformation in y y_spacing: float Pixel spacing in microns after transformation in y x_spacing: float Pixel spacing in microns after transformation in x tile_size: int Size of tiles to be written pyr_levels: list of tuples of int: Size of downsampled images in pyramid n_pyr_levels: int Number of downsamples in pyramid PhysicalSizeY: float physical size of image in micron for OME-TIFF in Y PhysicalSizeX: float physical size of image in micron for OME-TIFF in X subifds: int Number of sub-resolutions for pyramidal OME-TIFF compression: str tifffile string to pass to compression argument, defaults to "deflate" for minisblack and "jpeg" for RGB type images """ self.reg_image = reg_images self.reg_transform_seqs = reg_transform_seqs def _length_checks(self, sub_image_names): """Make sure incoming data is kosher in dimensions""" if isinstance(sub_image_names, list) is False: if sub_image_names is None: sub_image_names = [ "" for i in range(len(self.reg_image.images)) ] else: raise ValueError( "MergeRegImage requires a list of image names for each image to merge" ) if self.reg_transform_seqs is None: transformations = [None for i in range(len(self.reg_image.images))] else: transformations = self.reg_transform_seqs if len(transformations) != len(self.reg_image.images): raise ValueError( "MergeRegImage number of transforms does not match number of images" ) def _create_channel_names(self, sub_image_names): """Create channel names for merge data.""" def prepare_channel_names(sub_image_name, channel_names): return [f"{sub_image_name} - {c}" for c in channel_names] self.reg_image.channel_names = [ prepare_channel_names(im_name, cnames) for im_name, cnames in zip( sub_image_names, self.reg_image.channel_names ) ] self.reg_image.channel_names = [ item for sublist in self.reg_image.channel_names for item in sublist ] def _transform_check(self): """Check that all transforms as currently loaded output to the same size/resolution""" out_size = [] out_spacing = [] if not self.reg_transform_seqs: rts = [None for _ in range(len(self.reg_image.images))] else: rts = self.reg_transform_seqs for im, t in zip(self.reg_image.images, rts): if t: out_size.append(t.reg_transforms[-1].output_size) out_spacing.append(t.reg_transforms[-1].output_spacing) else: out_im_size = ( (im.shape[0], im.shape[1]) if im.is_rgb else (im.shape[1], im.shape[2]) ) out_im_spacing = im.image_res out_size.append(out_im_size) out_spacing.append(out_im_spacing) if all(out_spacing) is False: raise ValueError( "MergeRegImage all transforms output spacings and untransformed image spacings must match" ) if all(out_size) is False: raise ValueError( "MergeRegImage all transforms output sizes and untransformed image sizes must match" ) def _prepare_image_info( self, reg_image: RegImage, image_name: str, im_dtype: np.dtype, reg_transform_seq: Optional[RegTransformSeq] = None, write_pyramid: bool = True, tile_size: int = 512, compression: str = "default", ): """Prepare OME-XML and other data needed for saving""" if reg_transform_seq: self.x_size, self.y_size = reg_transform_seq.output_size self.x_spacing, self.y_spacing = reg_transform_seq.output_spacing else: self.y_size, self.x_size = ( (reg_image.shape[0], reg_image.shape[1]) if reg_image.is_rgb else (reg_image.shape[1], reg_image.shape[2]) ) self.y_spacing, self.x_spacing = ( reg_image.image_res, reg_image.image_res, ) self.tile_size = tile_size # protect against too large tile size while ( self.y_size / self.tile_size <= 1 or self.x_size / self.tile_size <= 1 ): self.tile_size = self.tile_size // 2 self.pyr_levels, _ = get_pyramid_info( self.y_size, self.x_size, reg_image.n_ch, self.tile_size ) self.n_pyr_levels = len(self.pyr_levels) if reg_transform_seq: self.PhysicalSizeY = self.y_spacing self.PhysicalSizeX = self.x_spacing else: self.PhysicalSizeY = reg_image.image_res self.PhysicalSizeX = reg_image.image_res channel_names = format_channel_names( self.reg_image.channel_names, self.reg_image.n_ch ) self.omexml = prepare_ome_xml_str( self.y_size, self.x_size, len(channel_names), im_dtype, False, PhysicalSizeX=self.PhysicalSizeX, PhysicalSizeY=self.PhysicalSizeY, PhysicalSizeXUnit="µm", PhysicalSizeYUnit="µm", Name=image_name, Channel={"Name": channel_names}, ) self.subifds = self.n_pyr_levels - 1 if write_pyramid is True else None if compression == "default": print("using default compression") self.compression = "deflate" else: self.compression = compression def _get_merge_dtype(self): """Determine data type for merger. Will default to the largest dtype. If one image is np.uint8 and another np.uint16, the image at np.uint8 will be cast to np.uint16""" dtype_max_size = [ np.iinfo(r.im_dtype).max for r in self.reg_image.images ] merge_dtype_np = self.reg_image.images[ np.argmax(dtype_max_size) ].im_dtype for k, v in SITK_TO_NP_DTYPE.items(): if k < 12: if v == merge_dtype_np: merge_dtype_sitk = k return merge_dtype_sitk, merge_dtype_np
[docs] def merge_write_image_by_plane( self, image_name: str, sub_image_names: List[str], output_dir: Union[Path, str] = "", write_pyramid: bool = True, tile_size: int = 512, compression: Optional[str] = "default", ) -> str: """ Write merged OME-TIFF image plane-by-plane to disk. RGB images will be de-interleaved with RGB channels written as separate planes. Parameters ---------- image_name: str Name to be written WITHOUT extension for example if image_name = "cool_image" the file would be "cool_image.ome.tiff" sub_image_names: list of str Names added before each channel of a given image to distinguish it. output_dir: Path or str Directory where the image will be saved write_pyramid: bool Whether to write the OME-TIFF with sub-resolutions or not tile_size: int What size to write OME-TIFF tiles to disk compression: str tifffile string to pass to compression argument, defaults to "deflate" for minisblack and "jpeg" for RGB type images Returns ------- output_file_name: str File path to the written OME-TIFF """ merge_dtype_sitk, merge_dtype_np = self._get_merge_dtype() self._length_checks(sub_image_names) self._create_channel_names(sub_image_names) self._transform_check() output_file_name = str(Path(output_dir) / f"{image_name}.ome.tiff") self._prepare_image_info( self.reg_image.images[0], image_name, merge_dtype_np, reg_transform_seq=self.reg_transform_seqs[0], write_pyramid=write_pyramid, tile_size=tile_size, compression=compression, ) print(f"saving to {output_file_name}") with TiffWriter(output_file_name, bigtiff=True) as tif: for m_idx, merge_image in enumerate(self.reg_image.images): if self.reg_image.images[m_idx].reader == "sitk": full_image = sitk.ReadImage( self.reg_image.images[m_idx].image_filepath ) merge_n_ch = merge_image.n_ch for channel_idx in range(merge_n_ch): if self.reg_image.images[m_idx].reader != "sitk": image = self.reg_image.images[ m_idx ].read_single_channel(channel_idx) image = np.squeeze(image) image = sitk.GetImageFromArray(image) image.SetSpacing( ( self.reg_image.images[m_idx].image_res, self.reg_image.images[m_idx].image_res, ) ) else: if len(full_image.GetSize()) > 2: image = full_image[:, :, channel_idx] else: image = full_image if image.GetPixelIDValue() != merge_dtype_sitk: image = sitk.Cast(image, merge_dtype_sitk) if self.reg_transform_seqs[m_idx]: image = self.reg_transform_seqs[ m_idx ].resampler.Execute(image) if isinstance(image, sitk.Image): image = sitk.GetArrayFromImage(image) options = dict( tile=(tile_size, tile_size), compression=self.compression, photometric="minisblack", metadata=None, ) # write OME-XML to the ImageDescription tag of the first page description = ( self.omexml if channel_idx == 0 and m_idx == 0 else None ) # write channel data print( f" writing subimage index {m_idx} : {sub_image_names[m_idx]} - " f"channel index - {channel_idx} - shape: {image.shape}" ) tif.write( image, subifds=self.subifds, description=description, **options, ) if write_pyramid: for pyr_idx in range(1, self.n_pyr_levels): resize_shape = ( self.pyr_levels[pyr_idx][0], self.pyr_levels[pyr_idx][1], ) image = cv2.resize( image, resize_shape, cv2.INTER_LINEAR, ) tif.write(image, **options, subfiletype=1) return output_file_name