Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions iltools/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
from rich.console import Console
from rich.table import Table

from iltools.datasets.amass.loader import AmassLoader
from iltools.datasets.loco_mujoco.loader import LocoMuJoCoLoader
from iltools.datasets.trajopt.loader import TrajoptLoader
from iltools.datasets.loaders import load_dataset_loader, registered_dataset_loaders

app = typer.Typer(help="Imitation Learning Tools CLI")
console = Console()
Expand All @@ -23,17 +21,28 @@ def load(
Loads a dataset and prints its metadata.
"""
with console.status(f"[bold green]Loading {dataset_name}...[/bold green]"):
if dataset_name == "amass":
loader = AmassLoader(data_path, model_path)
elif dataset_name == "loco_mujoco":
loader = LocoMuJoCoLoader(
try:
loader_cls = load_dataset_loader(dataset_name)
except KeyError:
choices = ", ".join(registered_dataset_loaders())
console.print(
"[bold red]"
f"Unknown dataset: {dataset_name}. Choices: {choices}"
"[/bold red]"
)
raise typer.Exit(1) from None
except ImportError as exc:
raise typer.BadParameter(str(exc)) from exc

normalized_name = dataset_name.strip().lower().replace("-", "_")
if normalized_name == "loco_mujoco":
loader = loader_cls(
env_name="Humanoid", task="walk", control_freq=control_freq
)
elif dataset_name == "trajopt":
loader = TrajoptLoader(data_path)
elif normalized_name in {"lafan1", "lafan1_csv"}:
loader = loader_cls(data_path)
else:
console.print(f"[bold red]Unknown dataset: {dataset_name}[/bold red]")
raise typer.Exit(1)
loader = loader_cls(data_path)

num_trajectories = len(loader)
metadata = loader.metadata
Expand Down
8 changes: 8 additions & 0 deletions iltools/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,9 @@
"""Dataset package."""

from .loaders import ( # noqa: F401
DatasetLoaderSpec,
get_dataset_loader_spec,
load_dataset_loader,
register_dataset_loader,
registered_dataset_loaders,
)
190 changes: 173 additions & 17 deletions iltools/datasets/lafan1/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,9 @@ def _normalize_frame_range(value: Any) -> tuple[int, int] | None:
else:
seq = list(value)
if len(seq) != 2:
raise ValueError("frame_range must have exactly two elements: [start, end].")
raise ValueError(
"frame_range must have exactly two elements: [start, end]."
)
start, end = seq
start_i = int(start)
end_i = int(end)
Expand Down Expand Up @@ -315,7 +317,10 @@ def _normalize_source_entries(self, value: Any) -> list[Any]:
if self._looks_like_motion_entry(value):
return [value]
# Mapping style: motion_name -> path(s)
return [{"name": str(name), "path": path_spec} for name, path_spec in value.items()]
return [
{"name": str(name), "path": path_spec}
for name, path_spec in value.items()
]
return _as_list(value)

def _looks_like_motion_entry(self, value: Mapping[str, Any]) -> bool:
Expand Down Expand Up @@ -370,7 +375,9 @@ def _resolve_entry(self, entry: Any) -> list[MotionSource]:
"files",
entry.get(
"csv_files",
entry.get("path", entry.get("file", entry.get("csv_path", None))),
entry.get(
"path", entry.get("file", entry.get("csv_path", None))
),
),
),
)
Expand Down Expand Up @@ -514,7 +521,9 @@ def _get_trajectories(
if build_zarr_dataset and dataset_group is not None:
motion_group = dataset_group.create_group(motion_name)

motion_entry = motion_info_dict.setdefault(self.dataset_name, {}).setdefault(
motion_entry = motion_info_dict.setdefault(
self.dataset_name, {}
).setdefault(
motion_name,
{
"motion_name": motion_name,
Expand Down Expand Up @@ -579,7 +588,9 @@ def _get_trajectories(

if motion_group is not None:
motion_group.attrs["num_trajectories"] = len(motion_sources)
motion_group.attrs["trajectory_lengths"] = motion_entry["trajectory_lengths"]
motion_group.attrs["trajectory_lengths"] = motion_entry[
"trajectory_lengths"
]
motion_group.attrs["source_files"] = motion_entry["source_files"]
motion_group.attrs["source_fps"] = motion_entry["source_fps"]
motion_group.attrs["output_fps"] = motion_entry["output_fps"]
Expand Down Expand Up @@ -712,6 +723,7 @@ def _load_csv_motion(
root_ang_vel=root_ang_vel,
joint_vel=joint_vel,
extra_data=None,
action=None,
)
return traj_data, source.input_fps, self.control_freq

Expand All @@ -721,10 +733,16 @@ def _load_npz_motion(
with np.load(source.path) as npz_data:
arrays = {key: np.asarray(npz_data[key]) for key in npz_data.files}

source_fps = float(np.asarray(arrays.get("fps", source.input_fps)).reshape(-1)[0])
self._apply_npz_name_metadata(arrays=arrays, path=source.path)

source_fps = float(
np.asarray(arrays.get("fps", source.input_fps)).reshape(-1)[0]
)
if source_fps <= 0.0:
raise ValueError(f"Invalid source fps ({source_fps}) for {source.path}")

action = self._extract_npz_action_labels(arrays, source.path)

if "qpos" in arrays:
qpos = arrays["qpos"].astype(np.float32)
qpos = self._apply_frame_range(qpos, source.frame_range, source.path)
Expand Down Expand Up @@ -755,11 +773,17 @@ def _load_npz_motion(
f"NPZ {source.path} must contain 'qpos' or 'joint_pos'."
)
joint_pos = joint_pos.astype(np.float32)
joint_pos = self._apply_frame_range(joint_pos, source.frame_range, source.path)
joint_pos = self._apply_frame_range(
joint_pos, source.frame_range, source.path
)

root_pos, root_quat = self._extract_root_pose_from_npz(arrays, source.path)
root_pos = self._apply_frame_range(root_pos, source.frame_range, source.path)
root_quat = self._apply_frame_range(root_quat, source.frame_range, source.path)
root_pos = self._apply_frame_range(
root_pos, source.frame_range, source.path
)
root_quat = self._apply_frame_range(
root_quat, source.frame_range, source.path
)

joint_vel_raw = arrays.get("joint_vel")
root_lin_vel_raw, root_ang_vel_raw = self._extract_root_vel_from_npz(arrays)
Expand All @@ -769,12 +793,16 @@ def _load_npz_motion(
else None
)
root_lin_vel = (
self._apply_frame_range(root_lin_vel_raw, source.frame_range, source.path)
self._apply_frame_range(
root_lin_vel_raw, source.frame_range, source.path
)
if root_lin_vel_raw is not None
else None
)
root_ang_vel = (
self._apply_frame_range(root_ang_vel_raw, source.frame_range, source.path)
self._apply_frame_range(
root_ang_vel_raw, source.frame_range, source.path
)
if root_ang_vel_raw is not None
else None
)
Expand All @@ -784,6 +812,9 @@ def _load_npz_motion(
if key in OPTIONAL_BODY_KEYS
}

if action is not None:
action = self._apply_frame_range(action, source.frame_range, source.path)

root_quat = self._normalize_quat(root_quat.astype(np.float32))
needs_resample = not np.isclose(source_fps, self.control_freq)

Expand All @@ -803,6 +834,7 @@ def _load_npz_motion(
)
# Body states are not resampled here to avoid introducing FK assumptions.
extra_data = None
action = None
output_fps = self.control_freq
else:
output_fps = source_fps
Expand All @@ -822,9 +854,119 @@ def _load_npz_motion(
root_ang_vel=root_ang_vel,
joint_vel=joint_vel,
extra_data=extra_data,
action=action,
)
return traj_data, source_fps, output_fps

def _extract_npz_action_labels(
self,
arrays: Mapping[str, np.ndarray],
source_path: Path,
) -> np.ndarray | None:
action = arrays.get("action")
if action is not None:
return self._validate_npz_action_array(
np.asarray(action),
source_path=source_path,
key="action",
)

transition_action = arrays.get("transition_action")
if transition_action is None:
return None

frame_count = self._infer_npz_frame_count(arrays, source_path)
transition_action = self._validate_npz_action_array(
np.asarray(transition_action),
source_path=source_path,
key="transition_action",
)
if transition_action.shape[0] != frame_count - 1:
raise ValueError(
f"NPZ {source_path} has transition_action length "
f"{transition_action.shape[0]}, expected {frame_count - 1}."
)
return np.concatenate([transition_action, transition_action[-1:]], axis=0)

@staticmethod
def _validate_npz_action_array(
action: np.ndarray,
*,
source_path: Path,
key: str,
) -> np.ndarray:
if action.ndim != 2:
raise ValueError(
f"NPZ {source_path} key '{key}' must have shape [T, A], "
f"got {tuple(action.shape)}."
)
if action.shape[0] == 0 or action.shape[1] == 0:
raise ValueError(
f"NPZ {source_path} key '{key}' must be non-empty, "
f"got {tuple(action.shape)}."
)
return action.astype(np.float32)

@staticmethod
def _infer_npz_frame_count(
arrays: Mapping[str, np.ndarray],
source_path: Path,
) -> int:
for key in ("qpos", "joint_pos", "body_pos_w"):
value = arrays.get(key)
if value is not None and value.ndim > 0:
return int(value.shape[0])
raise ValueError(f"Could not infer frame count for NPZ {source_path}.")

def _apply_npz_name_metadata(
self, *, arrays: Mapping[str, np.ndarray], path: Path
) -> None:
self._apply_npz_names(
key="joint_names",
current=self._joint_names,
setter=lambda names: setattr(self, "_joint_names", names),
arrays=arrays,
path=path,
)
self._apply_npz_names(
key="body_names",
current=self._body_names,
setter=lambda names: setattr(self, "_body_names", names),
arrays=arrays,
path=path,
)
self._apply_npz_names(
key="site_names",
current=self._site_names,
setter=lambda names: setattr(self, "_site_names", names),
arrays=arrays,
path=path,
)

def _apply_npz_names(
self,
*,
key: str,
current: list[str] | None,
setter: Any,
arrays: Mapping[str, np.ndarray],
path: Path,
) -> None:
if key not in arrays:
return
value = np.asarray(arrays[key])
if value.ndim != 1:
raise ValueError(f"NPZ {path} {key} must be a 1D array.")
parsed = [str(item) for item in value.tolist()]
if current is None:
setter(parsed)
return
if parsed != current:
raise ValueError(
f"NPZ {path} {key} does not match configured {key}: "
f"expected={current}, got={parsed}."
)

def _extract_root_pose_from_npz(
self, arrays: Mapping[str, np.ndarray], path: Path
) -> tuple[np.ndarray, np.ndarray]:
Expand Down Expand Up @@ -886,13 +1028,14 @@ def _build_trajectory_dict(
root_ang_vel: np.ndarray,
joint_vel: np.ndarray,
extra_data: Mapping[str, np.ndarray] | None,
action: np.ndarray | None,
) -> dict[str, np.ndarray]:
qpos = np.concatenate([root_pos, root_quat, joint_pos], axis=-1).astype(
np.float32
)
qvel = np.concatenate(
[root_lin_vel, root_ang_vel, joint_vel], axis=-1
).astype(np.float32)
qvel = np.concatenate([root_lin_vel, root_ang_vel, joint_vel], axis=-1).astype(
np.float32
)

traj_data: dict[str, np.ndarray] = {
"qpos": qpos,
Expand Down Expand Up @@ -924,6 +1067,19 @@ def _build_trajectory_dict(
continue
traj_data[key] = value

if action is not None:
action = np.asarray(action, dtype=np.float32)
if action.shape[0] != qpos.shape[0]:
raise ValueError(
"Frame-aligned action labels must match qpos length: "
f"action={tuple(action.shape)}, qpos={tuple(qpos.shape)}."
)
traj_data["action"] = action
traj_data["last_action"] = np.concatenate(
[np.zeros_like(action[:1]), action[:-1]],
axis=0,
).astype(np.float32)

return traj_data

def _apply_frame_range(
Expand Down Expand Up @@ -1124,9 +1280,9 @@ def _quat_slerp_batch(
theta = theta_0 * t[~linear_mask]
s0 = np.sin(theta_0 - theta) / np.maximum(sin_theta_0, EPS)
s1 = np.sin(theta) / np.maximum(sin_theta_0, EPS)
out[~linear_mask] = qa[~linear_mask] * s0[:, None] + qb[~linear_mask] * s1[
:, None
]
out[~linear_mask] = (
qa[~linear_mask] * s0[:, None] + qb[~linear_mask] * s1[:, None]
)
out[~linear_mask] = self._normalize_quat(out[~linear_mask])
return out.astype(np.float32)

Expand Down
5 changes: 5 additions & 0 deletions iltools/datasets/lerobot/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from __future__ import annotations

from .loader import LeRobotLoader

__all__ = ["LeRobotLoader"]
Loading