From da2167b306f8dac832200ac39aaaaf0b2651fecb Mon Sep 17 00:00:00 2001 From: Feiyang Wu Date: Mon, 11 May 2026 17:40:00 -0400 Subject: [PATCH 1/4] Add Unitree LeRobot streaming mapper --- iltools/datasets/lerobot_stream.py | 495 ++++++++++++++++++++++++++ pyproject.toml | 1 + tests/datasets/test_lerobot_stream.py | 194 ++++++++++ 3 files changed, 690 insertions(+) create mode 100644 iltools/datasets/lerobot_stream.py create mode 100644 tests/datasets/test_lerobot_stream.py diff --git a/iltools/datasets/lerobot_stream.py b/iltools/datasets/lerobot_stream.py new file mode 100644 index 0000000..c14e46b --- /dev/null +++ b/iltools/datasets/lerobot_stream.py @@ -0,0 +1,495 @@ +"""LeRobot streaming ingestion utilities for offline imitation pretraining. + +This module keeps LeRobot on the input side and TorchRL TensorDict replay +buffers on the training side. Heavy conversion happens before samples enter +the replay buffer, so algorithm code can consume validated TensorDict batches. +""" + +from __future__ import annotations + +import logging +import threading +from collections.abc import Iterable, Iterator, Mapping, Sequence +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import torch +from tensordict import TensorDict +from tensordict.base import TensorDictBase +from torch import Tensor +from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer +from torchrl.data.replay_buffers.writers import TensorDictRoundRobinWriter + +logger = logging.getLogger(__name__) + + +UNITREE_G1_WBT_DEFAULT_REPO_ID = "unitreerobotics/G1_WBT_Brainco_Pickup_Pillow" + + +@dataclass(frozen=True) +class UnitreeG1WBT29DofMapperConfig: + """Schema and robot constants for Unitree G1 WBT low-dimensional data.""" + + robot_q_current_key: str = "observation.state.robot_q_current" + robot_q_desired_key: str = "action.robot_q_desired" + episode_key: str = "episode_index" + dt: float = 1.0 / 30.0 + default_joint_pos: Sequence[Any] = () + action_scale: Sequence[float] = () + quat_order: str = "wxyz" + + +@dataclass(frozen=True) +class LeRobotStreamingCacheConfig: + """Runtime options for streaming LeRobot data into a TorchRL cache.""" + + repo_id: str = UNITREE_G1_WBT_DEFAULT_REPO_ID + split: str = "train" + cache_dir: str | Path = "/tmp/iltools_lerobot_torchrl_cache" + max_cache_transitions: int = 5_000_000 + min_ready_transitions: int = 100_000 + low_watermark: int = 1_000_000 + starvation_timeout_s: float = 300.0 + local_sample_prefetch: int = 0 + batch_size: int | None = None + max_episodes: int | None = None + mapper: UnitreeG1WBT29DofMapperConfig = UnitreeG1WBT29DofMapperConfig() + + +def _to_tensor(value: Any) -> Tensor: + if torch.is_tensor(value): + return value.detach().to(dtype=torch.float32) + return torch.as_tensor(value, dtype=torch.float32) + + +def _normalize_quat_wxyz(quat: Tensor, quat_order: str) -> Tensor: + if quat.shape[-1] != 4: + raise ValueError(f"Expected quaternion width 4, got {tuple(quat.shape)}.") + if quat_order == "wxyz": + quat_wxyz = quat + elif quat_order == "xyzw": + quat_wxyz = quat[..., [3, 0, 1, 2]] + else: + raise ValueError(f"Unsupported quat_order={quat_order!r}.") + return quat_wxyz / quat_wxyz.norm(dim=-1, keepdim=True).clamp_min(1.0e-8) + + +def _quat_conjugate_wxyz(quat: Tensor) -> Tensor: + return torch.cat([quat[..., :1], -quat[..., 1:]], dim=-1) + + +def _quat_mul_wxyz(lhs: Tensor, rhs: Tensor) -> Tensor: + lw, lx, ly, lz = lhs.unbind(dim=-1) + rw, rx, ry, rz = rhs.unbind(dim=-1) + return torch.stack( + ( + lw * rw - lx * rx - ly * ry - lz * rz, + lw * rx + lx * rw + ly * rz - lz * ry, + lw * ry - lx * rz + ly * rw + lz * rx, + lw * rz + lx * ry - ly * rx + lz * rw, + ), + dim=-1, + ) + + +def _axis_angle_from_quat_wxyz(quat: Tensor) -> Tensor: + quat = quat / quat.norm(dim=-1, keepdim=True).clamp_min(1.0e-8) + vector = quat[..., 1:] + vector_norm = vector.norm(dim=-1, keepdim=True) + angle = 2.0 * torch.atan2(vector_norm, quat[..., :1].clamp(-1.0, 1.0)) + axis = vector / vector_norm.clamp_min(1.0e-8) + return axis * angle + + +def _finite_difference(values: Tensor, dt: float) -> Tensor: + if values.shape[0] < 2: + raise ValueError("Need at least two frames to finite-difference an episode.") + return torch.gradient(values, spacing=(float(dt),), dim=0)[0] + + +def _so3_derivative_wxyz(rotations: Tensor, dt: float) -> Tensor: + if rotations.shape[0] < 3: + return torch.zeros( + (rotations.shape[0], 3), dtype=rotations.dtype, device=rotations.device + ) + q_prev = rotations[:-2] + q_next = rotations[2:] + q_rel = _quat_mul_wxyz(_quat_conjugate_wxyz(q_prev), q_next) + omega = _axis_angle_from_quat_wxyz(q_rel) / (2.0 * float(dt)) + return torch.cat([omega[:1], omega, omega[-1:]], dim=0) + + +def _get_required(mapping: Mapping[Any, Any] | TensorDictBase, key: str) -> Any: + if isinstance(mapping, TensorDictBase): + value = mapping.get(key) + elif key in mapping: + value = mapping[key] + else: + raise KeyError(f"Missing required LeRobot field {key!r}.") + if value is None: + raise KeyError(f"Missing required LeRobot field {key!r}.") + return value + + +def _stack_rows(rows: Sequence[Mapping[str, Any]], keys: Sequence[str]) -> TensorDict: + if len(rows) == 0: + raise ValueError("Cannot stack an empty episode.") + data = {} + for key in keys: + data[key] = torch.stack([_to_tensor(_get_required(row, key)) for row in rows]) + return TensorDict(data, batch_size=[len(rows)]) + + +def _identity_rot6d(*, length: int, device: torch.device, dtype: torch.dtype) -> Tensor: + rot6d = torch.zeros((length, 6), device=device, dtype=dtype) + rot6d[:, 0] = 1.0 + rot6d[:, 4] = 1.0 + return rot6d + + +class UnitreeG1WBT29DofMapper: + """Map Unitree G1 WBT LeRobot episodes into canonical training transitions.""" + + robot_q_width = 36 + joint_width = 29 + + def __init__(self, config: UnitreeG1WBT29DofMapperConfig) -> None: + self.config = config + if float(config.dt) <= 0.0: + raise ValueError("mapper.dt must be positive.") + default_joint_pos = _to_tensor(config.default_joint_pos) + if default_joint_pos.ndim == 1: + self.default_joint_pos_pool = default_joint_pos.unsqueeze(0) + elif default_joint_pos.ndim == 2: + self.default_joint_pos_pool = default_joint_pos + else: + raise ValueError( + "default_joint_pos must have shape [29] or [N, 29], got " + f"{tuple(default_joint_pos.shape)}." + ) + self.default_joint_pos = self.default_joint_pos_pool[0] + self.action_scale = _to_tensor(config.action_scale).flatten() + if self.default_joint_pos_pool.shape[-1] != self.joint_width: + raise ValueError( + "default_joint_pos must contain 29 G1 joint values per row, got " + f"{tuple(self.default_joint_pos_pool.shape)}." + ) + if self.default_joint_pos_pool.shape[0] <= 0: + raise ValueError("default_joint_pos must contain at least one row.") + if tuple(self.action_scale.shape) != (self.joint_width,): + raise ValueError( + "action_scale must contain 29 G1 joint values, got " + f"{tuple(self.action_scale.shape)}." + ) + if torch.any(self.action_scale.abs() <= 1.0e-8): + raise ValueError("action_scale must not contain zeros.") + + def map_episode( + self, episode: TensorDictBase | Mapping[str, Any] | Sequence[Mapping[str, Any]] + ) -> TensorDict: + if isinstance(episode, Sequence) and not isinstance(episode, TensorDictBase): + episode_td = _stack_rows( + episode, + ( + self.config.episode_key, + self.config.robot_q_current_key, + self.config.robot_q_desired_key, + ), + ) + else: + data = { + self.config.robot_q_current_key: _to_tensor( + _get_required(episode, self.config.robot_q_current_key) # type: ignore[arg-type] + ), + self.config.robot_q_desired_key: _to_tensor( + _get_required(episode, self.config.robot_q_desired_key) # type: ignore[arg-type] + ), + } + if self.config.episode_key in episode: # type: ignore[operator] + data[self.config.episode_key] = _to_tensor( # type: ignore[index] + episode[self.config.episode_key] # type: ignore[index] + ) + episode_td = TensorDict( + data, + batch_size=[ + int( + _to_tensor( + _get_required(episode, self.config.robot_q_current_key) # type: ignore[arg-type] + ).shape[0] + ) + ], + ) + return self._map_batched_episode(episode_td) + + def _episode_default_joint_pos(self, episode: TensorDictBase, like: Tensor) -> Tensor: + pool = self.default_joint_pos_pool.to(device=like.device, dtype=like.dtype) + if pool.shape[0] == 1: + return pool[0] + episode_index = episode.get(self.config.episode_key) + if episode_index is None: + raise KeyError( + "Unitree G1 WBT mapper requires episode_index to select from a " + "default_joint_pos pool." + ) + pool_index = int(_to_tensor(episode_index).flatten()[0].item()) % int( + pool.shape[0] + ) + return pool[pool_index] + + def _map_batched_episode(self, episode: TensorDictBase) -> TensorDict: + robot_q_current = _to_tensor(episode.get(self.config.robot_q_current_key)) + robot_q_desired = _to_tensor(episode.get(self.config.robot_q_desired_key)) + if robot_q_current.ndim != 2 or robot_q_current.shape[-1] != self.robot_q_width: + raise ValueError( + "robot_q_current must have shape [T, 36], got " + f"{tuple(robot_q_current.shape)}." + ) + if robot_q_desired.ndim != 2 or robot_q_desired.shape[-1] != self.robot_q_width: + raise ValueError( + "robot_q_desired must have shape [T, 36], got " + f"{tuple(robot_q_desired.shape)}." + ) + if robot_q_current.shape[0] != robot_q_desired.shape[0]: + raise ValueError("robot_q_current and robot_q_desired lengths differ.") + if int(robot_q_current.shape[0]) < 2: + raise ValueError("A WBT episode must contain at least two frames.") + + default_joint_pos = self._episode_default_joint_pos(episode, robot_q_current) + action_scale = self.action_scale.to( + device=robot_q_current.device, dtype=robot_q_current.dtype + ) + root_quat = _normalize_quat_wxyz( + robot_q_current[:, 3:7], self.config.quat_order + ) + root_pos = robot_q_current[:, :3] + joint_pos = robot_q_current[:, 7:] + joint_vel = _finite_difference(joint_pos, self.config.dt) + base_ang_vel = _so3_derivative_wxyz(root_quat, self.config.dt) + expert_motion = torch.cat([joint_pos, joint_vel], dim=-1) + expert_anchor_pos_b = torch.zeros( + (robot_q_current.shape[0], 3), + device=robot_q_current.device, + dtype=robot_q_current.dtype, + ) + expert_anchor_ori_b = _identity_rot6d( + length=int(robot_q_current.shape[0]), + device=robot_q_current.device, + dtype=robot_q_current.dtype, + ) + expert_action = (robot_q_desired[:, 7:] - default_joint_pos) / action_scale + last_action = torch.cat( + [torch.zeros_like(expert_action[:1]), expert_action[:-1]], dim=0 + ) + + n = int(robot_q_current.shape[0]) - 1 + done = torch.zeros(n, dtype=torch.bool, device=robot_q_current.device) + done[-1] = True + + return TensorDict( + { + ("policy", "root_pos"): root_pos[:-1], + ("policy", "root_quat"): root_quat[:-1], + ("policy", "joint_pos"): joint_pos[:-1], + ("policy", "base_ang_vel"): base_ang_vel[:-1], + ("policy", "joint_pos_rel"): joint_pos[:-1] - default_joint_pos, + ("policy", "joint_vel_rel"): joint_vel[:-1], + ("policy", "last_action"): last_action[:-1], + ("policy", "expert_motion"): expert_motion[:-1], + ("policy", "expert_anchor_pos_b"): expert_anchor_pos_b[:-1], + ("policy", "expert_anchor_ori_b"): expert_anchor_ori_b[:-1], + ("critic", "expert_motion"): expert_motion[:-1], + ("critic", "expert_anchor_pos_b"): expert_anchor_pos_b[:-1], + ("critic", "expert_anchor_ori_b"): expert_anchor_ori_b[:-1], + ("reward_input", "expert_motion"): expert_motion[:-1], + ("reward_input", "expert_anchor_pos_b"): expert_anchor_pos_b[:-1], + ("reward_input", "expert_anchor_ori_b"): expert_anchor_ori_b[:-1], + ("next", "policy", "root_pos"): root_pos[1:], + ("next", "policy", "root_quat"): root_quat[1:], + ("next", "policy", "joint_pos"): joint_pos[1:], + ("next", "policy", "base_ang_vel"): base_ang_vel[1:], + ("next", "policy", "joint_pos_rel"): joint_pos[1:] - default_joint_pos, + ("next", "policy", "joint_vel_rel"): joint_vel[1:], + ("next", "policy", "last_action"): last_action[1:], + ("next", "policy", "expert_motion"): expert_motion[1:], + ("next", "policy", "expert_anchor_pos_b"): expert_anchor_pos_b[1:], + ("next", "policy", "expert_anchor_ori_b"): expert_anchor_ori_b[1:], + "action": expert_action[:-1], + "expert_action": expert_action[:-1], + "done": done, + ("next", "done"): done, + ("next", "truncated"): torch.zeros_like(done), + }, + batch_size=[n], + ) + + +class StreamingTensorDictReplayCache: + """Bounded local TensorDict replay cache populated by a background producer.""" + + def __init__( + self, + config: LeRobotStreamingCacheConfig, + *, + mapper: UnitreeG1WBT29DofMapper, + source: Iterable[Mapping[str, Any]] | None = None, + ) -> None: + if int(config.max_cache_transitions) <= 0: + raise ValueError("max_cache_transitions must be positive.") + if int(config.min_ready_transitions) < 0: + raise ValueError("min_ready_transitions must be >= 0.") + if int(config.low_watermark) < 0: + raise ValueError("low_watermark must be >= 0.") + if int(config.min_ready_transitions) > int(config.max_cache_transitions): + raise ValueError( + "min_ready_transitions cannot exceed max_cache_transitions." + ) + if int(config.low_watermark) > int(config.max_cache_transitions): + raise ValueError("low_watermark cannot exceed max_cache_transitions.") + self.config = config + self.mapper = mapper + self.source = source + self.cache_dir = Path(config.cache_dir) + self.cache_dir.mkdir(parents=True, exist_ok=True) + storage = LazyMemmapStorage( + int(config.max_cache_transitions), + scratch_dir=str(self.cache_dir), + device="cpu", + existsok=True, + ) + self.replay_buffer = TensorDictReplayBuffer( + storage=storage, + writer=TensorDictRoundRobinWriter(), + batch_size=config.batch_size, + prefetch=int(config.local_sample_prefetch) + if int(config.local_sample_prefetch) > 0 + else None, + ) + self._condition = threading.Condition() + self._thread: threading.Thread | None = None + self._stop_event = threading.Event() + self._error: BaseException | None = None + self._episodes_written = 0 + + @property + def ready_transitions(self) -> int: + return len(self.replay_buffer) + + def start(self) -> None: + if self._thread is not None: + raise RuntimeError("Streaming cache producer has already been started.") + self._thread = threading.Thread( + target=self._producer_loop, + name="lerobot-streaming-cache", + daemon=True, + ) + self._thread.start() + + def stop(self) -> None: + self._stop_event.set() + thread = self._thread + if thread is not None: + thread.join(timeout=5.0) + + def wait_until_ready(self, timeout_s: float | None = None) -> None: + min_ready = int(self.config.min_ready_transitions) + timeout_s = ( + float(self.config.starvation_timeout_s) if timeout_s is None else timeout_s + ) + with self._condition: + ready = self._condition.wait_for( + lambda: self.ready_transitions >= min_ready + or self._error is not None + or (self._thread is not None and not self._thread.is_alive()), + timeout=float(timeout_s), + ) + if self._error is not None: + raise RuntimeError( + "LeRobot streaming producer failed." + ) from self._error + if self.ready_transitions >= min_ready: + return + if not ready: + raise TimeoutError( + "Timed out waiting for LeRobot cache readiness: " + f"ready={self.ready_transitions}, min_ready={min_ready}." + ) + raise RuntimeError( + "LeRobot streaming producer finished before cache reached " + f"min_ready_transitions={min_ready}; ready={self.ready_transitions}." + ) + + def sample(self, batch_size: int | None = None) -> TensorDict: + with self._condition: + if self._error is not None: + raise RuntimeError( + "LeRobot streaming producer failed." + ) from self._error + if self.ready_transitions <= 0: + raise RuntimeError("Cannot sample from an empty LeRobot cache.") + return self.replay_buffer.sample(batch_size) + + def _source_iter(self) -> Iterator[Mapping[str, Any]]: + if self.source is not None: + yield from self.source + return + try: + from lerobot.datasets import StreamingLeRobotDataset + except ImportError: + try: + from datasets import load_dataset + except ImportError as exc: + raise ImportError( + "lerobot_stream requires either lerobot or datasets. " + "Install iltools[lerobot], install lerobot directly, or install " + "huggingface datasets." + ) from exc + yield from load_dataset( + self.config.repo_id, + split=self.config.split, + streaming=True, + ) + return + yield from StreamingLeRobotDataset(self.config.repo_id) + + def _producer_loop(self) -> None: + try: + current_episode_id: int | None = None + current_rows: list[Mapping[str, Any]] = [] + for row in self._source_iter(): + if self._stop_event.is_set(): + break + episode_id = int(_get_required(row, self.config.mapper.episode_key)) + if current_episode_id is None: + current_episode_id = episode_id + if episode_id != current_episode_id: + self._write_episode(current_rows) + current_rows = [] + current_episode_id = episode_id + if ( + self.config.max_episodes is not None + and self._episodes_written >= int(self.config.max_episodes) + ): + break + current_rows.append(row) + if current_rows and not self._stop_event.is_set(): + self._write_episode(current_rows) + except BaseException as exc: # noqa: BLE001 + with self._condition: + self._error = exc + self._condition.notify_all() + + def _write_episode(self, rows: Sequence[Mapping[str, Any]]) -> None: + transitions = self.mapper.map_episode(rows) + with self._condition: + self.replay_buffer.extend(transitions) + self._episodes_written += 1 + ready_transitions = self.ready_transitions + self._condition.notify_all() + logger.debug( + "cached Unitree WBT episode %d | rows=%d | transitions=%d | ready=%d", + self._episodes_written, + len(rows), + transitions.numel(), + ready_transitions, + ) diff --git a/pyproject.toml b/pyproject.toml index 3ff5c4b..6a50d80 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ [project.optional-dependencies] loco-mujoco = ["loco-mujoco"] +lerobot = ["lerobot"] [tool.pytest.ini_options] addopts = "-q" diff --git a/tests/datasets/test_lerobot_stream.py b/tests/datasets/test_lerobot_stream.py new file mode 100644 index 0000000..58a1a59 --- /dev/null +++ b/tests/datasets/test_lerobot_stream.py @@ -0,0 +1,194 @@ +from __future__ import annotations + +import pytest +import torch +from tensordict import TensorDict + +from iltools.datasets.lerobot_stream import ( + LeRobotStreamingCacheConfig, + StreamingTensorDictReplayCache, + UnitreeG1WBT29DofMapper, + UnitreeG1WBT29DofMapperConfig, +) + + +def _make_mapper() -> UnitreeG1WBT29DofMapper: + default_joint_pos = torch.linspace(-0.2, 0.2, 29) + action_scale = torch.linspace(0.5, 1.5, 29) + return UnitreeG1WBT29DofMapper( + UnitreeG1WBT29DofMapperConfig( + default_joint_pos=default_joint_pos.tolist(), + action_scale=action_scale.tolist(), + ) + ) + + +def _make_fake_wbt_rows( + *, + episode_index: int = 7, + length: int = 5, +) -> tuple[list[dict[str, object]], torch.Tensor, torch.Tensor]: + mapper = _make_mapper() + default_joint_pos = mapper.default_joint_pos + action_scale = mapper.action_scale + + q_current = torch.zeros(length, 36) + q_current[:, 3:7] = torch.tensor([1.0, 0.0, 0.0, 0.0]) + joint_offsets = torch.linspace(-0.3, 0.3, 29) + frame_offsets = torch.arange(length, dtype=torch.float32).unsqueeze(-1) * 0.1 + q_current[:, 7:] = default_joint_pos + joint_offsets + frame_offsets + + expert_action = torch.stack( + [torch.linspace(-0.5, 0.5, 29) + 0.05 * float(frame) for frame in range(length)] + ) + q_desired = q_current.clone() + q_desired[:, 7:] = default_joint_pos + expert_action * action_scale + + rows = [ + { + "episode_index": episode_index, + "observation.state.robot_q_current": q_current[index], + "action.robot_q_desired": q_desired[index], + } + for index in range(length) + ] + return rows, q_current, expert_action + + +def test_unitree_g1_wbt_mapper_builds_aligned_training_transitions() -> None: + mapper = _make_mapper() + rows, q_current, expert_action = _make_fake_wbt_rows(length=5) + + transitions = mapper.map_episode(rows) + + assert transitions.batch_size == torch.Size([4]) + torch.testing.assert_close(transitions["action"], expert_action[:-1]) + torch.testing.assert_close(transitions["expert_action"], expert_action[:-1]) + torch.testing.assert_close( + transitions.get(("policy", "root_quat")), + torch.tensor([1.0, 0.0, 0.0, 0.0]).expand(4, 4), + ) + torch.testing.assert_close( + transitions.get(("next", "policy", "joint_pos")), + q_current[1:, 7:], + ) + torch.testing.assert_close( + transitions.get(("policy", "last_action"))[0], + torch.zeros(29), + ) + torch.testing.assert_close( + transitions.get(("policy", "last_action"))[1:], + expert_action[:-2], + ) + torch.testing.assert_close( + transitions.get(("next", "policy", "last_action")), + expert_action[:-1], + ) + torch.testing.assert_close( + transitions.get(("policy", "joint_vel_rel")), + torch.full((4, 29), 3.0), + ) + torch.testing.assert_close( + transitions.get(("policy", "base_ang_vel")), + torch.zeros(4, 3), + ) + torch.testing.assert_close( + transitions.get(("policy", "expert_motion")), + torch.cat([q_current[:-1, 7:], torch.full((4, 29), 3.0)], dim=-1), + ) + torch.testing.assert_close( + transitions.get(("policy", "expert_anchor_pos_b")), + torch.zeros(4, 3), + ) + expected_rot6d = torch.zeros(4, 6) + expected_rot6d[:, 0] = 1.0 + expected_rot6d[:, 4] = 1.0 + torch.testing.assert_close( + transitions.get(("policy", "expert_anchor_ori_b")), + expected_rot6d, + ) + torch.testing.assert_close( + transitions.get(("reward_input", "expert_motion")), + transitions.get(("policy", "expert_motion")), + ) + assert transitions["done"].tolist() == [False, False, False, True] + assert transitions.get(("next", "done")).tolist() == [False, False, False, True] + + +def test_unitree_g1_wbt_mapper_selects_episode_default_from_pool() -> None: + default_a = torch.zeros(29) + default_b = torch.linspace(-0.1, 0.1, 29) + action_scale = torch.ones(29) + mapper = UnitreeG1WBT29DofMapper( + UnitreeG1WBT29DofMapperConfig( + default_joint_pos=[default_a.tolist(), default_b.tolist()], + action_scale=action_scale.tolist(), + ) + ) + length = 4 + q_current = torch.zeros(length, 36) + q_current[:, 3:7] = torch.tensor([1.0, 0.0, 0.0, 0.0]) + q_current[:, 7:] = default_b + 0.25 + q_desired = q_current.clone() + q_desired[:, 7:] = default_b + 0.5 + rows = [ + { + "episode_index": 1, + "observation.state.robot_q_current": q_current[index], + "action.robot_q_desired": q_desired[index], + } + for index in range(length) + ] + + transitions = mapper.map_episode(rows) + + torch.testing.assert_close( + transitions.get(("policy", "joint_pos_rel")), + torch.full((length - 1, 29), 0.25), + ) + torch.testing.assert_close( + transitions.get("expert_action"), + torch.full((length - 1, 29), 0.5), + ) + + +def test_unitree_g1_wbt_mapper_fails_fast_on_bad_robot_width() -> None: + mapper = _make_mapper() + episode = TensorDict( + { + "observation.state.robot_q_current": torch.zeros(3, 35), + "action.robot_q_desired": torch.zeros(3, 36), + }, + batch_size=[3], + ) + + with pytest.raises(ValueError, match=r"shape \[T, 36\]"): + mapper.map_episode(episode) + + +def test_streaming_cache_fills_memmap_before_sampling(tmp_path) -> None: + mapper = _make_mapper() + rows, _, _ = _make_fake_wbt_rows(length=6) + cache = StreamingTensorDictReplayCache( + LeRobotStreamingCacheConfig( + cache_dir=tmp_path, + max_cache_transitions=16, + min_ready_transitions=5, + low_watermark=4, + batch_size=2, + max_episodes=1, + mapper=mapper.config, + ), + mapper=mapper, + source=rows, + ) + + cache.start() + cache.wait_until_ready(timeout_s=5.0) + sample = cache.sample(2) + cache.stop() + + assert sample.numel() == 2 + assert ("policy", "base_ang_vel") in sample.keys(True) + assert ("next", "policy", "joint_pos_rel") in sample.keys(True) + assert "expert_action" in sample.keys(True) From ff60e58a299286a3132c7f9d648fc6ac6fd8a4fd Mon Sep 17 00:00:00 2001 From: Feiyang Wu Date: Tue, 12 May 2026 17:00:14 -0400 Subject: [PATCH 2/4] Make dataset loaders lazy and optional --- iltools/cli/main.py | 31 +++++--- iltools/datasets/__init__.py | 8 +++ iltools/datasets/loaders.py | 98 ++++++++++++++++++++++++++ tests/datasets/test_loader_registry.py | 35 +++++++++ 4 files changed, 161 insertions(+), 11 deletions(-) create mode 100644 iltools/datasets/loaders.py create mode 100644 tests/datasets/test_loader_registry.py diff --git a/iltools/cli/main.py b/iltools/cli/main.py index 9a5de46..b8c01ed 100644 --- a/iltools/cli/main.py +++ b/iltools/cli/main.py @@ -2,9 +2,7 @@ from rich.console import Console from rich.table import Table -from iltools.datasets.amass.loader import AmassLoader -from iltools.datasets.loco_mujoco.loader import LocoMuJoCoLoader -from iltools.datasets.trajopt.loader import TrajoptLoader +from iltools.datasets.loaders import load_dataset_loader, registered_dataset_loaders app = typer.Typer(help="Imitation Learning Tools CLI") console = Console() @@ -23,17 +21,28 @@ def load( Loads a dataset and prints its metadata. """ with console.status(f"[bold green]Loading {dataset_name}...[/bold green]"): - if dataset_name == "amass": - loader = AmassLoader(data_path, model_path) - elif dataset_name == "loco_mujoco": - loader = LocoMuJoCoLoader( + try: + loader_cls = load_dataset_loader(dataset_name) + except KeyError: + choices = ", ".join(registered_dataset_loaders()) + console.print( + "[bold red]" + f"Unknown dataset: {dataset_name}. Choices: {choices}" + "[/bold red]" + ) + raise typer.Exit(1) from None + except ImportError as exc: + raise typer.BadParameter(str(exc)) from exc + + normalized_name = dataset_name.strip().lower().replace("-", "_") + if normalized_name == "loco_mujoco": + loader = loader_cls( env_name="Humanoid", task="walk", control_freq=control_freq ) - elif dataset_name == "trajopt": - loader = TrajoptLoader(data_path) + elif normalized_name in {"lafan1", "lafan1_csv"}: + loader = loader_cls(data_path) else: - console.print(f"[bold red]Unknown dataset: {dataset_name}[/bold red]") - raise typer.Exit(1) + loader = loader_cls(data_path) num_trajectories = len(loader) metadata = loader.metadata diff --git a/iltools/datasets/__init__.py b/iltools/datasets/__init__.py index d0df986..c587c17 100644 --- a/iltools/datasets/__init__.py +++ b/iltools/datasets/__init__.py @@ -1 +1,9 @@ """Dataset package.""" + +from .loaders import ( # noqa: F401 + DatasetLoaderSpec, + get_dataset_loader_spec, + load_dataset_loader, + register_dataset_loader, + registered_dataset_loaders, +) diff --git a/iltools/datasets/loaders.py b/iltools/datasets/loaders.py new file mode 100644 index 0000000..30df145 --- /dev/null +++ b/iltools/datasets/loaders.py @@ -0,0 +1,98 @@ +"""Lazy dataset-loader registry. + +This module intentionally stores loader targets as import strings so optional +dataset backends do not become import-time package requirements. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from importlib import import_module +from typing import Any + + +@dataclass(frozen=True) +class DatasetLoaderSpec: + """Import target for a dataset loader class.""" + + module: str + class_name: str + optional_dependency: str | None = None + + +_DATASET_LOADERS: dict[str, DatasetLoaderSpec] = { + "lafan1": DatasetLoaderSpec( + module="iltools.datasets.lafan1.loader", + class_name="Lafan1CsvLoader", + ), + "lafan1_csv": DatasetLoaderSpec( + module="iltools.datasets.lafan1.loader", + class_name="Lafan1CsvLoader", + ), + "loco_mujoco": DatasetLoaderSpec( + module="iltools.datasets.loco_mujoco.loader", + class_name="LocoMuJoCoLoader", + optional_dependency="loco-mujoco", + ), +} + + +def register_dataset_loader( + name: str, + *, + module: str, + class_name: str, + optional_dependency: str | None = None, +) -> None: + """Register or override a dataset loader import target.""" + key = _normalize_name(name) + _DATASET_LOADERS[key] = DatasetLoaderSpec( + module=module, + class_name=class_name, + optional_dependency=optional_dependency, + ) + + +def get_dataset_loader_spec(name: str) -> DatasetLoaderSpec | None: + """Return the registered loader spec for ``name`` if one exists.""" + return _DATASET_LOADERS.get(_normalize_name(name)) + + +def load_dataset_loader(name: str) -> type[Any]: + """Import and return a registered loader class. + + Raises: + KeyError: if ``name`` is not registered. + ImportError: if the loader module or class cannot be imported. + """ + key = _normalize_name(name) + spec = _DATASET_LOADERS.get(key) + if spec is None: + raise KeyError(f"Unknown dataset loader: {name}") + + try: + module = import_module(spec.module) + except ImportError as exc: + dependency = spec.optional_dependency or spec.module + raise ImportError( + f"Dataset loader '{key}' requires optional dependency '{dependency}'." + ) from exc + + try: + loader_cls = getattr(module, spec.class_name) + except AttributeError as exc: + raise ImportError( + f"Dataset loader '{key}' module '{spec.module}' does not define " + f"'{spec.class_name}'." + ) from exc + + return loader_cls + + +def registered_dataset_loaders() -> tuple[str, ...]: + """Return registered dataset loader names.""" + return tuple(sorted(_DATASET_LOADERS)) + + +def _normalize_name(name: str) -> str: + return name.strip().lower().replace("-", "_") diff --git a/tests/datasets/test_loader_registry.py b/tests/datasets/test_loader_registry.py new file mode 100644 index 0000000..985923b --- /dev/null +++ b/tests/datasets/test_loader_registry.py @@ -0,0 +1,35 @@ +from iltools.datasets.loaders import ( + DatasetLoaderSpec, + get_dataset_loader_spec, + load_dataset_loader, + register_dataset_loader, + registered_dataset_loaders, +) + + +def test_registered_loaders_do_not_import_optional_backends_on_listing(): + names = registered_dataset_loaders() + + assert "lafan1_csv" in names + assert "loco_mujoco" in names + + +def test_load_lafan1_csv_loader(): + loader_cls = load_dataset_loader("lafan1-csv") + + assert loader_cls.__name__ == "Lafan1CsvLoader" + + +def test_register_dataset_loader_import_target(): + register_dataset_loader( + "base_loader", + module="iltools.datasets.base_loader", + class_name="BaseLoader", + ) + + spec = get_dataset_loader_spec("base-loader") + assert spec == DatasetLoaderSpec( + module="iltools.datasets.base_loader", + class_name="BaseLoader", + ) + assert load_dataset_loader("base_loader").__name__ == "BaseLoader" From 5f8d41fccb0f3caa3e26282e2449bb7e10ee6cf6 Mon Sep 17 00:00:00 2001 From: Feiyang Wu Date: Wed, 13 May 2026 16:49:25 -0400 Subject: [PATCH 3/4] Add LeRobot loader and trajectory cache support --- iltools/datasets/lafan1/loader.py | 108 +- iltools/datasets/lerobot/__init__.py | 5 + iltools/datasets/lerobot/loader.py | 1639 ++++++++++++++++++++++ iltools/datasets/lerobot_stream.py | 183 ++- iltools/datasets/loaders.py | 5 + tests/datasets/test_lafan1_csv_loader.py | 73 +- tests/datasets/test_lerobot_loader.py | 154 ++ tests/datasets/test_lerobot_stream.py | 126 ++ tests/datasets/test_loader_registry.py | 7 + 9 files changed, 2249 insertions(+), 51 deletions(-) create mode 100644 iltools/datasets/lerobot/__init__.py create mode 100644 iltools/datasets/lerobot/loader.py create mode 100644 tests/datasets/test_lerobot_loader.py diff --git a/iltools/datasets/lafan1/loader.py b/iltools/datasets/lafan1/loader.py index 9567d89..5cfe0ab 100644 --- a/iltools/datasets/lafan1/loader.py +++ b/iltools/datasets/lafan1/loader.py @@ -98,7 +98,9 @@ def _normalize_frame_range(value: Any) -> tuple[int, int] | None: else: seq = list(value) if len(seq) != 2: - raise ValueError("frame_range must have exactly two elements: [start, end].") + raise ValueError( + "frame_range must have exactly two elements: [start, end]." + ) start, end = seq start_i = int(start) end_i = int(end) @@ -315,7 +317,10 @@ def _normalize_source_entries(self, value: Any) -> list[Any]: if self._looks_like_motion_entry(value): return [value] # Mapping style: motion_name -> path(s) - return [{"name": str(name), "path": path_spec} for name, path_spec in value.items()] + return [ + {"name": str(name), "path": path_spec} + for name, path_spec in value.items() + ] return _as_list(value) def _looks_like_motion_entry(self, value: Mapping[str, Any]) -> bool: @@ -370,7 +375,9 @@ def _resolve_entry(self, entry: Any) -> list[MotionSource]: "files", entry.get( "csv_files", - entry.get("path", entry.get("file", entry.get("csv_path", None))), + entry.get( + "path", entry.get("file", entry.get("csv_path", None)) + ), ), ), ) @@ -514,7 +521,9 @@ def _get_trajectories( if build_zarr_dataset and dataset_group is not None: motion_group = dataset_group.create_group(motion_name) - motion_entry = motion_info_dict.setdefault(self.dataset_name, {}).setdefault( + motion_entry = motion_info_dict.setdefault( + self.dataset_name, {} + ).setdefault( motion_name, { "motion_name": motion_name, @@ -579,7 +588,9 @@ def _get_trajectories( if motion_group is not None: motion_group.attrs["num_trajectories"] = len(motion_sources) - motion_group.attrs["trajectory_lengths"] = motion_entry["trajectory_lengths"] + motion_group.attrs["trajectory_lengths"] = motion_entry[ + "trajectory_lengths" + ] motion_group.attrs["source_files"] = motion_entry["source_files"] motion_group.attrs["source_fps"] = motion_entry["source_fps"] motion_group.attrs["output_fps"] = motion_entry["output_fps"] @@ -721,7 +732,11 @@ def _load_npz_motion( with np.load(source.path) as npz_data: arrays = {key: np.asarray(npz_data[key]) for key in npz_data.files} - source_fps = float(np.asarray(arrays.get("fps", source.input_fps)).reshape(-1)[0]) + self._apply_npz_name_metadata(arrays=arrays, path=source.path) + + source_fps = float( + np.asarray(arrays.get("fps", source.input_fps)).reshape(-1)[0] + ) if source_fps <= 0.0: raise ValueError(f"Invalid source fps ({source_fps}) for {source.path}") @@ -755,11 +770,17 @@ def _load_npz_motion( f"NPZ {source.path} must contain 'qpos' or 'joint_pos'." ) joint_pos = joint_pos.astype(np.float32) - joint_pos = self._apply_frame_range(joint_pos, source.frame_range, source.path) + joint_pos = self._apply_frame_range( + joint_pos, source.frame_range, source.path + ) root_pos, root_quat = self._extract_root_pose_from_npz(arrays, source.path) - root_pos = self._apply_frame_range(root_pos, source.frame_range, source.path) - root_quat = self._apply_frame_range(root_quat, source.frame_range, source.path) + root_pos = self._apply_frame_range( + root_pos, source.frame_range, source.path + ) + root_quat = self._apply_frame_range( + root_quat, source.frame_range, source.path + ) joint_vel_raw = arrays.get("joint_vel") root_lin_vel_raw, root_ang_vel_raw = self._extract_root_vel_from_npz(arrays) @@ -769,12 +790,16 @@ def _load_npz_motion( else None ) root_lin_vel = ( - self._apply_frame_range(root_lin_vel_raw, source.frame_range, source.path) + self._apply_frame_range( + root_lin_vel_raw, source.frame_range, source.path + ) if root_lin_vel_raw is not None else None ) root_ang_vel = ( - self._apply_frame_range(root_ang_vel_raw, source.frame_range, source.path) + self._apply_frame_range( + root_ang_vel_raw, source.frame_range, source.path + ) if root_ang_vel_raw is not None else None ) @@ -825,6 +850,55 @@ def _load_npz_motion( ) return traj_data, source_fps, output_fps + def _apply_npz_name_metadata( + self, *, arrays: Mapping[str, np.ndarray], path: Path + ) -> None: + self._apply_npz_names( + key="joint_names", + current=self._joint_names, + setter=lambda names: setattr(self, "_joint_names", names), + arrays=arrays, + path=path, + ) + self._apply_npz_names( + key="body_names", + current=self._body_names, + setter=lambda names: setattr(self, "_body_names", names), + arrays=arrays, + path=path, + ) + self._apply_npz_names( + key="site_names", + current=self._site_names, + setter=lambda names: setattr(self, "_site_names", names), + arrays=arrays, + path=path, + ) + + def _apply_npz_names( + self, + *, + key: str, + current: list[str] | None, + setter: Any, + arrays: Mapping[str, np.ndarray], + path: Path, + ) -> None: + if key not in arrays: + return + value = np.asarray(arrays[key]) + if value.ndim != 1: + raise ValueError(f"NPZ {path} {key} must be a 1D array.") + parsed = [str(item) for item in value.tolist()] + if current is None: + setter(parsed) + return + if parsed != current: + raise ValueError( + f"NPZ {path} {key} does not match configured {key}: " + f"expected={current}, got={parsed}." + ) + def _extract_root_pose_from_npz( self, arrays: Mapping[str, np.ndarray], path: Path ) -> tuple[np.ndarray, np.ndarray]: @@ -890,9 +964,9 @@ def _build_trajectory_dict( qpos = np.concatenate([root_pos, root_quat, joint_pos], axis=-1).astype( np.float32 ) - qvel = np.concatenate( - [root_lin_vel, root_ang_vel, joint_vel], axis=-1 - ).astype(np.float32) + qvel = np.concatenate([root_lin_vel, root_ang_vel, joint_vel], axis=-1).astype( + np.float32 + ) traj_data: dict[str, np.ndarray] = { "qpos": qpos, @@ -1124,9 +1198,9 @@ def _quat_slerp_batch( theta = theta_0 * t[~linear_mask] s0 = np.sin(theta_0 - theta) / np.maximum(sin_theta_0, EPS) s1 = np.sin(theta) / np.maximum(sin_theta_0, EPS) - out[~linear_mask] = qa[~linear_mask] * s0[:, None] + qb[~linear_mask] * s1[ - :, None - ] + out[~linear_mask] = ( + qa[~linear_mask] * s0[:, None] + qb[~linear_mask] * s1[:, None] + ) out[~linear_mask] = self._normalize_quat(out[~linear_mask]) return out.astype(np.float32) diff --git a/iltools/datasets/lerobot/__init__.py b/iltools/datasets/lerobot/__init__.py new file mode 100644 index 0000000..8f49e56 --- /dev/null +++ b/iltools/datasets/lerobot/__init__.py @@ -0,0 +1,5 @@ +from __future__ import annotations + +from .loader import LeRobotLoader + +__all__ = ["LeRobotLoader"] diff --git a/iltools/datasets/lerobot/loader.py b/iltools/datasets/lerobot/loader.py new file mode 100644 index 0000000..b7de8a6 --- /dev/null +++ b/iltools/datasets/lerobot/loader.py @@ -0,0 +1,1639 @@ +from __future__ import annotations + +import inspect +import os +import re +from collections.abc import Iterable, Iterator, Mapping, Sequence +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import numpy as np +import zarr +from zarr.storage import LocalStore + +from iltools.core.metadata_schema import DatasetMeta +from iltools.datasets.base_loader import BaseLoader + +TrajectoryEntry = dict[str, Any] +MotionIndex = dict[str, dict[str, dict[str, Any]]] + +BASE_EXPORT_KEYS: frozenset[str] = frozenset( + [ + "qpos", + "qvel", + "root_pos", + "root_quat", + "root_lin_vel", + "root_ang_vel", + "joint_pos", + "joint_vel", + "episode_index", + "frame_index", + ] +) +TRANSITION_EXPORT_KEYS: frozenset[str] = frozenset( + [ + "next_qpos", + "next_qvel", + "next_root_pos", + "next_root_quat", + "next_root_lin_vel", + "next_root_ang_vel", + "next_joint_pos", + "next_joint_vel", + "next_episode_index", + "next_frame_index", + ] +) +OPTIONAL_ACTION_KEYS: frozenset[str] = frozenset( + [ + "action", + "next_action", + "target_qpos", + "next_target_qpos", + "target_joint_pos", + "next_target_joint_pos", + "target_root_pos", + "next_target_root_pos", + "target_root_quat", + "next_target_root_quat", + ] +) +OPTIONAL_TIME_KEYS: frozenset[str] = frozenset(["timestamp", "next_timestamp"]) + +UNITREE_G1_WBT_DEFAULT_REPO_ID = "unitreerobotics/G1_WBT_Brainco_Pickup_Pillow" +UNITREE_G1_WBT_29DOF_DATASET_JOINT_NAMES: tuple[str, ...] = ( + "left_hip_pitch_joint", + "left_hip_roll_joint", + "left_hip_yaw_joint", + "left_knee_joint", + "left_ankle_pitch_joint", + "left_ankle_roll_joint", + "right_hip_pitch_joint", + "right_hip_roll_joint", + "right_hip_yaw_joint", + "right_knee_joint", + "right_ankle_pitch_joint", + "right_ankle_roll_joint", + "waist_yaw_joint", + "waist_roll_joint", + "waist_pitch_joint", + "left_shoulder_pitch_joint", + "left_shoulder_roll_joint", + "left_shoulder_yaw_joint", + "left_elbow_joint", + "left_wrist_roll_joint", + "left_wrist_pitch_joint", + "left_wrist_yaw_joint", + "right_shoulder_pitch_joint", + "right_shoulder_roll_joint", + "right_shoulder_yaw_joint", + "right_elbow_joint", + "right_wrist_roll_joint", + "right_wrist_pitch_joint", + "right_wrist_yaw_joint", +) + +EPS = 1.0e-8 +_MISSING = object() + + +def _cfg_get(cfg: Any, key: str, default: Any = None) -> Any: + if cfg is None: + return default + if isinstance(cfg, Mapping): + return cfg.get(key, default) + return getattr(cfg, key, default) + + +def _cfg_get_nested(cfg: Any, keys: Sequence[str], default: Any = None) -> Any: + current = cfg + for key in keys: + if current is None: + return default + if isinstance(current, Mapping): + current = current.get(key, None) + else: + current = getattr(current, key, None) + return default if current is None else current + + +def _as_list(value: Any) -> list[Any]: + if value is None: + return [] + if isinstance(value, (str, os.PathLike)): + return [value] + if isinstance(value, Mapping): + return [value] + if isinstance(value, Sequence): + return list(value) + return [value] + + +def _maybe_list_of_str(value: Any) -> list[str] | None: + if value is None: + return None + return [str(v) for v in _as_list(value)] + + +def _maybe_list_of_int(value: Any) -> tuple[int, ...] | None: + if value is None: + return None + return tuple(int(v) for v in _as_list(value)) + + +def _normalize_frame_range(value: Any) -> tuple[int, int] | None: + if value is None: + return None + if isinstance(value, Mapping): + start = value.get("start") + end = value.get("end") + else: + seq = list(value) + if len(seq) != 2: + raise ValueError( + "frame_range must have exactly two elements: [start, end]." + ) + start, end = seq + start_i = int(start) + end_i = int(end) + if start_i < 1: + raise ValueError("frame_range start must be >= 1 (1-indexed inclusive).") + if end_i < start_i: + raise ValueError("frame_range end must be >= start.") + return start_i, end_i + + +def _sanitize_name(name: str) -> str: + cleaned = re.sub(r"[^A-Za-z0-9_\-]+", "_", str(name)).strip("_") + return cleaned or "motion" + + +def _row_get(row: Mapping[str, Any], key: str, default: Any = _MISSING) -> Any: + if key in row: + return row[key] + + current: Any = row + for part in key.split("."): + if isinstance(current, Mapping) and part in current: + current = current[part] + else: + if default is _MISSING: + raise KeyError(key) + return default + return current + + +def _row_get_first( + row: Mapping[str, Any], keys: Sequence[str], default: Any = _MISSING +) -> tuple[Any, str | None]: + for key in keys: + try: + return _row_get(row, key), key + except KeyError: + continue + if default is _MISSING: + raise KeyError(f"Missing required row field; tried {list(keys)}.") + return default, None + + +def _to_numpy( + value: Any, *, dtype: np.dtype[Any] | type | None = np.float32 +) -> np.ndarray: + if hasattr(value, "detach") and callable(value.detach): + value = value.detach().cpu().numpy() + elif hasattr(value, "cpu") and callable(value.cpu): + value = value.cpu().numpy() + array = np.asarray(value) + if dtype is not None: + array = array.astype(dtype) + return array + + +def _stack_rows( + rows: Sequence[Mapping[str, Any]], + keys: Sequence[str], + *, + required: bool, + dtype: np.dtype[Any] | type | None = np.float32, +) -> tuple[np.ndarray | None, str | None]: + values: list[np.ndarray] = [] + selected_key: str | None = None + for row in rows: + value, key = _row_get_first(row, keys, default=None) + if key is None: + if required: + raise KeyError(f"Missing required row field; tried {list(keys)}.") + return None, None + if selected_key is None: + selected_key = key + elif selected_key != key: + raise ValueError( + "Rows mix multiple keys for the same field: " + f"first={selected_key!r}, current={key!r}." + ) + array = _to_numpy(value, dtype=dtype) + if array.ndim == 0: + array = array.reshape(1) + values.append(array) + if not values: + if required: + raise ValueError("Cannot stack an empty episode.") + return None, None + return np.stack(values, axis=0), selected_key + + +def _normalize_quat_wxyz(quat: np.ndarray, quat_order: str) -> np.ndarray: + if quat.shape[-1] != 4: + raise ValueError(f"Expected quaternion width 4, got {tuple(quat.shape)}.") + if quat_order == "wxyz": + quat_wxyz = quat + elif quat_order == "xyzw": + quat_wxyz = quat[..., [3, 0, 1, 2]] + else: + raise ValueError(f"Unsupported quat_order={quat_order!r}.") + norm = np.linalg.norm(quat_wxyz, axis=-1, keepdims=True) + norm = np.where(norm < EPS, 1.0, norm) + return (quat_wxyz / norm).astype(np.float32) + + +def _quat_conjugate_wxyz(quat: np.ndarray) -> np.ndarray: + out = quat.copy() + out[..., 1:] *= -1.0 + return out + + +def _quat_mul_wxyz(lhs: np.ndarray, rhs: np.ndarray) -> np.ndarray: + lw, lx, ly, lz = np.moveaxis(lhs, -1, 0) + rw, rx, ry, rz = np.moveaxis(rhs, -1, 0) + return np.stack( + ( + lw * rw - lx * rx - ly * ry - lz * rz, + lw * rx + lx * rw + ly * rz - lz * ry, + lw * ry - lx * rz + ly * rw + lz * rx, + lw * rz + lx * ry - ly * rx + lz * rw, + ), + axis=-1, + ) + + +def _axis_angle_from_quat_wxyz(quat: np.ndarray) -> np.ndarray: + q = _normalize_quat_wxyz(quat, "wxyz") + w = np.clip(q[..., 0], -1.0, 1.0) + vector = q[..., 1:] + vector_norm = np.linalg.norm(vector, axis=-1, keepdims=True) + axis = np.divide( + vector, + np.where(vector_norm < EPS, 1.0, vector_norm), + out=np.zeros_like(vector), + where=vector_norm >= EPS, + ) + angle = 2.0 * np.arctan2(vector_norm[..., 0], w) + axis_angle = axis * angle[..., None] + axis_angle[vector_norm[..., 0] < EPS] = 0.0 + return axis_angle.astype(np.float32) + + +def _so3_derivative_wxyz(rotations: np.ndarray, dt: float) -> np.ndarray: + q = _normalize_quat_wxyz(rotations, "wxyz") + n = q.shape[0] + if n <= 1: + return np.zeros((n, 3), dtype=np.float32) + if n == 2: + q_rel = _quat_mul_wxyz(q[1:2], _quat_conjugate_wxyz(q[0:1])) + omega = _axis_angle_from_quat_wxyz(q_rel) / float(dt) + return np.repeat(omega, 2, axis=0).astype(np.float32) + q_prev = q[:-2] + q_next = q[2:] + q_rel = _quat_mul_wxyz(q_next, _quat_conjugate_wxyz(q_prev)) + omega = _axis_angle_from_quat_wxyz(q_rel) / (2.0 * float(dt)) + return np.concatenate([omega[:1], omega, omega[-1:]], axis=0).astype(np.float32) + + +def _lerp(lhs: np.ndarray, rhs: np.ndarray, blend: np.ndarray) -> np.ndarray: + return lhs * (1.0 - blend) + rhs * blend + + +def _quat_slerp_wxyz(q0: np.ndarray, q1: np.ndarray, blend: np.ndarray) -> np.ndarray: + qa = _normalize_quat_wxyz(q0.astype(np.float32), "wxyz") + qb = _normalize_quat_wxyz(q1.astype(np.float32), "wxyz") + t = np.asarray(blend, dtype=np.float32) + + dot = np.sum(qa * qb, axis=-1) + neg_mask = dot < 0.0 + qb = qb.copy() + qb[neg_mask] *= -1.0 + dot = np.abs(dot) + dot = np.clip(dot, -1.0, 1.0) + + out = np.empty_like(qa) + linear_mask = dot > 0.9995 + if np.any(linear_mask): + t_linear = t[linear_mask][:, None] + out[linear_mask] = _normalize_quat_wxyz( + qa[linear_mask] * (1.0 - t_linear) + qb[linear_mask] * t_linear, + "wxyz", + ) + + if np.any(~linear_mask): + theta_0 = np.arccos(dot[~linear_mask]) + sin_theta_0 = np.sin(theta_0) + theta = theta_0 * t[~linear_mask] + s0 = np.sin(theta_0 - theta) / np.maximum(sin_theta_0, EPS) + s1 = np.sin(theta) / np.maximum(sin_theta_0, EPS) + out[~linear_mask] = ( + qa[~linear_mask] * s0[:, None] + qb[~linear_mask] * s1[:, None] + ) + out[~linear_mask] = _normalize_quat_wxyz(out[~linear_mask], "wxyz") + return out.astype(np.float32) + + +def _filter_kwargs(callable_obj: Any, kwargs: Mapping[str, Any]) -> dict[str, Any]: + try: + signature = inspect.signature(callable_obj) + except (TypeError, ValueError): + return {key: value for key, value in kwargs.items() if value is not None} + parameters = signature.parameters + accepts_kwargs = any( + parameter.kind == inspect.Parameter.VAR_KEYWORD + for parameter in parameters.values() + ) + return { + key: value + for key, value in kwargs.items() + if value is not None and (accepts_kwargs or key in parameters) + } + + +@dataclass(frozen=True) +class LeRobotSource: + repo_id: str + motion_name: str + split: str + root: str | None + revision: str | None + episodes: tuple[int, ...] | None + streaming: bool + fps: float | None + max_episodes: int | None + max_rows: int | None + source_rows_key: str | None = None + + +@dataclass(frozen=True) +class TrajectoryInfo: + dataset: str + motion: str + motion_name: str + trajectory_index: int + trajectory_in_motion: int + start: int + end: int + + @property + def length(self) -> int: + return self.end - self.start + + def to_dict(self) -> TrajectoryEntry: + return { + "dataset": self.dataset, + "motion": self.motion, + "motion_name": self.motion_name, + "trajectory_index": self.trajectory_index, + "trajectory_in_motion": self.trajectory_in_motion, + "start": self.start, + "end": self.end, + "length": self.length, + } + + +class LeRobotLoader(BaseLoader): + """Load low-dimensional LeRobot episodes into ILTools Zarr trajectories. + + By default this targets the public Unitree G1 WBT LeRobot schema, where + ``observation.state.robot_q_current`` and ``action.robot_q_desired`` are + 36-wide vectors: root position, root quaternion, then 29 joint positions. + The state/action keys, joint names, and quaternion order are configurable so + the loader can also ingest standard LeRobot rows that expose + ``observation.state`` and ``action``. + """ + + def __init__( + self, + cfg: Any, + build_zarr_dataset: bool = True, + zarr_path: str | None = None, + *, + source: Iterable[Mapping[str, Any]] + | Mapping[str, Iterable[Mapping[str, Any]]] + | None = None, + **kwargs: Any, + ) -> None: + super().__init__() + self.cfg = cfg + self.dataset_name = str( + _cfg_get( + cfg, + "dataset_name", + _cfg_get_nested(cfg, ("dataset", "name"), "lerobot"), + ) + ) + self.dataset_source = str(_cfg_get(cfg, "source_name", "lerobot")) + self.default_split = str( + _cfg_get(cfg, "split", _cfg_get_nested(cfg, ("dataset", "split"), "train")) + ) + self.default_fps = float(_cfg_get(cfg, "fps", 30.0)) + if self.default_fps <= 0.0: + raise ValueError("fps must be positive.") + self.control_freq = self._resolve_control_freq() + self.frame_range = _normalize_frame_range(_cfg_get(cfg, "frame_range", None)) + + self.state_key_candidates = self._resolve_key_candidates( + primary_key="state_key", + nested_path=("dataset", "state_key"), + defaults=( + "observation.state.robot_q_current", + "observation.state", + ), + ) + self.action_key_candidates = self._resolve_key_candidates( + primary_key="action_key", + nested_path=("dataset", "action_key"), + defaults=( + "action.robot_q_desired", + "action", + ), + allow_none=True, + ) + self.episode_key = str(_cfg_get(cfg, "episode_key", "episode_index")) + self.frame_key = str(_cfg_get(cfg, "frame_key", "frame_index")) + self.timestamp_key = str(_cfg_get(cfg, "timestamp_key", "timestamp")) + self.quat_order = str(_cfg_get(cfg, "quat_order", "wxyz")) + self.align_root_z_to_default = bool( + _cfg_get(cfg, "align_root_z_to_default", False) + ) + self.default_root_height = float(_cfg_get(cfg, "default_root_height", 0.0)) + self.drop_short_episodes = bool(_cfg_get(cfg, "drop_short_episodes", True)) + + self.default_root = self._optional_path_str( + _cfg_get(cfg, "root", _cfg_get_nested(cfg, ("dataset", "root"), None)) + ) + self.default_revision = _cfg_get( + cfg, "revision", _cfg_get_nested(cfg, ("dataset", "revision"), None) + ) + self.default_episodes = _maybe_list_of_int( + _cfg_get( + cfg, + "episodes", + _cfg_get_nested(cfg, ("dataset", "episodes"), None), + ) + ) + self.default_streaming = bool(_cfg_get(cfg, "streaming", False)) + self.force_cache_sync = bool(_cfg_get(cfg, "force_cache_sync", False)) + self.download_videos = bool(_cfg_get(cfg, "download_videos", False)) + self.video_backend = _cfg_get(cfg, "video_backend", None) + self.max_episodes = self._optional_positive_int( + _cfg_get(cfg, "max_episodes", None), "max_episodes" + ) + self.max_rows = self._optional_positive_int( + _cfg_get(cfg, "max_rows", None), "max_rows" + ) + + self._configured_joint_names = self._read_optional_name_list( + ("joint_names",), ("dataset", "joint_names") + ) + self._configured_body_names = self._read_optional_name_list( + ("body_names",), ("dataset", "body_names") + ) + self._configured_site_names = self._read_optional_name_list( + ("site_names",), ("dataset", "site_names") + ) + self._joint_names: list[str] | None = ( + list(self._configured_joint_names) + if self._configured_joint_names is not None + else None + ) + self._body_names: list[str] | None = ( + list(self._configured_body_names) + if self._configured_body_names is not None + else None + ) + self._site_names: list[str] | None = ( + list(self._configured_site_names) + if self._configured_site_names is not None + else None + ) + self.dataset_joint_names = tuple( + _maybe_list_of_str(_cfg_get(cfg, "dataset_joint_names", None)) + or UNITREE_G1_WBT_29DOF_DATASET_JOINT_NAMES + ) + self.target_joint_names = tuple( + _maybe_list_of_str(_cfg_get(cfg, "target_joint_names", None)) or () + ) + self._joint_reorder_index: np.ndarray | None = None + self._validate_joint_name_config() + + self._source_rows = self._normalize_source_override(source) + self.sources = self._collect_sources() + self._available_keys: set[str] = set() + self._trajectory_output_fps: list[float] = [] + self._source_metadata: dict[str, dict[str, Any]] = {} + + self.logger.info( + "Initializing LeRobotLoader with %d source(s)", len(self.sources) + ) + self._trajectory_info_list, self._motion_info_dict = self._get_trajectories( + build_zarr_dataset=build_zarr_dataset, + path=zarr_path or kwargs.pop("path", None), + **kwargs, + ) + self._metadata = self._discover_metadata() + + def _resolve_key_candidates( + self, + *, + primary_key: str, + nested_path: Sequence[str], + defaults: Sequence[str], + allow_none: bool = False, + ) -> tuple[str, ...]: + value = _cfg_get(self.cfg, primary_key, _cfg_get_nested(self.cfg, nested_path)) + if value is None: + return () if allow_none and defaults == () else tuple(defaults) + if allow_none and str(value).lower() in {"", "none", "null"}: + return () + return tuple(str(item) for item in _as_list(value)) + + def _resolve_control_freq(self) -> float | None: + control_freq = _cfg_get(self.cfg, "control_freq", None) + if control_freq is None: + control_freq = _cfg_get(self.cfg, "output_fps", None) + + if control_freq is None: + sim_dt = _cfg_get_nested(self.cfg, ("sim", "dt"), None) + decimation = _cfg_get(self.cfg, "decimation", None) + if sim_dt is not None and decimation is not None: + control_freq = 1.0 / (float(sim_dt) * float(decimation)) + + if control_freq is None: + sim_dt = _cfg_get_nested(self.cfg, ("sim", "dt"), None) + n_substeps = _cfg_get(self.cfg, "n_substeps", None) + if sim_dt is not None and n_substeps is not None: + control_freq = 1.0 / (float(sim_dt) * float(n_substeps)) + + if control_freq is None: + return None + + control_freq = float(control_freq) + if control_freq <= 0.0: + raise ValueError("control_freq must be positive.") + return control_freq + + def _optional_path_str(self, value: Any) -> str | None: + if value is None: + return None + return str(Path(str(value)).expanduser()) + + def _optional_positive_int(self, value: Any, label: str) -> int | None: + if value is None: + return None + parsed = int(value) + if parsed <= 0: + raise ValueError(f"{label} must be positive when provided.") + return parsed + + def _read_optional_name_list(self, *paths: Sequence[str]) -> list[str] | None: + for path in paths: + value = _cfg_get_nested(self.cfg, path, None) + parsed = _maybe_list_of_str(value) + if parsed is not None: + return parsed + return None + + def _validate_joint_name_config(self) -> None: + if len(self.dataset_joint_names) == 0: + return + if len(set(self.dataset_joint_names)) != len(self.dataset_joint_names): + raise ValueError("dataset_joint_names must not contain duplicate names.") + if not self.target_joint_names: + return + if len(set(self.target_joint_names)) != len(self.target_joint_names): + raise ValueError("target_joint_names must not contain duplicate names.") + missing = [ + joint_name + for joint_name in self.target_joint_names + if joint_name not in self.dataset_joint_names + ] + extra = [ + joint_name + for joint_name in self.dataset_joint_names + if joint_name not in self.target_joint_names + ] + if missing or extra: + raise ValueError( + "target_joint_names must contain the same joints as " + f"dataset_joint_names; missing={missing}, extra={extra}." + ) + self._joint_reorder_index = np.asarray( + [ + self.dataset_joint_names.index(joint_name) + for joint_name in self.target_joint_names + ], + dtype=np.int64, + ) + + def _normalize_source_override( + self, + source: Iterable[Mapping[str, Any]] + | Mapping[str, Iterable[Mapping[str, Any]]] + | None, + ) -> dict[str, list[Mapping[str, Any]]] | None: + if source is None: + return None + if isinstance(source, Mapping): + return { + _sanitize_name(str(name)): list(rows) + for name, rows in source.items() + } + return {"in_memory": list(source)} + + def _collect_sources(self) -> list[LeRobotSource]: + if self._source_rows is not None: + return [ + LeRobotSource( + repo_id=name, + motion_name=name, + split=self.default_split, + root=None, + revision=None, + episodes=None, + streaming=False, + fps=self.default_fps, + max_episodes=self.max_episodes, + max_rows=self.max_rows, + source_rows_key=name, + ) + for name in self._source_rows + ] + + entries = self._collect_source_entries() + sources = [self._resolve_source_entry(entry) for entry in entries] + if not sources: + raise ValueError("No LeRobot sources resolved from config.") + return sources + + def _collect_source_entries(self) -> list[Any]: + candidates = [ + ("dataset", "trajectories", "lerobot"), + ("dataset", "lerobot"), + ("dataset", "repo_ids"), + ("dataset", "repo_id"), + ("repo_ids",), + ("repo_id",), + ] + for candidate in candidates: + value = _cfg_get_nested(self.cfg, candidate, None) + if value is not None: + return self._normalize_source_entries(value) + return [UNITREE_G1_WBT_DEFAULT_REPO_ID] + + def _normalize_source_entries(self, value: Any) -> list[Any]: + if isinstance(value, Mapping): + if self._looks_like_source_entry(value): + return [value] + return [ + {"name": str(name), "repo_id": repo_id} + for name, repo_id in value.items() + ] + return _as_list(value) + + def _looks_like_source_entry(self, value: Mapping[str, Any]) -> bool: + entry_keys = { + "repo_id", + "id", + "name", + "root", + "split", + "revision", + "episodes", + "streaming", + "fps", + "max_episodes", + "max_rows", + } + return any(key in value for key in entry_keys) + + def _resolve_source_entry(self, entry: Any) -> LeRobotSource: + if isinstance(entry, Mapping): + repo_id_value = entry.get("repo_id", entry.get("id", None)) + if repo_id_value is None: + raise ValueError("LeRobot source entry must include 'repo_id'.") + repo_id = str(repo_id_value) + motion_name = _sanitize_name(str(entry.get("name", repo_id))) + split = str(entry.get("split", self.default_split)) + root = self._optional_path_str(entry.get("root", self.default_root)) + revision_value = entry.get("revision", self.default_revision) + revision = None if revision_value is None else str(revision_value) + episodes = _maybe_list_of_int(entry.get("episodes", self.default_episodes)) + streaming = bool(entry.get("streaming", self.default_streaming)) + fps_value = entry.get("fps", None) + fps = None if fps_value is None else float(fps_value) + max_episodes = self._optional_positive_int( + entry.get("max_episodes", self.max_episodes), + "max_episodes", + ) + max_rows = self._optional_positive_int( + entry.get("max_rows", self.max_rows), + "max_rows", + ) + else: + repo_id = str(entry) + motion_name = _sanitize_name(repo_id) + split = self.default_split + root = self.default_root + revision = ( + None + if self.default_revision is None + else str(self.default_revision) + ) + episodes = self.default_episodes + streaming = self.default_streaming + fps = None + max_episodes = self.max_episodes + max_rows = self.max_rows + + if fps is not None and fps <= 0.0: + raise ValueError("source fps must be positive.") + + return LeRobotSource( + repo_id=repo_id, + motion_name=motion_name, + split=split, + root=root, + revision=revision, + episodes=episodes, + streaming=streaming, + fps=fps, + max_episodes=max_episodes, + max_rows=max_rows, + ) + + @property + def num_traj(self) -> int: + return len(self._trajectory_info_list) + + @property + def control_dt(self) -> float | list[float]: + if not self._trajectory_output_fps: + fps = self.control_freq or self.default_fps + return 1.0 / float(fps) + if len(set(self._trajectory_output_fps)) == 1: + return 1.0 / float(self._trajectory_output_fps[0]) + return [1.0 / float(fps) for fps in self._trajectory_output_fps] + + @property + def metadata(self) -> DatasetMeta: + return self._metadata + + def __len__(self) -> int: + return self.num_traj + + @property + def trajectory_info_list(self) -> list[TrajectoryEntry]: + return list(self._trajectory_info_list) + + @property + def motion_info_dict(self) -> MotionIndex: + return dict(self._motion_info_dict) + + def _get_trajectories( + self, + build_zarr_dataset: bool = False, + path: str | None = None, + **kwargs: Any, + ) -> tuple[list[TrajectoryEntry], MotionIndex]: + if build_zarr_dataset and path is None: + raise ValueError("path must be provided when build_zarr_dataset is True") + + trajectory_info_list: list[TrajectoryEntry] = [] + motion_info_dict: MotionIndex = {} + global_idx = 0 + + dataset_group: zarr.Group | None = None + if build_zarr_dataset: + chunk_size = int(kwargs.get("chunk_size", 64)) + shard_size = int(kwargs.get("shard_size", 512)) + overwrite = bool(kwargs.get("overwrite", False)) + os.makedirs(path, exist_ok=True) + store = LocalStore(path) + root = zarr.group(store=store, overwrite=overwrite) + if self.dataset_name in root: + if not overwrite: + raise ValueError( + f"Group '{self.dataset_name}' already exists in {path}. " + "Use overwrite=True to rebuild." + ) + del root[self.dataset_name] + dataset_group = root.create_group(self.dataset_name) + else: + chunk_size = 64 + shard_size = 512 + + motion_local_cursors: dict[str, int] = {} + motion_local_counts: dict[str, int] = {} + motion_groups: dict[str, zarr.Group] = {} + motion_metadata: dict[str, dict[str, Any]] = {} + + for source in self.sources: + motion_name = source.motion_name + motion_group = motion_groups.get(motion_name) + if ( + motion_group is None + and build_zarr_dataset + and dataset_group is not None + ): + motion_group = dataset_group.create_group(motion_name) + motion_groups[motion_name] = motion_group + + motion_entry = motion_info_dict.setdefault( + self.dataset_name, {} + ).setdefault( + motion_name, + { + "motion_name": motion_name, + "repo_ids": [], + "splits": [], + "trajectory_indices": [], + "trajectory_lengths": [], + "trajectory_local_start_indices": [], + "trajectory_local_end_indices": [], + "source_fps": [], + "output_fps": [], + "episode_indices": [], + "state_keys": [], + "action_keys": [], + }, + ) + if source.repo_id not in motion_entry["repo_ids"]: + motion_entry["repo_ids"].append(source.repo_id) + if source.split not in motion_entry["splits"]: + motion_entry["splits"].append(source.split) + + local_start_cursor = motion_local_cursors.get(motion_name, 0) + local_count = motion_local_counts.get(motion_name, 0) + + source_episode_count = 0 + for episode_id, rows, source_fps in self._iter_source_episodes(source): + if source.max_episodes is not None and source_episode_count >= int( + source.max_episodes + ): + break + traj_data, output_fps, state_key, action_key = self._load_episode_rows( + rows=rows, + source_fps=source_fps, + ) + if traj_data is None: + continue + self._available_keys.update(traj_data.keys()) + self._infer_or_validate_names(traj_data) + + traj_len = int(traj_data["qpos"].shape[0]) + local_start = local_start_cursor + local_end = local_start + traj_len + local_start_cursor = local_end + + traj_info = TrajectoryInfo( + dataset=self.dataset_name, + motion=motion_name, + motion_name=motion_name, + trajectory_index=global_idx, + trajectory_in_motion=local_count, + start=local_start, + end=local_end, + ) + trajectory_info_list.append(traj_info.to_dict()) + + motion_entry["trajectory_indices"].append(global_idx) + motion_entry["trajectory_lengths"].append(traj_len) + motion_entry["trajectory_local_start_indices"].append(local_start) + motion_entry["trajectory_local_end_indices"].append(local_end) + motion_entry["source_fps"].append(float(source_fps)) + motion_entry["output_fps"].append(float(output_fps)) + motion_entry["episode_indices"].append(int(episode_id)) + motion_entry["state_keys"].append(state_key) + motion_entry["action_keys"].append(action_key) + self._trajectory_output_fps.append(float(output_fps)) + + if motion_group is not None: + traj_group = motion_group.create_group(f"trajectory_{local_count}") + self._save_trajectory_data( + traj_group, + traj_data, + chunk_size=chunk_size, + shard_size=shard_size, + ) + + source_episode_count += 1 + local_count += 1 + global_idx += 1 + + motion_local_cursors[motion_name] = local_start_cursor + motion_local_counts[motion_name] = local_count + + if motion_group is not None: + motion_group.attrs["num_trajectories"] = local_count + motion_group.attrs["trajectory_lengths"] = motion_entry[ + "trajectory_lengths" + ] + motion_group.attrs["repo_ids"] = motion_entry["repo_ids"] + motion_group.attrs["splits"] = motion_entry["splits"] + motion_group.attrs["episode_indices"] = motion_entry["episode_indices"] + motion_group.attrs["source_fps"] = motion_entry["source_fps"] + motion_group.attrs["output_fps"] = motion_entry["output_fps"] + + motion_metadata[motion_name] = { + "repo_ids": motion_entry["repo_ids"], + "splits": motion_entry["splits"], + "num_trajectories": len(motion_entry["trajectory_indices"]), + "trajectory_lengths": motion_entry["trajectory_lengths"], + "episode_indices": motion_entry["episode_indices"], + } + + if not trajectory_info_list: + raise ValueError("No LeRobot trajectories were loaded.") + + if build_zarr_dataset and dataset_group is not None: + dataset_group.attrs["num_trajectories"] = len(trajectory_info_list) + dataset_group.attrs["trajectory_lengths"] = [ + e["length"] for e in trajectory_info_list + ] + dataset_group.attrs["keys"] = sorted(self._available_keys) + dataset_group.attrs["joint_names"] = self._joint_names or [] + dataset_group.attrs["body_names"] = self._body_names or [] + dataset_group.attrs["site_names"] = self._site_names or [] + dataset_group.attrs["dt"] = self.control_dt + dataset_group.attrs["control_freq"] = self._metadata_control_freq() + dataset_group.attrs["transition_format"] = "flat_next_keys_v1" + dataset_group.attrs["transition_keys"] = sorted( + key for key in self._available_keys if key.startswith("next_") + ) + dataset_group.attrs["motion_metadata"] = motion_metadata + dataset_group.attrs["source_metadata"] = self._source_metadata + dataset_group.attrs["trajectory_info_list"] = trajectory_info_list + dataset_group.attrs["motion_info_dict"] = motion_info_dict + self.logger.info("Saved trajectories to Zarr store at %s", path) + + self.logger.info( + "Built LeRobot trajectory manifest with %d entries across %d motions", + len(trajectory_info_list), + sum(len(motions) for motions in motion_info_dict.values()), + ) + return trajectory_info_list, motion_info_dict + + def _iter_source_episodes( + self, source: LeRobotSource + ) -> Iterator[tuple[int, list[Mapping[str, Any]], float]]: + iterable, source_fps = self._make_source_iterable(source) + source_fps = float( + source.fps or source_fps or self.control_freq or self.default_fps + ) + if source_fps <= 0.0: + raise ValueError(f"Invalid LeRobot source fps: {source_fps}.") + + current_episode_id: int | None = None + current_rows: list[Mapping[str, Any]] = [] + rows_seen = 0 + yielded = 0 + + for row in iterable: + rows_seen += 1 + if source.max_rows is not None and rows_seen > int(source.max_rows): + break + episode_id = int( + _to_numpy(_row_get(row, self.episode_key), dtype=None).reshape(-1)[0] + ) + if source.episodes is not None and episode_id not in source.episodes: + if current_rows and current_episode_id is not None: + yield ( + current_episode_id, + self._finalize_episode_rows(current_rows), + source_fps, + ) + yielded += 1 + current_rows = [] + current_episode_id = None + continue + if current_episode_id is None: + current_episode_id = episode_id + if episode_id != current_episode_id: + yield ( + current_episode_id, + self._finalize_episode_rows(current_rows), + source_fps, + ) + yielded += 1 + if source.max_episodes is not None and yielded >= int( + source.max_episodes + ): + return + current_rows = [] + current_episode_id = episode_id + current_rows.append(row) + + if current_rows and current_episode_id is not None: + yield ( + current_episode_id, + self._finalize_episode_rows(current_rows), + source_fps, + ) + + def _make_source_iterable( + self, source: LeRobotSource + ) -> tuple[Iterable[Mapping[str, Any]], float | None]: + if self._source_rows is not None: + if source.source_rows_key is None: + raise RuntimeError( + "Internal source_rows_key missing for in-memory source." + ) + return self._source_rows[source.source_rows_key], source.fps + + if source.streaming: + return self._make_streaming_lerobot_iterable(source), source.fps + return self._make_lerobot_dataset_iterable(source) + + def _make_lerobot_dataset_iterable( + self, source: LeRobotSource + ) -> tuple[Iterable[Mapping[str, Any]], float | None]: + try: + from lerobot.datasets import LeRobotDataset + except ImportError as exc: + try: + from lerobot.datasets.lerobot_dataset import LeRobotDataset + except ImportError: + raise ImportError( + "LeRobotLoader requires the optional 'lerobot' package. " + "Install iltools[lerobot] or install lerobot directly." + ) from exc + + kwargs = _filter_kwargs( + LeRobotDataset, + { + "root": source.root, + "episodes": list(source.episodes) + if source.episodes is not None + else None, + "revision": source.revision, + "force_cache_sync": self.force_cache_sync, + "download_videos": self.download_videos, + "video_backend": self.video_backend, + "split": source.split, + }, + ) + dataset = LeRobotDataset(source.repo_id, **kwargs) + self._source_metadata[source.motion_name] = self._extract_lerobot_metadata( + dataset + ) + fps = self._read_dataset_fps(dataset) + + def _iter_dataset_rows() -> Iterator[Mapping[str, Any]]: + for index in range(len(dataset)): + yield dataset[index] + + return _iter_dataset_rows(), fps + + def _make_streaming_lerobot_iterable( + self, source: LeRobotSource + ) -> Iterable[Mapping[str, Any]]: + try: + from lerobot.datasets import StreamingLeRobotDataset + except ImportError as exc: + try: + from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset + except ImportError: + raise ImportError( + "LeRobotLoader streaming mode requires " + "lerobot.datasets.StreamingLeRobotDataset." + ) from exc + kwargs = _filter_kwargs( + StreamingLeRobotDataset, + { + "split": source.split, + "revision": source.revision, + }, + ) + return StreamingLeRobotDataset(source.repo_id, **kwargs) + + def _read_dataset_fps(self, dataset: Any) -> float | None: + fps = getattr(dataset, "fps", None) + if fps is not None: + return float(fps) + meta = getattr(dataset, "meta", None) + fps = getattr(meta, "fps", None) + if fps is not None: + return float(fps) + if isinstance(meta, Mapping) and meta.get("fps") is not None: + return float(meta["fps"]) + return None + + def _extract_lerobot_metadata(self, dataset: Any) -> dict[str, Any]: + metadata: dict[str, Any] = {} + for attr in ("fps", "num_frames", "num_episodes", "features"): + value = getattr(dataset, attr, None) + if value is not None: + metadata[attr] = self._json_safe(value) + meta = getattr(dataset, "meta", None) + if meta is not None: + for attr in ("repo_id", "total_frames", "total_episodes"): + value = getattr(meta, attr, None) + if value is not None: + metadata[attr] = self._json_safe(value) + return metadata + + def _json_safe(self, value: Any) -> Any: + if isinstance(value, (str, int, float, bool)) or value is None: + return value + if isinstance(value, Mapping): + return {str(k): self._json_safe(v) for k, v in value.items()} + if isinstance(value, Sequence) and not isinstance(value, (str, bytes)): + return [self._json_safe(v) for v in value] + if hasattr(value, "tolist"): + return self._json_safe(value.tolist()) + return str(value) + + def _finalize_episode_rows( + self, rows: list[Mapping[str, Any]] + ) -> list[Mapping[str, Any]]: + if not rows: + return rows + rows = list(rows) + try: + rows.sort( + key=lambda row: int( + _to_numpy(_row_get(row, self.frame_key), dtype=None).reshape(-1)[0] + ) + ) + except KeyError: + pass + if self.frame_range is None: + return rows + start, end = self.frame_range + if end > len(rows): + raise ValueError( + f"frame_range {self.frame_range} exceeds episode length {len(rows)}." + ) + return rows[start - 1 : end] + + def _load_episode_rows( + self, + *, + rows: Sequence[Mapping[str, Any]], + source_fps: float, + ) -> tuple[dict[str, np.ndarray] | None, float, str, str | None]: + if len(rows) < 2: + if self.drop_short_episodes: + self.logger.debug("Skipping LeRobot episode with fewer than 2 rows.") + return None, source_fps, "", None + raise ValueError("A LeRobot episode must contain at least two rows.") + + state, state_key = _stack_rows( + rows, + self.state_key_candidates, + required=True, + dtype=np.float32, + ) + if state is None or state_key is None: + raise KeyError("Missing required LeRobot state field.") + if state.ndim != 2 or state.shape[-1] < 8: + raise ValueError( + f"{state_key} must have shape [T, >=8], got {tuple(state.shape)}." + ) + + action: np.ndarray | None + action_key: str | None + if self.action_key_candidates: + action, action_key = _stack_rows( + rows, + self.action_key_candidates, + required=False, + dtype=np.float32, + ) + else: + action, action_key = None, None + + episode_index, _ = _stack_rows( + rows, + (self.episode_key,), + required=True, + dtype=np.int64, + ) + frame_index, _ = _stack_rows( + rows, + (self.frame_key,), + required=False, + dtype=np.int64, + ) + timestamp, _ = _stack_rows( + rows, + (self.timestamp_key,), + required=False, + dtype=np.float32, + ) + + root_pos = state[:, :3].astype(np.float32) + root_quat = _normalize_quat_wxyz( + state[:, 3:7].astype(np.float32), self.quat_order + ) + if self.align_root_z_to_default: + root_pos = root_pos.copy() + root_pos[:, 2] += float(self.default_root_height) - float(root_pos[0, 2]) + joint_pos = self._reorder_joints(state[:, 7:].astype(np.float32)) + + target_root_pos = None + target_root_quat = None + target_joint_pos = None + if action is not None: + if action.ndim == 1: + action = action[:, None] + if action.ndim != 2: + raise ValueError( + f"{action_key} must have shape [T, A], got {tuple(action.shape)}." + ) + if action.shape[0] != state.shape[0]: + raise ValueError( + f"{action_key} length {action.shape[0]} does not match " + f"{state_key} length {state.shape[0]}." + ) + if action.shape[-1] == state.shape[-1]: + target_root_pos = action[:, :3].astype(np.float32) + target_root_quat = _normalize_quat_wxyz( + action[:, 3:7].astype(np.float32), self.quat_order + ) + target_joint_pos = self._reorder_joints( + action[:, 7:].astype(np.float32) + ) + elif action.shape[-1] == joint_pos.shape[-1]: + target_joint_pos = action.astype(np.float32) + + output_fps = float(self.control_freq or source_fps) + if output_fps <= 0.0: + raise ValueError("output_fps must be positive.") + ( + root_pos, + root_quat, + joint_pos, + action, + target_root_pos, + target_root_quat, + target_joint_pos, + episode_index, + frame_index, + timestamp, + ) = self._maybe_resample_episode( + source_fps=source_fps, + output_fps=output_fps, + root_pos=root_pos, + root_quat=root_quat, + joint_pos=joint_pos, + action=action, + target_root_pos=target_root_pos, + target_root_quat=target_root_quat, + target_joint_pos=target_joint_pos, + episode_index=episode_index, + frame_index=frame_index, + timestamp=timestamp, + ) + + root_lin_vel, root_ang_vel, joint_vel = self._compute_velocities( + root_pos=root_pos, + root_quat=root_quat, + joint_pos=joint_pos, + dt=1.0 / output_fps, + ) + traj_data = self._build_trajectory_dict( + root_pos=root_pos, + root_quat=root_quat, + joint_pos=joint_pos, + root_lin_vel=root_lin_vel, + root_ang_vel=root_ang_vel, + joint_vel=joint_vel, + episode_index=episode_index, + frame_index=frame_index, + timestamp=timestamp, + action=action, + target_root_pos=target_root_pos, + target_root_quat=target_root_quat, + target_joint_pos=target_joint_pos, + ) + return traj_data, output_fps, state_key, action_key + + def _reorder_joints(self, joint_data: np.ndarray) -> np.ndarray: + if self._joint_reorder_index is None: + return joint_data.astype(np.float32) + if joint_data.shape[-1] != len(self.dataset_joint_names): + raise ValueError( + "Cannot apply target_joint_names because joint width " + f"{joint_data.shape[-1]} does not match dataset_joint_names length " + f"{len(self.dataset_joint_names)}." + ) + return joint_data[:, self._joint_reorder_index].astype(np.float32) + + def _maybe_resample_episode( + self, + *, + source_fps: float, + output_fps: float, + root_pos: np.ndarray, + root_quat: np.ndarray, + joint_pos: np.ndarray, + action: np.ndarray | None, + target_root_pos: np.ndarray | None, + target_root_quat: np.ndarray | None, + target_joint_pos: np.ndarray | None, + episode_index: np.ndarray, + frame_index: np.ndarray | None, + timestamp: np.ndarray | None, + ) -> tuple[ + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray | None, + np.ndarray | None, + np.ndarray | None, + np.ndarray | None, + np.ndarray, + np.ndarray | None, + np.ndarray | None, + ]: + if root_pos.shape[0] == 0: + raise ValueError("Cannot load empty LeRobot episode.") + if root_pos.shape[0] == 1 or np.isclose(source_fps, output_fps): + return ( + root_pos.astype(np.float32), + root_quat.astype(np.float32), + joint_pos.astype(np.float32), + None if action is None else action.astype(np.float32), + None + if target_root_pos is None + else target_root_pos.astype(np.float32), + None + if target_root_quat is None + else target_root_quat.astype(np.float32), + None + if target_joint_pos is None + else target_joint_pos.astype(np.float32), + episode_index.astype(np.int64), + None if frame_index is None else frame_index.astype(np.int64), + None if timestamp is None else timestamp.astype(np.float32), + ) + + input_dt = 1.0 / float(source_fps) + output_dt = 1.0 / float(output_fps) + duration = (root_pos.shape[0] - 1) * input_dt + if duration <= 0.0: + raise ValueError("Cannot resample a zero-duration LeRobot episode.") + + times = np.arange(0.0, duration, output_dt, dtype=np.float64) + if times.size < 2: + times = np.array([0.0, duration], dtype=np.float64) + phase = times / duration + index_0 = np.floor(phase * (root_pos.shape[0] - 1)).astype(np.int64) + index_1 = np.minimum(index_0 + 1, root_pos.shape[0] - 1) + blend = (phase * (root_pos.shape[0] - 1) - index_0).astype(np.float32) + blend_col = blend[:, None] + + def maybe_lerp(value: np.ndarray | None) -> np.ndarray | None: + if value is None: + return None + return _lerp(value[index_0], value[index_1], blend_col).astype( + np.float32 + ) + + def maybe_slerp(value: np.ndarray | None) -> np.ndarray | None: + if value is None: + return None + return _quat_slerp_wxyz(value[index_0], value[index_1], blend) + + out_episode = np.full( + (times.shape[0], 1), int(episode_index[0, 0]), dtype=np.int64 + ) + out_frame = None + if frame_index is not None: + out_frame = np.rint( + _lerp(frame_index[index_0], frame_index[index_1], blend_col) + ).astype(np.int64) + out_timestamp = None + if timestamp is not None: + out_timestamp = _lerp( + timestamp[index_0], timestamp[index_1], blend_col + ).astype(np.float32) + + return ( + _lerp(root_pos[index_0], root_pos[index_1], blend_col).astype( + np.float32 + ), + _quat_slerp_wxyz(root_quat[index_0], root_quat[index_1], blend), + _lerp(joint_pos[index_0], joint_pos[index_1], blend_col).astype( + np.float32 + ), + maybe_lerp(action), + maybe_lerp(target_root_pos), + maybe_slerp(target_root_quat), + maybe_lerp(target_joint_pos), + out_episode, + out_frame, + out_timestamp, + ) + + def _compute_velocities( + self, + *, + root_pos: np.ndarray, + root_quat: np.ndarray, + joint_pos: np.ndarray, + dt: float, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + if dt <= 0.0: + raise ValueError("dt must be positive when computing velocities.") + if root_pos.shape[0] <= 1: + zeros_root = np.zeros_like(root_pos, dtype=np.float32) + zeros_joint = np.zeros_like(joint_pos, dtype=np.float32) + zeros_ang = np.zeros((root_pos.shape[0], 3), dtype=np.float32) + return zeros_root, zeros_ang, zeros_joint + root_lin_vel = np.gradient(root_pos, dt, axis=0).astype(np.float32) + joint_vel = np.gradient(joint_pos, dt, axis=0).astype(np.float32) + root_ang_vel = _so3_derivative_wxyz(root_quat, dt).astype(np.float32) + return root_lin_vel, root_ang_vel, joint_vel + + def _build_trajectory_dict( + self, + *, + root_pos: np.ndarray, + root_quat: np.ndarray, + joint_pos: np.ndarray, + root_lin_vel: np.ndarray, + root_ang_vel: np.ndarray, + joint_vel: np.ndarray, + episode_index: np.ndarray, + frame_index: np.ndarray | None, + timestamp: np.ndarray | None, + action: np.ndarray | None, + target_root_pos: np.ndarray | None, + target_root_quat: np.ndarray | None, + target_joint_pos: np.ndarray | None, + ) -> dict[str, np.ndarray]: + qpos = np.concatenate([root_pos, root_quat, joint_pos], axis=-1).astype( + np.float32 + ) + qvel = np.concatenate([root_lin_vel, root_ang_vel, joint_vel], axis=-1).astype( + np.float32 + ) + if qpos.shape[0] < 2: + raise ValueError("A LeRobot trajectory must contain at least two frames.") + + traj_data: dict[str, np.ndarray] = { + "qpos": qpos, + "qvel": qvel, + "root_pos": root_pos.astype(np.float32), + "root_quat": root_quat.astype(np.float32), + "root_lin_vel": root_lin_vel.astype(np.float32), + "root_ang_vel": root_ang_vel.astype(np.float32), + "joint_pos": joint_pos.astype(np.float32), + "joint_vel": joint_vel.astype(np.float32), + "episode_index": episode_index.astype(np.int64).reshape( + qpos.shape[0], -1 + ), + "next_qpos": qpos[1:].astype(np.float32), + "next_qvel": qvel[1:].astype(np.float32), + "next_root_pos": root_pos[1:].astype(np.float32), + "next_root_quat": root_quat[1:].astype(np.float32), + "next_root_lin_vel": root_lin_vel[1:].astype(np.float32), + "next_root_ang_vel": root_ang_vel[1:].astype(np.float32), + "next_joint_pos": joint_pos[1:].astype(np.float32), + "next_joint_vel": joint_vel[1:].astype(np.float32), + "next_episode_index": episode_index[1:] + .astype(np.int64) + .reshape(qpos.shape[0] - 1, -1), + } + if frame_index is not None: + frame_index = frame_index.astype(np.int64).reshape(qpos.shape[0], -1) + traj_data["frame_index"] = frame_index + traj_data["next_frame_index"] = frame_index[1:] + if timestamp is not None: + timestamp = timestamp.astype(np.float32).reshape(qpos.shape[0], -1) + traj_data["timestamp"] = timestamp + traj_data["next_timestamp"] = timestamp[1:] + if action is not None: + action = action.astype(np.float32) + traj_data["action"] = action + traj_data["next_action"] = action[1:] + if target_root_pos is not None: + traj_data["target_root_pos"] = target_root_pos.astype(np.float32) + traj_data["next_target_root_pos"] = target_root_pos[1:].astype(np.float32) + if target_root_quat is not None: + traj_data["target_root_quat"] = target_root_quat.astype(np.float32) + traj_data["next_target_root_quat"] = target_root_quat[1:].astype(np.float32) + if target_joint_pos is not None: + traj_data["target_joint_pos"] = target_joint_pos.astype(np.float32) + traj_data["next_target_joint_pos"] = target_joint_pos[1:].astype(np.float32) + if ( + target_root_pos is not None + and target_root_quat is not None + and target_joint_pos is not None + ): + target_qpos = np.concatenate( + [target_root_pos, target_root_quat, target_joint_pos], axis=-1 + ).astype(np.float32) + traj_data["target_qpos"] = target_qpos + traj_data["next_target_qpos"] = target_qpos[1:] + return traj_data + + def _infer_or_validate_names(self, traj_data: Mapping[str, np.ndarray]) -> None: + joint_count = int(traj_data["joint_pos"].shape[-1]) + if self._joint_names is None: + if joint_count == len(UNITREE_G1_WBT_29DOF_DATASET_JOINT_NAMES): + if self.target_joint_names: + self._joint_names = list(self.target_joint_names) + else: + self._joint_names = list(UNITREE_G1_WBT_29DOF_DATASET_JOINT_NAMES) + else: + self._joint_names = [f"joint_{index}" for index in range(joint_count)] + elif len(self._joint_names) != joint_count: + raise ValueError( + f"joint_names length mismatch: expected {joint_count}, " + f"got {len(self._joint_names)}." + ) + if self._body_names is None: + self._body_names = [] + if self._site_names is None: + self._site_names = [] + + def _save_trajectory_data( + self, + traj_group: zarr.Group, + traj_data: Mapping[str, np.ndarray], + *, + chunk_size: int, + shard_size: int, + ) -> None: + for key, value in traj_data.items(): + array = np.asarray(value) + if array.ndim == 0 or array.shape[0] == 0: + continue + chunks = [min(chunk_size, array.shape[0])] + list(array.shape[1:]) + shards = [min(shard_size, array.shape[0])] + list(array.shape[1:]) + ds = traj_group.create_array( + key, + shape=array.shape, + dtype=array.dtype, + chunks=chunks, + shards=shards, + ) + ds[:] = array + + def _discover_metadata(self) -> DatasetMeta: + trajectory_lengths = [int(e["length"]) for e in self._trajectory_info_list] + return DatasetMeta( + name=self.dataset_name, + source=self.dataset_source, + version="1.0.0", + citation=( + "LeRobot datasets loaded through the optional lerobot package and " + "converted to ILTools qpos/qvel trajectories." + ), + num_trajectories=len(self._trajectory_info_list), + keys=sorted(self._available_keys), + trajectory_lengths=trajectory_lengths, + dt=self.control_dt, + joint_names=self._joint_names or [], + body_names=self._body_names or [], + site_names=self._site_names or [], + metadata={ + "trajectory_info_list": self._trajectory_info_list, + "motion_info_dict": self._motion_info_dict, + "source_metadata": self._source_metadata, + "state_key_candidates": list(self.state_key_candidates), + "action_key_candidates": list(self.action_key_candidates), + "control_freq": self._metadata_control_freq(), + "sources": [ + { + "repo_id": source.repo_id, + "motion_name": source.motion_name, + "split": source.split, + "root": source.root, + "revision": source.revision, + "episodes": source.episodes, + "streaming": source.streaming, + "fps": source.fps, + } + for source in self.sources + ], + }, + ) + + def _metadata_control_freq(self) -> float | list[float]: + if not self._trajectory_output_fps: + return float(self.control_freq or self.default_fps) + if len(set(self._trajectory_output_fps)) == 1: + return float(self._trajectory_output_fps[0]) + return [float(fps) for fps in self._trajectory_output_fps] diff --git a/iltools/datasets/lerobot_stream.py b/iltools/datasets/lerobot_stream.py index c14e46b..6cadb26 100644 --- a/iltools/datasets/lerobot_stream.py +++ b/iltools/datasets/lerobot_stream.py @@ -24,7 +24,41 @@ logger = logging.getLogger(__name__) +# The public Unitree WBT datasets store robot_q_current[7:] and robot_q_desired[7:] +# in Unitree G1_29_JointIndex order. This matches Unitree's unitree_lerobot enum +# and the G1 29-DoF MuJoCo actuator order. UNITREE_G1_WBT_DEFAULT_REPO_ID = "unitreerobotics/G1_WBT_Brainco_Pickup_Pillow" +UNITREE_G1_WBT_29DOF_DATASET_JOINT_NAMES: tuple[str, ...] = ( + "left_hip_pitch_joint", + "left_hip_roll_joint", + "left_hip_yaw_joint", + "left_knee_joint", + "left_ankle_pitch_joint", + "left_ankle_roll_joint", + "right_hip_pitch_joint", + "right_hip_roll_joint", + "right_hip_yaw_joint", + "right_knee_joint", + "right_ankle_pitch_joint", + "right_ankle_roll_joint", + "waist_yaw_joint", + "waist_roll_joint", + "waist_pitch_joint", + "left_shoulder_pitch_joint", + "left_shoulder_roll_joint", + "left_shoulder_yaw_joint", + "left_elbow_joint", + "left_wrist_roll_joint", + "left_wrist_pitch_joint", + "left_wrist_yaw_joint", + "right_shoulder_pitch_joint", + "right_shoulder_roll_joint", + "right_shoulder_yaw_joint", + "right_elbow_joint", + "right_wrist_roll_joint", + "right_wrist_pitch_joint", + "right_wrist_yaw_joint", +) @dataclass(frozen=True) @@ -37,6 +71,10 @@ class UnitreeG1WBT29DofMapperConfig: dt: float = 1.0 / 30.0 default_joint_pos: Sequence[Any] = () action_scale: Sequence[float] = () + dataset_joint_names: Sequence[str] = UNITREE_G1_WBT_29DOF_DATASET_JOINT_NAMES + target_joint_names: Sequence[str] = () + align_root_z_to_default: bool = False + default_root_height: float = 0.0 quat_order: str = "wxyz" @@ -45,7 +83,14 @@ class LeRobotStreamingCacheConfig: """Runtime options for streaming LeRobot data into a TorchRL cache.""" repo_id: str = UNITREE_G1_WBT_DEFAULT_REPO_ID + """Primary LeRobot repo id. Used when ``repo_ids`` is empty.""" + + repo_ids: Sequence[str] = () + """Optional ordered list of LeRobot repo ids to stream sequentially.""" + split: str = "train" + """Dataset split shared by all configured repos.""" + cache_dir: str | Path = "/tmp/iltools_lerobot_torchrl_cache" max_cache_transitions: int = 5_000_000 min_ready_transitions: int = 100_000 @@ -54,6 +99,7 @@ class LeRobotStreamingCacheConfig: local_sample_prefetch: int = 0 batch_size: int | None = None max_episodes: int | None = None + max_episodes_per_repo: int | None = None mapper: UnitreeG1WBT29DofMapperConfig = UnitreeG1WBT29DofMapperConfig() @@ -132,6 +178,12 @@ def _get_required(mapping: Mapping[Any, Any] | TensorDictBase, key: str) -> Any: return value +def _get_optional(mapping: Mapping[Any, Any] | TensorDictBase, key: str) -> Any: + if isinstance(mapping, TensorDictBase): + return mapping.get(key) + return mapping.get(key) + + def _stack_rows(rows: Sequence[Mapping[str, Any]], keys: Sequence[str]) -> TensorDict: if len(rows) == 0: raise ValueError("Cannot stack an empty episode.") @@ -158,6 +210,23 @@ def __init__(self, config: UnitreeG1WBT29DofMapperConfig) -> None: self.config = config if float(config.dt) <= 0.0: raise ValueError("mapper.dt must be positive.") + self.dataset_joint_names = tuple(config.dataset_joint_names) + if len(self.dataset_joint_names) == 0: + self.dataset_joint_names = UNITREE_G1_WBT_29DOF_DATASET_JOINT_NAMES + self.target_joint_names = tuple(config.target_joint_names) + if len(self.target_joint_names) == 0: + self.target_joint_names = self.dataset_joint_names + self._validate_joint_names( + self.dataset_joint_names, label="dataset_joint_names" + ) + self._validate_joint_names(self.target_joint_names, label="target_joint_names") + self._dataset_to_target_index = torch.tensor( + [ + self.dataset_joint_names.index(joint_name) + for joint_name in self.target_joint_names + ], + dtype=torch.int64, + ) default_joint_pos = _to_tensor(config.default_joint_pos) if default_joint_pos.ndim == 1: self.default_joint_pos_pool = default_joint_pos.unsqueeze(0) @@ -185,6 +254,33 @@ def __init__(self, config: UnitreeG1WBT29DofMapperConfig) -> None: if torch.any(self.action_scale.abs() <= 1.0e-8): raise ValueError("action_scale must not contain zeros.") + def _validate_joint_names(self, joint_names: Sequence[str], *, label: str) -> None: + if len(joint_names) != self.joint_width: + raise ValueError(f"{label} must contain 29 joint names.") + if len(set(joint_names)) != len(joint_names): + raise ValueError(f"{label} must not contain duplicate joint names.") + if label == "target_joint_names": + missing = [ + joint_name + for joint_name in joint_names + if joint_name not in self.dataset_joint_names + ] + extra = [ + joint_name + for joint_name in self.dataset_joint_names + if joint_name not in joint_names + ] + if missing or extra: + raise ValueError( + "target_joint_names must contain the same joints as " + "dataset_joint_names; " + f"missing={missing}, extra={extra}." + ) + + def _dataset_joints_to_target_order(self, joints: Tensor) -> Tensor: + index = self._dataset_to_target_index.to(device=joints.device) + return joints.index_select(-1, index) + def map_episode( self, episode: TensorDictBase | Mapping[str, Any] | Sequence[Mapping[str, Any]] ) -> TensorDict: @@ -222,7 +318,9 @@ def map_episode( ) return self._map_batched_episode(episode_td) - def _episode_default_joint_pos(self, episode: TensorDictBase, like: Tensor) -> Tensor: + def _episode_default_joint_pos( + self, episode: TensorDictBase, like: Tensor + ) -> Tensor: pool = self.default_joint_pos_pool.to(device=like.device, dtype=like.dtype) if pool.shape[0] == 1: return pool[0] @@ -263,7 +361,14 @@ def _map_batched_episode(self, episode: TensorDictBase) -> TensorDict: robot_q_current[:, 3:7], self.config.quat_order ) root_pos = robot_q_current[:, :3] - joint_pos = robot_q_current[:, 7:] + if self.config.align_root_z_to_default: + root_pos = root_pos.clone() + default_root_height = root_pos.new_tensor( + float(self.config.default_root_height) + ) + root_pos[:, 2] += default_root_height - root_pos[0, 2] + joint_pos = self._dataset_joints_to_target_order(robot_q_current[:, 7:]) + joint_pos_desired = self._dataset_joints_to_target_order(robot_q_desired[:, 7:]) joint_vel = _finite_difference(joint_pos, self.config.dt) base_ang_vel = _so3_derivative_wxyz(root_quat, self.config.dt) expert_motion = torch.cat([joint_pos, joint_vel], dim=-1) @@ -277,7 +382,7 @@ def _map_batched_episode(self, episode: TensorDictBase) -> TensorDict: device=robot_q_current.device, dtype=robot_q_current.dtype, ) - expert_action = (robot_q_desired[:, 7:] - default_joint_pos) / action_scale + expert_action = (joint_pos_desired - default_joint_pos) / action_scale last_action = torch.cat( [torch.zeros_like(expert_action[:1]), expert_action[:-1]], dim=0 ) @@ -370,6 +475,15 @@ def __init__( self._stop_event = threading.Event() self._error: BaseException | None = None self._episodes_written = 0 + self._repos_completed = 0 + self.repo_ids = self._normalize_repo_ids(config) + + @staticmethod + def _normalize_repo_ids(config: LeRobotStreamingCacheConfig) -> tuple[str, ...]: + repo_ids = tuple(str(repo_id) for repo_id in config.repo_ids if str(repo_id)) + if repo_ids: + return repo_ids + return (str(config.repo_id),) @property def ready_transitions(self) -> int: @@ -398,9 +512,11 @@ def wait_until_ready(self, timeout_s: float | None = None) -> None: ) with self._condition: ready = self._condition.wait_for( - lambda: self.ready_transitions >= min_ready - or self._error is not None - or (self._thread is not None and not self._thread.is_alive()), + lambda: ( + self.ready_transitions >= min_ready + or self._error is not None + or (self._thread is not None and not self._thread.is_alive()) + ), timeout=float(timeout_s), ) if self._error is not None: @@ -433,9 +549,12 @@ def _source_iter(self) -> Iterator[Mapping[str, Any]]: if self.source is not None: yield from self.source return + + use_lerobot = True try: from lerobot.datasets import StreamingLeRobotDataset except ImportError: + use_lerobot = False try: from datasets import load_dataset except ImportError as exc: @@ -444,28 +563,56 @@ def _source_iter(self) -> Iterator[Mapping[str, Any]]: "Install iltools[lerobot], install lerobot directly, or install " "huggingface datasets." ) from exc - yield from load_dataset( - self.config.repo_id, - split=self.config.split, - streaming=True, - ) - return - yield from StreamingLeRobotDataset(self.config.repo_id) + + max_episodes_per_repo = self.config.max_episodes_per_repo + if max_episodes_per_repo is not None and int(max_episodes_per_repo) <= 0: + max_episodes_per_repo = None + + for repo_index, repo_id in enumerate(self.repo_ids): + if use_lerobot: + iterator = StreamingLeRobotDataset(repo_id) + else: + iterator = load_dataset( # type: ignore[name-defined] + repo_id, + split=self.config.split, + streaming=True, + ) + current_episode_id: int | None = None + episodes_seen = 0 + for row in iterator: + episode_id = int(_get_required(row, self.config.mapper.episode_key)) + if current_episode_id is None: + current_episode_id = episode_id + elif episode_id != current_episode_id: + episodes_seen += 1 + if max_episodes_per_repo is not None and episodes_seen >= int( + max_episodes_per_repo + ): + break + current_episode_id = episode_id + enriched_row = dict(row) + enriched_row["__lerobot_repo_id"] = repo_id + enriched_row["__lerobot_repo_index"] = repo_index + yield enriched_row + self._repos_completed += 1 def _producer_loop(self) -> None: try: - current_episode_id: int | None = None + current_episode_key: tuple[int, int] | None = None current_rows: list[Mapping[str, Any]] = [] for row in self._source_iter(): if self._stop_event.is_set(): break episode_id = int(_get_required(row, self.config.mapper.episode_key)) - if current_episode_id is None: - current_episode_id = episode_id - if episode_id != current_episode_id: + repo_index_like = _get_optional(row, "__lerobot_repo_index") + repo_index = int(repo_index_like) if repo_index_like is not None else 0 + episode_key = (repo_index, episode_id) + if current_episode_key is None: + current_episode_key = episode_key + if episode_key != current_episode_key: self._write_episode(current_rows) current_rows = [] - current_episode_id = episode_id + current_episode_key = episode_key if ( self.config.max_episodes is not None and self._episodes_written >= int(self.config.max_episodes) diff --git a/iltools/datasets/loaders.py b/iltools/datasets/loaders.py index 30df145..b8cb14e 100644 --- a/iltools/datasets/loaders.py +++ b/iltools/datasets/loaders.py @@ -29,6 +29,11 @@ class DatasetLoaderSpec: module="iltools.datasets.lafan1.loader", class_name="Lafan1CsvLoader", ), + "lerobot": DatasetLoaderSpec( + module="iltools.datasets.lerobot.loader", + class_name="LeRobotLoader", + optional_dependency="lerobot", + ), "loco_mujoco": DatasetLoaderSpec( module="iltools.datasets.loco_mujoco.loader", class_name="LocoMuJoCoLoader", diff --git a/tests/datasets/test_lafan1_csv_loader.py b/tests/datasets/test_lafan1_csv_loader.py index 4d5a2af..1e4e52b 100644 --- a/tests/datasets/test_lafan1_csv_loader.py +++ b/tests/datasets/test_lafan1_csv_loader.py @@ -13,7 +13,9 @@ def _write_motion_csv(path: Path, *, frames: int = 12, joints: int = 4) -> None: t = np.arange(frames, dtype=np.float32) root_pos = np.stack([0.1 * t, 0.05 * t, np.ones_like(t)], axis=1) - root_quat_xyzw = np.tile(np.array([0.0, 0.0, 0.0, 1.0], dtype=np.float32), (frames, 1)) + root_quat_xyzw = np.tile( + np.array([0.0, 0.0, 0.0, 1.0], dtype=np.float32), (frames, 1) + ) joint_pos = np.stack( [np.sin(0.15 * t + float(j)) for j in range(joints)], axis=1, @@ -22,7 +24,14 @@ def _write_motion_csv(path: Path, *, frames: int = 12, joints: int = 4) -> None: np.savetxt(path, motion, delimiter=",") -def _write_commands_style_npz(path: Path, *, frames: int = 10, joints: int = 6, fps: float = 50.0) -> None: +def _write_commands_style_npz( + path: Path, + *, + frames: int = 10, + joints: int = 6, + fps: float = 50.0, + joint_names: list[str] | None = None, +) -> None: t = np.arange(frames, dtype=np.float32) joint_pos = np.stack( [0.2 * np.sin(0.1 * t + i) for i in range(joints)], @@ -38,16 +47,18 @@ def _write_commands_style_npz(path: Path, *, frames: int = 10, joints: int = 6, body_lin_vel_w = np.gradient(body_pos_w, 1.0 / fps, axis=0).astype(np.float32) body_ang_vel_w = np.zeros((frames, 1, 3), dtype=np.float32) - np.savez( - path, - fps=np.array([fps], dtype=np.float32), - joint_pos=joint_pos, - joint_vel=joint_vel, - body_pos_w=body_pos_w, - body_quat_w=body_quat_w, - body_lin_vel_w=body_lin_vel_w, - body_ang_vel_w=body_ang_vel_w, - ) + payload = { + "fps": np.array([fps], dtype=np.float32), + "joint_pos": joint_pos, + "joint_vel": joint_vel, + "body_pos_w": body_pos_w, + "body_quat_w": body_quat_w, + "body_lin_vel_w": body_lin_vel_w, + "body_ang_vel_w": body_ang_vel_w, + } + if joint_names is not None: + payload["joint_names"] = np.asarray(joint_names, dtype=np.str_) + np.savez(path, **payload) def test_lafan1_csv_loader_builds_manifest_and_metadata(tmp_path: Path) -> None: @@ -120,7 +131,9 @@ def test_lafan1_csv_loader_writes_zarr(tmp_path: Path) -> None: assert len(loader.trajectory_info_list) == 2 -def test_lafan1_csv_loader_make_rb_preserves_transition_alignment(tmp_path: Path) -> None: +def test_lafan1_csv_loader_make_rb_preserves_transition_alignment( + tmp_path: Path, +) -> None: csv_a = tmp_path / "sequence_a.csv" csv_b = tmp_path / "sequence_b.csv" _write_motion_csv(csv_a, frames=12, joints=3) @@ -145,7 +158,9 @@ def test_lafan1_csv_loader_make_rb_preserves_transition_alignment(tmp_path: Path next_qpos = np.asarray(traj_group["next_qpos"][:]) np.testing.assert_allclose(first["qpos"].cpu().numpy(), qpos[0], atol=1e-6) - np.testing.assert_allclose(first["next_qpos"].cpu().numpy(), next_qpos[0], atol=1e-6) + np.testing.assert_allclose( + first["next_qpos"].cpu().numpy(), next_qpos[0], atol=1e-6 + ) np.testing.assert_allclose(first["next_qpos"].cpu().numpy(), qpos[1], atol=1e-6) @@ -178,6 +193,26 @@ def test_lafan1_commands_npz_input(tmp_path: Path) -> None: assert "body_pos_w" in loader.metadata.keys +def test_lafan1_npz_joint_names_metadata(tmp_path: Path) -> None: + npz_file = tmp_path / "named_commands_style.npz" + joint_names = [f"named_joint_{index}" for index in range(4)] + _write_commands_style_npz( + npz_file, + frames=7, + joints=4, + fps=30.0, + joint_names=joint_names, + ) + + cfg = { + "dataset": {"trajectories": {"lafan1_csv": [str(npz_file)]}}, + "control_freq": 30, + } + loader = Lafan1CsvLoader(cfg=cfg, build_zarr_dataset=False) + + assert loader.metadata.joint_names == joint_names + + def test_lafan1_csv_loader_honors_frame_range(tmp_path: Path) -> None: csv_file = tmp_path / "slice_test.csv" _write_motion_csv(csv_file, frames=14, joints=2) @@ -193,7 +228,9 @@ def test_lafan1_csv_loader_honors_frame_range(tmp_path: Path) -> None: assert loader.metadata.trajectory_lengths == [8] -def test_lafan1_csv_loader_groups_multiple_files_under_one_motion(tmp_path: Path) -> None: +def test_lafan1_csv_loader_groups_multiple_files_under_one_motion( + tmp_path: Path, +) -> None: dance_a = tmp_path / "dance_a.csv" dance_b = tmp_path / "dance_b.csv" walk_a = tmp_path / "walk_a.csv" @@ -206,7 +243,11 @@ def test_lafan1_csv_loader_groups_multiple_files_under_one_motion(tmp_path: Path "dataset": { "trajectories": { "lafan1_csv": [ - {"name": "dance_combo", "paths": [str(dance_a), str(dance_b)], "input_fps": 60}, + { + "name": "dance_combo", + "paths": [str(dance_a), str(dance_b)], + "input_fps": 60, + }, {"name": "walk_combo", "path": str(walk_a), "input_fps": 60}, ] } diff --git a/tests/datasets/test_lerobot_loader.py b/tests/datasets/test_lerobot_loader.py new file mode 100644 index 0000000..8a389c4 --- /dev/null +++ b/tests/datasets/test_lerobot_loader.py @@ -0,0 +1,154 @@ +from __future__ import annotations + +import sys +import types +from pathlib import Path + +import numpy as np +import zarr +from zarr.storage import LocalStore + +from iltools.datasets.lerobot.loader import LeRobotLoader +from iltools.datasets.utils import make_rb_from + + +def _make_lerobot_rows( + *, + episode_index: int, + frames: int, + joints: int = 29, +) -> list[dict[str, object]]: + t = np.arange(frames, dtype=np.float32) + root_pos = np.stack([0.05 * t, -0.02 * t, np.ones_like(t)], axis=1) + root_quat = np.tile( + np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), (frames, 1) + ) + joint_pos = np.stack( + [0.1 * np.sin(0.2 * t + float(index)) for index in range(joints)], + axis=1, + ).astype(np.float32) + q_current = np.concatenate([root_pos, root_quat, joint_pos], axis=1) + q_desired = q_current.copy() + q_desired[:, 7:] += 0.05 + + return [ + { + "episode_index": episode_index, + "frame_index": frame, + "timestamp": float(frame) / 30.0, + "observation.state.robot_q_current": q_current[frame], + "action.robot_q_desired": q_desired[frame], + } + for frame in range(frames) + ] + + +def test_lerobot_loader_builds_manifest_and_zarr_from_rows(tmp_path: Path) -> None: + rows = _make_lerobot_rows(episode_index=0, frames=6) + _make_lerobot_rows( + episode_index=1, frames=5 + ) + zarr_path = tmp_path / "lerobot.zarr" + cfg = { + "dataset_name": "unitree_lerobot", + "control_freq": 30, + } + + loader = LeRobotLoader( + cfg=cfg, + source=rows, + build_zarr_dataset=True, + zarr_path=str(zarr_path), + ) + + assert len(loader) == 2 + assert loader.metadata.name == "unitree_lerobot" + assert loader.metadata.dt == 1.0 / 30.0 + assert loader.motion_info_dict["unitree_lerobot"]["in_memory"][ + "episode_indices" + ] == [0, 1] + + store = LocalStore(str(zarr_path)) + root = zarr.group(store=store, overwrite=False) + dataset_group = root["unitree_lerobot"] + assert dataset_group.attrs["transition_format"] == "flat_next_keys_v1" + assert dataset_group.attrs["num_trajectories"] == 2 + + traj_group = dataset_group["in_memory"]["trajectory_0"] + qpos = np.asarray(traj_group["qpos"][:]) + qvel = np.asarray(traj_group["qvel"][:]) + next_qpos = np.asarray(traj_group["next_qpos"][:]) + action = np.asarray(traj_group["action"][:]) + target_joint_pos = np.asarray(traj_group["target_joint_pos"][:]) + + assert qpos.shape == (6, 36) + assert qvel.shape == (6, 35) + assert next_qpos.shape == (5, 36) + assert action.shape == (6, 36) + assert target_joint_pos.shape == (6, 29) + np.testing.assert_allclose(next_qpos, qpos[1:], atol=1e-6) + np.testing.assert_allclose(target_joint_pos, action[:, 7:], atol=1e-6) + + +def test_lerobot_loader_make_rb_alignment(tmp_path: Path) -> None: + rows = _make_lerobot_rows(episode_index=4, frames=7) + zarr_path = tmp_path / "lerobot_rb.zarr" + _ = LeRobotLoader( + cfg={"control_freq": 30}, + source=rows, + build_zarr_dataset=True, + zarr_path=str(zarr_path), + ) + + rb, info = make_rb_from(zarr_path=zarr_path, device="cpu", verbose_tree=False) + assert info["written"] == 6 + + store = LocalStore(str(zarr_path)) + root = zarr.group(store=store, overwrite=False) + traj_group = root["lerobot"]["in_memory"]["trajectory_0"] + qpos = np.asarray(traj_group["qpos"][:]) + next_qpos = np.asarray(traj_group["next_qpos"][:]) + + first = rb[0] + np.testing.assert_allclose(first["qpos"].cpu().numpy(), qpos[0], atol=1e-6) + np.testing.assert_allclose( + first["next_qpos"].cpu().numpy(), next_qpos[0], atol=1e-6 + ) + np.testing.assert_allclose(first["next_qpos"].cpu().numpy(), qpos[1], atol=1e-6) + + +def test_lerobot_loader_uses_lerobot_dataset_package(monkeypatch) -> None: + rows = _make_lerobot_rows(episode_index=2, frames=4) + calls: list[tuple[str, dict[str, object]]] = [] + + class FakeLeRobotDataset: + fps = 30 + features = {"observation.state.robot_q_current": {"shape": (36,)}} + + def __init__(self, repo_id: str, **kwargs: object) -> None: + calls.append((repo_id, kwargs)) + + def __len__(self) -> int: + return len(rows) + + def __getitem__(self, index: int) -> dict[str, object]: + return rows[index] + + lerobot_module = types.ModuleType("lerobot") + datasets_module = types.ModuleType("lerobot.datasets") + datasets_module.LeRobotDataset = FakeLeRobotDataset + monkeypatch.setitem(sys.modules, "lerobot", lerobot_module) + monkeypatch.setitem(sys.modules, "lerobot.datasets", datasets_module) + + loader = LeRobotLoader( + cfg={ + "repo_id": "fake/repo", + "root": "/tmp/fake_lerobot_root", + "episodes": [2], + }, + build_zarr_dataset=False, + ) + + assert len(loader) == 1 + assert calls[0][0] == "fake/repo" + assert calls[0][1]["root"] == "/tmp/fake_lerobot_root" + assert loader.metadata.metadata["source_metadata"]["fake_repo"]["fps"] == 30 diff --git a/tests/datasets/test_lerobot_stream.py b/tests/datasets/test_lerobot_stream.py index 58a1a59..f8a01c4 100644 --- a/tests/datasets/test_lerobot_stream.py +++ b/tests/datasets/test_lerobot_stream.py @@ -7,6 +7,7 @@ from iltools.datasets.lerobot_stream import ( LeRobotStreamingCacheConfig, StreamingTensorDictReplayCache, + UNITREE_G1_WBT_29DOF_DATASET_JOINT_NAMES, UnitreeG1WBT29DofMapper, UnitreeG1WBT29DofMapperConfig, ) @@ -152,6 +153,96 @@ def test_unitree_g1_wbt_mapper_selects_episode_default_from_pool() -> None: ) +def test_unitree_g1_wbt_mapper_reorders_dataset_joints_to_target_order() -> None: + dataset_joint_names = UNITREE_G1_WBT_29DOF_DATASET_JOINT_NAMES + target_joint_names = tuple(reversed(dataset_joint_names)) + default_joint_pos = torch.linspace(-0.2, 0.2, 29) + action_scale = torch.linspace(0.5, 1.5, 29) + mapper = UnitreeG1WBT29DofMapper( + UnitreeG1WBT29DofMapperConfig( + default_joint_pos=default_joint_pos.tolist(), + action_scale=action_scale.tolist(), + target_joint_names=target_joint_names, + ) + ) + + length = 4 + q_current = torch.zeros(length, 36) + q_current[:, 3:7] = torch.tensor([1.0, 0.0, 0.0, 0.0]) + dataset_joint_pos = torch.stack( + [ + torch.arange(29, dtype=torch.float32) + 100.0 * frame + for frame in range(length) + ] + ) + q_current[:, 7:] = dataset_joint_pos + expected_target_joint_pos = torch.flip(dataset_joint_pos, dims=[1]) + + expert_action = torch.stack( + [torch.linspace(-0.5, 0.5, 29) + 0.05 * float(frame) for frame in range(length)] + ) + q_desired = q_current.clone() + desired_target_order = default_joint_pos + expert_action * action_scale + q_desired[:, 7:] = torch.flip(desired_target_order, dims=[1]) + rows = [ + { + "episode_index": 0, + "observation.state.robot_q_current": q_current[index], + "action.robot_q_desired": q_desired[index], + } + for index in range(length) + ] + + transitions = mapper.map_episode(rows) + + torch.testing.assert_close( + transitions.get(("policy", "joint_pos")), + expected_target_joint_pos[:-1], + ) + torch.testing.assert_close(transitions["expert_action"], expert_action[:-1]) + torch.testing.assert_close( + transitions.get(("policy", "expert_motion"))[:, :29], + expected_target_joint_pos[:-1], + ) + + +def test_unitree_g1_wbt_mapper_aligns_first_root_z_to_default_height() -> None: + default_joint_pos = torch.zeros(29) + action_scale = torch.ones(29) + mapper = UnitreeG1WBT29DofMapper( + UnitreeG1WBT29DofMapperConfig( + default_joint_pos=default_joint_pos.tolist(), + action_scale=action_scale.tolist(), + align_root_z_to_default=True, + default_root_height=0.76, + ) + ) + length = 4 + q_current = torch.zeros(length, 36) + q_current[:, 2] = torch.tensor([0.5, 0.45, 0.4, 0.35]) + q_current[:, 3:7] = torch.tensor([1.0, 0.0, 0.0, 0.0]) + q_desired = q_current.clone() + rows = [ + { + "episode_index": 0, + "observation.state.robot_q_current": q_current[index], + "action.robot_q_desired": q_desired[index], + } + for index in range(length) + ] + + transitions = mapper.map_episode(rows) + + torch.testing.assert_close( + transitions.get(("policy", "root_pos"))[:, 2], + torch.tensor([0.76, 0.71, 0.66]), + ) + torch.testing.assert_close( + transitions.get(("next", "policy", "root_pos"))[:, 2], + torch.tensor([0.71, 0.66, 0.61]), + ) + + def test_unitree_g1_wbt_mapper_fails_fast_on_bad_robot_width() -> None: mapper = _make_mapper() episode = TensorDict( @@ -192,3 +283,38 @@ def test_streaming_cache_fills_memmap_before_sampling(tmp_path) -> None: assert ("policy", "base_ang_vel") in sample.keys(True) assert ("next", "policy", "joint_pos_rel") in sample.keys(True) assert "expert_action" in sample.keys(True) + + +def test_streaming_cache_keeps_same_episode_ids_separate_across_repos(tmp_path) -> None: + mapper = _make_mapper() + rows_a, _, _ = _make_fake_wbt_rows(episode_index=0, length=4) + rows_b, _, _ = _make_fake_wbt_rows(episode_index=0, length=4) + rows = [ + {**row, "__lerobot_repo_index": 0, "__lerobot_repo_id": "repo/a"} + for row in rows_a + ] + [ + {**row, "__lerobot_repo_index": 1, "__lerobot_repo_id": "repo/b"} + for row in rows_b + ] + cache = StreamingTensorDictReplayCache( + LeRobotStreamingCacheConfig( + repo_ids=("repo/a", "repo/b"), + cache_dir=tmp_path, + max_cache_transitions=16, + min_ready_transitions=6, + low_watermark=4, + max_episodes=2, + mapper=mapper.config, + ), + mapper=mapper, + source=rows, + ) + + cache.start() + cache.wait_until_ready(timeout_s=5.0) + sample = cache.sample(6) + cache.stop() + + assert sample.numel() == 6 + assert cache.ready_transitions == 6 + assert cache._episodes_written == 2 diff --git a/tests/datasets/test_loader_registry.py b/tests/datasets/test_loader_registry.py index 985923b..763e04a 100644 --- a/tests/datasets/test_loader_registry.py +++ b/tests/datasets/test_loader_registry.py @@ -11,6 +11,7 @@ def test_registered_loaders_do_not_import_optional_backends_on_listing(): names = registered_dataset_loaders() assert "lafan1_csv" in names + assert "lerobot" in names assert "loco_mujoco" in names @@ -20,6 +21,12 @@ def test_load_lafan1_csv_loader(): assert loader_cls.__name__ == "Lafan1CsvLoader" +def test_load_lerobot_loader(): + loader_cls = load_dataset_loader("lerobot") + + assert loader_cls.__name__ == "LeRobotLoader" + + def test_register_dataset_loader_import_target(): register_dataset_loader( "base_loader", From 2aa832b869e311cbeccc8e25ff7e39e07cd9fde5 Mon Sep 17 00:00:00 2001 From: Feiyang Wu Date: Fri, 22 May 2026 09:33:59 -0400 Subject: [PATCH 4/4] Add NPZ action labels to LAFAN loader --- iltools/datasets/lafan1/loader.py | 82 +++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/iltools/datasets/lafan1/loader.py b/iltools/datasets/lafan1/loader.py index 5cfe0ab..3478e96 100644 --- a/iltools/datasets/lafan1/loader.py +++ b/iltools/datasets/lafan1/loader.py @@ -723,6 +723,7 @@ def _load_csv_motion( root_ang_vel=root_ang_vel, joint_vel=joint_vel, extra_data=None, + action=None, ) return traj_data, source.input_fps, self.control_freq @@ -740,6 +741,8 @@ def _load_npz_motion( if source_fps <= 0.0: raise ValueError(f"Invalid source fps ({source_fps}) for {source.path}") + action = self._extract_npz_action_labels(arrays, source.path) + if "qpos" in arrays: qpos = arrays["qpos"].astype(np.float32) qpos = self._apply_frame_range(qpos, source.frame_range, source.path) @@ -809,6 +812,9 @@ def _load_npz_motion( if key in OPTIONAL_BODY_KEYS } + if action is not None: + action = self._apply_frame_range(action, source.frame_range, source.path) + root_quat = self._normalize_quat(root_quat.astype(np.float32)) needs_resample = not np.isclose(source_fps, self.control_freq) @@ -828,6 +834,7 @@ def _load_npz_motion( ) # Body states are not resampled here to avoid introducing FK assumptions. extra_data = None + action = None output_fps = self.control_freq else: output_fps = source_fps @@ -847,9 +854,70 @@ def _load_npz_motion( root_ang_vel=root_ang_vel, joint_vel=joint_vel, extra_data=extra_data, + action=action, ) return traj_data, source_fps, output_fps + def _extract_npz_action_labels( + self, + arrays: Mapping[str, np.ndarray], + source_path: Path, + ) -> np.ndarray | None: + action = arrays.get("action") + if action is not None: + return self._validate_npz_action_array( + np.asarray(action), + source_path=source_path, + key="action", + ) + + transition_action = arrays.get("transition_action") + if transition_action is None: + return None + + frame_count = self._infer_npz_frame_count(arrays, source_path) + transition_action = self._validate_npz_action_array( + np.asarray(transition_action), + source_path=source_path, + key="transition_action", + ) + if transition_action.shape[0] != frame_count - 1: + raise ValueError( + f"NPZ {source_path} has transition_action length " + f"{transition_action.shape[0]}, expected {frame_count - 1}." + ) + return np.concatenate([transition_action, transition_action[-1:]], axis=0) + + @staticmethod + def _validate_npz_action_array( + action: np.ndarray, + *, + source_path: Path, + key: str, + ) -> np.ndarray: + if action.ndim != 2: + raise ValueError( + f"NPZ {source_path} key '{key}' must have shape [T, A], " + f"got {tuple(action.shape)}." + ) + if action.shape[0] == 0 or action.shape[1] == 0: + raise ValueError( + f"NPZ {source_path} key '{key}' must be non-empty, " + f"got {tuple(action.shape)}." + ) + return action.astype(np.float32) + + @staticmethod + def _infer_npz_frame_count( + arrays: Mapping[str, np.ndarray], + source_path: Path, + ) -> int: + for key in ("qpos", "joint_pos", "body_pos_w"): + value = arrays.get(key) + if value is not None and value.ndim > 0: + return int(value.shape[0]) + raise ValueError(f"Could not infer frame count for NPZ {source_path}.") + def _apply_npz_name_metadata( self, *, arrays: Mapping[str, np.ndarray], path: Path ) -> None: @@ -960,6 +1028,7 @@ def _build_trajectory_dict( root_ang_vel: np.ndarray, joint_vel: np.ndarray, extra_data: Mapping[str, np.ndarray] | None, + action: np.ndarray | None, ) -> dict[str, np.ndarray]: qpos = np.concatenate([root_pos, root_quat, joint_pos], axis=-1).astype( np.float32 @@ -998,6 +1067,19 @@ def _build_trajectory_dict( continue traj_data[key] = value + if action is not None: + action = np.asarray(action, dtype=np.float32) + if action.shape[0] != qpos.shape[0]: + raise ValueError( + "Frame-aligned action labels must match qpos length: " + f"action={tuple(action.shape)}, qpos={tuple(qpos.shape)}." + ) + traj_data["action"] = action + traj_data["last_action"] = np.concatenate( + [np.zeros_like(action[:1]), action[:-1]], + axis=0, + ).astype(np.float32) + return traj_data def _apply_frame_range(