diff --git a/docs/reading-and-writing.md b/docs/reading-and-writing.md index d0d76fa5..22ff835e 100644 --- a/docs/reading-and-writing.md +++ b/docs/reading-and-writing.md @@ -228,6 +228,104 @@ datasets or attributes default to raising an error. Set If a selected column can be missing from every group in an input file, pass `dtypes` for that column so the reader can emit a stable Arrow type. +## Zarr + +Zarr support lives behind the optional `macrodata-refiner[zarr]` extra. + +```bash +uv add "macrodata-refiner[zarr]" +``` + +`read_zarr(...)` reads one Zarr group, including directory stores and +`.zarr.zip` stores. By default, the group becomes one output row and selected +arrays are loaded as full array values. + +```python +import refiner as mdr + +pipeline = mdr.read_zarr( + "replay_buffer.zarr", + arrays={ + "action": "data/action", + "state": "data/state", + }, + attrs={"task": "task"}, +) +``` + +`arrays` and `attrs` accept the same selection forms as HDF5: a mapping from +output column name to Zarr path, one path string, or a sequence of path strings +with unique final components. Use a mapping when derived names would collide. + +For robotics-style replay buffers, pass `row_ends` to split concatenated arrays +into logical rows, usually episodes: + +```python +episodes = mdr.read_zarr( + "replay_buffer.zarr", + arrays={ + "action": "data/action", + "observation.state": "data/state", + "frames": "data/rgb", + }, + attrs={"task": "task"}, + row_ends="meta/episode_ends", + index_column="episode_id", + file_path_column=None, +) +``` + +For a store shaped like: + +```text +replay_buffer.zarr +├── data +│ ├── action # shape [total_steps, action_dim] +│ ├── state # shape [total_steps, state_dim] +│ └── rgb # shape [total_steps, height, width, channels] +└── meta + └── episode_ends # cumulative end offsets, for example [152, 319, 477] +``` + +this emits one row per `[start:end]` slice. The selected arrays are sliced along +their leading dimension, while selected attrs are repeated on each row. +`index_column` receives the row/episode index when `row_ends` is set. Set it to +`None` to omit that metadata. The final row end must match the leading dimension +of every selected array. + +If a Zarr store has aligned arrays but no episode boundaries, use +`split_leading_axis=True` to emit fixed-size rows along the leading axis: + +```python +rows = mdr.read_zarr( + "replay_buffer.zarr", + arrays={ + "action": "data/action", + "frames": "data/rgb", + }, + split_leading_axis=True, + leading_axis_row_size=1, + target_shard_bytes=128 * 1024**2, +) +``` + +Shard planning in this mode uses chunk metadata from the selected array with the +largest per-row byte size (`dtype.itemsize * product(shape[1:])`). This keeps +large image/video arrays in control of shard boundaries, so tiny action/state +arrays stored as one huge chunk do not force Refiner to load a much larger image +block than necessary. + +This mode requires selected arrays to have the same leading dimension, and that +dimension must be divisible by `leading_axis_row_size`. Each output row contains +`leading_axis_row_size` contiguous items from every selected array. Refiner plans +shards from array metadata and tries to keep shard boundaries aligned with the +dominant array's leading-axis chunks. Use `num_shards` when you need a target +shard count instead of byte-sized packing. + +By default, split readers load one shard block at a time and slice logical rows +from that block. Set `row_batch_size` to cap how many logical rows are loaded per +block when a shard would otherwise materialize too much data. + ## Common Crawl text readers [Common Crawl](https://commoncrawl.org/) publishes large public web crawls. diff --git a/docs/robotics_conversion.md b/docs/robotics_conversion.md index b23ac141..84dfe3e0 100644 --- a/docs/robotics_conversion.md +++ b/docs/robotics_conversion.md @@ -77,3 +77,43 @@ pipeline = ( ) ) ``` + +## Zarr Replay Buffers + +Some robotics replay buffers are stored as unzipped Zarr directory stores with +frame-aligned arrays under `data/` and cumulative episode boundaries under +`meta/episode_ends`. + +Reference datasets: + +- RoboCasa MT4 N216: + `hf://datasets/ahad-j/robocasa_mt4_N216_zarr/mt4_N216.zarr` + (`https://huggingface.co/datasets/ahad-j/robocasa_mt4_N216_zarr`) +- MetaWorld MT4 N200: + `hf://datasets/runningkiwi/metaworld_mt4_n200_zarr` + (`https://huggingface.co/datasets/runningkiwi/metaworld_mt4_n200_zarr`) + +```python +import refiner as mdr + +pipeline = ( + mdr.read_zarr( + "hf://datasets/ahad-j/robocasa_mt4_N216_zarr/mt4_N216.zarr", + arrays={ + "action": "data/action", + "eef_pos": "data/robot0_eef_pos", + "joint_pos": "data/robot0_joint_pos", + "gripper_qpos": "data/robot0_gripper_qpos", + "wrist": "data/robot0_eye_in_hand_rgb", + }, + row_ends="meta/episode_ends", + index_column="episode_id", + ) + .to_robot_rows( + episode_id_key="episode_id", + action_key="action", + state_key=("eef_pos", "joint_pos", "gripper_qpos"), + video_keys=("wrist",), + ) +) +``` diff --git a/pyproject.toml b/pyproject.toml index b359e587..4001aa0c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,10 @@ text = [ hdf5 = [ "h5py", ] +zarr = [ + "zarr>=2.18,<3", + "numcodecs<0.16", +] s3 = [ "s3fs", ] @@ -51,6 +55,7 @@ testing = [ "macrodata-refiner[huggingface]", "macrodata-refiner[hdf5]", "macrodata-refiner[robotics]", + "macrodata-refiner[zarr]", "macrodata-refiner[text]", "macrodata-refiner[s3]", "pytest>=8.0.0", diff --git a/src/refiner/__init__.py b/src/refiner/__init__.py index fd298b97..bd74ec19 100644 --- a/src/refiner/__init__.py +++ b/src/refiner/__init__.py @@ -20,6 +20,7 @@ read_lerobot, read_parquet, read_videos, + read_zarr, SUPPORTED_CUDA_VERSIONS, SUPPORTED_GPU_TYPES, task, @@ -54,6 +55,7 @@ "read_lerobot", "read_parquet", "read_videos", + "read_zarr", "from_items", "from_source", "task", diff --git a/src/refiner/io/datafolder.py b/src/refiner/io/datafolder.py index 501c3bd0..70e0ac4b 100644 --- a/src/refiner/io/datafolder.py +++ b/src/refiner/io/datafolder.py @@ -40,10 +40,14 @@ def __init__( auto_mkdir: if True, when opening a file in write mode its parent directories will be automatically created **storage_options: will be passed to a new fsspec filesystem object, when it is created. Ignored if fs is given """ - super().__init__( - path=path, - fs=fs if fs is not None else url_to_fs(path, **storage_options)[0], - ) + if fs is None: + fs, path = url_to_fs(path, **storage_options) + path = "/" if path is None else path + else: + path = fs._strip_protocol(path) + if path == "": + path = "/" + super().__init__(path=path, fs=fs) self.auto_mkdir = auto_mkdir @classmethod diff --git a/src/refiner/pipeline/__init__.py b/src/refiner/pipeline/__init__.py index de41089d..a37de383 100644 --- a/src/refiner/pipeline/__init__.py +++ b/src/refiner/pipeline/__init__.py @@ -13,6 +13,7 @@ read_lerobot, read_parquet, read_videos, + read_zarr, task, ) from refiner.pipeline.resources import ( @@ -41,6 +42,7 @@ "read_lerobot", "read_parquet", "read_videos", + "read_zarr", "from_items", "from_source", "task", diff --git a/src/refiner/pipeline/pipeline.py b/src/refiner/pipeline/pipeline.py index 611b0f6f..6ace95d2 100644 --- a/src/refiner/pipeline/pipeline.py +++ b/src/refiner/pipeline/pipeline.py @@ -40,9 +40,10 @@ Hdf5Reader, JsonReader, ParquetReader, + ZarrReader, ) +from refiner.pipeline.sources.readers.hdf5 import MissingPolicy from refiner.pipeline.sources.readers.lerobot import LeRobotEpisodeReader -from refiner.pipeline.sources.readers.hdf5 import MissingPolicy, PathSelection from refiner.pipeline.sources.items import ItemsSource from refiner.pipeline.sources.task import TaskSource from refiner.pipeline.data import datatype @@ -59,7 +60,10 @@ ) from refiner.execution.operators.row import ShardDeltaFn from refiner.pipeline.sources.base import SourceUnit -from refiner.pipeline.sources.readers.utils import DEFAULT_TARGET_SHARD_BYTES +from refiner.pipeline.sources.readers.utils import ( + DEFAULT_TARGET_SHARD_BYTES, + PathSelection, +) import pyarrow as pa if TYPE_CHECKING: @@ -176,7 +180,6 @@ def to_robot_rows( video_keys: Mapping[str, str] | Iterable[str] | None = None, stats_key: str | None = "stats", stats_prefix: str = "stats/", - episode_ends_key: str | None = None, ) -> "RefinerPipeline": """Expose rows through the RoboticsRow semantic view. @@ -207,11 +210,8 @@ def to_robot_rows( schema=self.output_schema(), stats_key=stats_key, stats_prefix=stats_prefix, - episode_ends_key=episode_ends_key, ) - if episode_ends_key is None: - return self.map(cast(MapFn, converter)) - return self.flat_map(cast(FlatMapFn, converter)) + return self.map(cast(MapFn, converter)) def map_async( self, @@ -809,6 +809,51 @@ def read_hdf5( ) +def read_zarr( + input: DataFolderLike, + *, + arrays: PathSelection | None = None, + attrs: PathSelection | None = None, + row_ends: str | None = None, + split_leading_axis: bool = False, + leading_axis_row_size: int = 1, + target_shard_bytes: int = DEFAULT_TARGET_SHARD_BYTES, + num_shards: int | None = None, + row_batch_size: int | None = None, + index_column: str | None = "index", + file_path_column: str | None = "file_path", + dtypes: DTypeMapping | None = None, +) -> RefinerPipeline: + """Create a pipeline with a Zarr reader source. + + The reader has three modes: + - group mode: one Zarr group becomes one row + - row_ends mode: cumulative offsets define whole-row source slices + - split_leading_axis mode: fixed-size leading-axis slices define output rows + + Missing selected arrays or attributes raise immediately. `row_ends` and + `split_leading_axis` are mutually exclusive. `target_shard_bytes` and + `num_shards` affect shard planning, not logical row size. `row_batch_size` + bounds how many logical rows are loaded per array block within each shard. + """ + return RefinerPipeline( + source=ZarrReader( + input, + arrays=arrays, + attrs=attrs, + row_ends=row_ends, + split_leading_axis=split_leading_axis, + leading_axis_row_size=leading_axis_row_size, + target_shard_bytes=target_shard_bytes, + num_shards=num_shards, + row_batch_size=row_batch_size, + index_column=index_column, + file_path_column=file_path_column, + dtypes=dtypes, + ) + ) + + def read_parquet( inputs: DataFileSetLike, *, diff --git a/src/refiner/pipeline/sources/__init__.py b/src/refiner/pipeline/sources/__init__.py index 4faa9083..d9aa4ee6 100644 --- a/src/refiner/pipeline/sources/__init__.py +++ b/src/refiner/pipeline/sources/__init__.py @@ -8,6 +8,7 @@ JsonReader, LeRobotEpisodeReader, ParquetReader, + ZarrReader, ) __all__ = [ @@ -20,4 +21,5 @@ "JsonReader", "LeRobotEpisodeReader", "ParquetReader", + "ZarrReader", ] diff --git a/src/refiner/pipeline/sources/readers/__init__.py b/src/refiner/pipeline/sources/readers/__init__.py index 1304d351..00da8fe1 100644 --- a/src/refiner/pipeline/sources/readers/__init__.py +++ b/src/refiner/pipeline/sources/readers/__init__.py @@ -6,6 +6,7 @@ from refiner.pipeline.sources.readers.json import JsonReader from refiner.pipeline.sources.readers.lerobot import LeRobotEpisodeReader from refiner.pipeline.sources.readers.parquet import ParquetReader +from refiner.pipeline.sources.readers.zarr import ZarrReader from refiner.robotics.lerobot_format import LeRobotRow __all__ = [ @@ -18,4 +19,5 @@ "LeRobotEpisodeReader", "LeRobotRow", "ParquetReader", + "ZarrReader", ] diff --git a/src/refiner/pipeline/sources/readers/hdf5.py b/src/refiner/pipeline/sources/readers/hdf5.py index 0bbef12b..e311f691 100644 --- a/src/refiner/pipeline/sources/readers/hdf5.py +++ b/src/refiner/pipeline/sources/readers/hdf5.py @@ -3,7 +3,7 @@ from collections.abc import Iterator, Mapping, Sequence import fnmatch from glob import has_magic -from typing import Any, Literal, cast +from typing import Any, Literal from fsspec import AbstractFileSystem @@ -13,58 +13,16 @@ from refiner.pipeline.data.row import DictRow from refiner.pipeline.data.shard import FilePartsDescriptor from refiner.pipeline.sources.readers.base import BaseReader, Shard, SourceUnit -from refiner.pipeline.sources.readers.utils import DEFAULT_TARGET_SHARD_BYTES +from refiner.pipeline.sources.readers.utils import ( + DEFAULT_TARGET_SHARD_BYTES, + PathSelection, + decode_value, + path_selection_map, +) from refiner.utils import check_required_dependencies MissingPolicy = Literal["error", "drop_row", "set_null"] -PathSelection = Mapping[str, str] | Sequence[str] | str - - -def _decode_value( - value: Any, - *, - decode_bytes: bool = True, - preserve_arrays: bool = False, -) -> Any: - if isinstance(value, bytes): - if not decode_bytes: - return value - try: - return value.decode("utf-8") - except UnicodeDecodeError: - return value - if isinstance(value, str) and any("\udc80" <= char <= "\udcff" for char in value): - return value.encode("utf-8", errors="surrogateescape") - if hasattr(value, "shape") and value.shape == (): - return _decode_value( - value.item(), - decode_bytes=decode_bytes, - preserve_arrays=preserve_arrays, - ) - if hasattr(value, "tolist"): - if preserve_arrays and getattr( - getattr(value, "dtype", None), "kind", None - ) not in ( - "O", - "S", - ): - return value - return _decode_value( - value.tolist(), - decode_bytes=decode_bytes, - preserve_arrays=preserve_arrays, - ) - if isinstance(value, list): - return [ - _decode_value( - item, - decode_bytes=decode_bytes, - preserve_arrays=preserve_arrays, - ) - for item in value - ] - return value class Hdf5Reader(BaseReader): @@ -126,8 +84,8 @@ def __init__( raise ValueError( "groups accepts a single glob string or a list of exact group paths" ) - self.datasets = self._mapping(datasets) - self.attrs = self._mapping(attrs) + self.datasets = path_selection_map(datasets, format_name="HDF5") + self.attrs = path_selection_map(attrs, format_name="HDF5") self.group_path_column = group_path_column self.missing_policy = missing_policy if missing_policy not in ("error", "drop_row", "set_null"): @@ -136,27 +94,6 @@ def __init__( ) self._validate_column_names() - @staticmethod - def _mapping( - value: PathSelection | None, - ) -> dict[str, str]: - if value is None: - return {} - if isinstance(value, str): - return {value.rsplit("/", 1)[-1]: value} - if isinstance(value, Mapping): - return dict(cast(Mapping[str, str], value)) - out: dict[str, str] = {} - for path in value: - name = path.rsplit("/", 1)[-1] - if name in out: - raise ValueError( - "HDF5 path selections must have unique derived column names; " - f"use an explicit mapping for duplicate name {name!r}" - ) - out[name] = path - return out - def describe(self) -> dict[str, Any]: description = super().describe() description.update( @@ -304,7 +241,7 @@ def _read_group( raise TypeError( f"HDF5 path under {group_path} is not a dataset: {dataset_path}" ) - row[output_name] = _decode_value( + row[output_name] = decode_value( dataset[()], decode_bytes=dataset.dtype.kind != "S", preserve_arrays=True, @@ -318,7 +255,7 @@ def _read_group( row[output_name] = None continue raise KeyError(f"HDF5 attr not found on {group_path}: {attr_name}") - row[output_name] = _decode_value(group.attrs[attr_name]) + row[output_name] = decode_value(group.attrs[attr_name]) return self._with_file_path(row, source) diff --git a/src/refiner/pipeline/sources/readers/utils.py b/src/refiner/pipeline/sources/readers/utils.py index 3860a9e5..9307082c 100644 --- a/src/refiner/pipeline/sources/readers/utils.py +++ b/src/refiner/pipeline/sources/readers/utils.py @@ -1,17 +1,89 @@ from __future__ import annotations import io -from typing import Optional +from collections.abc import Mapping, Sequence +from typing import Any, Optional +from typing import cast from fsspec import AbstractFileSystem DEFAULT_TARGET_SHARD_BYTES = 128 * 1024 * 1024 +PathSelection = Mapping[str, str] | Sequence[str] | str # Extensions that generally imply whole-file/container compression (not safely splittable by byte offsets). NON_SPLITTABLE_WHOLEFILE_EXTS = (".gz", ".bz2", ".xz", ".zip", ".zst") +def decode_value( + value: Any, + *, + decode_bytes: bool = True, + preserve_arrays: bool = False, +) -> Any: + if isinstance(value, bytes): + if not decode_bytes: + return value + try: + return value.decode("utf-8") + except UnicodeDecodeError: + return value + if isinstance(value, str) and any("\udc80" <= char <= "\udcff" for char in value): + return value.encode("utf-8", errors="surrogateescape") + if hasattr(value, "shape") and value.shape == (): + return decode_value( + value.item(), + decode_bytes=decode_bytes, + preserve_arrays=preserve_arrays, + ) + if hasattr(value, "tolist"): + if preserve_arrays and getattr( + getattr(value, "dtype", None), "kind", None + ) not in ( + "O", + "S", + ): + return value + return decode_value( + value.tolist(), + decode_bytes=decode_bytes, + preserve_arrays=preserve_arrays, + ) + if isinstance(value, list): + return [ + decode_value( + item, + decode_bytes=decode_bytes, + preserve_arrays=preserve_arrays, + ) + for item in value + ] + return value + + +def path_selection_map( + value: PathSelection | None, + *, + format_name: str, +) -> dict[str, str]: + if value is None: + return {} + if isinstance(value, str): + return {value.rsplit("/", 1)[-1]: value} + if isinstance(value, Mapping): + return dict(cast(Mapping[str, str], value)) + out: dict[str, str] = {} + for path in value: + name = path.rsplit("/", 1)[-1] + if name in out: + raise ValueError( + f"{format_name} path selections must have unique derived column names; " + f"use an explicit mapping for duplicate name {name!r}" + ) + out[name] = path + return out + + def is_splittable_by_bytes(fs: AbstractFileSystem, path: str) -> bool: """Return True if the input can be safely sharded by byte offsets.""" lp = path.lower() @@ -98,6 +170,9 @@ def readinto(self, b) -> int: __all__ = [ "DEFAULT_TARGET_SHARD_BYTES", "NON_SPLITTABLE_WHOLEFILE_EXTS", + "PathSelection", + "decode_value", + "path_selection_map", "is_splittable_by_bytes", "align_byte_range_to_newlines", "BoundedBinaryReader", diff --git a/src/refiner/pipeline/sources/readers/zarr.py b/src/refiner/pipeline/sources/readers/zarr.py new file mode 100644 index 00000000..91787107 --- /dev/null +++ b/src/refiner/pipeline/sources/readers/zarr.py @@ -0,0 +1,575 @@ +from __future__ import annotations + +from collections.abc import Iterator, Mapping +from contextlib import contextmanager +from math import ceil, prod +from operator import index as integer_index +from os import PathLike +from typing import Any + +from fsspec import AbstractFileSystem +from fsspec.implementations.zip import ZipFileSystem +import pyarrow as pa + +from refiner.io.datafile import DataFile, DataFileLike +from refiner.io.datafolder import DataFolder, DataFolderLike +from refiner.pipeline.data.datatype import ( + DTypeMapping, + dtype_to_plan, + schema_with_dtypes, +) +from refiner.pipeline.data.row import DictRow +from refiner.pipeline.data.shard import RowRangeDescriptor, Shard +from refiner.pipeline.sources.base import BaseSource, SourceUnit +from refiner.pipeline.sources.readers.utils import ( + DEFAULT_TARGET_SHARD_BYTES, + PathSelection, + decode_value, + path_selection_map, +) +from refiner.utils import check_required_dependencies + + +class ZarrReader(BaseSource): + """Read a Zarr group as one row, episode rows, or leading-axis rows.""" + + name = "read_zarr" + + def __init__( + self, + input: DataFolderLike, + *, + arrays: PathSelection | None = None, + attrs: PathSelection | None = None, + row_ends: str | None = None, + split_leading_axis: bool = False, + leading_axis_row_size: int = 1, + target_shard_bytes: int = DEFAULT_TARGET_SHARD_BYTES, + num_shards: int | None = None, + row_batch_size: int | None = None, + index_column: str | None = "index", + file_path_column: str | None = "file_path", + dtypes: DTypeMapping | None = None, + ): + """Create a Zarr reader. + + Args: + input: Zarr group path. + arrays: Array selections as output-name to Zarr-path mapping, a + single path, a path sequence, or None to discover all arrays. + attrs: Attribute selections with the same shape as ``arrays``. + row_ends: Optional Zarr array path containing cumulative end offsets. + When set, emitted rows are whole source ranges and never split + across these boundaries. + split_leading_axis: Emit fixed-size leading-axis rows when no + ``row_ends`` path is provided. + leading_axis_row_size: Number of leading-axis items in each logical + row when ``split_leading_axis`` is enabled. + target_shard_bytes: Approximate byte target used to pack logical + rows into shards in split modes. + num_shards: Optional target shard count for split modes. + row_batch_size: Optional maximum number of logical rows to load per + array block within a shard. None loads one whole shard block. + index_column: Output metadata column containing the logical row + index in split modes, or None to omit it. + file_path_column: Output metadata column containing the source path, + or None to omit it. + dtypes: Optional dtype overrides for output columns. + """ + zip_input: DataFileLike | None = None + if isinstance(input, DataFolder) and input.path.endswith(".zip"): + zip_input = (input.path, input.fs) + else: + if isinstance(input, PathLike): + input = str(input) + if isinstance(input, str) and input.endswith(".zip"): + zip_input = input + elif ( + isinstance(input, tuple) + and len(input) == 2 + and isinstance(input[1], AbstractFileSystem) + ): + path = input[0] + if isinstance(path, PathLike): + path = str(path) + if isinstance(path, str) and path.endswith(".zip"): + zip_input = (path, input[1]) + + if zip_input is not None: + self.zip_file = DataFile.resolve(zip_input) + self.root = None + self.source_path = self.zip_file.abs_path() + else: + self.zip_file = None + self.root = DataFolder.resolve(input) + self.source_path = self.root.abs_path() + check_required_dependencies("read_zarr", ["zarr"], dist="zarr") + if row_ends is not None and split_leading_axis: + raise ValueError("row_ends and split_leading_axis are mutually exclusive") + if leading_axis_row_size <= 0: + raise ValueError("leading_axis_row_size must be greater than zero") + if leading_axis_row_size != 1 and not split_leading_axis: + raise ValueError("leading_axis_row_size requires split_leading_axis=True") + if target_shard_bytes <= 0: + raise ValueError("target_shard_bytes must be greater than zero") + if num_shards is not None and num_shards <= 0: + raise ValueError("num_shards must be greater than zero") + if row_batch_size is not None and row_batch_size <= 0: + raise ValueError("row_batch_size must be greater than zero") + self.arrays = ( + None if arrays is None else path_selection_map(arrays, format_name="Zarr") + ) + self.attrs = ( + None if attrs is None else path_selection_map(attrs, format_name="Zarr") + ) + if row_ends is not None and self.arrays is not None: + for output_name, path in self.arrays.items(): + if path == row_ends: + raise ValueError( + f"Zarr array selection {output_name!r} cannot also be row_ends" + ) + self.row_ends = row_ends + self.split_leading_axis = split_leading_axis + self.leading_axis_row_size = leading_axis_row_size + self.target_shard_bytes = target_shard_bytes + self.num_shards = num_shards + self.row_batch_size = row_batch_size + self.index_column = index_column + self.file_path_column = file_path_column + self.dtypes = dtypes + _validate_output_names( + self.arrays or {}, + self.attrs or {}, + reserved=self._reserved_output_names( + split=row_ends is not None or split_leading_axis + ), + ) + if ( + (row_ends is not None or split_leading_axis) + and file_path_column is not None + and file_path_column == index_column + ): + raise ValueError("file_path_column and index_column must be distinct") + + @property + def schema(self) -> pa.Schema | None: + return schema_with_dtypes(None, self.dtypes) + + def describe(self) -> dict[str, Any]: + return { + "path": self.source_path, + "arrays": dict(self.arrays) if self.arrays is not None else None, + "attrs": dict(self.attrs) if self.attrs is not None else None, + "row_ends": self.row_ends, + "split_leading_axis": self.split_leading_axis, + "leading_axis_row_size": self.leading_axis_row_size, + "target_shard_bytes": self.target_shard_bytes, + "num_shards": self.num_shards, + "row_batch_size": self.row_batch_size, + "index_column": self.index_column, + "file_path_column": self.file_path_column, + "dtypes": ( + {key: dtype_to_plan(dtype) for key, dtype in self.dtypes.items()} + if self.dtypes + else None + ), + } + + def list_shards(self) -> list[Shard]: + with self._open_group() as group: + arrays = self._selected_arrays(group) + split_ranges = self._shard_ranges(group, arrays) + return [ + Shard.from_row_range( + start=start, + end=end, + global_ordinal=index, + start_key=self.source_path, + end_key=self.source_path, + ) + for index, (start, end) in enumerate(split_ranges) + ] + + def read_shard(self, shard: Shard) -> Iterator[SourceUnit]: + with self._open_group() as group: + arrays = self._selected_arrays(group) + if self.row_ends is not None: + descriptor = shard.descriptor + assert isinstance(descriptor, RowRangeDescriptor) + source_ranges = self._row_end_ranges( + group, + arrays, + row_start=descriptor.start, + row_end=descriptor.end, + ) + if not source_ranges: + return + attrs = self._read_attrs(group) + batch_size = self.row_batch_size or len(source_ranges) + for batch_offset in range(0, len(source_ranges), batch_size): + batch = source_ranges[batch_offset : batch_offset + batch_size] + block_start = batch[0][0] + block_end = batch[-1][1] + block = self._read_arrays(arrays, start=block_start, end=block_end) + for offset, (start, end) in enumerate(batch, start=batch_offset): + row = self._row_metadata(index=descriptor.start + offset) + row.update( + { + name: value[start - block_start : end - block_start] + for name, value in block.items() + } + ) + row.update(attrs) + yield DictRow(row) + return + + if self.split_leading_axis: + descriptor = shard.descriptor + assert isinstance(descriptor, RowRangeDescriptor) + attrs = self._read_attrs(group) + batch_size = self.row_batch_size or descriptor.end - descriptor.start + for batch_start in range(descriptor.start, descriptor.end, batch_size): + batch_end = min(batch_start + batch_size, descriptor.end) + raw_start = batch_start * self.leading_axis_row_size + raw_end = batch_end * self.leading_axis_row_size + block = self._read_arrays( + arrays, + start=raw_start, + end=raw_end, + ) + for row_index in range(batch_start, batch_end): + offset = (row_index - batch_start) * self.leading_axis_row_size + row = self._row_metadata(index=row_index) + row.update( + { + name: value[ + offset : offset + self.leading_axis_row_size + ] + for name, value in block.items() + } + ) + row.update(attrs) + yield DictRow(row) + return + + row = self._row_metadata(index=None) + row.update(self._read_arrays(arrays)) + row.update(self._read_attrs(group)) + yield DictRow(row) + + @contextmanager + def _open_group(self) -> Any: + import zarr + import zarr.storage + + if self.zip_file is not None: + handle = None + zip_fs = None + if self.zip_file.is_local: + store = zarr.ZipStore(self.zip_file.abs_path(), mode="r") + else: + handle = self.zip_file.open("rb", cache_type="none") + zip_fs = ZipFileSystem(fo=handle, mode="r") + store = zarr.storage.FSStore( + "/", + fs=zip_fs, + mode="r", + ) + try: + yield zarr.open_group(store=store, mode="r") + finally: + close_store = getattr(store, "close", None) + if callable(close_store): + close_store() + if zip_fs is not None: + zip_fs.close() + if handle is not None: + handle.close() + else: + assert self.root is not None + store = zarr.storage.FSStore(self.root._join(""), fs=self.root.fs, mode="r") + yield zarr.open_group(store=store, mode="r") + + def _reserved_output_names(self, *, split: bool) -> set[str]: + names = set() + if self.file_path_column is not None: + names.add(self.file_path_column) + if split and self.index_column is not None: + names.add(self.index_column) + return names + + def _row_metadata(self, *, index: int | None) -> dict[str, Any]: + row: dict[str, Any] = {} + if self.file_path_column is not None: + row[self.file_path_column] = self.source_path + if self.index_column is not None and index is not None: + row[self.index_column] = index + return row + + def _selected_arrays(self, group: Any) -> dict[str, Any]: + if self.arrays is None: + paths = { + path: path for path in _iter_array_paths(group) if path != self.row_ends + } + _validate_output_names( + paths, + self.attrs or {}, + reserved=self._reserved_output_names( + split=self.row_ends is not None or self.split_leading_axis + ), + ) + else: + paths = self.arrays + + arrays: dict[str, Any] = {} + for output_name, path in paths.items(): + try: + arrays[output_name] = group[path] + except KeyError: + raise KeyError(f"Zarr array not found: {path}") from None + return arrays + + def _row_end_ranges( + self, + group: Any, + arrays: Mapping[str, Any], + *, + row_start: int, + row_end: int, + ) -> list[tuple[int, int]]: + if row_end < row_start: + raise ValueError("Zarr shard row range is invalid") + if row_start == row_end: + return [] + row_ends_array = self._row_ends_array(group) + read_start = max(0, row_start - 1) + values = [ + _row_end_offset(value) for value in row_ends_array[read_start:row_end] + ] + if len(values) != row_end - read_start: + raise ValueError("Zarr shard row range exceeds row_ends length") + ranges: list[tuple[int, int]] = [] + start = 0 if row_start == 0 else values[0] + for end in values if row_start == 0 else values[1:]: + if end < start: + raise ValueError("Zarr row_ends must be monotonic increasing") + ranges.append((start, end)) + start = end + _check_final_end(arrays, ranges[-1][1], label="row_ends") + return ranges + + def _shard_ranges( + self, + group: Any, + arrays: Mapping[str, Any], + ) -> list[tuple[int, int]]: + if self.row_ends is None and not self.split_leading_axis: + return [(0, 1)] + + if self.split_leading_axis: + if not arrays: + raise ValueError( + "split_leading_axis requires at least one selected array" + ) + lengths: set[int] = set() + for array in arrays.values(): + if not array.shape: + raise ValueError( + "Zarr selected arrays must have a leading dimension to split" + ) + lengths.add(int(array.shape[0])) + if len(lengths) != 1: + raise ValueError( + "Zarr selected arrays must have the same leading dimension" + ) + length = lengths.pop() + if length == 0: + return [] + if length % self.leading_axis_row_size != 0: + raise ValueError("Zarr leading dimension must be divisible by row size") + row_count = length // self.leading_axis_row_size + if self.num_shards is not None: + step = ceil(row_count / self.num_shards) + else: + item_bytes = [ + (array, _leading_item_bytes(array)) for array in arrays.values() + ] + bytes_per_row = ( + sum(bytes_count for _, bytes_count in item_bytes) + * self.leading_axis_row_size + ) + target_rows = max(1, self.target_shard_bytes // max(1, bytes_per_row)) + largest_item_bytes = max(bytes_count for _, bytes_count in item_bytes) + chunk_rows = max( + 1, + ceil( + max( + int(array.chunks[0]) + if array.chunks + else int(array.shape[0]) + for array, bytes_count in item_bytes + if bytes_count == largest_item_bytes + ) + / self.leading_axis_row_size + ), + ) + step = max(chunk_rows, (target_rows // chunk_rows) * chunk_rows) + return [ + (start, min(start + step, row_count)) + for start in range(0, row_count, step) + ] + + row_ends_array = self._row_ends_array(group) + row_count = int(row_ends_array.shape[0]) + if row_count == 0: + _check_final_end(arrays, 0, label="row_ends", exact=True) + return [] + if self.num_shards is not None or not arrays: + final_end = _validate_row_ends(row_ends_array) + _check_final_end(arrays, final_end, label="row_ends", exact=True) + if self.num_shards is None: + return [(0, row_count)] + step = ceil(row_count / self.num_shards) + return [ + (start, min(start + step, row_count)) + for start in range(0, row_count, step) + ] + + bytes_per_step = sum(_leading_item_bytes(array) for array in arrays.values()) + ranges: list[tuple[int, int]] = [] + shard_start = 0 + current_bytes = 0 + previous_end = 0 + for row_index, end in _iter_row_ends(row_ends_array): + if end < previous_end: + raise ValueError("Zarr row_ends must be monotonic increasing") + row_bytes = max(1, end - previous_end) * bytes_per_step + if ( + row_index > shard_start + and current_bytes + row_bytes > self.target_shard_bytes + ): + ranges.append((shard_start, row_index)) + shard_start = row_index + current_bytes = 0 + current_bytes += row_bytes + previous_end = end + ranges.append((shard_start, row_count)) + _check_final_end(arrays, previous_end, label="row_ends", exact=True) + return ranges + + def _row_ends_array(self, group: Any) -> Any: + try: + row_ends_array = group[self.row_ends] + except KeyError: + raise KeyError(f"Zarr row_ends array not found: {self.row_ends}") from None + if len(row_ends_array.shape) != 1: + raise ValueError("Zarr row_ends must be one-dimensional") + return row_ends_array + + def _read_arrays( + self, + arrays: Mapping[str, Any], + *, + start: int | None = None, + end: int | None = None, + ) -> dict[str, Any]: + row: dict[str, Any] = {} + for output_name, array in arrays.items(): + if start is not None: + row[output_name] = array[start:end] + elif array.shape == (): + row[output_name] = array[()] + else: + row[output_name] = array[:] + return row + + def _read_attrs(self, group: Any) -> dict[str, Any]: + attrs: dict[str, Any] = {} + for output_name, attr_name in (self.attrs or {}).items(): + if attr_name not in group.attrs: + raise KeyError(f"Zarr attr not found: {attr_name}") + attrs[output_name] = decode_value(group.attrs[attr_name]) + return attrs + + +def _validate_output_names( + arrays: Mapping[str, str], + attrs: Mapping[str, str], + *, + reserved: set[str] | None = None, +) -> None: + duplicates = set(arrays).intersection(attrs) + if duplicates: + names = ", ".join(sorted(repr(name) for name in duplicates)) + raise ValueError(f"Zarr arrays and attrs use duplicate output names: {names}") + reserved_matches = set(arrays).union(attrs).intersection(reserved or set()) + if reserved_matches: + names = ", ".join(sorted(repr(name) for name in reserved_matches)) + raise ValueError(f"Zarr selections use reserved output names: {names}") + + +def _iter_row_ends(array: Any) -> Iterator[tuple[int, int]]: + chunk = max(1, int(array.chunks[0]) if array.chunks else 8192) + length = int(array.shape[0]) + for start in range(0, length, chunk): + for offset, value in enumerate(array[start : min(start + chunk, length)]): + yield start + offset, _row_end_offset(value) + + +def _row_end_offset(value: Any) -> int: + if isinstance(value, bool): + raise ValueError("Zarr row_ends must contain integer offsets") + try: + return integer_index(value) + except TypeError: + raise ValueError("Zarr row_ends must contain integer offsets") from None + + +def _validate_row_ends(array: Any) -> int: + previous_end = 0 + for _, end in _iter_row_ends(array): + if end < previous_end: + raise ValueError("Zarr row_ends must be monotonic increasing") + previous_end = end + return previous_end + + +def _check_final_end( + arrays: Mapping[str, Any], + final_end: int, + *, + label: str, + exact: bool = False, +) -> None: + for output_name, array in arrays.items(): + if not array.shape: + raise ValueError( + f"Zarr selected array {output_name!r} must have a leading dimension" + ) + leading_length = int(array.shape[0]) + if final_end > leading_length: + raise ValueError( + f"Zarr {label} exceed leading dimension for {output_name!r}" + ) + if exact and final_end != leading_length: + raise ValueError( + f"Zarr {label} end before leading dimension for {output_name!r}" + ) + + +def _leading_item_bytes(array: Any) -> int: + trailing_shape = tuple(int(value) for value in array.shape[1:]) + return max(1, int(array.dtype.itemsize) * int(prod(trailing_shape or (1,)))) + + +def _iter_array_paths(group: Any, prefix: str = "") -> Iterator[str]: + items = group.items() if hasattr(group, "items") else group.members() + for name, item in items: + path = f"{prefix}/{name}" if prefix else name + if hasattr(item, "shape"): + yield path + else: + yield from _iter_array_paths(item, path) + + +__all__ = ["ZarrReader"] diff --git a/src/refiner/robotics/row.py b/src/refiner/robotics/row.py index a3df7ba4..8bf0a971 100644 --- a/src/refiner/robotics/row.py +++ b/src/refiner/robotics/row.py @@ -551,36 +551,8 @@ def _robot_row_converter( video_keys: Mapping[str, str] | Iterable[str] | None = None, stats_key: str | None = "stats", stats_prefix: str = "stats/", - episode_ends_key: str | None = None, schema: pa.Schema | None = None, -) -> Callable[[Row], Row] | Callable[[Row], Iterable[Row]]: - if episode_ends_key is not None: - - def split_row(row: Row) -> Iterable[Row]: - return cast( - Iterable[Row], - _rows_from_episode_ends( - row, - episode_id_key=episode_id_key, - task_key=task_key, - fps=fps, - fps_key=fps_key, - robot_type=robot_type, - robot_type_key=robot_type_key, - timestamp_key=timestamp_key, - action_key=action_key, - state_key=state_key, - extra_observation_keys=extra_observation_keys, - video_keys=video_keys, - schema=schema, - stats_key=stats_key, - stats_prefix=stats_prefix, - episode_ends_key=episode_ends_key, - ), - ) - - return split_row - +) -> Callable[[Row], Row]: spec = _RoboticsRowSpec.from_options( episode_id_key=episode_id_key, task_key=task_key, @@ -616,90 +588,6 @@ def _valid_nested_frames_key(row: Row, key: str | None) -> str | None: return key if isinstance(value, Sequence) else None -def _rows_from_episode_ends( - row: Row, - *, - episode_id_key: str | None, - task_key: str | None, - fps: float | None, - fps_key: str | None, - robot_type: str | None, - robot_type_key: str | None, - timestamp_key: str | None, - action_key: str | None, - state_key: str | Sequence[str] | None, - extra_observation_keys: Mapping[str, str] | Iterable[str] | None, - video_keys: Mapping[str, str] | Iterable[str] | None, - schema: pa.Schema | None, - stats_key: str | None, - stats_prefix: str, - episode_ends_key: str, -) -> Iterator[RoboticsRow]: - episode_ends = [int(value) for value in _get_path(row, episode_ends_key)] - spec = _RoboticsRowSpec.from_options( - episode_id_key=episode_id_key, - task_key=task_key, - fps=fps, - fps_key=fps_key, - robot_type=robot_type, - robot_type_key=robot_type_key, - nested_frames_key=None, - timestamp_key=timestamp_key, - action_key=action_key, - state_key=state_key, - extra_observation_keys=extra_observation_keys, - video_keys=video_keys, - schema=schema, - stats_key=stats_key, - stats_prefix=stats_prefix, - ) - start = 0 - for episode_idx, end in enumerate(episode_ends): - if episode_id_key is None: - split_row = row - else: - split_row = row.update( - {episode_id_key: f"{row.get(episode_id_key, '-1')}-{episode_idx}"} - ) - for source_key in spec.frame_source_map.values(): - for key in _source_keys(source_key): - if _has_path(row, key): - split_row = split_row.update( - _set_path( - split_row, - key, - _slice_values(_get_path(row, key), start, end), - ) - ) - video_fps = int(spec.wrap(row).fps or 30) - for source_key, storage in spec.video_source_map.values(): - if not _has_path(row, source_key): - continue - value = _get_path(row, source_key) - video = video_from_storage_value(storage, value, fps=video_fps) - if video is None: - continue - clip_fps = int(getattr(video, "fps", video_fps)) - split_row = split_row.update( - _set_path( - split_row, - source_key, - video.clipped( - from_timestamp_s=start / clip_fps, - to_timestamp_s=end / clip_fps, - ), - ) - ) - if task_key is not None and task_key in row: - split_row = split_row.update({task_key: row[task_key]}) - if fps_key is not None and fps_key in row: - split_row = split_row.update({fps_key: row[fps_key]}) - if robot_type_key is not None and robot_type_key in row: - split_row = split_row.update({robot_type_key: row[robot_type_key]}) - yield spec.wrap(split_row) - start = end - - def _video_sources( *, schema: pa.Schema | None, @@ -853,12 +741,6 @@ def _set_path(row: Mapping[str, Any], key: str, values: Any) -> dict[str, Any]: return {head: root} -def _slice_values(values: Any, start: int, end: int) -> Any: - if isinstance(values, pa.ChunkedArray | pa.Array): - return values.slice(start, end - start) - return values[start:end] - - def _select_frame_table(table: Tabular, indices: Sequence[int]) -> Tabular: selected = table.table.take(pa.array(indices, type=pa.int64())) if "frame_index" in selected.column_names: diff --git a/tests/io/test_datafile_datafolder.py b/tests/io/test_datafile_datafolder.py index 4b0e4345..a33a82ea 100644 --- a/tests/io/test_datafile_datafolder.py +++ b/tests/io/test_datafile_datafolder.py @@ -203,6 +203,17 @@ def test_datafolder_resolve_with_path_fs_tuple(tmp_path): assert folder.exists("out.txt") +def test_datafolder_resolve_strips_protocol_root_from_url(): + fs = MemoryFileSystem() + fs.pipe("/bucket/root/file.txt", b"ok") + + folder = DataFolder("memory://bucket/root") + + assert folder.path == "/bucket/root" + assert folder.abs_path() == "memory:///bucket/root" + assert folder.open("file.txt").read() == b"ok" + + class _CountingMemoryFS(MemoryFileSystem): def __init__(self): super().__init__() diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py new file mode 100644 index 00000000..3948983b --- /dev/null +++ b/tests/readers/test_zarr_reader.py @@ -0,0 +1,675 @@ +from __future__ import annotations + +from pathlib import Path +import shutil +from typing import Any, Literal, cast + +from fsspec.implementations.memory import MemoryFileSystem +import numpy as np +import pytest +import zarr + +import refiner as mdr +from refiner.io.datafolder import DataFolder +from refiner.robotics.row import RoboticsRow +from refiner.pipeline.data.row import Row +from refiner.pipeline.data.shard import RowRangeDescriptor + + +def _open_test_zarr(path: Path, *, mode: Literal["r", "r+", "a", "w", "w-"]): + kwargs: dict[str, Any] = {"mode": mode, "zarr_format": 2} + try: + return zarr.open_group(str(path), **kwargs) + except TypeError: + return zarr.open_group(str(path), mode=mode) + + +def _create_array(root, name: str, data, **kwargs): + if hasattr(root, "create_array"): + kwargs.pop("shape", None) + return root.create_array(name, data=data, **kwargs) + return root.create_dataset(name, data=data, **kwargs) + + +def _write_policy_zarr(path: Path) -> None: + root = _open_test_zarr(path, mode="w") + _create_array( + root, + "data/action", + data=np.asarray([[0.0], [0.1], [1.0], [1.1], [1.2]], dtype=np.float32), + ) + _create_array( + root, + "data/state", + data=np.asarray([[10.0], [10.1], [20.0], [20.1], [20.2]], dtype=np.float32), + ) + _create_array( + root, + "data/rgb", + data=np.arange(5 * 4 * 4 * 3, dtype=np.uint8).reshape(5, 4, 4, 3), + ) + _create_array(root, "meta/episode_ends", data=np.asarray([2, 5], dtype=np.int64)) + root.attrs["dataset_id"] = "pusht" + root.attrs["task"] = "push tee" + + +def test_read_zarr_rejects_reserved_file_path_output_name(tmp_path: Path) -> None: + path = tmp_path / "policy.zarr" + _write_policy_zarr(path) + + with pytest.raises(ValueError, match="reserved output names"): + mdr.read_zarr(path, arrays={"file_path": "data/action"}) + + +def test_read_zarr_reads_selected_arrays_and_attrs(tmp_path: Path) -> None: + path = tmp_path / "policy.zarr" + _write_policy_zarr(path) + + row = mdr.read_zarr( + path, + arrays={ + "action": "data/action", + "state": "data/state", + "episode_ends": "meta/episode_ends", + }, + attrs={"task": "task"}, + file_path_column=None, + ).take(1)[0] + + assert row["task"] == "push tee" + assert row["episode_ends"].tolist() == [2, 5] + np.testing.assert_allclose(row["action"][:2], [[0.0], [0.1]]) + + +def test_read_zarr_reads_scalar_arrays(tmp_path: Path) -> None: + path = tmp_path / "scalar.zarr" + root = _open_test_zarr(path, mode="w") + _create_array(root, "version", data=np.asarray(3, dtype=np.int64), shape=()) + + row = mdr.read_zarr( + path, + arrays={"version": "version"}, + file_path_column=None, + ).take(1)[0] + + assert row["version"] == 3 + + +def test_read_zarr_splits_arrays_by_row_ends(tmp_path: Path) -> None: + path = tmp_path / "policy.zarr" + _write_policy_zarr(path) + + rows = mdr.read_zarr( + path, + arrays={ + "action": "data/action", + "observation.state": "data/state", + "frames": "data/rgb", + }, + attrs={"task": "task"}, + row_ends="meta/episode_ends", + file_path_column=None, + ).take(2) + + assert [row["task"] for row in rows] == ["push tee", "push tee"] + assert [len(row["action"]) for row in rows] == [2, 3] + np.testing.assert_allclose(rows[0]["action"], [[0.0], [0.1]]) + np.testing.assert_allclose(rows[1]["action"], [[1.0], [1.1], [1.2]]) + np.testing.assert_array_equal( + rows[0]["frames"], + np.arange(2 * 4 * 4 * 3).reshape(2, 4, 4, 3), + ) + + +def test_read_zarr_reads_zip_store(tmp_path: Path) -> None: + path = tmp_path / "policy.zarr" + _write_policy_zarr(path) + zip_path = Path(shutil.make_archive(str(path), "zip", root_dir=path)) + + row = mdr.read_zarr( + str(zip_path), + arrays={"action": "data/action", "frames": "data/rgb"}, + row_ends="meta/episode_ends", + ).take(1)[0] + + assert row["file_path"] == str(zip_path) + assert row["action"].shape == (2, 1) + assert row["frames"].shape == (2, 4, 4, 3) + + +def test_read_zarr_reads_zip_datafolder(tmp_path: Path) -> None: + path = tmp_path / "policy.zarr" + _write_policy_zarr(path) + zip_path = Path(shutil.make_archive(str(path), "zip", root_dir=path)) + + row = mdr.read_zarr( + DataFolder(str(zip_path)), + arrays={"action": "data/action"}, + row_ends="meta/episode_ends", + ).take(1)[0] + + assert row["file_path"] == str(zip_path) + assert row["action"].shape == (2, 1) + + +def test_read_zarr_reads_remote_store(tmp_path: Path) -> None: + path = tmp_path / "policy.zarr" + _write_policy_zarr(path) + + fs = MemoryFileSystem() + remote_root = "/policy.zarr" + for source_path in path.rglob("*"): + if source_path.is_file(): + relative_path = source_path.relative_to(path) + remote_path = f"{remote_root}/{relative_path}" + fs.makedirs(str(Path(remote_path).parent), exist_ok=True) + with source_path.open("rb") as src, fs.open(remote_path, "wb") as dst: + shutil.copyfileobj(src, dst) + + row = mdr.read_zarr( + (remote_root, fs), + arrays={"action": "data/action"}, + row_ends="meta/episode_ends", + ).take(1)[0] + + assert row["file_path"] == "memory:///policy.zarr" + assert row["action"].shape == (2, 1) + + +def test_read_zarr_reads_remote_zip_without_cache( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + path = tmp_path / "policy.zarr" + _write_policy_zarr(path) + zip_path = Path(shutil.make_archive(str(path), "zip", root_dir=path)) + + fs = MemoryFileSystem() + remote_path = "/policy.zarr.zip" + with zip_path.open("rb") as src, fs.open(remote_path, "wb") as dst: + shutil.copyfileobj(src, dst) + + open_calls: list[dict[str, Any]] = [] + opened_files: list[int] = [] + closed_files: list[int] = [] + original_open = fs.open + + def record_open(path, mode="rb", **kwargs): + if path == remote_path and mode == "rb": + open_calls.append(kwargs) + file = original_open(path, mode=mode, **kwargs) + opened_files.append(id(file)) + original_close = file.close + + def record_close(): + closed_files.append(id(file)) + original_close() + + file.close = record_close + return file + return original_open(path, mode=mode, **kwargs) + + monkeypatch.setattr(fs, "open", record_open) + + row = mdr.read_zarr( + (remote_path, fs), + arrays={"action": "data/action"}, + row_ends="meta/episode_ends", + ).take(1)[0] + + assert row["action"].shape == (2, 1) + assert open_calls + assert open_calls[0]["cache_type"] == "none" + assert set(closed_files) == set(opened_files) + + +def test_read_zarr_plans_row_ends_with_num_shards(tmp_path: Path) -> None: + path = tmp_path / "policy.zarr" + _write_policy_zarr(path) + + pipeline = mdr.read_zarr( + path, + arrays={"action": "data/action"}, + row_ends="meta/episode_ends", + num_shards=2, + file_path_column=None, + ) + + shards = pipeline.source.list_shards() + ranges = [cast(RowRangeDescriptor, shard.descriptor) for shard in shards] + assert [(item.start, item.end) for item in ranges] == [ + (0, 1), + (1, 2), + ] + + rows = [cast(Row, row) for row in pipeline.source.read_shard(shards[1])] + assert len(rows) == 1 + assert rows[0]["index"] == 1 + np.testing.assert_allclose(rows[0]["action"], [[1.0], [1.1], [1.2]]) + + +def test_read_zarr_allows_attrs_only_reads(tmp_path: Path) -> None: + path = tmp_path / "policy.zarr" + _write_policy_zarr(path) + + row = mdr.read_zarr( + path, + arrays={}, + attrs={"task": "task"}, + file_path_column=None, + ).take(1)[0] + + assert list(row) == ["task"] + assert row["task"] == "push tee" + + +def test_read_zarr_rejects_duplicate_output_names(tmp_path: Path) -> None: + path = tmp_path / "policy.zarr" + _write_policy_zarr(path) + + with pytest.raises(ValueError, match="duplicate output names"): + mdr.read_zarr( + path, + arrays={"task": "data/action"}, + attrs={"task": "task"}, + file_path_column=None, + ) + + +def test_read_zarr_rejects_discovered_array_attr_collisions(tmp_path: Path) -> None: + path = tmp_path / "collision.zarr" + root = _open_test_zarr(path, mode="w") + _create_array(root, "task", data=np.asarray([1], dtype=np.int64)) + root.attrs["task"] = "push tee" + + pipeline = mdr.read_zarr(path, attrs={"task": "task"}, file_path_column=None) + + with pytest.raises(ValueError, match="duplicate output names"): + pipeline.take(1) + + +def test_read_zarr_rejects_reserved_index_output_name(tmp_path: Path) -> None: + path = tmp_path / "policy.zarr" + _write_policy_zarr(path) + + with pytest.raises(ValueError, match="reserved output names"): + mdr.read_zarr( + path, + arrays={"index": "data/action"}, + row_ends="meta/episode_ends", + file_path_column=None, + ) + + +def test_read_zarr_rejects_selecting_row_ends_as_output_array(tmp_path: Path) -> None: + path = tmp_path / "policy.zarr" + _write_policy_zarr(path) + + with pytest.raises(ValueError, match="cannot also be row_ends"): + mdr.read_zarr( + path, + arrays={"episode_ends": "meta/episode_ends"}, + row_ends="meta/episode_ends", + file_path_column=None, + ) + + +def test_read_zarr_rejects_duplicate_metadata_column_names(tmp_path: Path) -> None: + path = tmp_path / "policy.zarr" + _write_policy_zarr(path) + + with pytest.raises(ValueError, match="must be distinct"): + mdr.read_zarr( + path, + row_ends="meta/episode_ends", + file_path_column="metadata", + index_column="metadata", + ) + + +def test_read_zarr_rejects_missing_selected_paths(tmp_path: Path) -> None: + path = tmp_path / "policy.zarr" + _write_policy_zarr(path) + + with pytest.raises(KeyError, match="Zarr array not found"): + mdr.read_zarr( + path, + arrays={"missing": "data/missing"}, + file_path_column=None, + ).take(1) + + +def test_read_zarr_rejects_missing_selected_attrs(tmp_path: Path) -> None: + path = tmp_path / "policy.zarr" + _write_policy_zarr(path) + + with pytest.raises(KeyError, match="Zarr attr not found"): + mdr.read_zarr( + path, + arrays={}, + attrs={"missing_attr": "missing_attr"}, + file_path_column=None, + ).take(1) + + +def test_read_zarr_split_leading_axis_emits_one_row_per_index(tmp_path: Path) -> None: + path = tmp_path / "leading_axis.zarr" + root = _open_test_zarr(path, mode="w") + _create_array( + root, + "data/action", + data=np.arange(5, dtype=np.float32).reshape(5, 1), + chunks=(1, 1), + ) + _create_array( + root, + "data/rgb", + data=np.arange(5 * 4 * 4 * 3, dtype=np.uint8).reshape(5, 4, 4, 3), + chunks=(2, 4, 4, 3), + ) + + pipeline = mdr.read_zarr( + path, + arrays={"action": "data/action", "rgb": "data/rgb"}, + split_leading_axis=True, + target_shard_bytes=96, + file_path_column=None, + ) + + shards = pipeline.source.list_shards() + ranges = [cast(RowRangeDescriptor, shard.descriptor) for shard in shards] + assert [(item.start, item.end) for item in ranges] == [ + (0, 2), + (2, 4), + (4, 5), + ] + + rows = pipeline.take(3) + + assert [row["index"] for row in rows] == [0, 1, 2] + assert [row["action"].shape for row in rows] == [(1, 1), (1, 1), (1, 1)] + assert [row["rgb"].shape for row in rows] == [(1, 4, 4, 3)] * 3 + np.testing.assert_allclose(rows[1]["action"], [[1.0]]) + + +def test_read_zarr_split_leading_axis_uses_row_size(tmp_path: Path) -> None: + path = tmp_path / "leading_axis_rows.zarr" + root = _open_test_zarr(path, mode="w") + _create_array( + root, + "data/action", + data=np.arange(6, dtype=np.float32).reshape(6, 1), + chunks=(2, 1), + ) + + rows = mdr.read_zarr( + path, + arrays={"action": "data/action"}, + split_leading_axis=True, + leading_axis_row_size=2, + file_path_column=None, + ).take(3) + + assert [row["index"] for row in rows] == [0, 1, 2] + assert [row["action"].shape for row in rows] == [(2, 1), (2, 1), (2, 1)] + np.testing.assert_allclose(rows[1]["action"], [[2.0], [3.0]]) + + +def test_read_zarr_split_leading_axis_uses_row_batch_size( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + path = tmp_path / "leading_axis_batched.zarr" + root = _open_test_zarr(path, mode="w") + _create_array( + root, + "data/action", + data=np.arange(6, dtype=np.float32).reshape(6, 1), + chunks=(1, 1), + ) + + pipeline = mdr.read_zarr( + path, + arrays={"action": "data/action"}, + split_leading_axis=True, + row_batch_size=2, + file_path_column=None, + ) + source = cast(Any, pipeline.source) + shards = source.list_shards() + calls: list[tuple[int | None, int | None]] = [] + read_arrays = source._read_arrays + + def record_read_arrays(arrays, *, start=None, end=None): + calls.append((start, end)) + return read_arrays(arrays, start=start, end=end) + + monkeypatch.setattr(source, "_read_arrays", record_read_arrays) + rows = list(source.read_shard(shards[0])) + + assert [row["index"] for row in rows] == list(range(6)) + np.testing.assert_allclose(rows[5]["action"], [[5.0]]) + assert calls == [(0, 2), (2, 4), (4, 6)] + + +def test_read_zarr_split_leading_axis_uses_dominant_array_chunks( + tmp_path: Path, +) -> None: + path = tmp_path / "image_dominant_chunks.zarr" + root = _open_test_zarr(path, mode="w") + _create_array( + root, + "data/action", + data=np.arange(10, dtype=np.float32).reshape(10, 1), + chunks=(10, 1), + ) + _create_array( + root, + "data/rgb", + data=np.zeros((10, 4, 4, 3), dtype=np.uint8), + chunks=(2, 4, 4, 3), + ) + + pipeline = mdr.read_zarr( + path, + arrays={"action": "data/action", "rgb": "data/rgb"}, + split_leading_axis=True, + target_shard_bytes=160, + file_path_column=None, + ) + + shards = pipeline.source.list_shards() + ranges = [cast(RowRangeDescriptor, shard.descriptor) for shard in shards] + assert [(item.start, item.end) for item in ranges] == [ + (0, 2), + (2, 4), + (4, 6), + (6, 8), + (8, 10), + ] + + +def test_read_zarr_split_leading_axis_requires_aligned_lengths(tmp_path: Path) -> None: + path = tmp_path / "misaligned.zarr" + root = _open_test_zarr(path, mode="w") + _create_array(root, "data/action", data=np.zeros((5, 1), dtype=np.float32)) + _create_array(root, "data/state", data=np.zeros((4, 1), dtype=np.float32)) + + with pytest.raises(ValueError, match="same leading dimension"): + mdr.read_zarr( + path, + arrays={"action": "data/action", "state": "data/state"}, + split_leading_axis=True, + file_path_column=None, + ).take(1) + + +def test_read_zarr_split_leading_axis_requires_full_rows(tmp_path: Path) -> None: + path = tmp_path / "partial-leading-axis-row.zarr" + root = _open_test_zarr(path, mode="w") + _create_array(root, "data/action", data=np.zeros((5, 1), dtype=np.float32)) + + with pytest.raises(ValueError, match="divisible by row size"): + mdr.read_zarr( + path, + arrays={"action": "data/action"}, + split_leading_axis=True, + leading_axis_row_size=2, + file_path_column=None, + ).take(1) + + +def test_read_zarr_leading_axis_row_size_requires_split_mode(tmp_path: Path) -> None: + path = tmp_path / "policy.zarr" + _write_policy_zarr(path) + + with pytest.raises(ValueError, match="requires split_leading_axis"): + mdr.read_zarr( + path, + arrays={"action": "data/action"}, + leading_axis_row_size=2, + ) + + +def test_read_zarr_rejects_invalid_row_batch_size(tmp_path: Path) -> None: + path = tmp_path / "policy.zarr" + _write_policy_zarr(path) + + with pytest.raises(ValueError, match="row_batch_size"): + mdr.read_zarr( + path, + arrays={"action": "data/action"}, + split_leading_axis=True, + row_batch_size=0, + ) + + +def test_read_zarr_rejects_non_monotonic_row_ends(tmp_path: Path) -> None: + path = tmp_path / "policy.zarr" + _write_policy_zarr(path) + root = _open_test_zarr(path, mode="a") + root["meta/episode_ends"][:] = np.asarray([3, 2], dtype=np.int64) + + with pytest.raises(ValueError, match="row_ends must be monotonic"): + mdr.read_zarr( + path, + arrays={"action": "data/action"}, + row_ends="meta/episode_ends", + file_path_column=None, + ).take(1) + + +def test_read_zarr_rejects_non_integer_row_ends(tmp_path: Path) -> None: + path = tmp_path / "float-row-ends.zarr" + root = _open_test_zarr(path, mode="w") + _create_array(root, "data/action", data=np.zeros((5, 1), dtype=np.float32)) + _create_array(root, "meta/episode_ends", data=np.asarray([2.5, 5.0])) + + with pytest.raises(ValueError, match="integer offsets"): + mdr.read_zarr( + path, + arrays={"action": "data/action"}, + row_ends="meta/episode_ends", + file_path_column=None, + ).take(1) + + +def test_read_zarr_rejects_out_of_range_row_ends(tmp_path: Path) -> None: + path = tmp_path / "policy.zarr" + _write_policy_zarr(path) + root = _open_test_zarr(path, mode="a") + root["meta/episode_ends"][:] = np.asarray([2, 6], dtype=np.int64) + + with pytest.raises(ValueError, match="row_ends exceed"): + mdr.read_zarr( + path, + arrays={"action": "data/action"}, + row_ends="meta/episode_ends", + file_path_column=None, + ).take(2) + + +def test_read_zarr_rejects_short_row_ends(tmp_path: Path) -> None: + path = tmp_path / "policy.zarr" + _write_policy_zarr(path) + root = _open_test_zarr(path, mode="a") + root["meta/episode_ends"][:] = np.asarray([2, 4], dtype=np.int64) + + with pytest.raises(ValueError, match="end before leading dimension"): + mdr.read_zarr( + path, + arrays={"action": "data/action"}, + row_ends="meta/episode_ends", + file_path_column=None, + ).take(2) + + +def test_read_zarr_rejects_empty_row_ends_for_nonempty_arrays(tmp_path: Path) -> None: + path = tmp_path / "empty-row-ends.zarr" + root = _open_test_zarr(path, mode="w") + _create_array(root, "data/action", data=np.zeros((2, 1), dtype=np.float32)) + _create_array(root, "meta/episode_ends", data=np.asarray([], dtype=np.int64)) + + with pytest.raises(ValueError, match="end before leading dimension"): + mdr.read_zarr( + path, + arrays={"action": "data/action"}, + row_ends="meta/episode_ends", + file_path_column=None, + ).take(1) + + +def test_read_zarr_rejects_scalar_arrays_in_row_ends_mode(tmp_path: Path) -> None: + path = tmp_path / "scalar-row-ends.zarr" + root = _open_test_zarr(path, mode="w") + _create_array(root, "version", data=np.asarray(3, dtype=np.int64), shape=()) + _create_array(root, "meta/episode_ends", data=np.asarray([1], dtype=np.int64)) + + with pytest.raises(ValueError, match="must have a leading dimension"): + mdr.read_zarr( + path, + arrays={"version": "version"}, + row_ends="meta/episode_ends", + file_path_column=None, + ).take(1) + + +def test_zarr_to_robot_rows_and_lerobot_roundtrip(tmp_path: Path) -> None: + path = tmp_path / "policy.zarr" + lerobot_out = tmp_path / "lerobot" + _write_policy_zarr(path) + + ( + mdr.read_zarr( + path, + arrays={ + "action": "data/action", + "observation.state": "data/state", + "frames": "data/rgb", + }, + attrs={"dataset_id": "dataset_id", "task": "task"}, + row_ends="meta/episode_ends", + file_path_column=None, + ) + .to_robot_rows( + episode_id_key="index", + task_key="task", + action_key="action", + state_key="observation.state", + video_keys={"observation.images.front": "frames"}, + fps=10, + robot_type="pusht", + ) + .write_lerobot(str(lerobot_out), max_video_prepare_in_flight=1) + .launch_local( + name="zarr-to-lerobot", num_workers=1, rundir=str(tmp_path / "run1") + ) + ) + + episodes = [ + cast(RoboticsRow, row) + for row in mdr.read_lerobot(str(lerobot_out)).materialize() + ] + episodes.sort(key=lambda episode: int(episode.episode_id)) + assert [episode.num_frames for episode in episodes] == [2, 3] + assert [episode.task for episode in episodes] == ["push tee", "push tee"] diff --git a/tests/robotics/test_robotics_row.py b/tests/robotics/test_robotics_row.py index cbb1717c..b4e026e8 100644 --- a/tests/robotics/test_robotics_row.py +++ b/tests/robotics/test_robotics_row.py @@ -34,10 +34,7 @@ def _robot_rows( converter = _robot_row_converter(**cast(Any, {**kwargs, "schema": rows.schema})) return [cast(RoboticsRow, converter(row)) for row in rows] converter = _robot_row_converter(**kwargs) - converted = converter(rows) - if kwargs.get("episode_ends_key") is not None: - return [cast(RoboticsRow, row) for row in converted] - return [cast(RoboticsRow, converted)] + return [cast(RoboticsRow, converter(rows))] def test_to_robot_rows_does_not_treat_video_uri_frames_as_frame_table() -> None: @@ -249,31 +246,6 @@ def test_to_robot_rows_preserves_explicit_fps_and_robot_type_keys() -> None: assert robotics_row.robot_type == "aloha" -def test_to_robot_rows_reuses_literal_fps_and_robot_type_for_episode_splits() -> None: - row = DictRow( - { - "dataset_id": "pusht", - "episode_ends": [1, 2], - "action": [[0.0], [1.0]], - } - ) - - rows = list( - _robot_rows( - row, - episode_id_key="dataset_id", - episode_ends_key="episode_ends", - timestamp_key=None, - state_key=None, - fps=30.0, - robot_type="pusht", - ) - ) - - assert [row.fps for row in rows] == [30.0, 30.0] - assert [row.robot_type for row in rows] == ["pusht", "pusht"] - - def test_pipeline_to_robot_rows_forwards_literals_and_explicit_keys() -> None: literal_row = ( from_items([{"episode_id": "episode-1"}]) @@ -654,118 +626,6 @@ def test_to_robot_rows_flattens_nested_frame_dicts() -> None: assert robotics_row.states.to_pylist() == [[0.0], [1.0]] -def test_to_robot_rows_splits_dataset_arrays_with_episode_ends() -> None: - row = DictRow( - { - "dataset_id": "pusht", - "meta": {"episode_ends": [2, 5]}, - "data": { - "obs": [[0.0], [0.1], [1.0], [1.1], [1.2]], - "action": [[0.0], [0.1], [1.0], [1.1], [1.2]], - }, - } - ) - - rows = list( - _robot_rows( - row, - episode_id_key="dataset_id", - episode_ends_key="meta/episode_ends", - timestamp_key=None, - state_key="data/obs", - action_key="data/action", - ) - ) - - assert [row.episode_id for row in rows] == ["pusht-0", "pusht-1"] - assert [row.num_frames for row in rows] == [2, 3] - assert rows[1].actions == [[1.0], [1.1], [1.2]] - - -def test_to_robot_rows_splits_video_frame_arrays_with_episode_ends() -> None: - frames = np.arange(5 * 2 * 2 * 3, dtype=np.uint8).reshape(5, 2, 2, 3) - row = DictRow( - { - "dataset_id": "pusht", - "episode_ends": [2, 5], - "actions": [[0.0], [0.1], [1.0], [1.1], [1.2]], - "front": frames, - } - ) - - rows = list( - _robot_rows( - row, - episode_id_key="dataset_id", - episode_ends_key="episode_ends", - timestamp_key=None, - action_key="actions", - state_key=None, - video_keys={"observation.images.front": "front"}, - fps=10, - ) - ) - - videos = [row.videos["observation.images.front"] for row in rows] - - assert [ - video.frame_count for video in videos if isinstance(video, VideoFrameArray) - ] == [ - 2, - 3, - ] - assert isinstance(videos[0], VideoFrameArray) - assert isinstance(videos[1], VideoFrameArray) - np.testing.assert_array_equal( - np.asarray(list(videos[0].iter_frame_arrays())), - frames[:2], - ) - np.testing.assert_array_equal( - np.asarray(list(videos[1].iter_frame_arrays())), - frames[2:], - ) - - -def test_to_robot_rows_splits_tuple_state_key_sources_with_episode_ends() -> None: - row = DictRow( - { - "dataset_id": "robomimic", - "meta": {"episode_ends": [2, 4]}, - "actions": [[0.0], [0.1], [1.0], [1.1]], - "obs": { - "joint_pos": [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]], - "joint_vel": [[0.1], [0.2], [0.3], [0.4]], - "eef_pos": np.asarray( - [[10.0, 11.0], [12.0, 13.0], [14.0, 15.0], [16.0, 17.0]], - dtype=np.float32, - ), - }, - } - ) - - rows = list( - _robot_rows( - row, - episode_id_key="dataset_id", - episode_ends_key="meta/episode_ends", - timestamp_key=None, - action_key="actions", - state_key=("obs/joint_pos", "obs/joint_vel", "obs/eef_pos"), - ) - ) - - assert [row.episode_id for row in rows] == ["robomimic-0", "robomimic-1"] - assert [row.num_frames for row in rows] == [2, 2] - assert rows[0].states == [ - [1.0, 2.0, 0.1, 10.0, 11.0], - [3.0, 4.0, 0.2, 12.0, 13.0], - ] - assert rows[1].states == [ - [5.0, 6.0, 0.3, 14.0, 15.0], - [7.0, 8.0, 0.4, 16.0, 17.0], - ] - - def test_motion_trim_works_with_mapped_semantic_keys() -> None: row = DictRow( { diff --git a/uv.lock b/uv.lock index 30083db1..d844aa09 100644 --- a/uv.lock +++ b/uv.lock @@ -5,9 +5,12 @@ resolution-markers = [ "python_full_version >= '3.14' and sys_platform == 'win32'", "python_full_version >= '3.14' and sys_platform == 'emscripten'", "python_full_version >= '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32'", - "python_full_version >= '3.11' and python_full_version < '3.14' and sys_platform == 'win32'", - "python_full_version >= '3.11' and python_full_version < '3.14' and sys_platform == 'emscripten'", - "python_full_version >= '3.11' and python_full_version < '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform == 'win32'", + "python_full_version == '3.11.*' and sys_platform == 'win32'", + "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform == 'emscripten'", + "python_full_version == '3.11.*' and sys_platform == 'emscripten'", + "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version == '3.11.*' and sys_platform != 'emscripten' and sys_platform != 'win32'", "python_full_version < '3.11'", ] @@ -204,6 +207,12 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/38/0e/27be9fdef66e72d64c0cdc3cc2823101b80585f8119b5c112c2e8f5f7dab/anyio-4.12.1-py3-none-any.whl", hash = "sha256:d405828884fc140aa80a3c667b8beed277f1dfedec42ba031bd6ac3db606ab6c", size = 113592, upload-time = "2026-01-06T11:45:19.497Z" }, ] +[[package]] +name = "asciitree" +version = "0.3.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/2d/6a/885bc91484e1aa8f618f6f0228d76d0e67000b0fdd6090673b777e311913/asciitree-0.3.3.tar.gz", hash = "sha256:4aa4b9b649f85e3fcb343363d97564aa1fb62e249677f2e18a96765145cc0f6e", size = 3951, upload-time = "2016-09-05T19:10:42.681Z" } + [[package]] name = "async-timeout" version = "5.0.1" @@ -621,6 +630,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8a/0e/97c33bf5009bdbac74fd2beace167cab3f978feb69cc36f1ef79360d6c4e/exceptiongroup-1.3.1-py3-none-any.whl", hash = "sha256:a7a39a3bd276781e98394987d3a5701d0c4edffb633bb7a5144577f82c773598", size = 16740, upload-time = "2025-11-21T23:01:53.443Z" }, ] +[[package]] +name = "fasteners" +version = "0.20" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/2d/18/7881a99ba5244bfc82f06017316ffe93217dbbbcfa52b887caa1d4f2a6d3/fasteners-0.20.tar.gz", hash = "sha256:55dce8792a41b56f727ba6e123fcaee77fd87e638a6863cec00007bfea84c8d8", size = 25087, upload-time = "2025-08-11T10:19:37.785Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/51/ac/e5d886f892666d2d1e5cb8c1a41146e1d79ae8896477b1153a21711d3b44/fasteners-0.20-py3-none-any.whl", hash = "sha256:9422c40d1e350e4259f509fb2e608d6bc43c0136f79a00db1b49046029d0b3b7", size = 18702, upload-time = "2025-08-11T10:19:35.716Z" }, +] + [[package]] name = "filelock" version = "3.25.2" @@ -997,10 +1015,12 @@ all = [ { name = "h5py" }, { name = "hf" }, { name = "huggingface-hub" }, + { name = "numcodecs" }, { name = "pytest" }, { name = "pytest-cov" }, { name = "s3fs" }, { name = "warcio" }, + { name = "zarr" }, ] hdf5 = [ { name = "h5py" }, @@ -1022,10 +1042,12 @@ testing = [ { name = "h5py" }, { name = "hf" }, { name = "huggingface-hub" }, + { name = "numcodecs" }, { name = "pytest" }, { name = "pytest-cov" }, { name = "s3fs" }, { name = "warcio" }, + { name = "zarr" }, ] text = [ { name = "warcio" }, @@ -1033,6 +1055,10 @@ text = [ video = [ { name = "av" }, ] +zarr = [ + { name = "numcodecs" }, + { name = "zarr" }, +] [package.dev-dependencies] dev = [ @@ -1060,7 +1086,9 @@ requires-dist = [ { name = "macrodata-refiner", extras = ["testing"], marker = "extra == 'all'" }, { name = "macrodata-refiner", extras = ["text"], marker = "extra == 'testing'" }, { name = "macrodata-refiner", extras = ["video"], marker = "extra == 'robotics'" }, + { name = "macrodata-refiner", extras = ["zarr"], marker = "extra == 'testing'" }, { name = "msgspec", specifier = ">=0.20.0" }, + { name = "numcodecs", marker = "extra == 'zarr'", specifier = "<0.16" }, { name = "numpy" }, { name = "orjson" }, { name = "pyarrow" }, @@ -1068,8 +1096,9 @@ requires-dist = [ { name = "pytest-cov", marker = "extra == 'testing'", specifier = ">=5.0.0" }, { name = "s3fs", marker = "extra == 's3'" }, { name = "warcio", marker = "extra == 'text'" }, + { name = "zarr", marker = "extra == 'zarr'", specifier = ">=2.18,<3" }, ] -provides-extras = ["huggingface", "video", "robotics", "text", "hdf5", "s3", "testing", "all"] +provides-extras = ["huggingface", "video", "robotics", "text", "hdf5", "zarr", "s3", "testing", "all"] [package.metadata.requires-dev] dev = [ @@ -1326,6 +1355,34 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/88/b2/d0896bdcdc8d28a7fc5717c305f1a861c26e18c05047949fb371034d98bd/nodeenv-1.10.0-py2.py3-none-any.whl", hash = "sha256:5bb13e3eed2923615535339b3c620e76779af4cb4c6a90deccc9e36b274d3827", size = 23438, upload-time = "2025-12-20T14:08:52.782Z" }, ] +[[package]] +name = "numcodecs" +version = "0.13.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "numpy", version = "2.4.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/85/56/8895a76abe4ec94ebd01eeb6d74f587bc4cddd46569670e1402852a5da13/numcodecs-0.13.1.tar.gz", hash = "sha256:a3cf37881df0898f3a9c0d4477df88133fe85185bffe57ba31bcc2fa207709bc", size = 5955215, upload-time = "2024-10-09T16:28:00.188Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/14/c0/6d72cde772bcec196b7188731d41282993b2958440f77fdf0db216f722da/numcodecs-0.13.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:96add4f783c5ce57cc7e650b6cac79dd101daf887c479a00a29bc1487ced180b", size = 1580012, upload-time = "2024-10-09T16:27:19.069Z" }, + { url = "https://files.pythonhosted.org/packages/94/1d/f81fc1fa9210bbea97258242393a1f9feab4f6d8fb201f81f76003005e4b/numcodecs-0.13.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:237b7171609e868a20fd313748494444458ccd696062f67e198f7f8f52000c15", size = 1176919, upload-time = "2024-10-09T16:27:21.634Z" }, + { url = "https://files.pythonhosted.org/packages/16/e4/b9ec2f4dfc34ecf724bc1beb96a9f6fa9b91801645688ffadacd485089da/numcodecs-0.13.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:96e42f73c31b8c24259c5fac6adba0c3ebf95536e37749dc6c62ade2989dca28", size = 8625842, upload-time = "2024-10-09T16:27:24.168Z" }, + { url = "https://files.pythonhosted.org/packages/fe/90/299952e1477954ec4f92813fa03e743945e3ff711bb4f6c9aace431cb3da/numcodecs-0.13.1-cp310-cp310-win_amd64.whl", hash = "sha256:eda7d7823c9282e65234731fd6bd3986b1f9e035755f7fed248d7d366bb291ab", size = 828638, upload-time = "2024-10-09T16:27:27.063Z" }, + { url = "https://files.pythonhosted.org/packages/f0/78/34b8e869ef143e88d62e8231f4dbfcad85e5c41302a11fc5bd2228a13df5/numcodecs-0.13.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:2eda97dd2f90add98df6d295f2c6ae846043396e3d51a739ca5db6c03b5eb666", size = 1580199, upload-time = "2024-10-09T16:27:29.336Z" }, + { url = "https://files.pythonhosted.org/packages/3b/cf/f70797d86bb585d258d1e6993dced30396f2044725b96ce8bcf87a02be9c/numcodecs-0.13.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2a86f5367af9168e30f99727ff03b27d849c31ad4522060dde0bce2923b3a8bc", size = 1177203, upload-time = "2024-10-09T16:27:31.011Z" }, + { url = "https://files.pythonhosted.org/packages/a8/b5/d14ad69b63fde041153dfd05d7181a49c0d4864de31a7a1093c8370da957/numcodecs-0.13.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:233bc7f26abce24d57e44ea8ebeb5cd17084690b4e7409dd470fdb75528d615f", size = 8868743, upload-time = "2024-10-09T16:27:32.833Z" }, + { url = "https://files.pythonhosted.org/packages/13/d4/27a7b5af0b33f6d61e198faf177fbbf3cb83ff10d9d1a6857b7efc525ad5/numcodecs-0.13.1-cp311-cp311-win_amd64.whl", hash = "sha256:796b3e6740107e4fa624cc636248a1580138b3f1c579160f260f76ff13a4261b", size = 829603, upload-time = "2024-10-09T16:27:35.415Z" }, + { url = "https://files.pythonhosted.org/packages/37/3a/bc09808425e7d3df41e5fc73fc7a802c429ba8c6b05e55f133654ade019d/numcodecs-0.13.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:5195bea384a6428f8afcece793860b1ab0ae28143c853f0b2b20d55a8947c917", size = 1575806, upload-time = "2024-10-09T16:27:37.804Z" }, + { url = "https://files.pythonhosted.org/packages/3a/cc/dc74d0bfdf9ec192332a089d199f1e543e747c556b5659118db7a437dcca/numcodecs-0.13.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3501a848adaddce98a71a262fee15cd3618312692aa419da77acd18af4a6a3f6", size = 1178233, upload-time = "2024-10-09T16:27:40.169Z" }, + { url = "https://files.pythonhosted.org/packages/d4/ce/434e8e3970b8e92ae9ab6d9db16cb9bc7aa1cd02e17c11de6848224100a1/numcodecs-0.13.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da2230484e6102e5fa3cc1a5dd37ca1f92dfbd183d91662074d6f7574e3e8f53", size = 8857827, upload-time = "2024-10-09T16:27:42.743Z" }, + { url = "https://files.pythonhosted.org/packages/83/e7/1d8b1b266a92f9013c755b1c146c5ad71a2bff147ecbc67f86546a2e4d6a/numcodecs-0.13.1-cp312-cp312-win_amd64.whl", hash = "sha256:e5db4824ebd5389ea30e54bc8aeccb82d514d28b6b68da6c536b8fa4596f4bca", size = 826539, upload-time = "2024-10-09T16:27:44.808Z" }, + { url = "https://files.pythonhosted.org/packages/83/8b/06771dead2cc4a8ae1ea9907737cf1c8d37a323392fa28f938a586373468/numcodecs-0.13.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:7a60d75179fd6692e301ddfb3b266d51eb598606dcae7b9fc57f986e8d65cb43", size = 1571660, upload-time = "2024-10-09T16:27:47.125Z" }, + { url = "https://files.pythonhosted.org/packages/f9/ea/d925bf85f92dfe4635356018da9fe4bfecb07b1c72f62b01c1bc47f936b1/numcodecs-0.13.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:3f593c7506b0ab248961a3b13cb148cc6e8355662ff124ac591822310bc55ecf", size = 1169925, upload-time = "2024-10-09T16:27:49.512Z" }, + { url = "https://files.pythonhosted.org/packages/0f/d6/643a3839d571d8e439a2c77dc4b0b8cab18d96ac808e4a81dbe88e959ab6/numcodecs-0.13.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80d3071465f03522e776a31045ddf2cfee7f52df468b977ed3afdd7fe5869701", size = 8814257, upload-time = "2024-10-09T16:27:52.059Z" }, + { url = "https://files.pythonhosted.org/packages/a6/c5/f3e56bc9b4e438a287fff738993d6d11abef368c0328a612ac2842ba9fca/numcodecs-0.13.1-cp313-cp313-win_amd64.whl", hash = "sha256:90d3065ae74c9342048ae0046006f99dcb1388b7288da5a19b3bddf9c30c3176", size = 821887, upload-time = "2024-10-09T16:27:55.039Z" }, +] + [[package]] name = "numpy" version = "2.2.6" @@ -1399,9 +1456,12 @@ resolution-markers = [ "python_full_version >= '3.14' and sys_platform == 'win32'", "python_full_version >= '3.14' and sys_platform == 'emscripten'", "python_full_version >= '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32'", - "python_full_version >= '3.11' and python_full_version < '3.14' and sys_platform == 'win32'", - "python_full_version >= '3.11' and python_full_version < '3.14' and sys_platform == 'emscripten'", - "python_full_version >= '3.11' and python_full_version < '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform == 'win32'", + "python_full_version == '3.11.*' and sys_platform == 'win32'", + "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform == 'emscripten'", + "python_full_version == '3.11.*' and sys_platform == 'emscripten'", + "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version == '3.11.*' and sys_platform != 'emscripten' and sys_platform != 'win32'", ] sdist = { url = "https://files.pythonhosted.org/packages/10/8b/c265f4823726ab832de836cdd184d0986dcf94480f81e8739692a7ac7af2/numpy-2.4.3.tar.gz", hash = "sha256:483a201202b73495f00dbc83796c6ae63137a9bdade074f7648b3e32613412dd", size = 20727743, upload-time = "2026-03-09T07:58:53.426Z" } wheels = [ @@ -1640,9 +1700,12 @@ resolution-markers = [ "python_full_version >= '3.14' and sys_platform == 'win32'", "python_full_version >= '3.14' and sys_platform == 'emscripten'", "python_full_version >= '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32'", - "python_full_version >= '3.11' and python_full_version < '3.14' and sys_platform == 'win32'", - "python_full_version >= '3.11' and python_full_version < '3.14' and sys_platform == 'emscripten'", - "python_full_version >= '3.11' and python_full_version < '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform == 'win32'", + "python_full_version == '3.11.*' and sys_platform == 'win32'", + "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform == 'emscripten'", + "python_full_version == '3.11.*' and sys_platform == 'emscripten'", + "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version == '3.11.*' and sys_platform != 'emscripten' and sys_platform != 'win32'", ] dependencies = [ { name = "numpy", version = "2.4.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, @@ -2680,3 +2743,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/51/47/3fa2286c3cb162c71cdb34c4224d5745a1ceceb391b2bd9b19b668a8d724/yarl-1.23.0-cp314-cp314t-win_arm64.whl", hash = "sha256:44bb7bef4ea409384e3f8bc36c063d77ea1b8d4a5b2706956c0d6695f07dcc25", size = 86041, upload-time = "2026-03-01T22:07:49.026Z" }, { url = "https://files.pythonhosted.org/packages/69/68/c8739671f5699c7dc470580a4f821ef37c32c4cb0b047ce223a7f115757f/yarl-1.23.0-py3-none-any.whl", hash = "sha256:a2df6afe50dea8ae15fa34c9f824a3ee958d785fd5d089063d960bae1daa0a3f", size = 48288, upload-time = "2026-03-01T22:07:51.388Z" }, ] + +[[package]] +name = "zarr" +version = "2.18.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "asciitree" }, + { name = "fasteners", marker = "sys_platform != 'emscripten'" }, + { name = "numcodecs" }, + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "numpy", version = "2.4.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/23/c4/187a21ce7cf7c8f00c060dd0e04c2a81139bb7b1ab178bba83f2e1134ce2/zarr-2.18.3.tar.gz", hash = "sha256:2580d8cb6dd84621771a10d31c4d777dca8a27706a1a89b29f42d2d37e2df5ce", size = 3603224, upload-time = "2024-09-04T23:20:16.595Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ed/c9/142095e654c2b97133ff71df60979422717b29738b08bc8a1709a5d5e0d0/zarr-2.18.3-py3-none-any.whl", hash = "sha256:b1f7dfd2496f436745cdd4c7bcf8d3b4bc1dceef5fdd0d589c87130d842496dd", size = 210723, upload-time = "2024-09-04T23:20:14.491Z" }, +]