diff --git a/docs/reading-and-writing.md b/docs/reading-and-writing.md index 22ff835e..479f5f29 100644 --- a/docs/reading-and-writing.md +++ b/docs/reading-and-writing.md @@ -441,6 +441,7 @@ Built-in sinks: | `.write_jsonl(output, ...)` | JSON Lines files | one output file per worker/shard according to the filename template | | `.write_parquet(output, ...)` | Parquet files | columnar output with optional compression | | `.write_lerobot(output, ...)` | LeRobot-compatible robotics datasets | materializes frame/video assets and dataset metadata | +| `.write_zarr(output, ...)` | Zarr stores | one store per shard/worker according to the store template | Example: @@ -488,6 +489,43 @@ pipeline = pipeline.map( pipeline = pipeline.cast(video=mdr.datatype.video_path()) ``` +Use `write_zarr(...)` when you want chunked array output, usually for robotics +episode rows or replay-buffer style data: + +```python +import refiner as mdr + +( + mdr.read_lerobot("hf://datasets/user/robot-data") + .write_zarr( + "s3://my-bucket/robot-data-zarr/", + arrays={ + "data/action": "action", + "data/state": "observation.state", + }, + ) +) +``` + +The `arrays` mapping is from output Zarr path to source row key. For +`RoboticsRow` inputs, omitting `arrays` writes the available default robotics +arrays: actions, states, and timestamps. The default schema is inferred once and +later rows must expose the same fields. Video sources selected through `arrays` +are decoded as RGB frame arrays and appended in bounded batches controlled by +`video_frame_batch_size`. `array_chunk_bytes` controls the target chunk size for +new arrays and the reducer read/write batch size when shard-local stores are +merged into a final store. + +By default, `write_zarr(...)` also writes cumulative episode boundaries to +`meta/episode_ends`. Set `episode_ends_path=None` to omit them. + +Launched runs write isolated stores per shard/worker using +`store_template="{shard_id}__w{worker_id}.zarr"`. This avoids concurrent workers +mutating the same Zarr group. By default, a reducer stage streams those +shard-local stores into one final Zarr group at the requested output path. Set +`reduce_to_single_store=False` to keep the isolated stores and read them +individually. + When you run a writer through `launch_local(...)` or `launch_cloud(...)`, some sinks add a reducer stage after the main writer stage. For `write_jsonl(...)` and `write_parquet(...)`, that reducer removes stale shard/worker files and diff --git a/src/refiner/pipeline/pipeline.py b/src/refiner/pipeline/pipeline.py index 6ace95d2..96016a6a 100644 --- a/src/refiner/pipeline/pipeline.py +++ b/src/refiner/pipeline/pipeline.py @@ -30,7 +30,7 @@ VectorizedSegmentStep, WithColumnsStep, ) -from refiner.pipeline.sinks import BaseSink, JsonlSink, ParquetSink +from refiner.pipeline.sinks import BaseSink, JsonlSink, ParquetSink, ZarrSink from refiner.pipeline.sinks.assets import MissingAssetPolicy from refiner.pipeline.sources import ( BaseSource, @@ -431,6 +431,53 @@ def write_parquet( ) ) + def write_zarr( + self, + output: DataFolderLike, + *, + arrays: Mapping[str, str] | None = None, + attrs: Mapping[str, str] | None = None, + episode_ends_path: str | None = "meta/episode_ends", + store_template: str = "{shard_id}__w{worker_id}.zarr", + video_frame_batch_size: int = 8, + array_chunk_bytes: int = 8 * 1024 * 1024, + reduce_to_single_store: bool = True, + ) -> "RefinerPipeline": + """Write rows to Zarr array stores. + + Args: + output: Output folder or URL prefix for the Zarr store(s). + arrays: Mapping from output Zarr array path to source row key. If + omitted for ``RoboticsRow`` inputs, writes the available default + robotics arrays: actions, states, and timestamps. + attrs: Mapping from output Zarr root attribute name to source row key. + Attribute values must be stable across rows in each output store. + episode_ends_path: Output Zarr path for cumulative row/episode end + offsets. Set to None to omit episode boundaries. + store_template: Per-shard store path template. Must include + ``{shard_id}`` and ``{worker_id}``. + video_frame_batch_size: Maximum decoded video frames to append per + video write batch. + array_chunk_bytes: Target byte size for chunks created for newly + written arrays and for read/write batches when reducing shard + stores into a single store. + reduce_to_single_store: If True, add a reducer stage that merges + shard-local stores into one Zarr group at ``output``. Defaults + to True. + """ + return self.with_sink( + ZarrSink( + output=output, + arrays=arrays, + attrs=attrs, + episode_ends_path=episode_ends_path, + store_template=store_template, + video_frame_batch_size=video_frame_batch_size, + array_chunk_bytes=array_chunk_bytes, + reduce_to_single_store=reduce_to_single_store, + ) + ) + def __iter__(self) -> Iterator[Row]: return iter(self.iter_rows()) diff --git a/src/refiner/pipeline/sinks/__init__.py b/src/refiner/pipeline/sinks/__init__.py index 8c1f26df..f0623f21 100644 --- a/src/refiner/pipeline/sinks/__init__.py +++ b/src/refiner/pipeline/sinks/__init__.py @@ -2,6 +2,7 @@ from refiner.pipeline.sinks.jsonl import JsonlSink from refiner.pipeline.sinks.parquet import ParquetSink from refiner.pipeline.sinks.reducer import FileCleanupReducerSink, LeRobotMetaReduceSink +from refiner.pipeline.sinks.zarr import ZarrSink __all__ = [ "BaseSink", @@ -10,4 +11,5 @@ "JsonlSink", "LeRobotMetaReduceSink", "ParquetSink", + "ZarrSink", ] diff --git a/src/refiner/pipeline/sinks/base.py b/src/refiner/pipeline/sinks/base.py index b481523e..f63634cc 100644 --- a/src/refiner/pipeline/sinks/base.py +++ b/src/refiner/pipeline/sinks/base.py @@ -82,6 +82,10 @@ def on_shard_complete(self, shard_id: str) -> None: """ del shard_id + def on_shard_finalized(self, shard_id: str) -> None: + """Run cleanup after the shard has been marked complete.""" + del shard_id + def close(self) -> None: """Finalize sink resources after all shard work is complete. diff --git a/src/refiner/pipeline/sinks/reducer/__init__.py b/src/refiner/pipeline/sinks/reducer/__init__.py index 748cc262..9363a80f 100644 --- a/src/refiner/pipeline/sinks/reducer/__init__.py +++ b/src/refiner/pipeline/sinks/reducer/__init__.py @@ -1,7 +1,9 @@ from refiner.pipeline.sinks.reducer.file import FileCleanupReducerSink from refiner.pipeline.sinks.reducer.lerobot import LeRobotMetaReduceSink +from refiner.pipeline.sinks.reducer.zarr import ZarrReducerSink __all__ = [ "FileCleanupReducerSink", "LeRobotMetaReduceSink", + "ZarrReducerSink", ] diff --git a/src/refiner/pipeline/sinks/reducer/file.py b/src/refiner/pipeline/sinks/reducer/file.py index 2225301e..f32b079f 100644 --- a/src/refiner/pipeline/sinks/reducer/file.py +++ b/src/refiner/pipeline/sinks/reducer/file.py @@ -20,32 +20,36 @@ _DEFAULT_FIELD_PATTERN = r"[^/]+" -def _compile_managed_path_pattern(filename_template: str) -> re.Pattern[str]: - parts: list[str] = [] +def _compile_output_path_patterns(filename_template: str) -> list[re.Pattern[str]]: + path_parts: list[str] = [] + patterns: list[re.Pattern[str]] = [] seen_fields: set[str] = set() - for literal_text, field_name, format_spec, conversion in Formatter().parse( - filename_template - ): - parts.append(re.escape(literal_text)) - if field_name is None: - continue - if conversion is not None or format_spec: - raise ValueError( - "filename_template reducer matching only supports plain " - "named fields without conversion or format specifiers" - ) - if not field_name.isidentifier(): - raise ValueError( - "filename_template reducer matching only supports plain named fields" - ) - if field_name in seen_fields: - # Repeated fields in the template must resolve to the same path segment. - parts.append(f"(?P={field_name})") - continue - pattern = _FIELD_PATTERNS.get(field_name, _DEFAULT_FIELD_PATTERN) - parts.append(f"(?P<{field_name}>{pattern})") - seen_fields.add(field_name) + for segment in (part for part in filename_template.split("/") if part): + segment_parts: list[str] = [] + for literal_text, field_name, format_spec, conversion in Formatter().parse( + segment + ): + segment_parts.append(re.escape(literal_text)) + if field_name is None: + continue + if conversion is not None or format_spec: + raise ValueError( + "filename_template reducer matching only supports plain " + "named fields without conversion or format specifiers" + ) + if not field_name.isidentifier(): + raise ValueError( + "filename_template reducer matching only supports plain named fields" + ) + if field_name in seen_fields: + segment_parts.append(f"(?P={field_name})") + continue + pattern = _FIELD_PATTERNS.get(field_name, _DEFAULT_FIELD_PATTERN) + segment_parts.append(f"(?P<{field_name}>{pattern})") + seen_fields.add(field_name) + path_parts.append("".join(segment_parts)) + patterns.append(re.compile("^" + "/".join(path_parts) + "$")) missing_fields = sorted(_REQUIRED_TEMPLATE_FIELDS.difference(seen_fields)) if missing_fields: @@ -54,7 +58,7 @@ def _compile_managed_path_pattern(filename_template: str) -> re.Pattern[str]: + ", ".join(f"{{{field_name}}}" for field_name in missing_fields) ) - return re.compile("^" + "".join(parts) + "$") + return patterns class FileCleanupReducerSink(BaseSink): @@ -72,7 +76,7 @@ def __init__( self.filename_template = filename_template self.reducer_name = reducer_name self.assets_subdir = assets_subdir - self._managed_path_pattern = _compile_managed_path_pattern(filename_template) + self._output_path_patterns = _compile_output_path_patterns(filename_template) self._cleanup_ran = False def write_shard_block(self, shard_id, block) -> None: @@ -108,56 +112,67 @@ def _run_cleanup(self) -> None: ) keep_pairs = { - ( - row.shard_id, - row.worker_token, - ) + (row.shard_id, row.worker_token) for row in get_finalized_workers(stage_index=stage_index - 1) } - try: - listed_paths = self.output.find("") - except FileNotFoundError: - listed_paths = [] - - assets_prefix = ( - f"{self.assets_subdir.rstrip('/')}/" - if self.assets_subdir is not None - else None + literal_prefix = "" + for literal_text, field_name, _format_spec, _conversion in Formatter().parse( + self.filename_template + ): + literal_prefix += literal_text + if field_name is not None: + break + listing_prefix = ( + "" if "/" not in literal_prefix else literal_prefix.rsplit("/", 1)[0] ) + paths = [listing_prefix] + prefix_parts = [part for part in listing_prefix.split("/") if part] + for pattern in self._output_path_patterns[len(prefix_parts) :]: + next_paths: list[str] = [] + for path in paths: + try: + next_paths.extend( + item + for item in self.output.ls(path, detail=False) + if pattern.fullmatch(item) + ) + except (FileNotFoundError, NotADirectoryError): + continue + paths = next_paths - removed_asset_attempts: set[str] = set() - # Extra template fields are treated as structure only. Authority is decided - # solely from the finalized (shard_id, worker_id) pair extracted from each - # managed path. - for rel_path in listed_paths: - if not isinstance(rel_path, str) or not rel_path or rel_path == ".": + paths_to_delete: set[str] = set() + # Extra template fields are structure only. Authority is decided from + # the finalized (shard_id, worker_id) pair extracted from the path. + for rel_path in paths: + match = self._output_path_patterns[-1].fullmatch(rel_path) + if match is None: continue - if assets_prefix is not None and ( - rel_path == self.assets_subdir or rel_path.startswith(assets_prefix) - ): - attempt_dir = rel_path[len(assets_prefix) :].split("/", maxsplit=1)[0] + if (match.group("shard_id"), match.group("worker_id")) not in keep_pairs: + paths_to_delete.add(rel_path) + + if self.assets_subdir is not None: + asset_prefix = f"{self.assets_subdir.rstrip('/')}/" + try: + asset_paths = self.output.find(self.assets_subdir) + except FileNotFoundError: + asset_paths = [] + for rel_path in asset_paths: + if not rel_path.startswith(asset_prefix): + continue + attempt_dir = rel_path[len(asset_prefix) :].split("/", maxsplit=1)[0] match = ASSET_ATTEMPT_DIR_RE.fullmatch(attempt_dir) if match is None: continue - if (match.group("shard_id"), match.group("worker_id")) in keep_pairs: - continue - if attempt_dir in removed_asset_attempts: - continue - removed_asset_attempts.add(attempt_dir) - try: - self.output.rm(f"{assets_prefix}{attempt_dir}", recursive=True) - except FileNotFoundError: - continue - continue + if ( + match.group("shard_id"), + match.group("worker_id"), + ) not in keep_pairs: + paths_to_delete.add(f"{asset_prefix}{attempt_dir}") - match = self._managed_path_pattern.fullmatch(rel_path) - if match is None: - continue - if (match.group("shard_id"), match.group("worker_id")) in keep_pairs: - continue + for path in sorted(paths_to_delete): try: - self.output.rm(rel_path) + self.output.rm(path, recursive=True) except FileNotFoundError: continue diff --git a/src/refiner/pipeline/sinks/reducer/zarr.py b/src/refiner/pipeline/sinks/reducer/zarr.py new file mode 100644 index 00000000..da56c105 --- /dev/null +++ b/src/refiner/pipeline/sinks/reducer/zarr.py @@ -0,0 +1,281 @@ +from __future__ import annotations + +from collections.abc import Iterable +from typing import Any + +import numpy as np + +from refiner.io.datafolder import DataFolderLike +from refiner.pipeline.data.block import Block +from refiner.pipeline.sinks.reducer.file import FileCleanupReducerSink +from refiner.pipeline.sinks.zarr import ( + _DEFAULT_ARRAY_CHUNK_BYTES, + _append_zarr_array, + _batch_length, + _render_store_relpath, + _zarr_store, +) +from refiner.utils import check_required_dependencies +from refiner.worker.context import get_active_stage_index, get_finalized_workers +from refiner.worker.lifecycle import sort_finalized_workers + + +class ZarrReducerSink(FileCleanupReducerSink): + def __init__( + self, + output: DataFolderLike, + *, + store_template: str, + episode_ends_path: str | None = None, + array_chunk_bytes: int = _DEFAULT_ARRAY_CHUNK_BYTES, + reduce_to_single_store: bool = True, + ) -> None: + check_required_dependencies("write_zarr", ["zarr"], dist="zarr") + super().__init__( + output=output, + filename_template=( + f"_parts/{store_template}" if reduce_to_single_store else store_template + ), + reducer_name="write_zarr_reduce", + ) + self.store_template = store_template + self.episode_ends_path = episode_ends_path + self.array_chunk_bytes = array_chunk_bytes + self.reduce_to_single_store = reduce_to_single_store + + def write_shard_block(self, shard_id: str, block: Block) -> None: + self._run_cleanup() + if self.reduce_to_single_store: + self._merge() + return + + relpaths = self._finalized_store_paths() + self._collect_stores(relpaths, for_merge=False) + try: + self.output.rm("_parts", recursive=True) + except FileNotFoundError: + pass + self._clear_root_payload_except(relpaths) + + def _clear_root_payload_except(self, relpaths: Iterable[str]) -> None: + import zarr + + keep_paths = set(relpaths) + try: + root = zarr.open_group(store=_zarr_store(self.output, "", mode="r+")) + except Exception: + return + + def clear_group(group: Any, prefix: str = "") -> None: + group_keys = set(group.group_keys()) + for key in sorted({*group.array_keys(), *group_keys}): + path = f"{prefix}/{key}" if prefix else key + if path == "_refiner" or path.startswith("_refiner/"): + continue + if path in keep_paths: + continue + if key in group_keys and any( + keep_path.startswith(f"{path}/") for keep_path in keep_paths + ): + clear_group(group[key], path) + continue + del group[key] + group.attrs.clear() + + clear_group(root) + + def describe(self) -> tuple[str, str, dict[str, object]]: + return ( + "write_zarr_reduce", + "writer", + { + "path": self.output.abs_path(), + "store_template": self.store_template, + "array_chunk_bytes": self.array_chunk_bytes, + "reduce_to_single_store": self.reduce_to_single_store, + }, + ) + + def _merge(self) -> None: + expected_parts = self._finalized_store_paths(prefix="_parts/") + import zarr + + parts = self._collect_stores(expected_parts, for_merge=True) + final_attrs: dict[str, Any] | None = None + for relpath, _paths in parts: + source = zarr.open_group( + store=_zarr_store(self.output, relpath, mode="r"), + mode="r", + ) + source_attrs = dict(source.attrs) + if final_attrs is None: + final_attrs = source_attrs + elif source_attrs != final_attrs: + raise ValueError("Zarr part store attrs differ") + + final = zarr.open_group( + store=_zarr_store(self.output, "", mode="a"), + mode="a", + ) + for key in sorted({*final.array_keys(), *final.group_keys()}): + if key != "_parts": + del final[key] + final.attrs.clear() + if final_attrs is not None: + final.attrs.update(final_attrs) + + row_offset = 0 + arrays: dict[str, Any] = {} + for relpath, paths in parts: + source = zarr.open_group( + store=_zarr_store(self.output, relpath, mode="r"), + mode="r", + ) + episode_path = self.episode_ends_path + payload_paths = paths - ( + {episode_path} if episode_path is not None else set() + ) + payload_lengths = {int(source[path].shape[0]) for path in payload_paths} + if len(payload_lengths) > 1: + raise ValueError( + "Zarr part store payload arrays must have matching lengths" + ) + payload_rows = next(iter(payload_lengths), None) + part_end = 0 + if episode_path is not None and episode_path in paths: + episode_ends = source[episode_path] + if episode_ends.shape[0] > 0: + part_end = int(np.asarray(episode_ends[-1])) + if payload_rows is not None and part_end != payload_rows: + raise ValueError( + "Zarr part store episode_ends final value does not match " + "payload row count" + ) + for path in sorted(paths): + source_array = source[path] + chunks = getattr(source_array, "chunks", None) + compressor = getattr(source_array, "compressor", None) + if source_array.shape[0] == 0 and path == episode_path: + continue + if source_array.shape[0] == 0: + _append_zarr_array( + final, + arrays, + path, + np.asarray(source_array[:0]), + chunks=chunks, + compressor=compressor, + ) + continue + + batch_size = _batch_length(source_array, self.array_chunk_bytes) + for start in range(0, int(source_array.shape[0]), batch_size): + end = min(int(source_array.shape[0]), start + batch_size) + values = np.asarray(source_array[start:end]) + if path == episode_path: + values = np.asarray(values, dtype=np.int64) + values = values + row_offset + _append_zarr_array( + final, + arrays, + path, + values, + chunks=chunks, + compressor=compressor, + ) + if path == episode_path: + row_offset += part_end + + def on_shard_finalized(self, shard_id: str) -> None: + del shard_id + if not self.reduce_to_single_store: + return + try: + self.output.rm("_parts", recursive=True) + except FileNotFoundError: + pass + + def _finalized_store_paths(self, prefix: str = "") -> list[str]: + stage_index = get_active_stage_index() + if stage_index is None or stage_index <= 0: + raise ValueError( + "write_zarr_reduce requires an active reducer stage with a prior writer stage" + ) + return [ + prefix + + _render_store_relpath( + self.store_template, + shard_id=row.shard_id, + worker_id=row.worker_token, + ) + for row in sort_finalized_workers( + get_finalized_workers(stage_index=stage_index - 1) + ) + ] + + def _collect_stores( + self, + relpaths: Iterable[str], + *, + for_merge: bool, + ) -> list[tuple[str, set[str]]]: + import zarr + + stores: list[tuple[str, set[str]]] = [] + payload_paths: set[str] | None = None + schemas: dict[str, tuple[tuple[int, ...], np.dtype[Any]]] = {} + episode_path = self.episode_ends_path if for_merge else None + for relpath in relpaths: + if not self.output.exists(relpath): + kind = "part store" if for_merge else "store" + raise ValueError(f"Zarr {kind} is missing: {relpath}") + source = zarr.open_group( + store=_zarr_store(self.output, relpath, mode="r"), + mode="r", + ) + source_paths = set(_iter_array_paths(source)) + if not source_paths: + continue + source_payload_paths = ( + source_paths - {episode_path} + if episode_path is not None + else source_paths + ) + if ( + episode_path is not None + and source_payload_paths + and episode_path not in source_paths + ): + raise ValueError(f"Zarr part stores must contain {episode_path!r}") + if payload_paths is None: + payload_paths = source_payload_paths + elif source_payload_paths != payload_paths: + kind = "part stores" if for_merge else "stores" + payload = " payload" if for_merge else "" + raise ValueError(f"Zarr {kind} must contain the same{payload} arrays") + for path in source_paths: + source_array = source[path] + schema = (tuple(source_array.shape[1:]), np.dtype(source_array.dtype)) + previous = schemas.setdefault(path, schema) + if previous != schema: + if previous[0] != schema[0]: + raise ValueError( + f"Zarr arrays for {path!r} must have matching trailing shapes" + ) + raise ValueError( + f"Zarr arrays for {path!r} must have matching dtypes" + ) + stores.append((relpath, source_paths)) + return stores + + +def _iter_array_paths(group: Any, prefix: str = "") -> Iterable[str]: + for name, item in group.items(): + path = f"{prefix}/{name}" if prefix else name + if hasattr(item, "shape"): + yield path + else: + yield from _iter_array_paths(item, path) + + +__all__ = ["ZarrReducerSink"] diff --git a/src/refiner/pipeline/sinks/zarr.py b/src/refiner/pipeline/sinks/zarr.py new file mode 100644 index 00000000..96d0b7fe --- /dev/null +++ b/src/refiner/pipeline/sinks/zarr.py @@ -0,0 +1,600 @@ +from __future__ import annotations + +from collections.abc import Iterable, Mapping +from dataclasses import dataclass, field +from string import Formatter +from typing import Any + +import numpy as np +import pyarrow as pa + +from refiner.execution.asyncio.runtime import submit +from refiner.io.datafolder import DataFolder, DataFolderLike +from refiner.pipeline.data.block import Block +from refiner.pipeline.data.row import Row +from refiner.pipeline.sinks.base import BaseSink +from refiner.robotics.row import RoboticsRow +from refiner.utils import check_required_dependencies +from refiner.video import VideoSource +from refiner.worker.context import get_active_worker_token + +_DEFAULT_ARRAY_CHUNK_BYTES = 8 * 1024 * 1024 +_MAX_INITIAL_CHUNK_ROWS = 1024 + + +@dataclass +class _ZarrWriteState: + root: Any + arrays: dict[str, Any] = field(default_factory=dict) + row_end: int = 0 + + +class ZarrSink(BaseSink): + def __init__( + self, + output: DataFolderLike, + *, + arrays: Mapping[str, str] | None = None, + attrs: Mapping[str, str] | None = None, + episode_ends_path: str | None = "meta/episode_ends", + store_template: str = "{shard_id}__w{worker_id}.zarr", + video_frame_batch_size: int = 8, + array_chunk_bytes: int = _DEFAULT_ARRAY_CHUNK_BYTES, + reduce_to_single_store: bool = True, + ): + check_required_dependencies("write_zarr", ["zarr"], dist="zarr") + if video_frame_batch_size <= 0: + raise ValueError("video_frame_batch_size must be greater than zero") + if array_chunk_bytes <= 0: + raise ValueError("array_chunk_bytes must be greater than zero") + _validate_store_template(store_template) + self.output = DataFolder.resolve(output) + self.episode_ends_path = ( + _normalize_public_zarr_path(episode_ends_path, "episode_ends_path") + if episode_ends_path is not None + else None + ) + self.arrays = ( + _normalize_array_paths(arrays, self.episode_ends_path) + if arrays is not None + else None + ) + self.attrs = dict(attrs) if attrs is not None else None + if self.arrays is not None: + if not self.arrays: + raise ValueError("write_zarr arrays must not be empty") + self.store_template = store_template + self.video_frame_batch_size = video_frame_batch_size + self.array_chunk_bytes = array_chunk_bytes + self.reduce_to_single_store = reduce_to_single_store + self._stores: dict[str, _ZarrWriteState] = {} + self._default_arrays: dict[str, str] | None = None + + def write_shard_block(self, shard_id: str, block: Block) -> int: + count = 0 + pending_arrays: dict[str, list[np.ndarray]] = {} + pending_lengths: list[int] = [] + pending_bytes = 0 + + def flush_pending() -> None: + nonlocal pending_arrays, pending_lengths, pending_bytes + if not pending_arrays: + return + store = self._store(shard_id) + previous_row_end = store.row_end + rollback_lengths = self._snapshot_array_lengths( + store, + pending_arrays.keys(), + ) + if self.episode_ends_path is not None: + rollback_lengths.update( + self._snapshot_array_lengths(store, [self.episode_ends_path]) + ) + try: + combined = { + zarr_path: ( + arrays[0] if len(arrays) == 1 else np.concatenate(arrays) + ) + for zarr_path, arrays in pending_arrays.items() + } + for zarr_path, array in combined.items(): + self._append_array(store, zarr_path, array) + if self.episode_ends_path is not None: + row_ends = ( + np.cumsum(np.asarray(pending_lengths, dtype=np.int64)) + + store.row_end + ) + self._append_array(store, self.episode_ends_path, row_ends) + store.row_end = int(row_ends[-1]) + except Exception: + self._restore_array_lengths(store, rollback_lengths) + store.row_end = previous_row_end + raise + finally: + pending_arrays = {} + pending_lengths = [] + pending_bytes = 0 + + for row in block: + try: + arrays = self._arrays_for_row(row) + row_arrays, row_videos, lengths = self._split_row_values(row, arrays) + except Exception: + flush_pending() + raise + + if row_videos: + flush_pending() + self._write_row_values(shard_id, row, row_arrays, row_videos, lengths) + count += 1 + continue + + try: + length = _matching_length(lengths) + except Exception: + flush_pending() + raise + if length is not None: + row_bytes = sum(array.nbytes for array in row_arrays.values()) + if pending_arrays and ( + pending_bytes + row_bytes > self.array_chunk_bytes + or len(pending_lengths) >= _MAX_INITIAL_CHUNK_ROWS + or set(row_arrays) != set(pending_arrays) + or any( + pending_arrays[zarr_path][0].shape[1:] != array.shape[1:] + or pending_arrays[zarr_path][0].dtype != array.dtype + for zarr_path, array in row_arrays.items() + ) + ): + flush_pending() + store = self._store(shard_id) + self._write_attrs(store, row) + for zarr_path, array in row_arrays.items(): + pending_arrays.setdefault(zarr_path, []).append(array) + pending_lengths.append(length) + pending_bytes += row_bytes + count += 1 + flush_pending() + return count + + def _write_attrs(self, store: _ZarrWriteState, row: Row) -> None: + if self.attrs is None: + return + for attr_name, source_key in self.attrs.items(): + value = _row_value(row, source_key) + if isinstance(value, np.generic): + value = value.item() + elif isinstance(value, np.ndarray): + value = value.tolist() + if attr_name in store.root.attrs and store.root.attrs[attr_name] != value: + raise ValueError(f"Zarr attr {attr_name!r} changed across rows") + store.root.attrs[attr_name] = value + + def _write_row_values( + self, + shard_id: str, + row: Row, + row_arrays: dict[str, np.ndarray], + row_videos: list[tuple[str, VideoSource]], + lengths: list[int], + ) -> None: + store = self._store(shard_id) + self._write_attrs(store, row) + expected_length = _matching_length(lengths) + + previous_row_end = store.row_end + rollback_lengths = self._snapshot_array_lengths( + store, + [*row_arrays, *(path for path, _ in row_videos)], + ) + try: + for zarr_path, array in row_arrays.items(): + self._append_array(store, zarr_path, array) + for zarr_path, video in row_videos: + video_length = submit( + self._append_video( + store, + zarr_path, + video, + expected_length=expected_length, + ) + ).result() + lengths.append(video_length) + length = _matching_length(lengths) + if length is not None and self.episode_ends_path is not None: + rollback_lengths.update( + self._snapshot_array_lengths(store, [self.episode_ends_path]) + ) + row_end = store.row_end + length + self._append_array( + store, + self.episode_ends_path, + np.asarray([row_end], dtype=np.int64), + ) + store.row_end = row_end + except Exception: + self._restore_array_lengths(store, rollback_lengths) + store.row_end = previous_row_end + raise + + def _split_row_values( + self, + row: Row, + arrays: Mapping[str, str], + ) -> tuple[dict[str, np.ndarray], list[tuple[str, VideoSource]], list[int]]: + row_arrays: dict[str, np.ndarray] = {} + row_videos: list[tuple[str, VideoSource]] = [] + lengths: list[int] = [] + for zarr_path, source_key in arrays.items(): + value = _row_value(row, source_key) + if value is None: + raise ValueError(f"Zarr source value is missing: {source_key}") + if isinstance(value, VideoSource): + row_videos.append((zarr_path, value)) + continue + array = _as_array(value) + if array.ndim == 0: + array = array.reshape(1) + lengths.append(int(array.shape[0])) + row_arrays[zarr_path] = array + return row_arrays, row_videos, lengths + + async def _append_video( + self, + store: _ZarrWriteState, + path: str, + video: VideoSource, + *, + expected_length: int | None = None, + ) -> int: + batch: list[np.ndarray] = [] + batch_limit: int | None = None + count = 0 + + def flush_batch() -> None: + nonlocal count + if not batch: + return + if expected_length is not None and count + len(batch) > expected_length: + raise ValueError("Zarr arrays for one row must have matching lengths") + self._append_array( + store, + path, + np.stack(batch, axis=0), + chunks=(batch_limit or len(batch), *batch[0].shape), + ) + count += len(batch) + batch.clear() + + async for frame in video.iter_numpy_frames(): + limit = batch_limit + if limit is None: + limit = min( + self.video_frame_batch_size, + _batch_length_for_shape( + (1, *frame.shape), frame.dtype, self.array_chunk_bytes + ), + ) + batch_limit = limit + batch.append(frame) + if len(batch) >= limit: + flush_batch() + flush_batch() + if count == 0: + raise ValueError("Zarr video source produced no frames") + if expected_length is not None and count != expected_length: + raise ValueError("Zarr arrays for one row must have matching lengths") + return count + + def _arrays_for_row(self, row: Row) -> dict[str, str]: + if self.arrays is not None: + return self.arrays + default_arrays = _default_robotics_arrays(row) + if self._default_arrays is None: + self._default_arrays = _normalize_array_paths( + default_arrays, + self.episode_ends_path, + ) + if not self._default_arrays: + raise ValueError( + "write_zarr inferred no default robotics arrays; pass arrays=..." + ) + elif default_arrays != self._default_arrays: + raise ValueError( + "Zarr default arrays changed across rows; pass arrays=... " + "to write an explicit stable schema" + ) + return self._default_arrays + + def _store(self, shard_id: str) -> _ZarrWriteState: + relpath = self._store_relpath(shard_id) + store = self._stores.get(relpath) + if store is not None: + return store + import zarr + + store = _ZarrWriteState( + zarr.open_group(store=_zarr_store(self.output, relpath, mode="w"), mode="w") + ) + self._stores[relpath] = store + return store + + def _store_relpath(self, shard_id: str) -> str: + relpath = _render_store_relpath( + self.store_template, + shard_id=shard_id, + worker_id=get_active_worker_token(), + ) + return f"_parts/{relpath}" if self.reduce_to_single_store else relpath + + def on_shard_complete(self, shard_id: str) -> None: + relpath = self._store_relpath(shard_id) + if relpath not in self._stores: + import zarr + + zarr.open_group(store=_zarr_store(self.output, relpath, mode="w"), mode="w") + self._stores.pop(relpath, None) + + def _append_array( + self, + store: _ZarrWriteState, + path: str, + array: np.ndarray, + *, + chunks: tuple[int, ...] | None = None, + ) -> None: + _append_zarr_array( + store.root, + store.arrays, + path, + array, + chunks=chunks or _chunk_shape(array, self.array_chunk_bytes), + ) + + def _snapshot_array_lengths( + self, + store: _ZarrWriteState, + paths: Iterable[str], + ) -> dict[str, int | None]: + return { + path: None + if (dataset := store.arrays.get(path)) is None + else int(dataset.shape[0]) + for path in paths + } + + def _restore_array_lengths( + self, + store: _ZarrWriteState, + lengths: Mapping[str, int | None], + ) -> None: + for path, length in lengths.items(): + if length is None: + store.arrays.pop(path, None) + try: + del store.root[path] + except (KeyError, FileNotFoundError): + pass + continue + dataset = store.arrays.get(path) + if dataset is not None: + dataset.resize((length, *dataset.shape[1:])) + + def close(self) -> None: + self._stores.clear() + + def describe(self) -> tuple[str, str, dict[str, object]]: + return ( + "write_zarr", + "writer", + { + "path": self.output.abs_path(), + "arrays": dict(self.arrays) if self.arrays is not None else None, + "attrs": dict(self.attrs) if self.attrs is not None else None, + "episode_ends_path": self.episode_ends_path, + "store_template": self.store_template, + "video_frame_batch_size": self.video_frame_batch_size, + "array_chunk_bytes": self.array_chunk_bytes, + "reduce_to_single_store": self.reduce_to_single_store, + }, + ) + + def build_reducer(self) -> BaseSink | None: + from refiner.pipeline.sinks.reducer.zarr import ZarrReducerSink + + return ZarrReducerSink( + output=self.output, + store_template=self.store_template, + episode_ends_path=self.episode_ends_path, + array_chunk_bytes=self.array_chunk_bytes, + reduce_to_single_store=self.reduce_to_single_store, + ) + + +def _default_robotics_arrays(row: Row) -> dict[str, str]: + if not isinstance(row, RoboticsRow): + raise ValueError("write_zarr requires arrays=... for non-RoboticsRow inputs") + arrays: dict[str, str] = {} + if row.actions is not None: + arrays["data/action"] = "action" + if row.states is not None: + arrays["data/observation.state"] = "observation.state" + if row.timestamps is not None: + arrays["data/timestamp"] = "timestamp" + return arrays + + +def _normalize_array_paths( + arrays: Mapping[str, str], + episode_ends_path: str | None, +) -> dict[str, str]: + normalized: dict[str, str] = {} + for path, source_key in arrays.items(): + normalized_path = _normalize_public_zarr_path(path, "Zarr array path") + if normalized_path in normalized: + raise ValueError(f"Duplicate Zarr array path: {normalized_path}") + normalized[normalized_path] = source_key + if episode_ends_path is not None and episode_ends_path in normalized: + raise ValueError( + f"Zarr array path collides with episode_ends_path: {episode_ends_path}" + ) + return normalized + + +def _validate_store_template(store_template: str) -> None: + _normalize_public_zarr_path(store_template, "store_template") + fields: set[str] = set() + for _literal_text, field_name, format_spec, conversion in Formatter().parse( + store_template + ): + if field_name is None: + continue + if conversion is not None or format_spec: + raise ValueError( + "store_template only supports plain {shard_id} and {worker_id} fields" + ) + if field_name not in {"shard_id", "worker_id"}: + raise ValueError( + "store_template only supports plain {shard_id} and {worker_id} fields" + ) + fields.add(field_name) + missing_fields = {"shard_id", "worker_id"}.difference(fields) + if missing_fields: + raise ValueError( + "store_template requires fields: " + + ", ".join(f"{{{field_name}}}" for field_name in sorted(missing_fields)) + ) + + +def _render_store_relpath( + store_template: str, + *, + shard_id: str, + worker_id: str, +) -> str: + relpath = store_template.format(shard_id=shard_id, worker_id=worker_id) + _normalize_public_zarr_path(relpath, "rendered store path") + return relpath + + +def _normalize_public_zarr_path(path: str, label: str) -> str: + path = str(path) + if path.startswith("/"): + raise ValueError(f"{label} must be relative") + parts = [part for part in path.split("/") if part] + if not parts: + raise ValueError(f"{label} must not be empty") + if any(part in {".", ".."} for part in parts): + raise ValueError(f"{label} must not contain '.' or '..' segments") + root = parts[0] + if root in {"_parts", "_refiner"}: + raise ValueError(f"{label} must not use reserved root: {root}") + return "/".join(parts) + + +def _row_value(row: Row, key: str) -> Any: + if isinstance(row, RoboticsRow): + if key == "action": + return row.actions + if key == "observation.state": + return row.states + if key == "timestamp": + return row.timestamps + if key.startswith("observation."): + try: + return row.observations(key) + except KeyError: + video = row.videos.get(key) + if video is None: + raise + return video + return row[key] + + +def _as_array(value: Any) -> np.ndarray: + if isinstance(value, pa.ChunkedArray): + return _as_array(value.combine_chunks()) + if isinstance(value, pa.Array): + if pa.types.is_primitive(value.type): + return value.to_numpy(zero_copy_only=False) + return np.asarray(value.to_pylist()) + if isinstance(value, Iterable) and not isinstance(value, str | bytes | np.ndarray): + return np.asarray(list(value)) + return np.asarray(value) + + +def _matching_length(lengths: list[int]) -> int | None: + if not lengths: + return None + length = lengths[0] + if any(item != length for item in lengths): + raise ValueError("Zarr arrays for one row must have matching lengths") + return length + + +def _zarr_store(output: DataFolder, path: str = "", *, mode: str = "r"): + import zarr + + return zarr.storage.FSStore( + output._join(path), + fs=output.fs, + mode=mode, + create=mode in {"w", "w-", "a"}, + ) + + +def _chunk_shape(array: np.ndarray, target_bytes: int) -> tuple[int, ...]: + chunk_rows = min( + _batch_length(array, target_bytes), + max(int(array.shape[0]), _MAX_INITIAL_CHUNK_ROWS), + ) + return (chunk_rows, *array.shape[1:]) + + +def _batch_length(array: Any, target_bytes: int) -> int: + return _batch_length_for_shape(tuple(array.shape), array.dtype, target_bytes) + + +def _batch_length_for_shape( + shape: tuple[int, ...], + dtype: np.dtype[Any] | type[Any], + target_bytes: int, +) -> int: + dtype = np.dtype(dtype) + row_values = int(np.prod(shape[1:], dtype=np.int64)) + row_bytes = max(1, dtype.itemsize * max(1, row_values)) + return max(1, target_bytes // row_bytes) + + +def _validate_array_schema(path: str, dataset: Any, values: np.ndarray) -> None: + if tuple(dataset.shape[1:]) != tuple(values.shape[1:]): + raise ValueError(f"Zarr arrays for {path!r} must have matching trailing shapes") + if dataset.dtype != values.dtype: + raise ValueError(f"Zarr arrays for {path!r} must have matching dtypes") + + +def _append_zarr_array( + root: Any, + arrays: dict[str, Any], + path: str, + values: np.ndarray, + *, + chunks: tuple[int, ...] | None = None, + compressor: Any = None, +) -> None: + dataset = arrays.get(path) + if dataset is None: + dataset = root.create_dataset( + path, + shape=(0, *values.shape[1:]), + chunks=chunks, + dtype=values.dtype, + compressor=compressor, + ) + arrays[path] = dataset + else: + _validate_array_schema(path, dataset, values) + dataset.append(values, axis=0) + + +__all__ = ["ZarrSink"] diff --git a/src/refiner/pipeline/sources/readers/zarr.py b/src/refiner/pipeline/sources/readers/zarr.py index 91787107..587c94d0 100644 --- a/src/refiner/pipeline/sources/readers/zarr.py +++ b/src/refiner/pipeline/sources/readers/zarr.py @@ -563,8 +563,7 @@ def _leading_item_bytes(array: Any) -> int: def _iter_array_paths(group: Any, prefix: str = "") -> Iterator[str]: - items = group.items() if hasattr(group, "items") else group.members() - for name, item in items: + for name, item in group.items(): path = f"{prefix}/{name}" if prefix else name if hasattr(item, "shape"): yield path diff --git a/src/refiner/video/types.py b/src/refiner/video/types.py index a5d67368..4414d347 100644 --- a/src/refiner/video/types.py +++ b/src/refiner/video/types.py @@ -28,6 +28,8 @@ def clipped( def iter_frames(self) -> AsyncIterator[DecodedVideoFrame]: ... + def iter_numpy_frames(self) -> AsyncIterator[np.ndarray]: ... + def iter_frame_windows( self, *, @@ -108,6 +110,10 @@ def iter_frames(self) -> AsyncIterator[DecodedVideoFrame]: return iter_encoded_frames(self) + async def iter_numpy_frames(self) -> AsyncIterator[np.ndarray]: + async for frame in self.iter_frames(): + yield frame.frame.to_ndarray(format="rgb24") + def iter_frame_windows( self, *, @@ -172,6 +178,10 @@ def iter_frames(self) -> AsyncIterator[DecodedVideoFrame]: return iter_encoded_frames(self) + async def iter_numpy_frames(self) -> AsyncIterator[np.ndarray]: + async for frame in self.iter_frames(): + yield frame.frame.to_ndarray(format="rgb24") + def iter_frame_windows( self, *, @@ -236,6 +246,10 @@ def duration_s(self) -> float: def iter_frame_arrays(self) -> Iterator[np.ndarray]: yield from self._array + async def iter_numpy_frames(self) -> AsyncIterator[np.ndarray]: + for frame in self._array: + yield frame + def clipped( self, *, diff --git a/src/refiner/worker/lifecycle.py b/src/refiner/worker/lifecycle.py index 89f09791..2d081876 100644 --- a/src/refiner/worker/lifecycle.py +++ b/src/refiner/worker/lifecycle.py @@ -13,6 +13,7 @@ class FinalizedShardWorker(msgspec.Struct, frozen=True): shard_id: str worker_id: str + global_ordinal: int | None = None @property def worker_token(self) -> str: @@ -53,12 +54,29 @@ def read_finalized_workers( except Exception: continue shard_id = payload.get("shard_id") if isinstance(payload, dict) else None + global_ordinal = ( + payload.get("global_ordinal") if isinstance(payload, dict) else None + ) if isinstance(shard_id, str): rows.append( - FinalizedShardWorker(shard_id=shard_id, worker_id=worker_id) + FinalizedShardWorker( + shard_id=shard_id, + worker_id=worker_id, + global_ordinal=( + global_ordinal if isinstance(global_ordinal, int) else None + ), + ) ) - rows.sort(key=lambda row: row.shard_id) - return rows + return sort_finalized_workers(rows) + + +def sort_finalized_workers( + rows: Iterable[FinalizedShardWorker], +) -> list[FinalizedShardWorker]: + rows = list(rows) + if any(row.global_ordinal is None for row in rows): + return sorted(rows, key=lambda row: row.shard_id) + return sorted(rows, key=lambda row: row.global_ordinal) class LocalRuntimeLifecycle: @@ -90,7 +108,15 @@ def complete(self, shard: Shard) -> None: ) path.parent.mkdir(parents=True, exist_ok=True) with path.open("a", encoding="utf-8") as handle: - handle.write(json.dumps({"shard_id": shard.id}, sort_keys=True)) + handle.write( + json.dumps( + { + "global_ordinal": shard.global_ordinal, + "shard_id": shard.id, + }, + sort_keys=True, + ) + ) handle.write("\n") return None @@ -112,4 +138,5 @@ def finalized_workers( "LocalRuntimeLifecycle", "RuntimeLifecycle", "read_finalized_workers", + "sort_finalized_workers", ] diff --git a/src/refiner/worker/runner.py b/src/refiner/worker/runner.py index b046a6cc..4468f1ff 100644 --- a/src/refiner/worker/runner.py +++ b/src/refiner/worker/runner.py @@ -130,6 +130,16 @@ def _complete_shard(shard_id: str) -> None: sink.on_shard_complete(shard_id) self.user_metrics_emitter.force_flush_user_metrics() self.runtime_lifecycle.complete(shard) + try: + with set_active_step_index(sink_step_index): + sink.on_shard_finalized(shard_id) + except Exception as e: # noqa: BLE001 + logger.warning( + "post-completion sink cleanup failed shard_id={}: {}: {}", + shard.id, + type(e).__name__, + e, + ) with inflight_lock: inflight_by_id.pop(shard_id, None) source_done_shards.discard(shard_id) diff --git a/tests/pipeline/test_sinks.py b/tests/pipeline/test_sinks.py index 3aaee883..07ef0d55 100644 --- a/tests/pipeline/test_sinks.py +++ b/tests/pipeline/test_sinks.py @@ -929,6 +929,191 @@ def test_file_cleanup_reducer_ignores_extra_template_fields(tmp_path) -> None: assert unmanaged_file.exists() +def test_file_cleanup_reducer_removes_non_finalized_directories(tmp_path) -> None: + output_dir = tmp_path / "zarr-cleanup" + shard_id = "0123456789ab" + winner_worker_id = "worker-2" + loser_worker_id = "worker-1" + winner_dir = output_dir / f"{shard_id}__w{worker_token_for(winner_worker_id)}.zarr" + loser_dir = output_dir / f"{shard_id}__w{worker_token_for(loser_worker_id)}.zarr" + (winner_dir / "data").mkdir(parents=True) + (loser_dir / "data").mkdir(parents=True) + (winner_dir / "data" / "0").write_bytes(b"keep") + (loser_dir / "data" / "0").write_bytes(b"drop") + + reducer = FileCleanupReducerSink( + output_dir, + filename_template="{shard_id}__w{worker_id}.zarr", + reducer_name="cleanup_zarr", + ) + with set_active_run_context( + job_id="job", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast( + RuntimeLifecycle, + _FinalizedWorkersRuntime( + [FinalizedShardWorker(shard_id=shard_id, worker_id=winner_worker_id)] + ), + ), + ): + reducer.write_block([DictRow({"task_rank": 0}, shard_id="reduce")]) + + assert winner_dir.exists() + assert not loser_dir.exists() + + +def test_file_cleanup_reducer_removes_non_finalized_nested_directories( + tmp_path, +) -> None: + output_dir = tmp_path / "zarr-cleanup-nested" + shard_id = "0123456789ab" + winner_worker_id = "worker-2" + loser_worker_id = "worker-1" + winner_dir = ( + output_dir / "split" / f"{shard_id}__w{worker_token_for(winner_worker_id)}.zarr" + ) + loser_dir = ( + output_dir / "split" / f"{shard_id}__w{worker_token_for(loser_worker_id)}.zarr" + ) + (winner_dir / "data").mkdir(parents=True) + (loser_dir / "data").mkdir(parents=True) + (winner_dir / "data" / "0").write_bytes(b"keep") + (loser_dir / "data" / "0").write_bytes(b"drop") + + reducer = FileCleanupReducerSink( + output_dir, + filename_template="split/{shard_id}__w{worker_id}.zarr", + reducer_name="cleanup_zarr", + ) + with set_active_run_context( + job_id="job", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast( + RuntimeLifecycle, + _FinalizedWorkersRuntime( + [FinalizedShardWorker(shard_id=shard_id, worker_id=winner_worker_id)] + ), + ), + ): + reducer.write_block([DictRow({"task_rank": 0}, shard_id="reduce")]) + + assert winner_dir.exists() + assert not loser_dir.exists() + + +def test_file_cleanup_reducer_removes_dynamic_nested_directories(tmp_path) -> None: + output_dir = tmp_path / "zarr-cleanup-dynamic-nested" + shard_id = "0123456789ab" + winner_worker_id = "worker-2" + loser_worker_id = "worker-1" + winner_dir = ( + output_dir / "split" / shard_id / f"{worker_token_for(winner_worker_id)}.zarr" + ) + loser_dir = ( + output_dir / "split" / shard_id / f"{worker_token_for(loser_worker_id)}.zarr" + ) + (winner_dir / "data").mkdir(parents=True) + (loser_dir / "data").mkdir(parents=True) + (winner_dir / "data" / "0").write_bytes(b"keep") + (loser_dir / "data" / "0").write_bytes(b"drop") + + reducer = FileCleanupReducerSink( + output_dir, + filename_template="split/{shard_id}/{worker_id}.zarr", + reducer_name="cleanup_zarr", + ) + with set_active_run_context( + job_id="job", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast( + RuntimeLifecycle, + _FinalizedWorkersRuntime( + [FinalizedShardWorker(shard_id=shard_id, worker_id=winner_worker_id)] + ), + ), + ): + reducer.write_block([DictRow({"task_rank": 0}, shard_id="reduce")]) + + assert winner_dir.exists() + assert not loser_dir.exists() + + +def test_file_cleanup_reducer_ignores_files_during_template_listing( + tmp_path, +) -> None: + output_dir = tmp_path / "zarr-cleanup-mixed" + shard_id = "0123456789ab" + winner_worker_id = "worker-2" + loser_worker_id = "worker-1" + winner_dir = ( + output_dir / "split" / shard_id / f"{worker_token_for(winner_worker_id)}.zarr" + ) + loser_dir = ( + output_dir / "split" / shard_id / f"{worker_token_for(loser_worker_id)}.zarr" + ) + (winner_dir / "data").mkdir(parents=True) + (loser_dir / "data").mkdir(parents=True) + (winner_dir / "data" / "0").write_bytes(b"keep") + (loser_dir / "data" / "0").write_bytes(b"drop") + (output_dir / "split" / "README.txt").write_text("notes", encoding="utf-8") + + reducer = FileCleanupReducerSink( + output_dir, + filename_template="split/{shard_id}/{worker_id}.zarr", + reducer_name="cleanup_zarr", + ) + with set_active_run_context( + job_id="job", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast( + RuntimeLifecycle, + _FinalizedWorkersRuntime( + [FinalizedShardWorker(shard_id=shard_id, worker_id=winner_worker_id)] + ), + ), + ): + reducer.write_block([DictRow({"task_rank": 0}, shard_id="reduce")]) + + assert winner_dir.exists() + assert not loser_dir.exists() + assert (output_dir / "split" / "README.txt").read_text(encoding="utf-8") == "notes" + + +def test_file_cleanup_reducer_propagates_template_listing_errors( + tmp_path, monkeypatch +) -> None: + output_dir = tmp_path / "zarr-cleanup-list-error" + output_dir.mkdir() + reducer = FileCleanupReducerSink( + output_dir, + filename_template="{shard_id}__w{worker_id}.zarr", + reducer_name="cleanup_zarr", + ) + + def fail_ls(*_args, **_kwargs): + raise OSError("list failed") + + monkeypatch.setattr(reducer.output, "ls", fail_ls) + + with set_active_run_context( + job_id="job", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast(RuntimeLifecycle, _FinalizedWorkersRuntime([])), + ): + with pytest.raises(OSError, match="list failed"): + reducer.write_block([DictRow({"task_rank": 0}, shard_id="reduce")]) + + def test_file_cleanup_reducer_tolerates_duplicate_listed_paths( tmp_path, monkeypatch ) -> None: @@ -951,8 +1136,12 @@ def test_file_cleanup_reducer_tolerates_duplicate_listed_paths( ) monkeypatch.setattr( reducer.output, - "find", - lambda _: [winner_path.name, winner_path.name, loser_path.name], + "ls", + lambda *_args, **_kwargs: [ + winner_path.name, + winner_path.name, + loser_path.name, + ], ) with set_active_run_context( diff --git a/tests/readers/test_zarr_reader.py b/tests/readers/test_zarr_reader.py index 3948983b..24bccb34 100644 --- a/tests/readers/test_zarr_reader.py +++ b/tests/readers/test_zarr_reader.py @@ -11,9 +11,46 @@ import refiner as mdr from refiner.io.datafolder import DataFolder +from refiner.io import DataFile from refiner.robotics.row import RoboticsRow +from refiner.pipeline.data.row import DictRow from refiner.pipeline.data.row import Row from refiner.pipeline.data.shard import RowRangeDescriptor +from refiner.pipeline.sinks.reducer.zarr import ZarrReducerSink +from refiner.pipeline.sinks.zarr import ZarrSink +from refiner.worker.context import set_active_run_context, worker_token_for +from refiner.worker.lifecycle import FinalizedShardWorker, RuntimeLifecycle + + +class _FinalizedWorkersRuntime: + def __init__(self, rows: list[FinalizedShardWorker]) -> None: + self._rows = rows + + def finalized_workers( + self, *, stage_index: int | None = None + ) -> list[FinalizedShardWorker]: + assert stage_index == 0 + return self._rows + + +class _EmptyVideoSource: + def clipped(self, **_kwargs): + return self + + async def iter_frames(self): + if False: + yield None + + async def iter_numpy_frames(self): + if False: + yield None + + async def iter_frame_windows(self, **_kwargs): + if False: + yield None + + async def write_to(self, writer, **_kwargs): + raise NotImplementedError def _open_test_zarr(path: Path, *, mode: Literal["r", "r+", "a", "w", "w-"]): @@ -53,12 +90,31 @@ def _write_policy_zarr(path: Path) -> None: root.attrs["task"] = "push tee" -def test_read_zarr_rejects_reserved_file_path_output_name(tmp_path: Path) -> None: - path = tmp_path / "policy.zarr" - _write_policy_zarr(path) +def _write_part_zarr(path: Path, arrays: dict[str, np.ndarray]) -> None: + root = _open_test_zarr(path, mode="w") + for name, data in arrays.items(): + _create_array(root, name, data=data) - with pytest.raises(ValueError, match="reserved output names"): - mdr.read_zarr(path, arrays={"file_path": "data/action"}) + +def _write_video(path: Path, *, num_frames: int = 3, fps: int = 5) -> None: + import av + + with av.open(str(path), mode="w") as container: + stream = container.add_stream("mpeg4", rate=fps) + stream.width = 4 + stream.height = 4 + stream.pix_fmt = "yuv420p" + + for value in range(num_frames): + frame = av.VideoFrame.from_ndarray( + np.full((4, 4, 3), value, dtype=np.uint8), + format="rgb24", + ) + for packet in stream.encode(frame): + container.mux(packet) + + for packet in stream.encode(None): + container.mux(packet) def test_read_zarr_reads_selected_arrays_and_attrs(tmp_path: Path) -> None: @@ -673,3 +729,1414 @@ def test_zarr_to_robot_rows_and_lerobot_roundtrip(tmp_path: Path) -> None: episodes.sort(key=lambda episode: int(episode.episode_id)) assert [episode.num_frames for episode in episodes] == [2, 3] assert [episode.task for episode in episodes] == ["push tee", "push tee"] + + +def test_write_zarr_roundtrips_lerobot_rows(tmp_path: Path) -> None: + path = tmp_path / "policy.zarr" + lerobot_out = tmp_path / "lerobot" + zarr_out = tmp_path / "roundtrip.zarr" + _write_policy_zarr(path) + + ( + mdr.read_zarr( + path, + arrays={ + "action": "data/action", + "observation.state": "data/state", + "frames": "data/rgb", + }, + attrs={"task": "task"}, + row_ends="meta/episode_ends", + file_path_column=None, + ) + .to_robot_rows( + episode_id_key="row_index", + task_key="task", + action_key="action", + state_key="observation.state", + video_keys={"observation.images.front": "frames"}, + fps=10, + robot_type="pusht", + ) + .write_lerobot(str(lerobot_out), max_video_prepare_in_flight=1) + .launch_local( + name="zarr-to-lerobot", num_workers=1, rundir=str(tmp_path / "run1") + ) + ) + + ( + mdr.read_lerobot(str(lerobot_out)) + .write_zarr( + str(zarr_out), + arrays={ + "data/action": "action", + "data/state": "observation.state", + }, + reduce_to_single_store=False, + ) + .launch_local( + name="lerobot-to-zarr", num_workers=1, rundir=str(tmp_path / "run2") + ) + ) + + rows = [ + mdr.read_zarr( + zarr_store, + arrays={ + "action": "data/action", + "state": "data/state", + "episode_ends": "meta/episode_ends", + }, + file_path_column=None, + ).take(1)[0] + for zarr_store in sorted(zarr_out.glob("*.zarr")) + ] + + episode_ends = np.concatenate([row["episode_ends"] for row in rows]).tolist() + assert episode_ends[-1] == 5 + assert sorted(np.diff([0, *episode_ends]).tolist()) == [2, 3] + action = np.concatenate([row["action"] for row in rows]) + state = np.concatenate([row["state"] for row in rows]) + assert action.shape == (5, 1) + np.testing.assert_allclose( + np.sort(state.reshape(-1)), + np.asarray([10.0, 10.1, 20.0, 20.1, 20.2]), + ) + + +def test_write_zarr_can_reduce_to_single_store(tmp_path: Path) -> None: + zarr_out = tmp_path / "single.zarr" + + ( + mdr.from_items( + [ + {"action": [[0.0], [0.1]], "state": [[1.0], [1.1]], "task": "push"}, + {"action": [[0.2]], "state": [[1.2]], "task": "push"}, + ], + items_per_shard=1, + ) + .write_zarr( + str(zarr_out), + arrays={ + "data/action": "action", + "data/state": "state", + }, + attrs={"task": "task"}, + reduce_to_single_store=True, + ) + .launch_local( + name="zarr-single-store", num_workers=1, rundir=str(tmp_path / "run") + ) + ) + + row = mdr.read_zarr( + zarr_out, + arrays={ + "action": "data/action", + "state": "data/state", + "episode_ends": "meta/episode_ends", + }, + attrs={"task": "task"}, + file_path_column=None, + ).take(1)[0] + + np.testing.assert_allclose(row["action"], [[0.0], [0.1], [0.2]]) + np.testing.assert_allclose(row["state"], [[1.0], [1.1], [1.2]]) + assert row["episode_ends"].tolist() == [2, 3] + assert row["task"] == "push" + assert not (zarr_out / "_parts").exists() + + +def test_write_zarr_single_store_preserves_empty_payload_arrays(tmp_path: Path) -> None: + zarr_out = tmp_path / "single-empty-payload.zarr" + + ( + mdr.from_items([{"action": np.empty((0, 2), dtype=np.float32)}]) + .write_zarr( + str(zarr_out), + arrays={"data/action": "action"}, + reduce_to_single_store=True, + ) + .launch_local( + name="zarr-single-empty-payload", + num_workers=1, + rundir=str(tmp_path / "run-empty-payload"), + ) + ) + + root = _open_test_zarr(zarr_out, mode="r") + assert "data/action" in root + assert root["data/action"].shape == (0, 2) + assert root["meta/episode_ends"][:].tolist() == [0] + + +def test_write_zarr_rejects_store_template_without_worker_id(tmp_path: Path) -> None: + with pytest.raises(ValueError, match="store_template requires fields"): + ZarrSink(str(tmp_path / "template.zarr"), store_template="{shard_id}.zarr") + + +def test_write_zarr_rejects_unsupported_store_template_fields(tmp_path: Path) -> None: + with pytest.raises(ValueError, match="store_template only supports"): + ZarrSink( + str(tmp_path / "extra-template.zarr"), + store_template="{shard_id}__w{worker_id}__{part}.zarr", + ) + with pytest.raises(ValueError, match="store_template only supports"): + ZarrSink( + str(tmp_path / "format-template.zarr"), + store_template="{shard_id:>12}__w{worker_id}.zarr", + ) + + +def test_write_zarr_rejects_path_traversal(tmp_path: Path) -> None: + with pytest.raises(ValueError, match="must not contain"): + ZarrSink( + str(tmp_path / "escape.zarr"), + store_template="../escape/{shard_id}__w{worker_id}.zarr", + ) + with pytest.raises(ValueError, match="must be relative"): + ZarrSink( + str(tmp_path / "absolute.zarr"), + store_template="/tmp/{shard_id}__w{worker_id}.zarr", + ) + with pytest.raises(ValueError, match="must not contain"): + ZarrSink( + str(tmp_path / "array-escape.zarr"), + arrays={"../action": "action"}, + ) + with pytest.raises(ValueError, match="must not be empty"): + ZarrSink( + str(tmp_path / "empty-array-path.zarr"), + arrays={"": "action"}, + ) + with pytest.raises(ValueError, match="must not be empty"): + ZarrSink( + str(tmp_path / "empty-episode-ends.zarr"), + arrays={"data/action": "action"}, + episode_ends_path="", + ) + + +def test_write_zarr_rejects_rendered_path_traversal(tmp_path: Path) -> None: + with set_active_run_context( + job_id="local", + stage_index=0, + worker_id="worker-a", + worker_name=None, + runtime_lifecycle=cast(RuntimeLifecycle, _FinalizedWorkersRuntime([])), + ): + with pytest.raises(ValueError, match="must not contain"): + ZarrSink( + str(tmp_path / "rendered-escape.zarr"), + arrays={"data/action": "action"}, + ).write_block([DictRow({"action": [[1.0]]}, shard_id="../escape")]) + + +def test_write_zarr_rejects_reserved_paths(tmp_path: Path) -> None: + with pytest.raises(ValueError, match="reserved root"): + ZarrSink( + str(tmp_path / "reserved-array.zarr"), + arrays={"_parts/action": "action"}, + ) + with pytest.raises(ValueError, match="reserved root"): + ZarrSink( + str(tmp_path / "reserved-episode-ends.zarr"), + arrays={"data/action": "action"}, + episode_ends_path="_parts/episode_ends", + ) + with pytest.raises(ValueError, match="reserved root"): + ZarrSink( + str(tmp_path / "reserved-template.zarr"), + store_template="_refiner/{shard_id}__w{worker_id}.zarr", + ) + + +def test_write_zarr_rejects_empty_array_mapping(tmp_path: Path) -> None: + with pytest.raises(ValueError, match="arrays must not be empty"): + ZarrSink(str(tmp_path / "empty-arrays.zarr"), arrays={}) + + +def test_write_zarr_rejects_empty_default_robotics_arrays(tmp_path: Path) -> None: + rows = list( + mdr.from_items([{"episode_id": "episode-1"}]).to_robot_rows( + episode_id_key="episode_id", + action_key=None, + state_key=None, + timestamp_key=None, + ) + ) + + with pytest.raises(ValueError, match="inferred no default robotics arrays"): + ZarrSink(str(tmp_path / "empty-defaults.zarr")).write_block(rows) + + +def test_write_zarr_single_store_replace_ignores_stale_parts(tmp_path: Path) -> None: + zarr_out = tmp_path / "single-replace.zarr" + + ( + mdr.from_items([{"action": [[0.0]]}], items_per_shard=1) + .write_zarr( + str(zarr_out), + arrays={"data/action": "action"}, + reduce_to_single_store=True, + ) + .launch_local( + name="zarr-single-replace-first", + num_workers=1, + rundir=str(tmp_path / "run-first"), + ) + ) + + stale_part = zarr_out / "_parts" / "old__wold.zarr" + stale_part.mkdir(parents=True) + (stale_part / ".zgroup").write_text('{"zarr_format": 2}', encoding="utf-8") + + ( + mdr.from_items([{"action": [[1.0]]}], items_per_shard=1) + .write_zarr( + str(zarr_out), + arrays={"data/action": "action"}, + reduce_to_single_store=True, + ) + .launch_local( + name="zarr-single-replace-second", + num_workers=1, + rundir=str(tmp_path / "run-second"), + ) + ) + + row = mdr.read_zarr( + zarr_out, + arrays={"action": "data/action"}, + file_path_column=None, + ).take(1)[0] + np.testing.assert_allclose(row["action"], [[1.0]]) + assert not stale_part.exists() + + +def test_write_zarr_sharded_replace_removes_single_store_payload_and_parts( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "sharded-replaces-single-store.zarr" + + ( + mdr.from_items([{"action": [[0.0]]}], items_per_shard=1) + .write_zarr( + str(zarr_out), + arrays={"data/action": "action"}, + reduce_to_single_store=True, + ) + .launch_local( + name="zarr-sharded-replace-first", + num_workers=1, + rundir=str(tmp_path / "run-sharded-replace-first"), + ) + ) + stale_part = zarr_out / "_parts" / "old__wold.zarr" + stale_part.mkdir(parents=True) + (stale_part / ".zgroup").write_text('{"zarr_format": 2}', encoding="utf-8") + + ( + mdr.from_items([{"action": [[1.0], [2.0]]}], items_per_shard=1) + .write_zarr( + str(zarr_out), + arrays={"data/action": "action"}, + reduce_to_single_store=False, + ) + .launch_local( + name="zarr-sharded-replace-second", + num_workers=1, + rundir=str(tmp_path / "run-sharded-replace-second"), + ) + ) + + assert not (zarr_out / "data").exists() + assert not (zarr_out / "meta").exists() + assert not (zarr_out / "_parts").exists() + stores = sorted(zarr_out.glob("*.zarr")) + assert len(stores) == 1 + row = mdr.read_zarr( + stores[0], + arrays={ + "action": "data/action", + "episode_ends": "meta/episode_ends", + }, + file_path_column=None, + ).take(1)[0] + np.testing.assert_allclose(row["action"], [[1.0], [2.0]]) + assert row["episode_ends"].tolist() == [2] + + +def test_write_zarr_sharded_replace_clears_payload_under_store_prefix( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "sharded-replaces-nested-single-store.zarr" + + ( + mdr.from_items([{"action": [[0.0]]}], items_per_shard=1) + .write_zarr( + str(zarr_out), + arrays={"split/action": "action"}, + reduce_to_single_store=True, + ) + .launch_local( + name="zarr-sharded-nested-replace-first", + num_workers=1, + rundir=str(tmp_path / "run-sharded-nested-replace-first"), + ) + ) + + ( + mdr.from_items([{"action": [[1.0], [2.0]]}], items_per_shard=1) + .write_zarr( + str(zarr_out), + arrays={"data/action": "action"}, + store_template="split/{shard_id}__w{worker_id}.zarr", + reduce_to_single_store=False, + ) + .launch_local( + name="zarr-sharded-nested-replace-second", + num_workers=1, + rundir=str(tmp_path / "run-sharded-nested-replace-second"), + ) + ) + + assert not (zarr_out / "split" / "action").exists() + assert not (zarr_out / "meta").exists() + stores = sorted((zarr_out / "split").glob("*.zarr")) + assert len(stores) == 1 + row = mdr.read_zarr( + stores[0], + arrays={"action": "data/action"}, + file_path_column=None, + ).take(1)[0] + np.testing.assert_allclose(row["action"], [[1.0], [2.0]]) + + +def test_write_zarr_non_reduced_cleanup_rejects_missing_finalized_store( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "sharded-missing-finalized-store.zarr" + runtime = _FinalizedWorkersRuntime( + [FinalizedShardWorker(shard_id="shard-a", worker_id="worker-a")] + ) + with set_active_run_context( + job_id="local", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast(RuntimeLifecycle, runtime), + ): + with pytest.raises(ValueError, match="Zarr store is missing"): + ZarrReducerSink( + str(zarr_out), + store_template="{shard_id}__w{worker_id}.zarr", + reduce_to_single_store=False, + ).write_block([DictRow({}, shard_id="reduce")]) + + +def test_write_zarr_empty_shard_completion_replaces_stale_store( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "empty-shard-replaces-stale-store.zarr" + worker_id = "worker-a" + stale = zarr_out / f"shard-a__w{worker_token_for(worker_id)}.zarr" + _write_part_zarr(stale, {"data/action": np.asarray([[9.0]], dtype=np.float32)}) + + with set_active_run_context( + job_id="local", + stage_index=0, + worker_id=worker_id, + worker_name=None, + runtime_lifecycle=cast(RuntimeLifecycle, _FinalizedWorkersRuntime([])), + ): + ZarrSink( + str(zarr_out), + arrays={"data/action": "action"}, + reduce_to_single_store=False, + ).on_shard_complete("shard-a") + + root = _open_test_zarr(stale, mode="r") + assert not list(root.array_keys()) + assert not list(root.group_keys()) + + +def test_write_zarr_non_reduced_cleanup_keeps_empty_stores_retryable( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "sharded-empty-cleanup-retry.zarr" + worker_id = "worker-a" + empty_store = zarr_out / f"shard-a__w{worker_token_for(worker_id)}.zarr" + _open_test_zarr(empty_store, mode="w") + + runtime = _FinalizedWorkersRuntime( + [FinalizedShardWorker(shard_id="shard-a", worker_id=worker_id)] + ) + for _ in range(2): + with set_active_run_context( + job_id="local", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast(RuntimeLifecycle, runtime), + ): + ZarrReducerSink( + str(zarr_out), + store_template="{shard_id}__w{worker_id}.zarr", + reduce_to_single_store=False, + ).write_block([DictRow({}, shard_id="reduce")]) + + assert empty_store.exists() + + +def test_write_zarr_rejects_sharded_schema_drift_after_cleanup( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "sharded-schema-drift.zarr" + first_worker = "worker-a" + second_worker = "worker-b" + first = zarr_out / f"shard-a__w{worker_token_for(first_worker)}.zarr" + second = zarr_out / f"shard-b__w{worker_token_for(second_worker)}.zarr" + _write_part_zarr( + first, + { + "data/action": np.asarray([[0.0]], dtype=np.float32), + "meta/episode_ends": np.asarray([1], dtype=np.int64), + }, + ) + _write_part_zarr( + second, + { + "data/action": np.asarray([[1.0]], dtype=np.float32), + "data/state": np.asarray([[2.0]], dtype=np.float32), + "meta/episode_ends": np.asarray([1], dtype=np.int64), + }, + ) + reducer = ZarrSink( + str(zarr_out), + arrays={"data/action": "action"}, + reduce_to_single_store=False, + ).build_reducer() + assert reducer is not None + runtime = _FinalizedWorkersRuntime( + [ + FinalizedShardWorker( + shard_id="shard-a", + worker_id=first_worker, + global_ordinal=0, + ), + FinalizedShardWorker( + shard_id="shard-b", + worker_id=second_worker, + global_ordinal=1, + ), + ] + ) + + with set_active_run_context( + job_id="local", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast(RuntimeLifecycle, runtime), + ): + with pytest.raises(ValueError, match="same arrays"): + reducer.write_block([DictRow({}, shard_id="reduce")]) + + +def test_write_zarr_single_store_skips_empty_shards(tmp_path: Path) -> None: + zarr_out = tmp_path / "single-empty-shards.zarr" + + ( + mdr.from_items( + [{"action": [[0.0]]}, {"action": [[0.1]]}], + items_per_shard=1, + ) + .filter(lambda row: False) + .write_zarr( + str(zarr_out), + arrays={"data/action": "action"}, + reduce_to_single_store=True, + ) + .launch_local( + name="zarr-single-empty-shards", + num_workers=2, + rundir=str(tmp_path / "run-empty"), + ) + ) + + assert not (zarr_out / "_parts").exists() + + +def test_write_zarr_single_store_empty_replace_ignores_stale_done_marker( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "single-empty-replace-stale-done.zarr" + + ( + mdr.from_items([{"action": [[1.0]]}], items_per_shard=1) + .write_zarr( + str(zarr_out), + arrays={"data/action": "action"}, + reduce_to_single_store=True, + ) + .launch_local( + name="zarr-single-stale-done-first", + num_workers=1, + rundir=str(tmp_path / "run-stale-done-first"), + ) + ) + + ( + mdr.from_items([{"action": [[2.0]]}], items_per_shard=1) + .filter(lambda row: False) + .write_zarr( + str(zarr_out), + arrays={"data/action": "action"}, + reduce_to_single_store=True, + ) + .launch_local( + name="zarr-single-stale-done-second", + num_workers=1, + rundir=str(tmp_path / "run-stale-done-second"), + ) + ) + + root = _open_test_zarr(zarr_out, mode="r") + assert "data/action" not in root + assert not (zarr_out / "_parts").exists() + + +def test_write_zarr_single_store_skips_mixed_empty_shards(tmp_path: Path) -> None: + zarr_out = tmp_path / "single-mixed-empty-shards.zarr" + + ( + mdr.from_items( + [{"action": [[0.0]]}, {"action": [[1.0]]}], + items_per_shard=1, + ) + .filter(lambda row: float(row["action"][0][0]) > 0.0) + .write_zarr( + str(zarr_out), + arrays={"data/action": "action"}, + reduce_to_single_store=True, + ) + .launch_local( + name="zarr-single-mixed-empty-shards", + num_workers=2, + rundir=str(tmp_path / "run-mixed-empty"), + ) + ) + + row = mdr.read_zarr( + zarr_out, + arrays={ + "action": "data/action", + "episode_ends": "meta/episode_ends", + }, + file_path_column=None, + ).take(1)[0] + np.testing.assert_allclose(row["action"], [[1.0]]) + assert row["episode_ends"].tolist() == [1] + assert not (zarr_out / "_parts").exists() + + +def test_write_zarr_single_store_offsets_batched_episode_ends( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "single-batched-episode-ends.zarr" + workers = ["worker-a", "worker-b"] + for shard_id, worker in enumerate(workers): + part = ( + zarr_out / "_parts" / f"shard-{shard_id}__w{worker_token_for(worker)}.zarr" + ) + _write_part_zarr( + part, + { + "data/action": np.arange(2, dtype=np.float32).reshape(2, 1), + "meta/episode_ends": np.asarray([1, 2], dtype=np.int64), + }, + ) + runtime = _FinalizedWorkersRuntime( + [ + FinalizedShardWorker( + shard_id=f"shard-{shard_id}", + worker_id=worker, + global_ordinal=shard_id, + ) + for shard_id, worker in enumerate(workers) + ] + ) + with set_active_run_context( + job_id="local", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast(RuntimeLifecycle, runtime), + ): + ZarrReducerSink( + str(zarr_out), + store_template="{shard_id}__w{worker_id}.zarr", + episode_ends_path="meta/episode_ends", + array_chunk_bytes=8, + reduce_to_single_store=True, + ).write_block([DictRow({}, shard_id="reduce")]) + + root = _open_test_zarr(zarr_out, mode="r") + assert root["meta/episode_ends"][:].tolist() == [1, 2, 3, 4] + + +def test_write_zarr_single_store_rejects_inconsistent_part_payloads( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "single-inconsistent-parts.zarr" + first_worker = "worker-a" + second_worker = "worker-b" + first_part = ( + zarr_out / "_parts" / f"shard-a__w{worker_token_for(first_worker)}.zarr" + ) + second_part = ( + zarr_out / "_parts" / f"shard-b__w{worker_token_for(second_worker)}.zarr" + ) + _write_part_zarr( + first_part, + { + "data/action": np.asarray([[0.0]], dtype=np.float32), + "meta/episode_ends": np.asarray([1], dtype=np.int64), + }, + ) + _write_part_zarr( + second_part, + { + "data/action": np.asarray([[1.0]], dtype=np.float32), + "data/state": np.asarray([[2.0]], dtype=np.float32), + "meta/episode_ends": np.asarray([1], dtype=np.int64), + }, + ) + + runtime = _FinalizedWorkersRuntime( + [ + FinalizedShardWorker( + shard_id="shard-a", + worker_id=first_worker, + global_ordinal=0, + ), + FinalizedShardWorker( + shard_id="shard-b", + worker_id=second_worker, + global_ordinal=1, + ), + ] + ) + with set_active_run_context( + job_id="local", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast(RuntimeLifecycle, runtime), + ): + with pytest.raises(ValueError, match="same payload arrays"): + ZarrReducerSink( + str(zarr_out), + store_template="{shard_id}__w{worker_id}.zarr", + episode_ends_path="meta/episode_ends", + array_chunk_bytes=1024, + reduce_to_single_store=True, + ).write_block([DictRow({}, shard_id="reduce")]) + assert first_part.exists() + assert second_part.exists() + + +def test_write_zarr_single_store_rejects_part_missing_episode_ends( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "single-missing-episode-ends.zarr" + worker = "worker-a" + part = zarr_out / "_parts" / f"shard-a__w{worker_token_for(worker)}.zarr" + _write_part_zarr( + part, + {"data/action": np.asarray([[0.0]], dtype=np.float32)}, + ) + runtime = _FinalizedWorkersRuntime( + [ + FinalizedShardWorker( + shard_id="shard-a", + worker_id=worker, + global_ordinal=0, + ) + ] + ) + with set_active_run_context( + job_id="local", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast(RuntimeLifecycle, runtime), + ): + with pytest.raises(ValueError, match="meta/episode_ends"): + ZarrReducerSink( + str(zarr_out), + store_template="{shard_id}__w{worker_id}.zarr", + episode_ends_path="meta/episode_ends", + array_chunk_bytes=1024, + reduce_to_single_store=True, + ).write_block([DictRow({}, shard_id="reduce")]) + + +def test_write_zarr_single_store_rejects_part_row_end_mismatch( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "single-row-end-mismatch.zarr" + worker = "worker-a" + part = zarr_out / "_parts" / f"shard-a__w{worker_token_for(worker)}.zarr" + _write_part_zarr( + part, + { + "data/action": np.asarray([[0.0], [1.0]], dtype=np.float32), + "meta/episode_ends": np.asarray([1], dtype=np.int64), + }, + ) + runtime = _FinalizedWorkersRuntime( + [ + FinalizedShardWorker( + shard_id="shard-a", + worker_id=worker, + global_ordinal=0, + ) + ] + ) + with set_active_run_context( + job_id="local", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast(RuntimeLifecycle, runtime), + ): + with pytest.raises(ValueError, match="episode_ends final value"): + ZarrReducerSink( + str(zarr_out), + store_template="{shard_id}__w{worker_id}.zarr", + episode_ends_path="meta/episode_ends", + array_chunk_bytes=1024, + reduce_to_single_store=True, + ).write_block([DictRow({}, shard_id="reduce")]) + + +def test_write_zarr_single_store_rejects_missing_finalized_part( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "single-missing-part.zarr" + _write_part_zarr( + zarr_out, + { + "data/action": np.asarray([[9.0]], dtype=np.float32), + "meta/episode_ends": np.asarray([1], dtype=np.int64), + }, + ) + runtime = _FinalizedWorkersRuntime( + [ + FinalizedShardWorker( + shard_id="shard-a", + worker_id="worker-a", + global_ordinal=0, + ) + ] + ) + + with set_active_run_context( + job_id="local", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast(RuntimeLifecycle, runtime), + ): + with pytest.raises(ValueError, match="part store is missing"): + ZarrReducerSink( + str(zarr_out), + store_template="{shard_id}__w{worker_id}.zarr", + episode_ends_path="meta/episode_ends", + array_chunk_bytes=1024, + reduce_to_single_store=True, + ).write_block([DictRow({}, shard_id="reduce")]) + + row = mdr.read_zarr( + zarr_out, + arrays={"action": "data/action", "episode_ends": "meta/episode_ends"}, + file_path_column=None, + ).take(1)[0] + np.testing.assert_allclose(row["action"], [[9.0]]) + assert row["episode_ends"].tolist() == [1] + + +def test_write_zarr_single_store_removes_parts_only_on_completion( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "single-complete-cleans-parts.zarr" + worker_id = "worker-a" + part = zarr_out / "_parts" / f"shard-a__w{worker_token_for(worker_id)}.zarr" + _write_part_zarr( + part, + { + "data/action": np.asarray([[9.0]], dtype=np.float32), + "meta/episode_ends": np.asarray([1], dtype=np.int64), + }, + ) + runtime = _FinalizedWorkersRuntime( + [ + FinalizedShardWorker( + shard_id="shard-a", + worker_id=worker_id, + global_ordinal=0, + ) + ] + ) + + with set_active_run_context( + job_id="local", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast(RuntimeLifecycle, runtime), + ): + reducer = ZarrReducerSink( + str(zarr_out), + store_template="{shard_id}__w{worker_id}.zarr", + episode_ends_path="meta/episode_ends", + array_chunk_bytes=1024, + reduce_to_single_store=True, + ) + reducer.write_block([DictRow({}, shard_id="reduce")]) + assert part.exists() + reducer.on_shard_finalized("reduce") + + row = mdr.read_zarr( + zarr_out, + arrays={"action": "data/action", "episode_ends": "meta/episode_ends"}, + file_path_column=None, + ).take(1)[0] + np.testing.assert_allclose(row["action"], [[9.0]]) + assert row["episode_ends"].tolist() == [1] + assert not (zarr_out / "_parts").exists() + + +def test_write_zarr_single_store_zero_shard_replace_clears_existing_output( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "single-zero-shard-replace.zarr" + _write_part_zarr( + zarr_out, + { + "data/action": np.asarray([[9.0]], dtype=np.float32), + "meta/episode_ends": np.asarray([1], dtype=np.int64), + }, + ) + runtime = _FinalizedWorkersRuntime([]) + + with set_active_run_context( + job_id="local", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast(RuntimeLifecycle, runtime), + ): + ZarrReducerSink( + str(zarr_out), + store_template="{shard_id}__w{worker_id}.zarr", + episode_ends_path="meta/episode_ends", + array_chunk_bytes=1024, + reduce_to_single_store=True, + ).write_block([DictRow({}, shard_id="reduce")]) + + root = _open_test_zarr(zarr_out, mode="r") + assert "data/action" not in root + + +def test_write_zarr_single_store_parts_are_resume_stable(tmp_path: Path) -> None: + zarr_out = tmp_path / "single-resume-stable.zarr" + worker_id = "original-worker" + part = zarr_out / "_parts" / f"shard-a__w{worker_token_for(worker_id)}.zarr" + _write_part_zarr( + part, + { + "data/action": np.asarray([[4.0]], dtype=np.float32), + "meta/episode_ends": np.asarray([1], dtype=np.int64), + }, + ) + runtime = _FinalizedWorkersRuntime( + [ + FinalizedShardWorker( + shard_id="shard-a", + worker_id=worker_id, + global_ordinal=0, + ) + ] + ) + + with set_active_run_context( + job_id="resumed-job", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast(RuntimeLifecycle, runtime), + ): + ZarrReducerSink( + str(zarr_out), + store_template="{shard_id}__w{worker_id}.zarr", + episode_ends_path="meta/episode_ends", + array_chunk_bytes=1024, + reduce_to_single_store=True, + ).write_block([DictRow({}, shard_id="reduce")]) + + row = mdr.read_zarr( + zarr_out, + arrays={"action": "data/action", "episode_ends": "meta/episode_ends"}, + file_path_column=None, + ).take(1)[0] + np.testing.assert_allclose(row["action"], [[4.0]]) + assert row["episode_ends"].tolist() == [1] + + +def test_write_zarr_single_store_replace_clears_root_attrs( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "single-replace-attrs.zarr" + root = _open_test_zarr(zarr_out, mode="w") + root.attrs["task"] = "old" + + ( + mdr.from_items([{"action": [[1.0]]}], items_per_shard=1) + .write_zarr( + str(zarr_out), + arrays={"data/action": "action"}, + reduce_to_single_store=True, + ) + .launch_local( + name="zarr-single-replace-attrs", + num_workers=1, + rundir=str(tmp_path / "run-replace-attrs"), + ) + ) + + root = _open_test_zarr(zarr_out, mode="r") + assert dict(root.attrs) == {} + + +def test_write_zarr_single_store_rejects_part_attr_drift_before_clearing( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "single-attr-drift.zarr" + _write_part_zarr( + zarr_out, + {"data/action": np.asarray([[99.0]], dtype=np.float32)}, + ) + root = _open_test_zarr(zarr_out, mode="r+") + root.attrs["task"] = "old" + + first_worker = "worker-a" + second_worker = "worker-b" + first_part = ( + zarr_out / "_parts" / f"shard-a__w{worker_token_for(first_worker)}.zarr" + ) + second_part = ( + zarr_out / "_parts" / f"shard-b__w{worker_token_for(second_worker)}.zarr" + ) + _write_part_zarr( + first_part, + { + "data/action": np.asarray([[0.0]], dtype=np.float32), + "meta/episode_ends": np.asarray([1], dtype=np.int64), + }, + ) + _write_part_zarr( + second_part, + { + "data/action": np.asarray([[1.0]], dtype=np.float32), + "meta/episode_ends": np.asarray([1], dtype=np.int64), + }, + ) + _open_test_zarr(first_part, mode="r+").attrs["task"] = "first" + _open_test_zarr(second_part, mode="r+").attrs["task"] = "second" + + runtime = _FinalizedWorkersRuntime( + [ + FinalizedShardWorker( + shard_id="shard-a", + worker_id=first_worker, + global_ordinal=0, + ), + FinalizedShardWorker( + shard_id="shard-b", + worker_id=second_worker, + global_ordinal=1, + ), + ] + ) + with set_active_run_context( + job_id="local", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast(RuntimeLifecycle, runtime), + ): + with pytest.raises(ValueError, match="attrs differ"): + ZarrReducerSink( + str(zarr_out), + store_template="{shard_id}__w{worker_id}.zarr", + episode_ends_path="meta/episode_ends", + array_chunk_bytes=1024, + reduce_to_single_store=True, + ).write_block([DictRow({}, shard_id="reduce")]) + + root = _open_test_zarr(zarr_out, mode="r") + np.testing.assert_allclose(root["data/action"][:], [[99.0]]) + assert dict(root.attrs) == {"task": "old"} + + +def test_write_zarr_single_store_rejects_part_dtype_drift( + tmp_path: Path, +) -> None: + zarr_out = tmp_path / "single-dtype-drift.zarr" + first_worker = "worker-a" + second_worker = "worker-b" + first_part = ( + zarr_out / "_parts" / f"shard-a__w{worker_token_for(first_worker)}.zarr" + ) + second_part = ( + zarr_out / "_parts" / f"shard-b__w{worker_token_for(second_worker)}.zarr" + ) + _write_part_zarr( + first_part, + { + "data/action": np.asarray([[0.0]], dtype=np.float32), + "meta/episode_ends": np.asarray([1], dtype=np.int64), + }, + ) + _write_part_zarr( + second_part, + { + "data/action": np.asarray([[1.0]], dtype=np.float64), + "meta/episode_ends": np.asarray([1], dtype=np.int64), + }, + ) + + runtime = _FinalizedWorkersRuntime( + [ + FinalizedShardWorker( + shard_id="shard-a", + worker_id=first_worker, + global_ordinal=0, + ), + FinalizedShardWorker( + shard_id="shard-b", + worker_id=second_worker, + global_ordinal=1, + ), + ] + ) + with set_active_run_context( + job_id="local", + stage_index=1, + worker_id="reducer", + worker_name=None, + runtime_lifecycle=cast(RuntimeLifecycle, runtime), + ): + with pytest.raises(ValueError, match="matching dtypes"): + ZarrReducerSink( + str(zarr_out), + store_template="{shard_id}__w{worker_id}.zarr", + episode_ends_path="meta/episode_ends", + array_chunk_bytes=1024, + reduce_to_single_store=True, + ).write_block([DictRow({}, shard_id="reduce")]) + assert first_part.exists() + assert second_part.exists() + + +def test_write_zarr_rejects_rows_missing_inferred_default_arrays( + tmp_path: Path, +) -> None: + output = tmp_path / "missing-default-array.zarr" + rows = list( + mdr.from_items( + [ + {"action": [[0.0]], "observation.state": [[1.0]]}, + {"action": [[0.1]]}, + ] + ).to_robot_rows( + action_key="action", + state_key="observation.state", + timestamp_key=None, + ) + ) + + with pytest.raises(ValueError, match="default arrays changed"): + ZarrSink(str(output)).write_block(rows) + + +def test_write_zarr_rejects_late_default_arrays(tmp_path: Path) -> None: + output = tmp_path / "late-default-array.zarr" + rows = list( + mdr.from_items( + [ + {"action": [[0.0]]}, + {"action": [[0.1]], "observation.state": [[1.1]]}, + ] + ).to_robot_rows( + action_key="action", + state_key="observation.state", + timestamp_key=None, + ) + ) + + with pytest.raises(ValueError, match="default arrays changed"): + ZarrSink(str(output)).write_block(rows) + + +def test_write_zarr_rejects_episode_ends_path_collision(tmp_path: Path) -> None: + with pytest.raises(ValueError, match="collides with episode_ends_path"): + ZarrSink( + str(tmp_path / "collision.zarr"), + arrays={"meta/episode_ends": "action"}, + ) + with pytest.raises(ValueError, match="collides with episode_ends_path"): + ZarrSink( + str(tmp_path / "normalized-collision.zarr"), + arrays={"meta//episode_ends": "action"}, + ) + + +def test_write_zarr_rejects_shape_drift_before_appending_bad_row( + tmp_path: Path, +) -> None: + output = tmp_path / "shape-mismatch.zarr" + rows: list[Row] = [ + DictRow({"action": [[0.0]], "state": [[1.0, 2.0]]}, shard_id="shard"), + DictRow({"action": [[0.1]], "state": [[1.1]]}, shard_id="shard"), + ] + + with pytest.raises(ValueError, match="matching trailing shapes"): + ZarrSink( + str(output), + arrays={ + "data/action": "action", + "data/state": "state", + }, + reduce_to_single_store=False, + ).write_block(rows) + + zarr_store = next(output.glob("*.zarr")) + row = mdr.read_zarr( + zarr_store, + arrays={"action": "data/action", "state": "data/state"}, + file_path_column=None, + ).take(1)[0] + assert row["action"].shape == (1, 1) + assert row["state"].shape == (1, 2) + + +def test_write_zarr_materializes_frame_array_videos(tmp_path: Path) -> None: + output = tmp_path / "video.zarr" + frames = np.arange(2 * 4 * 4 * 3, dtype=np.uint8).reshape(2, 4, 4, 3) + rows = list( + mdr.from_items( + [{"episode_id": "episode-1", "frames": frames, "action": [[0.0], [0.1]]}] + ).to_robot_rows( + episode_id_key="episode_id", + action_key="action", + state_key=None, + timestamp_key=None, + video_keys={"observation.images.front": "frames"}, + fps=10, + ) + ) + + ZarrSink( + str(output), + arrays={ + "data/action": "action", + "data/rgb": "observation.images.front", + }, + reduce_to_single_store=False, + ).write_block(rows) + + zarr_store = next(output.glob("*.zarr")) + row = mdr.read_zarr( + zarr_store, + arrays={"action": "data/action", "rgb": "data/rgb"}, + file_path_column=None, + ).take(1)[0] + np.testing.assert_array_equal(row["rgb"], frames) + np.testing.assert_allclose(row["action"], [[0.0], [0.1]]) + + +def test_write_zarr_rejects_empty_frame_array_videos(tmp_path: Path) -> None: + output = tmp_path / "empty-video.zarr" + frames = np.empty((0, 4, 5, 3), dtype=np.uint8) + rows = list( + mdr.from_items([{"episode_id": "episode-1", "frames": frames}]).to_robot_rows( + episode_id_key="episode_id", + action_key=None, + state_key=None, + timestamp_key=None, + video_keys={"observation.images.front": "frames"}, + fps=10, + ) + ) + + with pytest.raises(ValueError, match="produced no frames"): + ZarrSink( + str(output), + arrays={"data/rgb": "observation.images.front"}, + reduce_to_single_store=False, + ).write_block(rows) + + +def test_write_zarr_uses_byte_budgeted_chunks_for_large_rows(tmp_path: Path) -> None: + output = tmp_path / "video-chunks.zarr" + frames = np.zeros((2, 4, 4, 3), dtype=np.uint8) + rows = list( + mdr.from_items( + [{"episode_id": "episode-1", "frames": frames, "action": [[0.0], [0.1]]}] + ).to_robot_rows( + episode_id_key="episode_id", + action_key="action", + state_key=None, + timestamp_key=None, + video_keys={"observation.images.front": "frames"}, + fps=10, + ) + ) + + ZarrSink( + str(output), + arrays={ + "data/action": "action", + "data/rgb": "observation.images.front", + }, + array_chunk_bytes=50, + reduce_to_single_store=False, + ).write_block(rows) + + root = _open_test_zarr(next(output.glob("*.zarr")), mode="r") + assert root["data/rgb"].chunks == (1, 4, 4, 3) + + +def test_write_zarr_caps_low_dimensional_initial_chunks(tmp_path: Path) -> None: + output = tmp_path / "small-array-chunks.zarr" + ZarrSink( + str(output), + arrays={"data/action": "action"}, + reduce_to_single_store=False, + ).write_block( + [ + DictRow( + {"action": np.asarray([[1.0]], dtype=np.float32)}, + shard_id="shard", + ) + ] + ) + + root = _open_test_zarr(next(output.glob("*.zarr")), mode="r") + assert root["data/action"].chunks == (1024, 1) + + +def test_write_zarr_streams_encoded_videos(tmp_path: Path) -> None: + source = tmp_path / "source.mp4" + output = tmp_path / "encoded-video.zarr" + _write_video(source, num_frames=3, fps=5) + + rows = list( + mdr.from_items( + [ + { + "episode_id": "episode-1", + "clip": mdr.video.VideoFile(DataFile.resolve(source)), + "action": [[0.0], [0.1], [0.2]], + } + ] + ).to_robot_rows( + episode_id_key="episode_id", + action_key="action", + state_key=None, + timestamp_key=None, + video_keys={"observation.images.front": "clip"}, + ) + ) + + ZarrSink( + str(output), + arrays={ + "data/action": "action", + "data/rgb": "observation.images.front", + }, + video_frame_batch_size=2, + reduce_to_single_store=False, + ).write_block(rows) + + zarr_store = next(output.glob("*.zarr")) + row = mdr.read_zarr( + zarr_store, + arrays={"action": "data/action", "rgb": "data/rgb"}, + file_path_column=None, + ).take(1)[0] + assert row["rgb"].shape == (3, 4, 4, 3) + assert row["rgb"].dtype == np.uint8 + np.testing.assert_allclose(row["action"], [[0.0], [0.1], [0.2]]) + + +def test_write_zarr_rejects_empty_encoded_video_source(tmp_path: Path) -> None: + output = tmp_path / "empty-encoded-video.zarr" + + rows = list( + mdr.from_items( + [{"episode_id": "episode-1", "clip": _EmptyVideoSource()}] + ).to_robot_rows( + episode_id_key="episode_id", + action_key=None, + state_key=None, + timestamp_key=None, + video_keys={"observation.images.front": "clip"}, + ) + ) + + with pytest.raises(ValueError, match="produced no frames"): + ZarrSink( + str(output), + arrays={"data/rgb": "observation.images.front"}, + ).write_block(rows) + + +def test_write_zarr_rejects_video_length_mismatch_before_final_append( + tmp_path: Path, +) -> None: + output = tmp_path / "video-length-mismatch.zarr" + frames = np.zeros((3, 4, 4, 3), dtype=np.uint8) + rows = list( + mdr.from_items( + [{"episode_id": "episode-1", "frames": frames, "action": [[0.0], [0.1]]}] + ).to_robot_rows( + episode_id_key="episode_id", + action_key="action", + state_key=None, + timestamp_key=None, + video_keys={"observation.images.front": "frames"}, + fps=10, + ) + ) + + with pytest.raises(ValueError, match="matching lengths"): + ZarrSink( + str(output), + arrays={ + "data/action": "action", + "data/rgb": "observation.images.front", + }, + video_frame_batch_size=2, + reduce_to_single_store=False, + ).write_block(rows) + + zarr_store = next(output.glob("*.zarr")) + root = _open_test_zarr(zarr_store, mode="r") + assert "data/action" not in root + assert "data/rgb" not in root + assert "__tmp" not in root diff --git a/tests/test_video_decode.py b/tests/test_video_decode.py index e8765888..a9dba9ba 100644 --- a/tests/test_video_decode.py +++ b/tests/test_video_decode.py @@ -85,9 +85,11 @@ def test_video_frame_array_iter_frames() -> None: video = mdr.video.VideoFrameArray(frames, fps=5) decoded = asyncio.run(_collect_frames(video)) + arrays = list(video.iter_frame_arrays()) assert [frame.index for frame in decoded] == [0, 1, 2] assert [frame.timestamp_s for frame in decoded] == [0.0, 0.2, 0.4] + assert [int(frame[0, 0, 0]) for frame in arrays] == [0, 1, 2] def test_video_stream_writer_accepts_video_frame_array(tmp_path) -> None: diff --git a/tests/worker/test_runner.py b/tests/worker/test_runner.py index bfc9ae91..50661272 100644 --- a/tests/worker/test_runner.py +++ b/tests/worker/test_runner.py @@ -19,7 +19,7 @@ from refiner.pipeline.sources.readers.base import BaseReader from refiner.pipeline.data.row import DictRow, Row from refiner.worker.metrics.api import log_gauge -from refiner.worker.lifecycle import FinalizedShardWorker +from refiner.worker.lifecycle import FinalizedShardWorker, sort_finalized_workers class _FakeReader(BaseReader): @@ -71,6 +71,20 @@ def _shard(path: str, start: int, end: int) -> Shard: return Shard.from_file_parts([FilePart(path=path, start=start, end=end)]) +def test_sort_finalized_workers_uses_legacy_order_when_any_ordinal_is_missing() -> None: + rows = [ + FinalizedShardWorker("shard-c", "worker-c", global_ordinal=0), + FinalizedShardWorker("shard-a", "worker-a"), + FinalizedShardWorker("shard-b", "worker-b", global_ordinal=1), + ] + + assert [row.shard_id for row in sort_finalized_workers(rows)] == [ + "shard-a", + "shard-b", + "shard-c", + ] + + class _NoopTelemetryEmitter: def emit_user_counter(self, **kwargs) -> None: del kwargs @@ -453,6 +467,41 @@ def test_worker_completes_shards_only_after_sink_drain() -> None: assert runtime_lifecycle.completed_ids == [shard.id] +def test_worker_runs_post_completion_sink_hook_after_runtime_complete() -> None: + shard = _shard("p", 0, 1) + events: list[str] = [] + + class _OrderedRuntimeLifecycle(_FakeRuntimeLifecycle): + def complete(self, shard: Shard) -> None: + super().complete(shard) + events.append("runtime_complete") + + class _OrderedSink(_RecordingSink): + def on_shard_complete(self, shard_id: str) -> None: + super().on_shard_complete(shard_id) + events.append("sink_complete") + + def on_shard_finalized(self, shard_id: str) -> None: + events.append("sink_finalized") + + runtime_lifecycle = _OrderedRuntimeLifecycle([shard]) + sink = _OrderedSink() + worker = Worker( + pipeline=RefinerPipeline( + source=_FakeReader({shard.id: [DictRow({"x": 1})]}) + ).with_sink(sink), + job_id="job", + stage_index=0, + worker_id=runtime_lifecycle.worker_id, + runtime_lifecycle=runtime_lifecycle, + ) + + stats = worker.run() + + assert stats.completed == 1 + assert events == ["sink_complete", "runtime_complete", "sink_finalized"] + + def test_worker_metrics_use_correct_step_indexes_for_all_block_types( monkeypatch: pytest.MonkeyPatch, tmp_path ) -> None: