diff --git a/README.md b/README.md index 5ef9a9f..2970f58 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ > Tiled processing of arbitrarily large images — any image, any function. -``` +```text ┌──────┬──────┬──────┐ fn(tile) → labels ┌──────┬──────┬──────┐ │ tile │ tile │ tile │ ─────────────────────► │ 1 │ 2 │ 3 │ ├──────┼──────┼──────┤ ├──────┼──────┼──────┤ @@ -295,6 +295,7 @@ Full docs, guides and tutorials: **** - dask[array], numpy, zarr, scipy Optional: + - `psutil` — accurate RAM sizing for `tile_shape="auto"` - `nvidia-ml-py` — accurate GPU VRAM sizing - `tqdm` — progress bars diff --git a/docs/examples/stardist.md b/docs/examples/stardist.md index 795b043..5e641aa 100644 --- a/docs/examples/stardist.md +++ b/docs/examples/stardist.md @@ -47,18 +47,18 @@ tile_process( Load the model **outside** the `fn` closure. If you load it inside, it will be re-initialised (and potentially re-downloaded) once per tile. - For distributed execution, use `functools.partial` with a cached model: +For distributed execution, use `functools.partial` with a cached model: - ```python - from functools import lru_cache +```python +from functools import lru_cache - @lru_cache(maxsize=1) - def _get_model(): - return StarDist2D.from_pretrained("2D_versatile_fluo") +@lru_cache(maxsize=1) +def _get_model(): + return StarDist2D.from_pretrained("2D_versatile_fluo") - def stardist_fn(tile): - model = _get_model() - ... - ``` +def stardist_fn(tile): + model = _get_model() + ... +``` diff --git a/docs/getting_started.md b/docs/getting_started.md index bccd506..d026fd3 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -46,11 +46,11 @@ patchworks can be installed from PyPI on all operating systems, for Python ≥ 3 ## The one function you need -```python -from patchworks import tile_process + ```python + from patchworks import tile_process -result = tile_process(image, fn) -``` + result = tile_process(image, fn) + ``` `tile_process(image, fn)` splits `image` into tiles, runs `fn` on each tile, and returns a globally consistent label array. @@ -65,17 +65,17 @@ and returns a globally consistent label array. patchworks is method-agnostic. Your function receives a NumPy array (one tile) and must return an integer label array of the same shape: -```python -import numpy as np + ```python + import numpy as np -def my_fn(tile: np.ndarray) -> np.ndarray: - from skimage.filters import threshold_otsu - from skimage.measure import label + def my_fn(tile: np.ndarray) -> np.ndarray: + from skimage.filters import threshold_otsu + from skimage.measure import label - binary = tile > threshold_otsu(tile) - return label(binary).astype("int32") -``` + binary = tile > threshold_otsu(tile) + return label(binary).astype("int32") + ``` The function is called independently on every tile. patchworks ensures that objects spanning tile boundaries are merged into a single label. @@ -155,14 +155,14 @@ objects spanning tile boundaries are merged into a single label. Methods like Cellpose and StarDist need spatial context at tile boundaries. Use `overlap` (in voxels) so boundary objects are fully visible: -```python -result = tile_process( - "image.zarr", - my_fn, - tile_shape=(1, 2048, 2048), - overlap=20, # 20-voxel halo on every side -) -``` + ```python + result = tile_process( + "image.zarr", + my_fn, + tile_shape=(1, 2048, 2048), + overlap=20, # 20-voxel halo on every side + ) + ``` !!! info "How overlap works" Each tile is expanded by `overlap` voxels on every side before calling `fn`. @@ -173,22 +173,22 @@ result = tile_process( ## Use Cellpose -```python -from patchworks import tile_process -from patchworks.plugins.cellpose import cellpose_fn - -fn = cellpose_fn("cyto3", gpu=True, diameter=30) - -tile_process( - "image.zarr", - fn, - channel=0, - tile_shape=(1, 2048, 2048), - overlap=20, - write_to="labels.zarr", - progress=True, -) -``` + ```python + from patchworks import tile_process + from patchworks.plugins.cellpose import cellpose_fn + + fn = cellpose_fn("cyto3", gpu=True, diameter=30) + + tile_process( + "image.zarr", + fn, + channel=0, + tile_shape=(1, 2048, 2048), + overlap=20, + write_to="labels.zarr", + progress=True, + ) + ``` See the [Cellpose 2-D example](examples/cellpose_2d.md) for the full workflow. diff --git a/docs/guide/gpu_distributed.md b/docs/guide/gpu_distributed.md index a3ff925..44d0dc4 100644 --- a/docs/guide/gpu_distributed.md +++ b/docs/guide/gpu_distributed.md @@ -60,7 +60,7 @@ in the same process as the kernel. When your segmentation function holds the Python GIL (every PyTorch/CUDA `eval` does), the worker thread can't send heartbeats. The scheduler declares it dead, and the merge fails: -``` +```python FutureCancelledError: lost dependencies ``` diff --git a/docs/guide/merging.md b/docs/guide/merging.md index bf9fb07..d206133 100644 --- a/docs/guide/merging.md +++ b/docs/guide/merging.md @@ -9,7 +9,7 @@ even though it's the same cell. patchworks solves this with a zarr-native merge algorithm: -``` +```text Tile A labels: Tile B labels: After merge: ┌────────────┐ ┌────────────┐ ┌──────────────────────┐ │ 3 1 2 │ │ 1 4 2 │ │ 3 1 2 │ 501 5 502│ @@ -32,7 +32,7 @@ Each tile's labels are written to a temporary zarr once. This is critical: without staging, any downstream operation that reads the label array re-runs your segmentation function. The merge internally reads labels multiple times. -``` +```text tile_process calls fn once per tile → staged zarr │ merge reads from staged zarr (no fn calls) diff --git a/docs/guide/ome_zarr_napari.md b/docs/guide/ome_zarr_napari.md index 1291c92..5ee97af 100644 --- a/docs/guide/ome_zarr_napari.md +++ b/docs/guide/ome_zarr_napari.md @@ -76,6 +76,28 @@ and streaming the downsampled result out through dask with bounded chunks. The graph never chains level-on-level and no whole plane/volume is held in RAM, so terabyte images convert in bounded memory. +### Sharding (fewer files) + +A big array becomes tens of thousands of tiny chunk files, which strain +filesystems and object stores. Sharding packs many chunks into one **shard** +file (zarr v3), cutting the file count ~100×: + +```python +to_ome_zarr("scan.ims", "scan.zarr", shard=True) # auto ~512 MB shards +to_ome_zarr("scan.ims", "scan.zarr", shard=(1, 16, 2048, 2048)) # explicit +``` + +Default is `shard=False` for maximum reader compatibility — sharding is +zarr-v3-only, so older tools may not read it (your zarr/napari stack does). +A sharded write holds ~one shard per worker in RAM, so very large shards cost +memory. + +### Progress + +All write steps show a dask progress bar **by default** (`progress=True`), so +you can see how long a conversion will take. Pass `progress=False` to silence +it. + !!! note "Install the readers you need" `pip install "patchworks[bioio]"` pulls `bioio` plus the `bioio-bioformats` catch-all reader (needs a JVM). For speed, add native readers for your diff --git a/docs/guide/pitfalls.md b/docs/guide/pitfalls.md index b3e8cbb..db54d70 100644 --- a/docs/guide/pitfalls.md +++ b/docs/guide/pitfalls.md @@ -30,7 +30,7 @@ single-GPU runs — patchworks pins it to 1 thread automatically). patchworks detects in-process clients at startup and raises immediately: -``` +```python RuntimeError: Active Dask client uses an in-process worker (processes=False). This breaks the label merge when fn holds the GIL. Use a process-based cluster instead: diff --git a/docs/guide/skip_empty.md b/docs/guide/skip_empty.md index 98b0b5a..5e22023 100644 --- a/docs/guide/skip_empty.md +++ b/docs/guide/skip_empty.md @@ -88,6 +88,6 @@ tile_process( After a `tile_process` run with `skip_empty=True`, the log reports exactly how many tiles ran your function: -``` +```text INFO patchworks._core: skip_empty: 486/2200 tiles ran fn, 1714 skipped (max<=412.0) ``` diff --git a/docs/guide/tiling.md b/docs/guide/tiling.md index b2e2b7f..9049e52 100644 --- a/docs/guide/tiling.md +++ b/docs/guide/tiling.md @@ -13,6 +13,7 @@ peak RAM during segmentation is approximately one tile's worth of data. ## Choosing a tile size The right tile size depends on: + - Your available RAM (or GPU VRAM) - The minimum context your segmentation method needs (objects should fit fully inside a tile, or you need overlap) @@ -62,7 +63,7 @@ Methods that need spatial context (Cellpose, StarDist, U-Net) produce wrong results near tile edges: objects at the boundary are cut off. Overlap fixes this by expanding each tile by `overlap` voxels on every side. -``` +```text No overlap: With overlap=20: ┌──────────┐ ┌──────────────────┐ │ │ │ ░░░░░░░░░░░░░░ │ @@ -86,4 +87,4 @@ No overlap: With overlap=20: automatically clips the depth per axis, so z-tiles of size 1 (typical in 2-D Cellpose mode) get `depth=0` in z even if you pass `overlap=20`. - Axes that are too small for the requested overlap simply get a smaller halo. + Axes that are too small for the requested overlap simply get a smaller halo. diff --git a/docs/index.md b/docs/index.md index 2d1bfd4..f5bd9bf 100644 --- a/docs/index.md +++ b/docs/index.md @@ -2,7 +2,7 @@ **Tiled processing of arbitrarily large images — any image, any function.** -``` +```text ┌──────┬──────┬──────┐ ┌──────┬──────┬──────┐ │ │ │ │ fn(tile) → IDs │ 1 │ 2 │ 3 │ │ │ │ │ ───────────────► │ │ │ │ diff --git a/src/patchworks/_chunks.py b/src/patchworks/_chunks.py index 0659c3a..de353d6 100644 --- a/src/patchworks/_chunks.py +++ b/src/patchworks/_chunks.py @@ -49,6 +49,14 @@ def auto_overlap(diameter: float, safety: float = 1.0) -> int: def _get_available_memory() -> int: + """Return available system RAM in bytes. + + Returns + ------- + int + Available memory via ``psutil``, or an 8 GiB fallback if it is not + installed. + """ try: import psutil @@ -103,7 +111,14 @@ def safe_worker_count( def _get_gpu_memory() -> int: - """Return free GPU VRAM in bytes. Falls back to 8 GiB default.""" + """Return free GPU VRAM in bytes. + + Returns + ------- + int + Free VRAM of GPU 0 via ``nvidia-ml-py``, or an 8 GiB fallback if the + query fails. + """ try: import pynvml diff --git a/src/patchworks/_cluster.py b/src/patchworks/_cluster.py index 583121c..71b75d0 100644 --- a/src/patchworks/_cluster.py +++ b/src/patchworks/_cluster.py @@ -9,7 +9,14 @@ def _distributed_client(): - """Return the active dask.distributed Client, or None.""" + """Return the active dask.distributed Client, or None. + + Returns + ------- + distributed.Client or None + The current client, or ``None`` if none is active / distributed is not + installed. + """ try: from dask.distributed import get_client @@ -19,12 +26,22 @@ def _distributed_client(): def _client_is_in_process(client) -> bool: - """True if *client* runs its worker in this process (processes=False). + """Whether *client* runs its worker in this process (``processes=False``). An in-process worker shares the GIL. A long task that holds the GIL (e.g. a Cellpose/torch eval) starves the worker heartbeat, the scheduler declares it dead, and the P2P merge barrier drops its inputs → "FutureCancelledError: lost dependencies". + + Parameters + ---------- + client : distributed.Client + The client to inspect. + + Returns + ------- + bool + True if any worker address uses the ``inproc://`` transport. """ try: for addr in client.scheduler_info().get("workers", {}): diff --git a/src/patchworks/_core.py b/src/patchworks/_core.py index c4077a3..f8a4793 100644 --- a/src/patchworks/_core.py +++ b/src/patchworks/_core.py @@ -23,7 +23,23 @@ def _stage_to_zarr( arr: da.Array, path: str, component: str, show_progress: bool ) -> None: - """Write *arr* to zarr *path/component*, never loading it into RAM.""" + """Write *arr* to zarr ``path/component``, never loading it into RAM. + + Parameters + ---------- + arr : da.Array + Array to materialise to disk. + path : str + Zarr store path. + component : str + Array name within the store. + show_progress : bool + Show a progress bar while computing. + + Returns + ------- + None + """ import dask lazy_write = arr.to_zarr( @@ -57,7 +73,7 @@ def tile_process( level: int = 0, use_gpu: bool = False, max_workers: int | None = None, - progress: bool = False, + progress: bool = True, write_to: Union[str, Path, None] = None, output_component: str = "labels", pyramid_levels: int = 5, @@ -123,7 +139,8 @@ def tile_process( every core. Ignored when a distributed client is active (it manages its own concurrency). progress: - Show a progress bar during the tile-writing and relabel steps. + Show progress bars for staging, the label write and the pyramid + (default ``True``). Set ``False`` to silence them. write_to: Explicit output zarr store path. Overrides the default behaviour: the merged labels are written here as a single-resolution array named @@ -297,6 +314,20 @@ def tile_process( _skip_thr = _auto_empty_threshold(image_for_threshold, channel, level) def active_fn(block, block_info=None): + """Run *fn* on one tile, or return zeros for an empty tile. + + Parameters + ---------- + block : np.ndarray + One image tile. + block_info : dict or None + Dask block metadata (used for logging the tile location). + + Returns + ------- + np.ndarray + Integer labels, or an all-zero tile when skipped. + """ loc = block_info[0].get("chunk-location") if block_info else "?" if skip_empty and block.size and block.max() <= _skip_thr: if verbose: @@ -398,6 +429,12 @@ def active_fn(block, block_info=None): # that figure instead. def _cleanup_stage(): + """Delete the temporary stage store unless ``keep_stage`` is set. + + Returns + ------- + None + """ if not keep_stage: import shutil @@ -457,6 +494,7 @@ def _cleanup_stage(): name=output_component, n_levels=pyramid_levels, downscale=pyramid_downscale, + progress=progress, overwrite=True, ) shutil.rmtree(os.path.dirname(_merge_out), ignore_errors=True) diff --git a/src/patchworks/_io.py b/src/patchworks/_io.py index a364467..6f0ed78 100644 --- a/src/patchworks/_io.py +++ b/src/patchworks/_io.py @@ -75,6 +75,16 @@ def _otsu_threshold(sample: np.ndarray) -> float: Operates on the full distribution including zeros — zeros are background pixels and must be included so Otsu can find the signal/background boundary. + + Parameters + ---------- + sample : np.ndarray + Flat intensity sample. + + Returns + ------- + float + The Otsu threshold, or ``0.0`` when the sample is degenerate. """ try: from skimage.filters import threshold_otsu @@ -89,7 +99,22 @@ def _otsu_threshold(sample: np.ndarray) -> float: def _auto_empty_threshold( image: da.Array, channel: int | None, level: int ) -> float: - """Pick an empty-tile threshold from a cheap bounded sample (Otsu).""" + """Pick an empty-tile threshold from a cheap bounded sample (Otsu). + + Parameters + ---------- + image : da.Array + Image to sample. + channel : int or None + Channel hint (kept for signature symmetry). + level : int + Pyramid level hint (kept for signature symmetry). + + Returns + ------- + float + Otsu threshold over a few small centred windows. + """ n = image.ndim win = [min(64 if i >= n - 3 else s, s) for i, s in enumerate(image.shape)] win = [min(w, 256) if i >= n - 2 else w for i, w in enumerate(win)] diff --git a/src/patchworks/_merge.py b/src/patchworks/_merge.py index e0eed43..109c1be 100644 --- a/src/patchworks/_merge.py +++ b/src/patchworks/_merge.py @@ -55,6 +55,25 @@ def _init_worker(lut_path, staged_path, staged_comp, out_path, out_comp): + """Initialise a merge worker process with the shared paths and LUT. + + Parameters + ---------- + lut_path : str + Path to the relabel lookup table (loaded memory-mapped, read-only). + staged_path : str + Path to the staged-labels zarr store. + staged_comp : str + Component name within the staged store. + out_path : str + Path to the output zarr store. + out_comp : str + Component name within the output store. + + Returns + ------- + None + """ global _merge_lut, _merge_lut_path, _merge_staged_path, _merge_staged_comp global _merge_out_path, _merge_out_comp _merge_lut = np.load( @@ -68,6 +87,17 @@ def _init_worker(lut_path, staged_path, staged_comp, out_path, out_comp): def _relabel_chunk_worker(chunk_slice: tuple) -> None: + """Apply the relabel LUT to one chunk and write it to the output store. + + Parameters + ---------- + chunk_slice : tuple + The slice selecting this chunk in both stores. + + Returns + ------- + None + """ src = zarr.open_group(_merge_staged_path, mode="r")[_merge_staged_comp] dst = zarr.open_group(_merge_out_path, mode="r+")[_merge_out_comp] block = np.asarray(src[chunk_slice], dtype=np.int64) @@ -87,6 +117,20 @@ def _relabel_chunk_worker(chunk_slice: tuple) -> None: def _boundary_face_specs( shape: tuple[int, ...], chunk_shape: tuple[int, ...] ) -> list[tuple[int, int]]: + """Enumerate interior chunk boundaries to scan for touching labels. + + Parameters + ---------- + shape : tuple of int + Array shape. + chunk_shape : tuple of int + Chunk shape. + + Returns + ------- + list of tuple of int + ``(axis, position)`` pairs, one per interior chunk boundary. + """ specs = [] for ax, (s, cs) in enumerate(zip(shape, chunk_shape)): pos = cs @@ -105,6 +149,21 @@ def _scan_touching_pairs( is bounded to one chunk (~200 MB). Reading the full face at once (slice(None) on face axes) would allocate face_area × 8 bytes in one shot — e.g. 37888 × 27392 × 8 = 8 GiB for a single z-face (OOM on real datasets). + + Parameters + ---------- + zarr_path : str + Path to the staged-labels zarr store. + component : str + Component name within the store. + chunk_shape : tuple of int + Chunk shape (sets the per-read column size). + + Returns + ------- + np.ndarray + ``(N, 2)`` int64 array of unique label pairs touching across a + boundary. """ root = zarr.open_group(zarr_path, mode="r") arr = root[component] @@ -135,7 +194,20 @@ def _scan_touching_pairs( def _build_relabel_lut(pairs: np.ndarray, max_label: int) -> np.ndarray: - """Touching-pairs → scipy connected components → relabeling LUT.""" + """Build a relabel LUT from touching pairs via connected components. + + Parameters + ---------- + pairs : np.ndarray + ``(N, 2)`` array of touching label pairs. + max_label : int + Largest label id present. + + Returns + ------- + np.ndarray + Lookup table mapping each old label to its merged (component) id. + """ if max_label > _LUT_WARN_THRESHOLD: logger.warning( "_build_relabel_lut: max_label=%d → LUT ~%.0f MB. " @@ -168,6 +240,24 @@ def _build_relabel_lut(pairs: np.ndarray, max_label: int) -> np.ndarray: def _create_zarr_label_array( group: zarr.Group, name: str, shape: tuple, chunks: tuple ) -> zarr.Array: + """Create (replacing any existing) an int32 label array in *group*. + + Parameters + ---------- + group : zarr.Group + Parent group. + name : str + Array name (may be a nested path). + shape : tuple + Array shape. + chunks : tuple + Chunk shape. + + Returns + ------- + zarr.Array + The newly created array (works on zarr v2 and v3). + """ if name in group: del group[name] if _ZARR_V3: @@ -192,6 +282,25 @@ def zarr_native_merge( Scales to 2000+ chunks where the dask_image approach stalls (O(n_chunks²) graph). Reads *staged_path/staged_component*, merges touching cross-boundary labels, writes result to *out_path/out_component*. No dask task graph. + + Parameters + ---------- + staged_path : str + Path to the staged-labels zarr store. + staged_component : str + Component name within the staged store. + out_path : str + Path to the output zarr store. + out_component : str + Component name within the output store. + n_workers : int + Number of worker processes for the parallel relabel. + show_progress : bool + Show a progress bar over the relabel chunks. + + Returns + ------- + None """ root = zarr.open_group(staged_path, mode="r") arr = root[staged_component] diff --git a/src/patchworks/_relabel.py b/src/patchworks/_relabel.py index 2e27e18..5b5dcac 100644 --- a/src/patchworks/_relabel.py +++ b/src/patchworks/_relabel.py @@ -20,6 +20,16 @@ def relabel_sequential_array(labels: np.ndarray) -> np.ndarray: Background (0) stays 0. Runs in one ``np.unique`` + a lookup-table gather, i.e. O(voxels) — unlike dask's ``relabel_sequential`` which is O(n_chunks²). + Parameters + ---------- + labels : np.ndarray + Integer label array (may have gappy ids). + + Returns + ------- + np.ndarray + Labels remapped to a contiguous ``0, 1, … N`` range. + Examples -------- >>> relabel_sequential_array(np.array([0, 500000, 500000, 7])) diff --git a/src/patchworks/plugins/cellpose.py b/src/patchworks/plugins/cellpose.py index 3eb213f..14fa741 100644 --- a/src/patchworks/plugins/cellpose.py +++ b/src/patchworks/plugins/cellpose.py @@ -40,6 +40,12 @@ def _require_cellpose(): + """Raise an actionable ImportError if cellpose is not installed. + + Returns + ------- + None + """ if _cellpose_models is None: raise ImportError( "cellpose is not installed. Install it with:\n" @@ -126,6 +132,30 @@ def _make_config( do_3D: bool = False, **cellpose_kwargs: Any, ) -> dict[str, Any]: + """Build a picklable Cellpose configuration dict. + + Parameters + ---------- + model : str + Cellpose model type. + gpu : bool + Run on the GPU. + channels : list of int or None + Cellpose-3 ``[cyto, nucleus]`` channels; defaults to ``[0, 0]``. + channel_axis : int or None + Cellpose-4 channel axis. + diameter : float or None + Expected cell diameter in pixels. + do_3D : bool + Segment in 3-D. + **cellpose_kwargs : Any + Extra arguments forwarded to ``model.eval()``. + + Returns + ------- + dict + The configuration consumed by :func:`_get_model` and :func:`_run`. + """ return { "model": model, "gpu": gpu, @@ -138,7 +168,18 @@ def _make_config( def _get_model(cellpose_dict: dict[str, Any]) -> Any: - """Return a worker-local cached Cellpose model.""" + """Return a worker-local cached Cellpose model. + + Parameters + ---------- + cellpose_dict : dict + Configuration from :func:`_make_config`. + + Returns + ------- + Any + A Cellpose model instance (cached per ``(model, gpu)`` per process). + """ _require_cellpose() key = (cellpose_dict["model"], cellpose_dict.get("gpu", False)) if key not in _model_cache: @@ -156,7 +197,20 @@ def _get_model(cellpose_dict: dict[str, Any]) -> Any: def _run(block: np.ndarray, cellpose_dict: dict[str, Any]) -> np.ndarray: - """Segment one tile with a cached Cellpose model.""" + """Segment one tile with a cached Cellpose model. + + Parameters + ---------- + block : np.ndarray + One image tile. + cellpose_dict : dict + Configuration from :func:`_make_config`. + + Returns + ------- + np.ndarray + Integer (``int32``) label array of the same spatial shape. + """ model = _get_model(cellpose_dict) do_3D = cellpose_dict["do_3D"] diff --git a/src/patchworks/plugins/napari.py b/src/patchworks/plugins/napari.py index 12c86e3..189f259 100644 --- a/src/patchworks/plugins/napari.py +++ b/src/patchworks/plugins/napari.py @@ -37,6 +37,13 @@ def _require_napari(): + """Import and return napari, or raise an actionable ImportError. + + Returns + ------- + module + The imported ``napari`` module. + """ try: import napari except ImportError as exc: @@ -48,10 +55,34 @@ def _require_napari(): def _is_zarr(src: Any) -> bool: + """Whether *src* is a path ending in ``.zarr``. + + Parameters + ---------- + src : Any + Candidate source. + + Returns + ------- + bool + True for a str/Path ending in ``.zarr``. + """ return isinstance(src, (str, Path)) and str(src).endswith(".zarr") def _has_multiscales(path: Union[str, Path]) -> bool: + """Whether a zarr group carries NGFF ``multiscales`` metadata. + + Parameters + ---------- + path : str or Path + Zarr group path. + + Returns + ------- + bool + True if the group has a ``multiscales`` attribute. + """ root = zarr.open_group(str(path), mode="r") return "multiscales" in root.attrs @@ -59,7 +90,20 @@ def _has_multiscales(path: Union[str, Path]) -> bool: def _multiscale_levels( path: Union[str, Path], channel: int | None ) -> list[da.Array]: - """Return every pyramid level as a lazy dask array (napari multi-scale).""" + """Return every pyramid level as a lazy dask array (napari multi-scale). + + Parameters + ---------- + path : str or Path + OME-ZARR group path. + channel : int or None + Channel to select, or ``None`` to keep all channels. + + Returns + ------- + list of da.Array + One lazy array per resolution level. + """ root = zarr.open_group(str(path), mode="r") datasets = root.attrs["multiscales"][0]["datasets"] return [ @@ -71,7 +115,20 @@ def _multiscale_levels( def _resolve_image( source: Union[da.Array, str, Path], channel: int | None ) -> Union[da.Array, list[da.Array]]: - """Resolve *source* into data napari can display (lazily).""" + """Resolve *source* into data napari can display (lazily). + + Parameters + ---------- + source : da.Array, str or Path + OME-ZARR store, other image file, or an in-memory array. + channel : int or None + Channel to display, or ``None`` to keep all channels. + + Returns + ------- + da.Array or list of da.Array + A single array, or a multi-scale list for an OME-ZARR pyramid. + """ if _is_zarr(source): if _has_multiscales(source): return _multiscale_levels(source, channel) @@ -86,7 +143,18 @@ def _resolve_image( def _inner_label_names(store: Union[str, Path]) -> list[str]: - """Names registered under an OME-ZARR's NGFF ``labels/`` group, if any.""" + """List label images registered under an OME-ZARR's ``labels/`` group. + + Parameters + ---------- + store : str or Path + OME-ZARR store path. + + Returns + ------- + list of str + Registered label-image names (empty if there are none). + """ try: grp = zarr.open_group(f"{store}/labels", mode="r") except Exception: @@ -97,7 +165,20 @@ def _inner_label_names(store: Union[str, Path]) -> list[str]: def _resolve_labels( source: Union[da.Array, str, Path], component: str ) -> Union[da.Array, list[da.Array]]: - """Resolve a label *source* into integer data for an Labels layer.""" + """Resolve a label *source* into integer data for a Labels layer. + + Parameters + ---------- + source : da.Array, str or Path + Label store (plain or multi-scale) or an in-memory array. + component : str + Array name inside a plain-zarr label store. + + Returns + ------- + da.Array or list of da.Array + Integer (``int32``) labels; a list for a multi-scale store. + """ if _is_zarr(source): if _has_multiscales(source): levels = _multiscale_levels(source, None) diff --git a/src/patchworks/plugins/ome_zarr.py b/src/patchworks/plugins/ome_zarr.py index 68c2975..b2b2ef5 100644 --- a/src/patchworks/plugins/ome_zarr.py +++ b/src/patchworks/plugins/ome_zarr.py @@ -37,6 +37,7 @@ from __future__ import annotations import logging +import math from pathlib import Path from typing import Union @@ -66,6 +67,16 @@ def _default_axes(ndim: int) -> str: """Assign trailing OME axis names to an unlabelled array. A 2-D array becomes ``"yx"``, 3-D ``"zyx"``, 4-D ``"czyx"``. + + Parameters + ---------- + ndim : int + Number of array dimensions. + + Returns + ------- + str + The axes string for those trailing dimensions. """ if ndim > len(_DEFAULT_ORDER): raise ValueError( @@ -75,13 +86,38 @@ def _default_axes(ndim: int) -> str: def _axis_type(name: str) -> str: + """Map an axis letter to its NGFF axis type. + + Parameters + ---------- + name : str + Axis letter (``z``/``y``/``x``/``c``/``t``). + + Returns + ------- + str + ``"space"``, ``"time"`` or ``"channel"``. + """ if name in _SPATIAL_AXES: return "space" return "time" if name == "t" else "channel" def _axes_meta(axes: str, calibrated: bool) -> list[dict]: - """NGFF ``axes`` metadata; spatial axes get a µm unit when calibrated.""" + """Build NGFF ``axes`` metadata for an axes string. + + Parameters + ---------- + axes : str + One letter per axis. + calibrated : bool + When True, spatial axes carry a micrometer unit. + + Returns + ------- + list of dict + One ``{name, type[, unit]}`` entry per axis. + """ meta = [] for a in axes: entry = {"name": a, "type": _axis_type(a)} @@ -92,19 +128,244 @@ def _axes_meta(axes: str, calibrated: bool) -> list[dict]: def _strides(axes: str, downscale: int) -> tuple[int, ...]: - """Per-axis stride: downsample X/Y only; Z, C and T stay at 1.""" + """Per-axis downsampling stride for one pyramid step. + + Parameters + ---------- + axes : str + One letter per axis. + downscale : int + Downsampling factor for X/Y. + + Returns + ------- + tuple of int + ``downscale`` for X/Y axes, ``1`` for Z/C/T. + """ return tuple(downscale if a in _DOWNSAMPLE_AXES else 1 for a in axes) def _default_chunks(shape: tuple[int, ...], axes: str) -> tuple[int, ...]: - """Bounded chunk shape so writing a level never blows up RAM.""" + """Bounded chunk shape so writing a level never blows up RAM. + + Parameters + ---------- + shape : tuple of int + Array shape. + axes : str + One letter per axis (selects the per-axis cap). + + Returns + ------- + tuple of int + Per-axis chunk size, capped by ``_CHUNK_CAP``. + """ return tuple(min(s, _CHUNK_CAP.get(a, s)) for s, a in zip(shape, axes)) +ShardSpec = Union[bool, tuple[int, ...]] +_SHARD_TARGET_BYTES = 512 * 1024**2 # aim for ~512 MB shards +_ZARR_V3 = int(zarr.__version__.split(".")[0]) >= 3 + + +def _effective_shard( + requested: tuple[int, ...], + chunks: tuple[int, ...], + shape: tuple[int, ...], +) -> tuple[int, ...]: + """Clamp a requested shard shape to a valid one. + + A shard must be a whole multiple of the inner chunk and should not exceed + the array's chunk-padded extent. + + Parameters + ---------- + requested : tuple of int + Desired shard shape. + chunks : tuple of int + Inner chunk shape. + shape : tuple of int + Array shape. + + Returns + ------- + tuple of int + A shard shape that is a chunk-multiple within the array. + """ + out = [] + for r, c, s in zip(requested, chunks, shape): + cap = math.ceil(s / c) * c # array dim padded up to a whole chunk + out.append(min(max(c, (r // c) * c), cap)) + return tuple(out) + + +def _auto_shard( + chunks: tuple[int, ...], shape: tuple[int, ...], dtype +) -> tuple[int, ...]: + """Pick a shard shape of roughly ``_SHARD_TARGET_BYTES``. + + Grows the two largest axes equally until the shard reaches the target size, + then clamps to a valid chunk-multiple. + + Parameters + ---------- + chunks : tuple of int + Inner chunk shape. + shape : tuple of int + Array shape. + dtype : data-type + Array dtype, used to size the shard in bytes. + + Returns + ------- + tuple of int + The chosen shard shape. + """ + itemsize = np.dtype(dtype).itemsize + base = itemsize + for c in chunks: + base *= c + big = sorted(range(len(chunks)), key=lambda i: shape[i], reverse=True)[:2] + factor = max(1, int((_SHARD_TARGET_BYTES / max(1, base)) ** 0.5)) + shard = list(chunks) + for i in big: + shard[i] = chunks[i] * factor + return _effective_shard(tuple(shard), chunks, shape) + + +def _shard_for( + shard: ShardSpec, + chunks: tuple[int, ...], + shape: tuple[int, ...], + dtype, +) -> Union[tuple[int, ...], None]: + """Resolve the ``shard`` argument into a concrete shard shape. + + Parameters + ---------- + shard : bool or tuple of int + ``False`` → no sharding; ``True`` → auto; + a tuple → an explicit shard shape. + chunks : tuple of int + Inner chunk shape. + shape : tuple of int + Array shape. + dtype : data-type + Array dtype. + + Returns + ------- + tuple of int or None + The shard shape, or ``None`` when not sharding + (also when zarr is older than v3). + """ + if not shard: + return None + if not _ZARR_V3: + logger.warning("sharding requires zarr v3; writing unsharded.") + return None + if shard is True: + return _auto_shard(chunks, shape, dtype) + return _effective_shard(tuple(shard), chunks, shape) + + +def _progress_ctx(progress: bool, label: str): + """Return a progress-bar context manager. + + Parameters + ---------- + progress : bool + Whether to show a dask progress bar. + label : str + Name logged just before the bar. + + Returns + ------- + contextmanager + A ``ProgressBar`` when *progress* is set, else a no-op + context manager. + """ + if not progress: + from contextlib import nullcontext + + return nullcontext() + from dask.diagnostics import ProgressBar + + logger.info("writing %s …", label) + return ProgressBar() + + +def _to_zarr_level( + arr: da.Array, + group_path: str, + component: str, + shard: ShardSpec, + progress: bool = True, +) -> None: + """Write one array to ``group_path/component``, optionally sharded. + + Without sharding, ``da.to_zarr`` writes chunk by chunk. With sharding that + is unsafe — many chunks share one shard file and per-chunk writes race — + so we create the sharded array explicitly (inner *chunks* + *shards*) and + store with the dask blocks rechunked to the **shard** size, so each task + writes one whole shard atomically. + + Parameters + ---------- + arr : da.Array + Array to write (its chunk size becomes the inner chunk). + group_path : str + Path of the parent zarr group. + component : str + Array name within the group. + shard : bool or tuple of int + Sharding request; see :func:`_shard_for`. + progress : bool + Show a dask progress bar for the write. + + Returns + ------- + None + """ + inner = arr.chunksize + sh = _shard_for(shard, inner, arr.shape, arr.dtype) + ctx = _progress_ctx(progress, f"{Path(group_path).name}/{component}") + if not sh: + with ctx: + da.to_zarr(arr, group_path, component=component, overwrite=True) + return + grp = zarr.open_group(group_path, mode="a") + if component in grp: + del grp[component] + z = grp.create_array( + name=component, + shape=arr.shape, + chunks=inner, + shards=sh, + dtype=arr.dtype, + ) + with ctx: + arr.rechunk(sh).store(z, lock=True, compute=True) + + def _normalize_pixel_size( pixel_size: Union[PixelSize, tuple, None], axes: str ) -> PixelSize: - """Coerce a pixel-size dict/tuple into ``{axis: size}`` for spatial axes.""" + """Coerce a pixel-size dict/tuple into ``{axis: size}``. + + Parameters + ---------- + pixel_size : dict, tuple or None + Voxel size as a per-axis dict or a tuple + aligned to the spatial axes. + axes : str + One letter per axis. + + Returns + ------- + dict + ``{axis: size}`` for the spatial axes present (empty if none given). + """ if not pixel_size: return {} if isinstance(pixel_size, dict): @@ -115,11 +376,38 @@ def _normalize_pixel_size( def _base_scale(axes: str, pixel_size: PixelSize) -> list[float]: - """Level-0 NGFF scale per axis: physical size for spatial, 1.0 else.""" + """Build the level-0 NGFF scale vector. + + Parameters + ---------- + axes : str + One letter per axis. + pixel_size : dict + ``{axis: size}`` for spatial axes. + + Returns + ------- + list of float + Physical size per spatial axis, ``1.0`` for C/T. + """ return [float(pixel_size.get(a, 1.0)) for a in axes] def _dataset(name: str, scale: list[float]) -> dict: + """Build one NGFF ``multiscales`` dataset entry. + + Parameters + ---------- + name : str + Component path of the level (e.g. ``"0"``). + scale : list of float + Per-axis scale (physical size × downsample factor). + + Returns + ------- + dict + A dataset dict with its ``path`` and ``coordinateTransformations``. + """ return { "path": name, "coordinateTransformations": [ @@ -139,6 +427,8 @@ def _write_pyramid( base_scale: list[float], base_name: str = "0", write_base: bool = True, + shard: ShardSpec = False, + progress: bool = True, ) -> list[dict]: """Write pyramid levels into *group_path* and return NGFF datasets. @@ -147,16 +437,43 @@ def _write_pyramid( whole-volume recomputation, no OOM. Level 0 is named *base_name*; when *write_base* is False it is assumed to already exist (used by :func:`add_pyramid`). + + Parameters + ---------- + arr : da.Array + Full-resolution array. + axes : str + One letter per axis. + group_path : str + Path of the zarr group to write into. + n_levels : int + Maximum number of levels including full resolution. + downscale : int + Per-level X/Y downsampling factor. + chunks : tuple of int or None + Chunk shape, or a bounded default. + base_scale : list of float + Level-0 physical scale per axis. + base_name : str + Component name of level 0. + write_base : bool + Write level 0, or assume it already exists. + shard : bool or tuple of int + Sharding request (see :func:`_shard_for`). + progress : bool + Show a per-level progress bar. + + Returns + ------- + list of dict + One NGFF dataset entry per written level. """ strides = _strides(axes, downscale) if write_base: base_chunks = chunks or _default_chunks(arr.shape, axes) - da.to_zarr( - arr.rechunk(base_chunks), - group_path, - component=base_name, - overwrite=True, + _to_zarr_level( + arr.rechunk(base_chunks), group_path, base_name, shard, progress ) datasets = [_dataset(base_name, base_scale)] @@ -170,7 +487,7 @@ def _write_pyramid( src = da.from_zarr(group_path, component=prev_name) nxt = src[tuple(slice(None, None, st) for st in strides)] nxt = nxt.rechunk(chunks or _default_chunks(nxt.shape, axes)) - da.to_zarr(nxt, group_path, component=str(i), overwrite=True) + _to_zarr_level(nxt, group_path, str(i), shard, progress) scale = [base_scale[k] * (strides[k] ** i) for k in range(len(axes))] datasets.append(_dataset(str(i), scale)) logger.info("pyramid level %d: shape=%s", i, nxt.shape) @@ -187,7 +504,25 @@ def _write_multiscales( *, calibrated: bool, ) -> None: - """Write NGFF ``multiscales`` metadata onto *group_path*.""" + """Write NGFF ``multiscales`` metadata onto *group_path*. + + Parameters + ---------- + group_path : str + Path of the zarr group to annotate. + axes : str + One letter per axis. + datasets : list of dict + Per-level dataset entries (see :func:`_dataset`). + name : str + Multiscales name. + calibrated : bool + Whether spatial axes carry a micrometer unit. + + Returns + ------- + None + """ group = zarr.open_group(group_path, mode="a") group.attrs["multiscales"] = [ { @@ -200,7 +535,21 @@ def _write_multiscales( def _read_zarr_calibration(store: Union[str, Path], axes: str) -> PixelSize: - """Read level-0 spatial scale from an existing OME-ZARR, if any.""" + """Read level-0 spatial scale from an existing OME-ZARR, if any. + + Parameters + ---------- + store : str or Path + Path of the OME-ZARR group. + axes : str + One letter per axis (unused for parsing, kept for symmetry). + + Returns + ------- + dict + ``{axis: size}`` for spatial axes with a non-unit scale (empty if + the store has no multiscales metadata). + """ try: root = zarr.open_group(str(store), mode="r") ms = root.attrs["multiscales"][0] @@ -216,7 +565,23 @@ def _read_zarr_calibration(store: Union[str, Path], axes: str) -> PixelSize: def _open_bioio(path: str, scene: int) -> tuple[da.Array, str, PixelSize]: - """Open *path* with bioio → ``(array, axes, pixel_size)``, all lazy.""" + """Open *path* with bioio, lazily. + + Singleton non-spatial axes (T/C of size 1) are dropped. + + Parameters + ---------- + path : str + Image file path. + scene : int + Scene index for multi-scene files. + + Returns + ------- + tuple + ``(array, axes, pixel_size)`` — a lazy dask array, its axes string + and a ``{axis: micrometers}`` calibration dict. + """ try: from bioio import BioImage except ImportError as exc: @@ -256,7 +621,24 @@ def _open_bioio(path: str, scene: int) -> tuple[da.Array, str, PixelSize]: def _open_imaris(path: str, level: int = 0) -> tuple[da.Array, str, PixelSize]: - """Open an Imaris ``.ims`` *level* lazily → ``(array, axes, pixel_size)``.""" + """Open one Imaris ``.ims`` resolution level lazily. + + Reads the underlying HDF5 datasets directly (own handle, crops the Imaris + chunk padding) and stacks the per-(timepoint, channel) 3-D arrays. + + Parameters + ---------- + path : str + Imaris ``.ims`` file path. + level : int + Resolution level to read (0 = full resolution). + + Returns + ------- + tuple + ``(array, axes, pixel_size)`` — a lazy dask array, its axes string + and a ``{axis: micrometers}`` calibration dict. + """ try: from imaris_ims_file_reader.ims import ims except ImportError as exc: @@ -326,12 +708,34 @@ def _write_imaris_pyramid( *, chunks: Union[tuple[int, ...], None], overwrite: bool, + shard: ShardSpec = False, + progress: bool = True, ) -> str: """Copy an Imaris file's own resolution levels into an OME-ZARR. Each Imaris ``ResolutionLevel`` is written as a pyramid level with its own physical scale, so no downsampling is recomputed. Lazy (h5py-backed) reads stream straight to disk. + + Parameters + ---------- + path : str + Imaris ``.ims`` file path. + out : str + Destination ``.zarr`` store path. + chunks : tuple of int or None + Chunk shape, or a bounded default. + overwrite : bool + Overwrite an existing store. + shard : bool or tuple of int + Sharding request (see :func:`_shard_for`). + progress : bool + Show a per-level progress bar. + + Returns + ------- + str + The path to the written store. """ from imaris_ims_file_reader.ims import ims @@ -346,11 +750,12 @@ def _write_imaris_pyramid( arr, axes, ps = _open_imaris(path, level=level) scale = _base_scale(axes, ps) calibrated = calibrated or bool(ps) - da.to_zarr( + _to_zarr_level( arr.rechunk(chunks or _default_chunks(arr.shape, axes)), out, - component=str(level), - overwrite=True, + str(level), + shard, + progress, ) datasets.append(_dataset(str(level), scale)) logger.info("imaris level %d copied: shape=%s", level, arr.shape) @@ -365,7 +770,26 @@ def _to_dask( axes: Union[str, None], scene: int, ) -> tuple[da.Array, str, PixelSize]: - """Resolve *source* into a lazy ``(array, axes, pixel_size)`` triple.""" + """Resolve *source* into a lazy ``(array, axes, pixel_size)`` triple. + + Dispatches by type: dask/NumPy arrays pass through; ``.zarr`` paths use the + OME-ZARR loader; ``.ims`` paths use the Imaris reader; anything else uses + bioio. + + Parameters + ---------- + source : da.Array, np.ndarray, str or Path + Array or path to resolve. + axes : str or None + Explicit axes, or ``None`` to infer them. + scene : int + Scene index for bioio inputs. + + Returns + ------- + tuple + ``(array, axes, pixel_size)``. + """ if isinstance(source, da.Array): return source, axes or _default_axes(source.ndim), {} if isinstance(source, np.ndarray): @@ -394,7 +818,9 @@ def to_ome_zarr( n_levels: int = 5, downscale: int = 2, chunks: Union[tuple[int, ...], None] = None, + shard: ShardSpec = False, reuse_pyramid: bool = False, + progress: bool = True, overwrite: bool = False, ) -> str: """Write *source* as a pyramidal, calibrated OME-ZARR store. @@ -427,6 +853,15 @@ def to_ome_zarr( Per-level X/Y downsampling factor (default 2). chunks : tuple of int, optional Chunk shape for the written levels. ``None`` → a bounded default. + shard : bool or tuple of int, optional + Pack many chunks into one shard file (zarr v3), cutting the file count + ~100× on huge arrays. ``False`` (default) → unsharded, maximum reader + compatibility. ``True`` → auto-pick a ~512 MB shard. A tuple sets an + explicit shard shape (clamped to a chunk multiple). Sharded writes hold + ~one shard per worker in RAM. Requires zarr v3 (ignored otherwise). + progress : bool, optional + Show a per-level dask progress bar (default ``True``). Set ``False`` to + silence it. reuse_pyramid : bool, optional *Imaris ``.ims`` only.* Copy the file's **own** resolution levels instead of rebuilding the pyramid (faster, no recompute), keeping each @@ -460,7 +895,12 @@ def to_ome_zarr( ): try: return _write_imaris_pyramid( - str(source), str(out_path), chunks=chunks, overwrite=overwrite + str(source), + str(out_path), + chunks=chunks, + overwrite=overwrite, + shard=shard, + progress=progress, ) except Exception as exc: logger.warning( @@ -487,6 +927,8 @@ def to_ome_zarr( downscale=downscale, chunks=chunks, base_scale=base_scale, + shard=shard, + progress=progress, ) _write_multiscales(out, axes, datasets, Path(out).stem, calibrated=bool(ps)) return out @@ -501,6 +943,8 @@ def add_pyramid( n_levels: int = 5, downscale: int = 2, chunks: Union[tuple[int, ...], None] = None, + shard: ShardSpec = False, + progress: bool = True, ) -> str: """Add downsampled pyramid levels to an existing single-resolution zarr. @@ -552,6 +996,8 @@ def add_pyramid( base_scale=base_scale, base_name=base, write_base=False, + shard=shard, + progress=progress, ) _write_multiscales(gp, axes, datasets, Path(gp).stem, calibrated=bool(ps)) return gp @@ -566,6 +1012,8 @@ def register_labels( n_levels: int = 5, downscale: int = 2, chunks: Union[tuple[int, ...], None] = None, + shard: ShardSpec = False, + progress: bool = True, ) -> str: """Pyramidalise and register an existing ``labels//0`` base level. @@ -594,6 +1042,8 @@ def register_labels( n_levels=n_levels, downscale=downscale, chunks=chunks, + shard=shard, + progress=progress, ) grp = zarr.open_group(group, mode="a") grp.attrs["image-label"] = {"version": _NGFF_VERSION} @@ -616,6 +1066,8 @@ def write_labels( n_levels: int = 5, downscale: int = 2, chunks: Union[tuple[int, ...], None] = None, + shard: ShardSpec = False, + progress: bool = True, overwrite: bool = False, ) -> str: """Store *labels* inside *image_store* under the NGFF ``labels/`` group. @@ -648,7 +1100,7 @@ def write_labels( label_group = f"{store}/labels/{name}" base = arr.rechunk(chunks or _default_chunks(arr.shape, axes)) - da.to_zarr(base, label_group, component="0", overwrite=True) + _to_zarr_level(base, label_group, "0", shard, progress) return register_labels( store, name, @@ -657,4 +1109,6 @@ def write_labels( n_levels=n_levels, downscale=downscale, chunks=chunks, + shard=shard, + progress=progress, ) diff --git a/tests/test_ome_zarr.py b/tests/test_ome_zarr.py index db12cbf..837fef9 100644 --- a/tests/test_ome_zarr.py +++ b/tests/test_ome_zarr.py @@ -139,3 +139,42 @@ def test_reuse_pyramid_ignored_for_arrays(tmp_path): reuse_pyramid=True, ) assert load_ome_zarr(out, channel=None, level=1).shape == (8, 4, 4) + + +def test_sharding(tmp_path): + """shard=True/tuple writes zarr-v3 shards; data round-trips intact.""" + import zarr as _zarr + + a = np.arange(4 * 64 * 64, dtype="uint16").reshape(4, 64, 64) + + out = to_ome_zarr( + a, + tmp_path / "s.zarr", + axes="zyx", + n_levels=2, + chunks=(2, 16, 16), + shard=True, + ) + z0 = _zarr.open_array(f"{out}/0", mode="r") + assert z0.chunks == (2, 16, 16) + assert z0.shards is not None and z0.shards != z0.chunks + assert np.array_equal( + np.asarray(load_ome_zarr(out, channel=None, level=0)), a + ) + + out2 = to_ome_zarr( + a, + tmp_path / "e.zarr", + axes="zyx", + n_levels=1, + chunks=(2, 16, 16), + shard=(2, 32, 32), + ) + assert _zarr.open_array(f"{out2}/0", mode="r").shards == (2, 32, 32) + + out3 = to_ome_zarr( + a, tmp_path / "n.zarr", axes="zyx", n_levels=1, chunks=(2, 16, 16) + ) + assert ( + getattr(_zarr.open_array(f"{out3}/0", mode="r"), "shards", None) is None + )