From 3b553e31eb24abd6901d44ba85201e5afb458a0e Mon Sep 17 00:00:00 2001 From: guipenedo Date: Fri, 22 May 2026 01:26:50 +0200 Subject: [PATCH 01/39] Add Zarr reader and writer --- pyproject.toml | 4 + src/refiner/__init__.py | 2 + src/refiner/pipeline/__init__.py | 2 + src/refiner/pipeline/pipeline.py | 51 +++- src/refiner/pipeline/sinks/__init__.py | 2 + src/refiner/pipeline/sinks/zarr.py | 123 +++++++++ src/refiner/pipeline/sources/__init__.py | 2 + .../pipeline/sources/readers/__init__.py | 2 + src/refiner/pipeline/sources/readers/zarr.py | 150 +++++++++++ tests/readers/test_zarr_reader.py | 123 +++++++++ uv.lock | 242 +++++++++++++++++- 11 files changed, 692 insertions(+), 11 deletions(-) create mode 100644 src/refiner/pipeline/sinks/zarr.py create mode 100644 src/refiner/pipeline/sources/readers/zarr.py create mode 100644 tests/readers/test_zarr_reader.py diff --git a/pyproject.toml b/pyproject.toml index b359e587..dc3ba3e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,9 @@ text = [ hdf5 = [ "h5py", ] +zarr = [ + "zarr>=2.18", +] s3 = [ "s3fs", ] @@ -51,6 +54,7 @@ testing = [ "macrodata-refiner[huggingface]", "macrodata-refiner[hdf5]", "macrodata-refiner[robotics]", + "macrodata-refiner[zarr]", "macrodata-refiner[text]", "macrodata-refiner[s3]", "pytest>=8.0.0", diff --git a/src/refiner/__init__.py b/src/refiner/__init__.py index fd298b97..bd74ec19 100644 --- a/src/refiner/__init__.py +++ b/src/refiner/__init__.py @@ -20,6 +20,7 @@ read_lerobot, read_parquet, read_videos, + read_zarr, SUPPORTED_CUDA_VERSIONS, SUPPORTED_GPU_TYPES, task, @@ -54,6 +55,7 @@ "read_lerobot", "read_parquet", "read_videos", + "read_zarr", "from_items", "from_source", "task", diff --git a/src/refiner/pipeline/__init__.py b/src/refiner/pipeline/__init__.py index de41089d..a37de383 100644 --- a/src/refiner/pipeline/__init__.py +++ b/src/refiner/pipeline/__init__.py @@ -13,6 +13,7 @@ read_lerobot, read_parquet, read_videos, + read_zarr, task, ) from refiner.pipeline.resources import ( @@ -41,6 +42,7 @@ "read_lerobot", "read_parquet", "read_videos", + "read_zarr", "from_items", "from_source", "task", diff --git a/src/refiner/pipeline/pipeline.py b/src/refiner/pipeline/pipeline.py index 611b0f6f..20e2cae2 100644 --- a/src/refiner/pipeline/pipeline.py +++ b/src/refiner/pipeline/pipeline.py @@ -30,7 +30,7 @@ VectorizedSegmentStep, WithColumnsStep, ) -from refiner.pipeline.sinks import BaseSink, JsonlSink, ParquetSink +from refiner.pipeline.sinks import BaseSink, JsonlSink, ParquetSink, ZarrSink from refiner.pipeline.sinks.assets import MissingAssetPolicy from refiner.pipeline.sources import ( BaseSource, @@ -40,9 +40,14 @@ Hdf5Reader, JsonReader, ParquetReader, + ZarrReader, ) from refiner.pipeline.sources.readers.lerobot import LeRobotEpisodeReader from refiner.pipeline.sources.readers.hdf5 import MissingPolicy, PathSelection +from refiner.pipeline.sources.readers.zarr import ( + MissingPolicy as ZarrMissingPolicy, + PathSelection as ZarrPathSelection, +) from refiner.pipeline.sources.items import ItemsSource from refiner.pipeline.sources.task import TaskSource from refiner.pipeline.data import datatype @@ -431,6 +436,23 @@ def write_parquet( ) ) + def write_zarr( + self, + output: DataFolderLike, + *, + arrays: Mapping[str, str] | None = None, + episode_ends_path: str | None = "meta/episode_ends", + overwrite: bool = True, + ) -> "RefinerPipeline": + return self.with_sink( + ZarrSink( + output=output, + arrays=arrays, + episode_ends_path=episode_ends_path, + overwrite=overwrite, + ) + ) + def __iter__(self) -> Iterator[Row]: return iter(self.iter_rows()) @@ -809,6 +831,33 @@ def read_hdf5( ) +def read_zarr( + input: DataFolderLike, + *, + arrays: ZarrPathSelection | None = None, + attrs: ZarrPathSelection | None = None, + file_path_column: str | None = "file_path", + missing_policy: ZarrMissingPolicy = "error", + dtypes: DTypeMapping | None = None, +) -> RefinerPipeline: + """Create a pipeline with a Zarr reader source. + + The reader emits one row for the Zarr group. Select arrays with `arrays` + and pass `episode_ends_key` to `to_robot_rows(...)` for Diffusion + Policy-style dataset arrays. + """ + return RefinerPipeline( + source=ZarrReader( + input, + arrays=arrays, + attrs=attrs, + file_path_column=file_path_column, + missing_policy=missing_policy, + dtypes=dtypes, + ) + ) + + def read_parquet( inputs: DataFileSetLike, *, diff --git a/src/refiner/pipeline/sinks/__init__.py b/src/refiner/pipeline/sinks/__init__.py index 8c1f26df..f0623f21 100644 --- a/src/refiner/pipeline/sinks/__init__.py +++ b/src/refiner/pipeline/sinks/__init__.py @@ -2,6 +2,7 @@ from refiner.pipeline.sinks.jsonl import JsonlSink from refiner.pipeline.sinks.parquet import ParquetSink from refiner.pipeline.sinks.reducer import FileCleanupReducerSink, LeRobotMetaReduceSink +from refiner.pipeline.sinks.zarr import ZarrSink __all__ = [ "BaseSink", @@ -10,4 +11,5 @@ "JsonlSink", "LeRobotMetaReduceSink", "ParquetSink", + "ZarrSink", ] diff --git a/src/refiner/pipeline/sinks/zarr.py b/src/refiner/pipeline/sinks/zarr.py new file mode 100644 index 00000000..e3beb498 --- /dev/null +++ b/src/refiner/pipeline/sinks/zarr.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +from collections.abc import Iterable, Mapping +from typing import Any, cast + +import numpy as np +import pyarrow as pa + +from refiner.io.datafolder import DataFolder, DataFolderLike +from refiner.pipeline.data.block import Block +from refiner.pipeline.data.row import Row +from refiner.pipeline.data.tabular import Tabular +from refiner.pipeline.sinks.base import BaseSink +from refiner.robotics.row import RoboticsRow +from refiner.utils import check_required_dependencies + + +class ZarrSink(BaseSink): + def __init__( + self, + output: DataFolderLike, + *, + arrays: Mapping[str, str] | None = None, + episode_ends_path: str | None = "meta/episode_ends", + overwrite: bool = True, + ): + self.output = DataFolder.resolve(output) + self.arrays = dict(arrays) if arrays is not None else None + self.episode_ends_path = episode_ends_path + self.overwrite = overwrite + self._chunks: dict[str, list[np.ndarray]] = {} + self._episode_ends: list[int] = [] + + def write_shard_block(self, shard_id: str, block: Block) -> int: + del shard_id + rows = list(block) if isinstance(block, Tabular) else block + for row in rows: + self._write_row(row) + return len(rows) + + def _write_row(self, row: Row) -> None: + arrays = self.arrays or _default_robotics_arrays(row) + lengths: list[int] = [] + for zarr_path, source_key in arrays.items(): + value = _row_value(row, source_key) + if value is None: + continue + array = _as_array(value) + if array.ndim == 0: + array = array.reshape(1) + lengths.append(int(array.shape[0])) + self._chunks.setdefault(zarr_path, []).append(array) + if lengths and self.episode_ends_path is not None: + length = lengths[0] + if any(item != length for item in lengths): + raise ValueError("Zarr arrays for one row must have matching lengths") + end = (self._episode_ends[-1] if self._episode_ends else 0) + length + self._episode_ends.append(end) + + def close(self) -> None: + if not self._chunks and not self._episode_ends: + return + check_required_dependencies("write_zarr", ["zarr"], dist="zarr") + import zarr + + mode = "w" if self.overwrite else "w-" + root = zarr.open_group(self.output.abs_path(), mode=mode) + for path, chunks in self._chunks.items(): + root.create_dataset(path, data=np.concatenate(chunks, axis=0)) + if self.episode_ends_path is not None: + root.create_dataset( + self.episode_ends_path, + data=np.asarray(self._episode_ends, dtype=np.int64), + ) + + def describe(self) -> tuple[str, str, dict[str, object]]: + return ( + "write_zarr", + "writer", + { + "path": self.output.abs_path(), + "arrays": dict(self.arrays) if self.arrays is not None else None, + "episode_ends_path": self.episode_ends_path, + "overwrite": self.overwrite, + }, + ) + + +def _default_robotics_arrays(row: Row) -> dict[str, str]: + if not isinstance(row, RoboticsRow): + raise ValueError("write_zarr requires arrays=... for non-RoboticsRow inputs") + arrays: dict[str, str] = {} + if row.actions is not None: + arrays["data/action"] = "action" + if row.states is not None: + arrays["data/observation.state"] = "observation.state" + if row.timestamps is not None: + arrays["data/timestamp"] = "timestamp" + return arrays + + +def _row_value(row: Row, key: str) -> Any: + if isinstance(row, RoboticsRow): + if key == "action": + return row.actions + if key == "observation.state": + return row.states + if key == "timestamp": + return row.timestamps + if key.startswith("observation."): + return row.observations(key) + return row[key] + + +def _as_array(value: Any) -> np.ndarray: + if isinstance(value, pa.ChunkedArray | pa.Array): + return np.asarray(value.to_pylist()) + if isinstance(value, Iterable) and not isinstance(value, str | bytes | np.ndarray): + return np.asarray(list(cast(Iterable[Any], value))) + return np.asarray(value) + + +__all__ = ["ZarrSink"] diff --git a/src/refiner/pipeline/sources/__init__.py b/src/refiner/pipeline/sources/__init__.py index 4faa9083..d9aa4ee6 100644 --- a/src/refiner/pipeline/sources/__init__.py +++ b/src/refiner/pipeline/sources/__init__.py @@ -8,6 +8,7 @@ JsonReader, LeRobotEpisodeReader, ParquetReader, + ZarrReader, ) __all__ = [ @@ -20,4 +21,5 @@ "JsonReader", "LeRobotEpisodeReader", "ParquetReader", + "ZarrReader", ] diff --git a/src/refiner/pipeline/sources/readers/__init__.py b/src/refiner/pipeline/sources/readers/__init__.py index 1304d351..00da8fe1 100644 --- a/src/refiner/pipeline/sources/readers/__init__.py +++ b/src/refiner/pipeline/sources/readers/__init__.py @@ -6,6 +6,7 @@ from refiner.pipeline.sources.readers.json import JsonReader from refiner.pipeline.sources.readers.lerobot import LeRobotEpisodeReader from refiner.pipeline.sources.readers.parquet import ParquetReader +from refiner.pipeline.sources.readers.zarr import ZarrReader from refiner.robotics.lerobot_format import LeRobotRow __all__ = [ @@ -18,4 +19,5 @@ "LeRobotEpisodeReader", "LeRobotRow", "ParquetReader", + "ZarrReader", ] diff --git a/src/refiner/pipeline/sources/readers/zarr.py b/src/refiner/pipeline/sources/readers/zarr.py new file mode 100644 index 00000000..3b670210 --- /dev/null +++ b/src/refiner/pipeline/sources/readers/zarr.py @@ -0,0 +1,150 @@ +from __future__ import annotations + +from collections.abc import Iterator, Mapping, Sequence +from typing import Any, Literal, cast + +import pyarrow as pa + +from refiner.io.datafolder import DataFolder, DataFolderLike +from refiner.pipeline.data.datatype import ( + DTypeMapping, + dtype_to_plan, + schema_with_dtypes, +) +from refiner.pipeline.data.row import DictRow +from refiner.pipeline.data.shard import Shard +from refiner.pipeline.sources.base import BaseSource, SourceUnit +from refiner.utils import check_required_dependencies + +MissingPolicy = Literal["error", "drop_row", "set_null"] +PathSelection = Mapping[str, str] | Sequence[str] | str + + +def _selection_map(value: PathSelection | None) -> dict[str, str]: + if value is None: + return {} + if isinstance(value, str): + return {value.rsplit("/", 1)[-1]: value} + if isinstance(value, Mapping): + return dict(cast(Mapping[str, str], value)) + out: dict[str, str] = {} + for path in value: + name = path.rsplit("/", 1)[-1] + if name in out: + raise ValueError( + "Zarr path selections must have unique derived column names; " + f"use an explicit mapping for duplicate name {name!r}" + ) + out[name] = path + return out + + +def _decode_value(value: Any) -> Any: + if hasattr(value, "shape") and value.shape == (): + return _decode_value(value.item()) + if isinstance(value, bytes): + try: + return value.decode("utf-8") + except UnicodeDecodeError: + return value + return value + + +class ZarrReader(BaseSource): + """Read one Zarr group as one row of selected arrays and attributes.""" + + name = "read_zarr" + + def __init__( + self, + input: DataFolderLike, + *, + arrays: PathSelection | None = None, + attrs: PathSelection | None = None, + file_path_column: str | None = "file_path", + missing_policy: MissingPolicy = "error", + dtypes: DTypeMapping | None = None, + ): + self.root = DataFolder.resolve(input) + self.arrays = _selection_map(arrays) + self.attrs = _selection_map(attrs) + self.file_path_column = file_path_column + self.missing_policy = missing_policy + self.dtypes = dtypes + if missing_policy not in ("error", "drop_row", "set_null"): + raise ValueError( + "missing_policy must be one of 'error', 'drop_row', or 'set_null'" + ) + + @property + def schema(self) -> pa.Schema | None: + return schema_with_dtypes(None, self.dtypes) + + def describe(self) -> dict[str, Any]: + return { + "path": self.root.abs_path(), + "arrays": dict(self.arrays), + "attrs": dict(self.attrs), + "file_path_column": self.file_path_column, + "missing_policy": self.missing_policy, + "dtypes": ( + {key: dtype_to_plan(dtype) for key, dtype in self.dtypes.items()} + if self.dtypes + else None + ), + } + + def list_shards(self) -> list[Shard]: + path = self.root.abs_path() + return [ + Shard.from_row_range( + start=0, + end=1, + global_ordinal=0, + start_key=path, + end_key=path, + ) + ] + + def read_shard(self, shard: Shard) -> Iterator[SourceUnit]: + del shard + check_required_dependencies("read_zarr", ["zarr"], dist="zarr") + import zarr + + group = zarr.open_group(self.root.abs_path(), mode="r") + arrays = self.arrays or {path: path for path in _iter_array_paths(group)} + row: dict[str, Any] = {} + if self.file_path_column is not None: + row[self.file_path_column] = self.root.abs_path() + for output_name, path in arrays.items(): + try: + row[output_name] = group[path][:] + except KeyError: + if self.missing_policy == "drop_row": + return + if self.missing_policy == "set_null": + row[output_name] = None + continue + raise KeyError(f"Zarr array not found: {path}") from None + for output_name, attr_name in self.attrs.items(): + if attr_name not in group.attrs: + if self.missing_policy == "drop_row": + return + if self.missing_policy == "set_null": + row[output_name] = None + continue + raise KeyError(f"Zarr attr not found: {attr_name}") + row[output_name] = _decode_value(group.attrs[attr_name]) + yield DictRow(row) + + +def _iter_array_paths(group: Any, prefix: str = "") -> Iterator[str]: + for name, item in group.items(): + path = f"{prefix}/{name}" if prefix else name + if hasattr(item, "shape"): + yield path + else: + yield from _iter_array_paths(item, path) + + +__all__ = ["MissingPolicy", "PathSelection", "ZarrReader"] diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py new file mode 100644 index 00000000..a74372f3 --- /dev/null +++ b/tests/readers/test_zarr_reader.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +from pathlib import Path +from typing import cast + +import numpy as np +import zarr + +import refiner as mdr +from refiner.robotics.row import RoboticsRow + + +def _write_policy_zarr(path: Path) -> None: + root = zarr.open_group(str(path), mode="w") + root.create_dataset( + "data/action", + data=np.asarray([[0.0], [0.1], [1.0], [1.1], [1.2]], dtype=np.float32), + ) + root.create_dataset( + "data/state", + data=np.asarray([[10.0], [10.1], [20.0], [20.1], [20.2]], dtype=np.float32), + ) + root.create_dataset( + "data/rgb", + data=np.arange(5 * 4 * 4 * 3, dtype=np.uint8).reshape(5, 4, 4, 3), + ) + root.create_dataset("meta/episode_ends", data=np.asarray([2, 5], dtype=np.int64)) + root.attrs["dataset_id"] = "pusht" + root.attrs["task"] = "push tee" + + +def test_read_zarr_reads_selected_arrays_and_attrs(tmp_path: Path) -> None: + path = tmp_path / "policy.zarr" + _write_policy_zarr(path) + + row = mdr.read_zarr( + path, + arrays={ + "action": "data/action", + "state": "data/state", + "episode_ends": "meta/episode_ends", + }, + attrs={"task": "task"}, + file_path_column=None, + ).take(1)[0] + + assert row["task"] == "push tee" + assert row["episode_ends"].tolist() == [2, 5] + np.testing.assert_allclose(row["action"][:2], [[0.0], [0.1]]) + + +def test_zarr_to_robot_rows_and_lerobot_roundtrip(tmp_path: Path) -> None: + path = tmp_path / "policy.zarr" + lerobot_out = tmp_path / "lerobot" + zarr_out = tmp_path / "roundtrip.zarr" + _write_policy_zarr(path) + + ( + mdr.read_zarr( + path, + arrays={ + "action": "data/action", + "observation.state": "data/state", + "frames": "data/rgb", + "episode_ends": "meta/episode_ends", + }, + attrs={"dataset_id": "dataset_id", "task": "task"}, + file_path_column=None, + ) + .to_robot_rows( + episode_id_key="dataset_id", + task_key="task", + episode_ends_key="episode_ends", + action_key="action", + state_key="observation.state", + video_keys={"observation.images.front": "frames"}, + fps=10, + robot_type="pusht", + ) + .write_lerobot(str(lerobot_out), max_video_prepare_in_flight=1) + .launch_local( + name="zarr-to-lerobot", num_workers=1, rundir=str(tmp_path / "run1") + ) + ) + + episodes = [ + cast(RoboticsRow, row) + for row in mdr.read_lerobot(str(lerobot_out)).materialize() + ] + assert [episode.num_frames for episode in episodes] == [2, 3] + assert [episode.task for episode in episodes] == ["push tee", "push tee"] + + ( + mdr.read_lerobot(str(lerobot_out)) + .write_zarr( + str(zarr_out), + arrays={ + "data/action": "action", + "data/state": "observation.state", + }, + ) + .launch_local( + name="lerobot-to-zarr", num_workers=1, rundir=str(tmp_path / "run2") + ) + ) + + row = mdr.read_zarr( + zarr_out, + arrays={ + "action": "data/action", + "state": "data/state", + "episode_ends": "meta/episode_ends", + }, + file_path_column=None, + ).take(1)[0] + + assert row["episode_ends"].tolist() == [2, 5] + np.testing.assert_allclose( + row["action"], np.asarray([[0.0], [0.1], [1.0], [1.1], [1.2]]) + ) + np.testing.assert_allclose( + row["state"], np.asarray([[10.0], [10.1], [20.0], [20.1], [20.2]]) + ) diff --git a/uv.lock b/uv.lock index 30083db1..731b1849 100644 --- a/uv.lock +++ b/uv.lock @@ -5,9 +5,12 @@ resolution-markers = [ "python_full_version >= '3.14' and sys_platform == 'win32'", "python_full_version >= '3.14' and sys_platform == 'emscripten'", "python_full_version >= '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32'", - "python_full_version >= '3.11' and python_full_version < '3.14' and sys_platform == 'win32'", - "python_full_version >= '3.11' and python_full_version < '3.14' and sys_platform == 'emscripten'", - "python_full_version >= '3.11' and python_full_version < '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform == 'win32'", + "python_full_version == '3.11.*' and sys_platform == 'win32'", + "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform == 'emscripten'", + "python_full_version == '3.11.*' and sys_platform == 'emscripten'", + "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version == '3.11.*' and sys_platform != 'emscripten' and sys_platform != 'win32'", "python_full_version < '3.11'", ] @@ -204,6 +207,12 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/38/0e/27be9fdef66e72d64c0cdc3cc2823101b80585f8119b5c112c2e8f5f7dab/anyio-4.12.1-py3-none-any.whl", hash = "sha256:d405828884fc140aa80a3c667b8beed277f1dfedec42ba031bd6ac3db606ab6c", size = 113592, upload-time = "2026-01-06T11:45:19.497Z" }, ] +[[package]] +name = "asciitree" +version = "0.3.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/2d/6a/885bc91484e1aa8f618f6f0228d76d0e67000b0fdd6090673b777e311913/asciitree-0.3.3.tar.gz", hash = "sha256:4aa4b9b649f85e3fcb343363d97564aa1fb62e249677f2e18a96765145cc0f6e", size = 3951, upload-time = "2016-09-05T19:10:42.681Z" } + [[package]] name = "async-timeout" version = "5.0.1" @@ -609,6 +618,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/33/6b/e0547afaf41bf2c42e52430072fa5658766e3d65bd4b03a563d1b6336f57/distlib-0.4.0-py2.py3-none-any.whl", hash = "sha256:9659f7d87e46584a30b5780e43ac7a2143098441670ff0a49d5f9034c54a6c16", size = 469047, upload-time = "2025-07-17T16:51:58.613Z" }, ] +[[package]] +name = "donfig" +version = "0.8.1.post1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyyaml", marker = "python_full_version >= '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/25/71/80cc718ff6d7abfbabacb1f57aaa42e9c1552bfdd01e64ddd704e4a03638/donfig-0.8.1.post1.tar.gz", hash = "sha256:3bef3413a4c1c601b585e8d297256d0c1470ea012afa6e8461dc28bfb7c23f52", size = 19506, upload-time = "2024-05-23T14:14:31.513Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0c/d5/c5db1ea3394c6e1732fb3286b3bd878b59507a8f77d32a2cebda7d7b7cd4/donfig-0.8.1.post1-py3-none-any.whl", hash = "sha256:2a3175ce74a06109ff9307d90a230f81215cbac9a751f4d1c6194644b8204f9d", size = 21592, upload-time = "2024-05-23T14:13:55.283Z" }, +] + [[package]] name = "exceptiongroup" version = "1.3.1" @@ -621,6 +642,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8a/0e/97c33bf5009bdbac74fd2beace167cab3f978feb69cc36f1ef79360d6c4e/exceptiongroup-1.3.1-py3-none-any.whl", hash = "sha256:a7a39a3bd276781e98394987d3a5701d0c4edffb633bb7a5144577f82c773598", size = 16740, upload-time = "2025-11-21T23:01:53.443Z" }, ] +[[package]] +name = "fasteners" +version = "0.20" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/2d/18/7881a99ba5244bfc82f06017316ffe93217dbbbcfa52b887caa1d4f2a6d3/fasteners-0.20.tar.gz", hash = "sha256:55dce8792a41b56f727ba6e123fcaee77fd87e638a6863cec00007bfea84c8d8", size = 25087, upload-time = "2025-08-11T10:19:37.785Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/51/ac/e5d886f892666d2d1e5cb8c1a41146e1d79ae8896477b1153a21711d3b44/fasteners-0.20-py3-none-any.whl", hash = "sha256:9422c40d1e350e4259f509fb2e608d6bc43c0136f79a00db1b49046029d0b3b7", size = 18702, upload-time = "2025-08-11T10:19:35.716Z" }, +] + [[package]] name = "filelock" version = "3.25.2" @@ -765,6 +795,41 @@ http = [ { name = "aiohttp" }, ] +[[package]] +name = "google-crc32c" +version = "1.8.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/03/41/4b9c02f99e4c5fb477122cd5437403b552873f014616ac1d19ac8221a58d/google_crc32c-1.8.0.tar.gz", hash = "sha256:a428e25fb7691024de47fecfbff7ff957214da51eddded0da0ae0e0f03a2cf79", size = 14192, upload-time = "2025-12-16T00:35:25.142Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/95/ac/6f7bc93886a823ab545948c2dd48143027b2355ad1944c7cf852b338dc91/google_crc32c-1.8.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:0470b8c3d73b5f4e3300165498e4cf25221c7eb37f1159e221d1825b6df8a7ff", size = 31296, upload-time = "2025-12-16T00:19:07.261Z" }, + { url = "https://files.pythonhosted.org/packages/f7/97/a5accde175dee985311d949cfcb1249dcbb290f5ec83c994ea733311948f/google_crc32c-1.8.0-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:119fcd90c57c89f30040b47c211acee231b25a45d225e3225294386f5d258288", size = 30870, upload-time = "2025-12-16T00:29:17.669Z" }, + { url = "https://files.pythonhosted.org/packages/3d/63/bec827e70b7a0d4094e7476f863c0dbd6b5f0f1f91d9c9b32b76dcdfeb4e/google_crc32c-1.8.0-cp310-cp310-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:6f35aaffc8ccd81ba3162443fabb920e65b1f20ab1952a31b13173a67811467d", size = 33214, upload-time = "2025-12-16T00:40:19.618Z" }, + { url = "https://files.pythonhosted.org/packages/63/bc/11b70614df04c289128d782efc084b9035ef8466b3d0a8757c1b6f5cf7ac/google_crc32c-1.8.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:864abafe7d6e2c4c66395c1eb0fe12dc891879769b52a3d56499612ca93b6092", size = 33589, upload-time = "2025-12-16T00:40:20.7Z" }, + { url = "https://files.pythonhosted.org/packages/3e/00/a08a4bc24f1261cc5b0f47312d8aebfbe4b53c2e6307f1b595605eed246b/google_crc32c-1.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:db3fe8eaf0612fc8b20fa21a5f25bd785bc3cd5be69f8f3412b0ac2ffd49e733", size = 34437, upload-time = "2025-12-16T00:35:19.437Z" }, + { url = "https://files.pythonhosted.org/packages/5d/ef/21ccfaab3d5078d41efe8612e0ed0bfc9ce22475de074162a91a25f7980d/google_crc32c-1.8.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:014a7e68d623e9a4222d663931febc3033c5c7c9730785727de2a81f87d5bab8", size = 31298, upload-time = "2025-12-16T00:20:32.241Z" }, + { url = "https://files.pythonhosted.org/packages/c5/b8/f8413d3f4b676136e965e764ceedec904fe38ae8de0cdc52a12d8eb1096e/google_crc32c-1.8.0-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:86cfc00fe45a0ac7359e5214a1704e51a99e757d0272554874f419f79838c5f7", size = 30872, upload-time = "2025-12-16T00:33:58.785Z" }, + { url = "https://files.pythonhosted.org/packages/f6/fd/33aa4ec62b290477181c55bb1c9302c9698c58c0ce9a6ab4874abc8b0d60/google_crc32c-1.8.0-cp311-cp311-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:19b40d637a54cb71e0829179f6cb41835f0fbd9e8eb60552152a8b52c36cbe15", size = 33243, upload-time = "2025-12-16T00:40:21.46Z" }, + { url = "https://files.pythonhosted.org/packages/71/03/4820b3bd99c9653d1a5210cb32f9ba4da9681619b4d35b6a052432df4773/google_crc32c-1.8.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:17446feb05abddc187e5441a45971b8394ea4c1b6efd88ab0af393fd9e0a156a", size = 33608, upload-time = "2025-12-16T00:40:22.204Z" }, + { url = "https://files.pythonhosted.org/packages/7c/43/acf61476a11437bf9733fb2f70599b1ced11ec7ed9ea760fdd9a77d0c619/google_crc32c-1.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:71734788a88f551fbd6a97be9668a0020698e07b2bf5b3aa26a36c10cdfb27b2", size = 34439, upload-time = "2025-12-16T00:35:20.458Z" }, + { url = "https://files.pythonhosted.org/packages/e9/5f/7307325b1198b59324c0fa9807cafb551afb65e831699f2ce211ad5c8240/google_crc32c-1.8.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:4b8286b659c1335172e39563ab0a768b8015e88e08329fa5321f774275fc3113", size = 31300, upload-time = "2025-12-16T00:21:56.723Z" }, + { url = "https://files.pythonhosted.org/packages/21/8e/58c0d5d86e2220e6a37befe7e6a94dd2f6006044b1a33edf1ff6d9f7e319/google_crc32c-1.8.0-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:2a3dc3318507de089c5384cc74d54318401410f82aa65b2d9cdde9d297aca7cb", size = 30867, upload-time = "2025-12-16T00:38:31.302Z" }, + { url = "https://files.pythonhosted.org/packages/ce/a9/a780cc66f86335a6019f557a8aaca8fbb970728f0efd2430d15ff1beae0e/google_crc32c-1.8.0-cp312-cp312-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:14f87e04d613dfa218d6135e81b78272c3b904e2a7053b841481b38a7d901411", size = 33364, upload-time = "2025-12-16T00:40:22.96Z" }, + { url = "https://files.pythonhosted.org/packages/21/3f/3457ea803db0198c9aaca2dd373750972ce28a26f00544b6b85088811939/google_crc32c-1.8.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cb5c869c2923d56cb0c8e6bcdd73c009c36ae39b652dbe46a05eb4ef0ad01454", size = 33740, upload-time = "2025-12-16T00:40:23.96Z" }, + { url = "https://files.pythonhosted.org/packages/df/c0/87c2073e0c72515bb8733d4eef7b21548e8d189f094b5dad20b0ecaf64f6/google_crc32c-1.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:3cc0c8912038065eafa603b238abf252e204accab2a704c63b9e14837a854962", size = 34437, upload-time = "2025-12-16T00:35:21.395Z" }, + { url = "https://files.pythonhosted.org/packages/d1/db/000f15b41724589b0e7bc24bc7a8967898d8d3bc8caf64c513d91ef1f6c0/google_crc32c-1.8.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:3ebb04528e83b2634857f43f9bb8ef5b2bbe7f10f140daeb01b58f972d04736b", size = 31297, upload-time = "2025-12-16T00:23:20.709Z" }, + { url = "https://files.pythonhosted.org/packages/d7/0d/8ebed0c39c53a7e838e2a486da8abb0e52de135f1b376ae2f0b160eb4c1a/google_crc32c-1.8.0-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:450dc98429d3e33ed2926fc99ee81001928d63460f8538f21a5d6060912a8e27", size = 30867, upload-time = "2025-12-16T00:43:14.628Z" }, + { url = "https://files.pythonhosted.org/packages/ce/42/b468aec74a0354b34c8cbf748db20d6e350a68a2b0912e128cabee49806c/google_crc32c-1.8.0-cp313-cp313-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:3b9776774b24ba76831609ffbabce8cdf6fa2bd5e9df37b594221c7e333a81fa", size = 33344, upload-time = "2025-12-16T00:40:24.742Z" }, + { url = "https://files.pythonhosted.org/packages/1c/e8/b33784d6fc77fb5062a8a7854e43e1e618b87d5ddf610a88025e4de6226e/google_crc32c-1.8.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:89c17d53d75562edfff86679244830599ee0a48efc216200691de8b02ab6b2b8", size = 33694, upload-time = "2025-12-16T00:40:25.505Z" }, + { url = "https://files.pythonhosted.org/packages/92/b1/d3cbd4d988afb3d8e4db94ca953df429ed6db7282ed0e700d25e6c7bfc8d/google_crc32c-1.8.0-cp313-cp313-win_amd64.whl", hash = "sha256:57a50a9035b75643996fbf224d6661e386c7162d1dfdab9bc4ca790947d1007f", size = 34435, upload-time = "2025-12-16T00:35:22.107Z" }, + { url = "https://files.pythonhosted.org/packages/21/88/8ecf3c2b864a490b9e7010c84fd203ec8cf3b280651106a3a74dd1b0ca72/google_crc32c-1.8.0-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:e6584b12cb06796d285d09e33f63309a09368b9d806a551d8036a4207ea43697", size = 31301, upload-time = "2025-12-16T00:24:48.527Z" }, + { url = "https://files.pythonhosted.org/packages/36/c6/f7ff6c11f5ca215d9f43d3629163727a272eabc356e5c9b2853df2bfe965/google_crc32c-1.8.0-cp314-cp314-macosx_12_0_x86_64.whl", hash = "sha256:f4b51844ef67d6cf2e9425983274da75f18b1597bb2c998e1c0a0e8d46f8f651", size = 30868, upload-time = "2025-12-16T00:48:12.163Z" }, + { url = "https://files.pythonhosted.org/packages/56/15/c25671c7aad70f8179d858c55a6ae8404902abe0cdcf32a29d581792b491/google_crc32c-1.8.0-cp314-cp314-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:b0d1a7afc6e8e4635564ba8aa5c0548e3173e41b6384d7711a9123165f582de2", size = 33381, upload-time = "2025-12-16T00:40:26.268Z" }, + { url = "https://files.pythonhosted.org/packages/42/fa/f50f51260d7b0ef5d4898af122d8a7ec5a84e2984f676f746445f783705f/google_crc32c-1.8.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8b3f68782f3cbd1bce027e48768293072813469af6a61a86f6bb4977a4380f21", size = 33734, upload-time = "2025-12-16T00:40:27.028Z" }, + { url = "https://files.pythonhosted.org/packages/08/a5/7b059810934a09fb3ccb657e0843813c1fee1183d3bc2c8041800374aa2c/google_crc32c-1.8.0-cp314-cp314-win_amd64.whl", hash = "sha256:d511b3153e7011a27ab6ee6bb3a5404a55b994dc1a7322c0b87b29606d9790e2", size = 34878, upload-time = "2025-12-16T00:35:23.142Z" }, + { url = "https://files.pythonhosted.org/packages/52/c5/c171e4d8c44fec1422d801a6d2e5d7ddabd733eeda505c79730ee9607f07/google_crc32c-1.8.0-pp311-pypy311_pp73-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:87fa445064e7db928226b2e6f0d5304ab4cd0339e664a4e9a25029f384d9bb93", size = 28615, upload-time = "2025-12-16T00:40:29.298Z" }, + { url = "https://files.pythonhosted.org/packages/9c/97/7d75fe37a7a6ed171a2cf17117177e7aab7e6e0d115858741b41e9dd4254/google_crc32c-1.8.0-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:f639065ea2042d5c034bf258a9f085eaa7af0cd250667c0635a3118e8f92c69c", size = 28800, upload-time = "2025-12-16T00:40:30.322Z" }, +] + [[package]] name = "h11" version = "0.16.0" @@ -1001,6 +1066,9 @@ all = [ { name = "pytest-cov" }, { name = "s3fs" }, { name = "warcio" }, + { name = "zarr", version = "2.18.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "zarr", version = "3.1.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.11.*'" }, + { name = "zarr", version = "3.2.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, ] hdf5 = [ { name = "h5py" }, @@ -1026,6 +1094,9 @@ testing = [ { name = "pytest-cov" }, { name = "s3fs" }, { name = "warcio" }, + { name = "zarr", version = "2.18.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "zarr", version = "3.1.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.11.*'" }, + { name = "zarr", version = "3.2.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, ] text = [ { name = "warcio" }, @@ -1033,6 +1104,11 @@ text = [ video = [ { name = "av" }, ] +zarr = [ + { name = "zarr", version = "2.18.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "zarr", version = "3.1.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.11.*'" }, + { name = "zarr", version = "3.2.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, +] [package.dev-dependencies] dev = [ @@ -1060,6 +1136,7 @@ requires-dist = [ { name = "macrodata-refiner", extras = ["testing"], marker = "extra == 'all'" }, { name = "macrodata-refiner", extras = ["text"], marker = "extra == 'testing'" }, { name = "macrodata-refiner", extras = ["video"], marker = "extra == 'robotics'" }, + { name = "macrodata-refiner", extras = ["zarr"], marker = "extra == 'testing'" }, { name = "msgspec", specifier = ">=0.20.0" }, { name = "numpy" }, { name = "orjson" }, @@ -1068,8 +1145,9 @@ requires-dist = [ { name = "pytest-cov", marker = "extra == 'testing'", specifier = ">=5.0.0" }, { name = "s3fs", marker = "extra == 's3'" }, { name = "warcio", marker = "extra == 'text'" }, + { name = "zarr", marker = "extra == 'zarr'", specifier = ">=2.18" }, ] -provides-extras = ["huggingface", "video", "robotics", "text", "hdf5", "s3", "testing", "all"] +provides-extras = ["huggingface", "video", "robotics", "text", "hdf5", "zarr", "s3", "testing", "all"] [package.metadata.requires-dev] dev = [ @@ -1326,6 +1404,79 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/88/b2/d0896bdcdc8d28a7fc5717c305f1a861c26e18c05047949fb371034d98bd/nodeenv-1.10.0-py2.py3-none-any.whl", hash = "sha256:5bb13e3eed2923615535339b3c620e76779af4cb4c6a90deccc9e36b274d3827", size = 23438, upload-time = "2025-12-20T14:08:52.782Z" }, ] +[[package]] +name = "numcodecs" +version = "0.13.1" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version < '3.11'", +] +dependencies = [ + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/85/56/8895a76abe4ec94ebd01eeb6d74f587bc4cddd46569670e1402852a5da13/numcodecs-0.13.1.tar.gz", hash = "sha256:a3cf37881df0898f3a9c0d4477df88133fe85185bffe57ba31bcc2fa207709bc", size = 5955215, upload-time = "2024-10-09T16:28:00.188Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/14/c0/6d72cde772bcec196b7188731d41282993b2958440f77fdf0db216f722da/numcodecs-0.13.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:96add4f783c5ce57cc7e650b6cac79dd101daf887c479a00a29bc1487ced180b", size = 1580012, upload-time = "2024-10-09T16:27:19.069Z" }, + { url = "https://files.pythonhosted.org/packages/94/1d/f81fc1fa9210bbea97258242393a1f9feab4f6d8fb201f81f76003005e4b/numcodecs-0.13.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:237b7171609e868a20fd313748494444458ccd696062f67e198f7f8f52000c15", size = 1176919, upload-time = "2024-10-09T16:27:21.634Z" }, + { url = "https://files.pythonhosted.org/packages/16/e4/b9ec2f4dfc34ecf724bc1beb96a9f6fa9b91801645688ffadacd485089da/numcodecs-0.13.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:96e42f73c31b8c24259c5fac6adba0c3ebf95536e37749dc6c62ade2989dca28", size = 8625842, upload-time = "2024-10-09T16:27:24.168Z" }, + { url = "https://files.pythonhosted.org/packages/fe/90/299952e1477954ec4f92813fa03e743945e3ff711bb4f6c9aace431cb3da/numcodecs-0.13.1-cp310-cp310-win_amd64.whl", hash = "sha256:eda7d7823c9282e65234731fd6bd3986b1f9e035755f7fed248d7d366bb291ab", size = 828638, upload-time = "2024-10-09T16:27:27.063Z" }, + { url = "https://files.pythonhosted.org/packages/f0/78/34b8e869ef143e88d62e8231f4dbfcad85e5c41302a11fc5bd2228a13df5/numcodecs-0.13.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:2eda97dd2f90add98df6d295f2c6ae846043396e3d51a739ca5db6c03b5eb666", size = 1580199, upload-time = "2024-10-09T16:27:29.336Z" }, + { url = "https://files.pythonhosted.org/packages/3b/cf/f70797d86bb585d258d1e6993dced30396f2044725b96ce8bcf87a02be9c/numcodecs-0.13.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2a86f5367af9168e30f99727ff03b27d849c31ad4522060dde0bce2923b3a8bc", size = 1177203, upload-time = "2024-10-09T16:27:31.011Z" }, + { url = "https://files.pythonhosted.org/packages/a8/b5/d14ad69b63fde041153dfd05d7181a49c0d4864de31a7a1093c8370da957/numcodecs-0.13.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:233bc7f26abce24d57e44ea8ebeb5cd17084690b4e7409dd470fdb75528d615f", size = 8868743, upload-time = "2024-10-09T16:27:32.833Z" }, + { url = "https://files.pythonhosted.org/packages/13/d4/27a7b5af0b33f6d61e198faf177fbbf3cb83ff10d9d1a6857b7efc525ad5/numcodecs-0.13.1-cp311-cp311-win_amd64.whl", hash = "sha256:796b3e6740107e4fa624cc636248a1580138b3f1c579160f260f76ff13a4261b", size = 829603, upload-time = "2024-10-09T16:27:35.415Z" }, + { url = "https://files.pythonhosted.org/packages/37/3a/bc09808425e7d3df41e5fc73fc7a802c429ba8c6b05e55f133654ade019d/numcodecs-0.13.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:5195bea384a6428f8afcece793860b1ab0ae28143c853f0b2b20d55a8947c917", size = 1575806, upload-time = "2024-10-09T16:27:37.804Z" }, + { url = "https://files.pythonhosted.org/packages/3a/cc/dc74d0bfdf9ec192332a089d199f1e543e747c556b5659118db7a437dcca/numcodecs-0.13.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3501a848adaddce98a71a262fee15cd3618312692aa419da77acd18af4a6a3f6", size = 1178233, upload-time = "2024-10-09T16:27:40.169Z" }, + { url = "https://files.pythonhosted.org/packages/d4/ce/434e8e3970b8e92ae9ab6d9db16cb9bc7aa1cd02e17c11de6848224100a1/numcodecs-0.13.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da2230484e6102e5fa3cc1a5dd37ca1f92dfbd183d91662074d6f7574e3e8f53", size = 8857827, upload-time = "2024-10-09T16:27:42.743Z" }, + { url = "https://files.pythonhosted.org/packages/83/e7/1d8b1b266a92f9013c755b1c146c5ad71a2bff147ecbc67f86546a2e4d6a/numcodecs-0.13.1-cp312-cp312-win_amd64.whl", hash = "sha256:e5db4824ebd5389ea30e54bc8aeccb82d514d28b6b68da6c536b8fa4596f4bca", size = 826539, upload-time = "2024-10-09T16:27:44.808Z" }, + { url = "https://files.pythonhosted.org/packages/83/8b/06771dead2cc4a8ae1ea9907737cf1c8d37a323392fa28f938a586373468/numcodecs-0.13.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:7a60d75179fd6692e301ddfb3b266d51eb598606dcae7b9fc57f986e8d65cb43", size = 1571660, upload-time = "2024-10-09T16:27:47.125Z" }, + { url = "https://files.pythonhosted.org/packages/f9/ea/d925bf85f92dfe4635356018da9fe4bfecb07b1c72f62b01c1bc47f936b1/numcodecs-0.13.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:3f593c7506b0ab248961a3b13cb148cc6e8355662ff124ac591822310bc55ecf", size = 1169925, upload-time = "2024-10-09T16:27:49.512Z" }, + { url = "https://files.pythonhosted.org/packages/0f/d6/643a3839d571d8e439a2c77dc4b0b8cab18d96ac808e4a81dbe88e959ab6/numcodecs-0.13.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80d3071465f03522e776a31045ddf2cfee7f52df468b977ed3afdd7fe5869701", size = 8814257, upload-time = "2024-10-09T16:27:52.059Z" }, + { url = "https://files.pythonhosted.org/packages/a6/c5/f3e56bc9b4e438a287fff738993d6d11abef368c0328a612ac2842ba9fca/numcodecs-0.13.1-cp313-cp313-win_amd64.whl", hash = "sha256:90d3065ae74c9342048ae0046006f99dcb1388b7288da5a19b3bddf9c30c3176", size = 821887, upload-time = "2024-10-09T16:27:55.039Z" }, +] + +[[package]] +name = "numcodecs" +version = "0.16.5" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.14' and sys_platform == 'win32'", + "python_full_version >= '3.14' and sys_platform == 'emscripten'", + "python_full_version >= '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform == 'win32'", + "python_full_version == '3.11.*' and sys_platform == 'win32'", + "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform == 'emscripten'", + "python_full_version == '3.11.*' and sys_platform == 'emscripten'", + "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version == '3.11.*' and sys_platform != 'emscripten' and sys_platform != 'win32'", +] +dependencies = [ + { name = "numpy", version = "2.4.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "typing-extensions", marker = "python_full_version >= '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/44/bd/8a391e7c356366224734efd24da929cc4796fff468bfb179fe1af6548535/numcodecs-0.16.5.tar.gz", hash = "sha256:0d0fb60852f84c0bd9543cc4d2ab9eefd37fc8efcc410acd4777e62a1d300318", size = 6276387, upload-time = "2025-11-21T02:49:48.986Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/af/85/1ac101a40ead81eaa1c7dc49a8827a30e2e436211b43ebdc63c590eb1347/numcodecs-0.16.5-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:78382dcea50622f2ef1e6e7a71dbe7f861d8fe376b27b7c297c26907304fef1e", size = 1621795, upload-time = "2025-11-21T02:49:17.418Z" }, + { url = "https://files.pythonhosted.org/packages/0e/cc/0d97ef55dda48cb0f93d7b92d761208e7a99bd2eea6b0e859426e6a99a21/numcodecs-0.16.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e2d04a19cb57a3c519b4127ac377cca6471aee1990d7c18f5b1e3a4fe1306689", size = 1153030, upload-time = "2025-11-21T02:49:19.089Z" }, + { url = "https://files.pythonhosted.org/packages/5e/41/e120ee1b390730ac5987cde2afd82e2b8442cec315ab40b94b0373e93e73/numcodecs-0.16.5-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c043af648eb280cd61785c99c22ff5c3c3460f906eb51a8511327c4f5111b283", size = 8510503, upload-time = "2025-11-21T02:49:20.324Z" }, + { url = "https://files.pythonhosted.org/packages/54/4b/195ac84cc8f6077b4f0f421e8daee21b7f1bd88cb7716414234379fe68ec/numcodecs-0.16.5-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c398919ef2eb0e56b8e97456f622640bfd3deed06de3acc976989cbcb22628a3", size = 9123428, upload-time = "2025-11-21T02:49:22.328Z" }, + { url = "https://files.pythonhosted.org/packages/0f/5b/af02c417954f46e5c7bd5163ac251f535877d909fce54861c99ae197f6f6/numcodecs-0.16.5-cp311-cp311-win_amd64.whl", hash = "sha256:3820860ed302d4d84a1c66e70981ff959d5eb712555be4e7d8ced49888594773", size = 801542, upload-time = "2025-11-21T02:49:24.265Z" }, + { url = "https://files.pythonhosted.org/packages/75/cc/55420f3641a67f78392dc0bc5d02cb9eb0a9dcebf2848d1ac77253ca61fa/numcodecs-0.16.5-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:24e675dc8d1550cd976a99479b87d872cb142632c75cc402fea04c08c4898523", size = 1656287, upload-time = "2025-11-21T02:49:25.755Z" }, + { url = "https://files.pythonhosted.org/packages/f5/6c/86644987505dcb90ba6d627d6989c27bafb0699f9fd00187e06d05ea8594/numcodecs-0.16.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:94ddfa4341d1a3ab99989d13b01b5134abb687d3dab2ead54b450aefe4ad5bd6", size = 1148899, upload-time = "2025-11-21T02:49:26.87Z" }, + { url = "https://files.pythonhosted.org/packages/97/1e/98aaddf272552d9fef1f0296a9939d1487914a239e98678f6b20f8b0a5c8/numcodecs-0.16.5-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b554ab9ecf69de7ca2b6b5e8bc696bd9747559cb4dd5127bd08d7a28bec59c3a", size = 8534814, upload-time = "2025-11-21T02:49:28.547Z" }, + { url = "https://files.pythonhosted.org/packages/fb/53/78c98ef5c8b2b784453487f3e4d6c017b20747c58b470393e230c78d18e8/numcodecs-0.16.5-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ad1a379a45bd3491deab8ae6548313946744f868c21d5340116977ea3be5b1d6", size = 9173471, upload-time = "2025-11-21T02:49:30.444Z" }, + { url = "https://files.pythonhosted.org/packages/1c/20/2fdec87fc7f8cec950d2b0bea603c12dc9f05b4966dc5924ba5a36a61bf6/numcodecs-0.16.5-cp312-cp312-win_amd64.whl", hash = "sha256:845a9857886ffe4a3172ba1c537ae5bcc01e65068c31cf1fce1a844bd1da050f", size = 801412, upload-time = "2025-11-21T02:49:32.123Z" }, + { url = "https://files.pythonhosted.org/packages/38/38/071ced5a5fd1c85ba0e14ba721b66b053823e5176298c2f707e50bed11d9/numcodecs-0.16.5-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:25be3a516ab677dad890760d357cfe081a371d9c0a2e9a204562318ac5969de3", size = 1654359, upload-time = "2025-11-21T02:49:33.673Z" }, + { url = "https://files.pythonhosted.org/packages/d1/c0/5f84ba7525577c1b9909fc2d06ef11314825fc4ad4378f61d0e4c9883b4a/numcodecs-0.16.5-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:0107e839ef75b854e969cb577e140b1aadb9847893937636582d23a2a4c6ce50", size = 1144237, upload-time = "2025-11-21T02:49:35.294Z" }, + { url = "https://files.pythonhosted.org/packages/0b/00/787ea5f237b8ea7bc67140c99155f9c00b5baf11c49afc5f3bfefa298f95/numcodecs-0.16.5-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:015a7c859ecc2a06e2a548f64008c0ec3aaecabc26456c2c62f4278d8fc20597", size = 8483064, upload-time = "2025-11-21T02:49:36.454Z" }, + { url = "https://files.pythonhosted.org/packages/c4/e6/d359fdd37498e74d26a167f7a51e54542e642ea47181eb4e643a69a066c3/numcodecs-0.16.5-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:84230b4b9dad2392f2a84242bd6e3e659ac137b5a1ce3571d6965fca673e0903", size = 9126063, upload-time = "2025-11-21T02:49:38.018Z" }, + { url = "https://files.pythonhosted.org/packages/27/72/6663cc0382ddbb866136c255c837bcb96cc7ce5e83562efec55e1b995941/numcodecs-0.16.5-cp313-cp313-win_amd64.whl", hash = "sha256:5088145502ad1ebf677ec47d00eb6f0fd600658217db3e0c070c321c85d6cf3d", size = 799275, upload-time = "2025-11-21T02:49:39.558Z" }, + { url = "https://files.pythonhosted.org/packages/3c/9e/38e7ca8184c958b51f45d56a4aeceb1134ecde2d8bd157efadc98502cc42/numcodecs-0.16.5-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:b05647b8b769e6bc8016e9fd4843c823ce5c9f2337c089fb5c9c4da05e5275de", size = 1654721, upload-time = "2025-11-21T02:49:40.602Z" }, + { url = "https://files.pythonhosted.org/packages/a1/37/260fa42e7b2b08e6e00ad632f8dd620961a60a459426c26cea390f8c68d0/numcodecs-0.16.5-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:3832bd1b5af8bb3e413076b7d93318c8e7d7b68935006b9fa36ca057d1725a8f", size = 1146887, upload-time = "2025-11-21T02:49:41.721Z" }, + { url = "https://files.pythonhosted.org/packages/4e/15/e2e1151b5a8b14a15dfd4bb4abccce7fff7580f39bc34092780088835f3a/numcodecs-0.16.5-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:49f7b7d24f103187f53135bed28bb9f0ed6b2e14c604664726487bb6d7c882e1", size = 8476987, upload-time = "2025-11-21T02:49:43.363Z" }, + { url = "https://files.pythonhosted.org/packages/6d/30/16a57fc4d9fb0ba06c600408bd6634f2f1753c54a7a351c99c5e09b51ee2/numcodecs-0.16.5-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:aec9736d81b70f337d89c4070ee3ffeff113f386fd789492fa152d26a15043e4", size = 9102377, upload-time = "2025-11-21T02:49:45.508Z" }, + { url = "https://files.pythonhosted.org/packages/31/a5/a0425af36c20d55a3ea884db4b4efca25a43bea9214ba69ca7932dd997b4/numcodecs-0.16.5-cp314-cp314-win_amd64.whl", hash = "sha256:b16a14303800e9fb88abc39463ab4706c037647ac17e49e297faa5f7d7dbbf1d", size = 819022, upload-time = "2025-11-21T02:49:47.39Z" }, +] + [[package]] name = "numpy" version = "2.2.6" @@ -1399,9 +1550,12 @@ resolution-markers = [ "python_full_version >= '3.14' and sys_platform == 'win32'", "python_full_version >= '3.14' and sys_platform == 'emscripten'", "python_full_version >= '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32'", - "python_full_version >= '3.11' and python_full_version < '3.14' and sys_platform == 'win32'", - "python_full_version >= '3.11' and python_full_version < '3.14' and sys_platform == 'emscripten'", - "python_full_version >= '3.11' and python_full_version < '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform == 'win32'", + "python_full_version == '3.11.*' and sys_platform == 'win32'", + "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform == 'emscripten'", + "python_full_version == '3.11.*' and sys_platform == 'emscripten'", + "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version == '3.11.*' and sys_platform != 'emscripten' and sys_platform != 'win32'", ] sdist = { url = "https://files.pythonhosted.org/packages/10/8b/c265f4823726ab832de836cdd184d0986dcf94480f81e8739692a7ac7af2/numpy-2.4.3.tar.gz", hash = "sha256:483a201202b73495f00dbc83796c6ae63137a9bdade074f7648b3e32613412dd", size = 20727743, upload-time = "2026-03-09T07:58:53.426Z" } wheels = [ @@ -1640,9 +1794,12 @@ resolution-markers = [ "python_full_version >= '3.14' and sys_platform == 'win32'", "python_full_version >= '3.14' and sys_platform == 'emscripten'", "python_full_version >= '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32'", - "python_full_version >= '3.11' and python_full_version < '3.14' and sys_platform == 'win32'", - "python_full_version >= '3.11' and python_full_version < '3.14' and sys_platform == 'emscripten'", - "python_full_version >= '3.11' and python_full_version < '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform == 'win32'", + "python_full_version == '3.11.*' and sys_platform == 'win32'", + "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform == 'emscripten'", + "python_full_version == '3.11.*' and sys_platform == 'emscripten'", + "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version == '3.11.*' and sys_platform != 'emscripten' and sys_platform != 'win32'", ] dependencies = [ { name = "numpy", version = "2.4.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, @@ -2680,3 +2837,68 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/51/47/3fa2286c3cb162c71cdb34c4224d5745a1ceceb391b2bd9b19b668a8d724/yarl-1.23.0-cp314-cp314t-win_arm64.whl", hash = "sha256:44bb7bef4ea409384e3f8bc36c063d77ea1b8d4a5b2706956c0d6695f07dcc25", size = 86041, upload-time = "2026-03-01T22:07:49.026Z" }, { url = "https://files.pythonhosted.org/packages/69/68/c8739671f5699c7dc470580a4f821ef37c32c4cb0b047ce223a7f115757f/yarl-1.23.0-py3-none-any.whl", hash = "sha256:a2df6afe50dea8ae15fa34c9f824a3ee958d785fd5d089063d960bae1daa0a3f", size = 48288, upload-time = "2026-03-01T22:07:51.388Z" }, ] + +[[package]] +name = "zarr" +version = "2.18.3" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version < '3.11'", +] +dependencies = [ + { name = "asciitree", marker = "python_full_version < '3.11'" }, + { name = "fasteners", marker = "python_full_version < '3.11' and sys_platform != 'emscripten'" }, + { name = "numcodecs", version = "0.13.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/23/c4/187a21ce7cf7c8f00c060dd0e04c2a81139bb7b1ab178bba83f2e1134ce2/zarr-2.18.3.tar.gz", hash = "sha256:2580d8cb6dd84621771a10d31c4d777dca8a27706a1a89b29f42d2d37e2df5ce", size = 3603224, upload-time = "2024-09-04T23:20:16.595Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ed/c9/142095e654c2b97133ff71df60979422717b29738b08bc8a1709a5d5e0d0/zarr-2.18.3-py3-none-any.whl", hash = "sha256:b1f7dfd2496f436745cdd4c7bcf8d3b4bc1dceef5fdd0d589c87130d842496dd", size = 210723, upload-time = "2024-09-04T23:20:14.491Z" }, +] + +[[package]] +name = "zarr" +version = "3.1.6" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version == '3.11.*' and sys_platform == 'win32'", + "python_full_version == '3.11.*' and sys_platform == 'emscripten'", + "python_full_version == '3.11.*' and sys_platform != 'emscripten' and sys_platform != 'win32'", +] +dependencies = [ + { name = "donfig", marker = "python_full_version == '3.11.*'" }, + { name = "google-crc32c", marker = "python_full_version == '3.11.*'" }, + { name = "numcodecs", version = "0.16.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.11.*'" }, + { name = "numpy", version = "2.4.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.11.*'" }, + { name = "packaging", marker = "python_full_version == '3.11.*'" }, + { name = "typing-extensions", marker = "python_full_version == '3.11.*'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/31/5a/b8a0cf39a14c770c30bd1f2d120c54000c8cd9e84e8e79f38d9a7ce58071/zarr-3.1.6.tar.gz", hash = "sha256:d95e72cbea4b90e9a70679468b8266400331756232576ae2b43400ac5108d0eb", size = 386531, upload-time = "2026-03-23T17:25:18.748Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/de/7c/ba8ca8cbe9dbef8e83a95fc208fed8e6686c98b4719aaa0aa7f3d31fe390/zarr-3.1.6-py3-none-any.whl", hash = "sha256:b5a82c5079d1c3d4ee8f06746fa3b9a98a7d804300fa3f4be154362a33e1207e", size = 295655, upload-time = "2026-03-23T17:25:17.189Z" }, +] + +[[package]] +name = "zarr" +version = "3.2.1" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.14' and sys_platform == 'win32'", + "python_full_version >= '3.14' and sys_platform == 'emscripten'", + "python_full_version >= '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform == 'win32'", + "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform == 'emscripten'", + "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32'", +] +dependencies = [ + { name = "donfig", marker = "python_full_version >= '3.12'" }, + { name = "google-crc32c", marker = "python_full_version >= '3.12'" }, + { name = "numcodecs", version = "0.16.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, + { name = "numpy", version = "2.4.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, + { name = "packaging", marker = "python_full_version >= '3.12'" }, + { name = "typing-extensions", marker = "python_full_version >= '3.12'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/93/8d/aeb164004f87543b06ef54f885d02c342c31ceb274e2bbec470a98927621/zarr-3.2.1.tar.gz", hash = "sha256:71565b738a0e7e8ed226f0516eba8c6bb53440ad7669a8c48ebb3534a161d035", size = 675161, upload-time = "2026-05-05T12:37:22.383Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/0a/469e2bd01be1490336e6c8707386845655d59261543315778a3ccc7e8019/zarr-3.2.1-py3-none-any.whl", hash = "sha256:f78cdd3d9687ad0e9f9cba2c5683b64f0c52589c19f685eeabe872e93cc0d2c7", size = 319617, upload-time = "2026-05-05T12:37:20.66Z" }, +] From 08ca3c0b370384c67f1d97c23c80fd82259c5f78 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sat, 23 May 2026 13:22:33 +0200 Subject: [PATCH 02/39] Move Zarr row splitting into reader --- src/refiner/pipeline/pipeline.py | 14 +- src/refiner/pipeline/sources/readers/zarr.py | 49 ++++++- src/refiner/robotics/row.py | 120 +--------------- tests/readers/test_zarr_reader.py | 29 +++- tests/robotics/test_robotics_row.py | 142 +------------------ 5 files changed, 78 insertions(+), 276 deletions(-) diff --git a/src/refiner/pipeline/pipeline.py b/src/refiner/pipeline/pipeline.py index 20e2cae2..5ba7952a 100644 --- a/src/refiner/pipeline/pipeline.py +++ b/src/refiner/pipeline/pipeline.py @@ -181,7 +181,6 @@ def to_robot_rows( video_keys: Mapping[str, str] | Iterable[str] | None = None, stats_key: str | None = "stats", stats_prefix: str = "stats/", - episode_ends_key: str | None = None, ) -> "RefinerPipeline": """Expose rows through the RoboticsRow semantic view. @@ -212,11 +211,8 @@ def to_robot_rows( schema=self.output_schema(), stats_key=stats_key, stats_prefix=stats_prefix, - episode_ends_key=episode_ends_key, ) - if episode_ends_key is None: - return self.map(cast(MapFn, converter)) - return self.flat_map(cast(FlatMapFn, converter)) + return self.map(cast(MapFn, converter)) def map_async( self, @@ -836,21 +832,23 @@ def read_zarr( *, arrays: ZarrPathSelection | None = None, attrs: ZarrPathSelection | None = None, + row_ends: str | None = None, file_path_column: str | None = "file_path", missing_policy: ZarrMissingPolicy = "error", dtypes: DTypeMapping | None = None, ) -> RefinerPipeline: """Create a pipeline with a Zarr reader source. - The reader emits one row for the Zarr group. Select arrays with `arrays` - and pass `episode_ends_key` to `to_robot_rows(...)` for Diffusion - Policy-style dataset arrays. + The reader emits one row for the Zarr group. If `row_ends` is provided, + it reads that Zarr array as cumulative end offsets and emits one row per + `[start:end]` slice. """ return RefinerPipeline( source=ZarrReader( input, arrays=arrays, attrs=attrs, + row_ends=row_ends, file_path_column=file_path_column, missing_policy=missing_policy, dtypes=dtypes, diff --git a/src/refiner/pipeline/sources/readers/zarr.py b/src/refiner/pipeline/sources/readers/zarr.py index 3b670210..d0127bfc 100644 --- a/src/refiner/pipeline/sources/readers/zarr.py +++ b/src/refiner/pipeline/sources/readers/zarr.py @@ -51,7 +51,7 @@ def _decode_value(value: Any) -> Any: class ZarrReader(BaseSource): - """Read one Zarr group as one row of selected arrays and attributes.""" + """Read one Zarr group as one row, or split arrays by cumulative row ends.""" name = "read_zarr" @@ -61,6 +61,7 @@ def __init__( *, arrays: PathSelection | None = None, attrs: PathSelection | None = None, + row_ends: str | None = None, file_path_column: str | None = "file_path", missing_policy: MissingPolicy = "error", dtypes: DTypeMapping | None = None, @@ -68,6 +69,7 @@ def __init__( self.root = DataFolder.resolve(input) self.arrays = _selection_map(arrays) self.attrs = _selection_map(attrs) + self.row_ends = row_ends self.file_path_column = file_path_column self.missing_policy = missing_policy self.dtypes = dtypes @@ -85,6 +87,7 @@ def describe(self) -> dict[str, Any]: "path": self.root.abs_path(), "arrays": dict(self.arrays), "attrs": dict(self.attrs), + "row_ends": self.row_ends, "file_path_column": self.file_path_column, "missing_policy": self.missing_policy, "dtypes": ( @@ -112,16 +115,50 @@ def read_shard(self, shard: Shard) -> Iterator[SourceUnit]: import zarr group = zarr.open_group(self.root.abs_path(), mode="r") - arrays = self.arrays or {path: path for path in _iter_array_paths(group)} + arrays = self.arrays or { + path: path for path in _iter_array_paths(group) if path != self.row_ends + } + if self.row_ends is not None: + try: + row_ends = [int(value) for value in group[self.row_ends][:]] + except KeyError: + if self.missing_policy == "drop_row": + return + raise KeyError( + f"Zarr row_ends array not found: {self.row_ends}" + ) from None + start = 0 + for end in row_ends: + row = self._read_row(group, arrays, start=start, end=end) + if row is None: + return + yield DictRow(row) + start = end + return + + row = self._read_row(group, arrays) + if row is not None: + yield DictRow(row) + + def _read_row( + self, + group: Any, + arrays: Mapping[str, str], + *, + start: int | None = None, + end: int | None = None, + ) -> dict[str, Any] | None: row: dict[str, Any] = {} if self.file_path_column is not None: row[self.file_path_column] = self.root.abs_path() for output_name, path in arrays.items(): try: - row[output_name] = group[path][:] + row[output_name] = ( + group[path][start:end] if start is not None else group[path][:] + ) except KeyError: if self.missing_policy == "drop_row": - return + return None if self.missing_policy == "set_null": row[output_name] = None continue @@ -129,13 +166,13 @@ def read_shard(self, shard: Shard) -> Iterator[SourceUnit]: for output_name, attr_name in self.attrs.items(): if attr_name not in group.attrs: if self.missing_policy == "drop_row": - return + return None if self.missing_policy == "set_null": row[output_name] = None continue raise KeyError(f"Zarr attr not found: {attr_name}") row[output_name] = _decode_value(group.attrs[attr_name]) - yield DictRow(row) + return row def _iter_array_paths(group: Any, prefix: str = "") -> Iterator[str]: diff --git a/src/refiner/robotics/row.py b/src/refiner/robotics/row.py index a3df7ba4..8bf0a971 100644 --- a/src/refiner/robotics/row.py +++ b/src/refiner/robotics/row.py @@ -551,36 +551,8 @@ def _robot_row_converter( video_keys: Mapping[str, str] | Iterable[str] | None = None, stats_key: str | None = "stats", stats_prefix: str = "stats/", - episode_ends_key: str | None = None, schema: pa.Schema | None = None, -) -> Callable[[Row], Row] | Callable[[Row], Iterable[Row]]: - if episode_ends_key is not None: - - def split_row(row: Row) -> Iterable[Row]: - return cast( - Iterable[Row], - _rows_from_episode_ends( - row, - episode_id_key=episode_id_key, - task_key=task_key, - fps=fps, - fps_key=fps_key, - robot_type=robot_type, - robot_type_key=robot_type_key, - timestamp_key=timestamp_key, - action_key=action_key, - state_key=state_key, - extra_observation_keys=extra_observation_keys, - video_keys=video_keys, - schema=schema, - stats_key=stats_key, - stats_prefix=stats_prefix, - episode_ends_key=episode_ends_key, - ), - ) - - return split_row - +) -> Callable[[Row], Row]: spec = _RoboticsRowSpec.from_options( episode_id_key=episode_id_key, task_key=task_key, @@ -616,90 +588,6 @@ def _valid_nested_frames_key(row: Row, key: str | None) -> str | None: return key if isinstance(value, Sequence) else None -def _rows_from_episode_ends( - row: Row, - *, - episode_id_key: str | None, - task_key: str | None, - fps: float | None, - fps_key: str | None, - robot_type: str | None, - robot_type_key: str | None, - timestamp_key: str | None, - action_key: str | None, - state_key: str | Sequence[str] | None, - extra_observation_keys: Mapping[str, str] | Iterable[str] | None, - video_keys: Mapping[str, str] | Iterable[str] | None, - schema: pa.Schema | None, - stats_key: str | None, - stats_prefix: str, - episode_ends_key: str, -) -> Iterator[RoboticsRow]: - episode_ends = [int(value) for value in _get_path(row, episode_ends_key)] - spec = _RoboticsRowSpec.from_options( - episode_id_key=episode_id_key, - task_key=task_key, - fps=fps, - fps_key=fps_key, - robot_type=robot_type, - robot_type_key=robot_type_key, - nested_frames_key=None, - timestamp_key=timestamp_key, - action_key=action_key, - state_key=state_key, - extra_observation_keys=extra_observation_keys, - video_keys=video_keys, - schema=schema, - stats_key=stats_key, - stats_prefix=stats_prefix, - ) - start = 0 - for episode_idx, end in enumerate(episode_ends): - if episode_id_key is None: - split_row = row - else: - split_row = row.update( - {episode_id_key: f"{row.get(episode_id_key, '-1')}-{episode_idx}"} - ) - for source_key in spec.frame_source_map.values(): - for key in _source_keys(source_key): - if _has_path(row, key): - split_row = split_row.update( - _set_path( - split_row, - key, - _slice_values(_get_path(row, key), start, end), - ) - ) - video_fps = int(spec.wrap(row).fps or 30) - for source_key, storage in spec.video_source_map.values(): - if not _has_path(row, source_key): - continue - value = _get_path(row, source_key) - video = video_from_storage_value(storage, value, fps=video_fps) - if video is None: - continue - clip_fps = int(getattr(video, "fps", video_fps)) - split_row = split_row.update( - _set_path( - split_row, - source_key, - video.clipped( - from_timestamp_s=start / clip_fps, - to_timestamp_s=end / clip_fps, - ), - ) - ) - if task_key is not None and task_key in row: - split_row = split_row.update({task_key: row[task_key]}) - if fps_key is not None and fps_key in row: - split_row = split_row.update({fps_key: row[fps_key]}) - if robot_type_key is not None and robot_type_key in row: - split_row = split_row.update({robot_type_key: row[robot_type_key]}) - yield spec.wrap(split_row) - start = end - - def _video_sources( *, schema: pa.Schema | None, @@ -853,12 +741,6 @@ def _set_path(row: Mapping[str, Any], key: str, values: Any) -> dict[str, Any]: return {head: root} -def _slice_values(values: Any, start: int, end: int) -> Any: - if isinstance(values, pa.ChunkedArray | pa.Array): - return values.slice(start, end - start) - return values[start:end] - - def _select_frame_table(table: Tabular, indices: Sequence[int]) -> Tabular: selected = table.table.take(pa.array(indices, type=pa.int64())) if "frame_index" in selected.column_names: diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py index a74372f3..4adde7a1 100644 --- a/tests/readers/test_zarr_reader.py +++ b/tests/readers/test_zarr_reader.py @@ -49,6 +49,32 @@ def test_read_zarr_reads_selected_arrays_and_attrs(tmp_path: Path) -> None: np.testing.assert_allclose(row["action"][:2], [[0.0], [0.1]]) +def test_read_zarr_splits_arrays_by_row_ends(tmp_path: Path) -> None: + path = tmp_path / "policy.zarr" + _write_policy_zarr(path) + + rows = mdr.read_zarr( + path, + arrays={ + "action": "data/action", + "observation.state": "data/state", + "frames": "data/rgb", + }, + attrs={"task": "task"}, + row_ends="meta/episode_ends", + file_path_column=None, + ).take(2) + + assert [row["task"] for row in rows] == ["push tee", "push tee"] + assert [len(row["action"]) for row in rows] == [2, 3] + np.testing.assert_allclose(rows[0]["action"], [[0.0], [0.1]]) + np.testing.assert_allclose(rows[1]["action"], [[1.0], [1.1], [1.2]]) + np.testing.assert_array_equal( + rows[0]["frames"], + np.arange(2 * 4 * 4 * 3).reshape(2, 4, 4, 3), + ) + + def test_zarr_to_robot_rows_and_lerobot_roundtrip(tmp_path: Path) -> None: path = tmp_path / "policy.zarr" lerobot_out = tmp_path / "lerobot" @@ -62,15 +88,14 @@ def test_zarr_to_robot_rows_and_lerobot_roundtrip(tmp_path: Path) -> None: "action": "data/action", "observation.state": "data/state", "frames": "data/rgb", - "episode_ends": "meta/episode_ends", }, attrs={"dataset_id": "dataset_id", "task": "task"}, + row_ends="meta/episode_ends", file_path_column=None, ) .to_robot_rows( episode_id_key="dataset_id", task_key="task", - episode_ends_key="episode_ends", action_key="action", state_key="observation.state", video_keys={"observation.images.front": "frames"}, diff --git a/tests/robotics/test_robotics_row.py b/tests/robotics/test_robotics_row.py index cbb1717c..b4e026e8 100644 --- a/tests/robotics/test_robotics_row.py +++ b/tests/robotics/test_robotics_row.py @@ -34,10 +34,7 @@ def _robot_rows( converter = _robot_row_converter(**cast(Any, {**kwargs, "schema": rows.schema})) return [cast(RoboticsRow, converter(row)) for row in rows] converter = _robot_row_converter(**kwargs) - converted = converter(rows) - if kwargs.get("episode_ends_key") is not None: - return [cast(RoboticsRow, row) for row in converted] - return [cast(RoboticsRow, converted)] + return [cast(RoboticsRow, converter(rows))] def test_to_robot_rows_does_not_treat_video_uri_frames_as_frame_table() -> None: @@ -249,31 +246,6 @@ def test_to_robot_rows_preserves_explicit_fps_and_robot_type_keys() -> None: assert robotics_row.robot_type == "aloha" -def test_to_robot_rows_reuses_literal_fps_and_robot_type_for_episode_splits() -> None: - row = DictRow( - { - "dataset_id": "pusht", - "episode_ends": [1, 2], - "action": [[0.0], [1.0]], - } - ) - - rows = list( - _robot_rows( - row, - episode_id_key="dataset_id", - episode_ends_key="episode_ends", - timestamp_key=None, - state_key=None, - fps=30.0, - robot_type="pusht", - ) - ) - - assert [row.fps for row in rows] == [30.0, 30.0] - assert [row.robot_type for row in rows] == ["pusht", "pusht"] - - def test_pipeline_to_robot_rows_forwards_literals_and_explicit_keys() -> None: literal_row = ( from_items([{"episode_id": "episode-1"}]) @@ -654,118 +626,6 @@ def test_to_robot_rows_flattens_nested_frame_dicts() -> None: assert robotics_row.states.to_pylist() == [[0.0], [1.0]] -def test_to_robot_rows_splits_dataset_arrays_with_episode_ends() -> None: - row = DictRow( - { - "dataset_id": "pusht", - "meta": {"episode_ends": [2, 5]}, - "data": { - "obs": [[0.0], [0.1], [1.0], [1.1], [1.2]], - "action": [[0.0], [0.1], [1.0], [1.1], [1.2]], - }, - } - ) - - rows = list( - _robot_rows( - row, - episode_id_key="dataset_id", - episode_ends_key="meta/episode_ends", - timestamp_key=None, - state_key="data/obs", - action_key="data/action", - ) - ) - - assert [row.episode_id for row in rows] == ["pusht-0", "pusht-1"] - assert [row.num_frames for row in rows] == [2, 3] - assert rows[1].actions == [[1.0], [1.1], [1.2]] - - -def test_to_robot_rows_splits_video_frame_arrays_with_episode_ends() -> None: - frames = np.arange(5 * 2 * 2 * 3, dtype=np.uint8).reshape(5, 2, 2, 3) - row = DictRow( - { - "dataset_id": "pusht", - "episode_ends": [2, 5], - "actions": [[0.0], [0.1], [1.0], [1.1], [1.2]], - "front": frames, - } - ) - - rows = list( - _robot_rows( - row, - episode_id_key="dataset_id", - episode_ends_key="episode_ends", - timestamp_key=None, - action_key="actions", - state_key=None, - video_keys={"observation.images.front": "front"}, - fps=10, - ) - ) - - videos = [row.videos["observation.images.front"] for row in rows] - - assert [ - video.frame_count for video in videos if isinstance(video, VideoFrameArray) - ] == [ - 2, - 3, - ] - assert isinstance(videos[0], VideoFrameArray) - assert isinstance(videos[1], VideoFrameArray) - np.testing.assert_array_equal( - np.asarray(list(videos[0].iter_frame_arrays())), - frames[:2], - ) - np.testing.assert_array_equal( - np.asarray(list(videos[1].iter_frame_arrays())), - frames[2:], - ) - - -def test_to_robot_rows_splits_tuple_state_key_sources_with_episode_ends() -> None: - row = DictRow( - { - "dataset_id": "robomimic", - "meta": {"episode_ends": [2, 4]}, - "actions": [[0.0], [0.1], [1.0], [1.1]], - "obs": { - "joint_pos": [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]], - "joint_vel": [[0.1], [0.2], [0.3], [0.4]], - "eef_pos": np.asarray( - [[10.0, 11.0], [12.0, 13.0], [14.0, 15.0], [16.0, 17.0]], - dtype=np.float32, - ), - }, - } - ) - - rows = list( - _robot_rows( - row, - episode_id_key="dataset_id", - episode_ends_key="meta/episode_ends", - timestamp_key=None, - action_key="actions", - state_key=("obs/joint_pos", "obs/joint_vel", "obs/eef_pos"), - ) - ) - - assert [row.episode_id for row in rows] == ["robomimic-0", "robomimic-1"] - assert [row.num_frames for row in rows] == [2, 2] - assert rows[0].states == [ - [1.0, 2.0, 0.1, 10.0, 11.0], - [3.0, 4.0, 0.2, 12.0, 13.0], - ] - assert rows[1].states == [ - [5.0, 6.0, 0.3, 14.0, 15.0], - [7.0, 8.0, 0.4, 16.0, 17.0], - ] - - def test_motion_trim_works_with_mapped_semantic_keys() -> None: row = DictRow( { From 9e0971707f6b557cfdeb7427b187113c34124232 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sat, 23 May 2026 13:29:18 +0200 Subject: [PATCH 03/39] Keep Zarr PR reader-only --- src/refiner/pipeline/pipeline.py | 19 +--- src/refiner/pipeline/sinks/__init__.py | 2 - src/refiner/pipeline/sinks/zarr.py | 123 ------------------------- tests/readers/test_zarr_reader.py | 33 ------- 4 files changed, 1 insertion(+), 176 deletions(-) delete mode 100644 src/refiner/pipeline/sinks/zarr.py diff --git a/src/refiner/pipeline/pipeline.py b/src/refiner/pipeline/pipeline.py index 5ba7952a..837cc2be 100644 --- a/src/refiner/pipeline/pipeline.py +++ b/src/refiner/pipeline/pipeline.py @@ -30,7 +30,7 @@ VectorizedSegmentStep, WithColumnsStep, ) -from refiner.pipeline.sinks import BaseSink, JsonlSink, ParquetSink, ZarrSink +from refiner.pipeline.sinks import BaseSink, JsonlSink, ParquetSink from refiner.pipeline.sinks.assets import MissingAssetPolicy from refiner.pipeline.sources import ( BaseSource, @@ -432,23 +432,6 @@ def write_parquet( ) ) - def write_zarr( - self, - output: DataFolderLike, - *, - arrays: Mapping[str, str] | None = None, - episode_ends_path: str | None = "meta/episode_ends", - overwrite: bool = True, - ) -> "RefinerPipeline": - return self.with_sink( - ZarrSink( - output=output, - arrays=arrays, - episode_ends_path=episode_ends_path, - overwrite=overwrite, - ) - ) - def __iter__(self) -> Iterator[Row]: return iter(self.iter_rows()) diff --git a/src/refiner/pipeline/sinks/__init__.py b/src/refiner/pipeline/sinks/__init__.py index f0623f21..8c1f26df 100644 --- a/src/refiner/pipeline/sinks/__init__.py +++ b/src/refiner/pipeline/sinks/__init__.py @@ -2,7 +2,6 @@ from refiner.pipeline.sinks.jsonl import JsonlSink from refiner.pipeline.sinks.parquet import ParquetSink from refiner.pipeline.sinks.reducer import FileCleanupReducerSink, LeRobotMetaReduceSink -from refiner.pipeline.sinks.zarr import ZarrSink __all__ = [ "BaseSink", @@ -11,5 +10,4 @@ "JsonlSink", "LeRobotMetaReduceSink", "ParquetSink", - "ZarrSink", ] diff --git a/src/refiner/pipeline/sinks/zarr.py b/src/refiner/pipeline/sinks/zarr.py deleted file mode 100644 index e3beb498..00000000 --- a/src/refiner/pipeline/sinks/zarr.py +++ /dev/null @@ -1,123 +0,0 @@ -from __future__ import annotations - -from collections.abc import Iterable, Mapping -from typing import Any, cast - -import numpy as np -import pyarrow as pa - -from refiner.io.datafolder import DataFolder, DataFolderLike -from refiner.pipeline.data.block import Block -from refiner.pipeline.data.row import Row -from refiner.pipeline.data.tabular import Tabular -from refiner.pipeline.sinks.base import BaseSink -from refiner.robotics.row import RoboticsRow -from refiner.utils import check_required_dependencies - - -class ZarrSink(BaseSink): - def __init__( - self, - output: DataFolderLike, - *, - arrays: Mapping[str, str] | None = None, - episode_ends_path: str | None = "meta/episode_ends", - overwrite: bool = True, - ): - self.output = DataFolder.resolve(output) - self.arrays = dict(arrays) if arrays is not None else None - self.episode_ends_path = episode_ends_path - self.overwrite = overwrite - self._chunks: dict[str, list[np.ndarray]] = {} - self._episode_ends: list[int] = [] - - def write_shard_block(self, shard_id: str, block: Block) -> int: - del shard_id - rows = list(block) if isinstance(block, Tabular) else block - for row in rows: - self._write_row(row) - return len(rows) - - def _write_row(self, row: Row) -> None: - arrays = self.arrays or _default_robotics_arrays(row) - lengths: list[int] = [] - for zarr_path, source_key in arrays.items(): - value = _row_value(row, source_key) - if value is None: - continue - array = _as_array(value) - if array.ndim == 0: - array = array.reshape(1) - lengths.append(int(array.shape[0])) - self._chunks.setdefault(zarr_path, []).append(array) - if lengths and self.episode_ends_path is not None: - length = lengths[0] - if any(item != length for item in lengths): - raise ValueError("Zarr arrays for one row must have matching lengths") - end = (self._episode_ends[-1] if self._episode_ends else 0) + length - self._episode_ends.append(end) - - def close(self) -> None: - if not self._chunks and not self._episode_ends: - return - check_required_dependencies("write_zarr", ["zarr"], dist="zarr") - import zarr - - mode = "w" if self.overwrite else "w-" - root = zarr.open_group(self.output.abs_path(), mode=mode) - for path, chunks in self._chunks.items(): - root.create_dataset(path, data=np.concatenate(chunks, axis=0)) - if self.episode_ends_path is not None: - root.create_dataset( - self.episode_ends_path, - data=np.asarray(self._episode_ends, dtype=np.int64), - ) - - def describe(self) -> tuple[str, str, dict[str, object]]: - return ( - "write_zarr", - "writer", - { - "path": self.output.abs_path(), - "arrays": dict(self.arrays) if self.arrays is not None else None, - "episode_ends_path": self.episode_ends_path, - "overwrite": self.overwrite, - }, - ) - - -def _default_robotics_arrays(row: Row) -> dict[str, str]: - if not isinstance(row, RoboticsRow): - raise ValueError("write_zarr requires arrays=... for non-RoboticsRow inputs") - arrays: dict[str, str] = {} - if row.actions is not None: - arrays["data/action"] = "action" - if row.states is not None: - arrays["data/observation.state"] = "observation.state" - if row.timestamps is not None: - arrays["data/timestamp"] = "timestamp" - return arrays - - -def _row_value(row: Row, key: str) -> Any: - if isinstance(row, RoboticsRow): - if key == "action": - return row.actions - if key == "observation.state": - return row.states - if key == "timestamp": - return row.timestamps - if key.startswith("observation."): - return row.observations(key) - return row[key] - - -def _as_array(value: Any) -> np.ndarray: - if isinstance(value, pa.ChunkedArray | pa.Array): - return np.asarray(value.to_pylist()) - if isinstance(value, Iterable) and not isinstance(value, str | bytes | np.ndarray): - return np.asarray(list(cast(Iterable[Any], value))) - return np.asarray(value) - - -__all__ = ["ZarrSink"] diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py index 4adde7a1..f344434c 100644 --- a/tests/readers/test_zarr_reader.py +++ b/tests/readers/test_zarr_reader.py @@ -78,7 +78,6 @@ def test_read_zarr_splits_arrays_by_row_ends(tmp_path: Path) -> None: def test_zarr_to_robot_rows_and_lerobot_roundtrip(tmp_path: Path) -> None: path = tmp_path / "policy.zarr" lerobot_out = tmp_path / "lerobot" - zarr_out = tmp_path / "roundtrip.zarr" _write_policy_zarr(path) ( @@ -114,35 +113,3 @@ def test_zarr_to_robot_rows_and_lerobot_roundtrip(tmp_path: Path) -> None: ] assert [episode.num_frames for episode in episodes] == [2, 3] assert [episode.task for episode in episodes] == ["push tee", "push tee"] - - ( - mdr.read_lerobot(str(lerobot_out)) - .write_zarr( - str(zarr_out), - arrays={ - "data/action": "action", - "data/state": "observation.state", - }, - ) - .launch_local( - name="lerobot-to-zarr", num_workers=1, rundir=str(tmp_path / "run2") - ) - ) - - row = mdr.read_zarr( - zarr_out, - arrays={ - "action": "data/action", - "state": "data/state", - "episode_ends": "meta/episode_ends", - }, - file_path_column=None, - ).take(1)[0] - - assert row["episode_ends"].tolist() == [2, 5] - np.testing.assert_allclose( - row["action"], np.asarray([[0.0], [0.1], [1.0], [1.1], [1.2]]) - ) - np.testing.assert_allclose( - row["state"], np.asarray([[10.0], [10.1], [20.0], [20.1], [20.2]]) - ) From b1d6a2c56bf9b026288ebb9acd1a705b5b55f533 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sat, 23 May 2026 13:36:15 +0200 Subject: [PATCH 04/39] Tighten Zarr reader selections --- src/refiner/pipeline/sources/readers/zarr.py | 35 +++++++++++++------- tests/readers/test_zarr_reader.py | 29 ++++++++++++++++ 2 files changed, 52 insertions(+), 12 deletions(-) diff --git a/src/refiner/pipeline/sources/readers/zarr.py b/src/refiner/pipeline/sources/readers/zarr.py index d0127bfc..cbff465f 100644 --- a/src/refiner/pipeline/sources/readers/zarr.py +++ b/src/refiner/pipeline/sources/readers/zarr.py @@ -20,9 +20,7 @@ PathSelection = Mapping[str, str] | Sequence[str] | str -def _selection_map(value: PathSelection | None) -> dict[str, str]: - if value is None: - return {} +def _selection_map(value: PathSelection) -> dict[str, str]: if isinstance(value, str): return {value.rsplit("/", 1)[-1]: value} if isinstance(value, Mapping): @@ -67,8 +65,10 @@ def __init__( dtypes: DTypeMapping | None = None, ): self.root = DataFolder.resolve(input) - self.arrays = _selection_map(arrays) - self.attrs = _selection_map(attrs) + check_required_dependencies("read_zarr", ["zarr"], dist="zarr") + self.arrays = None if arrays is None else _selection_map(arrays) + self.attrs = None if attrs is None else _selection_map(attrs) + _validate_output_names(self.arrays or {}, self.attrs or {}) self.row_ends = row_ends self.file_path_column = file_path_column self.missing_policy = missing_policy @@ -85,8 +85,8 @@ def schema(self) -> pa.Schema | None: def describe(self) -> dict[str, Any]: return { "path": self.root.abs_path(), - "arrays": dict(self.arrays), - "attrs": dict(self.attrs), + "arrays": dict(self.arrays) if self.arrays is not None else None, + "attrs": dict(self.attrs) if self.attrs is not None else None, "row_ends": self.row_ends, "file_path_column": self.file_path_column, "missing_policy": self.missing_policy, @@ -111,13 +111,14 @@ def list_shards(self) -> list[Shard]: def read_shard(self, shard: Shard) -> Iterator[SourceUnit]: del shard - check_required_dependencies("read_zarr", ["zarr"], dist="zarr") import zarr group = zarr.open_group(self.root.abs_path(), mode="r") - arrays = self.arrays or { - path: path for path in _iter_array_paths(group) if path != self.row_ends - } + arrays = ( + {path: path for path in _iter_array_paths(group) if path != self.row_ends} + if self.arrays is None + else self.arrays + ) if self.row_ends is not None: try: row_ends = [int(value) for value in group[self.row_ends][:]] @@ -163,7 +164,7 @@ def _read_row( row[output_name] = None continue raise KeyError(f"Zarr array not found: {path}") from None - for output_name, attr_name in self.attrs.items(): + for output_name, attr_name in (self.attrs or {}).items(): if attr_name not in group.attrs: if self.missing_policy == "drop_row": return None @@ -175,6 +176,16 @@ def _read_row( return row +def _validate_output_names( + arrays: Mapping[str, str], + attrs: Mapping[str, str], +) -> None: + duplicates = set(arrays).intersection(attrs) + if duplicates: + names = ", ".join(sorted(repr(name) for name in duplicates)) + raise ValueError(f"Zarr arrays and attrs use duplicate output names: {names}") + + def _iter_array_paths(group: Any, prefix: str = "") -> Iterator[str]: for name, item in group.items(): path = f"{prefix}/{name}" if prefix else name diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py index f344434c..21614449 100644 --- a/tests/readers/test_zarr_reader.py +++ b/tests/readers/test_zarr_reader.py @@ -4,6 +4,7 @@ from typing import cast import numpy as np +import pytest import zarr import refiner as mdr @@ -75,6 +76,34 @@ def test_read_zarr_splits_arrays_by_row_ends(tmp_path: Path) -> None: ) +def test_read_zarr_allows_attrs_only_reads(tmp_path: Path) -> None: + path = tmp_path / "policy.zarr" + _write_policy_zarr(path) + + row = mdr.read_zarr( + path, + arrays={}, + attrs={"task": "task"}, + file_path_column=None, + ).take(1)[0] + + assert list(row) == ["task"] + assert row["task"] == "push tee" + + +def test_read_zarr_rejects_duplicate_output_names(tmp_path: Path) -> None: + path = tmp_path / "policy.zarr" + _write_policy_zarr(path) + + with pytest.raises(ValueError, match="duplicate output names"): + mdr.read_zarr( + path, + arrays={"task": "data/action"}, + attrs={"task": "task"}, + file_path_column=None, + ) + + def test_zarr_to_robot_rows_and_lerobot_roundtrip(tmp_path: Path) -> None: path = tmp_path / "policy.zarr" lerobot_out = tmp_path / "lerobot" From 28bd3ba7adec49c8cd81e2e6936ca37445b88472 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sat, 23 May 2026 13:50:08 +0200 Subject: [PATCH 05/39] Shard Zarr row-end reads --- pyproject.toml | 2 +- src/refiner/io/zarr.py | 25 ++++ src/refiner/pipeline/pipeline.py | 2 + src/refiner/pipeline/sources/readers/zarr.py | 51 ++++++-- tests/readers/test_zarr_reader.py | 29 ++++- uv.lock | 117 ++----------------- 6 files changed, 108 insertions(+), 118 deletions(-) create mode 100644 src/refiner/io/zarr.py diff --git a/pyproject.toml b/pyproject.toml index dc3ba3e7..c7764dbb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,7 @@ hdf5 = [ "h5py", ] zarr = [ - "zarr>=2.18", + "zarr>=2.18,<3", ] s3 = [ "s3fs", diff --git a/src/refiner/io/zarr.py b/src/refiner/io/zarr.py new file mode 100644 index 00000000..bff3958d --- /dev/null +++ b/src/refiner/io/zarr.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from typing import Literal + +from refiner.io.datafolder import DataFolder + + +def zarr_store( + folder: DataFolder, + path: str = "", + *, + mode: Literal["r", "w", "w-", "a"] = "r", +): + import zarr + + create = mode in {"w", "w-", "a"} + return zarr.storage.FSStore( + folder._join(path), + fs=folder.fs, + mode=mode, + create=create, + ) + + +__all__ = ["zarr_store"] diff --git a/src/refiner/pipeline/pipeline.py b/src/refiner/pipeline/pipeline.py index 837cc2be..8fbfd55a 100644 --- a/src/refiner/pipeline/pipeline.py +++ b/src/refiner/pipeline/pipeline.py @@ -816,6 +816,7 @@ def read_zarr( arrays: ZarrPathSelection | None = None, attrs: ZarrPathSelection | None = None, row_ends: str | None = None, + row_index_column: str | None = "row_index", file_path_column: str | None = "file_path", missing_policy: ZarrMissingPolicy = "error", dtypes: DTypeMapping | None = None, @@ -832,6 +833,7 @@ def read_zarr( arrays=arrays, attrs=attrs, row_ends=row_ends, + row_index_column=row_index_column, file_path_column=file_path_column, missing_policy=missing_policy, dtypes=dtypes, diff --git a/src/refiner/pipeline/sources/readers/zarr.py b/src/refiner/pipeline/sources/readers/zarr.py index cbff465f..7ac992af 100644 --- a/src/refiner/pipeline/sources/readers/zarr.py +++ b/src/refiner/pipeline/sources/readers/zarr.py @@ -6,13 +6,14 @@ import pyarrow as pa from refiner.io.datafolder import DataFolder, DataFolderLike +from refiner.io.zarr import zarr_store from refiner.pipeline.data.datatype import ( DTypeMapping, dtype_to_plan, schema_with_dtypes, ) from refiner.pipeline.data.row import DictRow -from refiner.pipeline.data.shard import Shard +from refiner.pipeline.data.shard import RowRangeDescriptor, Shard from refiner.pipeline.sources.base import BaseSource, SourceUnit from refiner.utils import check_required_dependencies @@ -60,6 +61,7 @@ def __init__( arrays: PathSelection | None = None, attrs: PathSelection | None = None, row_ends: str | None = None, + row_index_column: str | None = "row_index", file_path_column: str | None = "file_path", missing_policy: MissingPolicy = "error", dtypes: DTypeMapping | None = None, @@ -70,6 +72,7 @@ def __init__( self.attrs = None if attrs is None else _selection_map(attrs) _validate_output_names(self.arrays or {}, self.attrs or {}) self.row_ends = row_ends + self.row_index_column = row_index_column self.file_path_column = file_path_column self.missing_policy = missing_policy self.dtypes = dtypes @@ -88,6 +91,7 @@ def describe(self) -> dict[str, Any]: "arrays": dict(self.arrays) if self.arrays is not None else None, "attrs": dict(self.attrs) if self.attrs is not None else None, "row_ends": self.row_ends, + "row_index_column": self.row_index_column, "file_path_column": self.file_path_column, "missing_policy": self.missing_policy, "dtypes": ( @@ -99,6 +103,21 @@ def describe(self) -> dict[str, Any]: def list_shards(self) -> list[Shard]: path = self.root.abs_path() + if self.row_ends is not None: + import zarr + + group = zarr.open_group(store=zarr_store(self.root), mode="r") + row_count = len(group[self.row_ends]) + return [ + Shard.from_row_range( + start=index, + end=index + 1, + global_ordinal=index, + start_key=path, + end_key=path, + ) + for index in range(row_count) + ] return [ Shard.from_row_range( start=0, @@ -110,27 +129,42 @@ def list_shards(self) -> list[Shard]: ] def read_shard(self, shard: Shard) -> Iterator[SourceUnit]: - del shard import zarr - group = zarr.open_group(self.root.abs_path(), mode="r") + group = zarr.open_group(store=zarr_store(self.root), mode="r") arrays = ( {path: path for path in _iter_array_paths(group) if path != self.row_ends} if self.arrays is None else self.arrays ) if self.row_ends is not None: + descriptor = shard.descriptor + assert isinstance(descriptor, RowRangeDescriptor) try: - row_ends = [int(value) for value in group[self.row_ends][:]] + ends_array = group[self.row_ends] + row_ends = [ + int(value) + for value in ends_array[descriptor.start : descriptor.end] + ] + start = ( + 0 + if descriptor.start == 0 + else int(ends_array[descriptor.start - 1]) + ) except KeyError: if self.missing_policy == "drop_row": return raise KeyError( f"Zarr row_ends array not found: {self.row_ends}" ) from None - start = 0 - for end in row_ends: - row = self._read_row(group, arrays, start=start, end=end) + for offset, end in enumerate(row_ends): + row = self._read_row( + group, + arrays, + start=start, + end=end, + row_index=descriptor.start + offset, + ) if row is None: return yield DictRow(row) @@ -148,10 +182,13 @@ def _read_row( *, start: int | None = None, end: int | None = None, + row_index: int | None = None, ) -> dict[str, Any] | None: row: dict[str, Any] = {} if self.file_path_column is not None: row[self.file_path_column] = self.root.abs_path() + if self.row_index_column is not None and row_index is not None: + row[self.row_index_column] = row_index for output_name, path in arrays.items(): try: row[output_name] = ( diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py index 21614449..c9f08a25 100644 --- a/tests/readers/test_zarr_reader.py +++ b/tests/readers/test_zarr_reader.py @@ -9,6 +9,8 @@ import refiner as mdr from refiner.robotics.row import RoboticsRow +from refiner.pipeline.data.row import Row +from refiner.pipeline.data.shard import RowRangeDescriptor def _write_policy_zarr(path: Path) -> None: @@ -76,6 +78,30 @@ def test_read_zarr_splits_arrays_by_row_ends(tmp_path: Path) -> None: ) +def test_read_zarr_plans_one_shard_per_row_end(tmp_path: Path) -> None: + path = tmp_path / "policy.zarr" + _write_policy_zarr(path) + + pipeline = mdr.read_zarr( + path, + arrays={"action": "data/action"}, + row_ends="meta/episode_ends", + file_path_column=None, + ) + + shards = pipeline.source.list_shards() + ranges = [cast(RowRangeDescriptor, shard.descriptor) for shard in shards] + assert [(item.start, item.end) for item in ranges] == [ + (0, 1), + (1, 2), + ] + + rows = [cast(Row, row) for row in pipeline.source.read_shard(shards[1])] + assert len(rows) == 1 + assert rows[0]["row_index"] == 1 + np.testing.assert_allclose(rows[0]["action"], [[1.0], [1.1], [1.2]]) + + def test_read_zarr_allows_attrs_only_reads(tmp_path: Path) -> None: path = tmp_path / "policy.zarr" _write_policy_zarr(path) @@ -122,7 +148,7 @@ def test_zarr_to_robot_rows_and_lerobot_roundtrip(tmp_path: Path) -> None: file_path_column=None, ) .to_robot_rows( - episode_id_key="dataset_id", + episode_id_key="row_index", task_key="task", action_key="action", state_key="observation.state", @@ -140,5 +166,6 @@ def test_zarr_to_robot_rows_and_lerobot_roundtrip(tmp_path: Path) -> None: cast(RoboticsRow, row) for row in mdr.read_lerobot(str(lerobot_out)).materialize() ] + episodes.sort(key=lambda episode: int(episode.episode_id)) assert [episode.num_frames for episode in episodes] == [2, 3] assert [episode.task for episode in episodes] == ["push tee", "push tee"] diff --git a/uv.lock b/uv.lock index 731b1849..82a0836e 100644 --- a/uv.lock +++ b/uv.lock @@ -618,18 +618,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/33/6b/e0547afaf41bf2c42e52430072fa5658766e3d65bd4b03a563d1b6336f57/distlib-0.4.0-py2.py3-none-any.whl", hash = "sha256:9659f7d87e46584a30b5780e43ac7a2143098441670ff0a49d5f9034c54a6c16", size = 469047, upload-time = "2025-07-17T16:51:58.613Z" }, ] -[[package]] -name = "donfig" -version = "0.8.1.post1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "pyyaml", marker = "python_full_version >= '3.11'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/25/71/80cc718ff6d7abfbabacb1f57aaa42e9c1552bfdd01e64ddd704e4a03638/donfig-0.8.1.post1.tar.gz", hash = "sha256:3bef3413a4c1c601b585e8d297256d0c1470ea012afa6e8461dc28bfb7c23f52", size = 19506, upload-time = "2024-05-23T14:14:31.513Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/0c/d5/c5db1ea3394c6e1732fb3286b3bd878b59507a8f77d32a2cebda7d7b7cd4/donfig-0.8.1.post1-py3-none-any.whl", hash = "sha256:2a3175ce74a06109ff9307d90a230f81215cbac9a751f4d1c6194644b8204f9d", size = 21592, upload-time = "2024-05-23T14:13:55.283Z" }, -] - [[package]] name = "exceptiongroup" version = "1.3.1" @@ -795,41 +783,6 @@ http = [ { name = "aiohttp" }, ] -[[package]] -name = "google-crc32c" -version = "1.8.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/03/41/4b9c02f99e4c5fb477122cd5437403b552873f014616ac1d19ac8221a58d/google_crc32c-1.8.0.tar.gz", hash = "sha256:a428e25fb7691024de47fecfbff7ff957214da51eddded0da0ae0e0f03a2cf79", size = 14192, upload-time = "2025-12-16T00:35:25.142Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/95/ac/6f7bc93886a823ab545948c2dd48143027b2355ad1944c7cf852b338dc91/google_crc32c-1.8.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:0470b8c3d73b5f4e3300165498e4cf25221c7eb37f1159e221d1825b6df8a7ff", size = 31296, upload-time = "2025-12-16T00:19:07.261Z" }, - { url = "https://files.pythonhosted.org/packages/f7/97/a5accde175dee985311d949cfcb1249dcbb290f5ec83c994ea733311948f/google_crc32c-1.8.0-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:119fcd90c57c89f30040b47c211acee231b25a45d225e3225294386f5d258288", size = 30870, upload-time = "2025-12-16T00:29:17.669Z" }, - { url = "https://files.pythonhosted.org/packages/3d/63/bec827e70b7a0d4094e7476f863c0dbd6b5f0f1f91d9c9b32b76dcdfeb4e/google_crc32c-1.8.0-cp310-cp310-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:6f35aaffc8ccd81ba3162443fabb920e65b1f20ab1952a31b13173a67811467d", size = 33214, upload-time = "2025-12-16T00:40:19.618Z" }, - { url = "https://files.pythonhosted.org/packages/63/bc/11b70614df04c289128d782efc084b9035ef8466b3d0a8757c1b6f5cf7ac/google_crc32c-1.8.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:864abafe7d6e2c4c66395c1eb0fe12dc891879769b52a3d56499612ca93b6092", size = 33589, upload-time = "2025-12-16T00:40:20.7Z" }, - { url = "https://files.pythonhosted.org/packages/3e/00/a08a4bc24f1261cc5b0f47312d8aebfbe4b53c2e6307f1b595605eed246b/google_crc32c-1.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:db3fe8eaf0612fc8b20fa21a5f25bd785bc3cd5be69f8f3412b0ac2ffd49e733", size = 34437, upload-time = "2025-12-16T00:35:19.437Z" }, - { url = "https://files.pythonhosted.org/packages/5d/ef/21ccfaab3d5078d41efe8612e0ed0bfc9ce22475de074162a91a25f7980d/google_crc32c-1.8.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:014a7e68d623e9a4222d663931febc3033c5c7c9730785727de2a81f87d5bab8", size = 31298, upload-time = "2025-12-16T00:20:32.241Z" }, - { url = "https://files.pythonhosted.org/packages/c5/b8/f8413d3f4b676136e965e764ceedec904fe38ae8de0cdc52a12d8eb1096e/google_crc32c-1.8.0-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:86cfc00fe45a0ac7359e5214a1704e51a99e757d0272554874f419f79838c5f7", size = 30872, upload-time = "2025-12-16T00:33:58.785Z" }, - { url = "https://files.pythonhosted.org/packages/f6/fd/33aa4ec62b290477181c55bb1c9302c9698c58c0ce9a6ab4874abc8b0d60/google_crc32c-1.8.0-cp311-cp311-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:19b40d637a54cb71e0829179f6cb41835f0fbd9e8eb60552152a8b52c36cbe15", size = 33243, upload-time = "2025-12-16T00:40:21.46Z" }, - { url = "https://files.pythonhosted.org/packages/71/03/4820b3bd99c9653d1a5210cb32f9ba4da9681619b4d35b6a052432df4773/google_crc32c-1.8.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:17446feb05abddc187e5441a45971b8394ea4c1b6efd88ab0af393fd9e0a156a", size = 33608, upload-time = "2025-12-16T00:40:22.204Z" }, - { url = "https://files.pythonhosted.org/packages/7c/43/acf61476a11437bf9733fb2f70599b1ced11ec7ed9ea760fdd9a77d0c619/google_crc32c-1.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:71734788a88f551fbd6a97be9668a0020698e07b2bf5b3aa26a36c10cdfb27b2", size = 34439, upload-time = "2025-12-16T00:35:20.458Z" }, - { url = "https://files.pythonhosted.org/packages/e9/5f/7307325b1198b59324c0fa9807cafb551afb65e831699f2ce211ad5c8240/google_crc32c-1.8.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:4b8286b659c1335172e39563ab0a768b8015e88e08329fa5321f774275fc3113", size = 31300, upload-time = "2025-12-16T00:21:56.723Z" }, - { url = "https://files.pythonhosted.org/packages/21/8e/58c0d5d86e2220e6a37befe7e6a94dd2f6006044b1a33edf1ff6d9f7e319/google_crc32c-1.8.0-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:2a3dc3318507de089c5384cc74d54318401410f82aa65b2d9cdde9d297aca7cb", size = 30867, upload-time = "2025-12-16T00:38:31.302Z" }, - { url = "https://files.pythonhosted.org/packages/ce/a9/a780cc66f86335a6019f557a8aaca8fbb970728f0efd2430d15ff1beae0e/google_crc32c-1.8.0-cp312-cp312-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:14f87e04d613dfa218d6135e81b78272c3b904e2a7053b841481b38a7d901411", size = 33364, upload-time = "2025-12-16T00:40:22.96Z" }, - { url = "https://files.pythonhosted.org/packages/21/3f/3457ea803db0198c9aaca2dd373750972ce28a26f00544b6b85088811939/google_crc32c-1.8.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cb5c869c2923d56cb0c8e6bcdd73c009c36ae39b652dbe46a05eb4ef0ad01454", size = 33740, upload-time = "2025-12-16T00:40:23.96Z" }, - { url = "https://files.pythonhosted.org/packages/df/c0/87c2073e0c72515bb8733d4eef7b21548e8d189f094b5dad20b0ecaf64f6/google_crc32c-1.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:3cc0c8912038065eafa603b238abf252e204accab2a704c63b9e14837a854962", size = 34437, upload-time = "2025-12-16T00:35:21.395Z" }, - { url = "https://files.pythonhosted.org/packages/d1/db/000f15b41724589b0e7bc24bc7a8967898d8d3bc8caf64c513d91ef1f6c0/google_crc32c-1.8.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:3ebb04528e83b2634857f43f9bb8ef5b2bbe7f10f140daeb01b58f972d04736b", size = 31297, upload-time = "2025-12-16T00:23:20.709Z" }, - { url = "https://files.pythonhosted.org/packages/d7/0d/8ebed0c39c53a7e838e2a486da8abb0e52de135f1b376ae2f0b160eb4c1a/google_crc32c-1.8.0-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:450dc98429d3e33ed2926fc99ee81001928d63460f8538f21a5d6060912a8e27", size = 30867, upload-time = "2025-12-16T00:43:14.628Z" }, - { url = "https://files.pythonhosted.org/packages/ce/42/b468aec74a0354b34c8cbf748db20d6e350a68a2b0912e128cabee49806c/google_crc32c-1.8.0-cp313-cp313-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:3b9776774b24ba76831609ffbabce8cdf6fa2bd5e9df37b594221c7e333a81fa", size = 33344, upload-time = "2025-12-16T00:40:24.742Z" }, - { url = "https://files.pythonhosted.org/packages/1c/e8/b33784d6fc77fb5062a8a7854e43e1e618b87d5ddf610a88025e4de6226e/google_crc32c-1.8.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:89c17d53d75562edfff86679244830599ee0a48efc216200691de8b02ab6b2b8", size = 33694, upload-time = "2025-12-16T00:40:25.505Z" }, - { url = "https://files.pythonhosted.org/packages/92/b1/d3cbd4d988afb3d8e4db94ca953df429ed6db7282ed0e700d25e6c7bfc8d/google_crc32c-1.8.0-cp313-cp313-win_amd64.whl", hash = "sha256:57a50a9035b75643996fbf224d6661e386c7162d1dfdab9bc4ca790947d1007f", size = 34435, upload-time = "2025-12-16T00:35:22.107Z" }, - { url = "https://files.pythonhosted.org/packages/21/88/8ecf3c2b864a490b9e7010c84fd203ec8cf3b280651106a3a74dd1b0ca72/google_crc32c-1.8.0-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:e6584b12cb06796d285d09e33f63309a09368b9d806a551d8036a4207ea43697", size = 31301, upload-time = "2025-12-16T00:24:48.527Z" }, - { url = "https://files.pythonhosted.org/packages/36/c6/f7ff6c11f5ca215d9f43d3629163727a272eabc356e5c9b2853df2bfe965/google_crc32c-1.8.0-cp314-cp314-macosx_12_0_x86_64.whl", hash = "sha256:f4b51844ef67d6cf2e9425983274da75f18b1597bb2c998e1c0a0e8d46f8f651", size = 30868, upload-time = "2025-12-16T00:48:12.163Z" }, - { url = "https://files.pythonhosted.org/packages/56/15/c25671c7aad70f8179d858c55a6ae8404902abe0cdcf32a29d581792b491/google_crc32c-1.8.0-cp314-cp314-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:b0d1a7afc6e8e4635564ba8aa5c0548e3173e41b6384d7711a9123165f582de2", size = 33381, upload-time = "2025-12-16T00:40:26.268Z" }, - { url = "https://files.pythonhosted.org/packages/42/fa/f50f51260d7b0ef5d4898af122d8a7ec5a84e2984f676f746445f783705f/google_crc32c-1.8.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8b3f68782f3cbd1bce027e48768293072813469af6a61a86f6bb4977a4380f21", size = 33734, upload-time = "2025-12-16T00:40:27.028Z" }, - { url = "https://files.pythonhosted.org/packages/08/a5/7b059810934a09fb3ccb657e0843813c1fee1183d3bc2c8041800374aa2c/google_crc32c-1.8.0-cp314-cp314-win_amd64.whl", hash = "sha256:d511b3153e7011a27ab6ee6bb3a5404a55b994dc1a7322c0b87b29606d9790e2", size = 34878, upload-time = "2025-12-16T00:35:23.142Z" }, - { url = "https://files.pythonhosted.org/packages/52/c5/c171e4d8c44fec1422d801a6d2e5d7ddabd733eeda505c79730ee9607f07/google_crc32c-1.8.0-pp311-pypy311_pp73-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:87fa445064e7db928226b2e6f0d5304ab4cd0339e664a4e9a25029f384d9bb93", size = 28615, upload-time = "2025-12-16T00:40:29.298Z" }, - { url = "https://files.pythonhosted.org/packages/9c/97/7d75fe37a7a6ed171a2cf17117177e7aab7e6e0d115858741b41e9dd4254/google_crc32c-1.8.0-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:f639065ea2042d5c034bf258a9f085eaa7af0cd250667c0635a3118e8f92c69c", size = 28800, upload-time = "2025-12-16T00:40:30.322Z" }, -] - [[package]] name = "h11" version = "0.16.0" @@ -1066,9 +1019,7 @@ all = [ { name = "pytest-cov" }, { name = "s3fs" }, { name = "warcio" }, - { name = "zarr", version = "2.18.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "zarr", version = "3.1.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.11.*'" }, - { name = "zarr", version = "3.2.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, + { name = "zarr" }, ] hdf5 = [ { name = "h5py" }, @@ -1094,9 +1045,7 @@ testing = [ { name = "pytest-cov" }, { name = "s3fs" }, { name = "warcio" }, - { name = "zarr", version = "2.18.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "zarr", version = "3.1.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.11.*'" }, - { name = "zarr", version = "3.2.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, + { name = "zarr" }, ] text = [ { name = "warcio" }, @@ -1105,9 +1054,7 @@ video = [ { name = "av" }, ] zarr = [ - { name = "zarr", version = "2.18.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "zarr", version = "3.1.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.11.*'" }, - { name = "zarr", version = "3.2.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, + { name = "zarr" }, ] [package.dev-dependencies] @@ -1145,7 +1092,7 @@ requires-dist = [ { name = "pytest-cov", marker = "extra == 'testing'", specifier = ">=5.0.0" }, { name = "s3fs", marker = "extra == 's3'" }, { name = "warcio", marker = "extra == 'text'" }, - { name = "zarr", marker = "extra == 'zarr'", specifier = ">=2.18" }, + { name = "zarr", marker = "extra == 'zarr'", specifier = ">=2.18,<3" }, ] provides-extras = ["huggingface", "video", "robotics", "text", "hdf5", "zarr", "s3", "testing", "all"] @@ -2842,63 +2789,15 @@ wheels = [ name = "zarr" version = "2.18.3" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version < '3.11'", -] dependencies = [ - { name = "asciitree", marker = "python_full_version < '3.11'" }, - { name = "fasteners", marker = "python_full_version < '3.11' and sys_platform != 'emscripten'" }, + { name = "asciitree" }, + { name = "fasteners", marker = "sys_platform != 'emscripten'" }, { name = "numcodecs", version = "0.13.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "numcodecs", version = "0.16.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "numpy", version = "2.4.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/23/c4/187a21ce7cf7c8f00c060dd0e04c2a81139bb7b1ab178bba83f2e1134ce2/zarr-2.18.3.tar.gz", hash = "sha256:2580d8cb6dd84621771a10d31c4d777dca8a27706a1a89b29f42d2d37e2df5ce", size = 3603224, upload-time = "2024-09-04T23:20:16.595Z" } wheels = [ { url = "https://files.pythonhosted.org/packages/ed/c9/142095e654c2b97133ff71df60979422717b29738b08bc8a1709a5d5e0d0/zarr-2.18.3-py3-none-any.whl", hash = "sha256:b1f7dfd2496f436745cdd4c7bcf8d3b4bc1dceef5fdd0d589c87130d842496dd", size = 210723, upload-time = "2024-09-04T23:20:14.491Z" }, ] - -[[package]] -name = "zarr" -version = "3.1.6" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version == '3.11.*' and sys_platform == 'win32'", - "python_full_version == '3.11.*' and sys_platform == 'emscripten'", - "python_full_version == '3.11.*' and sys_platform != 'emscripten' and sys_platform != 'win32'", -] -dependencies = [ - { name = "donfig", marker = "python_full_version == '3.11.*'" }, - { name = "google-crc32c", marker = "python_full_version == '3.11.*'" }, - { name = "numcodecs", version = "0.16.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.11.*'" }, - { name = "numpy", version = "2.4.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.11.*'" }, - { name = "packaging", marker = "python_full_version == '3.11.*'" }, - { name = "typing-extensions", marker = "python_full_version == '3.11.*'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/31/5a/b8a0cf39a14c770c30bd1f2d120c54000c8cd9e84e8e79f38d9a7ce58071/zarr-3.1.6.tar.gz", hash = "sha256:d95e72cbea4b90e9a70679468b8266400331756232576ae2b43400ac5108d0eb", size = 386531, upload-time = "2026-03-23T17:25:18.748Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/de/7c/ba8ca8cbe9dbef8e83a95fc208fed8e6686c98b4719aaa0aa7f3d31fe390/zarr-3.1.6-py3-none-any.whl", hash = "sha256:b5a82c5079d1c3d4ee8f06746fa3b9a98a7d804300fa3f4be154362a33e1207e", size = 295655, upload-time = "2026-03-23T17:25:17.189Z" }, -] - -[[package]] -name = "zarr" -version = "3.2.1" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version >= '3.14' and sys_platform == 'win32'", - "python_full_version >= '3.14' and sys_platform == 'emscripten'", - "python_full_version >= '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32'", - "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform == 'win32'", - "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform == 'emscripten'", - "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32'", -] -dependencies = [ - { name = "donfig", marker = "python_full_version >= '3.12'" }, - { name = "google-crc32c", marker = "python_full_version >= '3.12'" }, - { name = "numcodecs", version = "0.16.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, - { name = "numpy", version = "2.4.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, - { name = "packaging", marker = "python_full_version >= '3.12'" }, - { name = "typing-extensions", marker = "python_full_version >= '3.12'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/93/8d/aeb164004f87543b06ef54f885d02c342c31ceb274e2bbec470a98927621/zarr-3.2.1.tar.gz", hash = "sha256:71565b738a0e7e8ed226f0516eba8c6bb53440ad7669a8c48ebb3534a161d035", size = 675161, upload-time = "2026-05-05T12:37:22.383Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/88/0a/469e2bd01be1490336e6c8707386845655d59261543315778a3ccc7e8019/zarr-3.2.1-py3-none-any.whl", hash = "sha256:f78cdd3d9687ad0e9f9cba2c5683b64f0c52589c19f685eeabe872e93cc0d2c7", size = 319617, upload-time = "2026-05-05T12:37:20.66Z" }, -] From 841aaed7e64bdd6d2ef0d6d8175f0d2ea141de87 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sat, 23 May 2026 14:02:31 +0200 Subject: [PATCH 06/39] Restore generic episode-end robot splits --- src/refiner/pipeline/pipeline.py | 8 +- src/refiner/pipeline/sources/readers/zarr.py | 11 +- src/refiner/robotics/row.py | 120 ++++++++++++++++++- tests/readers/test_zarr_reader.py | 1 + tests/robotics/test_robotics_row.py | 117 +++++++++++++++++- 5 files changed, 251 insertions(+), 6 deletions(-) diff --git a/src/refiner/pipeline/pipeline.py b/src/refiner/pipeline/pipeline.py index 8fbfd55a..e40844c3 100644 --- a/src/refiner/pipeline/pipeline.py +++ b/src/refiner/pipeline/pipeline.py @@ -181,6 +181,7 @@ def to_robot_rows( video_keys: Mapping[str, str] | Iterable[str] | None = None, stats_key: str | None = "stats", stats_prefix: str = "stats/", + episode_ends_key: str | None = None, ) -> "RefinerPipeline": """Expose rows through the RoboticsRow semantic view. @@ -211,8 +212,11 @@ def to_robot_rows( schema=self.output_schema(), stats_key=stats_key, stats_prefix=stats_prefix, + episode_ends_key=episode_ends_key, ) - return self.map(cast(MapFn, converter)) + if episode_ends_key is None: + return self.map(cast(MapFn, converter)) + return self.flat_map(cast(FlatMapFn, converter)) def map_async( self, @@ -816,6 +820,7 @@ def read_zarr( arrays: ZarrPathSelection | None = None, attrs: ZarrPathSelection | None = None, row_ends: str | None = None, + rows_per_shard: int = 128, row_index_column: str | None = "row_index", file_path_column: str | None = "file_path", missing_policy: ZarrMissingPolicy = "error", @@ -833,6 +838,7 @@ def read_zarr( arrays=arrays, attrs=attrs, row_ends=row_ends, + rows_per_shard=rows_per_shard, row_index_column=row_index_column, file_path_column=file_path_column, missing_policy=missing_policy, diff --git a/src/refiner/pipeline/sources/readers/zarr.py b/src/refiner/pipeline/sources/readers/zarr.py index 7ac992af..ca8e3b3b 100644 --- a/src/refiner/pipeline/sources/readers/zarr.py +++ b/src/refiner/pipeline/sources/readers/zarr.py @@ -61,6 +61,7 @@ def __init__( arrays: PathSelection | None = None, attrs: PathSelection | None = None, row_ends: str | None = None, + rows_per_shard: int = 128, row_index_column: str | None = "row_index", file_path_column: str | None = "file_path", missing_policy: MissingPolicy = "error", @@ -72,6 +73,7 @@ def __init__( self.attrs = None if attrs is None else _selection_map(attrs) _validate_output_names(self.arrays or {}, self.attrs or {}) self.row_ends = row_ends + self.rows_per_shard = rows_per_shard self.row_index_column = row_index_column self.file_path_column = file_path_column self.missing_policy = missing_policy @@ -91,6 +93,7 @@ def describe(self) -> dict[str, Any]: "arrays": dict(self.arrays) if self.arrays is not None else None, "attrs": dict(self.attrs) if self.attrs is not None else None, "row_ends": self.row_ends, + "rows_per_shard": self.rows_per_shard, "row_index_column": self.row_index_column, "file_path_column": self.file_path_column, "missing_policy": self.missing_policy, @@ -104,19 +107,21 @@ def describe(self) -> dict[str, Any]: def list_shards(self) -> list[Shard]: path = self.root.abs_path() if self.row_ends is not None: + if self.rows_per_shard <= 0: + raise ValueError("rows_per_shard must be greater than zero") import zarr group = zarr.open_group(store=zarr_store(self.root), mode="r") row_count = len(group[self.row_ends]) return [ Shard.from_row_range( - start=index, - end=index + 1, + start=start, + end=min(start + self.rows_per_shard, row_count), global_ordinal=index, start_key=path, end_key=path, ) - for index in range(row_count) + for index, start in enumerate(range(0, row_count, self.rows_per_shard)) ] return [ Shard.from_row_range( diff --git a/src/refiner/robotics/row.py b/src/refiner/robotics/row.py index 8bf0a971..a3df7ba4 100644 --- a/src/refiner/robotics/row.py +++ b/src/refiner/robotics/row.py @@ -551,8 +551,36 @@ def _robot_row_converter( video_keys: Mapping[str, str] | Iterable[str] | None = None, stats_key: str | None = "stats", stats_prefix: str = "stats/", + episode_ends_key: str | None = None, schema: pa.Schema | None = None, -) -> Callable[[Row], Row]: +) -> Callable[[Row], Row] | Callable[[Row], Iterable[Row]]: + if episode_ends_key is not None: + + def split_row(row: Row) -> Iterable[Row]: + return cast( + Iterable[Row], + _rows_from_episode_ends( + row, + episode_id_key=episode_id_key, + task_key=task_key, + fps=fps, + fps_key=fps_key, + robot_type=robot_type, + robot_type_key=robot_type_key, + timestamp_key=timestamp_key, + action_key=action_key, + state_key=state_key, + extra_observation_keys=extra_observation_keys, + video_keys=video_keys, + schema=schema, + stats_key=stats_key, + stats_prefix=stats_prefix, + episode_ends_key=episode_ends_key, + ), + ) + + return split_row + spec = _RoboticsRowSpec.from_options( episode_id_key=episode_id_key, task_key=task_key, @@ -588,6 +616,90 @@ def _valid_nested_frames_key(row: Row, key: str | None) -> str | None: return key if isinstance(value, Sequence) else None +def _rows_from_episode_ends( + row: Row, + *, + episode_id_key: str | None, + task_key: str | None, + fps: float | None, + fps_key: str | None, + robot_type: str | None, + robot_type_key: str | None, + timestamp_key: str | None, + action_key: str | None, + state_key: str | Sequence[str] | None, + extra_observation_keys: Mapping[str, str] | Iterable[str] | None, + video_keys: Mapping[str, str] | Iterable[str] | None, + schema: pa.Schema | None, + stats_key: str | None, + stats_prefix: str, + episode_ends_key: str, +) -> Iterator[RoboticsRow]: + episode_ends = [int(value) for value in _get_path(row, episode_ends_key)] + spec = _RoboticsRowSpec.from_options( + episode_id_key=episode_id_key, + task_key=task_key, + fps=fps, + fps_key=fps_key, + robot_type=robot_type, + robot_type_key=robot_type_key, + nested_frames_key=None, + timestamp_key=timestamp_key, + action_key=action_key, + state_key=state_key, + extra_observation_keys=extra_observation_keys, + video_keys=video_keys, + schema=schema, + stats_key=stats_key, + stats_prefix=stats_prefix, + ) + start = 0 + for episode_idx, end in enumerate(episode_ends): + if episode_id_key is None: + split_row = row + else: + split_row = row.update( + {episode_id_key: f"{row.get(episode_id_key, '-1')}-{episode_idx}"} + ) + for source_key in spec.frame_source_map.values(): + for key in _source_keys(source_key): + if _has_path(row, key): + split_row = split_row.update( + _set_path( + split_row, + key, + _slice_values(_get_path(row, key), start, end), + ) + ) + video_fps = int(spec.wrap(row).fps or 30) + for source_key, storage in spec.video_source_map.values(): + if not _has_path(row, source_key): + continue + value = _get_path(row, source_key) + video = video_from_storage_value(storage, value, fps=video_fps) + if video is None: + continue + clip_fps = int(getattr(video, "fps", video_fps)) + split_row = split_row.update( + _set_path( + split_row, + source_key, + video.clipped( + from_timestamp_s=start / clip_fps, + to_timestamp_s=end / clip_fps, + ), + ) + ) + if task_key is not None and task_key in row: + split_row = split_row.update({task_key: row[task_key]}) + if fps_key is not None and fps_key in row: + split_row = split_row.update({fps_key: row[fps_key]}) + if robot_type_key is not None and robot_type_key in row: + split_row = split_row.update({robot_type_key: row[robot_type_key]}) + yield spec.wrap(split_row) + start = end + + def _video_sources( *, schema: pa.Schema | None, @@ -741,6 +853,12 @@ def _set_path(row: Mapping[str, Any], key: str, values: Any) -> dict[str, Any]: return {head: root} +def _slice_values(values: Any, start: int, end: int) -> Any: + if isinstance(values, pa.ChunkedArray | pa.Array): + return values.slice(start, end - start) + return values[start:end] + + def _select_frame_table(table: Tabular, indices: Sequence[int]) -> Tabular: selected = table.table.take(pa.array(indices, type=pa.int64())) if "frame_index" in selected.column_names: diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py index c9f08a25..2d69a2d9 100644 --- a/tests/readers/test_zarr_reader.py +++ b/tests/readers/test_zarr_reader.py @@ -86,6 +86,7 @@ def test_read_zarr_plans_one_shard_per_row_end(tmp_path: Path) -> None: path, arrays={"action": "data/action"}, row_ends="meta/episode_ends", + rows_per_shard=1, file_path_column=None, ) diff --git a/tests/robotics/test_robotics_row.py b/tests/robotics/test_robotics_row.py index b4e026e8..94cccaab 100644 --- a/tests/robotics/test_robotics_row.py +++ b/tests/robotics/test_robotics_row.py @@ -34,7 +34,10 @@ def _robot_rows( converter = _robot_row_converter(**cast(Any, {**kwargs, "schema": rows.schema})) return [cast(RoboticsRow, converter(row)) for row in rows] converter = _robot_row_converter(**kwargs) - return [cast(RoboticsRow, converter(rows))] + converted = converter(rows) + if kwargs.get("episode_ends_key") is not None: + return [cast(RoboticsRow, row) for row in converted] + return [cast(RoboticsRow, converted)] def test_to_robot_rows_does_not_treat_video_uri_frames_as_frame_table() -> None: @@ -626,6 +629,118 @@ def test_to_robot_rows_flattens_nested_frame_dicts() -> None: assert robotics_row.states.to_pylist() == [[0.0], [1.0]] +def test_to_robot_rows_splits_dataset_arrays_with_episode_ends() -> None: + row = DictRow( + { + "dataset_id": "pusht", + "meta": {"episode_ends": [2, 5]}, + "data": { + "obs": [[0.0], [0.1], [1.0], [1.1], [1.2]], + "action": [[0.0], [0.1], [1.0], [1.1], [1.2]], + }, + } + ) + + rows = list( + _robot_rows( + row, + episode_id_key="dataset_id", + episode_ends_key="meta/episode_ends", + timestamp_key=None, + state_key="data/obs", + action_key="data/action", + ) + ) + + assert [row.episode_id for row in rows] == ["pusht-0", "pusht-1"] + assert [row.num_frames for row in rows] == [2, 3] + assert rows[1].actions == [[1.0], [1.1], [1.2]] + + +def test_to_robot_rows_splits_video_frame_arrays_with_episode_ends() -> None: + frames = np.arange(5 * 2 * 2 * 3, dtype=np.uint8).reshape(5, 2, 2, 3) + row = DictRow( + { + "dataset_id": "pusht", + "episode_ends": [2, 5], + "actions": [[0.0], [0.1], [1.0], [1.1], [1.2]], + "front": frames, + } + ) + + rows = list( + _robot_rows( + row, + episode_id_key="dataset_id", + episode_ends_key="episode_ends", + timestamp_key=None, + action_key="actions", + state_key=None, + video_keys={"observation.images.front": "front"}, + fps=10, + ) + ) + + videos = [row.videos["observation.images.front"] for row in rows] + + assert [ + video.frame_count for video in videos if isinstance(video, VideoFrameArray) + ] == [ + 2, + 3, + ] + assert isinstance(videos[0], VideoFrameArray) + assert isinstance(videos[1], VideoFrameArray) + np.testing.assert_array_equal( + np.asarray(list(videos[0].iter_frame_arrays())), + frames[:2], + ) + np.testing.assert_array_equal( + np.asarray(list(videos[1].iter_frame_arrays())), + frames[2:], + ) + + +def test_to_robot_rows_splits_tuple_state_key_sources_with_episode_ends() -> None: + row = DictRow( + { + "dataset_id": "robomimic", + "meta": {"episode_ends": [2, 4]}, + "actions": [[0.0], [0.1], [1.0], [1.1]], + "obs": { + "joint_pos": [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]], + "joint_vel": [[0.1], [0.2], [0.3], [0.4]], + "eef_pos": np.asarray( + [[10.0, 11.0], [12.0, 13.0], [14.0, 15.0], [16.0, 17.0]], + dtype=np.float32, + ), + }, + } + ) + + rows = list( + _robot_rows( + row, + episode_id_key="dataset_id", + episode_ends_key="meta/episode_ends", + timestamp_key=None, + action_key="actions", + state_key=("obs/joint_pos", "obs/joint_vel", "obs/eef_pos"), + ) + ) + + assert [row.episode_id for row in rows] == ["robomimic-0", "robomimic-1"] + assert [row.num_frames for row in rows] == [2, 2] + assert rows[0].states == [ + [1.0, 2.0, 0.1, 10.0, 11.0], + [3.0, 4.0, 0.2, 12.0, 13.0], + ] + assert rows[1].states == [ + [5.0, 6.0, 0.3, 14.0, 15.0], + [7.0, 8.0, 0.4, 16.0, 17.0], + ] + + def test_motion_trim_works_with_mapped_semantic_keys() -> None: row = DictRow( { From 55b81bba5630bfaea7a5118c474bdff88f68475f Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sat, 23 May 2026 14:08:32 +0200 Subject: [PATCH 07/39] Validate Zarr reader metadata names --- src/refiner/pipeline/sources/readers/zarr.py | 34 +++++++++++++- tests/readers/test_zarr_reader.py | 48 ++++++++++++++++++++ 2 files changed, 80 insertions(+), 2 deletions(-) diff --git a/src/refiner/pipeline/sources/readers/zarr.py b/src/refiner/pipeline/sources/readers/zarr.py index ca8e3b3b..cffb7e9e 100644 --- a/src/refiner/pipeline/sources/readers/zarr.py +++ b/src/refiner/pipeline/sources/readers/zarr.py @@ -71,13 +71,17 @@ def __init__( check_required_dependencies("read_zarr", ["zarr"], dist="zarr") self.arrays = None if arrays is None else _selection_map(arrays) self.attrs = None if attrs is None else _selection_map(attrs) - _validate_output_names(self.arrays or {}, self.attrs or {}) self.row_ends = row_ends self.rows_per_shard = rows_per_shard self.row_index_column = row_index_column self.file_path_column = file_path_column self.missing_policy = missing_policy self.dtypes = dtypes + _validate_output_names( + self.arrays or {}, + self.attrs or {}, + reserved=self._reserved_output_names(row_index=row_ends is not None), + ) if missing_policy not in ("error", "drop_row", "set_null"): raise ValueError( "missing_policy must be one of 'error', 'drop_row', or 'set_null'" @@ -112,7 +116,14 @@ def list_shards(self) -> list[Shard]: import zarr group = zarr.open_group(store=zarr_store(self.root), mode="r") - row_count = len(group[self.row_ends]) + try: + row_count = len(group[self.row_ends]) + except KeyError: + if self.missing_policy == "drop_row": + return [] + raise KeyError( + f"Zarr row_ends array not found: {self.row_ends}" + ) from None return [ Shard.from_row_range( start=start, @@ -142,6 +153,11 @@ def read_shard(self, shard: Shard) -> Iterator[SourceUnit]: if self.arrays is None else self.arrays ) + _validate_output_names( + arrays, + self.attrs or {}, + reserved=self._reserved_output_names(row_index=self.row_ends is not None), + ) if self.row_ends is not None: descriptor = shard.descriptor assert isinstance(descriptor, RowRangeDescriptor) @@ -180,6 +196,14 @@ def read_shard(self, shard: Shard) -> Iterator[SourceUnit]: if row is not None: yield DictRow(row) + def _reserved_output_names(self, *, row_index: bool) -> set[str]: + names = set() + if self.file_path_column is not None: + names.add(self.file_path_column) + if row_index and self.row_index_column is not None: + names.add(self.row_index_column) + return names + def _read_row( self, group: Any, @@ -221,11 +245,17 @@ def _read_row( def _validate_output_names( arrays: Mapping[str, str], attrs: Mapping[str, str], + *, + reserved: set[str] | None = None, ) -> None: duplicates = set(arrays).intersection(attrs) if duplicates: names = ", ".join(sorted(repr(name) for name in duplicates)) raise ValueError(f"Zarr arrays and attrs use duplicate output names: {names}") + reserved_matches = set(arrays).union(attrs).intersection(reserved or set()) + if reserved_matches: + names = ", ".join(sorted(repr(name) for name in reserved_matches)) + raise ValueError(f"Zarr selections use reserved output names: {names}") def _iter_array_paths(group: Any, prefix: str = "") -> Iterator[str]: diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py index 2d69a2d9..b4de37cd 100644 --- a/tests/readers/test_zarr_reader.py +++ b/tests/readers/test_zarr_reader.py @@ -32,6 +32,14 @@ def _write_policy_zarr(path: Path) -> None: root.attrs["task"] = "push tee" +def test_read_zarr_rejects_reserved_file_path_output_name(tmp_path: Path) -> None: + path = tmp_path / "policy.zarr" + _write_policy_zarr(path) + + with pytest.raises(ValueError, match="reserved output names"): + mdr.read_zarr(path, arrays={"file_path": "data/action"}) + + def test_read_zarr_reads_selected_arrays_and_attrs(tmp_path: Path) -> None: path = tmp_path / "policy.zarr" _write_policy_zarr(path) @@ -131,6 +139,46 @@ def test_read_zarr_rejects_duplicate_output_names(tmp_path: Path) -> None: ) +def test_read_zarr_rejects_discovered_array_attr_collisions(tmp_path: Path) -> None: + path = tmp_path / "collision.zarr" + root = zarr.open_group(str(path), mode="w") + root.create_dataset("task", data=np.asarray([1], dtype=np.int64)) + root.attrs["task"] = "push tee" + + pipeline = mdr.read_zarr(path, attrs={"task": "task"}, file_path_column=None) + + with pytest.raises(ValueError, match="duplicate output names"): + pipeline.take(1) + + +def test_read_zarr_rejects_reserved_row_index_output_name(tmp_path: Path) -> None: + path = tmp_path / "policy.zarr" + _write_policy_zarr(path) + + with pytest.raises(ValueError, match="reserved output names"): + mdr.read_zarr( + path, + arrays={"row_index": "data/action"}, + row_ends="meta/episode_ends", + file_path_column=None, + ) + + +def test_read_zarr_drop_row_handles_missing_row_ends(tmp_path: Path) -> None: + path = tmp_path / "policy.zarr" + _write_policy_zarr(path) + + pipeline = mdr.read_zarr( + path, + arrays={"action": "data/action"}, + row_ends="meta/missing_episode_ends", + missing_policy="drop_row", + file_path_column=None, + ) + + assert pipeline.source.list_shards() == [] + + def test_zarr_to_robot_rows_and_lerobot_roundtrip(tmp_path: Path) -> None: path = tmp_path / "policy.zarr" lerobot_out = tmp_path / "lerobot" From 6f58278f9679298c4cce854e6263d299790f2043 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sat, 23 May 2026 14:15:36 +0200 Subject: [PATCH 08/39] Reject duplicate Zarr metadata columns --- src/refiner/pipeline/sources/readers/zarr.py | 6 ++++++ tests/readers/test_zarr_reader.py | 13 +++++++++++++ 2 files changed, 19 insertions(+) diff --git a/src/refiner/pipeline/sources/readers/zarr.py b/src/refiner/pipeline/sources/readers/zarr.py index cffb7e9e..87883f82 100644 --- a/src/refiner/pipeline/sources/readers/zarr.py +++ b/src/refiner/pipeline/sources/readers/zarr.py @@ -86,6 +86,12 @@ def __init__( raise ValueError( "missing_policy must be one of 'error', 'drop_row', or 'set_null'" ) + if ( + row_ends is not None + and file_path_column is not None + and file_path_column == row_index_column + ): + raise ValueError("file_path_column and row_index_column must be distinct") @property def schema(self) -> pa.Schema | None: diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py index b4de37cd..5ccd65d1 100644 --- a/tests/readers/test_zarr_reader.py +++ b/tests/readers/test_zarr_reader.py @@ -164,6 +164,19 @@ def test_read_zarr_rejects_reserved_row_index_output_name(tmp_path: Path) -> Non ) +def test_read_zarr_rejects_duplicate_metadata_column_names(tmp_path: Path) -> None: + path = tmp_path / "policy.zarr" + _write_policy_zarr(path) + + with pytest.raises(ValueError, match="must be distinct"): + mdr.read_zarr( + path, + row_ends="meta/episode_ends", + file_path_column="metadata", + row_index_column="metadata", + ) + + def test_read_zarr_drop_row_handles_missing_row_ends(tmp_path: Path) -> None: path = tmp_path / "policy.zarr" _write_policy_zarr(path) From 528182c97160c2c753db86ead75ad9be1bd14fd9 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sat, 23 May 2026 14:21:23 +0200 Subject: [PATCH 09/39] Batch Zarr row-end reads --- src/refiner/pipeline/sources/readers/zarr.py | 50 ++++++++++++++------ tests/robotics/test_robotics_row.py | 25 ++++++++++ 2 files changed, 61 insertions(+), 14 deletions(-) diff --git a/src/refiner/pipeline/sources/readers/zarr.py b/src/refiner/pipeline/sources/readers/zarr.py index 87883f82..5c6f2e99 100644 --- a/src/refiner/pipeline/sources/readers/zarr.py +++ b/src/refiner/pipeline/sources/readers/zarr.py @@ -184,21 +184,37 @@ def read_shard(self, shard: Shard) -> Iterator[SourceUnit]: raise KeyError( f"Zarr row_ends array not found: {self.row_ends}" ) from None + shard_start = start + shard_end = row_ends[-1] if row_ends else start + shard_arrays = self._read_arrays( + group, + arrays, + start=shard_start, + end=shard_end, + ) + if shard_arrays is None: + return for offset, end in enumerate(row_ends): - row = self._read_row( - group, - arrays, - start=start, - end=end, - row_index=descriptor.start + offset, - ) + row = self._row_metadata(row_index=descriptor.start + offset) + relative_start = start - shard_start + relative_end = end - shard_start + for output_name, value in shard_arrays.items(): + row[output_name] = ( + None if value is None else value[relative_start:relative_end] + ) + row = self._read_attrs(group, row) if row is None: return yield DictRow(row) start = end return - row = self._read_row(group, arrays) + row = self._row_metadata(row_index=None) + row_arrays = self._read_arrays(group, arrays) + if row_arrays is None: + return + row.update(row_arrays) + row = self._read_attrs(group, row) if row is not None: yield DictRow(row) @@ -210,20 +226,23 @@ def _reserved_output_names(self, *, row_index: bool) -> set[str]: names.add(self.row_index_column) return names - def _read_row( + def _row_metadata(self, *, row_index: int | None) -> dict[str, Any]: + row: dict[str, Any] = {} + if self.file_path_column is not None: + row[self.file_path_column] = self.root.abs_path() + if self.row_index_column is not None and row_index is not None: + row[self.row_index_column] = row_index + return row + + def _read_arrays( self, group: Any, arrays: Mapping[str, str], *, start: int | None = None, end: int | None = None, - row_index: int | None = None, ) -> dict[str, Any] | None: row: dict[str, Any] = {} - if self.file_path_column is not None: - row[self.file_path_column] = self.root.abs_path() - if self.row_index_column is not None and row_index is not None: - row[self.row_index_column] = row_index for output_name, path in arrays.items(): try: row[output_name] = ( @@ -236,6 +255,9 @@ def _read_row( row[output_name] = None continue raise KeyError(f"Zarr array not found: {path}") from None + return row + + def _read_attrs(self, group: Any, row: dict[str, Any]) -> dict[str, Any] | None: for output_name, attr_name in (self.attrs or {}).items(): if attr_name not in group.attrs: if self.missing_policy == "drop_row": diff --git a/tests/robotics/test_robotics_row.py b/tests/robotics/test_robotics_row.py index 94cccaab..cbb1717c 100644 --- a/tests/robotics/test_robotics_row.py +++ b/tests/robotics/test_robotics_row.py @@ -249,6 +249,31 @@ def test_to_robot_rows_preserves_explicit_fps_and_robot_type_keys() -> None: assert robotics_row.robot_type == "aloha" +def test_to_robot_rows_reuses_literal_fps_and_robot_type_for_episode_splits() -> None: + row = DictRow( + { + "dataset_id": "pusht", + "episode_ends": [1, 2], + "action": [[0.0], [1.0]], + } + ) + + rows = list( + _robot_rows( + row, + episode_id_key="dataset_id", + episode_ends_key="episode_ends", + timestamp_key=None, + state_key=None, + fps=30.0, + robot_type="pusht", + ) + ) + + assert [row.fps for row in rows] == [30.0, 30.0] + assert [row.robot_type for row in rows] == ["pusht", "pusht"] + + def test_pipeline_to_robot_rows_forwards_literals_and_explicit_keys() -> None: literal_row = ( from_items([{"episode_id": "episode-1"}]) From 09e8636eab3c56a36000122d653b8cf18f72a9a2 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sat, 23 May 2026 14:30:27 +0200 Subject: [PATCH 10/39] Share reader path selection types --- src/refiner/pipeline/pipeline.py | 12 ++--- src/refiner/pipeline/sources/readers/hdf5.py | 28 +++-------- .../pipeline/sources/readers/selection.py | 34 ++++++++++++++ src/refiner/pipeline/sources/readers/zarr.py | 47 +++++++++---------- 4 files changed, 68 insertions(+), 53 deletions(-) create mode 100644 src/refiner/pipeline/sources/readers/selection.py diff --git a/src/refiner/pipeline/pipeline.py b/src/refiner/pipeline/pipeline.py index e40844c3..3301f99b 100644 --- a/src/refiner/pipeline/pipeline.py +++ b/src/refiner/pipeline/pipeline.py @@ -43,11 +43,7 @@ ZarrReader, ) from refiner.pipeline.sources.readers.lerobot import LeRobotEpisodeReader -from refiner.pipeline.sources.readers.hdf5 import MissingPolicy, PathSelection -from refiner.pipeline.sources.readers.zarr import ( - MissingPolicy as ZarrMissingPolicy, - PathSelection as ZarrPathSelection, -) +from refiner.pipeline.sources.readers.selection import MissingPolicy, PathSelection from refiner.pipeline.sources.items import ItemsSource from refiner.pipeline.sources.task import TaskSource from refiner.pipeline.data import datatype @@ -817,13 +813,13 @@ def read_hdf5( def read_zarr( input: DataFolderLike, *, - arrays: ZarrPathSelection | None = None, - attrs: ZarrPathSelection | None = None, + arrays: PathSelection | None = None, + attrs: PathSelection | None = None, row_ends: str | None = None, rows_per_shard: int = 128, row_index_column: str | None = "row_index", file_path_column: str | None = "file_path", - missing_policy: ZarrMissingPolicy = "error", + missing_policy: MissingPolicy = "error", dtypes: DTypeMapping | None = None, ) -> RefinerPipeline: """Create a pipeline with a Zarr reader source. diff --git a/src/refiner/pipeline/sources/readers/hdf5.py b/src/refiner/pipeline/sources/readers/hdf5.py index 0bbef12b..b50b8c03 100644 --- a/src/refiner/pipeline/sources/readers/hdf5.py +++ b/src/refiner/pipeline/sources/readers/hdf5.py @@ -3,7 +3,7 @@ from collections.abc import Iterator, Mapping, Sequence import fnmatch from glob import has_magic -from typing import Any, Literal, cast +from typing import Any from fsspec import AbstractFileSystem @@ -13,14 +13,15 @@ from refiner.pipeline.data.row import DictRow from refiner.pipeline.data.shard import FilePartsDescriptor from refiner.pipeline.sources.readers.base import BaseReader, Shard, SourceUnit +from refiner.pipeline.sources.readers.selection import ( + MissingPolicy, + PathSelection, + path_selection_map, +) from refiner.pipeline.sources.readers.utils import DEFAULT_TARGET_SHARD_BYTES from refiner.utils import check_required_dependencies -MissingPolicy = Literal["error", "drop_row", "set_null"] -PathSelection = Mapping[str, str] | Sequence[str] | str - - def _decode_value( value: Any, *, @@ -140,22 +141,7 @@ def __init__( def _mapping( value: PathSelection | None, ) -> dict[str, str]: - if value is None: - return {} - if isinstance(value, str): - return {value.rsplit("/", 1)[-1]: value} - if isinstance(value, Mapping): - return dict(cast(Mapping[str, str], value)) - out: dict[str, str] = {} - for path in value: - name = path.rsplit("/", 1)[-1] - if name in out: - raise ValueError( - "HDF5 path selections must have unique derived column names; " - f"use an explicit mapping for duplicate name {name!r}" - ) - out[name] = path - return out + return path_selection_map(value, format_name="HDF5") def describe(self) -> dict[str, Any]: description = super().describe() diff --git a/src/refiner/pipeline/sources/readers/selection.py b/src/refiner/pipeline/sources/readers/selection.py new file mode 100644 index 00000000..3f161607 --- /dev/null +++ b/src/refiner/pipeline/sources/readers/selection.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from typing import Literal, cast + + +MissingPolicy = Literal["error", "drop_row", "set_null"] +PathSelection = Mapping[str, str] | Sequence[str] | str + + +def path_selection_map( + value: PathSelection | None, + *, + format_name: str, +) -> dict[str, str]: + if value is None: + return {} + if isinstance(value, str): + return {value.rsplit("/", 1)[-1]: value} + if isinstance(value, Mapping): + return dict(cast(Mapping[str, str], value)) + out: dict[str, str] = {} + for path in value: + name = path.rsplit("/", 1)[-1] + if name in out: + raise ValueError( + f"{format_name} path selections must have unique derived column names; " + f"use an explicit mapping for duplicate name {name!r}" + ) + out[name] = path + return out + + +__all__ = ["MissingPolicy", "PathSelection", "path_selection_map"] diff --git a/src/refiner/pipeline/sources/readers/zarr.py b/src/refiner/pipeline/sources/readers/zarr.py index 5c6f2e99..bbf733dd 100644 --- a/src/refiner/pipeline/sources/readers/zarr.py +++ b/src/refiner/pipeline/sources/readers/zarr.py @@ -1,7 +1,7 @@ from __future__ import annotations -from collections.abc import Iterator, Mapping, Sequence -from typing import Any, Literal, cast +from collections.abc import Iterator, Mapping +from typing import Any import pyarrow as pa @@ -15,28 +15,13 @@ from refiner.pipeline.data.row import DictRow from refiner.pipeline.data.shard import RowRangeDescriptor, Shard from refiner.pipeline.sources.base import BaseSource, SourceUnit +from refiner.pipeline.sources.readers.selection import ( + MissingPolicy, + PathSelection, + path_selection_map, +) from refiner.utils import check_required_dependencies -MissingPolicy = Literal["error", "drop_row", "set_null"] -PathSelection = Mapping[str, str] | Sequence[str] | str - - -def _selection_map(value: PathSelection) -> dict[str, str]: - if isinstance(value, str): - return {value.rsplit("/", 1)[-1]: value} - if isinstance(value, Mapping): - return dict(cast(Mapping[str, str], value)) - out: dict[str, str] = {} - for path in value: - name = path.rsplit("/", 1)[-1] - if name in out: - raise ValueError( - "Zarr path selections must have unique derived column names; " - f"use an explicit mapping for duplicate name {name!r}" - ) - out[name] = path - return out - def _decode_value(value: Any) -> Any: if hasattr(value, "shape") and value.shape == (): @@ -69,8 +54,22 @@ def __init__( ): self.root = DataFolder.resolve(input) check_required_dependencies("read_zarr", ["zarr"], dist="zarr") - self.arrays = None if arrays is None else _selection_map(arrays) - self.attrs = None if attrs is None else _selection_map(attrs) + self.arrays = ( + None + if arrays is None + else path_selection_map( + arrays, + format_name="Zarr", + ) + ) + self.attrs = ( + None + if attrs is None + else path_selection_map( + attrs, + format_name="Zarr", + ) + ) self.row_ends = row_ends self.rows_per_shard = rows_per_shard self.row_index_column = row_index_column From eb86905d0bec590fca168b34798ab73a3fef9cdf Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sat, 23 May 2026 14:39:01 +0200 Subject: [PATCH 11/39] Keep row splitting in Zarr reader --- src/refiner/pipeline/pipeline.py | 9 +- src/refiner/pipeline/sources/readers/zarr.py | 23 +-- src/refiner/robotics/row.py | 120 +--------------- tests/readers/test_zarr_reader.py | 23 ++- tests/robotics/test_robotics_row.py | 142 +------------------ 5 files changed, 28 insertions(+), 289 deletions(-) diff --git a/src/refiner/pipeline/pipeline.py b/src/refiner/pipeline/pipeline.py index 3301f99b..9402cdd3 100644 --- a/src/refiner/pipeline/pipeline.py +++ b/src/refiner/pipeline/pipeline.py @@ -44,6 +44,7 @@ ) from refiner.pipeline.sources.readers.lerobot import LeRobotEpisodeReader from refiner.pipeline.sources.readers.selection import MissingPolicy, PathSelection +from refiner.pipeline.sources.readers.zarr import ZarrMissingPolicy from refiner.pipeline.sources.items import ItemsSource from refiner.pipeline.sources.task import TaskSource from refiner.pipeline.data import datatype @@ -177,7 +178,6 @@ def to_robot_rows( video_keys: Mapping[str, str] | Iterable[str] | None = None, stats_key: str | None = "stats", stats_prefix: str = "stats/", - episode_ends_key: str | None = None, ) -> "RefinerPipeline": """Expose rows through the RoboticsRow semantic view. @@ -208,11 +208,8 @@ def to_robot_rows( schema=self.output_schema(), stats_key=stats_key, stats_prefix=stats_prefix, - episode_ends_key=episode_ends_key, ) - if episode_ends_key is None: - return self.map(cast(MapFn, converter)) - return self.flat_map(cast(FlatMapFn, converter)) + return self.map(cast(MapFn, converter)) def map_async( self, @@ -819,7 +816,7 @@ def read_zarr( rows_per_shard: int = 128, row_index_column: str | None = "row_index", file_path_column: str | None = "file_path", - missing_policy: MissingPolicy = "error", + missing_policy: ZarrMissingPolicy = "error", dtypes: DTypeMapping | None = None, ) -> RefinerPipeline: """Create a pipeline with a Zarr reader source. diff --git a/src/refiner/pipeline/sources/readers/zarr.py b/src/refiner/pipeline/sources/readers/zarr.py index bbf733dd..8c32ccc6 100644 --- a/src/refiner/pipeline/sources/readers/zarr.py +++ b/src/refiner/pipeline/sources/readers/zarr.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Iterator, Mapping -from typing import Any +from typing import Any, Literal import pyarrow as pa @@ -16,12 +16,13 @@ from refiner.pipeline.data.shard import RowRangeDescriptor, Shard from refiner.pipeline.sources.base import BaseSource, SourceUnit from refiner.pipeline.sources.readers.selection import ( - MissingPolicy, PathSelection, path_selection_map, ) from refiner.utils import check_required_dependencies +ZarrMissingPolicy = Literal["error", "set_null"] + def _decode_value(value: Any) -> Any: if hasattr(value, "shape") and value.shape == (): @@ -49,7 +50,7 @@ def __init__( rows_per_shard: int = 128, row_index_column: str | None = "row_index", file_path_column: str | None = "file_path", - missing_policy: MissingPolicy = "error", + missing_policy: ZarrMissingPolicy = "error", dtypes: DTypeMapping | None = None, ): self.root = DataFolder.resolve(input) @@ -81,10 +82,8 @@ def __init__( self.attrs or {}, reserved=self._reserved_output_names(row_index=row_ends is not None), ) - if missing_policy not in ("error", "drop_row", "set_null"): - raise ValueError( - "missing_policy must be one of 'error', 'drop_row', or 'set_null'" - ) + if missing_policy not in ("error", "set_null"): + raise ValueError("missing_policy must be one of 'error' or 'set_null'") if ( row_ends is not None and file_path_column is not None @@ -124,8 +123,6 @@ def list_shards(self) -> list[Shard]: try: row_count = len(group[self.row_ends]) except KeyError: - if self.missing_policy == "drop_row": - return [] raise KeyError( f"Zarr row_ends array not found: {self.row_ends}" ) from None @@ -178,8 +175,6 @@ def read_shard(self, shard: Shard) -> Iterator[SourceUnit]: else int(ends_array[descriptor.start - 1]) ) except KeyError: - if self.missing_policy == "drop_row": - return raise KeyError( f"Zarr row_ends array not found: {self.row_ends}" ) from None @@ -248,8 +243,6 @@ def _read_arrays( group[path][start:end] if start is not None else group[path][:] ) except KeyError: - if self.missing_policy == "drop_row": - return None if self.missing_policy == "set_null": row[output_name] = None continue @@ -259,8 +252,6 @@ def _read_arrays( def _read_attrs(self, group: Any, row: dict[str, Any]) -> dict[str, Any] | None: for output_name, attr_name in (self.attrs or {}).items(): if attr_name not in group.attrs: - if self.missing_policy == "drop_row": - return None if self.missing_policy == "set_null": row[output_name] = None continue @@ -294,4 +285,4 @@ def _iter_array_paths(group: Any, prefix: str = "") -> Iterator[str]: yield from _iter_array_paths(item, path) -__all__ = ["MissingPolicy", "PathSelection", "ZarrReader"] +__all__ = ["PathSelection", "ZarrMissingPolicy", "ZarrReader"] diff --git a/src/refiner/robotics/row.py b/src/refiner/robotics/row.py index a3df7ba4..8bf0a971 100644 --- a/src/refiner/robotics/row.py +++ b/src/refiner/robotics/row.py @@ -551,36 +551,8 @@ def _robot_row_converter( video_keys: Mapping[str, str] | Iterable[str] | None = None, stats_key: str | None = "stats", stats_prefix: str = "stats/", - episode_ends_key: str | None = None, schema: pa.Schema | None = None, -) -> Callable[[Row], Row] | Callable[[Row], Iterable[Row]]: - if episode_ends_key is not None: - - def split_row(row: Row) -> Iterable[Row]: - return cast( - Iterable[Row], - _rows_from_episode_ends( - row, - episode_id_key=episode_id_key, - task_key=task_key, - fps=fps, - fps_key=fps_key, - robot_type=robot_type, - robot_type_key=robot_type_key, - timestamp_key=timestamp_key, - action_key=action_key, - state_key=state_key, - extra_observation_keys=extra_observation_keys, - video_keys=video_keys, - schema=schema, - stats_key=stats_key, - stats_prefix=stats_prefix, - episode_ends_key=episode_ends_key, - ), - ) - - return split_row - +) -> Callable[[Row], Row]: spec = _RoboticsRowSpec.from_options( episode_id_key=episode_id_key, task_key=task_key, @@ -616,90 +588,6 @@ def _valid_nested_frames_key(row: Row, key: str | None) -> str | None: return key if isinstance(value, Sequence) else None -def _rows_from_episode_ends( - row: Row, - *, - episode_id_key: str | None, - task_key: str | None, - fps: float | None, - fps_key: str | None, - robot_type: str | None, - robot_type_key: str | None, - timestamp_key: str | None, - action_key: str | None, - state_key: str | Sequence[str] | None, - extra_observation_keys: Mapping[str, str] | Iterable[str] | None, - video_keys: Mapping[str, str] | Iterable[str] | None, - schema: pa.Schema | None, - stats_key: str | None, - stats_prefix: str, - episode_ends_key: str, -) -> Iterator[RoboticsRow]: - episode_ends = [int(value) for value in _get_path(row, episode_ends_key)] - spec = _RoboticsRowSpec.from_options( - episode_id_key=episode_id_key, - task_key=task_key, - fps=fps, - fps_key=fps_key, - robot_type=robot_type, - robot_type_key=robot_type_key, - nested_frames_key=None, - timestamp_key=timestamp_key, - action_key=action_key, - state_key=state_key, - extra_observation_keys=extra_observation_keys, - video_keys=video_keys, - schema=schema, - stats_key=stats_key, - stats_prefix=stats_prefix, - ) - start = 0 - for episode_idx, end in enumerate(episode_ends): - if episode_id_key is None: - split_row = row - else: - split_row = row.update( - {episode_id_key: f"{row.get(episode_id_key, '-1')}-{episode_idx}"} - ) - for source_key in spec.frame_source_map.values(): - for key in _source_keys(source_key): - if _has_path(row, key): - split_row = split_row.update( - _set_path( - split_row, - key, - _slice_values(_get_path(row, key), start, end), - ) - ) - video_fps = int(spec.wrap(row).fps or 30) - for source_key, storage in spec.video_source_map.values(): - if not _has_path(row, source_key): - continue - value = _get_path(row, source_key) - video = video_from_storage_value(storage, value, fps=video_fps) - if video is None: - continue - clip_fps = int(getattr(video, "fps", video_fps)) - split_row = split_row.update( - _set_path( - split_row, - source_key, - video.clipped( - from_timestamp_s=start / clip_fps, - to_timestamp_s=end / clip_fps, - ), - ) - ) - if task_key is not None and task_key in row: - split_row = split_row.update({task_key: row[task_key]}) - if fps_key is not None and fps_key in row: - split_row = split_row.update({fps_key: row[fps_key]}) - if robot_type_key is not None and robot_type_key in row: - split_row = split_row.update({robot_type_key: row[robot_type_key]}) - yield spec.wrap(split_row) - start = end - - def _video_sources( *, schema: pa.Schema | None, @@ -853,12 +741,6 @@ def _set_path(row: Mapping[str, Any], key: str, values: Any) -> dict[str, Any]: return {head: root} -def _slice_values(values: Any, start: int, end: int) -> Any: - if isinstance(values, pa.ChunkedArray | pa.Array): - return values.slice(start, end - start) - return values[start:end] - - def _select_frame_table(table: Tabular, indices: Sequence[int]) -> Tabular: selected = table.table.take(pa.array(indices, type=pa.int64())) if "frame_index" in selected.column_names: diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py index 5ccd65d1..083ffe12 100644 --- a/tests/readers/test_zarr_reader.py +++ b/tests/readers/test_zarr_reader.py @@ -177,19 +177,28 @@ def test_read_zarr_rejects_duplicate_metadata_column_names(tmp_path: Path) -> No ) -def test_read_zarr_drop_row_handles_missing_row_ends(tmp_path: Path) -> None: +def test_read_zarr_rejects_drop_row_missing_policy(tmp_path: Path) -> None: path = tmp_path / "policy.zarr" _write_policy_zarr(path) - pipeline = mdr.read_zarr( + with pytest.raises(ValueError, match="missing_policy"): + mdr.read_zarr(path, missing_policy="drop_row") # type: ignore[arg-type] + + +def test_read_zarr_missing_set_null_keeps_group_row(tmp_path: Path) -> None: + path = tmp_path / "policy.zarr" + _write_policy_zarr(path) + + row = mdr.read_zarr( path, - arrays={"action": "data/action"}, - row_ends="meta/missing_episode_ends", - missing_policy="drop_row", + arrays={"missing": "data/missing"}, + attrs={"missing_attr": "missing_attr"}, + missing_policy="set_null", file_path_column=None, - ) + ).take(1)[0] - assert pipeline.source.list_shards() == [] + assert row["missing"] is None + assert row["missing_attr"] is None def test_zarr_to_robot_rows_and_lerobot_roundtrip(tmp_path: Path) -> None: diff --git a/tests/robotics/test_robotics_row.py b/tests/robotics/test_robotics_row.py index cbb1717c..b4e026e8 100644 --- a/tests/robotics/test_robotics_row.py +++ b/tests/robotics/test_robotics_row.py @@ -34,10 +34,7 @@ def _robot_rows( converter = _robot_row_converter(**cast(Any, {**kwargs, "schema": rows.schema})) return [cast(RoboticsRow, converter(row)) for row in rows] converter = _robot_row_converter(**kwargs) - converted = converter(rows) - if kwargs.get("episode_ends_key") is not None: - return [cast(RoboticsRow, row) for row in converted] - return [cast(RoboticsRow, converted)] + return [cast(RoboticsRow, converter(rows))] def test_to_robot_rows_does_not_treat_video_uri_frames_as_frame_table() -> None: @@ -249,31 +246,6 @@ def test_to_robot_rows_preserves_explicit_fps_and_robot_type_keys() -> None: assert robotics_row.robot_type == "aloha" -def test_to_robot_rows_reuses_literal_fps_and_robot_type_for_episode_splits() -> None: - row = DictRow( - { - "dataset_id": "pusht", - "episode_ends": [1, 2], - "action": [[0.0], [1.0]], - } - ) - - rows = list( - _robot_rows( - row, - episode_id_key="dataset_id", - episode_ends_key="episode_ends", - timestamp_key=None, - state_key=None, - fps=30.0, - robot_type="pusht", - ) - ) - - assert [row.fps for row in rows] == [30.0, 30.0] - assert [row.robot_type for row in rows] == ["pusht", "pusht"] - - def test_pipeline_to_robot_rows_forwards_literals_and_explicit_keys() -> None: literal_row = ( from_items([{"episode_id": "episode-1"}]) @@ -654,118 +626,6 @@ def test_to_robot_rows_flattens_nested_frame_dicts() -> None: assert robotics_row.states.to_pylist() == [[0.0], [1.0]] -def test_to_robot_rows_splits_dataset_arrays_with_episode_ends() -> None: - row = DictRow( - { - "dataset_id": "pusht", - "meta": {"episode_ends": [2, 5]}, - "data": { - "obs": [[0.0], [0.1], [1.0], [1.1], [1.2]], - "action": [[0.0], [0.1], [1.0], [1.1], [1.2]], - }, - } - ) - - rows = list( - _robot_rows( - row, - episode_id_key="dataset_id", - episode_ends_key="meta/episode_ends", - timestamp_key=None, - state_key="data/obs", - action_key="data/action", - ) - ) - - assert [row.episode_id for row in rows] == ["pusht-0", "pusht-1"] - assert [row.num_frames for row in rows] == [2, 3] - assert rows[1].actions == [[1.0], [1.1], [1.2]] - - -def test_to_robot_rows_splits_video_frame_arrays_with_episode_ends() -> None: - frames = np.arange(5 * 2 * 2 * 3, dtype=np.uint8).reshape(5, 2, 2, 3) - row = DictRow( - { - "dataset_id": "pusht", - "episode_ends": [2, 5], - "actions": [[0.0], [0.1], [1.0], [1.1], [1.2]], - "front": frames, - } - ) - - rows = list( - _robot_rows( - row, - episode_id_key="dataset_id", - episode_ends_key="episode_ends", - timestamp_key=None, - action_key="actions", - state_key=None, - video_keys={"observation.images.front": "front"}, - fps=10, - ) - ) - - videos = [row.videos["observation.images.front"] for row in rows] - - assert [ - video.frame_count for video in videos if isinstance(video, VideoFrameArray) - ] == [ - 2, - 3, - ] - assert isinstance(videos[0], VideoFrameArray) - assert isinstance(videos[1], VideoFrameArray) - np.testing.assert_array_equal( - np.asarray(list(videos[0].iter_frame_arrays())), - frames[:2], - ) - np.testing.assert_array_equal( - np.asarray(list(videos[1].iter_frame_arrays())), - frames[2:], - ) - - -def test_to_robot_rows_splits_tuple_state_key_sources_with_episode_ends() -> None: - row = DictRow( - { - "dataset_id": "robomimic", - "meta": {"episode_ends": [2, 4]}, - "actions": [[0.0], [0.1], [1.0], [1.1]], - "obs": { - "joint_pos": [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]], - "joint_vel": [[0.1], [0.2], [0.3], [0.4]], - "eef_pos": np.asarray( - [[10.0, 11.0], [12.0, 13.0], [14.0, 15.0], [16.0, 17.0]], - dtype=np.float32, - ), - }, - } - ) - - rows = list( - _robot_rows( - row, - episode_id_key="dataset_id", - episode_ends_key="meta/episode_ends", - timestamp_key=None, - action_key="actions", - state_key=("obs/joint_pos", "obs/joint_vel", "obs/eef_pos"), - ) - ) - - assert [row.episode_id for row in rows] == ["robomimic-0", "robomimic-1"] - assert [row.num_frames for row in rows] == [2, 2] - assert rows[0].states == [ - [1.0, 2.0, 0.1, 10.0, 11.0], - [3.0, 4.0, 0.2, 12.0, 13.0], - ] - assert rows[1].states == [ - [5.0, 6.0, 0.3, 14.0, 15.0], - [7.0, 8.0, 0.4, 16.0, 17.0], - ] - - def test_motion_trim_works_with_mapped_semantic_keys() -> None: row = DictRow( { From 1b00385cb9b7214a9dab5087f977343359849c15 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sat, 23 May 2026 14:55:40 +0200 Subject: [PATCH 12/39] Harden Zarr row-end reads --- docs/reading-and-writing.md | 76 ++++++++++++++++++++ src/refiner/pipeline/pipeline.py | 2 +- src/refiner/pipeline/sources/readers/zarr.py | 48 ++++++++++--- tests/readers/test_zarr_reader.py | 30 ++++++++ 4 files changed, 144 insertions(+), 12 deletions(-) diff --git a/docs/reading-and-writing.md b/docs/reading-and-writing.md index d0d76fa5..b50f0a12 100644 --- a/docs/reading-and-writing.md +++ b/docs/reading-and-writing.md @@ -228,6 +228,82 @@ datasets or attributes default to raising an error. Set If a selected column can be missing from every group in an input file, pass `dtypes` for that column so the reader can emit a stable Arrow type. +## Zarr + +Zarr support lives behind the optional `macrodata-refiner[zarr]` extra. + +```bash +uv add "macrodata-refiner[zarr]" +``` + +`read_zarr(...)` reads one Zarr group. By default, the group becomes one output +row and selected arrays are loaded as full array values. + +```python +import refiner as mdr + +pipeline = mdr.read_zarr( + "replay_buffer.zarr", + arrays={ + "action": "data/action", + "state": "data/state", + }, + attrs={"task": "task"}, +) +``` + +`arrays` and `attrs` accept the same selection forms as HDF5: a mapping from +output column name to Zarr path, one path string, or a sequence of path strings +with unique final components. Use a mapping when derived names would collide. + +For robotics-style replay buffers, pass `row_ends` to split concatenated arrays +into logical rows, usually episodes: + +```python +episodes = mdr.read_zarr( + "replay_buffer.zarr", + arrays={ + "action": "data/action", + "observation.state": "data/state", + "frames": "data/rgb", + }, + attrs={"task": "task"}, + row_ends="meta/episode_ends", + row_index_column="episode_id", + file_path_column=None, +) +``` + +For a store shaped like: + +```text +replay_buffer.zarr +├── data +│ ├── action # shape [total_steps, action_dim] +│ ├── state # shape [total_steps, state_dim] +│ └── rgb # shape [total_steps, height, width, channels] +└── meta + └── episode_ends # cumulative end offsets, for example [152, 319, 477] +``` + +this emits one row per `[start:end]` slice. The selected arrays are sliced along +their leading dimension, while selected attrs are repeated on each row. +`row_index_column` receives the row/episode index when `row_ends` is set. Set it +to `None` to omit that metadata. + +When `row_ends` is set, `rows_per_shard` controls how many output rows are read +as one shard. The default is `1` so image-heavy robotics episodes are not batched +into a large in-memory slice by default. Increase it only when each row slice is +small enough to materialize together. + +`row_ends` is reader control metadata, not an output selection. If you also want +the raw offsets as a column in non-split mode, select that path through `arrays`. + +Missing selected arrays or attrs default to raising an error. Set +`missing_policy="set_null"` to keep the group row and emit `None` for missing +selected values. Zarr does not support `drop_row` because a selected path is +missing at the group schema level rather than per output row. + ## Common Crawl text readers [Common Crawl](https://commoncrawl.org/) publishes large public web crawls. diff --git a/src/refiner/pipeline/pipeline.py b/src/refiner/pipeline/pipeline.py index 9402cdd3..92bce01f 100644 --- a/src/refiner/pipeline/pipeline.py +++ b/src/refiner/pipeline/pipeline.py @@ -813,7 +813,7 @@ def read_zarr( arrays: PathSelection | None = None, attrs: PathSelection | None = None, row_ends: str | None = None, - rows_per_shard: int = 128, + rows_per_shard: int = 1, row_index_column: str | None = "row_index", file_path_column: str | None = "file_path", missing_policy: ZarrMissingPolicy = "error", diff --git a/src/refiner/pipeline/sources/readers/zarr.py b/src/refiner/pipeline/sources/readers/zarr.py index 8c32ccc6..823c1f3f 100644 --- a/src/refiner/pipeline/sources/readers/zarr.py +++ b/src/refiner/pipeline/sources/readers/zarr.py @@ -47,7 +47,7 @@ def __init__( arrays: PathSelection | None = None, attrs: PathSelection | None = None, row_ends: str | None = None, - rows_per_shard: int = 128, + rows_per_shard: int = 1, row_index_column: str | None = "row_index", file_path_column: str | None = "file_path", missing_policy: ZarrMissingPolicy = "error", @@ -178,6 +178,13 @@ def read_shard(self, shard: Shard) -> Iterator[SourceUnit]: raise KeyError( f"Zarr row_ends array not found: {self.row_ends}" ) from None + _validate_row_ends( + row_ends, + start=start, + group=group, + arrays=arrays, + missing_policy=self.missing_policy, + ) shard_start = start shard_end = row_ends[-1] if row_ends else start shard_arrays = self._read_arrays( @@ -186,8 +193,6 @@ def read_shard(self, shard: Shard) -> Iterator[SourceUnit]: start=shard_start, end=shard_end, ) - if shard_arrays is None: - return for offset, end in enumerate(row_ends): row = self._row_metadata(row_index=descriptor.start + offset) relative_start = start - shard_start @@ -197,20 +202,15 @@ def read_shard(self, shard: Shard) -> Iterator[SourceUnit]: None if value is None else value[relative_start:relative_end] ) row = self._read_attrs(group, row) - if row is None: - return yield DictRow(row) start = end return row = self._row_metadata(row_index=None) row_arrays = self._read_arrays(group, arrays) - if row_arrays is None: - return row.update(row_arrays) row = self._read_attrs(group, row) - if row is not None: - yield DictRow(row) + yield DictRow(row) def _reserved_output_names(self, *, row_index: bool) -> set[str]: names = set() @@ -235,7 +235,7 @@ def _read_arrays( *, start: int | None = None, end: int | None = None, - ) -> dict[str, Any] | None: + ) -> dict[str, Any]: row: dict[str, Any] = {} for output_name, path in arrays.items(): try: @@ -249,7 +249,7 @@ def _read_arrays( raise KeyError(f"Zarr array not found: {path}") from None return row - def _read_attrs(self, group: Any, row: dict[str, Any]) -> dict[str, Any] | None: + def _read_attrs(self, group: Any, row: dict[str, Any]) -> dict[str, Any]: for output_name, attr_name in (self.attrs or {}).items(): if attr_name not in group.attrs: if self.missing_policy == "set_null": @@ -276,6 +276,32 @@ def _validate_output_names( raise ValueError(f"Zarr selections use reserved output names: {names}") +def _validate_row_ends( + row_ends: list[int], + *, + start: int, + group: Any, + arrays: Mapping[str, str], + missing_policy: ZarrMissingPolicy, +) -> None: + previous = start + for end in row_ends: + if end < previous: + raise ValueError("Zarr row_ends must be monotonic increasing") + previous = end + for output_name, path in arrays.items(): + try: + length = int(group[path].shape[0]) + except KeyError: + if missing_policy == "set_null": + continue + raise KeyError(f"Zarr array not found: {path}") from None + if row_ends and row_ends[-1] > length: + raise ValueError( + f"Zarr row_ends exceed leading dimension for {output_name!r}" + ) + + def _iter_array_paths(group: Any, prefix: str = "") -> Iterator[str]: for name, item in group.items(): path = f"{prefix}/{name}" if prefix else name diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py index 083ffe12..1e5b77d2 100644 --- a/tests/readers/test_zarr_reader.py +++ b/tests/readers/test_zarr_reader.py @@ -201,6 +201,36 @@ def test_read_zarr_missing_set_null_keeps_group_row(tmp_path: Path) -> None: assert row["missing_attr"] is None +def test_read_zarr_rejects_non_monotonic_row_ends(tmp_path: Path) -> None: + path = tmp_path / "policy.zarr" + _write_policy_zarr(path) + root = zarr.open_group(str(path), mode="a") + root["meta/episode_ends"][:] = np.asarray([3, 2], dtype=np.int64) + + with pytest.raises(ValueError, match="row_ends must be monotonic"): + mdr.read_zarr( + path, + arrays={"action": "data/action"}, + row_ends="meta/episode_ends", + file_path_column=None, + ).take(1) + + +def test_read_zarr_rejects_out_of_range_row_ends(tmp_path: Path) -> None: + path = tmp_path / "policy.zarr" + _write_policy_zarr(path) + root = zarr.open_group(str(path), mode="a") + root["meta/episode_ends"][:] = np.asarray([2, 6], dtype=np.int64) + + with pytest.raises(ValueError, match="row_ends exceed"): + mdr.read_zarr( + path, + arrays={"action": "data/action"}, + row_ends="meta/episode_ends", + file_path_column=None, + ).take(2) + + def test_zarr_to_robot_rows_and_lerobot_roundtrip(tmp_path: Path) -> None: path = tmp_path / "policy.zarr" lerobot_out = tmp_path / "lerobot" From 2127e32a79cffa2060f21f30fe5f728f8de34701 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sat, 23 May 2026 16:08:12 +0200 Subject: [PATCH 13/39] Plan Zarr reads from metadata --- docs/reading-and-writing.md | 35 +- src/refiner/pipeline/pipeline.py | 25 +- src/refiner/pipeline/sources/readers/zarr.py | 400 ++++++++++++------- tests/readers/test_zarr_reader.py | 89 ++++- 4 files changed, 369 insertions(+), 180 deletions(-) diff --git a/docs/reading-and-writing.md b/docs/reading-and-writing.md index b50f0a12..ef6c14bf 100644 --- a/docs/reading-and-writing.md +++ b/docs/reading-and-writing.md @@ -269,7 +269,7 @@ episodes = mdr.read_zarr( }, attrs={"task": "task"}, row_ends="meta/episode_ends", - row_index_column="episode_id", + index_column="episode_id", file_path_column=None, ) ``` @@ -288,21 +288,34 @@ replay_buffer.zarr this emits one row per `[start:end]` slice. The selected arrays are sliced along their leading dimension, while selected attrs are repeated on each row. -`row_index_column` receives the row/episode index when `row_ends` is set. Set it -to `None` to omit that metadata. +`index_column` receives the row/episode index when `row_ends` is set. Set it to +`None` to omit that metadata. -When `row_ends` is set, `rows_per_shard` controls how many output rows are read -as one shard. The default is `1` so image-heavy robotics episodes are not batched -into a large in-memory slice by default. Increase it only when each row slice is -small enough to materialize together. +If a Zarr store has aligned arrays but no episode boundaries, use +`split_leading_axis=True` to emit leading-axis windows: + +```python +windows = mdr.read_zarr( + "replay_buffer.zarr", + arrays={ + "action": "data/action", + "frames": "data/rgb", + }, + split_leading_axis=True, + target_shard_bytes=256 * 1024**2, +) +``` + +This mode requires selected arrays to have the same leading dimension. Refiner +chooses contiguous windows from array metadata, using the byte-heavy array's +chunking to avoid unnecessary chunk splits where possible. Use `num_shards` when +you need a target shard count instead of byte-sized packing. `row_ends` is reader control metadata, not an output selection. If you also want the raw offsets as a column in non-split mode, select that path through `arrays`. -Missing selected arrays or attrs default to raising an error. Set -`missing_policy="set_null"` to keep the group row and emit `None` for missing -selected values. Zarr does not support `drop_row` because a selected path is -missing at the group schema level rather than per output row. +Missing selected arrays or attrs always raise. Zarr selections describe group +schema, not row-local optional fields. ## Common Crawl text readers diff --git a/src/refiner/pipeline/pipeline.py b/src/refiner/pipeline/pipeline.py index 92bce01f..5ffa852e 100644 --- a/src/refiner/pipeline/pipeline.py +++ b/src/refiner/pipeline/pipeline.py @@ -44,7 +44,6 @@ ) from refiner.pipeline.sources.readers.lerobot import LeRobotEpisodeReader from refiner.pipeline.sources.readers.selection import MissingPolicy, PathSelection -from refiner.pipeline.sources.readers.zarr import ZarrMissingPolicy from refiner.pipeline.sources.items import ItemsSource from refiner.pipeline.sources.task import TaskSource from refiner.pipeline.data import datatype @@ -813,17 +812,22 @@ def read_zarr( arrays: PathSelection | None = None, attrs: PathSelection | None = None, row_ends: str | None = None, - rows_per_shard: int = 1, - row_index_column: str | None = "row_index", + split_leading_axis: bool = False, + target_shard_bytes: int = 256 * 1024**2, + num_shards: int | None = None, + index_column: str | None = "index", file_path_column: str | None = "file_path", - missing_policy: ZarrMissingPolicy = "error", dtypes: DTypeMapping | None = None, ) -> RefinerPipeline: """Create a pipeline with a Zarr reader source. - The reader emits one row for the Zarr group. If `row_ends` is provided, - it reads that Zarr array as cumulative end offsets and emits one row per - `[start:end]` slice. + The reader has three modes: + - group mode: one Zarr group becomes one row + - row_ends mode: cumulative offsets define whole-row source slices + - split_leading_axis mode: aligned axis-0 windows define output rows + + Missing selected arrays or attributes raise immediately. `row_ends` and + `split_leading_axis` are mutually exclusive. """ return RefinerPipeline( source=ZarrReader( @@ -831,10 +835,11 @@ def read_zarr( arrays=arrays, attrs=attrs, row_ends=row_ends, - rows_per_shard=rows_per_shard, - row_index_column=row_index_column, + split_leading_axis=split_leading_axis, + target_shard_bytes=target_shard_bytes, + num_shards=num_shards, + index_column=index_column, file_path_column=file_path_column, - missing_policy=missing_policy, dtypes=dtypes, ) ) diff --git a/src/refiner/pipeline/sources/readers/zarr.py b/src/refiner/pipeline/sources/readers/zarr.py index 823c1f3f..c20dd3bb 100644 --- a/src/refiner/pipeline/sources/readers/zarr.py +++ b/src/refiner/pipeline/sources/readers/zarr.py @@ -1,7 +1,9 @@ from __future__ import annotations -from collections.abc import Iterator, Mapping -from typing import Any, Literal +from collections.abc import Iterator, Mapping, Sequence +from dataclasses import dataclass +from math import ceil, prod +from typing import Any import pyarrow as pa @@ -21,7 +23,33 @@ ) from refiner.utils import check_required_dependencies -ZarrMissingPolicy = Literal["error", "set_null"] +DEFAULT_TARGET_SHARD_BYTES = 256 * 1024**2 + + +@dataclass(frozen=True, slots=True) +class _ArrayInfo: + output_name: str + path: str + shape: tuple[int, ...] + chunks: tuple[int, ...] + dtype: Any + + @property + def leading_length(self) -> int: + if not self.shape: + raise ValueError( + f"Zarr array {self.path!r} must have a leading dimension to split" + ) + return int(self.shape[0]) + + @property + def leading_chunk(self) -> int: + return int(self.chunks[0]) if self.chunks else self.leading_length + + @property + def bytes_per_step(self) -> int: + trailing_shape = self.shape[1:] + return max(1, int(self.dtype.itemsize) * int(prod(trailing_shape or (1,)))) def _decode_value(value: Any) -> Any: @@ -36,7 +64,7 @@ def _decode_value(value: Any) -> Any: class ZarrReader(BaseSource): - """Read one Zarr group as one row, or split arrays by cumulative row ends.""" + """Read a Zarr group as one row, episode rows, or leading-axis windows.""" name = "read_zarr" @@ -47,49 +75,66 @@ def __init__( arrays: PathSelection | None = None, attrs: PathSelection | None = None, row_ends: str | None = None, - rows_per_shard: int = 1, - row_index_column: str | None = "row_index", + split_leading_axis: bool = False, + target_shard_bytes: int = DEFAULT_TARGET_SHARD_BYTES, + num_shards: int | None = None, + index_column: str | None = "index", file_path_column: str | None = "file_path", - missing_policy: ZarrMissingPolicy = "error", dtypes: DTypeMapping | None = None, ): + """Create a Zarr reader. + + Args: + input: Zarr group path. + arrays: Array selections as output-name to Zarr-path mapping, a + single path, a path sequence, or None to discover all arrays. + attrs: Attribute selections with the same shape as ``arrays``. + row_ends: Optional Zarr array path containing cumulative end offsets. + When set, emitted rows are whole source ranges and never split + across these boundaries. + split_leading_axis: Emit aligned leading-axis windows when no + ``row_ends`` path is provided. + target_shard_bytes: Approximate byte target used to pack output rows + into shards in split modes. + num_shards: Optional target shard count for split modes. + index_column: Output metadata column containing the episode/window + index in split modes, or None to omit it. + file_path_column: Output metadata column containing the source path, + or None to omit it. + dtypes: Optional dtype overrides for output columns. + """ self.root = DataFolder.resolve(input) check_required_dependencies("read_zarr", ["zarr"], dist="zarr") + if row_ends is not None and split_leading_axis: + raise ValueError("row_ends and split_leading_axis are mutually exclusive") + if target_shard_bytes <= 0: + raise ValueError("target_shard_bytes must be greater than zero") + if num_shards is not None and num_shards <= 0: + raise ValueError("num_shards must be greater than zero") self.arrays = ( - None - if arrays is None - else path_selection_map( - arrays, - format_name="Zarr", - ) + None if arrays is None else path_selection_map(arrays, format_name="Zarr") ) self.attrs = ( - None - if attrs is None - else path_selection_map( - attrs, - format_name="Zarr", - ) + None if attrs is None else path_selection_map(attrs, format_name="Zarr") ) self.row_ends = row_ends - self.rows_per_shard = rows_per_shard - self.row_index_column = row_index_column + self.split_leading_axis = split_leading_axis + self.target_shard_bytes = target_shard_bytes + self.num_shards = num_shards + self.index_column = index_column self.file_path_column = file_path_column - self.missing_policy = missing_policy self.dtypes = dtypes _validate_output_names( self.arrays or {}, self.attrs or {}, - reserved=self._reserved_output_names(row_index=row_ends is not None), + reserved=self._reserved_output_names(split=self._is_split_mode), ) - if missing_policy not in ("error", "set_null"): - raise ValueError("missing_policy must be one of 'error' or 'set_null'") if ( - row_ends is not None + self._is_split_mode and file_path_column is not None - and file_path_column == row_index_column + and file_path_column == index_column ): - raise ValueError("file_path_column and row_index_column must be distinct") + raise ValueError("file_path_column and index_column must be distinct") @property def schema(self) -> pa.Schema | None: @@ -101,10 +146,11 @@ def describe(self) -> dict[str, Any]: "arrays": dict(self.arrays) if self.arrays is not None else None, "attrs": dict(self.attrs) if self.attrs is not None else None, "row_ends": self.row_ends, - "rows_per_shard": self.rows_per_shard, - "row_index_column": self.row_index_column, + "split_leading_axis": self.split_leading_axis, + "target_shard_bytes": self.target_shard_bytes, + "num_shards": self.num_shards, + "index_column": self.index_column, "file_path_column": self.file_path_column, - "missing_policy": self.missing_policy, "dtypes": ( {key: dtype_to_plan(dtype) for key, dtype in self.dtypes.items()} if self.dtypes @@ -114,120 +160,139 @@ def describe(self) -> dict[str, Any]: def list_shards(self) -> list[Shard]: path = self.root.abs_path() - if self.row_ends is not None: - if self.rows_per_shard <= 0: - raise ValueError("rows_per_shard must be greater than zero") - import zarr + import zarr - group = zarr.open_group(store=zarr_store(self.root), mode="r") - try: - row_count = len(group[self.row_ends]) - except KeyError: - raise KeyError( - f"Zarr row_ends array not found: {self.row_ends}" - ) from None - return [ - Shard.from_row_range( - start=start, - end=min(start + self.rows_per_shard, row_count), - global_ordinal=index, - start_key=path, - end_key=path, - ) - for index, start in enumerate(range(0, row_count, self.rows_per_shard)) - ] + group = zarr.open_group(store=zarr_store(self.root), mode="r") + arrays = self._array_selection(group) + _validate_output_names( + arrays, + self.attrs or {}, + reserved=self._reserved_output_names(split=self._is_split_mode), + ) + split_ranges = self._split_ranges(group, self._array_infos(group, arrays)) return [ Shard.from_row_range( - start=0, - end=1, - global_ordinal=0, + start=start, + end=end, + global_ordinal=index, start_key=path, end_key=path, ) + for index, (start, end) in enumerate(split_ranges) ] def read_shard(self, shard: Shard) -> Iterator[SourceUnit]: import zarr group = zarr.open_group(store=zarr_store(self.root), mode="r") - arrays = ( - {path: path for path in _iter_array_paths(group) if path != self.row_ends} - if self.arrays is None - else self.arrays - ) + arrays = self._array_selection(group) _validate_output_names( arrays, self.attrs or {}, - reserved=self._reserved_output_names(row_index=self.row_ends is not None), + reserved=self._reserved_output_names(split=self._is_split_mode), ) - if self.row_ends is not None: + if self._is_split_mode: descriptor = shard.descriptor assert isinstance(descriptor, RowRangeDescriptor) - try: - ends_array = group[self.row_ends] - row_ends = [ - int(value) - for value in ends_array[descriptor.start : descriptor.end] - ] - start = ( - 0 - if descriptor.start == 0 - else int(ends_array[descriptor.start - 1]) - ) - except KeyError: - raise KeyError( - f"Zarr row_ends array not found: {self.row_ends}" - ) from None - _validate_row_ends( - row_ends, - start=start, - group=group, - arrays=arrays, - missing_policy=self.missing_policy, - ) - shard_start = start - shard_end = row_ends[-1] if row_ends else start - shard_arrays = self._read_arrays( - group, - arrays, - start=shard_start, - end=shard_end, - ) - for offset, end in enumerate(row_ends): - row = self._row_metadata(row_index=descriptor.start + offset) - relative_start = start - shard_start - relative_end = end - shard_start - for output_name, value in shard_arrays.items(): - row[output_name] = ( - None if value is None else value[relative_start:relative_end] - ) - row = self._read_attrs(group, row) - yield DictRow(row) - start = end + source_ranges = self._source_ranges(group, self._array_infos(group, arrays)) + for row_index in range(descriptor.start, descriptor.end): + start, end = source_ranges[row_index] + row = self._row_metadata(index=row_index) + row.update(self._read_arrays(group, arrays, start=start, end=end)) + yield DictRow(self._read_attrs(group, row)) return - row = self._row_metadata(row_index=None) - row_arrays = self._read_arrays(group, arrays) - row.update(row_arrays) - row = self._read_attrs(group, row) - yield DictRow(row) + row = self._row_metadata(index=None) + row.update(self._read_arrays(group, arrays)) + yield DictRow(self._read_attrs(group, row)) - def _reserved_output_names(self, *, row_index: bool) -> set[str]: + @property + def _is_split_mode(self) -> bool: + return self.row_ends is not None or self.split_leading_axis + + def _reserved_output_names(self, *, split: bool) -> set[str]: names = set() if self.file_path_column is not None: names.add(self.file_path_column) - if row_index and self.row_index_column is not None: - names.add(self.row_index_column) + if split and self.index_column is not None: + names.add(self.index_column) return names - def _row_metadata(self, *, row_index: int | None) -> dict[str, Any]: + def _row_metadata(self, *, index: int | None) -> dict[str, Any]: row: dict[str, Any] = {} if self.file_path_column is not None: row[self.file_path_column] = self.root.abs_path() - if self.row_index_column is not None and row_index is not None: - row[self.row_index_column] = row_index + if self.index_column is not None and index is not None: + row[self.index_column] = index return row + def _array_selection(self, group: Any) -> dict[str, str]: + if self.arrays is not None: + return self.arrays + return { + path: path for path in _iter_array_paths(group) if path != self.row_ends + } + + def _array_infos( + self, + group: Any, + arrays: Mapping[str, str], + ) -> list[_ArrayInfo]: + infos: list[_ArrayInfo] = [] + for output_name, path in arrays.items(): + try: + array = group[path] + except KeyError: + raise KeyError(f"Zarr array not found: {path}") from None + infos.append( + _ArrayInfo( + output_name=output_name, + path=path, + shape=tuple(int(value) for value in array.shape), + chunks=tuple(int(value) for value in array.chunks), + dtype=array.dtype, + ) + ) + return infos + + def _source_ranges( + self, + group: Any, + infos: Sequence[_ArrayInfo], + ) -> list[tuple[int, int]]: + if self.row_ends is not None: + try: + ends = [int(value) for value in group[self.row_ends][:]] + except KeyError: + raise KeyError( + f"Zarr row_ends array not found: {self.row_ends}" + ) from None + ranges = _ranges_from_ends(ends) + _validate_source_ranges(ranges, infos, label="row_ends") + return ranges + return _leading_axis_ranges( + infos, + target_shard_bytes=self.target_shard_bytes, + num_shards=self.num_shards, + ) + + def _split_ranges( + self, + group: Any, + infos: Sequence[_ArrayInfo], + ) -> list[tuple[int, int]]: + if not self._is_split_mode: + return [(0, 1)] + if self.split_leading_axis: + source_ranges = self._source_ranges(group, infos) + return [(index, index + 1) for index in range(len(source_ranges))] + return _pack_output_rows( + self._source_ranges(group, infos), + infos, + target_shard_bytes=self.target_shard_bytes, + num_shards=self.num_shards, + ) + def _read_arrays( self, group: Any, @@ -243,18 +308,12 @@ def _read_arrays( group[path][start:end] if start is not None else group[path][:] ) except KeyError: - if self.missing_policy == "set_null": - row[output_name] = None - continue raise KeyError(f"Zarr array not found: {path}") from None return row def _read_attrs(self, group: Any, row: dict[str, Any]) -> dict[str, Any]: for output_name, attr_name in (self.attrs or {}).items(): if attr_name not in group.attrs: - if self.missing_policy == "set_null": - row[output_name] = None - continue raise KeyError(f"Zarr attr not found: {attr_name}") row[output_name] = _decode_value(group.attrs[attr_name]) return row @@ -276,32 +335,91 @@ def _validate_output_names( raise ValueError(f"Zarr selections use reserved output names: {names}") -def _validate_row_ends( - row_ends: list[int], +def _ranges_from_ends(ends: Sequence[int]) -> list[tuple[int, int]]: + ranges: list[tuple[int, int]] = [] + start = 0 + for end in ends: + if end < start: + raise ValueError("Zarr row_ends must be monotonic increasing") + ranges.append((start, end)) + start = end + return ranges + + +def _validate_source_ranges( + ranges: Sequence[tuple[int, int]], + infos: Sequence[_ArrayInfo], *, - start: int, - group: Any, - arrays: Mapping[str, str], - missing_policy: ZarrMissingPolicy, + label: str, ) -> None: - previous = start - for end in row_ends: - if end < previous: - raise ValueError("Zarr row_ends must be monotonic increasing") - previous = end - for output_name, path in arrays.items(): - try: - length = int(group[path].shape[0]) - except KeyError: - if missing_policy == "set_null": - continue - raise KeyError(f"Zarr array not found: {path}") from None - if row_ends and row_ends[-1] > length: + if not infos: + return + final_end = ranges[-1][1] if ranges else 0 + for info in infos: + if final_end > info.leading_length: raise ValueError( - f"Zarr row_ends exceed leading dimension for {output_name!r}" + f"Zarr {label} exceed leading dimension for {info.output_name!r}" ) +def _leading_axis_ranges( + infos: Sequence[_ArrayInfo], + *, + target_shard_bytes: int, + num_shards: int | None, +) -> list[tuple[int, int]]: + if not infos: + raise ValueError("split_leading_axis requires at least one selected array") + lengths = {info.leading_length for info in infos} + if len(lengths) != 1: + raise ValueError("Zarr selected arrays must have the same leading dimension") + length = lengths.pop() + if length == 0: + return [] + if num_shards is not None: + step = ceil(length / num_shards) + else: + bytes_per_step = sum(info.bytes_per_step for info in infos) + target_steps = max(1, target_shard_bytes // max(1, bytes_per_step)) + heavy = max(infos, key=lambda info: info.bytes_per_step) + base = max(1, heavy.leading_chunk) + step = max(base, (target_steps // base) * base) + return [(start, min(start + step, length)) for start in range(0, length, step)] + + +def _pack_output_rows( + source_ranges: Sequence[tuple[int, int]], + infos: Sequence[_ArrayInfo], + *, + target_shard_bytes: int, + num_shards: int | None, +) -> list[tuple[int, int]]: + if not source_ranges: + return [] + if num_shards is not None: + step = ceil(len(source_ranges) / num_shards) + return [ + (start, min(start + step, len(source_ranges))) + for start in range(0, len(source_ranges), step) + ] + + bytes_per_step = sum(info.bytes_per_step for info in infos) + if bytes_per_step <= 0: + return [(0, len(source_ranges))] + ranges: list[tuple[int, int]] = [] + start_index = 0 + current_bytes = 0 + for index, (start, end) in enumerate(source_ranges): + row_bytes = max(1, end - start) * bytes_per_step + if index > start_index and current_bytes + row_bytes > target_shard_bytes: + ranges.append((start_index, index)) + start_index = index + current_bytes = 0 + current_bytes += row_bytes + ranges.append((start_index, len(source_ranges))) + return ranges + + def _iter_array_paths(group: Any, prefix: str = "") -> Iterator[str]: for name, item in group.items(): path = f"{prefix}/{name}" if prefix else name @@ -311,4 +429,4 @@ def _iter_array_paths(group: Any, prefix: str = "") -> Iterator[str]: yield from _iter_array_paths(item, path) -__all__ = ["PathSelection", "ZarrMissingPolicy", "ZarrReader"] +__all__ = ["PathSelection", "ZarrReader"] diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py index 1e5b77d2..bd3df917 100644 --- a/tests/readers/test_zarr_reader.py +++ b/tests/readers/test_zarr_reader.py @@ -86,7 +86,7 @@ def test_read_zarr_splits_arrays_by_row_ends(tmp_path: Path) -> None: ) -def test_read_zarr_plans_one_shard_per_row_end(tmp_path: Path) -> None: +def test_read_zarr_plans_row_ends_with_num_shards(tmp_path: Path) -> None: path = tmp_path / "policy.zarr" _write_policy_zarr(path) @@ -94,7 +94,7 @@ def test_read_zarr_plans_one_shard_per_row_end(tmp_path: Path) -> None: path, arrays={"action": "data/action"}, row_ends="meta/episode_ends", - rows_per_shard=1, + num_shards=2, file_path_column=None, ) @@ -107,7 +107,7 @@ def test_read_zarr_plans_one_shard_per_row_end(tmp_path: Path) -> None: rows = [cast(Row, row) for row in pipeline.source.read_shard(shards[1])] assert len(rows) == 1 - assert rows[0]["row_index"] == 1 + assert rows[0]["index"] == 1 np.testing.assert_allclose(rows[0]["action"], [[1.0], [1.1], [1.2]]) @@ -151,14 +151,14 @@ def test_read_zarr_rejects_discovered_array_attr_collisions(tmp_path: Path) -> N pipeline.take(1) -def test_read_zarr_rejects_reserved_row_index_output_name(tmp_path: Path) -> None: +def test_read_zarr_rejects_reserved_index_output_name(tmp_path: Path) -> None: path = tmp_path / "policy.zarr" _write_policy_zarr(path) with pytest.raises(ValueError, match="reserved output names"): mdr.read_zarr( path, - arrays={"row_index": "data/action"}, + arrays={"index": "data/action"}, row_ends="meta/episode_ends", file_path_column=None, ) @@ -173,32 +173,85 @@ def test_read_zarr_rejects_duplicate_metadata_column_names(tmp_path: Path) -> No path, row_ends="meta/episode_ends", file_path_column="metadata", - row_index_column="metadata", + index_column="metadata", ) -def test_read_zarr_rejects_drop_row_missing_policy(tmp_path: Path) -> None: +def test_read_zarr_rejects_missing_selected_paths(tmp_path: Path) -> None: path = tmp_path / "policy.zarr" _write_policy_zarr(path) - with pytest.raises(ValueError, match="missing_policy"): - mdr.read_zarr(path, missing_policy="drop_row") # type: ignore[arg-type] + with pytest.raises(KeyError, match="Zarr array not found"): + mdr.read_zarr( + path, + arrays={"missing": "data/missing"}, + file_path_column=None, + ).take(1) -def test_read_zarr_missing_set_null_keeps_group_row(tmp_path: Path) -> None: +def test_read_zarr_rejects_missing_selected_attrs(tmp_path: Path) -> None: path = tmp_path / "policy.zarr" _write_policy_zarr(path) - row = mdr.read_zarr( + with pytest.raises(KeyError, match="Zarr attr not found"): + mdr.read_zarr( + path, + arrays={}, + attrs={"missing_attr": "missing_attr"}, + file_path_column=None, + ).take(1) + + +def test_read_zarr_split_leading_axis_emits_aligned_windows(tmp_path: Path) -> None: + path = tmp_path / "windows.zarr" + root = zarr.open_group(str(path), mode="w") + root.create_dataset( + "data/action", + data=np.arange(5, dtype=np.float32).reshape(5, 1), + chunks=(5, 1), + ) + root.create_dataset( + "data/rgb", + data=np.arange(5 * 4 * 4 * 3, dtype=np.uint8).reshape(5, 4, 4, 3), + chunks=(2, 4, 4, 3), + ) + + pipeline = mdr.read_zarr( path, - arrays={"missing": "data/missing"}, - attrs={"missing_attr": "missing_attr"}, - missing_policy="set_null", + arrays={"action": "data/action", "rgb": "data/rgb"}, + split_leading_axis=True, + target_shard_bytes=96, file_path_column=None, - ).take(1)[0] + ) + + shards = pipeline.source.list_shards() + ranges = [cast(RowRangeDescriptor, shard.descriptor) for shard in shards] + assert [(item.start, item.end) for item in ranges] == [ + (0, 1), + (1, 2), + (2, 3), + ] - assert row["missing"] is None - assert row["missing_attr"] is None + rows = pipeline.take(3) + + assert [row["index"] for row in rows] == [0, 1, 2] + assert [len(row["action"]) for row in rows] == [2, 2, 1] + np.testing.assert_allclose(rows[1]["action"], [[2.0], [3.0]]) + + +def test_read_zarr_split_leading_axis_requires_aligned_lengths(tmp_path: Path) -> None: + path = tmp_path / "misaligned.zarr" + root = zarr.open_group(str(path), mode="w") + root.create_dataset("data/action", data=np.zeros((5, 1), dtype=np.float32)) + root.create_dataset("data/state", data=np.zeros((4, 1), dtype=np.float32)) + + with pytest.raises(ValueError, match="same leading dimension"): + mdr.read_zarr( + path, + arrays={"action": "data/action", "state": "data/state"}, + split_leading_axis=True, + file_path_column=None, + ).take(1) def test_read_zarr_rejects_non_monotonic_row_ends(tmp_path: Path) -> None: @@ -249,7 +302,7 @@ def test_zarr_to_robot_rows_and_lerobot_roundtrip(tmp_path: Path) -> None: file_path_column=None, ) .to_robot_rows( - episode_id_key="row_index", + episode_id_key="index", task_key="task", action_key="action", state_key="observation.state", From 5a6a29cff25e2dfa50449caab15cf9d69bef18fc Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sat, 23 May 2026 16:15:20 +0200 Subject: [PATCH 14/39] Clean up Zarr reader planning --- docs/reading-and-writing.md | 2 +- src/refiner/io/zarr.py | 25 ------- src/refiner/pipeline/pipeline.py | 5 +- src/refiner/pipeline/sources/readers/hdf5.py | 11 ++- .../pipeline/sources/readers/selection.py | 5 +- src/refiner/pipeline/sources/readers/zarr.py | 67 +++++++++++++++---- tests/readers/test_zarr_reader.py | 14 ++++ 7 files changed, 78 insertions(+), 51 deletions(-) delete mode 100644 src/refiner/io/zarr.py diff --git a/docs/reading-and-writing.md b/docs/reading-and-writing.md index ef6c14bf..f0220204 100644 --- a/docs/reading-and-writing.md +++ b/docs/reading-and-writing.md @@ -302,7 +302,7 @@ windows = mdr.read_zarr( "frames": "data/rgb", }, split_leading_axis=True, - target_shard_bytes=256 * 1024**2, + target_shard_bytes=128 * 1024**2, ) ``` diff --git a/src/refiner/io/zarr.py b/src/refiner/io/zarr.py deleted file mode 100644 index bff3958d..00000000 --- a/src/refiner/io/zarr.py +++ /dev/null @@ -1,25 +0,0 @@ -from __future__ import annotations - -from typing import Literal - -from refiner.io.datafolder import DataFolder - - -def zarr_store( - folder: DataFolder, - path: str = "", - *, - mode: Literal["r", "w", "w-", "a"] = "r", -): - import zarr - - create = mode in {"w", "w-", "a"} - return zarr.storage.FSStore( - folder._join(path), - fs=folder.fs, - mode=mode, - create=create, - ) - - -__all__ = ["zarr_store"] diff --git a/src/refiner/pipeline/pipeline.py b/src/refiner/pipeline/pipeline.py index 5ffa852e..6005943b 100644 --- a/src/refiner/pipeline/pipeline.py +++ b/src/refiner/pipeline/pipeline.py @@ -42,8 +42,9 @@ ParquetReader, ZarrReader, ) +from refiner.pipeline.sources.readers.hdf5 import MissingPolicy from refiner.pipeline.sources.readers.lerobot import LeRobotEpisodeReader -from refiner.pipeline.sources.readers.selection import MissingPolicy, PathSelection +from refiner.pipeline.sources.readers.selection import PathSelection from refiner.pipeline.sources.items import ItemsSource from refiner.pipeline.sources.task import TaskSource from refiner.pipeline.data import datatype @@ -813,7 +814,7 @@ def read_zarr( attrs: PathSelection | None = None, row_ends: str | None = None, split_leading_axis: bool = False, - target_shard_bytes: int = 256 * 1024**2, + target_shard_bytes: int = DEFAULT_TARGET_SHARD_BYTES, num_shards: int | None = None, index_column: str | None = "index", file_path_column: str | None = "file_path", diff --git a/src/refiner/pipeline/sources/readers/hdf5.py b/src/refiner/pipeline/sources/readers/hdf5.py index b50b8c03..56583bf1 100644 --- a/src/refiner/pipeline/sources/readers/hdf5.py +++ b/src/refiner/pipeline/sources/readers/hdf5.py @@ -3,7 +3,7 @@ from collections.abc import Iterator, Mapping, Sequence import fnmatch from glob import has_magic -from typing import Any +from typing import Any, Literal from fsspec import AbstractFileSystem @@ -13,15 +13,14 @@ from refiner.pipeline.data.row import DictRow from refiner.pipeline.data.shard import FilePartsDescriptor from refiner.pipeline.sources.readers.base import BaseReader, Shard, SourceUnit -from refiner.pipeline.sources.readers.selection import ( - MissingPolicy, - PathSelection, - path_selection_map, -) +from refiner.pipeline.sources.readers.selection import PathSelection, path_selection_map from refiner.pipeline.sources.readers.utils import DEFAULT_TARGET_SHARD_BYTES from refiner.utils import check_required_dependencies +MissingPolicy = Literal["error", "drop_row", "set_null"] + + def _decode_value( value: Any, *, diff --git a/src/refiner/pipeline/sources/readers/selection.py b/src/refiner/pipeline/sources/readers/selection.py index 3f161607..283d8cca 100644 --- a/src/refiner/pipeline/sources/readers/selection.py +++ b/src/refiner/pipeline/sources/readers/selection.py @@ -1,10 +1,9 @@ from __future__ import annotations from collections.abc import Mapping, Sequence -from typing import Literal, cast +from typing import cast -MissingPolicy = Literal["error", "drop_row", "set_null"] PathSelection = Mapping[str, str] | Sequence[str] | str @@ -31,4 +30,4 @@ def path_selection_map( return out -__all__ = ["MissingPolicy", "PathSelection", "path_selection_map"] +__all__ = ["PathSelection", "path_selection_map"] diff --git a/src/refiner/pipeline/sources/readers/zarr.py b/src/refiner/pipeline/sources/readers/zarr.py index c20dd3bb..8bba68ea 100644 --- a/src/refiner/pipeline/sources/readers/zarr.py +++ b/src/refiner/pipeline/sources/readers/zarr.py @@ -8,7 +8,6 @@ import pyarrow as pa from refiner.io.datafolder import DataFolder, DataFolderLike -from refiner.io.zarr import zarr_store from refiner.pipeline.data.datatype import ( DTypeMapping, dtype_to_plan, @@ -21,10 +20,9 @@ PathSelection, path_selection_map, ) +from refiner.pipeline.sources.readers.utils import DEFAULT_TARGET_SHARD_BYTES from refiner.utils import check_required_dependencies -DEFAULT_TARGET_SHARD_BYTES = 256 * 1024**2 - @dataclass(frozen=True, slots=True) class _ArrayInfo: @@ -162,7 +160,10 @@ def list_shards(self) -> list[Shard]: path = self.root.abs_path() import zarr - group = zarr.open_group(store=zarr_store(self.root), mode="r") + group = zarr.open_group( + store=zarr.storage.FSStore(self.root._join(""), fs=self.root.fs, mode="r"), + mode="r", + ) arrays = self._array_selection(group) _validate_output_names( arrays, @@ -184,19 +185,30 @@ def list_shards(self) -> list[Shard]: def read_shard(self, shard: Shard) -> Iterator[SourceUnit]: import zarr - group = zarr.open_group(store=zarr_store(self.root), mode="r") + group = zarr.open_group( + store=zarr.storage.FSStore(self.root._join(""), fs=self.root.fs, mode="r"), + mode="r", + ) arrays = self._array_selection(group) _validate_output_names( arrays, self.attrs or {}, reserved=self._reserved_output_names(split=self._is_split_mode), ) + infos = self._array_infos(group, arrays) if self._is_split_mode: descriptor = shard.descriptor assert isinstance(descriptor, RowRangeDescriptor) - source_ranges = self._source_ranges(group, self._array_infos(group, arrays)) - for row_index in range(descriptor.start, descriptor.end): - start, end = source_ranges[row_index] + source_ranges = self._source_ranges( + group, + infos, + output_rows=(descriptor.start, descriptor.end), + ) + for row_index, (start, end) in zip( + range(descriptor.start, descriptor.end), + source_ranges, + strict=True, + ): row = self._row_metadata(index=row_index) row.update(self._read_arrays(group, arrays, start=start, end=end)) yield DictRow(self._read_attrs(group, row)) @@ -259,22 +271,45 @@ def _source_ranges( self, group: Any, infos: Sequence[_ArrayInfo], + *, + output_rows: tuple[int, int] | None = None, ) -> list[tuple[int, int]]: if self.row_ends is not None: try: - ends = [int(value) for value in group[self.row_ends][:]] + row_ends_array = group[self.row_ends] except KeyError: raise KeyError( f"Zarr row_ends array not found: {self.row_ends}" ) from None - ranges = _ranges_from_ends(ends) + if output_rows is None: + ranges = _ranges_from_ends([int(value) for value in row_ends_array[:]]) + else: + row_start, row_end = output_rows + if row_end < row_start: + raise ValueError("Zarr shard row range is invalid") + if row_start == row_end: + return [] + read_start = max(0, row_start - 1) + values = [int(value) for value in row_ends_array[read_start:row_end]] + if len(values) != row_end - read_start: + raise ValueError("Zarr shard row range exceeds row_ends length") + previous = 0 if row_start == 0 else values[0] + ranges = [] + for value in values if row_start == 0 else values[1:]: + if value < previous: + raise ValueError("Zarr row_ends must be monotonic increasing") + ranges.append((previous, value)) + previous = value _validate_source_ranges(ranges, infos, label="row_ends") return ranges - return _leading_axis_ranges( + ranges = _leading_axis_ranges( infos, target_shard_bytes=self.target_shard_bytes, num_shards=self.num_shards, ) + return ( + ranges if output_rows is None else ranges[output_rows[0] : output_rows[1]] + ) def _split_ranges( self, @@ -304,11 +339,15 @@ def _read_arrays( row: dict[str, Any] = {} for output_name, path in arrays.items(): try: - row[output_name] = ( - group[path][start:end] if start is not None else group[path][:] - ) + array = group[path] except KeyError: raise KeyError(f"Zarr array not found: {path}") from None + if start is not None: + row[output_name] = array[start:end] + elif array.shape == (): + row[output_name] = array[()] + else: + row[output_name] = array[:] return row def _read_attrs(self, group: Any, row: dict[str, Any]) -> dict[str, Any]: diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py index bd3df917..cb816a86 100644 --- a/tests/readers/test_zarr_reader.py +++ b/tests/readers/test_zarr_reader.py @@ -60,6 +60,20 @@ def test_read_zarr_reads_selected_arrays_and_attrs(tmp_path: Path) -> None: np.testing.assert_allclose(row["action"][:2], [[0.0], [0.1]]) +def test_read_zarr_reads_scalar_arrays(tmp_path: Path) -> None: + path = tmp_path / "scalar.zarr" + root = zarr.open_group(str(path), mode="w") + root.create_dataset("version", data=np.asarray(3, dtype=np.int64), shape=()) + + row = mdr.read_zarr( + path, + arrays={"version": "version"}, + file_path_column=None, + ).take(1)[0] + + assert row["version"] == 3 + + def test_read_zarr_splits_arrays_by_row_ends(tmp_path: Path) -> None: path = tmp_path / "policy.zarr" _write_policy_zarr(path) From f271cb5440cb2420dfbee24a7b9b3a5ba807111a Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sat, 23 May 2026 16:20:28 +0200 Subject: [PATCH 15/39] Avoid repeated Zarr shard reads --- src/refiner/pipeline/sources/readers/zarr.py | 244 +++++++++++-------- tests/readers/test_zarr_reader.py | 6 +- 2 files changed, 149 insertions(+), 101 deletions(-) diff --git a/src/refiner/pipeline/sources/readers/zarr.py b/src/refiner/pipeline/sources/readers/zarr.py index 8bba68ea..1a62752b 100644 --- a/src/refiner/pipeline/sources/readers/zarr.py +++ b/src/refiner/pipeline/sources/readers/zarr.py @@ -158,19 +158,17 @@ def describe(self) -> dict[str, Any]: def list_shards(self) -> list[Shard]: path = self.root.abs_path() - import zarr - group = zarr.open_group( - store=zarr.storage.FSStore(self.root._join(""), fs=self.root.fs, mode="r"), - mode="r", - ) + group = self._open_group() arrays = self._array_selection(group) _validate_output_names( arrays, self.attrs or {}, reserved=self._reserved_output_names(split=self._is_split_mode), ) - split_ranges = self._split_ranges(group, self._array_infos(group, arrays)) + split_ranges = self._planned_shard_ranges( + group, self._array_infos(group, arrays) + ) return [ Shard.from_row_range( start=start, @@ -183,12 +181,7 @@ def list_shards(self) -> list[Shard]: ] def read_shard(self, shard: Shard) -> Iterator[SourceUnit]: - import zarr - - group = zarr.open_group( - store=zarr.storage.FSStore(self.root._join(""), fs=self.root.fs, mode="r"), - mode="r", - ) + group = self._open_group() arrays = self._array_selection(group) _validate_output_names( arrays, @@ -196,28 +189,62 @@ def read_shard(self, shard: Shard) -> Iterator[SourceUnit]: reserved=self._reserved_output_names(split=self._is_split_mode), ) infos = self._array_infos(group, arrays) - if self._is_split_mode: + if self.row_ends is not None: descriptor = shard.descriptor assert isinstance(descriptor, RowRangeDescriptor) - source_ranges = self._source_ranges( + source_ranges = self._row_end_source_ranges( group, infos, - output_rows=(descriptor.start, descriptor.end), + row_start=descriptor.start, + row_end=descriptor.end, ) + if not source_ranges: + return + block_start = source_ranges[0][0] + block_end = source_ranges[-1][1] + block = self._read_arrays(group, arrays, start=block_start, end=block_end) + attrs = self._read_attrs(group, {}) for row_index, (start, end) in zip( range(descriptor.start, descriptor.end), source_ranges, strict=True, ): row = self._row_metadata(index=row_index) - row.update(self._read_arrays(group, arrays, start=start, end=end)) - yield DictRow(self._read_attrs(group, row)) + row.update( + { + name: value[start - block_start : end - block_start] + for name, value in block.items() + } + ) + row.update(attrs) + yield DictRow(row) + return + + if self.split_leading_axis: + descriptor = shard.descriptor + assert isinstance(descriptor, RowRangeDescriptor) + row = self._row_metadata(index=shard.global_ordinal) + row.update( + self._read_arrays( + group, + arrays, + start=descriptor.start, + end=descriptor.end, + ) + ) + yield DictRow(self._read_attrs(group, row)) return row = self._row_metadata(index=None) row.update(self._read_arrays(group, arrays)) yield DictRow(self._read_attrs(group, row)) + def _open_group(self) -> Any: + import zarr + + store = zarr.storage.FSStore(self.root._join(""), fs=self.root.fs, mode="r") + return zarr.open_group(store=store, mode="r") + @property def _is_split_mode(self) -> bool: return self.row_ends is not None or self.split_leading_axis @@ -267,51 +294,32 @@ def _array_infos( ) return infos - def _source_ranges( + def _row_end_source_ranges( self, group: Any, infos: Sequence[_ArrayInfo], *, - output_rows: tuple[int, int] | None = None, + row_start: int, + row_end: int, ) -> list[tuple[int, int]]: - if self.row_ends is not None: - try: - row_ends_array = group[self.row_ends] - except KeyError: - raise KeyError( - f"Zarr row_ends array not found: {self.row_ends}" - ) from None - if output_rows is None: - ranges = _ranges_from_ends([int(value) for value in row_ends_array[:]]) - else: - row_start, row_end = output_rows - if row_end < row_start: - raise ValueError("Zarr shard row range is invalid") - if row_start == row_end: - return [] - read_start = max(0, row_start - 1) - values = [int(value) for value in row_ends_array[read_start:row_end]] - if len(values) != row_end - read_start: - raise ValueError("Zarr shard row range exceeds row_ends length") - previous = 0 if row_start == 0 else values[0] - ranges = [] - for value in values if row_start == 0 else values[1:]: - if value < previous: - raise ValueError("Zarr row_ends must be monotonic increasing") - ranges.append((previous, value)) - previous = value - _validate_source_ranges(ranges, infos, label="row_ends") - return ranges - ranges = _leading_axis_ranges( - infos, - target_shard_bytes=self.target_shard_bytes, - num_shards=self.num_shards, - ) - return ( - ranges if output_rows is None else ranges[output_rows[0] : output_rows[1]] + if row_end < row_start: + raise ValueError("Zarr shard row range is invalid") + if row_start == row_end: + return [] + row_ends_array = self._row_ends_array(group) + read_start = max(0, row_start - 1) + values = [int(value) for value in row_ends_array[read_start:row_end]] + if len(values) != row_end - read_start: + raise ValueError("Zarr shard row range exceeds row_ends length") + ranges = ( + _ranges_from_ends(values) + if row_start == 0 + else _ranges_from_ends(values[1:], start=values[0]) ) + _validate_source_ranges(ranges, infos, label="row_ends") + return ranges - def _split_ranges( + def _planned_shard_ranges( self, group: Any, infos: Sequence[_ArrayInfo], @@ -319,14 +327,64 @@ def _split_ranges( if not self._is_split_mode: return [(0, 1)] if self.split_leading_axis: - source_ranges = self._source_ranges(group, infos) - return [(index, index + 1) for index in range(len(source_ranges))] - return _pack_output_rows( - self._source_ranges(group, infos), - infos, - target_shard_bytes=self.target_shard_bytes, - num_shards=self.num_shards, - ) + return _leading_axis_ranges( + infos, + target_shard_bytes=self.target_shard_bytes, + num_shards=self.num_shards, + ) + return self._row_end_shard_ranges(group, infos) + + def _row_ends_array(self, group: Any) -> Any: + try: + row_ends_array = group[self.row_ends] + except KeyError: + raise KeyError(f"Zarr row_ends array not found: {self.row_ends}") from None + if len(row_ends_array.shape) != 1: + raise ValueError("Zarr row_ends must be one-dimensional") + return row_ends_array + + def _row_end_shard_ranges( + self, + group: Any, + infos: Sequence[_ArrayInfo], + ) -> list[tuple[int, int]]: + row_ends_array = self._row_ends_array(group) + row_count = int(row_ends_array.shape[0]) + if row_count == 0: + return [] + if self.num_shards is not None: + _validate_row_ends(row_ends_array, infos) + step = ceil(row_count / self.num_shards) + return [ + (start, min(start + step, row_count)) + for start in range(0, row_count, step) + ] + + bytes_per_step = sum(info.bytes_per_step for info in infos) + if bytes_per_step <= 0: + _validate_row_ends(row_ends_array, infos) + return [(0, row_count)] + + ranges: list[tuple[int, int]] = [] + shard_start = 0 + current_bytes = 0 + previous_end = 0 + for row_index, end in _iter_row_ends(row_ends_array): + if end < previous_end: + raise ValueError("Zarr row_ends must be monotonic increasing") + row_bytes = max(1, end - previous_end) * bytes_per_step + if ( + row_index > shard_start + and current_bytes + row_bytes > self.target_shard_bytes + ): + ranges.append((shard_start, row_index)) + shard_start = row_index + current_bytes = 0 + current_bytes += row_bytes + previous_end = end + ranges.append((shard_start, row_count)) + _validate_source_ranges([(0, previous_end)], infos, label="row_ends") + return ranges def _read_arrays( self, @@ -374,9 +432,12 @@ def _validate_output_names( raise ValueError(f"Zarr selections use reserved output names: {names}") -def _ranges_from_ends(ends: Sequence[int]) -> list[tuple[int, int]]: +def _ranges_from_ends( + ends: Sequence[int], + *, + start: int = 0, +) -> list[tuple[int, int]]: ranges: list[tuple[int, int]] = [] - start = 0 for end in ends: if end < start: raise ValueError("Zarr row_ends must be monotonic increasing") @@ -385,6 +446,26 @@ def _ranges_from_ends(ends: Sequence[int]) -> list[tuple[int, int]]: return ranges +def _iter_row_ends(array: Any) -> Iterator[tuple[int, int]]: + chunk = max(1, int(array.chunks[0]) if array.chunks else 8192) + length = int(array.shape[0]) + for start in range(0, length, chunk): + for offset, value in enumerate(array[start : min(start + chunk, length)]): + yield start + offset, int(value) + + +def _validate_row_ends( + array: Any, + infos: Sequence[_ArrayInfo], +) -> None: + previous_end = 0 + for _, end in _iter_row_ends(array): + if end < previous_end: + raise ValueError("Zarr row_ends must be monotonic increasing") + previous_end = end + _validate_source_ranges([(0, previous_end)], infos, label="row_ends") + + def _validate_source_ranges( ranges: Sequence[tuple[int, int]], infos: Sequence[_ArrayInfo], @@ -426,39 +507,6 @@ def _leading_axis_ranges( return [(start, min(start + step, length)) for start in range(0, length, step)] -def _pack_output_rows( - source_ranges: Sequence[tuple[int, int]], - infos: Sequence[_ArrayInfo], - *, - target_shard_bytes: int, - num_shards: int | None, -) -> list[tuple[int, int]]: - if not source_ranges: - return [] - if num_shards is not None: - step = ceil(len(source_ranges) / num_shards) - return [ - (start, min(start + step, len(source_ranges))) - for start in range(0, len(source_ranges), step) - ] - - bytes_per_step = sum(info.bytes_per_step for info in infos) - if bytes_per_step <= 0: - return [(0, len(source_ranges))] - ranges: list[tuple[int, int]] = [] - start_index = 0 - current_bytes = 0 - for index, (start, end) in enumerate(source_ranges): - row_bytes = max(1, end - start) * bytes_per_step - if index > start_index and current_bytes + row_bytes > target_shard_bytes: - ranges.append((start_index, index)) - start_index = index - current_bytes = 0 - current_bytes += row_bytes - ranges.append((start_index, len(source_ranges))) - return ranges - - def _iter_array_paths(group: Any, prefix: str = "") -> Iterator[str]: for name, item in group.items(): path = f"{prefix}/{name}" if prefix else name diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py index cb816a86..53a7bc99 100644 --- a/tests/readers/test_zarr_reader.py +++ b/tests/readers/test_zarr_reader.py @@ -241,9 +241,9 @@ def test_read_zarr_split_leading_axis_emits_aligned_windows(tmp_path: Path) -> N shards = pipeline.source.list_shards() ranges = [cast(RowRangeDescriptor, shard.descriptor) for shard in shards] assert [(item.start, item.end) for item in ranges] == [ - (0, 1), - (1, 2), - (2, 3), + (0, 2), + (2, 4), + (4, 5), ] rows = pipeline.take(3) From d888a361f87f20a4f40264e00a86668ef0a8b49c Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sat, 23 May 2026 16:26:12 +0200 Subject: [PATCH 16/39] Tighten Zarr row boundary validation --- docs/reading-and-writing.md | 9 +-- src/refiner/pipeline/sources/readers/hdf5.py | 55 ++----------------- src/refiner/pipeline/sources/readers/utils.py | 49 ++++++++++++++++- src/refiner/pipeline/sources/readers/zarr.py | 53 ++++++++++-------- tests/readers/test_zarr_reader.py | 17 +++++- 5 files changed, 106 insertions(+), 77 deletions(-) diff --git a/docs/reading-and-writing.md b/docs/reading-and-writing.md index f0220204..78824159 100644 --- a/docs/reading-and-writing.md +++ b/docs/reading-and-writing.md @@ -289,7 +289,8 @@ replay_buffer.zarr this emits one row per `[start:end]` slice. The selected arrays are sliced along their leading dimension, while selected attrs are repeated on each row. `index_column` receives the row/episode index when `row_ends` is set. Set it to -`None` to omit that metadata. +`None` to omit that metadata. The final row end must match the leading dimension +of every selected array. If a Zarr store has aligned arrays but no episode boundaries, use `split_leading_axis=True` to emit leading-axis windows: @@ -307,9 +308,9 @@ windows = mdr.read_zarr( ``` This mode requires selected arrays to have the same leading dimension. Refiner -chooses contiguous windows from array metadata, using the byte-heavy array's -chunking to avoid unnecessary chunk splits where possible. Use `num_shards` when -you need a target shard count instead of byte-sized packing. +chooses contiguous windows from array metadata and avoids splitting below the +largest selected leading-axis chunk where possible. Use `num_shards` when you +need a target shard count instead of byte-sized packing. `row_ends` is reader control metadata, not an output selection. If you also want the raw offsets as a column in non-split mode, select that path through `arrays`. diff --git a/src/refiner/pipeline/sources/readers/hdf5.py b/src/refiner/pipeline/sources/readers/hdf5.py index 56583bf1..59211797 100644 --- a/src/refiner/pipeline/sources/readers/hdf5.py +++ b/src/refiner/pipeline/sources/readers/hdf5.py @@ -14,59 +14,16 @@ from refiner.pipeline.data.shard import FilePartsDescriptor from refiner.pipeline.sources.readers.base import BaseReader, Shard, SourceUnit from refiner.pipeline.sources.readers.selection import PathSelection, path_selection_map -from refiner.pipeline.sources.readers.utils import DEFAULT_TARGET_SHARD_BYTES +from refiner.pipeline.sources.readers.utils import ( + DEFAULT_TARGET_SHARD_BYTES, + decode_value, +) from refiner.utils import check_required_dependencies MissingPolicy = Literal["error", "drop_row", "set_null"] -def _decode_value( - value: Any, - *, - decode_bytes: bool = True, - preserve_arrays: bool = False, -) -> Any: - if isinstance(value, bytes): - if not decode_bytes: - return value - try: - return value.decode("utf-8") - except UnicodeDecodeError: - return value - if isinstance(value, str) and any("\udc80" <= char <= "\udcff" for char in value): - return value.encode("utf-8", errors="surrogateescape") - if hasattr(value, "shape") and value.shape == (): - return _decode_value( - value.item(), - decode_bytes=decode_bytes, - preserve_arrays=preserve_arrays, - ) - if hasattr(value, "tolist"): - if preserve_arrays and getattr( - getattr(value, "dtype", None), "kind", None - ) not in ( - "O", - "S", - ): - return value - return _decode_value( - value.tolist(), - decode_bytes=decode_bytes, - preserve_arrays=preserve_arrays, - ) - if isinstance(value, list): - return [ - _decode_value( - item, - decode_bytes=decode_bytes, - preserve_arrays=preserve_arrays, - ) - for item in value - ] - return value - - class Hdf5Reader(BaseReader): """HDF5 reader planned at file granularity. @@ -289,7 +246,7 @@ def _read_group( raise TypeError( f"HDF5 path under {group_path} is not a dataset: {dataset_path}" ) - row[output_name] = _decode_value( + row[output_name] = decode_value( dataset[()], decode_bytes=dataset.dtype.kind != "S", preserve_arrays=True, @@ -303,7 +260,7 @@ def _read_group( row[output_name] = None continue raise KeyError(f"HDF5 attr not found on {group_path}: {attr_name}") - row[output_name] = _decode_value(group.attrs[attr_name]) + row[output_name] = decode_value(group.attrs[attr_name]) return self._with_file_path(row, source) diff --git a/src/refiner/pipeline/sources/readers/utils.py b/src/refiner/pipeline/sources/readers/utils.py index 3860a9e5..2bedd54e 100644 --- a/src/refiner/pipeline/sources/readers/utils.py +++ b/src/refiner/pipeline/sources/readers/utils.py @@ -1,7 +1,7 @@ from __future__ import annotations import io -from typing import Optional +from typing import Any, Optional from fsspec import AbstractFileSystem @@ -12,6 +12,52 @@ NON_SPLITTABLE_WHOLEFILE_EXTS = (".gz", ".bz2", ".xz", ".zip", ".zst") +def decode_value( + value: Any, + *, + decode_bytes: bool = True, + preserve_arrays: bool = False, +) -> Any: + if isinstance(value, bytes): + if not decode_bytes: + return value + try: + return value.decode("utf-8") + except UnicodeDecodeError: + return value + if isinstance(value, str) and any("\udc80" <= char <= "\udcff" for char in value): + return value.encode("utf-8", errors="surrogateescape") + if hasattr(value, "shape") and value.shape == (): + return decode_value( + value.item(), + decode_bytes=decode_bytes, + preserve_arrays=preserve_arrays, + ) + if hasattr(value, "tolist"): + if preserve_arrays and getattr( + getattr(value, "dtype", None), "kind", None + ) not in ( + "O", + "S", + ): + return value + return decode_value( + value.tolist(), + decode_bytes=decode_bytes, + preserve_arrays=preserve_arrays, + ) + if isinstance(value, list): + return [ + decode_value( + item, + decode_bytes=decode_bytes, + preserve_arrays=preserve_arrays, + ) + for item in value + ] + return value + + def is_splittable_by_bytes(fs: AbstractFileSystem, path: str) -> bool: """Return True if the input can be safely sharded by byte offsets.""" lp = path.lower() @@ -98,6 +144,7 @@ def readinto(self, b) -> int: __all__ = [ "DEFAULT_TARGET_SHARD_BYTES", "NON_SPLITTABLE_WHOLEFILE_EXTS", + "decode_value", "is_splittable_by_bytes", "align_byte_range_to_newlines", "BoundedBinaryReader", diff --git a/src/refiner/pipeline/sources/readers/zarr.py b/src/refiner/pipeline/sources/readers/zarr.py index 1a62752b..ac10a11e 100644 --- a/src/refiner/pipeline/sources/readers/zarr.py +++ b/src/refiner/pipeline/sources/readers/zarr.py @@ -20,7 +20,10 @@ PathSelection, path_selection_map, ) -from refiner.pipeline.sources.readers.utils import DEFAULT_TARGET_SHARD_BYTES +from refiner.pipeline.sources.readers.utils import ( + DEFAULT_TARGET_SHARD_BYTES, + decode_value, +) from refiner.utils import check_required_dependencies @@ -50,17 +53,6 @@ def bytes_per_step(self) -> int: return max(1, int(self.dtype.itemsize) * int(prod(trailing_shape or (1,)))) -def _decode_value(value: Any) -> Any: - if hasattr(value, "shape") and value.shape == (): - return _decode_value(value.item()) - if isinstance(value, bytes): - try: - return value.decode("utf-8") - except UnicodeDecodeError: - return value - return value - - class ZarrReader(BaseSource): """Read a Zarr group as one row, episode rows, or leading-axis windows.""" @@ -203,7 +195,7 @@ def read_shard(self, shard: Shard) -> Iterator[SourceUnit]: block_start = source_ranges[0][0] block_end = source_ranges[-1][1] block = self._read_arrays(group, arrays, start=block_start, end=block_end) - attrs = self._read_attrs(group, {}) + attrs = self._read_attrs(group) for row_index, (start, end) in zip( range(descriptor.start, descriptor.end), source_ranges, @@ -232,12 +224,14 @@ def read_shard(self, shard: Shard) -> Iterator[SourceUnit]: end=descriptor.end, ) ) - yield DictRow(self._read_attrs(group, row)) + row.update(self._read_attrs(group)) + yield DictRow(row) return row = self._row_metadata(index=None) row.update(self._read_arrays(group, arrays)) - yield DictRow(self._read_attrs(group, row)) + row.update(self._read_attrs(group)) + yield DictRow(row) def _open_group(self) -> Any: import zarr @@ -383,7 +377,12 @@ def _row_end_shard_ranges( current_bytes += row_bytes previous_end = end ranges.append((shard_start, row_count)) - _validate_source_ranges([(0, previous_end)], infos, label="row_ends") + _validate_source_ranges( + [(0, previous_end)], + infos, + label="row_ends", + require_exact=True, + ) return ranges def _read_arrays( @@ -408,12 +407,13 @@ def _read_arrays( row[output_name] = array[:] return row - def _read_attrs(self, group: Any, row: dict[str, Any]) -> dict[str, Any]: + def _read_attrs(self, group: Any) -> dict[str, Any]: + attrs: dict[str, Any] = {} for output_name, attr_name in (self.attrs or {}).items(): if attr_name not in group.attrs: raise KeyError(f"Zarr attr not found: {attr_name}") - row[output_name] = _decode_value(group.attrs[attr_name]) - return row + attrs[output_name] = decode_value(group.attrs[attr_name]) + return attrs def _validate_output_names( @@ -463,7 +463,12 @@ def _validate_row_ends( if end < previous_end: raise ValueError("Zarr row_ends must be monotonic increasing") previous_end = end - _validate_source_ranges([(0, previous_end)], infos, label="row_ends") + _validate_source_ranges( + [(0, previous_end)], + infos, + label="row_ends", + require_exact=True, + ) def _validate_source_ranges( @@ -471,6 +476,7 @@ def _validate_source_ranges( infos: Sequence[_ArrayInfo], *, label: str, + require_exact: bool = False, ) -> None: if not infos: return @@ -480,6 +486,10 @@ def _validate_source_ranges( raise ValueError( f"Zarr {label} exceed leading dimension for {info.output_name!r}" ) + if require_exact and final_end != info.leading_length: + raise ValueError( + f"Zarr {label} end before leading dimension for {info.output_name!r}" + ) def _leading_axis_ranges( @@ -501,8 +511,7 @@ def _leading_axis_ranges( else: bytes_per_step = sum(info.bytes_per_step for info in infos) target_steps = max(1, target_shard_bytes // max(1, bytes_per_step)) - heavy = max(infos, key=lambda info: info.bytes_per_step) - base = max(1, heavy.leading_chunk) + base = max(1, max(info.leading_chunk for info in infos)) step = max(base, (target_steps // base) * base) return [(start, min(start + step, length)) for start in range(0, length, step)] diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py index 53a7bc99..493d0069 100644 --- a/tests/readers/test_zarr_reader.py +++ b/tests/readers/test_zarr_reader.py @@ -222,7 +222,7 @@ def test_read_zarr_split_leading_axis_emits_aligned_windows(tmp_path: Path) -> N root.create_dataset( "data/action", data=np.arange(5, dtype=np.float32).reshape(5, 1), - chunks=(5, 1), + chunks=(1, 1), ) root.create_dataset( "data/rgb", @@ -298,6 +298,21 @@ def test_read_zarr_rejects_out_of_range_row_ends(tmp_path: Path) -> None: ).take(2) +def test_read_zarr_rejects_short_row_ends(tmp_path: Path) -> None: + path = tmp_path / "policy.zarr" + _write_policy_zarr(path) + root = zarr.open_group(str(path), mode="a") + root["meta/episode_ends"][:] = np.asarray([2, 4], dtype=np.int64) + + with pytest.raises(ValueError, match="end before leading dimension"): + mdr.read_zarr( + path, + arrays={"action": "data/action"}, + row_ends="meta/episode_ends", + file_path_column=None, + ).take(2) + + def test_zarr_to_robot_rows_and_lerobot_roundtrip(tmp_path: Path) -> None: path = tmp_path / "policy.zarr" lerobot_out = tmp_path / "lerobot" From ae02651a3aec296a791d2cf50e1f6bb84f4dba1d Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sat, 23 May 2026 17:07:23 +0200 Subject: [PATCH 17/39] Emit leading-axis Zarr rows per index --- docs/reading-and-writing.md | 8 +++--- src/refiner/pipeline/sources/readers/zarr.py | 26 ++++++++++++-------- tests/readers/test_zarr_reader.py | 9 ++++--- 3 files changed, 25 insertions(+), 18 deletions(-) diff --git a/docs/reading-and-writing.md b/docs/reading-and-writing.md index 78824159..f4b7a542 100644 --- a/docs/reading-and-writing.md +++ b/docs/reading-and-writing.md @@ -293,7 +293,7 @@ their leading dimension, while selected attrs are repeated on each row. of every selected array. If a Zarr store has aligned arrays but no episode boundaries, use -`split_leading_axis=True` to emit leading-axis windows: +`split_leading_axis=True` to emit one row per leading-axis item: ```python windows = mdr.read_zarr( @@ -308,9 +308,9 @@ windows = mdr.read_zarr( ``` This mode requires selected arrays to have the same leading dimension. Refiner -chooses contiguous windows from array metadata and avoids splitting below the -largest selected leading-axis chunk where possible. Use `num_shards` when you -need a target shard count instead of byte-sized packing. +plans shards from array metadata and avoids splitting shards below the largest +selected leading-axis chunk where possible. Use `num_shards` when you need a +target shard count instead of byte-sized packing. `row_ends` is reader control metadata, not an output selection. If you also want the raw offsets as a column in non-split mode, select that path through `arrays`. diff --git a/src/refiner/pipeline/sources/readers/zarr.py b/src/refiner/pipeline/sources/readers/zarr.py index ac10a11e..79f519fd 100644 --- a/src/refiner/pipeline/sources/readers/zarr.py +++ b/src/refiner/pipeline/sources/readers/zarr.py @@ -215,17 +215,23 @@ def read_shard(self, shard: Shard) -> Iterator[SourceUnit]: if self.split_leading_axis: descriptor = shard.descriptor assert isinstance(descriptor, RowRangeDescriptor) - row = self._row_metadata(index=shard.global_ordinal) - row.update( - self._read_arrays( - group, - arrays, - start=descriptor.start, - end=descriptor.end, - ) + block = self._read_arrays( + group, + arrays, + start=descriptor.start, + end=descriptor.end, ) - row.update(self._read_attrs(group)) - yield DictRow(row) + attrs = self._read_attrs(group) + for row_index in range(descriptor.start, descriptor.end): + row = self._row_metadata(index=row_index) + row.update( + { + name: value[row_index - descriptor.start] + for name, value in block.items() + } + ) + row.update(attrs) + yield DictRow(row) return row = self._row_metadata(index=None) diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py index 493d0069..e41bdf0c 100644 --- a/tests/readers/test_zarr_reader.py +++ b/tests/readers/test_zarr_reader.py @@ -216,8 +216,8 @@ def test_read_zarr_rejects_missing_selected_attrs(tmp_path: Path) -> None: ).take(1) -def test_read_zarr_split_leading_axis_emits_aligned_windows(tmp_path: Path) -> None: - path = tmp_path / "windows.zarr" +def test_read_zarr_split_leading_axis_emits_one_row_per_index(tmp_path: Path) -> None: + path = tmp_path / "leading_axis.zarr" root = zarr.open_group(str(path), mode="w") root.create_dataset( "data/action", @@ -249,8 +249,9 @@ def test_read_zarr_split_leading_axis_emits_aligned_windows(tmp_path: Path) -> N rows = pipeline.take(3) assert [row["index"] for row in rows] == [0, 1, 2] - assert [len(row["action"]) for row in rows] == [2, 2, 1] - np.testing.assert_allclose(rows[1]["action"], [[2.0], [3.0]]) + assert [row["action"].shape for row in rows] == [(1,), (1,), (1,)] + assert [row["rgb"].shape for row in rows] == [(4, 4, 3)] * 3 + np.testing.assert_allclose(rows[1]["action"], [1.0]) def test_read_zarr_split_leading_axis_requires_aligned_lengths(tmp_path: Path) -> None: From 77018f391c46a521c8fdc4e91168771716858630 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sat, 23 May 2026 17:14:12 +0200 Subject: [PATCH 18/39] Add Zarr leading axis row size --- docs/reading-and-writing.md | 15 +++--- src/refiner/pipeline/pipeline.py | 7 ++- src/refiner/pipeline/sources/readers/zarr.py | 54 +++++++++++++------ tests/readers/test_zarr_reader.py | 55 ++++++++++++++++++-- 4 files changed, 104 insertions(+), 27 deletions(-) diff --git a/docs/reading-and-writing.md b/docs/reading-and-writing.md index f4b7a542..77e3ca79 100644 --- a/docs/reading-and-writing.md +++ b/docs/reading-and-writing.md @@ -293,24 +293,27 @@ their leading dimension, while selected attrs are repeated on each row. of every selected array. If a Zarr store has aligned arrays but no episode boundaries, use -`split_leading_axis=True` to emit one row per leading-axis item: +`split_leading_axis=True` to emit fixed-size rows along the leading axis: ```python -windows = mdr.read_zarr( +rows = mdr.read_zarr( "replay_buffer.zarr", arrays={ "action": "data/action", "frames": "data/rgb", }, split_leading_axis=True, + leading_axis_row_size=1, target_shard_bytes=128 * 1024**2, ) ``` -This mode requires selected arrays to have the same leading dimension. Refiner -plans shards from array metadata and avoids splitting shards below the largest -selected leading-axis chunk where possible. Use `num_shards` when you need a -target shard count instead of byte-sized packing. +This mode requires selected arrays to have the same leading dimension, and that +dimension must be divisible by `leading_axis_row_size`. Each output row contains +`leading_axis_row_size` contiguous items from every selected array. Refiner plans +shards from array metadata and avoids splitting shards below the largest selected +leading-axis chunk where possible. Use `num_shards` when you need a target shard +count instead of byte-sized packing. `row_ends` is reader control metadata, not an output selection. If you also want the raw offsets as a column in non-split mode, select that path through `arrays`. diff --git a/src/refiner/pipeline/pipeline.py b/src/refiner/pipeline/pipeline.py index 6005943b..5bf6bd90 100644 --- a/src/refiner/pipeline/pipeline.py +++ b/src/refiner/pipeline/pipeline.py @@ -814,6 +814,7 @@ def read_zarr( attrs: PathSelection | None = None, row_ends: str | None = None, split_leading_axis: bool = False, + leading_axis_row_size: int = 1, target_shard_bytes: int = DEFAULT_TARGET_SHARD_BYTES, num_shards: int | None = None, index_column: str | None = "index", @@ -825,10 +826,11 @@ def read_zarr( The reader has three modes: - group mode: one Zarr group becomes one row - row_ends mode: cumulative offsets define whole-row source slices - - split_leading_axis mode: aligned axis-0 windows define output rows + - split_leading_axis mode: fixed-size leading-axis slices define output rows Missing selected arrays or attributes raise immediately. `row_ends` and - `split_leading_axis` are mutually exclusive. + `split_leading_axis` are mutually exclusive. `target_shard_bytes` and + `num_shards` affect shard planning, not logical row size. """ return RefinerPipeline( source=ZarrReader( @@ -837,6 +839,7 @@ def read_zarr( attrs=attrs, row_ends=row_ends, split_leading_axis=split_leading_axis, + leading_axis_row_size=leading_axis_row_size, target_shard_bytes=target_shard_bytes, num_shards=num_shards, index_column=index_column, diff --git a/src/refiner/pipeline/sources/readers/zarr.py b/src/refiner/pipeline/sources/readers/zarr.py index 79f519fd..d9b4e1b8 100644 --- a/src/refiner/pipeline/sources/readers/zarr.py +++ b/src/refiner/pipeline/sources/readers/zarr.py @@ -54,7 +54,7 @@ def bytes_per_step(self) -> int: class ZarrReader(BaseSource): - """Read a Zarr group as one row, episode rows, or leading-axis windows.""" + """Read a Zarr group as one row, episode rows, or leading-axis rows.""" name = "read_zarr" @@ -66,6 +66,7 @@ def __init__( attrs: PathSelection | None = None, row_ends: str | None = None, split_leading_axis: bool = False, + leading_axis_row_size: int = 1, target_shard_bytes: int = DEFAULT_TARGET_SHARD_BYTES, num_shards: int | None = None, index_column: str | None = "index", @@ -82,12 +83,14 @@ def __init__( row_ends: Optional Zarr array path containing cumulative end offsets. When set, emitted rows are whole source ranges and never split across these boundaries. - split_leading_axis: Emit aligned leading-axis windows when no + split_leading_axis: Emit fixed-size leading-axis rows when no ``row_ends`` path is provided. - target_shard_bytes: Approximate byte target used to pack output rows - into shards in split modes. + leading_axis_row_size: Number of leading-axis items in each logical + row when ``split_leading_axis`` is enabled. + target_shard_bytes: Approximate byte target used to pack logical + rows into shards in split modes. num_shards: Optional target shard count for split modes. - index_column: Output metadata column containing the episode/window + index_column: Output metadata column containing the logical row index in split modes, or None to omit it. file_path_column: Output metadata column containing the source path, or None to omit it. @@ -97,6 +100,10 @@ def __init__( check_required_dependencies("read_zarr", ["zarr"], dist="zarr") if row_ends is not None and split_leading_axis: raise ValueError("row_ends and split_leading_axis are mutually exclusive") + if leading_axis_row_size <= 0: + raise ValueError("leading_axis_row_size must be greater than zero") + if leading_axis_row_size != 1 and not split_leading_axis: + raise ValueError("leading_axis_row_size requires split_leading_axis=True") if target_shard_bytes <= 0: raise ValueError("target_shard_bytes must be greater than zero") if num_shards is not None and num_shards <= 0: @@ -109,6 +116,7 @@ def __init__( ) self.row_ends = row_ends self.split_leading_axis = split_leading_axis + self.leading_axis_row_size = leading_axis_row_size self.target_shard_bytes = target_shard_bytes self.num_shards = num_shards self.index_column = index_column @@ -137,6 +145,7 @@ def describe(self) -> dict[str, Any]: "attrs": dict(self.attrs) if self.attrs is not None else None, "row_ends": self.row_ends, "split_leading_axis": self.split_leading_axis, + "leading_axis_row_size": self.leading_axis_row_size, "target_shard_bytes": self.target_shard_bytes, "num_shards": self.num_shards, "index_column": self.index_column, @@ -215,18 +224,21 @@ def read_shard(self, shard: Shard) -> Iterator[SourceUnit]: if self.split_leading_axis: descriptor = shard.descriptor assert isinstance(descriptor, RowRangeDescriptor) + raw_start = descriptor.start * self.leading_axis_row_size + raw_end = descriptor.end * self.leading_axis_row_size block = self._read_arrays( group, arrays, - start=descriptor.start, - end=descriptor.end, + start=raw_start, + end=raw_end, ) attrs = self._read_attrs(group) for row_index in range(descriptor.start, descriptor.end): + offset = (row_index - descriptor.start) * self.leading_axis_row_size row = self._row_metadata(index=row_index) row.update( { - name: value[row_index - descriptor.start] + name: value[offset : offset + self.leading_axis_row_size] for name, value in block.items() } ) @@ -327,8 +339,9 @@ def _planned_shard_ranges( if not self._is_split_mode: return [(0, 1)] if self.split_leading_axis: - return _leading_axis_ranges( + return _leading_axis_shard_ranges( infos, + row_size=self.leading_axis_row_size, target_shard_bytes=self.target_shard_bytes, num_shards=self.num_shards, ) @@ -498,9 +511,10 @@ def _validate_source_ranges( ) -def _leading_axis_ranges( +def _leading_axis_shard_ranges( infos: Sequence[_ArrayInfo], *, + row_size: int, target_shard_bytes: int, num_shards: int | None, ) -> list[tuple[int, int]]: @@ -512,14 +526,22 @@ def _leading_axis_ranges( length = lengths.pop() if length == 0: return [] + if length % row_size != 0: + raise ValueError("Zarr leading dimension must be divisible by row size") + row_count = length // row_size if num_shards is not None: - step = ceil(length / num_shards) + step = ceil(row_count / num_shards) else: - bytes_per_step = sum(info.bytes_per_step for info in infos) - target_steps = max(1, target_shard_bytes // max(1, bytes_per_step)) - base = max(1, max(info.leading_chunk for info in infos)) - step = max(base, (target_steps // base) * base) - return [(start, min(start + step, length)) for start in range(0, length, step)] + bytes_per_row = sum(info.bytes_per_step for info in infos) * row_size + target_rows = max(1, target_shard_bytes // max(1, bytes_per_row)) + chunk_rows = max( + 1, + ceil(max(info.leading_chunk for info in infos) / row_size), + ) + step = max(chunk_rows, (target_rows // chunk_rows) * chunk_rows) + return [ + (start, min(start + step, row_count)) for start in range(0, row_count, step) + ] def _iter_array_paths(group: Any, prefix: str = "") -> Iterator[str]: diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py index e41bdf0c..3c41fae8 100644 --- a/tests/readers/test_zarr_reader.py +++ b/tests/readers/test_zarr_reader.py @@ -249,9 +249,31 @@ def test_read_zarr_split_leading_axis_emits_one_row_per_index(tmp_path: Path) -> rows = pipeline.take(3) assert [row["index"] for row in rows] == [0, 1, 2] - assert [row["action"].shape for row in rows] == [(1,), (1,), (1,)] - assert [row["rgb"].shape for row in rows] == [(4, 4, 3)] * 3 - np.testing.assert_allclose(rows[1]["action"], [1.0]) + assert [row["action"].shape for row in rows] == [(1, 1), (1, 1), (1, 1)] + assert [row["rgb"].shape for row in rows] == [(1, 4, 4, 3)] * 3 + np.testing.assert_allclose(rows[1]["action"], [[1.0]]) + + +def test_read_zarr_split_leading_axis_uses_row_size(tmp_path: Path) -> None: + path = tmp_path / "leading_axis_rows.zarr" + root = zarr.open_group(str(path), mode="w") + root.create_dataset( + "data/action", + data=np.arange(6, dtype=np.float32).reshape(6, 1), + chunks=(2, 1), + ) + + rows = mdr.read_zarr( + path, + arrays={"action": "data/action"}, + split_leading_axis=True, + leading_axis_row_size=2, + file_path_column=None, + ).take(3) + + assert [row["index"] for row in rows] == [0, 1, 2] + assert [row["action"].shape for row in rows] == [(2, 1), (2, 1), (2, 1)] + np.testing.assert_allclose(rows[1]["action"], [[2.0], [3.0]]) def test_read_zarr_split_leading_axis_requires_aligned_lengths(tmp_path: Path) -> None: @@ -269,6 +291,33 @@ def test_read_zarr_split_leading_axis_requires_aligned_lengths(tmp_path: Path) - ).take(1) +def test_read_zarr_split_leading_axis_requires_full_rows(tmp_path: Path) -> None: + path = tmp_path / "partial-leading-axis-row.zarr" + root = zarr.open_group(str(path), mode="w") + root.create_dataset("data/action", data=np.zeros((5, 1), dtype=np.float32)) + + with pytest.raises(ValueError, match="divisible by row size"): + mdr.read_zarr( + path, + arrays={"action": "data/action"}, + split_leading_axis=True, + leading_axis_row_size=2, + file_path_column=None, + ).take(1) + + +def test_read_zarr_leading_axis_row_size_requires_split_mode(tmp_path: Path) -> None: + path = tmp_path / "policy.zarr" + _write_policy_zarr(path) + + with pytest.raises(ValueError, match="requires split_leading_axis"): + mdr.read_zarr( + path, + arrays={"action": "data/action"}, + leading_axis_row_size=2, + ) + + def test_read_zarr_rejects_non_monotonic_row_ends(tmp_path: Path) -> None: path = tmp_path / "policy.zarr" _write_policy_zarr(path) From 8953a2c631db6a5d442b317d2c41effe12235e4e Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sat, 23 May 2026 17:24:12 +0200 Subject: [PATCH 19/39] Move reader path selection helpers into utils --- src/refiner/pipeline/pipeline.py | 6 ++-- src/refiner/pipeline/sources/readers/hdf5.py | 3 +- .../pipeline/sources/readers/selection.py | 33 ------------------- src/refiner/pipeline/sources/readers/utils.py | 28 ++++++++++++++++ src/refiner/pipeline/sources/readers/zarr.py | 6 ++-- 5 files changed, 36 insertions(+), 40 deletions(-) delete mode 100644 src/refiner/pipeline/sources/readers/selection.py diff --git a/src/refiner/pipeline/pipeline.py b/src/refiner/pipeline/pipeline.py index 5bf6bd90..0ddd74ad 100644 --- a/src/refiner/pipeline/pipeline.py +++ b/src/refiner/pipeline/pipeline.py @@ -44,7 +44,6 @@ ) from refiner.pipeline.sources.readers.hdf5 import MissingPolicy from refiner.pipeline.sources.readers.lerobot import LeRobotEpisodeReader -from refiner.pipeline.sources.readers.selection import PathSelection from refiner.pipeline.sources.items import ItemsSource from refiner.pipeline.sources.task import TaskSource from refiner.pipeline.data import datatype @@ -61,7 +60,10 @@ ) from refiner.execution.operators.row import ShardDeltaFn from refiner.pipeline.sources.base import SourceUnit -from refiner.pipeline.sources.readers.utils import DEFAULT_TARGET_SHARD_BYTES +from refiner.pipeline.sources.readers.utils import ( + DEFAULT_TARGET_SHARD_BYTES, + PathSelection, +) import pyarrow as pa if TYPE_CHECKING: diff --git a/src/refiner/pipeline/sources/readers/hdf5.py b/src/refiner/pipeline/sources/readers/hdf5.py index 59211797..c22c1db4 100644 --- a/src/refiner/pipeline/sources/readers/hdf5.py +++ b/src/refiner/pipeline/sources/readers/hdf5.py @@ -13,10 +13,11 @@ from refiner.pipeline.data.row import DictRow from refiner.pipeline.data.shard import FilePartsDescriptor from refiner.pipeline.sources.readers.base import BaseReader, Shard, SourceUnit -from refiner.pipeline.sources.readers.selection import PathSelection, path_selection_map from refiner.pipeline.sources.readers.utils import ( DEFAULT_TARGET_SHARD_BYTES, + PathSelection, decode_value, + path_selection_map, ) from refiner.utils import check_required_dependencies diff --git a/src/refiner/pipeline/sources/readers/selection.py b/src/refiner/pipeline/sources/readers/selection.py deleted file mode 100644 index 283d8cca..00000000 --- a/src/refiner/pipeline/sources/readers/selection.py +++ /dev/null @@ -1,33 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping, Sequence -from typing import cast - - -PathSelection = Mapping[str, str] | Sequence[str] | str - - -def path_selection_map( - value: PathSelection | None, - *, - format_name: str, -) -> dict[str, str]: - if value is None: - return {} - if isinstance(value, str): - return {value.rsplit("/", 1)[-1]: value} - if isinstance(value, Mapping): - return dict(cast(Mapping[str, str], value)) - out: dict[str, str] = {} - for path in value: - name = path.rsplit("/", 1)[-1] - if name in out: - raise ValueError( - f"{format_name} path selections must have unique derived column names; " - f"use an explicit mapping for duplicate name {name!r}" - ) - out[name] = path - return out - - -__all__ = ["PathSelection", "path_selection_map"] diff --git a/src/refiner/pipeline/sources/readers/utils.py b/src/refiner/pipeline/sources/readers/utils.py index 2bedd54e..9307082c 100644 --- a/src/refiner/pipeline/sources/readers/utils.py +++ b/src/refiner/pipeline/sources/readers/utils.py @@ -1,11 +1,14 @@ from __future__ import annotations import io +from collections.abc import Mapping, Sequence from typing import Any, Optional +from typing import cast from fsspec import AbstractFileSystem DEFAULT_TARGET_SHARD_BYTES = 128 * 1024 * 1024 +PathSelection = Mapping[str, str] | Sequence[str] | str # Extensions that generally imply whole-file/container compression (not safely splittable by byte offsets). @@ -58,6 +61,29 @@ def decode_value( return value +def path_selection_map( + value: PathSelection | None, + *, + format_name: str, +) -> dict[str, str]: + if value is None: + return {} + if isinstance(value, str): + return {value.rsplit("/", 1)[-1]: value} + if isinstance(value, Mapping): + return dict(cast(Mapping[str, str], value)) + out: dict[str, str] = {} + for path in value: + name = path.rsplit("/", 1)[-1] + if name in out: + raise ValueError( + f"{format_name} path selections must have unique derived column names; " + f"use an explicit mapping for duplicate name {name!r}" + ) + out[name] = path + return out + + def is_splittable_by_bytes(fs: AbstractFileSystem, path: str) -> bool: """Return True if the input can be safely sharded by byte offsets.""" lp = path.lower() @@ -144,7 +170,9 @@ def readinto(self, b) -> int: __all__ = [ "DEFAULT_TARGET_SHARD_BYTES", "NON_SPLITTABLE_WHOLEFILE_EXTS", + "PathSelection", "decode_value", + "path_selection_map", "is_splittable_by_bytes", "align_byte_range_to_newlines", "BoundedBinaryReader", diff --git a/src/refiner/pipeline/sources/readers/zarr.py b/src/refiner/pipeline/sources/readers/zarr.py index d9b4e1b8..bdd4ef72 100644 --- a/src/refiner/pipeline/sources/readers/zarr.py +++ b/src/refiner/pipeline/sources/readers/zarr.py @@ -16,13 +16,11 @@ from refiner.pipeline.data.row import DictRow from refiner.pipeline.data.shard import RowRangeDescriptor, Shard from refiner.pipeline.sources.base import BaseSource, SourceUnit -from refiner.pipeline.sources.readers.selection import ( - PathSelection, - path_selection_map, -) from refiner.pipeline.sources.readers.utils import ( DEFAULT_TARGET_SHARD_BYTES, + PathSelection, decode_value, + path_selection_map, ) from refiner.utils import check_required_dependencies From 464fc6801e64562b075c21367f279e2ad63582f3 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sat, 23 May 2026 17:31:58 +0200 Subject: [PATCH 20/39] Avoid duplicate Zarr read validation --- src/refiner/pipeline/sources/readers/zarr.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/refiner/pipeline/sources/readers/zarr.py b/src/refiner/pipeline/sources/readers/zarr.py index bdd4ef72..c97a4872 100644 --- a/src/refiner/pipeline/sources/readers/zarr.py +++ b/src/refiner/pipeline/sources/readers/zarr.py @@ -182,11 +182,6 @@ def list_shards(self) -> list[Shard]: def read_shard(self, shard: Shard) -> Iterator[SourceUnit]: group = self._open_group() arrays = self._array_selection(group) - _validate_output_names( - arrays, - self.attrs or {}, - reserved=self._reserved_output_names(split=self._is_split_mode), - ) infos = self._array_infos(group, arrays) if self.row_ends is not None: descriptor = shard.descriptor From 935e720aa761a998f63e08bf5a3e2efda3e478be Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sat, 23 May 2026 17:37:00 +0200 Subject: [PATCH 21/39] Simplify Zarr shard planning --- src/refiner/pipeline/sources/readers/zarr.py | 300 ++++++++----------- 1 file changed, 119 insertions(+), 181 deletions(-) diff --git a/src/refiner/pipeline/sources/readers/zarr.py b/src/refiner/pipeline/sources/readers/zarr.py index c97a4872..7f094354 100644 --- a/src/refiner/pipeline/sources/readers/zarr.py +++ b/src/refiner/pipeline/sources/readers/zarr.py @@ -1,7 +1,6 @@ from __future__ import annotations -from collections.abc import Iterator, Mapping, Sequence -from dataclasses import dataclass +from collections.abc import Iterator, Mapping from math import ceil, prod from typing import Any @@ -25,32 +24,6 @@ from refiner.utils import check_required_dependencies -@dataclass(frozen=True, slots=True) -class _ArrayInfo: - output_name: str - path: str - shape: tuple[int, ...] - chunks: tuple[int, ...] - dtype: Any - - @property - def leading_length(self) -> int: - if not self.shape: - raise ValueError( - f"Zarr array {self.path!r} must have a leading dimension to split" - ) - return int(self.shape[0]) - - @property - def leading_chunk(self) -> int: - return int(self.chunks[0]) if self.chunks else self.leading_length - - @property - def bytes_per_step(self) -> int: - trailing_shape = self.shape[1:] - return max(1, int(self.dtype.itemsize) * int(prod(trailing_shape or (1,)))) - - class ZarrReader(BaseSource): """Read a Zarr group as one row, episode rows, or leading-axis rows.""" @@ -159,15 +132,8 @@ def list_shards(self) -> list[Shard]: path = self.root.abs_path() group = self._open_group() - arrays = self._array_selection(group) - _validate_output_names( - arrays, - self.attrs or {}, - reserved=self._reserved_output_names(split=self._is_split_mode), - ) - split_ranges = self._planned_shard_ranges( - group, self._array_infos(group, arrays) - ) + arrays = self._selected_arrays(group, validate_names=True) + split_ranges = self._shard_ranges(group, arrays) return [ Shard.from_row_range( start=start, @@ -181,14 +147,13 @@ def list_shards(self) -> list[Shard]: def read_shard(self, shard: Shard) -> Iterator[SourceUnit]: group = self._open_group() - arrays = self._array_selection(group) - infos = self._array_infos(group, arrays) + arrays = self._selected_arrays(group) if self.row_ends is not None: descriptor = shard.descriptor assert isinstance(descriptor, RowRangeDescriptor) - source_ranges = self._row_end_source_ranges( + source_ranges = self._row_end_ranges( group, - infos, + arrays, row_start=descriptor.start, row_end=descriptor.end, ) @@ -196,7 +161,7 @@ def read_shard(self, shard: Shard) -> Iterator[SourceUnit]: return block_start = source_ranges[0][0] block_end = source_ranges[-1][1] - block = self._read_arrays(group, arrays, start=block_start, end=block_end) + block = self._read_arrays(arrays, start=block_start, end=block_end) attrs = self._read_attrs(group) for row_index, (start, end) in zip( range(descriptor.start, descriptor.end), @@ -220,7 +185,6 @@ def read_shard(self, shard: Shard) -> Iterator[SourceUnit]: raw_start = descriptor.start * self.leading_axis_row_size raw_end = descriptor.end * self.leading_axis_row_size block = self._read_arrays( - group, arrays, start=raw_start, end=raw_end, @@ -240,7 +204,7 @@ def read_shard(self, shard: Shard) -> Iterator[SourceUnit]: return row = self._row_metadata(index=None) - row.update(self._read_arrays(group, arrays)) + row.update(self._read_arrays(arrays)) row.update(self._read_attrs(group)) yield DictRow(row) @@ -270,39 +234,37 @@ def _row_metadata(self, *, index: int | None) -> dict[str, Any]: row[self.index_column] = index return row - def _array_selection(self, group: Any) -> dict[str, str]: - if self.arrays is not None: - return self.arrays - return { - path: path for path in _iter_array_paths(group) if path != self.row_ends - } - - def _array_infos( + def _selected_arrays( self, group: Any, - arrays: Mapping[str, str], - ) -> list[_ArrayInfo]: - infos: list[_ArrayInfo] = [] - for output_name, path in arrays.items(): + *, + validate_names: bool = False, + ) -> dict[str, Any]: + paths = ( + self.arrays + if self.arrays is not None + else { + path: path for path in _iter_array_paths(group) if path != self.row_ends + } + ) + if validate_names: + _validate_output_names( + paths, + self.attrs or {}, + reserved=self._reserved_output_names(split=self._is_split_mode), + ) + arrays: dict[str, Any] = {} + for output_name, path in paths.items(): try: - array = group[path] + arrays[output_name] = group[path] except KeyError: raise KeyError(f"Zarr array not found: {path}") from None - infos.append( - _ArrayInfo( - output_name=output_name, - path=path, - shape=tuple(int(value) for value in array.shape), - chunks=tuple(int(value) for value in array.chunks), - dtype=array.dtype, - ) - ) - return infos + return arrays - def _row_end_source_ranges( + def _row_end_ranges( self, group: Any, - infos: Sequence[_ArrayInfo], + arrays: Mapping[str, Any], *, row_start: int, row_end: int, @@ -316,59 +278,84 @@ def _row_end_source_ranges( values = [int(value) for value in row_ends_array[read_start:row_end]] if len(values) != row_end - read_start: raise ValueError("Zarr shard row range exceeds row_ends length") - ranges = ( - _ranges_from_ends(values) - if row_start == 0 - else _ranges_from_ends(values[1:], start=values[0]) - ) - _validate_source_ranges(ranges, infos, label="row_ends") + ranges: list[tuple[int, int]] = [] + start = 0 if row_start == 0 else values[0] + for end in values if row_start == 0 else values[1:]: + if end < start: + raise ValueError("Zarr row_ends must be monotonic increasing") + ranges.append((start, end)) + start = end + _check_final_end(arrays, ranges[-1][1], label="row_ends") return ranges - def _planned_shard_ranges( + def _shard_ranges( self, group: Any, - infos: Sequence[_ArrayInfo], + arrays: Mapping[str, Any], ) -> list[tuple[int, int]]: if not self._is_split_mode: return [(0, 1)] - if self.split_leading_axis: - return _leading_axis_shard_ranges( - infos, - row_size=self.leading_axis_row_size, - target_shard_bytes=self.target_shard_bytes, - num_shards=self.num_shards, - ) - return self._row_end_shard_ranges(group, infos) - def _row_ends_array(self, group: Any) -> Any: - try: - row_ends_array = group[self.row_ends] - except KeyError: - raise KeyError(f"Zarr row_ends array not found: {self.row_ends}") from None - if len(row_ends_array.shape) != 1: - raise ValueError("Zarr row_ends must be one-dimensional") - return row_ends_array + if self.split_leading_axis: + if not arrays: + raise ValueError( + "split_leading_axis requires at least one selected array" + ) + lengths: set[int] = set() + for array in arrays.values(): + if not array.shape: + raise ValueError( + "Zarr selected arrays must have a leading dimension to split" + ) + lengths.add(int(array.shape[0])) + if len(lengths) != 1: + raise ValueError( + "Zarr selected arrays must have the same leading dimension" + ) + length = lengths.pop() + if length == 0: + return [] + if length % self.leading_axis_row_size != 0: + raise ValueError("Zarr leading dimension must be divisible by row size") + row_count = length // self.leading_axis_row_size + if self.num_shards is not None: + step = ceil(row_count / self.num_shards) + else: + bytes_per_row = ( + sum(_leading_item_bytes(array) for array in arrays.values()) + * self.leading_axis_row_size + ) + target_rows = max(1, self.target_shard_bytes // max(1, bytes_per_row)) + chunk_rows = max( + 1, + ceil( + max(_leading_chunk(array) for array in arrays.values()) + / self.leading_axis_row_size + ), + ) + step = max(chunk_rows, (target_rows // chunk_rows) * chunk_rows) + return [ + (start, min(start + step, row_count)) + for start in range(0, row_count, step) + ] - def _row_end_shard_ranges( - self, - group: Any, - infos: Sequence[_ArrayInfo], - ) -> list[tuple[int, int]]: row_ends_array = self._row_ends_array(group) row_count = int(row_ends_array.shape[0]) if row_count == 0: return [] if self.num_shards is not None: - _validate_row_ends(row_ends_array, infos) + final_end = _validate_row_ends(row_ends_array) + _check_final_end(arrays, final_end, label="row_ends", exact=True) step = ceil(row_count / self.num_shards) return [ (start, min(start + step, row_count)) for start in range(0, row_count, step) ] - bytes_per_step = sum(info.bytes_per_step for info in infos) + bytes_per_step = sum(_leading_item_bytes(array) for array in arrays.values()) if bytes_per_step <= 0: - _validate_row_ends(row_ends_array, infos) + final_end = _validate_row_ends(row_ends_array) + _check_final_end(arrays, final_end, label="row_ends", exact=True) return [(0, row_count)] ranges: list[tuple[int, int]] = [] @@ -389,28 +376,27 @@ def _row_end_shard_ranges( current_bytes += row_bytes previous_end = end ranges.append((shard_start, row_count)) - _validate_source_ranges( - [(0, previous_end)], - infos, - label="row_ends", - require_exact=True, - ) + _check_final_end(arrays, previous_end, label="row_ends", exact=True) return ranges + def _row_ends_array(self, group: Any) -> Any: + try: + row_ends_array = group[self.row_ends] + except KeyError: + raise KeyError(f"Zarr row_ends array not found: {self.row_ends}") from None + if len(row_ends_array.shape) != 1: + raise ValueError("Zarr row_ends must be one-dimensional") + return row_ends_array + def _read_arrays( self, - group: Any, - arrays: Mapping[str, str], + arrays: Mapping[str, Any], *, start: int | None = None, end: int | None = None, ) -> dict[str, Any]: row: dict[str, Any] = {} - for output_name, path in arrays.items(): - try: - array = group[path] - except KeyError: - raise KeyError(f"Zarr array not found: {path}") from None + for output_name, array in arrays.items(): if start is not None: row[output_name] = array[start:end] elif array.shape == (): @@ -444,20 +430,6 @@ def _validate_output_names( raise ValueError(f"Zarr selections use reserved output names: {names}") -def _ranges_from_ends( - ends: Sequence[int], - *, - start: int = 0, -) -> list[tuple[int, int]]: - ranges: list[tuple[int, int]] = [] - for end in ends: - if end < start: - raise ValueError("Zarr row_ends must be monotonic increasing") - ranges.append((start, end)) - start = end - return ranges - - def _iter_row_ends(array: Any) -> Iterator[tuple[int, int]]: chunk = max(1, int(array.chunks[0]) if array.chunks else 8192) length = int(array.shape[0]) @@ -466,75 +438,41 @@ def _iter_row_ends(array: Any) -> Iterator[tuple[int, int]]: yield start + offset, int(value) -def _validate_row_ends( - array: Any, - infos: Sequence[_ArrayInfo], -) -> None: +def _validate_row_ends(array: Any) -> int: previous_end = 0 for _, end in _iter_row_ends(array): if end < previous_end: raise ValueError("Zarr row_ends must be monotonic increasing") previous_end = end - _validate_source_ranges( - [(0, previous_end)], - infos, - label="row_ends", - require_exact=True, - ) + return previous_end -def _validate_source_ranges( - ranges: Sequence[tuple[int, int]], - infos: Sequence[_ArrayInfo], +def _check_final_end( + arrays: Mapping[str, Any], + final_end: int, *, label: str, - require_exact: bool = False, + exact: bool = False, ) -> None: - if not infos: - return - final_end = ranges[-1][1] if ranges else 0 - for info in infos: - if final_end > info.leading_length: + for output_name, array in arrays.items(): + leading_length = int(array.shape[0]) + if final_end > leading_length: raise ValueError( - f"Zarr {label} exceed leading dimension for {info.output_name!r}" + f"Zarr {label} exceed leading dimension for {output_name!r}" ) - if require_exact and final_end != info.leading_length: + if exact and final_end != leading_length: raise ValueError( - f"Zarr {label} end before leading dimension for {info.output_name!r}" + f"Zarr {label} end before leading dimension for {output_name!r}" ) -def _leading_axis_shard_ranges( - infos: Sequence[_ArrayInfo], - *, - row_size: int, - target_shard_bytes: int, - num_shards: int | None, -) -> list[tuple[int, int]]: - if not infos: - raise ValueError("split_leading_axis requires at least one selected array") - lengths = {info.leading_length for info in infos} - if len(lengths) != 1: - raise ValueError("Zarr selected arrays must have the same leading dimension") - length = lengths.pop() - if length == 0: - return [] - if length % row_size != 0: - raise ValueError("Zarr leading dimension must be divisible by row size") - row_count = length // row_size - if num_shards is not None: - step = ceil(row_count / num_shards) - else: - bytes_per_row = sum(info.bytes_per_step for info in infos) * row_size - target_rows = max(1, target_shard_bytes // max(1, bytes_per_row)) - chunk_rows = max( - 1, - ceil(max(info.leading_chunk for info in infos) / row_size), - ) - step = max(chunk_rows, (target_rows // chunk_rows) * chunk_rows) - return [ - (start, min(start + step, row_count)) for start in range(0, row_count, step) - ] +def _leading_item_bytes(array: Any) -> int: + trailing_shape = tuple(int(value) for value in array.shape[1:]) + return max(1, int(array.dtype.itemsize) * int(prod(trailing_shape or (1,)))) + + +def _leading_chunk(array: Any) -> int: + return int(array.chunks[0]) if array.chunks else int(array.shape[0]) def _iter_array_paths(group: Any, prefix: str = "") -> Iterator[str]: From 772d2de37eb6e053c5a8053b569f4270f30ad9ed Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sat, 23 May 2026 17:40:02 +0200 Subject: [PATCH 22/39] Inline tiny Zarr helpers --- src/refiner/pipeline/sources/readers/zarr.py | 27 ++++++++++---------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/src/refiner/pipeline/sources/readers/zarr.py b/src/refiner/pipeline/sources/readers/zarr.py index 7f094354..c9cd73be 100644 --- a/src/refiner/pipeline/sources/readers/zarr.py +++ b/src/refiner/pipeline/sources/readers/zarr.py @@ -96,10 +96,12 @@ def __init__( _validate_output_names( self.arrays or {}, self.attrs or {}, - reserved=self._reserved_output_names(split=self._is_split_mode), + reserved=self._reserved_output_names( + split=row_ends is not None or split_leading_axis + ), ) if ( - self._is_split_mode + (row_ends is not None or split_leading_axis) and file_path_column is not None and file_path_column == index_column ): @@ -214,10 +216,6 @@ def _open_group(self) -> Any: store = zarr.storage.FSStore(self.root._join(""), fs=self.root.fs, mode="r") return zarr.open_group(store=store, mode="r") - @property - def _is_split_mode(self) -> bool: - return self.row_ends is not None or self.split_leading_axis - def _reserved_output_names(self, *, split: bool) -> set[str]: names = set() if self.file_path_column is not None: @@ -251,7 +249,9 @@ def _selected_arrays( _validate_output_names( paths, self.attrs or {}, - reserved=self._reserved_output_names(split=self._is_split_mode), + reserved=self._reserved_output_names( + split=self.row_ends is not None or self.split_leading_axis + ), ) arrays: dict[str, Any] = {} for output_name, path in paths.items(): @@ -293,7 +293,7 @@ def _shard_ranges( group: Any, arrays: Mapping[str, Any], ) -> list[tuple[int, int]]: - if not self._is_split_mode: + if self.row_ends is None and not self.split_leading_axis: return [(0, 1)] if self.split_leading_axis: @@ -329,7 +329,12 @@ def _shard_ranges( chunk_rows = max( 1, ceil( - max(_leading_chunk(array) for array in arrays.values()) + max( + int(array.chunks[0]) + if array.chunks + else int(array.shape[0]) + for array in arrays.values() + ) / self.leading_axis_row_size ), ) @@ -471,10 +476,6 @@ def _leading_item_bytes(array: Any) -> int: return max(1, int(array.dtype.itemsize) * int(prod(trailing_shape or (1,)))) -def _leading_chunk(array: Any) -> int: - return int(array.chunks[0]) if array.chunks else int(array.shape[0]) - - def _iter_array_paths(group: Any, prefix: str = "") -> Iterator[str]: for name, item in group.items(): path = f"{prefix}/{name}" if prefix else name From 4ec551e7c5a3bbb7cd2a6cf991b3587653ac4d52 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sat, 23 May 2026 18:01:26 +0200 Subject: [PATCH 23/39] Support Zarr 3 runtime --- pyproject.toml | 4 +- src/refiner/pipeline/sources/readers/zarr.py | 32 ++++- tests/readers/test_zarr_reader.py | 97 +++++++++++---- uv.lock | 122 +++++++++++++++++-- 4 files changed, 219 insertions(+), 36 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c7764dbb..b1c8f122 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,9 @@ hdf5 = [ "h5py", ] zarr = [ - "zarr>=2.18,<3", + "zarr>=2.18,<3; python_version < '3.11'", + "zarr>=3; python_version >= '3.11'", + "numcodecs<0.16; python_version < '3.11'", ] s3 = [ "s3fs", diff --git a/src/refiner/pipeline/sources/readers/zarr.py b/src/refiner/pipeline/sources/readers/zarr.py index c9cd73be..4fa08064 100644 --- a/src/refiner/pipeline/sources/readers/zarr.py +++ b/src/refiner/pipeline/sources/readers/zarr.py @@ -2,6 +2,7 @@ from collections.abc import Iterator, Mapping from math import ceil, prod +from operator import index as integer_index from typing import Any import pyarrow as pa @@ -212,8 +213,15 @@ def read_shard(self, shard: Shard) -> Iterator[SourceUnit]: def _open_group(self) -> Any: import zarr + import zarr.storage - store = zarr.storage.FSStore(self.root._join(""), fs=self.root.fs, mode="r") + if hasattr(zarr.storage, "FsspecStore"): + store = zarr.storage.FsspecStore.from_url( + self.root.abs_path(), + read_only=True, + ) + else: + store = zarr.storage.FSStore(self.root._join(""), fs=self.root.fs, mode="r") return zarr.open_group(store=store, mode="r") def _reserved_output_names(self, *, split: bool) -> set[str]: @@ -275,7 +283,9 @@ def _row_end_ranges( return [] row_ends_array = self._row_ends_array(group) read_start = max(0, row_start - 1) - values = [int(value) for value in row_ends_array[read_start:row_end]] + values = [ + _row_end_offset(value) for value in row_ends_array[read_start:row_end] + ] if len(values) != row_end - read_start: raise ValueError("Zarr shard row range exceeds row_ends length") ranges: list[tuple[int, int]] = [] @@ -440,7 +450,16 @@ def _iter_row_ends(array: Any) -> Iterator[tuple[int, int]]: length = int(array.shape[0]) for start in range(0, length, chunk): for offset, value in enumerate(array[start : min(start + chunk, length)]): - yield start + offset, int(value) + yield start + offset, _row_end_offset(value) + + +def _row_end_offset(value: Any) -> int: + if isinstance(value, bool): + raise ValueError("Zarr row_ends must contain integer offsets") + try: + return integer_index(value) + except TypeError: + raise ValueError("Zarr row_ends must contain integer offsets") from None def _validate_row_ends(array: Any) -> int: @@ -460,6 +479,10 @@ def _check_final_end( exact: bool = False, ) -> None: for output_name, array in arrays.items(): + if not array.shape: + raise ValueError( + f"Zarr selected array {output_name!r} must have a leading dimension" + ) leading_length = int(array.shape[0]) if final_end > leading_length: raise ValueError( @@ -477,7 +500,8 @@ def _leading_item_bytes(array: Any) -> int: def _iter_array_paths(group: Any, prefix: str = "") -> Iterator[str]: - for name, item in group.items(): + items = group.items() if hasattr(group, "items") else group.members() + for name, item in items: path = f"{prefix}/{name}" if prefix else name if hasattr(item, "shape"): yield path diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py index 3c41fae8..6180a669 100644 --- a/tests/readers/test_zarr_reader.py +++ b/tests/readers/test_zarr_reader.py @@ -1,7 +1,7 @@ from __future__ import annotations from pathlib import Path -from typing import cast +from typing import Any, Literal, cast import numpy as np import pytest @@ -13,21 +13,39 @@ from refiner.pipeline.data.shard import RowRangeDescriptor +def _open_test_zarr(path: Path, *, mode: Literal["r", "r+", "a", "w", "w-"]): + kwargs: dict[str, Any] = {"mode": mode, "zarr_format": 2} + try: + return zarr.open_group(str(path), **kwargs) + except TypeError: + return zarr.open_group(str(path), mode=mode) + + +def _create_array(root, name: str, data, **kwargs): + if hasattr(root, "create_array"): + kwargs.pop("shape", None) + return root.create_array(name, data=data, **kwargs) + return root.create_dataset(name, data=data, **kwargs) + + def _write_policy_zarr(path: Path) -> None: - root = zarr.open_group(str(path), mode="w") - root.create_dataset( + root = _open_test_zarr(path, mode="w") + _create_array( + root, "data/action", data=np.asarray([[0.0], [0.1], [1.0], [1.1], [1.2]], dtype=np.float32), ) - root.create_dataset( + _create_array( + root, "data/state", data=np.asarray([[10.0], [10.1], [20.0], [20.1], [20.2]], dtype=np.float32), ) - root.create_dataset( + _create_array( + root, "data/rgb", data=np.arange(5 * 4 * 4 * 3, dtype=np.uint8).reshape(5, 4, 4, 3), ) - root.create_dataset("meta/episode_ends", data=np.asarray([2, 5], dtype=np.int64)) + _create_array(root, "meta/episode_ends", data=np.asarray([2, 5], dtype=np.int64)) root.attrs["dataset_id"] = "pusht" root.attrs["task"] = "push tee" @@ -62,8 +80,8 @@ def test_read_zarr_reads_selected_arrays_and_attrs(tmp_path: Path) -> None: def test_read_zarr_reads_scalar_arrays(tmp_path: Path) -> None: path = tmp_path / "scalar.zarr" - root = zarr.open_group(str(path), mode="w") - root.create_dataset("version", data=np.asarray(3, dtype=np.int64), shape=()) + root = _open_test_zarr(path, mode="w") + _create_array(root, "version", data=np.asarray(3, dtype=np.int64), shape=()) row = mdr.read_zarr( path, @@ -155,8 +173,8 @@ def test_read_zarr_rejects_duplicate_output_names(tmp_path: Path) -> None: def test_read_zarr_rejects_discovered_array_attr_collisions(tmp_path: Path) -> None: path = tmp_path / "collision.zarr" - root = zarr.open_group(str(path), mode="w") - root.create_dataset("task", data=np.asarray([1], dtype=np.int64)) + root = _open_test_zarr(path, mode="w") + _create_array(root, "task", data=np.asarray([1], dtype=np.int64)) root.attrs["task"] = "push tee" pipeline = mdr.read_zarr(path, attrs={"task": "task"}, file_path_column=None) @@ -218,13 +236,15 @@ def test_read_zarr_rejects_missing_selected_attrs(tmp_path: Path) -> None: def test_read_zarr_split_leading_axis_emits_one_row_per_index(tmp_path: Path) -> None: path = tmp_path / "leading_axis.zarr" - root = zarr.open_group(str(path), mode="w") - root.create_dataset( + root = _open_test_zarr(path, mode="w") + _create_array( + root, "data/action", data=np.arange(5, dtype=np.float32).reshape(5, 1), chunks=(1, 1), ) - root.create_dataset( + _create_array( + root, "data/rgb", data=np.arange(5 * 4 * 4 * 3, dtype=np.uint8).reshape(5, 4, 4, 3), chunks=(2, 4, 4, 3), @@ -256,8 +276,9 @@ def test_read_zarr_split_leading_axis_emits_one_row_per_index(tmp_path: Path) -> def test_read_zarr_split_leading_axis_uses_row_size(tmp_path: Path) -> None: path = tmp_path / "leading_axis_rows.zarr" - root = zarr.open_group(str(path), mode="w") - root.create_dataset( + root = _open_test_zarr(path, mode="w") + _create_array( + root, "data/action", data=np.arange(6, dtype=np.float32).reshape(6, 1), chunks=(2, 1), @@ -278,9 +299,9 @@ def test_read_zarr_split_leading_axis_uses_row_size(tmp_path: Path) -> None: def test_read_zarr_split_leading_axis_requires_aligned_lengths(tmp_path: Path) -> None: path = tmp_path / "misaligned.zarr" - root = zarr.open_group(str(path), mode="w") - root.create_dataset("data/action", data=np.zeros((5, 1), dtype=np.float32)) - root.create_dataset("data/state", data=np.zeros((4, 1), dtype=np.float32)) + root = _open_test_zarr(path, mode="w") + _create_array(root, "data/action", data=np.zeros((5, 1), dtype=np.float32)) + _create_array(root, "data/state", data=np.zeros((4, 1), dtype=np.float32)) with pytest.raises(ValueError, match="same leading dimension"): mdr.read_zarr( @@ -293,8 +314,8 @@ def test_read_zarr_split_leading_axis_requires_aligned_lengths(tmp_path: Path) - def test_read_zarr_split_leading_axis_requires_full_rows(tmp_path: Path) -> None: path = tmp_path / "partial-leading-axis-row.zarr" - root = zarr.open_group(str(path), mode="w") - root.create_dataset("data/action", data=np.zeros((5, 1), dtype=np.float32)) + root = _open_test_zarr(path, mode="w") + _create_array(root, "data/action", data=np.zeros((5, 1), dtype=np.float32)) with pytest.raises(ValueError, match="divisible by row size"): mdr.read_zarr( @@ -321,7 +342,7 @@ def test_read_zarr_leading_axis_row_size_requires_split_mode(tmp_path: Path) -> def test_read_zarr_rejects_non_monotonic_row_ends(tmp_path: Path) -> None: path = tmp_path / "policy.zarr" _write_policy_zarr(path) - root = zarr.open_group(str(path), mode="a") + root = _open_test_zarr(path, mode="a") root["meta/episode_ends"][:] = np.asarray([3, 2], dtype=np.int64) with pytest.raises(ValueError, match="row_ends must be monotonic"): @@ -333,10 +354,25 @@ def test_read_zarr_rejects_non_monotonic_row_ends(tmp_path: Path) -> None: ).take(1) +def test_read_zarr_rejects_non_integer_row_ends(tmp_path: Path) -> None: + path = tmp_path / "float-row-ends.zarr" + root = _open_test_zarr(path, mode="w") + _create_array(root, "data/action", data=np.zeros((5, 1), dtype=np.float32)) + _create_array(root, "meta/episode_ends", data=np.asarray([2.5, 5.0])) + + with pytest.raises(ValueError, match="integer offsets"): + mdr.read_zarr( + path, + arrays={"action": "data/action"}, + row_ends="meta/episode_ends", + file_path_column=None, + ).take(1) + + def test_read_zarr_rejects_out_of_range_row_ends(tmp_path: Path) -> None: path = tmp_path / "policy.zarr" _write_policy_zarr(path) - root = zarr.open_group(str(path), mode="a") + root = _open_test_zarr(path, mode="a") root["meta/episode_ends"][:] = np.asarray([2, 6], dtype=np.int64) with pytest.raises(ValueError, match="row_ends exceed"): @@ -351,7 +387,7 @@ def test_read_zarr_rejects_out_of_range_row_ends(tmp_path: Path) -> None: def test_read_zarr_rejects_short_row_ends(tmp_path: Path) -> None: path = tmp_path / "policy.zarr" _write_policy_zarr(path) - root = zarr.open_group(str(path), mode="a") + root = _open_test_zarr(path, mode="a") root["meta/episode_ends"][:] = np.asarray([2, 4], dtype=np.int64) with pytest.raises(ValueError, match="end before leading dimension"): @@ -363,6 +399,21 @@ def test_read_zarr_rejects_short_row_ends(tmp_path: Path) -> None: ).take(2) +def test_read_zarr_rejects_scalar_arrays_in_row_ends_mode(tmp_path: Path) -> None: + path = tmp_path / "scalar-row-ends.zarr" + root = _open_test_zarr(path, mode="w") + _create_array(root, "version", data=np.asarray(3, dtype=np.int64), shape=()) + _create_array(root, "meta/episode_ends", data=np.asarray([1], dtype=np.int64)) + + with pytest.raises(ValueError, match="must have a leading dimension"): + mdr.read_zarr( + path, + arrays={"version": "version"}, + row_ends="meta/episode_ends", + file_path_column=None, + ).take(1) + + def test_zarr_to_robot_rows_and_lerobot_roundtrip(tmp_path: Path) -> None: path = tmp_path / "policy.zarr" lerobot_out = tmp_path / "lerobot" diff --git a/uv.lock b/uv.lock index 82a0836e..b9215124 100644 --- a/uv.lock +++ b/uv.lock @@ -618,6 +618,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/33/6b/e0547afaf41bf2c42e52430072fa5658766e3d65bd4b03a563d1b6336f57/distlib-0.4.0-py2.py3-none-any.whl", hash = "sha256:9659f7d87e46584a30b5780e43ac7a2143098441670ff0a49d5f9034c54a6c16", size = 469047, upload-time = "2025-07-17T16:51:58.613Z" }, ] +[[package]] +name = "donfig" +version = "0.8.1.post1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyyaml", marker = "python_full_version >= '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/25/71/80cc718ff6d7abfbabacb1f57aaa42e9c1552bfdd01e64ddd704e4a03638/donfig-0.8.1.post1.tar.gz", hash = "sha256:3bef3413a4c1c601b585e8d297256d0c1470ea012afa6e8461dc28bfb7c23f52", size = 19506, upload-time = "2024-05-23T14:14:31.513Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0c/d5/c5db1ea3394c6e1732fb3286b3bd878b59507a8f77d32a2cebda7d7b7cd4/donfig-0.8.1.post1-py3-none-any.whl", hash = "sha256:2a3175ce74a06109ff9307d90a230f81215cbac9a751f4d1c6194644b8204f9d", size = 21592, upload-time = "2024-05-23T14:13:55.283Z" }, +] + [[package]] name = "exceptiongroup" version = "1.3.1" @@ -783,6 +795,41 @@ http = [ { name = "aiohttp" }, ] +[[package]] +name = "google-crc32c" +version = "1.8.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/03/41/4b9c02f99e4c5fb477122cd5437403b552873f014616ac1d19ac8221a58d/google_crc32c-1.8.0.tar.gz", hash = "sha256:a428e25fb7691024de47fecfbff7ff957214da51eddded0da0ae0e0f03a2cf79", size = 14192, upload-time = "2025-12-16T00:35:25.142Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/95/ac/6f7bc93886a823ab545948c2dd48143027b2355ad1944c7cf852b338dc91/google_crc32c-1.8.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:0470b8c3d73b5f4e3300165498e4cf25221c7eb37f1159e221d1825b6df8a7ff", size = 31296, upload-time = "2025-12-16T00:19:07.261Z" }, + { url = "https://files.pythonhosted.org/packages/f7/97/a5accde175dee985311d949cfcb1249dcbb290f5ec83c994ea733311948f/google_crc32c-1.8.0-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:119fcd90c57c89f30040b47c211acee231b25a45d225e3225294386f5d258288", size = 30870, upload-time = "2025-12-16T00:29:17.669Z" }, + { url = "https://files.pythonhosted.org/packages/3d/63/bec827e70b7a0d4094e7476f863c0dbd6b5f0f1f91d9c9b32b76dcdfeb4e/google_crc32c-1.8.0-cp310-cp310-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:6f35aaffc8ccd81ba3162443fabb920e65b1f20ab1952a31b13173a67811467d", size = 33214, upload-time = "2025-12-16T00:40:19.618Z" }, + { url = "https://files.pythonhosted.org/packages/63/bc/11b70614df04c289128d782efc084b9035ef8466b3d0a8757c1b6f5cf7ac/google_crc32c-1.8.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:864abafe7d6e2c4c66395c1eb0fe12dc891879769b52a3d56499612ca93b6092", size = 33589, upload-time = "2025-12-16T00:40:20.7Z" }, + { url = "https://files.pythonhosted.org/packages/3e/00/a08a4bc24f1261cc5b0f47312d8aebfbe4b53c2e6307f1b595605eed246b/google_crc32c-1.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:db3fe8eaf0612fc8b20fa21a5f25bd785bc3cd5be69f8f3412b0ac2ffd49e733", size = 34437, upload-time = "2025-12-16T00:35:19.437Z" }, + { url = "https://files.pythonhosted.org/packages/5d/ef/21ccfaab3d5078d41efe8612e0ed0bfc9ce22475de074162a91a25f7980d/google_crc32c-1.8.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:014a7e68d623e9a4222d663931febc3033c5c7c9730785727de2a81f87d5bab8", size = 31298, upload-time = "2025-12-16T00:20:32.241Z" }, + { url = "https://files.pythonhosted.org/packages/c5/b8/f8413d3f4b676136e965e764ceedec904fe38ae8de0cdc52a12d8eb1096e/google_crc32c-1.8.0-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:86cfc00fe45a0ac7359e5214a1704e51a99e757d0272554874f419f79838c5f7", size = 30872, upload-time = "2025-12-16T00:33:58.785Z" }, + { url = "https://files.pythonhosted.org/packages/f6/fd/33aa4ec62b290477181c55bb1c9302c9698c58c0ce9a6ab4874abc8b0d60/google_crc32c-1.8.0-cp311-cp311-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:19b40d637a54cb71e0829179f6cb41835f0fbd9e8eb60552152a8b52c36cbe15", size = 33243, upload-time = "2025-12-16T00:40:21.46Z" }, + { url = "https://files.pythonhosted.org/packages/71/03/4820b3bd99c9653d1a5210cb32f9ba4da9681619b4d35b6a052432df4773/google_crc32c-1.8.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:17446feb05abddc187e5441a45971b8394ea4c1b6efd88ab0af393fd9e0a156a", size = 33608, upload-time = "2025-12-16T00:40:22.204Z" }, + { url = "https://files.pythonhosted.org/packages/7c/43/acf61476a11437bf9733fb2f70599b1ced11ec7ed9ea760fdd9a77d0c619/google_crc32c-1.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:71734788a88f551fbd6a97be9668a0020698e07b2bf5b3aa26a36c10cdfb27b2", size = 34439, upload-time = "2025-12-16T00:35:20.458Z" }, + { url = "https://files.pythonhosted.org/packages/e9/5f/7307325b1198b59324c0fa9807cafb551afb65e831699f2ce211ad5c8240/google_crc32c-1.8.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:4b8286b659c1335172e39563ab0a768b8015e88e08329fa5321f774275fc3113", size = 31300, upload-time = "2025-12-16T00:21:56.723Z" }, + { url = "https://files.pythonhosted.org/packages/21/8e/58c0d5d86e2220e6a37befe7e6a94dd2f6006044b1a33edf1ff6d9f7e319/google_crc32c-1.8.0-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:2a3dc3318507de089c5384cc74d54318401410f82aa65b2d9cdde9d297aca7cb", size = 30867, upload-time = "2025-12-16T00:38:31.302Z" }, + { url = "https://files.pythonhosted.org/packages/ce/a9/a780cc66f86335a6019f557a8aaca8fbb970728f0efd2430d15ff1beae0e/google_crc32c-1.8.0-cp312-cp312-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:14f87e04d613dfa218d6135e81b78272c3b904e2a7053b841481b38a7d901411", size = 33364, upload-time = "2025-12-16T00:40:22.96Z" }, + { url = "https://files.pythonhosted.org/packages/21/3f/3457ea803db0198c9aaca2dd373750972ce28a26f00544b6b85088811939/google_crc32c-1.8.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cb5c869c2923d56cb0c8e6bcdd73c009c36ae39b652dbe46a05eb4ef0ad01454", size = 33740, upload-time = "2025-12-16T00:40:23.96Z" }, + { url = "https://files.pythonhosted.org/packages/df/c0/87c2073e0c72515bb8733d4eef7b21548e8d189f094b5dad20b0ecaf64f6/google_crc32c-1.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:3cc0c8912038065eafa603b238abf252e204accab2a704c63b9e14837a854962", size = 34437, upload-time = "2025-12-16T00:35:21.395Z" }, + { url = "https://files.pythonhosted.org/packages/d1/db/000f15b41724589b0e7bc24bc7a8967898d8d3bc8caf64c513d91ef1f6c0/google_crc32c-1.8.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:3ebb04528e83b2634857f43f9bb8ef5b2bbe7f10f140daeb01b58f972d04736b", size = 31297, upload-time = "2025-12-16T00:23:20.709Z" }, + { url = "https://files.pythonhosted.org/packages/d7/0d/8ebed0c39c53a7e838e2a486da8abb0e52de135f1b376ae2f0b160eb4c1a/google_crc32c-1.8.0-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:450dc98429d3e33ed2926fc99ee81001928d63460f8538f21a5d6060912a8e27", size = 30867, upload-time = "2025-12-16T00:43:14.628Z" }, + { url = "https://files.pythonhosted.org/packages/ce/42/b468aec74a0354b34c8cbf748db20d6e350a68a2b0912e128cabee49806c/google_crc32c-1.8.0-cp313-cp313-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:3b9776774b24ba76831609ffbabce8cdf6fa2bd5e9df37b594221c7e333a81fa", size = 33344, upload-time = "2025-12-16T00:40:24.742Z" }, + { url = "https://files.pythonhosted.org/packages/1c/e8/b33784d6fc77fb5062a8a7854e43e1e618b87d5ddf610a88025e4de6226e/google_crc32c-1.8.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:89c17d53d75562edfff86679244830599ee0a48efc216200691de8b02ab6b2b8", size = 33694, upload-time = "2025-12-16T00:40:25.505Z" }, + { url = "https://files.pythonhosted.org/packages/92/b1/d3cbd4d988afb3d8e4db94ca953df429ed6db7282ed0e700d25e6c7bfc8d/google_crc32c-1.8.0-cp313-cp313-win_amd64.whl", hash = "sha256:57a50a9035b75643996fbf224d6661e386c7162d1dfdab9bc4ca790947d1007f", size = 34435, upload-time = "2025-12-16T00:35:22.107Z" }, + { url = "https://files.pythonhosted.org/packages/21/88/8ecf3c2b864a490b9e7010c84fd203ec8cf3b280651106a3a74dd1b0ca72/google_crc32c-1.8.0-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:e6584b12cb06796d285d09e33f63309a09368b9d806a551d8036a4207ea43697", size = 31301, upload-time = "2025-12-16T00:24:48.527Z" }, + { url = "https://files.pythonhosted.org/packages/36/c6/f7ff6c11f5ca215d9f43d3629163727a272eabc356e5c9b2853df2bfe965/google_crc32c-1.8.0-cp314-cp314-macosx_12_0_x86_64.whl", hash = "sha256:f4b51844ef67d6cf2e9425983274da75f18b1597bb2c998e1c0a0e8d46f8f651", size = 30868, upload-time = "2025-12-16T00:48:12.163Z" }, + { url = "https://files.pythonhosted.org/packages/56/15/c25671c7aad70f8179d858c55a6ae8404902abe0cdcf32a29d581792b491/google_crc32c-1.8.0-cp314-cp314-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:b0d1a7afc6e8e4635564ba8aa5c0548e3173e41b6384d7711a9123165f582de2", size = 33381, upload-time = "2025-12-16T00:40:26.268Z" }, + { url = "https://files.pythonhosted.org/packages/42/fa/f50f51260d7b0ef5d4898af122d8a7ec5a84e2984f676f746445f783705f/google_crc32c-1.8.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8b3f68782f3cbd1bce027e48768293072813469af6a61a86f6bb4977a4380f21", size = 33734, upload-time = "2025-12-16T00:40:27.028Z" }, + { url = "https://files.pythonhosted.org/packages/08/a5/7b059810934a09fb3ccb657e0843813c1fee1183d3bc2c8041800374aa2c/google_crc32c-1.8.0-cp314-cp314-win_amd64.whl", hash = "sha256:d511b3153e7011a27ab6ee6bb3a5404a55b994dc1a7322c0b87b29606d9790e2", size = 34878, upload-time = "2025-12-16T00:35:23.142Z" }, + { url = "https://files.pythonhosted.org/packages/52/c5/c171e4d8c44fec1422d801a6d2e5d7ddabd733eeda505c79730ee9607f07/google_crc32c-1.8.0-pp311-pypy311_pp73-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:87fa445064e7db928226b2e6f0d5304ab4cd0339e664a4e9a25029f384d9bb93", size = 28615, upload-time = "2025-12-16T00:40:29.298Z" }, + { url = "https://files.pythonhosted.org/packages/9c/97/7d75fe37a7a6ed171a2cf17117177e7aab7e6e0d115858741b41e9dd4254/google_crc32c-1.8.0-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:f639065ea2042d5c034bf258a9f085eaa7af0cd250667c0635a3118e8f92c69c", size = 28800, upload-time = "2025-12-16T00:40:30.322Z" }, +] + [[package]] name = "h11" version = "0.16.0" @@ -1015,11 +1062,14 @@ all = [ { name = "h5py" }, { name = "hf" }, { name = "huggingface-hub" }, + { name = "numcodecs", version = "0.13.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "pytest" }, { name = "pytest-cov" }, { name = "s3fs" }, { name = "warcio" }, - { name = "zarr" }, + { name = "zarr", version = "2.18.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "zarr", version = "3.1.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.11.*'" }, + { name = "zarr", version = "3.2.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, ] hdf5 = [ { name = "h5py" }, @@ -1041,11 +1091,14 @@ testing = [ { name = "h5py" }, { name = "hf" }, { name = "huggingface-hub" }, + { name = "numcodecs", version = "0.13.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "pytest" }, { name = "pytest-cov" }, { name = "s3fs" }, { name = "warcio" }, - { name = "zarr" }, + { name = "zarr", version = "2.18.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "zarr", version = "3.1.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.11.*'" }, + { name = "zarr", version = "3.2.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, ] text = [ { name = "warcio" }, @@ -1054,7 +1107,10 @@ video = [ { name = "av" }, ] zarr = [ - { name = "zarr" }, + { name = "numcodecs", version = "0.13.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "zarr", version = "2.18.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "zarr", version = "3.1.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.11.*'" }, + { name = "zarr", version = "3.2.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, ] [package.dev-dependencies] @@ -1085,6 +1141,7 @@ requires-dist = [ { name = "macrodata-refiner", extras = ["video"], marker = "extra == 'robotics'" }, { name = "macrodata-refiner", extras = ["zarr"], marker = "extra == 'testing'" }, { name = "msgspec", specifier = ">=0.20.0" }, + { name = "numcodecs", marker = "python_full_version < '3.11' and extra == 'zarr'", specifier = "<0.16" }, { name = "numpy" }, { name = "orjson" }, { name = "pyarrow" }, @@ -1092,7 +1149,8 @@ requires-dist = [ { name = "pytest-cov", marker = "extra == 'testing'", specifier = ">=5.0.0" }, { name = "s3fs", marker = "extra == 's3'" }, { name = "warcio", marker = "extra == 'text'" }, - { name = "zarr", marker = "extra == 'zarr'", specifier = ">=2.18,<3" }, + { name = "zarr", marker = "python_full_version >= '3.11' and extra == 'zarr'", specifier = ">=3" }, + { name = "zarr", marker = "python_full_version < '3.11' and extra == 'zarr'", specifier = ">=2.18,<3" }, ] provides-extras = ["huggingface", "video", "robotics", "text", "hdf5", "zarr", "s3", "testing", "all"] @@ -2789,15 +2847,63 @@ wheels = [ name = "zarr" version = "2.18.3" source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version < '3.11'", +] dependencies = [ - { name = "asciitree" }, - { name = "fasteners", marker = "sys_platform != 'emscripten'" }, + { name = "asciitree", marker = "python_full_version < '3.11'" }, + { name = "fasteners", marker = "python_full_version < '3.11' and sys_platform != 'emscripten'" }, { name = "numcodecs", version = "0.13.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "numcodecs", version = "0.16.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "numpy", version = "2.4.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/23/c4/187a21ce7cf7c8f00c060dd0e04c2a81139bb7b1ab178bba83f2e1134ce2/zarr-2.18.3.tar.gz", hash = "sha256:2580d8cb6dd84621771a10d31c4d777dca8a27706a1a89b29f42d2d37e2df5ce", size = 3603224, upload-time = "2024-09-04T23:20:16.595Z" } wheels = [ { url = "https://files.pythonhosted.org/packages/ed/c9/142095e654c2b97133ff71df60979422717b29738b08bc8a1709a5d5e0d0/zarr-2.18.3-py3-none-any.whl", hash = "sha256:b1f7dfd2496f436745cdd4c7bcf8d3b4bc1dceef5fdd0d589c87130d842496dd", size = 210723, upload-time = "2024-09-04T23:20:14.491Z" }, ] + +[[package]] +name = "zarr" +version = "3.1.6" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version == '3.11.*' and sys_platform == 'win32'", + "python_full_version == '3.11.*' and sys_platform == 'emscripten'", + "python_full_version == '3.11.*' and sys_platform != 'emscripten' and sys_platform != 'win32'", +] +dependencies = [ + { name = "donfig", marker = "python_full_version == '3.11.*'" }, + { name = "google-crc32c", marker = "python_full_version == '3.11.*'" }, + { name = "numcodecs", version = "0.16.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.11.*'" }, + { name = "numpy", version = "2.4.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.11.*'" }, + { name = "packaging", marker = "python_full_version == '3.11.*'" }, + { name = "typing-extensions", marker = "python_full_version == '3.11.*'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/31/5a/b8a0cf39a14c770c30bd1f2d120c54000c8cd9e84e8e79f38d9a7ce58071/zarr-3.1.6.tar.gz", hash = "sha256:d95e72cbea4b90e9a70679468b8266400331756232576ae2b43400ac5108d0eb", size = 386531, upload-time = "2026-03-23T17:25:18.748Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/de/7c/ba8ca8cbe9dbef8e83a95fc208fed8e6686c98b4719aaa0aa7f3d31fe390/zarr-3.1.6-py3-none-any.whl", hash = "sha256:b5a82c5079d1c3d4ee8f06746fa3b9a98a7d804300fa3f4be154362a33e1207e", size = 295655, upload-time = "2026-03-23T17:25:17.189Z" }, +] + +[[package]] +name = "zarr" +version = "3.2.1" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.14' and sys_platform == 'win32'", + "python_full_version >= '3.14' and sys_platform == 'emscripten'", + "python_full_version >= '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform == 'win32'", + "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform == 'emscripten'", + "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32'", +] +dependencies = [ + { name = "donfig", marker = "python_full_version >= '3.12'" }, + { name = "google-crc32c", marker = "python_full_version >= '3.12'" }, + { name = "numcodecs", version = "0.16.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, + { name = "numpy", version = "2.4.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, + { name = "packaging", marker = "python_full_version >= '3.12'" }, + { name = "typing-extensions", marker = "python_full_version >= '3.12'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/93/8d/aeb164004f87543b06ef54f885d02c342c31ceb274e2bbec470a98927621/zarr-3.2.1.tar.gz", hash = "sha256:71565b738a0e7e8ed226f0516eba8c6bb53440ad7669a8c48ebb3534a161d035", size = 675161, upload-time = "2026-05-05T12:37:22.383Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/0a/469e2bd01be1490336e6c8707386845655d59261543315778a3ccc7e8019/zarr-3.2.1-py3-none-any.whl", hash = "sha256:f78cdd3d9687ad0e9f9cba2c5683b64f0c52589c19f685eeabe872e93cc0d2c7", size = 319617, upload-time = "2026-05-05T12:37:20.66Z" }, +] From bb7e78d7018cb1733f1fe561d68d6a88a6dc7bbe Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sat, 23 May 2026 18:15:01 +0200 Subject: [PATCH 24/39] Improve Zarr robotics sharding --- docs/reading-and-writing.md | 9 +++- src/refiner/pipeline/sources/readers/zarr.py | 16 ++++-- tests/readers/test_zarr_reader.py | 54 ++++++++++++++++++++ 3 files changed, 73 insertions(+), 6 deletions(-) diff --git a/docs/reading-and-writing.md b/docs/reading-and-writing.md index 77e3ca79..5f2223da 100644 --- a/docs/reading-and-writing.md +++ b/docs/reading-and-writing.md @@ -236,8 +236,9 @@ Zarr support lives behind the optional `macrodata-refiner[zarr]` extra. uv add "macrodata-refiner[zarr]" ``` -`read_zarr(...)` reads one Zarr group. By default, the group becomes one output -row and selected arrays are loaded as full array values. +`read_zarr(...)` reads one Zarr group, including directory stores and local +`.zarr.zip` stores. By default, the group becomes one output row and selected +arrays are loaded as full array values. ```python import refiner as mdr @@ -308,6 +309,10 @@ rows = mdr.read_zarr( ) ``` +Shard planning in this mode aligns to the dominant selected arrays by per-row +size, so large image arrays drive chunking instead of tiny low-dimensional +state/action arrays that may be stored as one large chunk. + This mode requires selected arrays to have the same leading dimension, and that dimension must be divisible by `leading_axis_row_size`. Each output row contains `leading_axis_row_size` contiguous items from every selected array. Refiner plans diff --git a/src/refiner/pipeline/sources/readers/zarr.py b/src/refiner/pipeline/sources/readers/zarr.py index 4fa08064..d14a5d7d 100644 --- a/src/refiner/pipeline/sources/readers/zarr.py +++ b/src/refiner/pipeline/sources/readers/zarr.py @@ -215,9 +215,12 @@ def _open_group(self) -> Any: import zarr import zarr.storage - if hasattr(zarr.storage, "FsspecStore"): + path = self.root.abs_path() + if path.endswith(".zip"): + store = zarr.storage.ZipStore(path, mode="r") + elif hasattr(zarr.storage, "FsspecStore"): store = zarr.storage.FsspecStore.from_url( - self.root.abs_path(), + path, read_only=True, ) else: @@ -331,11 +334,15 @@ def _shard_ranges( if self.num_shards is not None: step = ceil(row_count / self.num_shards) else: + item_bytes = [ + (array, _leading_item_bytes(array)) for array in arrays.values() + ] bytes_per_row = ( - sum(_leading_item_bytes(array) for array in arrays.values()) + sum(bytes_count for _, bytes_count in item_bytes) * self.leading_axis_row_size ) target_rows = max(1, self.target_shard_bytes // max(1, bytes_per_row)) + largest_item_bytes = max(bytes_count for _, bytes_count in item_bytes) chunk_rows = max( 1, ceil( @@ -343,7 +350,8 @@ def _shard_ranges( int(array.chunks[0]) if array.chunks else int(array.shape[0]) - for array in arrays.values() + for array, bytes_count in item_bytes + if bytes_count == largest_item_bytes ) / self.leading_axis_row_size ), diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py index 6180a669..4fb3e30f 100644 --- a/tests/readers/test_zarr_reader.py +++ b/tests/readers/test_zarr_reader.py @@ -1,6 +1,7 @@ from __future__ import annotations from pathlib import Path +import shutil from typing import Any, Literal, cast import numpy as np @@ -118,6 +119,22 @@ def test_read_zarr_splits_arrays_by_row_ends(tmp_path: Path) -> None: ) +def test_read_zarr_reads_zip_store(tmp_path: Path) -> None: + path = tmp_path / "policy.zarr" + _write_policy_zarr(path) + zip_path = Path(shutil.make_archive(str(path), "zip", root_dir=path)) + + row = mdr.read_zarr( + str(zip_path), + arrays={"action": "data/action", "frames": "data/rgb"}, + row_ends="meta/episode_ends", + file_path_column=None, + ).take(1)[0] + + assert row["action"].shape == (2, 1) + assert row["frames"].shape == (2, 4, 4, 3) + + def test_read_zarr_plans_row_ends_with_num_shards(tmp_path: Path) -> None: path = tmp_path / "policy.zarr" _write_policy_zarr(path) @@ -297,6 +314,43 @@ def test_read_zarr_split_leading_axis_uses_row_size(tmp_path: Path) -> None: np.testing.assert_allclose(rows[1]["action"], [[2.0], [3.0]]) +def test_read_zarr_split_leading_axis_uses_dominant_array_chunks( + tmp_path: Path, +) -> None: + path = tmp_path / "image_dominant_chunks.zarr" + root = _open_test_zarr(path, mode="w") + _create_array( + root, + "data/action", + data=np.arange(10, dtype=np.float32).reshape(10, 1), + chunks=(10, 1), + ) + _create_array( + root, + "data/rgb", + data=np.zeros((10, 4, 4, 3), dtype=np.uint8), + chunks=(2, 4, 4, 3), + ) + + pipeline = mdr.read_zarr( + path, + arrays={"action": "data/action", "rgb": "data/rgb"}, + split_leading_axis=True, + target_shard_bytes=160, + file_path_column=None, + ) + + shards = pipeline.source.list_shards() + ranges = [cast(RowRangeDescriptor, shard.descriptor) for shard in shards] + assert [(item.start, item.end) for item in ranges] == [ + (0, 2), + (2, 4), + (4, 6), + (6, 8), + (8, 10), + ] + + def test_read_zarr_split_leading_axis_requires_aligned_lengths(tmp_path: Path) -> None: path = tmp_path / "misaligned.zarr" root = _open_test_zarr(path, mode="w") From 0416d9d1edb7045a1f44a0b7c471e588731747e8 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sat, 23 May 2026 18:18:04 +0200 Subject: [PATCH 25/39] Preserve Zarr filesystem handles --- src/refiner/pipeline/sources/readers/zarr.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/src/refiner/pipeline/sources/readers/zarr.py b/src/refiner/pipeline/sources/readers/zarr.py index d14a5d7d..dcb05254 100644 --- a/src/refiner/pipeline/sources/readers/zarr.py +++ b/src/refiner/pipeline/sources/readers/zarr.py @@ -3,7 +3,7 @@ from collections.abc import Iterator, Mapping from math import ceil, prod from operator import index as integer_index -from typing import Any +from typing import Any, cast import pyarrow as pa @@ -219,8 +219,22 @@ def _open_group(self) -> Any: if path.endswith(".zip"): store = zarr.storage.ZipStore(path, mode="r") elif hasattr(zarr.storage, "FsspecStore"): - store = zarr.storage.FsspecStore.from_url( - path, + fs = self.root.fs + if fs.async_impl and not fs.asynchronous: + import json + + import fsspec + + fs_config = json.loads(fs.to_json()) + fs_config["asynchronous"] = True + fs = fsspec.AbstractFileSystem.from_json(json.dumps(fs_config)) + elif not fs.async_impl: + from fsspec.implementations.asyn_wrapper import AsyncFileSystemWrapper + + fs = AsyncFileSystemWrapper(fs, asynchronous=True) + store = zarr.storage.FsspecStore( + fs=cast(Any, fs), + path=self.root._join("").rstrip("/"), read_only=True, ) else: From 81ab4f6bc9667d7065e63a89c4907cd3a0ab3faf Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sat, 23 May 2026 19:29:44 +0200 Subject: [PATCH 26/39] Clarify Zarr split sharding docs --- docs/reading-and-writing.md | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/docs/reading-and-writing.md b/docs/reading-and-writing.md index 5f2223da..7a5047e9 100644 --- a/docs/reading-and-writing.md +++ b/docs/reading-and-writing.md @@ -309,16 +309,18 @@ rows = mdr.read_zarr( ) ``` -Shard planning in this mode aligns to the dominant selected arrays by per-row -size, so large image arrays drive chunking instead of tiny low-dimensional -state/action arrays that may be stored as one large chunk. +Shard planning in this mode uses chunk metadata from the selected array with the +largest per-row byte size (`dtype.itemsize * product(shape[1:])`). This keeps +large image/video arrays in control of shard boundaries, so tiny action/state +arrays stored as one huge chunk do not force Refiner to load a much larger image +block than necessary. This mode requires selected arrays to have the same leading dimension, and that dimension must be divisible by `leading_axis_row_size`. Each output row contains `leading_axis_row_size` contiguous items from every selected array. Refiner plans -shards from array metadata and avoids splitting shards below the largest selected -leading-axis chunk where possible. Use `num_shards` when you need a target shard -count instead of byte-sized packing. +shards from array metadata and tries to keep shard boundaries aligned with the +dominant array's leading-axis chunks. Use `num_shards` when you need a target +shard count instead of byte-sized packing. `row_ends` is reader control metadata, not an output selection. If you also want the raw offsets as a column in non-split mode, select that path through `arrays`. From 08928c239c43f4dbcc55e1f6233c0b803a7ae9ca Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sat, 23 May 2026 20:10:50 +0200 Subject: [PATCH 27/39] Use fsspec for zipped Zarr stores --- docs/reading-and-writing.md | 6 +- pyproject.toml | 5 +- src/refiner/io/datafolder.py | 12 +- src/refiner/pipeline/sources/readers/zarr.py | 64 +++---- tests/readers/test_zarr_reader.py | 2 +- uv.lock | 174 ++----------------- 6 files changed, 62 insertions(+), 201 deletions(-) diff --git a/docs/reading-and-writing.md b/docs/reading-and-writing.md index 7a5047e9..74c2ad0d 100644 --- a/docs/reading-and-writing.md +++ b/docs/reading-and-writing.md @@ -236,9 +236,9 @@ Zarr support lives behind the optional `macrodata-refiner[zarr]` extra. uv add "macrodata-refiner[zarr]" ``` -`read_zarr(...)` reads one Zarr group, including directory stores and local -`.zarr.zip` stores. By default, the group becomes one output row and selected -arrays are loaded as full array values. +`read_zarr(...)` reads one Zarr group, including directory stores and +`.zarr.zip` stores mounted through fsspec. By default, the group becomes one +output row and selected arrays are loaded as full array values. ```python import refiner as mdr diff --git a/pyproject.toml b/pyproject.toml index b1c8f122..4001aa0c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,9 +45,8 @@ hdf5 = [ "h5py", ] zarr = [ - "zarr>=2.18,<3; python_version < '3.11'", - "zarr>=3; python_version >= '3.11'", - "numcodecs<0.16; python_version < '3.11'", + "zarr>=2.18,<3", + "numcodecs<0.16", ] s3 = [ "s3fs", diff --git a/src/refiner/io/datafolder.py b/src/refiner/io/datafolder.py index 501c3bd0..70e0ac4b 100644 --- a/src/refiner/io/datafolder.py +++ b/src/refiner/io/datafolder.py @@ -40,10 +40,14 @@ def __init__( auto_mkdir: if True, when opening a file in write mode its parent directories will be automatically created **storage_options: will be passed to a new fsspec filesystem object, when it is created. Ignored if fs is given """ - super().__init__( - path=path, - fs=fs if fs is not None else url_to_fs(path, **storage_options)[0], - ) + if fs is None: + fs, path = url_to_fs(path, **storage_options) + path = "/" if path is None else path + else: + path = fs._strip_protocol(path) + if path == "": + path = "/" + super().__init__(path=path, fs=fs) self.auto_mkdir = auto_mkdir @classmethod diff --git a/src/refiner/pipeline/sources/readers/zarr.py b/src/refiner/pipeline/sources/readers/zarr.py index dcb05254..5f3bdc85 100644 --- a/src/refiner/pipeline/sources/readers/zarr.py +++ b/src/refiner/pipeline/sources/readers/zarr.py @@ -3,10 +3,14 @@ from collections.abc import Iterator, Mapping from math import ceil, prod from operator import index as integer_index -from typing import Any, cast +from os import PathLike +from typing import Any +from fsspec import AbstractFileSystem +from fsspec.implementations.zip import ZipFileSystem import pyarrow as pa +from refiner.io.datafile import DataFile, DataFileLike from refiner.io.datafolder import DataFolder, DataFolderLike from refiner.pipeline.data.datatype import ( DTypeMapping, @@ -68,7 +72,32 @@ def __init__( or None to omit it. dtypes: Optional dtype overrides for output columns. """ - self.root = DataFolder.resolve(input) + zip_input: DataFileLike | None = None + if isinstance(input, PathLike): + input = str(input) + if isinstance(input, str) and input.endswith(".zip"): + zip_input = input + elif ( + isinstance(input, tuple) + and len(input) == 2 + and isinstance(input[1], AbstractFileSystem) + ): + path = input[0] + if isinstance(path, PathLike): + path = str(path) + if isinstance(path, str) and path.endswith(".zip"): + zip_input = (path, input[1]) + + if zip_input is not None: + zip_file = DataFile.resolve(zip_input) + self.root = DataFolder( + "/", + fs=ZipFileSystem(fo=zip_file.open("rb"), mode="r"), + ) + self.source_path = zip_file.abs_path() + else: + self.root = DataFolder.resolve(input) + self.source_path = self.root.abs_path() check_required_dependencies("read_zarr", ["zarr"], dist="zarr") if row_ends is not None and split_leading_axis: raise ValueError("row_ends and split_leading_axis are mutually exclusive") @@ -114,7 +143,7 @@ def schema(self) -> pa.Schema | None: def describe(self) -> dict[str, Any]: return { - "path": self.root.abs_path(), + "path": self.source_path, "arrays": dict(self.arrays) if self.arrays is not None else None, "attrs": dict(self.attrs) if self.attrs is not None else None, "row_ends": self.row_ends, @@ -132,7 +161,7 @@ def describe(self) -> dict[str, Any]: } def list_shards(self) -> list[Shard]: - path = self.root.abs_path() + path = self.source_path group = self._open_group() arrays = self._selected_arrays(group, validate_names=True) @@ -215,30 +244,7 @@ def _open_group(self) -> Any: import zarr import zarr.storage - path = self.root.abs_path() - if path.endswith(".zip"): - store = zarr.storage.ZipStore(path, mode="r") - elif hasattr(zarr.storage, "FsspecStore"): - fs = self.root.fs - if fs.async_impl and not fs.asynchronous: - import json - - import fsspec - - fs_config = json.loads(fs.to_json()) - fs_config["asynchronous"] = True - fs = fsspec.AbstractFileSystem.from_json(json.dumps(fs_config)) - elif not fs.async_impl: - from fsspec.implementations.asyn_wrapper import AsyncFileSystemWrapper - - fs = AsyncFileSystemWrapper(fs, asynchronous=True) - store = zarr.storage.FsspecStore( - fs=cast(Any, fs), - path=self.root._join("").rstrip("/"), - read_only=True, - ) - else: - store = zarr.storage.FSStore(self.root._join(""), fs=self.root.fs, mode="r") + store = zarr.storage.FSStore(self.root._join(""), fs=self.root.fs, mode="r") return zarr.open_group(store=store, mode="r") def _reserved_output_names(self, *, split: bool) -> set[str]: @@ -252,7 +258,7 @@ def _reserved_output_names(self, *, split: bool) -> set[str]: def _row_metadata(self, *, index: int | None) -> dict[str, Any]: row: dict[str, Any] = {} if self.file_path_column is not None: - row[self.file_path_column] = self.root.abs_path() + row[self.file_path_column] = self.source_path if self.index_column is not None and index is not None: row[self.index_column] = index return row diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py index 4fb3e30f..4fdfa0ee 100644 --- a/tests/readers/test_zarr_reader.py +++ b/tests/readers/test_zarr_reader.py @@ -128,9 +128,9 @@ def test_read_zarr_reads_zip_store(tmp_path: Path) -> None: str(zip_path), arrays={"action": "data/action", "frames": "data/rgb"}, row_ends="meta/episode_ends", - file_path_column=None, ).take(1)[0] + assert row["file_path"] == str(zip_path) assert row["action"].shape == (2, 1) assert row["frames"].shape == (2, 4, 4, 3) diff --git a/uv.lock b/uv.lock index b9215124..d844aa09 100644 --- a/uv.lock +++ b/uv.lock @@ -618,18 +618,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/33/6b/e0547afaf41bf2c42e52430072fa5658766e3d65bd4b03a563d1b6336f57/distlib-0.4.0-py2.py3-none-any.whl", hash = "sha256:9659f7d87e46584a30b5780e43ac7a2143098441670ff0a49d5f9034c54a6c16", size = 469047, upload-time = "2025-07-17T16:51:58.613Z" }, ] -[[package]] -name = "donfig" -version = "0.8.1.post1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "pyyaml", marker = "python_full_version >= '3.11'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/25/71/80cc718ff6d7abfbabacb1f57aaa42e9c1552bfdd01e64ddd704e4a03638/donfig-0.8.1.post1.tar.gz", hash = "sha256:3bef3413a4c1c601b585e8d297256d0c1470ea012afa6e8461dc28bfb7c23f52", size = 19506, upload-time = "2024-05-23T14:14:31.513Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/0c/d5/c5db1ea3394c6e1732fb3286b3bd878b59507a8f77d32a2cebda7d7b7cd4/donfig-0.8.1.post1-py3-none-any.whl", hash = "sha256:2a3175ce74a06109ff9307d90a230f81215cbac9a751f4d1c6194644b8204f9d", size = 21592, upload-time = "2024-05-23T14:13:55.283Z" }, -] - [[package]] name = "exceptiongroup" version = "1.3.1" @@ -795,41 +783,6 @@ http = [ { name = "aiohttp" }, ] -[[package]] -name = "google-crc32c" -version = "1.8.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/03/41/4b9c02f99e4c5fb477122cd5437403b552873f014616ac1d19ac8221a58d/google_crc32c-1.8.0.tar.gz", hash = "sha256:a428e25fb7691024de47fecfbff7ff957214da51eddded0da0ae0e0f03a2cf79", size = 14192, upload-time = "2025-12-16T00:35:25.142Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/95/ac/6f7bc93886a823ab545948c2dd48143027b2355ad1944c7cf852b338dc91/google_crc32c-1.8.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:0470b8c3d73b5f4e3300165498e4cf25221c7eb37f1159e221d1825b6df8a7ff", size = 31296, upload-time = "2025-12-16T00:19:07.261Z" }, - { url = "https://files.pythonhosted.org/packages/f7/97/a5accde175dee985311d949cfcb1249dcbb290f5ec83c994ea733311948f/google_crc32c-1.8.0-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:119fcd90c57c89f30040b47c211acee231b25a45d225e3225294386f5d258288", size = 30870, upload-time = "2025-12-16T00:29:17.669Z" }, - { url = "https://files.pythonhosted.org/packages/3d/63/bec827e70b7a0d4094e7476f863c0dbd6b5f0f1f91d9c9b32b76dcdfeb4e/google_crc32c-1.8.0-cp310-cp310-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:6f35aaffc8ccd81ba3162443fabb920e65b1f20ab1952a31b13173a67811467d", size = 33214, upload-time = "2025-12-16T00:40:19.618Z" }, - { url = "https://files.pythonhosted.org/packages/63/bc/11b70614df04c289128d782efc084b9035ef8466b3d0a8757c1b6f5cf7ac/google_crc32c-1.8.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:864abafe7d6e2c4c66395c1eb0fe12dc891879769b52a3d56499612ca93b6092", size = 33589, upload-time = "2025-12-16T00:40:20.7Z" }, - { url = "https://files.pythonhosted.org/packages/3e/00/a08a4bc24f1261cc5b0f47312d8aebfbe4b53c2e6307f1b595605eed246b/google_crc32c-1.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:db3fe8eaf0612fc8b20fa21a5f25bd785bc3cd5be69f8f3412b0ac2ffd49e733", size = 34437, upload-time = "2025-12-16T00:35:19.437Z" }, - { url = "https://files.pythonhosted.org/packages/5d/ef/21ccfaab3d5078d41efe8612e0ed0bfc9ce22475de074162a91a25f7980d/google_crc32c-1.8.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:014a7e68d623e9a4222d663931febc3033c5c7c9730785727de2a81f87d5bab8", size = 31298, upload-time = "2025-12-16T00:20:32.241Z" }, - { url = "https://files.pythonhosted.org/packages/c5/b8/f8413d3f4b676136e965e764ceedec904fe38ae8de0cdc52a12d8eb1096e/google_crc32c-1.8.0-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:86cfc00fe45a0ac7359e5214a1704e51a99e757d0272554874f419f79838c5f7", size = 30872, upload-time = "2025-12-16T00:33:58.785Z" }, - { url = "https://files.pythonhosted.org/packages/f6/fd/33aa4ec62b290477181c55bb1c9302c9698c58c0ce9a6ab4874abc8b0d60/google_crc32c-1.8.0-cp311-cp311-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:19b40d637a54cb71e0829179f6cb41835f0fbd9e8eb60552152a8b52c36cbe15", size = 33243, upload-time = "2025-12-16T00:40:21.46Z" }, - { url = "https://files.pythonhosted.org/packages/71/03/4820b3bd99c9653d1a5210cb32f9ba4da9681619b4d35b6a052432df4773/google_crc32c-1.8.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:17446feb05abddc187e5441a45971b8394ea4c1b6efd88ab0af393fd9e0a156a", size = 33608, upload-time = "2025-12-16T00:40:22.204Z" }, - { url = "https://files.pythonhosted.org/packages/7c/43/acf61476a11437bf9733fb2f70599b1ced11ec7ed9ea760fdd9a77d0c619/google_crc32c-1.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:71734788a88f551fbd6a97be9668a0020698e07b2bf5b3aa26a36c10cdfb27b2", size = 34439, upload-time = "2025-12-16T00:35:20.458Z" }, - { url = "https://files.pythonhosted.org/packages/e9/5f/7307325b1198b59324c0fa9807cafb551afb65e831699f2ce211ad5c8240/google_crc32c-1.8.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:4b8286b659c1335172e39563ab0a768b8015e88e08329fa5321f774275fc3113", size = 31300, upload-time = "2025-12-16T00:21:56.723Z" }, - { url = "https://files.pythonhosted.org/packages/21/8e/58c0d5d86e2220e6a37befe7e6a94dd2f6006044b1a33edf1ff6d9f7e319/google_crc32c-1.8.0-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:2a3dc3318507de089c5384cc74d54318401410f82aa65b2d9cdde9d297aca7cb", size = 30867, upload-time = "2025-12-16T00:38:31.302Z" }, - { url = "https://files.pythonhosted.org/packages/ce/a9/a780cc66f86335a6019f557a8aaca8fbb970728f0efd2430d15ff1beae0e/google_crc32c-1.8.0-cp312-cp312-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:14f87e04d613dfa218d6135e81b78272c3b904e2a7053b841481b38a7d901411", size = 33364, upload-time = "2025-12-16T00:40:22.96Z" }, - { url = "https://files.pythonhosted.org/packages/21/3f/3457ea803db0198c9aaca2dd373750972ce28a26f00544b6b85088811939/google_crc32c-1.8.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cb5c869c2923d56cb0c8e6bcdd73c009c36ae39b652dbe46a05eb4ef0ad01454", size = 33740, upload-time = "2025-12-16T00:40:23.96Z" }, - { url = "https://files.pythonhosted.org/packages/df/c0/87c2073e0c72515bb8733d4eef7b21548e8d189f094b5dad20b0ecaf64f6/google_crc32c-1.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:3cc0c8912038065eafa603b238abf252e204accab2a704c63b9e14837a854962", size = 34437, upload-time = "2025-12-16T00:35:21.395Z" }, - { url = "https://files.pythonhosted.org/packages/d1/db/000f15b41724589b0e7bc24bc7a8967898d8d3bc8caf64c513d91ef1f6c0/google_crc32c-1.8.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:3ebb04528e83b2634857f43f9bb8ef5b2bbe7f10f140daeb01b58f972d04736b", size = 31297, upload-time = "2025-12-16T00:23:20.709Z" }, - { url = "https://files.pythonhosted.org/packages/d7/0d/8ebed0c39c53a7e838e2a486da8abb0e52de135f1b376ae2f0b160eb4c1a/google_crc32c-1.8.0-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:450dc98429d3e33ed2926fc99ee81001928d63460f8538f21a5d6060912a8e27", size = 30867, upload-time = "2025-12-16T00:43:14.628Z" }, - { url = "https://files.pythonhosted.org/packages/ce/42/b468aec74a0354b34c8cbf748db20d6e350a68a2b0912e128cabee49806c/google_crc32c-1.8.0-cp313-cp313-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:3b9776774b24ba76831609ffbabce8cdf6fa2bd5e9df37b594221c7e333a81fa", size = 33344, upload-time = "2025-12-16T00:40:24.742Z" }, - { url = "https://files.pythonhosted.org/packages/1c/e8/b33784d6fc77fb5062a8a7854e43e1e618b87d5ddf610a88025e4de6226e/google_crc32c-1.8.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:89c17d53d75562edfff86679244830599ee0a48efc216200691de8b02ab6b2b8", size = 33694, upload-time = "2025-12-16T00:40:25.505Z" }, - { url = "https://files.pythonhosted.org/packages/92/b1/d3cbd4d988afb3d8e4db94ca953df429ed6db7282ed0e700d25e6c7bfc8d/google_crc32c-1.8.0-cp313-cp313-win_amd64.whl", hash = "sha256:57a50a9035b75643996fbf224d6661e386c7162d1dfdab9bc4ca790947d1007f", size = 34435, upload-time = "2025-12-16T00:35:22.107Z" }, - { url = "https://files.pythonhosted.org/packages/21/88/8ecf3c2b864a490b9e7010c84fd203ec8cf3b280651106a3a74dd1b0ca72/google_crc32c-1.8.0-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:e6584b12cb06796d285d09e33f63309a09368b9d806a551d8036a4207ea43697", size = 31301, upload-time = "2025-12-16T00:24:48.527Z" }, - { url = "https://files.pythonhosted.org/packages/36/c6/f7ff6c11f5ca215d9f43d3629163727a272eabc356e5c9b2853df2bfe965/google_crc32c-1.8.0-cp314-cp314-macosx_12_0_x86_64.whl", hash = "sha256:f4b51844ef67d6cf2e9425983274da75f18b1597bb2c998e1c0a0e8d46f8f651", size = 30868, upload-time = "2025-12-16T00:48:12.163Z" }, - { url = "https://files.pythonhosted.org/packages/56/15/c25671c7aad70f8179d858c55a6ae8404902abe0cdcf32a29d581792b491/google_crc32c-1.8.0-cp314-cp314-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:b0d1a7afc6e8e4635564ba8aa5c0548e3173e41b6384d7711a9123165f582de2", size = 33381, upload-time = "2025-12-16T00:40:26.268Z" }, - { url = "https://files.pythonhosted.org/packages/42/fa/f50f51260d7b0ef5d4898af122d8a7ec5a84e2984f676f746445f783705f/google_crc32c-1.8.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8b3f68782f3cbd1bce027e48768293072813469af6a61a86f6bb4977a4380f21", size = 33734, upload-time = "2025-12-16T00:40:27.028Z" }, - { url = "https://files.pythonhosted.org/packages/08/a5/7b059810934a09fb3ccb657e0843813c1fee1183d3bc2c8041800374aa2c/google_crc32c-1.8.0-cp314-cp314-win_amd64.whl", hash = "sha256:d511b3153e7011a27ab6ee6bb3a5404a55b994dc1a7322c0b87b29606d9790e2", size = 34878, upload-time = "2025-12-16T00:35:23.142Z" }, - { url = "https://files.pythonhosted.org/packages/52/c5/c171e4d8c44fec1422d801a6d2e5d7ddabd733eeda505c79730ee9607f07/google_crc32c-1.8.0-pp311-pypy311_pp73-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:87fa445064e7db928226b2e6f0d5304ab4cd0339e664a4e9a25029f384d9bb93", size = 28615, upload-time = "2025-12-16T00:40:29.298Z" }, - { url = "https://files.pythonhosted.org/packages/9c/97/7d75fe37a7a6ed171a2cf17117177e7aab7e6e0d115858741b41e9dd4254/google_crc32c-1.8.0-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:f639065ea2042d5c034bf258a9f085eaa7af0cd250667c0635a3118e8f92c69c", size = 28800, upload-time = "2025-12-16T00:40:30.322Z" }, -] - [[package]] name = "h11" version = "0.16.0" @@ -1062,14 +1015,12 @@ all = [ { name = "h5py" }, { name = "hf" }, { name = "huggingface-hub" }, - { name = "numcodecs", version = "0.13.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "numcodecs" }, { name = "pytest" }, { name = "pytest-cov" }, { name = "s3fs" }, { name = "warcio" }, - { name = "zarr", version = "2.18.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "zarr", version = "3.1.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.11.*'" }, - { name = "zarr", version = "3.2.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, + { name = "zarr" }, ] hdf5 = [ { name = "h5py" }, @@ -1091,14 +1042,12 @@ testing = [ { name = "h5py" }, { name = "hf" }, { name = "huggingface-hub" }, - { name = "numcodecs", version = "0.13.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "numcodecs" }, { name = "pytest" }, { name = "pytest-cov" }, { name = "s3fs" }, { name = "warcio" }, - { name = "zarr", version = "2.18.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "zarr", version = "3.1.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.11.*'" }, - { name = "zarr", version = "3.2.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, + { name = "zarr" }, ] text = [ { name = "warcio" }, @@ -1107,10 +1056,8 @@ video = [ { name = "av" }, ] zarr = [ - { name = "numcodecs", version = "0.13.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "zarr", version = "2.18.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "zarr", version = "3.1.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.11.*'" }, - { name = "zarr", version = "3.2.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, + { name = "numcodecs" }, + { name = "zarr" }, ] [package.dev-dependencies] @@ -1141,7 +1088,7 @@ requires-dist = [ { name = "macrodata-refiner", extras = ["video"], marker = "extra == 'robotics'" }, { name = "macrodata-refiner", extras = ["zarr"], marker = "extra == 'testing'" }, { name = "msgspec", specifier = ">=0.20.0" }, - { name = "numcodecs", marker = "python_full_version < '3.11' and extra == 'zarr'", specifier = "<0.16" }, + { name = "numcodecs", marker = "extra == 'zarr'", specifier = "<0.16" }, { name = "numpy" }, { name = "orjson" }, { name = "pyarrow" }, @@ -1149,8 +1096,7 @@ requires-dist = [ { name = "pytest-cov", marker = "extra == 'testing'", specifier = ">=5.0.0" }, { name = "s3fs", marker = "extra == 's3'" }, { name = "warcio", marker = "extra == 'text'" }, - { name = "zarr", marker = "python_full_version >= '3.11' and extra == 'zarr'", specifier = ">=3" }, - { name = "zarr", marker = "python_full_version < '3.11' and extra == 'zarr'", specifier = ">=2.18,<3" }, + { name = "zarr", marker = "extra == 'zarr'", specifier = ">=2.18,<3" }, ] provides-extras = ["huggingface", "video", "robotics", "text", "hdf5", "zarr", "s3", "testing", "all"] @@ -1413,11 +1359,9 @@ wheels = [ name = "numcodecs" version = "0.13.1" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version < '3.11'", -] dependencies = [ { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "numpy", version = "2.4.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/85/56/8895a76abe4ec94ebd01eeb6d74f587bc4cddd46569670e1402852a5da13/numcodecs-0.13.1.tar.gz", hash = "sha256:a3cf37881df0898f3a9c0d4477df88133fe85185bffe57ba31bcc2fa207709bc", size = 5955215, upload-time = "2024-10-09T16:28:00.188Z" } wheels = [ @@ -1439,49 +1383,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a6/c5/f3e56bc9b4e438a287fff738993d6d11abef368c0328a612ac2842ba9fca/numcodecs-0.13.1-cp313-cp313-win_amd64.whl", hash = "sha256:90d3065ae74c9342048ae0046006f99dcb1388b7288da5a19b3bddf9c30c3176", size = 821887, upload-time = "2024-10-09T16:27:55.039Z" }, ] -[[package]] -name = "numcodecs" -version = "0.16.5" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version >= '3.14' and sys_platform == 'win32'", - "python_full_version >= '3.14' and sys_platform == 'emscripten'", - "python_full_version >= '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32'", - "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform == 'win32'", - "python_full_version == '3.11.*' and sys_platform == 'win32'", - "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform == 'emscripten'", - "python_full_version == '3.11.*' and sys_platform == 'emscripten'", - "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32'", - "python_full_version == '3.11.*' and sys_platform != 'emscripten' and sys_platform != 'win32'", -] -dependencies = [ - { name = "numpy", version = "2.4.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, - { name = "typing-extensions", marker = "python_full_version >= '3.11'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/44/bd/8a391e7c356366224734efd24da929cc4796fff468bfb179fe1af6548535/numcodecs-0.16.5.tar.gz", hash = "sha256:0d0fb60852f84c0bd9543cc4d2ab9eefd37fc8efcc410acd4777e62a1d300318", size = 6276387, upload-time = "2025-11-21T02:49:48.986Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/af/85/1ac101a40ead81eaa1c7dc49a8827a30e2e436211b43ebdc63c590eb1347/numcodecs-0.16.5-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:78382dcea50622f2ef1e6e7a71dbe7f861d8fe376b27b7c297c26907304fef1e", size = 1621795, upload-time = "2025-11-21T02:49:17.418Z" }, - { url = "https://files.pythonhosted.org/packages/0e/cc/0d97ef55dda48cb0f93d7b92d761208e7a99bd2eea6b0e859426e6a99a21/numcodecs-0.16.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e2d04a19cb57a3c519b4127ac377cca6471aee1990d7c18f5b1e3a4fe1306689", size = 1153030, upload-time = "2025-11-21T02:49:19.089Z" }, - { url = "https://files.pythonhosted.org/packages/5e/41/e120ee1b390730ac5987cde2afd82e2b8442cec315ab40b94b0373e93e73/numcodecs-0.16.5-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c043af648eb280cd61785c99c22ff5c3c3460f906eb51a8511327c4f5111b283", size = 8510503, upload-time = "2025-11-21T02:49:20.324Z" }, - { url = "https://files.pythonhosted.org/packages/54/4b/195ac84cc8f6077b4f0f421e8daee21b7f1bd88cb7716414234379fe68ec/numcodecs-0.16.5-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c398919ef2eb0e56b8e97456f622640bfd3deed06de3acc976989cbcb22628a3", size = 9123428, upload-time = "2025-11-21T02:49:22.328Z" }, - { url = "https://files.pythonhosted.org/packages/0f/5b/af02c417954f46e5c7bd5163ac251f535877d909fce54861c99ae197f6f6/numcodecs-0.16.5-cp311-cp311-win_amd64.whl", hash = "sha256:3820860ed302d4d84a1c66e70981ff959d5eb712555be4e7d8ced49888594773", size = 801542, upload-time = "2025-11-21T02:49:24.265Z" }, - { url = "https://files.pythonhosted.org/packages/75/cc/55420f3641a67f78392dc0bc5d02cb9eb0a9dcebf2848d1ac77253ca61fa/numcodecs-0.16.5-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:24e675dc8d1550cd976a99479b87d872cb142632c75cc402fea04c08c4898523", size = 1656287, upload-time = "2025-11-21T02:49:25.755Z" }, - { url = "https://files.pythonhosted.org/packages/f5/6c/86644987505dcb90ba6d627d6989c27bafb0699f9fd00187e06d05ea8594/numcodecs-0.16.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:94ddfa4341d1a3ab99989d13b01b5134abb687d3dab2ead54b450aefe4ad5bd6", size = 1148899, upload-time = "2025-11-21T02:49:26.87Z" }, - { url = "https://files.pythonhosted.org/packages/97/1e/98aaddf272552d9fef1f0296a9939d1487914a239e98678f6b20f8b0a5c8/numcodecs-0.16.5-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b554ab9ecf69de7ca2b6b5e8bc696bd9747559cb4dd5127bd08d7a28bec59c3a", size = 8534814, upload-time = "2025-11-21T02:49:28.547Z" }, - { url = "https://files.pythonhosted.org/packages/fb/53/78c98ef5c8b2b784453487f3e4d6c017b20747c58b470393e230c78d18e8/numcodecs-0.16.5-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ad1a379a45bd3491deab8ae6548313946744f868c21d5340116977ea3be5b1d6", size = 9173471, upload-time = "2025-11-21T02:49:30.444Z" }, - { url = "https://files.pythonhosted.org/packages/1c/20/2fdec87fc7f8cec950d2b0bea603c12dc9f05b4966dc5924ba5a36a61bf6/numcodecs-0.16.5-cp312-cp312-win_amd64.whl", hash = "sha256:845a9857886ffe4a3172ba1c537ae5bcc01e65068c31cf1fce1a844bd1da050f", size = 801412, upload-time = "2025-11-21T02:49:32.123Z" }, - { url = "https://files.pythonhosted.org/packages/38/38/071ced5a5fd1c85ba0e14ba721b66b053823e5176298c2f707e50bed11d9/numcodecs-0.16.5-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:25be3a516ab677dad890760d357cfe081a371d9c0a2e9a204562318ac5969de3", size = 1654359, upload-time = "2025-11-21T02:49:33.673Z" }, - { url = "https://files.pythonhosted.org/packages/d1/c0/5f84ba7525577c1b9909fc2d06ef11314825fc4ad4378f61d0e4c9883b4a/numcodecs-0.16.5-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:0107e839ef75b854e969cb577e140b1aadb9847893937636582d23a2a4c6ce50", size = 1144237, upload-time = "2025-11-21T02:49:35.294Z" }, - { url = "https://files.pythonhosted.org/packages/0b/00/787ea5f237b8ea7bc67140c99155f9c00b5baf11c49afc5f3bfefa298f95/numcodecs-0.16.5-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:015a7c859ecc2a06e2a548f64008c0ec3aaecabc26456c2c62f4278d8fc20597", size = 8483064, upload-time = "2025-11-21T02:49:36.454Z" }, - { url = "https://files.pythonhosted.org/packages/c4/e6/d359fdd37498e74d26a167f7a51e54542e642ea47181eb4e643a69a066c3/numcodecs-0.16.5-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:84230b4b9dad2392f2a84242bd6e3e659ac137b5a1ce3571d6965fca673e0903", size = 9126063, upload-time = "2025-11-21T02:49:38.018Z" }, - { url = "https://files.pythonhosted.org/packages/27/72/6663cc0382ddbb866136c255c837bcb96cc7ce5e83562efec55e1b995941/numcodecs-0.16.5-cp313-cp313-win_amd64.whl", hash = "sha256:5088145502ad1ebf677ec47d00eb6f0fd600658217db3e0c070c321c85d6cf3d", size = 799275, upload-time = "2025-11-21T02:49:39.558Z" }, - { url = "https://files.pythonhosted.org/packages/3c/9e/38e7ca8184c958b51f45d56a4aeceb1134ecde2d8bd157efadc98502cc42/numcodecs-0.16.5-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:b05647b8b769e6bc8016e9fd4843c823ce5c9f2337c089fb5c9c4da05e5275de", size = 1654721, upload-time = "2025-11-21T02:49:40.602Z" }, - { url = "https://files.pythonhosted.org/packages/a1/37/260fa42e7b2b08e6e00ad632f8dd620961a60a459426c26cea390f8c68d0/numcodecs-0.16.5-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:3832bd1b5af8bb3e413076b7d93318c8e7d7b68935006b9fa36ca057d1725a8f", size = 1146887, upload-time = "2025-11-21T02:49:41.721Z" }, - { url = "https://files.pythonhosted.org/packages/4e/15/e2e1151b5a8b14a15dfd4bb4abccce7fff7580f39bc34092780088835f3a/numcodecs-0.16.5-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:49f7b7d24f103187f53135bed28bb9f0ed6b2e14c604664726487bb6d7c882e1", size = 8476987, upload-time = "2025-11-21T02:49:43.363Z" }, - { url = "https://files.pythonhosted.org/packages/6d/30/16a57fc4d9fb0ba06c600408bd6634f2f1753c54a7a351c99c5e09b51ee2/numcodecs-0.16.5-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:aec9736d81b70f337d89c4070ee3ffeff113f386fd789492fa152d26a15043e4", size = 9102377, upload-time = "2025-11-21T02:49:45.508Z" }, - { url = "https://files.pythonhosted.org/packages/31/a5/a0425af36c20d55a3ea884db4b4efca25a43bea9214ba69ca7932dd997b4/numcodecs-0.16.5-cp314-cp314-win_amd64.whl", hash = "sha256:b16a14303800e9fb88abc39463ab4706c037647ac17e49e297faa5f7d7dbbf1d", size = 819022, upload-time = "2025-11-21T02:49:47.39Z" }, -] - [[package]] name = "numpy" version = "2.2.6" @@ -2847,63 +2748,14 @@ wheels = [ name = "zarr" version = "2.18.3" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version < '3.11'", -] dependencies = [ - { name = "asciitree", marker = "python_full_version < '3.11'" }, - { name = "fasteners", marker = "python_full_version < '3.11' and sys_platform != 'emscripten'" }, - { name = "numcodecs", version = "0.13.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "asciitree" }, + { name = "fasteners", marker = "sys_platform != 'emscripten'" }, + { name = "numcodecs" }, { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "numpy", version = "2.4.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/23/c4/187a21ce7cf7c8f00c060dd0e04c2a81139bb7b1ab178bba83f2e1134ce2/zarr-2.18.3.tar.gz", hash = "sha256:2580d8cb6dd84621771a10d31c4d777dca8a27706a1a89b29f42d2d37e2df5ce", size = 3603224, upload-time = "2024-09-04T23:20:16.595Z" } wheels = [ { url = "https://files.pythonhosted.org/packages/ed/c9/142095e654c2b97133ff71df60979422717b29738b08bc8a1709a5d5e0d0/zarr-2.18.3-py3-none-any.whl", hash = "sha256:b1f7dfd2496f436745cdd4c7bcf8d3b4bc1dceef5fdd0d589c87130d842496dd", size = 210723, upload-time = "2024-09-04T23:20:14.491Z" }, ] - -[[package]] -name = "zarr" -version = "3.1.6" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version == '3.11.*' and sys_platform == 'win32'", - "python_full_version == '3.11.*' and sys_platform == 'emscripten'", - "python_full_version == '3.11.*' and sys_platform != 'emscripten' and sys_platform != 'win32'", -] -dependencies = [ - { name = "donfig", marker = "python_full_version == '3.11.*'" }, - { name = "google-crc32c", marker = "python_full_version == '3.11.*'" }, - { name = "numcodecs", version = "0.16.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.11.*'" }, - { name = "numpy", version = "2.4.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.11.*'" }, - { name = "packaging", marker = "python_full_version == '3.11.*'" }, - { name = "typing-extensions", marker = "python_full_version == '3.11.*'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/31/5a/b8a0cf39a14c770c30bd1f2d120c54000c8cd9e84e8e79f38d9a7ce58071/zarr-3.1.6.tar.gz", hash = "sha256:d95e72cbea4b90e9a70679468b8266400331756232576ae2b43400ac5108d0eb", size = 386531, upload-time = "2026-03-23T17:25:18.748Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/de/7c/ba8ca8cbe9dbef8e83a95fc208fed8e6686c98b4719aaa0aa7f3d31fe390/zarr-3.1.6-py3-none-any.whl", hash = "sha256:b5a82c5079d1c3d4ee8f06746fa3b9a98a7d804300fa3f4be154362a33e1207e", size = 295655, upload-time = "2026-03-23T17:25:17.189Z" }, -] - -[[package]] -name = "zarr" -version = "3.2.1" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version >= '3.14' and sys_platform == 'win32'", - "python_full_version >= '3.14' and sys_platform == 'emscripten'", - "python_full_version >= '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32'", - "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform == 'win32'", - "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform == 'emscripten'", - "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32'", -] -dependencies = [ - { name = "donfig", marker = "python_full_version >= '3.12'" }, - { name = "google-crc32c", marker = "python_full_version >= '3.12'" }, - { name = "numcodecs", version = "0.16.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, - { name = "numpy", version = "2.4.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, - { name = "packaging", marker = "python_full_version >= '3.12'" }, - { name = "typing-extensions", marker = "python_full_version >= '3.12'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/93/8d/aeb164004f87543b06ef54f885d02c342c31ceb274e2bbec470a98927621/zarr-3.2.1.tar.gz", hash = "sha256:71565b738a0e7e8ed226f0516eba8c6bb53440ad7669a8c48ebb3534a161d035", size = 675161, upload-time = "2026-05-05T12:37:22.383Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/88/0a/469e2bd01be1490336e6c8707386845655d59261543315778a3ccc7e8019/zarr-3.2.1-py3-none-any.whl", hash = "sha256:f78cdd3d9687ad0e9f9cba2c5683b64f0c52589c19f685eeabe872e93cc0d2c7", size = 319617, upload-time = "2026-05-05T12:37:20.66Z" }, -] From 89a865bbb4303c75b7d356d9dc3627d7ba59f9e8 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sat, 23 May 2026 21:06:41 +0200 Subject: [PATCH 28/39] Add Zarr row batch sizing --- docs/reading-and-writing.md | 4 + src/refiner/pipeline/pipeline.py | 5 +- src/refiner/pipeline/sources/readers/zarr.py | 99 ++++++++++++-------- tests/readers/test_zarr_reader.py | 50 ++++++++++ 4 files changed, 117 insertions(+), 41 deletions(-) diff --git a/docs/reading-and-writing.md b/docs/reading-and-writing.md index 74c2ad0d..b63c6ad7 100644 --- a/docs/reading-and-writing.md +++ b/docs/reading-and-writing.md @@ -322,6 +322,10 @@ shards from array metadata and tries to keep shard boundaries aligned with the dominant array's leading-axis chunks. Use `num_shards` when you need a target shard count instead of byte-sized packing. +By default, split readers load one shard block at a time and slice logical rows +from that block. Set `row_batch_size` to cap how many logical rows are loaded per +block when a shard would otherwise materialize too much data. + `row_ends` is reader control metadata, not an output selection. If you also want the raw offsets as a column in non-split mode, select that path through `arrays`. diff --git a/src/refiner/pipeline/pipeline.py b/src/refiner/pipeline/pipeline.py index 0ddd74ad..6ace95d2 100644 --- a/src/refiner/pipeline/pipeline.py +++ b/src/refiner/pipeline/pipeline.py @@ -819,6 +819,7 @@ def read_zarr( leading_axis_row_size: int = 1, target_shard_bytes: int = DEFAULT_TARGET_SHARD_BYTES, num_shards: int | None = None, + row_batch_size: int | None = None, index_column: str | None = "index", file_path_column: str | None = "file_path", dtypes: DTypeMapping | None = None, @@ -832,7 +833,8 @@ def read_zarr( Missing selected arrays or attributes raise immediately. `row_ends` and `split_leading_axis` are mutually exclusive. `target_shard_bytes` and - `num_shards` affect shard planning, not logical row size. + `num_shards` affect shard planning, not logical row size. `row_batch_size` + bounds how many logical rows are loaded per array block within each shard. """ return RefinerPipeline( source=ZarrReader( @@ -844,6 +846,7 @@ def read_zarr( leading_axis_row_size=leading_axis_row_size, target_shard_bytes=target_shard_bytes, num_shards=num_shards, + row_batch_size=row_batch_size, index_column=index_column, file_path_column=file_path_column, dtypes=dtypes, diff --git a/src/refiner/pipeline/sources/readers/zarr.py b/src/refiner/pipeline/sources/readers/zarr.py index 5f3bdc85..1e788753 100644 --- a/src/refiner/pipeline/sources/readers/zarr.py +++ b/src/refiner/pipeline/sources/readers/zarr.py @@ -45,6 +45,7 @@ def __init__( leading_axis_row_size: int = 1, target_shard_bytes: int = DEFAULT_TARGET_SHARD_BYTES, num_shards: int | None = None, + row_batch_size: int | None = None, index_column: str | None = "index", file_path_column: str | None = "file_path", dtypes: DTypeMapping | None = None, @@ -66,6 +67,8 @@ def __init__( target_shard_bytes: Approximate byte target used to pack logical rows into shards in split modes. num_shards: Optional target shard count for split modes. + row_batch_size: Optional maximum number of logical rows to load per + array block within a shard. None loads one whole shard block. index_column: Output metadata column containing the logical row index in split modes, or None to omit it. file_path_column: Output metadata column containing the source path, @@ -89,13 +92,11 @@ def __init__( zip_input = (path, input[1]) if zip_input is not None: - zip_file = DataFile.resolve(zip_input) - self.root = DataFolder( - "/", - fs=ZipFileSystem(fo=zip_file.open("rb"), mode="r"), - ) - self.source_path = zip_file.abs_path() + self.zip_file = DataFile.resolve(zip_input) + self.root = None + self.source_path = self.zip_file.abs_path() else: + self.zip_file = None self.root = DataFolder.resolve(input) self.source_path = self.root.abs_path() check_required_dependencies("read_zarr", ["zarr"], dist="zarr") @@ -109,6 +110,8 @@ def __init__( raise ValueError("target_shard_bytes must be greater than zero") if num_shards is not None and num_shards <= 0: raise ValueError("num_shards must be greater than zero") + if row_batch_size is not None and row_batch_size <= 0: + raise ValueError("row_batch_size must be greater than zero") self.arrays = ( None if arrays is None else path_selection_map(arrays, format_name="Zarr") ) @@ -120,6 +123,7 @@ def __init__( self.leading_axis_row_size = leading_axis_row_size self.target_shard_bytes = target_shard_bytes self.num_shards = num_shards + self.row_batch_size = row_batch_size self.index_column = index_column self.file_path_column = file_path_column self.dtypes = dtypes @@ -151,6 +155,7 @@ def describe(self) -> dict[str, Any]: "leading_axis_row_size": self.leading_axis_row_size, "target_shard_bytes": self.target_shard_bytes, "num_shards": self.num_shards, + "row_batch_size": self.row_batch_size, "index_column": self.index_column, "file_path_column": self.file_path_column, "dtypes": ( @@ -191,48 +196,50 @@ def read_shard(self, shard: Shard) -> Iterator[SourceUnit]: ) if not source_ranges: return - block_start = source_ranges[0][0] - block_end = source_ranges[-1][1] - block = self._read_arrays(arrays, start=block_start, end=block_end) attrs = self._read_attrs(group) - for row_index, (start, end) in zip( - range(descriptor.start, descriptor.end), - source_ranges, - strict=True, - ): - row = self._row_metadata(index=row_index) - row.update( - { - name: value[start - block_start : end - block_start] - for name, value in block.items() - } - ) - row.update(attrs) - yield DictRow(row) + batch_size = self.row_batch_size or len(source_ranges) + for batch_offset in range(0, len(source_ranges), batch_size): + batch = source_ranges[batch_offset : batch_offset + batch_size] + block_start = batch[0][0] + block_end = batch[-1][1] + block = self._read_arrays(arrays, start=block_start, end=block_end) + for offset, (start, end) in enumerate(batch, start=batch_offset): + row = self._row_metadata(index=descriptor.start + offset) + row.update( + { + name: value[start - block_start : end - block_start] + for name, value in block.items() + } + ) + row.update(attrs) + yield DictRow(row) return if self.split_leading_axis: descriptor = shard.descriptor assert isinstance(descriptor, RowRangeDescriptor) - raw_start = descriptor.start * self.leading_axis_row_size - raw_end = descriptor.end * self.leading_axis_row_size - block = self._read_arrays( - arrays, - start=raw_start, - end=raw_end, - ) attrs = self._read_attrs(group) - for row_index in range(descriptor.start, descriptor.end): - offset = (row_index - descriptor.start) * self.leading_axis_row_size - row = self._row_metadata(index=row_index) - row.update( - { - name: value[offset : offset + self.leading_axis_row_size] - for name, value in block.items() - } + batch_size = self.row_batch_size or descriptor.end - descriptor.start + for batch_start in range(descriptor.start, descriptor.end, batch_size): + batch_end = min(batch_start + batch_size, descriptor.end) + raw_start = batch_start * self.leading_axis_row_size + raw_end = batch_end * self.leading_axis_row_size + block = self._read_arrays( + arrays, + start=raw_start, + end=raw_end, ) - row.update(attrs) - yield DictRow(row) + for row_index in range(batch_start, batch_end): + offset = (row_index - batch_start) * self.leading_axis_row_size + row = self._row_metadata(index=row_index) + row.update( + { + name: value[offset : offset + self.leading_axis_row_size] + for name, value in block.items() + } + ) + row.update(attrs) + yield DictRow(row) return row = self._row_metadata(index=None) @@ -244,6 +251,18 @@ def _open_group(self) -> Any: import zarr import zarr.storage + if self.zip_file is not None: + store = ( + zarr.ZipStore(self.zip_file.abs_path(), mode="r") + if self.zip_file.is_local + else zarr.storage.FSStore( + "/", + fs=ZipFileSystem(fo=self.zip_file.open("rb"), mode="r"), + mode="r", + ) + ) + return zarr.open_group(store=store, mode="r") + assert self.root is not None store = zarr.storage.FSStore(self.root._join(""), fs=self.root.fs, mode="r") return zarr.open_group(store=store, mode="r") diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py index 4fdfa0ee..8b3ff4f1 100644 --- a/tests/readers/test_zarr_reader.py +++ b/tests/readers/test_zarr_reader.py @@ -314,6 +314,43 @@ def test_read_zarr_split_leading_axis_uses_row_size(tmp_path: Path) -> None: np.testing.assert_allclose(rows[1]["action"], [[2.0], [3.0]]) +def test_read_zarr_split_leading_axis_uses_row_batch_size( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + path = tmp_path / "leading_axis_batched.zarr" + root = _open_test_zarr(path, mode="w") + _create_array( + root, + "data/action", + data=np.arange(6, dtype=np.float32).reshape(6, 1), + chunks=(1, 1), + ) + + pipeline = mdr.read_zarr( + path, + arrays={"action": "data/action"}, + split_leading_axis=True, + row_batch_size=2, + file_path_column=None, + ) + source = cast(Any, pipeline.source) + shards = source.list_shards() + calls: list[tuple[int | None, int | None]] = [] + read_arrays = source._read_arrays + + def record_read_arrays(arrays, *, start=None, end=None): + calls.append((start, end)) + return read_arrays(arrays, start=start, end=end) + + monkeypatch.setattr(source, "_read_arrays", record_read_arrays) + rows = list(source.read_shard(shards[0])) + + assert [row["index"] for row in rows] == list(range(6)) + np.testing.assert_allclose(rows[5]["action"], [[5.0]]) + assert calls == [(0, 2), (2, 4), (4, 6)] + + def test_read_zarr_split_leading_axis_uses_dominant_array_chunks( tmp_path: Path, ) -> None: @@ -393,6 +430,19 @@ def test_read_zarr_leading_axis_row_size_requires_split_mode(tmp_path: Path) -> ) +def test_read_zarr_rejects_invalid_row_batch_size(tmp_path: Path) -> None: + path = tmp_path / "policy.zarr" + _write_policy_zarr(path) + + with pytest.raises(ValueError, match="row_batch_size"): + mdr.read_zarr( + path, + arrays={"action": "data/action"}, + split_leading_axis=True, + row_batch_size=0, + ) + + def test_read_zarr_rejects_non_monotonic_row_ends(tmp_path: Path) -> None: path = tmp_path / "policy.zarr" _write_policy_zarr(path) From a86e86f18a8ff6651020f2f80b6437169ba63a00 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sat, 23 May 2026 23:40:20 +0200 Subject: [PATCH 29/39] Document Zarr robotics reference datasets --- docs/robotics_conversion.md | 40 +++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/docs/robotics_conversion.md b/docs/robotics_conversion.md index b23ac141..84dfe3e0 100644 --- a/docs/robotics_conversion.md +++ b/docs/robotics_conversion.md @@ -77,3 +77,43 @@ pipeline = ( ) ) ``` + +## Zarr Replay Buffers + +Some robotics replay buffers are stored as unzipped Zarr directory stores with +frame-aligned arrays under `data/` and cumulative episode boundaries under +`meta/episode_ends`. + +Reference datasets: + +- RoboCasa MT4 N216: + `hf://datasets/ahad-j/robocasa_mt4_N216_zarr/mt4_N216.zarr` + (`https://huggingface.co/datasets/ahad-j/robocasa_mt4_N216_zarr`) +- MetaWorld MT4 N200: + `hf://datasets/runningkiwi/metaworld_mt4_n200_zarr` + (`https://huggingface.co/datasets/runningkiwi/metaworld_mt4_n200_zarr`) + +```python +import refiner as mdr + +pipeline = ( + mdr.read_zarr( + "hf://datasets/ahad-j/robocasa_mt4_N216_zarr/mt4_N216.zarr", + arrays={ + "action": "data/action", + "eef_pos": "data/robot0_eef_pos", + "joint_pos": "data/robot0_joint_pos", + "gripper_qpos": "data/robot0_gripper_qpos", + "wrist": "data/robot0_eye_in_hand_rgb", + }, + row_ends="meta/episode_ends", + index_column="episode_id", + ) + .to_robot_rows( + episode_id_key="episode_id", + action_key="action", + state_key=("eef_pos", "joint_pos", "gripper_qpos"), + video_keys=("wrist",), + ) +) +``` From 703253d83815caad43826ee587d387d5b09b45c8 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sat, 23 May 2026 23:50:13 +0200 Subject: [PATCH 30/39] Disable fsspec cache for remote Zarr zips --- docs/reading-and-writing.md | 5 +-- src/refiner/pipeline/sources/readers/zarr.py | 4 ++- tests/readers/test_zarr_reader.py | 35 ++++++++++++++++++++ 3 files changed, 41 insertions(+), 3 deletions(-) diff --git a/docs/reading-and-writing.md b/docs/reading-and-writing.md index b63c6ad7..efa80ad8 100644 --- a/docs/reading-and-writing.md +++ b/docs/reading-and-writing.md @@ -237,8 +237,9 @@ uv add "macrodata-refiner[zarr]" ``` `read_zarr(...)` reads one Zarr group, including directory stores and -`.zarr.zip` stores mounted through fsspec. By default, the group becomes one -output row and selected arrays are loaded as full array values. +`.zarr.zip` stores. Local zip stores use Zarr's native zip support; remote zip +stores are mounted through fsspec with block caching disabled. By default, the +group becomes one output row and selected arrays are loaded as full array values. ```python import refiner as mdr diff --git a/src/refiner/pipeline/sources/readers/zarr.py b/src/refiner/pipeline/sources/readers/zarr.py index 1e788753..d1c2f63c 100644 --- a/src/refiner/pipeline/sources/readers/zarr.py +++ b/src/refiner/pipeline/sources/readers/zarr.py @@ -257,7 +257,9 @@ def _open_group(self) -> Any: if self.zip_file.is_local else zarr.storage.FSStore( "/", - fs=ZipFileSystem(fo=self.zip_file.open("rb"), mode="r"), + fs=ZipFileSystem( + fo=self.zip_file.open("rb", cache_type="none"), mode="r" + ), mode="r", ) ) diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py index 8b3ff4f1..1e542db6 100644 --- a/tests/readers/test_zarr_reader.py +++ b/tests/readers/test_zarr_reader.py @@ -4,6 +4,7 @@ import shutil from typing import Any, Literal, cast +from fsspec.implementations.memory import MemoryFileSystem import numpy as np import pytest import zarr @@ -135,6 +136,40 @@ def test_read_zarr_reads_zip_store(tmp_path: Path) -> None: assert row["frames"].shape == (2, 4, 4, 3) +def test_read_zarr_reads_remote_zip_without_cache( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + path = tmp_path / "policy.zarr" + _write_policy_zarr(path) + zip_path = Path(shutil.make_archive(str(path), "zip", root_dir=path)) + + fs = MemoryFileSystem() + remote_path = "/policy.zarr.zip" + with zip_path.open("rb") as src, fs.open(remote_path, "wb") as dst: + shutil.copyfileobj(src, dst) + + open_calls: list[dict[str, Any]] = [] + original_open = fs.open + + def record_open(path, mode="rb", **kwargs): + if path == remote_path and mode == "rb": + open_calls.append(kwargs) + return original_open(path, mode=mode, **kwargs) + + monkeypatch.setattr(fs, "open", record_open) + + row = mdr.read_zarr( + (remote_path, fs), + arrays={"action": "data/action"}, + row_ends="meta/episode_ends", + ).take(1)[0] + + assert row["action"].shape == (2, 1) + assert open_calls + assert open_calls[0]["cache_type"] == "none" + + def test_read_zarr_plans_row_ends_with_num_shards(tmp_path: Path) -> None: path = tmp_path / "policy.zarr" _write_policy_zarr(path) From efd6efc80ff9eb5c09babc3ae8fccdc13ba3bffb Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sat, 23 May 2026 23:53:26 +0200 Subject: [PATCH 31/39] Clean up Zarr reader branching --- src/refiner/pipeline/sources/readers/zarr.py | 38 ++++++++------------ 1 file changed, 14 insertions(+), 24 deletions(-) diff --git a/src/refiner/pipeline/sources/readers/zarr.py b/src/refiner/pipeline/sources/readers/zarr.py index d1c2f63c..a9ca74e2 100644 --- a/src/refiner/pipeline/sources/readers/zarr.py +++ b/src/refiner/pipeline/sources/readers/zarr.py @@ -169,7 +169,7 @@ def list_shards(self) -> list[Shard]: path = self.source_path group = self._open_group() - arrays = self._selected_arrays(group, validate_names=True) + arrays = self._selected_arrays(group) split_ranges = self._shard_ranges(group, arrays) return [ Shard.from_row_range( @@ -252,17 +252,16 @@ def _open_group(self) -> Any: import zarr.storage if self.zip_file is not None: - store = ( - zarr.ZipStore(self.zip_file.abs_path(), mode="r") - if self.zip_file.is_local - else zarr.storage.FSStore( + if self.zip_file.is_local: + store = zarr.ZipStore(self.zip_file.abs_path(), mode="r") + else: + store = zarr.storage.FSStore( "/", fs=ZipFileSystem( fo=self.zip_file.open("rb", cache_type="none"), mode="r" ), mode="r", ) - ) return zarr.open_group(store=store, mode="r") assert self.root is not None store = zarr.storage.FSStore(self.root._join(""), fs=self.root.fs, mode="r") @@ -284,20 +283,11 @@ def _row_metadata(self, *, index: int | None) -> dict[str, Any]: row[self.index_column] = index return row - def _selected_arrays( - self, - group: Any, - *, - validate_names: bool = False, - ) -> dict[str, Any]: - paths = ( - self.arrays - if self.arrays is not None - else { + def _selected_arrays(self, group: Any) -> dict[str, Any]: + if self.arrays is None: + paths = { path: path for path in _iter_array_paths(group) if path != self.row_ends } - ) - if validate_names: _validate_output_names( paths, self.attrs or {}, @@ -305,6 +295,9 @@ def _selected_arrays( split=self.row_ends is not None or self.split_leading_axis ), ) + else: + paths = self.arrays + arrays: dict[str, Any] = {} for output_name, path in paths.items(): try: @@ -407,9 +400,11 @@ def _shard_ranges( row_count = int(row_ends_array.shape[0]) if row_count == 0: return [] - if self.num_shards is not None: + if self.num_shards is not None or not arrays: final_end = _validate_row_ends(row_ends_array) _check_final_end(arrays, final_end, label="row_ends", exact=True) + if self.num_shards is None: + return [(0, row_count)] step = ceil(row_count / self.num_shards) return [ (start, min(start + step, row_count)) @@ -417,11 +412,6 @@ def _shard_ranges( ] bytes_per_step = sum(_leading_item_bytes(array) for array in arrays.values()) - if bytes_per_step <= 0: - final_end = _validate_row_ends(row_ends_array) - _check_final_end(arrays, final_end, label="row_ends", exact=True) - return [(0, row_count)] - ranges: list[tuple[int, int]] = [] shard_start = 0 current_bytes = 0 From cba3ff06339c00bf40a3164365d025790e92186e Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sun, 24 May 2026 00:00:28 +0200 Subject: [PATCH 32/39] Trim reader helper wrappers --- src/refiner/pipeline/sources/readers/hdf5.py | 10 ++-------- src/refiner/pipeline/sources/readers/zarr.py | 6 ++---- 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/src/refiner/pipeline/sources/readers/hdf5.py b/src/refiner/pipeline/sources/readers/hdf5.py index c22c1db4..e311f691 100644 --- a/src/refiner/pipeline/sources/readers/hdf5.py +++ b/src/refiner/pipeline/sources/readers/hdf5.py @@ -84,8 +84,8 @@ def __init__( raise ValueError( "groups accepts a single glob string or a list of exact group paths" ) - self.datasets = self._mapping(datasets) - self.attrs = self._mapping(attrs) + self.datasets = path_selection_map(datasets, format_name="HDF5") + self.attrs = path_selection_map(attrs, format_name="HDF5") self.group_path_column = group_path_column self.missing_policy = missing_policy if missing_policy not in ("error", "drop_row", "set_null"): @@ -94,12 +94,6 @@ def __init__( ) self._validate_column_names() - @staticmethod - def _mapping( - value: PathSelection | None, - ) -> dict[str, str]: - return path_selection_map(value, format_name="HDF5") - def describe(self) -> dict[str, Any]: description = super().describe() description.update( diff --git a/src/refiner/pipeline/sources/readers/zarr.py b/src/refiner/pipeline/sources/readers/zarr.py index a9ca74e2..b3a7cb03 100644 --- a/src/refiner/pipeline/sources/readers/zarr.py +++ b/src/refiner/pipeline/sources/readers/zarr.py @@ -166,8 +166,6 @@ def describe(self) -> dict[str, Any]: } def list_shards(self) -> list[Shard]: - path = self.source_path - group = self._open_group() arrays = self._selected_arrays(group) split_ranges = self._shard_ranges(group, arrays) @@ -176,8 +174,8 @@ def list_shards(self) -> list[Shard]: start=start, end=end, global_ordinal=index, - start_key=path, - end_key=path, + start_key=self.source_path, + end_key=self.source_path, ) for index, (start, end) in enumerate(split_ranges) ] From 63e5541523c089373fb02e069cf9f4698430b104 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sun, 24 May 2026 00:05:26 +0200 Subject: [PATCH 33/39] Address Zarr reader review comments --- src/refiner/pipeline/sources/readers/zarr.py | 175 ++++++++++--------- tests/readers/test_zarr_reader.py | 44 +++++ 2 files changed, 141 insertions(+), 78 deletions(-) diff --git a/src/refiner/pipeline/sources/readers/zarr.py b/src/refiner/pipeline/sources/readers/zarr.py index b3a7cb03..6fafe2b3 100644 --- a/src/refiner/pipeline/sources/readers/zarr.py +++ b/src/refiner/pipeline/sources/readers/zarr.py @@ -1,6 +1,7 @@ from __future__ import annotations from collections.abc import Iterator, Mapping +from contextlib import contextmanager from math import ceil, prod from operator import index as integer_index from os import PathLike @@ -76,7 +77,9 @@ def __init__( dtypes: Optional dtype overrides for output columns. """ zip_input: DataFileLike | None = None - if isinstance(input, PathLike): + if isinstance(input, DataFolder) and input.path.endswith(".zip"): + zip_input = (input.path, input.fs) + elif isinstance(input, PathLike): input = str(input) if isinstance(input, str) and input.endswith(".zip"): zip_input = input @@ -166,104 +169,119 @@ def describe(self) -> dict[str, Any]: } def list_shards(self) -> list[Shard]: - group = self._open_group() - arrays = self._selected_arrays(group) - split_ranges = self._shard_ranges(group, arrays) - return [ - Shard.from_row_range( - start=start, - end=end, - global_ordinal=index, - start_key=self.source_path, - end_key=self.source_path, - ) - for index, (start, end) in enumerate(split_ranges) - ] + with self._open_group() as group: + arrays = self._selected_arrays(group) + split_ranges = self._shard_ranges(group, arrays) + return [ + Shard.from_row_range( + start=start, + end=end, + global_ordinal=index, + start_key=self.source_path, + end_key=self.source_path, + ) + for index, (start, end) in enumerate(split_ranges) + ] def read_shard(self, shard: Shard) -> Iterator[SourceUnit]: - group = self._open_group() - arrays = self._selected_arrays(group) - if self.row_ends is not None: - descriptor = shard.descriptor - assert isinstance(descriptor, RowRangeDescriptor) - source_ranges = self._row_end_ranges( - group, - arrays, - row_start=descriptor.start, - row_end=descriptor.end, - ) - if not source_ranges: - return - attrs = self._read_attrs(group) - batch_size = self.row_batch_size or len(source_ranges) - for batch_offset in range(0, len(source_ranges), batch_size): - batch = source_ranges[batch_offset : batch_offset + batch_size] - block_start = batch[0][0] - block_end = batch[-1][1] - block = self._read_arrays(arrays, start=block_start, end=block_end) - for offset, (start, end) in enumerate(batch, start=batch_offset): - row = self._row_metadata(index=descriptor.start + offset) - row.update( - { - name: value[start - block_start : end - block_start] - for name, value in block.items() - } - ) - row.update(attrs) - yield DictRow(row) - return - - if self.split_leading_axis: - descriptor = shard.descriptor - assert isinstance(descriptor, RowRangeDescriptor) - attrs = self._read_attrs(group) - batch_size = self.row_batch_size or descriptor.end - descriptor.start - for batch_start in range(descriptor.start, descriptor.end, batch_size): - batch_end = min(batch_start + batch_size, descriptor.end) - raw_start = batch_start * self.leading_axis_row_size - raw_end = batch_end * self.leading_axis_row_size - block = self._read_arrays( + with self._open_group() as group: + arrays = self._selected_arrays(group) + if self.row_ends is not None: + descriptor = shard.descriptor + assert isinstance(descriptor, RowRangeDescriptor) + source_ranges = self._row_end_ranges( + group, arrays, - start=raw_start, - end=raw_end, + row_start=descriptor.start, + row_end=descriptor.end, ) - for row_index in range(batch_start, batch_end): - offset = (row_index - batch_start) * self.leading_axis_row_size - row = self._row_metadata(index=row_index) - row.update( - { - name: value[offset : offset + self.leading_axis_row_size] - for name, value in block.items() - } + if not source_ranges: + return + attrs = self._read_attrs(group) + batch_size = self.row_batch_size or len(source_ranges) + for batch_offset in range(0, len(source_ranges), batch_size): + batch = source_ranges[batch_offset : batch_offset + batch_size] + block_start = batch[0][0] + block_end = batch[-1][1] + block = self._read_arrays(arrays, start=block_start, end=block_end) + for offset, (start, end) in enumerate(batch, start=batch_offset): + row = self._row_metadata(index=descriptor.start + offset) + row.update( + { + name: value[start - block_start : end - block_start] + for name, value in block.items() + } + ) + row.update(attrs) + yield DictRow(row) + return + + if self.split_leading_axis: + descriptor = shard.descriptor + assert isinstance(descriptor, RowRangeDescriptor) + attrs = self._read_attrs(group) + batch_size = self.row_batch_size or descriptor.end - descriptor.start + for batch_start in range(descriptor.start, descriptor.end, batch_size): + batch_end = min(batch_start + batch_size, descriptor.end) + raw_start = batch_start * self.leading_axis_row_size + raw_end = batch_end * self.leading_axis_row_size + block = self._read_arrays( + arrays, + start=raw_start, + end=raw_end, ) - row.update(attrs) - yield DictRow(row) - return + for row_index in range(batch_start, batch_end): + offset = (row_index - batch_start) * self.leading_axis_row_size + row = self._row_metadata(index=row_index) + row.update( + { + name: value[ + offset : offset + self.leading_axis_row_size + ] + for name, value in block.items() + } + ) + row.update(attrs) + yield DictRow(row) + return - row = self._row_metadata(index=None) - row.update(self._read_arrays(arrays)) - row.update(self._read_attrs(group)) - yield DictRow(row) + row = self._row_metadata(index=None) + row.update(self._read_arrays(arrays)) + row.update(self._read_attrs(group)) + yield DictRow(row) + @contextmanager def _open_group(self) -> Any: import zarr import zarr.storage if self.zip_file is not None: + handle = None + zip_fs = None if self.zip_file.is_local: store = zarr.ZipStore(self.zip_file.abs_path(), mode="r") else: + handle = self.zip_file.open("rb", cache_type="none") + zip_fs = ZipFileSystem(fo=handle, mode="r") store = zarr.storage.FSStore( "/", - fs=ZipFileSystem( - fo=self.zip_file.open("rb", cache_type="none"), mode="r" - ), + fs=zip_fs, mode="r", ) - return zarr.open_group(store=store, mode="r") + try: + yield zarr.open_group(store=store, mode="r") + finally: + close_store = getattr(store, "close", None) + if callable(close_store): + close_store() + if zip_fs is not None: + zip_fs.close() + if handle is not None: + handle.close() + return assert self.root is not None store = zarr.storage.FSStore(self.root._join(""), fs=self.root.fs, mode="r") - return zarr.open_group(store=store, mode="r") + yield zarr.open_group(store=store, mode="r") def _reserved_output_names(self, *, split: bool) -> set[str]: names = set() @@ -397,6 +415,7 @@ def _shard_ranges( row_ends_array = self._row_ends_array(group) row_count = int(row_ends_array.shape[0]) if row_count == 0: + _check_final_end(arrays, 0, label="row_ends", exact=True) return [] if self.num_shards is not None or not arrays: final_end = _validate_row_ends(row_ends_array) diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py index 1e542db6..535be8c5 100644 --- a/tests/readers/test_zarr_reader.py +++ b/tests/readers/test_zarr_reader.py @@ -10,6 +10,7 @@ import zarr import refiner as mdr +from refiner.io.datafolder import DataFolder from refiner.robotics.row import RoboticsRow from refiner.pipeline.data.row import Row from refiner.pipeline.data.shard import RowRangeDescriptor @@ -136,6 +137,21 @@ def test_read_zarr_reads_zip_store(tmp_path: Path) -> None: assert row["frames"].shape == (2, 4, 4, 3) +def test_read_zarr_reads_zip_datafolder(tmp_path: Path) -> None: + path = tmp_path / "policy.zarr" + _write_policy_zarr(path) + zip_path = Path(shutil.make_archive(str(path), "zip", root_dir=path)) + + row = mdr.read_zarr( + DataFolder(str(zip_path)), + arrays={"action": "data/action"}, + row_ends="meta/episode_ends", + ).take(1)[0] + + assert row["file_path"] == str(zip_path) + assert row["action"].shape == (2, 1) + + def test_read_zarr_reads_remote_zip_without_cache( tmp_path: Path, monkeypatch: pytest.MonkeyPatch, @@ -150,11 +166,23 @@ def test_read_zarr_reads_remote_zip_without_cache( shutil.copyfileobj(src, dst) open_calls: list[dict[str, Any]] = [] + opened_files: list[int] = [] + closed_files: list[int] = [] original_open = fs.open def record_open(path, mode="rb", **kwargs): if path == remote_path and mode == "rb": open_calls.append(kwargs) + file = original_open(path, mode=mode, **kwargs) + opened_files.append(id(file)) + original_close = file.close + + def record_close(): + closed_files.append(id(file)) + original_close() + + file.close = record_close + return file return original_open(path, mode=mode, **kwargs) monkeypatch.setattr(fs, "open", record_open) @@ -168,6 +196,7 @@ def record_open(path, mode="rb", **kwargs): assert row["action"].shape == (2, 1) assert open_calls assert open_calls[0]["cache_type"] == "none" + assert set(closed_files) == set(opened_files) def test_read_zarr_plans_row_ends_with_num_shards(tmp_path: Path) -> None: @@ -538,6 +567,21 @@ def test_read_zarr_rejects_short_row_ends(tmp_path: Path) -> None: ).take(2) +def test_read_zarr_rejects_empty_row_ends_for_nonempty_arrays(tmp_path: Path) -> None: + path = tmp_path / "empty-row-ends.zarr" + root = _open_test_zarr(path, mode="w") + _create_array(root, "data/action", data=np.zeros((2, 1), dtype=np.float32)) + _create_array(root, "meta/episode_ends", data=np.asarray([], dtype=np.int64)) + + with pytest.raises(ValueError, match="end before leading dimension"): + mdr.read_zarr( + path, + arrays={"action": "data/action"}, + row_ends="meta/episode_ends", + file_path_column=None, + ).take(1) + + def test_read_zarr_rejects_scalar_arrays_in_row_ends_mode(tmp_path: Path) -> None: path = tmp_path / "scalar-row-ends.zarr" root = _open_test_zarr(path, mode="w") From c091a0ea0c93dfbcdf4df0f9e44350bd6e8346ee Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sun, 24 May 2026 00:07:51 +0200 Subject: [PATCH 34/39] Tighten Zarr zip input handling --- src/refiner/pipeline/sources/readers/zarr.py | 37 ++++++++++---------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/src/refiner/pipeline/sources/readers/zarr.py b/src/refiner/pipeline/sources/readers/zarr.py index 6fafe2b3..ad1ce18b 100644 --- a/src/refiner/pipeline/sources/readers/zarr.py +++ b/src/refiner/pipeline/sources/readers/zarr.py @@ -79,20 +79,21 @@ def __init__( zip_input: DataFileLike | None = None if isinstance(input, DataFolder) and input.path.endswith(".zip"): zip_input = (input.path, input.fs) - elif isinstance(input, PathLike): - input = str(input) - if isinstance(input, str) and input.endswith(".zip"): - zip_input = input - elif ( - isinstance(input, tuple) - and len(input) == 2 - and isinstance(input[1], AbstractFileSystem) - ): - path = input[0] - if isinstance(path, PathLike): - path = str(path) - if isinstance(path, str) and path.endswith(".zip"): - zip_input = (path, input[1]) + else: + if isinstance(input, PathLike): + input = str(input) + if isinstance(input, str) and input.endswith(".zip"): + zip_input = input + elif ( + isinstance(input, tuple) + and len(input) == 2 + and isinstance(input[1], AbstractFileSystem) + ): + path = input[0] + if isinstance(path, PathLike): + path = str(path) + if isinstance(path, str) and path.endswith(".zip"): + zip_input = (path, input[1]) if zip_input is not None: self.zip_file = DataFile.resolve(zip_input) @@ -278,10 +279,10 @@ def _open_group(self) -> Any: zip_fs.close() if handle is not None: handle.close() - return - assert self.root is not None - store = zarr.storage.FSStore(self.root._join(""), fs=self.root.fs, mode="r") - yield zarr.open_group(store=store, mode="r") + else: + assert self.root is not None + store = zarr.storage.FSStore(self.root._join(""), fs=self.root.fs, mode="r") + yield zarr.open_group(store=store, mode="r") def _reserved_output_names(self, *, split: bool) -> set[str]: names = set() From 0103d155389dadc1889bbd71167cc6225727dd89 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sun, 24 May 2026 00:11:54 +0200 Subject: [PATCH 35/39] Reject row_ends output selection --- src/refiner/pipeline/sources/readers/zarr.py | 8 +++++++- tests/readers/test_zarr_reader.py | 13 +++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/src/refiner/pipeline/sources/readers/zarr.py b/src/refiner/pipeline/sources/readers/zarr.py index ad1ce18b..91787107 100644 --- a/src/refiner/pipeline/sources/readers/zarr.py +++ b/src/refiner/pipeline/sources/readers/zarr.py @@ -122,6 +122,12 @@ def __init__( self.attrs = ( None if attrs is None else path_selection_map(attrs, format_name="Zarr") ) + if row_ends is not None and self.arrays is not None: + for output_name, path in self.arrays.items(): + if path == row_ends: + raise ValueError( + f"Zarr array selection {output_name!r} cannot also be row_ends" + ) self.row_ends = row_ends self.split_leading_axis = split_leading_axis self.leading_axis_row_size = leading_axis_row_size @@ -566,4 +572,4 @@ def _iter_array_paths(group: Any, prefix: str = "") -> Iterator[str]: yield from _iter_array_paths(item, path) -__all__ = ["PathSelection", "ZarrReader"] +__all__ = ["ZarrReader"] diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py index 535be8c5..93b87297 100644 --- a/tests/readers/test_zarr_reader.py +++ b/tests/readers/test_zarr_reader.py @@ -277,6 +277,19 @@ def test_read_zarr_rejects_reserved_index_output_name(tmp_path: Path) -> None: ) +def test_read_zarr_rejects_selecting_row_ends_as_output_array(tmp_path: Path) -> None: + path = tmp_path / "policy.zarr" + _write_policy_zarr(path) + + with pytest.raises(ValueError, match="cannot also be row_ends"): + mdr.read_zarr( + path, + arrays={"episode_ends": "meta/episode_ends"}, + row_ends="meta/episode_ends", + file_path_column=None, + ) + + def test_read_zarr_rejects_duplicate_metadata_column_names(tmp_path: Path) -> None: path = tmp_path / "policy.zarr" _write_policy_zarr(path) From 37b78fe1199616567644c34874afa4c6367ecf8e Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sun, 24 May 2026 00:38:07 +0200 Subject: [PATCH 36/39] Support Zarr 3 reader installs --- pyproject.toml | 5 +- src/refiner/pipeline/sources/readers/zarr.py | 65 ++++++- tests/readers/test_zarr_reader.py | 24 +++ uv.lock | 174 +++++++++++++++++-- 4 files changed, 252 insertions(+), 16 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4001aa0c..e2f46440 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,8 +45,9 @@ hdf5 = [ "h5py", ] zarr = [ - "zarr>=2.18,<3", - "numcodecs<0.16", + "zarr>=2.18,<3; python_version < '3.11'", + "zarr>=3.1,<4; python_version >= '3.11'", + "numcodecs<0.16; python_version < '3.11'", ] s3 = [ "s3fs", diff --git a/src/refiner/pipeline/sources/readers/zarr.py b/src/refiner/pipeline/sources/readers/zarr.py index 91787107..f2e0b477 100644 --- a/src/refiner/pipeline/sources/readers/zarr.py +++ b/src/refiner/pipeline/sources/readers/zarr.py @@ -2,6 +2,7 @@ from collections.abc import Iterator, Mapping from contextlib import contextmanager +from importlib import import_module from math import ceil, prod from operator import index as integer_index from os import PathLike @@ -262,11 +263,73 @@ def _open_group(self) -> Any: import zarr import zarr.storage + if not hasattr(zarr.storage, "FSStore"): + if self.zip_file is not None: + if self.zip_file.is_local: + store = zarr.storage.ZipStore(self.zip_file.abs_path(), mode="r") + try: + yield zarr.open_group(store=store, mode="r") + finally: + store.close() + return + + handle = self.zip_file.open("rb", cache_type="none") + zip_fs = ZipFileSystem(fo=handle, mode="r") + try: + make_async = getattr( + import_module("zarr.storage._fsspec"), "_make_async" + ) + store = zarr.storage.FsspecStore( + fs=make_async(zip_fs), + path="", + read_only=True, + allowed_exceptions=( + FileNotFoundError, + IsADirectoryError, + NotADirectoryError, + KeyError, + ), + ) + try: + open_kwargs: dict[str, Any] = { + "store": store, + "mode": "r", + "zarr_format": 2, + "use_consolidated": False, + } + yield zarr.open_group(**open_kwargs) + finally: + store.close() + finally: + zip_fs.close() + handle.close() + return + + assert self.root is not None + make_async = getattr(import_module("zarr.storage._fsspec"), "_make_async") + + protocol = self.root.fs.protocol + if protocol == "file" or ( + not isinstance(protocol, str) and "file" in protocol + ): + store = zarr.storage.LocalStore(self.root._join(""), read_only=True) + else: + store = zarr.storage.FsspecStore( + fs=make_async(self.root.fs), + path=self.root._join(""), + read_only=True, + ) + try: + yield zarr.open_group(store=store, mode="r") + finally: + store.close() + return + if self.zip_file is not None: handle = None zip_fs = None if self.zip_file.is_local: - store = zarr.ZipStore(self.zip_file.abs_path(), mode="r") + store = getattr(zarr, "ZipStore")(self.zip_file.abs_path(), mode="r") else: handle = self.zip_file.open("rb", cache_type="none") zip_fs = ZipFileSystem(fo=handle, mode="r") diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py index 93b87297..3948983b 100644 --- a/tests/readers/test_zarr_reader.py +++ b/tests/readers/test_zarr_reader.py @@ -152,6 +152,30 @@ def test_read_zarr_reads_zip_datafolder(tmp_path: Path) -> None: assert row["action"].shape == (2, 1) +def test_read_zarr_reads_remote_store(tmp_path: Path) -> None: + path = tmp_path / "policy.zarr" + _write_policy_zarr(path) + + fs = MemoryFileSystem() + remote_root = "/policy.zarr" + for source_path in path.rglob("*"): + if source_path.is_file(): + relative_path = source_path.relative_to(path) + remote_path = f"{remote_root}/{relative_path}" + fs.makedirs(str(Path(remote_path).parent), exist_ok=True) + with source_path.open("rb") as src, fs.open(remote_path, "wb") as dst: + shutil.copyfileobj(src, dst) + + row = mdr.read_zarr( + (remote_root, fs), + arrays={"action": "data/action"}, + row_ends="meta/episode_ends", + ).take(1)[0] + + assert row["file_path"] == "memory:///policy.zarr" + assert row["action"].shape == (2, 1) + + def test_read_zarr_reads_remote_zip_without_cache( tmp_path: Path, monkeypatch: pytest.MonkeyPatch, diff --git a/uv.lock b/uv.lock index d844aa09..34a5ca20 100644 --- a/uv.lock +++ b/uv.lock @@ -618,6 +618,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/33/6b/e0547afaf41bf2c42e52430072fa5658766e3d65bd4b03a563d1b6336f57/distlib-0.4.0-py2.py3-none-any.whl", hash = "sha256:9659f7d87e46584a30b5780e43ac7a2143098441670ff0a49d5f9034c54a6c16", size = 469047, upload-time = "2025-07-17T16:51:58.613Z" }, ] +[[package]] +name = "donfig" +version = "0.8.1.post1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyyaml", marker = "python_full_version >= '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/25/71/80cc718ff6d7abfbabacb1f57aaa42e9c1552bfdd01e64ddd704e4a03638/donfig-0.8.1.post1.tar.gz", hash = "sha256:3bef3413a4c1c601b585e8d297256d0c1470ea012afa6e8461dc28bfb7c23f52", size = 19506, upload-time = "2024-05-23T14:14:31.513Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0c/d5/c5db1ea3394c6e1732fb3286b3bd878b59507a8f77d32a2cebda7d7b7cd4/donfig-0.8.1.post1-py3-none-any.whl", hash = "sha256:2a3175ce74a06109ff9307d90a230f81215cbac9a751f4d1c6194644b8204f9d", size = 21592, upload-time = "2024-05-23T14:13:55.283Z" }, +] + [[package]] name = "exceptiongroup" version = "1.3.1" @@ -783,6 +795,41 @@ http = [ { name = "aiohttp" }, ] +[[package]] +name = "google-crc32c" +version = "1.8.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/03/41/4b9c02f99e4c5fb477122cd5437403b552873f014616ac1d19ac8221a58d/google_crc32c-1.8.0.tar.gz", hash = "sha256:a428e25fb7691024de47fecfbff7ff957214da51eddded0da0ae0e0f03a2cf79", size = 14192, upload-time = "2025-12-16T00:35:25.142Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/95/ac/6f7bc93886a823ab545948c2dd48143027b2355ad1944c7cf852b338dc91/google_crc32c-1.8.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:0470b8c3d73b5f4e3300165498e4cf25221c7eb37f1159e221d1825b6df8a7ff", size = 31296, upload-time = "2025-12-16T00:19:07.261Z" }, + { url = "https://files.pythonhosted.org/packages/f7/97/a5accde175dee985311d949cfcb1249dcbb290f5ec83c994ea733311948f/google_crc32c-1.8.0-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:119fcd90c57c89f30040b47c211acee231b25a45d225e3225294386f5d258288", size = 30870, upload-time = "2025-12-16T00:29:17.669Z" }, + { url = "https://files.pythonhosted.org/packages/3d/63/bec827e70b7a0d4094e7476f863c0dbd6b5f0f1f91d9c9b32b76dcdfeb4e/google_crc32c-1.8.0-cp310-cp310-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:6f35aaffc8ccd81ba3162443fabb920e65b1f20ab1952a31b13173a67811467d", size = 33214, upload-time = "2025-12-16T00:40:19.618Z" }, + { url = "https://files.pythonhosted.org/packages/63/bc/11b70614df04c289128d782efc084b9035ef8466b3d0a8757c1b6f5cf7ac/google_crc32c-1.8.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:864abafe7d6e2c4c66395c1eb0fe12dc891879769b52a3d56499612ca93b6092", size = 33589, upload-time = "2025-12-16T00:40:20.7Z" }, + { url = "https://files.pythonhosted.org/packages/3e/00/a08a4bc24f1261cc5b0f47312d8aebfbe4b53c2e6307f1b595605eed246b/google_crc32c-1.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:db3fe8eaf0612fc8b20fa21a5f25bd785bc3cd5be69f8f3412b0ac2ffd49e733", size = 34437, upload-time = "2025-12-16T00:35:19.437Z" }, + { url = "https://files.pythonhosted.org/packages/5d/ef/21ccfaab3d5078d41efe8612e0ed0bfc9ce22475de074162a91a25f7980d/google_crc32c-1.8.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:014a7e68d623e9a4222d663931febc3033c5c7c9730785727de2a81f87d5bab8", size = 31298, upload-time = "2025-12-16T00:20:32.241Z" }, + { url = "https://files.pythonhosted.org/packages/c5/b8/f8413d3f4b676136e965e764ceedec904fe38ae8de0cdc52a12d8eb1096e/google_crc32c-1.8.0-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:86cfc00fe45a0ac7359e5214a1704e51a99e757d0272554874f419f79838c5f7", size = 30872, upload-time = "2025-12-16T00:33:58.785Z" }, + { url = "https://files.pythonhosted.org/packages/f6/fd/33aa4ec62b290477181c55bb1c9302c9698c58c0ce9a6ab4874abc8b0d60/google_crc32c-1.8.0-cp311-cp311-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:19b40d637a54cb71e0829179f6cb41835f0fbd9e8eb60552152a8b52c36cbe15", size = 33243, upload-time = "2025-12-16T00:40:21.46Z" }, + { url = "https://files.pythonhosted.org/packages/71/03/4820b3bd99c9653d1a5210cb32f9ba4da9681619b4d35b6a052432df4773/google_crc32c-1.8.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:17446feb05abddc187e5441a45971b8394ea4c1b6efd88ab0af393fd9e0a156a", size = 33608, upload-time = "2025-12-16T00:40:22.204Z" }, + { url = "https://files.pythonhosted.org/packages/7c/43/acf61476a11437bf9733fb2f70599b1ced11ec7ed9ea760fdd9a77d0c619/google_crc32c-1.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:71734788a88f551fbd6a97be9668a0020698e07b2bf5b3aa26a36c10cdfb27b2", size = 34439, upload-time = "2025-12-16T00:35:20.458Z" }, + { url = "https://files.pythonhosted.org/packages/e9/5f/7307325b1198b59324c0fa9807cafb551afb65e831699f2ce211ad5c8240/google_crc32c-1.8.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:4b8286b659c1335172e39563ab0a768b8015e88e08329fa5321f774275fc3113", size = 31300, upload-time = "2025-12-16T00:21:56.723Z" }, + { url = "https://files.pythonhosted.org/packages/21/8e/58c0d5d86e2220e6a37befe7e6a94dd2f6006044b1a33edf1ff6d9f7e319/google_crc32c-1.8.0-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:2a3dc3318507de089c5384cc74d54318401410f82aa65b2d9cdde9d297aca7cb", size = 30867, upload-time = "2025-12-16T00:38:31.302Z" }, + { url = "https://files.pythonhosted.org/packages/ce/a9/a780cc66f86335a6019f557a8aaca8fbb970728f0efd2430d15ff1beae0e/google_crc32c-1.8.0-cp312-cp312-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:14f87e04d613dfa218d6135e81b78272c3b904e2a7053b841481b38a7d901411", size = 33364, upload-time = "2025-12-16T00:40:22.96Z" }, + { url = "https://files.pythonhosted.org/packages/21/3f/3457ea803db0198c9aaca2dd373750972ce28a26f00544b6b85088811939/google_crc32c-1.8.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cb5c869c2923d56cb0c8e6bcdd73c009c36ae39b652dbe46a05eb4ef0ad01454", size = 33740, upload-time = "2025-12-16T00:40:23.96Z" }, + { url = "https://files.pythonhosted.org/packages/df/c0/87c2073e0c72515bb8733d4eef7b21548e8d189f094b5dad20b0ecaf64f6/google_crc32c-1.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:3cc0c8912038065eafa603b238abf252e204accab2a704c63b9e14837a854962", size = 34437, upload-time = "2025-12-16T00:35:21.395Z" }, + { url = "https://files.pythonhosted.org/packages/d1/db/000f15b41724589b0e7bc24bc7a8967898d8d3bc8caf64c513d91ef1f6c0/google_crc32c-1.8.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:3ebb04528e83b2634857f43f9bb8ef5b2bbe7f10f140daeb01b58f972d04736b", size = 31297, upload-time = "2025-12-16T00:23:20.709Z" }, + { url = "https://files.pythonhosted.org/packages/d7/0d/8ebed0c39c53a7e838e2a486da8abb0e52de135f1b376ae2f0b160eb4c1a/google_crc32c-1.8.0-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:450dc98429d3e33ed2926fc99ee81001928d63460f8538f21a5d6060912a8e27", size = 30867, upload-time = "2025-12-16T00:43:14.628Z" }, + { url = "https://files.pythonhosted.org/packages/ce/42/b468aec74a0354b34c8cbf748db20d6e350a68a2b0912e128cabee49806c/google_crc32c-1.8.0-cp313-cp313-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:3b9776774b24ba76831609ffbabce8cdf6fa2bd5e9df37b594221c7e333a81fa", size = 33344, upload-time = "2025-12-16T00:40:24.742Z" }, + { url = "https://files.pythonhosted.org/packages/1c/e8/b33784d6fc77fb5062a8a7854e43e1e618b87d5ddf610a88025e4de6226e/google_crc32c-1.8.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:89c17d53d75562edfff86679244830599ee0a48efc216200691de8b02ab6b2b8", size = 33694, upload-time = "2025-12-16T00:40:25.505Z" }, + { url = "https://files.pythonhosted.org/packages/92/b1/d3cbd4d988afb3d8e4db94ca953df429ed6db7282ed0e700d25e6c7bfc8d/google_crc32c-1.8.0-cp313-cp313-win_amd64.whl", hash = "sha256:57a50a9035b75643996fbf224d6661e386c7162d1dfdab9bc4ca790947d1007f", size = 34435, upload-time = "2025-12-16T00:35:22.107Z" }, + { url = "https://files.pythonhosted.org/packages/21/88/8ecf3c2b864a490b9e7010c84fd203ec8cf3b280651106a3a74dd1b0ca72/google_crc32c-1.8.0-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:e6584b12cb06796d285d09e33f63309a09368b9d806a551d8036a4207ea43697", size = 31301, upload-time = "2025-12-16T00:24:48.527Z" }, + { url = "https://files.pythonhosted.org/packages/36/c6/f7ff6c11f5ca215d9f43d3629163727a272eabc356e5c9b2853df2bfe965/google_crc32c-1.8.0-cp314-cp314-macosx_12_0_x86_64.whl", hash = "sha256:f4b51844ef67d6cf2e9425983274da75f18b1597bb2c998e1c0a0e8d46f8f651", size = 30868, upload-time = "2025-12-16T00:48:12.163Z" }, + { url = "https://files.pythonhosted.org/packages/56/15/c25671c7aad70f8179d858c55a6ae8404902abe0cdcf32a29d581792b491/google_crc32c-1.8.0-cp314-cp314-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:b0d1a7afc6e8e4635564ba8aa5c0548e3173e41b6384d7711a9123165f582de2", size = 33381, upload-time = "2025-12-16T00:40:26.268Z" }, + { url = "https://files.pythonhosted.org/packages/42/fa/f50f51260d7b0ef5d4898af122d8a7ec5a84e2984f676f746445f783705f/google_crc32c-1.8.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8b3f68782f3cbd1bce027e48768293072813469af6a61a86f6bb4977a4380f21", size = 33734, upload-time = "2025-12-16T00:40:27.028Z" }, + { url = "https://files.pythonhosted.org/packages/08/a5/7b059810934a09fb3ccb657e0843813c1fee1183d3bc2c8041800374aa2c/google_crc32c-1.8.0-cp314-cp314-win_amd64.whl", hash = "sha256:d511b3153e7011a27ab6ee6bb3a5404a55b994dc1a7322c0b87b29606d9790e2", size = 34878, upload-time = "2025-12-16T00:35:23.142Z" }, + { url = "https://files.pythonhosted.org/packages/52/c5/c171e4d8c44fec1422d801a6d2e5d7ddabd733eeda505c79730ee9607f07/google_crc32c-1.8.0-pp311-pypy311_pp73-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:87fa445064e7db928226b2e6f0d5304ab4cd0339e664a4e9a25029f384d9bb93", size = 28615, upload-time = "2025-12-16T00:40:29.298Z" }, + { url = "https://files.pythonhosted.org/packages/9c/97/7d75fe37a7a6ed171a2cf17117177e7aab7e6e0d115858741b41e9dd4254/google_crc32c-1.8.0-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:f639065ea2042d5c034bf258a9f085eaa7af0cd250667c0635a3118e8f92c69c", size = 28800, upload-time = "2025-12-16T00:40:30.322Z" }, +] + [[package]] name = "h11" version = "0.16.0" @@ -1015,12 +1062,14 @@ all = [ { name = "h5py" }, { name = "hf" }, { name = "huggingface-hub" }, - { name = "numcodecs" }, + { name = "numcodecs", version = "0.13.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "pytest" }, { name = "pytest-cov" }, { name = "s3fs" }, { name = "warcio" }, - { name = "zarr" }, + { name = "zarr", version = "2.18.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "zarr", version = "3.1.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.11.*'" }, + { name = "zarr", version = "3.2.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, ] hdf5 = [ { name = "h5py" }, @@ -1042,12 +1091,14 @@ testing = [ { name = "h5py" }, { name = "hf" }, { name = "huggingface-hub" }, - { name = "numcodecs" }, + { name = "numcodecs", version = "0.13.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "pytest" }, { name = "pytest-cov" }, { name = "s3fs" }, { name = "warcio" }, - { name = "zarr" }, + { name = "zarr", version = "2.18.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "zarr", version = "3.1.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.11.*'" }, + { name = "zarr", version = "3.2.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, ] text = [ { name = "warcio" }, @@ -1056,8 +1107,10 @@ video = [ { name = "av" }, ] zarr = [ - { name = "numcodecs" }, - { name = "zarr" }, + { name = "numcodecs", version = "0.13.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "zarr", version = "2.18.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "zarr", version = "3.1.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.11.*'" }, + { name = "zarr", version = "3.2.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, ] [package.dev-dependencies] @@ -1088,7 +1141,7 @@ requires-dist = [ { name = "macrodata-refiner", extras = ["video"], marker = "extra == 'robotics'" }, { name = "macrodata-refiner", extras = ["zarr"], marker = "extra == 'testing'" }, { name = "msgspec", specifier = ">=0.20.0" }, - { name = "numcodecs", marker = "extra == 'zarr'", specifier = "<0.16" }, + { name = "numcodecs", marker = "python_full_version < '3.11' and extra == 'zarr'", specifier = "<0.16" }, { name = "numpy" }, { name = "orjson" }, { name = "pyarrow" }, @@ -1096,7 +1149,8 @@ requires-dist = [ { name = "pytest-cov", marker = "extra == 'testing'", specifier = ">=5.0.0" }, { name = "s3fs", marker = "extra == 's3'" }, { name = "warcio", marker = "extra == 'text'" }, - { name = "zarr", marker = "extra == 'zarr'", specifier = ">=2.18,<3" }, + { name = "zarr", marker = "python_full_version >= '3.11' and extra == 'zarr'", specifier = ">=3.1,<4" }, + { name = "zarr", marker = "python_full_version < '3.11' and extra == 'zarr'", specifier = ">=2.18,<3" }, ] provides-extras = ["huggingface", "video", "robotics", "text", "hdf5", "zarr", "s3", "testing", "all"] @@ -1359,9 +1413,11 @@ wheels = [ name = "numcodecs" version = "0.13.1" source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version < '3.11'", +] dependencies = [ { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "numpy", version = "2.4.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/85/56/8895a76abe4ec94ebd01eeb6d74f587bc4cddd46569670e1402852a5da13/numcodecs-0.13.1.tar.gz", hash = "sha256:a3cf37881df0898f3a9c0d4477df88133fe85185bffe57ba31bcc2fa207709bc", size = 5955215, upload-time = "2024-10-09T16:28:00.188Z" } wheels = [ @@ -1383,6 +1439,49 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a6/c5/f3e56bc9b4e438a287fff738993d6d11abef368c0328a612ac2842ba9fca/numcodecs-0.13.1-cp313-cp313-win_amd64.whl", hash = "sha256:90d3065ae74c9342048ae0046006f99dcb1388b7288da5a19b3bddf9c30c3176", size = 821887, upload-time = "2024-10-09T16:27:55.039Z" }, ] +[[package]] +name = "numcodecs" +version = "0.16.5" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.14' and sys_platform == 'win32'", + "python_full_version >= '3.14' and sys_platform == 'emscripten'", + "python_full_version >= '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform == 'win32'", + "python_full_version == '3.11.*' and sys_platform == 'win32'", + "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform == 'emscripten'", + "python_full_version == '3.11.*' and sys_platform == 'emscripten'", + "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version == '3.11.*' and sys_platform != 'emscripten' and sys_platform != 'win32'", +] +dependencies = [ + { name = "numpy", version = "2.4.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "typing-extensions", marker = "python_full_version >= '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/44/bd/8a391e7c356366224734efd24da929cc4796fff468bfb179fe1af6548535/numcodecs-0.16.5.tar.gz", hash = "sha256:0d0fb60852f84c0bd9543cc4d2ab9eefd37fc8efcc410acd4777e62a1d300318", size = 6276387, upload-time = "2025-11-21T02:49:48.986Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/af/85/1ac101a40ead81eaa1c7dc49a8827a30e2e436211b43ebdc63c590eb1347/numcodecs-0.16.5-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:78382dcea50622f2ef1e6e7a71dbe7f861d8fe376b27b7c297c26907304fef1e", size = 1621795, upload-time = "2025-11-21T02:49:17.418Z" }, + { url = "https://files.pythonhosted.org/packages/0e/cc/0d97ef55dda48cb0f93d7b92d761208e7a99bd2eea6b0e859426e6a99a21/numcodecs-0.16.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e2d04a19cb57a3c519b4127ac377cca6471aee1990d7c18f5b1e3a4fe1306689", size = 1153030, upload-time = "2025-11-21T02:49:19.089Z" }, + { url = "https://files.pythonhosted.org/packages/5e/41/e120ee1b390730ac5987cde2afd82e2b8442cec315ab40b94b0373e93e73/numcodecs-0.16.5-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c043af648eb280cd61785c99c22ff5c3c3460f906eb51a8511327c4f5111b283", size = 8510503, upload-time = "2025-11-21T02:49:20.324Z" }, + { url = "https://files.pythonhosted.org/packages/54/4b/195ac84cc8f6077b4f0f421e8daee21b7f1bd88cb7716414234379fe68ec/numcodecs-0.16.5-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c398919ef2eb0e56b8e97456f622640bfd3deed06de3acc976989cbcb22628a3", size = 9123428, upload-time = "2025-11-21T02:49:22.328Z" }, + { url = "https://files.pythonhosted.org/packages/0f/5b/af02c417954f46e5c7bd5163ac251f535877d909fce54861c99ae197f6f6/numcodecs-0.16.5-cp311-cp311-win_amd64.whl", hash = "sha256:3820860ed302d4d84a1c66e70981ff959d5eb712555be4e7d8ced49888594773", size = 801542, upload-time = "2025-11-21T02:49:24.265Z" }, + { url = "https://files.pythonhosted.org/packages/75/cc/55420f3641a67f78392dc0bc5d02cb9eb0a9dcebf2848d1ac77253ca61fa/numcodecs-0.16.5-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:24e675dc8d1550cd976a99479b87d872cb142632c75cc402fea04c08c4898523", size = 1656287, upload-time = "2025-11-21T02:49:25.755Z" }, + { url = "https://files.pythonhosted.org/packages/f5/6c/86644987505dcb90ba6d627d6989c27bafb0699f9fd00187e06d05ea8594/numcodecs-0.16.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:94ddfa4341d1a3ab99989d13b01b5134abb687d3dab2ead54b450aefe4ad5bd6", size = 1148899, upload-time = "2025-11-21T02:49:26.87Z" }, + { url = "https://files.pythonhosted.org/packages/97/1e/98aaddf272552d9fef1f0296a9939d1487914a239e98678f6b20f8b0a5c8/numcodecs-0.16.5-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b554ab9ecf69de7ca2b6b5e8bc696bd9747559cb4dd5127bd08d7a28bec59c3a", size = 8534814, upload-time = "2025-11-21T02:49:28.547Z" }, + { url = "https://files.pythonhosted.org/packages/fb/53/78c98ef5c8b2b784453487f3e4d6c017b20747c58b470393e230c78d18e8/numcodecs-0.16.5-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ad1a379a45bd3491deab8ae6548313946744f868c21d5340116977ea3be5b1d6", size = 9173471, upload-time = "2025-11-21T02:49:30.444Z" }, + { url = "https://files.pythonhosted.org/packages/1c/20/2fdec87fc7f8cec950d2b0bea603c12dc9f05b4966dc5924ba5a36a61bf6/numcodecs-0.16.5-cp312-cp312-win_amd64.whl", hash = "sha256:845a9857886ffe4a3172ba1c537ae5bcc01e65068c31cf1fce1a844bd1da050f", size = 801412, upload-time = "2025-11-21T02:49:32.123Z" }, + { url = "https://files.pythonhosted.org/packages/38/38/071ced5a5fd1c85ba0e14ba721b66b053823e5176298c2f707e50bed11d9/numcodecs-0.16.5-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:25be3a516ab677dad890760d357cfe081a371d9c0a2e9a204562318ac5969de3", size = 1654359, upload-time = "2025-11-21T02:49:33.673Z" }, + { url = "https://files.pythonhosted.org/packages/d1/c0/5f84ba7525577c1b9909fc2d06ef11314825fc4ad4378f61d0e4c9883b4a/numcodecs-0.16.5-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:0107e839ef75b854e969cb577e140b1aadb9847893937636582d23a2a4c6ce50", size = 1144237, upload-time = "2025-11-21T02:49:35.294Z" }, + { url = "https://files.pythonhosted.org/packages/0b/00/787ea5f237b8ea7bc67140c99155f9c00b5baf11c49afc5f3bfefa298f95/numcodecs-0.16.5-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:015a7c859ecc2a06e2a548f64008c0ec3aaecabc26456c2c62f4278d8fc20597", size = 8483064, upload-time = "2025-11-21T02:49:36.454Z" }, + { url = "https://files.pythonhosted.org/packages/c4/e6/d359fdd37498e74d26a167f7a51e54542e642ea47181eb4e643a69a066c3/numcodecs-0.16.5-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:84230b4b9dad2392f2a84242bd6e3e659ac137b5a1ce3571d6965fca673e0903", size = 9126063, upload-time = "2025-11-21T02:49:38.018Z" }, + { url = "https://files.pythonhosted.org/packages/27/72/6663cc0382ddbb866136c255c837bcb96cc7ce5e83562efec55e1b995941/numcodecs-0.16.5-cp313-cp313-win_amd64.whl", hash = "sha256:5088145502ad1ebf677ec47d00eb6f0fd600658217db3e0c070c321c85d6cf3d", size = 799275, upload-time = "2025-11-21T02:49:39.558Z" }, + { url = "https://files.pythonhosted.org/packages/3c/9e/38e7ca8184c958b51f45d56a4aeceb1134ecde2d8bd157efadc98502cc42/numcodecs-0.16.5-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:b05647b8b769e6bc8016e9fd4843c823ce5c9f2337c089fb5c9c4da05e5275de", size = 1654721, upload-time = "2025-11-21T02:49:40.602Z" }, + { url = "https://files.pythonhosted.org/packages/a1/37/260fa42e7b2b08e6e00ad632f8dd620961a60a459426c26cea390f8c68d0/numcodecs-0.16.5-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:3832bd1b5af8bb3e413076b7d93318c8e7d7b68935006b9fa36ca057d1725a8f", size = 1146887, upload-time = "2025-11-21T02:49:41.721Z" }, + { url = "https://files.pythonhosted.org/packages/4e/15/e2e1151b5a8b14a15dfd4bb4abccce7fff7580f39bc34092780088835f3a/numcodecs-0.16.5-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:49f7b7d24f103187f53135bed28bb9f0ed6b2e14c604664726487bb6d7c882e1", size = 8476987, upload-time = "2025-11-21T02:49:43.363Z" }, + { url = "https://files.pythonhosted.org/packages/6d/30/16a57fc4d9fb0ba06c600408bd6634f2f1753c54a7a351c99c5e09b51ee2/numcodecs-0.16.5-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:aec9736d81b70f337d89c4070ee3ffeff113f386fd789492fa152d26a15043e4", size = 9102377, upload-time = "2025-11-21T02:49:45.508Z" }, + { url = "https://files.pythonhosted.org/packages/31/a5/a0425af36c20d55a3ea884db4b4efca25a43bea9214ba69ca7932dd997b4/numcodecs-0.16.5-cp314-cp314-win_amd64.whl", hash = "sha256:b16a14303800e9fb88abc39463ab4706c037647ac17e49e297faa5f7d7dbbf1d", size = 819022, upload-time = "2025-11-21T02:49:47.39Z" }, +] + [[package]] name = "numpy" version = "2.2.6" @@ -2748,14 +2847,63 @@ wheels = [ name = "zarr" version = "2.18.3" source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version < '3.11'", +] dependencies = [ - { name = "asciitree" }, - { name = "fasteners", marker = "sys_platform != 'emscripten'" }, - { name = "numcodecs" }, + { name = "asciitree", marker = "python_full_version < '3.11'" }, + { name = "fasteners", marker = "python_full_version < '3.11' and sys_platform != 'emscripten'" }, + { name = "numcodecs", version = "0.13.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "numpy", version = "2.4.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/23/c4/187a21ce7cf7c8f00c060dd0e04c2a81139bb7b1ab178bba83f2e1134ce2/zarr-2.18.3.tar.gz", hash = "sha256:2580d8cb6dd84621771a10d31c4d777dca8a27706a1a89b29f42d2d37e2df5ce", size = 3603224, upload-time = "2024-09-04T23:20:16.595Z" } wheels = [ { url = "https://files.pythonhosted.org/packages/ed/c9/142095e654c2b97133ff71df60979422717b29738b08bc8a1709a5d5e0d0/zarr-2.18.3-py3-none-any.whl", hash = "sha256:b1f7dfd2496f436745cdd4c7bcf8d3b4bc1dceef5fdd0d589c87130d842496dd", size = 210723, upload-time = "2024-09-04T23:20:14.491Z" }, ] + +[[package]] +name = "zarr" +version = "3.1.6" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version == '3.11.*' and sys_platform == 'win32'", + "python_full_version == '3.11.*' and sys_platform == 'emscripten'", + "python_full_version == '3.11.*' and sys_platform != 'emscripten' and sys_platform != 'win32'", +] +dependencies = [ + { name = "donfig", marker = "python_full_version == '3.11.*'" }, + { name = "google-crc32c", marker = "python_full_version == '3.11.*'" }, + { name = "numcodecs", version = "0.16.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.11.*'" }, + { name = "numpy", version = "2.4.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.11.*'" }, + { name = "packaging", marker = "python_full_version == '3.11.*'" }, + { name = "typing-extensions", marker = "python_full_version == '3.11.*'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/31/5a/b8a0cf39a14c770c30bd1f2d120c54000c8cd9e84e8e79f38d9a7ce58071/zarr-3.1.6.tar.gz", hash = "sha256:d95e72cbea4b90e9a70679468b8266400331756232576ae2b43400ac5108d0eb", size = 386531, upload-time = "2026-03-23T17:25:18.748Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/de/7c/ba8ca8cbe9dbef8e83a95fc208fed8e6686c98b4719aaa0aa7f3d31fe390/zarr-3.1.6-py3-none-any.whl", hash = "sha256:b5a82c5079d1c3d4ee8f06746fa3b9a98a7d804300fa3f4be154362a33e1207e", size = 295655, upload-time = "2026-03-23T17:25:17.189Z" }, +] + +[[package]] +name = "zarr" +version = "3.2.1" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.14' and sys_platform == 'win32'", + "python_full_version >= '3.14' and sys_platform == 'emscripten'", + "python_full_version >= '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform == 'win32'", + "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform == 'emscripten'", + "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32'", +] +dependencies = [ + { name = "donfig", marker = "python_full_version >= '3.12'" }, + { name = "google-crc32c", marker = "python_full_version >= '3.12'" }, + { name = "numcodecs", version = "0.16.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, + { name = "numpy", version = "2.4.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, + { name = "packaging", marker = "python_full_version >= '3.12'" }, + { name = "typing-extensions", marker = "python_full_version >= '3.12'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/93/8d/aeb164004f87543b06ef54f885d02c342c31ceb274e2bbec470a98927621/zarr-3.2.1.tar.gz", hash = "sha256:71565b738a0e7e8ed226f0516eba8c6bb53440ad7669a8c48ebb3534a161d035", size = 675161, upload-time = "2026-05-05T12:37:22.383Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/0a/469e2bd01be1490336e6c8707386845655d59261543315778a3ccc7e8019/zarr-3.2.1-py3-none-any.whl", hash = "sha256:f78cdd3d9687ad0e9f9cba2c5683b64f0c52589c19f685eeabe872e93cc0d2c7", size = 319617, upload-time = "2026-05-05T12:37:20.66Z" }, +] From 535a74f0f70849a0f7ece96318b97380c83ec958 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sun, 24 May 2026 00:45:22 +0200 Subject: [PATCH 37/39] Drop Zarr 3 compatibility path --- pyproject.toml | 5 +- src/refiner/pipeline/sources/readers/zarr.py | 65 +------ uv.lock | 174 ++----------------- 3 files changed, 16 insertions(+), 228 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e2f46440..4001aa0c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,9 +45,8 @@ hdf5 = [ "h5py", ] zarr = [ - "zarr>=2.18,<3; python_version < '3.11'", - "zarr>=3.1,<4; python_version >= '3.11'", - "numcodecs<0.16; python_version < '3.11'", + "zarr>=2.18,<3", + "numcodecs<0.16", ] s3 = [ "s3fs", diff --git a/src/refiner/pipeline/sources/readers/zarr.py b/src/refiner/pipeline/sources/readers/zarr.py index f2e0b477..91787107 100644 --- a/src/refiner/pipeline/sources/readers/zarr.py +++ b/src/refiner/pipeline/sources/readers/zarr.py @@ -2,7 +2,6 @@ from collections.abc import Iterator, Mapping from contextlib import contextmanager -from importlib import import_module from math import ceil, prod from operator import index as integer_index from os import PathLike @@ -263,73 +262,11 @@ def _open_group(self) -> Any: import zarr import zarr.storage - if not hasattr(zarr.storage, "FSStore"): - if self.zip_file is not None: - if self.zip_file.is_local: - store = zarr.storage.ZipStore(self.zip_file.abs_path(), mode="r") - try: - yield zarr.open_group(store=store, mode="r") - finally: - store.close() - return - - handle = self.zip_file.open("rb", cache_type="none") - zip_fs = ZipFileSystem(fo=handle, mode="r") - try: - make_async = getattr( - import_module("zarr.storage._fsspec"), "_make_async" - ) - store = zarr.storage.FsspecStore( - fs=make_async(zip_fs), - path="", - read_only=True, - allowed_exceptions=( - FileNotFoundError, - IsADirectoryError, - NotADirectoryError, - KeyError, - ), - ) - try: - open_kwargs: dict[str, Any] = { - "store": store, - "mode": "r", - "zarr_format": 2, - "use_consolidated": False, - } - yield zarr.open_group(**open_kwargs) - finally: - store.close() - finally: - zip_fs.close() - handle.close() - return - - assert self.root is not None - make_async = getattr(import_module("zarr.storage._fsspec"), "_make_async") - - protocol = self.root.fs.protocol - if protocol == "file" or ( - not isinstance(protocol, str) and "file" in protocol - ): - store = zarr.storage.LocalStore(self.root._join(""), read_only=True) - else: - store = zarr.storage.FsspecStore( - fs=make_async(self.root.fs), - path=self.root._join(""), - read_only=True, - ) - try: - yield zarr.open_group(store=store, mode="r") - finally: - store.close() - return - if self.zip_file is not None: handle = None zip_fs = None if self.zip_file.is_local: - store = getattr(zarr, "ZipStore")(self.zip_file.abs_path(), mode="r") + store = zarr.ZipStore(self.zip_file.abs_path(), mode="r") else: handle = self.zip_file.open("rb", cache_type="none") zip_fs = ZipFileSystem(fo=handle, mode="r") diff --git a/uv.lock b/uv.lock index 34a5ca20..d844aa09 100644 --- a/uv.lock +++ b/uv.lock @@ -618,18 +618,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/33/6b/e0547afaf41bf2c42e52430072fa5658766e3d65bd4b03a563d1b6336f57/distlib-0.4.0-py2.py3-none-any.whl", hash = "sha256:9659f7d87e46584a30b5780e43ac7a2143098441670ff0a49d5f9034c54a6c16", size = 469047, upload-time = "2025-07-17T16:51:58.613Z" }, ] -[[package]] -name = "donfig" -version = "0.8.1.post1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "pyyaml", marker = "python_full_version >= '3.11'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/25/71/80cc718ff6d7abfbabacb1f57aaa42e9c1552bfdd01e64ddd704e4a03638/donfig-0.8.1.post1.tar.gz", hash = "sha256:3bef3413a4c1c601b585e8d297256d0c1470ea012afa6e8461dc28bfb7c23f52", size = 19506, upload-time = "2024-05-23T14:14:31.513Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/0c/d5/c5db1ea3394c6e1732fb3286b3bd878b59507a8f77d32a2cebda7d7b7cd4/donfig-0.8.1.post1-py3-none-any.whl", hash = "sha256:2a3175ce74a06109ff9307d90a230f81215cbac9a751f4d1c6194644b8204f9d", size = 21592, upload-time = "2024-05-23T14:13:55.283Z" }, -] - [[package]] name = "exceptiongroup" version = "1.3.1" @@ -795,41 +783,6 @@ http = [ { name = "aiohttp" }, ] -[[package]] -name = "google-crc32c" -version = "1.8.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/03/41/4b9c02f99e4c5fb477122cd5437403b552873f014616ac1d19ac8221a58d/google_crc32c-1.8.0.tar.gz", hash = "sha256:a428e25fb7691024de47fecfbff7ff957214da51eddded0da0ae0e0f03a2cf79", size = 14192, upload-time = "2025-12-16T00:35:25.142Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/95/ac/6f7bc93886a823ab545948c2dd48143027b2355ad1944c7cf852b338dc91/google_crc32c-1.8.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:0470b8c3d73b5f4e3300165498e4cf25221c7eb37f1159e221d1825b6df8a7ff", size = 31296, upload-time = "2025-12-16T00:19:07.261Z" }, - { url = "https://files.pythonhosted.org/packages/f7/97/a5accde175dee985311d949cfcb1249dcbb290f5ec83c994ea733311948f/google_crc32c-1.8.0-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:119fcd90c57c89f30040b47c211acee231b25a45d225e3225294386f5d258288", size = 30870, upload-time = "2025-12-16T00:29:17.669Z" }, - { url = "https://files.pythonhosted.org/packages/3d/63/bec827e70b7a0d4094e7476f863c0dbd6b5f0f1f91d9c9b32b76dcdfeb4e/google_crc32c-1.8.0-cp310-cp310-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:6f35aaffc8ccd81ba3162443fabb920e65b1f20ab1952a31b13173a67811467d", size = 33214, upload-time = "2025-12-16T00:40:19.618Z" }, - { url = "https://files.pythonhosted.org/packages/63/bc/11b70614df04c289128d782efc084b9035ef8466b3d0a8757c1b6f5cf7ac/google_crc32c-1.8.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:864abafe7d6e2c4c66395c1eb0fe12dc891879769b52a3d56499612ca93b6092", size = 33589, upload-time = "2025-12-16T00:40:20.7Z" }, - { url = "https://files.pythonhosted.org/packages/3e/00/a08a4bc24f1261cc5b0f47312d8aebfbe4b53c2e6307f1b595605eed246b/google_crc32c-1.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:db3fe8eaf0612fc8b20fa21a5f25bd785bc3cd5be69f8f3412b0ac2ffd49e733", size = 34437, upload-time = "2025-12-16T00:35:19.437Z" }, - { url = "https://files.pythonhosted.org/packages/5d/ef/21ccfaab3d5078d41efe8612e0ed0bfc9ce22475de074162a91a25f7980d/google_crc32c-1.8.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:014a7e68d623e9a4222d663931febc3033c5c7c9730785727de2a81f87d5bab8", size = 31298, upload-time = "2025-12-16T00:20:32.241Z" }, - { url = "https://files.pythonhosted.org/packages/c5/b8/f8413d3f4b676136e965e764ceedec904fe38ae8de0cdc52a12d8eb1096e/google_crc32c-1.8.0-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:86cfc00fe45a0ac7359e5214a1704e51a99e757d0272554874f419f79838c5f7", size = 30872, upload-time = "2025-12-16T00:33:58.785Z" }, - { url = "https://files.pythonhosted.org/packages/f6/fd/33aa4ec62b290477181c55bb1c9302c9698c58c0ce9a6ab4874abc8b0d60/google_crc32c-1.8.0-cp311-cp311-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:19b40d637a54cb71e0829179f6cb41835f0fbd9e8eb60552152a8b52c36cbe15", size = 33243, upload-time = "2025-12-16T00:40:21.46Z" }, - { url = "https://files.pythonhosted.org/packages/71/03/4820b3bd99c9653d1a5210cb32f9ba4da9681619b4d35b6a052432df4773/google_crc32c-1.8.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:17446feb05abddc187e5441a45971b8394ea4c1b6efd88ab0af393fd9e0a156a", size = 33608, upload-time = "2025-12-16T00:40:22.204Z" }, - { url = "https://files.pythonhosted.org/packages/7c/43/acf61476a11437bf9733fb2f70599b1ced11ec7ed9ea760fdd9a77d0c619/google_crc32c-1.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:71734788a88f551fbd6a97be9668a0020698e07b2bf5b3aa26a36c10cdfb27b2", size = 34439, upload-time = "2025-12-16T00:35:20.458Z" }, - { url = "https://files.pythonhosted.org/packages/e9/5f/7307325b1198b59324c0fa9807cafb551afb65e831699f2ce211ad5c8240/google_crc32c-1.8.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:4b8286b659c1335172e39563ab0a768b8015e88e08329fa5321f774275fc3113", size = 31300, upload-time = "2025-12-16T00:21:56.723Z" }, - { url = "https://files.pythonhosted.org/packages/21/8e/58c0d5d86e2220e6a37befe7e6a94dd2f6006044b1a33edf1ff6d9f7e319/google_crc32c-1.8.0-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:2a3dc3318507de089c5384cc74d54318401410f82aa65b2d9cdde9d297aca7cb", size = 30867, upload-time = "2025-12-16T00:38:31.302Z" }, - { url = "https://files.pythonhosted.org/packages/ce/a9/a780cc66f86335a6019f557a8aaca8fbb970728f0efd2430d15ff1beae0e/google_crc32c-1.8.0-cp312-cp312-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:14f87e04d613dfa218d6135e81b78272c3b904e2a7053b841481b38a7d901411", size = 33364, upload-time = "2025-12-16T00:40:22.96Z" }, - { url = "https://files.pythonhosted.org/packages/21/3f/3457ea803db0198c9aaca2dd373750972ce28a26f00544b6b85088811939/google_crc32c-1.8.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cb5c869c2923d56cb0c8e6bcdd73c009c36ae39b652dbe46a05eb4ef0ad01454", size = 33740, upload-time = "2025-12-16T00:40:23.96Z" }, - { url = "https://files.pythonhosted.org/packages/df/c0/87c2073e0c72515bb8733d4eef7b21548e8d189f094b5dad20b0ecaf64f6/google_crc32c-1.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:3cc0c8912038065eafa603b238abf252e204accab2a704c63b9e14837a854962", size = 34437, upload-time = "2025-12-16T00:35:21.395Z" }, - { url = "https://files.pythonhosted.org/packages/d1/db/000f15b41724589b0e7bc24bc7a8967898d8d3bc8caf64c513d91ef1f6c0/google_crc32c-1.8.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:3ebb04528e83b2634857f43f9bb8ef5b2bbe7f10f140daeb01b58f972d04736b", size = 31297, upload-time = "2025-12-16T00:23:20.709Z" }, - { url = "https://files.pythonhosted.org/packages/d7/0d/8ebed0c39c53a7e838e2a486da8abb0e52de135f1b376ae2f0b160eb4c1a/google_crc32c-1.8.0-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:450dc98429d3e33ed2926fc99ee81001928d63460f8538f21a5d6060912a8e27", size = 30867, upload-time = "2025-12-16T00:43:14.628Z" }, - { url = "https://files.pythonhosted.org/packages/ce/42/b468aec74a0354b34c8cbf748db20d6e350a68a2b0912e128cabee49806c/google_crc32c-1.8.0-cp313-cp313-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:3b9776774b24ba76831609ffbabce8cdf6fa2bd5e9df37b594221c7e333a81fa", size = 33344, upload-time = "2025-12-16T00:40:24.742Z" }, - { url = "https://files.pythonhosted.org/packages/1c/e8/b33784d6fc77fb5062a8a7854e43e1e618b87d5ddf610a88025e4de6226e/google_crc32c-1.8.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:89c17d53d75562edfff86679244830599ee0a48efc216200691de8b02ab6b2b8", size = 33694, upload-time = "2025-12-16T00:40:25.505Z" }, - { url = "https://files.pythonhosted.org/packages/92/b1/d3cbd4d988afb3d8e4db94ca953df429ed6db7282ed0e700d25e6c7bfc8d/google_crc32c-1.8.0-cp313-cp313-win_amd64.whl", hash = "sha256:57a50a9035b75643996fbf224d6661e386c7162d1dfdab9bc4ca790947d1007f", size = 34435, upload-time = "2025-12-16T00:35:22.107Z" }, - { url = "https://files.pythonhosted.org/packages/21/88/8ecf3c2b864a490b9e7010c84fd203ec8cf3b280651106a3a74dd1b0ca72/google_crc32c-1.8.0-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:e6584b12cb06796d285d09e33f63309a09368b9d806a551d8036a4207ea43697", size = 31301, upload-time = "2025-12-16T00:24:48.527Z" }, - { url = "https://files.pythonhosted.org/packages/36/c6/f7ff6c11f5ca215d9f43d3629163727a272eabc356e5c9b2853df2bfe965/google_crc32c-1.8.0-cp314-cp314-macosx_12_0_x86_64.whl", hash = "sha256:f4b51844ef67d6cf2e9425983274da75f18b1597bb2c998e1c0a0e8d46f8f651", size = 30868, upload-time = "2025-12-16T00:48:12.163Z" }, - { url = "https://files.pythonhosted.org/packages/56/15/c25671c7aad70f8179d858c55a6ae8404902abe0cdcf32a29d581792b491/google_crc32c-1.8.0-cp314-cp314-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:b0d1a7afc6e8e4635564ba8aa5c0548e3173e41b6384d7711a9123165f582de2", size = 33381, upload-time = "2025-12-16T00:40:26.268Z" }, - { url = "https://files.pythonhosted.org/packages/42/fa/f50f51260d7b0ef5d4898af122d8a7ec5a84e2984f676f746445f783705f/google_crc32c-1.8.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8b3f68782f3cbd1bce027e48768293072813469af6a61a86f6bb4977a4380f21", size = 33734, upload-time = "2025-12-16T00:40:27.028Z" }, - { url = "https://files.pythonhosted.org/packages/08/a5/7b059810934a09fb3ccb657e0843813c1fee1183d3bc2c8041800374aa2c/google_crc32c-1.8.0-cp314-cp314-win_amd64.whl", hash = "sha256:d511b3153e7011a27ab6ee6bb3a5404a55b994dc1a7322c0b87b29606d9790e2", size = 34878, upload-time = "2025-12-16T00:35:23.142Z" }, - { url = "https://files.pythonhosted.org/packages/52/c5/c171e4d8c44fec1422d801a6d2e5d7ddabd733eeda505c79730ee9607f07/google_crc32c-1.8.0-pp311-pypy311_pp73-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:87fa445064e7db928226b2e6f0d5304ab4cd0339e664a4e9a25029f384d9bb93", size = 28615, upload-time = "2025-12-16T00:40:29.298Z" }, - { url = "https://files.pythonhosted.org/packages/9c/97/7d75fe37a7a6ed171a2cf17117177e7aab7e6e0d115858741b41e9dd4254/google_crc32c-1.8.0-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:f639065ea2042d5c034bf258a9f085eaa7af0cd250667c0635a3118e8f92c69c", size = 28800, upload-time = "2025-12-16T00:40:30.322Z" }, -] - [[package]] name = "h11" version = "0.16.0" @@ -1062,14 +1015,12 @@ all = [ { name = "h5py" }, { name = "hf" }, { name = "huggingface-hub" }, - { name = "numcodecs", version = "0.13.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "numcodecs" }, { name = "pytest" }, { name = "pytest-cov" }, { name = "s3fs" }, { name = "warcio" }, - { name = "zarr", version = "2.18.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "zarr", version = "3.1.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.11.*'" }, - { name = "zarr", version = "3.2.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, + { name = "zarr" }, ] hdf5 = [ { name = "h5py" }, @@ -1091,14 +1042,12 @@ testing = [ { name = "h5py" }, { name = "hf" }, { name = "huggingface-hub" }, - { name = "numcodecs", version = "0.13.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "numcodecs" }, { name = "pytest" }, { name = "pytest-cov" }, { name = "s3fs" }, { name = "warcio" }, - { name = "zarr", version = "2.18.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "zarr", version = "3.1.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.11.*'" }, - { name = "zarr", version = "3.2.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, + { name = "zarr" }, ] text = [ { name = "warcio" }, @@ -1107,10 +1056,8 @@ video = [ { name = "av" }, ] zarr = [ - { name = "numcodecs", version = "0.13.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "zarr", version = "2.18.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "zarr", version = "3.1.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.11.*'" }, - { name = "zarr", version = "3.2.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, + { name = "numcodecs" }, + { name = "zarr" }, ] [package.dev-dependencies] @@ -1141,7 +1088,7 @@ requires-dist = [ { name = "macrodata-refiner", extras = ["video"], marker = "extra == 'robotics'" }, { name = "macrodata-refiner", extras = ["zarr"], marker = "extra == 'testing'" }, { name = "msgspec", specifier = ">=0.20.0" }, - { name = "numcodecs", marker = "python_full_version < '3.11' and extra == 'zarr'", specifier = "<0.16" }, + { name = "numcodecs", marker = "extra == 'zarr'", specifier = "<0.16" }, { name = "numpy" }, { name = "orjson" }, { name = "pyarrow" }, @@ -1149,8 +1096,7 @@ requires-dist = [ { name = "pytest-cov", marker = "extra == 'testing'", specifier = ">=5.0.0" }, { name = "s3fs", marker = "extra == 's3'" }, { name = "warcio", marker = "extra == 'text'" }, - { name = "zarr", marker = "python_full_version >= '3.11' and extra == 'zarr'", specifier = ">=3.1,<4" }, - { name = "zarr", marker = "python_full_version < '3.11' and extra == 'zarr'", specifier = ">=2.18,<3" }, + { name = "zarr", marker = "extra == 'zarr'", specifier = ">=2.18,<3" }, ] provides-extras = ["huggingface", "video", "robotics", "text", "hdf5", "zarr", "s3", "testing", "all"] @@ -1413,11 +1359,9 @@ wheels = [ name = "numcodecs" version = "0.13.1" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version < '3.11'", -] dependencies = [ { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "numpy", version = "2.4.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/85/56/8895a76abe4ec94ebd01eeb6d74f587bc4cddd46569670e1402852a5da13/numcodecs-0.13.1.tar.gz", hash = "sha256:a3cf37881df0898f3a9c0d4477df88133fe85185bffe57ba31bcc2fa207709bc", size = 5955215, upload-time = "2024-10-09T16:28:00.188Z" } wheels = [ @@ -1439,49 +1383,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a6/c5/f3e56bc9b4e438a287fff738993d6d11abef368c0328a612ac2842ba9fca/numcodecs-0.13.1-cp313-cp313-win_amd64.whl", hash = "sha256:90d3065ae74c9342048ae0046006f99dcb1388b7288da5a19b3bddf9c30c3176", size = 821887, upload-time = "2024-10-09T16:27:55.039Z" }, ] -[[package]] -name = "numcodecs" -version = "0.16.5" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version >= '3.14' and sys_platform == 'win32'", - "python_full_version >= '3.14' and sys_platform == 'emscripten'", - "python_full_version >= '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32'", - "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform == 'win32'", - "python_full_version == '3.11.*' and sys_platform == 'win32'", - "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform == 'emscripten'", - "python_full_version == '3.11.*' and sys_platform == 'emscripten'", - "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32'", - "python_full_version == '3.11.*' and sys_platform != 'emscripten' and sys_platform != 'win32'", -] -dependencies = [ - { name = "numpy", version = "2.4.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, - { name = "typing-extensions", marker = "python_full_version >= '3.11'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/44/bd/8a391e7c356366224734efd24da929cc4796fff468bfb179fe1af6548535/numcodecs-0.16.5.tar.gz", hash = "sha256:0d0fb60852f84c0bd9543cc4d2ab9eefd37fc8efcc410acd4777e62a1d300318", size = 6276387, upload-time = "2025-11-21T02:49:48.986Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/af/85/1ac101a40ead81eaa1c7dc49a8827a30e2e436211b43ebdc63c590eb1347/numcodecs-0.16.5-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:78382dcea50622f2ef1e6e7a71dbe7f861d8fe376b27b7c297c26907304fef1e", size = 1621795, upload-time = "2025-11-21T02:49:17.418Z" }, - { url = "https://files.pythonhosted.org/packages/0e/cc/0d97ef55dda48cb0f93d7b92d761208e7a99bd2eea6b0e859426e6a99a21/numcodecs-0.16.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e2d04a19cb57a3c519b4127ac377cca6471aee1990d7c18f5b1e3a4fe1306689", size = 1153030, upload-time = "2025-11-21T02:49:19.089Z" }, - { url = "https://files.pythonhosted.org/packages/5e/41/e120ee1b390730ac5987cde2afd82e2b8442cec315ab40b94b0373e93e73/numcodecs-0.16.5-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c043af648eb280cd61785c99c22ff5c3c3460f906eb51a8511327c4f5111b283", size = 8510503, upload-time = "2025-11-21T02:49:20.324Z" }, - { url = "https://files.pythonhosted.org/packages/54/4b/195ac84cc8f6077b4f0f421e8daee21b7f1bd88cb7716414234379fe68ec/numcodecs-0.16.5-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c398919ef2eb0e56b8e97456f622640bfd3deed06de3acc976989cbcb22628a3", size = 9123428, upload-time = "2025-11-21T02:49:22.328Z" }, - { url = "https://files.pythonhosted.org/packages/0f/5b/af02c417954f46e5c7bd5163ac251f535877d909fce54861c99ae197f6f6/numcodecs-0.16.5-cp311-cp311-win_amd64.whl", hash = "sha256:3820860ed302d4d84a1c66e70981ff959d5eb712555be4e7d8ced49888594773", size = 801542, upload-time = "2025-11-21T02:49:24.265Z" }, - { url = "https://files.pythonhosted.org/packages/75/cc/55420f3641a67f78392dc0bc5d02cb9eb0a9dcebf2848d1ac77253ca61fa/numcodecs-0.16.5-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:24e675dc8d1550cd976a99479b87d872cb142632c75cc402fea04c08c4898523", size = 1656287, upload-time = "2025-11-21T02:49:25.755Z" }, - { url = "https://files.pythonhosted.org/packages/f5/6c/86644987505dcb90ba6d627d6989c27bafb0699f9fd00187e06d05ea8594/numcodecs-0.16.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:94ddfa4341d1a3ab99989d13b01b5134abb687d3dab2ead54b450aefe4ad5bd6", size = 1148899, upload-time = "2025-11-21T02:49:26.87Z" }, - { url = "https://files.pythonhosted.org/packages/97/1e/98aaddf272552d9fef1f0296a9939d1487914a239e98678f6b20f8b0a5c8/numcodecs-0.16.5-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b554ab9ecf69de7ca2b6b5e8bc696bd9747559cb4dd5127bd08d7a28bec59c3a", size = 8534814, upload-time = "2025-11-21T02:49:28.547Z" }, - { url = "https://files.pythonhosted.org/packages/fb/53/78c98ef5c8b2b784453487f3e4d6c017b20747c58b470393e230c78d18e8/numcodecs-0.16.5-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ad1a379a45bd3491deab8ae6548313946744f868c21d5340116977ea3be5b1d6", size = 9173471, upload-time = "2025-11-21T02:49:30.444Z" }, - { url = "https://files.pythonhosted.org/packages/1c/20/2fdec87fc7f8cec950d2b0bea603c12dc9f05b4966dc5924ba5a36a61bf6/numcodecs-0.16.5-cp312-cp312-win_amd64.whl", hash = "sha256:845a9857886ffe4a3172ba1c537ae5bcc01e65068c31cf1fce1a844bd1da050f", size = 801412, upload-time = "2025-11-21T02:49:32.123Z" }, - { url = "https://files.pythonhosted.org/packages/38/38/071ced5a5fd1c85ba0e14ba721b66b053823e5176298c2f707e50bed11d9/numcodecs-0.16.5-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:25be3a516ab677dad890760d357cfe081a371d9c0a2e9a204562318ac5969de3", size = 1654359, upload-time = "2025-11-21T02:49:33.673Z" }, - { url = "https://files.pythonhosted.org/packages/d1/c0/5f84ba7525577c1b9909fc2d06ef11314825fc4ad4378f61d0e4c9883b4a/numcodecs-0.16.5-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:0107e839ef75b854e969cb577e140b1aadb9847893937636582d23a2a4c6ce50", size = 1144237, upload-time = "2025-11-21T02:49:35.294Z" }, - { url = "https://files.pythonhosted.org/packages/0b/00/787ea5f237b8ea7bc67140c99155f9c00b5baf11c49afc5f3bfefa298f95/numcodecs-0.16.5-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:015a7c859ecc2a06e2a548f64008c0ec3aaecabc26456c2c62f4278d8fc20597", size = 8483064, upload-time = "2025-11-21T02:49:36.454Z" }, - { url = "https://files.pythonhosted.org/packages/c4/e6/d359fdd37498e74d26a167f7a51e54542e642ea47181eb4e643a69a066c3/numcodecs-0.16.5-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:84230b4b9dad2392f2a84242bd6e3e659ac137b5a1ce3571d6965fca673e0903", size = 9126063, upload-time = "2025-11-21T02:49:38.018Z" }, - { url = "https://files.pythonhosted.org/packages/27/72/6663cc0382ddbb866136c255c837bcb96cc7ce5e83562efec55e1b995941/numcodecs-0.16.5-cp313-cp313-win_amd64.whl", hash = "sha256:5088145502ad1ebf677ec47d00eb6f0fd600658217db3e0c070c321c85d6cf3d", size = 799275, upload-time = "2025-11-21T02:49:39.558Z" }, - { url = "https://files.pythonhosted.org/packages/3c/9e/38e7ca8184c958b51f45d56a4aeceb1134ecde2d8bd157efadc98502cc42/numcodecs-0.16.5-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:b05647b8b769e6bc8016e9fd4843c823ce5c9f2337c089fb5c9c4da05e5275de", size = 1654721, upload-time = "2025-11-21T02:49:40.602Z" }, - { url = "https://files.pythonhosted.org/packages/a1/37/260fa42e7b2b08e6e00ad632f8dd620961a60a459426c26cea390f8c68d0/numcodecs-0.16.5-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:3832bd1b5af8bb3e413076b7d93318c8e7d7b68935006b9fa36ca057d1725a8f", size = 1146887, upload-time = "2025-11-21T02:49:41.721Z" }, - { url = "https://files.pythonhosted.org/packages/4e/15/e2e1151b5a8b14a15dfd4bb4abccce7fff7580f39bc34092780088835f3a/numcodecs-0.16.5-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:49f7b7d24f103187f53135bed28bb9f0ed6b2e14c604664726487bb6d7c882e1", size = 8476987, upload-time = "2025-11-21T02:49:43.363Z" }, - { url = "https://files.pythonhosted.org/packages/6d/30/16a57fc4d9fb0ba06c600408bd6634f2f1753c54a7a351c99c5e09b51ee2/numcodecs-0.16.5-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:aec9736d81b70f337d89c4070ee3ffeff113f386fd789492fa152d26a15043e4", size = 9102377, upload-time = "2025-11-21T02:49:45.508Z" }, - { url = "https://files.pythonhosted.org/packages/31/a5/a0425af36c20d55a3ea884db4b4efca25a43bea9214ba69ca7932dd997b4/numcodecs-0.16.5-cp314-cp314-win_amd64.whl", hash = "sha256:b16a14303800e9fb88abc39463ab4706c037647ac17e49e297faa5f7d7dbbf1d", size = 819022, upload-time = "2025-11-21T02:49:47.39Z" }, -] - [[package]] name = "numpy" version = "2.2.6" @@ -2847,63 +2748,14 @@ wheels = [ name = "zarr" version = "2.18.3" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version < '3.11'", -] dependencies = [ - { name = "asciitree", marker = "python_full_version < '3.11'" }, - { name = "fasteners", marker = "python_full_version < '3.11' and sys_platform != 'emscripten'" }, - { name = "numcodecs", version = "0.13.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "asciitree" }, + { name = "fasteners", marker = "sys_platform != 'emscripten'" }, + { name = "numcodecs" }, { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "numpy", version = "2.4.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/23/c4/187a21ce7cf7c8f00c060dd0e04c2a81139bb7b1ab178bba83f2e1134ce2/zarr-2.18.3.tar.gz", hash = "sha256:2580d8cb6dd84621771a10d31c4d777dca8a27706a1a89b29f42d2d37e2df5ce", size = 3603224, upload-time = "2024-09-04T23:20:16.595Z" } wheels = [ { url = "https://files.pythonhosted.org/packages/ed/c9/142095e654c2b97133ff71df60979422717b29738b08bc8a1709a5d5e0d0/zarr-2.18.3-py3-none-any.whl", hash = "sha256:b1f7dfd2496f436745cdd4c7bcf8d3b4bc1dceef5fdd0d589c87130d842496dd", size = 210723, upload-time = "2024-09-04T23:20:14.491Z" }, ] - -[[package]] -name = "zarr" -version = "3.1.6" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version == '3.11.*' and sys_platform == 'win32'", - "python_full_version == '3.11.*' and sys_platform == 'emscripten'", - "python_full_version == '3.11.*' and sys_platform != 'emscripten' and sys_platform != 'win32'", -] -dependencies = [ - { name = "donfig", marker = "python_full_version == '3.11.*'" }, - { name = "google-crc32c", marker = "python_full_version == '3.11.*'" }, - { name = "numcodecs", version = "0.16.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.11.*'" }, - { name = "numpy", version = "2.4.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.11.*'" }, - { name = "packaging", marker = "python_full_version == '3.11.*'" }, - { name = "typing-extensions", marker = "python_full_version == '3.11.*'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/31/5a/b8a0cf39a14c770c30bd1f2d120c54000c8cd9e84e8e79f38d9a7ce58071/zarr-3.1.6.tar.gz", hash = "sha256:d95e72cbea4b90e9a70679468b8266400331756232576ae2b43400ac5108d0eb", size = 386531, upload-time = "2026-03-23T17:25:18.748Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/de/7c/ba8ca8cbe9dbef8e83a95fc208fed8e6686c98b4719aaa0aa7f3d31fe390/zarr-3.1.6-py3-none-any.whl", hash = "sha256:b5a82c5079d1c3d4ee8f06746fa3b9a98a7d804300fa3f4be154362a33e1207e", size = 295655, upload-time = "2026-03-23T17:25:17.189Z" }, -] - -[[package]] -name = "zarr" -version = "3.2.1" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version >= '3.14' and sys_platform == 'win32'", - "python_full_version >= '3.14' and sys_platform == 'emscripten'", - "python_full_version >= '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32'", - "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform == 'win32'", - "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform == 'emscripten'", - "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32'", -] -dependencies = [ - { name = "donfig", marker = "python_full_version >= '3.12'" }, - { name = "google-crc32c", marker = "python_full_version >= '3.12'" }, - { name = "numcodecs", version = "0.16.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, - { name = "numpy", version = "2.4.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, - { name = "packaging", marker = "python_full_version >= '3.12'" }, - { name = "typing-extensions", marker = "python_full_version >= '3.12'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/93/8d/aeb164004f87543b06ef54f885d02c342c31ceb274e2bbec470a98927621/zarr-3.2.1.tar.gz", hash = "sha256:71565b738a0e7e8ed226f0516eba8c6bb53440ad7669a8c48ebb3534a161d035", size = 675161, upload-time = "2026-05-05T12:37:22.383Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/88/0a/469e2bd01be1490336e6c8707386845655d59261543315778a3ccc7e8019/zarr-3.2.1-py3-none-any.whl", hash = "sha256:f78cdd3d9687ad0e9f9cba2c5683b64f0c52589c19f685eeabe872e93cc0d2c7", size = 319617, upload-time = "2026-05-05T12:37:20.66Z" }, -] From 3c6ead23031b02dfdee65ee5fd790f53f9211643 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sun, 24 May 2026 00:48:47 +0200 Subject: [PATCH 38/39] Simplify Zarr zip docs --- docs/reading-and-writing.md | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/docs/reading-and-writing.md b/docs/reading-and-writing.md index efa80ad8..35c160d1 100644 --- a/docs/reading-and-writing.md +++ b/docs/reading-and-writing.md @@ -237,9 +237,8 @@ uv add "macrodata-refiner[zarr]" ``` `read_zarr(...)` reads one Zarr group, including directory stores and -`.zarr.zip` stores. Local zip stores use Zarr's native zip support; remote zip -stores are mounted through fsspec with block caching disabled. By default, the -group becomes one output row and selected arrays are loaded as full array values. +`.zarr.zip` stores. By default, the group becomes one output row and selected +arrays are loaded as full array values. ```python import refiner as mdr From 10627ae98e70fa71ed2cf2275fd283b5b876daad Mon Sep 17 00:00:00 2001 From: guipenedo Date: Sun, 24 May 2026 00:51:34 +0200 Subject: [PATCH 39/39] Trim Zarr docs and cover DataFolder URL roots --- docs/reading-and-writing.md | 6 ------ tests/io/test_datafile_datafolder.py | 11 +++++++++++ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/docs/reading-and-writing.md b/docs/reading-and-writing.md index 35c160d1..22ff835e 100644 --- a/docs/reading-and-writing.md +++ b/docs/reading-and-writing.md @@ -326,12 +326,6 @@ By default, split readers load one shard block at a time and slice logical rows from that block. Set `row_batch_size` to cap how many logical rows are loaded per block when a shard would otherwise materialize too much data. -`row_ends` is reader control metadata, not an output selection. If you also want -the raw offsets as a column in non-split mode, select that path through `arrays`. - -Missing selected arrays or attrs always raise. Zarr selections describe group -schema, not row-local optional fields. - ## Common Crawl text readers [Common Crawl](https://commoncrawl.org/) publishes large public web crawls. diff --git a/tests/io/test_datafile_datafolder.py b/tests/io/test_datafile_datafolder.py index 4b0e4345..a33a82ea 100644 --- a/tests/io/test_datafile_datafolder.py +++ b/tests/io/test_datafile_datafolder.py @@ -203,6 +203,17 @@ def test_datafolder_resolve_with_path_fs_tuple(tmp_path): assert folder.exists("out.txt") +def test_datafolder_resolve_strips_protocol_root_from_url(): + fs = MemoryFileSystem() + fs.pipe("/bucket/root/file.txt", b"ok") + + folder = DataFolder("memory://bucket/root") + + assert folder.path == "/bucket/root" + assert folder.abs_path() == "memory:///bucket/root" + assert folder.open("file.txt").read() == b"ok" + + class _CountingMemoryFS(MemoryFileSystem): def __init__(self): super().__init__()