Source code for wsireg.writers.tiled_ome_tiff_writer

import multiprocessing
import random
import string
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import List, Optional, Tuple, Union

import dask.array as da
import numpy as np
import SimpleITK as sitk
import zarr
from tifffile import TiffWriter
from tiler import Tiler
from tqdm import tqdm
from wsireg.reg_images.reg_image import RegImage
from wsireg.reg_transforms.reg_transform_seq import RegTransformSeq
from wsireg.utils.im_utils import (
    format_channel_names,
    get_pyramid_info,
    prepare_ome_xml_str,
)
from wsireg.utils.tform_utils import ELX_TO_ITK_INTERPOLATORS


[docs] class OmeTiffTiledWriter: """ Class for transforming, then writing whole slide images tile-by-tile, allowing memory-efficient transformation of images. The image to be transformed has to have a dask representation that is itself tiled because the writer finds the position of each write tile in fixed space in moving space and reads only the necessary portion of the image to perform the transformation. Tiles are stored in a temporary `zarr` store that is deleted after writing to OME-TIFF. Uses the Tiler library to manage "virtual tiling" for transformation. Parameters ---------- reg_image: RegImage wsireg RegImage that has a dask store that is chunked in XY (typical of WSIs) .czi reader does not work! reg_transform_seq: RegTransformSeq wsireg registration transform sequence to be applied to the image tile_size: int Tile size of the output image zarr_tile_size: int Tile used in the zarr intermediate moving_tile_padding: int How much additional padded to pull from moving for each transformed tile This ensures that the interpolation is correctly performed during resampling Non-rigid transforms may need more spacing Attributes ---------- reg_image: RegImage RegImage to be transformed reg_transform_seq: RegTransformSeq RegTransformSeq to be used in transformation tile_shape: tuple of ints Shape of OME-TIFF tiles going to disk zarr_tile_shape: tuple of ints Shape of zarr tiles going to disk temporarily moving_tile_padding: int Tile padding use at read in for interpolation """ def __init__( self, reg_image: RegImage, reg_transform_seq: RegTransformSeq, tile_size: int = 512, zarr_tile_size: int = 2048, moving_tile_padding: int = 128, ): self._fixed_tile_positions: List[Tuple[int, int, int, int]] = [] self._fixed_tile_positions_phys: List[ Tuple[float, float, float, float] ] = [] self._moving_tile_positions: List[Tuple[int, int, int, int]] = [] self._moving_tile_positions_phys: List[ Tuple[float, float, float, float] ] = [] self._tiler: Optional[Tiler] = None self.reg_image: RegImage = reg_image self.reg_transform_seq: RegTransformSeq = reg_transform_seq self.tile_shape = (tile_size, tile_size) self.zarr_tile_shape = (zarr_tile_size, zarr_tile_size) self._check_dask_array_chunk_sizes(self.reg_image.dask_image) self.moving_tile_padding = moving_tile_padding self._build_transformation_tiles() @property def fixed_tile_positions(self) -> List[Tuple[int, int, int, int]]: """List of tile positions on the fixed image in pixels first np.ndarray is top-left x,y coordinate second np.ndarray is bottom-right x,y coordinate""" return self._fixed_tile_positions @property def fixed_tile_positions_phys( self, ) -> List[Tuple[float, float, float, float]]: """List of tile positions on the fixed image in physical coordinate space first np.ndarray is top-left x,y coordinate second np.ndarray is bottom-right x,y coordinate""" return self._fixed_tile_positions_phys @property def moving_tile_positions(self) -> List[Tuple[int, int, int, int]]: """Transformed coordinates of fixed tile positions to moving, pixels first np.ndarray is top-left x,y coordinate second np.ndarray is bottom-right x,y coordinate""" return self._moving_tile_positions @property def moving_tile_positions_phys( self, ) -> List[Tuple[float, float, float, float]]: """Transformed coordinates of fixed tile positions to moving, physical first np.ndarray is top-left x,y coordinate second np.ndarray is bottom-right x,y coordinate""" return self._moving_tile_positions_phys @property def tiler(self) -> Tiler: """Tiler instance to manage fixed output tiling from image shape.""" return self._tiler def _create_tiler(self): """Create the Tiler instance.""" self._tiler = Tiler( self.reg_transform_seq.output_size, self.zarr_tile_shape, overlap=0, mode="irregular", ) def _check_dask_array_chunk_sizes(self, dask_image: da.Array) -> None: """Check if dask image has an acceptable chunk-size for tiled writing.""" yx_chunks = ( dask_image.chunksize[:2] if self.reg_image.is_rgb else dask_image.chunksize[1:] ) if np.any(np.asarray(yx_chunks) > np.asarray(self.zarr_tile_shape)): raise ValueError( f"Dask image chunksize for image {str(self.reg_image.path)} " "is too large for tiled writing and effectively memory use is not " "compared to plane-by-plane writing." ) return def _build_transformation_tiles(self): """Method to reinitialize tiler if there are changes.""" self._create_tiler() self._get_fixed_tile_positions() self._get_fixed_tile_positions_phys() self._get_moving_tile_positions() def _get_and_clip_fixed_tile( self, tile_idx: int, output_size: Tuple[int, int], order=[1, 0] ): """Method to ensure tiles do not go beyond the output shape of the fixed target image""" tile_pos = self.tiler.get_tile_bbox(tile_idx) if tile_pos[1][0] > output_size[order[0]]: tile_pos[1][0] = output_size[order[0]] if tile_pos[1][1] > output_size[order[1]]: tile_pos[1][1] = output_size[order[1]] return tile_pos def _get_fixed_tile_positions(self): """Find the tile positions on the fixed image.""" self._fixed_tile_positions = [ self._get_and_clip_fixed_tile( i, self.reg_transform_seq.output_size, order=[0, 1] ) for i in range(self.tiler.n_tiles) ] def _get_fixed_tile_positions_phys(self): """Fixed tile pixel indices to physical coordinates used in ITK transforms.""" self._fixed_tile_positions_phys = [ ( f[0] * self.reg_transform_seq.output_spacing, f[1] * self.reg_transform_seq.output_spacing, ) for f in self._fixed_tile_positions ] def _get_moving_tile_positions(self): """Method to transform tile positions in fixed to moving so that each write tile in fixed has a corresponding read region in moving.""" for fixed_tile_pos in self._fixed_tile_positions_phys: corners_phys = [] corners_px = [] for idx, corner in enumerate(fixed_tile_pos): if idx == 0: corner -= self.moving_tile_padding if idx == 1: corner += self.moving_tile_padding for idx, t in enumerate( self.reg_transform_seq.reg_transforms[::-1] ): if idx == 0: t_pt = t.itk_transform.TransformPoint(corner.tolist()) else: t_pt = t.itk_transform.TransformPoint(t_pt) t_pt = np.array(t_pt) t_pt_px = t_pt / self.reg_image.image_res corners_phys.append(t_pt) corners_px.append(t_pt_px) self._moving_tile_positions_phys.append(tuple(corners_phys)) self._moving_tile_positions.append(tuple(corners_px))
[docs] def set_output_spacing( self, output_spacing: Tuple[Union[int, float], Union[int, float]] ) -> None: """ Sets the output spacing of the resampled image and will change output shape accordingly Parameters ---------- output_spacing: Tuple[Union[int,float], Union[int,float]] Spacing of grid for resampling. Will default to target image spacing """ self.reg_transform_seq.set_output_spacing(output_spacing) self._build_transformation_tiles()
[docs] def set_tile_size(self, tile_size: int) -> None: """ Set the internal tile size of the OME-TIFF to be written. Parameters ---------- tile_size: int tile size in pixels in x and y for the OME-TIFF """ self.tile_shape = (tile_size, tile_size)
[docs] def set_zarr_tile_size(self, tile_size: int) -> None: """ Set the tile size for the zarr intermediate. Parameters ---------- tile_size: int tile size in pixels in x and y for the temporary zarr store """ self.zarr_tile_shape = (tile_size, tile_size) self._build_transformation_tiles()
def _create_tile_resampler( self, tile_origin: Tuple[float, float] ) -> sitk.ResampleImageFilter: """ Build each tile's resampler. Parameters ---------- tile_origin: Tuple[float, float] Position of the tile in physical coordinates Returns ------- resampler: sitk.ResampleImageFilter resampler for an individual fixed tile """ resampler = sitk.ResampleImageFilter() resampler.SetOutputOrigin(tile_origin) resampler.SetOutputDirection( self.reg_transform_seq.reg_transforms[-1].output_direction ) resampler.SetSize(self.zarr_tile_shape) resampler.SetOutputSpacing(self.reg_transform_seq.output_spacing) interpolator = ELX_TO_ITK_INTERPOLATORS.get( self.reg_transform_seq.reg_transforms[-1].resample_interpolator ) resampler.SetInterpolator(interpolator) resampler.SetTransform(self.reg_transform_seq.composite_transform) return resampler
[docs] def write_tiles_to_zarr_store( self, temp_zarr_store: zarr.TempStore, max_workers: Optional[int] = None, ): """ Write tiles to a temporary zarr store. Parameters ---------- temp_zarr_store: zarr.TempStore Temporary store where the dataset will go. Returns ------- resample_zarray: zarr.Array zarr store contained transformed images """ zgrp = zarr.open(temp_zarr_store) if self.reg_image.is_rgb: resample_zarray = zgrp.create_dataset( random_str(), shape=( self.reg_transform_seq.output_size[1], self.reg_transform_seq.output_size[0], self.reg_image.shape[-1], ), chunks=self.tile_shape, dtype=self.reg_image.im_dtype, ) else: resample_zarray = zgrp.create_dataset( random_str(), shape=( self.reg_image.n_ch, self.reg_transform_seq.output_size[1], self.reg_transform_seq.output_size[0], ), chunks=(1,) + self.tile_shape, dtype=self.reg_image.im_dtype, ) self._transform_write_tile_set( resample_zarray, max_workers=max_workers ) return resample_zarray
def _transform_write_tile(self, data): """Worker function to transform and place tile in zarr store.""" ( resample_zarray, ch_idx, fixed_tile_position, fixed_tile_origin, moving_tile_corners, ) = data tile_resampler = self._create_tile_resampler(fixed_tile_origin) x_size, y_size = self._get_image_size() x_max, x_min, y_max, y_min = self._get_moving_tile_slice( moving_tile_corners, x_size, y_size ) tile_resampled = self._resample_tile( ch_idx, tile_resampler, x_max, x_min, y_max, y_min ) if tile_resampled: ( x_max_fixed, x_min_fixed, y_max_fixed, y_min_fixed, ) = self._get_fixed_slice(fixed_tile_position) x_max, y_max = self._correct_end_moving_slices( x_max_fixed, x_min_fixed, y_max_fixed, y_min_fixed, ) if self.reg_image.is_rgb: resample_zarray[ y_min_fixed:y_max_fixed, x_min_fixed:x_max_fixed, : ] = sitk.GetArrayFromImage(tile_resampled)[:y_max, :x_max, :] else: resample_zarray[ ch_idx, y_min_fixed:y_max_fixed, x_min_fixed:x_max_fixed ] = sitk.GetArrayFromImage(tile_resampled)[:y_max, :x_max] def _get_image_size(self) -> Tuple[int, int]: """Get moving image size for tile dilineation""" x_size = ( self.reg_image.shape[1] if self.reg_image.is_rgb else self.reg_image.shape[2] ) y_size = ( self.reg_image.shape[0] if self.reg_image.is_rgb else self.reg_image.shape[1] ) return x_size, y_size def _resample_tile( self, ch_idx: int, tile_resampler: sitk.ResampleImageFilter, x_max: int, x_min: int, y_max: int, y_min: int, ) -> Optional[sitk.Image]: """Resample tile or don't if it is outside of the moving image space.""" if x_min == 0 and x_max == 0: return if y_min == 0 and y_max == 0: return if self.reg_image.is_rgb: image = self.reg_image.dask_image[y_min:y_max, x_min:x_max, :] image = sitk.GetImageFromArray(image, isVector=True) elif self.reg_image.n_ch == 1: image = da.squeeze(self.reg_image.dask_image)[ y_min:y_max, x_min:x_max ] image = sitk.GetImageFromArray(image, isVector=False) else: image = self.reg_image.dask_image[ch_idx, y_min:y_max, x_min:x_max] image = sitk.GetImageFromArray(image, isVector=False) image.SetSpacing((self.reg_image.image_res, self.reg_image.image_res)) image.SetOrigin( image.TransformIndexToPhysicalPoint([int(x_min), int(y_min)]) ) tile_resampled = tile_resampler.Execute(image) return tile_resampled def _correct_end_moving_slices( self, x_max_fixed: int, x_min_fixed: int, y_max_fixed: int, y_min_fixed: int, ) -> Tuple[int, int]: """Correct tiles that extend past the size of the fixed coordinate space.""" # correct for end tiles y_max = y_max_fixed - y_min_fixed x_max = x_max_fixed - x_min_fixed return x_max, y_max def _get_fixed_slice( self, fixed_tile_position: Tuple[np.ndarray, np.ndarray] ) -> Tuple[int, int, int, int]: """Get tile slice in fixed tile pixels.""" y_min_fixed = fixed_tile_position[0][1] x_min_fixed = fixed_tile_position[0][0] y_max_fixed = fixed_tile_position[1][1] x_max_fixed = fixed_tile_position[1][0] return x_max_fixed, x_min_fixed, y_max_fixed, y_min_fixed def _get_moving_tile_slice( self, moving_tile_corners: Tuple[np.ndarray, np.ndarray], x_size: int, y_size: int, ) -> Tuple[int, int, int, int]: """Get tile slice in moving tile pixels.""" x_min = ( moving_tile_corners[0][0] if moving_tile_corners[0][0] >= 0 else 0 ) x_min = x_min if x_min <= x_size else x_size x_min = np.ceil(x_min).astype(int) x_max = ( moving_tile_corners[1][0] if moving_tile_corners[1][0] >= 0 else 0 ) x_max = x_max if x_max <= x_size else x_size x_max = np.ceil(x_max).astype(int) y_min = ( moving_tile_corners[0][1] if moving_tile_corners[0][1] >= 0 else 0 ) y_min = y_min if y_min <= y_size else y_size y_min = np.ceil(y_min).astype(int) y_max = ( moving_tile_corners[1][1] if moving_tile_corners[1][1] >= 0 else 0 ) y_max = y_max if y_max <= y_size else y_size y_max = np.ceil(y_max).astype(int) # catch changing positions of x and y when there are coordinate flips if y_min > y_max: y_max_temp = y_min y_min_temp = y_max y_max = y_max_temp y_min = y_min_temp if x_min > x_max: x_max_temp = x_min x_min_temp = x_max x_max = x_max_temp x_min = x_min_temp return x_max, x_min, y_max, y_min def _transform_write_tile_set( self, resample_zarray: zarr.Array, max_workers: Optional[int] = None ): """Function to loop over all channels and tile positions and write to zarr""" if max_workers == 1: use_multiprocessing = False else: use_multiprocessing = True if not max_workers: max_workers = multiprocessing.cpu_count() n_ch = 1 if self.reg_image.is_rgb else self.reg_image.n_ch all_tile_args = [] for ch_idx in range(n_ch): for ft_pos, mt_pos in tqdm( zip( self._fixed_tile_positions, self._moving_tile_positions, ), total=len(self._fixed_tile_positions), desc="Writing zarr tiles", unit=" tile", disable=True if use_multiprocessing else False, ): tile_origin = ( ft_pos[0] * self.reg_transform_seq.output_spacing[0] ) tile_args = ( resample_zarray, ch_idx, ft_pos, tuple(tile_origin.astype(float)), mt_pos, ) all_tile_args.append(tile_args) if not use_multiprocessing: self._transform_write_tile(tile_args) if use_multiprocessing: with ThreadPoolExecutor(max_workers) as executor: _ = list( tqdm( executor.map( self._transform_write_tile, all_tile_args ), total=len(all_tile_args), desc="Writing zarr tiles", unit=" tile", ) ) def _prepare_image_info( self, image_name, write_pyramid=True, ): """Prepare info for pyramidalization and create OME-TIFF.""" x_size, y_size = self.reg_transform_seq.output_size x_spacing, y_spacing = self.reg_transform_seq.output_spacing out_tile_shape = self.tile_shape # protect against too large tile size while ( y_size / out_tile_shape[0] <= 1 or x_size / out_tile_shape[0] <= 1 ): out_tile_shape = (out_tile_shape[0] // 2, out_tile_shape[1] // 2) pyr_levels, _ = get_pyramid_info( y_size, x_size, self.reg_image.n_ch, self.tile_shape[0] ) n_pyr_levels = len(pyr_levels) PhysicalSizeY = y_spacing PhysicalSizeX = x_spacing channel_names = format_channel_names( self.reg_image.channel_names, self.reg_image.n_ch ) omexml = prepare_ome_xml_str( y_size, x_size, self.reg_image.n_ch, self.reg_image.im_dtype, self.reg_image.is_rgb, PhysicalSizeX=PhysicalSizeX, PhysicalSizeY=PhysicalSizeY, PhysicalSizeXUnit="µm", PhysicalSizeYUnit="µm", Name=image_name, Channel=None if self.reg_image.is_rgb else {"Name": channel_names}, ) subifds = n_pyr_levels - 1 if write_pyramid is True else None return n_pyr_levels, subifds, out_tile_shape, omexml def _transformed_tile_generator(self, d_array: da.Array, ch_idx: int): """Create generator of tifffile tiles for OME-TIFF.""" out_shape = ( d_array.shape[:2] if self.reg_image.is_rgb else d_array.shape[1:] ) for y in range(0, out_shape[0], self.tile_shape[0]): for x in range(0, out_shape[1], self.tile_shape[1]): if self.reg_image.is_rgb: yield d_array[ y : y + self.tile_shape[0], x : x + self.tile_shape[1], :, ].compute() else: yield d_array[ ch_idx, y : y + self.tile_shape[0], x : x + self.tile_shape[1], ].compute()
[docs] def write_image_by_tile( self, image_name: str, output_dir: Union[Path, str] = "", write_pyramid: bool = True, compression: Optional[str] = "default", zarr_temp_dir: Optional[Union[str, Path]] = None, ) -> str: """ Write images to OME-TIFF from temp zarr store with data. Parameters ---------- image_name: str file path stem of the image to be written output_dir: Union[str,Path] directory where image is to be written write_pyramid: bool whether to write a pyramid or single layer compression: str Use compression. "default" will be lossless "deflate" for non-rgb images and "jpeg" for RGB images zarr_temp_dir: Path or str Directory to store the temporary zarr data (mostly used for debugging) Returns ------- output_file_name: Path Path to written image file """ zstr = zarr.TempStore(dir=zarr_temp_dir) try: resample_zarray = self.write_tiles_to_zarr_store(zstr) output_file_name = str(Path(output_dir) / f"{image_name}.ome.tiff") if compression == "default": print("using default compression") compression = "jpeg" if self.reg_image.is_rgb else "deflate" else: compression = compression ( n_pyr_levels, subifds, out_tile_shape, omexml, ) = self._prepare_image_info( image_name, write_pyramid=write_pyramid ) print(f"saving to {output_file_name}") dask_image = da.from_zarr(resample_zarray) options = dict( tile=self.tile_shape, compression=compression, photometric="rgb" if self.reg_image.is_rgb else "minisblack", metadata=None, ) with TiffWriter(output_file_name, bigtiff=True) as tif: if self.reg_image.is_rgb: print( f"writing base layer RGB - shape: {dask_image.shape}" ) # tile_iterator_strides = self._get_tile_iterator_strides(dask_image) tile_iterator = self._transformed_tile_generator( dask_image, 0 ) tif.write( tile_iterator, subifds=subifds, description=omexml, shape=dask_image.shape, dtype=dask_image.dtype, **options, ) if write_pyramid: for pyr_idx in range(1, n_pyr_levels): sub_res = compute_sub_res( dask_image, pyr_idx, self.tile_shape[0], self.reg_image.is_rgb, self.reg_image.im_dtype, ) print( f"pyr {pyr_idx} : RGB-shape: {sub_res.shape}" ) # tile_strides = self._get_tile_iterator_strides(sub_res) sub_res_tile_iterator = ( self._transformed_tile_generator(sub_res, 0) ) tif.write( sub_res_tile_iterator, shape=sub_res.shape, dtype=self.reg_image.im_dtype, **options, subfiletype=1, ) else: for channel_idx in range(self.reg_image.n_ch): description = omexml if channel_idx == 0 else None print( f"writing channel {channel_idx} - shape: {dask_image.shape[1:]}" ) tile_iterator = self._transformed_tile_generator( dask_image, channel_idx ) tif.write( tile_iterator, subifds=subifds, description=description, shape=dask_image.shape[1:], dtype=dask_image.dtype, **options, ) if write_pyramid: for pyr_idx in range(1, n_pyr_levels): sub_res = compute_sub_res( dask_image, pyr_idx, self.tile_shape[0], self.reg_image.is_rgb, self.reg_image.im_dtype, ) sub_res_tile_iterator = ( self._transformed_tile_generator( sub_res, channel_idx ) ) tif.write( sub_res_tile_iterator, shape=sub_res.shape[1:], dtype=dask_image.dtype, **options, subfiletype=1, ) try: resample_zarray.store.clear() except FileNotFoundError: pass return output_file_name # bare except to always clear temporary storage on failure except Exception as e: print(e) try: resample_zarray.store.clear() except FileNotFoundError: pass
[docs] def compute_sub_res( zarray: da.Array, pyr_level: int, tile_size: int, is_rgb: bool, im_dtype: np.dtype, ) -> da.Array: """ Compute factor-of-2 sub-resolutions from dask array for pyramidalization using dask. Parameters ---------- zarray: da.Array Dask array to be downsampled pyr_level: int level of the pyramid. 0 = base, 1 = 2x downsampled, 2=4x downsampled... tile_size: int Size of tiles in dask array after downsampling is_rgb: bool whether dask array is RGB interleaved im_dtype: np.dtype dtype of the output da.Array Returns ------- resampled_zarray_subres: da.Array Dask array (unprocessed) to be written """ if is_rgb: resampling_axis = {0: 2**pyr_level, 1: 2**pyr_level, 2: 1} tiling = (tile_size, tile_size, 3) else: resampling_axis = {0: 1, 1: 2**pyr_level, 2: 2**pyr_level} tiling = (1, tile_size, tile_size) resampled_zarray_subres = da.coarsen( np.mean, zarray, resampling_axis, trim_excess=True, ) resampled_zarray_subres = resampled_zarray_subres.astype(im_dtype) resampled_zarray_subres = resampled_zarray_subres.rechunk(tiling) return resampled_zarray_subres
[docs] def random_str() -> str: """Get a random string to store the zarr array""" letters = string.ascii_lowercase return ''.join(random.choice(letters) for _ in range(10))