Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 23 additions & 20 deletions frontend/src/pages/Training.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -199,28 +199,31 @@ const ConfigurationMode: React.FC = () => {
return;
}

// Pre-flight: smolvla/pi0/diffusion need an optional package. Catch it here
// with a one-click installer instead of a buried ImportError after the job
// has already started.
try {
const r = await fetchWithHeaders(
`${baseUrl}/system/policy-extra/${trainingConfig.policy_type}`,
);
if (r.ok) {
const extra = await r.json();
if (extra.needs_extra && !extra.available) {
setPolicyExtra({
policyType: trainingConfig.policy_type,
packageName: extra.package,
installTarget: extra.install_target,
installHint: extra.install_hint,
});
return;
// Pre-flight: smolvla/pi0/diffusion need an optional package installed
// locally. Catch it here with a one-click installer instead of a buried
// ImportError after the job has already started. Cloud jobs run in their
// own environment, so the local package is irrelevant — skip the check.
if (trainingConfig.target.runner === "local") {
try {
const r = await fetchWithHeaders(
`${baseUrl}/system/policy-extra/${trainingConfig.policy_type}`,
);
if (r.ok) {
const extra = await r.json();
if (extra.needs_extra && !extra.available) {
setPolicyExtra({
policyType: trainingConfig.policy_type,
packageName: extra.package,
installTarget: extra.install_target,
installHint: extra.install_hint,
});
return;
}
}
} catch {
// Check failed (offline / older backend) — fall through and let the
// job report any problem itself.
}
} catch {
// Check failed (offline / older backend) — fall through and let the job
// report any problem itself.
}

setIsStarting(true);
Expand Down
167 changes: 100 additions & 67 deletions lelab/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import sys
import threading
import time
from collections.abc import Callable
from collections.abc import Callable, Iterable
from datetime import datetime
from pathlib import Path
from queue import Empty, Queue
Expand Down Expand Up @@ -208,11 +208,14 @@ def parse_metrics_into(line: str, metrics: TrainingMetrics) -> None:
logger.debug("Error parsing log line %r: %s", line, exc)


class LocalJobRunner:
"""Run a training as a local subprocess.
class SubprocessJobRunner:
"""Spawn a subprocess and pump its stdout into a log file + in-memory queue.

The runner is single-shot: instantiate a fresh one per job. Lifetime of
the underlying subprocess is bounded by this object's existence in memory.
The shared engine behind both LocalJobRunner (which runs `lerobot-train`
directly) and HfCloudJobRunner (which runs `lerobot-train --job.target=...`,
a local process that submits the job and streams the remote logs to its own
stdout). Subclasses override `_on_line` to inspect each stdout line for
runner-specific markers (e.g. the HF job id / page URL).
"""

def __init__(
Expand All @@ -229,27 +232,21 @@ def __init__(
self._log_file = None # type: ignore[assignment]
self._wandb_run_url: str | None = None

def start(
self,
job_id: str,
config: TrainingRequest,
output_dir: str,
) -> None:
if self._process is not None:
raise RuntimeError("LocalJobRunner already started")

# Build the command via the helper that lives in train.py.
from .train import build_training_command # avoid import cycle at module load

cmd = build_training_command(config, output_dir, sys.executable)
logger.info("Starting job %s: %s", job_id, " ".join(cmd))

# Open the persistent log sink (one JSON line per stdout line). Held
# open for the subprocess's lifetime so we don't reopen per write.
def _open_log_file(self) -> None:
"""Open the persistent log sink (one JSON line per consumed line).
Held open for the consumer thread's lifetime so we don't reopen per
write; _consume_lines closes it when its iterator is exhausted."""
if self._log_file_path is not None:
self._log_file_path.parent.mkdir(parents=True, exist_ok=True)
self._log_file = self._log_file_path.open("a", buffering=1)

def _spawn(self, cmd: list[str], thread_name: str) -> None:
"""Open the log sink, launch `cmd`, and start the stdout pump thread."""
if self._process is not None:
raise RuntimeError(f"{type(self).__name__} already started")

self._open_log_file()

# PYTHONUNBUFFERED makes the child's stdout flush per line. Without it
# block-buffering hides log lines from our parser for many seconds.
child_env = os.environ.copy()
Expand All @@ -259,7 +256,7 @@ def start(
# group. Without it, signals sent to the uvicorn worker (e.g. when
# --reload restarts it on a .py file change) cascade to the child
# and kill the training. With it, the child survives reloads; the
# next worker re-attaches via TailingJobRunner using job.json's pid.
# next worker re-attaches via the reattach path.
self._process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
Expand All @@ -270,9 +267,7 @@ def start(
start_new_session=True,
)

self._monitor_thread = threading.Thread(
target=self._pump_stdout, name=f"job-{job_id}-stdout", daemon=True
)
self._monitor_thread = threading.Thread(target=self._pump_stdout, name=thread_name, daemon=True)
self._monitor_thread.start()

def pid(self) -> int | None:
Expand Down Expand Up @@ -316,10 +311,17 @@ def wandb_run_url(self) -> str | None:

# -- internals --

def _pump_stdout(self) -> None:
assert self._process is not None
def _on_line(self, line: str) -> None:
"""Hook for subclasses to inspect each stdout line. Default: no-op."""

def _consume_lines(self, lines: Iterable[str]) -> None:
"""Drive each text line through the metric/marker parse + log.jsonl
append + in-memory queue. Source-agnostic: a subprocess's stdout
(LocalJobRunner) or a remote log stream iterator (cloud reattach) feed
the same pipeline. Closes the log file when the iterator is exhausted.
"""
try:
for line in iter(self._process.stdout.readline, ""):
for line in lines:
if self._stop_event.is_set():
break
stripped = line.rstrip()
Expand All @@ -330,25 +332,51 @@ def _pump_stdout(self) -> None:
url = extract_wandb_run_url(stripped)
if url is not None:
self._wandb_run_url = url
self._on_line(stripped)
log_line = LogLine(timestamp=time.time(), message=stripped)
if self._log_file is not None:
try:
self._log_file.write(log_line.model_dump_json() + "\n")
except Exception as exc: # pragma: no cover — best-effort persist
logger.exception("Error writing to log file: %s", exc)
# Cap queue so a chatty subprocess can't grow memory unbounded.
# Cap queue so a chatty source can't grow memory unbounded.
if self._log_queue.qsize() >= 1000:
with contextlib.suppress(Empty):
self._log_queue.get_nowait()
self._log_queue.put(log_line)
except Exception as exc:
logger.exception("Error reading subprocess stdout: %s", exc)
logger.exception("Error consuming log lines: %s", exc)
finally:
if self._log_file is not None:
with contextlib.suppress(Exception):
self._log_file.close()
self._log_file = None

def _pump_stdout(self) -> None:
assert self._process is not None
self._consume_lines(iter(self._process.stdout.readline, ""))


class LocalJobRunner(SubprocessJobRunner):
"""Run a training as a local subprocess.

The runner is single-shot: instantiate a fresh one per job. Lifetime of
the underlying subprocess is bounded by this object's existence in memory.
"""

def start(
self,
job_id: str,
config: TrainingRequest,
output_dir: str,
) -> None:
# Build the command via the helper that lives in train.py.
from .train import build_training_command # avoid import cycle at module load

cmd = build_training_command(config, output_dir, sys.executable)
logger.info("Starting job %s: %s", job_id, " ".join(cmd))
self._spawn(cmd, thread_name=f"job-{job_id}-stdout")


class TailingJobRunner:
"""Re-attaches to a detached subprocess after a uvicorn reload.
Expand Down Expand Up @@ -551,17 +579,6 @@ def _list_imported_hub(api, repo_id: str) -> list[JobCheckpoint]:
return []


def _list_hub_checkpoints(api, repo_id: str) -> list[JobCheckpoint]:
"""List checkpoints by introspecting the model repo file tree."""
try:
files = api.list_repo_files(repo_id, repo_type="model")
except Exception:
# Repo may not exist yet (training just started, sidecar hasn't
# uploaded anything). Treat as no checkpoints.
return []
return _hub_checkpoints_from_files(files, repo_id)


_LANGUAGE_CONDITIONED_POLICY_TYPES = {"smolvla", "pi0", "pi0_fast", "pi05"}


Expand Down Expand Up @@ -829,15 +846,12 @@ def start(self, config: TrainingRequest, target: JobTarget | None = None) -> Job
self._persist(record, force=True)
raise

# Capture runner-specific identifiers.
# Capture runner-specific identifiers. For cloud jobs the HF job id
# / page URL / model repo are printed by lerobot's submit_to_hf and
# only appear in stdout a few seconds after start, so they're None
# here; the watchdog (_tick) parses and persists them once they land.
if target.runner == "local":
record.process_pid = runner.pid()
else:
record.hf_job_id = runner.hf_job_id()
record.hf_job_url = runner.hf_job_url()
# config was mutated by HfCloudJobRunner.start to set
# policy_repo_id; mirror it onto the record for the UI.
record.hf_repo_id = config.policy_repo_id

self._persist(record, force=True)
self._runners[job_id] = runner
Expand Down Expand Up @@ -1003,10 +1017,13 @@ def read_metrics_history(self, job_id: str) -> builtins.list[MetricsHistoryPoint
def _checkpoints_for(self, record: JobRecord) -> builtins.list[JobCheckpoint]:
if record.runner == "imported":
if record.hf_repo_id:
return self._list_cloud_cached(record.hf_repo_id, _list_imported_hub)
return self._list_cloud_cached(record.hf_repo_id)
return _list_imported_local(record.output_dir)
if record.runner == "local":
return _list_local_checkpoints(record.output_dir)
# Cloud: _list_imported_hub prefers the checkpoints/<step>/ tree (pushed when
# save_checkpoint_to_hub is on) and falls back to the final model at the repo
# root, so a finished run is always reachable even with no per-step tree.
return self._list_cloud_cached(record.hf_repo_id)

def list_checkpoints(self, job_id: str) -> builtins.list[JobCheckpoint]:
Expand All @@ -1021,12 +1038,10 @@ def list_checkpoints(self, job_id: str) -> builtins.list[JobCheckpoint]:
raise JobNotFoundError(job_id)
return self._checkpoints_for(record)

def _list_cloud_cached(
self, repo_id: str | None, fetch=_list_hub_checkpoints
) -> builtins.list[JobCheckpoint]:
"""30s-TTL cache over a hub checkpoint listing. `fetch(api, repo_id)`
defaults to the training-job tree scan; imported hub models pass
`_list_imported_hub` so they share the same cache + rate-limit budget."""
def _list_cloud_cached(self, repo_id: str | None) -> builtins.list[JobCheckpoint]:
"""30s-TTL cache over the hub checkpoint listing (`_list_imported_hub`:
the checkpoints/<step>/ tree, else the root model). All hub listings —
cloud-trained and imported alike — share this cache + rate-limit budget."""
if not repo_id:
return []
now = time.time()
Expand All @@ -1035,7 +1050,7 @@ def _list_cloud_cached(
return cached[1]
from .utils.hf_auth import shared_hf_api # lazy: keeps unit-test imports cheap

result = fetch(shared_hf_api(), repo_id)
result = _list_imported_hub(shared_hf_api(), repo_id)
self._cloud_ckpt_cache[repo_id] = (now + _CLOUD_CKPT_TTL_SECONDS, result)
return result

Expand Down Expand Up @@ -1188,6 +1203,10 @@ def _tick(self) -> None:
with self._lock:
record.wandb_run_url = url
self._persist(record, force=True)
# Cloud jobs print their HF job id / page URL / model repo a few
# seconds after start; capture them onto the record once parsed.
if record.runner == "hf_cloud":
self._sync_cloud_ids(record, runner)
# Persist metric snapshot at most once per second.
self._persist(record, force=False)
progress_snapshots.append(
Expand All @@ -1202,6 +1221,10 @@ def _tick(self) -> None:
continue

# Subprocess exited since the last tick. Finalise.
# Capture any cloud ids printed right before exit (e.g. a job that
# submitted then failed fast) so checkpoint listing has the repo id.
if record.runner == "hf_cloud":
self._sync_cloud_ids(record, runner)
rc = runner.returncode()
with self._lock:
if record.wandb_run_url is None:
Expand All @@ -1210,22 +1233,31 @@ def _tick(self) -> None:
record.ended_at = time.time()
record.exit_code = rc
if rc != 0 and record.error_message is None:
# Prefer a runner-supplied reason (e.g. HF Jobs'
# 'Job timeout') over the synthetic exit-code message.
reason = None
get_message = getattr(runner, "terminal_message", None)
if callable(get_message):
try:
reason = get_message()
except Exception:
reason = None
record.error_message = reason or f"Subprocess exited with code {rc}"
record.error_message = f"Job exited with code {rc}"
self._runners.pop(jid, None)
self._persist(record, force=True)
self._notify_change()

self._notify_progress(progress_snapshots)

def _sync_cloud_ids(self, record: JobRecord, runner: JobRunner) -> None:
"""Copy HF job id / page URL / model repo from a cloud runner onto the
record once lerobot's submit_to_hf has printed them. Persists on first
appearance so the ids survive a uvicorn --reload (which drives reattach).
"""
changed = False
for attr in ("hf_job_id", "hf_job_url", "hf_repo_id"):
if getattr(record, attr) is not None:
continue
getter = getattr(runner, attr, None)
value = getter() if callable(getter) else None
if value is not None:
with self._lock:
setattr(record, attr, value)
changed = True
if changed:
self._persist(record, force=True)

def _persist(self, record: JobRecord, force: bool) -> None:
now = time.time()
last = self._last_persist_at.get(record.id, 0.0)
Expand Down Expand Up @@ -1262,6 +1294,7 @@ def _write_meta(self, record: JobRecord) -> None:
"JobCheckpoint",
"MetricsHistoryPoint",
"JobRunner",
"SubprocessJobRunner",
"LocalJobRunner",
"JobRegistry",
"JobAlreadyRunningError",
Expand Down
6 changes: 4 additions & 2 deletions lelab/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,8 @@ def record_with_web_events(cfg: RecordConfig, web_events: dict) -> LeRobotDatase
cfg.dataset.repo_id,
root=cfg.dataset.root,
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
vcodec=cfg.dataset.vcodec,
rgb_encoder=cfg.dataset.rgb_encoder,
depth_encoder=cfg.dataset.depth_encoder,
streaming_encoding=cfg.dataset.streaming_encoding,
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
encoder_threads=cfg.dataset.encoder_threads,
Expand All @@ -622,7 +623,8 @@ def record_with_web_events(cfg: RecordConfig, web_events: dict) -> LeRobotDatase
image_writer_processes=cfg.dataset.num_image_writer_processes,
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
vcodec=cfg.dataset.vcodec,
rgb_encoder=cfg.dataset.rgb_encoder,
depth_encoder=cfg.dataset.depth_encoder,
streaming_encoding=cfg.dataset.streaming_encoding,
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
encoder_threads=cfg.dataset.encoder_threads,
Expand Down
Loading
Loading