From 75f66cfb4e710e2966581c99628f47152a43598a Mon Sep 17 00:00:00 2001 From: guipenedo Date: Fri, 22 May 2026 01:26:50 +0200 Subject: [PATCH 01/39] Add Zarr reader and writer --- src/refiner/pipeline/pipeline.py | 19 +++- src/refiner/pipeline/sinks/__init__.py | 2 + src/refiner/pipeline/sinks/zarr.py | 123 +++++++++++++++++++++++++ 3 files changed, 143 insertions(+), 1 deletion(-) create mode 100644 src/refiner/pipeline/sinks/zarr.py diff --git a/src/refiner/pipeline/pipeline.py b/src/refiner/pipeline/pipeline.py index 6ace95d2..533434d5 100644 --- a/src/refiner/pipeline/pipeline.py +++ b/src/refiner/pipeline/pipeline.py @@ -30,7 +30,7 @@ VectorizedSegmentStep, WithColumnsStep, ) -from refiner.pipeline.sinks import BaseSink, JsonlSink, ParquetSink +from refiner.pipeline.sinks import BaseSink, JsonlSink, ParquetSink, ZarrSink from refiner.pipeline.sinks.assets import MissingAssetPolicy from refiner.pipeline.sources import ( BaseSource, @@ -431,6 +431,23 @@ def write_parquet( ) ) + def write_zarr( + self, + output: DataFolderLike, + *, + arrays: Mapping[str, str] | None = None, + episode_ends_path: str | None = "meta/episode_ends", + overwrite: bool = True, + ) -> "RefinerPipeline": + return self.with_sink( + ZarrSink( + output=output, + arrays=arrays, + episode_ends_path=episode_ends_path, + overwrite=overwrite, + ) + ) + def __iter__(self) -> Iterator[Row]: return iter(self.iter_rows()) diff --git a/src/refiner/pipeline/sinks/__init__.py b/src/refiner/pipeline/sinks/__init__.py index 8c1f26df..f0623f21 100644 --- a/src/refiner/pipeline/sinks/__init__.py +++ b/src/refiner/pipeline/sinks/__init__.py @@ -2,6 +2,7 @@ from refiner.pipeline.sinks.jsonl import JsonlSink from refiner.pipeline.sinks.parquet import ParquetSink from refiner.pipeline.sinks.reducer import FileCleanupReducerSink, LeRobotMetaReduceSink +from refiner.pipeline.sinks.zarr import ZarrSink __all__ = [ "BaseSink", @@ -10,4 +11,5 @@ "JsonlSink", "LeRobotMetaReduceSink", "ParquetSink", + "ZarrSink", ] diff --git a/src/refiner/pipeline/sinks/zarr.py b/src/refiner/pipeline/sinks/zarr.py new file mode 100644 index 00000000..e3beb498 --- /dev/null +++ b/src/refiner/pipeline/sinks/zarr.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +from collections.abc import Iterable, Mapping +from typing import Any, cast + +import numpy as np +import pyarrow as pa + +from refiner.io.datafolder import DataFolder, DataFolderLike +from refiner.pipeline.data.block import Block +from refiner.pipeline.data.row import Row +from refiner.pipeline.data.tabular import Tabular +from refiner.pipeline.sinks.base import BaseSink +from refiner.robotics.row import RoboticsRow +from refiner.utils import check_required_dependencies + + +class ZarrSink(BaseSink): + def __init__( + self, + output: DataFolderLike, + *, + arrays: Mapping[str, str] | None = None, + episode_ends_path: str | None = "meta/episode_ends", + overwrite: bool = True, + ): + self.output = DataFolder.resolve(output) + self.arrays = dict(arrays) if arrays is not None else None + self.episode_ends_path = episode_ends_path + self.overwrite = overwrite + self._chunks: dict[str, list[np.ndarray]] = {} + self._episode_ends: list[int] = [] + + def write_shard_block(self, shard_id: str, block: Block) -> int: + del shard_id + rows = list(block) if isinstance(block, Tabular) else block + for row in rows: + self._write_row(row) + return len(rows) + + def _write_row(self, row: Row) -> None: + arrays = self.arrays or _default_robotics_arrays(row) + lengths: list[int] = [] + for zarr_path, source_key in arrays.items(): + value = _row_value(row, source_key) + if value is None: + continue + array = _as_array(value) + if array.ndim == 0: + array = array.reshape(1) + lengths.append(int(array.shape[0])) + self._chunks.setdefault(zarr_path, []).append(array) + if lengths and self.episode_ends_path is not None: + length = lengths[0] + if any(item != length for item in lengths): + raise ValueError("Zarr arrays for one row must have matching lengths") + end = (self._episode_ends[-1] if self._episode_ends else 0) + length + self._episode_ends.append(end) + + def close(self) -> None: + if not self._chunks and not self._episode_ends: + return + check_required_dependencies("write_zarr", ["zarr"], dist="zarr") + import zarr + + mode = "w" if self.overwrite else "w-" + root = zarr.open_group(self.output.abs_path(), mode=mode) + for path, chunks in self._chunks.items(): + root.create_dataset(path, data=np.concatenate(chunks, axis=0)) + if self.episode_ends_path is not None: + root.create_dataset( + self.episode_ends_path, + data=np.asarray(self._episode_ends, dtype=np.int64), + ) + + def describe(self) -> tuple[str, str, dict[str, object]]: + return ( + "write_zarr", + "writer", + { + "path": self.output.abs_path(), + "arrays": dict(self.arrays) if self.arrays is not None else None, + "episode_ends_path": self.episode_ends_path, + "overwrite": self.overwrite, + }, + ) + + +def _default_robotics_arrays(row: Row) -> dict[str, str]: + if not isinstance(row, RoboticsRow): + raise ValueError("write_zarr requires arrays=... for non-RoboticsRow inputs") + arrays: dict[str, str] = {} + if row.actions is not None: + arrays["data/action"] = "action" + if row.states is not None: + arrays["data/observation.state"] = "observation.state" + if row.timestamps is not None: + arrays["data/timestamp"] = "timestamp" + return arrays + + +def _row_value(row: Row, key: str) -> Any: + if isinstance(row, RoboticsRow): + if key == "action": + return row.actions + if key == "observation.state": + return row.states + if key == "timestamp": + return row.timestamps + if key.startswith("observation."): + return row.observations(key) + return row[key] + + +def _as_array(value: Any) -> np.ndarray: + if isinstance(value, pa.ChunkedArray | pa.Array): + return np.asarray(value.to_pylist()) + if isinstance(value, Iterable) and not isinstance(value, str | bytes | np.ndarray): + return np.asarray(list(cast(Iterable[Any], value))) + return np.asarray(value) + + +__all__ = ["ZarrSink"] From a6d303d8fe7bee18ef98d0567adad02760b85569 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sat, 23 May 2026 13:29:18 +0200 Subject: [PATCH 02/39] Keep Zarr PR reader-only --- src/refiner/pipeline/pipeline.py | 19 +--- src/refiner/pipeline/sinks/__init__.py | 2 - src/refiner/pipeline/sinks/zarr.py | 123 ------------------------- 3 files changed, 1 insertion(+), 143 deletions(-) delete mode 100644 src/refiner/pipeline/sinks/zarr.py diff --git a/src/refiner/pipeline/pipeline.py b/src/refiner/pipeline/pipeline.py index 533434d5..6ace95d2 100644 --- a/src/refiner/pipeline/pipeline.py +++ b/src/refiner/pipeline/pipeline.py @@ -30,7 +30,7 @@ VectorizedSegmentStep, WithColumnsStep, ) -from refiner.pipeline.sinks import BaseSink, JsonlSink, ParquetSink, ZarrSink +from refiner.pipeline.sinks import BaseSink, JsonlSink, ParquetSink from refiner.pipeline.sinks.assets import MissingAssetPolicy from refiner.pipeline.sources import ( BaseSource, @@ -431,23 +431,6 @@ def write_parquet( ) ) - def write_zarr( - self, - output: DataFolderLike, - *, - arrays: Mapping[str, str] | None = None, - episode_ends_path: str | None = "meta/episode_ends", - overwrite: bool = True, - ) -> "RefinerPipeline": - return self.with_sink( - ZarrSink( - output=output, - arrays=arrays, - episode_ends_path=episode_ends_path, - overwrite=overwrite, - ) - ) - def __iter__(self) -> Iterator[Row]: return iter(self.iter_rows()) diff --git a/src/refiner/pipeline/sinks/__init__.py b/src/refiner/pipeline/sinks/__init__.py index f0623f21..8c1f26df 100644 --- a/src/refiner/pipeline/sinks/__init__.py +++ b/src/refiner/pipeline/sinks/__init__.py @@ -2,7 +2,6 @@ from refiner.pipeline.sinks.jsonl import JsonlSink from refiner.pipeline.sinks.parquet import ParquetSink from refiner.pipeline.sinks.reducer import FileCleanupReducerSink, LeRobotMetaReduceSink -from refiner.pipeline.sinks.zarr import ZarrSink __all__ = [ "BaseSink", @@ -11,5 +10,4 @@ "JsonlSink", "LeRobotMetaReduceSink", "ParquetSink", - "ZarrSink", ] diff --git a/src/refiner/pipeline/sinks/zarr.py b/src/refiner/pipeline/sinks/zarr.py deleted file mode 100644 index e3beb498..00000000 --- a/src/refiner/pipeline/sinks/zarr.py +++ /dev/null @@ -1,123 +0,0 @@ -from __future__ import annotations - -from collections.abc import Iterable, Mapping -from typing import Any, cast - -import numpy as np -import pyarrow as pa - -from refiner.io.datafolder import DataFolder, DataFolderLike -from refiner.pipeline.data.block import Block -from refiner.pipeline.data.row import Row -from refiner.pipeline.data.tabular import Tabular -from refiner.pipeline.sinks.base import BaseSink -from refiner.robotics.row import RoboticsRow -from refiner.utils import check_required_dependencies - - -class ZarrSink(BaseSink): - def __init__( - self, - output: DataFolderLike, - *, - arrays: Mapping[str, str] | None = None, - episode_ends_path: str | None = "meta/episode_ends", - overwrite: bool = True, - ): - self.output = DataFolder.resolve(output) - self.arrays = dict(arrays) if arrays is not None else None - self.episode_ends_path = episode_ends_path - self.overwrite = overwrite - self._chunks: dict[str, list[np.ndarray]] = {} - self._episode_ends: list[int] = [] - - def write_shard_block(self, shard_id: str, block: Block) -> int: - del shard_id - rows = list(block) if isinstance(block, Tabular) else block - for row in rows: - self._write_row(row) - return len(rows) - - def _write_row(self, row: Row) -> None: - arrays = self.arrays or _default_robotics_arrays(row) - lengths: list[int] = [] - for zarr_path, source_key in arrays.items(): - value = _row_value(row, source_key) - if value is None: - continue - array = _as_array(value) - if array.ndim == 0: - array = array.reshape(1) - lengths.append(int(array.shape[0])) - self._chunks.setdefault(zarr_path, []).append(array) - if lengths and self.episode_ends_path is not None: - length = lengths[0] - if any(item != length for item in lengths): - raise ValueError("Zarr arrays for one row must have matching lengths") - end = (self._episode_ends[-1] if self._episode_ends else 0) + length - self._episode_ends.append(end) - - def close(self) -> None: - if not self._chunks and not self._episode_ends: - return - check_required_dependencies("write_zarr", ["zarr"], dist="zarr") - import zarr - - mode = "w" if self.overwrite else "w-" - root = zarr.open_group(self.output.abs_path(), mode=mode) - for path, chunks in self._chunks.items(): - root.create_dataset(path, data=np.concatenate(chunks, axis=0)) - if self.episode_ends_path is not None: - root.create_dataset( - self.episode_ends_path, - data=np.asarray(self._episode_ends, dtype=np.int64), - ) - - def describe(self) -> tuple[str, str, dict[str, object]]: - return ( - "write_zarr", - "writer", - { - "path": self.output.abs_path(), - "arrays": dict(self.arrays) if self.arrays is not None else None, - "episode_ends_path": self.episode_ends_path, - "overwrite": self.overwrite, - }, - ) - - -def _default_robotics_arrays(row: Row) -> dict[str, str]: - if not isinstance(row, RoboticsRow): - raise ValueError("write_zarr requires arrays=... for non-RoboticsRow inputs") - arrays: dict[str, str] = {} - if row.actions is not None: - arrays["data/action"] = "action" - if row.states is not None: - arrays["data/observation.state"] = "observation.state" - if row.timestamps is not None: - arrays["data/timestamp"] = "timestamp" - return arrays - - -def _row_value(row: Row, key: str) -> Any: - if isinstance(row, RoboticsRow): - if key == "action": - return row.actions - if key == "observation.state": - return row.states - if key == "timestamp": - return row.timestamps - if key.startswith("observation."): - return row.observations(key) - return row[key] - - -def _as_array(value: Any) -> np.ndarray: - if isinstance(value, pa.ChunkedArray | pa.Array): - return np.asarray(value.to_pylist()) - if isinstance(value, Iterable) and not isinstance(value, str | bytes | np.ndarray): - return np.asarray(list(cast(Iterable[Any], value))) - return np.asarray(value) - - -__all__ = ["ZarrSink"] From fb1161ab16a49e1b15e53fd3d00a4f42d579bb0a Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sat, 23 May 2026 13:33:56 +0200 Subject: [PATCH 03/39] Add Zarr writer --- src/refiner/pipeline/pipeline.py | 19 +++- src/refiner/pipeline/sinks/__init__.py | 2 + src/refiner/pipeline/sinks/zarr.py | 123 +++++++++++++++++++++++++ tests/readers/test_zarr_reader.py | 33 +++++++ 4 files changed, 176 insertions(+), 1 deletion(-) create mode 100644 src/refiner/pipeline/sinks/zarr.py diff --git a/src/refiner/pipeline/pipeline.py b/src/refiner/pipeline/pipeline.py index 6ace95d2..533434d5 100644 --- a/src/refiner/pipeline/pipeline.py +++ b/src/refiner/pipeline/pipeline.py @@ -30,7 +30,7 @@ VectorizedSegmentStep, WithColumnsStep, ) -from refiner.pipeline.sinks import BaseSink, JsonlSink, ParquetSink +from refiner.pipeline.sinks import BaseSink, JsonlSink, ParquetSink, ZarrSink from refiner.pipeline.sinks.assets import MissingAssetPolicy from refiner.pipeline.sources import ( BaseSource, @@ -431,6 +431,23 @@ def write_parquet( ) ) + def write_zarr( + self, + output: DataFolderLike, + *, + arrays: Mapping[str, str] | None = None, + episode_ends_path: str | None = "meta/episode_ends", + overwrite: bool = True, + ) -> "RefinerPipeline": + return self.with_sink( + ZarrSink( + output=output, + arrays=arrays, + episode_ends_path=episode_ends_path, + overwrite=overwrite, + ) + ) + def __iter__(self) -> Iterator[Row]: return iter(self.iter_rows()) diff --git a/src/refiner/pipeline/sinks/__init__.py b/src/refiner/pipeline/sinks/__init__.py index 8c1f26df..f0623f21 100644 --- a/src/refiner/pipeline/sinks/__init__.py +++ b/src/refiner/pipeline/sinks/__init__.py @@ -2,6 +2,7 @@ from refiner.pipeline.sinks.jsonl import JsonlSink from refiner.pipeline.sinks.parquet import ParquetSink from refiner.pipeline.sinks.reducer import FileCleanupReducerSink, LeRobotMetaReduceSink +from refiner.pipeline.sinks.zarr import ZarrSink __all__ = [ "BaseSink", @@ -10,4 +11,5 @@ "JsonlSink", "LeRobotMetaReduceSink", "ParquetSink", + "ZarrSink", ] diff --git a/src/refiner/pipeline/sinks/zarr.py b/src/refiner/pipeline/sinks/zarr.py new file mode 100644 index 00000000..e3beb498 --- /dev/null +++ b/src/refiner/pipeline/sinks/zarr.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +from collections.abc import Iterable, Mapping +from typing import Any, cast + +import numpy as np +import pyarrow as pa + +from refiner.io.datafolder import DataFolder, DataFolderLike +from refiner.pipeline.data.block import Block +from refiner.pipeline.data.row import Row +from refiner.pipeline.data.tabular import Tabular +from refiner.pipeline.sinks.base import BaseSink +from refiner.robotics.row import RoboticsRow +from refiner.utils import check_required_dependencies + + +class ZarrSink(BaseSink): + def __init__( + self, + output: DataFolderLike, + *, + arrays: Mapping[str, str] | None = None, + episode_ends_path: str | None = "meta/episode_ends", + overwrite: bool = True, + ): + self.output = DataFolder.resolve(output) + self.arrays = dict(arrays) if arrays is not None else None + self.episode_ends_path = episode_ends_path + self.overwrite = overwrite + self._chunks: dict[str, list[np.ndarray]] = {} + self._episode_ends: list[int] = [] + + def write_shard_block(self, shard_id: str, block: Block) -> int: + del shard_id + rows = list(block) if isinstance(block, Tabular) else block + for row in rows: + self._write_row(row) + return len(rows) + + def _write_row(self, row: Row) -> None: + arrays = self.arrays or _default_robotics_arrays(row) + lengths: list[int] = [] + for zarr_path, source_key in arrays.items(): + value = _row_value(row, source_key) + if value is None: + continue + array = _as_array(value) + if array.ndim == 0: + array = array.reshape(1) + lengths.append(int(array.shape[0])) + self._chunks.setdefault(zarr_path, []).append(array) + if lengths and self.episode_ends_path is not None: + length = lengths[0] + if any(item != length for item in lengths): + raise ValueError("Zarr arrays for one row must have matching lengths") + end = (self._episode_ends[-1] if self._episode_ends else 0) + length + self._episode_ends.append(end) + + def close(self) -> None: + if not self._chunks and not self._episode_ends: + return + check_required_dependencies("write_zarr", ["zarr"], dist="zarr") + import zarr + + mode = "w" if self.overwrite else "w-" + root = zarr.open_group(self.output.abs_path(), mode=mode) + for path, chunks in self._chunks.items(): + root.create_dataset(path, data=np.concatenate(chunks, axis=0)) + if self.episode_ends_path is not None: + root.create_dataset( + self.episode_ends_path, + data=np.asarray(self._episode_ends, dtype=np.int64), + ) + + def describe(self) -> tuple[str, str, dict[str, object]]: + return ( + "write_zarr", + "writer", + { + "path": self.output.abs_path(), + "arrays": dict(self.arrays) if self.arrays is not None else None, + "episode_ends_path": self.episode_ends_path, + "overwrite": self.overwrite, + }, + ) + + +def _default_robotics_arrays(row: Row) -> dict[str, str]: + if not isinstance(row, RoboticsRow): + raise ValueError("write_zarr requires arrays=... for non-RoboticsRow inputs") + arrays: dict[str, str] = {} + if row.actions is not None: + arrays["data/action"] = "action" + if row.states is not None: + arrays["data/observation.state"] = "observation.state" + if row.timestamps is not None: + arrays["data/timestamp"] = "timestamp" + return arrays + + +def _row_value(row: Row, key: str) -> Any: + if isinstance(row, RoboticsRow): + if key == "action": + return row.actions + if key == "observation.state": + return row.states + if key == "timestamp": + return row.timestamps + if key.startswith("observation."): + return row.observations(key) + return row[key] + + +def _as_array(value: Any) -> np.ndarray: + if isinstance(value, pa.ChunkedArray | pa.Array): + return np.asarray(value.to_pylist()) + if isinstance(value, Iterable) and not isinstance(value, str | bytes | np.ndarray): + return np.asarray(list(cast(Iterable[Any], value))) + return np.asarray(value) + + +__all__ = ["ZarrSink"] diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py index 3948983b..3e553759 100644 --- a/tests/readers/test_zarr_reader.py +++ b/tests/readers/test_zarr_reader.py @@ -637,6 +637,7 @@ def test_read_zarr_rejects_scalar_arrays_in_row_ends_mode(tmp_path: Path) -> Non def test_zarr_to_robot_rows_and_lerobot_roundtrip(tmp_path: Path) -> None: path = tmp_path / "policy.zarr" lerobot_out = tmp_path / "lerobot" + zarr_out = tmp_path / "roundtrip.zarr" _write_policy_zarr(path) ( @@ -673,3 +674,35 @@ def test_zarr_to_robot_rows_and_lerobot_roundtrip(tmp_path: Path) -> None: episodes.sort(key=lambda episode: int(episode.episode_id)) assert [episode.num_frames for episode in episodes] == [2, 3] assert [episode.task for episode in episodes] == ["push tee", "push tee"] + + ( + mdr.read_lerobot(str(lerobot_out)) + .write_zarr( + str(zarr_out), + arrays={ + "data/action": "action", + "data/state": "observation.state", + }, + ) + .launch_local( + name="lerobot-to-zarr", num_workers=1, rundir=str(tmp_path / "run2") + ) + ) + + row = mdr.read_zarr( + zarr_out, + arrays={ + "action": "data/action", + "state": "data/state", + "episode_ends": "meta/episode_ends", + }, + file_path_column=None, + ).take(1)[0] + + assert row["episode_ends"].tolist() == [2, 5] + np.testing.assert_allclose( + row["action"], np.asarray([[0.0], [0.1], [1.0], [1.1], [1.2]]) + ) + np.testing.assert_allclose( + row["state"], np.asarray([[10.0], [10.1], [20.0], [20.1], [20.2]]) + ) From 4c12e60951254de4823d8db5daff93ddccf68528 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sat, 23 May 2026 13:40:52 +0200 Subject: [PATCH 04/39] Stream Zarr writer outputs --- src/refiner/pipeline/pipeline.py | 2 + src/refiner/pipeline/sinks/zarr.py | 83 +++++++++++++++++++++--------- tests/readers/test_zarr_reader.py | 3 +- 3 files changed, 64 insertions(+), 24 deletions(-) diff --git a/src/refiner/pipeline/pipeline.py b/src/refiner/pipeline/pipeline.py index 533434d5..43ac63b6 100644 --- a/src/refiner/pipeline/pipeline.py +++ b/src/refiner/pipeline/pipeline.py @@ -437,6 +437,7 @@ def write_zarr( *, arrays: Mapping[str, str] | None = None, episode_ends_path: str | None = "meta/episode_ends", + store_template: str = "{shard_id}__w{worker_id}.zarr", overwrite: bool = True, ) -> "RefinerPipeline": return self.with_sink( @@ -444,6 +445,7 @@ def write_zarr( output=output, arrays=arrays, episode_ends_path=episode_ends_path, + store_template=store_template, overwrite=overwrite, ) ) diff --git a/src/refiner/pipeline/sinks/zarr.py b/src/refiner/pipeline/sinks/zarr.py index e3beb498..105bb540 100644 --- a/src/refiner/pipeline/sinks/zarr.py +++ b/src/refiner/pipeline/sinks/zarr.py @@ -1,6 +1,7 @@ from __future__ import annotations from collections.abc import Iterable, Mapping +from dataclasses import dataclass, field from typing import Any, cast import numpy as np @@ -9,10 +10,17 @@ from refiner.io.datafolder import DataFolder, DataFolderLike from refiner.pipeline.data.block import Block from refiner.pipeline.data.row import Row -from refiner.pipeline.data.tabular import Tabular from refiner.pipeline.sinks.base import BaseSink from refiner.robotics.row import RoboticsRow from refiner.utils import check_required_dependencies +from refiner.worker.context import get_active_worker_token + + +@dataclass +class _ShardStore: + root: Any + arrays: dict[str, Any] = field(default_factory=dict) + row_end: int = 0 class ZarrSink(BaseSink): @@ -22,56 +30,84 @@ def __init__( *, arrays: Mapping[str, str] | None = None, episode_ends_path: str | None = "meta/episode_ends", + store_template: str = "{shard_id}__w{worker_id}.zarr", overwrite: bool = True, ): + check_required_dependencies("write_zarr", ["zarr"], dist="zarr") self.output = DataFolder.resolve(output) self.arrays = dict(arrays) if arrays is not None else None self.episode_ends_path = episode_ends_path + self.store_template = store_template self.overwrite = overwrite - self._chunks: dict[str, list[np.ndarray]] = {} - self._episode_ends: list[int] = [] + self._stores: dict[str, _ShardStore] = {} def write_shard_block(self, shard_id: str, block: Block) -> int: - del shard_id - rows = list(block) if isinstance(block, Tabular) else block + rows = block + count = 0 for row in rows: - self._write_row(row) - return len(rows) + self._write_row(shard_id, row) + count += 1 + return count - def _write_row(self, row: Row) -> None: + def _write_row(self, shard_id: str, row: Row) -> None: arrays = self.arrays or _default_robotics_arrays(row) + store = self._store(shard_id) lengths: list[int] = [] for zarr_path, source_key in arrays.items(): value = _row_value(row, source_key) if value is None: - continue + raise ValueError(f"Zarr source value is missing: {source_key}") array = _as_array(value) if array.ndim == 0: array = array.reshape(1) lengths.append(int(array.shape[0])) - self._chunks.setdefault(zarr_path, []).append(array) + self._append_array(store, zarr_path, array) if lengths and self.episode_ends_path is not None: length = lengths[0] if any(item != length for item in lengths): raise ValueError("Zarr arrays for one row must have matching lengths") - end = (self._episode_ends[-1] if self._episode_ends else 0) + length - self._episode_ends.append(end) + store.row_end += length + self._append_array( + store, + self.episode_ends_path, + np.asarray([store.row_end], dtype=np.int64), + ) - def close(self) -> None: - if not self._chunks and not self._episode_ends: - return - check_required_dependencies("write_zarr", ["zarr"], dist="zarr") + def _store(self, shard_id: str) -> _ShardStore: + relpath = self.store_template.format( + shard_id=shard_id, + worker_id=get_active_worker_token(), + ) + store = self._stores.get(relpath) + if store is not None: + return store import zarr mode = "w" if self.overwrite else "w-" - root = zarr.open_group(self.output.abs_path(), mode=mode) - for path, chunks in self._chunks.items(): - root.create_dataset(path, data=np.concatenate(chunks, axis=0)) - if self.episode_ends_path is not None: - root.create_dataset( - self.episode_ends_path, - data=np.asarray(self._episode_ends, dtype=np.int64), + store = _ShardStore(zarr.open_group(self.output.abs_path(relpath), mode=mode)) + self._stores[relpath] = store + return store + + def _append_array( + self, + store: _ShardStore, + path: str, + array: np.ndarray, + ) -> None: + dataset = store.arrays.get(path) + if dataset is None: + chunks = (max(1, min(int(array.shape[0]), 1024)), *array.shape[1:]) + dataset = store.root.create_dataset( + path, + shape=(0, *array.shape[1:]), + chunks=chunks, + dtype=array.dtype, ) + store.arrays[path] = dataset + dataset.append(array, axis=0) + + def close(self) -> None: + self._stores.clear() def describe(self) -> tuple[str, str, dict[str, object]]: return ( @@ -81,6 +117,7 @@ def describe(self) -> tuple[str, str, dict[str, object]]: "path": self.output.abs_path(), "arrays": dict(self.arrays) if self.arrays is not None else None, "episode_ends_path": self.episode_ends_path, + "store_template": self.store_template, "overwrite": self.overwrite, }, ) diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py index 3e553759..df095e14 100644 --- a/tests/readers/test_zarr_reader.py +++ b/tests/readers/test_zarr_reader.py @@ -689,8 +689,9 @@ def test_zarr_to_robot_rows_and_lerobot_roundtrip(tmp_path: Path) -> None: ) ) + zarr_store = next(zarr_out.glob("*.zarr")) row = mdr.read_zarr( - zarr_out, + zarr_store, arrays={ "action": "data/action", "state": "data/state", From cf2e0cccc341e313871bb7f34204638ea7dea990 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sat, 23 May 2026 13:55:36 +0200 Subject: [PATCH 05/39] Harden Zarr writer outputs --- src/refiner/pipeline/sinks/reducer/file.py | 16 ++++++- src/refiner/pipeline/sinks/zarr.py | 27 +++++++++-- tests/pipeline/test_sinks.py | 36 +++++++++++++++ tests/readers/test_zarr_reader.py | 54 ++++++++++++++++------ 4 files changed, 112 insertions(+), 21 deletions(-) diff --git a/src/refiner/pipeline/sinks/reducer/file.py b/src/refiner/pipeline/sinks/reducer/file.py index 2225301e..24584fcb 100644 --- a/src/refiner/pipeline/sinks/reducer/file.py +++ b/src/refiner/pipeline/sinks/reducer/file.py @@ -67,11 +67,13 @@ def __init__( filename_template: str, reducer_name: str, assets_subdir: str | None = None, + recursive: bool = False, ) -> None: self.output = DataFolder.resolve(output) self.filename_template = filename_template self.reducer_name = reducer_name self.assets_subdir = assets_subdir + self.recursive = recursive self._managed_path_pattern = _compile_managed_path_pattern(filename_template) self._cleanup_ran = False @@ -90,6 +92,8 @@ def describe(self) -> tuple[str, str, dict[str, object]]: } if self.assets_subdir is not None: args["assets_subdir"] = self.assets_subdir + if self.recursive: + args["recursive"] = True return ( self.reducer_name, "writer", @@ -127,6 +131,7 @@ def _run_cleanup(self) -> None: ) removed_asset_attempts: set[str] = set() + removed_managed_paths: set[str] = set() # Extra template fields are treated as structure only. Authority is decided # solely from the finalized (shard_id, worker_id) pair extracted from each # managed path. @@ -151,13 +156,20 @@ def _run_cleanup(self) -> None: continue continue - match = self._managed_path_pattern.fullmatch(rel_path) + managed_path = rel_path + match = self._managed_path_pattern.fullmatch(managed_path) + if match is None and self.recursive: + managed_path = rel_path.split("/", maxsplit=1)[0] + match = self._managed_path_pattern.fullmatch(managed_path) if match is None: continue if (match.group("shard_id"), match.group("worker_id")) in keep_pairs: continue + if managed_path in removed_managed_paths: + continue + removed_managed_paths.add(managed_path) try: - self.output.rm(rel_path) + self.output.rm(managed_path, recursive=self.recursive) except FileNotFoundError: continue diff --git a/src/refiner/pipeline/sinks/zarr.py b/src/refiner/pipeline/sinks/zarr.py index 105bb540..dd19ef3a 100644 --- a/src/refiner/pipeline/sinks/zarr.py +++ b/src/refiner/pipeline/sinks/zarr.py @@ -8,9 +8,11 @@ import pyarrow as pa from refiner.io.datafolder import DataFolder, DataFolderLike +from refiner.io.zarr import zarr_store from refiner.pipeline.data.block import Block from refiner.pipeline.data.row import Row from refiner.pipeline.sinks.base import BaseSink +from refiner.pipeline.sinks.reducer.file import FileCleanupReducerSink from refiner.robotics.row import RoboticsRow from refiner.utils import check_required_dependencies from refiner.worker.context import get_active_worker_token @@ -42,9 +44,8 @@ def __init__( self._stores: dict[str, _ShardStore] = {} def write_shard_block(self, shard_id: str, block: Block) -> int: - rows = block count = 0 - for row in rows: + for row in block: self._write_row(shard_id, row) count += 1 return count @@ -81,10 +82,14 @@ def _store(self, shard_id: str) -> _ShardStore: store = self._stores.get(relpath) if store is not None: return store + mode = "w" if self.overwrite else "w-" import zarr - mode = "w" if self.overwrite else "w-" - store = _ShardStore(zarr.open_group(self.output.abs_path(relpath), mode=mode)) + store = _ShardStore( + zarr.open_group( + store=zarr_store(self.output, relpath, mode=mode), mode=mode + ) + ) self._stores[relpath] = store return store @@ -122,6 +127,14 @@ def describe(self) -> tuple[str, str, dict[str, object]]: }, ) + def build_reducer(self) -> BaseSink | None: + return FileCleanupReducerSink( + output=self.output, + filename_template=self.store_template, + reducer_name="write_zarr_reduce", + recursive=True, + ) + def _default_robotics_arrays(row: Row) -> dict[str, str]: if not isinstance(row, RoboticsRow): @@ -150,7 +163,11 @@ def _row_value(row: Row, key: str) -> Any: def _as_array(value: Any) -> np.ndarray: - if isinstance(value, pa.ChunkedArray | pa.Array): + if isinstance(value, pa.ChunkedArray): + return _as_array(value.combine_chunks()) + if isinstance(value, pa.Array): + if pa.types.is_primitive(value.type): + return value.to_numpy(zero_copy_only=False) return np.asarray(value.to_pylist()) if isinstance(value, Iterable) and not isinstance(value, str | bytes | np.ndarray): return np.asarray(list(cast(Iterable[Any], value))) diff --git a/tests/pipeline/test_sinks.py b/tests/pipeline/test_sinks.py index 3aaee883..aac9d6d4 100644 --- a/tests/pipeline/test_sinks.py +++ b/tests/pipeline/test_sinks.py @@ -929,6 +929,42 @@ def test_file_cleanup_reducer_ignores_extra_template_fields(tmp_path) -> None: assert unmanaged_file.exists() +def test_file_cleanup_reducer_removes_non_finalized_directories(tmp_path) -> None: + output_dir = tmp_path / "zarr-cleanup" + shard_id = "0123456789ab" + winner_worker_id = "worker-2" + loser_worker_id = "worker-1" + winner_dir = output_dir / f"{shard_id}__w{worker_token_for(winner_worker_id)}.zarr" + loser_dir = output_dir / f"{shard_id}__w{worker_token_for(loser_worker_id)}.zarr" + (winner_dir / "data").mkdir(parents=True) + (loser_dir / "data").mkdir(parents=True) + (winner_dir / "data" / "0").write_bytes(b"keep") + (loser_dir / "data" / "0").write_bytes(b"drop") + + reducer = FileCleanupReducerSink( + output_dir, + filename_template="{shard_id}__w{worker_id}.zarr", + reducer_name="cleanup_zarr", + recursive=True, + ) + with set_active_run_context( + job_id="job", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast( + RuntimeLifecycle, + _FinalizedWorkersRuntime( + [FinalizedShardWorker(shard_id=shard_id, worker_id=winner_worker_id)] + ), + ), + ): + reducer.write_block([DictRow({"task_rank": 0}, shard_id="reduce")]) + + assert winner_dir.exists() + assert not loser_dir.exists() + + def test_file_cleanup_reducer_tolerates_duplicate_listed_paths( tmp_path, monkeypatch ) -> None: diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py index df095e14..e84f030d 100644 --- a/tests/readers/test_zarr_reader.py +++ b/tests/readers/test_zarr_reader.py @@ -53,14 +53,6 @@ def _write_policy_zarr(path: Path) -> None: root.attrs["task"] = "push tee" -def test_read_zarr_rejects_reserved_file_path_output_name(tmp_path: Path) -> None: - path = tmp_path / "policy.zarr" - _write_policy_zarr(path) - - with pytest.raises(ValueError, match="reserved output names"): - mdr.read_zarr(path, arrays={"file_path": "data/action"}) - - def test_read_zarr_reads_selected_arrays_and_attrs(tmp_path: Path) -> None: path = tmp_path / "policy.zarr" _write_policy_zarr(path) @@ -637,7 +629,6 @@ def test_read_zarr_rejects_scalar_arrays_in_row_ends_mode(tmp_path: Path) -> Non def test_zarr_to_robot_rows_and_lerobot_roundtrip(tmp_path: Path) -> None: path = tmp_path / "policy.zarr" lerobot_out = tmp_path / "lerobot" - zarr_out = tmp_path / "roundtrip.zarr" _write_policy_zarr(path) ( @@ -675,6 +666,40 @@ def test_zarr_to_robot_rows_and_lerobot_roundtrip(tmp_path: Path) -> None: assert [episode.num_frames for episode in episodes] == [2, 3] assert [episode.task for episode in episodes] == ["push tee", "push tee"] + +def test_write_zarr_roundtrips_lerobot_rows(tmp_path: Path) -> None: + path = tmp_path / "policy.zarr" + lerobot_out = tmp_path / "lerobot" + zarr_out = tmp_path / "roundtrip.zarr" + _write_policy_zarr(path) + + ( + mdr.read_zarr( + path, + arrays={ + "action": "data/action", + "observation.state": "data/state", + "frames": "data/rgb", + }, + attrs={"task": "task"}, + row_ends="meta/episode_ends", + file_path_column=None, + ) + .to_robot_rows( + episode_id_key="row_index", + task_key="task", + action_key="action", + state_key="observation.state", + video_keys={"observation.images.front": "frames"}, + fps=10, + robot_type="pusht", + ) + .write_lerobot(str(lerobot_out), max_video_prepare_in_flight=1) + .launch_local( + name="zarr-to-lerobot", num_workers=1, rundir=str(tmp_path / "run1") + ) + ) + ( mdr.read_lerobot(str(lerobot_out)) .write_zarr( @@ -700,10 +725,11 @@ def test_zarr_to_robot_rows_and_lerobot_roundtrip(tmp_path: Path) -> None: file_path_column=None, ).take(1)[0] - assert row["episode_ends"].tolist() == [2, 5] - np.testing.assert_allclose( - row["action"], np.asarray([[0.0], [0.1], [1.0], [1.1], [1.2]]) - ) + episode_ends = row["episode_ends"].tolist() + assert episode_ends[-1] == 5 + assert sorted(np.diff([0, *episode_ends]).tolist()) == [2, 3] + assert row["action"].shape == (5, 1) np.testing.assert_allclose( - row["state"], np.asarray([[10.0], [10.1], [20.0], [20.1], [20.2]]) + np.sort(row["state"].reshape(-1)), + np.asarray([10.0, 10.1, 20.0, 20.1, 20.2]), ) From e25d100b85329c63c70399b04833ab2b73f3e37a Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sat, 23 May 2026 14:06:43 +0200 Subject: [PATCH 06/39] Batch and clean Zarr writer stores --- src/refiner/pipeline/sinks/reducer/file.py | 9 +++- src/refiner/pipeline/sinks/zarr.py | 48 +++++++++++++++------- tests/pipeline/test_sinks.py | 42 +++++++++++++++++++ 3 files changed, 83 insertions(+), 16 deletions(-) diff --git a/src/refiner/pipeline/sinks/reducer/file.py b/src/refiner/pipeline/sinks/reducer/file.py index 24584fcb..ccf7c6c8 100644 --- a/src/refiner/pipeline/sinks/reducer/file.py +++ b/src/refiner/pipeline/sinks/reducer/file.py @@ -159,8 +159,13 @@ def _run_cleanup(self) -> None: managed_path = rel_path match = self._managed_path_pattern.fullmatch(managed_path) if match is None and self.recursive: - managed_path = rel_path.split("/", maxsplit=1)[0] - match = self._managed_path_pattern.fullmatch(managed_path) + parts = rel_path.split("/") + for index in range(1, len(parts)): + candidate = "/".join(parts[:index]) + match = self._managed_path_pattern.fullmatch(candidate) + if match is not None: + managed_path = candidate + break if match is None: continue if (match.group("shard_id"), match.group("worker_id")) in keep_pairs: diff --git a/src/refiner/pipeline/sinks/zarr.py b/src/refiner/pipeline/sinks/zarr.py index dd19ef3a..c676682a 100644 --- a/src/refiner/pipeline/sinks/zarr.py +++ b/src/refiner/pipeline/sinks/zarr.py @@ -44,15 +44,32 @@ def __init__( self._stores: dict[str, _ShardStore] = {} def write_shard_block(self, shard_id: str, block: Block) -> int: + arrays_by_path: dict[str, list[np.ndarray]] = {} + episode_lengths: list[int] = [] count = 0 for row in block: - self._write_row(shard_id, row) + length = self._collect_row(row, arrays_by_path) + if length is not None: + episode_lengths.append(length) count += 1 + if arrays_by_path: + store = self._store(shard_id) + for zarr_path, arrays in arrays_by_path.items(): + self._append_array(store, zarr_path, np.concatenate(arrays, axis=0)) + if self.episode_ends_path is not None and episode_lengths: + episode_ends = store.row_end + np.cumsum( + np.asarray(episode_lengths, dtype=np.int64) + ) + store.row_end = int(episode_ends[-1]) + self._append_array(store, self.episode_ends_path, episode_ends) return count - def _write_row(self, shard_id: str, row: Row) -> None: + def _collect_row( + self, + row: Row, + arrays_by_path: dict[str, list[np.ndarray]], + ) -> int | None: arrays = self.arrays or _default_robotics_arrays(row) - store = self._store(shard_id) lengths: list[int] = [] for zarr_path, source_key in arrays.items(): value = _row_value(row, source_key) @@ -62,17 +79,13 @@ def _write_row(self, shard_id: str, row: Row) -> None: if array.ndim == 0: array = array.reshape(1) lengths.append(int(array.shape[0])) - self._append_array(store, zarr_path, array) - if lengths and self.episode_ends_path is not None: - length = lengths[0] - if any(item != length for item in lengths): - raise ValueError("Zarr arrays for one row must have matching lengths") - store.row_end += length - self._append_array( - store, - self.episode_ends_path, - np.asarray([store.row_end], dtype=np.int64), - ) + arrays_by_path.setdefault(zarr_path, []).append(array) + if not lengths: + return None + length = lengths[0] + if any(item != length for item in lengths): + raise ValueError("Zarr arrays for one row must have matching lengths") + return length def _store(self, shard_id: str) -> _ShardStore: relpath = self.store_template.format( @@ -93,6 +106,13 @@ def _store(self, shard_id: str) -> _ShardStore: self._stores[relpath] = store return store + def on_shard_complete(self, shard_id: str) -> None: + relpath = self.store_template.format( + shard_id=shard_id, + worker_id=get_active_worker_token(), + ) + self._stores.pop(relpath, None) + def _append_array( self, store: _ShardStore, diff --git a/tests/pipeline/test_sinks.py b/tests/pipeline/test_sinks.py index aac9d6d4..b2e6d5a8 100644 --- a/tests/pipeline/test_sinks.py +++ b/tests/pipeline/test_sinks.py @@ -965,6 +965,48 @@ def test_file_cleanup_reducer_removes_non_finalized_directories(tmp_path) -> Non assert not loser_dir.exists() +def test_file_cleanup_reducer_removes_non_finalized_nested_directories( + tmp_path, +) -> None: + output_dir = tmp_path / "zarr-cleanup-nested" + shard_id = "0123456789ab" + winner_worker_id = "worker-2" + loser_worker_id = "worker-1" + winner_dir = ( + output_dir / "split" / f"{shard_id}__w{worker_token_for(winner_worker_id)}.zarr" + ) + loser_dir = ( + output_dir / "split" / f"{shard_id}__w{worker_token_for(loser_worker_id)}.zarr" + ) + (winner_dir / "data").mkdir(parents=True) + (loser_dir / "data").mkdir(parents=True) + (winner_dir / "data" / "0").write_bytes(b"keep") + (loser_dir / "data" / "0").write_bytes(b"drop") + + reducer = FileCleanupReducerSink( + output_dir, + filename_template="split/{shard_id}__w{worker_id}.zarr", + reducer_name="cleanup_zarr", + recursive=True, + ) + with set_active_run_context( + job_id="job", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast( + RuntimeLifecycle, + _FinalizedWorkersRuntime( + [FinalizedShardWorker(shard_id=shard_id, worker_id=winner_worker_id)] + ), + ), + ): + reducer.write_block([DictRow({"task_rank": 0}, shard_id="reduce")]) + + assert winner_dir.exists() + assert not loser_dir.exists() + + def test_file_cleanup_reducer_tolerates_duplicate_listed_paths( tmp_path, monkeypatch ) -> None: From 917b4964274fd12c780c177fc5d3d49afd8ca227 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sat, 23 May 2026 14:14:12 +0200 Subject: [PATCH 07/39] Stabilize default Zarr writer arrays --- src/refiner/pipeline/sinks/zarr.py | 10 +++++++++- tests/readers/test_zarr_reader.py | 22 ++++++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/src/refiner/pipeline/sinks/zarr.py b/src/refiner/pipeline/sinks/zarr.py index c676682a..277ebe40 100644 --- a/src/refiner/pipeline/sinks/zarr.py +++ b/src/refiner/pipeline/sinks/zarr.py @@ -42,6 +42,7 @@ def __init__( self.store_template = store_template self.overwrite = overwrite self._stores: dict[str, _ShardStore] = {} + self._default_arrays: dict[str, str] | None = None def write_shard_block(self, shard_id: str, block: Block) -> int: arrays_by_path: dict[str, list[np.ndarray]] = {} @@ -69,7 +70,7 @@ def _collect_row( row: Row, arrays_by_path: dict[str, list[np.ndarray]], ) -> int | None: - arrays = self.arrays or _default_robotics_arrays(row) + arrays = self._arrays_for_row(row) lengths: list[int] = [] for zarr_path, source_key in arrays.items(): value = _row_value(row, source_key) @@ -87,6 +88,13 @@ def _collect_row( raise ValueError("Zarr arrays for one row must have matching lengths") return length + def _arrays_for_row(self, row: Row) -> dict[str, str]: + if self.arrays is not None: + return self.arrays + if self._default_arrays is None: + self._default_arrays = _default_robotics_arrays(row) + return self._default_arrays + def _store(self, shard_id: str) -> _ShardStore: relpath = self.store_template.format( shard_id=shard_id, diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py index e84f030d..93607f9e 100644 --- a/tests/readers/test_zarr_reader.py +++ b/tests/readers/test_zarr_reader.py @@ -14,6 +14,7 @@ from refiner.robotics.row import RoboticsRow from refiner.pipeline.data.row import Row from refiner.pipeline.data.shard import RowRangeDescriptor +from refiner.pipeline.sinks.zarr import ZarrSink def _open_test_zarr(path: Path, *, mode: Literal["r", "r+", "a", "w", "w-"]): @@ -733,3 +734,24 @@ def test_write_zarr_roundtrips_lerobot_rows(tmp_path: Path) -> None: np.sort(row["state"].reshape(-1)), np.asarray([10.0, 10.1, 20.0, 20.1, 20.2]), ) + + +def test_write_zarr_rejects_rows_missing_inferred_default_arrays( + tmp_path: Path, +) -> None: + output = tmp_path / "missing-default-array.zarr" + rows = list( + mdr.from_items( + [ + {"action": [[0.0]], "observation.state": [[1.0]]}, + {"action": [[0.1]]}, + ] + ).to_robot_rows( + action_key="action", + state_key="observation.state", + timestamp_key=None, + ) + ) + + with pytest.raises(ValueError, match="observation.state"): + ZarrSink(str(output)).write_block(rows) From 0457fc1155ece46eb57e0b6ceca2acca686fb55f Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sat, 23 May 2026 14:26:50 +0200 Subject: [PATCH 08/39] Validate Zarr writer schemas --- src/refiner/pipeline/sinks/zarr.py | 21 ++++++++++- tests/readers/test_zarr_reader.py | 59 +++++++++++++++++++++++------- 2 files changed, 65 insertions(+), 15 deletions(-) diff --git a/src/refiner/pipeline/sinks/zarr.py b/src/refiner/pipeline/sinks/zarr.py index 277ebe40..208a9962 100644 --- a/src/refiner/pipeline/sinks/zarr.py +++ b/src/refiner/pipeline/sinks/zarr.py @@ -39,6 +39,8 @@ def __init__( self.output = DataFolder.resolve(output) self.arrays = dict(arrays) if arrays is not None else None self.episode_ends_path = episode_ends_path + if self.arrays is not None: + _validate_array_paths(self.arrays, episode_ends_path) self.store_template = store_template self.overwrite = overwrite self._stores: dict[str, _ShardStore] = {} @@ -91,8 +93,15 @@ def _collect_row( def _arrays_for_row(self, row: Row) -> dict[str, str]: if self.arrays is not None: return self.arrays + default_arrays = _default_robotics_arrays(row) if self._default_arrays is None: - self._default_arrays = _default_robotics_arrays(row) + self._default_arrays = default_arrays + _validate_array_paths(self._default_arrays, self.episode_ends_path) + elif default_arrays != self._default_arrays: + raise ValueError( + "Zarr default arrays changed across rows; pass arrays=... " + "to write an explicit stable schema" + ) return self._default_arrays def _store(self, shard_id: str) -> _ShardStore: @@ -177,6 +186,16 @@ def _default_robotics_arrays(row: Row) -> dict[str, str]: return arrays +def _validate_array_paths( + arrays: Mapping[str, str], + episode_ends_path: str | None, +) -> None: + if episode_ends_path is not None and episode_ends_path in arrays: + raise ValueError( + f"Zarr array path collides with episode_ends_path: {episode_ends_path}" + ) + + def _row_value(row: Row, key: str) -> Any: if isinstance(row, RoboticsRow): if key == "action": diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py index 93607f9e..03c1f060 100644 --- a/tests/readers/test_zarr_reader.py +++ b/tests/readers/test_zarr_reader.py @@ -715,23 +715,27 @@ def test_write_zarr_roundtrips_lerobot_rows(tmp_path: Path) -> None: ) ) - zarr_store = next(zarr_out.glob("*.zarr")) - row = mdr.read_zarr( - zarr_store, - arrays={ - "action": "data/action", - "state": "data/state", - "episode_ends": "meta/episode_ends", - }, - file_path_column=None, - ).take(1)[0] + rows = [ + mdr.read_zarr( + zarr_store, + arrays={ + "action": "data/action", + "state": "data/state", + "episode_ends": "meta/episode_ends", + }, + file_path_column=None, + ).take(1)[0] + for zarr_store in sorted(zarr_out.glob("*.zarr")) + ] - episode_ends = row["episode_ends"].tolist() + episode_ends = np.concatenate([row["episode_ends"] for row in rows]).tolist() assert episode_ends[-1] == 5 assert sorted(np.diff([0, *episode_ends]).tolist()) == [2, 3] - assert row["action"].shape == (5, 1) + action = np.concatenate([row["action"] for row in rows]) + state = np.concatenate([row["state"] for row in rows]) + assert action.shape == (5, 1) np.testing.assert_allclose( - np.sort(row["state"].reshape(-1)), + np.sort(state.reshape(-1)), np.asarray([10.0, 10.1, 20.0, 20.1, 20.2]), ) @@ -753,5 +757,32 @@ def test_write_zarr_rejects_rows_missing_inferred_default_arrays( ) ) - with pytest.raises(ValueError, match="observation.state"): + with pytest.raises(ValueError, match="default arrays changed"): ZarrSink(str(output)).write_block(rows) + + +def test_write_zarr_rejects_late_default_arrays(tmp_path: Path) -> None: + output = tmp_path / "late-default-array.zarr" + rows = list( + mdr.from_items( + [ + {"action": [[0.0]]}, + {"action": [[0.1]], "observation.state": [[1.1]]}, + ] + ).to_robot_rows( + action_key="action", + state_key="observation.state", + timestamp_key=None, + ) + ) + + with pytest.raises(ValueError, match="default arrays changed"): + ZarrSink(str(output)).write_block(rows) + + +def test_write_zarr_rejects_episode_ends_path_collision(tmp_path: Path) -> None: + with pytest.raises(ValueError, match="collides with episode_ends_path"): + ZarrSink( + str(tmp_path / "collision.zarr"), + arrays={"meta/episode_ends": "action"}, + ) From 7836a3fd2ca3baf028fc34c2a289af659b08bfba Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sat, 23 May 2026 14:56:00 +0200 Subject: [PATCH 09/39] Stream Zarr writer appends --- src/refiner/pipeline/sinks/reducer/file.py | 36 +++++++++-- src/refiner/pipeline/sinks/zarr.py | 71 ++++++++++++++-------- tests/readers/test_zarr_reader.py | 63 +++++++++++++++++++ 3 files changed, 142 insertions(+), 28 deletions(-) diff --git a/src/refiner/pipeline/sinks/reducer/file.py b/src/refiner/pipeline/sinks/reducer/file.py index ccf7c6c8..7d26f45e 100644 --- a/src/refiner/pipeline/sinks/reducer/file.py +++ b/src/refiner/pipeline/sinks/reducer/file.py @@ -57,6 +57,19 @@ def _compile_managed_path_pattern(filename_template: str) -> re.Pattern[str]: return re.compile("^" + "".join(parts) + "$") +def _managed_listing_prefix(filename_template: str) -> str: + literal_prefix = "" + for literal_text, field_name, _format_spec, _conversion in Formatter().parse( + filename_template + ): + literal_prefix += literal_text + if field_name is not None: + break + if "/" not in literal_prefix: + return "" + return literal_prefix.rsplit("/", maxsplit=1)[0] + + class FileCleanupReducerSink(BaseSink): """Delete non-finalized deterministic file-sink outputs.""" @@ -75,6 +88,7 @@ def __init__( self.assets_subdir = assets_subdir self.recursive = recursive self._managed_path_pattern = _compile_managed_path_pattern(filename_template) + self._managed_listing_prefix = _managed_listing_prefix(filename_template) self._cleanup_ran = False def write_shard_block(self, shard_id, block) -> None: @@ -119,10 +133,7 @@ def _run_cleanup(self) -> None: for row in get_finalized_workers(stage_index=stage_index - 1) } - try: - listed_paths = self.output.find("") - except FileNotFoundError: - listed_paths = [] + listed_paths = list(self._listed_cleanup_paths()) assets_prefix = ( f"{self.assets_subdir.rstrip('/')}/" @@ -178,5 +189,22 @@ def _run_cleanup(self) -> None: except FileNotFoundError: continue + def _listed_cleanup_paths(self) -> list[str]: + if not self.recursive or self.assets_subdir is not None: + try: + return self.output.find("") + except FileNotFoundError: + return [] + + try: + paths = self.output.ls(self._managed_listing_prefix, detail=False) + except FileNotFoundError: + return [] + return [ + path + for path in paths + if isinstance(path, str) and not path.rstrip("/").endswith("/.") + ] + __all__ = ["FileCleanupReducerSink"] diff --git a/src/refiner/pipeline/sinks/zarr.py b/src/refiner/pipeline/sinks/zarr.py index 208a9962..c1d4d856 100644 --- a/src/refiner/pipeline/sinks/zarr.py +++ b/src/refiner/pipeline/sinks/zarr.py @@ -15,6 +15,7 @@ from refiner.pipeline.sinks.reducer.file import FileCleanupReducerSink from refiner.robotics.row import RoboticsRow from refiner.utils import check_required_dependencies +from refiner.video import VideoFrameArray from refiner.worker.context import get_active_worker_token @@ -47,32 +48,27 @@ def __init__( self._default_arrays: dict[str, str] | None = None def write_shard_block(self, shard_id: str, block: Block) -> int: - arrays_by_path: dict[str, list[np.ndarray]] = {} - episode_lengths: list[int] = [] count = 0 for row in block: - length = self._collect_row(row, arrays_by_path) - if length is not None: - episode_lengths.append(length) + row_arrays, length = self._row_arrays(row) + if row_arrays: + store = self._store(shard_id) + self._validate_store_append(store, row_arrays) + for zarr_path, array in row_arrays.items(): + self._append_array(store, zarr_path, array) + if self.episode_ends_path is not None and length is not None: + store.row_end += length + self._append_array( + store, + self.episode_ends_path, + np.asarray([store.row_end], dtype=np.int64), + ) count += 1 - if arrays_by_path: - store = self._store(shard_id) - for zarr_path, arrays in arrays_by_path.items(): - self._append_array(store, zarr_path, np.concatenate(arrays, axis=0)) - if self.episode_ends_path is not None and episode_lengths: - episode_ends = store.row_end + np.cumsum( - np.asarray(episode_lengths, dtype=np.int64) - ) - store.row_end = int(episode_ends[-1]) - self._append_array(store, self.episode_ends_path, episode_ends) return count - def _collect_row( - self, - row: Row, - arrays_by_path: dict[str, list[np.ndarray]], - ) -> int | None: + def _row_arrays(self, row: Row) -> tuple[dict[str, np.ndarray], int | None]: arrays = self._arrays_for_row(row) + row_arrays: dict[str, np.ndarray] = {} lengths: list[int] = [] for zarr_path, source_key in arrays.items(): value = _row_value(row, source_key) @@ -82,13 +78,13 @@ def _collect_row( if array.ndim == 0: array = array.reshape(1) lengths.append(int(array.shape[0])) - arrays_by_path.setdefault(zarr_path, []).append(array) + row_arrays[zarr_path] = array if not lengths: - return None + return {}, None length = lengths[0] if any(item != length for item in lengths): raise ValueError("Zarr arrays for one row must have matching lengths") - return length + return row_arrays, length def _arrays_for_row(self, row: Row) -> dict[str, str]: if self.arrays is not None: @@ -148,6 +144,22 @@ def _append_array( store.arrays[path] = dataset dataset.append(array, axis=0) + def _validate_store_append( + self, + store: _ShardStore, + row_arrays: Mapping[str, np.ndarray], + ) -> None: + for path, array in row_arrays.items(): + dataset = store.arrays.get(path) + if dataset is None: + continue + if tuple(dataset.shape[1:]) != tuple(array.shape[1:]): + raise ValueError( + f"Zarr arrays for {path!r} must have matching trailing shapes" + ) + if dataset.dtype != array.dtype: + raise ValueError(f"Zarr arrays for {path!r} must have matching dtypes") + def close(self) -> None: self._stores.clear() @@ -205,7 +217,18 @@ def _row_value(row: Row, key: str) -> Any: if key == "timestamp": return row.timestamps if key.startswith("observation."): - return row.observations(key) + try: + return row.observations(key) + except KeyError: + video = row.videos.get(key) + if isinstance(video, VideoFrameArray): + return np.asarray(list(video.iter_frame_arrays())) + if video is not None: + raise ValueError( + "write_zarr can only materialize video observations backed " + f"by frame arrays, got {type(video).__name__} for {key!r}" + ) + raise return row[key] diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py index 03c1f060..daaf66f6 100644 --- a/tests/readers/test_zarr_reader.py +++ b/tests/readers/test_zarr_reader.py @@ -12,6 +12,7 @@ import refiner as mdr from refiner.io.datafolder import DataFolder from refiner.robotics.row import RoboticsRow +from refiner.pipeline.data.row import DictRow from refiner.pipeline.data.row import Row from refiner.pipeline.data.shard import RowRangeDescriptor from refiner.pipeline.sinks.zarr import ZarrSink @@ -786,3 +787,65 @@ def test_write_zarr_rejects_episode_ends_path_collision(tmp_path: Path) -> None: str(tmp_path / "collision.zarr"), arrays={"meta/episode_ends": "action"}, ) + + +def test_write_zarr_rejects_shape_drift_before_appending_bad_row( + tmp_path: Path, +) -> None: + output = tmp_path / "shape-mismatch.zarr" + rows: list[Row] = [ + DictRow({"action": [[0.0]], "state": [[1.0, 2.0]]}, shard_id="shard"), + DictRow({"action": [[0.1]], "state": [[1.1]]}, shard_id="shard"), + ] + + with pytest.raises(ValueError, match="matching trailing shapes"): + ZarrSink( + str(output), + arrays={ + "data/action": "action", + "data/state": "state", + }, + ).write_block(rows) + + zarr_store = next(output.glob("*.zarr")) + row = mdr.read_zarr( + zarr_store, + arrays={"action": "data/action", "state": "data/state"}, + file_path_column=None, + ).take(1)[0] + assert row["action"].shape == (1, 1) + assert row["state"].shape == (1, 2) + + +def test_write_zarr_materializes_frame_array_videos(tmp_path: Path) -> None: + output = tmp_path / "video.zarr" + frames = np.arange(2 * 4 * 4 * 3, dtype=np.uint8).reshape(2, 4, 4, 3) + rows = list( + mdr.from_items( + [{"episode_id": "episode-1", "frames": frames, "action": [[0.0], [0.1]]}] + ).to_robot_rows( + episode_id_key="episode_id", + action_key="action", + state_key=None, + timestamp_key=None, + video_keys={"observation.images.front": "frames"}, + fps=10, + ) + ) + + ZarrSink( + str(output), + arrays={ + "data/action": "action", + "data/rgb": "observation.images.front", + }, + ).write_block(rows) + + zarr_store = next(output.glob("*.zarr")) + row = mdr.read_zarr( + zarr_store, + arrays={"action": "data/action", "rgb": "data/rgb"}, + file_path_column=None, + ).take(1)[0] + np.testing.assert_array_equal(row["rgb"], frames) + np.testing.assert_allclose(row["action"], [[0.0], [0.1]]) From c03209cfd145be7a6833a9ee10c247bb4ae96ac4 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sun, 24 May 2026 00:59:36 +0200 Subject: [PATCH 10/39] Document Zarr writer --- docs/reading-and-writing.md | 37 +++++++++++++++++++++++++++++++++++-- src/refiner/io/zarr.py | 25 +++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 2 deletions(-) create mode 100644 src/refiner/io/zarr.py diff --git a/docs/reading-and-writing.md b/docs/reading-and-writing.md index 22ff835e..5aeaa3e8 100644 --- a/docs/reading-and-writing.md +++ b/docs/reading-and-writing.md @@ -441,6 +441,7 @@ Built-in sinks: | `.write_jsonl(output, ...)` | JSON Lines files | one output file per worker/shard according to the filename template | | `.write_parquet(output, ...)` | Parquet files | columnar output with optional compression | | `.write_lerobot(output, ...)` | LeRobot-compatible robotics datasets | materializes frame/video assets and dataset metadata | +| `.write_zarr(output, ...)` | Zarr stores | one store per shard/worker according to the store template | Example: @@ -488,11 +489,43 @@ pipeline = pipeline.map( pipeline = pipeline.cast(video=mdr.datatype.video_path()) ``` +Use `write_zarr(...)` when you want chunked array output, usually for robotics +episode rows or replay-buffer style data: + +```python +import refiner as mdr + +( + mdr.read_lerobot("hf://datasets/user/robot-data") + .write_zarr( + "s3://my-bucket/robot-data-zarr/", + arrays={ + "data/action": "action", + "data/state": "observation.state", + }, + ) +) +``` + +The `arrays` mapping is from output Zarr path to source row key. For +`RoboticsRow` inputs, omitting `arrays` writes the available default robotics +arrays: actions, states, and timestamps. The default schema is inferred once and +later rows must expose the same fields. + +By default, `write_zarr(...)` also writes cumulative episode boundaries to +`meta/episode_ends`. Set `episode_ends_path=None` to omit them. + +Launched runs write isolated stores per shard/worker using +`store_template="{shard_id}__w{worker_id}.zarr"`. This avoids concurrent workers +mutating the same Zarr group. Read the resulting stores individually or merge +them in a later workflow if you need a single physical store. + When you run a writer through `launch_local(...)` or `launch_cloud(...)`, some sinks add a reducer stage after the main writer stage. For `write_jsonl(...)` and `write_parquet(...)`, that reducer removes stale shard/worker files and -uploaded asset attempt folders, keeping only finalized outputs. The output -prefix should therefore be dedicated to Refiner-managed files. +uploaded asset attempt folders, keeping only finalized outputs. `write_zarr(...)` +also removes stale shard/worker store directories. The output prefix should +therefore be dedicated to Refiner-managed files. ## What Python Functions Actually See diff --git a/src/refiner/io/zarr.py b/src/refiner/io/zarr.py new file mode 100644 index 00000000..bff3958d --- /dev/null +++ b/src/refiner/io/zarr.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from typing import Literal + +from refiner.io.datafolder import DataFolder + + +def zarr_store( + folder: DataFolder, + path: str = "", + *, + mode: Literal["r", "w", "w-", "a"] = "r", +): + import zarr + + create = mode in {"w", "w-", "a"} + return zarr.storage.FSStore( + folder._join(path), + fs=folder.fs, + mode=mode, + create=create, + ) + + +__all__ = ["zarr_store"] From e5226fc9d934401160638429b372ec2a4d7250b2 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sun, 24 May 2026 01:32:41 +0200 Subject: [PATCH 11/39] Stream video arrays in Zarr writer --- docs/reading-and-writing.md | 9 +- src/refiner/pipeline/pipeline.py | 4 + src/refiner/pipeline/sinks/zarr.py | 287 +++++++++++++++++++++++----- src/refiner/video/decode.py | 2 +- src/refiner/video/types.py | 21 +- src/refiner/video/writer.py | 2 +- src/refiner/worker/lifecycle.py | 29 ++- tests/readers/test_zarr_reader.py | 105 ++++++++++ tests/robotics/test_robotics_row.py | 9 +- tests/test_video_decode.py | 8 +- 10 files changed, 418 insertions(+), 58 deletions(-) diff --git a/docs/reading-and-writing.md b/docs/reading-and-writing.md index 5aeaa3e8..1a9f7b9b 100644 --- a/docs/reading-and-writing.md +++ b/docs/reading-and-writing.md @@ -510,15 +510,18 @@ import refiner as mdr The `arrays` mapping is from output Zarr path to source row key. For `RoboticsRow` inputs, omitting `arrays` writes the available default robotics arrays: actions, states, and timestamps. The default schema is inferred once and -later rows must expose the same fields. +later rows must expose the same fields. Video sources selected through `arrays` +are decoded as RGB frame arrays and appended in bounded batches controlled by +`video_frame_batch_size`. By default, `write_zarr(...)` also writes cumulative episode boundaries to `meta/episode_ends`. Set `episode_ends_path=None` to omit them. Launched runs write isolated stores per shard/worker using `store_template="{shard_id}__w{worker_id}.zarr"`. This avoids concurrent workers -mutating the same Zarr group. Read the resulting stores individually or merge -them in a later workflow if you need a single physical store. +mutating the same Zarr group. Read the resulting stores individually, or set +`reduce_to_single_store=True` to add a reducer stage that streams the shard-local +stores into one final Zarr group at the requested output path. When you run a writer through `launch_local(...)` or `launch_cloud(...)`, some sinks add a reducer stage after the main writer stage. For `write_jsonl(...)` diff --git a/src/refiner/pipeline/pipeline.py b/src/refiner/pipeline/pipeline.py index 43ac63b6..3fde27f1 100644 --- a/src/refiner/pipeline/pipeline.py +++ b/src/refiner/pipeline/pipeline.py @@ -438,6 +438,8 @@ def write_zarr( arrays: Mapping[str, str] | None = None, episode_ends_path: str | None = "meta/episode_ends", store_template: str = "{shard_id}__w{worker_id}.zarr", + video_frame_batch_size: int = 64, + reduce_to_single_store: bool = False, overwrite: bool = True, ) -> "RefinerPipeline": return self.with_sink( @@ -446,6 +448,8 @@ def write_zarr( arrays=arrays, episode_ends_path=episode_ends_path, store_template=store_template, + video_frame_batch_size=video_frame_batch_size, + reduce_to_single_store=reduce_to_single_store, overwrite=overwrite, ) ) diff --git a/src/refiner/pipeline/sinks/zarr.py b/src/refiner/pipeline/sinks/zarr.py index c1d4d856..8ff0f872 100644 --- a/src/refiner/pipeline/sinks/zarr.py +++ b/src/refiner/pipeline/sinks/zarr.py @@ -7,6 +7,7 @@ import numpy as np import pyarrow as pa +from refiner.execution.asyncio.runtime import submit from refiner.io.datafolder import DataFolder, DataFolderLike from refiner.io.zarr import zarr_store from refiner.pipeline.data.block import Block @@ -15,8 +16,12 @@ from refiner.pipeline.sinks.reducer.file import FileCleanupReducerSink from refiner.robotics.row import RoboticsRow from refiner.utils import check_required_dependencies -from refiner.video import VideoFrameArray -from refiner.worker.context import get_active_worker_token +from refiner.video import VideoSource +from refiner.worker.context import ( + get_active_stage_index, + get_active_worker_token, + get_finalized_workers, +) @dataclass @@ -34,15 +39,21 @@ def __init__( arrays: Mapping[str, str] | None = None, episode_ends_path: str | None = "meta/episode_ends", store_template: str = "{shard_id}__w{worker_id}.zarr", + video_frame_batch_size: int = 64, + reduce_to_single_store: bool = False, overwrite: bool = True, ): check_required_dependencies("write_zarr", ["zarr"], dist="zarr") + if video_frame_batch_size <= 0: + raise ValueError("video_frame_batch_size must be greater than zero") self.output = DataFolder.resolve(output) self.arrays = dict(arrays) if arrays is not None else None self.episode_ends_path = episode_ends_path if self.arrays is not None: _validate_array_paths(self.arrays, episode_ends_path) self.store_template = store_template + self.video_frame_batch_size = video_frame_batch_size + self.reduce_to_single_store = reduce_to_single_store self.overwrite = overwrite self._stores: dict[str, _ShardStore] = {} self._default_arrays: dict[str, str] | None = None @@ -50,41 +61,68 @@ def __init__( def write_shard_block(self, shard_id: str, block: Block) -> int: count = 0 for row in block: - row_arrays, length = self._row_arrays(row) - if row_arrays: - store = self._store(shard_id) - self._validate_store_append(store, row_arrays) - for zarr_path, array in row_arrays.items(): - self._append_array(store, zarr_path, array) - if self.episode_ends_path is not None and length is not None: - store.row_end += length - self._append_array( - store, - self.episode_ends_path, - np.asarray([store.row_end], dtype=np.int64), - ) + submit(self._write_row(shard_id, row)).result() count += 1 return count - def _row_arrays(self, row: Row) -> tuple[dict[str, np.ndarray], int | None]: + async def _write_row(self, shard_id: str, row: Row) -> None: arrays = self._arrays_for_row(row) row_arrays: dict[str, np.ndarray] = {} + row_videos: list[tuple[str, VideoSource]] = [] lengths: list[int] = [] + store: _ShardStore | None = None for zarr_path, source_key in arrays.items(): value = _row_value(row, source_key) if value is None: raise ValueError(f"Zarr source value is missing: {source_key}") + if store is None: + store = self._store(shard_id) + if isinstance(value, VideoSource): + row_videos.append((zarr_path, value)) + continue array = _as_array(value) if array.ndim == 0: array = array.reshape(1) lengths.append(int(array.shape[0])) row_arrays[zarr_path] = array + if store is not None: + for zarr_path, array in row_arrays.items(): + self._validate_array_append(store, zarr_path, array) + for zarr_path, array in row_arrays.items(): + self._append_array(store, zarr_path, array) + for zarr_path, video in row_videos: + lengths.append(await self._append_video(store, zarr_path, video)) if not lengths: - return {}, None + return length = lengths[0] if any(item != length for item in lengths): raise ValueError("Zarr arrays for one row must have matching lengths") - return row_arrays, length + if store is not None and self.episode_ends_path is not None: + store.row_end += length + self._append_array( + store, + self.episode_ends_path, + np.asarray([store.row_end], dtype=np.int64), + ) + + async def _append_video( + self, + store: _ShardStore, + path: str, + video: VideoSource, + ) -> int: + batch: list[np.ndarray] = [] + count = 0 + async for frame in video.iter_frame_arrays(): + batch.append(np.asarray(frame)) + if len(batch) >= self.video_frame_batch_size: + self._append_array(store, path, np.stack(batch, axis=0)) + count += len(batch) + batch.clear() + if batch: + self._append_array(store, path, np.stack(batch, axis=0)) + count += len(batch) + return count def _arrays_for_row(self, row: Row) -> dict[str, str]: if self.arrays is not None: @@ -101,10 +139,7 @@ def _arrays_for_row(self, row: Row) -> dict[str, str]: return self._default_arrays def _store(self, shard_id: str) -> _ShardStore: - relpath = self.store_template.format( - shard_id=shard_id, - worker_id=get_active_worker_token(), - ) + relpath = self._store_relpath(shard_id) store = self._stores.get(relpath) if store is not None: return store @@ -119,12 +154,15 @@ def _store(self, shard_id: str) -> _ShardStore: self._stores[relpath] = store return store - def on_shard_complete(self, shard_id: str) -> None: + def _store_relpath(self, shard_id: str) -> str: relpath = self.store_template.format( shard_id=shard_id, worker_id=get_active_worker_token(), ) - self._stores.pop(relpath, None) + return f"_parts/{relpath}" if self.reduce_to_single_store else relpath + + def on_shard_complete(self, shard_id: str) -> None: + self._stores.pop(self._store_relpath(shard_id), None) def _append_array( self, @@ -142,23 +180,25 @@ def _append_array( dtype=array.dtype, ) store.arrays[path] = dataset + else: + self._validate_array_append(store, path, array) dataset.append(array, axis=0) - def _validate_store_append( + def _validate_array_append( self, store: _ShardStore, - row_arrays: Mapping[str, np.ndarray], + path: str, + array: np.ndarray, ) -> None: - for path, array in row_arrays.items(): - dataset = store.arrays.get(path) - if dataset is None: - continue - if tuple(dataset.shape[1:]) != tuple(array.shape[1:]): - raise ValueError( - f"Zarr arrays for {path!r} must have matching trailing shapes" - ) - if dataset.dtype != array.dtype: - raise ValueError(f"Zarr arrays for {path!r} must have matching dtypes") + dataset = store.arrays.get(path) + if dataset is None: + return + if tuple(dataset.shape[1:]) != tuple(array.shape[1:]): + raise ValueError( + f"Zarr arrays for {path!r} must have matching trailing shapes" + ) + if dataset.dtype != array.dtype: + raise ValueError(f"Zarr arrays for {path!r} must have matching dtypes") def close(self) -> None: self._stores.clear() @@ -172,11 +212,20 @@ def describe(self) -> tuple[str, str, dict[str, object]]: "arrays": dict(self.arrays) if self.arrays is not None else None, "episode_ends_path": self.episode_ends_path, "store_template": self.store_template, + "video_frame_batch_size": self.video_frame_batch_size, + "reduce_to_single_store": self.reduce_to_single_store, "overwrite": self.overwrite, }, ) def build_reducer(self) -> BaseSink | None: + if self.reduce_to_single_store: + return ZarrMergeReducerSink( + output=self.output, + store_template=self.store_template, + episode_ends_path=self.episode_ends_path, + overwrite=self.overwrite, + ) return FileCleanupReducerSink( output=self.output, filename_template=self.store_template, @@ -185,6 +234,118 @@ def build_reducer(self) -> BaseSink | None: ) +class ZarrMergeReducerSink(BaseSink): + def __init__( + self, + output: DataFolderLike, + *, + store_template: str, + episode_ends_path: str | None, + overwrite: bool, + ) -> None: + check_required_dependencies("write_zarr", ["zarr"], dist="zarr") + self.output = DataFolder.resolve(output) + self.store_template = store_template + self.episode_ends_path = episode_ends_path + self.overwrite = overwrite + self._merged = False + + @property + def counts_output_rows(self) -> bool: + return False + + def write_shard_block(self, shard_id, block) -> None: + del shard_id, block + self._merge() + + def describe(self) -> tuple[str, str, dict[str, object]]: + return ( + "write_zarr_reduce", + "writer", + { + "path": self.output.abs_path(), + "store_template": self.store_template, + "reduce_to_single_store": True, + }, + ) + + def _merge(self) -> None: + if self._merged: + return + self._merged = True + + stage_index = get_active_stage_index() + if stage_index is None or stage_index <= 0: + raise ValueError( + "write_zarr_reduce requires an active reducer stage with a prior writer stage" + ) + + import zarr + + final = zarr.open_group( + store=zarr_store(self.output, "", mode="a"), + mode="a", + ) + if self.overwrite: + _clear_final_group(final) + + stores = sorted( + get_finalized_workers(stage_index=stage_index - 1), + key=lambda row: ( + row.global_ordinal is None, + row.global_ordinal if row.global_ordinal is not None else row.shard_id, + ), + ) + row_offset = 0 + arrays: dict[str, Any] = {} + for row in stores: + relpath = self._part_relpath(row.shard_id, row.worker_token) + source = zarr.open_group( + store=zarr_store(self.output, relpath, mode="r"), + mode="r", + ) + for path in _iter_zarr_arrays(source): + source_array = source[path] + if path == self.episode_ends_path: + if source_array.shape[0] == 0: + continue + values = np.asarray(source_array[:], dtype=np.int64) + _append_reduced_array( + final, + arrays, + path, + values + row_offset, + source_array, + ) + row_offset += int(values[-1]) + continue + for start in range( + 0, int(source_array.shape[0]), _merge_batch_size(source_array) + ): + end = min( + int(source_array.shape[0]), + start + _merge_batch_size(source_array), + ) + _append_reduced_array( + final, + arrays, + path, + np.asarray(source_array[start:end]), + source_array, + ) + + try: + self.output.rm("_parts", recursive=True) + except FileNotFoundError: + pass + + def _part_relpath(self, shard_id: str, worker_token: str) -> str: + return "_parts/" + self.store_template.format( + shard_id=shard_id, + worker_id=worker_token, + ) + + def _default_robotics_arrays(row: Row) -> dict[str, str]: if not isinstance(row, RoboticsRow): raise ValueError("write_zarr requires arrays=... for non-RoboticsRow inputs") @@ -221,14 +382,9 @@ def _row_value(row: Row, key: str) -> Any: return row.observations(key) except KeyError: video = row.videos.get(key) - if isinstance(video, VideoFrameArray): - return np.asarray(list(video.iter_frame_arrays())) - if video is not None: - raise ValueError( - "write_zarr can only materialize video observations backed " - f"by frame arrays, got {type(video).__name__} for {key!r}" - ) - raise + if video is None: + raise + return video return row[key] @@ -244,4 +400,45 @@ def _as_array(value: Any) -> np.ndarray: return np.asarray(value) -__all__ = ["ZarrSink"] +def _clear_final_group(group: Any) -> None: + for key in sorted({*group.array_keys(), *group.group_keys()}): + if key != "_parts": + del group[key] + + +def _iter_zarr_arrays(group: Any, prefix: str = "") -> Iterable[str]: + for key in sorted(group.array_keys()): + yield f"{prefix}/{key}" if prefix else key + for key in sorted(group.group_keys()): + child_prefix = f"{prefix}/{key}" if prefix else key + yield from _iter_zarr_arrays(group[key], child_prefix) + + +def _merge_batch_size(array: Any) -> int: + chunks = getattr(array, "chunks", None) + if isinstance(chunks, tuple) and chunks and isinstance(chunks[0], int): + return max(1, int(chunks[0])) + return max(1, min(int(array.shape[0]), 1024)) + + +def _append_reduced_array( + root: Any, + arrays: dict[str, Any], + path: str, + values: np.ndarray, + source_array: Any, +) -> None: + dataset = arrays.get(path) + if dataset is None: + dataset = root.create_dataset( + path, + shape=(0, *values.shape[1:]), + chunks=getattr(source_array, "chunks", None), + dtype=source_array.dtype, + compressor=getattr(source_array, "compressor", None), + ) + arrays[path] = dataset + dataset.append(values, axis=0) + + +__all__ = ["ZarrMergeReducerSink", "ZarrSink"] diff --git a/src/refiner/video/decode.py b/src/refiner/video/decode.py index d07b28ae..4f3ac953 100644 --- a/src/refiner/video/decode.py +++ b/src/refiner/video/decode.py @@ -63,7 +63,7 @@ async def export_clip( fps=video.fps, movflags=None, ) - writer.append_frame_arrays(video.iter_frame_arrays()) + writer.append_frame_arrays(video.frame_arrays) writer.close() return output_file.getvalue() encoded_video = cast("VideoFile | VideoBytes", video) diff --git a/src/refiner/video/types.py b/src/refiner/video/types.py index a5d67368..344da07e 100644 --- a/src/refiner/video/types.py +++ b/src/refiner/video/types.py @@ -2,7 +2,7 @@ import io import math -from collections.abc import AsyncIterator, Iterator, Mapping, Sequence +from collections.abc import AsyncIterator, Mapping, Sequence from dataclasses import dataclass, field from fractions import Fraction from typing import IO, TYPE_CHECKING, Any, Protocol, runtime_checkable @@ -28,6 +28,8 @@ def clipped( def iter_frames(self) -> AsyncIterator[DecodedVideoFrame]: ... + def iter_frame_arrays(self) -> AsyncIterator[np.ndarray]: ... + def iter_frame_windows( self, *, @@ -108,6 +110,10 @@ def iter_frames(self) -> AsyncIterator[DecodedVideoFrame]: return iter_encoded_frames(self) + async def iter_frame_arrays(self) -> AsyncIterator[np.ndarray]: + async for frame in self.iter_frames(): + yield frame.frame.to_ndarray(format="rgb24") + def iter_frame_windows( self, *, @@ -172,6 +178,10 @@ def iter_frames(self) -> AsyncIterator[DecodedVideoFrame]: return iter_encoded_frames(self) + async def iter_frame_arrays(self) -> AsyncIterator[np.ndarray]: + async for frame in self.iter_frames(): + yield frame.frame.to_ndarray(format="rgb24") + def iter_frame_windows( self, *, @@ -233,8 +243,13 @@ def frame_count(self) -> int: def duration_s(self) -> float: return self.frame_count / float(self.fps) - def iter_frame_arrays(self) -> Iterator[np.ndarray]: - yield from self._array + @property + def frame_arrays(self) -> np.ndarray: + return self._array + + async def iter_frame_arrays(self) -> AsyncIterator[np.ndarray]: + for frame in self._array: + yield frame def clipped( self, diff --git a/src/refiner/video/writer.py b/src/refiner/video/writer.py index 888cd231..3228788b 100644 --- a/src/refiner/video/writer.py +++ b/src/refiner/video/writer.py @@ -141,7 +141,7 @@ def _commit_frame_arrays_sync( writer = self._ensure_transcode_writer(video.fps) file_index = self._next_file_index from_timestamp, to_timestamp = writer.append_frame_arrays( - video.iter_frame_arrays(), + video.frame_arrays, frame_observer=frame_observer, ) if writer.stream is None: diff --git a/src/refiner/worker/lifecycle.py b/src/refiner/worker/lifecycle.py index 89f09791..c77ddb26 100644 --- a/src/refiner/worker/lifecycle.py +++ b/src/refiner/worker/lifecycle.py @@ -13,6 +13,7 @@ class FinalizedShardWorker(msgspec.Struct, frozen=True): shard_id: str worker_id: str + global_ordinal: int | None = None @property def worker_token(self) -> str: @@ -53,11 +54,25 @@ def read_finalized_workers( except Exception: continue shard_id = payload.get("shard_id") if isinstance(payload, dict) else None + global_ordinal = ( + payload.get("global_ordinal") if isinstance(payload, dict) else None + ) if isinstance(shard_id, str): rows.append( - FinalizedShardWorker(shard_id=shard_id, worker_id=worker_id) + FinalizedShardWorker( + shard_id=shard_id, + worker_id=worker_id, + global_ordinal=( + global_ordinal if isinstance(global_ordinal, int) else None + ), + ) ) - rows.sort(key=lambda row: row.shard_id) + rows.sort( + key=lambda row: ( + row.global_ordinal is None, + row.global_ordinal if row.global_ordinal is not None else row.shard_id, + ) + ) return rows @@ -90,7 +105,15 @@ def complete(self, shard: Shard) -> None: ) path.parent.mkdir(parents=True, exist_ok=True) with path.open("a", encoding="utf-8") as handle: - handle.write(json.dumps({"shard_id": shard.id}, sort_keys=True)) + handle.write( + json.dumps( + { + "global_ordinal": shard.global_ordinal, + "shard_id": shard.id, + }, + sort_keys=True, + ) + ) handle.write("\n") return None diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py index daaf66f6..d3f0b8ac 100644 --- a/tests/readers/test_zarr_reader.py +++ b/tests/readers/test_zarr_reader.py @@ -11,6 +11,7 @@ import refiner as mdr from refiner.io.datafolder import DataFolder +from refiner.io import DataFile from refiner.robotics.row import RoboticsRow from refiner.pipeline.data.row import DictRow from refiner.pipeline.data.row import Row @@ -55,6 +56,27 @@ def _write_policy_zarr(path: Path) -> None: root.attrs["task"] = "push tee" +def _write_video(path: Path, *, num_frames: int = 3, fps: int = 5) -> None: + import av + + with av.open(str(path), mode="w") as container: + stream = container.add_stream("mpeg4", rate=fps) + stream.width = 4 + stream.height = 4 + stream.pix_fmt = "yuv420p" + + for value in range(num_frames): + frame = av.VideoFrame.from_ndarray( + np.full((4, 4, 3), value, dtype=np.uint8), + format="rgb24", + ) + for packet in stream.encode(frame): + container.mux(packet) + + for packet in stream.encode(None): + container.mux(packet) + + def test_read_zarr_reads_selected_arrays_and_attrs(tmp_path: Path) -> None: path = tmp_path / "policy.zarr" _write_policy_zarr(path) @@ -741,6 +763,46 @@ def test_write_zarr_roundtrips_lerobot_rows(tmp_path: Path) -> None: ) +def test_write_zarr_can_reduce_to_single_store(tmp_path: Path) -> None: + zarr_out = tmp_path / "single.zarr" + + ( + mdr.from_items( + [ + {"action": [[0.0], [0.1]], "state": [[1.0], [1.1]]}, + {"action": [[0.2]], "state": [[1.2]]}, + ], + items_per_shard=1, + ) + .write_zarr( + str(zarr_out), + arrays={ + "data/action": "action", + "data/state": "state", + }, + reduce_to_single_store=True, + ) + .launch_local( + name="zarr-single-store", num_workers=1, rundir=str(tmp_path / "run") + ) + ) + + row = mdr.read_zarr( + zarr_out, + arrays={ + "action": "data/action", + "state": "data/state", + "episode_ends": "meta/episode_ends", + }, + file_path_column=None, + ).take(1)[0] + + np.testing.assert_allclose(row["action"], [[0.0], [0.1], [0.2]]) + np.testing.assert_allclose(row["state"], [[1.0], [1.1], [1.2]]) + assert row["episode_ends"].tolist() == [2, 3] + assert not (zarr_out / "_parts").exists() + + def test_write_zarr_rejects_rows_missing_inferred_default_arrays( tmp_path: Path, ) -> None: @@ -849,3 +911,46 @@ def test_write_zarr_materializes_frame_array_videos(tmp_path: Path) -> None: ).take(1)[0] np.testing.assert_array_equal(row["rgb"], frames) np.testing.assert_allclose(row["action"], [[0.0], [0.1]]) + + +def test_write_zarr_streams_encoded_videos(tmp_path: Path) -> None: + source = tmp_path / "source.mp4" + output = tmp_path / "encoded-video.zarr" + _write_video(source, num_frames=3, fps=5) + + rows = list( + mdr.from_items( + [ + { + "episode_id": "episode-1", + "clip": mdr.video.VideoFile(DataFile.resolve(source)), + "action": [[0.0], [0.1], [0.2]], + } + ] + ).to_robot_rows( + episode_id_key="episode_id", + action_key="action", + state_key=None, + timestamp_key=None, + video_keys={"observation.images.front": "clip"}, + ) + ) + + ZarrSink( + str(output), + arrays={ + "data/action": "action", + "data/rgb": "observation.images.front", + }, + video_frame_batch_size=2, + ).write_block(rows) + + zarr_store = next(output.glob("*.zarr")) + row = mdr.read_zarr( + zarr_store, + arrays={"action": "data/action", "rgb": "data/rgb"}, + file_path_column=None, + ).take(1)[0] + assert row["rgb"].shape == (3, 4, 4, 3) + assert row["rgb"].dtype == np.uint8 + np.testing.assert_allclose(row["action"], [[0.0], [0.1], [0.2]]) diff --git a/tests/robotics/test_robotics_row.py b/tests/robotics/test_robotics_row.py index b4e026e8..78b85654 100644 --- a/tests/robotics/test_robotics_row.py +++ b/tests/robotics/test_robotics_row.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio from typing import Any, cast import numpy as np @@ -120,12 +121,18 @@ def test_to_robot_rows_uses_video_frame_array_asset_schema() -> None: video = robotics_row.videos["camera"] assert isinstance(video, VideoFrameArray) - video_frames = list(video.iter_frame_arrays()) + video_frames = asyncio.run( + _collect_frame_arrays(video), + ) assert len(video_frames) == 2 assert video_frames[0].shape == (4, 5, 3) assert video.fps == 12 +async def _collect_frame_arrays(video): + return [frame async for frame in video.iter_frame_arrays()] + + def test_to_robot_rows_accepts_unmapped_key_iterables() -> None: row = DictRow( { diff --git a/tests/test_video_decode.py b/tests/test_video_decode.py index e8765888..4e828feb 100644 --- a/tests/test_video_decode.py +++ b/tests/test_video_decode.py @@ -35,6 +35,10 @@ async def _collect_frames(video: mdr.video.VideoSource): return [frame async for frame in video.iter_frames()] +async def _collect_frame_arrays(video: mdr.video.VideoSource): + return [frame async for frame in video.iter_frame_arrays()] + + async def _collect_windows( video: mdr.video.VideoSource, *, @@ -74,7 +78,7 @@ def test_video_frame_array_clip_returns_frame_view() -> None: clipped = video.clipped(from_timestamp_s=0.2, to_timestamp_s=0.5) assert isinstance(clipped, mdr.video.VideoFrameArray) - clipped_frames = list(clipped.iter_frame_arrays()) + clipped_frames = asyncio.run(_collect_frame_arrays(clipped)) assert len(clipped_frames) == 3 assert clipped_frames[0].shape == (4, 4, 3) assert [int(frame[0, 0, 0]) for frame in clipped_frames] == [2, 3, 4] @@ -85,9 +89,11 @@ def test_video_frame_array_iter_frames() -> None: video = mdr.video.VideoFrameArray(frames, fps=5) decoded = asyncio.run(_collect_frames(video)) + arrays = asyncio.run(_collect_frame_arrays(video)) assert [frame.index for frame in decoded] == [0, 1, 2] assert [frame.timestamp_s for frame in decoded] == [0.0, 0.2, 0.4] + assert [int(frame[0, 0, 0]) for frame in arrays] == [0, 1, 2] def test_video_stream_writer_accepts_video_frame_array(tmp_path) -> None: From 75d8a52b708c3fbf77ebc2797356d72a90189645 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sun, 24 May 2026 01:43:46 +0200 Subject: [PATCH 12/39] Harden Zarr writer streaming --- src/refiner/io/zarr.py | 14 +- src/refiner/pipeline/pipeline.py | 2 +- src/refiner/pipeline/sinks/reducer/file.py | 21 ++- src/refiner/pipeline/sinks/zarr.py | 142 ++++++++++++++----- src/refiner/pipeline/sources/readers/zarr.py | 15 +- src/refiner/video/types.py | 21 ++- tests/pipeline/test_sinks.py | 40 ++++++ tests/readers/test_zarr_reader.py | 35 +++++ tests/robotics/test_robotics_row.py | 5 +- tests/test_video_decode.py | 3 +- 10 files changed, 238 insertions(+), 60 deletions(-) diff --git a/src/refiner/io/zarr.py b/src/refiner/io/zarr.py index bff3958d..90312f42 100644 --- a/src/refiner/io/zarr.py +++ b/src/refiner/io/zarr.py @@ -1,6 +1,8 @@ from __future__ import annotations +from collections.abc import Iterator from typing import Literal +from typing import Any from refiner.io.datafolder import DataFolder @@ -22,4 +24,14 @@ def zarr_store( ) -__all__ = ["zarr_store"] +def iter_zarr_array_paths(group: Any, prefix: str = "") -> Iterator[str]: + items = group.items() if hasattr(group, "items") else group.members() + for name, item in items: + path = f"{prefix}/{name}" if prefix else name + if hasattr(item, "shape"): + yield path + else: + yield from iter_zarr_array_paths(item, path) + + +__all__ = ["iter_zarr_array_paths", "zarr_store"] diff --git a/src/refiner/pipeline/pipeline.py b/src/refiner/pipeline/pipeline.py index 3fde27f1..566b789c 100644 --- a/src/refiner/pipeline/pipeline.py +++ b/src/refiner/pipeline/pipeline.py @@ -438,7 +438,7 @@ def write_zarr( arrays: Mapping[str, str] | None = None, episode_ends_path: str | None = "meta/episode_ends", store_template: str = "{shard_id}__w{worker_id}.zarr", - video_frame_batch_size: int = 64, + video_frame_batch_size: int = 8, reduce_to_single_store: bool = False, overwrite: bool = True, ) -> "RefinerPipeline": diff --git a/src/refiner/pipeline/sinks/reducer/file.py b/src/refiner/pipeline/sinks/reducer/file.py index 7d26f45e..456cd912 100644 --- a/src/refiner/pipeline/sinks/reducer/file.py +++ b/src/refiner/pipeline/sinks/reducer/file.py @@ -70,6 +70,10 @@ def _managed_listing_prefix(filename_template: str) -> str: return literal_prefix.rsplit("/", maxsplit=1)[0] +def _path_depth(path: str) -> int: + return len([part for part in path.split("/") if part]) + + class FileCleanupReducerSink(BaseSink): """Delete non-finalized deterministic file-sink outputs.""" @@ -89,6 +93,7 @@ def __init__( self.recursive = recursive self._managed_path_pattern = _compile_managed_path_pattern(filename_template) self._managed_listing_prefix = _managed_listing_prefix(filename_template) + self._managed_path_depth = _path_depth(filename_template) self._cleanup_ran = False def write_shard_block(self, shard_id, block) -> None: @@ -196,10 +201,18 @@ def _listed_cleanup_paths(self) -> list[str]: except FileNotFoundError: return [] - try: - paths = self.output.ls(self._managed_listing_prefix, detail=False) - except FileNotFoundError: - return [] + paths = [self._managed_listing_prefix] + depth = max( + 1, self._managed_path_depth - _path_depth(self._managed_listing_prefix) + ) + for _ in range(depth): + next_paths: list[str] = [] + for path in paths: + try: + next_paths.extend(self.output.ls(path, detail=False)) + except FileNotFoundError: + continue + paths = next_paths return [ path for path in paths diff --git a/src/refiner/pipeline/sinks/zarr.py b/src/refiner/pipeline/sinks/zarr.py index 8ff0f872..ea477763 100644 --- a/src/refiner/pipeline/sinks/zarr.py +++ b/src/refiner/pipeline/sinks/zarr.py @@ -9,7 +9,7 @@ from refiner.execution.asyncio.runtime import submit from refiner.io.datafolder import DataFolder, DataFolderLike -from refiner.io.zarr import zarr_store +from refiner.io.zarr import iter_zarr_array_paths, zarr_store from refiner.pipeline.data.block import Block from refiner.pipeline.data.row import Row from refiner.pipeline.sinks.base import BaseSink @@ -23,12 +23,15 @@ get_finalized_workers, ) +_DEFAULT_ARRAY_CHUNK_LENGTH = 1024 + @dataclass class _ShardStore: root: Any arrays: dict[str, Any] = field(default_factory=dict) row_end: int = 0 + next_temp_index: int = 0 class ZarrSink(BaseSink): @@ -39,7 +42,7 @@ def __init__( arrays: Mapping[str, str] | None = None, episode_ends_path: str | None = "meta/episode_ends", store_template: str = "{shard_id}__w{worker_id}.zarr", - video_frame_batch_size: int = 64, + video_frame_batch_size: int = 8, reduce_to_single_store: bool = False, overwrite: bool = True, ): @@ -61,11 +64,11 @@ def __init__( def write_shard_block(self, shard_id: str, block: Block) -> int: count = 0 for row in block: - submit(self._write_row(shard_id, row)).result() + self._write_row(shard_id, row) count += 1 return count - async def _write_row(self, shard_id: str, row: Row) -> None: + def _write_row(self, shard_id: str, row: Row) -> None: arrays = self._arrays_for_row(row) row_arrays: dict[str, np.ndarray] = {} row_videos: list[tuple[str, VideoSource]] = [] @@ -85,43 +88,81 @@ async def _write_row(self, shard_id: str, row: Row) -> None: array = array.reshape(1) lengths.append(int(array.shape[0])) row_arrays[zarr_path] = array - if store is not None: - for zarr_path, array in row_arrays.items(): - self._validate_array_append(store, zarr_path, array) - for zarr_path, array in row_arrays.items(): - self._append_array(store, zarr_path, array) - for zarr_path, video in row_videos: - lengths.append(await self._append_video(store, zarr_path, video)) if not lengths: - return - length = lengths[0] - if any(item != length for item in lengths): - raise ValueError("Zarr arrays for one row must have matching lengths") - if store is not None and self.episode_ends_path is not None: - store.row_end += length - self._append_array( - store, - self.episode_ends_path, - np.asarray([store.row_end], dtype=np.int64), - ) + expected_length = None + else: + expected_length = lengths[0] + if any(item != expected_length for item in lengths): + raise ValueError("Zarr arrays for one row must have matching lengths") + temp_videos: list[tuple[str, str, int]] = [] + if store is not None: + try: + for zarr_path, video in row_videos: + temp_path = self._temp_path(store, zarr_path) + temp_videos.append((zarr_path, temp_path, 0)) + video_length = submit( + self._append_video( + store, + temp_path, + video, + expected_length=expected_length, + ) + ).result() + lengths.append(video_length) + temp_videos[-1] = (zarr_path, temp_path, video_length) + if lengths: + length = lengths[0] + if any(item != length for item in lengths): + raise ValueError( + "Zarr arrays for one row must have matching lengths" + ) + for zarr_path, array in row_arrays.items(): + self._validate_array_append(store, zarr_path, array) + for zarr_path, temp_path, _ in temp_videos: + self._validate_array_append( + store, zarr_path, store.arrays[temp_path] + ) + for zarr_path, array in row_arrays.items(): + self._append_array(store, zarr_path, array) + for zarr_path, temp_path, _ in temp_videos: + self._copy_temp_array(store, temp_path, zarr_path) + if lengths and self.episode_ends_path is not None: + store.row_end += lengths[0] + self._append_array( + store, + self.episode_ends_path, + np.asarray([store.row_end], dtype=np.int64), + ) + finally: + for _, temp_path, _ in temp_videos: + self._drop_array(store, temp_path) + return async def _append_video( self, store: _ShardStore, path: str, video: VideoSource, + *, + expected_length: int | None = None, ) -> int: batch: list[np.ndarray] = [] count = 0 - async for frame in video.iter_frame_arrays(): + async for frame in _iter_video_frame_arrays(video): batch.append(np.asarray(frame)) if len(batch) >= self.video_frame_batch_size: self._append_array(store, path, np.stack(batch, axis=0)) count += len(batch) batch.clear() + if expected_length is not None and count > expected_length: + raise ValueError( + "Zarr arrays for one row must have matching lengths" + ) if batch: self._append_array(store, path, np.stack(batch, axis=0)) count += len(batch) + if expected_length is not None and count != expected_length: + raise ValueError("Zarr arrays for one row must have matching lengths") return count def _arrays_for_row(self, row: Row) -> dict[str, str]: @@ -164,6 +205,11 @@ def _store_relpath(self, shard_id: str) -> str: def on_shard_complete(self, shard_id: str) -> None: self._stores.pop(self._store_relpath(shard_id), None) + def _temp_path(self, store: _ShardStore, path: str) -> str: + temp_path = f"__tmp/{store.next_temp_index}/{path}" + store.next_temp_index += 1 + return temp_path + def _append_array( self, store: _ShardStore, @@ -172,7 +218,7 @@ def _append_array( ) -> None: dataset = store.arrays.get(path) if dataset is None: - chunks = (max(1, min(int(array.shape[0]), 1024)), *array.shape[1:]) + chunks = (_DEFAULT_ARRAY_CHUNK_LENGTH, *array.shape[1:]) dataset = store.root.create_dataset( path, shape=(0, *array.shape[1:]), @@ -200,6 +246,29 @@ def _validate_array_append( if dataset.dtype != array.dtype: raise ValueError(f"Zarr arrays for {path!r} must have matching dtypes") + def _copy_temp_array(self, store: _ShardStore, temp_path: str, path: str) -> None: + source = store.arrays[temp_path] + for start in range(0, int(source.shape[0]), _merge_batch_size(source)): + end = min(int(source.shape[0]), start + _merge_batch_size(source)) + self._append_array(store, path, np.asarray(source[start:end])) + + def _drop_array(self, store: _ShardStore, path: str) -> None: + store.arrays.pop(path, None) + if path.startswith("__tmp/"): + path = "/".join(path.split("/")[:2]) + for key in list(store.arrays): + if key == path or key.startswith(f"{path}/"): + store.arrays.pop(key, None) + try: + del store.root[path] + except (KeyError, FileNotFoundError): + pass + if path.startswith("__tmp/"): + try: + del store.root["__tmp"] + except (KeyError, FileNotFoundError): + pass + def close(self) -> None: self._stores.clear() @@ -304,7 +373,7 @@ def _merge(self) -> None: store=zarr_store(self.output, relpath, mode="r"), mode="r", ) - for path in _iter_zarr_arrays(source): + for path in iter_zarr_array_paths(source): source_array = source[path] if path == self.episode_ends_path: if source_array.shape[0] == 0: @@ -400,20 +469,27 @@ def _as_array(value: Any) -> np.ndarray: return np.asarray(value) +async def _iter_video_frame_arrays(video: VideoSource): + iter_frame_arrays = getattr(video, "iter_frame_arrays", None) + if callable(iter_frame_arrays): + frames = iter_frame_arrays() + if hasattr(frames, "__aiter__"): + async for frame in frames: + yield frame + return + for frame in frames: + yield frame + return + async for frame in video.iter_frames(): + yield frame.frame.to_ndarray(format="rgb24") + + def _clear_final_group(group: Any) -> None: for key in sorted({*group.array_keys(), *group.group_keys()}): if key != "_parts": del group[key] -def _iter_zarr_arrays(group: Any, prefix: str = "") -> Iterable[str]: - for key in sorted(group.array_keys()): - yield f"{prefix}/{key}" if prefix else key - for key in sorted(group.group_keys()): - child_prefix = f"{prefix}/{key}" if prefix else key - yield from _iter_zarr_arrays(group[key], child_prefix) - - def _merge_batch_size(array: Any) -> int: chunks = getattr(array, "chunks", None) if isinstance(chunks, tuple) and chunks and isinstance(chunks[0], int): diff --git a/src/refiner/pipeline/sources/readers/zarr.py b/src/refiner/pipeline/sources/readers/zarr.py index 91787107..058014e0 100644 --- a/src/refiner/pipeline/sources/readers/zarr.py +++ b/src/refiner/pipeline/sources/readers/zarr.py @@ -13,6 +13,7 @@ from refiner.io.datafile import DataFile, DataFileLike from refiner.io.datafolder import DataFolder, DataFolderLike +from refiner.io.zarr import iter_zarr_array_paths from refiner.pipeline.data.datatype import ( DTypeMapping, dtype_to_plan, @@ -309,7 +310,9 @@ def _row_metadata(self, *, index: int | None) -> dict[str, Any]: def _selected_arrays(self, group: Any) -> dict[str, Any]: if self.arrays is None: paths = { - path: path for path in _iter_array_paths(group) if path != self.row_ends + path: path + for path in iter_zarr_array_paths(group) + if path != self.row_ends } _validate_output_names( paths, @@ -562,14 +565,4 @@ def _leading_item_bytes(array: Any) -> int: return max(1, int(array.dtype.itemsize) * int(prod(trailing_shape or (1,)))) -def _iter_array_paths(group: Any, prefix: str = "") -> Iterator[str]: - items = group.items() if hasattr(group, "items") else group.members() - for name, item in items: - path = f"{prefix}/{name}" if prefix else name - if hasattr(item, "shape"): - yield path - else: - yield from _iter_array_paths(item, path) - - __all__ = ["ZarrReader"] diff --git a/src/refiner/video/types.py b/src/refiner/video/types.py index 344da07e..8921ec4d 100644 --- a/src/refiner/video/types.py +++ b/src/refiner/video/types.py @@ -2,7 +2,7 @@ import io import math -from collections.abc import AsyncIterator, Mapping, Sequence +from collections.abc import AsyncIterator, Iterator, Mapping, Sequence from dataclasses import dataclass, field from fractions import Fraction from typing import IO, TYPE_CHECKING, Any, Protocol, runtime_checkable @@ -28,8 +28,6 @@ def clipped( def iter_frames(self) -> AsyncIterator[DecodedVideoFrame]: ... - def iter_frame_arrays(self) -> AsyncIterator[np.ndarray]: ... - def iter_frame_windows( self, *, @@ -247,9 +245,8 @@ def duration_s(self) -> float: def frame_arrays(self) -> np.ndarray: return self._array - async def iter_frame_arrays(self) -> AsyncIterator[np.ndarray]: - for frame in self._array: - yield frame + def iter_frame_arrays(self) -> "_FrameArrayView": + return _FrameArrayView(self._array) def clipped( self, @@ -344,6 +341,18 @@ def video_from_storage_value( return None +@dataclass(frozen=True, slots=True) +class _FrameArrayView: + frames: np.ndarray + + def __iter__(self) -> Iterator[np.ndarray]: + yield from self.frames + + async def __aiter__(self) -> AsyncIterator[np.ndarray]: + for frame in self.frames: + yield frame + + __all__ = [ "VideoBytes", "VideoFile", diff --git a/tests/pipeline/test_sinks.py b/tests/pipeline/test_sinks.py index b2e6d5a8..570f68fd 100644 --- a/tests/pipeline/test_sinks.py +++ b/tests/pipeline/test_sinks.py @@ -1007,6 +1007,46 @@ def test_file_cleanup_reducer_removes_non_finalized_nested_directories( assert not loser_dir.exists() +def test_file_cleanup_reducer_removes_dynamic_nested_directories(tmp_path) -> None: + output_dir = tmp_path / "zarr-cleanup-dynamic-nested" + shard_id = "0123456789ab" + winner_worker_id = "worker-2" + loser_worker_id = "worker-1" + winner_dir = ( + output_dir / "split" / shard_id / f"{worker_token_for(winner_worker_id)}.zarr" + ) + loser_dir = ( + output_dir / "split" / shard_id / f"{worker_token_for(loser_worker_id)}.zarr" + ) + (winner_dir / "data").mkdir(parents=True) + (loser_dir / "data").mkdir(parents=True) + (winner_dir / "data" / "0").write_bytes(b"keep") + (loser_dir / "data" / "0").write_bytes(b"drop") + + reducer = FileCleanupReducerSink( + output_dir, + filename_template="split/{shard_id}/{worker_id}.zarr", + reducer_name="cleanup_zarr", + recursive=True, + ) + with set_active_run_context( + job_id="job", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast( + RuntimeLifecycle, + _FinalizedWorkersRuntime( + [FinalizedShardWorker(shard_id=shard_id, worker_id=winner_worker_id)] + ), + ), + ): + reducer.write_block([DictRow({"task_rank": 0}, shard_id="reduce")]) + + assert winner_dir.exists() + assert not loser_dir.exists() + + def test_file_cleanup_reducer_tolerates_duplicate_listed_paths( tmp_path, monkeypatch ) -> None: diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py index d3f0b8ac..61f4813b 100644 --- a/tests/readers/test_zarr_reader.py +++ b/tests/readers/test_zarr_reader.py @@ -954,3 +954,38 @@ def test_write_zarr_streams_encoded_videos(tmp_path: Path) -> None: assert row["rgb"].shape == (3, 4, 4, 3) assert row["rgb"].dtype == np.uint8 np.testing.assert_allclose(row["action"], [[0.0], [0.1], [0.2]]) + + +def test_write_zarr_rejects_video_length_mismatch_before_final_append( + tmp_path: Path, +) -> None: + output = tmp_path / "video-length-mismatch.zarr" + frames = np.zeros((3, 4, 4, 3), dtype=np.uint8) + rows = list( + mdr.from_items( + [{"episode_id": "episode-1", "frames": frames, "action": [[0.0], [0.1]]}] + ).to_robot_rows( + episode_id_key="episode_id", + action_key="action", + state_key=None, + timestamp_key=None, + video_keys={"observation.images.front": "frames"}, + fps=10, + ) + ) + + with pytest.raises(ValueError, match="matching lengths"): + ZarrSink( + str(output), + arrays={ + "data/action": "action", + "data/rgb": "observation.images.front", + }, + video_frame_batch_size=2, + ).write_block(rows) + + zarr_store = next(output.glob("*.zarr")) + root = _open_test_zarr(zarr_store, mode="r") + assert "data/action" not in root + assert "data/rgb" not in root + assert "__tmp" not in root diff --git a/tests/robotics/test_robotics_row.py b/tests/robotics/test_robotics_row.py index 78b85654..11c5d6d0 100644 --- a/tests/robotics/test_robotics_row.py +++ b/tests/robotics/test_robotics_row.py @@ -121,9 +121,8 @@ def test_to_robot_rows_uses_video_frame_array_asset_schema() -> None: video = robotics_row.videos["camera"] assert isinstance(video, VideoFrameArray) - video_frames = asyncio.run( - _collect_frame_arrays(video), - ) + assert len(list(video.iter_frame_arrays())) == 2 + video_frames = asyncio.run(_collect_frame_arrays(video)) assert len(video_frames) == 2 assert video_frames[0].shape == (4, 5, 3) assert video.fps == 12 diff --git a/tests/test_video_decode.py b/tests/test_video_decode.py index 4e828feb..ea5da14a 100644 --- a/tests/test_video_decode.py +++ b/tests/test_video_decode.py @@ -35,7 +35,7 @@ async def _collect_frames(video: mdr.video.VideoSource): return [frame async for frame in video.iter_frames()] -async def _collect_frame_arrays(video: mdr.video.VideoSource): +async def _collect_frame_arrays(video): return [frame async for frame in video.iter_frame_arrays()] @@ -78,6 +78,7 @@ def test_video_frame_array_clip_returns_frame_view() -> None: clipped = video.clipped(from_timestamp_s=0.2, to_timestamp_s=0.5) assert isinstance(clipped, mdr.video.VideoFrameArray) + assert len(list(clipped.iter_frame_arrays())) == 3 clipped_frames = asyncio.run(_collect_frame_arrays(clipped)) assert len(clipped_frames) == 3 assert clipped_frames[0].shape == (4, 4, 3) From db2c4538097a71050accaa4437cd11d0cc773374 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sun, 24 May 2026 02:13:14 +0200 Subject: [PATCH 13/39] Optimize Zarr writer merge IO --- src/refiner/pipeline/pipeline.py | 2 + src/refiner/pipeline/sinks/zarr.py | 212 +++++++++++-------- src/refiner/pipeline/sources/readers/zarr.py | 4 +- src/refiner/video/types.py | 24 +-- tests/readers/test_zarr_reader.py | 102 +++++++++ tests/robotics/test_robotics_row.py | 8 +- tests/test_video_decode.py | 8 +- 7 files changed, 234 insertions(+), 126 deletions(-) diff --git a/src/refiner/pipeline/pipeline.py b/src/refiner/pipeline/pipeline.py index 566b789c..8909802c 100644 --- a/src/refiner/pipeline/pipeline.py +++ b/src/refiner/pipeline/pipeline.py @@ -439,6 +439,7 @@ def write_zarr( episode_ends_path: str | None = "meta/episode_ends", store_template: str = "{shard_id}__w{worker_id}.zarr", video_frame_batch_size: int = 8, + array_chunk_bytes: int = 8 * 1024 * 1024, reduce_to_single_store: bool = False, overwrite: bool = True, ) -> "RefinerPipeline": @@ -449,6 +450,7 @@ def write_zarr( episode_ends_path=episode_ends_path, store_template=store_template, video_frame_batch_size=video_frame_batch_size, + array_chunk_bytes=array_chunk_bytes, reduce_to_single_store=reduce_to_single_store, overwrite=overwrite, ) diff --git a/src/refiner/pipeline/sinks/zarr.py b/src/refiner/pipeline/sinks/zarr.py index ea477763..aea1b55b 100644 --- a/src/refiner/pipeline/sinks/zarr.py +++ b/src/refiner/pipeline/sinks/zarr.py @@ -2,6 +2,7 @@ from collections.abc import Iterable, Mapping from dataclasses import dataclass, field +from string import Formatter from typing import Any, cast import numpy as np @@ -16,14 +17,15 @@ from refiner.pipeline.sinks.reducer.file import FileCleanupReducerSink from refiner.robotics.row import RoboticsRow from refiner.utils import check_required_dependencies -from refiner.video import VideoSource +from refiner.video import VideoFrameArray, VideoSource from refiner.worker.context import ( + get_active_job_id, get_active_stage_index, get_active_worker_token, get_finalized_workers, ) -_DEFAULT_ARRAY_CHUNK_LENGTH = 1024 +_DEFAULT_ARRAY_CHUNK_BYTES = 8 * 1024 * 1024 @dataclass @@ -31,7 +33,6 @@ class _ShardStore: root: Any arrays: dict[str, Any] = field(default_factory=dict) row_end: int = 0 - next_temp_index: int = 0 class ZarrSink(BaseSink): @@ -43,12 +44,16 @@ def __init__( episode_ends_path: str | None = "meta/episode_ends", store_template: str = "{shard_id}__w{worker_id}.zarr", video_frame_batch_size: int = 8, + array_chunk_bytes: int = _DEFAULT_ARRAY_CHUNK_BYTES, reduce_to_single_store: bool = False, overwrite: bool = True, ): check_required_dependencies("write_zarr", ["zarr"], dist="zarr") if video_frame_batch_size <= 0: raise ValueError("video_frame_batch_size must be greater than zero") + if array_chunk_bytes <= 0: + raise ValueError("array_chunk_bytes must be greater than zero") + _validate_store_template(store_template) self.output = DataFolder.resolve(output) self.arrays = dict(arrays) if arrays is not None else None self.episode_ends_path = episode_ends_path @@ -56,6 +61,7 @@ def __init__( _validate_array_paths(self.arrays, episode_ends_path) self.store_template = store_template self.video_frame_batch_size = video_frame_batch_size + self.array_chunk_bytes = array_chunk_bytes self.reduce_to_single_store = reduce_to_single_store self.overwrite = overwrite self._stores: dict[str, _ShardStore] = {} @@ -94,48 +100,61 @@ def _write_row(self, shard_id: str, row: Row) -> None: expected_length = lengths[0] if any(item != expected_length for item in lengths): raise ValueError("Zarr arrays for one row must have matching lengths") - temp_videos: list[tuple[str, str, int]] = [] if store is not None: + rollback_lengths: dict[str, int | None] = {} + for zarr_path in [*row_arrays, *(path for path, _ in row_videos)]: + dataset = store.arrays.get(zarr_path) + rollback_lengths[zarr_path] = ( + None if dataset is None else int(dataset.shape[0]) + ) try: + for zarr_path, array in row_arrays.items(): + self._validate_array_append(store, zarr_path, array) + for zarr_path, array in row_arrays.items(): + self._append_array(store, zarr_path, array) for zarr_path, video in row_videos: - temp_path = self._temp_path(store, zarr_path) - temp_videos.append((zarr_path, temp_path, 0)) video_length = submit( self._append_video( store, - temp_path, + zarr_path, video, expected_length=expected_length, ) ).result() lengths.append(video_length) - temp_videos[-1] = (zarr_path, temp_path, video_length) if lengths: length = lengths[0] if any(item != length for item in lengths): raise ValueError( "Zarr arrays for one row must have matching lengths" ) - for zarr_path, array in row_arrays.items(): - self._validate_array_append(store, zarr_path, array) - for zarr_path, temp_path, _ in temp_videos: - self._validate_array_append( - store, zarr_path, store.arrays[temp_path] - ) - for zarr_path, array in row_arrays.items(): - self._append_array(store, zarr_path, array) - for zarr_path, temp_path, _ in temp_videos: - self._copy_temp_array(store, temp_path, zarr_path) if lengths and self.episode_ends_path is not None: + dataset = store.arrays.get(self.episode_ends_path) + rollback_lengths[self.episode_ends_path] = ( + None if dataset is None else int(dataset.shape[0]) + ) store.row_end += lengths[0] self._append_array( store, self.episode_ends_path, np.asarray([store.row_end], dtype=np.int64), ) - finally: - for _, temp_path, _ in temp_videos: - self._drop_array(store, temp_path) + except Exception: + for zarr_path, length in rollback_lengths.items(): + if length is None: + self._drop_array(store, zarr_path) + continue + dataset = store.arrays.get(zarr_path) + if dataset is not None: + dataset.resize((length, *dataset.shape[1:])) + if self.episode_ends_path is not None: + dataset = store.arrays.get(self.episode_ends_path) + store.row_end = ( + 0 + if dataset is None or dataset.shape[0] == 0 + else int(dataset[-1]) + ) + raise return async def _append_video( @@ -146,19 +165,30 @@ async def _append_video( *, expected_length: int | None = None, ) -> int: + if isinstance(video, VideoFrameArray): + frames = video.frame_arrays + if expected_length is not None and int(frames.shape[0]) != expected_length: + raise ValueError("Zarr arrays for one row must have matching lengths") + for start in range(0, int(frames.shape[0]), self.video_frame_batch_size): + end = min(int(frames.shape[0]), start + self.video_frame_batch_size) + self._append_array(store, path, frames[start:end]) + return int(frames.shape[0]) + batch: list[np.ndarray] = [] count = 0 - async for frame in _iter_video_frame_arrays(video): - batch.append(np.asarray(frame)) + async for frame in video.iter_frames(): + batch.append(frame.frame.to_ndarray(format="rgb24")) if len(batch) >= self.video_frame_batch_size: - self._append_array(store, path, np.stack(batch, axis=0)) - count += len(batch) - batch.clear() - if expected_length is not None and count > expected_length: + if expected_length is not None and count + len(batch) > expected_length: raise ValueError( "Zarr arrays for one row must have matching lengths" ) + self._append_array(store, path, np.stack(batch, axis=0)) + count += len(batch) + batch.clear() if batch: + if expected_length is not None and count + len(batch) > expected_length: + raise ValueError("Zarr arrays for one row must have matching lengths") self._append_array(store, path, np.stack(batch, axis=0)) count += len(batch) if expected_length is not None and count != expected_length: @@ -200,16 +230,13 @@ def _store_relpath(self, shard_id: str) -> str: shard_id=shard_id, worker_id=get_active_worker_token(), ) - return f"_parts/{relpath}" if self.reduce_to_single_store else relpath + if self.reduce_to_single_store: + return _part_store_relpath(relpath) + return relpath def on_shard_complete(self, shard_id: str) -> None: self._stores.pop(self._store_relpath(shard_id), None) - def _temp_path(self, store: _ShardStore, path: str) -> str: - temp_path = f"__tmp/{store.next_temp_index}/{path}" - store.next_temp_index += 1 - return temp_path - def _append_array( self, store: _ShardStore, @@ -218,11 +245,10 @@ def _append_array( ) -> None: dataset = store.arrays.get(path) if dataset is None: - chunks = (_DEFAULT_ARRAY_CHUNK_LENGTH, *array.shape[1:]) dataset = store.root.create_dataset( path, shape=(0, *array.shape[1:]), - chunks=chunks, + chunks=_chunk_shape(array, self.array_chunk_bytes), dtype=array.dtype, ) store.arrays[path] = dataset @@ -246,28 +272,12 @@ def _validate_array_append( if dataset.dtype != array.dtype: raise ValueError(f"Zarr arrays for {path!r} must have matching dtypes") - def _copy_temp_array(self, store: _ShardStore, temp_path: str, path: str) -> None: - source = store.arrays[temp_path] - for start in range(0, int(source.shape[0]), _merge_batch_size(source)): - end = min(int(source.shape[0]), start + _merge_batch_size(source)) - self._append_array(store, path, np.asarray(source[start:end])) - def _drop_array(self, store: _ShardStore, path: str) -> None: store.arrays.pop(path, None) - if path.startswith("__tmp/"): - path = "/".join(path.split("/")[:2]) - for key in list(store.arrays): - if key == path or key.startswith(f"{path}/"): - store.arrays.pop(key, None) try: del store.root[path] except (KeyError, FileNotFoundError): pass - if path.startswith("__tmp/"): - try: - del store.root["__tmp"] - except (KeyError, FileNotFoundError): - pass def close(self) -> None: self._stores.clear() @@ -282,6 +292,7 @@ def describe(self) -> tuple[str, str, dict[str, object]]: "episode_ends_path": self.episode_ends_path, "store_template": self.store_template, "video_frame_batch_size": self.video_frame_batch_size, + "array_chunk_bytes": self.array_chunk_bytes, "reduce_to_single_store": self.reduce_to_single_store, "overwrite": self.overwrite, }, @@ -293,6 +304,7 @@ def build_reducer(self) -> BaseSink | None: output=self.output, store_template=self.store_template, episode_ends_path=self.episode_ends_path, + array_chunk_bytes=self.array_chunk_bytes, overwrite=self.overwrite, ) return FileCleanupReducerSink( @@ -310,12 +322,14 @@ def __init__( *, store_template: str, episode_ends_path: str | None, + array_chunk_bytes: int, overwrite: bool, ) -> None: check_required_dependencies("write_zarr", ["zarr"], dist="zarr") self.output = DataFolder.resolve(output) self.store_template = store_template self.episode_ends_path = episode_ends_path + self.array_chunk_bytes = array_chunk_bytes self.overwrite = overwrite self._merged = False @@ -334,6 +348,7 @@ def describe(self) -> tuple[str, str, dict[str, object]]: { "path": self.output.abs_path(), "store_template": self.store_template, + "array_chunk_bytes": self.array_chunk_bytes, "reduce_to_single_store": True, }, ) @@ -369,6 +384,8 @@ def _merge(self) -> None: arrays: dict[str, Any] = {} for row in stores: relpath = self._part_relpath(row.shard_id, row.worker_token) + if not self.output.exists(relpath): + continue source = zarr.open_group( store=zarr_store(self.output, relpath, mode="r"), mode="r", @@ -378,23 +395,24 @@ def _merge(self) -> None: if path == self.episode_ends_path: if source_array.shape[0] == 0: continue - values = np.asarray(source_array[:], dtype=np.int64) - _append_reduced_array( - final, - arrays, - path, - values + row_offset, - source_array, - ) - row_offset += int(values[-1]) + part_last = row_offset + batch_size = _batch_length(source_array, self.array_chunk_bytes) + for start in range(0, int(source_array.shape[0]), batch_size): + end = min(int(source_array.shape[0]), start + batch_size) + values = np.asarray(source_array[start:end], dtype=np.int64) + _append_reduced_array( + final, + arrays, + path, + values + row_offset, + source_array, + ) + part_last = int(values[-1]) + row_offset += part_last continue - for start in range( - 0, int(source_array.shape[0]), _merge_batch_size(source_array) - ): - end = min( - int(source_array.shape[0]), - start + _merge_batch_size(source_array), - ) + batch_size = _batch_length(source_array, self.array_chunk_bytes) + for start in range(0, int(source_array.shape[0]), batch_size): + end = min(int(source_array.shape[0]), start + batch_size) _append_reduced_array( final, arrays, @@ -404,14 +422,21 @@ def _merge(self) -> None: ) try: - self.output.rm("_parts", recursive=True) + self.output.rm(f"_parts/{get_active_job_id()}", recursive=True) except FileNotFoundError: pass + try: + if not self.output.ls("_parts"): + self.output.rmdir("_parts") + except (FileNotFoundError, OSError, ValueError): + pass def _part_relpath(self, shard_id: str, worker_token: str) -> str: - return "_parts/" + self.store_template.format( - shard_id=shard_id, - worker_id=worker_token, + return _part_store_relpath( + self.store_template.format( + shard_id=shard_id, + worker_id=worker_token, + ) ) @@ -438,6 +463,22 @@ def _validate_array_paths( ) +def _validate_store_template(store_template: str) -> None: + fields = { + field_name + for _literal_text, field_name, _format_spec, _conversion in Formatter().parse( + store_template + ) + if field_name is not None + } + missing_fields = {"shard_id", "worker_id"}.difference(fields) + if missing_fields: + raise ValueError( + "store_template requires fields: " + + ", ".join(f"{{{field_name}}}" for field_name in sorted(missing_fields)) + ) + + def _row_value(row: Row, key: str) -> Any: if isinstance(row, RoboticsRow): if key == "action": @@ -469,19 +510,8 @@ def _as_array(value: Any) -> np.ndarray: return np.asarray(value) -async def _iter_video_frame_arrays(video: VideoSource): - iter_frame_arrays = getattr(video, "iter_frame_arrays", None) - if callable(iter_frame_arrays): - frames = iter_frame_arrays() - if hasattr(frames, "__aiter__"): - async for frame in frames: - yield frame - return - for frame in frames: - yield frame - return - async for frame in video.iter_frames(): - yield frame.frame.to_ndarray(format="rgb24") +def _part_store_relpath(relpath: str) -> str: + return f"_parts/{get_active_job_id()}/{relpath}" def _clear_final_group(group: Any) -> None: @@ -490,11 +520,15 @@ def _clear_final_group(group: Any) -> None: del group[key] -def _merge_batch_size(array: Any) -> int: - chunks = getattr(array, "chunks", None) - if isinstance(chunks, tuple) and chunks and isinstance(chunks[0], int): - return max(1, int(chunks[0])) - return max(1, min(int(array.shape[0]), 1024)) +def _chunk_shape(array: np.ndarray, target_bytes: int) -> tuple[int, ...]: + return (_batch_length(array, target_bytes), *array.shape[1:]) + + +def _batch_length(array: Any, target_bytes: int) -> int: + dtype = np.dtype(array.dtype) + row_values = int(np.prod(tuple(array.shape[1:]), dtype=np.int64)) + row_bytes = max(1, dtype.itemsize * max(1, row_values)) + return max(1, target_bytes // row_bytes) def _append_reduced_array( diff --git a/src/refiner/pipeline/sources/readers/zarr.py b/src/refiner/pipeline/sources/readers/zarr.py index 058014e0..e315f1b6 100644 --- a/src/refiner/pipeline/sources/readers/zarr.py +++ b/src/refiner/pipeline/sources/readers/zarr.py @@ -13,7 +13,7 @@ from refiner.io.datafile import DataFile, DataFileLike from refiner.io.datafolder import DataFolder, DataFolderLike -from refiner.io.zarr import iter_zarr_array_paths +from refiner.io.zarr import iter_zarr_array_paths, zarr_store from refiner.pipeline.data.datatype import ( DTypeMapping, dtype_to_plan, @@ -288,7 +288,7 @@ def _open_group(self) -> Any: handle.close() else: assert self.root is not None - store = zarr.storage.FSStore(self.root._join(""), fs=self.root.fs, mode="r") + store = zarr_store(self.root, mode="r") yield zarr.open_group(store=store, mode="r") def _reserved_output_names(self, *, split: bool) -> set[str]: diff --git a/src/refiner/video/types.py b/src/refiner/video/types.py index 8921ec4d..13b04539 100644 --- a/src/refiner/video/types.py +++ b/src/refiner/video/types.py @@ -108,10 +108,6 @@ def iter_frames(self) -> AsyncIterator[DecodedVideoFrame]: return iter_encoded_frames(self) - async def iter_frame_arrays(self) -> AsyncIterator[np.ndarray]: - async for frame in self.iter_frames(): - yield frame.frame.to_ndarray(format="rgb24") - def iter_frame_windows( self, *, @@ -176,10 +172,6 @@ def iter_frames(self) -> AsyncIterator[DecodedVideoFrame]: return iter_encoded_frames(self) - async def iter_frame_arrays(self) -> AsyncIterator[np.ndarray]: - async for frame in self.iter_frames(): - yield frame.frame.to_ndarray(format="rgb24") - def iter_frame_windows( self, *, @@ -245,8 +237,8 @@ def duration_s(self) -> float: def frame_arrays(self) -> np.ndarray: return self._array - def iter_frame_arrays(self) -> "_FrameArrayView": - return _FrameArrayView(self._array) + def iter_frame_arrays(self) -> Iterator[np.ndarray]: + yield from self._array def clipped( self, @@ -341,18 +333,6 @@ def video_from_storage_value( return None -@dataclass(frozen=True, slots=True) -class _FrameArrayView: - frames: np.ndarray - - def __iter__(self) -> Iterator[np.ndarray]: - yield from self.frames - - async def __aiter__(self) -> AsyncIterator[np.ndarray]: - for frame in self.frames: - yield frame - - __all__ = [ "VideoBytes", "VideoFile", diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py index 61f4813b..e2ae1a8b 100644 --- a/tests/readers/test_zarr_reader.py +++ b/tests/readers/test_zarr_reader.py @@ -803,6 +803,79 @@ def test_write_zarr_can_reduce_to_single_store(tmp_path: Path) -> None: assert not (zarr_out / "_parts").exists() +def test_write_zarr_rejects_store_template_without_worker_id(tmp_path: Path) -> None: + with pytest.raises(ValueError, match="store_template requires fields"): + ZarrSink(str(tmp_path / "template.zarr"), store_template="{shard_id}.zarr") + + +def test_write_zarr_single_store_overwrite_ignores_stale_parts(tmp_path: Path) -> None: + zarr_out = tmp_path / "single-overwrite.zarr" + + ( + mdr.from_items([{"action": [[0.0]]}], items_per_shard=1) + .write_zarr( + str(zarr_out), + arrays={"data/action": "action"}, + reduce_to_single_store=True, + ) + .launch_local( + name="zarr-single-overwrite-first", + num_workers=1, + rundir=str(tmp_path / "run-first"), + ) + ) + + stale_part = zarr_out / "_parts" / "old-job" / "old__wold.zarr" + stale_part.mkdir(parents=True) + (stale_part / ".zgroup").write_text('{"zarr_format": 2}', encoding="utf-8") + + ( + mdr.from_items([{"action": [[1.0]]}], items_per_shard=1) + .write_zarr( + str(zarr_out), + arrays={"data/action": "action"}, + reduce_to_single_store=True, + ) + .launch_local( + name="zarr-single-overwrite-second", + num_workers=1, + rundir=str(tmp_path / "run-second"), + ) + ) + + row = mdr.read_zarr( + zarr_out, + arrays={"action": "data/action"}, + file_path_column=None, + ).take(1)[0] + np.testing.assert_allclose(row["action"], [[1.0]]) + assert stale_part.exists() + + +def test_write_zarr_single_store_skips_empty_shards(tmp_path: Path) -> None: + zarr_out = tmp_path / "single-empty-shards.zarr" + + ( + mdr.from_items( + [{"action": [[0.0]]}, {"action": [[0.1]]}], + items_per_shard=1, + ) + .filter(lambda row: False) + .write_zarr( + str(zarr_out), + arrays={"data/action": "action"}, + reduce_to_single_store=True, + ) + .launch_local( + name="zarr-single-empty-shards", + num_workers=2, + rundir=str(tmp_path / "run-empty"), + ) + ) + + assert not (zarr_out / "_parts").exists() + + def test_write_zarr_rejects_rows_missing_inferred_default_arrays( tmp_path: Path, ) -> None: @@ -913,6 +986,35 @@ def test_write_zarr_materializes_frame_array_videos(tmp_path: Path) -> None: np.testing.assert_allclose(row["action"], [[0.0], [0.1]]) +def test_write_zarr_uses_byte_budgeted_chunks_for_large_rows(tmp_path: Path) -> None: + output = tmp_path / "video-chunks.zarr" + frames = np.zeros((2, 4, 4, 3), dtype=np.uint8) + rows = list( + mdr.from_items( + [{"episode_id": "episode-1", "frames": frames, "action": [[0.0], [0.1]]}] + ).to_robot_rows( + episode_id_key="episode_id", + action_key="action", + state_key=None, + timestamp_key=None, + video_keys={"observation.images.front": "frames"}, + fps=10, + ) + ) + + ZarrSink( + str(output), + arrays={ + "data/action": "action", + "data/rgb": "observation.images.front", + }, + array_chunk_bytes=50, + ).write_block(rows) + + root = _open_test_zarr(next(output.glob("*.zarr")), mode="r") + assert root["data/rgb"].chunks == (1, 4, 4, 3) + + def test_write_zarr_streams_encoded_videos(tmp_path: Path) -> None: source = tmp_path / "source.mp4" output = tmp_path / "encoded-video.zarr" diff --git a/tests/robotics/test_robotics_row.py b/tests/robotics/test_robotics_row.py index 11c5d6d0..ca6121d2 100644 --- a/tests/robotics/test_robotics_row.py +++ b/tests/robotics/test_robotics_row.py @@ -1,6 +1,4 @@ from __future__ import annotations - -import asyncio from typing import Any, cast import numpy as np @@ -122,16 +120,12 @@ def test_to_robot_rows_uses_video_frame_array_asset_schema() -> None: video = robotics_row.videos["camera"] assert isinstance(video, VideoFrameArray) assert len(list(video.iter_frame_arrays())) == 2 - video_frames = asyncio.run(_collect_frame_arrays(video)) + video_frames = list(video.iter_frame_arrays()) assert len(video_frames) == 2 assert video_frames[0].shape == (4, 5, 3) assert video.fps == 12 -async def _collect_frame_arrays(video): - return [frame async for frame in video.iter_frame_arrays()] - - def test_to_robot_rows_accepts_unmapped_key_iterables() -> None: row = DictRow( { diff --git a/tests/test_video_decode.py b/tests/test_video_decode.py index ea5da14a..2f085608 100644 --- a/tests/test_video_decode.py +++ b/tests/test_video_decode.py @@ -35,10 +35,6 @@ async def _collect_frames(video: mdr.video.VideoSource): return [frame async for frame in video.iter_frames()] -async def _collect_frame_arrays(video): - return [frame async for frame in video.iter_frame_arrays()] - - async def _collect_windows( video: mdr.video.VideoSource, *, @@ -79,7 +75,7 @@ def test_video_frame_array_clip_returns_frame_view() -> None: assert isinstance(clipped, mdr.video.VideoFrameArray) assert len(list(clipped.iter_frame_arrays())) == 3 - clipped_frames = asyncio.run(_collect_frame_arrays(clipped)) + clipped_frames = list(clipped.iter_frame_arrays()) assert len(clipped_frames) == 3 assert clipped_frames[0].shape == (4, 4, 3) assert [int(frame[0, 0, 0]) for frame in clipped_frames] == [2, 3, 4] @@ -90,7 +86,7 @@ def test_video_frame_array_iter_frames() -> None: video = mdr.video.VideoFrameArray(frames, fps=5) decoded = asyncio.run(_collect_frames(video)) - arrays = asyncio.run(_collect_frame_arrays(video)) + arrays = list(video.iter_frame_arrays()) assert [frame.index for frame in decoded] == [0, 1, 2] assert [frame.timestamp_s for frame in decoded] == [0.0, 0.2, 0.4] From d4ade58cf2c8651ba1148053d8bda35fd081119b Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sun, 24 May 2026 02:51:30 +0200 Subject: [PATCH 14/39] Expose Zarr reducer batch sizing --- docs/reading-and-writing.md | 6 +- src/refiner/pipeline/pipeline.py | 2 + src/refiner/pipeline/sinks/zarr.py | 95 +++++++++++--- tests/readers/test_zarr_reader.py | 198 ++++++++++++++++++++++++++++- 4 files changed, 283 insertions(+), 18 deletions(-) diff --git a/docs/reading-and-writing.md b/docs/reading-and-writing.md index 1a9f7b9b..bc0c555f 100644 --- a/docs/reading-and-writing.md +++ b/docs/reading-and-writing.md @@ -512,7 +512,11 @@ The `arrays` mapping is from output Zarr path to source row key. For arrays: actions, states, and timestamps. The default schema is inferred once and later rows must expose the same fields. Video sources selected through `arrays` are decoded as RGB frame arrays and appended in bounded batches controlled by -`video_frame_batch_size`. +`video_frame_batch_size`. `array_chunk_bytes` controls the target chunk size for +new arrays. When `reduce_to_single_store=True`, the reducer copies shard-local +arrays into the final store in read/write batches controlled by +`reduce_array_batch_bytes`; by default it uses the same value as +`array_chunk_bytes`. By default, `write_zarr(...)` also writes cumulative episode boundaries to `meta/episode_ends`. Set `episode_ends_path=None` to omit them. diff --git a/src/refiner/pipeline/pipeline.py b/src/refiner/pipeline/pipeline.py index 8909802c..5afcde8e 100644 --- a/src/refiner/pipeline/pipeline.py +++ b/src/refiner/pipeline/pipeline.py @@ -440,6 +440,7 @@ def write_zarr( store_template: str = "{shard_id}__w{worker_id}.zarr", video_frame_batch_size: int = 8, array_chunk_bytes: int = 8 * 1024 * 1024, + reduce_array_batch_bytes: int | None = None, reduce_to_single_store: bool = False, overwrite: bool = True, ) -> "RefinerPipeline": @@ -451,6 +452,7 @@ def write_zarr( store_template=store_template, video_frame_batch_size=video_frame_batch_size, array_chunk_bytes=array_chunk_bytes, + reduce_array_batch_bytes=reduce_array_batch_bytes, reduce_to_single_store=reduce_to_single_store, overwrite=overwrite, ) diff --git a/src/refiner/pipeline/sinks/zarr.py b/src/refiner/pipeline/sinks/zarr.py index aea1b55b..6ca6c606 100644 --- a/src/refiner/pipeline/sinks/zarr.py +++ b/src/refiner/pipeline/sinks/zarr.py @@ -45,6 +45,7 @@ def __init__( store_template: str = "{shard_id}__w{worker_id}.zarr", video_frame_batch_size: int = 8, array_chunk_bytes: int = _DEFAULT_ARRAY_CHUNK_BYTES, + reduce_array_batch_bytes: int | None = None, reduce_to_single_store: bool = False, overwrite: bool = True, ): @@ -53,19 +54,29 @@ def __init__( raise ValueError("video_frame_batch_size must be greater than zero") if array_chunk_bytes <= 0: raise ValueError("array_chunk_bytes must be greater than zero") + if reduce_array_batch_bytes is not None and reduce_array_batch_bytes <= 0: + raise ValueError("reduce_array_batch_bytes must be greater than zero") _validate_store_template(store_template) self.output = DataFolder.resolve(output) self.arrays = dict(arrays) if arrays is not None else None self.episode_ends_path = episode_ends_path if self.arrays is not None: + if not self.arrays: + raise ValueError("write_zarr arrays must not be empty") _validate_array_paths(self.arrays, episode_ends_path) self.store_template = store_template self.video_frame_batch_size = video_frame_batch_size self.array_chunk_bytes = array_chunk_bytes + self.reduce_array_batch_bytes = ( + array_chunk_bytes + if reduce_array_batch_bytes is None + else reduce_array_batch_bytes + ) self.reduce_to_single_store = reduce_to_single_store self.overwrite = overwrite self._stores: dict[str, _ShardStore] = {} self._default_arrays: dict[str, str] | None = None + self._checked_no_overwrite = False def write_shard_block(self, shard_id: str, block: Block) -> int: count = 0 @@ -201,6 +212,10 @@ def _arrays_for_row(self, row: Row) -> dict[str, str]: default_arrays = _default_robotics_arrays(row) if self._default_arrays is None: self._default_arrays = default_arrays + if not self._default_arrays: + raise ValueError( + "write_zarr inferred no default robotics arrays; pass arrays=..." + ) _validate_array_paths(self._default_arrays, self.episode_ends_path) elif default_arrays != self._default_arrays: raise ValueError( @@ -210,6 +225,7 @@ def _arrays_for_row(self, row: Row) -> dict[str, str]: return self._default_arrays def _store(self, shard_id: str) -> _ShardStore: + self._check_no_overwrite_output() relpath = self._store_relpath(shard_id) store = self._stores.get(relpath) if store is not None: @@ -225,6 +241,17 @@ def _store(self, shard_id: str) -> _ShardStore: self._stores[relpath] = store return store + def _check_no_overwrite_output(self) -> None: + if self.overwrite or self._checked_no_overwrite: + return + self._checked_no_overwrite = True + try: + entries = self.output.ls("", detail=False) + except FileNotFoundError: + return + if entries: + raise ValueError("write_zarr output already exists and overwrite=False") + def _store_relpath(self, shard_id: str) -> str: relpath = self.store_template.format( shard_id=shard_id, @@ -293,6 +320,7 @@ def describe(self) -> tuple[str, str, dict[str, object]]: "store_template": self.store_template, "video_frame_batch_size": self.video_frame_batch_size, "array_chunk_bytes": self.array_chunk_bytes, + "reduce_array_batch_bytes": self.reduce_array_batch_bytes, "reduce_to_single_store": self.reduce_to_single_store, "overwrite": self.overwrite, }, @@ -300,11 +328,11 @@ def describe(self) -> tuple[str, str, dict[str, object]]: def build_reducer(self) -> BaseSink | None: if self.reduce_to_single_store: - return ZarrMergeReducerSink( + return _ZarrMergeReducerSink( output=self.output, store_template=self.store_template, episode_ends_path=self.episode_ends_path, - array_chunk_bytes=self.array_chunk_bytes, + reduce_array_batch_bytes=self.reduce_array_batch_bytes, overwrite=self.overwrite, ) return FileCleanupReducerSink( @@ -315,21 +343,21 @@ def build_reducer(self) -> BaseSink | None: ) -class ZarrMergeReducerSink(BaseSink): +class _ZarrMergeReducerSink(BaseSink): def __init__( self, output: DataFolderLike, *, store_template: str, episode_ends_path: str | None, - array_chunk_bytes: int, + reduce_array_batch_bytes: int, overwrite: bool, ) -> None: check_required_dependencies("write_zarr", ["zarr"], dist="zarr") self.output = DataFolder.resolve(output) self.store_template = store_template self.episode_ends_path = episode_ends_path - self.array_chunk_bytes = array_chunk_bytes + self.reduce_array_batch_bytes = reduce_array_batch_bytes self.overwrite = overwrite self._merged = False @@ -339,7 +367,11 @@ def counts_output_rows(self) -> bool: def write_shard_block(self, shard_id, block) -> None: del shard_id, block - self._merge() + try: + self._merge() + except Exception: + self._remove_current_parts() + raise def describe(self) -> tuple[str, str, dict[str, object]]: return ( @@ -348,7 +380,7 @@ def describe(self) -> tuple[str, str, dict[str, object]]: { "path": self.output.abs_path(), "store_template": self.store_template, - "array_chunk_bytes": self.array_chunk_bytes, + "reduce_array_batch_bytes": self.reduce_array_batch_bytes, "reduce_to_single_store": True, }, ) @@ -372,6 +404,8 @@ def _merge(self) -> None: ) if self.overwrite: _clear_final_group(final) + elif _group_has_payload(final): + raise ValueError("write_zarr output already exists and overwrite=False") stores = sorted( get_finalized_workers(stage_index=stage_index - 1), @@ -382,6 +416,7 @@ def _merge(self) -> None: ) row_offset = 0 arrays: dict[str, Any] = {} + payload_paths: set[str] | None = None for row in stores: relpath = self._part_relpath(row.shard_id, row.worker_token) if not self.output.exists(relpath): @@ -390,13 +425,26 @@ def _merge(self) -> None: store=zarr_store(self.output, relpath, mode="r"), mode="r", ) - for path in iter_zarr_array_paths(source): + source_paths = set(iter_zarr_array_paths(source)) + source_payload_paths = { + path for path in source_paths if path != self.episode_ends_path + } + if payload_paths is None: + payload_paths = source_payload_paths + elif source_payload_paths != payload_paths: + raise ValueError( + "Zarr part stores must contain the same payload arrays" + ) + for path in sorted(source_paths): source_array = source[path] if path == self.episode_ends_path: if source_array.shape[0] == 0: continue part_last = row_offset - batch_size = _batch_length(source_array, self.array_chunk_bytes) + batch_size = _batch_length( + source_array, + self.reduce_array_batch_bytes, + ) for start in range(0, int(source_array.shape[0]), batch_size): end = min(int(source_array.shape[0]), start + batch_size) values = np.asarray(source_array[start:end], dtype=np.int64) @@ -410,7 +458,7 @@ def _merge(self) -> None: part_last = int(values[-1]) row_offset += part_last continue - batch_size = _batch_length(source_array, self.array_chunk_bytes) + batch_size = _batch_length(source_array, self.reduce_array_batch_bytes) for start in range(0, int(source_array.shape[0]), batch_size): end = min(int(source_array.shape[0]), start + batch_size) _append_reduced_array( @@ -421,10 +469,9 @@ def _merge(self) -> None: source_array, ) - try: - self.output.rm(f"_parts/{get_active_job_id()}", recursive=True) - except FileNotFoundError: - pass + self._remove_current_parts() + if self.overwrite: + self._remove_stale_parts() try: if not self.output.ls("_parts"): self.output.rmdir("_parts") @@ -439,6 +486,20 @@ def _part_relpath(self, shard_id: str, worker_token: str) -> str: ) ) + def _remove_current_parts(self) -> None: + try: + self.output.rm(f"_parts/{get_active_job_id()}", recursive=True) + except FileNotFoundError: + pass + + def _remove_stale_parts(self) -> None: + try: + for path in self.output.ls("_parts", detail=False): + if path != f"_parts/{get_active_job_id()}": + self.output.rm(path, recursive=True) + except FileNotFoundError: + pass + def _default_robotics_arrays(row: Row) -> dict[str, str]: if not isinstance(row, RoboticsRow): @@ -520,6 +581,10 @@ def _clear_final_group(group: Any) -> None: del group[key] +def _group_has_payload(group: Any) -> bool: + return any(key != "_parts" for key in {*group.array_keys(), *group.group_keys()}) + + def _chunk_shape(array: np.ndarray, target_bytes: int) -> tuple[int, ...]: return (_batch_length(array, target_bytes), *array.shape[1:]) @@ -551,4 +616,4 @@ def _append_reduced_array( dataset.append(values, axis=0) -__all__ = ["ZarrMergeReducerSink", "ZarrSink"] +__all__ = ["ZarrSink"] diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py index e2ae1a8b..95816bbe 100644 --- a/tests/readers/test_zarr_reader.py +++ b/tests/readers/test_zarr_reader.py @@ -10,13 +10,27 @@ import zarr import refiner as mdr +from refiner.cli.run.local import LocalLaunchResumeError from refiner.io.datafolder import DataFolder from refiner.io import DataFile from refiner.robotics.row import RoboticsRow from refiner.pipeline.data.row import DictRow from refiner.pipeline.data.row import Row from refiner.pipeline.data.shard import RowRangeDescriptor -from refiner.pipeline.sinks.zarr import ZarrSink +from refiner.pipeline.sinks.zarr import _ZarrMergeReducerSink, ZarrSink +from refiner.worker.context import set_active_run_context, worker_token_for +from refiner.worker.lifecycle import FinalizedShardWorker, RuntimeLifecycle + + +class _FinalizedWorkersRuntime: + def __init__(self, rows: list[FinalizedShardWorker]) -> None: + self._rows = rows + + def finalized_workers( + self, *, stage_index: int | None = None + ) -> list[FinalizedShardWorker]: + assert stage_index == 0 + return self._rows def _open_test_zarr(path: Path, *, mode: Literal["r", "r+", "a", "w", "w-"]): @@ -56,6 +70,12 @@ def _write_policy_zarr(path: Path) -> None: root.attrs["task"] = "push tee" +def _write_part_zarr(path: Path, arrays: dict[str, np.ndarray]) -> None: + root = _open_test_zarr(path, mode="w") + for name, data in arrays.items(): + _create_array(root, name, data=data) + + def _write_video(path: Path, *, num_frames: int = 3, fps: int = 5) -> None: import av @@ -808,6 +828,30 @@ def test_write_zarr_rejects_store_template_without_worker_id(tmp_path: Path) -> ZarrSink(str(tmp_path / "template.zarr"), store_template="{shard_id}.zarr") +def test_write_zarr_rejects_invalid_reduce_batch_bytes(tmp_path: Path) -> None: + with pytest.raises(ValueError, match="reduce_array_batch_bytes"): + ZarrSink(str(tmp_path / "bad-batch.zarr"), reduce_array_batch_bytes=0) + + +def test_write_zarr_rejects_empty_array_mapping(tmp_path: Path) -> None: + with pytest.raises(ValueError, match="arrays must not be empty"): + ZarrSink(str(tmp_path / "empty-arrays.zarr"), arrays={}) + + +def test_write_zarr_rejects_empty_default_robotics_arrays(tmp_path: Path) -> None: + rows = list( + mdr.from_items([{"episode_id": "episode-1"}]).to_robot_rows( + episode_id_key="episode_id", + action_key=None, + state_key=None, + timestamp_key=None, + ) + ) + + with pytest.raises(ValueError, match="inferred no default robotics arrays"): + ZarrSink(str(tmp_path / "empty-defaults.zarr")).write_block(rows) + + def test_write_zarr_single_store_overwrite_ignores_stale_parts(tmp_path: Path) -> None: zarr_out = tmp_path / "single-overwrite.zarr" @@ -849,7 +893,92 @@ def test_write_zarr_single_store_overwrite_ignores_stale_parts(tmp_path: Path) - file_path_column=None, ).take(1)[0] np.testing.assert_allclose(row["action"], [[1.0]]) - assert stale_part.exists() + assert not stale_part.exists() + + +def test_write_zarr_single_store_rejects_existing_output_when_not_overwriting( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "single-no-overwrite.zarr" + + ( + mdr.from_items([{"action": [[0.0]]}], items_per_shard=1) + .write_zarr( + str(zarr_out), + arrays={"data/action": "action"}, + reduce_to_single_store=True, + ) + .launch_local( + name="zarr-single-no-overwrite-first", + num_workers=1, + rundir=str(tmp_path / "run-no-overwrite-first"), + ) + ) + + with pytest.raises(LocalLaunchResumeError): + ( + mdr.from_items([{"action": [[1.0]]}], items_per_shard=1) + .write_zarr( + str(zarr_out), + arrays={"data/action": "action"}, + reduce_to_single_store=True, + overwrite=False, + ) + .launch_local( + name="zarr-single-no-overwrite-second", + num_workers=1, + rundir=str(tmp_path / "run-no-overwrite-second"), + ) + ) + + row = mdr.read_zarr( + zarr_out, + arrays={"action": "data/action"}, + file_path_column=None, + ).take(1)[0] + np.testing.assert_allclose(row["action"], [[0.0]]) + + +def test_write_zarr_rejects_existing_non_reduced_output_when_not_overwriting( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "sharded-no-overwrite.zarr" + + ( + mdr.from_items([{"action": [[0.0]]}], items_per_shard=1) + .write_zarr(str(zarr_out), arrays={"data/action": "action"}) + .launch_local( + name="zarr-sharded-no-overwrite-first", + num_workers=1, + rundir=str(tmp_path / "run-sharded-no-overwrite-first"), + ) + ) + + with pytest.raises(LocalLaunchResumeError): + ( + mdr.from_items([{"action": [[1.0]]}], items_per_shard=1) + .write_zarr( + str(zarr_out), + arrays={"data/action": "action"}, + overwrite=False, + ) + .launch_local( + name="zarr-sharded-no-overwrite-second", + num_workers=1, + rundir=str(tmp_path / "run-sharded-no-overwrite-second"), + ) + ) + + rows = [ + mdr.read_zarr( + store, + arrays={"action": "data/action"}, + file_path_column=None, + ).take(1)[0] + for store in zarr_out.glob("*.zarr") + ] + assert len(rows) == 1 + np.testing.assert_allclose(rows[0]["action"], [[0.0]]) def test_write_zarr_single_store_skips_empty_shards(tmp_path: Path) -> None: @@ -876,6 +1005,71 @@ def test_write_zarr_single_store_skips_empty_shards(tmp_path: Path) -> None: assert not (zarr_out / "_parts").exists() +def test_write_zarr_single_store_rejects_inconsistent_part_payloads( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "single-inconsistent-parts.zarr" + first_worker = "worker-a" + second_worker = "worker-b" + first_part = ( + zarr_out + / "_parts" + / "local" + / f"shard-a__w{worker_token_for(first_worker)}.zarr" + ) + second_part = ( + zarr_out + / "_parts" + / "local" + / f"shard-b__w{worker_token_for(second_worker)}.zarr" + ) + _write_part_zarr( + first_part, + { + "data/action": np.asarray([[0.0]], dtype=np.float32), + "meta/episode_ends": np.asarray([1], dtype=np.int64), + }, + ) + _write_part_zarr( + second_part, + { + "data/action": np.asarray([[1.0]], dtype=np.float32), + "data/state": np.asarray([[2.0]], dtype=np.float32), + "meta/episode_ends": np.asarray([1], dtype=np.int64), + }, + ) + + runtime = _FinalizedWorkersRuntime( + [ + FinalizedShardWorker( + shard_id="shard-a", + worker_id=first_worker, + global_ordinal=0, + ), + FinalizedShardWorker( + shard_id="shard-b", + worker_id=second_worker, + global_ordinal=1, + ), + ] + ) + with set_active_run_context( + job_id="local", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast(RuntimeLifecycle, runtime), + ): + with pytest.raises(ValueError, match="same payload arrays"): + _ZarrMergeReducerSink( + str(zarr_out), + store_template="{shard_id}__w{worker_id}.zarr", + episode_ends_path="meta/episode_ends", + reduce_array_batch_bytes=1024, + overwrite=True, + ).write_block([DictRow({}, shard_id="reduce")]) + + def test_write_zarr_rejects_rows_missing_inferred_default_arrays( tmp_path: Path, ) -> None: From 6f17c04cd266d3b1c508a53f0ec1b0e788109a85 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sun, 24 May 2026 03:46:51 +0200 Subject: [PATCH 15/39] Harden Zarr writer reduction --- src/refiner/pipeline/sinks/reducer/file.py | 38 +- src/refiner/pipeline/sinks/zarr.py | 351 +++++++++++------ src/refiner/video/decode.py | 2 +- src/refiner/video/types.py | 4 - src/refiner/video/writer.py | 2 +- tests/readers/test_zarr_reader.py | 413 ++++++++++++++++++++- tests/robotics/test_robotics_row.py | 2 +- tests/test_video_decode.py | 1 - 8 files changed, 666 insertions(+), 147 deletions(-) diff --git a/src/refiner/pipeline/sinks/reducer/file.py b/src/refiner/pipeline/sinks/reducer/file.py index 456cd912..324d3480 100644 --- a/src/refiner/pipeline/sinks/reducer/file.py +++ b/src/refiner/pipeline/sinks/reducer/file.py @@ -85,15 +85,15 @@ def __init__( reducer_name: str, assets_subdir: str | None = None, recursive: bool = False, + overwrite: bool = True, ) -> None: self.output = DataFolder.resolve(output) self.filename_template = filename_template self.reducer_name = reducer_name self.assets_subdir = assets_subdir self.recursive = recursive + self.overwrite = overwrite self._managed_path_pattern = _compile_managed_path_pattern(filename_template) - self._managed_listing_prefix = _managed_listing_prefix(filename_template) - self._managed_path_depth = _path_depth(filename_template) self._cleanup_ran = False def write_shard_block(self, shard_id, block) -> None: @@ -113,6 +113,8 @@ def describe(self) -> tuple[str, str, dict[str, object]]: args["assets_subdir"] = self.assets_subdir if self.recursive: args["recursive"] = True + if not self.overwrite: + args["overwrite"] = False return ( self.reducer_name, "writer", @@ -146,8 +148,8 @@ def _run_cleanup(self) -> None: else None ) - removed_asset_attempts: set[str] = set() - removed_managed_paths: set[str] = set() + stale_asset_attempts: set[str] = set() + stale_managed_paths: set[str] = set() # Extra template fields are treated as structure only. Authority is decided # solely from the finalized (shard_id, worker_id) pair extracted from each # managed path. @@ -161,15 +163,10 @@ def _run_cleanup(self) -> None: match = ASSET_ATTEMPT_DIR_RE.fullmatch(attempt_dir) if match is None: continue + asset_path = f"{assets_prefix}{attempt_dir}" if (match.group("shard_id"), match.group("worker_id")) in keep_pairs: continue - if attempt_dir in removed_asset_attempts: - continue - removed_asset_attempts.add(attempt_dir) - try: - self.output.rm(f"{assets_prefix}{attempt_dir}", recursive=True) - except FileNotFoundError: - continue + stale_asset_attempts.add(asset_path) continue managed_path = rel_path @@ -186,11 +183,19 @@ def _run_cleanup(self) -> None: continue if (match.group("shard_id"), match.group("worker_id")) in keep_pairs: continue - if managed_path in removed_managed_paths: + stale_managed_paths.add(managed_path) + + if not self.overwrite and (stale_asset_attempts or stale_managed_paths): + raise ValueError(f"{self.reducer_name} output already exists") + + for path in sorted(stale_asset_attempts): + try: + self.output.rm(path, recursive=True) + except FileNotFoundError: continue - removed_managed_paths.add(managed_path) + for path in sorted(stale_managed_paths): try: - self.output.rm(managed_path, recursive=self.recursive) + self.output.rm(path, recursive=self.recursive) except FileNotFoundError: continue @@ -201,9 +206,10 @@ def _listed_cleanup_paths(self) -> list[str]: except FileNotFoundError: return [] - paths = [self._managed_listing_prefix] + listing_prefix = _managed_listing_prefix(self.filename_template) + paths = [listing_prefix] depth = max( - 1, self._managed_path_depth - _path_depth(self._managed_listing_prefix) + 1, _path_depth(self.filename_template) - _path_depth(listing_prefix) ) for _ in range(depth): next_paths: list[str] = [] diff --git a/src/refiner/pipeline/sinks/zarr.py b/src/refiner/pipeline/sinks/zarr.py index 6ca6c606..9e894de4 100644 --- a/src/refiner/pipeline/sinks/zarr.py +++ b/src/refiner/pipeline/sinks/zarr.py @@ -19,13 +19,14 @@ from refiner.utils import check_required_dependencies from refiner.video import VideoFrameArray, VideoSource from refiner.worker.context import ( - get_active_job_id, get_active_stage_index, get_active_worker_token, get_finalized_workers, ) _DEFAULT_ARRAY_CHUNK_BYTES = 8 * 1024 * 1024 +_MAX_INITIAL_CHUNK_ROWS = 1024 +_DONE_MARKER_RELPATH = "_refiner/write_zarr.done" @dataclass @@ -35,6 +36,12 @@ class _ShardStore: row_end: int = 0 +@dataclass(frozen=True) +class _PartStore: + relpath: str + paths: set[str] + + class ZarrSink(BaseSink): def __init__( self, @@ -177,19 +184,29 @@ async def _append_video( expected_length: int | None = None, ) -> int: if isinstance(video, VideoFrameArray): - frames = video.frame_arrays - if expected_length is not None and int(frames.shape[0]) != expected_length: + if expected_length is not None and video.frame_count != expected_length: raise ValueError("Zarr arrays for one row must have matching lengths") - for start in range(0, int(frames.shape[0]), self.video_frame_batch_size): - end = min(int(frames.shape[0]), start + self.video_frame_batch_size) - self._append_array(store, path, frames[start:end]) - return int(frames.shape[0]) + batch: list[np.ndarray] = [] + batch_limit: int | None = None + for frame in video.iter_frame_arrays(): + batch.append(frame) + if batch_limit is None: + batch_limit = self._video_batch_limit(frame) + if len(batch) >= batch_limit: + self._append_array(store, path, np.stack(batch, axis=0)) + batch.clear() + if batch: + self._append_array(store, path, np.stack(batch, axis=0)) + return video.frame_count batch: list[np.ndarray] = [] + batch_limit: int | None = None count = 0 async for frame in video.iter_frames(): batch.append(frame.frame.to_ndarray(format="rgb24")) - if len(batch) >= self.video_frame_batch_size: + if batch_limit is None: + batch_limit = self._video_batch_limit(batch[0]) + if len(batch) >= batch_limit: if expected_length is not None and count + len(batch) > expected_length: raise ValueError( "Zarr arrays for one row must have matching lengths" @@ -206,6 +223,14 @@ async def _append_video( raise ValueError("Zarr arrays for one row must have matching lengths") return count + def _video_batch_limit(self, frame: np.ndarray) -> int: + return max( + self.video_frame_batch_size, + _batch_length_for_shape( + (1, *frame.shape), frame.dtype, self.array_chunk_bytes + ), + ) + def _arrays_for_row(self, row: Row) -> dict[str, str]: if self.arrays is not None: return self.arrays @@ -242,14 +267,14 @@ def _store(self, shard_id: str) -> _ShardStore: return store def _check_no_overwrite_output(self) -> None: - if self.overwrite or self._checked_no_overwrite: + if ( + self.overwrite + or self._checked_no_overwrite + or not self.reduce_to_single_store + ): return self._checked_no_overwrite = True - try: - entries = self.output.ls("", detail=False) - except FileNotFoundError: - return - if entries: + if _output_has_payload(self.output): raise ValueError("write_zarr output already exists and overwrite=False") def _store_relpath(self, shard_id: str) -> str: @@ -262,26 +287,30 @@ def _store_relpath(self, shard_id: str) -> str: return relpath def on_shard_complete(self, shard_id: str) -> None: + if ( + self.reduce_to_single_store + and self._store_relpath(shard_id) not in self._stores + ): + with self.output.open(self._empty_marker_relpath(shard_id), mode="wb"): + pass self._stores.pop(self._store_relpath(shard_id), None) + def _empty_marker_relpath(self, shard_id: str) -> str: + return self._store_relpath(shard_id) + ".empty" + def _append_array( self, store: _ShardStore, path: str, array: np.ndarray, ) -> None: - dataset = store.arrays.get(path) - if dataset is None: - dataset = store.root.create_dataset( - path, - shape=(0, *array.shape[1:]), - chunks=_chunk_shape(array, self.array_chunk_bytes), - dtype=array.dtype, - ) - store.arrays[path] = dataset - else: - self._validate_array_append(store, path, array) - dataset.append(array, axis=0) + _append_zarr_array( + store.root, + store.arrays, + path, + array, + chunks=_chunk_shape(array, self.array_chunk_bytes), + ) def _validate_array_append( self, @@ -292,12 +321,7 @@ def _validate_array_append( dataset = store.arrays.get(path) if dataset is None: return - if tuple(dataset.shape[1:]) != tuple(array.shape[1:]): - raise ValueError( - f"Zarr arrays for {path!r} must have matching trailing shapes" - ) - if dataset.dtype != array.dtype: - raise ValueError(f"Zarr arrays for {path!r} must have matching dtypes") + _validate_array_schema(path, dataset, array) def _drop_array(self, store: _ShardStore, path: str) -> None: store.arrays.pop(path, None) @@ -340,6 +364,7 @@ def build_reducer(self) -> BaseSink | None: filename_template=self.store_template, reducer_name="write_zarr_reduce", recursive=True, + overwrite=self.overwrite, ) @@ -367,11 +392,7 @@ def counts_output_rows(self) -> bool: def write_shard_block(self, shard_id, block) -> None: del shard_id, block - try: - self._merge() - except Exception: - self._remove_current_parts() - raise + self._merge() def describe(self) -> tuple[str, str, dict[str, object]]: return ( @@ -396,6 +417,17 @@ def _merge(self) -> None: "write_zarr_reduce requires an active reducer stage with a prior writer stage" ) + expected_parts = self._expected_parts(stage_index) + if self.output.exists(_DONE_MARKER_RELPATH) and not any( + self.output.exists(relpath) or self.output.exists(f"{relpath}.empty") + for relpath in expected_parts + ): + return + + parts = self._collect_parts(expected_parts) + if not self.overwrite and _output_has_payload(self.output): + raise ValueError("write_zarr output already exists and overwrite=False") + import zarr final = zarr.open_group( @@ -404,80 +436,84 @@ def _merge(self) -> None: ) if self.overwrite: _clear_final_group(final) - elif _group_has_payload(final): - raise ValueError("write_zarr output already exists and overwrite=False") - stores = sorted( - get_finalized_workers(stage_index=stage_index - 1), - key=lambda row: ( - row.global_ordinal is None, - row.global_ordinal if row.global_ordinal is not None else row.shard_id, - ), - ) - row_offset = 0 - arrays: dict[str, Any] = {} - payload_paths: set[str] | None = None - for row in stores: - relpath = self._part_relpath(row.shard_id, row.worker_token) - if not self.output.exists(relpath): - continue - source = zarr.open_group( - store=zarr_store(self.output, relpath, mode="r"), - mode="r", - ) - source_paths = set(iter_zarr_array_paths(source)) - source_payload_paths = { - path for path in source_paths if path != self.episode_ends_path - } - if payload_paths is None: - payload_paths = source_payload_paths - elif source_payload_paths != payload_paths: - raise ValueError( - "Zarr part stores must contain the same payload arrays" + try: + row_offset = 0 + arrays: dict[str, Any] = {} + for part in parts: + source = zarr.open_group( + store=zarr_store(self.output, part.relpath, mode="r"), + mode="r", ) - for path in sorted(source_paths): - source_array = source[path] - if path == self.episode_ends_path: - if source_array.shape[0] == 0: + for path in sorted(part.paths): + source_array = source[path] + if path == self.episode_ends_path: + if source_array.shape[0] == 0: + continue + part_last = row_offset + batch_size = _batch_length( + source_array, + self.reduce_array_batch_bytes, + ) + for start in range(0, int(source_array.shape[0]), batch_size): + end = min(int(source_array.shape[0]), start + batch_size) + values = np.asarray(source_array[start:end], dtype=np.int64) + _append_zarr_array( + final, + arrays, + path, + values + row_offset, + chunks=getattr(source_array, "chunks", None), + compressor=getattr(source_array, "compressor", None), + ) + part_last = int(values[-1]) + row_offset += part_last continue - part_last = row_offset batch_size = _batch_length( - source_array, - self.reduce_array_batch_bytes, + source_array, self.reduce_array_batch_bytes ) + if source_array.shape[0] == 0: + _append_zarr_array( + final, + arrays, + path, + np.asarray(source_array[:0]), + chunks=getattr(source_array, "chunks", None), + compressor=getattr(source_array, "compressor", None), + ) + continue for start in range(0, int(source_array.shape[0]), batch_size): end = min(int(source_array.shape[0]), start + batch_size) - values = np.asarray(source_array[start:end], dtype=np.int64) - _append_reduced_array( + _append_zarr_array( final, arrays, path, - values + row_offset, - source_array, + np.asarray(source_array[start:end]), + chunks=getattr(source_array, "chunks", None), + compressor=getattr(source_array, "compressor", None), ) - part_last = int(values[-1]) - row_offset += part_last - continue - batch_size = _batch_length(source_array, self.reduce_array_batch_bytes) - for start in range(0, int(source_array.shape[0]), batch_size): - end = min(int(source_array.shape[0]), start + batch_size) - _append_reduced_array( - final, - arrays, - path, - np.asarray(source_array[start:end]), - source_array, - ) + except Exception: + if not self.overwrite: + _clear_final_group(final) + raise - self._remove_current_parts() - if self.overwrite: - self._remove_stale_parts() + with self.output.open(_DONE_MARKER_RELPATH, mode="wb"): + pass + self._remove_parts() try: if not self.output.ls("_parts"): self.output.rmdir("_parts") except (FileNotFoundError, OSError, ValueError): pass + def _expected_parts(self, stage_index: int) -> list[str]: + return [ + self._part_relpath(row.shard_id, row.worker_token) + for row in _sort_finalized_workers( + get_finalized_workers(stage_index=stage_index - 1), + ) + ] + def _part_relpath(self, shard_id: str, worker_token: str) -> str: return _part_store_relpath( self.store_template.format( @@ -486,17 +522,51 @@ def _part_relpath(self, shard_id: str, worker_token: str) -> str: ) ) - def _remove_current_parts(self) -> None: - try: - self.output.rm(f"_parts/{get_active_job_id()}", recursive=True) - except FileNotFoundError: - pass + def _collect_parts(self, expected_parts: Iterable[str]) -> list[_PartStore]: + import zarr - def _remove_stale_parts(self) -> None: + parts: list[_PartStore] = [] + payload_paths: set[str] | None = None + schemas: dict[str, tuple[tuple[int, ...], np.dtype[Any]]] = {} + for relpath in expected_parts: + if not self.output.exists(relpath): + if self.output.exists(f"{relpath}.empty"): + continue + raise ValueError(f"Zarr part store is missing: {relpath}") + source = zarr.open_group( + store=zarr_store(self.output, relpath, mode="r"), + mode="r", + ) + source_paths = set(iter_zarr_array_paths(source)) + if not source_paths: + continue + source_payload_paths = { + path for path in source_paths if path != self.episode_ends_path + } + if payload_paths is None: + payload_paths = source_payload_paths + elif source_payload_paths != payload_paths: + raise ValueError( + "Zarr part stores must contain the same payload arrays" + ) + for path in source_paths: + source_array = source[path] + schema = (tuple(source_array.shape[1:]), np.dtype(source_array.dtype)) + previous = schemas.setdefault(path, schema) + if previous != schema: + if previous[0] != schema[0]: + raise ValueError( + f"Zarr arrays for {path!r} must have matching trailing shapes" + ) + raise ValueError( + f"Zarr arrays for {path!r} must have matching dtypes" + ) + parts.append(_PartStore(relpath=relpath, paths=source_paths)) + return parts + + def _remove_parts(self) -> None: try: - for path in self.output.ls("_parts", detail=False): - if path != f"_parts/{get_active_job_id()}": - self.output.rm(path, recursive=True) + self.output.rm("_parts", recursive=True) except FileNotFoundError: pass @@ -572,47 +642,104 @@ def _as_array(value: Any) -> np.ndarray: def _part_store_relpath(relpath: str) -> str: - return f"_parts/{get_active_job_id()}/{relpath}" + return f"_parts/{relpath}" + + +def _sort_finalized_workers(rows: Iterable[Any]) -> list[Any]: + return sorted( + rows, + key=lambda row: ( + row.global_ordinal is None, + row.global_ordinal if row.global_ordinal is not None else row.shard_id, + ), + ) def _clear_final_group(group: Any) -> None: for key in sorted({*group.array_keys(), *group.group_keys()}): if key != "_parts": del group[key] + group.attrs.clear() def _group_has_payload(group: Any) -> bool: - return any(key != "_parts" for key in {*group.array_keys(), *group.group_keys()}) + return bool(group.attrs) or any( + key != "_parts" for key in {*group.array_keys(), *group.group_keys()} + ) + + +def _output_has_payload(output: DataFolder) -> bool: + import zarr + + try: + entries = output.ls("", detail=False) + except FileNotFoundError: + return False + non_part_entries = [ + entry + for entry in entries + if str(entry).split("/", maxsplit=1)[0] not in {"_parts", "_refiner"} + ] + if not non_part_entries: + return False + try: + group = zarr.open_group(store=zarr_store(output, "", mode="r"), mode="r") + except Exception: + return True + return _group_has_payload(group) def _chunk_shape(array: np.ndarray, target_bytes: int) -> tuple[int, ...]: - return (_batch_length(array, target_bytes), *array.shape[1:]) + chunk_rows = min( + _batch_length(array, target_bytes), + max(int(array.shape[0]), _MAX_INITIAL_CHUNK_ROWS), + ) + return (chunk_rows, *array.shape[1:]) def _batch_length(array: Any, target_bytes: int) -> int: - dtype = np.dtype(array.dtype) - row_values = int(np.prod(tuple(array.shape[1:]), dtype=np.int64)) + return _batch_length_for_shape(tuple(array.shape), array.dtype, target_bytes) + + +def _batch_length_for_shape( + shape: tuple[int, ...], + dtype: np.dtype[Any] | type[Any], + target_bytes: int, +) -> int: + dtype = np.dtype(dtype) + row_values = int(np.prod(shape[1:], dtype=np.int64)) row_bytes = max(1, dtype.itemsize * max(1, row_values)) return max(1, target_bytes // row_bytes) -def _append_reduced_array( +def _validate_array_schema(path: str, dataset: Any, values: np.ndarray) -> None: + if tuple(dataset.shape[1:]) != tuple(values.shape[1:]): + raise ValueError(f"Zarr arrays for {path!r} must have matching trailing shapes") + if dataset.dtype != values.dtype: + raise ValueError(f"Zarr arrays for {path!r} must have matching dtypes") + + +def _append_zarr_array( root: Any, arrays: dict[str, Any], path: str, values: np.ndarray, - source_array: Any, + *, + chunks: tuple[int, ...] | None = None, + compressor: Any = None, ) -> None: dataset = arrays.get(path) if dataset is None: dataset = root.create_dataset( path, shape=(0, *values.shape[1:]), - chunks=getattr(source_array, "chunks", None), - dtype=source_array.dtype, - compressor=getattr(source_array, "compressor", None), + chunks=chunks, + dtype=values.dtype, + compressor=compressor, ) arrays[path] = dataset + else: + _validate_array_schema(path, dataset, values) dataset.append(values, axis=0) diff --git a/src/refiner/video/decode.py b/src/refiner/video/decode.py index 4f3ac953..d07b28ae 100644 --- a/src/refiner/video/decode.py +++ b/src/refiner/video/decode.py @@ -63,7 +63,7 @@ async def export_clip( fps=video.fps, movflags=None, ) - writer.append_frame_arrays(video.frame_arrays) + writer.append_frame_arrays(video.iter_frame_arrays()) writer.close() return output_file.getvalue() encoded_video = cast("VideoFile | VideoBytes", video) diff --git a/src/refiner/video/types.py b/src/refiner/video/types.py index 13b04539..a5d67368 100644 --- a/src/refiner/video/types.py +++ b/src/refiner/video/types.py @@ -233,10 +233,6 @@ def frame_count(self) -> int: def duration_s(self) -> float: return self.frame_count / float(self.fps) - @property - def frame_arrays(self) -> np.ndarray: - return self._array - def iter_frame_arrays(self) -> Iterator[np.ndarray]: yield from self._array diff --git a/src/refiner/video/writer.py b/src/refiner/video/writer.py index 3228788b..888cd231 100644 --- a/src/refiner/video/writer.py +++ b/src/refiner/video/writer.py @@ -141,7 +141,7 @@ def _commit_frame_arrays_sync( writer = self._ensure_transcode_writer(video.fps) file_index = self._next_file_index from_timestamp, to_timestamp = writer.append_frame_arrays( - video.frame_arrays, + video.iter_frame_arrays(), frame_observer=frame_observer, ) if writer.stream is None: diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py index 95816bbe..62737308 100644 --- a/tests/readers/test_zarr_reader.py +++ b/tests/readers/test_zarr_reader.py @@ -823,6 +823,29 @@ def test_write_zarr_can_reduce_to_single_store(tmp_path: Path) -> None: assert not (zarr_out / "_parts").exists() +def test_write_zarr_single_store_preserves_empty_payload_arrays(tmp_path: Path) -> None: + zarr_out = tmp_path / "single-empty-payload.zarr" + + ( + mdr.from_items([{"action": np.empty((0, 2), dtype=np.float32)}]) + .write_zarr( + str(zarr_out), + arrays={"data/action": "action"}, + reduce_to_single_store=True, + ) + .launch_local( + name="zarr-single-empty-payload", + num_workers=1, + rundir=str(tmp_path / "run-empty-payload"), + ) + ) + + root = _open_test_zarr(zarr_out, mode="r") + assert "data/action" in root + assert root["data/action"].shape == (0, 2) + assert root["meta/episode_ends"][:].tolist() == [0] + + def test_write_zarr_rejects_store_template_without_worker_id(tmp_path: Path) -> None: with pytest.raises(ValueError, match="store_template requires fields"): ZarrSink(str(tmp_path / "template.zarr"), store_template="{shard_id}.zarr") @@ -869,7 +892,7 @@ def test_write_zarr_single_store_overwrite_ignores_stale_parts(tmp_path: Path) - ) ) - stale_part = zarr_out / "_parts" / "old-job" / "old__wold.zarr" + stale_part = zarr_out / "_parts" / "old__wold.zarr" stale_part.mkdir(parents=True) (stale_part / ".zgroup").write_text('{"zarr_format": 2}', encoding="utf-8") @@ -939,6 +962,33 @@ def test_write_zarr_single_store_rejects_existing_output_when_not_overwriting( np.testing.assert_allclose(row["action"], [[0.0]]) +def test_write_zarr_single_store_rejects_attrs_only_output_when_not_overwriting( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "single-attrs-no-overwrite.zarr" + root = _open_test_zarr(zarr_out, mode="w") + root.attrs["task"] = "old" + + with pytest.raises(LocalLaunchResumeError): + ( + mdr.from_items([{"action": [[1.0]]}], items_per_shard=1) + .write_zarr( + str(zarr_out), + arrays={"data/action": "action"}, + reduce_to_single_store=True, + overwrite=False, + ) + .launch_local( + name="zarr-single-attrs-no-overwrite", + num_workers=1, + rundir=str(tmp_path / "run-attrs-no-overwrite"), + ) + ) + + root = _open_test_zarr(zarr_out, mode="r") + assert dict(root.attrs) == {"task": "old"} + + def test_write_zarr_rejects_existing_non_reduced_output_when_not_overwriting( tmp_path: Path, ) -> None: @@ -977,8 +1027,77 @@ def test_write_zarr_rejects_existing_non_reduced_output_when_not_overwriting( ).take(1)[0] for store in zarr_out.glob("*.zarr") ] - assert len(rows) == 1 - np.testing.assert_allclose(rows[0]["action"], [[0.0]]) + assert sorted(float(row["action"][0][0]) for row in rows) == [0.0, 1.0] + + +def test_write_zarr_non_reduced_no_overwrite_preserves_finalized_retry_output( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "sharded-no-overwrite-retry.zarr" + shard_id = "0123456789ab" + loser_worker_id = "loser" + winner_worker_id = "winner" + loser = zarr_out / f"{shard_id}__w{worker_token_for(loser_worker_id)}.zarr" + winner = zarr_out / f"{shard_id}__w{worker_token_for(winner_worker_id)}.zarr" + _write_part_zarr(loser, {"data/action": np.asarray([[0.0]], dtype=np.float32)}) + _write_part_zarr(winner, {"data/action": np.asarray([[1.0]], dtype=np.float32)}) + + reducer = ZarrSink( + str(zarr_out), + arrays={"data/action": "action"}, + overwrite=False, + ).build_reducer() + assert reducer is not None + + runtime = _FinalizedWorkersRuntime( + [FinalizedShardWorker(shard_id=shard_id, worker_id=winner_worker_id)] + ) + with set_active_run_context( + job_id="local", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast(RuntimeLifecycle, runtime), + ): + with pytest.raises(ValueError, match="output already exists"): + reducer.write_block([DictRow({}, shard_id="reduce")]) + + assert loser.exists() + assert winner.exists() + + +def test_write_zarr_allows_fresh_non_reduced_multiworker_no_overwrite( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "sharded-fresh-no-overwrite.zarr" + + ( + mdr.from_items( + [{"action": [[0.0]]}, {"action": [[1.0]]}], + items_per_shard=1, + ) + .write_zarr( + str(zarr_out), + arrays={"data/action": "action"}, + overwrite=False, + ) + .launch_local( + name="zarr-sharded-fresh-no-overwrite", + num_workers=2, + rundir=str(tmp_path / "run-sharded-fresh-no-overwrite"), + ) + ) + + rows = [ + mdr.read_zarr( + store, + arrays={"action": "data/action"}, + file_path_column=None, + ).take(1)[0] + for store in sorted(zarr_out.glob("*.zarr")) + ] + assert len(rows) == 2 + assert sorted(float(row["action"][0][0]) for row in rows) == [0.0, 1.0] def test_write_zarr_single_store_skips_empty_shards(tmp_path: Path) -> None: @@ -1005,6 +1124,40 @@ def test_write_zarr_single_store_skips_empty_shards(tmp_path: Path) -> None: assert not (zarr_out / "_parts").exists() +def test_write_zarr_single_store_skips_mixed_empty_shards(tmp_path: Path) -> None: + zarr_out = tmp_path / "single-mixed-empty-shards.zarr" + + ( + mdr.from_items( + [{"action": [[0.0]]}, {"action": [[1.0]]}], + items_per_shard=1, + ) + .filter(lambda row: float(row["action"][0][0]) > 0.0) + .write_zarr( + str(zarr_out), + arrays={"data/action": "action"}, + reduce_to_single_store=True, + ) + .launch_local( + name="zarr-single-mixed-empty-shards", + num_workers=2, + rundir=str(tmp_path / "run-mixed-empty"), + ) + ) + + row = mdr.read_zarr( + zarr_out, + arrays={ + "action": "data/action", + "episode_ends": "meta/episode_ends", + }, + file_path_column=None, + ).take(1)[0] + np.testing.assert_allclose(row["action"], [[1.0]]) + assert row["episode_ends"].tolist() == [1] + assert not (zarr_out / "_parts").exists() + + def test_write_zarr_single_store_rejects_inconsistent_part_payloads( tmp_path: Path, ) -> None: @@ -1012,16 +1165,10 @@ def test_write_zarr_single_store_rejects_inconsistent_part_payloads( first_worker = "worker-a" second_worker = "worker-b" first_part = ( - zarr_out - / "_parts" - / "local" - / f"shard-a__w{worker_token_for(first_worker)}.zarr" + zarr_out / "_parts" / f"shard-a__w{worker_token_for(first_worker)}.zarr" ) second_part = ( - zarr_out - / "_parts" - / "local" - / f"shard-b__w{worker_token_for(second_worker)}.zarr" + zarr_out / "_parts" / f"shard-b__w{worker_token_for(second_worker)}.zarr" ) _write_part_zarr( first_part, @@ -1068,6 +1215,232 @@ def test_write_zarr_single_store_rejects_inconsistent_part_payloads( reduce_array_batch_bytes=1024, overwrite=True, ).write_block([DictRow({}, shard_id="reduce")]) + assert first_part.exists() + assert second_part.exists() + + +def test_write_zarr_single_store_rejects_missing_finalized_part( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "single-missing-part.zarr" + _write_part_zarr( + zarr_out, + { + "data/action": np.asarray([[9.0]], dtype=np.float32), + "meta/episode_ends": np.asarray([1], dtype=np.int64), + }, + ) + runtime = _FinalizedWorkersRuntime( + [ + FinalizedShardWorker( + shard_id="shard-a", + worker_id="worker-a", + global_ordinal=0, + ) + ] + ) + + with set_active_run_context( + job_id="local", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast(RuntimeLifecycle, runtime), + ): + with pytest.raises(ValueError, match="part store is missing"): + _ZarrMergeReducerSink( + str(zarr_out), + store_template="{shard_id}__w{worker_id}.zarr", + episode_ends_path="meta/episode_ends", + reduce_array_batch_bytes=1024, + overwrite=True, + ).write_block([DictRow({}, shard_id="reduce")]) + + row = mdr.read_zarr( + zarr_out, + arrays={"action": "data/action", "episode_ends": "meta/episode_ends"}, + file_path_column=None, + ).take(1)[0] + np.testing.assert_allclose(row["action"], [[9.0]]) + assert row["episode_ends"].tolist() == [1] + + +def test_write_zarr_single_store_completed_merge_is_retryable( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "single-completed-retry.zarr" + _write_part_zarr( + zarr_out, + { + "data/action": np.asarray([[9.0]], dtype=np.float32), + "meta/episode_ends": np.asarray([1], dtype=np.int64), + }, + ) + marker = zarr_out / "_refiner" / "write_zarr.done" + marker.parent.mkdir(parents=True) + marker.write_bytes(b"") + runtime = _FinalizedWorkersRuntime( + [ + FinalizedShardWorker( + shard_id="shard-a", + worker_id="worker-a", + global_ordinal=0, + ) + ] + ) + + with set_active_run_context( + job_id="local", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast(RuntimeLifecycle, runtime), + ): + _ZarrMergeReducerSink( + str(zarr_out), + store_template="{shard_id}__w{worker_id}.zarr", + episode_ends_path="meta/episode_ends", + reduce_array_batch_bytes=1024, + overwrite=True, + ).write_block([DictRow({}, shard_id="reduce")]) + + row = mdr.read_zarr( + zarr_out, + arrays={"action": "data/action", "episode_ends": "meta/episode_ends"}, + file_path_column=None, + ).take(1)[0] + np.testing.assert_allclose(row["action"], [[9.0]]) + assert row["episode_ends"].tolist() == [1] + + +def test_write_zarr_single_store_parts_are_resume_stable(tmp_path: Path) -> None: + zarr_out = tmp_path / "single-resume-stable.zarr" + worker_id = "original-worker" + part = zarr_out / "_parts" / f"shard-a__w{worker_token_for(worker_id)}.zarr" + _write_part_zarr( + part, + { + "data/action": np.asarray([[4.0]], dtype=np.float32), + "meta/episode_ends": np.asarray([1], dtype=np.int64), + }, + ) + runtime = _FinalizedWorkersRuntime( + [ + FinalizedShardWorker( + shard_id="shard-a", + worker_id=worker_id, + global_ordinal=0, + ) + ] + ) + + with set_active_run_context( + job_id="resumed-job", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast(RuntimeLifecycle, runtime), + ): + _ZarrMergeReducerSink( + str(zarr_out), + store_template="{shard_id}__w{worker_id}.zarr", + episode_ends_path="meta/episode_ends", + reduce_array_batch_bytes=1024, + overwrite=True, + ).write_block([DictRow({}, shard_id="reduce")]) + + row = mdr.read_zarr( + zarr_out, + arrays={"action": "data/action", "episode_ends": "meta/episode_ends"}, + file_path_column=None, + ).take(1)[0] + np.testing.assert_allclose(row["action"], [[4.0]]) + assert row["episode_ends"].tolist() == [1] + + +def test_write_zarr_single_store_overwrite_clears_root_attrs( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "single-overwrite-attrs.zarr" + root = _open_test_zarr(zarr_out, mode="w") + root.attrs["task"] = "old" + + ( + mdr.from_items([{"action": [[1.0]]}], items_per_shard=1) + .write_zarr( + str(zarr_out), + arrays={"data/action": "action"}, + reduce_to_single_store=True, + ) + .launch_local( + name="zarr-single-overwrite-attrs", + num_workers=1, + rundir=str(tmp_path / "run-overwrite-attrs"), + ) + ) + + root = _open_test_zarr(zarr_out, mode="r") + assert dict(root.attrs) == {} + + +def test_write_zarr_single_store_rejects_part_dtype_drift( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "single-dtype-drift.zarr" + first_worker = "worker-a" + second_worker = "worker-b" + first_part = ( + zarr_out / "_parts" / f"shard-a__w{worker_token_for(first_worker)}.zarr" + ) + second_part = ( + zarr_out / "_parts" / f"shard-b__w{worker_token_for(second_worker)}.zarr" + ) + _write_part_zarr( + first_part, + { + "data/action": np.asarray([[0.0]], dtype=np.float32), + "meta/episode_ends": np.asarray([1], dtype=np.int64), + }, + ) + _write_part_zarr( + second_part, + { + "data/action": np.asarray([[1.0]], dtype=np.float64), + "meta/episode_ends": np.asarray([1], dtype=np.int64), + }, + ) + + runtime = _FinalizedWorkersRuntime( + [ + FinalizedShardWorker( + shard_id="shard-a", + worker_id=first_worker, + global_ordinal=0, + ), + FinalizedShardWorker( + shard_id="shard-b", + worker_id=second_worker, + global_ordinal=1, + ), + ] + ) + with set_active_run_context( + job_id="local", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast(RuntimeLifecycle, runtime), + ): + with pytest.raises(ValueError, match="matching dtypes"): + _ZarrMergeReducerSink( + str(zarr_out), + store_template="{shard_id}__w{worker_id}.zarr", + episode_ends_path="meta/episode_ends", + reduce_array_batch_bytes=1024, + overwrite=True, + ).write_block([DictRow({}, shard_id="reduce")]) + assert first_part.exists() + assert second_part.exists() def test_write_zarr_rejects_rows_missing_inferred_default_arrays( @@ -1209,6 +1582,24 @@ def test_write_zarr_uses_byte_budgeted_chunks_for_large_rows(tmp_path: Path) -> assert root["data/rgb"].chunks == (1, 4, 4, 3) +def test_write_zarr_caps_low_dimensional_initial_chunks(tmp_path: Path) -> None: + output = tmp_path / "small-array-chunks.zarr" + ZarrSink( + str(output), + arrays={"data/action": "action"}, + ).write_block( + [ + DictRow( + {"action": np.asarray([[1.0]], dtype=np.float32)}, + shard_id="shard", + ) + ] + ) + + root = _open_test_zarr(next(output.glob("*.zarr")), mode="r") + assert root["data/action"].chunks == (1024, 1) + + def test_write_zarr_streams_encoded_videos(tmp_path: Path) -> None: source = tmp_path / "source.mp4" output = tmp_path / "encoded-video.zarr" diff --git a/tests/robotics/test_robotics_row.py b/tests/robotics/test_robotics_row.py index ca6121d2..b4e026e8 100644 --- a/tests/robotics/test_robotics_row.py +++ b/tests/robotics/test_robotics_row.py @@ -1,4 +1,5 @@ from __future__ import annotations + from typing import Any, cast import numpy as np @@ -119,7 +120,6 @@ def test_to_robot_rows_uses_video_frame_array_asset_schema() -> None: video = robotics_row.videos["camera"] assert isinstance(video, VideoFrameArray) - assert len(list(video.iter_frame_arrays())) == 2 video_frames = list(video.iter_frame_arrays()) assert len(video_frames) == 2 assert video_frames[0].shape == (4, 5, 3) diff --git a/tests/test_video_decode.py b/tests/test_video_decode.py index 2f085608..a9dba9ba 100644 --- a/tests/test_video_decode.py +++ b/tests/test_video_decode.py @@ -74,7 +74,6 @@ def test_video_frame_array_clip_returns_frame_view() -> None: clipped = video.clipped(from_timestamp_s=0.2, to_timestamp_s=0.5) assert isinstance(clipped, mdr.video.VideoFrameArray) - assert len(list(clipped.iter_frame_arrays())) == 3 clipped_frames = list(clipped.iter_frame_arrays()) assert len(clipped_frames) == 3 assert clipped_frames[0].shape == (4, 4, 3) From 7d5bb18a0d5e5ec68534cad503805c514783ffe7 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sun, 24 May 2026 04:25:22 +0200 Subject: [PATCH 16/39] Harden Zarr writer retry semantics --- src/refiner/pipeline/sinks/reducer/file.py | 9 +- src/refiner/pipeline/sinks/zarr.py | 191 ++++++++++++--- src/refiner/worker/lifecycle.py | 18 +- tests/pipeline/test_sinks.py | 44 ++++ tests/readers/test_zarr_reader.py | 257 ++++++++++++++++++++- tests/worker/test_runner.py | 16 +- 6 files changed, 489 insertions(+), 46 deletions(-) diff --git a/src/refiner/pipeline/sinks/reducer/file.py b/src/refiner/pipeline/sinks/reducer/file.py index 324d3480..445e486f 100644 --- a/src/refiner/pipeline/sinks/reducer/file.py +++ b/src/refiner/pipeline/sinks/reducer/file.py @@ -85,14 +85,12 @@ def __init__( reducer_name: str, assets_subdir: str | None = None, recursive: bool = False, - overwrite: bool = True, ) -> None: self.output = DataFolder.resolve(output) self.filename_template = filename_template self.reducer_name = reducer_name self.assets_subdir = assets_subdir self.recursive = recursive - self.overwrite = overwrite self._managed_path_pattern = _compile_managed_path_pattern(filename_template) self._cleanup_ran = False @@ -113,8 +111,6 @@ def describe(self) -> tuple[str, str, dict[str, object]]: args["assets_subdir"] = self.assets_subdir if self.recursive: args["recursive"] = True - if not self.overwrite: - args["overwrite"] = False return ( self.reducer_name, "writer", @@ -185,9 +181,6 @@ def _run_cleanup(self) -> None: continue stale_managed_paths.add(managed_path) - if not self.overwrite and (stale_asset_attempts or stale_managed_paths): - raise ValueError(f"{self.reducer_name} output already exists") - for path in sorted(stale_asset_attempts): try: self.output.rm(path, recursive=True) @@ -216,7 +209,7 @@ def _listed_cleanup_paths(self) -> list[str]: for path in paths: try: next_paths.extend(self.output.ls(path, detail=False)) - except FileNotFoundError: + except (FileNotFoundError, NotADirectoryError, OSError, ValueError): continue paths = next_paths return [ diff --git a/src/refiner/pipeline/sinks/zarr.py b/src/refiner/pipeline/sinks/zarr.py index 9e894de4..afedef99 100644 --- a/src/refiner/pipeline/sinks/zarr.py +++ b/src/refiner/pipeline/sinks/zarr.py @@ -23,10 +23,13 @@ get_active_worker_token, get_finalized_workers, ) +from refiner.worker.lifecycle import sort_finalized_workers _DEFAULT_ARRAY_CHUNK_BYTES = 8 * 1024 * 1024 _MAX_INITIAL_CHUNK_ROWS = 1024 _DONE_MARKER_RELPATH = "_refiner/write_zarr.done" +_PUBLISH_STARTED_MARKER_RELPATH = "_refiner/write_zarr_publish.started" +_PUBLISH_DONE_MARKER_RELPATH = "_refiner/write_zarr_publish.done" @dataclass @@ -84,6 +87,8 @@ def __init__( self._stores: dict[str, _ShardStore] = {} self._default_arrays: dict[str, str] | None = None self._checked_no_overwrite = False + self._cleared_publish_markers = False + self._cleared_merge_marker = False def write_shard_block(self, shard_id: str, block: Block) -> int: count = 0 @@ -186,6 +191,10 @@ async def _append_video( if isinstance(video, VideoFrameArray): if expected_length is not None and video.frame_count != expected_length: raise ValueError("Zarr arrays for one row must have matching lengths") + if video.frame_count == 0: + empty = np.asarray(video.frames, dtype=np.uint8) + self._append_array(store, path, empty[:0]) + return 0 batch: list[np.ndarray] = [] batch_limit: int | None = None for frame in video.iter_frame_arrays(): @@ -224,7 +233,7 @@ async def _append_video( return count def _video_batch_limit(self, frame: np.ndarray) -> int: - return max( + return min( self.video_frame_batch_size, _batch_length_for_shape( (1, *frame.shape), frame.dtype, self.array_chunk_bytes @@ -250,6 +259,10 @@ def _arrays_for_row(self, row: Row) -> dict[str, str]: return self._default_arrays def _store(self, shard_id: str) -> _ShardStore: + if self.reduce_to_single_store: + self._clear_merge_marker_once() + if not self.overwrite and not self.reduce_to_single_store: + self._clear_publish_markers_once() self._check_no_overwrite_output() relpath = self._store_relpath(shard_id) store = self._stores.get(relpath) @@ -267,30 +280,51 @@ def _store(self, shard_id: str) -> _ShardStore: return store def _check_no_overwrite_output(self) -> None: - if ( - self.overwrite - or self._checked_no_overwrite - or not self.reduce_to_single_store - ): + if self.overwrite or self._checked_no_overwrite: return self._checked_no_overwrite = True if _output_has_payload(self.output): raise ValueError("write_zarr output already exists and overwrite=False") + def _clear_publish_markers_once(self) -> None: + if self._cleared_publish_markers: + return + self._cleared_publish_markers = True + for marker in ( + _PUBLISH_STARTED_MARKER_RELPATH, + _PUBLISH_DONE_MARKER_RELPATH, + ): + try: + self.output.rm(marker) + except FileNotFoundError: + pass + + def _clear_merge_marker_once(self) -> None: + if self._cleared_merge_marker: + return + self._cleared_merge_marker = True + try: + self.output.rm(_DONE_MARKER_RELPATH) + except FileNotFoundError: + pass + def _store_relpath(self, shard_id: str) -> str: relpath = self.store_template.format( shard_id=shard_id, worker_id=get_active_worker_token(), ) - if self.reduce_to_single_store: + if self.reduce_to_single_store or not self.overwrite: return _part_store_relpath(relpath) return relpath def on_shard_complete(self, shard_id: str) -> None: - if ( - self.reduce_to_single_store - and self._store_relpath(shard_id) not in self._stores - ): + if self.reduce_to_single_store: + self._clear_merge_marker_once() + if not self.overwrite and not self.reduce_to_single_store: + self._clear_publish_markers_once() + if (self.reduce_to_single_store or not self.overwrite) and self._store_relpath( + shard_id + ) not in self._stores: with self.output.open(self._empty_marker_relpath(shard_id), mode="wb"): pass self._stores.pop(self._store_relpath(shard_id), None) @@ -359,15 +393,109 @@ def build_reducer(self) -> BaseSink | None: reduce_array_batch_bytes=self.reduce_array_batch_bytes, overwrite=self.overwrite, ) + if not self.overwrite: + return _ZarrPublishPartsReducerSink( + output=self.output, + store_template=self.store_template, + ) return FileCleanupReducerSink( output=self.output, filename_template=self.store_template, reducer_name="write_zarr_reduce", recursive=True, - overwrite=self.overwrite, ) +class _ZarrPublishPartsReducerSink(BaseSink): + def __init__( + self, + output: DataFolderLike, + *, + store_template: str, + ) -> None: + self.output = DataFolder.resolve(output) + self.store_template = store_template + self._published = False + + @property + def counts_output_rows(self) -> bool: + return False + + def write_shard_block(self, shard_id, block) -> None: + del shard_id, block + if self._published: + return + self._published = True + + stage_index = get_active_stage_index() + if stage_index is None or stage_index <= 0: + raise ValueError( + "write_zarr_publish requires an active reducer stage with a prior writer stage" + ) + + parts = [ + self.store_template.format( + shard_id=row.shard_id, worker_id=row.worker_token + ) + for row in sort_finalized_workers( + get_finalized_workers(stage_index=stage_index - 1) + ) + ] + if self.output.exists(_PUBLISH_DONE_MARKER_RELPATH): + self._remove_parts() + return + + if _output_has_existing_store(self.output): + if not self.output.exists(_PUBLISH_STARTED_MARKER_RELPATH): + self._remove_parts() + raise ValueError("write_zarr output already exists and overwrite=False") + self._remove_publish_targets(parts) + + with self.output.open(_PUBLISH_STARTED_MARKER_RELPATH, mode="wb"): + pass + + try: + for final_relpath in parts: + part_relpath = _part_store_relpath(final_relpath) + if not self.output.exists(part_relpath): + if self.output.exists(part_relpath + ".empty"): + continue + raise ValueError(f"Zarr part store is missing: {part_relpath}") + target_parent = final_relpath.rsplit("/", maxsplit=1)[0] + if target_parent != final_relpath: + self.output.makedirs(target_parent, exist_ok=True) + self.output.copy( + part_relpath, + final_relpath, + recursive=True, + on_error="raise", + ) + except Exception: + self._remove_publish_targets(parts) + raise + + with self.output.open(_PUBLISH_DONE_MARKER_RELPATH, mode="wb"): + pass + try: + self.output.rm(_PUBLISH_STARTED_MARKER_RELPATH) + except FileNotFoundError: + pass + self._remove_parts() + + def _remove_publish_targets(self, relpaths: Iterable[str]) -> None: + for relpath in relpaths: + try: + self.output.rm(relpath, recursive=True) + except FileNotFoundError: + continue + + def _remove_parts(self) -> None: + try: + self.output.rm("_parts", recursive=True) + except FileNotFoundError: + pass + + class _ZarrMergeReducerSink(BaseSink): def __init__( self, @@ -418,14 +546,12 @@ def _merge(self) -> None: ) expected_parts = self._expected_parts(stage_index) - if self.output.exists(_DONE_MARKER_RELPATH) and not any( - self.output.exists(relpath) or self.output.exists(f"{relpath}.empty") - for relpath in expected_parts - ): + if self.output.exists(_DONE_MARKER_RELPATH): + self._remove_parts() return parts = self._collect_parts(expected_parts) - if not self.overwrite and _output_has_payload(self.output): + if not self.overwrite and _output_has_existing_store(self.output): raise ValueError("write_zarr output already exists and overwrite=False") import zarr @@ -509,7 +635,7 @@ def _merge(self) -> None: def _expected_parts(self, stage_index: int) -> list[str]: return [ self._part_relpath(row.shard_id, row.worker_token) - for row in _sort_finalized_workers( + for row in sort_finalized_workers( get_finalized_workers(stage_index=stage_index - 1), ) ] @@ -645,16 +771,6 @@ def _part_store_relpath(relpath: str) -> str: return f"_parts/{relpath}" -def _sort_finalized_workers(rows: Iterable[Any]) -> list[Any]: - return sorted( - rows, - key=lambda row: ( - row.global_ordinal is None, - row.global_ordinal if row.global_ordinal is not None else row.shard_id, - ), - ) - - def _clear_final_group(group: Any) -> None: for key in sorted({*group.array_keys(), *group.group_keys()}): if key != "_parts": @@ -689,6 +805,25 @@ def _output_has_payload(output: DataFolder) -> bool: return _group_has_payload(group) +def _output_has_existing_store(output: DataFolder) -> bool: + if _output_has_payload(output): + return True + if output.exists(_DONE_MARKER_RELPATH) or output.exists( + _PUBLISH_DONE_MARKER_RELPATH + ): + return True + try: + entries = output.ls("", detail=False) + except FileNotFoundError: + return False + for entry in entries: + root = str(entry).split("/", maxsplit=1)[0] + if root in {"_parts", "_refiner"}: + continue + return True + return False + + def _chunk_shape(array: np.ndarray, target_bytes: int) -> tuple[int, ...]: chunk_rows = min( _batch_length(array, target_bytes), diff --git a/src/refiner/worker/lifecycle.py b/src/refiner/worker/lifecycle.py index c77ddb26..2d081876 100644 --- a/src/refiner/worker/lifecycle.py +++ b/src/refiner/worker/lifecycle.py @@ -67,13 +67,16 @@ def read_finalized_workers( ), ) ) - rows.sort( - key=lambda row: ( - row.global_ordinal is None, - row.global_ordinal if row.global_ordinal is not None else row.shard_id, - ) - ) - return rows + return sort_finalized_workers(rows) + + +def sort_finalized_workers( + rows: Iterable[FinalizedShardWorker], +) -> list[FinalizedShardWorker]: + rows = list(rows) + if any(row.global_ordinal is None for row in rows): + return sorted(rows, key=lambda row: row.shard_id) + return sorted(rows, key=lambda row: row.global_ordinal) class LocalRuntimeLifecycle: @@ -135,4 +138,5 @@ def finalized_workers( "LocalRuntimeLifecycle", "RuntimeLifecycle", "read_finalized_workers", + "sort_finalized_workers", ] diff --git a/tests/pipeline/test_sinks.py b/tests/pipeline/test_sinks.py index 570f68fd..d347003f 100644 --- a/tests/pipeline/test_sinks.py +++ b/tests/pipeline/test_sinks.py @@ -1047,6 +1047,50 @@ def test_file_cleanup_reducer_removes_dynamic_nested_directories(tmp_path) -> No assert not loser_dir.exists() +def test_file_cleanup_reducer_ignores_files_during_recursive_traversal( + tmp_path, +) -> None: + output_dir = tmp_path / "zarr-cleanup-mixed" + shard_id = "0123456789ab" + winner_worker_id = "worker-2" + loser_worker_id = "worker-1" + winner_dir = ( + output_dir / "split" / shard_id / f"{worker_token_for(winner_worker_id)}.zarr" + ) + loser_dir = ( + output_dir / "split" / shard_id / f"{worker_token_for(loser_worker_id)}.zarr" + ) + (winner_dir / "data").mkdir(parents=True) + (loser_dir / "data").mkdir(parents=True) + (winner_dir / "data" / "0").write_bytes(b"keep") + (loser_dir / "data" / "0").write_bytes(b"drop") + (output_dir / "split" / "README.txt").write_text("notes", encoding="utf-8") + + reducer = FileCleanupReducerSink( + output_dir, + filename_template="split/{shard_id}/{worker_id}.zarr", + reducer_name="cleanup_zarr", + recursive=True, + ) + with set_active_run_context( + job_id="job", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast( + RuntimeLifecycle, + _FinalizedWorkersRuntime( + [FinalizedShardWorker(shard_id=shard_id, worker_id=winner_worker_id)] + ), + ), + ): + reducer.write_block([DictRow({"task_rank": 0}, shard_id="reduce")]) + + assert winner_dir.exists() + assert not loser_dir.exists() + assert (output_dir / "split" / "README.txt").read_text(encoding="utf-8") == "notes" + + def test_file_cleanup_reducer_tolerates_duplicate_listed_paths( tmp_path, monkeypatch ) -> None: diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py index 62737308..2c23bb3a 100644 --- a/tests/readers/test_zarr_reader.py +++ b/tests/readers/test_zarr_reader.py @@ -989,6 +989,43 @@ def test_write_zarr_single_store_rejects_attrs_only_output_when_not_overwriting( assert dict(root.attrs) == {"task": "old"} +def test_write_zarr_single_store_rejects_empty_existing_store_when_not_overwriting( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "single-empty-existing-no-overwrite.zarr" + + ( + mdr.from_items([{"action": [[0.0]]}], items_per_shard=1) + .filter(lambda row: False) + .write_zarr( + str(zarr_out), + arrays={"data/action": "action"}, + reduce_to_single_store=True, + ) + .launch_local( + name="zarr-single-empty-existing-first", + num_workers=1, + rundir=str(tmp_path / "run-empty-existing-first"), + ) + ) + + with pytest.raises(LocalLaunchResumeError): + ( + mdr.from_items([{"action": [[1.0]]}], items_per_shard=1) + .write_zarr( + str(zarr_out), + arrays={"data/action": "action"}, + reduce_to_single_store=True, + overwrite=False, + ) + .launch_local( + name="zarr-single-empty-existing-second", + num_workers=1, + rundir=str(tmp_path / "run-empty-existing-second"), + ) + ) + + def test_write_zarr_rejects_existing_non_reduced_output_when_not_overwriting( tmp_path: Path, ) -> None: @@ -1027,7 +1064,8 @@ def test_write_zarr_rejects_existing_non_reduced_output_when_not_overwriting( ).take(1)[0] for store in zarr_out.glob("*.zarr") ] - assert sorted(float(row["action"][0][0]) for row in rows) == [0.0, 1.0] + assert sorted(float(row["action"][0][0]) for row in rows) == [0.0] + assert not (zarr_out / "_parts").exists() def test_write_zarr_non_reduced_no_overwrite_preserves_finalized_retry_output( @@ -1066,6 +1104,147 @@ def test_write_zarr_non_reduced_no_overwrite_preserves_finalized_retry_output( assert winner.exists() +def test_write_zarr_non_reduced_no_overwrite_rejects_missing_finalized_part( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "sharded-no-overwrite-missing-part.zarr" + reducer = ZarrSink( + str(zarr_out), + arrays={"data/action": "action"}, + overwrite=False, + ).build_reducer() + assert reducer is not None + + runtime = _FinalizedWorkersRuntime( + [FinalizedShardWorker(shard_id="shard-a", worker_id="worker-a")] + ) + with set_active_run_context( + job_id="local", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast(RuntimeLifecycle, runtime), + ): + with pytest.raises(ValueError, match="part store is missing"): + reducer.write_block([DictRow({}, shard_id="reduce")]) + + +def test_write_zarr_non_reduced_no_overwrite_skips_empty_parts( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "sharded-no-overwrite-empty-part.zarr" + + ( + mdr.from_items([{"action": [[0.0]]}], items_per_shard=1) + .filter(lambda row: False) + .write_zarr( + str(zarr_out), + arrays={"data/action": "action"}, + overwrite=False, + ) + .launch_local( + name="zarr-sharded-empty-no-overwrite", + num_workers=1, + rundir=str(tmp_path / "run-sharded-empty-no-overwrite"), + ) + ) + + assert not list(zarr_out.glob("*.zarr")) + assert not (zarr_out / "_parts").exists() + + +def test_write_zarr_non_reduced_no_overwrite_retry_removes_partial_publish( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "sharded-no-overwrite-partial-publish.zarr" + shard_id = "0123456789ab" + worker_id = "worker-a" + relpath = f"{shard_id}__w{worker_token_for(worker_id)}.zarr" + _write_part_zarr( + zarr_out / "_parts" / relpath, + {"data/action": np.asarray([[1.0]], dtype=np.float32)}, + ) + _write_part_zarr( + zarr_out / relpath, + {"data/action": np.asarray([[0.0]], dtype=np.float32)}, + ) + marker = zarr_out / "_refiner" / "write_zarr_publish.started" + marker.parent.mkdir(parents=True) + marker.write_bytes(b"") + + reducer = ZarrSink( + str(zarr_out), + arrays={"data/action": "action"}, + overwrite=False, + ).build_reducer() + assert reducer is not None + runtime = _FinalizedWorkersRuntime( + [FinalizedShardWorker(shard_id=shard_id, worker_id=worker_id)] + ) + with set_active_run_context( + job_id="local", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast(RuntimeLifecycle, runtime), + ): + reducer.write_block([DictRow({}, shard_id="reduce")]) + + row = mdr.read_zarr( + zarr_out / relpath, + arrays={"action": "data/action"}, + file_path_column=None, + ).take(1)[0] + np.testing.assert_allclose(row["action"], [[1.0]]) + assert not (zarr_out / "_parts").exists() + + +def test_write_zarr_non_reduced_no_overwrite_completed_publish_is_retryable( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "sharded-no-overwrite-complete-retry.zarr" + shard_id = "0123456789ab" + worker_id = "worker-a" + relpath = f"{shard_id}__w{worker_token_for(worker_id)}.zarr" + _write_part_zarr( + zarr_out / relpath, + {"data/action": np.asarray([[1.0]], dtype=np.float32)}, + ) + _write_part_zarr( + zarr_out / "_parts" / relpath, + {"data/action": np.asarray([[1.0]], dtype=np.float32)}, + ) + marker = zarr_out / "_refiner" / "write_zarr_publish.done" + marker.parent.mkdir(parents=True) + marker.write_bytes(b"") + + reducer = ZarrSink( + str(zarr_out), + arrays={"data/action": "action"}, + overwrite=False, + ).build_reducer() + assert reducer is not None + runtime = _FinalizedWorkersRuntime( + [FinalizedShardWorker(shard_id=shard_id, worker_id=worker_id)] + ) + with set_active_run_context( + job_id="local", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast(RuntimeLifecycle, runtime), + ): + reducer.write_block([DictRow({}, shard_id="reduce")]) + + row = mdr.read_zarr( + zarr_out / relpath, + arrays={"action": "data/action"}, + file_path_column=None, + ).take(1)[0] + np.testing.assert_allclose(row["action"], [[1.0]]) + assert not (zarr_out / "_parts").exists() + + def test_write_zarr_allows_fresh_non_reduced_multiworker_no_overwrite( tmp_path: Path, ) -> None: @@ -1098,6 +1277,8 @@ def test_write_zarr_allows_fresh_non_reduced_multiworker_no_overwrite( ] assert len(rows) == 2 assert sorted(float(row["action"][0][0]) for row in rows) == [0.0, 1.0] + assert not (zarr_out / "_parts").exists() + assert not (zarr_out / "_refiner" / "write_zarr_publish.started").exists() def test_write_zarr_single_store_skips_empty_shards(tmp_path: Path) -> None: @@ -1124,6 +1305,45 @@ def test_write_zarr_single_store_skips_empty_shards(tmp_path: Path) -> None: assert not (zarr_out / "_parts").exists() +def test_write_zarr_single_store_empty_overwrite_ignores_stale_done_marker( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "single-empty-overwrite-stale-done.zarr" + + ( + mdr.from_items([{"action": [[1.0]]}], items_per_shard=1) + .write_zarr( + str(zarr_out), + arrays={"data/action": "action"}, + reduce_to_single_store=True, + ) + .launch_local( + name="zarr-single-stale-done-first", + num_workers=1, + rundir=str(tmp_path / "run-stale-done-first"), + ) + ) + + ( + mdr.from_items([{"action": [[2.0]]}], items_per_shard=1) + .filter(lambda row: False) + .write_zarr( + str(zarr_out), + arrays={"data/action": "action"}, + reduce_to_single_store=True, + ) + .launch_local( + name="zarr-single-stale-done-second", + num_workers=1, + rundir=str(tmp_path / "run-stale-done-second"), + ) + ) + + root = _open_test_zarr(zarr_out, mode="r") + assert "data/action" not in root + assert not (zarr_out / "_parts").exists() + + def test_write_zarr_single_store_skips_mixed_empty_shards(tmp_path: Path) -> None: zarr_out = tmp_path / "single-mixed-empty-shards.zarr" @@ -1269,6 +1489,7 @@ def test_write_zarr_single_store_completed_merge_is_retryable( tmp_path: Path, ) -> None: zarr_out = tmp_path / "single-completed-retry.zarr" + worker_id = "worker-a" _write_part_zarr( zarr_out, { @@ -1276,6 +1497,13 @@ def test_write_zarr_single_store_completed_merge_is_retryable( "meta/episode_ends": np.asarray([1], dtype=np.int64), }, ) + _write_part_zarr( + zarr_out / "_parts" / f"shard-a__w{worker_token_for(worker_id)}.zarr", + { + "data/action": np.asarray([[9.0]], dtype=np.float32), + "meta/episode_ends": np.asarray([1], dtype=np.int64), + }, + ) marker = zarr_out / "_refiner" / "write_zarr.done" marker.parent.mkdir(parents=True) marker.write_bytes(b"") @@ -1283,7 +1511,7 @@ def test_write_zarr_single_store_completed_merge_is_retryable( [ FinalizedShardWorker( shard_id="shard-a", - worker_id="worker-a", + worker_id=worker_id, global_ordinal=0, ) ] @@ -1311,6 +1539,7 @@ def test_write_zarr_single_store_completed_merge_is_retryable( ).take(1)[0] np.testing.assert_allclose(row["action"], [[9.0]]) assert row["episode_ends"].tolist() == [1] + assert not (zarr_out / "_parts").exists() def test_write_zarr_single_store_parts_are_resume_stable(tmp_path: Path) -> None: @@ -1553,6 +1782,30 @@ def test_write_zarr_materializes_frame_array_videos(tmp_path: Path) -> None: np.testing.assert_allclose(row["action"], [[0.0], [0.1]]) +def test_write_zarr_materializes_empty_frame_array_videos(tmp_path: Path) -> None: + output = tmp_path / "empty-video.zarr" + frames = np.empty((0, 4, 5, 3), dtype=np.uint8) + rows = list( + mdr.from_items([{"episode_id": "episode-1", "frames": frames}]).to_robot_rows( + episode_id_key="episode_id", + action_key=None, + state_key=None, + timestamp_key=None, + video_keys={"observation.images.front": "frames"}, + fps=10, + ) + ) + + ZarrSink( + str(output), + arrays={"data/rgb": "observation.images.front"}, + ).write_block(rows) + + root = _open_test_zarr(next(output.glob("*.zarr")), mode="r") + assert root["data/rgb"].shape == frames.shape + assert root["meta/episode_ends"][:].tolist() == [0] + + def test_write_zarr_uses_byte_budgeted_chunks_for_large_rows(tmp_path: Path) -> None: output = tmp_path / "video-chunks.zarr" frames = np.zeros((2, 4, 4, 3), dtype=np.uint8) diff --git a/tests/worker/test_runner.py b/tests/worker/test_runner.py index bfc9ae91..c2f6249d 100644 --- a/tests/worker/test_runner.py +++ b/tests/worker/test_runner.py @@ -19,7 +19,7 @@ from refiner.pipeline.sources.readers.base import BaseReader from refiner.pipeline.data.row import DictRow, Row from refiner.worker.metrics.api import log_gauge -from refiner.worker.lifecycle import FinalizedShardWorker +from refiner.worker.lifecycle import FinalizedShardWorker, sort_finalized_workers class _FakeReader(BaseReader): @@ -71,6 +71,20 @@ def _shard(path: str, start: int, end: int) -> Shard: return Shard.from_file_parts([FilePart(path=path, start=start, end=end)]) +def test_sort_finalized_workers_uses_legacy_order_when_any_ordinal_is_missing() -> None: + rows = [ + FinalizedShardWorker("shard-c", "worker-c", global_ordinal=0), + FinalizedShardWorker("shard-a", "worker-a"), + FinalizedShardWorker("shard-b", "worker-b", global_ordinal=1), + ] + + assert [row.shard_id for row in sort_finalized_workers(rows)] == [ + "shard-a", + "shard-b", + "shard-c", + ] + + class _NoopTelemetryEmitter: def emit_user_counter(self, **kwargs) -> None: del kwargs From 71431cd5ffc435304f930d4ead9cf1a9b013241f Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sun, 24 May 2026 04:32:49 +0200 Subject: [PATCH 17/39] Tighten Zarr no-overwrite checks --- src/refiner/pipeline/sinks/zarr.py | 31 +++++++++++------- tests/readers/test_zarr_reader.py | 52 ++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 11 deletions(-) diff --git a/src/refiner/pipeline/sinks/zarr.py b/src/refiner/pipeline/sinks/zarr.py index afedef99..b4da44df 100644 --- a/src/refiner/pipeline/sinks/zarr.py +++ b/src/refiner/pipeline/sinks/zarr.py @@ -259,11 +259,11 @@ def _arrays_for_row(self, row: Row) -> dict[str, str]: return self._default_arrays def _store(self, shard_id: str) -> _ShardStore: - if self.reduce_to_single_store: + self._check_no_overwrite_output() + if self.reduce_to_single_store and self.overwrite: self._clear_merge_marker_once() if not self.overwrite and not self.reduce_to_single_store: self._clear_publish_markers_once() - self._check_no_overwrite_output() relpath = self._store_relpath(shard_id) store = self._stores.get(relpath) if store is not None: @@ -283,7 +283,7 @@ def _check_no_overwrite_output(self) -> None: if self.overwrite or self._checked_no_overwrite: return self._checked_no_overwrite = True - if _output_has_payload(self.output): + if _output_has_existing_store(self.output): raise ValueError("write_zarr output already exists and overwrite=False") def _clear_publish_markers_once(self) -> None: @@ -318,7 +318,8 @@ def _store_relpath(self, shard_id: str) -> str: return relpath def on_shard_complete(self, shard_id: str) -> None: - if self.reduce_to_single_store: + self._check_no_overwrite_output() + if self.reduce_to_single_store and self.overwrite: self._clear_merge_marker_once() if not self.overwrite and not self.reduce_to_single_store: self._clear_publish_markers_once() @@ -721,13 +722,21 @@ def _validate_array_paths( def _validate_store_template(store_template: str) -> None: - fields = { - field_name - for _literal_text, field_name, _format_spec, _conversion in Formatter().parse( - store_template - ) - if field_name is not None - } + fields: set[str] = set() + for _literal_text, field_name, format_spec, conversion in Formatter().parse( + store_template + ): + if field_name is None: + continue + if conversion is not None or format_spec: + raise ValueError( + "store_template only supports plain {shard_id} and {worker_id} fields" + ) + if field_name not in {"shard_id", "worker_id"}: + raise ValueError( + "store_template only supports plain {shard_id} and {worker_id} fields" + ) + fields.add(field_name) missing_fields = {"shard_id", "worker_id"}.difference(fields) if missing_fields: raise ValueError( diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py index 2c23bb3a..74337999 100644 --- a/tests/readers/test_zarr_reader.py +++ b/tests/readers/test_zarr_reader.py @@ -851,6 +851,19 @@ def test_write_zarr_rejects_store_template_without_worker_id(tmp_path: Path) -> ZarrSink(str(tmp_path / "template.zarr"), store_template="{shard_id}.zarr") +def test_write_zarr_rejects_unsupported_store_template_fields(tmp_path: Path) -> None: + with pytest.raises(ValueError, match="store_template only supports"): + ZarrSink( + str(tmp_path / "extra-template.zarr"), + store_template="{shard_id}__w{worker_id}__{part}.zarr", + ) + with pytest.raises(ValueError, match="store_template only supports"): + ZarrSink( + str(tmp_path / "format-template.zarr"), + store_template="{shard_id:>12}__w{worker_id}.zarr", + ) + + def test_write_zarr_rejects_invalid_reduce_batch_bytes(tmp_path: Path) -> None: with pytest.raises(ValueError, match="reduce_array_batch_bytes"): ZarrSink(str(tmp_path / "bad-batch.zarr"), reduce_array_batch_bytes=0) @@ -1153,6 +1166,45 @@ def test_write_zarr_non_reduced_no_overwrite_skips_empty_parts( assert not (zarr_out / "_parts").exists() +def test_write_zarr_non_reduced_rejects_empty_existing_output_when_not_overwriting( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "sharded-empty-existing-no-overwrite.zarr" + + ( + mdr.from_items([{"action": [[0.0]]}], items_per_shard=1) + .filter(lambda row: False) + .write_zarr( + str(zarr_out), + arrays={"data/action": "action"}, + overwrite=False, + ) + .launch_local( + name="zarr-sharded-empty-existing-first", + num_workers=1, + rundir=str(tmp_path / "run-sharded-empty-existing-first"), + ) + ) + + with pytest.raises(LocalLaunchResumeError): + ( + mdr.from_items([{"action": [[1.0]]}], items_per_shard=1) + .write_zarr( + str(zarr_out), + arrays={"data/action": "action"}, + overwrite=False, + ) + .launch_local( + name="zarr-sharded-empty-existing-second", + num_workers=1, + rundir=str(tmp_path / "run-sharded-empty-existing-second"), + ) + ) + + assert not list(zarr_out.glob("*.zarr")) + assert not (zarr_out / "_parts").exists() + + def test_write_zarr_non_reduced_no_overwrite_retry_removes_partial_publish( tmp_path: Path, ) -> None: From a3fa1673aedc5b556e936c8d46166cd5536c15b8 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sun, 24 May 2026 04:40:56 +0200 Subject: [PATCH 18/39] Fix Zarr reducer edge cases --- src/refiner/pipeline/sinks/zarr.py | 120 ++++++++++++++++++-- tests/readers/test_zarr_reader.py | 173 +++++++++++++++++++++++++++++ 2 files changed, 283 insertions(+), 10 deletions(-) diff --git a/src/refiner/pipeline/sinks/zarr.py b/src/refiner/pipeline/sinks/zarr.py index b4da44df..a6a49c2b 100644 --- a/src/refiner/pipeline/sinks/zarr.py +++ b/src/refiner/pipeline/sinks/zarr.py @@ -28,6 +28,7 @@ _DEFAULT_ARRAY_CHUNK_BYTES = 8 * 1024 * 1024 _MAX_INITIAL_CHUNK_ROWS = 1024 _DONE_MARKER_RELPATH = "_refiner/write_zarr.done" +_MERGE_STARTED_MARKER_RELPATH = "_refiner/write_zarr.started" _PUBLISH_STARTED_MARKER_RELPATH = "_refiner/write_zarr_publish.started" _PUBLISH_DONE_MARKER_RELPATH = "_refiner/write_zarr_publish.done" @@ -399,13 +400,44 @@ def build_reducer(self) -> BaseSink | None: output=self.output, store_template=self.store_template, ) - return FileCleanupReducerSink( + return _ZarrCleanupReducerSink( output=self.output, - filename_template=self.store_template, + store_template=self.store_template, + ) + + +class _ZarrCleanupReducerSink(BaseSink): + def __init__(self, output: DataFolderLike, *, store_template: str) -> None: + self.output = DataFolder.resolve(output) + self.store_template = store_template + self._cleanup = FileCleanupReducerSink( + output=self.output, + filename_template=store_template, reducer_name="write_zarr_reduce", recursive=True, ) + @property + def counts_output_rows(self) -> bool: + return False + + def write_shard_block(self, shard_id, block) -> None: + self._cleanup.write_shard_block(shard_id, block) + stage_index = get_active_stage_index() + if stage_index is None or stage_index <= 0: + raise ValueError( + "write_zarr_reduce requires an active reducer stage with a prior writer stage" + ) + relpaths = [ + self.store_template.format( + shard_id=row.shard_id, worker_id=row.worker_token + ) + for row in sort_finalized_workers( + get_finalized_workers(stage_index=stage_index - 1) + ) + ] + _validate_zarr_stores(self.output, relpaths) + class _ZarrPublishPartsReducerSink(BaseSink): def __init__( @@ -442,8 +474,17 @@ def write_shard_block(self, shard_id, block) -> None: get_finalized_workers(stage_index=stage_index - 1) ) ] + if not parts: + if _output_has_existing_store(self.output): + self._remove_parts(best_effort=True) + raise ValueError("write_zarr output already exists and overwrite=False") + with self.output.open(_PUBLISH_DONE_MARKER_RELPATH, mode="wb"): + pass + self._remove_parts(best_effort=True) + return + if self.output.exists(_PUBLISH_DONE_MARKER_RELPATH): - self._remove_parts() + self._remove_parts(best_effort=True) return if _output_has_existing_store(self.output): @@ -479,9 +520,9 @@ def write_shard_block(self, shard_id, block) -> None: pass try: self.output.rm(_PUBLISH_STARTED_MARKER_RELPATH) - except FileNotFoundError: + except (FileNotFoundError, OSError, ValueError): pass - self._remove_parts() + self._remove_parts(best_effort=True) def _remove_publish_targets(self, relpaths: Iterable[str]) -> None: for relpath in relpaths: @@ -490,11 +531,14 @@ def _remove_publish_targets(self, relpaths: Iterable[str]) -> None: except FileNotFoundError: continue - def _remove_parts(self) -> None: + def _remove_parts(self, *, best_effort: bool = False) -> None: try: self.output.rm("_parts", recursive=True) except FileNotFoundError: pass + except (OSError, ValueError): + if not best_effort: + raise class _ZarrMergeReducerSink(BaseSink): @@ -547,13 +591,30 @@ def _merge(self) -> None: ) expected_parts = self._expected_parts(stage_index) + if not expected_parts: + if not self.overwrite and _output_has_existing_store(self.output): + raise ValueError("write_zarr output already exists and overwrite=False") + import zarr + + final = zarr.open_group( + store=zarr_store(self.output, "", mode="a"), + mode="a", + ) + if self.overwrite: + _clear_final_group(final) + with self.output.open(_DONE_MARKER_RELPATH, mode="wb"): + pass + self._remove_parts(best_effort=True) + return + if self.output.exists(_DONE_MARKER_RELPATH): - self._remove_parts() + self._remove_parts(best_effort=True) return parts = self._collect_parts(expected_parts) if not self.overwrite and _output_has_existing_store(self.output): - raise ValueError("write_zarr output already exists and overwrite=False") + if not self.output.exists(_MERGE_STARTED_MARKER_RELPATH): + raise ValueError("write_zarr output already exists and overwrite=False") import zarr @@ -563,6 +624,9 @@ def _merge(self) -> None: ) if self.overwrite: _clear_final_group(final) + elif not self.output.exists(_MERGE_STARTED_MARKER_RELPATH): + with self.output.open(_MERGE_STARTED_MARKER_RELPATH, mode="wb"): + pass try: row_offset = 0 @@ -626,7 +690,11 @@ def _merge(self) -> None: with self.output.open(_DONE_MARKER_RELPATH, mode="wb"): pass - self._remove_parts() + try: + self.output.rm(_MERGE_STARTED_MARKER_RELPATH) + except (FileNotFoundError, OSError, ValueError): + pass + self._remove_parts(best_effort=True) try: if not self.output.ls("_parts"): self.output.rmdir("_parts") @@ -691,11 +759,14 @@ def _collect_parts(self, expected_parts: Iterable[str]) -> list[_PartStore]: parts.append(_PartStore(relpath=relpath, paths=source_paths)) return parts - def _remove_parts(self) -> None: + def _remove_parts(self, *, best_effort: bool = False) -> None: try: self.output.rm("_parts", recursive=True) except FileNotFoundError: pass + except (OSError, ValueError): + if not best_effort: + raise def _default_robotics_arrays(row: Row) -> dict[str, str]: @@ -780,6 +851,35 @@ def _part_store_relpath(relpath: str) -> str: return f"_parts/{relpath}" +def _validate_zarr_stores(output: DataFolder, relpaths: Iterable[str]) -> None: + import zarr + + payload_paths: set[str] | None = None + schemas: dict[str, tuple[tuple[int, ...], np.dtype[Any]]] = {} + for relpath in relpaths: + if not output.exists(relpath): + continue + source = zarr.open_group( + store=zarr_store(output, relpath, mode="r"), + mode="r", + ) + source_paths = set(iter_zarr_array_paths(source)) + if payload_paths is None: + payload_paths = source_paths + elif source_paths != payload_paths: + raise ValueError("Zarr stores must contain the same arrays") + for path in source_paths: + source_array = source[path] + schema = (tuple(source_array.shape[1:]), np.dtype(source_array.dtype)) + previous = schemas.setdefault(path, schema) + if previous != schema: + if previous[0] != schema[0]: + raise ValueError( + f"Zarr arrays for {path!r} must have matching trailing shapes" + ) + raise ValueError(f"Zarr arrays for {path!r} must have matching dtypes") + + def _clear_final_group(group: Any) -> None: for key in sorted({*group.array_keys(), *group.group_keys()}): if key != "_parts": diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py index 74337999..8861bc5a 100644 --- a/tests/readers/test_zarr_reader.py +++ b/tests/readers/test_zarr_reader.py @@ -1333,6 +1333,60 @@ def test_write_zarr_allows_fresh_non_reduced_multiworker_no_overwrite( assert not (zarr_out / "_refiner" / "write_zarr_publish.started").exists() +def test_write_zarr_rejects_sharded_schema_drift_after_cleanup( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "sharded-schema-drift.zarr" + first_worker = "worker-a" + second_worker = "worker-b" + first = zarr_out / f"shard-a__w{worker_token_for(first_worker)}.zarr" + second = zarr_out / f"shard-b__w{worker_token_for(second_worker)}.zarr" + _write_part_zarr( + first, + { + "data/action": np.asarray([[0.0]], dtype=np.float32), + "meta/episode_ends": np.asarray([1], dtype=np.int64), + }, + ) + _write_part_zarr( + second, + { + "data/action": np.asarray([[1.0]], dtype=np.float32), + "data/state": np.asarray([[2.0]], dtype=np.float32), + "meta/episode_ends": np.asarray([1], dtype=np.int64), + }, + ) + reducer = ZarrSink( + str(zarr_out), + arrays={"data/action": "action"}, + ).build_reducer() + assert reducer is not None + runtime = _FinalizedWorkersRuntime( + [ + FinalizedShardWorker( + shard_id="shard-a", + worker_id=first_worker, + global_ordinal=0, + ), + FinalizedShardWorker( + shard_id="shard-b", + worker_id=second_worker, + global_ordinal=1, + ), + ] + ) + + with set_active_run_context( + job_id="local", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast(RuntimeLifecycle, runtime), + ): + with pytest.raises(ValueError, match="same arrays"): + reducer.write_block([DictRow({}, shard_id="reduce")]) + + def test_write_zarr_single_store_skips_empty_shards(tmp_path: Path) -> None: zarr_out = tmp_path / "single-empty-shards.zarr" @@ -1594,6 +1648,125 @@ def test_write_zarr_single_store_completed_merge_is_retryable( assert not (zarr_out / "_parts").exists() +def test_write_zarr_single_store_zero_shard_overwrite_ignores_stale_done_marker( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "single-zero-shard-overwrite.zarr" + _write_part_zarr( + zarr_out, + { + "data/action": np.asarray([[9.0]], dtype=np.float32), + "meta/episode_ends": np.asarray([1], dtype=np.int64), + }, + ) + marker = zarr_out / "_refiner" / "write_zarr.done" + marker.parent.mkdir(parents=True) + marker.write_bytes(b"") + runtime = _FinalizedWorkersRuntime([]) + + with set_active_run_context( + job_id="local", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast(RuntimeLifecycle, runtime), + ): + _ZarrMergeReducerSink( + str(zarr_out), + store_template="{shard_id}__w{worker_id}.zarr", + episode_ends_path="meta/episode_ends", + reduce_array_batch_bytes=1024, + overwrite=True, + ).write_block([DictRow({}, shard_id="reduce")]) + + root = _open_test_zarr(zarr_out, mode="r") + assert "data/action" not in root + + +def test_write_zarr_single_store_zero_shard_no_overwrite_rejects_existing_output( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "single-zero-shard-no-overwrite.zarr" + _write_part_zarr( + zarr_out, + { + "data/action": np.asarray([[9.0]], dtype=np.float32), + "meta/episode_ends": np.asarray([1], dtype=np.int64), + }, + ) + marker = zarr_out / "_refiner" / "write_zarr.done" + marker.parent.mkdir(parents=True) + marker.write_bytes(b"") + runtime = _FinalizedWorkersRuntime([]) + + with set_active_run_context( + job_id="local", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast(RuntimeLifecycle, runtime), + ): + with pytest.raises(ValueError, match="output already exists"): + _ZarrMergeReducerSink( + str(zarr_out), + store_template="{shard_id}__w{worker_id}.zarr", + episode_ends_path="meta/episode_ends", + reduce_array_batch_bytes=1024, + overwrite=False, + ).write_block([DictRow({}, shard_id="reduce")]) + + +def test_write_zarr_single_store_no_overwrite_started_merge_is_retryable( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "single-no-overwrite-started-retry.zarr" + worker_id = "worker-a" + root = _open_test_zarr(zarr_out, mode="w") + assert dict(root.attrs) == {} + _write_part_zarr( + zarr_out / "_parts" / f"shard-a__w{worker_token_for(worker_id)}.zarr", + { + "data/action": np.asarray([[3.0]], dtype=np.float32), + "meta/episode_ends": np.asarray([1], dtype=np.int64), + }, + ) + marker = zarr_out / "_refiner" / "write_zarr.started" + marker.parent.mkdir(parents=True) + marker.write_bytes(b"") + runtime = _FinalizedWorkersRuntime( + [ + FinalizedShardWorker( + shard_id="shard-a", + worker_id=worker_id, + global_ordinal=0, + ) + ] + ) + + with set_active_run_context( + job_id="local", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast(RuntimeLifecycle, runtime), + ): + _ZarrMergeReducerSink( + str(zarr_out), + store_template="{shard_id}__w{worker_id}.zarr", + episode_ends_path="meta/episode_ends", + reduce_array_batch_bytes=1024, + overwrite=False, + ).write_block([DictRow({}, shard_id="reduce")]) + + row = mdr.read_zarr( + zarr_out, + arrays={"action": "data/action", "episode_ends": "meta/episode_ends"}, + file_path_column=None, + ).take(1)[0] + np.testing.assert_allclose(row["action"], [[3.0]]) + assert row["episode_ends"].tolist() == [1] + + def test_write_zarr_single_store_parts_are_resume_stable(tmp_path: Path) -> None: zarr_out = tmp_path / "single-resume-stable.zarr" worker_id = "original-worker" From 942392b02bffb9cc7ac8abb7c5728c199707ce17 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sun, 24 May 2026 04:46:56 +0200 Subject: [PATCH 19/39] Guard Zarr reducer reserved paths --- src/refiner/pipeline/sinks/zarr.py | 16 +++- tests/readers/test_zarr_reader.py | 116 +++++++++++++++++++++++++++++ 2 files changed, 131 insertions(+), 1 deletion(-) diff --git a/src/refiner/pipeline/sinks/zarr.py b/src/refiner/pipeline/sinks/zarr.py index a6a49c2b..ec79bafe 100644 --- a/src/refiner/pipeline/sinks/zarr.py +++ b/src/refiner/pipeline/sinks/zarr.py @@ -497,6 +497,9 @@ def write_shard_block(self, shard_id, block) -> None: pass try: + _validate_zarr_stores( + self.output, (_part_store_relpath(relpath) for relpath in parts) + ) for final_relpath in parts: part_relpath = _part_store_relpath(final_relpath) if not self.output.exists(part_relpath): @@ -624,7 +627,9 @@ def _merge(self) -> None: ) if self.overwrite: _clear_final_group(final) - elif not self.output.exists(_MERGE_STARTED_MARKER_RELPATH): + elif self.output.exists(_MERGE_STARTED_MARKER_RELPATH): + _clear_final_group(final) + else: with self.output.open(_MERGE_STARTED_MARKER_RELPATH, mode="wb"): pass @@ -786,6 +791,8 @@ def _validate_array_paths( arrays: Mapping[str, str], episode_ends_path: str | None, ) -> None: + for path in arrays: + _validate_public_zarr_path(path, "Zarr array path") if episode_ends_path is not None and episode_ends_path in arrays: raise ValueError( f"Zarr array path collides with episode_ends_path: {episode_ends_path}" @@ -793,6 +800,7 @@ def _validate_array_paths( def _validate_store_template(store_template: str) -> None: + _validate_public_zarr_path(store_template, "store_template") fields: set[str] = set() for _literal_text, field_name, format_spec, conversion in Formatter().parse( store_template @@ -816,6 +824,12 @@ def _validate_store_template(store_template: str) -> None: ) +def _validate_public_zarr_path(path: str, label: str) -> None: + root = str(path).lstrip("/").split("/", maxsplit=1)[0] + if root in {"_parts", "_refiner"}: + raise ValueError(f"{label} must not use reserved root: {root}") + + def _row_value(row: Row, key: str) -> Any: if isinstance(row, RoboticsRow): if key == "action": diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py index 8861bc5a..79a50d0c 100644 --- a/tests/readers/test_zarr_reader.py +++ b/tests/readers/test_zarr_reader.py @@ -864,6 +864,19 @@ def test_write_zarr_rejects_unsupported_store_template_fields(tmp_path: Path) -> ) +def test_write_zarr_rejects_reserved_paths(tmp_path: Path) -> None: + with pytest.raises(ValueError, match="reserved root"): + ZarrSink( + str(tmp_path / "reserved-array.zarr"), + arrays={"_parts/action": "action"}, + ) + with pytest.raises(ValueError, match="reserved root"): + ZarrSink( + str(tmp_path / "reserved-template.zarr"), + store_template="_refiner/{shard_id}__w{worker_id}.zarr", + ) + + def test_write_zarr_rejects_invalid_reduce_batch_bytes(tmp_path: Path) -> None: with pytest.raises(ValueError, match="reduce_array_batch_bytes"): ZarrSink(str(tmp_path / "bad-batch.zarr"), reduce_array_batch_bytes=0) @@ -1387,6 +1400,56 @@ def test_write_zarr_rejects_sharded_schema_drift_after_cleanup( reducer.write_block([DictRow({}, shard_id="reduce")]) +def test_write_zarr_no_overwrite_rejects_part_schema_drift_before_publish( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "sharded-no-overwrite-schema-drift.zarr" + first_worker = "worker-a" + second_worker = "worker-b" + _write_part_zarr( + zarr_out / "_parts" / f"shard-a__w{worker_token_for(first_worker)}.zarr", + {"data/action": np.asarray([[0.0]], dtype=np.float32)}, + ) + _write_part_zarr( + zarr_out / "_parts" / f"shard-b__w{worker_token_for(second_worker)}.zarr", + { + "data/action": np.asarray([[1.0]], dtype=np.float32), + "data/state": np.asarray([[2.0]], dtype=np.float32), + }, + ) + reducer = ZarrSink( + str(zarr_out), + arrays={"data/action": "action"}, + overwrite=False, + ).build_reducer() + assert reducer is not None + runtime = _FinalizedWorkersRuntime( + [ + FinalizedShardWorker( + shard_id="shard-a", + worker_id=first_worker, + global_ordinal=0, + ), + FinalizedShardWorker( + shard_id="shard-b", + worker_id=second_worker, + global_ordinal=1, + ), + ] + ) + + with set_active_run_context( + job_id="local", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast(RuntimeLifecycle, runtime), + ): + with pytest.raises(ValueError, match="same arrays"): + reducer.write_block([DictRow({}, shard_id="reduce")]) + assert not list(zarr_out.glob("*.zarr")) + + def test_write_zarr_single_store_skips_empty_shards(tmp_path: Path) -> None: zarr_out = tmp_path / "single-empty-shards.zarr" @@ -1767,6 +1830,59 @@ def test_write_zarr_single_store_no_overwrite_started_merge_is_retryable( assert row["episode_ends"].tolist() == [1] +def test_write_zarr_single_store_no_overwrite_started_merge_replaces_partial_output( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "single-no-overwrite-started-partial.zarr" + worker_id = "worker-a" + _write_part_zarr( + zarr_out, + {"data/action": np.asarray([[0.0]], dtype=np.float32)}, + ) + _write_part_zarr( + zarr_out / "_parts" / f"shard-a__w{worker_token_for(worker_id)}.zarr", + { + "data/action": np.asarray([[3.0]], dtype=np.float32), + "meta/episode_ends": np.asarray([1], dtype=np.int64), + }, + ) + marker = zarr_out / "_refiner" / "write_zarr.started" + marker.parent.mkdir(parents=True) + marker.write_bytes(b"") + runtime = _FinalizedWorkersRuntime( + [ + FinalizedShardWorker( + shard_id="shard-a", + worker_id=worker_id, + global_ordinal=0, + ) + ] + ) + + with set_active_run_context( + job_id="local", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast(RuntimeLifecycle, runtime), + ): + _ZarrMergeReducerSink( + str(zarr_out), + store_template="{shard_id}__w{worker_id}.zarr", + episode_ends_path="meta/episode_ends", + reduce_array_batch_bytes=1024, + overwrite=False, + ).write_block([DictRow({}, shard_id="reduce")]) + + row = mdr.read_zarr( + zarr_out, + arrays={"action": "data/action", "episode_ends": "meta/episode_ends"}, + file_path_column=None, + ).take(1)[0] + np.testing.assert_allclose(row["action"], [[3.0]]) + assert row["episode_ends"].tolist() == [1] + + def test_write_zarr_single_store_parts_are_resume_stable(tmp_path: Path) -> None: zarr_out = tmp_path / "single-resume-stable.zarr" worker_id = "original-worker" From 8a98fb44cddb0d19533a6ee8f05a045238bd5782 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sun, 24 May 2026 06:13:32 +0200 Subject: [PATCH 20/39] Harden Zarr writer retries --- src/refiner/pipeline/sinks/reducer/file.py | 10 +- src/refiner/pipeline/sinks/zarr.py | 302 ++++++++++++++++----- tests/pipeline/test_sinks.py | 67 +++++ tests/readers/test_zarr_reader.py | 297 +++++++++++++++++++- 4 files changed, 605 insertions(+), 71 deletions(-) diff --git a/src/refiner/pipeline/sinks/reducer/file.py b/src/refiner/pipeline/sinks/reducer/file.py index 445e486f..5d67fa69 100644 --- a/src/refiner/pipeline/sinks/reducer/file.py +++ b/src/refiner/pipeline/sinks/reducer/file.py @@ -166,9 +166,13 @@ def _run_cleanup(self) -> None: continue managed_path = rel_path + marker_path = None + if rel_path.endswith(".empty"): + managed_path = rel_path[: -len(".empty")] + marker_path = rel_path match = self._managed_path_pattern.fullmatch(managed_path) if match is None and self.recursive: - parts = rel_path.split("/") + parts = managed_path.split("/") for index in range(1, len(parts)): candidate = "/".join(parts[:index]) match = self._managed_path_pattern.fullmatch(candidate) @@ -179,7 +183,7 @@ def _run_cleanup(self) -> None: continue if (match.group("shard_id"), match.group("worker_id")) in keep_pairs: continue - stale_managed_paths.add(managed_path) + stale_managed_paths.add(marker_path or managed_path) for path in sorted(stale_asset_attempts): try: @@ -209,7 +213,7 @@ def _listed_cleanup_paths(self) -> list[str]: for path in paths: try: next_paths.extend(self.output.ls(path, detail=False)) - except (FileNotFoundError, NotADirectoryError, OSError, ValueError): + except (FileNotFoundError, NotADirectoryError): continue paths = next_paths return [ diff --git a/src/refiner/pipeline/sinks/zarr.py b/src/refiner/pipeline/sinks/zarr.py index ec79bafe..dda4e393 100644 --- a/src/refiner/pipeline/sinks/zarr.py +++ b/src/refiner/pipeline/sinks/zarr.py @@ -93,31 +93,120 @@ def __init__( def write_shard_block(self, shard_id: str, block: Block) -> int: count = 0 + pending_store: _ShardStore | None = None + pending_arrays: dict[str, list[np.ndarray]] = {} + pending_lengths: list[int] = [] + pending_bytes = 0 + + def flush_pending() -> None: + nonlocal pending_store, pending_arrays, pending_lengths, pending_bytes + if pending_store is None or not pending_arrays: + return + store = pending_store + rollback_lengths: dict[str, int | None] = {} + previous_row_end = store.row_end + for zarr_path in pending_arrays: + dataset = store.arrays.get(zarr_path) + rollback_lengths[zarr_path] = ( + None if dataset is None else int(dataset.shape[0]) + ) + if self.episode_ends_path is not None: + dataset = store.arrays.get(self.episode_ends_path) + rollback_lengths[self.episode_ends_path] = ( + None if dataset is None else int(dataset.shape[0]) + ) + try: + combined = { + zarr_path: ( + arrays[0] if len(arrays) == 1 else np.concatenate(arrays) + ) + for zarr_path, arrays in pending_arrays.items() + } + for zarr_path, array in combined.items(): + self._validate_array_append(store, zarr_path, array) + for zarr_path, array in combined.items(): + self._append_array(store, zarr_path, array) + if self.episode_ends_path is not None: + row_ends = ( + np.cumsum(np.asarray(pending_lengths, dtype=np.int64)) + + store.row_end + ) + self._append_array(store, self.episode_ends_path, row_ends) + store.row_end = int(row_ends[-1]) + except Exception: + for zarr_path, length in rollback_lengths.items(): + if length is None: + self._drop_array(store, zarr_path) + continue + dataset = store.arrays.get(zarr_path) + if dataset is not None: + dataset.resize((length, *dataset.shape[1:])) + store.row_end = previous_row_end + raise + finally: + pending_store = None + pending_arrays = {} + pending_lengths = [] + pending_bytes = 0 + for row in block: - self._write_row(shard_id, row) + try: + arrays = self._arrays_for_row(row) + row_arrays, row_videos, lengths = self._row_values(row, arrays) + except Exception: + flush_pending() + raise + + if row_videos: + flush_pending() + self._write_row_values(shard_id, row_arrays, row_videos, lengths) + count += 1 + continue + + if lengths: + length = lengths[0] + if any(item != length for item in lengths): + flush_pending() + raise ValueError( + "Zarr arrays for one row must have matching lengths" + ) + row_bytes = sum(array.nbytes for array in row_arrays.values()) + if ( + pending_arrays + and pending_bytes + row_bytes > self.array_chunk_bytes + ): + flush_pending() + if pending_arrays and len(pending_lengths) >= _MAX_INITIAL_CHUNK_ROWS: + flush_pending() + if pending_arrays and set(row_arrays) != set(pending_arrays): + flush_pending() + if pending_arrays and any( + pending_arrays[zarr_path][0].shape[1:] != array.shape[1:] + or pending_arrays[zarr_path][0].dtype != array.dtype + for zarr_path, array in row_arrays.items() + ): + flush_pending() + store = self._store(shard_id) + if pending_store is None: + pending_store = store + for zarr_path, array in row_arrays.items(): + pending_arrays.setdefault(zarr_path, []).append(array) + pending_lengths.append(length) + pending_bytes += row_bytes count += 1 + flush_pending() return count - def _write_row(self, shard_id: str, row: Row) -> None: - arrays = self._arrays_for_row(row) - row_arrays: dict[str, np.ndarray] = {} - row_videos: list[tuple[str, VideoSource]] = [] - lengths: list[int] = [] + def _write_row_values( + self, + shard_id: str, + row_arrays: dict[str, np.ndarray], + row_videos: list[tuple[str, VideoSource]], + lengths: list[int], + ) -> None: store: _ShardStore | None = None - for zarr_path, source_key in arrays.items(): - value = _row_value(row, source_key) - if value is None: - raise ValueError(f"Zarr source value is missing: {source_key}") - if store is None: - store = self._store(shard_id) - if isinstance(value, VideoSource): - row_videos.append((zarr_path, value)) - continue - array = _as_array(value) - if array.ndim == 0: - array = array.reshape(1) - lengths.append(int(array.shape[0])) - row_arrays[zarr_path] = array + if row_arrays or row_videos: + store = self._store(shard_id) if not lengths: expected_length = None else: @@ -179,7 +268,28 @@ def _write_row(self, shard_id: str, row: Row) -> None: else int(dataset[-1]) ) raise - return + + def _row_values( + self, + row: Row, + arrays: Mapping[str, str], + ) -> tuple[dict[str, np.ndarray], list[tuple[str, VideoSource]], list[int]]: + row_arrays: dict[str, np.ndarray] = {} + row_videos: list[tuple[str, VideoSource]] = [] + lengths: list[int] = [] + for zarr_path, source_key in arrays.items(): + value = _row_value(row, source_key) + if value is None: + raise ValueError(f"Zarr source value is missing: {source_key}") + if isinstance(value, VideoSource): + row_videos.append((zarr_path, value)) + continue + array = _as_array(value) + if array.ndim == 0: + array = array.reshape(1) + lengths.append(int(array.shape[0])) + row_arrays[zarr_path] = array + return row_arrays, row_videos, lengths async def _append_video( self, @@ -203,10 +313,20 @@ async def _append_video( if batch_limit is None: batch_limit = self._video_batch_limit(frame) if len(batch) >= batch_limit: - self._append_array(store, path, np.stack(batch, axis=0)) + self._append_array( + store, + path, + np.stack(batch, axis=0), + chunks=(batch_limit, *frame.shape), + ) batch.clear() if batch: - self._append_array(store, path, np.stack(batch, axis=0)) + self._append_array( + store, + path, + np.stack(batch, axis=0), + chunks=(batch_limit or len(batch), *batch[0].shape), + ) return video.frame_count batch: list[np.ndarray] = [] @@ -221,13 +341,23 @@ async def _append_video( raise ValueError( "Zarr arrays for one row must have matching lengths" ) - self._append_array(store, path, np.stack(batch, axis=0)) + self._append_array( + store, + path, + np.stack(batch, axis=0), + chunks=(batch_limit, *batch[0].shape), + ) count += len(batch) batch.clear() if batch: if expected_length is not None and count + len(batch) > expected_length: raise ValueError("Zarr arrays for one row must have matching lengths") - self._append_array(store, path, np.stack(batch, axis=0)) + self._append_array( + store, + path, + np.stack(batch, axis=0), + chunks=(batch_limit or len(batch), *batch[0].shape), + ) count += len(batch) if expected_length is not None and count != expected_length: raise ValueError("Zarr arrays for one row must have matching lengths") @@ -269,7 +399,7 @@ def _store(self, shard_id: str) -> _ShardStore: store = self._stores.get(relpath) if store is not None: return store - mode = "w" if self.overwrite else "w-" + mode = "w" if self.overwrite or relpath.startswith("_parts/") else "w-" import zarr store = _ShardStore( @@ -277,6 +407,10 @@ def _store(self, shard_id: str) -> _ShardStore: store=zarr_store(self.output, relpath, mode=mode), mode=mode ) ) + try: + self.output.rm(self._empty_marker_relpath(shard_id)) + except FileNotFoundError: + pass self._stores[relpath] = store return store @@ -314,6 +448,7 @@ def _store_relpath(self, shard_id: str) -> str: shard_id=shard_id, worker_id=get_active_worker_token(), ) + _validate_public_zarr_path(relpath, "rendered store path") if self.reduce_to_single_store or not self.overwrite: return _part_store_relpath(relpath) return relpath @@ -324,12 +459,15 @@ def on_shard_complete(self, shard_id: str) -> None: self._clear_merge_marker_once() if not self.overwrite and not self.reduce_to_single_store: self._clear_publish_markers_once() - if (self.reduce_to_single_store or not self.overwrite) and self._store_relpath( - shard_id - ) not in self._stores: + relpath = self._store_relpath(shard_id) + if relpath not in self._stores: + try: + self.output.rm(relpath, recursive=True) + except FileNotFoundError: + pass with self.output.open(self._empty_marker_relpath(shard_id), mode="wb"): pass - self._stores.pop(self._store_relpath(shard_id), None) + self._stores.pop(relpath, None) def _empty_marker_relpath(self, shard_id: str) -> str: return self._store_relpath(shard_id) + ".empty" @@ -339,13 +477,15 @@ def _append_array( store: _ShardStore, path: str, array: np.ndarray, + *, + chunks: tuple[int, ...] | None = None, ) -> None: _append_zarr_array( store.root, store.arrays, path, array, - chunks=_chunk_shape(array, self.array_chunk_bytes), + chunks=chunks or _chunk_shape(array, self.array_chunk_bytes), ) def _validate_array_append( @@ -437,6 +577,11 @@ def write_shard_block(self, shard_id, block) -> None: ) ] _validate_zarr_stores(self.output, relpaths) + for relpath in relpaths: + try: + self.output.rm(f"{relpath}.empty") + except FileNotFoundError: + pass class _ZarrPublishPartsReducerSink(BaseSink): @@ -475,21 +620,31 @@ def write_shard_block(self, shard_id, block) -> None: ) ] if not parts: + if self.output.exists( + _PUBLISH_DONE_MARKER_RELPATH + ) and not _output_has_payload(self.output): + _remove_parts(self.output, best_effort=True) + return if _output_has_existing_store(self.output): - self._remove_parts(best_effort=True) + _remove_parts(self.output, best_effort=True) raise ValueError("write_zarr output already exists and overwrite=False") with self.output.open(_PUBLISH_DONE_MARKER_RELPATH, mode="wb"): pass - self._remove_parts(best_effort=True) + _remove_parts(self.output, best_effort=True) return if self.output.exists(_PUBLISH_DONE_MARKER_RELPATH): - self._remove_parts(best_effort=True) + _remove_parts(self.output, best_effort=True) return - if _output_has_existing_store(self.output): + has_existing_output = _output_has_existing_store(self.output) + if has_existing_output and self.output.exists(_PUBLISH_STARTED_MARKER_RELPATH): + _validate_zarr_stores( + self.output, (_part_store_relpath(relpath) for relpath in parts) + ) + if has_existing_output: if not self.output.exists(_PUBLISH_STARTED_MARKER_RELPATH): - self._remove_parts() + _remove_parts(self.output) raise ValueError("write_zarr output already exists and overwrite=False") self._remove_publish_targets(parts) @@ -525,7 +680,7 @@ def write_shard_block(self, shard_id, block) -> None: self.output.rm(_PUBLISH_STARTED_MARKER_RELPATH) except (FileNotFoundError, OSError, ValueError): pass - self._remove_parts(best_effort=True) + _remove_parts(self.output, best_effort=True) def _remove_publish_targets(self, relpaths: Iterable[str]) -> None: for relpath in relpaths: @@ -534,15 +689,6 @@ def _remove_publish_targets(self, relpaths: Iterable[str]) -> None: except FileNotFoundError: continue - def _remove_parts(self, *, best_effort: bool = False) -> None: - try: - self.output.rm("_parts", recursive=True) - except FileNotFoundError: - pass - except (OSError, ValueError): - if not best_effort: - raise - class _ZarrMergeReducerSink(BaseSink): def __init__( @@ -595,6 +741,11 @@ def _merge(self) -> None: expected_parts = self._expected_parts(stage_index) if not expected_parts: + if self.output.exists(_DONE_MARKER_RELPATH) and not _output_has_payload( + self.output + ): + _remove_parts(self.output, best_effort=True) + return if not self.overwrite and _output_has_existing_store(self.output): raise ValueError("write_zarr output already exists and overwrite=False") import zarr @@ -607,11 +758,11 @@ def _merge(self) -> None: _clear_final_group(final) with self.output.open(_DONE_MARKER_RELPATH, mode="wb"): pass - self._remove_parts(best_effort=True) + _remove_parts(self.output, best_effort=True) return if self.output.exists(_DONE_MARKER_RELPATH): - self._remove_parts(best_effort=True) + _remove_parts(self.output, best_effort=True) return parts = self._collect_parts(expected_parts) @@ -699,7 +850,7 @@ def _merge(self) -> None: self.output.rm(_MERGE_STARTED_MARKER_RELPATH) except (FileNotFoundError, OSError, ValueError): pass - self._remove_parts(best_effort=True) + _remove_parts(self.output, best_effort=True) try: if not self.output.ls("_parts"): self.output.rmdir("_parts") @@ -715,12 +866,12 @@ def _expected_parts(self, stage_index: int) -> list[str]: ] def _part_relpath(self, shard_id: str, worker_token: str) -> str: - return _part_store_relpath( - self.store_template.format( - shard_id=shard_id, - worker_id=worker_token, - ) + relpath = self.store_template.format( + shard_id=shard_id, + worker_id=worker_token, ) + _validate_public_zarr_path(relpath, "rendered store path") + return _part_store_relpath(relpath) def _collect_parts(self, expected_parts: Iterable[str]) -> list[_PartStore]: import zarr @@ -743,6 +894,14 @@ def _collect_parts(self, expected_parts: Iterable[str]) -> list[_PartStore]: source_payload_paths = { path for path in source_paths if path != self.episode_ends_path } + if ( + self.episode_ends_path is not None + and source_payload_paths + and self.episode_ends_path not in source_paths + ): + raise ValueError( + f"Zarr part stores must contain {self.episode_ends_path!r}" + ) if payload_paths is None: payload_paths = source_payload_paths elif source_payload_paths != payload_paths: @@ -764,15 +923,6 @@ def _collect_parts(self, expected_parts: Iterable[str]) -> list[_PartStore]: parts.append(_PartStore(relpath=relpath, paths=source_paths)) return parts - def _remove_parts(self, *, best_effort: bool = False) -> None: - try: - self.output.rm("_parts", recursive=True) - except FileNotFoundError: - pass - except (OSError, ValueError): - if not best_effort: - raise - def _default_robotics_arrays(row: Row) -> dict[str, str]: if not isinstance(row, RoboticsRow): @@ -793,6 +943,8 @@ def _validate_array_paths( ) -> None: for path in arrays: _validate_public_zarr_path(path, "Zarr array path") + if episode_ends_path is not None: + _validate_public_zarr_path(episode_ends_path, "episode_ends_path") if episode_ends_path is not None and episode_ends_path in arrays: raise ValueError( f"Zarr array path collides with episode_ends_path: {episode_ends_path}" @@ -825,7 +977,13 @@ def _validate_store_template(store_template: str) -> None: def _validate_public_zarr_path(path: str, label: str) -> None: - root = str(path).lstrip("/").split("/", maxsplit=1)[0] + path = str(path) + if path.startswith("/"): + raise ValueError(f"{label} must be relative") + parts = [part for part in path.split("/") if part] + if any(part in {".", ".."} for part in parts): + raise ValueError(f"{label} must not contain '.' or '..' segments") + root = parts[0] if parts else "" if root in {"_parts", "_refiner"}: raise ValueError(f"{label} must not use reserved root: {root}") @@ -865,6 +1023,16 @@ def _part_store_relpath(relpath: str) -> str: return f"_parts/{relpath}" +def _remove_parts(output: DataFolder, *, best_effort: bool = False) -> None: + try: + output.rm("_parts", recursive=True) + except FileNotFoundError: + pass + except (OSError, ValueError): + if not best_effort: + raise + + def _validate_zarr_stores(output: DataFolder, relpaths: Iterable[str]) -> None: import zarr @@ -872,7 +1040,9 @@ def _validate_zarr_stores(output: DataFolder, relpaths: Iterable[str]) -> None: schemas: dict[str, tuple[tuple[int, ...], np.dtype[Any]]] = {} for relpath in relpaths: if not output.exists(relpath): - continue + if output.exists(f"{relpath}.empty"): + continue + raise ValueError(f"Zarr store is missing: {relpath}") source = zarr.open_group( store=zarr_store(output, relpath, mode="r"), mode="r", @@ -941,7 +1111,7 @@ def _output_has_existing_store(output: DataFolder) -> bool: return False for entry in entries: root = str(entry).split("/", maxsplit=1)[0] - if root in {"_parts", "_refiner"}: + if root in {"_parts", "_refiner", ".zgroup", ".zattrs", ".zmetadata"}: continue return True return False diff --git a/tests/pipeline/test_sinks.py b/tests/pipeline/test_sinks.py index d347003f..cb396ba4 100644 --- a/tests/pipeline/test_sinks.py +++ b/tests/pipeline/test_sinks.py @@ -965,6 +965,45 @@ def test_file_cleanup_reducer_removes_non_finalized_directories(tmp_path) -> Non assert not loser_dir.exists() +def test_file_cleanup_reducer_removes_non_finalized_empty_markers(tmp_path) -> None: + output_dir = tmp_path / "zarr-cleanup-empty-markers" + shard_id = "0123456789ab" + winner_worker_id = "worker-2" + loser_worker_id = "worker-1" + winner_marker = ( + output_dir / f"{shard_id}__w{worker_token_for(winner_worker_id)}.zarr.empty" + ) + loser_marker = ( + output_dir / f"{shard_id}__w{worker_token_for(loser_worker_id)}.zarr.empty" + ) + output_dir.mkdir(parents=True) + winner_marker.write_bytes(b"") + loser_marker.write_bytes(b"") + + reducer = FileCleanupReducerSink( + output_dir, + filename_template="{shard_id}__w{worker_id}.zarr", + reducer_name="cleanup_zarr", + recursive=True, + ) + with set_active_run_context( + job_id="job", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast( + RuntimeLifecycle, + _FinalizedWorkersRuntime( + [FinalizedShardWorker(shard_id=shard_id, worker_id=winner_worker_id)] + ), + ), + ): + reducer.write_block([DictRow({"task_rank": 0}, shard_id="reduce")]) + + assert winner_marker.exists() + assert not loser_marker.exists() + + def test_file_cleanup_reducer_removes_non_finalized_nested_directories( tmp_path, ) -> None: @@ -1091,6 +1130,34 @@ def test_file_cleanup_reducer_ignores_files_during_recursive_traversal( assert (output_dir / "split" / "README.txt").read_text(encoding="utf-8") == "notes" +def test_file_cleanup_reducer_propagates_recursive_listing_errors( + tmp_path, monkeypatch +) -> None: + output_dir = tmp_path / "zarr-cleanup-list-error" + output_dir.mkdir() + reducer = FileCleanupReducerSink( + output_dir, + filename_template="{shard_id}__w{worker_id}.zarr", + reducer_name="cleanup_zarr", + recursive=True, + ) + + def fail_ls(*_args, **_kwargs): + raise OSError("list failed") + + monkeypatch.setattr(reducer.output, "ls", fail_ls) + + with set_active_run_context( + job_id="job", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast(RuntimeLifecycle, _FinalizedWorkersRuntime([])), + ): + with pytest.raises(OSError, match="list failed"): + reducer.write_block([DictRow({"task_rank": 0}, shard_id="reduce")]) + + def test_file_cleanup_reducer_tolerates_duplicate_listed_paths( tmp_path, monkeypatch ) -> None: diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py index 79a50d0c..f6e36a6b 100644 --- a/tests/readers/test_zarr_reader.py +++ b/tests/readers/test_zarr_reader.py @@ -17,7 +17,11 @@ from refiner.pipeline.data.row import DictRow from refiner.pipeline.data.row import Row from refiner.pipeline.data.shard import RowRangeDescriptor -from refiner.pipeline.sinks.zarr import _ZarrMergeReducerSink, ZarrSink +from refiner.pipeline.sinks.zarr import ( + _ZarrCleanupReducerSink, + _ZarrMergeReducerSink, + ZarrSink, +) from refiner.worker.context import set_active_run_context, worker_token_for from refiner.worker.lifecycle import FinalizedShardWorker, RuntimeLifecycle @@ -864,12 +868,51 @@ def test_write_zarr_rejects_unsupported_store_template_fields(tmp_path: Path) -> ) +def test_write_zarr_rejects_path_traversal(tmp_path: Path) -> None: + with pytest.raises(ValueError, match="must not contain"): + ZarrSink( + str(tmp_path / "escape.zarr"), + store_template="../escape/{shard_id}__w{worker_id}.zarr", + ) + with pytest.raises(ValueError, match="must be relative"): + ZarrSink( + str(tmp_path / "absolute.zarr"), + store_template="/tmp/{shard_id}__w{worker_id}.zarr", + ) + with pytest.raises(ValueError, match="must not contain"): + ZarrSink( + str(tmp_path / "array-escape.zarr"), + arrays={"../action": "action"}, + ) + + +def test_write_zarr_rejects_rendered_path_traversal(tmp_path: Path) -> None: + with set_active_run_context( + job_id="local", + stage_index=0, + worker_id="worker-a", + worker_name=None, + runtime_lifecycle=cast(RuntimeLifecycle, _FinalizedWorkersRuntime([])), + ): + with pytest.raises(ValueError, match="must not contain"): + ZarrSink( + str(tmp_path / "rendered-escape.zarr"), + arrays={"data/action": "action"}, + ).write_block([DictRow({"action": [[1.0]]}, shard_id="../escape")]) + + def test_write_zarr_rejects_reserved_paths(tmp_path: Path) -> None: with pytest.raises(ValueError, match="reserved root"): ZarrSink( str(tmp_path / "reserved-array.zarr"), arrays={"_parts/action": "action"}, ) + with pytest.raises(ValueError, match="reserved root"): + ZarrSink( + str(tmp_path / "reserved-episode-ends.zarr"), + arrays={"data/action": "action"}, + episode_ends_path="_parts/episode_ends", + ) with pytest.raises(ValueError, match="reserved root"): ZarrSink( str(tmp_path / "reserved-template.zarr"), @@ -1151,10 +1194,55 @@ def test_write_zarr_non_reduced_no_overwrite_rejects_missing_finalized_part( worker_name=None, runtime_lifecycle=cast(RuntimeLifecycle, runtime), ): - with pytest.raises(ValueError, match="part store is missing"): + with pytest.raises(ValueError, match="Zarr store is missing"): reducer.write_block([DictRow({}, shard_id="reduce")]) +def test_write_zarr_non_reduced_cleanup_rejects_missing_finalized_store( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "sharded-missing-finalized-store.zarr" + runtime = _FinalizedWorkersRuntime( + [FinalizedShardWorker(shard_id="shard-a", worker_id="worker-a")] + ) + with set_active_run_context( + job_id="local", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast(RuntimeLifecycle, runtime), + ): + with pytest.raises(ValueError, match="Zarr store is missing"): + _ZarrCleanupReducerSink( + str(zarr_out), + store_template="{shard_id}__w{worker_id}.zarr", + ).write_block([DictRow({}, shard_id="reduce")]) + + +def test_write_zarr_empty_shard_completion_removes_stale_store( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "empty-shard-removes-stale-store.zarr" + worker_id = "worker-a" + stale = zarr_out / f"shard-a__w{worker_token_for(worker_id)}.zarr" + _write_part_zarr(stale, {"data/action": np.asarray([[9.0]], dtype=np.float32)}) + + with set_active_run_context( + job_id="local", + stage_index=0, + worker_id=worker_id, + worker_name=None, + runtime_lifecycle=cast(RuntimeLifecycle, _FinalizedWorkersRuntime([])), + ): + ZarrSink( + str(zarr_out), + arrays={"data/action": "action"}, + ).on_shard_complete("shard-a") + + assert not stale.exists() + assert stale.with_name(stale.name + ".empty").exists() + + def test_write_zarr_non_reduced_no_overwrite_skips_empty_parts( tmp_path: Path, ) -> None: @@ -1179,6 +1267,30 @@ def test_write_zarr_non_reduced_no_overwrite_skips_empty_parts( assert not (zarr_out / "_parts").exists() +def test_write_zarr_non_reduced_no_overwrite_empty_publish_is_retryable( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "sharded-empty-publish-retry.zarr" + marker = zarr_out / "_refiner" / "write_zarr_publish.done" + marker.parent.mkdir(parents=True) + marker.write_bytes(b"") + reducer = ZarrSink( + str(zarr_out), + arrays={"data/action": "action"}, + overwrite=False, + ).build_reducer() + assert reducer is not None + + with set_active_run_context( + job_id="local", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast(RuntimeLifecycle, _FinalizedWorkersRuntime([])), + ): + reducer.write_block([DictRow({}, shard_id="reduce")]) + + def test_write_zarr_non_reduced_rejects_empty_existing_output_when_not_overwriting( tmp_path: Path, ) -> None: @@ -1264,6 +1376,78 @@ def test_write_zarr_non_reduced_no_overwrite_retry_removes_partial_publish( assert not (zarr_out / "_parts").exists() +def test_write_zarr_non_reduced_no_overwrite_retry_keeps_final_when_part_missing( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "sharded-no-overwrite-retry-missing-part.zarr" + shard_id = "0123456789ab" + worker_id = "worker-a" + relpath = f"{shard_id}__w{worker_token_for(worker_id)}.zarr" + _write_part_zarr( + zarr_out / relpath, + {"data/action": np.asarray([[0.0]], dtype=np.float32)}, + ) + marker = zarr_out / "_refiner" / "write_zarr_publish.started" + marker.parent.mkdir(parents=True) + marker.write_bytes(b"") + + reducer = ZarrSink( + str(zarr_out), + arrays={"data/action": "action"}, + overwrite=False, + ).build_reducer() + assert reducer is not None + runtime = _FinalizedWorkersRuntime( + [FinalizedShardWorker(shard_id=shard_id, worker_id=worker_id)] + ) + with set_active_run_context( + job_id="local", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast(RuntimeLifecycle, runtime), + ): + with pytest.raises(ValueError, match="Zarr store is missing"): + reducer.write_block([DictRow({}, shard_id="reduce")]) + + row = mdr.read_zarr( + zarr_out / relpath, + arrays={"action": "data/action"}, + file_path_column=None, + ).take(1)[0] + np.testing.assert_allclose(row["action"], [[0.0]]) + + +def test_write_zarr_non_reduced_no_overwrite_replaces_stale_staging_part( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "sharded-no-overwrite-stale-staging.zarr" + shard_id = "shard-a" + worker_id = "worker-a" + part = zarr_out / "_parts" / f"{shard_id}__w{worker_token_for(worker_id)}.zarr" + _write_part_zarr(part, {"data/action": np.asarray([[0.0]], dtype=np.float32)}) + + with set_active_run_context( + job_id="local", + stage_index=0, + worker_id=worker_id, + worker_name=None, + runtime_lifecycle=cast(RuntimeLifecycle, _FinalizedWorkersRuntime([])), + ): + ZarrSink( + str(zarr_out), + arrays={"data/action": "action"}, + overwrite=False, + ).write_block([DictRow({"action": [[1.0]]}, shard_id=shard_id)]) + + row = mdr.read_zarr( + part, + arrays={"action": "data/action"}, + file_path_column=None, + ).take(1)[0] + np.testing.assert_allclose(row["action"], [[1.0]]) + + def test_write_zarr_non_reduced_no_overwrite_completed_publish_is_retryable( tmp_path: Path, ) -> None: @@ -1608,6 +1792,42 @@ def test_write_zarr_single_store_rejects_inconsistent_part_payloads( assert second_part.exists() +def test_write_zarr_single_store_rejects_part_missing_episode_ends( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "single-missing-episode-ends.zarr" + worker = "worker-a" + part = zarr_out / "_parts" / f"shard-a__w{worker_token_for(worker)}.zarr" + _write_part_zarr( + part, + {"data/action": np.asarray([[0.0]], dtype=np.float32)}, + ) + runtime = _FinalizedWorkersRuntime( + [ + FinalizedShardWorker( + shard_id="shard-a", + worker_id=worker, + global_ordinal=0, + ) + ] + ) + with set_active_run_context( + job_id="local", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast(RuntimeLifecycle, runtime), + ): + with pytest.raises(ValueError, match="meta/episode_ends"): + _ZarrMergeReducerSink( + str(zarr_out), + store_template="{shard_id}__w{worker_id}.zarr", + episode_ends_path="meta/episode_ends", + reduce_array_batch_bytes=1024, + overwrite=True, + ).write_block([DictRow({}, shard_id="reduce")]) + + def test_write_zarr_single_store_rejects_missing_finalized_part( tmp_path: Path, ) -> None: @@ -1779,6 +1999,31 @@ def test_write_zarr_single_store_zero_shard_no_overwrite_rejects_existing_output ).write_block([DictRow({}, shard_id="reduce")]) +def test_write_zarr_single_store_zero_shard_no_overwrite_done_marker_is_retryable( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "single-zero-shard-no-overwrite-done-retry.zarr" + marker = zarr_out / "_refiner" / "write_zarr.done" + marker.parent.mkdir(parents=True) + marker.write_bytes(b"") + runtime = _FinalizedWorkersRuntime([]) + + with set_active_run_context( + job_id="local", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast(RuntimeLifecycle, runtime), + ): + _ZarrMergeReducerSink( + str(zarr_out), + store_template="{shard_id}__w{worker_id}.zarr", + episode_ends_path="meta/episode_ends", + reduce_array_batch_bytes=1024, + overwrite=False, + ).write_block([DictRow({}, shard_id="reduce")]) + + def test_write_zarr_single_store_no_overwrite_started_merge_is_retryable( tmp_path: Path, ) -> None: @@ -1830,6 +2075,54 @@ def test_write_zarr_single_store_no_overwrite_started_merge_is_retryable( assert row["episode_ends"].tolist() == [1] +def test_write_zarr_single_store_no_overwrite_recovers_empty_root_without_marker( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "single-no-overwrite-empty-root-retry.zarr" + worker_id = "worker-a" + root = _open_test_zarr(zarr_out, mode="w") + assert list(root.array_keys()) == [] + _write_part_zarr( + zarr_out / "_parts" / f"shard-a__w{worker_token_for(worker_id)}.zarr", + { + "data/action": np.asarray([[3.0]], dtype=np.float32), + "meta/episode_ends": np.asarray([1], dtype=np.int64), + }, + ) + runtime = _FinalizedWorkersRuntime( + [ + FinalizedShardWorker( + shard_id="shard-a", + worker_id=worker_id, + global_ordinal=0, + ) + ] + ) + + with set_active_run_context( + job_id="local", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast(RuntimeLifecycle, runtime), + ): + _ZarrMergeReducerSink( + str(zarr_out), + store_template="{shard_id}__w{worker_id}.zarr", + episode_ends_path="meta/episode_ends", + reduce_array_batch_bytes=1024, + overwrite=False, + ).write_block([DictRow({}, shard_id="reduce")]) + + row = mdr.read_zarr( + zarr_out, + arrays={"action": "data/action", "episode_ends": "meta/episode_ends"}, + file_path_column=None, + ).take(1)[0] + np.testing.assert_allclose(row["action"], [[3.0]]) + assert row["episode_ends"].tolist() == [1] + + def test_write_zarr_single_store_no_overwrite_started_merge_replaces_partial_output( tmp_path: Path, ) -> None: From 1124099ed242c5dc337f603d2f590749fe891cce Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sun, 24 May 2026 06:29:42 +0200 Subject: [PATCH 21/39] Close Zarr writer retry gaps --- src/refiner/pipeline/sinks/zarr.py | 33 ++++++++++++---- tests/readers/test_zarr_reader.py | 63 ++++++++++++++++++++++++++++++ 2 files changed, 88 insertions(+), 8 deletions(-) diff --git a/src/refiner/pipeline/sinks/zarr.py b/src/refiner/pipeline/sinks/zarr.py index dda4e393..8be7eefc 100644 --- a/src/refiner/pipeline/sinks/zarr.py +++ b/src/refiner/pipeline/sinks/zarr.py @@ -359,6 +359,8 @@ async def _append_video( chunks=(batch_limit or len(batch), *batch[0].shape), ) count += len(batch) + if count == 0: + raise ValueError("Zarr video source produced no frames") if expected_length is not None and count != expected_length: raise ValueError("Zarr arrays for one row must have matching lengths") return count @@ -444,11 +446,11 @@ def _clear_merge_marker_once(self) -> None: pass def _store_relpath(self, shard_id: str) -> str: - relpath = self.store_template.format( + relpath = _render_store_relpath( + self.store_template, shard_id=shard_id, worker_id=get_active_worker_token(), ) - _validate_public_zarr_path(relpath, "rendered store path") if self.reduce_to_single_store or not self.overwrite: return _part_store_relpath(relpath) return relpath @@ -569,8 +571,10 @@ def write_shard_block(self, shard_id, block) -> None: "write_zarr_reduce requires an active reducer stage with a prior writer stage" ) relpaths = [ - self.store_template.format( - shard_id=row.shard_id, worker_id=row.worker_token + _render_store_relpath( + self.store_template, + shard_id=row.shard_id, + worker_id=row.worker_token, ) for row in sort_finalized_workers( get_finalized_workers(stage_index=stage_index - 1) @@ -612,8 +616,10 @@ def write_shard_block(self, shard_id, block) -> None: ) parts = [ - self.store_template.format( - shard_id=row.shard_id, worker_id=row.worker_token + _render_store_relpath( + self.store_template, + shard_id=row.shard_id, + worker_id=row.worker_token, ) for row in sort_finalized_workers( get_finalized_workers(stage_index=stage_index - 1) @@ -866,11 +872,11 @@ def _expected_parts(self, stage_index: int) -> list[str]: ] def _part_relpath(self, shard_id: str, worker_token: str) -> str: - relpath = self.store_template.format( + relpath = _render_store_relpath( + self.store_template, shard_id=shard_id, worker_id=worker_token, ) - _validate_public_zarr_path(relpath, "rendered store path") return _part_store_relpath(relpath) def _collect_parts(self, expected_parts: Iterable[str]) -> list[_PartStore]: @@ -976,6 +982,17 @@ def _validate_store_template(store_template: str) -> None: ) +def _render_store_relpath( + store_template: str, + *, + shard_id: str, + worker_id: str, +) -> str: + relpath = store_template.format(shard_id=shard_id, worker_id=worker_id) + _validate_public_zarr_path(relpath, "rendered store path") + return relpath + + def _validate_public_zarr_path(path: str, label: str) -> None: path = str(path) if path.startswith("/"): diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py index f6e36a6b..43acfb5b 100644 --- a/tests/readers/test_zarr_reader.py +++ b/tests/readers/test_zarr_reader.py @@ -37,6 +37,22 @@ def finalized_workers( return self._rows +class _EmptyVideoSource: + def clipped(self, **_kwargs): + return self + + async def iter_frames(self): + if False: + yield None + + async def iter_frame_windows(self, **_kwargs): + if False: + yield None + + async def write_to(self, writer, **_kwargs): + raise NotImplementedError + + def _open_test_zarr(path: Path, *, mode: Literal["r", "r+", "a", "w", "w-"]): kwargs: dict[str, Any] = {"mode": mode, "zarr_format": 2} try: @@ -1198,6 +1214,31 @@ def test_write_zarr_non_reduced_no_overwrite_rejects_missing_finalized_part( reducer.write_block([DictRow({}, shard_id="reduce")]) +def test_write_zarr_non_reduced_no_overwrite_rejects_unsafe_finalized_part_path( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "sharded-no-overwrite-unsafe-finalized.zarr" + reducer = ZarrSink( + str(zarr_out), + arrays={"data/action": "action"}, + overwrite=False, + ).build_reducer() + assert reducer is not None + + runtime = _FinalizedWorkersRuntime( + [FinalizedShardWorker(shard_id="../escape", worker_id="worker-a")] + ) + with set_active_run_context( + job_id="local", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast(RuntimeLifecycle, runtime), + ): + with pytest.raises(ValueError, match="must not contain"): + reducer.write_block([DictRow({}, shard_id="reduce")]) + + def test_write_zarr_non_reduced_cleanup_rejects_missing_finalized_store( tmp_path: Path, ) -> None: @@ -2530,6 +2571,28 @@ def test_write_zarr_streams_encoded_videos(tmp_path: Path) -> None: np.testing.assert_allclose(row["action"], [[0.0], [0.1], [0.2]]) +def test_write_zarr_rejects_empty_encoded_video_source(tmp_path: Path) -> None: + output = tmp_path / "empty-encoded-video.zarr" + + rows = list( + mdr.from_items( + [{"episode_id": "episode-1", "clip": _EmptyVideoSource()}] + ).to_robot_rows( + episode_id_key="episode_id", + action_key=None, + state_key=None, + timestamp_key=None, + video_keys={"observation.images.front": "clip"}, + ) + ) + + with pytest.raises(ValueError, match="produced no frames"): + ZarrSink( + str(output), + arrays={"data/rgb": "observation.images.front"}, + ).write_block(rows) + + def test_write_zarr_rejects_video_length_mismatch_before_final_append( tmp_path: Path, ) -> None: From 45cf8e95c0c7a5148d004ca06d3e5fdf67dbb876 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sun, 24 May 2026 07:02:58 +0200 Subject: [PATCH 22/39] Harden Zarr overwrite cleanup --- src/refiner/io/zarr.py | 2 +- src/refiner/pipeline/sinks/zarr.py | 43 +++++-- tests/readers/test_zarr_reader.py | 180 +++++++++++++++++++++++++++++ 3 files changed, 215 insertions(+), 10 deletions(-) diff --git a/src/refiner/io/zarr.py b/src/refiner/io/zarr.py index 90312f42..c1918438 100644 --- a/src/refiner/io/zarr.py +++ b/src/refiner/io/zarr.py @@ -11,7 +11,7 @@ def zarr_store( folder: DataFolder, path: str = "", *, - mode: Literal["r", "w", "w-", "a"] = "r", + mode: Literal["r", "r+", "w", "w-", "a"] = "r", ): import zarr diff --git a/src/refiner/pipeline/sinks/zarr.py b/src/refiner/pipeline/sinks/zarr.py index 8be7eefc..11a3cefc 100644 --- a/src/refiner/pipeline/sinks/zarr.py +++ b/src/refiner/pipeline/sinks/zarr.py @@ -581,11 +581,34 @@ def write_shard_block(self, shard_id, block) -> None: ) ] _validate_zarr_stores(self.output, relpaths) - for relpath in relpaths: - try: - self.output.rm(f"{relpath}.empty") - except FileNotFoundError: - pass + _remove_parts(self.output) + self._clear_root_payload_except(relpaths) + + def _clear_root_payload_except(self, relpaths: Iterable[str]) -> None: + import zarr + + keep_paths = set(relpaths) + try: + root = zarr.open_group(store=zarr_store(self.output, "", mode="r+")) + except Exception: + return + + def clear_group(group: Any, prefix: str = "") -> None: + group_keys = set(group.group_keys()) + for key in sorted({*group.array_keys(), *group_keys}): + path = f"{prefix}/{key}" if prefix else key + if path == "_refiner" or path.startswith("_refiner/"): + continue + if path in keep_paths: + continue + if any(keep_path.startswith(f"{path}/") for keep_path in keep_paths): + if key in group_keys: + clear_group(group[key], path) + continue + del group[key] + group.attrs.clear() + + clear_group(root) class _ZarrPublishPartsReducerSink(BaseSink): @@ -644,13 +667,14 @@ def write_shard_block(self, shard_id, block) -> None: return has_existing_output = _output_has_existing_store(self.output) + parts_validated = False if has_existing_output and self.output.exists(_PUBLISH_STARTED_MARKER_RELPATH): _validate_zarr_stores( self.output, (_part_store_relpath(relpath) for relpath in parts) ) + parts_validated = True if has_existing_output: if not self.output.exists(_PUBLISH_STARTED_MARKER_RELPATH): - _remove_parts(self.output) raise ValueError("write_zarr output already exists and overwrite=False") self._remove_publish_targets(parts) @@ -658,9 +682,10 @@ def write_shard_block(self, shard_id, block) -> None: pass try: - _validate_zarr_stores( - self.output, (_part_store_relpath(relpath) for relpath in parts) - ) + if not parts_validated: + _validate_zarr_stores( + self.output, (_part_store_relpath(relpath) for relpath in parts) + ) for final_relpath in parts: part_relpath = _part_store_relpath(final_relpath) if not self.output.exists(part_relpath): diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py index 43acfb5b..dc636942 100644 --- a/tests/readers/test_zarr_reader.py +++ b/tests/readers/test_zarr_reader.py @@ -1004,6 +1004,100 @@ def test_write_zarr_single_store_overwrite_ignores_stale_parts(tmp_path: Path) - assert not stale_part.exists() +def test_write_zarr_sharded_overwrite_removes_single_store_payload_and_parts( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "sharded-overwrites-single-store.zarr" + + ( + mdr.from_items([{"action": [[0.0]]}], items_per_shard=1) + .write_zarr( + str(zarr_out), + arrays={"data/action": "action"}, + reduce_to_single_store=True, + ) + .launch_local( + name="zarr-sharded-overwrite-first", + num_workers=1, + rundir=str(tmp_path / "run-sharded-overwrite-first"), + ) + ) + stale_part = zarr_out / "_parts" / "old__wold.zarr" + stale_part.mkdir(parents=True) + (stale_part / ".zgroup").write_text('{"zarr_format": 2}', encoding="utf-8") + + ( + mdr.from_items([{"action": [[1.0], [2.0]]}], items_per_shard=1) + .write_zarr(str(zarr_out), arrays={"data/action": "action"}) + .launch_local( + name="zarr-sharded-overwrite-second", + num_workers=1, + rundir=str(tmp_path / "run-sharded-overwrite-second"), + ) + ) + + assert not (zarr_out / "data").exists() + assert not (zarr_out / "meta").exists() + assert not (zarr_out / "_parts").exists() + stores = sorted(zarr_out.glob("*.zarr")) + assert len(stores) == 1 + row = mdr.read_zarr( + stores[0], + arrays={ + "action": "data/action", + "episode_ends": "meta/episode_ends", + }, + file_path_column=None, + ).take(1)[0] + np.testing.assert_allclose(row["action"], [[1.0], [2.0]]) + assert row["episode_ends"].tolist() == [2] + + +def test_write_zarr_sharded_overwrite_clears_payload_under_store_prefix( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "sharded-overwrites-nested-single-store.zarr" + + ( + mdr.from_items([{"action": [[0.0]]}], items_per_shard=1) + .write_zarr( + str(zarr_out), + arrays={"split/action": "action"}, + reduce_to_single_store=True, + ) + .launch_local( + name="zarr-sharded-nested-overwrite-first", + num_workers=1, + rundir=str(tmp_path / "run-sharded-nested-overwrite-first"), + ) + ) + + ( + mdr.from_items([{"action": [[1.0], [2.0]]}], items_per_shard=1) + .write_zarr( + str(zarr_out), + arrays={"data/action": "action"}, + store_template="split/{shard_id}__w{worker_id}.zarr", + ) + .launch_local( + name="zarr-sharded-nested-overwrite-second", + num_workers=1, + rundir=str(tmp_path / "run-sharded-nested-overwrite-second"), + ) + ) + + assert not (zarr_out / "split" / "action").exists() + assert not (zarr_out / "meta").exists() + stores = sorted((zarr_out / "split").glob("*.zarr")) + assert len(stores) == 1 + row = mdr.read_zarr( + stores[0], + arrays={"action": "data/action"}, + file_path_column=None, + ).take(1)[0] + np.testing.assert_allclose(row["action"], [[1.0], [2.0]]) + + def test_write_zarr_single_store_rejects_existing_output_when_not_overwriting( tmp_path: Path, ) -> None: @@ -1153,6 +1247,64 @@ def test_write_zarr_rejects_existing_non_reduced_output_when_not_overwriting( assert not (zarr_out / "_parts").exists() +def test_write_zarr_non_reduced_no_overwrite_preserves_parts_on_conflict_for_resume( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "sharded-no-overwrite-conflict-resume.zarr" + shard_id = "0123456789ab" + worker_id = "worker-a" + relpath = f"{shard_id}__w{worker_token_for(worker_id)}.zarr" + part = zarr_out / "_parts" / relpath + existing = zarr_out / "existing.zarr" + _write_part_zarr(part, {"data/action": np.asarray([[1.0]], dtype=np.float32)}) + _write_part_zarr(existing, {"data/action": np.asarray([[0.0]], dtype=np.float32)}) + + runtime = _FinalizedWorkersRuntime( + [FinalizedShardWorker(shard_id=shard_id, worker_id=worker_id)] + ) + with set_active_run_context( + job_id="local", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast(RuntimeLifecycle, runtime), + ): + reducer = ZarrSink( + str(zarr_out), + arrays={"data/action": "action"}, + overwrite=False, + ).build_reducer() + assert reducer is not None + with pytest.raises(ValueError, match="output already exists"): + reducer.write_block([DictRow({}, shard_id="reduce")]) + + assert part.exists() + shutil.rmtree(existing) + + with set_active_run_context( + job_id="local", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast(RuntimeLifecycle, runtime), + ): + reducer = ZarrSink( + str(zarr_out), + arrays={"data/action": "action"}, + overwrite=False, + ).build_reducer() + assert reducer is not None + reducer.write_block([DictRow({}, shard_id="reduce")]) + + assert not part.exists() + row = mdr.read_zarr( + zarr_out / relpath, + arrays={"action": "data/action"}, + file_path_column=None, + ).take(1)[0] + np.testing.assert_allclose(row["action"], [[1.0]]) + + def test_write_zarr_non_reduced_no_overwrite_preserves_finalized_retry_output( tmp_path: Path, ) -> None: @@ -1284,6 +1436,34 @@ def test_write_zarr_empty_shard_completion_removes_stale_store( assert stale.with_name(stale.name + ".empty").exists() +def test_write_zarr_non_reduced_cleanup_keeps_empty_markers_retryable( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "sharded-empty-cleanup-retry.zarr" + worker_id = "worker-a" + marker = zarr_out / f"shard-a__w{worker_token_for(worker_id)}.zarr.empty" + marker.parent.mkdir(parents=True) + marker.write_bytes(b"") + + runtime = _FinalizedWorkersRuntime( + [FinalizedShardWorker(shard_id="shard-a", worker_id=worker_id)] + ) + for _ in range(2): + with set_active_run_context( + job_id="local", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast(RuntimeLifecycle, runtime), + ): + _ZarrCleanupReducerSink( + str(zarr_out), + store_template="{shard_id}__w{worker_id}.zarr", + ).write_block([DictRow({}, shard_id="reduce")]) + + assert marker.exists() + + def test_write_zarr_non_reduced_no_overwrite_skips_empty_parts( tmp_path: Path, ) -> None: From d58b5a5bea45224cbeda35a6ee42931e201b515f Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sun, 24 May 2026 21:46:26 +0200 Subject: [PATCH 23/39] Default Zarr writes to single store --- docs/reading-and-writing.md | 13 +++--- src/refiner/pipeline/pipeline.py | 26 ++++++++++-- src/refiner/pipeline/sinks/zarr.py | 25 +++-------- tests/readers/test_zarr_reader.py | 68 +++++++++++++++++++++--------- 4 files changed, 85 insertions(+), 47 deletions(-) diff --git a/docs/reading-and-writing.md b/docs/reading-and-writing.md index bc0c555f..d462eb45 100644 --- a/docs/reading-and-writing.md +++ b/docs/reading-and-writing.md @@ -513,19 +513,18 @@ arrays: actions, states, and timestamps. The default schema is inferred once and later rows must expose the same fields. Video sources selected through `arrays` are decoded as RGB frame arrays and appended in bounded batches controlled by `video_frame_batch_size`. `array_chunk_bytes` controls the target chunk size for -new arrays. When `reduce_to_single_store=True`, the reducer copies shard-local -arrays into the final store in read/write batches controlled by -`reduce_array_batch_bytes`; by default it uses the same value as -`array_chunk_bytes`. +new arrays and the reducer read/write batch size when shard-local stores are +merged into a final store. By default, `write_zarr(...)` also writes cumulative episode boundaries to `meta/episode_ends`. Set `episode_ends_path=None` to omit them. Launched runs write isolated stores per shard/worker using `store_template="{shard_id}__w{worker_id}.zarr"`. This avoids concurrent workers -mutating the same Zarr group. Read the resulting stores individually, or set -`reduce_to_single_store=True` to add a reducer stage that streams the shard-local -stores into one final Zarr group at the requested output path. +mutating the same Zarr group. By default, a reducer stage streams those +shard-local stores into one final Zarr group at the requested output path. Set +`reduce_to_single_store=False` to keep the isolated stores and read them +individually. When you run a writer through `launch_local(...)` or `launch_cloud(...)`, some sinks add a reducer stage after the main writer stage. For `write_jsonl(...)` diff --git a/src/refiner/pipeline/pipeline.py b/src/refiner/pipeline/pipeline.py index 5afcde8e..b73458cd 100644 --- a/src/refiner/pipeline/pipeline.py +++ b/src/refiner/pipeline/pipeline.py @@ -440,10 +440,31 @@ def write_zarr( store_template: str = "{shard_id}__w{worker_id}.zarr", video_frame_batch_size: int = 8, array_chunk_bytes: int = 8 * 1024 * 1024, - reduce_array_batch_bytes: int | None = None, - reduce_to_single_store: bool = False, + reduce_to_single_store: bool = True, overwrite: bool = True, ) -> "RefinerPipeline": + """Write rows to Zarr array stores. + + Args: + output: Output folder or URL prefix for the Zarr store(s). + arrays: Mapping from output Zarr array path to source row key. If + omitted for ``RoboticsRow`` inputs, writes the available default + robotics arrays: actions, states, and timestamps. + episode_ends_path: Output Zarr path for cumulative row/episode end + offsets. Set to None to omit episode boundaries. + store_template: Per-shard store path template. Must include + ``{shard_id}`` and ``{worker_id}``. + video_frame_batch_size: Maximum decoded video frames to append per + video write batch. + array_chunk_bytes: Target byte size for chunks created for newly + written arrays and for read/write batches when reducing shard + stores into a single store. + reduce_to_single_store: If True, add a reducer stage that merges + shard-local stores into one Zarr group at ``output``. Defaults + to True. + overwrite: If True, replace Refiner-managed output at the target. + If False, fail when final output already exists. + """ return self.with_sink( ZarrSink( output=output, @@ -452,7 +473,6 @@ def write_zarr( store_template=store_template, video_frame_batch_size=video_frame_batch_size, array_chunk_bytes=array_chunk_bytes, - reduce_array_batch_bytes=reduce_array_batch_bytes, reduce_to_single_store=reduce_to_single_store, overwrite=overwrite, ) diff --git a/src/refiner/pipeline/sinks/zarr.py b/src/refiner/pipeline/sinks/zarr.py index 11a3cefc..23eb9581 100644 --- a/src/refiner/pipeline/sinks/zarr.py +++ b/src/refiner/pipeline/sinks/zarr.py @@ -56,8 +56,7 @@ def __init__( store_template: str = "{shard_id}__w{worker_id}.zarr", video_frame_batch_size: int = 8, array_chunk_bytes: int = _DEFAULT_ARRAY_CHUNK_BYTES, - reduce_array_batch_bytes: int | None = None, - reduce_to_single_store: bool = False, + reduce_to_single_store: bool = True, overwrite: bool = True, ): check_required_dependencies("write_zarr", ["zarr"], dist="zarr") @@ -65,8 +64,6 @@ def __init__( raise ValueError("video_frame_batch_size must be greater than zero") if array_chunk_bytes <= 0: raise ValueError("array_chunk_bytes must be greater than zero") - if reduce_array_batch_bytes is not None and reduce_array_batch_bytes <= 0: - raise ValueError("reduce_array_batch_bytes must be greater than zero") _validate_store_template(store_template) self.output = DataFolder.resolve(output) self.arrays = dict(arrays) if arrays is not None else None @@ -78,11 +75,6 @@ def __init__( self.store_template = store_template self.video_frame_batch_size = video_frame_batch_size self.array_chunk_bytes = array_chunk_bytes - self.reduce_array_batch_bytes = ( - array_chunk_bytes - if reduce_array_batch_bytes is None - else reduce_array_batch_bytes - ) self.reduce_to_single_store = reduce_to_single_store self.overwrite = overwrite self._stores: dict[str, _ShardStore] = {} @@ -522,7 +514,6 @@ def describe(self) -> tuple[str, str, dict[str, object]]: "store_template": self.store_template, "video_frame_batch_size": self.video_frame_batch_size, "array_chunk_bytes": self.array_chunk_bytes, - "reduce_array_batch_bytes": self.reduce_array_batch_bytes, "reduce_to_single_store": self.reduce_to_single_store, "overwrite": self.overwrite, }, @@ -534,7 +525,7 @@ def build_reducer(self) -> BaseSink | None: output=self.output, store_template=self.store_template, episode_ends_path=self.episode_ends_path, - reduce_array_batch_bytes=self.reduce_array_batch_bytes, + array_chunk_bytes=self.array_chunk_bytes, overwrite=self.overwrite, ) if not self.overwrite: @@ -728,14 +719,14 @@ def __init__( *, store_template: str, episode_ends_path: str | None, - reduce_array_batch_bytes: int, + array_chunk_bytes: int, overwrite: bool, ) -> None: check_required_dependencies("write_zarr", ["zarr"], dist="zarr") self.output = DataFolder.resolve(output) self.store_template = store_template self.episode_ends_path = episode_ends_path - self.reduce_array_batch_bytes = reduce_array_batch_bytes + self.array_chunk_bytes = array_chunk_bytes self.overwrite = overwrite self._merged = False @@ -754,7 +745,7 @@ def describe(self) -> tuple[str, str, dict[str, object]]: { "path": self.output.abs_path(), "store_template": self.store_template, - "reduce_array_batch_bytes": self.reduce_array_batch_bytes, + "array_chunk_bytes": self.array_chunk_bytes, "reduce_to_single_store": True, }, ) @@ -831,7 +822,7 @@ def _merge(self) -> None: part_last = row_offset batch_size = _batch_length( source_array, - self.reduce_array_batch_bytes, + self.array_chunk_bytes, ) for start in range(0, int(source_array.shape[0]), batch_size): end = min(int(source_array.shape[0]), start + batch_size) @@ -847,9 +838,7 @@ def _merge(self) -> None: part_last = int(values[-1]) row_offset += part_last continue - batch_size = _batch_length( - source_array, self.reduce_array_batch_bytes - ) + batch_size = _batch_length(source_array, self.array_chunk_bytes) if source_array.shape[0] == 0: _append_zarr_array( final, diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py index dc636942..a45141ce 100644 --- a/tests/readers/test_zarr_reader.py +++ b/tests/readers/test_zarr_reader.py @@ -772,6 +772,7 @@ def test_write_zarr_roundtrips_lerobot_rows(tmp_path: Path) -> None: "data/action": "action", "data/state": "observation.state", }, + reduce_to_single_store=False, ) .launch_local( name="lerobot-to-zarr", num_workers=1, rundir=str(tmp_path / "run2") @@ -936,11 +937,6 @@ def test_write_zarr_rejects_reserved_paths(tmp_path: Path) -> None: ) -def test_write_zarr_rejects_invalid_reduce_batch_bytes(tmp_path: Path) -> None: - with pytest.raises(ValueError, match="reduce_array_batch_bytes"): - ZarrSink(str(tmp_path / "bad-batch.zarr"), reduce_array_batch_bytes=0) - - def test_write_zarr_rejects_empty_array_mapping(tmp_path: Path) -> None: with pytest.raises(ValueError, match="arrays must not be empty"): ZarrSink(str(tmp_path / "empty-arrays.zarr"), arrays={}) @@ -1028,7 +1024,11 @@ def test_write_zarr_sharded_overwrite_removes_single_store_payload_and_parts( ( mdr.from_items([{"action": [[1.0], [2.0]]}], items_per_shard=1) - .write_zarr(str(zarr_out), arrays={"data/action": "action"}) + .write_zarr( + str(zarr_out), + arrays={"data/action": "action"}, + reduce_to_single_store=False, + ) .launch_local( name="zarr-sharded-overwrite-second", num_workers=1, @@ -1078,6 +1078,7 @@ def test_write_zarr_sharded_overwrite_clears_payload_under_store_prefix( str(zarr_out), arrays={"data/action": "action"}, store_template="split/{shard_id}__w{worker_id}.zarr", + reduce_to_single_store=False, ) .launch_local( name="zarr-sharded-nested-overwrite-second", @@ -1212,7 +1213,11 @@ def test_write_zarr_rejects_existing_non_reduced_output_when_not_overwriting( ( mdr.from_items([{"action": [[0.0]]}], items_per_shard=1) - .write_zarr(str(zarr_out), arrays={"data/action": "action"}) + .write_zarr( + str(zarr_out), + arrays={"data/action": "action"}, + reduce_to_single_store=False, + ) .launch_local( name="zarr-sharded-no-overwrite-first", num_workers=1, @@ -1227,6 +1232,7 @@ def test_write_zarr_rejects_existing_non_reduced_output_when_not_overwriting( str(zarr_out), arrays={"data/action": "action"}, overwrite=False, + reduce_to_single_store=False, ) .launch_local( name="zarr-sharded-no-overwrite-second", @@ -1273,6 +1279,7 @@ def test_write_zarr_non_reduced_no_overwrite_preserves_parts_on_conflict_for_res str(zarr_out), arrays={"data/action": "action"}, overwrite=False, + reduce_to_single_store=False, ).build_reducer() assert reducer is not None with pytest.raises(ValueError, match="output already exists"): @@ -1292,6 +1299,7 @@ def test_write_zarr_non_reduced_no_overwrite_preserves_parts_on_conflict_for_res str(zarr_out), arrays={"data/action": "action"}, overwrite=False, + reduce_to_single_store=False, ).build_reducer() assert reducer is not None reducer.write_block([DictRow({}, shard_id="reduce")]) @@ -1321,6 +1329,7 @@ def test_write_zarr_non_reduced_no_overwrite_preserves_finalized_retry_output( str(zarr_out), arrays={"data/action": "action"}, overwrite=False, + reduce_to_single_store=False, ).build_reducer() assert reducer is not None @@ -1349,6 +1358,7 @@ def test_write_zarr_non_reduced_no_overwrite_rejects_missing_finalized_part( str(zarr_out), arrays={"data/action": "action"}, overwrite=False, + reduce_to_single_store=False, ).build_reducer() assert reducer is not None @@ -1374,6 +1384,7 @@ def test_write_zarr_non_reduced_no_overwrite_rejects_unsafe_finalized_part_path( str(zarr_out), arrays={"data/action": "action"}, overwrite=False, + reduce_to_single_store=False, ).build_reducer() assert reducer is not None @@ -1430,6 +1441,7 @@ def test_write_zarr_empty_shard_completion_removes_stale_store( ZarrSink( str(zarr_out), arrays={"data/action": "action"}, + reduce_to_single_store=False, ).on_shard_complete("shard-a") assert not stale.exists() @@ -1476,6 +1488,7 @@ def test_write_zarr_non_reduced_no_overwrite_skips_empty_parts( str(zarr_out), arrays={"data/action": "action"}, overwrite=False, + reduce_to_single_store=False, ) .launch_local( name="zarr-sharded-empty-no-overwrite", @@ -1499,6 +1512,7 @@ def test_write_zarr_non_reduced_no_overwrite_empty_publish_is_retryable( str(zarr_out), arrays={"data/action": "action"}, overwrite=False, + reduce_to_single_store=False, ).build_reducer() assert reducer is not None @@ -1524,6 +1538,7 @@ def test_write_zarr_non_reduced_rejects_empty_existing_output_when_not_overwriti str(zarr_out), arrays={"data/action": "action"}, overwrite=False, + reduce_to_single_store=False, ) .launch_local( name="zarr-sharded-empty-existing-first", @@ -1539,6 +1554,7 @@ def test_write_zarr_non_reduced_rejects_empty_existing_output_when_not_overwriti str(zarr_out), arrays={"data/action": "action"}, overwrite=False, + reduce_to_single_store=False, ) .launch_local( name="zarr-sharded-empty-existing-second", @@ -1574,6 +1590,7 @@ def test_write_zarr_non_reduced_no_overwrite_retry_removes_partial_publish( str(zarr_out), arrays={"data/action": "action"}, overwrite=False, + reduce_to_single_store=False, ).build_reducer() assert reducer is not None runtime = _FinalizedWorkersRuntime( @@ -1616,6 +1633,7 @@ def test_write_zarr_non_reduced_no_overwrite_retry_keeps_final_when_part_missing str(zarr_out), arrays={"data/action": "action"}, overwrite=False, + reduce_to_single_store=False, ).build_reducer() assert reducer is not None runtime = _FinalizedWorkersRuntime( @@ -1659,6 +1677,7 @@ def test_write_zarr_non_reduced_no_overwrite_replaces_stale_staging_part( str(zarr_out), arrays={"data/action": "action"}, overwrite=False, + reduce_to_single_store=False, ).write_block([DictRow({"action": [[1.0]]}, shard_id=shard_id)]) row = mdr.read_zarr( @@ -1692,6 +1711,7 @@ def test_write_zarr_non_reduced_no_overwrite_completed_publish_is_retryable( str(zarr_out), arrays={"data/action": "action"}, overwrite=False, + reduce_to_single_store=False, ).build_reducer() assert reducer is not None runtime = _FinalizedWorkersRuntime( @@ -1729,6 +1749,7 @@ def test_write_zarr_allows_fresh_non_reduced_multiworker_no_overwrite( str(zarr_out), arrays={"data/action": "action"}, overwrite=False, + reduce_to_single_store=False, ) .launch_local( name="zarr-sharded-fresh-no-overwrite", @@ -1777,6 +1798,7 @@ def test_write_zarr_rejects_sharded_schema_drift_after_cleanup( reducer = ZarrSink( str(zarr_out), arrays={"data/action": "action"}, + reduce_to_single_store=False, ).build_reducer() assert reducer is not None runtime = _FinalizedWorkersRuntime( @@ -1826,6 +1848,7 @@ def test_write_zarr_no_overwrite_rejects_part_schema_drift_before_publish( str(zarr_out), arrays={"data/action": "action"}, overwrite=False, + reduce_to_single_store=False, ).build_reducer() assert reducer is not None runtime = _FinalizedWorkersRuntime( @@ -2006,7 +2029,7 @@ def test_write_zarr_single_store_rejects_inconsistent_part_payloads( str(zarr_out), store_template="{shard_id}__w{worker_id}.zarr", episode_ends_path="meta/episode_ends", - reduce_array_batch_bytes=1024, + array_chunk_bytes=1024, overwrite=True, ).write_block([DictRow({}, shard_id="reduce")]) assert first_part.exists() @@ -2044,7 +2067,7 @@ def test_write_zarr_single_store_rejects_part_missing_episode_ends( str(zarr_out), store_template="{shard_id}__w{worker_id}.zarr", episode_ends_path="meta/episode_ends", - reduce_array_batch_bytes=1024, + array_chunk_bytes=1024, overwrite=True, ).write_block([DictRow({}, shard_id="reduce")]) @@ -2082,7 +2105,7 @@ def test_write_zarr_single_store_rejects_missing_finalized_part( str(zarr_out), store_template="{shard_id}__w{worker_id}.zarr", episode_ends_path="meta/episode_ends", - reduce_array_batch_bytes=1024, + array_chunk_bytes=1024, overwrite=True, ).write_block([DictRow({}, shard_id="reduce")]) @@ -2138,7 +2161,7 @@ def test_write_zarr_single_store_completed_merge_is_retryable( str(zarr_out), store_template="{shard_id}__w{worker_id}.zarr", episode_ends_path="meta/episode_ends", - reduce_array_batch_bytes=1024, + array_chunk_bytes=1024, overwrite=True, ).write_block([DictRow({}, shard_id="reduce")]) @@ -2179,7 +2202,7 @@ def test_write_zarr_single_store_zero_shard_overwrite_ignores_stale_done_marker( str(zarr_out), store_template="{shard_id}__w{worker_id}.zarr", episode_ends_path="meta/episode_ends", - reduce_array_batch_bytes=1024, + array_chunk_bytes=1024, overwrite=True, ).write_block([DictRow({}, shard_id="reduce")]) @@ -2215,7 +2238,7 @@ def test_write_zarr_single_store_zero_shard_no_overwrite_rejects_existing_output str(zarr_out), store_template="{shard_id}__w{worker_id}.zarr", episode_ends_path="meta/episode_ends", - reduce_array_batch_bytes=1024, + array_chunk_bytes=1024, overwrite=False, ).write_block([DictRow({}, shard_id="reduce")]) @@ -2240,7 +2263,7 @@ def test_write_zarr_single_store_zero_shard_no_overwrite_done_marker_is_retryabl str(zarr_out), store_template="{shard_id}__w{worker_id}.zarr", episode_ends_path="meta/episode_ends", - reduce_array_batch_bytes=1024, + array_chunk_bytes=1024, overwrite=False, ).write_block([DictRow({}, shard_id="reduce")]) @@ -2283,7 +2306,7 @@ def test_write_zarr_single_store_no_overwrite_started_merge_is_retryable( str(zarr_out), store_template="{shard_id}__w{worker_id}.zarr", episode_ends_path="meta/episode_ends", - reduce_array_batch_bytes=1024, + array_chunk_bytes=1024, overwrite=False, ).write_block([DictRow({}, shard_id="reduce")]) @@ -2331,7 +2354,7 @@ def test_write_zarr_single_store_no_overwrite_recovers_empty_root_without_marker str(zarr_out), store_template="{shard_id}__w{worker_id}.zarr", episode_ends_path="meta/episode_ends", - reduce_array_batch_bytes=1024, + array_chunk_bytes=1024, overwrite=False, ).write_block([DictRow({}, shard_id="reduce")]) @@ -2384,7 +2407,7 @@ def test_write_zarr_single_store_no_overwrite_started_merge_replaces_partial_out str(zarr_out), store_template="{shard_id}__w{worker_id}.zarr", episode_ends_path="meta/episode_ends", - reduce_array_batch_bytes=1024, + array_chunk_bytes=1024, overwrite=False, ).write_block([DictRow({}, shard_id="reduce")]) @@ -2429,7 +2452,7 @@ def test_write_zarr_single_store_parts_are_resume_stable(tmp_path: Path) -> None str(zarr_out), store_template="{shard_id}__w{worker_id}.zarr", episode_ends_path="meta/episode_ends", - reduce_array_batch_bytes=1024, + array_chunk_bytes=1024, overwrite=True, ).write_block([DictRow({}, shard_id="reduce")]) @@ -2520,7 +2543,7 @@ def test_write_zarr_single_store_rejects_part_dtype_drift( str(zarr_out), store_template="{shard_id}__w{worker_id}.zarr", episode_ends_path="meta/episode_ends", - reduce_array_batch_bytes=1024, + array_chunk_bytes=1024, overwrite=True, ).write_block([DictRow({}, shard_id="reduce")]) assert first_part.exists() @@ -2591,6 +2614,7 @@ def test_write_zarr_rejects_shape_drift_before_appending_bad_row( "data/action": "action", "data/state": "state", }, + reduce_to_single_store=False, ).write_block(rows) zarr_store = next(output.glob("*.zarr")) @@ -2625,6 +2649,7 @@ def test_write_zarr_materializes_frame_array_videos(tmp_path: Path) -> None: "data/action": "action", "data/rgb": "observation.images.front", }, + reduce_to_single_store=False, ).write_block(rows) zarr_store = next(output.glob("*.zarr")) @@ -2654,6 +2679,7 @@ def test_write_zarr_materializes_empty_frame_array_videos(tmp_path: Path) -> Non ZarrSink( str(output), arrays={"data/rgb": "observation.images.front"}, + reduce_to_single_store=False, ).write_block(rows) root = _open_test_zarr(next(output.glob("*.zarr")), mode="r") @@ -2684,6 +2710,7 @@ def test_write_zarr_uses_byte_budgeted_chunks_for_large_rows(tmp_path: Path) -> "data/rgb": "observation.images.front", }, array_chunk_bytes=50, + reduce_to_single_store=False, ).write_block(rows) root = _open_test_zarr(next(output.glob("*.zarr")), mode="r") @@ -2695,6 +2722,7 @@ def test_write_zarr_caps_low_dimensional_initial_chunks(tmp_path: Path) -> None: ZarrSink( str(output), arrays={"data/action": "action"}, + reduce_to_single_store=False, ).write_block( [ DictRow( @@ -2738,6 +2766,7 @@ def test_write_zarr_streams_encoded_videos(tmp_path: Path) -> None: "data/rgb": "observation.images.front", }, video_frame_batch_size=2, + reduce_to_single_store=False, ).write_block(rows) zarr_store = next(output.glob("*.zarr")) @@ -2799,6 +2828,7 @@ def test_write_zarr_rejects_video_length_mismatch_before_final_append( "data/rgb": "observation.images.front", }, video_frame_batch_size=2, + reduce_to_single_store=False, ).write_block(rows) zarr_store = next(output.glob("*.zarr")) From a745da4d593668ff367a4fc48a829081aa62d1d8 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sun, 24 May 2026 22:58:43 +0200 Subject: [PATCH 24/39] Inline Zarr IO helpers --- docs/reading-and-writing.md | 5 +-- src/refiner/io/zarr.py | 37 ------------------ src/refiner/pipeline/sinks/zarr.py | 41 ++++++++++++++------ src/refiner/pipeline/sources/readers/zarr.py | 16 +++++--- 4 files changed, 43 insertions(+), 56 deletions(-) delete mode 100644 src/refiner/io/zarr.py diff --git a/docs/reading-and-writing.md b/docs/reading-and-writing.md index d462eb45..479f5f29 100644 --- a/docs/reading-and-writing.md +++ b/docs/reading-and-writing.md @@ -529,9 +529,8 @@ individually. When you run a writer through `launch_local(...)` or `launch_cloud(...)`, some sinks add a reducer stage after the main writer stage. For `write_jsonl(...)` and `write_parquet(...)`, that reducer removes stale shard/worker files and -uploaded asset attempt folders, keeping only finalized outputs. `write_zarr(...)` -also removes stale shard/worker store directories. The output prefix should -therefore be dedicated to Refiner-managed files. +uploaded asset attempt folders, keeping only finalized outputs. The output +prefix should therefore be dedicated to Refiner-managed files. ## What Python Functions Actually See diff --git a/src/refiner/io/zarr.py b/src/refiner/io/zarr.py deleted file mode 100644 index c1918438..00000000 --- a/src/refiner/io/zarr.py +++ /dev/null @@ -1,37 +0,0 @@ -from __future__ import annotations - -from collections.abc import Iterator -from typing import Literal -from typing import Any - -from refiner.io.datafolder import DataFolder - - -def zarr_store( - folder: DataFolder, - path: str = "", - *, - mode: Literal["r", "r+", "w", "w-", "a"] = "r", -): - import zarr - - create = mode in {"w", "w-", "a"} - return zarr.storage.FSStore( - folder._join(path), - fs=folder.fs, - mode=mode, - create=create, - ) - - -def iter_zarr_array_paths(group: Any, prefix: str = "") -> Iterator[str]: - items = group.items() if hasattr(group, "items") else group.members() - for name, item in items: - path = f"{prefix}/{name}" if prefix else name - if hasattr(item, "shape"): - yield path - else: - yield from iter_zarr_array_paths(item, path) - - -__all__ = ["iter_zarr_array_paths", "zarr_store"] diff --git a/src/refiner/pipeline/sinks/zarr.py b/src/refiner/pipeline/sinks/zarr.py index 23eb9581..70d30e42 100644 --- a/src/refiner/pipeline/sinks/zarr.py +++ b/src/refiner/pipeline/sinks/zarr.py @@ -10,7 +10,6 @@ from refiner.execution.asyncio.runtime import submit from refiner.io.datafolder import DataFolder, DataFolderLike -from refiner.io.zarr import iter_zarr_array_paths, zarr_store from refiner.pipeline.data.block import Block from refiner.pipeline.data.row import Row from refiner.pipeline.sinks.base import BaseSink @@ -398,7 +397,7 @@ def _store(self, shard_id: str) -> _ShardStore: store = _ShardStore( zarr.open_group( - store=zarr_store(self.output, relpath, mode=mode), mode=mode + store=_zarr_store(self.output, relpath, mode=mode), mode=mode ) ) try: @@ -580,7 +579,7 @@ def _clear_root_payload_except(self, relpaths: Iterable[str]) -> None: keep_paths = set(relpaths) try: - root = zarr.open_group(store=zarr_store(self.output, "", mode="r+")) + root = zarr.open_group(store=_zarr_store(self.output, "", mode="r+")) except Exception: return @@ -773,7 +772,7 @@ def _merge(self) -> None: import zarr final = zarr.open_group( - store=zarr_store(self.output, "", mode="a"), + store=_zarr_store(self.output, "", mode="a"), mode="a", ) if self.overwrite: @@ -795,7 +794,7 @@ def _merge(self) -> None: import zarr final = zarr.open_group( - store=zarr_store(self.output, "", mode="a"), + store=_zarr_store(self.output, "", mode="a"), mode="a", ) if self.overwrite: @@ -811,7 +810,7 @@ def _merge(self) -> None: arrays: dict[str, Any] = {} for part in parts: source = zarr.open_group( - store=zarr_store(self.output, part.relpath, mode="r"), + store=_zarr_store(self.output, part.relpath, mode="r"), mode="r", ) for path in sorted(part.paths): @@ -905,10 +904,10 @@ def _collect_parts(self, expected_parts: Iterable[str]) -> list[_PartStore]: continue raise ValueError(f"Zarr part store is missing: {relpath}") source = zarr.open_group( - store=zarr_store(self.output, relpath, mode="r"), + store=_zarr_store(self.output, relpath, mode="r"), mode="r", ) - source_paths = set(iter_zarr_array_paths(source)) + source_paths = set(_iter_array_paths(source)) if not source_paths: continue source_payload_paths = { @@ -1054,6 +1053,26 @@ def _part_store_relpath(relpath: str) -> str: return f"_parts/{relpath}" +def _zarr_store(output: DataFolder, path: str = "", *, mode: str = "r"): + import zarr + + return zarr.storage.FSStore( + output._join(path), + fs=output.fs, + mode=mode, + create=mode in {"w", "w-", "a"}, + ) + + +def _iter_array_paths(group: Any, prefix: str = "") -> Iterable[str]: + for name, item in group.items(): + path = f"{prefix}/{name}" if prefix else name + if hasattr(item, "shape"): + yield path + else: + yield from _iter_array_paths(item, path) + + def _remove_parts(output: DataFolder, *, best_effort: bool = False) -> None: try: output.rm("_parts", recursive=True) @@ -1075,10 +1094,10 @@ def _validate_zarr_stores(output: DataFolder, relpaths: Iterable[str]) -> None: continue raise ValueError(f"Zarr store is missing: {relpath}") source = zarr.open_group( - store=zarr_store(output, relpath, mode="r"), + store=_zarr_store(output, relpath, mode="r"), mode="r", ) - source_paths = set(iter_zarr_array_paths(source)) + source_paths = set(_iter_array_paths(source)) if payload_paths is None: payload_paths = source_paths elif source_paths != payload_paths: @@ -1123,7 +1142,7 @@ def _output_has_payload(output: DataFolder) -> bool: if not non_part_entries: return False try: - group = zarr.open_group(store=zarr_store(output, "", mode="r"), mode="r") + group = zarr.open_group(store=_zarr_store(output, "", mode="r"), mode="r") except Exception: return True return _group_has_payload(group) diff --git a/src/refiner/pipeline/sources/readers/zarr.py b/src/refiner/pipeline/sources/readers/zarr.py index e315f1b6..587c94d0 100644 --- a/src/refiner/pipeline/sources/readers/zarr.py +++ b/src/refiner/pipeline/sources/readers/zarr.py @@ -13,7 +13,6 @@ from refiner.io.datafile import DataFile, DataFileLike from refiner.io.datafolder import DataFolder, DataFolderLike -from refiner.io.zarr import iter_zarr_array_paths, zarr_store from refiner.pipeline.data.datatype import ( DTypeMapping, dtype_to_plan, @@ -288,7 +287,7 @@ def _open_group(self) -> Any: handle.close() else: assert self.root is not None - store = zarr_store(self.root, mode="r") + store = zarr.storage.FSStore(self.root._join(""), fs=self.root.fs, mode="r") yield zarr.open_group(store=store, mode="r") def _reserved_output_names(self, *, split: bool) -> set[str]: @@ -310,9 +309,7 @@ def _row_metadata(self, *, index: int | None) -> dict[str, Any]: def _selected_arrays(self, group: Any) -> dict[str, Any]: if self.arrays is None: paths = { - path: path - for path in iter_zarr_array_paths(group) - if path != self.row_ends + path: path for path in _iter_array_paths(group) if path != self.row_ends } _validate_output_names( paths, @@ -565,4 +562,13 @@ def _leading_item_bytes(array: Any) -> int: return max(1, int(array.dtype.itemsize) * int(prod(trailing_shape or (1,)))) +def _iter_array_paths(group: Any, prefix: str = "") -> Iterator[str]: + for name, item in group.items(): + path = f"{prefix}/{name}" if prefix else name + if hasattr(item, "shape"): + yield path + else: + yield from _iter_array_paths(item, path) + + __all__ = ["ZarrReader"] From af4fcfe822317eee1dceea6723d1163d5a611cb1 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sun, 24 May 2026 23:10:40 +0200 Subject: [PATCH 25/39] Remove Zarr overwrite mode --- src/refiner/pipeline/pipeline.py | 4 - src/refiner/pipeline/sinks/zarr.py | 349 ++--------- tests/readers/test_zarr_reader.py | 957 ++--------------------------- 3 files changed, 93 insertions(+), 1217 deletions(-) diff --git a/src/refiner/pipeline/pipeline.py b/src/refiner/pipeline/pipeline.py index b73458cd..d6483dc4 100644 --- a/src/refiner/pipeline/pipeline.py +++ b/src/refiner/pipeline/pipeline.py @@ -441,7 +441,6 @@ def write_zarr( video_frame_batch_size: int = 8, array_chunk_bytes: int = 8 * 1024 * 1024, reduce_to_single_store: bool = True, - overwrite: bool = True, ) -> "RefinerPipeline": """Write rows to Zarr array stores. @@ -462,8 +461,6 @@ def write_zarr( reduce_to_single_store: If True, add a reducer stage that merges shard-local stores into one Zarr group at ``output``. Defaults to True. - overwrite: If True, replace Refiner-managed output at the target. - If False, fail when final output already exists. """ return self.with_sink( ZarrSink( @@ -474,7 +471,6 @@ def write_zarr( video_frame_batch_size=video_frame_batch_size, array_chunk_bytes=array_chunk_bytes, reduce_to_single_store=reduce_to_single_store, - overwrite=overwrite, ) ) diff --git a/src/refiner/pipeline/sinks/zarr.py b/src/refiner/pipeline/sinks/zarr.py index 70d30e42..3df55693 100644 --- a/src/refiner/pipeline/sinks/zarr.py +++ b/src/refiner/pipeline/sinks/zarr.py @@ -26,10 +26,6 @@ _DEFAULT_ARRAY_CHUNK_BYTES = 8 * 1024 * 1024 _MAX_INITIAL_CHUNK_ROWS = 1024 -_DONE_MARKER_RELPATH = "_refiner/write_zarr.done" -_MERGE_STARTED_MARKER_RELPATH = "_refiner/write_zarr.started" -_PUBLISH_STARTED_MARKER_RELPATH = "_refiner/write_zarr_publish.started" -_PUBLISH_DONE_MARKER_RELPATH = "_refiner/write_zarr_publish.done" @dataclass @@ -56,7 +52,6 @@ def __init__( video_frame_batch_size: int = 8, array_chunk_bytes: int = _DEFAULT_ARRAY_CHUNK_BYTES, reduce_to_single_store: bool = True, - overwrite: bool = True, ): check_required_dependencies("write_zarr", ["zarr"], dist="zarr") if video_frame_batch_size <= 0: @@ -75,12 +70,8 @@ def __init__( self.video_frame_batch_size = video_frame_batch_size self.array_chunk_bytes = array_chunk_bytes self.reduce_to_single_store = reduce_to_single_store - self.overwrite = overwrite self._stores: dict[str, _ShardStore] = {} self._default_arrays: dict[str, str] | None = None - self._checked_no_overwrite = False - self._cleared_publish_markers = False - self._cleared_merge_marker = False def write_shard_block(self, shard_id: str, block: Block) -> int: count = 0 @@ -383,22 +374,14 @@ def _arrays_for_row(self, row: Row) -> dict[str, str]: return self._default_arrays def _store(self, shard_id: str) -> _ShardStore: - self._check_no_overwrite_output() - if self.reduce_to_single_store and self.overwrite: - self._clear_merge_marker_once() - if not self.overwrite and not self.reduce_to_single_store: - self._clear_publish_markers_once() relpath = self._store_relpath(shard_id) store = self._stores.get(relpath) if store is not None: return store - mode = "w" if self.overwrite or relpath.startswith("_parts/") else "w-" import zarr store = _ShardStore( - zarr.open_group( - store=_zarr_store(self.output, relpath, mode=mode), mode=mode - ) + zarr.open_group(store=_zarr_store(self.output, relpath, mode="w"), mode="w") ) try: self.output.rm(self._empty_marker_relpath(shard_id)) @@ -407,51 +390,17 @@ def _store(self, shard_id: str) -> _ShardStore: self._stores[relpath] = store return store - def _check_no_overwrite_output(self) -> None: - if self.overwrite or self._checked_no_overwrite: - return - self._checked_no_overwrite = True - if _output_has_existing_store(self.output): - raise ValueError("write_zarr output already exists and overwrite=False") - - def _clear_publish_markers_once(self) -> None: - if self._cleared_publish_markers: - return - self._cleared_publish_markers = True - for marker in ( - _PUBLISH_STARTED_MARKER_RELPATH, - _PUBLISH_DONE_MARKER_RELPATH, - ): - try: - self.output.rm(marker) - except FileNotFoundError: - pass - - def _clear_merge_marker_once(self) -> None: - if self._cleared_merge_marker: - return - self._cleared_merge_marker = True - try: - self.output.rm(_DONE_MARKER_RELPATH) - except FileNotFoundError: - pass - def _store_relpath(self, shard_id: str) -> str: relpath = _render_store_relpath( self.store_template, shard_id=shard_id, worker_id=get_active_worker_token(), ) - if self.reduce_to_single_store or not self.overwrite: + if self.reduce_to_single_store: return _part_store_relpath(relpath) return relpath def on_shard_complete(self, shard_id: str) -> None: - self._check_no_overwrite_output() - if self.reduce_to_single_store and self.overwrite: - self._clear_merge_marker_once() - if not self.overwrite and not self.reduce_to_single_store: - self._clear_publish_markers_once() relpath = self._store_relpath(shard_id) if relpath not in self._stores: try: @@ -514,7 +463,6 @@ def describe(self) -> tuple[str, str, dict[str, object]]: "video_frame_batch_size": self.video_frame_batch_size, "array_chunk_bytes": self.array_chunk_bytes, "reduce_to_single_store": self.reduce_to_single_store, - "overwrite": self.overwrite, }, ) @@ -525,12 +473,6 @@ def build_reducer(self) -> BaseSink | None: store_template=self.store_template, episode_ends_path=self.episode_ends_path, array_chunk_bytes=self.array_chunk_bytes, - overwrite=self.overwrite, - ) - if not self.overwrite: - return _ZarrPublishPartsReducerSink( - output=self.output, - store_template=self.store_template, ) return _ZarrCleanupReducerSink( output=self.output, @@ -601,116 +543,6 @@ def clear_group(group: Any, prefix: str = "") -> None: clear_group(root) -class _ZarrPublishPartsReducerSink(BaseSink): - def __init__( - self, - output: DataFolderLike, - *, - store_template: str, - ) -> None: - self.output = DataFolder.resolve(output) - self.store_template = store_template - self._published = False - - @property - def counts_output_rows(self) -> bool: - return False - - def write_shard_block(self, shard_id, block) -> None: - del shard_id, block - if self._published: - return - self._published = True - - stage_index = get_active_stage_index() - if stage_index is None or stage_index <= 0: - raise ValueError( - "write_zarr_publish requires an active reducer stage with a prior writer stage" - ) - - parts = [ - _render_store_relpath( - self.store_template, - shard_id=row.shard_id, - worker_id=row.worker_token, - ) - for row in sort_finalized_workers( - get_finalized_workers(stage_index=stage_index - 1) - ) - ] - if not parts: - if self.output.exists( - _PUBLISH_DONE_MARKER_RELPATH - ) and not _output_has_payload(self.output): - _remove_parts(self.output, best_effort=True) - return - if _output_has_existing_store(self.output): - _remove_parts(self.output, best_effort=True) - raise ValueError("write_zarr output already exists and overwrite=False") - with self.output.open(_PUBLISH_DONE_MARKER_RELPATH, mode="wb"): - pass - _remove_parts(self.output, best_effort=True) - return - - if self.output.exists(_PUBLISH_DONE_MARKER_RELPATH): - _remove_parts(self.output, best_effort=True) - return - - has_existing_output = _output_has_existing_store(self.output) - parts_validated = False - if has_existing_output and self.output.exists(_PUBLISH_STARTED_MARKER_RELPATH): - _validate_zarr_stores( - self.output, (_part_store_relpath(relpath) for relpath in parts) - ) - parts_validated = True - if has_existing_output: - if not self.output.exists(_PUBLISH_STARTED_MARKER_RELPATH): - raise ValueError("write_zarr output already exists and overwrite=False") - self._remove_publish_targets(parts) - - with self.output.open(_PUBLISH_STARTED_MARKER_RELPATH, mode="wb"): - pass - - try: - if not parts_validated: - _validate_zarr_stores( - self.output, (_part_store_relpath(relpath) for relpath in parts) - ) - for final_relpath in parts: - part_relpath = _part_store_relpath(final_relpath) - if not self.output.exists(part_relpath): - if self.output.exists(part_relpath + ".empty"): - continue - raise ValueError(f"Zarr part store is missing: {part_relpath}") - target_parent = final_relpath.rsplit("/", maxsplit=1)[0] - if target_parent != final_relpath: - self.output.makedirs(target_parent, exist_ok=True) - self.output.copy( - part_relpath, - final_relpath, - recursive=True, - on_error="raise", - ) - except Exception: - self._remove_publish_targets(parts) - raise - - with self.output.open(_PUBLISH_DONE_MARKER_RELPATH, mode="wb"): - pass - try: - self.output.rm(_PUBLISH_STARTED_MARKER_RELPATH) - except (FileNotFoundError, OSError, ValueError): - pass - _remove_parts(self.output, best_effort=True) - - def _remove_publish_targets(self, relpaths: Iterable[str]) -> None: - for relpath in relpaths: - try: - self.output.rm(relpath, recursive=True) - except FileNotFoundError: - continue - - class _ZarrMergeReducerSink(BaseSink): def __init__( self, @@ -719,15 +551,14 @@ def __init__( store_template: str, episode_ends_path: str | None, array_chunk_bytes: int, - overwrite: bool, ) -> None: check_required_dependencies("write_zarr", ["zarr"], dist="zarr") self.output = DataFolder.resolve(output) self.store_template = store_template self.episode_ends_path = episode_ends_path self.array_chunk_bytes = array_chunk_bytes - self.overwrite = overwrite self._merged = False + self._remove_parts_on_complete = False @property def counts_output_rows(self) -> bool: @@ -762,34 +593,17 @@ def _merge(self) -> None: expected_parts = self._expected_parts(stage_index) if not expected_parts: - if self.output.exists(_DONE_MARKER_RELPATH) and not _output_has_payload( - self.output - ): - _remove_parts(self.output, best_effort=True) - return - if not self.overwrite and _output_has_existing_store(self.output): - raise ValueError("write_zarr output already exists and overwrite=False") import zarr final = zarr.open_group( store=_zarr_store(self.output, "", mode="a"), mode="a", ) - if self.overwrite: - _clear_final_group(final) - with self.output.open(_DONE_MARKER_RELPATH, mode="wb"): - pass - _remove_parts(self.output, best_effort=True) - return - - if self.output.exists(_DONE_MARKER_RELPATH): - _remove_parts(self.output, best_effort=True) + _clear_final_group(final) + self._remove_parts_on_complete = True return parts = self._collect_parts(expected_parts) - if not self.overwrite and _output_has_existing_store(self.output): - if not self.output.exists(_MERGE_STARTED_MARKER_RELPATH): - raise ValueError("write_zarr output already exists and overwrite=False") import zarr @@ -797,78 +611,67 @@ def _merge(self) -> None: store=_zarr_store(self.output, "", mode="a"), mode="a", ) - if self.overwrite: - _clear_final_group(final) - elif self.output.exists(_MERGE_STARTED_MARKER_RELPATH): - _clear_final_group(final) - else: - with self.output.open(_MERGE_STARTED_MARKER_RELPATH, mode="wb"): - pass + _clear_final_group(final) - try: - row_offset = 0 - arrays: dict[str, Any] = {} - for part in parts: - source = zarr.open_group( - store=_zarr_store(self.output, part.relpath, mode="r"), - mode="r", - ) - for path in sorted(part.paths): - source_array = source[path] - if path == self.episode_ends_path: - if source_array.shape[0] == 0: - continue - part_last = row_offset - batch_size = _batch_length( - source_array, - self.array_chunk_bytes, - ) - for start in range(0, int(source_array.shape[0]), batch_size): - end = min(int(source_array.shape[0]), start + batch_size) - values = np.asarray(source_array[start:end], dtype=np.int64) - _append_zarr_array( - final, - arrays, - path, - values + row_offset, - chunks=getattr(source_array, "chunks", None), - compressor=getattr(source_array, "compressor", None), - ) - part_last = int(values[-1]) - row_offset += part_last - continue - batch_size = _batch_length(source_array, self.array_chunk_bytes) + row_offset = 0 + arrays: dict[str, Any] = {} + for part in parts: + source = zarr.open_group( + store=_zarr_store(self.output, part.relpath, mode="r"), + mode="r", + ) + for path in sorted(part.paths): + source_array = source[path] + if path == self.episode_ends_path: if source_array.shape[0] == 0: - _append_zarr_array( - final, - arrays, - path, - np.asarray(source_array[:0]), - chunks=getattr(source_array, "chunks", None), - compressor=getattr(source_array, "compressor", None), - ) continue + part_last = row_offset + batch_size = _batch_length( + source_array, + self.array_chunk_bytes, + ) for start in range(0, int(source_array.shape[0]), batch_size): end = min(int(source_array.shape[0]), start + batch_size) + values = np.asarray(source_array[start:end], dtype=np.int64) _append_zarr_array( final, arrays, path, - np.asarray(source_array[start:end]), + values + row_offset, chunks=getattr(source_array, "chunks", None), compressor=getattr(source_array, "compressor", None), ) - except Exception: - if not self.overwrite: - _clear_final_group(final) - raise + part_last = int(values[-1]) + row_offset += part_last + continue + batch_size = _batch_length(source_array, self.array_chunk_bytes) + if source_array.shape[0] == 0: + _append_zarr_array( + final, + arrays, + path, + np.asarray(source_array[:0]), + chunks=getattr(source_array, "chunks", None), + compressor=getattr(source_array, "compressor", None), + ) + continue + for start in range(0, int(source_array.shape[0]), batch_size): + end = min(int(source_array.shape[0]), start + batch_size) + _append_zarr_array( + final, + arrays, + path, + np.asarray(source_array[start:end]), + chunks=getattr(source_array, "chunks", None), + compressor=getattr(source_array, "compressor", None), + ) - with self.output.open(_DONE_MARKER_RELPATH, mode="wb"): - pass - try: - self.output.rm(_MERGE_STARTED_MARKER_RELPATH) - except (FileNotFoundError, OSError, ValueError): - pass + self._remove_parts_on_complete = True + + def on_shard_complete(self, shard_id: str) -> None: + del shard_id + if not self._remove_parts_on_complete: + return _remove_parts(self.output, best_effort=True) try: if not self.output.ls("_parts"): @@ -1121,52 +924,6 @@ def _clear_final_group(group: Any) -> None: group.attrs.clear() -def _group_has_payload(group: Any) -> bool: - return bool(group.attrs) or any( - key != "_parts" for key in {*group.array_keys(), *group.group_keys()} - ) - - -def _output_has_payload(output: DataFolder) -> bool: - import zarr - - try: - entries = output.ls("", detail=False) - except FileNotFoundError: - return False - non_part_entries = [ - entry - for entry in entries - if str(entry).split("/", maxsplit=1)[0] not in {"_parts", "_refiner"} - ] - if not non_part_entries: - return False - try: - group = zarr.open_group(store=_zarr_store(output, "", mode="r"), mode="r") - except Exception: - return True - return _group_has_payload(group) - - -def _output_has_existing_store(output: DataFolder) -> bool: - if _output_has_payload(output): - return True - if output.exists(_DONE_MARKER_RELPATH) or output.exists( - _PUBLISH_DONE_MARKER_RELPATH - ): - return True - try: - entries = output.ls("", detail=False) - except FileNotFoundError: - return False - for entry in entries: - root = str(entry).split("/", maxsplit=1)[0] - if root in {"_parts", "_refiner", ".zgroup", ".zattrs", ".zmetadata"}: - continue - return True - return False - - def _chunk_shape(array: np.ndarray, target_bytes: int) -> tuple[int, ...]: chunk_rows = min( _batch_length(array, target_bytes), diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py index a45141ce..1bf21105 100644 --- a/tests/readers/test_zarr_reader.py +++ b/tests/readers/test_zarr_reader.py @@ -10,7 +10,6 @@ import zarr import refiner as mdr -from refiner.cli.run.local import LocalLaunchResumeError from refiner.io.datafolder import DataFolder from refiner.io import DataFile from refiner.robotics.row import RoboticsRow @@ -956,8 +955,8 @@ def test_write_zarr_rejects_empty_default_robotics_arrays(tmp_path: Path) -> Non ZarrSink(str(tmp_path / "empty-defaults.zarr")).write_block(rows) -def test_write_zarr_single_store_overwrite_ignores_stale_parts(tmp_path: Path) -> None: - zarr_out = tmp_path / "single-overwrite.zarr" +def test_write_zarr_single_store_replace_ignores_stale_parts(tmp_path: Path) -> None: + zarr_out = tmp_path / "single-replace.zarr" ( mdr.from_items([{"action": [[0.0]]}], items_per_shard=1) @@ -967,7 +966,7 @@ def test_write_zarr_single_store_overwrite_ignores_stale_parts(tmp_path: Path) - reduce_to_single_store=True, ) .launch_local( - name="zarr-single-overwrite-first", + name="zarr-single-replace-first", num_workers=1, rundir=str(tmp_path / "run-first"), ) @@ -985,7 +984,7 @@ def test_write_zarr_single_store_overwrite_ignores_stale_parts(tmp_path: Path) - reduce_to_single_store=True, ) .launch_local( - name="zarr-single-overwrite-second", + name="zarr-single-replace-second", num_workers=1, rundir=str(tmp_path / "run-second"), ) @@ -1000,10 +999,10 @@ def test_write_zarr_single_store_overwrite_ignores_stale_parts(tmp_path: Path) - assert not stale_part.exists() -def test_write_zarr_sharded_overwrite_removes_single_store_payload_and_parts( +def test_write_zarr_sharded_replace_removes_single_store_payload_and_parts( tmp_path: Path, ) -> None: - zarr_out = tmp_path / "sharded-overwrites-single-store.zarr" + zarr_out = tmp_path / "sharded-replaces-single-store.zarr" ( mdr.from_items([{"action": [[0.0]]}], items_per_shard=1) @@ -1013,9 +1012,9 @@ def test_write_zarr_sharded_overwrite_removes_single_store_payload_and_parts( reduce_to_single_store=True, ) .launch_local( - name="zarr-sharded-overwrite-first", + name="zarr-sharded-replace-first", num_workers=1, - rundir=str(tmp_path / "run-sharded-overwrite-first"), + rundir=str(tmp_path / "run-sharded-replace-first"), ) ) stale_part = zarr_out / "_parts" / "old__wold.zarr" @@ -1030,9 +1029,9 @@ def test_write_zarr_sharded_overwrite_removes_single_store_payload_and_parts( reduce_to_single_store=False, ) .launch_local( - name="zarr-sharded-overwrite-second", + name="zarr-sharded-replace-second", num_workers=1, - rundir=str(tmp_path / "run-sharded-overwrite-second"), + rundir=str(tmp_path / "run-sharded-replace-second"), ) ) @@ -1053,10 +1052,10 @@ def test_write_zarr_sharded_overwrite_removes_single_store_payload_and_parts( assert row["episode_ends"].tolist() == [2] -def test_write_zarr_sharded_overwrite_clears_payload_under_store_prefix( +def test_write_zarr_sharded_replace_clears_payload_under_store_prefix( tmp_path: Path, ) -> None: - zarr_out = tmp_path / "sharded-overwrites-nested-single-store.zarr" + zarr_out = tmp_path / "sharded-replaces-nested-single-store.zarr" ( mdr.from_items([{"action": [[0.0]]}], items_per_shard=1) @@ -1066,9 +1065,9 @@ def test_write_zarr_sharded_overwrite_clears_payload_under_store_prefix( reduce_to_single_store=True, ) .launch_local( - name="zarr-sharded-nested-overwrite-first", + name="zarr-sharded-nested-replace-first", num_workers=1, - rundir=str(tmp_path / "run-sharded-nested-overwrite-first"), + rundir=str(tmp_path / "run-sharded-nested-replace-first"), ) ) @@ -1081,9 +1080,9 @@ def test_write_zarr_sharded_overwrite_clears_payload_under_store_prefix( reduce_to_single_store=False, ) .launch_local( - name="zarr-sharded-nested-overwrite-second", + name="zarr-sharded-nested-replace-second", num_workers=1, - rundir=str(tmp_path / "run-sharded-nested-overwrite-second"), + rundir=str(tmp_path / "run-sharded-nested-replace-second"), ) ) @@ -1099,309 +1098,6 @@ def test_write_zarr_sharded_overwrite_clears_payload_under_store_prefix( np.testing.assert_allclose(row["action"], [[1.0], [2.0]]) -def test_write_zarr_single_store_rejects_existing_output_when_not_overwriting( - tmp_path: Path, -) -> None: - zarr_out = tmp_path / "single-no-overwrite.zarr" - - ( - mdr.from_items([{"action": [[0.0]]}], items_per_shard=1) - .write_zarr( - str(zarr_out), - arrays={"data/action": "action"}, - reduce_to_single_store=True, - ) - .launch_local( - name="zarr-single-no-overwrite-first", - num_workers=1, - rundir=str(tmp_path / "run-no-overwrite-first"), - ) - ) - - with pytest.raises(LocalLaunchResumeError): - ( - mdr.from_items([{"action": [[1.0]]}], items_per_shard=1) - .write_zarr( - str(zarr_out), - arrays={"data/action": "action"}, - reduce_to_single_store=True, - overwrite=False, - ) - .launch_local( - name="zarr-single-no-overwrite-second", - num_workers=1, - rundir=str(tmp_path / "run-no-overwrite-second"), - ) - ) - - row = mdr.read_zarr( - zarr_out, - arrays={"action": "data/action"}, - file_path_column=None, - ).take(1)[0] - np.testing.assert_allclose(row["action"], [[0.0]]) - - -def test_write_zarr_single_store_rejects_attrs_only_output_when_not_overwriting( - tmp_path: Path, -) -> None: - zarr_out = tmp_path / "single-attrs-no-overwrite.zarr" - root = _open_test_zarr(zarr_out, mode="w") - root.attrs["task"] = "old" - - with pytest.raises(LocalLaunchResumeError): - ( - mdr.from_items([{"action": [[1.0]]}], items_per_shard=1) - .write_zarr( - str(zarr_out), - arrays={"data/action": "action"}, - reduce_to_single_store=True, - overwrite=False, - ) - .launch_local( - name="zarr-single-attrs-no-overwrite", - num_workers=1, - rundir=str(tmp_path / "run-attrs-no-overwrite"), - ) - ) - - root = _open_test_zarr(zarr_out, mode="r") - assert dict(root.attrs) == {"task": "old"} - - -def test_write_zarr_single_store_rejects_empty_existing_store_when_not_overwriting( - tmp_path: Path, -) -> None: - zarr_out = tmp_path / "single-empty-existing-no-overwrite.zarr" - - ( - mdr.from_items([{"action": [[0.0]]}], items_per_shard=1) - .filter(lambda row: False) - .write_zarr( - str(zarr_out), - arrays={"data/action": "action"}, - reduce_to_single_store=True, - ) - .launch_local( - name="zarr-single-empty-existing-first", - num_workers=1, - rundir=str(tmp_path / "run-empty-existing-first"), - ) - ) - - with pytest.raises(LocalLaunchResumeError): - ( - mdr.from_items([{"action": [[1.0]]}], items_per_shard=1) - .write_zarr( - str(zarr_out), - arrays={"data/action": "action"}, - reduce_to_single_store=True, - overwrite=False, - ) - .launch_local( - name="zarr-single-empty-existing-second", - num_workers=1, - rundir=str(tmp_path / "run-empty-existing-second"), - ) - ) - - -def test_write_zarr_rejects_existing_non_reduced_output_when_not_overwriting( - tmp_path: Path, -) -> None: - zarr_out = tmp_path / "sharded-no-overwrite.zarr" - - ( - mdr.from_items([{"action": [[0.0]]}], items_per_shard=1) - .write_zarr( - str(zarr_out), - arrays={"data/action": "action"}, - reduce_to_single_store=False, - ) - .launch_local( - name="zarr-sharded-no-overwrite-first", - num_workers=1, - rundir=str(tmp_path / "run-sharded-no-overwrite-first"), - ) - ) - - with pytest.raises(LocalLaunchResumeError): - ( - mdr.from_items([{"action": [[1.0]]}], items_per_shard=1) - .write_zarr( - str(zarr_out), - arrays={"data/action": "action"}, - overwrite=False, - reduce_to_single_store=False, - ) - .launch_local( - name="zarr-sharded-no-overwrite-second", - num_workers=1, - rundir=str(tmp_path / "run-sharded-no-overwrite-second"), - ) - ) - - rows = [ - mdr.read_zarr( - store, - arrays={"action": "data/action"}, - file_path_column=None, - ).take(1)[0] - for store in zarr_out.glob("*.zarr") - ] - assert sorted(float(row["action"][0][0]) for row in rows) == [0.0] - assert not (zarr_out / "_parts").exists() - - -def test_write_zarr_non_reduced_no_overwrite_preserves_parts_on_conflict_for_resume( - tmp_path: Path, -) -> None: - zarr_out = tmp_path / "sharded-no-overwrite-conflict-resume.zarr" - shard_id = "0123456789ab" - worker_id = "worker-a" - relpath = f"{shard_id}__w{worker_token_for(worker_id)}.zarr" - part = zarr_out / "_parts" / relpath - existing = zarr_out / "existing.zarr" - _write_part_zarr(part, {"data/action": np.asarray([[1.0]], dtype=np.float32)}) - _write_part_zarr(existing, {"data/action": np.asarray([[0.0]], dtype=np.float32)}) - - runtime = _FinalizedWorkersRuntime( - [FinalizedShardWorker(shard_id=shard_id, worker_id=worker_id)] - ) - with set_active_run_context( - job_id="local", - stage_index=1, - worker_id="reducer", - worker_name=None, - runtime_lifecycle=cast(RuntimeLifecycle, runtime), - ): - reducer = ZarrSink( - str(zarr_out), - arrays={"data/action": "action"}, - overwrite=False, - reduce_to_single_store=False, - ).build_reducer() - assert reducer is not None - with pytest.raises(ValueError, match="output already exists"): - reducer.write_block([DictRow({}, shard_id="reduce")]) - - assert part.exists() - shutil.rmtree(existing) - - with set_active_run_context( - job_id="local", - stage_index=1, - worker_id="reducer", - worker_name=None, - runtime_lifecycle=cast(RuntimeLifecycle, runtime), - ): - reducer = ZarrSink( - str(zarr_out), - arrays={"data/action": "action"}, - overwrite=False, - reduce_to_single_store=False, - ).build_reducer() - assert reducer is not None - reducer.write_block([DictRow({}, shard_id="reduce")]) - - assert not part.exists() - row = mdr.read_zarr( - zarr_out / relpath, - arrays={"action": "data/action"}, - file_path_column=None, - ).take(1)[0] - np.testing.assert_allclose(row["action"], [[1.0]]) - - -def test_write_zarr_non_reduced_no_overwrite_preserves_finalized_retry_output( - tmp_path: Path, -) -> None: - zarr_out = tmp_path / "sharded-no-overwrite-retry.zarr" - shard_id = "0123456789ab" - loser_worker_id = "loser" - winner_worker_id = "winner" - loser = zarr_out / f"{shard_id}__w{worker_token_for(loser_worker_id)}.zarr" - winner = zarr_out / f"{shard_id}__w{worker_token_for(winner_worker_id)}.zarr" - _write_part_zarr(loser, {"data/action": np.asarray([[0.0]], dtype=np.float32)}) - _write_part_zarr(winner, {"data/action": np.asarray([[1.0]], dtype=np.float32)}) - - reducer = ZarrSink( - str(zarr_out), - arrays={"data/action": "action"}, - overwrite=False, - reduce_to_single_store=False, - ).build_reducer() - assert reducer is not None - - runtime = _FinalizedWorkersRuntime( - [FinalizedShardWorker(shard_id=shard_id, worker_id=winner_worker_id)] - ) - with set_active_run_context( - job_id="local", - stage_index=1, - worker_id="reducer", - worker_name=None, - runtime_lifecycle=cast(RuntimeLifecycle, runtime), - ): - with pytest.raises(ValueError, match="output already exists"): - reducer.write_block([DictRow({}, shard_id="reduce")]) - - assert loser.exists() - assert winner.exists() - - -def test_write_zarr_non_reduced_no_overwrite_rejects_missing_finalized_part( - tmp_path: Path, -) -> None: - zarr_out = tmp_path / "sharded-no-overwrite-missing-part.zarr" - reducer = ZarrSink( - str(zarr_out), - arrays={"data/action": "action"}, - overwrite=False, - reduce_to_single_store=False, - ).build_reducer() - assert reducer is not None - - runtime = _FinalizedWorkersRuntime( - [FinalizedShardWorker(shard_id="shard-a", worker_id="worker-a")] - ) - with set_active_run_context( - job_id="local", - stage_index=1, - worker_id="reducer", - worker_name=None, - runtime_lifecycle=cast(RuntimeLifecycle, runtime), - ): - with pytest.raises(ValueError, match="Zarr store is missing"): - reducer.write_block([DictRow({}, shard_id="reduce")]) - - -def test_write_zarr_non_reduced_no_overwrite_rejects_unsafe_finalized_part_path( - tmp_path: Path, -) -> None: - zarr_out = tmp_path / "sharded-no-overwrite-unsafe-finalized.zarr" - reducer = ZarrSink( - str(zarr_out), - arrays={"data/action": "action"}, - overwrite=False, - reduce_to_single_store=False, - ).build_reducer() - assert reducer is not None - - runtime = _FinalizedWorkersRuntime( - [FinalizedShardWorker(shard_id="../escape", worker_id="worker-a")] - ) - with set_active_run_context( - job_id="local", - stage_index=1, - worker_id="reducer", - worker_name=None, - runtime_lifecycle=cast(RuntimeLifecycle, runtime), - ): - with pytest.raises(ValueError, match="must not contain"): - reducer.write_block([DictRow({}, shard_id="reduce")]) - - def test_write_zarr_non_reduced_cleanup_rejects_missing_finalized_store( tmp_path: Path, ) -> None: @@ -1466,310 +1162,14 @@ def test_write_zarr_non_reduced_cleanup_keeps_empty_markers_retryable( stage_index=1, worker_id="reducer", worker_name=None, - runtime_lifecycle=cast(RuntimeLifecycle, runtime), - ): - _ZarrCleanupReducerSink( - str(zarr_out), - store_template="{shard_id}__w{worker_id}.zarr", - ).write_block([DictRow({}, shard_id="reduce")]) - - assert marker.exists() - - -def test_write_zarr_non_reduced_no_overwrite_skips_empty_parts( - tmp_path: Path, -) -> None: - zarr_out = tmp_path / "sharded-no-overwrite-empty-part.zarr" - - ( - mdr.from_items([{"action": [[0.0]]}], items_per_shard=1) - .filter(lambda row: False) - .write_zarr( - str(zarr_out), - arrays={"data/action": "action"}, - overwrite=False, - reduce_to_single_store=False, - ) - .launch_local( - name="zarr-sharded-empty-no-overwrite", - num_workers=1, - rundir=str(tmp_path / "run-sharded-empty-no-overwrite"), - ) - ) - - assert not list(zarr_out.glob("*.zarr")) - assert not (zarr_out / "_parts").exists() - - -def test_write_zarr_non_reduced_no_overwrite_empty_publish_is_retryable( - tmp_path: Path, -) -> None: - zarr_out = tmp_path / "sharded-empty-publish-retry.zarr" - marker = zarr_out / "_refiner" / "write_zarr_publish.done" - marker.parent.mkdir(parents=True) - marker.write_bytes(b"") - reducer = ZarrSink( - str(zarr_out), - arrays={"data/action": "action"}, - overwrite=False, - reduce_to_single_store=False, - ).build_reducer() - assert reducer is not None - - with set_active_run_context( - job_id="local", - stage_index=1, - worker_id="reducer", - worker_name=None, - runtime_lifecycle=cast(RuntimeLifecycle, _FinalizedWorkersRuntime([])), - ): - reducer.write_block([DictRow({}, shard_id="reduce")]) - - -def test_write_zarr_non_reduced_rejects_empty_existing_output_when_not_overwriting( - tmp_path: Path, -) -> None: - zarr_out = tmp_path / "sharded-empty-existing-no-overwrite.zarr" - - ( - mdr.from_items([{"action": [[0.0]]}], items_per_shard=1) - .filter(lambda row: False) - .write_zarr( - str(zarr_out), - arrays={"data/action": "action"}, - overwrite=False, - reduce_to_single_store=False, - ) - .launch_local( - name="zarr-sharded-empty-existing-first", - num_workers=1, - rundir=str(tmp_path / "run-sharded-empty-existing-first"), - ) - ) - - with pytest.raises(LocalLaunchResumeError): - ( - mdr.from_items([{"action": [[1.0]]}], items_per_shard=1) - .write_zarr( - str(zarr_out), - arrays={"data/action": "action"}, - overwrite=False, - reduce_to_single_store=False, - ) - .launch_local( - name="zarr-sharded-empty-existing-second", - num_workers=1, - rundir=str(tmp_path / "run-sharded-empty-existing-second"), - ) - ) - - assert not list(zarr_out.glob("*.zarr")) - assert not (zarr_out / "_parts").exists() - - -def test_write_zarr_non_reduced_no_overwrite_retry_removes_partial_publish( - tmp_path: Path, -) -> None: - zarr_out = tmp_path / "sharded-no-overwrite-partial-publish.zarr" - shard_id = "0123456789ab" - worker_id = "worker-a" - relpath = f"{shard_id}__w{worker_token_for(worker_id)}.zarr" - _write_part_zarr( - zarr_out / "_parts" / relpath, - {"data/action": np.asarray([[1.0]], dtype=np.float32)}, - ) - _write_part_zarr( - zarr_out / relpath, - {"data/action": np.asarray([[0.0]], dtype=np.float32)}, - ) - marker = zarr_out / "_refiner" / "write_zarr_publish.started" - marker.parent.mkdir(parents=True) - marker.write_bytes(b"") - - reducer = ZarrSink( - str(zarr_out), - arrays={"data/action": "action"}, - overwrite=False, - reduce_to_single_store=False, - ).build_reducer() - assert reducer is not None - runtime = _FinalizedWorkersRuntime( - [FinalizedShardWorker(shard_id=shard_id, worker_id=worker_id)] - ) - with set_active_run_context( - job_id="local", - stage_index=1, - worker_id="reducer", - worker_name=None, - runtime_lifecycle=cast(RuntimeLifecycle, runtime), - ): - reducer.write_block([DictRow({}, shard_id="reduce")]) - - row = mdr.read_zarr( - zarr_out / relpath, - arrays={"action": "data/action"}, - file_path_column=None, - ).take(1)[0] - np.testing.assert_allclose(row["action"], [[1.0]]) - assert not (zarr_out / "_parts").exists() - - -def test_write_zarr_non_reduced_no_overwrite_retry_keeps_final_when_part_missing( - tmp_path: Path, -) -> None: - zarr_out = tmp_path / "sharded-no-overwrite-retry-missing-part.zarr" - shard_id = "0123456789ab" - worker_id = "worker-a" - relpath = f"{shard_id}__w{worker_token_for(worker_id)}.zarr" - _write_part_zarr( - zarr_out / relpath, - {"data/action": np.asarray([[0.0]], dtype=np.float32)}, - ) - marker = zarr_out / "_refiner" / "write_zarr_publish.started" - marker.parent.mkdir(parents=True) - marker.write_bytes(b"") - - reducer = ZarrSink( - str(zarr_out), - arrays={"data/action": "action"}, - overwrite=False, - reduce_to_single_store=False, - ).build_reducer() - assert reducer is not None - runtime = _FinalizedWorkersRuntime( - [FinalizedShardWorker(shard_id=shard_id, worker_id=worker_id)] - ) - with set_active_run_context( - job_id="local", - stage_index=1, - worker_id="reducer", - worker_name=None, - runtime_lifecycle=cast(RuntimeLifecycle, runtime), - ): - with pytest.raises(ValueError, match="Zarr store is missing"): - reducer.write_block([DictRow({}, shard_id="reduce")]) - - row = mdr.read_zarr( - zarr_out / relpath, - arrays={"action": "data/action"}, - file_path_column=None, - ).take(1)[0] - np.testing.assert_allclose(row["action"], [[0.0]]) - - -def test_write_zarr_non_reduced_no_overwrite_replaces_stale_staging_part( - tmp_path: Path, -) -> None: - zarr_out = tmp_path / "sharded-no-overwrite-stale-staging.zarr" - shard_id = "shard-a" - worker_id = "worker-a" - part = zarr_out / "_parts" / f"{shard_id}__w{worker_token_for(worker_id)}.zarr" - _write_part_zarr(part, {"data/action": np.asarray([[0.0]], dtype=np.float32)}) - - with set_active_run_context( - job_id="local", - stage_index=0, - worker_id=worker_id, - worker_name=None, - runtime_lifecycle=cast(RuntimeLifecycle, _FinalizedWorkersRuntime([])), - ): - ZarrSink( - str(zarr_out), - arrays={"data/action": "action"}, - overwrite=False, - reduce_to_single_store=False, - ).write_block([DictRow({"action": [[1.0]]}, shard_id=shard_id)]) - - row = mdr.read_zarr( - part, - arrays={"action": "data/action"}, - file_path_column=None, - ).take(1)[0] - np.testing.assert_allclose(row["action"], [[1.0]]) - - -def test_write_zarr_non_reduced_no_overwrite_completed_publish_is_retryable( - tmp_path: Path, -) -> None: - zarr_out = tmp_path / "sharded-no-overwrite-complete-retry.zarr" - shard_id = "0123456789ab" - worker_id = "worker-a" - relpath = f"{shard_id}__w{worker_token_for(worker_id)}.zarr" - _write_part_zarr( - zarr_out / relpath, - {"data/action": np.asarray([[1.0]], dtype=np.float32)}, - ) - _write_part_zarr( - zarr_out / "_parts" / relpath, - {"data/action": np.asarray([[1.0]], dtype=np.float32)}, - ) - marker = zarr_out / "_refiner" / "write_zarr_publish.done" - marker.parent.mkdir(parents=True) - marker.write_bytes(b"") - - reducer = ZarrSink( - str(zarr_out), - arrays={"data/action": "action"}, - overwrite=False, - reduce_to_single_store=False, - ).build_reducer() - assert reducer is not None - runtime = _FinalizedWorkersRuntime( - [FinalizedShardWorker(shard_id=shard_id, worker_id=worker_id)] - ) - with set_active_run_context( - job_id="local", - stage_index=1, - worker_id="reducer", - worker_name=None, - runtime_lifecycle=cast(RuntimeLifecycle, runtime), - ): - reducer.write_block([DictRow({}, shard_id="reduce")]) - - row = mdr.read_zarr( - zarr_out / relpath, - arrays={"action": "data/action"}, - file_path_column=None, - ).take(1)[0] - np.testing.assert_allclose(row["action"], [[1.0]]) - assert not (zarr_out / "_parts").exists() - - -def test_write_zarr_allows_fresh_non_reduced_multiworker_no_overwrite( - tmp_path: Path, -) -> None: - zarr_out = tmp_path / "sharded-fresh-no-overwrite.zarr" - - ( - mdr.from_items( - [{"action": [[0.0]]}, {"action": [[1.0]]}], - items_per_shard=1, - ) - .write_zarr( - str(zarr_out), - arrays={"data/action": "action"}, - overwrite=False, - reduce_to_single_store=False, - ) - .launch_local( - name="zarr-sharded-fresh-no-overwrite", - num_workers=2, - rundir=str(tmp_path / "run-sharded-fresh-no-overwrite"), - ) - ) + runtime_lifecycle=cast(RuntimeLifecycle, runtime), + ): + _ZarrCleanupReducerSink( + str(zarr_out), + store_template="{shard_id}__w{worker_id}.zarr", + ).write_block([DictRow({}, shard_id="reduce")]) - rows = [ - mdr.read_zarr( - store, - arrays={"action": "data/action"}, - file_path_column=None, - ).take(1)[0] - for store in sorted(zarr_out.glob("*.zarr")) - ] - assert len(rows) == 2 - assert sorted(float(row["action"][0][0]) for row in rows) == [0.0, 1.0] - assert not (zarr_out / "_parts").exists() - assert not (zarr_out / "_refiner" / "write_zarr_publish.started").exists() + assert marker.exists() def test_write_zarr_rejects_sharded_schema_drift_after_cleanup( @@ -1827,57 +1227,6 @@ def test_write_zarr_rejects_sharded_schema_drift_after_cleanup( reducer.write_block([DictRow({}, shard_id="reduce")]) -def test_write_zarr_no_overwrite_rejects_part_schema_drift_before_publish( - tmp_path: Path, -) -> None: - zarr_out = tmp_path / "sharded-no-overwrite-schema-drift.zarr" - first_worker = "worker-a" - second_worker = "worker-b" - _write_part_zarr( - zarr_out / "_parts" / f"shard-a__w{worker_token_for(first_worker)}.zarr", - {"data/action": np.asarray([[0.0]], dtype=np.float32)}, - ) - _write_part_zarr( - zarr_out / "_parts" / f"shard-b__w{worker_token_for(second_worker)}.zarr", - { - "data/action": np.asarray([[1.0]], dtype=np.float32), - "data/state": np.asarray([[2.0]], dtype=np.float32), - }, - ) - reducer = ZarrSink( - str(zarr_out), - arrays={"data/action": "action"}, - overwrite=False, - reduce_to_single_store=False, - ).build_reducer() - assert reducer is not None - runtime = _FinalizedWorkersRuntime( - [ - FinalizedShardWorker( - shard_id="shard-a", - worker_id=first_worker, - global_ordinal=0, - ), - FinalizedShardWorker( - shard_id="shard-b", - worker_id=second_worker, - global_ordinal=1, - ), - ] - ) - - with set_active_run_context( - job_id="local", - stage_index=1, - worker_id="reducer", - worker_name=None, - runtime_lifecycle=cast(RuntimeLifecycle, runtime), - ): - with pytest.raises(ValueError, match="same arrays"): - reducer.write_block([DictRow({}, shard_id="reduce")]) - assert not list(zarr_out.glob("*.zarr")) - - def test_write_zarr_single_store_skips_empty_shards(tmp_path: Path) -> None: zarr_out = tmp_path / "single-empty-shards.zarr" @@ -1902,10 +1251,10 @@ def test_write_zarr_single_store_skips_empty_shards(tmp_path: Path) -> None: assert not (zarr_out / "_parts").exists() -def test_write_zarr_single_store_empty_overwrite_ignores_stale_done_marker( +def test_write_zarr_single_store_empty_replace_ignores_stale_done_marker( tmp_path: Path, ) -> None: - zarr_out = tmp_path / "single-empty-overwrite-stale-done.zarr" + zarr_out = tmp_path / "single-empty-replace-stale-done.zarr" ( mdr.from_items([{"action": [[1.0]]}], items_per_shard=1) @@ -2030,7 +1379,6 @@ def test_write_zarr_single_store_rejects_inconsistent_part_payloads( store_template="{shard_id}__w{worker_id}.zarr", episode_ends_path="meta/episode_ends", array_chunk_bytes=1024, - overwrite=True, ).write_block([DictRow({}, shard_id="reduce")]) assert first_part.exists() assert second_part.exists() @@ -2068,7 +1416,6 @@ def test_write_zarr_single_store_rejects_part_missing_episode_ends( store_template="{shard_id}__w{worker_id}.zarr", episode_ends_path="meta/episode_ends", array_chunk_bytes=1024, - overwrite=True, ).write_block([DictRow({}, shard_id="reduce")]) @@ -2106,7 +1453,6 @@ def test_write_zarr_single_store_rejects_missing_finalized_part( store_template="{shard_id}__w{worker_id}.zarr", episode_ends_path="meta/episode_ends", array_chunk_bytes=1024, - overwrite=True, ).write_block([DictRow({}, shard_id="reduce")]) row = mdr.read_zarr( @@ -2118,28 +1464,19 @@ def test_write_zarr_single_store_rejects_missing_finalized_part( assert row["episode_ends"].tolist() == [1] -def test_write_zarr_single_store_completed_merge_is_retryable( +def test_write_zarr_single_store_removes_parts_only_on_completion( tmp_path: Path, ) -> None: - zarr_out = tmp_path / "single-completed-retry.zarr" + zarr_out = tmp_path / "single-complete-cleans-parts.zarr" worker_id = "worker-a" + part = zarr_out / "_parts" / f"shard-a__w{worker_token_for(worker_id)}.zarr" _write_part_zarr( - zarr_out, - { - "data/action": np.asarray([[9.0]], dtype=np.float32), - "meta/episode_ends": np.asarray([1], dtype=np.int64), - }, - ) - _write_part_zarr( - zarr_out / "_parts" / f"shard-a__w{worker_token_for(worker_id)}.zarr", + part, { "data/action": np.asarray([[9.0]], dtype=np.float32), "meta/episode_ends": np.asarray([1], dtype=np.int64), }, ) - marker = zarr_out / "_refiner" / "write_zarr.done" - marker.parent.mkdir(parents=True) - marker.write_bytes(b"") runtime = _FinalizedWorkersRuntime( [ FinalizedShardWorker( @@ -2157,13 +1494,15 @@ def test_write_zarr_single_store_completed_merge_is_retryable( worker_name=None, runtime_lifecycle=cast(RuntimeLifecycle, runtime), ): - _ZarrMergeReducerSink( + reducer = _ZarrMergeReducerSink( str(zarr_out), store_template="{shard_id}__w{worker_id}.zarr", episode_ends_path="meta/episode_ends", array_chunk_bytes=1024, - overwrite=True, - ).write_block([DictRow({}, shard_id="reduce")]) + ) + reducer.write_block([DictRow({}, shard_id="reduce")]) + assert part.exists() + reducer.on_shard_complete("reduce") row = mdr.read_zarr( zarr_out, @@ -2175,10 +1514,10 @@ def test_write_zarr_single_store_completed_merge_is_retryable( assert not (zarr_out / "_parts").exists() -def test_write_zarr_single_store_zero_shard_overwrite_ignores_stale_done_marker( +def test_write_zarr_single_store_zero_shard_replace_clears_existing_output( tmp_path: Path, ) -> None: - zarr_out = tmp_path / "single-zero-shard-overwrite.zarr" + zarr_out = tmp_path / "single-zero-shard-replace.zarr" _write_part_zarr( zarr_out, { @@ -2186,9 +1525,6 @@ def test_write_zarr_single_store_zero_shard_overwrite_ignores_stale_done_marker( "meta/episode_ends": np.asarray([1], dtype=np.int64), }, ) - marker = zarr_out / "_refiner" / "write_zarr.done" - marker.parent.mkdir(parents=True) - marker.write_bytes(b"") runtime = _FinalizedWorkersRuntime([]) with set_active_run_context( @@ -2203,223 +1539,12 @@ def test_write_zarr_single_store_zero_shard_overwrite_ignores_stale_done_marker( store_template="{shard_id}__w{worker_id}.zarr", episode_ends_path="meta/episode_ends", array_chunk_bytes=1024, - overwrite=True, ).write_block([DictRow({}, shard_id="reduce")]) root = _open_test_zarr(zarr_out, mode="r") assert "data/action" not in root -def test_write_zarr_single_store_zero_shard_no_overwrite_rejects_existing_output( - tmp_path: Path, -) -> None: - zarr_out = tmp_path / "single-zero-shard-no-overwrite.zarr" - _write_part_zarr( - zarr_out, - { - "data/action": np.asarray([[9.0]], dtype=np.float32), - "meta/episode_ends": np.asarray([1], dtype=np.int64), - }, - ) - marker = zarr_out / "_refiner" / "write_zarr.done" - marker.parent.mkdir(parents=True) - marker.write_bytes(b"") - runtime = _FinalizedWorkersRuntime([]) - - with set_active_run_context( - job_id="local", - stage_index=1, - worker_id="reducer", - worker_name=None, - runtime_lifecycle=cast(RuntimeLifecycle, runtime), - ): - with pytest.raises(ValueError, match="output already exists"): - _ZarrMergeReducerSink( - str(zarr_out), - store_template="{shard_id}__w{worker_id}.zarr", - episode_ends_path="meta/episode_ends", - array_chunk_bytes=1024, - overwrite=False, - ).write_block([DictRow({}, shard_id="reduce")]) - - -def test_write_zarr_single_store_zero_shard_no_overwrite_done_marker_is_retryable( - tmp_path: Path, -) -> None: - zarr_out = tmp_path / "single-zero-shard-no-overwrite-done-retry.zarr" - marker = zarr_out / "_refiner" / "write_zarr.done" - marker.parent.mkdir(parents=True) - marker.write_bytes(b"") - runtime = _FinalizedWorkersRuntime([]) - - with set_active_run_context( - job_id="local", - stage_index=1, - worker_id="reducer", - worker_name=None, - runtime_lifecycle=cast(RuntimeLifecycle, runtime), - ): - _ZarrMergeReducerSink( - str(zarr_out), - store_template="{shard_id}__w{worker_id}.zarr", - episode_ends_path="meta/episode_ends", - array_chunk_bytes=1024, - overwrite=False, - ).write_block([DictRow({}, shard_id="reduce")]) - - -def test_write_zarr_single_store_no_overwrite_started_merge_is_retryable( - tmp_path: Path, -) -> None: - zarr_out = tmp_path / "single-no-overwrite-started-retry.zarr" - worker_id = "worker-a" - root = _open_test_zarr(zarr_out, mode="w") - assert dict(root.attrs) == {} - _write_part_zarr( - zarr_out / "_parts" / f"shard-a__w{worker_token_for(worker_id)}.zarr", - { - "data/action": np.asarray([[3.0]], dtype=np.float32), - "meta/episode_ends": np.asarray([1], dtype=np.int64), - }, - ) - marker = zarr_out / "_refiner" / "write_zarr.started" - marker.parent.mkdir(parents=True) - marker.write_bytes(b"") - runtime = _FinalizedWorkersRuntime( - [ - FinalizedShardWorker( - shard_id="shard-a", - worker_id=worker_id, - global_ordinal=0, - ) - ] - ) - - with set_active_run_context( - job_id="local", - stage_index=1, - worker_id="reducer", - worker_name=None, - runtime_lifecycle=cast(RuntimeLifecycle, runtime), - ): - _ZarrMergeReducerSink( - str(zarr_out), - store_template="{shard_id}__w{worker_id}.zarr", - episode_ends_path="meta/episode_ends", - array_chunk_bytes=1024, - overwrite=False, - ).write_block([DictRow({}, shard_id="reduce")]) - - row = mdr.read_zarr( - zarr_out, - arrays={"action": "data/action", "episode_ends": "meta/episode_ends"}, - file_path_column=None, - ).take(1)[0] - np.testing.assert_allclose(row["action"], [[3.0]]) - assert row["episode_ends"].tolist() == [1] - - -def test_write_zarr_single_store_no_overwrite_recovers_empty_root_without_marker( - tmp_path: Path, -) -> None: - zarr_out = tmp_path / "single-no-overwrite-empty-root-retry.zarr" - worker_id = "worker-a" - root = _open_test_zarr(zarr_out, mode="w") - assert list(root.array_keys()) == [] - _write_part_zarr( - zarr_out / "_parts" / f"shard-a__w{worker_token_for(worker_id)}.zarr", - { - "data/action": np.asarray([[3.0]], dtype=np.float32), - "meta/episode_ends": np.asarray([1], dtype=np.int64), - }, - ) - runtime = _FinalizedWorkersRuntime( - [ - FinalizedShardWorker( - shard_id="shard-a", - worker_id=worker_id, - global_ordinal=0, - ) - ] - ) - - with set_active_run_context( - job_id="local", - stage_index=1, - worker_id="reducer", - worker_name=None, - runtime_lifecycle=cast(RuntimeLifecycle, runtime), - ): - _ZarrMergeReducerSink( - str(zarr_out), - store_template="{shard_id}__w{worker_id}.zarr", - episode_ends_path="meta/episode_ends", - array_chunk_bytes=1024, - overwrite=False, - ).write_block([DictRow({}, shard_id="reduce")]) - - row = mdr.read_zarr( - zarr_out, - arrays={"action": "data/action", "episode_ends": "meta/episode_ends"}, - file_path_column=None, - ).take(1)[0] - np.testing.assert_allclose(row["action"], [[3.0]]) - assert row["episode_ends"].tolist() == [1] - - -def test_write_zarr_single_store_no_overwrite_started_merge_replaces_partial_output( - tmp_path: Path, -) -> None: - zarr_out = tmp_path / "single-no-overwrite-started-partial.zarr" - worker_id = "worker-a" - _write_part_zarr( - zarr_out, - {"data/action": np.asarray([[0.0]], dtype=np.float32)}, - ) - _write_part_zarr( - zarr_out / "_parts" / f"shard-a__w{worker_token_for(worker_id)}.zarr", - { - "data/action": np.asarray([[3.0]], dtype=np.float32), - "meta/episode_ends": np.asarray([1], dtype=np.int64), - }, - ) - marker = zarr_out / "_refiner" / "write_zarr.started" - marker.parent.mkdir(parents=True) - marker.write_bytes(b"") - runtime = _FinalizedWorkersRuntime( - [ - FinalizedShardWorker( - shard_id="shard-a", - worker_id=worker_id, - global_ordinal=0, - ) - ] - ) - - with set_active_run_context( - job_id="local", - stage_index=1, - worker_id="reducer", - worker_name=None, - runtime_lifecycle=cast(RuntimeLifecycle, runtime), - ): - _ZarrMergeReducerSink( - str(zarr_out), - store_template="{shard_id}__w{worker_id}.zarr", - episode_ends_path="meta/episode_ends", - array_chunk_bytes=1024, - overwrite=False, - ).write_block([DictRow({}, shard_id="reduce")]) - - row = mdr.read_zarr( - zarr_out, - arrays={"action": "data/action", "episode_ends": "meta/episode_ends"}, - file_path_column=None, - ).take(1)[0] - np.testing.assert_allclose(row["action"], [[3.0]]) - assert row["episode_ends"].tolist() == [1] - - def test_write_zarr_single_store_parts_are_resume_stable(tmp_path: Path) -> None: zarr_out = tmp_path / "single-resume-stable.zarr" worker_id = "original-worker" @@ -2453,7 +1578,6 @@ def test_write_zarr_single_store_parts_are_resume_stable(tmp_path: Path) -> None store_template="{shard_id}__w{worker_id}.zarr", episode_ends_path="meta/episode_ends", array_chunk_bytes=1024, - overwrite=True, ).write_block([DictRow({}, shard_id="reduce")]) row = mdr.read_zarr( @@ -2465,10 +1589,10 @@ def test_write_zarr_single_store_parts_are_resume_stable(tmp_path: Path) -> None assert row["episode_ends"].tolist() == [1] -def test_write_zarr_single_store_overwrite_clears_root_attrs( +def test_write_zarr_single_store_replace_clears_root_attrs( tmp_path: Path, ) -> None: - zarr_out = tmp_path / "single-overwrite-attrs.zarr" + zarr_out = tmp_path / "single-replace-attrs.zarr" root = _open_test_zarr(zarr_out, mode="w") root.attrs["task"] = "old" @@ -2480,9 +1604,9 @@ def test_write_zarr_single_store_overwrite_clears_root_attrs( reduce_to_single_store=True, ) .launch_local( - name="zarr-single-overwrite-attrs", + name="zarr-single-replace-attrs", num_workers=1, - rundir=str(tmp_path / "run-overwrite-attrs"), + rundir=str(tmp_path / "run-replace-attrs"), ) ) @@ -2544,7 +1668,6 @@ def test_write_zarr_single_store_rejects_part_dtype_drift( store_template="{shard_id}__w{worker_id}.zarr", episode_ends_path="meta/episode_ends", array_chunk_bytes=1024, - overwrite=True, ).write_block([DictRow({}, shard_id="reduce")]) assert first_part.exists() assert second_part.exists() From f36cc8653131859b97fca0f01de190694c8fcbc6 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sun, 24 May 2026 23:24:04 +0200 Subject: [PATCH 26/39] Simplify Zarr write cleanup --- src/refiner/pipeline/sinks/base.py | 4 + src/refiner/pipeline/sinks/reducer/file.py | 89 +++++---- src/refiner/pipeline/sinks/zarr.py | 200 +++++++++------------ src/refiner/worker/runner.py | 10 ++ tests/pipeline/test_sinks.py | 39 ---- tests/readers/test_zarr_reader.py | 20 +-- tests/worker/test_runner.py | 35 ++++ 7 files changed, 181 insertions(+), 216 deletions(-) diff --git a/src/refiner/pipeline/sinks/base.py b/src/refiner/pipeline/sinks/base.py index b481523e..f63634cc 100644 --- a/src/refiner/pipeline/sinks/base.py +++ b/src/refiner/pipeline/sinks/base.py @@ -82,6 +82,10 @@ def on_shard_complete(self, shard_id: str) -> None: """ del shard_id + def on_shard_finalized(self, shard_id: str) -> None: + """Run cleanup after the shard has been marked complete.""" + del shard_id + def close(self) -> None: """Finalize sink resources after all shard work is complete. diff --git a/src/refiner/pipeline/sinks/reducer/file.py b/src/refiner/pipeline/sinks/reducer/file.py index 5d67fa69..01211c79 100644 --- a/src/refiner/pipeline/sinks/reducer/file.py +++ b/src/refiner/pipeline/sinks/reducer/file.py @@ -57,23 +57,6 @@ def _compile_managed_path_pattern(filename_template: str) -> re.Pattern[str]: return re.compile("^" + "".join(parts) + "$") -def _managed_listing_prefix(filename_template: str) -> str: - literal_prefix = "" - for literal_text, field_name, _format_spec, _conversion in Formatter().parse( - filename_template - ): - literal_prefix += literal_text - if field_name is not None: - break - if "/" not in literal_prefix: - return "" - return literal_prefix.rsplit("/", maxsplit=1)[0] - - -def _path_depth(path: str) -> int: - return len([part for part in path.split("/") if part]) - - class FileCleanupReducerSink(BaseSink): """Delete non-finalized deterministic file-sink outputs.""" @@ -136,7 +119,45 @@ def _run_cleanup(self) -> None: for row in get_finalized_workers(stage_index=stage_index - 1) } - listed_paths = list(self._listed_cleanup_paths()) + if not self.recursive or self.assets_subdir is not None: + try: + listed_paths = self.output.find("") + except FileNotFoundError: + listed_paths = [] + else: + literal_prefix = "" + for ( + literal_text, + field_name, + _format_spec, + _conversion, + ) in Formatter().parse(self.filename_template): + literal_prefix += literal_text + if field_name is not None: + break + listing_prefix = ( + "" + if "/" not in literal_prefix + else literal_prefix.rsplit("/", maxsplit=1)[0] + ) + paths = [listing_prefix] + template_depth = len( + [part for part in self.filename_template.split("/") if part] + ) + prefix_depth = len([part for part in listing_prefix.split("/") if part]) + for _ in range(max(1, template_depth - prefix_depth)): + next_paths: list[str] = [] + for path in paths: + try: + next_paths.extend(self.output.ls(path, detail=False)) + except (FileNotFoundError, NotADirectoryError): + continue + paths = next_paths + listed_paths = [ + path + for path in paths + if isinstance(path, str) and not path.rstrip("/").endswith("/.") + ] assets_prefix = ( f"{self.assets_subdir.rstrip('/')}/" @@ -166,10 +187,6 @@ def _run_cleanup(self) -> None: continue managed_path = rel_path - marker_path = None - if rel_path.endswith(".empty"): - managed_path = rel_path[: -len(".empty")] - marker_path = rel_path match = self._managed_path_pattern.fullmatch(managed_path) if match is None and self.recursive: parts = managed_path.split("/") @@ -183,7 +200,7 @@ def _run_cleanup(self) -> None: continue if (match.group("shard_id"), match.group("worker_id")) in keep_pairs: continue - stale_managed_paths.add(marker_path or managed_path) + stale_managed_paths.add(managed_path) for path in sorted(stale_asset_attempts): try: @@ -196,31 +213,5 @@ def _run_cleanup(self) -> None: except FileNotFoundError: continue - def _listed_cleanup_paths(self) -> list[str]: - if not self.recursive or self.assets_subdir is not None: - try: - return self.output.find("") - except FileNotFoundError: - return [] - - listing_prefix = _managed_listing_prefix(self.filename_template) - paths = [listing_prefix] - depth = max( - 1, _path_depth(self.filename_template) - _path_depth(listing_prefix) - ) - for _ in range(depth): - next_paths: list[str] = [] - for path in paths: - try: - next_paths.extend(self.output.ls(path, detail=False)) - except (FileNotFoundError, NotADirectoryError): - continue - paths = next_paths - return [ - path - for path in paths - if isinstance(path, str) and not path.rstrip("/").endswith("/.") - ] - __all__ = ["FileCleanupReducerSink"] diff --git a/src/refiner/pipeline/sinks/zarr.py b/src/refiner/pipeline/sinks/zarr.py index 3df55693..2547ef4e 100644 --- a/src/refiner/pipeline/sinks/zarr.py +++ b/src/refiner/pipeline/sinks/zarr.py @@ -35,12 +35,6 @@ class _ShardStore: row_end: int = 0 -@dataclass(frozen=True) -class _PartStore: - relpath: str - paths: set[str] - - class ZarrSink(BaseSink): def __init__( self, @@ -186,70 +180,66 @@ def _write_row_values( row_videos: list[tuple[str, VideoSource]], lengths: list[int], ) -> None: - store: _ShardStore | None = None - if row_arrays or row_videos: - store = self._store(shard_id) + store = self._store(shard_id) if not lengths: expected_length = None else: expected_length = lengths[0] if any(item != expected_length for item in lengths): raise ValueError("Zarr arrays for one row must have matching lengths") - if store is not None: - rollback_lengths: dict[str, int | None] = {} - for zarr_path in [*row_arrays, *(path for path, _ in row_videos)]: - dataset = store.arrays.get(zarr_path) - rollback_lengths[zarr_path] = ( - None if dataset is None else int(dataset.shape[0]) - ) - try: - for zarr_path, array in row_arrays.items(): - self._validate_array_append(store, zarr_path, array) - for zarr_path, array in row_arrays.items(): - self._append_array(store, zarr_path, array) - for zarr_path, video in row_videos: - video_length = submit( - self._append_video( - store, - zarr_path, - video, - expected_length=expected_length, - ) - ).result() - lengths.append(video_length) - if lengths: - length = lengths[0] - if any(item != length for item in lengths): - raise ValueError( - "Zarr arrays for one row must have matching lengths" - ) - if lengths and self.episode_ends_path is not None: - dataset = store.arrays.get(self.episode_ends_path) - rollback_lengths[self.episode_ends_path] = ( - None if dataset is None else int(dataset.shape[0]) - ) - store.row_end += lengths[0] - self._append_array( + + rollback_lengths: dict[str, int | None] = {} + for zarr_path in [*row_arrays, *(path for path, _ in row_videos)]: + dataset = store.arrays.get(zarr_path) + rollback_lengths[zarr_path] = ( + None if dataset is None else int(dataset.shape[0]) + ) + try: + for zarr_path, array in row_arrays.items(): + self._validate_array_append(store, zarr_path, array) + for zarr_path, array in row_arrays.items(): + self._append_array(store, zarr_path, array) + for zarr_path, video in row_videos: + video_length = submit( + self._append_video( store, - self.episode_ends_path, - np.asarray([store.row_end], dtype=np.int64), + zarr_path, + video, + expected_length=expected_length, ) - except Exception: - for zarr_path, length in rollback_lengths.items(): - if length is None: - self._drop_array(store, zarr_path) - continue - dataset = store.arrays.get(zarr_path) - if dataset is not None: - dataset.resize((length, *dataset.shape[1:])) - if self.episode_ends_path is not None: - dataset = store.arrays.get(self.episode_ends_path) - store.row_end = ( - 0 - if dataset is None or dataset.shape[0] == 0 - else int(dataset[-1]) + ).result() + lengths.append(video_length) + if lengths: + length = lengths[0] + if any(item != length for item in lengths): + raise ValueError( + "Zarr arrays for one row must have matching lengths" ) - raise + if lengths and self.episode_ends_path is not None: + dataset = store.arrays.get(self.episode_ends_path) + rollback_lengths[self.episode_ends_path] = ( + None if dataset is None else int(dataset.shape[0]) + ) + store.row_end += lengths[0] + self._append_array( + store, + self.episode_ends_path, + np.asarray([store.row_end], dtype=np.int64), + ) + except Exception: + for zarr_path, length in rollback_lengths.items(): + if length is None: + self._drop_array(store, zarr_path) + continue + dataset = store.arrays.get(zarr_path) + if dataset is not None: + dataset.resize((length, *dataset.shape[1:])) + if self.episode_ends_path is not None: + dataset = store.arrays.get(self.episode_ends_path) + store.row_end = ( + 0 if dataset is None or dataset.shape[0] == 0 else int(dataset[-1]) + ) + raise def _row_values( self, @@ -383,10 +373,6 @@ def _store(self, shard_id: str) -> _ShardStore: store = _ShardStore( zarr.open_group(store=_zarr_store(self.output, relpath, mode="w"), mode="w") ) - try: - self.output.rm(self._empty_marker_relpath(shard_id)) - except FileNotFoundError: - pass self._stores[relpath] = store return store @@ -397,22 +383,16 @@ def _store_relpath(self, shard_id: str) -> str: worker_id=get_active_worker_token(), ) if self.reduce_to_single_store: - return _part_store_relpath(relpath) + return f"_parts/{relpath}" return relpath def on_shard_complete(self, shard_id: str) -> None: relpath = self._store_relpath(shard_id) if relpath not in self._stores: - try: - self.output.rm(relpath, recursive=True) - except FileNotFoundError: - pass - with self.output.open(self._empty_marker_relpath(shard_id), mode="wb"): - pass - self._stores.pop(relpath, None) + import zarr - def _empty_marker_relpath(self, shard_id: str) -> str: - return self._store_relpath(shard_id) + ".empty" + zarr.open_group(store=_zarr_store(self.output, relpath, mode="w"), mode="w") + self._stores.pop(relpath, None) def _append_array( self, @@ -558,7 +538,6 @@ def __init__( self.episode_ends_path = episode_ends_path self.array_chunk_bytes = array_chunk_bytes self._merged = False - self._remove_parts_on_complete = False @property def counts_output_rows(self) -> bool: @@ -583,7 +562,6 @@ def describe(self) -> tuple[str, str, dict[str, object]]: def _merge(self) -> None: if self._merged: return - self._merged = True stage_index = get_active_stage_index() if stage_index is None or stage_index <= 0: @@ -591,7 +569,17 @@ def _merge(self) -> None: "write_zarr_reduce requires an active reducer stage with a prior writer stage" ) - expected_parts = self._expected_parts(stage_index) + expected_parts = [ + "_parts/" + + _render_store_relpath( + self.store_template, + shard_id=row.shard_id, + worker_id=row.worker_token, + ) + for row in sort_finalized_workers( + get_finalized_workers(stage_index=stage_index - 1), + ) + ] if not expected_parts: import zarr @@ -600,7 +588,7 @@ def _merge(self) -> None: mode="a", ) _clear_final_group(final) - self._remove_parts_on_complete = True + self._merged = True return parts = self._collect_parts(expected_parts) @@ -615,12 +603,12 @@ def _merge(self) -> None: row_offset = 0 arrays: dict[str, Any] = {} - for part in parts: + for relpath, paths in parts: source = zarr.open_group( - store=_zarr_store(self.output, part.relpath, mode="r"), + store=_zarr_store(self.output, relpath, mode="r"), mode="r", ) - for path in sorted(part.paths): + for path in sorted(paths): source_array = source[path] if path == self.episode_ends_path: if source_array.shape[0] == 0: @@ -665,46 +653,29 @@ def _merge(self) -> None: chunks=getattr(source_array, "chunks", None), compressor=getattr(source_array, "compressor", None), ) + self._merged = True - self._remove_parts_on_complete = True - - def on_shard_complete(self, shard_id: str) -> None: + def on_shard_finalized(self, shard_id: str) -> None: del shard_id - if not self._remove_parts_on_complete: + if not self._merged: return - _remove_parts(self.output, best_effort=True) + _remove_parts(self.output) try: if not self.output.ls("_parts"): self.output.rmdir("_parts") except (FileNotFoundError, OSError, ValueError): pass - def _expected_parts(self, stage_index: int) -> list[str]: - return [ - self._part_relpath(row.shard_id, row.worker_token) - for row in sort_finalized_workers( - get_finalized_workers(stage_index=stage_index - 1), - ) - ] - - def _part_relpath(self, shard_id: str, worker_token: str) -> str: - relpath = _render_store_relpath( - self.store_template, - shard_id=shard_id, - worker_id=worker_token, - ) - return _part_store_relpath(relpath) - - def _collect_parts(self, expected_parts: Iterable[str]) -> list[_PartStore]: + def _collect_parts( + self, expected_parts: Iterable[str] + ) -> list[tuple[str, set[str]]]: import zarr - parts: list[_PartStore] = [] + parts: list[tuple[str, set[str]]] = [] payload_paths: set[str] | None = None schemas: dict[str, tuple[tuple[int, ...], np.dtype[Any]]] = {} for relpath in expected_parts: if not self.output.exists(relpath): - if self.output.exists(f"{relpath}.empty"): - continue raise ValueError(f"Zarr part store is missing: {relpath}") source = zarr.open_group( store=_zarr_store(self.output, relpath, mode="r"), @@ -742,7 +713,7 @@ def _collect_parts(self, expected_parts: Iterable[str]) -> list[_PartStore]: raise ValueError( f"Zarr arrays for {path!r} must have matching dtypes" ) - parts.append(_PartStore(relpath=relpath, paths=source_paths)) + parts.append((relpath, source_paths)) return parts @@ -852,10 +823,6 @@ def _as_array(value: Any) -> np.ndarray: return np.asarray(value) -def _part_store_relpath(relpath: str) -> str: - return f"_parts/{relpath}" - - def _zarr_store(output: DataFolder, path: str = "", *, mode: str = "r"): import zarr @@ -876,14 +843,11 @@ def _iter_array_paths(group: Any, prefix: str = "") -> Iterable[str]: yield from _iter_array_paths(item, path) -def _remove_parts(output: DataFolder, *, best_effort: bool = False) -> None: +def _remove_parts(output: DataFolder) -> None: try: output.rm("_parts", recursive=True) except FileNotFoundError: pass - except (OSError, ValueError): - if not best_effort: - raise def _validate_zarr_stores(output: DataFolder, relpaths: Iterable[str]) -> None: @@ -893,14 +857,14 @@ def _validate_zarr_stores(output: DataFolder, relpaths: Iterable[str]) -> None: schemas: dict[str, tuple[tuple[int, ...], np.dtype[Any]]] = {} for relpath in relpaths: if not output.exists(relpath): - if output.exists(f"{relpath}.empty"): - continue raise ValueError(f"Zarr store is missing: {relpath}") source = zarr.open_group( store=_zarr_store(output, relpath, mode="r"), mode="r", ) source_paths = set(_iter_array_paths(source)) + if not source_paths: + continue if payload_paths is None: payload_paths = source_paths elif source_paths != payload_paths: diff --git a/src/refiner/worker/runner.py b/src/refiner/worker/runner.py index b046a6cc..4468f1ff 100644 --- a/src/refiner/worker/runner.py +++ b/src/refiner/worker/runner.py @@ -130,6 +130,16 @@ def _complete_shard(shard_id: str) -> None: sink.on_shard_complete(shard_id) self.user_metrics_emitter.force_flush_user_metrics() self.runtime_lifecycle.complete(shard) + try: + with set_active_step_index(sink_step_index): + sink.on_shard_finalized(shard_id) + except Exception as e: # noqa: BLE001 + logger.warning( + "post-completion sink cleanup failed shard_id={}: {}: {}", + shard.id, + type(e).__name__, + e, + ) with inflight_lock: inflight_by_id.pop(shard_id, None) source_done_shards.discard(shard_id) diff --git a/tests/pipeline/test_sinks.py b/tests/pipeline/test_sinks.py index cb396ba4..84237e3c 100644 --- a/tests/pipeline/test_sinks.py +++ b/tests/pipeline/test_sinks.py @@ -965,45 +965,6 @@ def test_file_cleanup_reducer_removes_non_finalized_directories(tmp_path) -> Non assert not loser_dir.exists() -def test_file_cleanup_reducer_removes_non_finalized_empty_markers(tmp_path) -> None: - output_dir = tmp_path / "zarr-cleanup-empty-markers" - shard_id = "0123456789ab" - winner_worker_id = "worker-2" - loser_worker_id = "worker-1" - winner_marker = ( - output_dir / f"{shard_id}__w{worker_token_for(winner_worker_id)}.zarr.empty" - ) - loser_marker = ( - output_dir / f"{shard_id}__w{worker_token_for(loser_worker_id)}.zarr.empty" - ) - output_dir.mkdir(parents=True) - winner_marker.write_bytes(b"") - loser_marker.write_bytes(b"") - - reducer = FileCleanupReducerSink( - output_dir, - filename_template="{shard_id}__w{worker_id}.zarr", - reducer_name="cleanup_zarr", - recursive=True, - ) - with set_active_run_context( - job_id="job", - stage_index=1, - worker_id="reducer", - worker_name=None, - runtime_lifecycle=cast( - RuntimeLifecycle, - _FinalizedWorkersRuntime( - [FinalizedShardWorker(shard_id=shard_id, worker_id=winner_worker_id)] - ), - ), - ): - reducer.write_block([DictRow({"task_rank": 0}, shard_id="reduce")]) - - assert winner_marker.exists() - assert not loser_marker.exists() - - def test_file_cleanup_reducer_removes_non_finalized_nested_directories( tmp_path, ) -> None: diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py index 1bf21105..ecbe687e 100644 --- a/tests/readers/test_zarr_reader.py +++ b/tests/readers/test_zarr_reader.py @@ -1119,10 +1119,10 @@ def test_write_zarr_non_reduced_cleanup_rejects_missing_finalized_store( ).write_block([DictRow({}, shard_id="reduce")]) -def test_write_zarr_empty_shard_completion_removes_stale_store( +def test_write_zarr_empty_shard_completion_replaces_stale_store( tmp_path: Path, ) -> None: - zarr_out = tmp_path / "empty-shard-removes-stale-store.zarr" + zarr_out = tmp_path / "empty-shard-replaces-stale-store.zarr" worker_id = "worker-a" stale = zarr_out / f"shard-a__w{worker_token_for(worker_id)}.zarr" _write_part_zarr(stale, {"data/action": np.asarray([[9.0]], dtype=np.float32)}) @@ -1140,18 +1140,18 @@ def test_write_zarr_empty_shard_completion_removes_stale_store( reduce_to_single_store=False, ).on_shard_complete("shard-a") - assert not stale.exists() - assert stale.with_name(stale.name + ".empty").exists() + root = _open_test_zarr(stale, mode="r") + assert not list(root.array_keys()) + assert not list(root.group_keys()) -def test_write_zarr_non_reduced_cleanup_keeps_empty_markers_retryable( +def test_write_zarr_non_reduced_cleanup_keeps_empty_stores_retryable( tmp_path: Path, ) -> None: zarr_out = tmp_path / "sharded-empty-cleanup-retry.zarr" worker_id = "worker-a" - marker = zarr_out / f"shard-a__w{worker_token_for(worker_id)}.zarr.empty" - marker.parent.mkdir(parents=True) - marker.write_bytes(b"") + empty_store = zarr_out / f"shard-a__w{worker_token_for(worker_id)}.zarr" + _open_test_zarr(empty_store, mode="w") runtime = _FinalizedWorkersRuntime( [FinalizedShardWorker(shard_id="shard-a", worker_id=worker_id)] @@ -1169,7 +1169,7 @@ def test_write_zarr_non_reduced_cleanup_keeps_empty_markers_retryable( store_template="{shard_id}__w{worker_id}.zarr", ).write_block([DictRow({}, shard_id="reduce")]) - assert marker.exists() + assert empty_store.exists() def test_write_zarr_rejects_sharded_schema_drift_after_cleanup( @@ -1502,7 +1502,7 @@ def test_write_zarr_single_store_removes_parts_only_on_completion( ) reducer.write_block([DictRow({}, shard_id="reduce")]) assert part.exists() - reducer.on_shard_complete("reduce") + reducer.on_shard_finalized("reduce") row = mdr.read_zarr( zarr_out, diff --git a/tests/worker/test_runner.py b/tests/worker/test_runner.py index c2f6249d..50661272 100644 --- a/tests/worker/test_runner.py +++ b/tests/worker/test_runner.py @@ -467,6 +467,41 @@ def test_worker_completes_shards_only_after_sink_drain() -> None: assert runtime_lifecycle.completed_ids == [shard.id] +def test_worker_runs_post_completion_sink_hook_after_runtime_complete() -> None: + shard = _shard("p", 0, 1) + events: list[str] = [] + + class _OrderedRuntimeLifecycle(_FakeRuntimeLifecycle): + def complete(self, shard: Shard) -> None: + super().complete(shard) + events.append("runtime_complete") + + class _OrderedSink(_RecordingSink): + def on_shard_complete(self, shard_id: str) -> None: + super().on_shard_complete(shard_id) + events.append("sink_complete") + + def on_shard_finalized(self, shard_id: str) -> None: + events.append("sink_finalized") + + runtime_lifecycle = _OrderedRuntimeLifecycle([shard]) + sink = _OrderedSink() + worker = Worker( + pipeline=RefinerPipeline( + source=_FakeReader({shard.id: [DictRow({"x": 1})]}) + ).with_sink(sink), + job_id="job", + stage_index=0, + worker_id=runtime_lifecycle.worker_id, + runtime_lifecycle=runtime_lifecycle, + ) + + stats = worker.run() + + assert stats.completed == 1 + assert events == ["sink_complete", "runtime_complete", "sink_finalized"] + + def test_worker_metrics_use_correct_step_indexes_for_all_block_types( monkeypatch: pytest.MonkeyPatch, tmp_path ) -> None: From eba19e0c3e24b2b7d8d2b62daecc822a6c027296 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Mon, 25 May 2026 00:15:15 +0200 Subject: [PATCH 27/39] Remove recursive cleanup mode --- src/refiner/pipeline/sinks/reducer/file.py | 93 ++++++++++------------ src/refiner/pipeline/sinks/zarr.py | 1 - tests/pipeline/test_sinks.py | 17 ++-- 3 files changed, 51 insertions(+), 60 deletions(-) diff --git a/src/refiner/pipeline/sinks/reducer/file.py b/src/refiner/pipeline/sinks/reducer/file.py index 01211c79..7a895ac7 100644 --- a/src/refiner/pipeline/sinks/reducer/file.py +++ b/src/refiner/pipeline/sinks/reducer/file.py @@ -67,13 +67,11 @@ def __init__( filename_template: str, reducer_name: str, assets_subdir: str | None = None, - recursive: bool = False, ) -> None: self.output = DataFolder.resolve(output) self.filename_template = filename_template self.reducer_name = reducer_name self.assets_subdir = assets_subdir - self.recursive = recursive self._managed_path_pattern = _compile_managed_path_pattern(filename_template) self._cleanup_ran = False @@ -92,8 +90,6 @@ def describe(self) -> tuple[str, str, dict[str, object]]: } if self.assets_subdir is not None: args["assets_subdir"] = self.assets_subdir - if self.recursive: - args["recursive"] = True return ( self.reducer_name, "writer", @@ -119,45 +115,47 @@ def _run_cleanup(self) -> None: for row in get_finalized_workers(stage_index=stage_index - 1) } - if not self.recursive or self.assets_subdir is not None: + literal_prefix = "" + for ( + literal_text, + field_name, + _format_spec, + _conversion, + ) in Formatter().parse(self.filename_template): + literal_prefix += literal_text + if field_name is not None: + break + listing_prefix = ( + "" + if "/" not in literal_prefix + else literal_prefix.rsplit("/", maxsplit=1)[0] + ) + paths = [listing_prefix] + template_depth = len( + [part for part in self.filename_template.split("/") if part] + ) + prefix_depth = len([part for part in listing_prefix.split("/") if part]) + for _ in range(max(1, template_depth - prefix_depth)): + next_paths: list[str] = [] + for path in paths: + try: + next_paths.extend(self.output.ls(path, detail=False)) + except (FileNotFoundError, NotADirectoryError): + continue + paths = next_paths + listed_paths = [ + path + for path in paths + if isinstance(path, str) and not path.rstrip("/").endswith("/.") + ] + + if self.assets_subdir is not None: try: - listed_paths = self.output.find("") + asset_paths = self.output.find(self.assets_subdir) except FileNotFoundError: - listed_paths = [] + asset_paths = [] else: - literal_prefix = "" - for ( - literal_text, - field_name, - _format_spec, - _conversion, - ) in Formatter().parse(self.filename_template): - literal_prefix += literal_text - if field_name is not None: - break - listing_prefix = ( - "" - if "/" not in literal_prefix - else literal_prefix.rsplit("/", maxsplit=1)[0] - ) - paths = [listing_prefix] - template_depth = len( - [part for part in self.filename_template.split("/") if part] - ) - prefix_depth = len([part for part in listing_prefix.split("/") if part]) - for _ in range(max(1, template_depth - prefix_depth)): - next_paths: list[str] = [] - for path in paths: - try: - next_paths.extend(self.output.ls(path, detail=False)) - except (FileNotFoundError, NotADirectoryError): - continue - paths = next_paths - listed_paths = [ - path - for path in paths - if isinstance(path, str) and not path.rstrip("/").endswith("/.") - ] + asset_paths = [] assets_prefix = ( f"{self.assets_subdir.rstrip('/')}/" @@ -170,7 +168,7 @@ def _run_cleanup(self) -> None: # Extra template fields are treated as structure only. Authority is decided # solely from the finalized (shard_id, worker_id) pair extracted from each # managed path. - for rel_path in listed_paths: + for rel_path in asset_paths: if not isinstance(rel_path, str) or not rel_path or rel_path == ".": continue if assets_prefix is not None and ( @@ -186,16 +184,11 @@ def _run_cleanup(self) -> None: stale_asset_attempts.add(asset_path) continue + for rel_path in listed_paths: + if not isinstance(rel_path, str) or not rel_path or rel_path == ".": + continue managed_path = rel_path match = self._managed_path_pattern.fullmatch(managed_path) - if match is None and self.recursive: - parts = managed_path.split("/") - for index in range(1, len(parts)): - candidate = "/".join(parts[:index]) - match = self._managed_path_pattern.fullmatch(candidate) - if match is not None: - managed_path = candidate - break if match is None: continue if (match.group("shard_id"), match.group("worker_id")) in keep_pairs: @@ -209,7 +202,7 @@ def _run_cleanup(self) -> None: continue for path in sorted(stale_managed_paths): try: - self.output.rm(path, recursive=self.recursive) + self.output.rm(path, recursive=True) except FileNotFoundError: continue diff --git a/src/refiner/pipeline/sinks/zarr.py b/src/refiner/pipeline/sinks/zarr.py index 2547ef4e..e1ba7e02 100644 --- a/src/refiner/pipeline/sinks/zarr.py +++ b/src/refiner/pipeline/sinks/zarr.py @@ -468,7 +468,6 @@ def __init__(self, output: DataFolderLike, *, store_template: str) -> None: output=self.output, filename_template=store_template, reducer_name="write_zarr_reduce", - recursive=True, ) @property diff --git a/tests/pipeline/test_sinks.py b/tests/pipeline/test_sinks.py index 84237e3c..07ef0d55 100644 --- a/tests/pipeline/test_sinks.py +++ b/tests/pipeline/test_sinks.py @@ -945,7 +945,6 @@ def test_file_cleanup_reducer_removes_non_finalized_directories(tmp_path) -> Non output_dir, filename_template="{shard_id}__w{worker_id}.zarr", reducer_name="cleanup_zarr", - recursive=True, ) with set_active_run_context( job_id="job", @@ -987,7 +986,6 @@ def test_file_cleanup_reducer_removes_non_finalized_nested_directories( output_dir, filename_template="split/{shard_id}__w{worker_id}.zarr", reducer_name="cleanup_zarr", - recursive=True, ) with set_active_run_context( job_id="job", @@ -1027,7 +1025,6 @@ def test_file_cleanup_reducer_removes_dynamic_nested_directories(tmp_path) -> No output_dir, filename_template="split/{shard_id}/{worker_id}.zarr", reducer_name="cleanup_zarr", - recursive=True, ) with set_active_run_context( job_id="job", @@ -1047,7 +1044,7 @@ def test_file_cleanup_reducer_removes_dynamic_nested_directories(tmp_path) -> No assert not loser_dir.exists() -def test_file_cleanup_reducer_ignores_files_during_recursive_traversal( +def test_file_cleanup_reducer_ignores_files_during_template_listing( tmp_path, ) -> None: output_dir = tmp_path / "zarr-cleanup-mixed" @@ -1070,7 +1067,6 @@ def test_file_cleanup_reducer_ignores_files_during_recursive_traversal( output_dir, filename_template="split/{shard_id}/{worker_id}.zarr", reducer_name="cleanup_zarr", - recursive=True, ) with set_active_run_context( job_id="job", @@ -1091,7 +1087,7 @@ def test_file_cleanup_reducer_ignores_files_during_recursive_traversal( assert (output_dir / "split" / "README.txt").read_text(encoding="utf-8") == "notes" -def test_file_cleanup_reducer_propagates_recursive_listing_errors( +def test_file_cleanup_reducer_propagates_template_listing_errors( tmp_path, monkeypatch ) -> None: output_dir = tmp_path / "zarr-cleanup-list-error" @@ -1100,7 +1096,6 @@ def test_file_cleanup_reducer_propagates_recursive_listing_errors( output_dir, filename_template="{shard_id}__w{worker_id}.zarr", reducer_name="cleanup_zarr", - recursive=True, ) def fail_ls(*_args, **_kwargs): @@ -1141,8 +1136,12 @@ def test_file_cleanup_reducer_tolerates_duplicate_listed_paths( ) monkeypatch.setattr( reducer.output, - "find", - lambda _: [winner_path.name, winner_path.name, loser_path.name], + "ls", + lambda *_args, **_kwargs: [ + winner_path.name, + winner_path.name, + loser_path.name, + ], ) with set_active_run_context( From 0114d8f8a6c6f895534bf7a495b7e3b9912ffaa8 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Mon, 25 May 2026 00:24:36 +0200 Subject: [PATCH 28/39] Simplify file cleanup reducer --- src/refiner/pipeline/sinks/reducer/file.py | 95 ++++++++-------------- 1 file changed, 33 insertions(+), 62 deletions(-) diff --git a/src/refiner/pipeline/sinks/reducer/file.py b/src/refiner/pipeline/sinks/reducer/file.py index 7a895ac7..4f818e83 100644 --- a/src/refiner/pipeline/sinks/reducer/file.py +++ b/src/refiner/pipeline/sinks/reducer/file.py @@ -20,7 +20,7 @@ _DEFAULT_FIELD_PATTERN = r"[^/]+" -def _compile_managed_path_pattern(filename_template: str) -> re.Pattern[str]: +def _compile_output_path_pattern(filename_template: str) -> re.Pattern[str]: parts: list[str] = [] seen_fields: set[str] = set() @@ -72,7 +72,7 @@ def __init__( self.filename_template = filename_template self.reducer_name = reducer_name self.assets_subdir = assets_subdir - self._managed_path_pattern = _compile_managed_path_pattern(filename_template) + self._output_path_pattern = _compile_output_path_pattern(filename_template) self._cleanup_ran = False def write_shard_block(self, shard_id, block) -> None: @@ -115,27 +115,20 @@ def _run_cleanup(self) -> None: for row in get_finalized_workers(stage_index=stage_index - 1) } + template_parts = [part for part in self.filename_template.split("/") if part] literal_prefix = "" - for ( - literal_text, - field_name, - _format_spec, - _conversion, - ) in Formatter().parse(self.filename_template): + for literal_text, field_name, _format_spec, _conversion in Formatter().parse( + self.filename_template + ): literal_prefix += literal_text if field_name is not None: break listing_prefix = ( - "" - if "/" not in literal_prefix - else literal_prefix.rsplit("/", maxsplit=1)[0] + "" if "/" not in literal_prefix else literal_prefix.rsplit("/", 1)[0] ) paths = [listing_prefix] - template_depth = len( - [part for part in self.filename_template.split("/") if part] - ) - prefix_depth = len([part for part in listing_prefix.split("/") if part]) - for _ in range(max(1, template_depth - prefix_depth)): + prefix_parts = [part for part in listing_prefix.split("/") if part] + for _ in range(max(1, len(template_parts) - len(prefix_parts))): next_paths: list[str] = [] for path in paths: try: @@ -143,64 +136,42 @@ def _run_cleanup(self) -> None: except (FileNotFoundError, NotADirectoryError): continue paths = next_paths - listed_paths = [ - path - for path in paths - if isinstance(path, str) and not path.rstrip("/").endswith("/.") - ] + + paths_to_delete: set[str] = set() + # Extra template fields are structure only. Authority is decided from + # the finalized (shard_id, worker_id) pair extracted from the path. + for rel_path in paths: + if not isinstance(rel_path, str) or not rel_path or rel_path == ".": + continue + if rel_path.rstrip("/").endswith("/."): + continue + match = self._output_path_pattern.fullmatch(rel_path) + if match is None: + continue + if (match.group("shard_id"), match.group("worker_id")) in keep_pairs: + continue + paths_to_delete.add(rel_path) if self.assets_subdir is not None: + asset_prefix = f"{self.assets_subdir.rstrip('/')}/" try: asset_paths = self.output.find(self.assets_subdir) except FileNotFoundError: asset_paths = [] - else: - asset_paths = [] - - assets_prefix = ( - f"{self.assets_subdir.rstrip('/')}/" - if self.assets_subdir is not None - else None - ) - - stale_asset_attempts: set[str] = set() - stale_managed_paths: set[str] = set() - # Extra template fields are treated as structure only. Authority is decided - # solely from the finalized (shard_id, worker_id) pair extracted from each - # managed path. - for rel_path in asset_paths: - if not isinstance(rel_path, str) or not rel_path or rel_path == ".": - continue - if assets_prefix is not None and ( - rel_path == self.assets_subdir or rel_path.startswith(assets_prefix) - ): - attempt_dir = rel_path[len(assets_prefix) :].split("/", maxsplit=1)[0] + for rel_path in asset_paths: + if not isinstance(rel_path, str) or not rel_path.startswith( + asset_prefix + ): + continue + attempt_dir = rel_path[len(asset_prefix) :].split("/", maxsplit=1)[0] match = ASSET_ATTEMPT_DIR_RE.fullmatch(attempt_dir) if match is None: continue - asset_path = f"{assets_prefix}{attempt_dir}" if (match.group("shard_id"), match.group("worker_id")) in keep_pairs: continue - stale_asset_attempts.add(asset_path) - continue + paths_to_delete.add(f"{asset_prefix}{attempt_dir}") - for rel_path in listed_paths: - if not isinstance(rel_path, str) or not rel_path or rel_path == ".": - continue - managed_path = rel_path - match = self._managed_path_pattern.fullmatch(managed_path) - if match is None: - continue - if (match.group("shard_id"), match.group("worker_id")) in keep_pairs: - continue - stale_managed_paths.add(managed_path) - - for path in sorted(stale_asset_attempts): - try: - self.output.rm(path, recursive=True) - except FileNotFoundError: - continue - for path in sorted(stale_managed_paths): + for path in sorted(paths_to_delete): try: self.output.rm(path, recursive=True) except FileNotFoundError: From f28e3f7350834c29e5ff8fe4641d07ee549ef3ba Mon Sep 17 00:00:00 2001 From: guipenedo Date: Mon, 25 May 2026 00:29:59 +0200 Subject: [PATCH 29/39] Prune file cleanup listing --- src/refiner/pipeline/sinks/reducer/file.py | 67 ++++++++++++---------- 1 file changed, 37 insertions(+), 30 deletions(-) diff --git a/src/refiner/pipeline/sinks/reducer/file.py b/src/refiner/pipeline/sinks/reducer/file.py index 4f818e83..9f5b732e 100644 --- a/src/refiner/pipeline/sinks/reducer/file.py +++ b/src/refiner/pipeline/sinks/reducer/file.py @@ -20,32 +20,36 @@ _DEFAULT_FIELD_PATTERN = r"[^/]+" -def _compile_output_path_pattern(filename_template: str) -> re.Pattern[str]: - parts: list[str] = [] +def _compile_output_path_patterns(filename_template: str) -> list[re.Pattern[str]]: + path_parts: list[str] = [] + patterns: list[re.Pattern[str]] = [] seen_fields: set[str] = set() - for literal_text, field_name, format_spec, conversion in Formatter().parse( - filename_template - ): - parts.append(re.escape(literal_text)) - if field_name is None: - continue - if conversion is not None or format_spec: - raise ValueError( - "filename_template reducer matching only supports plain " - "named fields without conversion or format specifiers" - ) - if not field_name.isidentifier(): - raise ValueError( - "filename_template reducer matching only supports plain named fields" - ) - if field_name in seen_fields: - # Repeated fields in the template must resolve to the same path segment. - parts.append(f"(?P={field_name})") - continue - pattern = _FIELD_PATTERNS.get(field_name, _DEFAULT_FIELD_PATTERN) - parts.append(f"(?P<{field_name}>{pattern})") - seen_fields.add(field_name) + for segment in (part for part in filename_template.split("/") if part): + segment_parts: list[str] = [] + for literal_text, field_name, format_spec, conversion in Formatter().parse( + segment + ): + segment_parts.append(re.escape(literal_text)) + if field_name is None: + continue + if conversion is not None or format_spec: + raise ValueError( + "filename_template reducer matching only supports plain " + "named fields without conversion or format specifiers" + ) + if not field_name.isidentifier(): + raise ValueError( + "filename_template reducer matching only supports plain named fields" + ) + if field_name in seen_fields: + segment_parts.append(f"(?P={field_name})") + continue + pattern = _FIELD_PATTERNS.get(field_name, _DEFAULT_FIELD_PATTERN) + segment_parts.append(f"(?P<{field_name}>{pattern})") + seen_fields.add(field_name) + path_parts.append("".join(segment_parts)) + patterns.append(re.compile("^" + "/".join(path_parts) + "$")) missing_fields = sorted(_REQUIRED_TEMPLATE_FIELDS.difference(seen_fields)) if missing_fields: @@ -54,7 +58,7 @@ def _compile_output_path_pattern(filename_template: str) -> re.Pattern[str]: + ", ".join(f"{{{field_name}}}" for field_name in missing_fields) ) - return re.compile("^" + "".join(parts) + "$") + return patterns class FileCleanupReducerSink(BaseSink): @@ -72,7 +76,7 @@ def __init__( self.filename_template = filename_template self.reducer_name = reducer_name self.assets_subdir = assets_subdir - self._output_path_pattern = _compile_output_path_pattern(filename_template) + self._output_path_patterns = _compile_output_path_patterns(filename_template) self._cleanup_ran = False def write_shard_block(self, shard_id, block) -> None: @@ -115,7 +119,6 @@ def _run_cleanup(self) -> None: for row in get_finalized_workers(stage_index=stage_index - 1) } - template_parts = [part for part in self.filename_template.split("/") if part] literal_prefix = "" for literal_text, field_name, _format_spec, _conversion in Formatter().parse( self.filename_template @@ -128,11 +131,15 @@ def _run_cleanup(self) -> None: ) paths = [listing_prefix] prefix_parts = [part for part in listing_prefix.split("/") if part] - for _ in range(max(1, len(template_parts) - len(prefix_parts))): + for pattern in self._output_path_patterns[len(prefix_parts) :]: next_paths: list[str] = [] for path in paths: try: - next_paths.extend(self.output.ls(path, detail=False)) + next_paths.extend( + item + for item in self.output.ls(path, detail=False) + if isinstance(item, str) and pattern.fullmatch(item) + ) except (FileNotFoundError, NotADirectoryError): continue paths = next_paths @@ -145,7 +152,7 @@ def _run_cleanup(self) -> None: continue if rel_path.rstrip("/").endswith("/."): continue - match = self._output_path_pattern.fullmatch(rel_path) + match = self._output_path_patterns[-1].fullmatch(rel_path) if match is None: continue if (match.group("shard_id"), match.group("worker_id")) in keep_pairs: From 8476546e601e69b2dbef4b8f90e2ec40060f21f5 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Mon, 25 May 2026 00:46:18 +0200 Subject: [PATCH 30/39] Split Zarr reducer sinks --- .../pipeline/sinks/reducer/__init__.py | 6 + src/refiner/pipeline/sinks/reducer/zarr.py | 333 +++++++++++++++++ src/refiner/pipeline/sinks/zarr.py | 345 +----------------- tests/readers/test_zarr_reader.py | 26 +- 4 files changed, 369 insertions(+), 341 deletions(-) create mode 100644 src/refiner/pipeline/sinks/reducer/zarr.py diff --git a/src/refiner/pipeline/sinks/reducer/__init__.py b/src/refiner/pipeline/sinks/reducer/__init__.py index 748cc262..c9c65905 100644 --- a/src/refiner/pipeline/sinks/reducer/__init__.py +++ b/src/refiner/pipeline/sinks/reducer/__init__.py @@ -1,7 +1,13 @@ from refiner.pipeline.sinks.reducer.file import FileCleanupReducerSink from refiner.pipeline.sinks.reducer.lerobot import LeRobotMetaReduceSink +from refiner.pipeline.sinks.reducer.zarr import ( + ZarrCleanupReducerSink, + ZarrMergeReducerSink, +) __all__ = [ "FileCleanupReducerSink", "LeRobotMetaReduceSink", + "ZarrCleanupReducerSink", + "ZarrMergeReducerSink", ] diff --git a/src/refiner/pipeline/sinks/reducer/zarr.py b/src/refiner/pipeline/sinks/reducer/zarr.py new file mode 100644 index 00000000..2a20862c --- /dev/null +++ b/src/refiner/pipeline/sinks/reducer/zarr.py @@ -0,0 +1,333 @@ +from __future__ import annotations + +from collections.abc import Iterable +from typing import Any + +import numpy as np + +from refiner.io.datafolder import DataFolder, DataFolderLike +from refiner.pipeline.data.block import Block +from refiner.pipeline.sinks.base import BaseSink +from refiner.pipeline.sinks.reducer.file import FileCleanupReducerSink +from refiner.pipeline.sinks.zarr import ( + _append_zarr_array, + _batch_length, + _render_store_relpath, + _zarr_store, +) +from refiner.utils import check_required_dependencies +from refiner.worker.context import get_active_stage_index, get_finalized_workers +from refiner.worker.lifecycle import sort_finalized_workers + + +class ZarrCleanupReducerSink(BaseSink): + def __init__(self, output: DataFolderLike, *, store_template: str) -> None: + self.output = DataFolder.resolve(output) + self.store_template = store_template + self._cleanup = FileCleanupReducerSink( + output=self.output, + filename_template=store_template, + reducer_name="write_zarr_reduce", + ) + + @property + def counts_output_rows(self) -> bool: + return False + + def write_shard_block(self, shard_id: str, block: Block) -> None: + self._cleanup.write_shard_block(shard_id, block) + stage_index = get_active_stage_index() + if stage_index is None or stage_index <= 0: + raise ValueError( + "write_zarr_reduce requires an active reducer stage with a prior writer stage" + ) + relpaths = [ + _render_store_relpath( + self.store_template, + shard_id=row.shard_id, + worker_id=row.worker_token, + ) + for row in sort_finalized_workers( + get_finalized_workers(stage_index=stage_index - 1) + ) + ] + _validate_zarr_stores(self.output, relpaths) + _remove_parts(self.output) + self._clear_root_payload_except(relpaths) + + def _clear_root_payload_except(self, relpaths: Iterable[str]) -> None: + import zarr + + keep_paths = set(relpaths) + try: + root = zarr.open_group(store=_zarr_store(self.output, "", mode="r+")) + except Exception: + return + + def clear_group(group: Any, prefix: str = "") -> None: + group_keys = set(group.group_keys()) + for key in sorted({*group.array_keys(), *group_keys}): + path = f"{prefix}/{key}" if prefix else key + if path == "_refiner" or path.startswith("_refiner/"): + continue + if path in keep_paths: + continue + if any(keep_path.startswith(f"{path}/") for keep_path in keep_paths): + if key in group_keys: + clear_group(group[key], path) + continue + del group[key] + group.attrs.clear() + + clear_group(root) + + +class ZarrMergeReducerSink(BaseSink): + def __init__( + self, + output: DataFolderLike, + *, + store_template: str, + episode_ends_path: str | None, + array_chunk_bytes: int, + ) -> None: + check_required_dependencies("write_zarr", ["zarr"], dist="zarr") + self.output = DataFolder.resolve(output) + self.store_template = store_template + self.episode_ends_path = episode_ends_path + self.array_chunk_bytes = array_chunk_bytes + self._merged = False + + @property + def counts_output_rows(self) -> bool: + return False + + def write_shard_block(self, shard_id: str, block: Block) -> None: + del shard_id, block + self._merge() + + def describe(self) -> tuple[str, str, dict[str, object]]: + return ( + "write_zarr_reduce", + "writer", + { + "path": self.output.abs_path(), + "store_template": self.store_template, + "array_chunk_bytes": self.array_chunk_bytes, + "reduce_to_single_store": True, + }, + ) + + def _merge(self) -> None: + if self._merged: + return + + stage_index = get_active_stage_index() + if stage_index is None or stage_index <= 0: + raise ValueError( + "write_zarr_reduce requires an active reducer stage with a prior writer stage" + ) + + expected_parts = [ + "_parts/" + + _render_store_relpath( + self.store_template, + shard_id=row.shard_id, + worker_id=row.worker_token, + ) + for row in sort_finalized_workers( + get_finalized_workers(stage_index=stage_index - 1), + ) + ] + if not expected_parts: + import zarr + + final = zarr.open_group( + store=_zarr_store(self.output, "", mode="a"), + mode="a", + ) + _clear_final_group(final) + self._merged = True + return + + parts = self._collect_parts(expected_parts) + + import zarr + + final = zarr.open_group( + store=_zarr_store(self.output, "", mode="a"), + mode="a", + ) + _clear_final_group(final) + + row_offset = 0 + arrays: dict[str, Any] = {} + for relpath, paths in parts: + source = zarr.open_group( + store=_zarr_store(self.output, relpath, mode="r"), + mode="r", + ) + for path in sorted(paths): + source_array = source[path] + if path == self.episode_ends_path: + if source_array.shape[0] == 0: + continue + part_last = row_offset + batch_size = _batch_length( + source_array, + self.array_chunk_bytes, + ) + for start in range(0, int(source_array.shape[0]), batch_size): + end = min(int(source_array.shape[0]), start + batch_size) + values = np.asarray(source_array[start:end], dtype=np.int64) + _append_zarr_array( + final, + arrays, + path, + values + row_offset, + chunks=getattr(source_array, "chunks", None), + compressor=getattr(source_array, "compressor", None), + ) + part_last = int(values[-1]) + row_offset += part_last + continue + batch_size = _batch_length(source_array, self.array_chunk_bytes) + if source_array.shape[0] == 0: + _append_zarr_array( + final, + arrays, + path, + np.asarray(source_array[:0]), + chunks=getattr(source_array, "chunks", None), + compressor=getattr(source_array, "compressor", None), + ) + continue + for start in range(0, int(source_array.shape[0]), batch_size): + end = min(int(source_array.shape[0]), start + batch_size) + _append_zarr_array( + final, + arrays, + path, + np.asarray(source_array[start:end]), + chunks=getattr(source_array, "chunks", None), + compressor=getattr(source_array, "compressor", None), + ) + self._merged = True + + def on_shard_finalized(self, shard_id: str) -> None: + del shard_id + if not self._merged: + return + _remove_parts(self.output) + try: + if not self.output.ls("_parts"): + self.output.rmdir("_parts") + except (FileNotFoundError, OSError, ValueError): + pass + + def _collect_parts( + self, expected_parts: Iterable[str] + ) -> list[tuple[str, set[str]]]: + import zarr + + parts: list[tuple[str, set[str]]] = [] + payload_paths: set[str] | None = None + schemas: dict[str, tuple[tuple[int, ...], np.dtype[Any]]] = {} + for relpath in expected_parts: + if not self.output.exists(relpath): + raise ValueError(f"Zarr part store is missing: {relpath}") + source = zarr.open_group( + store=_zarr_store(self.output, relpath, mode="r"), + mode="r", + ) + source_paths = set(_iter_array_paths(source)) + if not source_paths: + continue + source_payload_paths = { + path for path in source_paths if path != self.episode_ends_path + } + if ( + self.episode_ends_path is not None + and source_payload_paths + and self.episode_ends_path not in source_paths + ): + raise ValueError( + f"Zarr part stores must contain {self.episode_ends_path!r}" + ) + if payload_paths is None: + payload_paths = source_payload_paths + elif source_payload_paths != payload_paths: + raise ValueError( + "Zarr part stores must contain the same payload arrays" + ) + for path in source_paths: + source_array = source[path] + schema = (tuple(source_array.shape[1:]), np.dtype(source_array.dtype)) + previous = schemas.setdefault(path, schema) + if previous != schema: + if previous[0] != schema[0]: + raise ValueError( + f"Zarr arrays for {path!r} must have matching trailing shapes" + ) + raise ValueError( + f"Zarr arrays for {path!r} must have matching dtypes" + ) + parts.append((relpath, source_paths)) + return parts + + +def _iter_array_paths(group: Any, prefix: str = "") -> Iterable[str]: + for name, item in group.items(): + path = f"{prefix}/{name}" if prefix else name + if hasattr(item, "shape"): + yield path + else: + yield from _iter_array_paths(item, path) + + +def _remove_parts(output: DataFolder) -> None: + try: + output.rm("_parts", recursive=True) + except FileNotFoundError: + pass + + +def _validate_zarr_stores(output: DataFolder, relpaths: Iterable[str]) -> None: + import zarr + + payload_paths: set[str] | None = None + schemas: dict[str, tuple[tuple[int, ...], np.dtype[Any]]] = {} + for relpath in relpaths: + if not output.exists(relpath): + raise ValueError(f"Zarr store is missing: {relpath}") + source = zarr.open_group( + store=_zarr_store(output, relpath, mode="r"), + mode="r", + ) + source_paths = set(_iter_array_paths(source)) + if not source_paths: + continue + if payload_paths is None: + payload_paths = source_paths + elif source_paths != payload_paths: + raise ValueError("Zarr stores must contain the same arrays") + for path in source_paths: + source_array = source[path] + schema = (tuple(source_array.shape[1:]), np.dtype(source_array.dtype)) + previous = schemas.setdefault(path, schema) + if previous != schema: + if previous[0] != schema[0]: + raise ValueError( + f"Zarr arrays for {path!r} must have matching trailing shapes" + ) + raise ValueError(f"Zarr arrays for {path!r} must have matching dtypes") + + +def _clear_final_group(group: Any) -> None: + for key in sorted({*group.array_keys(), *group.group_keys()}): + if key != "_parts": + del group[key] + group.attrs.clear() + + +__all__ = ["ZarrCleanupReducerSink", "ZarrMergeReducerSink"] diff --git a/src/refiner/pipeline/sinks/zarr.py b/src/refiner/pipeline/sinks/zarr.py index e1ba7e02..34c8fa62 100644 --- a/src/refiner/pipeline/sinks/zarr.py +++ b/src/refiner/pipeline/sinks/zarr.py @@ -13,23 +13,17 @@ from refiner.pipeline.data.block import Block from refiner.pipeline.data.row import Row from refiner.pipeline.sinks.base import BaseSink -from refiner.pipeline.sinks.reducer.file import FileCleanupReducerSink from refiner.robotics.row import RoboticsRow from refiner.utils import check_required_dependencies from refiner.video import VideoFrameArray, VideoSource -from refiner.worker.context import ( - get_active_stage_index, - get_active_worker_token, - get_finalized_workers, -) -from refiner.worker.lifecycle import sort_finalized_workers +from refiner.worker.context import get_active_worker_token _DEFAULT_ARRAY_CHUNK_BYTES = 8 * 1024 * 1024 _MAX_INITIAL_CHUNK_ROWS = 1024 @dataclass -class _ShardStore: +class _ZarrWriteState: root: Any arrays: dict[str, Any] = field(default_factory=dict) row_end: int = 0 @@ -64,12 +58,12 @@ def __init__( self.video_frame_batch_size = video_frame_batch_size self.array_chunk_bytes = array_chunk_bytes self.reduce_to_single_store = reduce_to_single_store - self._stores: dict[str, _ShardStore] = {} + self._stores: dict[str, _ZarrWriteState] = {} self._default_arrays: dict[str, str] | None = None def write_shard_block(self, shard_id: str, block: Block) -> int: count = 0 - pending_store: _ShardStore | None = None + pending_store: _ZarrWriteState | None = None pending_arrays: dict[str, list[np.ndarray]] = {} pending_lengths: list[int] = [] pending_bytes = 0 @@ -265,7 +259,7 @@ def _row_values( async def _append_video( self, - store: _ShardStore, + store: _ZarrWriteState, path: str, video: VideoSource, *, @@ -363,14 +357,14 @@ def _arrays_for_row(self, row: Row) -> dict[str, str]: ) return self._default_arrays - def _store(self, shard_id: str) -> _ShardStore: + def _store(self, shard_id: str) -> _ZarrWriteState: relpath = self._store_relpath(shard_id) store = self._stores.get(relpath) if store is not None: return store import zarr - store = _ShardStore( + store = _ZarrWriteState( zarr.open_group(store=_zarr_store(self.output, relpath, mode="w"), mode="w") ) self._stores[relpath] = store @@ -396,7 +390,7 @@ def on_shard_complete(self, shard_id: str) -> None: def _append_array( self, - store: _ShardStore, + store: _ZarrWriteState, path: str, array: np.ndarray, *, @@ -412,7 +406,7 @@ def _append_array( def _validate_array_append( self, - store: _ShardStore, + store: _ZarrWriteState, path: str, array: np.ndarray, ) -> None: @@ -421,7 +415,7 @@ def _validate_array_append( return _validate_array_schema(path, dataset, array) - def _drop_array(self, store: _ShardStore, path: str) -> None: + def _drop_array(self, store: _ZarrWriteState, path: str) -> None: store.arrays.pop(path, None) try: del store.root[path] @@ -447,275 +441,24 @@ def describe(self) -> tuple[str, str, dict[str, object]]: ) def build_reducer(self) -> BaseSink | None: + from refiner.pipeline.sinks.reducer.zarr import ( + ZarrCleanupReducerSink, + ZarrMergeReducerSink, + ) + if self.reduce_to_single_store: - return _ZarrMergeReducerSink( + return ZarrMergeReducerSink( output=self.output, store_template=self.store_template, episode_ends_path=self.episode_ends_path, array_chunk_bytes=self.array_chunk_bytes, ) - return _ZarrCleanupReducerSink( + return ZarrCleanupReducerSink( output=self.output, store_template=self.store_template, ) -class _ZarrCleanupReducerSink(BaseSink): - def __init__(self, output: DataFolderLike, *, store_template: str) -> None: - self.output = DataFolder.resolve(output) - self.store_template = store_template - self._cleanup = FileCleanupReducerSink( - output=self.output, - filename_template=store_template, - reducer_name="write_zarr_reduce", - ) - - @property - def counts_output_rows(self) -> bool: - return False - - def write_shard_block(self, shard_id, block) -> None: - self._cleanup.write_shard_block(shard_id, block) - stage_index = get_active_stage_index() - if stage_index is None or stage_index <= 0: - raise ValueError( - "write_zarr_reduce requires an active reducer stage with a prior writer stage" - ) - relpaths = [ - _render_store_relpath( - self.store_template, - shard_id=row.shard_id, - worker_id=row.worker_token, - ) - for row in sort_finalized_workers( - get_finalized_workers(stage_index=stage_index - 1) - ) - ] - _validate_zarr_stores(self.output, relpaths) - _remove_parts(self.output) - self._clear_root_payload_except(relpaths) - - def _clear_root_payload_except(self, relpaths: Iterable[str]) -> None: - import zarr - - keep_paths = set(relpaths) - try: - root = zarr.open_group(store=_zarr_store(self.output, "", mode="r+")) - except Exception: - return - - def clear_group(group: Any, prefix: str = "") -> None: - group_keys = set(group.group_keys()) - for key in sorted({*group.array_keys(), *group_keys}): - path = f"{prefix}/{key}" if prefix else key - if path == "_refiner" or path.startswith("_refiner/"): - continue - if path in keep_paths: - continue - if any(keep_path.startswith(f"{path}/") for keep_path in keep_paths): - if key in group_keys: - clear_group(group[key], path) - continue - del group[key] - group.attrs.clear() - - clear_group(root) - - -class _ZarrMergeReducerSink(BaseSink): - def __init__( - self, - output: DataFolderLike, - *, - store_template: str, - episode_ends_path: str | None, - array_chunk_bytes: int, - ) -> None: - check_required_dependencies("write_zarr", ["zarr"], dist="zarr") - self.output = DataFolder.resolve(output) - self.store_template = store_template - self.episode_ends_path = episode_ends_path - self.array_chunk_bytes = array_chunk_bytes - self._merged = False - - @property - def counts_output_rows(self) -> bool: - return False - - def write_shard_block(self, shard_id, block) -> None: - del shard_id, block - self._merge() - - def describe(self) -> tuple[str, str, dict[str, object]]: - return ( - "write_zarr_reduce", - "writer", - { - "path": self.output.abs_path(), - "store_template": self.store_template, - "array_chunk_bytes": self.array_chunk_bytes, - "reduce_to_single_store": True, - }, - ) - - def _merge(self) -> None: - if self._merged: - return - - stage_index = get_active_stage_index() - if stage_index is None or stage_index <= 0: - raise ValueError( - "write_zarr_reduce requires an active reducer stage with a prior writer stage" - ) - - expected_parts = [ - "_parts/" - + _render_store_relpath( - self.store_template, - shard_id=row.shard_id, - worker_id=row.worker_token, - ) - for row in sort_finalized_workers( - get_finalized_workers(stage_index=stage_index - 1), - ) - ] - if not expected_parts: - import zarr - - final = zarr.open_group( - store=_zarr_store(self.output, "", mode="a"), - mode="a", - ) - _clear_final_group(final) - self._merged = True - return - - parts = self._collect_parts(expected_parts) - - import zarr - - final = zarr.open_group( - store=_zarr_store(self.output, "", mode="a"), - mode="a", - ) - _clear_final_group(final) - - row_offset = 0 - arrays: dict[str, Any] = {} - for relpath, paths in parts: - source = zarr.open_group( - store=_zarr_store(self.output, relpath, mode="r"), - mode="r", - ) - for path in sorted(paths): - source_array = source[path] - if path == self.episode_ends_path: - if source_array.shape[0] == 0: - continue - part_last = row_offset - batch_size = _batch_length( - source_array, - self.array_chunk_bytes, - ) - for start in range(0, int(source_array.shape[0]), batch_size): - end = min(int(source_array.shape[0]), start + batch_size) - values = np.asarray(source_array[start:end], dtype=np.int64) - _append_zarr_array( - final, - arrays, - path, - values + row_offset, - chunks=getattr(source_array, "chunks", None), - compressor=getattr(source_array, "compressor", None), - ) - part_last = int(values[-1]) - row_offset += part_last - continue - batch_size = _batch_length(source_array, self.array_chunk_bytes) - if source_array.shape[0] == 0: - _append_zarr_array( - final, - arrays, - path, - np.asarray(source_array[:0]), - chunks=getattr(source_array, "chunks", None), - compressor=getattr(source_array, "compressor", None), - ) - continue - for start in range(0, int(source_array.shape[0]), batch_size): - end = min(int(source_array.shape[0]), start + batch_size) - _append_zarr_array( - final, - arrays, - path, - np.asarray(source_array[start:end]), - chunks=getattr(source_array, "chunks", None), - compressor=getattr(source_array, "compressor", None), - ) - self._merged = True - - def on_shard_finalized(self, shard_id: str) -> None: - del shard_id - if not self._merged: - return - _remove_parts(self.output) - try: - if not self.output.ls("_parts"): - self.output.rmdir("_parts") - except (FileNotFoundError, OSError, ValueError): - pass - - def _collect_parts( - self, expected_parts: Iterable[str] - ) -> list[tuple[str, set[str]]]: - import zarr - - parts: list[tuple[str, set[str]]] = [] - payload_paths: set[str] | None = None - schemas: dict[str, tuple[tuple[int, ...], np.dtype[Any]]] = {} - for relpath in expected_parts: - if not self.output.exists(relpath): - raise ValueError(f"Zarr part store is missing: {relpath}") - source = zarr.open_group( - store=_zarr_store(self.output, relpath, mode="r"), - mode="r", - ) - source_paths = set(_iter_array_paths(source)) - if not source_paths: - continue - source_payload_paths = { - path for path in source_paths if path != self.episode_ends_path - } - if ( - self.episode_ends_path is not None - and source_payload_paths - and self.episode_ends_path not in source_paths - ): - raise ValueError( - f"Zarr part stores must contain {self.episode_ends_path!r}" - ) - if payload_paths is None: - payload_paths = source_payload_paths - elif source_payload_paths != payload_paths: - raise ValueError( - "Zarr part stores must contain the same payload arrays" - ) - for path in source_paths: - source_array = source[path] - schema = (tuple(source_array.shape[1:]), np.dtype(source_array.dtype)) - previous = schemas.setdefault(path, schema) - if previous != schema: - if previous[0] != schema[0]: - raise ValueError( - f"Zarr arrays for {path!r} must have matching trailing shapes" - ) - raise ValueError( - f"Zarr arrays for {path!r} must have matching dtypes" - ) - parts.append((relpath, source_paths)) - return parts - - def _default_robotics_arrays(row: Row) -> dict[str, str]: if not isinstance(row, RoboticsRow): raise ValueError("write_zarr requires arrays=... for non-RoboticsRow inputs") @@ -833,60 +576,6 @@ def _zarr_store(output: DataFolder, path: str = "", *, mode: str = "r"): ) -def _iter_array_paths(group: Any, prefix: str = "") -> Iterable[str]: - for name, item in group.items(): - path = f"{prefix}/{name}" if prefix else name - if hasattr(item, "shape"): - yield path - else: - yield from _iter_array_paths(item, path) - - -def _remove_parts(output: DataFolder) -> None: - try: - output.rm("_parts", recursive=True) - except FileNotFoundError: - pass - - -def _validate_zarr_stores(output: DataFolder, relpaths: Iterable[str]) -> None: - import zarr - - payload_paths: set[str] | None = None - schemas: dict[str, tuple[tuple[int, ...], np.dtype[Any]]] = {} - for relpath in relpaths: - if not output.exists(relpath): - raise ValueError(f"Zarr store is missing: {relpath}") - source = zarr.open_group( - store=_zarr_store(output, relpath, mode="r"), - mode="r", - ) - source_paths = set(_iter_array_paths(source)) - if not source_paths: - continue - if payload_paths is None: - payload_paths = source_paths - elif source_paths != payload_paths: - raise ValueError("Zarr stores must contain the same arrays") - for path in source_paths: - source_array = source[path] - schema = (tuple(source_array.shape[1:]), np.dtype(source_array.dtype)) - previous = schemas.setdefault(path, schema) - if previous != schema: - if previous[0] != schema[0]: - raise ValueError( - f"Zarr arrays for {path!r} must have matching trailing shapes" - ) - raise ValueError(f"Zarr arrays for {path!r} must have matching dtypes") - - -def _clear_final_group(group: Any) -> None: - for key in sorted({*group.array_keys(), *group.group_keys()}): - if key != "_parts": - del group[key] - group.attrs.clear() - - def _chunk_shape(array: np.ndarray, target_bytes: int) -> tuple[int, ...]: chunk_rows = min( _batch_length(array, target_bytes), diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py index ecbe687e..c6283667 100644 --- a/tests/readers/test_zarr_reader.py +++ b/tests/readers/test_zarr_reader.py @@ -16,11 +16,11 @@ from refiner.pipeline.data.row import DictRow from refiner.pipeline.data.row import Row from refiner.pipeline.data.shard import RowRangeDescriptor -from refiner.pipeline.sinks.zarr import ( - _ZarrCleanupReducerSink, - _ZarrMergeReducerSink, - ZarrSink, +from refiner.pipeline.sinks.reducer.zarr import ( + ZarrCleanupReducerSink, + ZarrMergeReducerSink, ) +from refiner.pipeline.sinks.zarr import ZarrSink from refiner.worker.context import set_active_run_context, worker_token_for from refiner.worker.lifecycle import FinalizedShardWorker, RuntimeLifecycle @@ -1113,7 +1113,7 @@ def test_write_zarr_non_reduced_cleanup_rejects_missing_finalized_store( runtime_lifecycle=cast(RuntimeLifecycle, runtime), ): with pytest.raises(ValueError, match="Zarr store is missing"): - _ZarrCleanupReducerSink( + ZarrCleanupReducerSink( str(zarr_out), store_template="{shard_id}__w{worker_id}.zarr", ).write_block([DictRow({}, shard_id="reduce")]) @@ -1164,7 +1164,7 @@ def test_write_zarr_non_reduced_cleanup_keeps_empty_stores_retryable( worker_name=None, runtime_lifecycle=cast(RuntimeLifecycle, runtime), ): - _ZarrCleanupReducerSink( + ZarrCleanupReducerSink( str(zarr_out), store_template="{shard_id}__w{worker_id}.zarr", ).write_block([DictRow({}, shard_id="reduce")]) @@ -1374,7 +1374,7 @@ def test_write_zarr_single_store_rejects_inconsistent_part_payloads( runtime_lifecycle=cast(RuntimeLifecycle, runtime), ): with pytest.raises(ValueError, match="same payload arrays"): - _ZarrMergeReducerSink( + ZarrMergeReducerSink( str(zarr_out), store_template="{shard_id}__w{worker_id}.zarr", episode_ends_path="meta/episode_ends", @@ -1411,7 +1411,7 @@ def test_write_zarr_single_store_rejects_part_missing_episode_ends( runtime_lifecycle=cast(RuntimeLifecycle, runtime), ): with pytest.raises(ValueError, match="meta/episode_ends"): - _ZarrMergeReducerSink( + ZarrMergeReducerSink( str(zarr_out), store_template="{shard_id}__w{worker_id}.zarr", episode_ends_path="meta/episode_ends", @@ -1448,7 +1448,7 @@ def test_write_zarr_single_store_rejects_missing_finalized_part( runtime_lifecycle=cast(RuntimeLifecycle, runtime), ): with pytest.raises(ValueError, match="part store is missing"): - _ZarrMergeReducerSink( + ZarrMergeReducerSink( str(zarr_out), store_template="{shard_id}__w{worker_id}.zarr", episode_ends_path="meta/episode_ends", @@ -1494,7 +1494,7 @@ def test_write_zarr_single_store_removes_parts_only_on_completion( worker_name=None, runtime_lifecycle=cast(RuntimeLifecycle, runtime), ): - reducer = _ZarrMergeReducerSink( + reducer = ZarrMergeReducerSink( str(zarr_out), store_template="{shard_id}__w{worker_id}.zarr", episode_ends_path="meta/episode_ends", @@ -1534,7 +1534,7 @@ def test_write_zarr_single_store_zero_shard_replace_clears_existing_output( worker_name=None, runtime_lifecycle=cast(RuntimeLifecycle, runtime), ): - _ZarrMergeReducerSink( + ZarrMergeReducerSink( str(zarr_out), store_template="{shard_id}__w{worker_id}.zarr", episode_ends_path="meta/episode_ends", @@ -1573,7 +1573,7 @@ def test_write_zarr_single_store_parts_are_resume_stable(tmp_path: Path) -> None worker_name=None, runtime_lifecycle=cast(RuntimeLifecycle, runtime), ): - _ZarrMergeReducerSink( + ZarrMergeReducerSink( str(zarr_out), store_template="{shard_id}__w{worker_id}.zarr", episode_ends_path="meta/episode_ends", @@ -1663,7 +1663,7 @@ def test_write_zarr_single_store_rejects_part_dtype_drift( runtime_lifecycle=cast(RuntimeLifecycle, runtime), ): with pytest.raises(ValueError, match="matching dtypes"): - _ZarrMergeReducerSink( + ZarrMergeReducerSink( str(zarr_out), store_template="{shard_id}__w{worker_id}.zarr", episode_ends_path="meta/episode_ends", From d80c8bfa12470fc0396127e4411492111d69e4fa Mon Sep 17 00:00:00 2001 From: guipenedo Date: Mon, 25 May 2026 00:52:20 +0200 Subject: [PATCH 31/39] Unify Zarr reducer sink --- .../pipeline/sinks/reducer/__init__.py | 8 +-- src/refiner/pipeline/sinks/reducer/zarr.py | 70 ++++++++----------- src/refiner/pipeline/sinks/zarr.py | 17 ++--- tests/readers/test_zarr_reader.py | 30 ++++---- 4 files changed, 53 insertions(+), 72 deletions(-) diff --git a/src/refiner/pipeline/sinks/reducer/__init__.py b/src/refiner/pipeline/sinks/reducer/__init__.py index c9c65905..9363a80f 100644 --- a/src/refiner/pipeline/sinks/reducer/__init__.py +++ b/src/refiner/pipeline/sinks/reducer/__init__.py @@ -1,13 +1,9 @@ from refiner.pipeline.sinks.reducer.file import FileCleanupReducerSink from refiner.pipeline.sinks.reducer.lerobot import LeRobotMetaReduceSink -from refiner.pipeline.sinks.reducer.zarr import ( - ZarrCleanupReducerSink, - ZarrMergeReducerSink, -) +from refiner.pipeline.sinks.reducer.zarr import ZarrReducerSink __all__ = [ "FileCleanupReducerSink", "LeRobotMetaReduceSink", - "ZarrCleanupReducerSink", - "ZarrMergeReducerSink", + "ZarrReducerSink", ] diff --git a/src/refiner/pipeline/sinks/reducer/zarr.py b/src/refiner/pipeline/sinks/reducer/zarr.py index 2a20862c..b1caca89 100644 --- a/src/refiner/pipeline/sinks/reducer/zarr.py +++ b/src/refiner/pipeline/sinks/reducer/zarr.py @@ -7,7 +7,6 @@ from refiner.io.datafolder import DataFolder, DataFolderLike from refiner.pipeline.data.block import Block -from refiner.pipeline.sinks.base import BaseSink from refiner.pipeline.sinks.reducer.file import FileCleanupReducerSink from refiner.pipeline.sinks.zarr import ( _append_zarr_array, @@ -20,22 +19,36 @@ from refiner.worker.lifecycle import sort_finalized_workers -class ZarrCleanupReducerSink(BaseSink): - def __init__(self, output: DataFolderLike, *, store_template: str) -> None: - self.output = DataFolder.resolve(output) - self.store_template = store_template - self._cleanup = FileCleanupReducerSink( - output=self.output, - filename_template=store_template, +class ZarrReducerSink(FileCleanupReducerSink): + def __init__( + self, + output: DataFolderLike, + *, + store_template: str, + episode_ends_path: str | None = None, + array_chunk_bytes: int = 8 * 1024 * 1024, + reduce_to_single_store: bool = False, + ) -> None: + check_required_dependencies("write_zarr", ["zarr"], dist="zarr") + super().__init__( + output=output, + filename_template=( + f"_parts/{store_template}" if reduce_to_single_store else store_template + ), reducer_name="write_zarr_reduce", ) - - @property - def counts_output_rows(self) -> bool: - return False + self.store_template = store_template + self.episode_ends_path = episode_ends_path + self.array_chunk_bytes = array_chunk_bytes + self.reduce_to_single_store = reduce_to_single_store + self._merged = False def write_shard_block(self, shard_id: str, block: Block) -> None: - self._cleanup.write_shard_block(shard_id, block) + super().write_shard_block(shard_id, block) + if self.reduce_to_single_store: + self._merge() + return + stage_index = get_active_stage_index() if stage_index is None or stage_index <= 0: raise ValueError( @@ -81,31 +94,6 @@ def clear_group(group: Any, prefix: str = "") -> None: clear_group(root) - -class ZarrMergeReducerSink(BaseSink): - def __init__( - self, - output: DataFolderLike, - *, - store_template: str, - episode_ends_path: str | None, - array_chunk_bytes: int, - ) -> None: - check_required_dependencies("write_zarr", ["zarr"], dist="zarr") - self.output = DataFolder.resolve(output) - self.store_template = store_template - self.episode_ends_path = episode_ends_path - self.array_chunk_bytes = array_chunk_bytes - self._merged = False - - @property - def counts_output_rows(self) -> bool: - return False - - def write_shard_block(self, shard_id: str, block: Block) -> None: - del shard_id, block - self._merge() - def describe(self) -> tuple[str, str, dict[str, object]]: return ( "write_zarr_reduce", @@ -114,7 +102,7 @@ def describe(self) -> tuple[str, str, dict[str, object]]: "path": self.output.abs_path(), "store_template": self.store_template, "array_chunk_bytes": self.array_chunk_bytes, - "reduce_to_single_store": True, + "reduce_to_single_store": self.reduce_to_single_store, }, ) @@ -216,7 +204,7 @@ def _merge(self) -> None: def on_shard_finalized(self, shard_id: str) -> None: del shard_id - if not self._merged: + if not self.reduce_to_single_store or not self._merged: return _remove_parts(self.output) try: @@ -330,4 +318,4 @@ def _clear_final_group(group: Any) -> None: group.attrs.clear() -__all__ = ["ZarrCleanupReducerSink", "ZarrMergeReducerSink"] +__all__ = ["ZarrReducerSink"] diff --git a/src/refiner/pipeline/sinks/zarr.py b/src/refiner/pipeline/sinks/zarr.py index 34c8fa62..9601f319 100644 --- a/src/refiner/pipeline/sinks/zarr.py +++ b/src/refiner/pipeline/sinks/zarr.py @@ -441,21 +441,14 @@ def describe(self) -> tuple[str, str, dict[str, object]]: ) def build_reducer(self) -> BaseSink | None: - from refiner.pipeline.sinks.reducer.zarr import ( - ZarrCleanupReducerSink, - ZarrMergeReducerSink, - ) + from refiner.pipeline.sinks.reducer.zarr import ZarrReducerSink - if self.reduce_to_single_store: - return ZarrMergeReducerSink( - output=self.output, - store_template=self.store_template, - episode_ends_path=self.episode_ends_path, - array_chunk_bytes=self.array_chunk_bytes, - ) - return ZarrCleanupReducerSink( + return ZarrReducerSink( output=self.output, store_template=self.store_template, + episode_ends_path=self.episode_ends_path, + array_chunk_bytes=self.array_chunk_bytes, + reduce_to_single_store=self.reduce_to_single_store, ) diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py index c6283667..937d0d02 100644 --- a/tests/readers/test_zarr_reader.py +++ b/tests/readers/test_zarr_reader.py @@ -16,10 +16,7 @@ from refiner.pipeline.data.row import DictRow from refiner.pipeline.data.row import Row from refiner.pipeline.data.shard import RowRangeDescriptor -from refiner.pipeline.sinks.reducer.zarr import ( - ZarrCleanupReducerSink, - ZarrMergeReducerSink, -) +from refiner.pipeline.sinks.reducer.zarr import ZarrReducerSink from refiner.pipeline.sinks.zarr import ZarrSink from refiner.worker.context import set_active_run_context, worker_token_for from refiner.worker.lifecycle import FinalizedShardWorker, RuntimeLifecycle @@ -1113,7 +1110,7 @@ def test_write_zarr_non_reduced_cleanup_rejects_missing_finalized_store( runtime_lifecycle=cast(RuntimeLifecycle, runtime), ): with pytest.raises(ValueError, match="Zarr store is missing"): - ZarrCleanupReducerSink( + ZarrReducerSink( str(zarr_out), store_template="{shard_id}__w{worker_id}.zarr", ).write_block([DictRow({}, shard_id="reduce")]) @@ -1164,7 +1161,7 @@ def test_write_zarr_non_reduced_cleanup_keeps_empty_stores_retryable( worker_name=None, runtime_lifecycle=cast(RuntimeLifecycle, runtime), ): - ZarrCleanupReducerSink( + ZarrReducerSink( str(zarr_out), store_template="{shard_id}__w{worker_id}.zarr", ).write_block([DictRow({}, shard_id="reduce")]) @@ -1374,11 +1371,12 @@ def test_write_zarr_single_store_rejects_inconsistent_part_payloads( runtime_lifecycle=cast(RuntimeLifecycle, runtime), ): with pytest.raises(ValueError, match="same payload arrays"): - ZarrMergeReducerSink( + ZarrReducerSink( str(zarr_out), store_template="{shard_id}__w{worker_id}.zarr", episode_ends_path="meta/episode_ends", array_chunk_bytes=1024, + reduce_to_single_store=True, ).write_block([DictRow({}, shard_id="reduce")]) assert first_part.exists() assert second_part.exists() @@ -1411,11 +1409,12 @@ def test_write_zarr_single_store_rejects_part_missing_episode_ends( runtime_lifecycle=cast(RuntimeLifecycle, runtime), ): with pytest.raises(ValueError, match="meta/episode_ends"): - ZarrMergeReducerSink( + ZarrReducerSink( str(zarr_out), store_template="{shard_id}__w{worker_id}.zarr", episode_ends_path="meta/episode_ends", array_chunk_bytes=1024, + reduce_to_single_store=True, ).write_block([DictRow({}, shard_id="reduce")]) @@ -1448,11 +1447,12 @@ def test_write_zarr_single_store_rejects_missing_finalized_part( runtime_lifecycle=cast(RuntimeLifecycle, runtime), ): with pytest.raises(ValueError, match="part store is missing"): - ZarrMergeReducerSink( + ZarrReducerSink( str(zarr_out), store_template="{shard_id}__w{worker_id}.zarr", episode_ends_path="meta/episode_ends", array_chunk_bytes=1024, + reduce_to_single_store=True, ).write_block([DictRow({}, shard_id="reduce")]) row = mdr.read_zarr( @@ -1494,11 +1494,12 @@ def test_write_zarr_single_store_removes_parts_only_on_completion( worker_name=None, runtime_lifecycle=cast(RuntimeLifecycle, runtime), ): - reducer = ZarrMergeReducerSink( + reducer = ZarrReducerSink( str(zarr_out), store_template="{shard_id}__w{worker_id}.zarr", episode_ends_path="meta/episode_ends", array_chunk_bytes=1024, + reduce_to_single_store=True, ) reducer.write_block([DictRow({}, shard_id="reduce")]) assert part.exists() @@ -1534,11 +1535,12 @@ def test_write_zarr_single_store_zero_shard_replace_clears_existing_output( worker_name=None, runtime_lifecycle=cast(RuntimeLifecycle, runtime), ): - ZarrMergeReducerSink( + ZarrReducerSink( str(zarr_out), store_template="{shard_id}__w{worker_id}.zarr", episode_ends_path="meta/episode_ends", array_chunk_bytes=1024, + reduce_to_single_store=True, ).write_block([DictRow({}, shard_id="reduce")]) root = _open_test_zarr(zarr_out, mode="r") @@ -1573,11 +1575,12 @@ def test_write_zarr_single_store_parts_are_resume_stable(tmp_path: Path) -> None worker_name=None, runtime_lifecycle=cast(RuntimeLifecycle, runtime), ): - ZarrMergeReducerSink( + ZarrReducerSink( str(zarr_out), store_template="{shard_id}__w{worker_id}.zarr", episode_ends_path="meta/episode_ends", array_chunk_bytes=1024, + reduce_to_single_store=True, ).write_block([DictRow({}, shard_id="reduce")]) row = mdr.read_zarr( @@ -1663,11 +1666,12 @@ def test_write_zarr_single_store_rejects_part_dtype_drift( runtime_lifecycle=cast(RuntimeLifecycle, runtime), ): with pytest.raises(ValueError, match="matching dtypes"): - ZarrMergeReducerSink( + ZarrReducerSink( str(zarr_out), store_template="{shard_id}__w{worker_id}.zarr", episode_ends_path="meta/episode_ends", array_chunk_bytes=1024, + reduce_to_single_store=True, ).write_block([DictRow({}, shard_id="reduce")]) assert first_part.exists() assert second_part.exists() From a84e93b6100bfe7673e39b4267f3b79ca71b9f10 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Mon, 25 May 2026 00:56:10 +0200 Subject: [PATCH 32/39] Simplify reducer cleanup code --- src/refiner/pipeline/sinks/reducer/file.py | 28 ++-- src/refiner/pipeline/sinks/reducer/zarr.py | 184 ++++++++------------- 2 files changed, 76 insertions(+), 136 deletions(-) diff --git a/src/refiner/pipeline/sinks/reducer/file.py b/src/refiner/pipeline/sinks/reducer/file.py index 9f5b732e..f32b079f 100644 --- a/src/refiner/pipeline/sinks/reducer/file.py +++ b/src/refiner/pipeline/sinks/reducer/file.py @@ -112,10 +112,7 @@ def _run_cleanup(self) -> None: ) keep_pairs = { - ( - row.shard_id, - row.worker_token, - ) + (row.shard_id, row.worker_token) for row in get_finalized_workers(stage_index=stage_index - 1) } @@ -138,7 +135,7 @@ def _run_cleanup(self) -> None: next_paths.extend( item for item in self.output.ls(path, detail=False) - if isinstance(item, str) and pattern.fullmatch(item) + if pattern.fullmatch(item) ) except (FileNotFoundError, NotADirectoryError): continue @@ -148,16 +145,11 @@ def _run_cleanup(self) -> None: # Extra template fields are structure only. Authority is decided from # the finalized (shard_id, worker_id) pair extracted from the path. for rel_path in paths: - if not isinstance(rel_path, str) or not rel_path or rel_path == ".": - continue - if rel_path.rstrip("/").endswith("/."): - continue match = self._output_path_patterns[-1].fullmatch(rel_path) if match is None: continue - if (match.group("shard_id"), match.group("worker_id")) in keep_pairs: - continue - paths_to_delete.add(rel_path) + if (match.group("shard_id"), match.group("worker_id")) not in keep_pairs: + paths_to_delete.add(rel_path) if self.assets_subdir is not None: asset_prefix = f"{self.assets_subdir.rstrip('/')}/" @@ -166,17 +158,17 @@ def _run_cleanup(self) -> None: except FileNotFoundError: asset_paths = [] for rel_path in asset_paths: - if not isinstance(rel_path, str) or not rel_path.startswith( - asset_prefix - ): + if not rel_path.startswith(asset_prefix): continue attempt_dir = rel_path[len(asset_prefix) :].split("/", maxsplit=1)[0] match = ASSET_ATTEMPT_DIR_RE.fullmatch(attempt_dir) if match is None: continue - if (match.group("shard_id"), match.group("worker_id")) in keep_pairs: - continue - paths_to_delete.add(f"{asset_prefix}{attempt_dir}") + if ( + match.group("shard_id"), + match.group("worker_id"), + ) not in keep_pairs: + paths_to_delete.add(f"{asset_prefix}{attempt_dir}") for path in sorted(paths_to_delete): try: diff --git a/src/refiner/pipeline/sinks/reducer/zarr.py b/src/refiner/pipeline/sinks/reducer/zarr.py index b1caca89..f60e3ddc 100644 --- a/src/refiner/pipeline/sinks/reducer/zarr.py +++ b/src/refiner/pipeline/sinks/reducer/zarr.py @@ -49,22 +49,8 @@ def write_shard_block(self, shard_id: str, block: Block) -> None: self._merge() return - stage_index = get_active_stage_index() - if stage_index is None or stage_index <= 0: - raise ValueError( - "write_zarr_reduce requires an active reducer stage with a prior writer stage" - ) - relpaths = [ - _render_store_relpath( - self.store_template, - shard_id=row.shard_id, - worker_id=row.worker_token, - ) - for row in sort_finalized_workers( - get_finalized_workers(stage_index=stage_index - 1) - ) - ] - _validate_zarr_stores(self.output, relpaths) + relpaths = self._finalized_store_paths() + self._collect_stores(relpaths, for_merge=False) _remove_parts(self.output) self._clear_root_payload_except(relpaths) @@ -110,26 +96,10 @@ def _merge(self) -> None: if self._merged: return - stage_index = get_active_stage_index() - if stage_index is None or stage_index <= 0: - raise ValueError( - "write_zarr_reduce requires an active reducer stage with a prior writer stage" - ) + expected_parts = self._finalized_store_paths(prefix="_parts/") + import zarr - expected_parts = [ - "_parts/" - + _render_store_relpath( - self.store_template, - shard_id=row.shard_id, - worker_id=row.worker_token, - ) - for row in sort_finalized_workers( - get_finalized_workers(stage_index=stage_index - 1), - ) - ] if not expected_parts: - import zarr - final = zarr.open_group( store=_zarr_store(self.output, "", mode="a"), mode="a", @@ -138,10 +108,7 @@ def _merge(self) -> None: self._merged = True return - parts = self._collect_parts(expected_parts) - - import zarr - + parts = self._collect_stores(expected_parts, for_merge=True) final = zarr.open_group( store=_zarr_store(self.output, "", mode="a"), mode="a", @@ -157,49 +124,37 @@ def _merge(self) -> None: ) for path in sorted(paths): source_array = source[path] - if path == self.episode_ends_path: - if source_array.shape[0] == 0: - continue - part_last = row_offset - batch_size = _batch_length( - source_array, - self.array_chunk_bytes, - ) - for start in range(0, int(source_array.shape[0]), batch_size): - end = min(int(source_array.shape[0]), start + batch_size) - values = np.asarray(source_array[start:end], dtype=np.int64) - _append_zarr_array( - final, - arrays, - path, - values + row_offset, - chunks=getattr(source_array, "chunks", None), - compressor=getattr(source_array, "compressor", None), - ) - part_last = int(values[-1]) - row_offset += part_last + chunks = getattr(source_array, "chunks", None) + compressor = getattr(source_array, "compressor", None) + if source_array.shape[0] == 0 and path == self.episode_ends_path: continue - batch_size = _batch_length(source_array, self.array_chunk_bytes) if source_array.shape[0] == 0: _append_zarr_array( final, arrays, path, np.asarray(source_array[:0]), - chunks=getattr(source_array, "chunks", None), - compressor=getattr(source_array, "compressor", None), + chunks=chunks, + compressor=compressor, ) continue - for start in range(0, int(source_array.shape[0]), batch_size): - end = min(int(source_array.shape[0]), start + batch_size) + + part_end = 0 + for values in _array_batches(source_array, self.array_chunk_bytes): + if path == self.episode_ends_path: + values = np.asarray(values, dtype=np.int64) + part_end = int(values[-1]) + values = values + row_offset _append_zarr_array( final, arrays, path, - np.asarray(source_array[start:end]), - chunks=getattr(source_array, "chunks", None), - compressor=getattr(source_array, "compressor", None), + values, + chunks=chunks, + compressor=compressor, ) + if path == self.episode_ends_path: + row_offset += part_end self._merged = True def on_shard_finalized(self, shard_id: str) -> None: @@ -207,23 +162,40 @@ def on_shard_finalized(self, shard_id: str) -> None: if not self.reduce_to_single_store or not self._merged: return _remove_parts(self.output) - try: - if not self.output.ls("_parts"): - self.output.rmdir("_parts") - except (FileNotFoundError, OSError, ValueError): - pass - def _collect_parts( - self, expected_parts: Iterable[str] + def _finalized_store_paths(self, prefix: str = "") -> list[str]: + stage_index = get_active_stage_index() + if stage_index is None or stage_index <= 0: + raise ValueError( + "write_zarr_reduce requires an active reducer stage with a prior writer stage" + ) + return [ + prefix + + _render_store_relpath( + self.store_template, + shard_id=row.shard_id, + worker_id=row.worker_token, + ) + for row in sort_finalized_workers( + get_finalized_workers(stage_index=stage_index - 1) + ) + ] + + def _collect_stores( + self, + relpaths: Iterable[str], + *, + for_merge: bool, ) -> list[tuple[str, set[str]]]: import zarr - parts: list[tuple[str, set[str]]] = [] + stores: list[tuple[str, set[str]]] = [] payload_paths: set[str] | None = None schemas: dict[str, tuple[tuple[int, ...], np.dtype[Any]]] = {} - for relpath in expected_parts: + for relpath in relpaths: if not self.output.exists(relpath): - raise ValueError(f"Zarr part store is missing: {relpath}") + kind = "part store" if for_merge else "store" + raise ValueError(f"Zarr {kind} is missing: {relpath}") source = zarr.open_group( store=_zarr_store(self.output, relpath, mode="r"), mode="r", @@ -231,23 +203,24 @@ def _collect_parts( source_paths = set(_iter_array_paths(source)) if not source_paths: continue - source_payload_paths = { - path for path in source_paths if path != self.episode_ends_path - } if ( - self.episode_ends_path is not None - and source_payload_paths + for_merge + and self.episode_ends_path is not None + and (source_paths - {self.episode_ends_path}) and self.episode_ends_path not in source_paths ): raise ValueError( f"Zarr part stores must contain {self.episode_ends_path!r}" ) + source_payload_paths = source_paths + if for_merge and self.episode_ends_path is not None: + source_payload_paths = source_paths - {self.episode_ends_path} if payload_paths is None: payload_paths = source_payload_paths elif source_payload_paths != payload_paths: - raise ValueError( - "Zarr part stores must contain the same payload arrays" - ) + kind = "part stores" if for_merge else "stores" + payload = " payload" if for_merge else "" + raise ValueError(f"Zarr {kind} must contain the same{payload} arrays") for path in source_paths: source_array = source[path] schema = (tuple(source_array.shape[1:]), np.dtype(source_array.dtype)) @@ -260,8 +233,8 @@ def _collect_parts( raise ValueError( f"Zarr arrays for {path!r} must have matching dtypes" ) - parts.append((relpath, source_paths)) - return parts + stores.append((relpath, source_paths)) + return stores def _iter_array_paths(group: Any, prefix: str = "") -> Iterable[str]: @@ -273,6 +246,12 @@ def _iter_array_paths(group: Any, prefix: str = "") -> Iterable[str]: yield from _iter_array_paths(item, path) +def _array_batches(array: Any, max_bytes: int) -> Iterable[np.ndarray]: + batch_size = _batch_length(array, max_bytes) + for start in range(0, int(array.shape[0]), batch_size): + yield np.asarray(array[start : min(int(array.shape[0]), start + batch_size)]) + + def _remove_parts(output: DataFolder) -> None: try: output.rm("_parts", recursive=True) @@ -280,37 +259,6 @@ def _remove_parts(output: DataFolder) -> None: pass -def _validate_zarr_stores(output: DataFolder, relpaths: Iterable[str]) -> None: - import zarr - - payload_paths: set[str] | None = None - schemas: dict[str, tuple[tuple[int, ...], np.dtype[Any]]] = {} - for relpath in relpaths: - if not output.exists(relpath): - raise ValueError(f"Zarr store is missing: {relpath}") - source = zarr.open_group( - store=_zarr_store(output, relpath, mode="r"), - mode="r", - ) - source_paths = set(_iter_array_paths(source)) - if not source_paths: - continue - if payload_paths is None: - payload_paths = source_paths - elif source_paths != payload_paths: - raise ValueError("Zarr stores must contain the same arrays") - for path in source_paths: - source_array = source[path] - schema = (tuple(source_array.shape[1:]), np.dtype(source_array.dtype)) - previous = schemas.setdefault(path, schema) - if previous != schema: - if previous[0] != schema[0]: - raise ValueError( - f"Zarr arrays for {path!r} must have matching trailing shapes" - ) - raise ValueError(f"Zarr arrays for {path!r} must have matching dtypes") - - def _clear_final_group(group: Any) -> None: for key in sorted({*group.array_keys(), *group.group_keys()}): if key != "_parts": From 47f121ea8e9434971fa3b98f5529e07a9356afb0 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Mon, 25 May 2026 00:59:53 +0200 Subject: [PATCH 33/39] Tighten Zarr reducer implementation --- src/refiner/pipeline/sinks/reducer/zarr.py | 83 +++++++++------------- 1 file changed, 34 insertions(+), 49 deletions(-) diff --git a/src/refiner/pipeline/sinks/reducer/zarr.py b/src/refiner/pipeline/sinks/reducer/zarr.py index f60e3ddc..74ff6473 100644 --- a/src/refiner/pipeline/sinks/reducer/zarr.py +++ b/src/refiner/pipeline/sinks/reducer/zarr.py @@ -5,10 +5,11 @@ import numpy as np -from refiner.io.datafolder import DataFolder, DataFolderLike +from refiner.io.datafolder import DataFolderLike from refiner.pipeline.data.block import Block from refiner.pipeline.sinks.reducer.file import FileCleanupReducerSink from refiner.pipeline.sinks.zarr import ( + _DEFAULT_ARRAY_CHUNK_BYTES, _append_zarr_array, _batch_length, _render_store_relpath, @@ -26,7 +27,7 @@ def __init__( *, store_template: str, episode_ends_path: str | None = None, - array_chunk_bytes: int = 8 * 1024 * 1024, + array_chunk_bytes: int = _DEFAULT_ARRAY_CHUNK_BYTES, reduce_to_single_store: bool = False, ) -> None: check_required_dependencies("write_zarr", ["zarr"], dist="zarr") @@ -51,7 +52,10 @@ def write_shard_block(self, shard_id: str, block: Block) -> None: relpaths = self._finalized_store_paths() self._collect_stores(relpaths, for_merge=False) - _remove_parts(self.output) + try: + self.output.rm("_parts", recursive=True) + except FileNotFoundError: + pass self._clear_root_payload_except(relpaths) def _clear_root_payload_except(self, relpaths: Iterable[str]) -> None: @@ -71,10 +75,11 @@ def clear_group(group: Any, prefix: str = "") -> None: continue if path in keep_paths: continue - if any(keep_path.startswith(f"{path}/") for keep_path in keep_paths): - if key in group_keys: - clear_group(group[key], path) - continue + if key in group_keys and any( + keep_path.startswith(f"{path}/") for keep_path in keep_paths + ): + clear_group(group[key], path) + continue del group[key] group.attrs.clear() @@ -99,21 +104,15 @@ def _merge(self) -> None: expected_parts = self._finalized_store_paths(prefix="_parts/") import zarr - if not expected_parts: - final = zarr.open_group( - store=_zarr_store(self.output, "", mode="a"), - mode="a", - ) - _clear_final_group(final) - self._merged = True - return - parts = self._collect_stores(expected_parts, for_merge=True) final = zarr.open_group( store=_zarr_store(self.output, "", mode="a"), mode="a", ) - _clear_final_group(final) + for key in sorted({*final.array_keys(), *final.group_keys()}): + if key != "_parts": + del final[key] + final.attrs.clear() row_offset = 0 arrays: dict[str, Any] = {} @@ -140,7 +139,10 @@ def _merge(self) -> None: continue part_end = 0 - for values in _array_batches(source_array, self.array_chunk_bytes): + batch_size = _batch_length(source_array, self.array_chunk_bytes) + for start in range(0, int(source_array.shape[0]), batch_size): + end = min(int(source_array.shape[0]), start + batch_size) + values = np.asarray(source_array[start:end]) if path == self.episode_ends_path: values = np.asarray(values, dtype=np.int64) part_end = int(values[-1]) @@ -161,7 +163,10 @@ def on_shard_finalized(self, shard_id: str) -> None: del shard_id if not self.reduce_to_single_store or not self._merged: return - _remove_parts(self.output) + try: + self.output.rm("_parts", recursive=True) + except FileNotFoundError: + pass def _finalized_store_paths(self, prefix: str = "") -> list[str]: stage_index = get_active_stage_index() @@ -192,6 +197,7 @@ def _collect_stores( stores: list[tuple[str, set[str]]] = [] payload_paths: set[str] | None = None schemas: dict[str, tuple[tuple[int, ...], np.dtype[Any]]] = {} + episode_path = self.episode_ends_path if for_merge else None for relpath in relpaths: if not self.output.exists(relpath): kind = "part store" if for_merge else "store" @@ -203,18 +209,17 @@ def _collect_stores( source_paths = set(_iter_array_paths(source)) if not source_paths: continue + source_payload_paths = ( + source_paths - {episode_path} + if episode_path is not None + else source_paths + ) if ( - for_merge - and self.episode_ends_path is not None - and (source_paths - {self.episode_ends_path}) - and self.episode_ends_path not in source_paths + episode_path is not None + and source_payload_paths + and episode_path not in source_paths ): - raise ValueError( - f"Zarr part stores must contain {self.episode_ends_path!r}" - ) - source_payload_paths = source_paths - if for_merge and self.episode_ends_path is not None: - source_payload_paths = source_paths - {self.episode_ends_path} + raise ValueError(f"Zarr part stores must contain {episode_path!r}") if payload_paths is None: payload_paths = source_payload_paths elif source_payload_paths != payload_paths: @@ -246,24 +251,4 @@ def _iter_array_paths(group: Any, prefix: str = "") -> Iterable[str]: yield from _iter_array_paths(item, path) -def _array_batches(array: Any, max_bytes: int) -> Iterable[np.ndarray]: - batch_size = _batch_length(array, max_bytes) - for start in range(0, int(array.shape[0]), batch_size): - yield np.asarray(array[start : min(int(array.shape[0]), start + batch_size)]) - - -def _remove_parts(output: DataFolder) -> None: - try: - output.rm("_parts", recursive=True) - except FileNotFoundError: - pass - - -def _clear_final_group(group: Any) -> None: - for key in sorted({*group.array_keys(), *group.group_keys()}): - if key != "_parts": - del group[key] - group.attrs.clear() - - __all__ = ["ZarrReducerSink"] From 538800d4a2ae7b9ed4d7c616c37c6982a53a15f4 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Mon, 25 May 2026 01:04:06 +0200 Subject: [PATCH 34/39] Simplify Zarr writer batching --- src/refiner/pipeline/sinks/zarr.py | 162 +++++++++++------------------ 1 file changed, 62 insertions(+), 100 deletions(-) diff --git a/src/refiner/pipeline/sinks/zarr.py b/src/refiner/pipeline/sinks/zarr.py index 9601f319..ec384ddd 100644 --- a/src/refiner/pipeline/sinks/zarr.py +++ b/src/refiner/pipeline/sinks/zarr.py @@ -73,13 +73,13 @@ def flush_pending() -> None: if pending_store is None or not pending_arrays: return store = pending_store - rollback_lengths: dict[str, int | None] = {} previous_row_end = store.row_end - for zarr_path in pending_arrays: - dataset = store.arrays.get(zarr_path) - rollback_lengths[zarr_path] = ( - None if dataset is None else int(dataset.shape[0]) - ) + rollback_lengths = { + zarr_path: None + if (dataset := store.arrays.get(zarr_path)) is None + else int(dataset.shape[0]) + for zarr_path in pending_arrays + } if self.episode_ends_path is not None: dataset = store.arrays.get(self.episode_ends_path) rollback_lengths[self.episode_ends_path] = ( @@ -92,8 +92,6 @@ def flush_pending() -> None: ) for zarr_path, arrays in pending_arrays.items() } - for zarr_path, array in combined.items(): - self._validate_array_append(store, zarr_path, array) for zarr_path, array in combined.items(): self._append_array(store, zarr_path, array) if self.episode_ends_path is not None: @@ -141,19 +139,15 @@ def flush_pending() -> None: "Zarr arrays for one row must have matching lengths" ) row_bytes = sum(array.nbytes for array in row_arrays.values()) - if ( - pending_arrays - and pending_bytes + row_bytes > self.array_chunk_bytes - ): - flush_pending() - if pending_arrays and len(pending_lengths) >= _MAX_INITIAL_CHUNK_ROWS: - flush_pending() - if pending_arrays and set(row_arrays) != set(pending_arrays): - flush_pending() - if pending_arrays and any( - pending_arrays[zarr_path][0].shape[1:] != array.shape[1:] - or pending_arrays[zarr_path][0].dtype != array.dtype - for zarr_path, array in row_arrays.items() + if pending_arrays and ( + pending_bytes + row_bytes > self.array_chunk_bytes + or len(pending_lengths) >= _MAX_INITIAL_CHUNK_ROWS + or set(row_arrays) != set(pending_arrays) + or any( + pending_arrays[zarr_path][0].shape[1:] != array.shape[1:] + or pending_arrays[zarr_path][0].dtype != array.dtype + for zarr_path, array in row_arrays.items() + ) ): flush_pending() store = self._store(shard_id) @@ -175,22 +169,20 @@ def _write_row_values( lengths: list[int], ) -> None: store = self._store(shard_id) - if not lengths: - expected_length = None - else: - expected_length = lengths[0] - if any(item != expected_length for item in lengths): - raise ValueError("Zarr arrays for one row must have matching lengths") + expected_length = lengths[0] if lengths else None + if expected_length is not None and any( + item != expected_length for item in lengths + ): + raise ValueError("Zarr arrays for one row must have matching lengths") - rollback_lengths: dict[str, int | None] = {} - for zarr_path in [*row_arrays, *(path for path, _ in row_videos)]: - dataset = store.arrays.get(zarr_path) - rollback_lengths[zarr_path] = ( - None if dataset is None else int(dataset.shape[0]) - ) + previous_row_end = store.row_end + rollback_lengths = { + zarr_path: None + if (dataset := store.arrays.get(zarr_path)) is None + else int(dataset.shape[0]) + for zarr_path in [*row_arrays, *(path for path, _ in row_videos)] + } try: - for zarr_path, array in row_arrays.items(): - self._validate_array_append(store, zarr_path, array) for zarr_path, array in row_arrays.items(): self._append_array(store, zarr_path, array) for zarr_path, video in row_videos: @@ -214,12 +206,13 @@ def _write_row_values( rollback_lengths[self.episode_ends_path] = ( None if dataset is None else int(dataset.shape[0]) ) - store.row_end += lengths[0] + row_end = store.row_end + lengths[0] self._append_array( store, self.episode_ends_path, - np.asarray([store.row_end], dtype=np.int64), + np.asarray([row_end], dtype=np.int64), ) + store.row_end = row_end except Exception: for zarr_path, length in rollback_lengths.items(): if length is None: @@ -228,11 +221,7 @@ def _write_row_values( dataset = store.arrays.get(zarr_path) if dataset is not None: dataset.resize((length, *dataset.shape[1:])) - if self.episode_ends_path is not None: - dataset = store.arrays.get(self.episode_ends_path) - store.row_end = ( - 0 if dataset is None or dataset.shape[0] == 0 else int(dataset[-1]) - ) + store.row_end = previous_row_end raise def _row_values( @@ -265,57 +254,20 @@ async def _append_video( *, expected_length: int | None = None, ) -> int: - if isinstance(video, VideoFrameArray): - if expected_length is not None and video.frame_count != expected_length: - raise ValueError("Zarr arrays for one row must have matching lengths") - if video.frame_count == 0: - empty = np.asarray(video.frames, dtype=np.uint8) - self._append_array(store, path, empty[:0]) - return 0 - batch: list[np.ndarray] = [] - batch_limit: int | None = None - for frame in video.iter_frame_arrays(): - batch.append(frame) - if batch_limit is None: - batch_limit = self._video_batch_limit(frame) - if len(batch) >= batch_limit: - self._append_array( - store, - path, - np.stack(batch, axis=0), - chunks=(batch_limit, *frame.shape), - ) - batch.clear() - if batch: - self._append_array( - store, - path, - np.stack(batch, axis=0), - chunks=(batch_limit or len(batch), *batch[0].shape), - ) - return video.frame_count - batch: list[np.ndarray] = [] batch_limit: int | None = None count = 0 - async for frame in video.iter_frames(): - batch.append(frame.frame.to_ndarray(format="rgb24")) + + def append_frame(frame: np.ndarray) -> None: + nonlocal batch_limit + batch.append(frame) if batch_limit is None: - batch_limit = self._video_batch_limit(batch[0]) - if len(batch) >= batch_limit: - if expected_length is not None and count + len(batch) > expected_length: - raise ValueError( - "Zarr arrays for one row must have matching lengths" - ) - self._append_array( - store, - path, - np.stack(batch, axis=0), - chunks=(batch_limit, *batch[0].shape), - ) - count += len(batch) - batch.clear() - if batch: + batch_limit = self._video_batch_limit(frame) + + def flush_batch() -> None: + nonlocal count + if not batch: + return if expected_length is not None and count + len(batch) > expected_length: raise ValueError("Zarr arrays for one row must have matching lengths") self._append_array( @@ -325,6 +277,27 @@ async def _append_video( chunks=(batch_limit or len(batch), *batch[0].shape), ) count += len(batch) + batch.clear() + + if isinstance(video, VideoFrameArray): + if expected_length is not None and video.frame_count != expected_length: + raise ValueError("Zarr arrays for one row must have matching lengths") + if video.frame_count == 0: + empty = np.asarray(video.frames, dtype=np.uint8) + self._append_array(store, path, empty[:0]) + return 0 + for frame in video.iter_frame_arrays(): + append_frame(frame) + if batch_limit is not None and len(batch) >= batch_limit: + flush_batch() + flush_batch() + return count + + async for frame in video.iter_frames(): + append_frame(frame.frame.to_ndarray(format="rgb24")) + if batch_limit is not None and len(batch) >= batch_limit: + flush_batch() + flush_batch() if count == 0: raise ValueError("Zarr video source produced no frames") if expected_length is not None and count != expected_length: @@ -404,17 +377,6 @@ def _append_array( chunks=chunks or _chunk_shape(array, self.array_chunk_bytes), ) - def _validate_array_append( - self, - store: _ZarrWriteState, - path: str, - array: np.ndarray, - ) -> None: - dataset = store.arrays.get(path) - if dataset is None: - return - _validate_array_schema(path, dataset, array) - def _drop_array(self, store: _ZarrWriteState, path: str) -> None: store.arrays.pop(path, None) try: From 2e797106b98ef9e3cec7db3c8361d19f30c2e3f1 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Mon, 25 May 2026 01:17:18 +0200 Subject: [PATCH 35/39] Consolidate Zarr writer validation --- src/refiner/pipeline/sinks/zarr.py | 139 ++++++++++++++--------------- 1 file changed, 67 insertions(+), 72 deletions(-) diff --git a/src/refiner/pipeline/sinks/zarr.py b/src/refiner/pipeline/sinks/zarr.py index ec384ddd..11ed42e5 100644 --- a/src/refiner/pipeline/sinks/zarr.py +++ b/src/refiner/pipeline/sinks/zarr.py @@ -3,7 +3,7 @@ from collections.abc import Iterable, Mapping from dataclasses import dataclass, field from string import Formatter -from typing import Any, cast +from typing import Any import numpy as np import pyarrow as pa @@ -74,16 +74,10 @@ def flush_pending() -> None: return store = pending_store previous_row_end = store.row_end - rollback_lengths = { - zarr_path: None - if (dataset := store.arrays.get(zarr_path)) is None - else int(dataset.shape[0]) - for zarr_path in pending_arrays - } + rollback_lengths = self._snapshot_array_lengths(store, pending_arrays) if self.episode_ends_path is not None: - dataset = store.arrays.get(self.episode_ends_path) - rollback_lengths[self.episode_ends_path] = ( - None if dataset is None else int(dataset.shape[0]) + rollback_lengths.update( + self._snapshot_array_lengths(store, [self.episode_ends_path]) ) try: combined = { @@ -102,13 +96,7 @@ def flush_pending() -> None: self._append_array(store, self.episode_ends_path, row_ends) store.row_end = int(row_ends[-1]) except Exception: - for zarr_path, length in rollback_lengths.items(): - if length is None: - self._drop_array(store, zarr_path) - continue - dataset = store.arrays.get(zarr_path) - if dataset is not None: - dataset.resize((length, *dataset.shape[1:])) + self._restore_array_lengths(store, rollback_lengths) store.row_end = previous_row_end raise finally: @@ -131,13 +119,12 @@ def flush_pending() -> None: count += 1 continue - if lengths: - length = lengths[0] - if any(item != length for item in lengths): - flush_pending() - raise ValueError( - "Zarr arrays for one row must have matching lengths" - ) + try: + length = _matching_length(lengths) + except Exception: + flush_pending() + raise + if length is not None: row_bytes = sum(array.nbytes for array in row_arrays.values()) if pending_arrays and ( pending_bytes + row_bytes > self.array_chunk_bytes @@ -169,19 +156,13 @@ def _write_row_values( lengths: list[int], ) -> None: store = self._store(shard_id) - expected_length = lengths[0] if lengths else None - if expected_length is not None and any( - item != expected_length for item in lengths - ): - raise ValueError("Zarr arrays for one row must have matching lengths") + expected_length = _matching_length(lengths) previous_row_end = store.row_end - rollback_lengths = { - zarr_path: None - if (dataset := store.arrays.get(zarr_path)) is None - else int(dataset.shape[0]) - for zarr_path in [*row_arrays, *(path for path, _ in row_videos)] - } + rollback_lengths = self._snapshot_array_lengths( + store, + [*row_arrays, *(path for path, _ in row_videos)], + ) try: for zarr_path, array in row_arrays.items(): self._append_array(store, zarr_path, array) @@ -195,18 +176,12 @@ def _write_row_values( ) ).result() lengths.append(video_length) - if lengths: - length = lengths[0] - if any(item != length for item in lengths): - raise ValueError( - "Zarr arrays for one row must have matching lengths" - ) - if lengths and self.episode_ends_path is not None: - dataset = store.arrays.get(self.episode_ends_path) - rollback_lengths[self.episode_ends_path] = ( - None if dataset is None else int(dataset.shape[0]) + length = _matching_length(lengths) + if length is not None and self.episode_ends_path is not None: + rollback_lengths.update( + self._snapshot_array_lengths(store, [self.episode_ends_path]) ) - row_end = store.row_end + lengths[0] + row_end = store.row_end + length self._append_array( store, self.episode_ends_path, @@ -214,13 +189,7 @@ def _write_row_values( ) store.row_end = row_end except Exception: - for zarr_path, length in rollback_lengths.items(): - if length is None: - self._drop_array(store, zarr_path) - continue - dataset = store.arrays.get(zarr_path) - if dataset is not None: - dataset.resize((length, *dataset.shape[1:])) + self._restore_array_lengths(store, rollback_lengths) store.row_end = previous_row_end raise @@ -262,7 +231,12 @@ def append_frame(frame: np.ndarray) -> None: nonlocal batch_limit batch.append(frame) if batch_limit is None: - batch_limit = self._video_batch_limit(frame) + batch_limit = min( + self.video_frame_batch_size, + _batch_length_for_shape( + (1, *frame.shape), frame.dtype, self.array_chunk_bytes + ), + ) def flush_batch() -> None: nonlocal count @@ -304,14 +278,6 @@ def flush_batch() -> None: raise ValueError("Zarr arrays for one row must have matching lengths") return count - def _video_batch_limit(self, frame: np.ndarray) -> int: - return min( - self.video_frame_batch_size, - _batch_length_for_shape( - (1, *frame.shape), frame.dtype, self.array_chunk_bytes - ), - ) - def _arrays_for_row(self, row: Row) -> dict[str, str]: if self.arrays is not None: return self.arrays @@ -349,9 +315,7 @@ def _store_relpath(self, shard_id: str) -> str: shard_id=shard_id, worker_id=get_active_worker_token(), ) - if self.reduce_to_single_store: - return f"_parts/{relpath}" - return relpath + return f"_parts/{relpath}" if self.reduce_to_single_store else relpath def on_shard_complete(self, shard_id: str) -> None: relpath = self._store_relpath(shard_id) @@ -377,12 +341,34 @@ def _append_array( chunks=chunks or _chunk_shape(array, self.array_chunk_bytes), ) - def _drop_array(self, store: _ZarrWriteState, path: str) -> None: - store.arrays.pop(path, None) - try: - del store.root[path] - except (KeyError, FileNotFoundError): - pass + def _snapshot_array_lengths( + self, + store: _ZarrWriteState, + paths: Iterable[str], + ) -> dict[str, int | None]: + return { + path: None + if (dataset := store.arrays.get(path)) is None + else int(dataset.shape[0]) + for path in paths + } + + def _restore_array_lengths( + self, + store: _ZarrWriteState, + lengths: Mapping[str, int | None], + ) -> None: + for path, length in lengths.items(): + if length is None: + store.arrays.pop(path, None) + try: + del store.root[path] + except (KeyError, FileNotFoundError): + pass + continue + dataset = store.arrays.get(path) + if dataset is not None: + dataset.resize((length, *dataset.shape[1:])) def close(self) -> None: self._stores.clear() @@ -516,10 +502,19 @@ def _as_array(value: Any) -> np.ndarray: return value.to_numpy(zero_copy_only=False) return np.asarray(value.to_pylist()) if isinstance(value, Iterable) and not isinstance(value, str | bytes | np.ndarray): - return np.asarray(list(cast(Iterable[Any], value))) + return np.asarray(list(value)) return np.asarray(value) +def _matching_length(lengths: list[int]) -> int | None: + if not lengths: + return None + length = lengths[0] + if any(item != length for item in lengths): + raise ValueError("Zarr arrays for one row must have matching lengths") + return length + + def _zarr_store(output: DataFolder, path: str = "", *, mode: str = "r"): import zarr From e15c7d38f17ee25f60966094313c209109050308 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Mon, 25 May 2026 01:21:18 +0200 Subject: [PATCH 36/39] Align Zarr reducer default --- src/refiner/pipeline/sinks/reducer/zarr.py | 2 +- tests/readers/test_zarr_reader.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/refiner/pipeline/sinks/reducer/zarr.py b/src/refiner/pipeline/sinks/reducer/zarr.py index 74ff6473..2fdfa280 100644 --- a/src/refiner/pipeline/sinks/reducer/zarr.py +++ b/src/refiner/pipeline/sinks/reducer/zarr.py @@ -28,7 +28,7 @@ def __init__( store_template: str, episode_ends_path: str | None = None, array_chunk_bytes: int = _DEFAULT_ARRAY_CHUNK_BYTES, - reduce_to_single_store: bool = False, + reduce_to_single_store: bool = True, ) -> None: check_required_dependencies("write_zarr", ["zarr"], dist="zarr") super().__init__( diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py index 937d0d02..2a498188 100644 --- a/tests/readers/test_zarr_reader.py +++ b/tests/readers/test_zarr_reader.py @@ -1113,6 +1113,7 @@ def test_write_zarr_non_reduced_cleanup_rejects_missing_finalized_store( ZarrReducerSink( str(zarr_out), store_template="{shard_id}__w{worker_id}.zarr", + reduce_to_single_store=False, ).write_block([DictRow({}, shard_id="reduce")]) @@ -1164,6 +1165,7 @@ def test_write_zarr_non_reduced_cleanup_keeps_empty_stores_retryable( ZarrReducerSink( str(zarr_out), store_template="{shard_id}__w{worker_id}.zarr", + reduce_to_single_store=False, ).write_block([DictRow({}, shard_id="reduce")]) assert empty_store.exists() From 5c0689851b7eaa94658ca09296127bbe8d0f2e2d Mon Sep 17 00:00:00 2001 From: guipenedo Date: Mon, 25 May 2026 15:22:14 +0200 Subject: [PATCH 37/39] Preserve Zarr writer attrs --- src/refiner/pipeline/pipeline.py | 4 ++++ src/refiner/pipeline/sinks/reducer/zarr.py | 16 +++++++++------- src/refiner/pipeline/sinks/zarr.py | 21 ++++++++++++++++++++- tests/readers/test_zarr_reader.py | 7 +++++-- 4 files changed, 38 insertions(+), 10 deletions(-) diff --git a/src/refiner/pipeline/pipeline.py b/src/refiner/pipeline/pipeline.py index d6483dc4..96016a6a 100644 --- a/src/refiner/pipeline/pipeline.py +++ b/src/refiner/pipeline/pipeline.py @@ -436,6 +436,7 @@ def write_zarr( output: DataFolderLike, *, arrays: Mapping[str, str] | None = None, + attrs: Mapping[str, str] | None = None, episode_ends_path: str | None = "meta/episode_ends", store_template: str = "{shard_id}__w{worker_id}.zarr", video_frame_batch_size: int = 8, @@ -449,6 +450,8 @@ def write_zarr( arrays: Mapping from output Zarr array path to source row key. If omitted for ``RoboticsRow`` inputs, writes the available default robotics arrays: actions, states, and timestamps. + attrs: Mapping from output Zarr root attribute name to source row key. + Attribute values must be stable across rows in each output store. episode_ends_path: Output Zarr path for cumulative row/episode end offsets. Set to None to omit episode boundaries. store_template: Per-shard store path template. Must include @@ -466,6 +469,7 @@ def write_zarr( ZarrSink( output=output, arrays=arrays, + attrs=attrs, episode_ends_path=episode_ends_path, store_template=store_template, video_frame_batch_size=video_frame_batch_size, diff --git a/src/refiner/pipeline/sinks/reducer/zarr.py b/src/refiner/pipeline/sinks/reducer/zarr.py index 2fdfa280..49484328 100644 --- a/src/refiner/pipeline/sinks/reducer/zarr.py +++ b/src/refiner/pipeline/sinks/reducer/zarr.py @@ -42,10 +42,9 @@ def __init__( self.episode_ends_path = episode_ends_path self.array_chunk_bytes = array_chunk_bytes self.reduce_to_single_store = reduce_to_single_store - self._merged = False def write_shard_block(self, shard_id: str, block: Block) -> None: - super().write_shard_block(shard_id, block) + self._run_cleanup() if self.reduce_to_single_store: self._merge() return @@ -98,9 +97,6 @@ def describe(self) -> tuple[str, str, dict[str, object]]: ) def _merge(self) -> None: - if self._merged: - return - expected_parts = self._finalized_store_paths(prefix="_parts/") import zarr @@ -116,11 +112,18 @@ def _merge(self) -> None: row_offset = 0 arrays: dict[str, Any] = {} + final_attrs: dict[str, Any] | None = None for relpath, paths in parts: source = zarr.open_group( store=_zarr_store(self.output, relpath, mode="r"), mode="r", ) + source_attrs = dict(source.attrs) + if final_attrs is None: + final_attrs = source_attrs + final.attrs.update(source_attrs) + elif source_attrs != final_attrs: + raise ValueError("Zarr part store attrs differ") for path in sorted(paths): source_array = source[path] chunks = getattr(source_array, "chunks", None) @@ -157,11 +160,10 @@ def _merge(self) -> None: ) if path == self.episode_ends_path: row_offset += part_end - self._merged = True def on_shard_finalized(self, shard_id: str) -> None: del shard_id - if not self.reduce_to_single_store or not self._merged: + if not self.reduce_to_single_store: return try: self.output.rm("_parts", recursive=True) diff --git a/src/refiner/pipeline/sinks/zarr.py b/src/refiner/pipeline/sinks/zarr.py index 11ed42e5..28fe6ee8 100644 --- a/src/refiner/pipeline/sinks/zarr.py +++ b/src/refiner/pipeline/sinks/zarr.py @@ -35,6 +35,7 @@ def __init__( output: DataFolderLike, *, arrays: Mapping[str, str] | None = None, + attrs: Mapping[str, str] | None = None, episode_ends_path: str | None = "meta/episode_ends", store_template: str = "{shard_id}__w{worker_id}.zarr", video_frame_batch_size: int = 8, @@ -49,6 +50,7 @@ def __init__( _validate_store_template(store_template) self.output = DataFolder.resolve(output) self.arrays = dict(arrays) if arrays is not None else None + self.attrs = dict(attrs) if attrs is not None else None self.episode_ends_path = episode_ends_path if self.arrays is not None: if not self.arrays: @@ -115,7 +117,7 @@ def flush_pending() -> None: if row_videos: flush_pending() - self._write_row_values(shard_id, row_arrays, row_videos, lengths) + self._write_row_values(shard_id, row, row_arrays, row_videos, lengths) count += 1 continue @@ -138,6 +140,7 @@ def flush_pending() -> None: ): flush_pending() store = self._store(shard_id) + self._write_attrs(store, row) if pending_store is None: pending_store = store for zarr_path, array in row_arrays.items(): @@ -148,14 +151,29 @@ def flush_pending() -> None: flush_pending() return count + def _write_attrs(self, store: _ZarrWriteState, row: Row) -> None: + if self.attrs is None: + return + for attr_name, source_key in self.attrs.items(): + value = _row_value(row, source_key) + if isinstance(value, np.generic): + value = value.item() + elif isinstance(value, np.ndarray): + value = value.tolist() + if attr_name in store.root.attrs and store.root.attrs[attr_name] != value: + raise ValueError(f"Zarr attr {attr_name!r} changed across rows") + store.root.attrs[attr_name] = value + def _write_row_values( self, shard_id: str, + row: Row, row_arrays: dict[str, np.ndarray], row_videos: list[tuple[str, VideoSource]], lengths: list[int], ) -> None: store = self._store(shard_id) + self._write_attrs(store, row) expected_length = _matching_length(lengths) previous_row_end = store.row_end @@ -380,6 +398,7 @@ def describe(self) -> tuple[str, str, dict[str, object]]: { "path": self.output.abs_path(), "arrays": dict(self.arrays) if self.arrays is not None else None, + "attrs": dict(self.attrs) if self.attrs is not None else None, "episode_ends_path": self.episode_ends_path, "store_template": self.store_template, "video_frame_batch_size": self.video_frame_batch_size, diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py index 2a498188..60cafcf0 100644 --- a/tests/readers/test_zarr_reader.py +++ b/tests/readers/test_zarr_reader.py @@ -806,8 +806,8 @@ def test_write_zarr_can_reduce_to_single_store(tmp_path: Path) -> None: ( mdr.from_items( [ - {"action": [[0.0], [0.1]], "state": [[1.0], [1.1]]}, - {"action": [[0.2]], "state": [[1.2]]}, + {"action": [[0.0], [0.1]], "state": [[1.0], [1.1]], "task": "push"}, + {"action": [[0.2]], "state": [[1.2]], "task": "push"}, ], items_per_shard=1, ) @@ -817,6 +817,7 @@ def test_write_zarr_can_reduce_to_single_store(tmp_path: Path) -> None: "data/action": "action", "data/state": "state", }, + attrs={"task": "task"}, reduce_to_single_store=True, ) .launch_local( @@ -831,12 +832,14 @@ def test_write_zarr_can_reduce_to_single_store(tmp_path: Path) -> None: "state": "data/state", "episode_ends": "meta/episode_ends", }, + attrs={"task": "task"}, file_path_column=None, ).take(1)[0] np.testing.assert_allclose(row["action"], [[0.0], [0.1], [0.2]]) np.testing.assert_allclose(row["state"], [[1.0], [1.1], [1.2]]) assert row["episode_ends"].tolist() == [2, 3] + assert row["task"] == "push" assert not (zarr_out / "_parts").exists() From f7981f2fdabf160655249e0ae13f31ec6e0c3b3b Mon Sep 17 00:00:00 2001 From: guipenedo Date: Mon, 25 May 2026 15:41:44 +0200 Subject: [PATCH 38/39] Simplify Zarr video writes --- src/refiner/pipeline/sinks/zarr.py | 61 +++++++++++------------------- src/refiner/video/types.py | 14 +++++++ tests/readers/test_zarr_reader.py | 21 +++++----- 3 files changed, 47 insertions(+), 49 deletions(-) diff --git a/src/refiner/pipeline/sinks/zarr.py b/src/refiner/pipeline/sinks/zarr.py index 28fe6ee8..55055ea2 100644 --- a/src/refiner/pipeline/sinks/zarr.py +++ b/src/refiner/pipeline/sinks/zarr.py @@ -15,7 +15,7 @@ from refiner.pipeline.sinks.base import BaseSink from refiner.robotics.row import RoboticsRow from refiner.utils import check_required_dependencies -from refiner.video import VideoFrameArray, VideoSource +from refiner.video import VideoSource from refiner.worker.context import get_active_worker_token _DEFAULT_ARRAY_CHUNK_BYTES = 8 * 1024 * 1024 @@ -65,18 +65,20 @@ def __init__( def write_shard_block(self, shard_id: str, block: Block) -> int: count = 0 - pending_store: _ZarrWriteState | None = None pending_arrays: dict[str, list[np.ndarray]] = {} pending_lengths: list[int] = [] pending_bytes = 0 def flush_pending() -> None: - nonlocal pending_store, pending_arrays, pending_lengths, pending_bytes - if pending_store is None or not pending_arrays: + nonlocal pending_arrays, pending_lengths, pending_bytes + if not pending_arrays: return - store = pending_store + store = self._store(shard_id) previous_row_end = store.row_end - rollback_lengths = self._snapshot_array_lengths(store, pending_arrays) + rollback_lengths = self._snapshot_array_lengths( + store, + pending_arrays.keys(), + ) if self.episode_ends_path is not None: rollback_lengths.update( self._snapshot_array_lengths(store, [self.episode_ends_path]) @@ -102,7 +104,6 @@ def flush_pending() -> None: store.row_end = previous_row_end raise finally: - pending_store = None pending_arrays = {} pending_lengths = [] pending_bytes = 0 @@ -110,7 +111,7 @@ def flush_pending() -> None: for row in block: try: arrays = self._arrays_for_row(row) - row_arrays, row_videos, lengths = self._row_values(row, arrays) + row_arrays, row_videos, lengths = self._split_row_values(row, arrays) except Exception: flush_pending() raise @@ -141,8 +142,6 @@ def flush_pending() -> None: flush_pending() store = self._store(shard_id) self._write_attrs(store, row) - if pending_store is None: - pending_store = store for zarr_path, array in row_arrays.items(): pending_arrays.setdefault(zarr_path, []).append(array) pending_lengths.append(length) @@ -211,7 +210,7 @@ def _write_row_values( store.row_end = previous_row_end raise - def _row_values( + def _split_row_values( self, row: Row, arrays: Mapping[str, str], @@ -245,17 +244,6 @@ async def _append_video( batch_limit: int | None = None count = 0 - def append_frame(frame: np.ndarray) -> None: - nonlocal batch_limit - batch.append(frame) - if batch_limit is None: - batch_limit = min( - self.video_frame_batch_size, - _batch_length_for_shape( - (1, *frame.shape), frame.dtype, self.array_chunk_bytes - ), - ) - def flush_batch() -> None: nonlocal count if not batch: @@ -271,23 +259,18 @@ def flush_batch() -> None: count += len(batch) batch.clear() - if isinstance(video, VideoFrameArray): - if expected_length is not None and video.frame_count != expected_length: - raise ValueError("Zarr arrays for one row must have matching lengths") - if video.frame_count == 0: - empty = np.asarray(video.frames, dtype=np.uint8) - self._append_array(store, path, empty[:0]) - return 0 - for frame in video.iter_frame_arrays(): - append_frame(frame) - if batch_limit is not None and len(batch) >= batch_limit: - flush_batch() - flush_batch() - return count - - async for frame in video.iter_frames(): - append_frame(frame.frame.to_ndarray(format="rgb24")) - if batch_limit is not None and len(batch) >= batch_limit: + async for frame in video.iter_numpy_frames(): + limit = batch_limit + if limit is None: + limit = min( + self.video_frame_batch_size, + _batch_length_for_shape( + (1, *frame.shape), frame.dtype, self.array_chunk_bytes + ), + ) + batch_limit = limit + batch.append(frame) + if len(batch) >= limit: flush_batch() flush_batch() if count == 0: diff --git a/src/refiner/video/types.py b/src/refiner/video/types.py index a5d67368..4414d347 100644 --- a/src/refiner/video/types.py +++ b/src/refiner/video/types.py @@ -28,6 +28,8 @@ def clipped( def iter_frames(self) -> AsyncIterator[DecodedVideoFrame]: ... + def iter_numpy_frames(self) -> AsyncIterator[np.ndarray]: ... + def iter_frame_windows( self, *, @@ -108,6 +110,10 @@ def iter_frames(self) -> AsyncIterator[DecodedVideoFrame]: return iter_encoded_frames(self) + async def iter_numpy_frames(self) -> AsyncIterator[np.ndarray]: + async for frame in self.iter_frames(): + yield frame.frame.to_ndarray(format="rgb24") + def iter_frame_windows( self, *, @@ -172,6 +178,10 @@ def iter_frames(self) -> AsyncIterator[DecodedVideoFrame]: return iter_encoded_frames(self) + async def iter_numpy_frames(self) -> AsyncIterator[np.ndarray]: + async for frame in self.iter_frames(): + yield frame.frame.to_ndarray(format="rgb24") + def iter_frame_windows( self, *, @@ -236,6 +246,10 @@ def duration_s(self) -> float: def iter_frame_arrays(self) -> Iterator[np.ndarray]: yield from self._array + async def iter_numpy_frames(self) -> AsyncIterator[np.ndarray]: + for frame in self._array: + yield frame + def clipped( self, *, diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py index 60cafcf0..8d55eb58 100644 --- a/tests/readers/test_zarr_reader.py +++ b/tests/readers/test_zarr_reader.py @@ -41,6 +41,10 @@ async def iter_frames(self): if False: yield None + async def iter_numpy_frames(self): + if False: + yield None + async def iter_frame_windows(self, **_kwargs): if False: yield None @@ -1794,7 +1798,7 @@ def test_write_zarr_materializes_frame_array_videos(tmp_path: Path) -> None: np.testing.assert_allclose(row["action"], [[0.0], [0.1]]) -def test_write_zarr_materializes_empty_frame_array_videos(tmp_path: Path) -> None: +def test_write_zarr_rejects_empty_frame_array_videos(tmp_path: Path) -> None: output = tmp_path / "empty-video.zarr" frames = np.empty((0, 4, 5, 3), dtype=np.uint8) rows = list( @@ -1808,15 +1812,12 @@ def test_write_zarr_materializes_empty_frame_array_videos(tmp_path: Path) -> Non ) ) - ZarrSink( - str(output), - arrays={"data/rgb": "observation.images.front"}, - reduce_to_single_store=False, - ).write_block(rows) - - root = _open_test_zarr(next(output.glob("*.zarr")), mode="r") - assert root["data/rgb"].shape == frames.shape - assert root["meta/episode_ends"][:].tolist() == [0] + with pytest.raises(ValueError, match="produced no frames"): + ZarrSink( + str(output), + arrays={"data/rgb": "observation.images.front"}, + reduce_to_single_store=False, + ).write_block(rows) def test_write_zarr_uses_byte_budgeted_chunks_for_large_rows(tmp_path: Path) -> None: From 11f9851f67604a9a5a4c5b13c95681f233f58eb1 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Mon, 25 May 2026 15:57:46 +0200 Subject: [PATCH 39/39] Address Zarr writer review comments --- src/refiner/pipeline/sinks/reducer/zarr.py | 49 ++++-- src/refiner/pipeline/sinks/zarr.py | 47 ++++-- tests/readers/test_zarr_reader.py | 171 +++++++++++++++++++++ 3 files changed, 239 insertions(+), 28 deletions(-) diff --git a/src/refiner/pipeline/sinks/reducer/zarr.py b/src/refiner/pipeline/sinks/reducer/zarr.py index 49484328..da56c105 100644 --- a/src/refiner/pipeline/sinks/reducer/zarr.py +++ b/src/refiner/pipeline/sinks/reducer/zarr.py @@ -101,6 +101,18 @@ def _merge(self) -> None: import zarr parts = self._collect_stores(expected_parts, for_merge=True) + final_attrs: dict[str, Any] | None = None + for relpath, _paths in parts: + source = zarr.open_group( + store=_zarr_store(self.output, relpath, mode="r"), + mode="r", + ) + source_attrs = dict(source.attrs) + if final_attrs is None: + final_attrs = source_attrs + elif source_attrs != final_attrs: + raise ValueError("Zarr part store attrs differ") + final = zarr.open_group( store=_zarr_store(self.output, "", mode="a"), mode="a", @@ -109,26 +121,41 @@ def _merge(self) -> None: if key != "_parts": del final[key] final.attrs.clear() + if final_attrs is not None: + final.attrs.update(final_attrs) row_offset = 0 arrays: dict[str, Any] = {} - final_attrs: dict[str, Any] | None = None for relpath, paths in parts: source = zarr.open_group( store=_zarr_store(self.output, relpath, mode="r"), mode="r", ) - source_attrs = dict(source.attrs) - if final_attrs is None: - final_attrs = source_attrs - final.attrs.update(source_attrs) - elif source_attrs != final_attrs: - raise ValueError("Zarr part store attrs differ") + episode_path = self.episode_ends_path + payload_paths = paths - ( + {episode_path} if episode_path is not None else set() + ) + payload_lengths = {int(source[path].shape[0]) for path in payload_paths} + if len(payload_lengths) > 1: + raise ValueError( + "Zarr part store payload arrays must have matching lengths" + ) + payload_rows = next(iter(payload_lengths), None) + part_end = 0 + if episode_path is not None and episode_path in paths: + episode_ends = source[episode_path] + if episode_ends.shape[0] > 0: + part_end = int(np.asarray(episode_ends[-1])) + if payload_rows is not None and part_end != payload_rows: + raise ValueError( + "Zarr part store episode_ends final value does not match " + "payload row count" + ) for path in sorted(paths): source_array = source[path] chunks = getattr(source_array, "chunks", None) compressor = getattr(source_array, "compressor", None) - if source_array.shape[0] == 0 and path == self.episode_ends_path: + if source_array.shape[0] == 0 and path == episode_path: continue if source_array.shape[0] == 0: _append_zarr_array( @@ -141,14 +168,12 @@ def _merge(self) -> None: ) continue - part_end = 0 batch_size = _batch_length(source_array, self.array_chunk_bytes) for start in range(0, int(source_array.shape[0]), batch_size): end = min(int(source_array.shape[0]), start + batch_size) values = np.asarray(source_array[start:end]) - if path == self.episode_ends_path: + if path == episode_path: values = np.asarray(values, dtype=np.int64) - part_end = int(values[-1]) values = values + row_offset _append_zarr_array( final, @@ -158,7 +183,7 @@ def _merge(self) -> None: chunks=chunks, compressor=compressor, ) - if path == self.episode_ends_path: + if path == episode_path: row_offset += part_end def on_shard_finalized(self, shard_id: str) -> None: diff --git a/src/refiner/pipeline/sinks/zarr.py b/src/refiner/pipeline/sinks/zarr.py index 55055ea2..96d0b7fe 100644 --- a/src/refiner/pipeline/sinks/zarr.py +++ b/src/refiner/pipeline/sinks/zarr.py @@ -49,13 +49,20 @@ def __init__( raise ValueError("array_chunk_bytes must be greater than zero") _validate_store_template(store_template) self.output = DataFolder.resolve(output) - self.arrays = dict(arrays) if arrays is not None else None + self.episode_ends_path = ( + _normalize_public_zarr_path(episode_ends_path, "episode_ends_path") + if episode_ends_path is not None + else None + ) + self.arrays = ( + _normalize_array_paths(arrays, self.episode_ends_path) + if arrays is not None + else None + ) self.attrs = dict(attrs) if attrs is not None else None - self.episode_ends_path = episode_ends_path if self.arrays is not None: if not self.arrays: raise ValueError("write_zarr arrays must not be empty") - _validate_array_paths(self.arrays, episode_ends_path) self.store_template = store_template self.video_frame_batch_size = video_frame_batch_size self.array_chunk_bytes = array_chunk_bytes @@ -284,12 +291,14 @@ def _arrays_for_row(self, row: Row) -> dict[str, str]: return self.arrays default_arrays = _default_robotics_arrays(row) if self._default_arrays is None: - self._default_arrays = default_arrays + self._default_arrays = _normalize_array_paths( + default_arrays, + self.episode_ends_path, + ) if not self._default_arrays: raise ValueError( "write_zarr inferred no default robotics arrays; pass arrays=..." ) - _validate_array_paths(self._default_arrays, self.episode_ends_path) elif default_arrays != self._default_arrays: raise ValueError( "Zarr default arrays changed across rows; pass arrays=... " @@ -415,22 +424,25 @@ def _default_robotics_arrays(row: Row) -> dict[str, str]: return arrays -def _validate_array_paths( +def _normalize_array_paths( arrays: Mapping[str, str], episode_ends_path: str | None, -) -> None: - for path in arrays: - _validate_public_zarr_path(path, "Zarr array path") - if episode_ends_path is not None: - _validate_public_zarr_path(episode_ends_path, "episode_ends_path") - if episode_ends_path is not None and episode_ends_path in arrays: +) -> dict[str, str]: + normalized: dict[str, str] = {} + for path, source_key in arrays.items(): + normalized_path = _normalize_public_zarr_path(path, "Zarr array path") + if normalized_path in normalized: + raise ValueError(f"Duplicate Zarr array path: {normalized_path}") + normalized[normalized_path] = source_key + if episode_ends_path is not None and episode_ends_path in normalized: raise ValueError( f"Zarr array path collides with episode_ends_path: {episode_ends_path}" ) + return normalized def _validate_store_template(store_template: str) -> None: - _validate_public_zarr_path(store_template, "store_template") + _normalize_public_zarr_path(store_template, "store_template") fields: set[str] = set() for _literal_text, field_name, format_spec, conversion in Formatter().parse( store_template @@ -461,20 +473,23 @@ def _render_store_relpath( worker_id: str, ) -> str: relpath = store_template.format(shard_id=shard_id, worker_id=worker_id) - _validate_public_zarr_path(relpath, "rendered store path") + _normalize_public_zarr_path(relpath, "rendered store path") return relpath -def _validate_public_zarr_path(path: str, label: str) -> None: +def _normalize_public_zarr_path(path: str, label: str) -> str: path = str(path) if path.startswith("/"): raise ValueError(f"{label} must be relative") parts = [part for part in path.split("/") if part] + if not parts: + raise ValueError(f"{label} must not be empty") if any(part in {".", ".."} for part in parts): raise ValueError(f"{label} must not contain '.' or '..' segments") - root = parts[0] if parts else "" + root = parts[0] if root in {"_parts", "_refiner"}: raise ValueError(f"{label} must not use reserved root: {root}") + return "/".join(parts) def _row_value(row: Row, key: str) -> Any: diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py index 8d55eb58..24bccb34 100644 --- a/tests/readers/test_zarr_reader.py +++ b/tests/readers/test_zarr_reader.py @@ -904,6 +904,17 @@ def test_write_zarr_rejects_path_traversal(tmp_path: Path) -> None: str(tmp_path / "array-escape.zarr"), arrays={"../action": "action"}, ) + with pytest.raises(ValueError, match="must not be empty"): + ZarrSink( + str(tmp_path / "empty-array-path.zarr"), + arrays={"": "action"}, + ) + with pytest.raises(ValueError, match="must not be empty"): + ZarrSink( + str(tmp_path / "empty-episode-ends.zarr"), + arrays={"data/action": "action"}, + episode_ends_path="", + ) def test_write_zarr_rejects_rendered_path_traversal(tmp_path: Path) -> None: @@ -1330,6 +1341,51 @@ def test_write_zarr_single_store_skips_mixed_empty_shards(tmp_path: Path) -> Non assert not (zarr_out / "_parts").exists() +def test_write_zarr_single_store_offsets_batched_episode_ends( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "single-batched-episode-ends.zarr" + workers = ["worker-a", "worker-b"] + for shard_id, worker in enumerate(workers): + part = ( + zarr_out / "_parts" / f"shard-{shard_id}__w{worker_token_for(worker)}.zarr" + ) + _write_part_zarr( + part, + { + "data/action": np.arange(2, dtype=np.float32).reshape(2, 1), + "meta/episode_ends": np.asarray([1, 2], dtype=np.int64), + }, + ) + runtime = _FinalizedWorkersRuntime( + [ + FinalizedShardWorker( + shard_id=f"shard-{shard_id}", + worker_id=worker, + global_ordinal=shard_id, + ) + for shard_id, worker in enumerate(workers) + ] + ) + with set_active_run_context( + job_id="local", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast(RuntimeLifecycle, runtime), + ): + ZarrReducerSink( + str(zarr_out), + store_template="{shard_id}__w{worker_id}.zarr", + episode_ends_path="meta/episode_ends", + array_chunk_bytes=8, + reduce_to_single_store=True, + ).write_block([DictRow({}, shard_id="reduce")]) + + root = _open_test_zarr(zarr_out, mode="r") + assert root["meta/episode_ends"][:].tolist() == [1, 2, 3, 4] + + def test_write_zarr_single_store_rejects_inconsistent_part_payloads( tmp_path: Path, ) -> None: @@ -1427,6 +1483,45 @@ def test_write_zarr_single_store_rejects_part_missing_episode_ends( ).write_block([DictRow({}, shard_id="reduce")]) +def test_write_zarr_single_store_rejects_part_row_end_mismatch( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "single-row-end-mismatch.zarr" + worker = "worker-a" + part = zarr_out / "_parts" / f"shard-a__w{worker_token_for(worker)}.zarr" + _write_part_zarr( + part, + { + "data/action": np.asarray([[0.0], [1.0]], dtype=np.float32), + "meta/episode_ends": np.asarray([1], dtype=np.int64), + }, + ) + runtime = _FinalizedWorkersRuntime( + [ + FinalizedShardWorker( + shard_id="shard-a", + worker_id=worker, + global_ordinal=0, + ) + ] + ) + with set_active_run_context( + job_id="local", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast(RuntimeLifecycle, runtime), + ): + with pytest.raises(ValueError, match="episode_ends final value"): + ZarrReducerSink( + str(zarr_out), + store_template="{shard_id}__w{worker_id}.zarr", + episode_ends_path="meta/episode_ends", + array_chunk_bytes=1024, + reduce_to_single_store=True, + ).write_block([DictRow({}, shard_id="reduce")]) + + def test_write_zarr_single_store_rejects_missing_finalized_part( tmp_path: Path, ) -> None: @@ -1626,6 +1721,77 @@ def test_write_zarr_single_store_replace_clears_root_attrs( assert dict(root.attrs) == {} +def test_write_zarr_single_store_rejects_part_attr_drift_before_clearing( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "single-attr-drift.zarr" + _write_part_zarr( + zarr_out, + {"data/action": np.asarray([[99.0]], dtype=np.float32)}, + ) + root = _open_test_zarr(zarr_out, mode="r+") + root.attrs["task"] = "old" + + first_worker = "worker-a" + second_worker = "worker-b" + first_part = ( + zarr_out / "_parts" / f"shard-a__w{worker_token_for(first_worker)}.zarr" + ) + second_part = ( + zarr_out / "_parts" / f"shard-b__w{worker_token_for(second_worker)}.zarr" + ) + _write_part_zarr( + first_part, + { + "data/action": np.asarray([[0.0]], dtype=np.float32), + "meta/episode_ends": np.asarray([1], dtype=np.int64), + }, + ) + _write_part_zarr( + second_part, + { + "data/action": np.asarray([[1.0]], dtype=np.float32), + "meta/episode_ends": np.asarray([1], dtype=np.int64), + }, + ) + _open_test_zarr(first_part, mode="r+").attrs["task"] = "first" + _open_test_zarr(second_part, mode="r+").attrs["task"] = "second" + + runtime = _FinalizedWorkersRuntime( + [ + FinalizedShardWorker( + shard_id="shard-a", + worker_id=first_worker, + global_ordinal=0, + ), + FinalizedShardWorker( + shard_id="shard-b", + worker_id=second_worker, + global_ordinal=1, + ), + ] + ) + with set_active_run_context( + job_id="local", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast(RuntimeLifecycle, runtime), + ): + with pytest.raises(ValueError, match="attrs differ"): + ZarrReducerSink( + str(zarr_out), + store_template="{shard_id}__w{worker_id}.zarr", + episode_ends_path="meta/episode_ends", + array_chunk_bytes=1024, + reduce_to_single_store=True, + ).write_block([DictRow({}, shard_id="reduce")]) + + root = _open_test_zarr(zarr_out, mode="r") + np.testing.assert_allclose(root["data/action"][:], [[99.0]]) + assert dict(root.attrs) == {"task": "old"} + + def test_write_zarr_single_store_rejects_part_dtype_drift( tmp_path: Path, ) -> None: @@ -1732,6 +1898,11 @@ def test_write_zarr_rejects_episode_ends_path_collision(tmp_path: Path) -> None: str(tmp_path / "collision.zarr"), arrays={"meta/episode_ends": "action"}, ) + with pytest.raises(ValueError, match="collides with episode_ends_path"): + ZarrSink( + str(tmp_path / "normalized-collision.zarr"), + arrays={"meta//episode_ends": "action"}, + ) def test_write_zarr_rejects_shape_drift_before_appending_bad_row(