diff --git a/README.md b/README.md
index 5ef9a9f..2970f58 100644
--- a/README.md
+++ b/README.md
@@ -7,7 +7,7 @@
> Tiled processing of arbitrarily large images — any image, any function.
-```
+```text
┌──────┬──────┬──────┐ fn(tile) → labels ┌──────┬──────┬──────┐
│ tile │ tile │ tile │ ─────────────────────► │ 1 │ 2 │ 3 │
├──────┼──────┼──────┤ ├──────┼──────┼──────┤
@@ -295,6 +295,7 @@ Full docs, guides and tutorials: ****
- dask[array], numpy, zarr, scipy
Optional:
+
- `psutil` — accurate RAM sizing for `tile_shape="auto"`
- `nvidia-ml-py` — accurate GPU VRAM sizing
- `tqdm` — progress bars
diff --git a/docs/examples/stardist.md b/docs/examples/stardist.md
index 795b043..5e641aa 100644
--- a/docs/examples/stardist.md
+++ b/docs/examples/stardist.md
@@ -47,18 +47,18 @@ tile_process(
Load the model **outside** the `fn` closure. If you load it inside,
it will be re-initialised (and potentially re-downloaded) once per tile.
- For distributed execution, use `functools.partial` with a cached model:
+For distributed execution, use `functools.partial` with a cached model:
- ```python
- from functools import lru_cache
+```python
+from functools import lru_cache
- @lru_cache(maxsize=1)
- def _get_model():
- return StarDist2D.from_pretrained("2D_versatile_fluo")
+@lru_cache(maxsize=1)
+def _get_model():
+ return StarDist2D.from_pretrained("2D_versatile_fluo")
- def stardist_fn(tile):
- model = _get_model()
- ...
- ```
+def stardist_fn(tile):
+ model = _get_model()
+ ...
+```
diff --git a/docs/getting_started.md b/docs/getting_started.md
index bccd506..d026fd3 100644
--- a/docs/getting_started.md
+++ b/docs/getting_started.md
@@ -46,11 +46,11 @@ patchworks can be installed from PyPI on all operating systems, for Python ≥ 3
## The one function you need
-```python
-from patchworks import tile_process
+ ```python
+ from patchworks import tile_process
-result = tile_process(image, fn)
-```
+ result = tile_process(image, fn)
+ ```
`tile_process(image, fn)` splits `image` into tiles, runs `fn` on each tile,
and returns a globally consistent label array.
@@ -65,17 +65,17 @@ and returns a globally consistent label array.
patchworks is method-agnostic. Your function receives a NumPy array (one tile)
and must return an integer label array of the same shape:
-```python
-import numpy as np
+ ```python
+ import numpy as np
-def my_fn(tile: np.ndarray) -> np.ndarray:
- from skimage.filters import threshold_otsu
- from skimage.measure import label
+ def my_fn(tile: np.ndarray) -> np.ndarray:
+ from skimage.filters import threshold_otsu
+ from skimage.measure import label
- binary = tile > threshold_otsu(tile)
- return label(binary).astype("int32")
-```
+ binary = tile > threshold_otsu(tile)
+ return label(binary).astype("int32")
+ ```
The function is called independently on every tile. patchworks ensures that
objects spanning tile boundaries are merged into a single label.
@@ -155,14 +155,14 @@ objects spanning tile boundaries are merged into a single label.
Methods like Cellpose and StarDist need spatial context at tile boundaries.
Use `overlap` (in voxels) so boundary objects are fully visible:
-```python
-result = tile_process(
- "image.zarr",
- my_fn,
- tile_shape=(1, 2048, 2048),
- overlap=20, # 20-voxel halo on every side
-)
-```
+ ```python
+ result = tile_process(
+ "image.zarr",
+ my_fn,
+ tile_shape=(1, 2048, 2048),
+ overlap=20, # 20-voxel halo on every side
+ )
+ ```
!!! info "How overlap works"
Each tile is expanded by `overlap` voxels on every side before calling `fn`.
@@ -173,22 +173,22 @@ result = tile_process(
## Use Cellpose
-```python
-from patchworks import tile_process
-from patchworks.plugins.cellpose import cellpose_fn
-
-fn = cellpose_fn("cyto3", gpu=True, diameter=30)
-
-tile_process(
- "image.zarr",
- fn,
- channel=0,
- tile_shape=(1, 2048, 2048),
- overlap=20,
- write_to="labels.zarr",
- progress=True,
-)
-```
+ ```python
+ from patchworks import tile_process
+ from patchworks.plugins.cellpose import cellpose_fn
+
+ fn = cellpose_fn("cyto3", gpu=True, diameter=30)
+
+ tile_process(
+ "image.zarr",
+ fn,
+ channel=0,
+ tile_shape=(1, 2048, 2048),
+ overlap=20,
+ write_to="labels.zarr",
+ progress=True,
+ )
+ ```
See the [Cellpose 2-D example](examples/cellpose_2d.md) for the full workflow.
diff --git a/docs/guide/gpu_distributed.md b/docs/guide/gpu_distributed.md
index a3ff925..44d0dc4 100644
--- a/docs/guide/gpu_distributed.md
+++ b/docs/guide/gpu_distributed.md
@@ -60,7 +60,7 @@ in the same process as the kernel. When your segmentation function holds the
Python GIL (every PyTorch/CUDA `eval` does), the worker thread can't send
heartbeats. The scheduler declares it dead, and the merge fails:
-```
+```python
FutureCancelledError: lost dependencies
```
diff --git a/docs/guide/merging.md b/docs/guide/merging.md
index bf9fb07..d206133 100644
--- a/docs/guide/merging.md
+++ b/docs/guide/merging.md
@@ -9,7 +9,7 @@ even though it's the same cell.
patchworks solves this with a zarr-native merge algorithm:
-```
+```text
Tile A labels: Tile B labels: After merge:
┌────────────┐ ┌────────────┐ ┌──────────────────────┐
│ 3 1 2 │ │ 1 4 2 │ │ 3 1 2 │ 501 5 502│
@@ -32,7 +32,7 @@ Each tile's labels are written to a temporary zarr once. This is critical:
without staging, any downstream operation that reads the label array re-runs
your segmentation function. The merge internally reads labels multiple times.
-```
+```text
tile_process calls fn once per tile → staged zarr
│
merge reads from staged zarr (no fn calls)
diff --git a/docs/guide/ome_zarr_napari.md b/docs/guide/ome_zarr_napari.md
index 1291c92..5ee97af 100644
--- a/docs/guide/ome_zarr_napari.md
+++ b/docs/guide/ome_zarr_napari.md
@@ -76,6 +76,28 @@ and streaming the downsampled result out through dask with bounded chunks. The
graph never chains level-on-level and no whole plane/volume is held in RAM, so
terabyte images convert in bounded memory.
+### Sharding (fewer files)
+
+A big array becomes tens of thousands of tiny chunk files, which strain
+filesystems and object stores. Sharding packs many chunks into one **shard**
+file (zarr v3), cutting the file count ~100×:
+
+```python
+to_ome_zarr("scan.ims", "scan.zarr", shard=True) # auto ~512 MB shards
+to_ome_zarr("scan.ims", "scan.zarr", shard=(1, 16, 2048, 2048)) # explicit
+```
+
+Default is `shard=False` for maximum reader compatibility — sharding is
+zarr-v3-only, so older tools may not read it (your zarr/napari stack does).
+A sharded write holds ~one shard per worker in RAM, so very large shards cost
+memory.
+
+### Progress
+
+All write steps show a dask progress bar **by default** (`progress=True`), so
+you can see how long a conversion will take. Pass `progress=False` to silence
+it.
+
!!! note "Install the readers you need"
`pip install "patchworks[bioio]"` pulls `bioio` plus the `bioio-bioformats`
catch-all reader (needs a JVM). For speed, add native readers for your
diff --git a/docs/guide/pitfalls.md b/docs/guide/pitfalls.md
index b3e8cbb..db54d70 100644
--- a/docs/guide/pitfalls.md
+++ b/docs/guide/pitfalls.md
@@ -30,7 +30,7 @@ single-GPU runs — patchworks pins it to 1 thread automatically).
patchworks detects in-process clients at startup and raises immediately:
-```
+```python
RuntimeError: Active Dask client uses an in-process worker (processes=False).
This breaks the label merge when fn holds the GIL. Use a process-based
cluster instead:
diff --git a/docs/guide/skip_empty.md b/docs/guide/skip_empty.md
index 98b0b5a..5e22023 100644
--- a/docs/guide/skip_empty.md
+++ b/docs/guide/skip_empty.md
@@ -88,6 +88,6 @@ tile_process(
After a `tile_process` run with `skip_empty=True`, the log reports exactly
how many tiles ran your function:
-```
+```text
INFO patchworks._core: skip_empty: 486/2200 tiles ran fn, 1714 skipped (max<=412.0)
```
diff --git a/docs/guide/tiling.md b/docs/guide/tiling.md
index b2e2b7f..9049e52 100644
--- a/docs/guide/tiling.md
+++ b/docs/guide/tiling.md
@@ -13,6 +13,7 @@ peak RAM during segmentation is approximately one tile's worth of data.
## Choosing a tile size
The right tile size depends on:
+
- Your available RAM (or GPU VRAM)
- The minimum context your segmentation method needs (objects should fit fully
inside a tile, or you need overlap)
@@ -62,7 +63,7 @@ Methods that need spatial context (Cellpose, StarDist, U-Net) produce wrong
results near tile edges: objects at the boundary are cut off. Overlap fixes this
by expanding each tile by `overlap` voxels on every side.
-```
+```text
No overlap: With overlap=20:
┌──────────┐ ┌──────────────────┐
│ │ │ ░░░░░░░░░░░░░░ │
@@ -86,4 +87,4 @@ No overlap: With overlap=20:
automatically clips the depth per axis, so z-tiles of size 1 (typical in
2-D Cellpose mode) get `depth=0` in z even if you pass `overlap=20`.
- Axes that are too small for the requested overlap simply get a smaller halo.
+ Axes that are too small for the requested overlap simply get a smaller halo.
diff --git a/docs/index.md b/docs/index.md
index 2d1bfd4..f5bd9bf 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -2,7 +2,7 @@
**Tiled processing of arbitrarily large images — any image, any function.**
-```
+```text
┌──────┬──────┬──────┐ ┌──────┬──────┬──────┐
│ │ │ │ fn(tile) → IDs │ 1 │ 2 │ 3 │
│ │ │ │ ───────────────► │ │ │ │
diff --git a/src/patchworks/_chunks.py b/src/patchworks/_chunks.py
index 0659c3a..de353d6 100644
--- a/src/patchworks/_chunks.py
+++ b/src/patchworks/_chunks.py
@@ -49,6 +49,14 @@ def auto_overlap(diameter: float, safety: float = 1.0) -> int:
def _get_available_memory() -> int:
+ """Return available system RAM in bytes.
+
+ Returns
+ -------
+ int
+ Available memory via ``psutil``, or an 8 GiB fallback if it is not
+ installed.
+ """
try:
import psutil
@@ -103,7 +111,14 @@ def safe_worker_count(
def _get_gpu_memory() -> int:
- """Return free GPU VRAM in bytes. Falls back to 8 GiB default."""
+ """Return free GPU VRAM in bytes.
+
+ Returns
+ -------
+ int
+ Free VRAM of GPU 0 via ``nvidia-ml-py``, or an 8 GiB fallback if the
+ query fails.
+ """
try:
import pynvml
diff --git a/src/patchworks/_cluster.py b/src/patchworks/_cluster.py
index 583121c..71b75d0 100644
--- a/src/patchworks/_cluster.py
+++ b/src/patchworks/_cluster.py
@@ -9,7 +9,14 @@
def _distributed_client():
- """Return the active dask.distributed Client, or None."""
+ """Return the active dask.distributed Client, or None.
+
+ Returns
+ -------
+ distributed.Client or None
+ The current client, or ``None`` if none is active / distributed is not
+ installed.
+ """
try:
from dask.distributed import get_client
@@ -19,12 +26,22 @@ def _distributed_client():
def _client_is_in_process(client) -> bool:
- """True if *client* runs its worker in this process (processes=False).
+ """Whether *client* runs its worker in this process (``processes=False``).
An in-process worker shares the GIL. A long task that holds the GIL
(e.g. a Cellpose/torch eval) starves the worker heartbeat, the scheduler
declares it dead, and the P2P merge barrier drops its inputs →
"FutureCancelledError: lost dependencies".
+
+ Parameters
+ ----------
+ client : distributed.Client
+ The client to inspect.
+
+ Returns
+ -------
+ bool
+ True if any worker address uses the ``inproc://`` transport.
"""
try:
for addr in client.scheduler_info().get("workers", {}):
diff --git a/src/patchworks/_core.py b/src/patchworks/_core.py
index c4077a3..f8a4793 100644
--- a/src/patchworks/_core.py
+++ b/src/patchworks/_core.py
@@ -23,7 +23,23 @@
def _stage_to_zarr(
arr: da.Array, path: str, component: str, show_progress: bool
) -> None:
- """Write *arr* to zarr *path/component*, never loading it into RAM."""
+ """Write *arr* to zarr ``path/component``, never loading it into RAM.
+
+ Parameters
+ ----------
+ arr : da.Array
+ Array to materialise to disk.
+ path : str
+ Zarr store path.
+ component : str
+ Array name within the store.
+ show_progress : bool
+ Show a progress bar while computing.
+
+ Returns
+ -------
+ None
+ """
import dask
lazy_write = arr.to_zarr(
@@ -57,7 +73,7 @@ def tile_process(
level: int = 0,
use_gpu: bool = False,
max_workers: int | None = None,
- progress: bool = False,
+ progress: bool = True,
write_to: Union[str, Path, None] = None,
output_component: str = "labels",
pyramid_levels: int = 5,
@@ -123,7 +139,8 @@ def tile_process(
every core. Ignored when a distributed client is active (it manages its
own concurrency).
progress:
- Show a progress bar during the tile-writing and relabel steps.
+ Show progress bars for staging, the label write and the pyramid
+ (default ``True``). Set ``False`` to silence them.
write_to:
Explicit output zarr store path. Overrides the default behaviour: the
merged labels are written here as a single-resolution array named
@@ -297,6 +314,20 @@ def tile_process(
_skip_thr = _auto_empty_threshold(image_for_threshold, channel, level)
def active_fn(block, block_info=None):
+ """Run *fn* on one tile, or return zeros for an empty tile.
+
+ Parameters
+ ----------
+ block : np.ndarray
+ One image tile.
+ block_info : dict or None
+ Dask block metadata (used for logging the tile location).
+
+ Returns
+ -------
+ np.ndarray
+ Integer labels, or an all-zero tile when skipped.
+ """
loc = block_info[0].get("chunk-location") if block_info else "?"
if skip_empty and block.size and block.max() <= _skip_thr:
if verbose:
@@ -398,6 +429,12 @@ def active_fn(block, block_info=None):
# that figure instead.
def _cleanup_stage():
+ """Delete the temporary stage store unless ``keep_stage`` is set.
+
+ Returns
+ -------
+ None
+ """
if not keep_stage:
import shutil
@@ -457,6 +494,7 @@ def _cleanup_stage():
name=output_component,
n_levels=pyramid_levels,
downscale=pyramid_downscale,
+ progress=progress,
overwrite=True,
)
shutil.rmtree(os.path.dirname(_merge_out), ignore_errors=True)
diff --git a/src/patchworks/_io.py b/src/patchworks/_io.py
index a364467..6f0ed78 100644
--- a/src/patchworks/_io.py
+++ b/src/patchworks/_io.py
@@ -75,6 +75,16 @@ def _otsu_threshold(sample: np.ndarray) -> float:
Operates on the full distribution including zeros — zeros are background
pixels and must be included so Otsu can find the signal/background boundary.
+
+ Parameters
+ ----------
+ sample : np.ndarray
+ Flat intensity sample.
+
+ Returns
+ -------
+ float
+ The Otsu threshold, or ``0.0`` when the sample is degenerate.
"""
try:
from skimage.filters import threshold_otsu
@@ -89,7 +99,22 @@ def _otsu_threshold(sample: np.ndarray) -> float:
def _auto_empty_threshold(
image: da.Array, channel: int | None, level: int
) -> float:
- """Pick an empty-tile threshold from a cheap bounded sample (Otsu)."""
+ """Pick an empty-tile threshold from a cheap bounded sample (Otsu).
+
+ Parameters
+ ----------
+ image : da.Array
+ Image to sample.
+ channel : int or None
+ Channel hint (kept for signature symmetry).
+ level : int
+ Pyramid level hint (kept for signature symmetry).
+
+ Returns
+ -------
+ float
+ Otsu threshold over a few small centred windows.
+ """
n = image.ndim
win = [min(64 if i >= n - 3 else s, s) for i, s in enumerate(image.shape)]
win = [min(w, 256) if i >= n - 2 else w for i, w in enumerate(win)]
diff --git a/src/patchworks/_merge.py b/src/patchworks/_merge.py
index e0eed43..109c1be 100644
--- a/src/patchworks/_merge.py
+++ b/src/patchworks/_merge.py
@@ -55,6 +55,25 @@
def _init_worker(lut_path, staged_path, staged_comp, out_path, out_comp):
+ """Initialise a merge worker process with the shared paths and LUT.
+
+ Parameters
+ ----------
+ lut_path : str
+ Path to the relabel lookup table (loaded memory-mapped, read-only).
+ staged_path : str
+ Path to the staged-labels zarr store.
+ staged_comp : str
+ Component name within the staged store.
+ out_path : str
+ Path to the output zarr store.
+ out_comp : str
+ Component name within the output store.
+
+ Returns
+ -------
+ None
+ """
global _merge_lut, _merge_lut_path, _merge_staged_path, _merge_staged_comp
global _merge_out_path, _merge_out_comp
_merge_lut = np.load(
@@ -68,6 +87,17 @@ def _init_worker(lut_path, staged_path, staged_comp, out_path, out_comp):
def _relabel_chunk_worker(chunk_slice: tuple) -> None:
+ """Apply the relabel LUT to one chunk and write it to the output store.
+
+ Parameters
+ ----------
+ chunk_slice : tuple
+ The slice selecting this chunk in both stores.
+
+ Returns
+ -------
+ None
+ """
src = zarr.open_group(_merge_staged_path, mode="r")[_merge_staged_comp]
dst = zarr.open_group(_merge_out_path, mode="r+")[_merge_out_comp]
block = np.asarray(src[chunk_slice], dtype=np.int64)
@@ -87,6 +117,20 @@ def _relabel_chunk_worker(chunk_slice: tuple) -> None:
def _boundary_face_specs(
shape: tuple[int, ...], chunk_shape: tuple[int, ...]
) -> list[tuple[int, int]]:
+ """Enumerate interior chunk boundaries to scan for touching labels.
+
+ Parameters
+ ----------
+ shape : tuple of int
+ Array shape.
+ chunk_shape : tuple of int
+ Chunk shape.
+
+ Returns
+ -------
+ list of tuple of int
+ ``(axis, position)`` pairs, one per interior chunk boundary.
+ """
specs = []
for ax, (s, cs) in enumerate(zip(shape, chunk_shape)):
pos = cs
@@ -105,6 +149,21 @@ def _scan_touching_pairs(
is bounded to one chunk (~200 MB). Reading the full face at once
(slice(None) on face axes) would allocate face_area × 8 bytes in one shot —
e.g. 37888 × 27392 × 8 = 8 GiB for a single z-face (OOM on real datasets).
+
+ Parameters
+ ----------
+ zarr_path : str
+ Path to the staged-labels zarr store.
+ component : str
+ Component name within the store.
+ chunk_shape : tuple of int
+ Chunk shape (sets the per-read column size).
+
+ Returns
+ -------
+ np.ndarray
+ ``(N, 2)`` int64 array of unique label pairs touching across a
+ boundary.
"""
root = zarr.open_group(zarr_path, mode="r")
arr = root[component]
@@ -135,7 +194,20 @@ def _scan_touching_pairs(
def _build_relabel_lut(pairs: np.ndarray, max_label: int) -> np.ndarray:
- """Touching-pairs → scipy connected components → relabeling LUT."""
+ """Build a relabel LUT from touching pairs via connected components.
+
+ Parameters
+ ----------
+ pairs : np.ndarray
+ ``(N, 2)`` array of touching label pairs.
+ max_label : int
+ Largest label id present.
+
+ Returns
+ -------
+ np.ndarray
+ Lookup table mapping each old label to its merged (component) id.
+ """
if max_label > _LUT_WARN_THRESHOLD:
logger.warning(
"_build_relabel_lut: max_label=%d → LUT ~%.0f MB. "
@@ -168,6 +240,24 @@ def _build_relabel_lut(pairs: np.ndarray, max_label: int) -> np.ndarray:
def _create_zarr_label_array(
group: zarr.Group, name: str, shape: tuple, chunks: tuple
) -> zarr.Array:
+ """Create (replacing any existing) an int32 label array in *group*.
+
+ Parameters
+ ----------
+ group : zarr.Group
+ Parent group.
+ name : str
+ Array name (may be a nested path).
+ shape : tuple
+ Array shape.
+ chunks : tuple
+ Chunk shape.
+
+ Returns
+ -------
+ zarr.Array
+ The newly created array (works on zarr v2 and v3).
+ """
if name in group:
del group[name]
if _ZARR_V3:
@@ -192,6 +282,25 @@ def zarr_native_merge(
Scales to 2000+ chunks where the dask_image approach stalls (O(n_chunks²)
graph). Reads *staged_path/staged_component*, merges touching cross-boundary
labels, writes result to *out_path/out_component*. No dask task graph.
+
+ Parameters
+ ----------
+ staged_path : str
+ Path to the staged-labels zarr store.
+ staged_component : str
+ Component name within the staged store.
+ out_path : str
+ Path to the output zarr store.
+ out_component : str
+ Component name within the output store.
+ n_workers : int
+ Number of worker processes for the parallel relabel.
+ show_progress : bool
+ Show a progress bar over the relabel chunks.
+
+ Returns
+ -------
+ None
"""
root = zarr.open_group(staged_path, mode="r")
arr = root[staged_component]
diff --git a/src/patchworks/_relabel.py b/src/patchworks/_relabel.py
index 2e27e18..5b5dcac 100644
--- a/src/patchworks/_relabel.py
+++ b/src/patchworks/_relabel.py
@@ -20,6 +20,16 @@ def relabel_sequential_array(labels: np.ndarray) -> np.ndarray:
Background (0) stays 0. Runs in one ``np.unique`` + a lookup-table gather,
i.e. O(voxels) — unlike dask's ``relabel_sequential`` which is O(n_chunks²).
+ Parameters
+ ----------
+ labels : np.ndarray
+ Integer label array (may have gappy ids).
+
+ Returns
+ -------
+ np.ndarray
+ Labels remapped to a contiguous ``0, 1, … N`` range.
+
Examples
--------
>>> relabel_sequential_array(np.array([0, 500000, 500000, 7]))
diff --git a/src/patchworks/plugins/cellpose.py b/src/patchworks/plugins/cellpose.py
index 3eb213f..14fa741 100644
--- a/src/patchworks/plugins/cellpose.py
+++ b/src/patchworks/plugins/cellpose.py
@@ -40,6 +40,12 @@
def _require_cellpose():
+ """Raise an actionable ImportError if cellpose is not installed.
+
+ Returns
+ -------
+ None
+ """
if _cellpose_models is None:
raise ImportError(
"cellpose is not installed. Install it with:\n"
@@ -126,6 +132,30 @@ def _make_config(
do_3D: bool = False,
**cellpose_kwargs: Any,
) -> dict[str, Any]:
+ """Build a picklable Cellpose configuration dict.
+
+ Parameters
+ ----------
+ model : str
+ Cellpose model type.
+ gpu : bool
+ Run on the GPU.
+ channels : list of int or None
+ Cellpose-3 ``[cyto, nucleus]`` channels; defaults to ``[0, 0]``.
+ channel_axis : int or None
+ Cellpose-4 channel axis.
+ diameter : float or None
+ Expected cell diameter in pixels.
+ do_3D : bool
+ Segment in 3-D.
+ **cellpose_kwargs : Any
+ Extra arguments forwarded to ``model.eval()``.
+
+ Returns
+ -------
+ dict
+ The configuration consumed by :func:`_get_model` and :func:`_run`.
+ """
return {
"model": model,
"gpu": gpu,
@@ -138,7 +168,18 @@ def _make_config(
def _get_model(cellpose_dict: dict[str, Any]) -> Any:
- """Return a worker-local cached Cellpose model."""
+ """Return a worker-local cached Cellpose model.
+
+ Parameters
+ ----------
+ cellpose_dict : dict
+ Configuration from :func:`_make_config`.
+
+ Returns
+ -------
+ Any
+ A Cellpose model instance (cached per ``(model, gpu)`` per process).
+ """
_require_cellpose()
key = (cellpose_dict["model"], cellpose_dict.get("gpu", False))
if key not in _model_cache:
@@ -156,7 +197,20 @@ def _get_model(cellpose_dict: dict[str, Any]) -> Any:
def _run(block: np.ndarray, cellpose_dict: dict[str, Any]) -> np.ndarray:
- """Segment one tile with a cached Cellpose model."""
+ """Segment one tile with a cached Cellpose model.
+
+ Parameters
+ ----------
+ block : np.ndarray
+ One image tile.
+ cellpose_dict : dict
+ Configuration from :func:`_make_config`.
+
+ Returns
+ -------
+ np.ndarray
+ Integer (``int32``) label array of the same spatial shape.
+ """
model = _get_model(cellpose_dict)
do_3D = cellpose_dict["do_3D"]
diff --git a/src/patchworks/plugins/napari.py b/src/patchworks/plugins/napari.py
index 12c86e3..189f259 100644
--- a/src/patchworks/plugins/napari.py
+++ b/src/patchworks/plugins/napari.py
@@ -37,6 +37,13 @@
def _require_napari():
+ """Import and return napari, or raise an actionable ImportError.
+
+ Returns
+ -------
+ module
+ The imported ``napari`` module.
+ """
try:
import napari
except ImportError as exc:
@@ -48,10 +55,34 @@ def _require_napari():
def _is_zarr(src: Any) -> bool:
+ """Whether *src* is a path ending in ``.zarr``.
+
+ Parameters
+ ----------
+ src : Any
+ Candidate source.
+
+ Returns
+ -------
+ bool
+ True for a str/Path ending in ``.zarr``.
+ """
return isinstance(src, (str, Path)) and str(src).endswith(".zarr")
def _has_multiscales(path: Union[str, Path]) -> bool:
+ """Whether a zarr group carries NGFF ``multiscales`` metadata.
+
+ Parameters
+ ----------
+ path : str or Path
+ Zarr group path.
+
+ Returns
+ -------
+ bool
+ True if the group has a ``multiscales`` attribute.
+ """
root = zarr.open_group(str(path), mode="r")
return "multiscales" in root.attrs
@@ -59,7 +90,20 @@ def _has_multiscales(path: Union[str, Path]) -> bool:
def _multiscale_levels(
path: Union[str, Path], channel: int | None
) -> list[da.Array]:
- """Return every pyramid level as a lazy dask array (napari multi-scale)."""
+ """Return every pyramid level as a lazy dask array (napari multi-scale).
+
+ Parameters
+ ----------
+ path : str or Path
+ OME-ZARR group path.
+ channel : int or None
+ Channel to select, or ``None`` to keep all channels.
+
+ Returns
+ -------
+ list of da.Array
+ One lazy array per resolution level.
+ """
root = zarr.open_group(str(path), mode="r")
datasets = root.attrs["multiscales"][0]["datasets"]
return [
@@ -71,7 +115,20 @@ def _multiscale_levels(
def _resolve_image(
source: Union[da.Array, str, Path], channel: int | None
) -> Union[da.Array, list[da.Array]]:
- """Resolve *source* into data napari can display (lazily)."""
+ """Resolve *source* into data napari can display (lazily).
+
+ Parameters
+ ----------
+ source : da.Array, str or Path
+ OME-ZARR store, other image file, or an in-memory array.
+ channel : int or None
+ Channel to display, or ``None`` to keep all channels.
+
+ Returns
+ -------
+ da.Array or list of da.Array
+ A single array, or a multi-scale list for an OME-ZARR pyramid.
+ """
if _is_zarr(source):
if _has_multiscales(source):
return _multiscale_levels(source, channel)
@@ -86,7 +143,18 @@ def _resolve_image(
def _inner_label_names(store: Union[str, Path]) -> list[str]:
- """Names registered under an OME-ZARR's NGFF ``labels/`` group, if any."""
+ """List label images registered under an OME-ZARR's ``labels/`` group.
+
+ Parameters
+ ----------
+ store : str or Path
+ OME-ZARR store path.
+
+ Returns
+ -------
+ list of str
+ Registered label-image names (empty if there are none).
+ """
try:
grp = zarr.open_group(f"{store}/labels", mode="r")
except Exception:
@@ -97,7 +165,20 @@ def _inner_label_names(store: Union[str, Path]) -> list[str]:
def _resolve_labels(
source: Union[da.Array, str, Path], component: str
) -> Union[da.Array, list[da.Array]]:
- """Resolve a label *source* into integer data for an Labels layer."""
+ """Resolve a label *source* into integer data for a Labels layer.
+
+ Parameters
+ ----------
+ source : da.Array, str or Path
+ Label store (plain or multi-scale) or an in-memory array.
+ component : str
+ Array name inside a plain-zarr label store.
+
+ Returns
+ -------
+ da.Array or list of da.Array
+ Integer (``int32``) labels; a list for a multi-scale store.
+ """
if _is_zarr(source):
if _has_multiscales(source):
levels = _multiscale_levels(source, None)
diff --git a/src/patchworks/plugins/ome_zarr.py b/src/patchworks/plugins/ome_zarr.py
index 68c2975..b2b2ef5 100644
--- a/src/patchworks/plugins/ome_zarr.py
+++ b/src/patchworks/plugins/ome_zarr.py
@@ -37,6 +37,7 @@
from __future__ import annotations
import logging
+import math
from pathlib import Path
from typing import Union
@@ -66,6 +67,16 @@ def _default_axes(ndim: int) -> str:
"""Assign trailing OME axis names to an unlabelled array.
A 2-D array becomes ``"yx"``, 3-D ``"zyx"``, 4-D ``"czyx"``.
+
+ Parameters
+ ----------
+ ndim : int
+ Number of array dimensions.
+
+ Returns
+ -------
+ str
+ The axes string for those trailing dimensions.
"""
if ndim > len(_DEFAULT_ORDER):
raise ValueError(
@@ -75,13 +86,38 @@ def _default_axes(ndim: int) -> str:
def _axis_type(name: str) -> str:
+ """Map an axis letter to its NGFF axis type.
+
+ Parameters
+ ----------
+ name : str
+ Axis letter (``z``/``y``/``x``/``c``/``t``).
+
+ Returns
+ -------
+ str
+ ``"space"``, ``"time"`` or ``"channel"``.
+ """
if name in _SPATIAL_AXES:
return "space"
return "time" if name == "t" else "channel"
def _axes_meta(axes: str, calibrated: bool) -> list[dict]:
- """NGFF ``axes`` metadata; spatial axes get a µm unit when calibrated."""
+ """Build NGFF ``axes`` metadata for an axes string.
+
+ Parameters
+ ----------
+ axes : str
+ One letter per axis.
+ calibrated : bool
+ When True, spatial axes carry a micrometer unit.
+
+ Returns
+ -------
+ list of dict
+ One ``{name, type[, unit]}`` entry per axis.
+ """
meta = []
for a in axes:
entry = {"name": a, "type": _axis_type(a)}
@@ -92,19 +128,244 @@ def _axes_meta(axes: str, calibrated: bool) -> list[dict]:
def _strides(axes: str, downscale: int) -> tuple[int, ...]:
- """Per-axis stride: downsample X/Y only; Z, C and T stay at 1."""
+ """Per-axis downsampling stride for one pyramid step.
+
+ Parameters
+ ----------
+ axes : str
+ One letter per axis.
+ downscale : int
+ Downsampling factor for X/Y.
+
+ Returns
+ -------
+ tuple of int
+ ``downscale`` for X/Y axes, ``1`` for Z/C/T.
+ """
return tuple(downscale if a in _DOWNSAMPLE_AXES else 1 for a in axes)
def _default_chunks(shape: tuple[int, ...], axes: str) -> tuple[int, ...]:
- """Bounded chunk shape so writing a level never blows up RAM."""
+ """Bounded chunk shape so writing a level never blows up RAM.
+
+ Parameters
+ ----------
+ shape : tuple of int
+ Array shape.
+ axes : str
+ One letter per axis (selects the per-axis cap).
+
+ Returns
+ -------
+ tuple of int
+ Per-axis chunk size, capped by ``_CHUNK_CAP``.
+ """
return tuple(min(s, _CHUNK_CAP.get(a, s)) for s, a in zip(shape, axes))
+ShardSpec = Union[bool, tuple[int, ...]]
+_SHARD_TARGET_BYTES = 512 * 1024**2 # aim for ~512 MB shards
+_ZARR_V3 = int(zarr.__version__.split(".")[0]) >= 3
+
+
+def _effective_shard(
+ requested: tuple[int, ...],
+ chunks: tuple[int, ...],
+ shape: tuple[int, ...],
+) -> tuple[int, ...]:
+ """Clamp a requested shard shape to a valid one.
+
+ A shard must be a whole multiple of the inner chunk and should not exceed
+ the array's chunk-padded extent.
+
+ Parameters
+ ----------
+ requested : tuple of int
+ Desired shard shape.
+ chunks : tuple of int
+ Inner chunk shape.
+ shape : tuple of int
+ Array shape.
+
+ Returns
+ -------
+ tuple of int
+ A shard shape that is a chunk-multiple within the array.
+ """
+ out = []
+ for r, c, s in zip(requested, chunks, shape):
+ cap = math.ceil(s / c) * c # array dim padded up to a whole chunk
+ out.append(min(max(c, (r // c) * c), cap))
+ return tuple(out)
+
+
+def _auto_shard(
+ chunks: tuple[int, ...], shape: tuple[int, ...], dtype
+) -> tuple[int, ...]:
+ """Pick a shard shape of roughly ``_SHARD_TARGET_BYTES``.
+
+ Grows the two largest axes equally until the shard reaches the target size,
+ then clamps to a valid chunk-multiple.
+
+ Parameters
+ ----------
+ chunks : tuple of int
+ Inner chunk shape.
+ shape : tuple of int
+ Array shape.
+ dtype : data-type
+ Array dtype, used to size the shard in bytes.
+
+ Returns
+ -------
+ tuple of int
+ The chosen shard shape.
+ """
+ itemsize = np.dtype(dtype).itemsize
+ base = itemsize
+ for c in chunks:
+ base *= c
+ big = sorted(range(len(chunks)), key=lambda i: shape[i], reverse=True)[:2]
+ factor = max(1, int((_SHARD_TARGET_BYTES / max(1, base)) ** 0.5))
+ shard = list(chunks)
+ for i in big:
+ shard[i] = chunks[i] * factor
+ return _effective_shard(tuple(shard), chunks, shape)
+
+
+def _shard_for(
+ shard: ShardSpec,
+ chunks: tuple[int, ...],
+ shape: tuple[int, ...],
+ dtype,
+) -> Union[tuple[int, ...], None]:
+ """Resolve the ``shard`` argument into a concrete shard shape.
+
+ Parameters
+ ----------
+ shard : bool or tuple of int
+ ``False`` → no sharding; ``True`` → auto;
+ a tuple → an explicit shard shape.
+ chunks : tuple of int
+ Inner chunk shape.
+ shape : tuple of int
+ Array shape.
+ dtype : data-type
+ Array dtype.
+
+ Returns
+ -------
+ tuple of int or None
+ The shard shape, or ``None`` when not sharding
+ (also when zarr is older than v3).
+ """
+ if not shard:
+ return None
+ if not _ZARR_V3:
+ logger.warning("sharding requires zarr v3; writing unsharded.")
+ return None
+ if shard is True:
+ return _auto_shard(chunks, shape, dtype)
+ return _effective_shard(tuple(shard), chunks, shape)
+
+
+def _progress_ctx(progress: bool, label: str):
+ """Return a progress-bar context manager.
+
+ Parameters
+ ----------
+ progress : bool
+ Whether to show a dask progress bar.
+ label : str
+ Name logged just before the bar.
+
+ Returns
+ -------
+ contextmanager
+ A ``ProgressBar`` when *progress* is set, else a no-op
+ context manager.
+ """
+ if not progress:
+ from contextlib import nullcontext
+
+ return nullcontext()
+ from dask.diagnostics import ProgressBar
+
+ logger.info("writing %s …", label)
+ return ProgressBar()
+
+
+def _to_zarr_level(
+ arr: da.Array,
+ group_path: str,
+ component: str,
+ shard: ShardSpec,
+ progress: bool = True,
+) -> None:
+ """Write one array to ``group_path/component``, optionally sharded.
+
+ Without sharding, ``da.to_zarr`` writes chunk by chunk. With sharding that
+ is unsafe — many chunks share one shard file and per-chunk writes race —
+ so we create the sharded array explicitly (inner *chunks* + *shards*) and
+ store with the dask blocks rechunked to the **shard** size, so each task
+ writes one whole shard atomically.
+
+ Parameters
+ ----------
+ arr : da.Array
+ Array to write (its chunk size becomes the inner chunk).
+ group_path : str
+ Path of the parent zarr group.
+ component : str
+ Array name within the group.
+ shard : bool or tuple of int
+ Sharding request; see :func:`_shard_for`.
+ progress : bool
+ Show a dask progress bar for the write.
+
+ Returns
+ -------
+ None
+ """
+ inner = arr.chunksize
+ sh = _shard_for(shard, inner, arr.shape, arr.dtype)
+ ctx = _progress_ctx(progress, f"{Path(group_path).name}/{component}")
+ if not sh:
+ with ctx:
+ da.to_zarr(arr, group_path, component=component, overwrite=True)
+ return
+ grp = zarr.open_group(group_path, mode="a")
+ if component in grp:
+ del grp[component]
+ z = grp.create_array(
+ name=component,
+ shape=arr.shape,
+ chunks=inner,
+ shards=sh,
+ dtype=arr.dtype,
+ )
+ with ctx:
+ arr.rechunk(sh).store(z, lock=True, compute=True)
+
+
def _normalize_pixel_size(
pixel_size: Union[PixelSize, tuple, None], axes: str
) -> PixelSize:
- """Coerce a pixel-size dict/tuple into ``{axis: size}`` for spatial axes."""
+ """Coerce a pixel-size dict/tuple into ``{axis: size}``.
+
+ Parameters
+ ----------
+ pixel_size : dict, tuple or None
+ Voxel size as a per-axis dict or a tuple
+ aligned to the spatial axes.
+ axes : str
+ One letter per axis.
+
+ Returns
+ -------
+ dict
+ ``{axis: size}`` for the spatial axes present (empty if none given).
+ """
if not pixel_size:
return {}
if isinstance(pixel_size, dict):
@@ -115,11 +376,38 @@ def _normalize_pixel_size(
def _base_scale(axes: str, pixel_size: PixelSize) -> list[float]:
- """Level-0 NGFF scale per axis: physical size for spatial, 1.0 else."""
+ """Build the level-0 NGFF scale vector.
+
+ Parameters
+ ----------
+ axes : str
+ One letter per axis.
+ pixel_size : dict
+ ``{axis: size}`` for spatial axes.
+
+ Returns
+ -------
+ list of float
+ Physical size per spatial axis, ``1.0`` for C/T.
+ """
return [float(pixel_size.get(a, 1.0)) for a in axes]
def _dataset(name: str, scale: list[float]) -> dict:
+ """Build one NGFF ``multiscales`` dataset entry.
+
+ Parameters
+ ----------
+ name : str
+ Component path of the level (e.g. ``"0"``).
+ scale : list of float
+ Per-axis scale (physical size × downsample factor).
+
+ Returns
+ -------
+ dict
+ A dataset dict with its ``path`` and ``coordinateTransformations``.
+ """
return {
"path": name,
"coordinateTransformations": [
@@ -139,6 +427,8 @@ def _write_pyramid(
base_scale: list[float],
base_name: str = "0",
write_base: bool = True,
+ shard: ShardSpec = False,
+ progress: bool = True,
) -> list[dict]:
"""Write pyramid levels into *group_path* and return NGFF datasets.
@@ -147,16 +437,43 @@ def _write_pyramid(
whole-volume recomputation, no OOM. Level 0 is named *base_name*; when
*write_base* is False it is assumed to already exist (used by
:func:`add_pyramid`).
+
+ Parameters
+ ----------
+ arr : da.Array
+ Full-resolution array.
+ axes : str
+ One letter per axis.
+ group_path : str
+ Path of the zarr group to write into.
+ n_levels : int
+ Maximum number of levels including full resolution.
+ downscale : int
+ Per-level X/Y downsampling factor.
+ chunks : tuple of int or None
+ Chunk shape, or a bounded default.
+ base_scale : list of float
+ Level-0 physical scale per axis.
+ base_name : str
+ Component name of level 0.
+ write_base : bool
+ Write level 0, or assume it already exists.
+ shard : bool or tuple of int
+ Sharding request (see :func:`_shard_for`).
+ progress : bool
+ Show a per-level progress bar.
+
+ Returns
+ -------
+ list of dict
+ One NGFF dataset entry per written level.
"""
strides = _strides(axes, downscale)
if write_base:
base_chunks = chunks or _default_chunks(arr.shape, axes)
- da.to_zarr(
- arr.rechunk(base_chunks),
- group_path,
- component=base_name,
- overwrite=True,
+ _to_zarr_level(
+ arr.rechunk(base_chunks), group_path, base_name, shard, progress
)
datasets = [_dataset(base_name, base_scale)]
@@ -170,7 +487,7 @@ def _write_pyramid(
src = da.from_zarr(group_path, component=prev_name)
nxt = src[tuple(slice(None, None, st) for st in strides)]
nxt = nxt.rechunk(chunks or _default_chunks(nxt.shape, axes))
- da.to_zarr(nxt, group_path, component=str(i), overwrite=True)
+ _to_zarr_level(nxt, group_path, str(i), shard, progress)
scale = [base_scale[k] * (strides[k] ** i) for k in range(len(axes))]
datasets.append(_dataset(str(i), scale))
logger.info("pyramid level %d: shape=%s", i, nxt.shape)
@@ -187,7 +504,25 @@ def _write_multiscales(
*,
calibrated: bool,
) -> None:
- """Write NGFF ``multiscales`` metadata onto *group_path*."""
+ """Write NGFF ``multiscales`` metadata onto *group_path*.
+
+ Parameters
+ ----------
+ group_path : str
+ Path of the zarr group to annotate.
+ axes : str
+ One letter per axis.
+ datasets : list of dict
+ Per-level dataset entries (see :func:`_dataset`).
+ name : str
+ Multiscales name.
+ calibrated : bool
+ Whether spatial axes carry a micrometer unit.
+
+ Returns
+ -------
+ None
+ """
group = zarr.open_group(group_path, mode="a")
group.attrs["multiscales"] = [
{
@@ -200,7 +535,21 @@ def _write_multiscales(
def _read_zarr_calibration(store: Union[str, Path], axes: str) -> PixelSize:
- """Read level-0 spatial scale from an existing OME-ZARR, if any."""
+ """Read level-0 spatial scale from an existing OME-ZARR, if any.
+
+ Parameters
+ ----------
+ store : str or Path
+ Path of the OME-ZARR group.
+ axes : str
+ One letter per axis (unused for parsing, kept for symmetry).
+
+ Returns
+ -------
+ dict
+ ``{axis: size}`` for spatial axes with a non-unit scale (empty if
+ the store has no multiscales metadata).
+ """
try:
root = zarr.open_group(str(store), mode="r")
ms = root.attrs["multiscales"][0]
@@ -216,7 +565,23 @@ def _read_zarr_calibration(store: Union[str, Path], axes: str) -> PixelSize:
def _open_bioio(path: str, scene: int) -> tuple[da.Array, str, PixelSize]:
- """Open *path* with bioio → ``(array, axes, pixel_size)``, all lazy."""
+ """Open *path* with bioio, lazily.
+
+ Singleton non-spatial axes (T/C of size 1) are dropped.
+
+ Parameters
+ ----------
+ path : str
+ Image file path.
+ scene : int
+ Scene index for multi-scene files.
+
+ Returns
+ -------
+ tuple
+ ``(array, axes, pixel_size)`` — a lazy dask array, its axes string
+ and a ``{axis: micrometers}`` calibration dict.
+ """
try:
from bioio import BioImage
except ImportError as exc:
@@ -256,7 +621,24 @@ def _open_bioio(path: str, scene: int) -> tuple[da.Array, str, PixelSize]:
def _open_imaris(path: str, level: int = 0) -> tuple[da.Array, str, PixelSize]:
- """Open an Imaris ``.ims`` *level* lazily → ``(array, axes, pixel_size)``."""
+ """Open one Imaris ``.ims`` resolution level lazily.
+
+ Reads the underlying HDF5 datasets directly (own handle, crops the Imaris
+ chunk padding) and stacks the per-(timepoint, channel) 3-D arrays.
+
+ Parameters
+ ----------
+ path : str
+ Imaris ``.ims`` file path.
+ level : int
+ Resolution level to read (0 = full resolution).
+
+ Returns
+ -------
+ tuple
+ ``(array, axes, pixel_size)`` — a lazy dask array, its axes string
+ and a ``{axis: micrometers}`` calibration dict.
+ """
try:
from imaris_ims_file_reader.ims import ims
except ImportError as exc:
@@ -326,12 +708,34 @@ def _write_imaris_pyramid(
*,
chunks: Union[tuple[int, ...], None],
overwrite: bool,
+ shard: ShardSpec = False,
+ progress: bool = True,
) -> str:
"""Copy an Imaris file's own resolution levels into an OME-ZARR.
Each Imaris ``ResolutionLevel`` is written as a pyramid level with its own
physical scale, so no downsampling is recomputed. Lazy (h5py-backed) reads
stream straight to disk.
+
+ Parameters
+ ----------
+ path : str
+ Imaris ``.ims`` file path.
+ out : str
+ Destination ``.zarr`` store path.
+ chunks : tuple of int or None
+ Chunk shape, or a bounded default.
+ overwrite : bool
+ Overwrite an existing store.
+ shard : bool or tuple of int
+ Sharding request (see :func:`_shard_for`).
+ progress : bool
+ Show a per-level progress bar.
+
+ Returns
+ -------
+ str
+ The path to the written store.
"""
from imaris_ims_file_reader.ims import ims
@@ -346,11 +750,12 @@ def _write_imaris_pyramid(
arr, axes, ps = _open_imaris(path, level=level)
scale = _base_scale(axes, ps)
calibrated = calibrated or bool(ps)
- da.to_zarr(
+ _to_zarr_level(
arr.rechunk(chunks or _default_chunks(arr.shape, axes)),
out,
- component=str(level),
- overwrite=True,
+ str(level),
+ shard,
+ progress,
)
datasets.append(_dataset(str(level), scale))
logger.info("imaris level %d copied: shape=%s", level, arr.shape)
@@ -365,7 +770,26 @@ def _to_dask(
axes: Union[str, None],
scene: int,
) -> tuple[da.Array, str, PixelSize]:
- """Resolve *source* into a lazy ``(array, axes, pixel_size)`` triple."""
+ """Resolve *source* into a lazy ``(array, axes, pixel_size)`` triple.
+
+ Dispatches by type: dask/NumPy arrays pass through; ``.zarr`` paths use the
+ OME-ZARR loader; ``.ims`` paths use the Imaris reader; anything else uses
+ bioio.
+
+ Parameters
+ ----------
+ source : da.Array, np.ndarray, str or Path
+ Array or path to resolve.
+ axes : str or None
+ Explicit axes, or ``None`` to infer them.
+ scene : int
+ Scene index for bioio inputs.
+
+ Returns
+ -------
+ tuple
+ ``(array, axes, pixel_size)``.
+ """
if isinstance(source, da.Array):
return source, axes or _default_axes(source.ndim), {}
if isinstance(source, np.ndarray):
@@ -394,7 +818,9 @@ def to_ome_zarr(
n_levels: int = 5,
downscale: int = 2,
chunks: Union[tuple[int, ...], None] = None,
+ shard: ShardSpec = False,
reuse_pyramid: bool = False,
+ progress: bool = True,
overwrite: bool = False,
) -> str:
"""Write *source* as a pyramidal, calibrated OME-ZARR store.
@@ -427,6 +853,15 @@ def to_ome_zarr(
Per-level X/Y downsampling factor (default 2).
chunks : tuple of int, optional
Chunk shape for the written levels. ``None`` → a bounded default.
+ shard : bool or tuple of int, optional
+ Pack many chunks into one shard file (zarr v3), cutting the file count
+ ~100× on huge arrays. ``False`` (default) → unsharded, maximum reader
+ compatibility. ``True`` → auto-pick a ~512 MB shard. A tuple sets an
+ explicit shard shape (clamped to a chunk multiple). Sharded writes hold
+ ~one shard per worker in RAM. Requires zarr v3 (ignored otherwise).
+ progress : bool, optional
+ Show a per-level dask progress bar (default ``True``). Set ``False`` to
+ silence it.
reuse_pyramid : bool, optional
*Imaris ``.ims`` only.* Copy the file's **own** resolution levels
instead of rebuilding the pyramid (faster, no recompute), keeping each
@@ -460,7 +895,12 @@ def to_ome_zarr(
):
try:
return _write_imaris_pyramid(
- str(source), str(out_path), chunks=chunks, overwrite=overwrite
+ str(source),
+ str(out_path),
+ chunks=chunks,
+ overwrite=overwrite,
+ shard=shard,
+ progress=progress,
)
except Exception as exc:
logger.warning(
@@ -487,6 +927,8 @@ def to_ome_zarr(
downscale=downscale,
chunks=chunks,
base_scale=base_scale,
+ shard=shard,
+ progress=progress,
)
_write_multiscales(out, axes, datasets, Path(out).stem, calibrated=bool(ps))
return out
@@ -501,6 +943,8 @@ def add_pyramid(
n_levels: int = 5,
downscale: int = 2,
chunks: Union[tuple[int, ...], None] = None,
+ shard: ShardSpec = False,
+ progress: bool = True,
) -> str:
"""Add downsampled pyramid levels to an existing single-resolution zarr.
@@ -552,6 +996,8 @@ def add_pyramid(
base_scale=base_scale,
base_name=base,
write_base=False,
+ shard=shard,
+ progress=progress,
)
_write_multiscales(gp, axes, datasets, Path(gp).stem, calibrated=bool(ps))
return gp
@@ -566,6 +1012,8 @@ def register_labels(
n_levels: int = 5,
downscale: int = 2,
chunks: Union[tuple[int, ...], None] = None,
+ shard: ShardSpec = False,
+ progress: bool = True,
) -> str:
"""Pyramidalise and register an existing ``labels//0`` base level.
@@ -594,6 +1042,8 @@ def register_labels(
n_levels=n_levels,
downscale=downscale,
chunks=chunks,
+ shard=shard,
+ progress=progress,
)
grp = zarr.open_group(group, mode="a")
grp.attrs["image-label"] = {"version": _NGFF_VERSION}
@@ -616,6 +1066,8 @@ def write_labels(
n_levels: int = 5,
downscale: int = 2,
chunks: Union[tuple[int, ...], None] = None,
+ shard: ShardSpec = False,
+ progress: bool = True,
overwrite: bool = False,
) -> str:
"""Store *labels* inside *image_store* under the NGFF ``labels/`` group.
@@ -648,7 +1100,7 @@ def write_labels(
label_group = f"{store}/labels/{name}"
base = arr.rechunk(chunks or _default_chunks(arr.shape, axes))
- da.to_zarr(base, label_group, component="0", overwrite=True)
+ _to_zarr_level(base, label_group, "0", shard, progress)
return register_labels(
store,
name,
@@ -657,4 +1109,6 @@ def write_labels(
n_levels=n_levels,
downscale=downscale,
chunks=chunks,
+ shard=shard,
+ progress=progress,
)
diff --git a/tests/test_ome_zarr.py b/tests/test_ome_zarr.py
index db12cbf..837fef9 100644
--- a/tests/test_ome_zarr.py
+++ b/tests/test_ome_zarr.py
@@ -139,3 +139,42 @@ def test_reuse_pyramid_ignored_for_arrays(tmp_path):
reuse_pyramid=True,
)
assert load_ome_zarr(out, channel=None, level=1).shape == (8, 4, 4)
+
+
+def test_sharding(tmp_path):
+ """shard=True/tuple writes zarr-v3 shards; data round-trips intact."""
+ import zarr as _zarr
+
+ a = np.arange(4 * 64 * 64, dtype="uint16").reshape(4, 64, 64)
+
+ out = to_ome_zarr(
+ a,
+ tmp_path / "s.zarr",
+ axes="zyx",
+ n_levels=2,
+ chunks=(2, 16, 16),
+ shard=True,
+ )
+ z0 = _zarr.open_array(f"{out}/0", mode="r")
+ assert z0.chunks == (2, 16, 16)
+ assert z0.shards is not None and z0.shards != z0.chunks
+ assert np.array_equal(
+ np.asarray(load_ome_zarr(out, channel=None, level=0)), a
+ )
+
+ out2 = to_ome_zarr(
+ a,
+ tmp_path / "e.zarr",
+ axes="zyx",
+ n_levels=1,
+ chunks=(2, 16, 16),
+ shard=(2, 32, 32),
+ )
+ assert _zarr.open_array(f"{out2}/0", mode="r").shards == (2, 32, 32)
+
+ out3 = to_ome_zarr(
+ a, tmp_path / "n.zarr", axes="zyx", n_levels=1, chunks=(2, 16, 16)
+ )
+ assert (
+ getattr(_zarr.open_array(f"{out3}/0", mode="r"), "shards", None) is None
+ )