diff --git a/frontend/src/pages/Training.tsx b/frontend/src/pages/Training.tsx index bcdedcc..a361d04 100644 --- a/frontend/src/pages/Training.tsx +++ b/frontend/src/pages/Training.tsx @@ -199,28 +199,31 @@ const ConfigurationMode: React.FC = () => { return; } - // Pre-flight: smolvla/pi0/diffusion need an optional package. Catch it here - // with a one-click installer instead of a buried ImportError after the job - // has already started. - try { - const r = await fetchWithHeaders( - `${baseUrl}/system/policy-extra/${trainingConfig.policy_type}`, - ); - if (r.ok) { - const extra = await r.json(); - if (extra.needs_extra && !extra.available) { - setPolicyExtra({ - policyType: trainingConfig.policy_type, - packageName: extra.package, - installTarget: extra.install_target, - installHint: extra.install_hint, - }); - return; + // Pre-flight: smolvla/pi0/diffusion need an optional package installed + // locally. Catch it here with a one-click installer instead of a buried + // ImportError after the job has already started. Cloud jobs run in their + // own environment, so the local package is irrelevant — skip the check. + if (trainingConfig.target.runner === "local") { + try { + const r = await fetchWithHeaders( + `${baseUrl}/system/policy-extra/${trainingConfig.policy_type}`, + ); + if (r.ok) { + const extra = await r.json(); + if (extra.needs_extra && !extra.available) { + setPolicyExtra({ + policyType: trainingConfig.policy_type, + packageName: extra.package, + installTarget: extra.install_target, + installHint: extra.install_hint, + }); + return; + } } + } catch { + // Check failed (offline / older backend) — fall through and let the + // job report any problem itself. } - } catch { - // Check failed (offline / older backend) — fall through and let the job - // report any problem itself. } setIsStarting(true); diff --git a/lelab/jobs.py b/lelab/jobs.py index cf6706a..10810a1 100644 --- a/lelab/jobs.py +++ b/lelab/jobs.py @@ -30,7 +30,7 @@ import sys import threading import time -from collections.abc import Callable +from collections.abc import Callable, Iterable from datetime import datetime from pathlib import Path from queue import Empty, Queue @@ -208,11 +208,14 @@ def parse_metrics_into(line: str, metrics: TrainingMetrics) -> None: logger.debug("Error parsing log line %r: %s", line, exc) -class LocalJobRunner: - """Run a training as a local subprocess. +class SubprocessJobRunner: + """Spawn a subprocess and pump its stdout into a log file + in-memory queue. - The runner is single-shot: instantiate a fresh one per job. Lifetime of - the underlying subprocess is bounded by this object's existence in memory. + The shared engine behind both LocalJobRunner (which runs `lerobot-train` + directly) and HfCloudJobRunner (which runs `lerobot-train --job.target=...`, + a local process that submits the job and streams the remote logs to its own + stdout). Subclasses override `_on_line` to inspect each stdout line for + runner-specific markers (e.g. the HF job id / page URL). """ def __init__( @@ -229,27 +232,21 @@ def __init__( self._log_file = None # type: ignore[assignment] self._wandb_run_url: str | None = None - def start( - self, - job_id: str, - config: TrainingRequest, - output_dir: str, - ) -> None: - if self._process is not None: - raise RuntimeError("LocalJobRunner already started") - - # Build the command via the helper that lives in train.py. - from .train import build_training_command # avoid import cycle at module load - - cmd = build_training_command(config, output_dir, sys.executable) - logger.info("Starting job %s: %s", job_id, " ".join(cmd)) - - # Open the persistent log sink (one JSON line per stdout line). Held - # open for the subprocess's lifetime so we don't reopen per write. + def _open_log_file(self) -> None: + """Open the persistent log sink (one JSON line per consumed line). + Held open for the consumer thread's lifetime so we don't reopen per + write; _consume_lines closes it when its iterator is exhausted.""" if self._log_file_path is not None: self._log_file_path.parent.mkdir(parents=True, exist_ok=True) self._log_file = self._log_file_path.open("a", buffering=1) + def _spawn(self, cmd: list[str], thread_name: str) -> None: + """Open the log sink, launch `cmd`, and start the stdout pump thread.""" + if self._process is not None: + raise RuntimeError(f"{type(self).__name__} already started") + + self._open_log_file() + # PYTHONUNBUFFERED makes the child's stdout flush per line. Without it # block-buffering hides log lines from our parser for many seconds. child_env = os.environ.copy() @@ -259,7 +256,7 @@ def start( # group. Without it, signals sent to the uvicorn worker (e.g. when # --reload restarts it on a .py file change) cascade to the child # and kill the training. With it, the child survives reloads; the - # next worker re-attaches via TailingJobRunner using job.json's pid. + # next worker re-attaches via the reattach path. self._process = subprocess.Popen( cmd, stdout=subprocess.PIPE, @@ -270,9 +267,7 @@ def start( start_new_session=True, ) - self._monitor_thread = threading.Thread( - target=self._pump_stdout, name=f"job-{job_id}-stdout", daemon=True - ) + self._monitor_thread = threading.Thread(target=self._pump_stdout, name=thread_name, daemon=True) self._monitor_thread.start() def pid(self) -> int | None: @@ -316,10 +311,17 @@ def wandb_run_url(self) -> str | None: # -- internals -- - def _pump_stdout(self) -> None: - assert self._process is not None + def _on_line(self, line: str) -> None: + """Hook for subclasses to inspect each stdout line. Default: no-op.""" + + def _consume_lines(self, lines: Iterable[str]) -> None: + """Drive each text line through the metric/marker parse + log.jsonl + append + in-memory queue. Source-agnostic: a subprocess's stdout + (LocalJobRunner) or a remote log stream iterator (cloud reattach) feed + the same pipeline. Closes the log file when the iterator is exhausted. + """ try: - for line in iter(self._process.stdout.readline, ""): + for line in lines: if self._stop_event.is_set(): break stripped = line.rstrip() @@ -330,25 +332,51 @@ def _pump_stdout(self) -> None: url = extract_wandb_run_url(stripped) if url is not None: self._wandb_run_url = url + self._on_line(stripped) log_line = LogLine(timestamp=time.time(), message=stripped) if self._log_file is not None: try: self._log_file.write(log_line.model_dump_json() + "\n") except Exception as exc: # pragma: no cover — best-effort persist logger.exception("Error writing to log file: %s", exc) - # Cap queue so a chatty subprocess can't grow memory unbounded. + # Cap queue so a chatty source can't grow memory unbounded. if self._log_queue.qsize() >= 1000: with contextlib.suppress(Empty): self._log_queue.get_nowait() self._log_queue.put(log_line) except Exception as exc: - logger.exception("Error reading subprocess stdout: %s", exc) + logger.exception("Error consuming log lines: %s", exc) finally: if self._log_file is not None: with contextlib.suppress(Exception): self._log_file.close() self._log_file = None + def _pump_stdout(self) -> None: + assert self._process is not None + self._consume_lines(iter(self._process.stdout.readline, "")) + + +class LocalJobRunner(SubprocessJobRunner): + """Run a training as a local subprocess. + + The runner is single-shot: instantiate a fresh one per job. Lifetime of + the underlying subprocess is bounded by this object's existence in memory. + """ + + def start( + self, + job_id: str, + config: TrainingRequest, + output_dir: str, + ) -> None: + # Build the command via the helper that lives in train.py. + from .train import build_training_command # avoid import cycle at module load + + cmd = build_training_command(config, output_dir, sys.executable) + logger.info("Starting job %s: %s", job_id, " ".join(cmd)) + self._spawn(cmd, thread_name=f"job-{job_id}-stdout") + class TailingJobRunner: """Re-attaches to a detached subprocess after a uvicorn reload. @@ -551,17 +579,6 @@ def _list_imported_hub(api, repo_id: str) -> list[JobCheckpoint]: return [] -def _list_hub_checkpoints(api, repo_id: str) -> list[JobCheckpoint]: - """List checkpoints by introspecting the model repo file tree.""" - try: - files = api.list_repo_files(repo_id, repo_type="model") - except Exception: - # Repo may not exist yet (training just started, sidecar hasn't - # uploaded anything). Treat as no checkpoints. - return [] - return _hub_checkpoints_from_files(files, repo_id) - - _LANGUAGE_CONDITIONED_POLICY_TYPES = {"smolvla", "pi0", "pi0_fast", "pi05"} @@ -829,15 +846,12 @@ def start(self, config: TrainingRequest, target: JobTarget | None = None) -> Job self._persist(record, force=True) raise - # Capture runner-specific identifiers. + # Capture runner-specific identifiers. For cloud jobs the HF job id + # / page URL / model repo are printed by lerobot's submit_to_hf and + # only appear in stdout a few seconds after start, so they're None + # here; the watchdog (_tick) parses and persists them once they land. if target.runner == "local": record.process_pid = runner.pid() - else: - record.hf_job_id = runner.hf_job_id() - record.hf_job_url = runner.hf_job_url() - # config was mutated by HfCloudJobRunner.start to set - # policy_repo_id; mirror it onto the record for the UI. - record.hf_repo_id = config.policy_repo_id self._persist(record, force=True) self._runners[job_id] = runner @@ -1003,10 +1017,13 @@ def read_metrics_history(self, job_id: str) -> builtins.list[MetricsHistoryPoint def _checkpoints_for(self, record: JobRecord) -> builtins.list[JobCheckpoint]: if record.runner == "imported": if record.hf_repo_id: - return self._list_cloud_cached(record.hf_repo_id, _list_imported_hub) + return self._list_cloud_cached(record.hf_repo_id) return _list_imported_local(record.output_dir) if record.runner == "local": return _list_local_checkpoints(record.output_dir) + # Cloud: _list_imported_hub prefers the checkpoints// tree (pushed when + # save_checkpoint_to_hub is on) and falls back to the final model at the repo + # root, so a finished run is always reachable even with no per-step tree. return self._list_cloud_cached(record.hf_repo_id) def list_checkpoints(self, job_id: str) -> builtins.list[JobCheckpoint]: @@ -1021,12 +1038,10 @@ def list_checkpoints(self, job_id: str) -> builtins.list[JobCheckpoint]: raise JobNotFoundError(job_id) return self._checkpoints_for(record) - def _list_cloud_cached( - self, repo_id: str | None, fetch=_list_hub_checkpoints - ) -> builtins.list[JobCheckpoint]: - """30s-TTL cache over a hub checkpoint listing. `fetch(api, repo_id)` - defaults to the training-job tree scan; imported hub models pass - `_list_imported_hub` so they share the same cache + rate-limit budget.""" + def _list_cloud_cached(self, repo_id: str | None) -> builtins.list[JobCheckpoint]: + """30s-TTL cache over the hub checkpoint listing (`_list_imported_hub`: + the checkpoints// tree, else the root model). All hub listings — + cloud-trained and imported alike — share this cache + rate-limit budget.""" if not repo_id: return [] now = time.time() @@ -1035,7 +1050,7 @@ def _list_cloud_cached( return cached[1] from .utils.hf_auth import shared_hf_api # lazy: keeps unit-test imports cheap - result = fetch(shared_hf_api(), repo_id) + result = _list_imported_hub(shared_hf_api(), repo_id) self._cloud_ckpt_cache[repo_id] = (now + _CLOUD_CKPT_TTL_SECONDS, result) return result @@ -1188,6 +1203,10 @@ def _tick(self) -> None: with self._lock: record.wandb_run_url = url self._persist(record, force=True) + # Cloud jobs print their HF job id / page URL / model repo a few + # seconds after start; capture them onto the record once parsed. + if record.runner == "hf_cloud": + self._sync_cloud_ids(record, runner) # Persist metric snapshot at most once per second. self._persist(record, force=False) progress_snapshots.append( @@ -1202,6 +1221,10 @@ def _tick(self) -> None: continue # Subprocess exited since the last tick. Finalise. + # Capture any cloud ids printed right before exit (e.g. a job that + # submitted then failed fast) so checkpoint listing has the repo id. + if record.runner == "hf_cloud": + self._sync_cloud_ids(record, runner) rc = runner.returncode() with self._lock: if record.wandb_run_url is None: @@ -1210,22 +1233,31 @@ def _tick(self) -> None: record.ended_at = time.time() record.exit_code = rc if rc != 0 and record.error_message is None: - # Prefer a runner-supplied reason (e.g. HF Jobs' - # 'Job timeout') over the synthetic exit-code message. - reason = None - get_message = getattr(runner, "terminal_message", None) - if callable(get_message): - try: - reason = get_message() - except Exception: - reason = None - record.error_message = reason or f"Subprocess exited with code {rc}" + record.error_message = f"Job exited with code {rc}" self._runners.pop(jid, None) self._persist(record, force=True) self._notify_change() self._notify_progress(progress_snapshots) + def _sync_cloud_ids(self, record: JobRecord, runner: JobRunner) -> None: + """Copy HF job id / page URL / model repo from a cloud runner onto the + record once lerobot's submit_to_hf has printed them. Persists on first + appearance so the ids survive a uvicorn --reload (which drives reattach). + """ + changed = False + for attr in ("hf_job_id", "hf_job_url", "hf_repo_id"): + if getattr(record, attr) is not None: + continue + getter = getattr(runner, attr, None) + value = getter() if callable(getter) else None + if value is not None: + with self._lock: + setattr(record, attr, value) + changed = True + if changed: + self._persist(record, force=True) + def _persist(self, record: JobRecord, force: bool) -> None: now = time.time() last = self._last_persist_at.get(record.id, 0.0) @@ -1262,6 +1294,7 @@ def _write_meta(self, record: JobRecord) -> None: "JobCheckpoint", "MetricsHistoryPoint", "JobRunner", + "SubprocessJobRunner", "LocalJobRunner", "JobRegistry", "JobAlreadyRunningError", diff --git a/lelab/record.py b/lelab/record.py index 2e61d9e..3a08629 100644 --- a/lelab/record.py +++ b/lelab/record.py @@ -600,7 +600,8 @@ def record_with_web_events(cfg: RecordConfig, web_events: dict) -> LeRobotDatase cfg.dataset.repo_id, root=cfg.dataset.root, batch_encoding_size=cfg.dataset.video_encoding_batch_size, - vcodec=cfg.dataset.vcodec, + rgb_encoder=cfg.dataset.rgb_encoder, + depth_encoder=cfg.dataset.depth_encoder, streaming_encoding=cfg.dataset.streaming_encoding, encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize, encoder_threads=cfg.dataset.encoder_threads, @@ -622,7 +623,8 @@ def record_with_web_events(cfg: RecordConfig, web_events: dict) -> LeRobotDatase image_writer_processes=cfg.dataset.num_image_writer_processes, image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras), batch_encoding_size=cfg.dataset.video_encoding_batch_size, - vcodec=cfg.dataset.vcodec, + rgb_encoder=cfg.dataset.rgb_encoder, + depth_encoder=cfg.dataset.depth_encoder, streaming_encoding=cfg.dataset.streaming_encoding, encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize, encoder_threads=cfg.dataset.encoder_threads, diff --git a/lelab/runners/hf_cloud.py b/lelab/runners/hf_cloud.py index 7d778b3..f15ca78 100644 --- a/lelab/runners/hf_cloud.py +++ b/lelab/runners/hf_cloud.py @@ -12,196 +12,64 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""HF Jobs runner — runs a training as an HF Jobs job on HuggingFace's GPUs. - -Uses huggingface/lerobot-gpu:latest as the runtime image (lerobot pre-installed). -Tails logs via HfApi.fetch_job_logs and reuses the existing parse_metrics_into -parser since stdout format is identical to a local lerobot run. +"""HF Jobs runner — submits a training to HuggingFace's GPUs via LeRobot's +native remote-training feature (`lerobot-train --job.target=`). + +With the native feature the LOCAL machine runs `lerobot-train --job.target=...`; +that local process submits the job to HF Jobs and streams the remote pod's logs +to its own stdout. So the cloud path is "spawn a local subprocess and tail its +stdout" — the exact machinery LocalJobRunner already provides. This runner just +adds a stdout parser for the HF job id / page URL / model repo that lerobot's +`submit_to_hf` prints, and a stop() that cancels the remote job too. """ from __future__ import annotations import contextlib import logging -import netrc -import os +import re +import sys import threading import time from pathlib import Path -from queue import Empty, Queue - -from huggingface_hub import get_token -from huggingface_hub.errors import RepositoryNotFoundError -from ..jobs import LogLine, TrainingMetrics, extract_wandb_run_url, parse_metrics_into +from ..jobs import JobTarget, SubprocessJobRunner, TrainingMetrics from ..train import TrainingRequest, build_training_command -from ..utils.config import with_lelab_tag -from ..utils.hf_auth import cached_whoami, shared_hf_api +from ..utils.hf_auth import shared_hf_api logger = logging.getLogger(__name__) -LEROBOT_IMAGE = "huggingface/lerobot-gpu:latest" - -# Where the trainer writes checkpoints inside the HF Jobs container. The host -# path the registry hands us (under ~/.cache/...) doesn't exist on the remote -# pod, so we ignore it and pin a writable container-local path instead. The -# wrapper reads --output_dir from the trainer argv and uploads checkpoints from -# here to the Hub, so the lelab UI never reads this path directly. -_CONTAINER_OUTPUT_DIR = "/tmp/lelab/train" # nosec B108 — fixed path inside the remote HF Jobs container, not host-local - -# Inlined sidecar uploader for HF Jobs. Spawns the lerobot trainer as a -# subprocess and concurrently uploads new /checkpoints// -# directories to the Hub model repo, so the lelab UI can list them while -# training is in progress. -# -# Sent verbatim as the value of `python -c '...'`. Anything after `--` in -# the command argv is forwarded to the trainer. -WRAPPER_SOURCE = r''' -import os, re, sys, threading, subprocess -from pathlib import Path -from huggingface_hub import HfApi - -argv = sys.argv[1:] -if "--" not in argv: - print("[wrapper] missing -- separator", flush=True) - sys.exit(2) -sep = argv.index("--") -trainer_argv = argv[sep + 1:] - - -def _arg(name): - """Return the value of --name=foo or --name foo from trainer_argv.""" - for i, tok in enumerate(trainer_argv): - if tok == name and i + 1 < len(trainer_argv): - return trainer_argv[i + 1] - if tok.startswith(name + "="): - return tok.split("=", 1)[1] - return None - - -output_dir = _arg("--output_dir") -repo_id = _arg("--policy.repo_id") -if not output_dir or not repo_id: - print(f"[wrapper] need --output_dir and --policy.repo_id; got {output_dir} / {repo_id}", flush=True) - sys.exit(2) - -api = HfApi() -# lerobot only calls push_to_hub at the end of training, so the repo doesn't -# exist when our checkpoint watcher fires. Create it up front (idempotent). -try: - api.create_repo(repo_id=repo_id, repo_type="model", exist_ok=True) - print(f"[wrapper] repo ready: {repo_id}", flush=True) -except Exception as exc: - print(f"[wrapper] create_repo failed: {exc}", flush=True) - -seen = set() -stop_event = threading.Event() - - -def _scan_and_upload(): - root = Path(output_dir) / "checkpoints" - if not root.is_dir(): - return - # Snapshot before iterating so deletions during the walk do not raise. - entries = sorted(p for p in root.iterdir() if p.is_dir() and not p.is_symlink()) - for entry in entries: - if not re.fullmatch(r"\d+", entry.name): - continue - config_json = entry / "pretrained_model" / "config.json" - if not config_json.is_file(): - continue - if entry.name in seen: - continue - try: - api.upload_folder( - folder_path=str(entry), - repo_id=repo_id, - path_in_repo=f"checkpoints/{entry.name}", - commit_message=f"checkpoint {entry.name}", - ) - seen.add(entry.name) - print(f"[wrapper] uploaded checkpoint {entry.name}", flush=True) - except Exception as exc: - print(f"[wrapper] upload failed for {entry.name}: {exc}", flush=True) - - -def _watch(): - while not stop_event.is_set(): - try: - _scan_and_upload() - except Exception as exc: - print(f"[wrapper] scan error: {exc}", flush=True) - stop_event.wait(15) - - -watch_thread = threading.Thread(target=_watch, name="ckpt-watcher", daemon=True) -watch_thread.start() - -print(f"[wrapper] launching trainer: {' '.join(trainer_argv)}", flush=True) -proc = subprocess.Popen(list(trainer_argv), env=os.environ.copy()) -try: - rc = proc.wait() -finally: - stop_event.set() - # One final pass picks up any checkpoint saved in the last 15s window. - try: - _scan_and_upload() - except Exception as exc: - print(f"[wrapper] final scan error: {exc}", flush=True) - -print(f"[wrapper] trainer exited with rc={rc}", flush=True) -sys.exit(rc) -''' - -# HF Jobs' platform default timeout has killed legitimate runs that pushed -# the model successfully but were still uploading auxiliary files. 2h covers -# our typical ACT/SmolVLA runs on t4-small with comfortable headroom. -HF_JOB_TIMEOUT = "2h" - -# Cadence at which the status poller hits inspect_job. inspect_job is the -# authoritative source for job liveness; the log stream is best-effort and -# may drop during long runs (NAT eviction, laptop sleep, proxy idle timeout) -# without the job actually ending. -_STATUS_POLL_INTERVAL_S = 5.0 - -# Stages we treat as terminal. Allowlist (not "anything except RUNNING") so -# freshly-submitted jobs in transient stages like QUEUED / BUILDING / STARTING -# aren't mistaken for failures before they get a chance to run. +# HF Jobs stages we treat as terminal (job is no longer making progress). _TERMINAL_STAGES = frozenset({"COMPLETED", "CANCELED", "ERROR", "DELETED"}) -# How long _tail_loop waits before reconnecting after a clean stream end -# (gives the status poller a chance to confirm the job is actually terminal, -# so we don't reconnect and re-replay the entire buffered log). -_TAIL_CLEAN_END_WAIT_S = 15.0 - -# How long _tail_loop waits before reconnecting after an SSE exception -# (transient network blip during a long training). -_TAIL_RECONNECT_BACKOFF_S = 5.0 - - -def resolve_wandb_api_key() -> str | None: - """Look up the host's wandb API key for forwarding to a cloud job. - - Checks WANDB_API_KEY first, then falls back to ~/.netrc (where - `wandb login` writes the key under machine api.wandb.ai). Returns None - if neither source has it; the caller decides how to surface that. +# Min seconds between inspect_job calls on the reattach path. The watchdog calls +# is_running()/returncode() at ~1Hz; without throttling that hammers /jobs. +_STAGE_POLL_INTERVAL_S = 5.0 + +# Markers printed by lerobot's submit_to_hf (src/lerobot/jobs/hf.py). Kept in +# sync with the exact f-strings emitted on submission: +# print(f"Job submitted: {job_id}") +# print(f" Job page: {job_url}") +# print(f" Model repo: https://huggingface.co/{repo_id}") +_JOB_ID_RE = re.compile(r"^Job submitted:\s*(\S+)") +_JOB_PAGE_RE = re.compile(r"^\s*Job page:\s*(\S+)") +_MODEL_REPO_RE = re.compile(r"^\s*Model repo:\s*(https://huggingface\.co/\S+)") + + +class HfCloudJobRunner(SubprocessJobRunner): + """Run a training on HF Jobs. Single-shot — instantiate per job. + + Reuses SubprocessJobRunner's spawn/pump/parse pipeline: the tailed + subprocess is the local `lerobot-train --job.target=` process, + whose stdout carries both the remote training logs and the submission + markers we parse for the HF job id / page URL / model repo. + + hf_job_id / hf_job_url / hf_repo_id are discovered ASYNCHRONOUSLY by + parsing that local stdout (lerobot's submit_to_hf prints them a few + seconds after start), so all three return None until the markers appear. + JobRegistry._tick → _sync_cloud_ids polls the getters and persists the + values onto the JobRecord once present. """ - key = os.environ.get("WANDB_API_KEY") - if key: - return key - try: - rc = netrc.netrc() - except (FileNotFoundError, netrc.NetrcParseError, OSError): - return None - auth = rc.authenticators("api.wandb.ai") - if auth is None: - return None - _login, _account, password = auth - return password or None - - -class HfCloudJobRunner: - """Run a training as an HF Jobs job. Single-shot — instantiate per job.""" def __init__( self, @@ -209,286 +77,111 @@ def __init__( log_file_path: Path, flavor: str, ) -> None: - self._metrics = metrics - self._log_file_path = log_file_path + super().__init__(metrics, log_file_path) self._flavor = flavor - # Shared HfApi: its in-process whoami cache covers run_job's - # internal self.whoami(token=...) call too (see utils/hf_auth.py), - # so submitting many jobs doesn't hammer /whoami-v2. - self._api = shared_hf_api() self._hf_job_id: str | None = None self._hf_job_url: str | None = None - self._log_queue: Queue[LogLine] = Queue() - self._tail_thread: threading.Thread | None = None - # _status_thread polls inspect_job and is the sole writer of - # _terminal_status (except for stop(), which pre-sets CANCELED). - # Decoupling liveness from the log stream means a flaky SSE - # connection no longer makes us declare a running job as failed. - self._status_thread: threading.Thread | None = None - self._stop_event = threading.Event() - self._log_file = None # type: ignore[assignment] - # Cached terminal status once the job ends; None while live. - self._terminal_status: str | None = None - # Status.message at the terminal tick (e.g. "Job timeout"), so the - # registry can surface it to the UI instead of a synthetic exit code. - self._terminal_message: str | None = None - self._wandb_run_url: str | None = None - # Count of log lines processed across (possibly multiple) SSE - # connections, so reconnects skip the replayed prefix. - self._lines_processed: int = 0 + self._hf_repo_id: str | None = None + # Set on reattach so is_running()/returncode() derive liveness from the + # remote job stage rather than the log stream (which just ends when the + # job is terminal, carrying no exit code). + self._reattached_job_id: str | None = None + self._reattach_thread: threading.Thread | None = None + # 5s-TTL cache over inspect_job for the reattach path. + self._stage_cache: str | None = None + self._stage_fetched_at: float = 0.0 def start(self, job_id: str, config: TrainingRequest, output_dir: str) -> None: - # output_dir is the host-local path the registry pins for local jobs; - # it doesn't exist on the remote pod, so cloud jobs write to a - # container-local path instead (checkpoints reach the UI via the Hub). - del output_dir - if self._hf_job_id is not None: - raise RuntimeError("HfCloudJobRunner already started") - - token = get_token() - if not token: - raise RuntimeError("HF token not found. Run 'hf auth login' before launching cloud jobs.") - - whoami = cached_whoami() - username = whoami.get("name") if whoami else None - if not username: - raise RuntimeError("Could not resolve HF username from whoami()") - - # Open the log file early so dataset-upload progress is recorded - # before the cloud job is submitted. - self._log_file_path.parent.mkdir(parents=True, exist_ok=True) - self._log_file = self._log_file_path.open("a", buffering=1) - - # Cloud pods can't see the host's LeRobot cache. If the dataset - # only exists locally, push it to the Hub before submitting. - self._ensure_dataset_on_hub(config.dataset_repo_id) - - # Mutate the config so build_training_command emits the right flags. - # The mutated config is what gets persisted in JobRecord.config, so - # the historical record reflects what actually ran. - config.policy_push_to_hub = True - # job_id is already a unique slug like "act_dataset_2026-05-04_10-22-03". - config.policy_repo_id = f"{username}/{job_id}" - - trainer_argv = build_training_command(config, _CONTAINER_OUTPUT_DIR) - # The wrapper expects `python -c WRAPPER_SOURCE -- `. - # `python -c` consumes the first non-option argument as the script, - # so we prepend a "--" sentinel of our own. - wrapped_command = ["python", "-c", WRAPPER_SOURCE, "--", *trainer_argv] - logger.info( - "Submitting HF Cloud job %s on %s (wrapped trainer): %s", - job_id, - self._flavor, - " ".join(trainer_argv), + # The submission runs LOCALLY (lerobot's submit_to_hf), so the + # subprocess must use lelab's own interpreter — same as LocalJobRunner. + cmd = build_training_command( + config, + output_dir, + sys.executable, + job_target=JobTarget(runner="hf_cloud", flavor=self._flavor), ) - - # HF_TOKEN goes via `secrets` (not `env`) so it doesn't show up in - # the job's environment variable inspection / logs. - secrets = {"HF_TOKEN": token} - if config.wandb_enable: - wandb_key = resolve_wandb_api_key() - if not wandb_key: - # ValueError so main.py maps it to a 400 + detail the UI shows. - raise ValueError( - "WANDB_API_KEY not found on this machine. " - "Run `wandb login` or export WANDB_API_KEY before launching " - "cloud jobs with W&B enabled." - ) - secrets["WANDB_API_KEY"] = wandb_key - - job = self._api.run_job( - image=LEROBOT_IMAGE, - command=wrapped_command, - flavor=self._flavor, - secrets=secrets, - timeout=HF_JOB_TIMEOUT, - ) - self._hf_job_id = job.id - self._hf_job_url = getattr(job, "url", None) - - self._start_worker_threads(job_id) + logger.info("Submitting HF Cloud job %s on %s: %s", job_id, self._flavor, " ".join(cmd)) + self._spawn(cmd, thread_name=f"hf-job-{job_id}-logs") def reattach(self, hf_job_id: str) -> None: - """Take over an existing HF job after a process restart. - - Skips submission; just opens the log file in append mode and starts - the log-tailing + status-polling threads. + """Take over an existing HF job after a uvicorn --reload restart. + + The local lerobot-train process that submitted the job is gone, but + the remote job persists. We re-stream its logs via the always-available + Python API HfApi.fetch_job_logs(follow=True) in a daemon thread feeding + the SAME _consume_lines pipeline as the subprocess path — no dependency + on the `hf` CLI being on PATH. A terminal job's follow stream just ends; + liveness/outcome are read from inspect_job in is_running()/returncode(). """ - if self._hf_job_id is not None: - raise RuntimeError("HfCloudJobRunner already started") self._hf_job_id = hf_job_id - self._log_file_path.parent.mkdir(parents=True, exist_ok=True) - self._log_file = self._log_file_path.open("a", buffering=1) - self._start_worker_threads(f"{hf_job_id}-reattach") - - def _start_worker_threads(self, label: str) -> None: - """Start the log tail and status poll threads. Both run for the - life of the runner; the status poller is what marks the job terminal.""" - self._tail_thread = threading.Thread(target=self._tail_loop, name=f"hf-job-{label}-logs", daemon=True) - self._tail_thread.start() - self._status_thread = threading.Thread( - target=self._status_poll_loop, name=f"hf-job-{label}-status", daemon=True + self._reattached_job_id = hf_job_id + self._open_log_file() + self._reattach_thread = threading.Thread( + target=self._stream_remote_logs, name=f"hf-job-{hf_job_id}-reattach", daemon=True ) - self._status_thread.start() + self._reattach_thread.start() - def _set_terminal(self, status: str, message: str | None = None) -> None: - """Record the job's terminal stage. Idempotent. Wakes the tail loop.""" - if self._terminal_status is not None: - return - self._terminal_status = status - if message: - self._terminal_message = message - self._stop_event.set() - - def _log_line(self, message: str) -> None: - """Append a wrapper-style line to the job's log file.""" - if self._log_file is None: - return - line = LogLine(timestamp=time.time(), message=message) + def _stream_remote_logs(self) -> None: + """Feed the remote job's followed log stream into _consume_lines. + fetch_job_logs yields text lines and ends when the job is terminal; if + it raises (job already gone / transient error) we log and stop — the + terminal state is still recoverable via inspect_job.""" try: - self._log_file.write(line.model_dump_json() + "\n") + lines = shared_hf_api().fetch_job_logs(job_id=self._reattached_job_id, follow=True) except Exception as exc: - logger.warning("Could not write upload log line: %s", exc) - - def _ensure_dataset_on_hub(self, repo_id: str) -> None: - """If the dataset is local-only, push it to the Hub. - - The cloud pod resolves the dataset by repo_id; it can't see the - host's `~/.cache/huggingface/lerobot`. We push synchronously and - let any failure bubble up — JobRegistry.start marks the record - as failed with the exception message. - """ - try: - self._api.dataset_info(repo_id) - return - except RepositoryNotFoundError: - pass - - cache_root = Path(os.environ.get("HF_LEROBOT_HOME", "~/.cache/huggingface/lerobot")).expanduser() - if not (cache_root / repo_id / "meta" / "info.json").is_file(): - # Neither local nor on Hub. Let the trainer surface the error - # — same behaviour as before. - return - - self._log_line(f"[upload] dataset {repo_id} not on Hub; pushing local copy...") - from lerobot.datasets import LeRobotDataset - - try: - LeRobotDataset(repo_id).push_to_hub(tags=with_lelab_tag(None), private=False) - except Exception as exc: - msg = f"Failed to upload local dataset {repo_id} to Hub: {exc}" - self._log_line(f"[upload] {msg}") - raise RuntimeError(msg) from exc - self._log_line(f"[upload] dataset {repo_id} uploaded.") - - def _tail_loop(self) -> None: - """Stream HfApi.fetch_job_logs, teeing each line to disk and the - in-memory queue. Reconnects on stream end or transient error while - the status poller still says the job is alive — SSE death is no - longer fatal. Exits when _stop_event is set (status poller saw a - terminal stage, or stop() was called). - """ - assert self._hf_job_id is not None - try: - while not self._stop_event.is_set(): - clean_end = False - try: - seen = 0 - for raw in self._api.fetch_job_logs(job_id=self._hf_job_id, follow=True): - if self._stop_event.is_set(): - return - seen += 1 - # Skip the replayed prefix from a reconnect. - if seen <= self._lines_processed: - continue - self._lines_processed = seen - stripped = raw.rstrip() - if not stripped: - continue - parse_metrics_into(stripped, self._metrics) - if self._wandb_run_url is None: - url = extract_wandb_run_url(stripped) - if url is not None: - self._wandb_run_url = url - log_line = LogLine(timestamp=time.time(), message=stripped) - if self._log_file is not None: - try: - self._log_file.write(log_line.model_dump_json() + "\n") - except Exception as exc: # pragma: no cover - logger.exception("Error writing HF log: %s", exc) - if self._log_queue.qsize() >= 1000: - with contextlib.suppress(Empty): - self._log_queue.get_nowait() - self._log_queue.put(log_line) - clean_end = True - except Exception as exc: - logger.info("HF log tail disconnected, will reconnect: %s", exc) - - wait_s = _TAIL_CLEAN_END_WAIT_S if clean_end else _TAIL_RECONNECT_BACKOFF_S - if self._stop_event.wait(wait_s): - return - finally: + logger.warning("fetch_job_logs(%s) failed on reattach: %s", self._reattached_job_id, exc) if self._log_file is not None: with contextlib.suppress(Exception): self._log_file.close() self._log_file = None + return + self._consume_lines(lines) - def _status_poll_loop(self) -> None: - """Poll inspect_job until the job reaches a terminal stage. + def _on_line(self, line: str) -> None: + if self._hf_job_id is None: + m = _JOB_ID_RE.match(line) + if m: + self._hf_job_id = m.group(1) + if self._hf_job_url is None: + m = _JOB_PAGE_RE.match(line) + if m: + self._hf_job_url = m.group(1) + if self._hf_repo_id is None: + m = _MODEL_REPO_RE.match(line) + if m: + # Store the bare repo id; the rest of lelab keys checkpoints on it. + self._hf_repo_id = m.group(1).removeprefix("https://huggingface.co/") - Sole writer of _terminal_status under normal operation. Decoupled - from the log stream: a dropped SSE connection during a long run - (NAT eviction, sleep, proxy idle timeout) no longer causes LeLab - to declare a still-running job as failed. - """ - assert self._hf_job_id is not None - while not self._stop_event.is_set(): + def stop(self) -> None: + # Signal the consumer to break, then on the start path kill the local + # lerobot-train subprocess (graceful→force via the shared impl). The + # reattach path has no subprocess; setting the event unblocks + # _consume_lines on its next yield. Either way the local side only + # detaches the log stream, so also cancel the remote job if we have its id. + self._stop_event.set() + super().stop() + if self._hf_job_id is not None: try: - info = self._api.inspect_job(job_id=self._hf_job_id) - status_obj = getattr(info, "status", None) - stage = getattr(status_obj, "stage", None) if status_obj is not None else None - if stage is not None: - stage_str = str(stage).upper() - if stage_str in _TERMINAL_STAGES: - msg = getattr(status_obj, "message", None) - self._set_terminal(stage_str, str(msg) if msg else None) - return + shared_hf_api().cancel_job(job_id=self._hf_job_id) except Exception as exc: - logger.warning("inspect_job poll failed for %s: %s", self._hf_job_id, exc) - if self._stop_event.wait(_STATUS_POLL_INTERVAL_S): - return - - def stop(self) -> None: - if self._hf_job_id is None: - return - # Pre-set CANCELED so the watchdog finalises as canceled regardless - # of whether the status poller observed a terminal stage first. - self._set_terminal("CANCELED") - try: - self._api.cancel_job(job_id=self._hf_job_id) - except Exception as exc: - # Already-completed jobs may 404; that's fine. - logger.info("cancel_job(%s) ignored: %s", self._hf_job_id, exc) + # Already-finished jobs may 404; that's fine. + logger.info("cancel_job(%s) ignored: %s", self._hf_job_id, exc) def is_running(self) -> bool: - # Liveness is driven by _status_poll_loop's inspect_job calls. - if self._hf_job_id is None: - return False - return self._terminal_status is None + if self._reattached_job_id is None: + return super().is_running() + # Reattach: the followed log stream ending doesn't mean the run is over, + # and a terminal run may keep it open during finalization. inspect_job + # is authoritative. + return self._remote_stage() not in _TERMINAL_STAGES def returncode(self) -> int | None: - if self._terminal_status is None: + if self._reattached_job_id is None: + return super().returncode() + stage = self._remote_stage() + if stage not in _TERMINAL_STAGES: return None - return 0 if self._terminal_status == "COMPLETED" else 1 - - def stream_log_lines(self) -> list[LogLine]: - out: list[LogLine] = [] - try: - while True: - out.append(self._log_queue.get_nowait()) - except Empty: - pass - return out + return 0 if stage == "COMPLETED" else 1 def hf_job_id(self) -> str | None: return self._hf_job_id @@ -496,14 +189,30 @@ def hf_job_id(self) -> str | None: def hf_job_url(self) -> str | None: return self._hf_job_url - def wandb_run_url(self) -> str | None: - return self._wandb_run_url + def hf_repo_id(self) -> str | None: + return self._hf_repo_id - def terminal_message(self) -> str | None: - """Status.message captured when the job reached a terminal stage. + # -- internals -- - Set by _status_poll_loop when it observes a terminal stage. Used by - the registry watchdog to surface platform reasons like 'Job timeout' - rather than a synthetic 'exit code 1'. - """ - return self._terminal_message + def _remote_stage(self) -> str | None: + """Current HF Jobs stage for the reattached job, upper-cased, or None + if it can't be resolved (transient API error → treated as running). + Cached for _STAGE_POLL_INTERVAL_S so the ~1Hz watchdog doesn't spam + inspect_job.""" + now = time.time() + if now - self._stage_fetched_at < _STAGE_POLL_INTERVAL_S: + return self._stage_cache + self._stage_fetched_at = now + try: + info = shared_hf_api().inspect_job(job_id=self._reattached_job_id) + status_obj = getattr(info, "status", None) + stage = getattr(status_obj, "stage", None) if status_obj is not None else None + # huggingface_hub may give a plain str ("COMPLETED") or a JobStage enum; + # unwrap the enum so `str(...).upper()` yields the bare value, not + # "JOBSTAGE.COMPLETED" (which would never match _TERMINAL_STAGES). + stage = getattr(stage, "value", stage) + self._stage_cache = str(stage).upper() if stage is not None else None + except Exception as exc: + logger.warning("inspect_job poll failed for %s: %s", self._reattached_job_id, exc) + self._stage_cache = None + return self._stage_cache diff --git a/lelab/train.py b/lelab/train.py index 6ee726e..bfdc022 100644 --- a/lelab/train.py +++ b/lelab/train.py @@ -19,9 +19,13 @@ """ import re +from typing import TYPE_CHECKING from pydantic import BaseModel +if TYPE_CHECKING: + from lelab.jobs import JobTarget + _SLUG_RE = re.compile(r"[^a-zA-Z0-9._-]+") @@ -44,7 +48,7 @@ class TrainingRequest(BaseModel): # Logging and checkpointing log_freq: int = 250 save_freq: int = 1000 - eval_freq: int = 0 + env_eval_freq: int = 0 save_checkpoint: bool = True # Output configuration @@ -87,7 +91,10 @@ class TrainingRequest(BaseModel): def build_training_command( - request: TrainingRequest, output_dir: str, python_executable: str = "python" + request: TrainingRequest, + output_dir: str, + python_executable: str = "python", + job_target: "JobTarget | None" = None, ) -> list[str]: """Build the argv list to invoke ` -m lerobot.scripts.lerobot_train`. @@ -126,20 +133,28 @@ def build_training_command( if request.policy_device: cmd.extend(["--policy.device", request.policy_device]) cmd.extend(["--policy.use_amp", "true" if request.policy_use_amp else "false"]) + # On HF Cloud, lerobot's submit_to_hf owns the model repo and sets push_to_hub on + # the pod itself; _pod_forwarded_args drops any --policy.push_to_hub/--policy.repo_id + # we'd pass, so we must not emit them. Local runs keep the existing behavior: # LeRobot defaults push_to_hub=True and demands --policy.repo_id when so. - # Local jobs keep it off; HF Cloud jobs flip it on via the runner. - cmd.extend(["--policy.push_to_hub", "true" if request.policy_push_to_hub else "false"]) - if request.policy_push_to_hub and request.policy_repo_id: - cmd.extend(["--policy.repo_id", request.policy_repo_id]) + is_cloud = job_target is not None and job_target.runner == "hf_cloud" + if not is_cloud: + cmd.extend(["--policy.push_to_hub", "true" if request.policy_push_to_hub else "false"]) + if request.policy_push_to_hub and request.policy_repo_id: + cmd.extend(["--policy.repo_id", request.policy_repo_id]) # Logging / checkpointing cmd.extend(["--log_freq", str(request.log_freq)]) cmd.extend(["--save_freq", str(request.save_freq)]) - cmd.extend(["--eval_freq", str(request.eval_freq)]) + cmd.extend(["--env_eval_freq", str(request.env_eval_freq)]) cmd.extend(["--save_checkpoint", "true" if request.save_checkpoint else "false"]) - # Output - cmd.extend(["--output_dir", output_dir]) + # Output. On HF Cloud the pod, not this host, runs the trainer: an absolute host + # output_dir (e.g. ~/.cache/.../outputs/train) is baked into the staged config and + # the pod crashes trying to mkdir it under /Users. Checkpoints land on the Hub repo + # anyway, so we omit it for cloud and let lerobot pick its in-pod default. + if not is_cloud: + cmd.extend(["--output_dir", output_dir]) cmd.extend(["--resume", "true" if request.resume else "false"]) if request.job_name: cmd.extend(["--job_name", request.job_name]) @@ -185,4 +200,18 @@ def build_training_command( if request.config_path: cmd.extend(["--config_path", request.config_path]) + # HF Jobs: --job.target= dispatches the run remotely (lerobot commit #3856). + # Image/timeout use lerobot's JobConfig defaults. lelab tags its jobs; lerobot always + # adds a "lerobot" tag too. A pod's local checkpoints die with it, so push each one to + # the model repo's checkpoints// tree (the native replacement for lelab's old + # in-pod uploader) — that's what makes the trained checkpoints reachable afterwards. + if is_cloud and job_target.flavor: + cmd.extend(["--job.target", job_target.flavor]) + cmd.extend(["--job.tags", '["lelab"]']) + # save_checkpoint_to_hub needs policy.repo_id, which submit_to_hf only sets on the + # fresh-run path; on a resume it isn't set before validate(), so the flag would + # abort the submit. A resume already pushes back to its source repo, so skip it. + if request.save_checkpoint and not request.resume: + cmd.extend(["--save_checkpoint_to_hub", "true"]) + return cmd diff --git a/pyproject.toml b/pyproject.toml index f94d038..2eefec7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ dependencies = [ "websockets>=15.0.1", "uvicorn>=0.24.0", "psutil>=5.9.0", - "lerobot[core_scripts,feetech,training] @ git+https://github.com/huggingface/lerobot.git@82dffde7fad11cba91f7916b050fbe7d7eea35ab", + "lerobot[core_scripts,feetech,training] @ git+https://github.com/huggingface/lerobot.git@5ac3b49a5fd25d9e570ea1de4e6a81c77d603bd3", # Windows-only: real DirectShow camera names for /available-cameras so the # frontend can match a camera to its browser deviceId (issues #12, #16). # Guarded by try/except at the call site, so its absence degrades gracefully. diff --git a/tests/test_runners_hf_cloud.py b/tests/test_runners_hf_cloud.py index 9a7d507..035d56a 100644 --- a/tests/test_runners_hf_cloud.py +++ b/tests/test_runners_hf_cloud.py @@ -11,100 +11,211 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for lelab.runners.hf_cloud — covers the host-side wandb credential -resolution path. HfCloudJobRunner itself talks to HF Jobs and is not unit- -testable without a heavy mock of HfApi; we intentionally leave it for -integration tests.""" - -from __future__ import annotations - -import netrc - -import pytest - - -def test_resolve_wandb_api_key_prefers_environment_variable( - monkeypatch: pytest.MonkeyPatch, -) -> None: - from lelab.runners.hf_cloud import resolve_wandb_api_key - - monkeypatch.setenv("WANDB_API_KEY", "env-key-123") - assert resolve_wandb_api_key() == "env-key-123" +"""Tests for lelab.runners.hf_cloud. +With LeRobot's native remote-training feature, HfCloudJobRunner is a thin +subprocess tailer: it runs `lerobot-train --job.target=` locally and +parses the submission markers lerobot prints to stdout. The credential / dataset +/ checkpoint-upload work is now lerobot's. The unit-testable surface is the +stdout parser (`_on_line`) — it must stay in lockstep with the exact strings +lerobot's submit_to_hf emits. Submission against HF Jobs is left to integration +tests.""" -def test_resolve_wandb_api_key_falls_back_to_netrc(monkeypatch: pytest.MonkeyPatch) -> None: - """When WANDB_API_KEY is unset, the function must read the same place - `wandb login` writes — ~/.netrc under machine api.wandb.ai.""" - from lelab.runners.hf_cloud import resolve_wandb_api_key - - monkeypatch.delenv("WANDB_API_KEY", raising=False) - - class _FakeNetrc: - def authenticators(self, host): - assert host == "api.wandb.ai" - return ("login", "account", "netrc-key-456") - - monkeypatch.setattr(netrc, "netrc", lambda: _FakeNetrc()) - assert resolve_wandb_api_key() == "netrc-key-456" - - -def test_resolve_wandb_api_key_returns_none_when_netrc_has_no_wandb_entry( - monkeypatch: pytest.MonkeyPatch, -) -> None: - from lelab.runners.hf_cloud import resolve_wandb_api_key +from __future__ import annotations - monkeypatch.delenv("WANDB_API_KEY", raising=False) +from pathlib import Path +from unittest.mock import MagicMock, patch - class _FakeNetrc: - def authenticators(self, host): - return None +from lelab.jobs import TrainingMetrics +from lelab.runners.hf_cloud import HfCloudJobRunner - monkeypatch.setattr(netrc, "netrc", lambda: _FakeNetrc()) - assert resolve_wandb_api_key() is None +def _runner(tmp_path: Path) -> HfCloudJobRunner: + return HfCloudJobRunner(TrainingMetrics(), tmp_path / "log.jsonl", flavor="t4-small") -def test_resolve_wandb_api_key_returns_none_when_netrc_missing( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """No env var, no ~/.netrc — neither source has it, caller decides.""" - from lelab.runners.hf_cloud import resolve_wandb_api_key - monkeypatch.delenv("WANDB_API_KEY", raising=False) +def _stage_info(stage: str) -> MagicMock: + """A fake huggingface_hub JobInfo with .status.stage.""" + info = MagicMock() + info.status.stage = stage + return info - def _raise_missing(): - raise FileNotFoundError("~/.netrc") - monkeypatch.setattr(netrc, "netrc", _raise_missing) - assert resolve_wandb_api_key() is None +def test_on_line_parses_submission_markers(tmp_path: Path) -> None: + """Feed the exact lines lerobot's submit_to_hf prints and assert the + runner picks up the job id, page URL, and (bare) model repo id.""" + runner = _runner(tmp_path) + for line in [ + "Submitting job to HF Jobs (flavor=t4-small, image=huggingface/lerobot-gpu:latest) ...", + "Job submitted: 0123abcd", + " Job page: https://huggingface.co/jobs/me/0123abcd", + " Model repo: https://huggingface.co/me/act_dataset_2026-06-30", + " Monitor: hf jobs logs 0123abcd", + ]: + runner._on_line(line) + assert runner.hf_job_id() == "0123abcd" + assert runner.hf_job_url() == "https://huggingface.co/jobs/me/0123abcd" + assert runner.hf_repo_id() == "me/act_dataset_2026-06-30" -def test_resolve_wandb_api_key_returns_none_when_netrc_parse_fails( - monkeypatch: pytest.MonkeyPatch, -) -> None: - from lelab.runners.hf_cloud import resolve_wandb_api_key - monkeypatch.delenv("WANDB_API_KEY", raising=False) +def test_on_line_ignores_unrelated_lines(tmp_path: Path) -> None: + runner = _runner(tmp_path) + runner._on_line("INFO step:250 loss:0.42 lr:1e-4") + runner._on_line("Training: 1%| | 125/10000 [02:02<2:36:10, 1.05step/s]") - def _raise_parse(): - raise netrc.NetrcParseError("bad netrc", "~/.netrc", 1) + assert runner.hf_job_id() is None + assert runner.hf_job_url() is None + assert runner.hf_repo_id() is None - monkeypatch.setattr(netrc, "netrc", _raise_parse) - assert resolve_wandb_api_key() is None +def test_on_line_keeps_first_job_id(tmp_path: Path) -> None: + """Once parsed, ids are sticky — a later spurious match must not clobber.""" + runner = _runner(tmp_path) + runner._on_line("Job submitted: first") + runner._on_line("Job submitted: second") + assert runner.hf_job_id() == "first" -def test_resolve_wandb_api_key_returns_none_when_password_is_empty( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """An empty password from netrc is treated as missing — the helper - contract is 'returns the usable key or None', not 'returns whatever - netrc happened to have'.""" - from lelab.runners.hf_cloud import resolve_wandb_api_key - monkeypatch.delenv("WANDB_API_KEY", raising=False) +# -- reattach: re-stream remote logs via the Python API (no `hf` CLI) ---------- - class _FakeNetrc: - def authenticators(self, host): - return ("login", "account", "") - monkeypatch.setattr(netrc, "netrc", lambda: _FakeNetrc()) - assert resolve_wandb_api_key() is None +def test_reattach_streams_remote_logs_through_pipeline(tmp_path: Path) -> None: + """reattach() must re-stream the job's logs via HfApi.fetch_job_logs(follow=True) + — not the `hf` CLI — feeding the same parse/persist/queue pipeline as the + subprocess path: markers parsed, metrics updated, lines queued + persisted.""" + runner = _runner(tmp_path) + remote_lines = [ + "Job submitted: jb_42\n", + " Model repo: https://huggingface.co/me/act_run\n", + "INFO step:250 loss:0.42 lr:1e-4\n", + ] + api = MagicMock() + api.fetch_job_logs.return_value = iter(remote_lines) + + with patch("lelab.runners.hf_cloud.shared_hf_api", return_value=api): + runner.reattach("jb_42") + assert runner._reattach_thread is not None + runner._reattach_thread.join(timeout=5) + + api.fetch_job_logs.assert_called_once_with(job_id="jb_42", follow=True) + assert runner.hf_repo_id() == "me/act_run" # parsed from the streamed marker + assert runner._metrics.current_step == 250 # metrics parsed from the streamed line + messages = [line.message for line in runner.stream_log_lines()] + assert "Job submitted: jb_42" in messages + assert (tmp_path / "log.jsonl").exists() # lines were persisted + + +def test_reattach_survives_fetch_job_logs_error(tmp_path: Path) -> None: + """If the job is already gone, fetch_job_logs raises; reattach must not crash.""" + runner = _runner(tmp_path) + api = MagicMock() + api.fetch_job_logs.side_effect = RuntimeError("job not found") + + with patch("lelab.runners.hf_cloud.shared_hf_api", return_value=api): + runner.reattach("gone") + assert runner._reattach_thread is not None + runner._reattach_thread.join(timeout=5) + + assert runner.stream_log_lines() == [] # nothing queued, no exception escaped + + +# -- stop(): cancel the remote job too ----------------------------------------- + + +def test_stop_cancels_remote_job_when_id_known(tmp_path: Path) -> None: + runner = _runner(tmp_path) + runner._hf_job_id = "jb_99" + api = MagicMock() + with patch("lelab.runners.hf_cloud.shared_hf_api", return_value=api): + runner.stop() + api.cancel_job.assert_called_once_with(job_id="jb_99") + + +def test_stop_without_id_does_not_cancel(tmp_path: Path) -> None: + """Before the submission marker is parsed there is no id; stop() must not + call cancel_job (killing only the local tail would leave the pod running, + but there is nothing we can cancel yet) and must not raise.""" + runner = _runner(tmp_path) + assert runner.hf_job_id() is None + api = MagicMock() + with patch("lelab.runners.hf_cloud.shared_hf_api", return_value=api): + runner.stop() + api.cancel_job.assert_not_called() + + +def test_stop_ignores_cancel_job_failure(tmp_path: Path) -> None: + """An already-finished job 404s on cancel; stop() must swallow it.""" + runner = _runner(tmp_path) + runner._hf_job_id = "done" + api = MagicMock() + api.cancel_job.side_effect = RuntimeError("404 not found") + with patch("lelab.runners.hf_cloud.shared_hf_api", return_value=api): + runner.stop() # must not raise + + +# -- reattach liveness/returncode derived from the remote stage ---------------- + + +def test_reattach_stage_maps_to_liveness_and_returncode(tmp_path: Path) -> None: + """On the reattach path, inspect_job's stage is authoritative: non-terminal + → running/None; COMPLETED → done/0; ERROR or CANCELED → done/1.""" + runner = _runner(tmp_path) + runner._reattached_job_id = "jb_1" # force the reattach branch + api = MagicMock() + + with patch("lelab.runners.hf_cloud.shared_hf_api", return_value=api): + runner._stage_fetched_at = 0.0 + api.inspect_job.return_value = _stage_info("RUNNING") + assert runner.is_running() is True + runner._stage_fetched_at = 0.0 + assert runner.returncode() is None + + runner._stage_fetched_at = 0.0 + api.inspect_job.return_value = _stage_info("COMPLETED") + assert runner.is_running() is False + runner._stage_fetched_at = 0.0 + assert runner.returncode() == 0 + + for bad in ("ERROR", "CANCELED"): + runner._stage_fetched_at = 0.0 + api.inspect_job.return_value = _stage_info(bad) + assert runner.is_running() is False + runner._stage_fetched_at = 0.0 + assert runner.returncode() == 1 + + +def test_reattach_stage_poll_is_throttled(tmp_path: Path) -> None: + """inspect_job is cached for the poll interval so the ~1Hz watchdog calling + is_running()/returncode() repeatedly doesn't hammer the /jobs API.""" + runner = _runner(tmp_path) + runner._reattached_job_id = "jb_1" + api = MagicMock() + api.inspect_job.return_value = _stage_info("RUNNING") + with patch("lelab.runners.hf_cloud.shared_hf_api", return_value=api): + runner.is_running() + runner.is_running() + runner.returncode() + api.inspect_job.assert_called_once() # subsequent calls hit the TTL cache + + +def test_reattach_handles_jobstage_enum(tmp_path: Path) -> None: + """huggingface_hub may return status.stage as a JobStage enum rather than a + str. _remote_stage must unwrap `.value`; otherwise str(enum).upper() yields + 'JOBSTAGE.COMPLETED', never matches a terminal stage, and the reattached job + is stranded as 'running' forever.""" + from enum import Enum + + class _JobStage(Enum): + COMPLETED = "COMPLETED" + + runner = _runner(tmp_path) + runner._reattached_job_id = "jb_1" + api = MagicMock() + info = MagicMock() + info.status.stage = _JobStage.COMPLETED + api.inspect_job.return_value = info + with patch("lelab.runners.hf_cloud.shared_hf_api", return_value=api): + assert runner.is_running() is False + runner._stage_fetched_at = 0.0 + assert runner.returncode() == 0 diff --git a/tests/test_train.py b/tests/test_train.py index cd2fa12..58fd0d1 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -120,3 +120,58 @@ def test_training_request_validates_required_field() -> None: with pytest.raises(ValidationError): TrainingRequest() # dataset_repo_id is required + + +def test_env_eval_freq_flag() -> None: + from lelab.train import TrainingRequest, build_training_command + + cmd = build_training_command(TrainingRequest(dataset_repo_id="x", env_eval_freq=5000), "/tmp/out") + # LeRobot main renamed eval_freq -> env_eval_freq (top-level flag, underscore form). + assert _arg_value(cmd, "--env_eval_freq") == "5000" + assert "--eval_freq" not in cmd + + +def test_cloud_target_emits_job_flags_and_skips_push_to_hub() -> None: + from lelab.jobs import JobTarget + from lelab.train import TrainingRequest, build_training_command + + # push_to_hub is requested, but for a cloud target lerobot's submit_to_hf owns the + # model repo and _pod_forwarded_args drops --policy.* — so we must NOT emit them. + req = TrainingRequest(dataset_repo_id="x", policy_push_to_hub=True, policy_repo_id="me/x") + cmd = build_training_command( + req, "/tmp/out", job_target=JobTarget(runner="hf_cloud", flavor="a10g-small") + ) + assert _arg_value(cmd, "--job.target") == "a10g-small" + assert _arg_value(cmd, "--job.tags") == '["lelab"]' + assert "--policy.push_to_hub" not in cmd + assert "--policy.repo_id" not in cmd + # An absolute host output_dir would be baked into the staged config and crash the + # pod (mkdir /Users ...); checkpoints go to the Hub repo, so it must be omitted. + assert "--output_dir" not in cmd + # Pod checkpoints are ephemeral, so they must be pushed to the Hub to be reachable. + assert _arg_value(cmd, "--save_checkpoint_to_hub") == "true" + + +def test_cloud_resume_omits_save_checkpoint_to_hub() -> None: + from lelab.jobs import JobTarget + from lelab.train import TrainingRequest, build_training_command + + # On a cloud resume, submit_to_hf never sets policy.repo_id before validate(), so + # --save_checkpoint_to_hub would abort the submit. It must be suppressed. + req = TrainingRequest(dataset_repo_id="x", save_checkpoint=True, resume=True) + cmd = build_training_command( + req, "/tmp/out", job_target=JobTarget(runner="hf_cloud", flavor="a10g-small") + ) + assert "--save_checkpoint_to_hub" not in cmd + assert _arg_value(cmd, "--job.target") == "a10g-small" + + +def test_local_target_keeps_push_to_hub() -> None: + from lelab.jobs import JobTarget + from lelab.train import TrainingRequest, build_training_command + + req = TrainingRequest(dataset_repo_id="x", policy_push_to_hub=True, policy_repo_id="me/x") + cmd = build_training_command(req, "/tmp/out", job_target=JobTarget(runner="local")) + assert _arg_value(cmd, "--policy.push_to_hub") == "true" + assert _arg_value(cmd, "--policy.repo_id") == "me/x" + assert "--job.target" not in cmd