diff --git a/docs/ddp_design.md b/docs/ddp_design.md new file mode 100644 index 00000000..5db2b8b8 --- /dev/null +++ b/docs/ddp_design.md @@ -0,0 +1,67 @@ +# WeightsLab DDP — design + +## Two spaces + +The runtime is split, like kernel/user-space: + +- **train-space** — user code: the training loop (`next(loader); preds = model(batch); loss(preds, batch); [loss.backward();] optimizer.step()`). +- **sdk-space** — WL wrappers embedded at well-known call sites (loss, metric, optimizer, dataloader, training guard). + +All WL synchronisation lives in sdk-space, so train-space stays unmodified across single-process and DDP. + +## SPMD with one privileged rank + +Every rank runs the same script. **Only rank-0** binds the gRPC port; UI/CLI commands enter the system there. There is no IPC to non-rank-0 — sync to other ranks goes exclusively through `torch.distributed` (broadcast / gather / all_reduce). + +Data-loader workers stay simple: they only decode the indices the (main-process) sampler hands them — no collectives, no gRPC. The deny-list never reaches workers: a discarded sample is simply never yielded to them (see below). + +## Two kinds of synchronisation + +1. **Gradient reduction** — handled by `torch.distributed` (all_reduce around `optimizer.step()`); data-loader workers re-converge at each `batch_collate`. Off-the-shelf, untouched by WL. +2. **Async UI state** — the hard part. UI events land on rank-0 at arbitrary times, but only rank-0 sees them, and we've ruled out non-collective IPC. **This is what WL adds.** + +## The transactional unit + +Each loop iteration is a transaction: + +```python +batch = next(loader) +preds = model(batch) +loss(preds, batch) # + per-sample metrics +[loss.backward();] optimizer.step() +``` + +No async UI change propagates **mid-iteration**. Consistency is enforced exactly at the **train-space → sdk-space transition** — the first instruction WL controls each iteration (`guard_training_context.__enter__`). Every rank agrees on the consistent state before the loop body runs. + +## What needs to be consistent — and which way it flows + +| State | Direction | Why | +|----------------------------------------------------------------|-----------------|----------------------------------------------------| +| hyperparams, `pause_at_step`, `paused` | rank-0 → rank-1+| UI authors them on rank-0; ranks read-only | +| dataframe `DOWN_ONLY` (`discarded` — the deny-list) | rank-0 → rank-1+| UI mutates, ranks consume to derive the same shard | +| per-sample signals, loss/metric scalars, `last_seen` writes | rank-1+ → rank-0| each rank trains its shard; rank-0 holds the global view | + +Rank-0 is the **single source of truth**; rank-1+ hold reconciled copies sufficient for their shard. + +## Mechanism, by direction + +**DOWN — one broadcast, all consistent states.** Rank-0 builds a snapshot of every registered consistent state and broadcasts it; children diff-apply. One collective per step regardless of how many states are registered. +→ API: `register_consistent_state(name, snapshot, apply)` + `reconcile_all()`. + +**UP — one gather, all per-sample writes.** Rank-1+ stages call-time parameters (e.g. `metric.update(sid, value)`) into a local **outbox**, never touching its own dataframe. The anchor gathers the lot once per step; rank-0 then re-issues the "consolidated call" with everyone's parameters as if it ran once globally. From the caller's view it's a normal function call; under DDP it accumulates locally and is re-issued on rank-0. +→ API: `register_outbox(name, local_dump, merge)` + `flush_outbox()`. + +Each outbox dumps a **delta**, not a full snapshot — only what changed on this rank since the last flush (changed dataframe rows; signal triples past a per-`(graph, exp_hash)` cursor). Otherwise the per-step gather carries the whole dataframe + whole signal history every step, so payload scales with `N_samples × world` and grows unboundedly — the budget below caps the *count* of collectives, not their *bytes*, so the delta is what keeps the bytes bounded too. The cache is process-local (each rank ships its own delta); on respawn/restore it resets to a one-time full resend, which is safe because every `merge` is idempotent. Delta merges seed rank-0's current value first so `MAX`/`UNION` never regress and `LATEST` still resolves to the newest write. + +**Deny-list enforcement — sampler-side, no extra channel.** The `discarded` column gates *which* samples train, and it's enforced entirely in the main-process sampler: a discarded sample is never yielded, so workers never receive it. The sampler's pandas deny-list cache refreshes whenever the origin's deny-list revision bumps (a discard bumps it), so a live discard is reflected within one index. A sample already handed to a worker's prefetch queue is dropped by **iterator invalidation**: when a `DOWN_ONLY` value actually changes, `dataframe_manager` flags every loader, and the next step rebuilds the iterator. With `persistent_workers=True` that rebuild **reuses** the worker processes — PyTorch's re-iter resets them and drops the stale prefetch — so a since-discarded queued sample never reaches the model *without paying a fork+dataset-reinit per discard*. The change is gated on an *actual* value diff — essential under DDP, where rank-1+ re-apply the same reconciled deny-list every step and must not rebuild the iterator each step. The DOWN reconcile ships a **delta** (only the sample-ids whose `discarded` changed since last step, full snapshot once on first reconcile / post-restore), so the broadcast is O(changed), not O(N). `discarded` is the *only* `DOWN_ONLY` column — tags are rank-0 UI state and never need to reach a rank's sampler. + +**Sharding — rebalance, not reshuffle.** Each rank's shard is the live set re-balanced across ranks: filter the fixed permutation (a pure function of `(ddp_seed, reshuffle_seq)`) to the non-discarded indices, pad to a multiple of `world`, then take a strided slice — `live[rank::world]` (`_ddp_rebalanced_shard`). Striding the *live* list spreads survivors evenly, so every rank's shard is **equal length** → identical batch count → the grad `all_reduce` can never deadlock waiting on a rank whose entire shard was discarded (the empty-shard-starvation case). This is a *rebalance*, not a *reshuffle*: the permutation is unchanged and each rank's relative order is preserved, so a discard/undiscard just re-derives the same permutation over the new live set — deterministic and reproducible across resets. A discard or undiscard rebuilds the iterator (above), so the new balance takes effect immediately, including when the live set **grows** (un-discard). `drop_last=False` under DDP keeps the final partial batch so a tiny live set still trains (progress) rather than dropping to zero. Cost: at most `world-1` padded duplicate encounters per pass — honest extra training events with a distinct `model_age`, not pollution. (Trade-off: a sample's owning rank shifts as the live set changes, so this is incompatible with pinned per-sample ownership; the UP outbox reconverges per-sample writes regardless of owner, so correctness is unaffected.) + +## Anchor + budget + +The anchor is split across the step's pre/post hooks, so each direction fires at its natural moment: + +1. `guard_training_context.__enter__` → `sync_step()` — the **DOWN** half: `reconcile_all()`, 1 broadcast of every consistent state, *before* the body consumes it (+ the collective pause spin). +2. `guard_training_context.__exit__` → `flush_outbox()` — the **UP** half: 1 gather of every per-sample write **delta**, at the step's *end*, so this step's writes publish with no one-step lag. Run unconditionally (even if the body raised) so every rank reaches the gather the same number of times — skipping it on one rank would desync the group. + +**Collective budget: ~2 rendezvous/step (+ grad all_reduce).** Everything else in WL stays local — read the reconciled value, stage to the outbox, log. A collective leaking into a hot path is a regression: `WL_DDP_LOG=1` traces who-did-what; `WL_DDP_COLLECTIVE_LOG=` records per-step counts so the invariant can be asserted in tests (`scenario_collective_budget`). The budget governs collective *count*; the outbox delta (above) is what keeps each collective's *payload* bounded by the per-step change set rather than the dataset size. diff --git a/weightslab/backend/dataloader_interface.py b/weightslab/backend/dataloader_interface.py index d9b16e94..e5699119 100644 --- a/weightslab/backend/dataloader_interface.py +++ b/weightslab/backend/dataloader_interface.py @@ -24,7 +24,7 @@ from torch.utils.data import DataLoader, Dataset, Sampler from weightslab.data.data_samples_with_ops import DataSampleTrackingWrapper -from weightslab.utils import filter_kwargs_for_callable, restore_rng_state +from weightslab.utils import filter_kwargs_for_callable, restore_rng_state, ddp_info from weightslab.backend.ledgers import ( register_dataloader, get_hyperparams, @@ -134,6 +134,15 @@ def __init__( # Evaluation-mode allow-list: when set, only samples whose uid is in # this set are yielded. None = no filter (normal behaviour). self._eval_allow_list: Optional[set] = None + # DDP sharding: per-rank shard = rebalanced live set (see _ddp_rebalanced_shard). + # (0, 1) when not under DDP -> single-process path unchanged. + self._ddp_rank, self._ddp_world_size = ddp_info() + self._ddp_seed = int(os.environ.get("WL_DDP_SEED", "0")) + # Reshuffle generation: shard permutation = f(ddp_seed, reshuffle_seq, rank, + # world). Advances only on a genuine pass-end, never on a discard's iterator + # reset (a discard re-filters the SAME permutation, no reshuffle). Restore + # reproduces the stream from (seed, seq, offset, checkpointed deny-list). + self._reshuffle_seq = 0 def _get_deny_listed_uids(self, origin: str = None) -> set: """Get set of deny-listed UIDs from tracked dataset.""" @@ -220,18 +229,89 @@ def _is_deny_listed(self, idx: int) -> bool: return uid in self._refresh_deny_list_cache() - def _generate_indices(self): - """Generate base indices (shuffled or sequential).""" + def set_epoch(self, epoch: int) -> None: + """Back-compat alias for set_reshuffle_seq (mirrors DistributedSampler's + set_epoch name). Sets the reshuffle generation directly.""" + self._reshuffle_seq = int(epoch) + + def advance_reshuffle(self) -> None: + """Bump the reshuffle generation so the NEXT generated pass uses a new + permutation. Called by the loader ONLY on a genuine epoch-completion + reset — never on a discard/tag invalidation reset (which must keep the + same permutation and merely re-filter).""" + self._reshuffle_seq += 1 + + def restore_reshuffle_seq(self, seq: int, seed=None) -> None: + """Restore the reshuffle generation from a checkpoint so the next pass + reproduces the exact per-rank permutation that was live at save time. + No-op outside DDP. Warns on a seed mismatch — the permutation is a + function of (seed, reshuffle_seq), so a different seed can't reproduce.""" + if self._ddp_world_size <= 1: + return + if seed is not None and int(seed) != int(self._ddp_seed): + logger.warning( + "[ddp] restore seed mismatch (saved=%s current=%s); per-rank " + "shard order will NOT reproduce the saved run.", seed, self._ddp_seed) + self._reshuffle_seq = int(seq) + + def _ddp_rebalanced_shard(self): + """This rank's live shard: filter the fixed permutation (f(ddp_seed, + reshuffle_seq)) to the non-discarded set, pad to a multiple of world, then + stride `live[rank::world]`. Equal-length shards by construction -> matched + batch counts -> no empty-shard grad-allreduce deadlock. Order-preserving, + deterministic, non-advancing (so __len__/snapshots may call it freely). + + TRAINING ONLY: this shards eval loaders too, which is wrong under DDP (the + per-step anchor is training-only, so a sharded eval undercounts). Resolve + the eval sharding/aggregation policy before adding a DDP eval loop. + TODO(ddp-eval). + """ n = len(self.data_source) if self.shuffle: - indices = torch.randperm(n).tolist() + g = torch.Generator() + g.manual_seed(int(self._ddp_seed) + int(self._reshuffle_seq)) + full = torch.randperm(n, generator=g).tolist() else: - indices = list(range(n)) - return indices + full = list(range(n)) + live = list(self._iter_filtered_indices(full)) + world = int(self._ddp_world_size) + if live: # pad so len % world == 0 + live = live + live[:(-len(live)) % world] + return live[int(self._ddp_rank)::world] + + def _rank_indices_snapshot(self): + """This rank's indices for the current reshuffle generation, WITHOUT + advancing it. Used by __len__ and yolo_pipeline's ownership snapshot.""" + if self._ddp_world_size > 1: + return self._ddp_rebalanced_shard() + return list(range(len(self.data_source))) + + def _generate_indices(self): + """Base indices for a pass. Under DDP this is the rank's filtered + + balanced shard (see _ddp_rebalanced_shard); it does NOT auto-advance the + reshuffle generation (advance_reshuffle, called only on a genuine pass + end, does), so a mid-loop discard rebuilds onto the SAME permutation + re-balanced over the new live set, never a reshuffle.""" + if self._ddp_world_size > 1: + return self._ddp_rebalanced_shard() + n = len(self.data_source) + if self.shuffle: + return torch.randperm(n).tolist() + return list(range(n)) def _iter_filtered_indices(self, indices): - """Yield indices lazily so new discards are respected mid-epoch.""" + """Yield indices lazily so new discards are respected mid-epoch. + + The deny-list is enforced here, in the main-process sampler: a discarded + sample is simply never yielded. The pandas deny-list cache is refreshed + whenever the origin's deny-list revision bumps (a discard bumps it), so a + live discard is reflected within one index. Samples already yielded into a + worker prefetch queue are dropped separately by iterator invalidation + (dataframe_manager → loader._invalidate_iter → worker teardown). + """ skipped = 0 + _shard_dbg = os.environ.get("WL_DDP_SHARD_DEBUG") == "1" + _yielded = 0 unique_ids = getattr(self.tracked_dataset, "unique_ids", None) # Prefer physical_uids: after grouped indexing __len__ returns # len(physical_uids) so idx is a physical index; unique_ids is the @@ -276,58 +356,61 @@ def _iter_filtered_indices(self, indices): skipped += 1 continue + _yielded += 1 yield idx + if _shard_dbg: + print(f"[shard_dbg r{self._ddp_rank}/{self._ddp_world_size}] epoch " + f"shard_in={len(indices)} -> yielded={_yielded} " + f"deny={len(deny_listed_uids)}", flush=True) + def __iter__(self): """Iterate over indices or batches of indices.""" indices = self._generate_indices() - filtered_indices = self._iter_filtered_indices(indices) + if self._ddp_world_size > 1: + # Already this rank's filtered+balanced shard; iterate directly (re-filtering + # could shrink it mid-pass and desync batch counts — a discard rebuilds instead). + idx_source = iter(indices) + else: + idx_source = self._iter_filtered_indices(indices) if self.batch_size is None: - yield from filtered_indices + yield from idx_source else: batch = [] - for idx in filtered_indices: + for idx in idx_source: batch.append(idx) if len(batch) >= int(self.batch_size): yield list(batch) batch = [] - if batch and not self.drop_last: + # Under DDP always emit the final partial: shards are equal length so it's the + # same size on every rank, and dropping it would stall a tiny live set. + if batch and (self._ddp_world_size > 1 or not self.drop_last): yield list(batch) def __len__(self): """Return the number of samples or batches.""" - # In evaluation mode with an allow-list, compute the exact filtered - # cardinality so progress/timeout logic uses the real bounded set size. - if self._eval_allow_list is not None: - total = sum(1 for _ in self._iter_filtered_indices(list(range(len(self.data_source))))) - - if self.batch_size is not None: - b = max(1, int(self.batch_size)) - if self.drop_last: - return total // b - return (total + b - 1) // b - - return total - - # Start with total dataset size - total = len(self.data_source) - - # Subtract deny-listed samples - deny_listed_uids = self._refresh_deny_list_cache() - total -= len(deny_listed_uids) - - # Subtract offset - total = max(0, total - self.offset) + # When a filter bounds the set (eval allow-list or a DDP shard), count the + # exact filtered cardinality over a non-advancing snapshot. + if self._ddp_world_size > 1: + # already filtered+balanced; count directly (re-filtering would double-skip) + total = len(self._rank_indices_snapshot()) + elif self._eval_allow_list is not None: + total = sum(1 for _ in self._iter_filtered_indices(self._rank_indices_snapshot())) + else: + # Start with total dataset size, subtract deny-listed and offset. + total = len(self.data_source) + total -= len(self._refresh_deny_list_cache()) + total = max(0, total - self.offset) - # If batching, return number of batches + # If batching, return number of batches. Under DDP we always keep the + # final partial batch (see __iter__), so count it with ceil there too. if self.batch_size is not None: b = max(1, int(self.batch_size)) - if self.drop_last: + if self.drop_last and self._ddp_world_size <= 1: return total // b - else: - return (total + b - 1) // b + return (total + b - 1) // b return total @@ -509,6 +592,12 @@ def __init__( ) num_workers = _resolve_safe_num_workers(self.tracked_dataset, num_workers, loader_name) + # persistent_workers: reuse workers across iterator resets so a + # discard/undiscard rebalance is a cheap re-iter (drains stale prefetch), + # not a fork. Safe because the deny-list+rebalance live in the main-process + # sampler — workers just fetch by index. Requires num_workers > 0. + persistent_workers = num_workers > 0 and bool(kwargs.pop("persistent_workers", True)) + # Finally, construct dataloader using our batch_sampler self.dataloader = DataLoader( self.tracked_dataset, @@ -529,6 +618,7 @@ def __init__( "drop_last": drop_last, "pin_memory": pin_memory, "collate_fn": collate_fn, + "persistent_workers": persistent_workers, } self._dl_build_kwargs.update(kwargs or {}) @@ -550,6 +640,12 @@ def __init__( self._samples_yielded: int = 0 self._sample_offset: int = 0 self._skipped = [] + # Flag set by dataframe_manager.upsert_df whenever a DOWN_ONLY column changes + # (UI discard / tag). The next __next__ call resets the iterator BEFORE + # pulling — workers + their prefetched-but-not-yet-consumed batches are + # shut down. Without this, a sample yielded by the sampler PRE-discard + # but still sitting in a worker's queue gets trained on POST-discard. + self._iter_invalidated: bool = False # Optionally register in the global ledger for cross-thread access. # If no explicit `loader_name` is provided, try to infer a friendly loader_name from @@ -844,9 +940,26 @@ def __next__(self) -> Any: """ self._sync_batch_size_from_ledger() - # If the previous epoch ended, reset for the next one + # DOWN_ONLY-change invalidation: a UI discard / tag wrote a column that + # affects WHICH samples are valid to train on. Any indices already + # yielded by the sampler and sitting in worker prefetch queues are now + # stale — they were chosen against the OLD deny-list. Reset first so + # the model never trains on a since-discarded sample. (Real fix, not a + # post-hoc filter: the forward pass on the discarded sample never runs.) + if getattr(self, '_iter_invalidated', False): + self._iter_invalidated = False + logger.debug("[DataLoaderInterface] iter invalidated by DOWN_ONLY change; resetting workers") + self._reset_iterator() + + # If the previous epoch ended, reset for the next one. THIS is the only + # place the DDP shard reshuffles — a genuine pass completion advances the + # reshuffle generation so the next pass gets a new permutation. (The + # invalidation reset above deliberately does NOT advance it.) if getattr(self, '_epoch_exhausted', False): logger.debug("Auto-resetting iterator for next epoch") + s = getattr(self, '_mutable_batch_sampler', None) + if s is not None and hasattr(s, 'advance_reshuffle'): + s.advance_reshuffle() self._reset_iterator() self._epoch_exhausted = False @@ -858,6 +971,17 @@ def __next__(self) -> Any: self._epoch_exhausted = True raise + def _invalidate_iter(self) -> None: + """Mark the active iterator as stale. Called from dataframe_manager + whenever a DOWN_ONLY column changes (the sampler-time filter only + protects FUTURE yields; this drops the prefetched queue too). + + Safe to call repeatedly: only flips a flag — the actual worker shutdown + happens inside the next __next__ call's _reset_iterator() (via the + existing del + gc.collect() path which the iter destructor uses to kill + worker subprocesses cleanly).""" + self._iter_invalidated = True + # ------------------------------------------------------------------------- # Ledger / pause helpers # ------------------------------------------------------------------------- @@ -1027,8 +1151,14 @@ def _reset_iterator(self) -> None: if respawning: time.sleep(0.01) # 10ms delay for worker cleanup - # Create new iterator + # Create new iterator (persistent: reuses workers + resets; else: respawns) self._iterator = iter(self.dataloader) + # Clear the invalidate flag — this reset just satisfied it. Without this, + # the next __next__ would do a SECOND reset (load_state calls reset_iterator + # explicitly, AND our upsert hook may have set _iter_invalidated during the + # snapshot apply — so without clearing here, we'd shut down + restart workers + # twice back-to-back. Costly under num_workers>0). + self._iter_invalidated = False logger.debug(f"Created new iterator (num_workers={getattr(self.dataloader, 'num_workers', 'unknown')}, sampler_len={len(self._mutable_batch_sampler) if self._mutable_batch_sampler else 'N/A'})") def reset_iterator(self) -> None: @@ -1056,10 +1186,21 @@ def capture_iteration_state(self) -> dict: boundary. Works with and without shuffling. When shuffling, ensure RNG state is also captured/restored before calling `restore_iteration_state`. """ - return { + state = { "samples_yielded": int(self._samples_yielded), "batch_size": self.batch_size or 1 } + # DDP: save the reshuffle generation + seed so restore reproduces the + # exact per-rank permutation (DistributedSampler shuffles from (seed, + # reshuffle_seq), which global RNG capture/restore does NOT cover). The + # deny-list that filters the permutation is checkpointed separately (a + # DOWN_ONLY df column), so (seed, reshuffle_seq, samples_yielded, + # deny-list) together reproduce the filtered stream across a reset. + s = getattr(self, "_mutable_batch_sampler", None) + if s is not None and getattr(s, "_ddp_world_size", 1) > 1: + state["ddp_reshuffle_seq"] = int(getattr(s, "_reshuffle_seq", 0)) + state["ddp_seed"] = int(getattr(s, "_ddp_seed", 0)) + return state def restore_iteration_state(self, state: dict) -> None: """Restore iteration position efficiently without reprocessing skipped data. @@ -1080,6 +1221,15 @@ def restore_iteration_state(self, state: dict) -> None: # Calculate sample offset (how many individual samples to skip) sample_offset = samples_yielded + # DDP: restore the reshuffle generation onto the LIVE sampler now, so even + # the no-rebuild path (offset == 0, pass boundary) reproduces the saved + # per-rank permutation. The rebuild branch re-applies it to its new sampler. + ddp_seq = state.get("ddp_reshuffle_seq") + if ddp_seq is not None: + live = getattr(self, "_mutable_batch_sampler", None) + if live is not None and hasattr(live, "restore_reshuffle_seq"): + live.restore_reshuffle_seq(ddp_seq, state.get("ddp_seed")) + # If we own the dataloader construction, rebuild with offset sampler if getattr(self, "_dl_build_kwargs", None) is not None and sample_offset > 0: try: @@ -1104,6 +1254,9 @@ def restore_iteration_state(self, state: dict) -> None: batch_size=batch_size, drop_last=drop_last, ) + # Carry the restored reshuffle generation onto the new sampler. + if ddp_seq is not None and hasattr(sampler, "restore_reshuffle_seq"): + sampler.restore_reshuffle_seq(ddp_seq, state.get("ddp_seed")) self._mutable_batch_sampler = sampler self._sample_offset = 0 num_workers = _resolve_safe_num_workers( @@ -1111,6 +1264,9 @@ def restore_iteration_state(self, state: dict) -> None: num_workers, getattr(self, "_ledger_name", None), ) + # Keep persistent_workers consistent with the resolved worker count + # (persistent_workers=True requires num_workers>0, else DataLoader raises). + kwargs["persistent_workers"] = num_workers > 0 and bool(kwargs.get("persistent_workers", False)) # Rebuild dataloader with offset sampler self.dataloader = DataLoader( @@ -1187,6 +1343,9 @@ def set_batch_size(self, new_batch_size: int) -> None: num_workers, getattr(self, "_ledger_name", None), ) + # Keep persistent_workers consistent with the resolved worker count + # (persistent_workers=True requires num_workers>0, else DataLoader raises). + kwargs["persistent_workers"] = num_workers > 0 and bool(kwargs.get("persistent_workers", False)) # Rebuild sampler & dataloader if we had one if getattr(self, "_mutable_batch_sampler", None) is not None: diff --git a/weightslab/backend/logger.py b/weightslab/backend/logger.py index dc077e6b..f8765229 100644 --- a/weightslab/backend/logger.py +++ b/weightslab/backend/logger.py @@ -699,6 +699,33 @@ def get_signal_history_per_sample(self): }) return result + def ingest_per_sample(self, graph_name, exp_hash, triples): + """Merge external per-sample (sample_id, step, value) triples into the + per-sample history. Idempotent by (sample_id, step) — re-ingesting the same + triples is a no-op. Used to fold per-sample signals (e.g. loss) computed on + OTHER DDP ranks into rank 0's logger so Break-By-Slice plots cover the whole + universe, not just rank 0's shard.""" + if not triples: + return + self.graph_names.add(graph_name) + self._signal_history_per_sample.setdefault(graph_name, {}) + if exp_hash not in self._signal_history_per_sample[graph_name]: + self._signal_history_per_sample[graph_name][exp_hash] = _make_per_sample_buf() + buf = self._signal_history_per_sample[graph_name][exp_hash] + idx_map = self._sample_index.setdefault(graph_name, {}).setdefault(exp_hash, {}) + seen = set(zip(buf["sample_ids"], buf["steps"])) + for sid, step, val in triples: + sid_s = str(sid) + key = (sid_s, int(step)) + if key in seen: + continue + row = len(buf["sample_ids"]) + buf["sample_ids"].append(sid_s) + buf["steps"].append(int(step)) + buf["values"].append(float(val)) + idx_map.setdefault(sid_s, []).append(row) + seen.add(key) + def get_current_signaL_history_per_sample(self, graph_name: str, sample_ids: list = None, exp_hash: str = None): """Get current-hash per-sample history for a specific signal.""" if graph_name not in self.graph_names: diff --git a/weightslab/components/checkpoint_manager.py b/weightslab/components/checkpoint_manager.py index 6dfca3c8..67886890 100644 --- a/weightslab/components/checkpoint_manager.py +++ b/weightslab/components/checkpoint_manager.py @@ -420,8 +420,13 @@ def get_current_experiment_hash(self) -> Optional[str]: return self.current_exp_hash + # Experiment-STATE keys: runtime state, not config. Excluded from the saved HP + # snapshot so a restore's register_hyperparams(saved_config) can't resurrect them + # (e.g. overwrite the live is_training). Same set excluded from the experiment hash. + _STATE_ONLY_HP = ("is_training", "pause_at_step", "root_log_dir") + def get_HP_snapshot(self) -> Dict[str, Any]: - """Get current hyperparameters snapshot from ledger.""" + """Get current hyperparameters snapshot from ledger (excluding experiment state).""" try: hp_name = ledgers.resolve_hp_name() hp = ledgers.get_hyperparams(hp_name) @@ -430,11 +435,14 @@ def get_HP_snapshot(self) -> Dict[str, Any]: if isinstance(hp, ledgers.Proxy) and hasattr(hp, 'get') and callable(hp.get): hp = hp.get() if isinstance(hp, dict): - return hp + snap = dict(hp) elif hasattr(hp, '__dict__'): - return vars(hp) + snap = dict(vars(hp)) else: return {} + for k in self._STATE_ONLY_HP: # copy first, then strip — never mutate the ledger + snap.pop(k, None) + return snap except Exception: return {} @@ -663,10 +671,17 @@ def update_experiment_hash( rng_state = capture_rng_state() restore_rng_state(rng_state) - # Reset dataloader iterators to sync with new state + # Reset dataloader iterators to sync with new state. + # Lazy invalidate-flag path preferred (see notes above) — avoids + # tearing workers down from a non-owning thread. for loader_name in get_dataloaders(): loader = get_dataloader(loader_name) - if hasattr(loader, 'reset_iterator') and callable(loader.reset_iterator): + if loader is None: + continue + inv = getattr(loader, '_invalidate_iter', None) + if callable(inv): + inv() + elif hasattr(loader, 'reset_iterator') and callable(loader.reset_iterator): loader.reset_iterator() logger.debug(f"Reset iterator for dataloader: {loader_name}") except Exception as e: diff --git a/weightslab/components/global_monitoring.py b/weightslab/components/global_monitoring.py index 4dcfa184..bda2b981 100644 --- a/weightslab/components/global_monitoring.py +++ b/weightslab/components/global_monitoring.py @@ -166,6 +166,12 @@ def resume(self, force: bool = False) -> bool: def is_paused(self): return not self._event.is_set() + def wait_for_resume(self, timeout=None): + """Block until resumed (the resume Event is set), waking the instant the + gRPC resume handler fires. Returns True if resumed, False on timeout. Lets + the DDP pause-anchor wait on the resume signal instead of busy-polling.""" + return self._event.wait(timeout) + def _get_checkpoint_manager(self): if self.checkpoint_manager is None: self.checkpoint_manager = get_checkpoint_manager() @@ -218,11 +224,40 @@ def __enter__(self, f: bool = False): """ Executed upon entering the 'with' block. Sets the model to training mode. """ - self._maybe_pause_at_step() - if not is_in_evaluation(): - if f: - pause_controller.resume(force=f) - pause_controller.wait_if_paused() + # Per-step anchor + pause control plane. + # single-process: rank-0-only pause tick + eval-aware blocking wait + # (1.2.3 behaviour preserved). + # DDP (world > 1): rank-0 ticks pause state; every rank then enters + # sync_step — ONE bundled broadcast of every consistent state + # (hparams + deny-list + paused) followed by a collective spin if + # paused. No rank ever blocks alone (that would deadlock the grad + # all-reduce). Core states auto-register on first call so train.py + # never sees the DDP plumbing. + in_ddp = False + if self.for_training: + try: + from weightslab.utils import ddp_info + in_ddp = ddp_info()[1] > 1 + if in_ddp and ddp_info()[0] == 0: + self._maybe_pause_at_step() # rank 0 is the pause authority + except Exception as exc: + logger.debug("[GuardContext] ddp probe failed: %s", exc) + + # __exit__ runs the UP outbox flush (END-of-step half of the anchor); record mode. + self._in_ddp = in_ddp + if in_ddp: + from weightslab.components.parallel_primitives import ( + _ensure_core_ddp_registered, sync_step, + ) + _ensure_core_ddp_registered() # idempotent; no-op after first + sync_step() # DOWN reconcile (+ collective pause spin) + else: + # single-process: original eval-aware blocking pause + self._maybe_pause_at_step() + if not is_in_evaluation(): + if f: + pause_controller.resume(force=f) + pause_controller.wait_if_paused() self.architecture_guard.__enter__() # Set the current context for this execution @@ -281,6 +316,16 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any, f: bool = Fals Executed upon exiting the 'with' block (after user code runs). Reverts the model state. """ + # UP half of the per-step anchor: gather this step's per-sample write + # deltas (rank-1+ → rank-0). UNCONDITIONAL even if the body raised — every + # rank runs __exit__, so all reach this collective the same number of times; + # skipping it on one rank would desync/hang the group. No-op outside DDP. + if getattr(self, "_in_ddp", False): + try: + from weightslab.components.parallel_primitives import flush_outbox + flush_outbox() # UP: 1 gather, all write deltas + except Exception as exc: + logger.debug("[GuardContext] outbox flush failed: %s", exc) if f: pause_controller.pause() diff --git a/weightslab/components/parallel_state.py b/weightslab/components/parallel_state.py new file mode 100644 index 00000000..dbabf6cd --- /dev/null +++ b/weightslab/components/parallel_state.py @@ -0,0 +1,461 @@ +"""DDP planes & reducers — the entire WL cross-rank surface in 4 named concepts. + +This module is the home for **what crosses ranks**. Once a value fits into one of +the four planes below, no train.py code is needed for it to be correctly +synchronized under DDP — the SDK's `_ensure_core_ddp_registered` hooks each +plane's local_dump/merge (or snapshot/apply) into the per-step anchor. + +The 4 planes (DOWN = reconcile broadcast, UP = outbox gather; see + parallel_primitives + docs/ddp_design.md → "Mechanism, by direction") +============ + CONFIG ↓ DOWN reconcile hparams rank-0 authority; no reducer + CONTROL ↓ DOWN reconcile paused, tracking, contexts rank-0 authority; no reducer + DATAFRAME ↕ both ways per-sample columns DOWN reconcile (deny-list, + tags) + UP outbox (last_seen, + counters, …) via dtype-keyed + reducers + LOGGER ↑ UP outbox per-sample signal history idempotent ingest keyed by + (sid, step, exp_hash); no reducer + +Reducers (only the DATAFRAME plane needs them) +============================================== + MAX numeric / bool / timestamp monotonic upward (last_seen, counters, + True-wins). Stateless and IDEMPOTENT — + the only retry-safe choice for counters. + LATEST scalar string / categorical last writer wins (rank order is determ.) + UNION list / tuple / set concat / set-union (tag lists) + RANK_0_ONLY any DOWN-only column — NEVER read UP. The one + place column names appear in the plumbing. + IGNORE any local-only — never crosses ranks (debug) + +Auto-classification (default policy by pandas dtype, no per-column config needed): + bool / numeric / datetime → MAX + object / string → LATEST + list / set → UNION (resolved at value level when needed) + +Adding a new per-sample column: name it with a sensible dtype and it just works. +Adding a new DOWN-only column: append to `DOWN_ONLY`. That's the only edit. +""" +import logging + +import pandas as pd + +logger = logging.getLogger(__name__) + + +# ============================================================================ +# REDUCER REGISTRY — stateless, idempotent, retry-safe by design. +# ============================================================================ +def _r_max(s): + d = s.dropna() + return d.max() if not d.empty else None + + +def _r_latest(s): + d = s.dropna() + return d.iloc[-1] if not d.empty else None + + +def _r_union(s): + out = set() + for v in s.dropna(): + if isinstance(v, (list, tuple, set)): + out.update(v) + else: + out.add(v) + return sorted(out) if out else None + + +REDUCERS = {"MAX": _r_max, "LATEST": _r_latest, "UNION": _r_union} + +# The ONLY place column names appear in the cross-rank plumbing. DOWN-only: rank-0 +# sets it, the reconcile broadcasts it to children, children never read it back UP. +# Just the deny-list: rank-1+ need `discarded` to derive the SAME live shard as rank-0 +# (else shards desync -> grad all_reduce deadlock). Tags don't ride — they're rank-0 +# UI/curation state (the tag->label override is vestigial); tag queries gather signals +# UP and filter on rank-0. ("user_tags" used to sit here but was never a real column.) +DOWN_ONLY = {"discarded"} + + +def policy_for(col, dtype): + """Default reducer policy by pandas dtype. Caller is responsible for + pre-filtering DOWN_ONLY columns; anything that reaches here is UP-flowable.""" + if pd.api.types.is_bool_dtype(dtype): + return "MAX" + if pd.api.types.is_numeric_dtype(dtype): + return "MAX" + if pd.api.types.is_datetime64_any_dtype(dtype): + return "MAX" + return "LATEST" + + +# ============================================================================ +# DATAFRAME PLANE — schema-agnostic local dump + dtype-keyed merge. +# Mirrors the structure of the LOGGER plane below, so the two planes share the +# same registration shape: (local_dump, merge). +# ============================================================================ +# Object-dtype columns whose contents are scalar strings — safe to gather and +# reduce with LATEST. Anything outside this list and not numeric/bool/datetime +# gets DROPPED at gather-time (tensors, dicts of tensors, ndarrays etc. would +# either fail to pickle cleanly across ranks or produce pandas dtype-mismatch +# warnings on the merge upsert). +_OBJECT_GATHER_ALLOWLIST: set[str] = {"origin", "group_id"} + + +def _is_gather_safe_column(col: str, dtype) -> bool: + """A column is safe to ship across ranks iff its cells reduce cleanly under + the reducer table. Numeric/bool/datetime → MAX. Object dtype is rejected + UNLESS it's a known scalar-string column, or a tag-flag column (`tag.*`), + or a list/set column (UNION-mergeable). Tensors / arrays / dicts of arrays + silently drop here — they're either signal-plane traffic (handled by the + logger outbox) or array-store traffic (handled by H5ArrayStore on rank-0). + """ + if pd.api.types.is_bool_dtype(dtype): + return True + if pd.api.types.is_numeric_dtype(dtype): + return True + if pd.api.types.is_datetime64_any_dtype(dtype): + return True + if col in _OBJECT_GATHER_ALLOWLIST: + return True + if col.startswith("tag.") or col.startswith("tags."): + return True + return False + + +# Per-rank delta state. The outbox ships only what CHANGED since the last flush, +# not the whole dataframe / whole signal history every step — otherwise per-step +# cost scales with dataset size (df) and grows unboundedly (signals), which is +# the real scaling wall behind the "~2 collectives/step" budget (the budget +# counts rendezvous, not bytes). The change-set is sourced from the dataframe +# manager's outbox-dirty set (sids the per-sample writers touched since the last +# flush) — NOT a snapshot diff — so building the delta is O(changes) and there is +# no fragile signature comparison. The signal cursor is the same idea for the +# append-only signal buffers. +_SIGNAL_CURSOR: dict = {} # (graph, exp_hash) -> count already sent + + +def reset_outbox_state(): + """Drop the per-rank delta cursors so the next flush re-sends everything. + Called from clear_registry (tests) and safe to call on experiment reset.""" + _SIGNAL_CURSOR.clear() + + +def local_df_writes(): + """This rank's per-sample dataframe DELTA — gather-safe columns, ONLY the rows + whose per-sample UP values changed since the last flush. The change-set comes + from the dataframe manager's outbox-dirty set (populated by the per-sample + writers: enqueue_batch / update_by_groups_bulk), so there's no whole-dataframe + snapshot diff — building the delta is O(changes), which is what keeps the + per-step gather small (the dataframe is pre-seeded with ALL sample_ids, so a + full scan would ship ~every row every step). + + Schema-agnostic: no column NAMES baked in (only the DOWN_ONLY filter + the + object-allowlist). Tensors / dicts / arrays are skipped — they don't reduce + cleanly under our reducer table. Reads `get_combined_df` so the manager's + unflushed buffer is included, then narrows to the dirty rows. + """ + from weightslab.backend.ledgers import get_dataframe + try: + dfm = get_dataframe() + dirty = dfm.drain_outbox_dirty() + except Exception: + return None + if not dirty: + return None + try: + df = dfm.get_combined_df(return_proxies=False) + except Exception: + return None + if df is None or getattr(df, "empty", True): + return None + df = df.copy() + if "sample_id" not in df.columns: + if isinstance(df.index, pd.MultiIndex): + df["sample_id"] = [t[-1] for t in df.index] + else: + df["sample_id"] = df.index + df["sample_id"] = df["sample_id"].astype(str) + df = df.reset_index(drop=True) + df = df[df["sample_id"].isin(dirty)] # only the changed rows + if df.empty: + return None + # Drop DOWN-only columns (they flow ↓ not ↑) + any column whose cells are + # tensors / dicts / arrays (would mangle the merge upsert). + keep = ["sample_id"] + for c in df.columns: + if c == "sample_id" or c in DOWN_ONLY: + continue + if _is_gather_safe_column(c, df[c].dtype): + keep.append(c) + df = df[keep] + return df.to_dict(orient="records") or None + + +# ---------------------------------------------------------------------------- +# DATAFRAME-DOWN — broadcast rank-0's values for every DOWN_ONLY column. +# Replaces what used to be a column-specific "deny-list" reconcile: now adding +# a new DOWN-only column is a single DOWN_ONLY entry, zero plumbing. +# ---------------------------------------------------------------------------- +def rank0_df_down_state(): + """Rank-0's DOWN_ONLY values as {col: {sample_id: value}} for children to mirror + (apply_df_down_state). DELTA: ships only sample-ids changed since the last + reconcile (drain_down_delta), with one full snapshot on first reconcile / post- + restore so children converge before deltas — keeps the broadcast O(changed), + not O(N). Non-null values only (an un-discard rides as False; truly-unset cells + don't pollute children).""" + from weightslab.backend.ledgers import get_dataframe + try: + dfm = get_dataframe() + df = dfm.get_combined_df(return_proxies=False) if dfm is not None else None + except Exception: + return None + if df is None or getattr(df, "empty", True): + return None + cols = [c for c in DOWN_ONLY if c in df.columns] + if not cols: + return None + full, dirty = dfm.drain_down_delta() + if "sample_id" in df.columns: + sids = [str(s) for s in df["sample_id"].tolist()] + elif isinstance(df.index, pd.MultiIndex): + sids = [str(t[-1]) for t in df.index] + else: + sids = [str(s) for s in df.index] + if full: + want = None # everything + elif not dirty: + return None # nothing changed this step + else: + want = set(str(s) for s in dirty) + out = {} + for col in cols: + vals = df[col].tolist() + out[col] = {sid: v for sid, v in zip(sids, vals) + if (want is None or sid in want) and pd.notna(v)} + return out if any(out.values()) else None + + +def apply_df_down_state(state): + """Children: replace local DOWN_ONLY columns with rank-0's values via + upsert_df. Idempotent. NO direct call to discard_samples / column-specific + helpers — the column name is purely data here.""" + if not state: + return + from weightslab.backend.ledgers import get_dataframe + dfm = get_dataframe() + if dfm is None: + return + rows = {} + for col, sid_to_val in state.items(): + for sid, val in (sid_to_val or {}).items(): + rows.setdefault(str(sid), {})[col] = val + if not rows: + return + df = pd.DataFrame.from_dict(rows, orient="index") + df.index.name = "sample_id" + try: + dfm.upsert_df(df, force_flush=True) + except Exception as exc: + logger.debug("[df_down] apply upsert failed: %s", exc) + + +# ============================================================================ +# CONFIG PLANE — rank-0 hyperparams ↓ (no reducer; single source of truth) +# ============================================================================ +def _proxy_to_plain(obj): + """Recursively convert a hyperparams Proxy / nested dict to plain picklable data.""" + if hasattr(obj, "items"): + return {k: _proxy_to_plain(v) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + return [_proxy_to_plain(v) for v in obj] + return obj + + +def _flatten_hparams(d, prefix=""): + """Flatten a nested dict to {dot.key.path: leaf_value}.""" + out = {} + if not hasattr(d, "items"): + return out + for k, v in d.items(): + key = f"{prefix}.{k}" if prefix else str(k) + if isinstance(v, dict): + out.update(_flatten_hparams(v, key)) + else: + out[key] = v + return out + + +def rank0_hparams(): + """Rank-0's hyperparams as a plain nested dict (the authoritative config + live edits).""" + from weightslab.backend.ledgers import get_hyperparams, resolve_hp_name + try: + hp = get_hyperparams(resolve_hp_name()) + return _proxy_to_plain(hp) if hp is not None else {} + except Exception: + return {} + + +def apply_hparams(hp): + """Children: apply rank-0's hyperparams (changed leaves only). lr/batch_size are + read from local hyperparams each step (optimizer.step, _sync_batch_size_from_ledger), + so syncing the dict is enough for live edits to take effect identically on every rank.""" + if not hp: + return + from weightslab.backend.ledgers import get_hyperparams, resolve_hp_name, set_hyperparam + try: + cur = _flatten_hparams(_proxy_to_plain(get_hyperparams(resolve_hp_name()))) + except Exception: + cur = {} + for key_path, val in _flatten_hparams(hp).items(): + if cur.get(key_path) != val: + try: + set_hyperparam(key_path, val) + except Exception: + pass + + +# ============================================================================ +# LOGGER PLANE — per-sample signal history ↑ (idempotent ingest; no reducer) +# Signal entries are keyed by (graph, exp_hash, sample_id, step), so re-ingesting +# the same triple is a no-op — retries are safe by construction. +# ============================================================================ +def local_signal_triples(): + """{graph: {exp_hash: [(sid, step, val)]}} of per-sample signals on THIS rank, + DELTA only — triples appended since the last flush. + + The per-sample buffers are append-only typed arrays, so a per-(graph, exp_hash) + cursor (count already sent) gives a truly incremental slice in O(new) — reading + the raw buffer directly rather than reconstructing the whole history each step + (which is O(total) and grows every step). On restore the buffer may be rebuilt + SHORTER than the cursor; we detect that (cur_len < cursor) and resend from 0. + """ + from weightslab.backend.ledgers import get_logger + out = {} + try: + hist = get_logger()._signal_history_per_sample or {} + except Exception: + return out + for graph, by_hash in hist.items(): + graph_out = {} + for exp_hash, buf in by_hash.items(): + sids = buf["sample_ids"] + cur_len = len(sids) + key = (graph, exp_hash) + start = _SIGNAL_CURSOR.get(key, 0) + if start > cur_len: # buffer shrank (restore/clear) → resend all + start = 0 + if start >= cur_len: + continue # nothing new for this graph/hash + steps = buf["steps"] + vals = buf["values"] + graph_out[exp_hash] = [ + (str(sids[i]), int(steps[i]), float(vals[i])) + for i in range(start, cur_len) + ] + _SIGNAL_CURSOR[key] = cur_len + if graph_out: + out[graph] = graph_out + return out + + +def merge_signal_triples_into_logger(maps): + """Rank-0: fold gathered per-sample signal triples into the logger (idempotent).""" + from weightslab.backend.ledgers import get_logger + try: + lg = get_logger() + except Exception: + return + for m in maps: + for graph, by_hash in (m or {}).items(): + for exp_hash, triples in by_hash.items(): + try: + lg.ingest_per_sample(graph, exp_hash, triples) + except Exception as exc: + logger.debug("[signals] ingest failed for %s: %s", graph, exc) + + +def _rank0_existing_seed(sample_ids, cols): + """Rank-0's CURRENT values for `cols` over `sample_ids`, as a records frame. + Prepended (existing-first) to the per-rank deltas before reducing so the + reducers fold against the authoritative value. Critical for deltas: a rank's + delta may omit a sample rank-0 already has a HIGHER value for — without the + seed, MAX/UNION would regress it (and the upsert would lower last_seen). With + the seed placed first, LATEST still resolves to the newest delta (later row), + and a sample with no delta this round simply keeps its existing value.""" + from weightslab.backend.ledgers import get_dataframe + try: + df = get_dataframe().get_combined_df(return_proxies=False) + except Exception: + return None + if df is None or getattr(df, "empty", True): + return None + # Index the wanted sids directly and copy ONLY those ~batch rows + delta cols — + # the old df.copy() duplicated the WHOLE frame every flush (O(N) per step, the + # same hidden scaling cost as the DOWN reconcile had). + if "sample_id" in df.columns: + sid_idx = pd.Index(df["sample_id"].astype(str)) + elif isinstance(df.index, pd.MultiIndex): + sid_idx = pd.Index([str(t[-1]) for t in df.index]) + else: + sid_idx = df.index.astype(str) + want = set(str(s) for s in sample_ids) + mask = sid_idx.isin(want) + if not mask.any(): + return None + keep = [c for c in cols if c in df.columns] + sub = df.loc[mask, keep].copy() + sub.insert(0, "sample_id", sid_idx[mask].to_numpy()) + return sub.to_dict(orient="records") + + +def merge_df_writes(parts): + """Rank-0 fold: per-column reducer apply. Concat all per-rank records, + groupby sample_id, apply the dtype-keyed reducer per column. Upsert into + rank-0's dataframe with force_flush so the DataService snapshot picks it up.""" + from weightslab.backend.ledgers import get_dataframe + frames = [] + for p in parts: + if not p: + continue + try: + frames.append(pd.DataFrame(p)) + except Exception: + continue + if not frames: + return + delta = pd.concat(frames, ignore_index=True) + if "sample_id" not in delta.columns: + return + # Seed with rank-0's existing values (existing-first) so the per-column + # reducers fold against the authoritative value — see _rank0_existing_seed. + cols = [c for c in delta.columns if c != "sample_id"] + seed = _rank0_existing_seed(delta["sample_id"].tolist(), cols) + seed_frame = pd.DataFrame(seed) if seed else None + big = pd.concat( + [f for f in (seed_frame, delta) if f is not None and not f.empty], + ignore_index=True, + ) + big = big.set_index("sample_id") + + # Vectorized fold: MAX->groupby.max(), LATEST->groupby.last() (both skipna, + # matching _r_max/_r_latest). One groupby.agg instead of a python reducer call + # per group per column — each group is <=2 rows (seed + the owning rank's write). + # policy_for only ever yields MAX/LATEST here (UNION is tags, which are DOWN-only). + agg = {} + for col in big.columns: + try: + agg[col] = "max" if policy_for(col, big[col].dtype) == "MAX" else "last" + except Exception: + agg[col] = "last" + if not agg: + return + try: + merged = big.groupby(level=0).agg(agg) + except Exception as exc: + logger.debug("[df outbox] vectorized reduce failed: %s", exc) + return + try: + get_dataframe().upsert_df(merged, force_flush=True) + except Exception as exc: + logger.debug("[df outbox] upsert failed: %s", exc) diff --git a/weightslab/data/dataframe_manager.py b/weightslab/data/dataframe_manager.py index 6dad1e7b..c209a95c 100644 --- a/weightslab/data/dataframe_manager.py +++ b/weightslab/data/dataframe_manager.py @@ -3,6 +3,8 @@ import traceback import threading import logging +import threading +import time import traceback import warnings import numpy as np @@ -93,6 +95,20 @@ def __init__(self, flush_interval: float = 3.0, flush_max_rows: int = 100, enabl self._array_store: H5ArrayStore | None = None self._origin_revisions: Dict[str, int] = {} self._pending: set[int] = set() + # Sample-ids with per-sample UP writes (signals / last_seen) since the last + # DDP outbox drain. Populated ONLY by the per-sample writers (enqueue_batch, + # update_by_groups_bulk) — NOT by upsert_df (merge-back / DOWN reconcile), + # so it's exactly this rank's UP change-set with no re-ship loop. Drained + # by the outbox each flush (drain_outbox_dirty). + self._outbox_dirty: set = set() + self._outbox_dirty_lock = threading.Lock() + # DOWN-plane delta: sample-ids whose DOWN_ONLY cells (discarded/tags) changed + # since the last reconcile, so rank-0 broadcasts only the change-set, not the + # whole deny-list every step. _down_full_pending forces ONE full snapshot + # (first reconcile / after a restore) so children converge before deltas. + self._down_dirty: set = set() + self._down_dirty_lock = threading.Lock() + self._down_full_pending = True self._force_flush = False self._flush_interval = flush_interval self._flush_max_rows = flush_max_rows @@ -382,6 +398,112 @@ def _collect_affected_origins(self, df_norm: pd.DataFrame, origin: str | None = return affected_origins + # ---- DOWN-only (deny-list / tags) change detection -------------------- + # The DDP plane lists DOWN_ONLY columns that flow rank-0 -> children via + # `reconcile_all`. When such a column actually changes (a UI discard / tag, + # or a child applying rank-0's reconciled snapshot), every registered + # loader's iterator is invalidated so its workers tear down and a + # since-discarded sample sitting in a prefetch queue is dropped before it can + # be trained on (the sampler also stops yielding it via the pandas deny-list + # check, which refreshes on the revision bump below). Invalidation is gated + # on an ACTUAL value change — critical under DDP, where rank-N re-applies the + # SAME deny-list snapshot every step and must not respawn workers each step. + def drain_outbox_dirty(self) -> set: + """Return and clear the set of sample-ids with fresh per-sample UP writes + since the last drain (the DDP outbox's change-set). O(changes). Empty set + means nothing changed → the outbox flush ships nothing this step.""" + with self._outbox_dirty_lock: + dirty = self._outbox_dirty + self._outbox_dirty = set() + return dirty + + @staticmethod + def _down_only_columns() -> set[str]: + """Source of truth: weightslab.components.parallel_state.DOWN_ONLY. Lazy + import to avoid a circular dependency (planes -> ledgers -> here).""" + try: + from weightslab.components.parallel_state import DOWN_ONLY + return set(DOWN_ONLY) + except Exception: + return {"discarded", "user_tags"} # safe default; matches planes + + @staticmethod + def _cells_differ(old, new) -> bool: + """NaN-safe scalar/list comparison for DOWN_ONLY change detection. + Treats None/NaN as equal to each other; falls back to 'differ' on any + uncomparable type so we err toward invalidating (correctness over a + spurious respawn).""" + def _is_na(v): + if v is None: + return True + if isinstance(v, (list, tuple, set, dict)): + return False + try: + return bool(pd.isna(v)) + except Exception: + return False + o_na, n_na = _is_na(old), _is_na(new) + if o_na and n_na: + return False + if o_na != n_na: + return True + try: + return bool(old != new) + except Exception: + return True + + def _down_only_changed(self, df_norm: pd.DataFrame) -> bool: + """True iff `df_norm` changes any DOWN_ONLY (deny-list / tags) cell versus + the CURRENT self._df. Must be called BEFORE the upsert merges df_norm in. + + Gates iterator invalidation. Critical under DDP, where rank-N re-applies + the same reconciled deny-list snapshot every step and must NOT respawn + workers when nothing actually changed. + """ + down_only = self._down_only_columns() + cols = [c for c in df_norm.columns if c in down_only] + if not cols: + return False + for col in cols: + new_col = df_norm[col] + if col not in self._df.columns: + # Brand-new DOWN_ONLY column: a change iff any non-null value. + if new_col.notna().any(): + return True + continue + for sid, new_val in new_col.items(): + if sid in self._df.index: + old_val = self._df.at[sid, col] + else: + old_val = None # new row + if self._cells_differ(old_val, new_val): + return True + return False + + def _invalidate_loader_iters_on_down_only_change(self, df_norm: pd.DataFrame) -> None: + """If `df_norm` touched any DOWN_ONLY column, mark every registered + loader's iterator as stale. Triggers on BOTH rank-0 direct discards + AND rank-N reconcile-applies (apply_df_down_state calls upsert_df) so + every rank gets a fresh iter symmetrically. No-op if no loader has an + active iter (e.g. during ledger init, before training starts).""" + down_only = self._down_only_columns() + if not any(c in down_only for c in df_norm.columns): + return + try: + from weightslab.backend.ledgers import get_dataloaders, get_dataloader + except Exception: + return + for name in get_dataloaders(): + loader = get_dataloader(name) + if loader is None: + continue + inv = getattr(loader, "_invalidate_iter", None) + if callable(inv): + try: + inv() + except Exception as exc: + logger.debug("[invalidate iter] %s: %s", name, exc) + def get_array_store(self) -> H5ArrayStore | None: """Get the array store instance.""" return self._array_store @@ -545,7 +667,10 @@ def _load_existing_data(self, origin: str = None, autoload_arrays: bool | list | loaded_df = self._store.load_all(origin) if self._store else pd.DataFrame() if not loaded_df.empty: - # Ensure multi-level index on (sample_id, annotation_id) if available + # A restore/reload changes DOWN_ONLY state wholesale -> force one full DOWN + # reconcile so children re-converge before deltas resume. + self.mark_down_full_resend() + # Ensure single-level index on sample_id if "sample_id" in loaded_df.columns: try: if "annotation_id" in loaded_df.columns: @@ -664,7 +789,15 @@ def upsert_df(self, df_local: List | pd.DataFrame, origin: str = None, force_flu with self._lock: affected_origins = self._collect_affected_origins(df_norm, origin=origin) - # Align columns + # Detect DOWN_ONLY (deny-list / tags) changes BEFORE the merge below + # overwrites the prior values — used to gate iterator invalidation. + try: + down_only_changed = self._down_only_changed(df_norm) + except Exception as exc: + down_only_changed = False + logger.debug("[down-only diff] failed: %s", exc) + + # Align columns: Ensure the global dataframe has all columns present in the update missing_cols = df_norm.columns.difference(self._df.columns) if len(missing_cols) > 0: self._df = self._df.reindex(columns=self._df.columns.union(missing_cols)) @@ -727,6 +860,42 @@ def upsert_df(self, df_local: List | pd.DataFrame, origin: str = None, force_flu self.mark_dirty_batch(sample_ids, force_flush=force_flush) self._bump_origin_revisions(affected_origins) + # Drop every registered loader's iterator ONLY if a DOWN_ONLY value + # actually changed (computed pre-merge above). Critical under DDP: + # rank-N's reconcile_all applies the SAME snapshot every step → if we + # invalidated unconditionally, workers would respawn every step and + # throughput would collapse. + if down_only_changed: + # DOWN-plane delta: these sample-ids' deny-list/tags changed, so the + # next reconcile ships just them (not the whole table). Marked after + # the merge so the drain reads the committed values. + with self._down_dirty_lock: + self._down_dirty.update(str(s) for s in df_norm.index.tolist()) + try: + self._invalidate_loader_iters_on_down_only_change(df_norm) + except Exception as exc: + logger.debug("[invalidate iter] failed: %s", exc) + + def drain_down_delta(self): + """For the DOWN reconcile. Returns (full, sids): full=True -> rank-0 should + broadcast the WHOLE DOWN_ONLY state once (first reconcile / post-restore); + else `sids` is the set of sample-ids whose DOWN_ONLY cells changed since the + last drain (empty -> nothing to broadcast). Drains both.""" + with self._down_dirty_lock: + if self._down_full_pending: + self._down_full_pending = False + self._down_dirty = set() + return True, None + sids = self._down_dirty + self._down_dirty = set() + return False, sids + + def mark_down_full_resend(self): + """Force the next DOWN reconcile to be a full snapshot (call on restore / + wholesale df reload so children re-converge before deltas resume).""" + with self._down_dirty_lock: + self._down_full_pending = True + def mark_dirty(self, sample_id: int): """Mark sample as dirty for H5 flush. @@ -748,10 +917,6 @@ def mark_dirty_batch(self, sample_ids: List[int], force_flush: bool = False): if force_flush: self._force_flush = True - def _is_array_column_to_norm(self, column_name: str, value: Any) -> bool: - """Check if a column should store arrays in separate H5 file.""" - return column_name in self._array_columns and isinstance(value, (np.ndarray, ArrayH5Proxy)) - def _should_array_be_stored(self, array_name) -> bool: """Check if array storage is enabled.""" return array_name in SAMPLES_STATS_TO_SAVE_TO_H5 # Regexed signals are not considered here @@ -1004,6 +1169,10 @@ def index_batch(obj, batch_index, rec=False): # Merge nested dicts: update existing sample_id records, add new ones for sample_id, record in records_to_add.items(): self._buffer.setdefault(sample_id, {}).update(record) + # Mark these sids as having fresh UP writes for the DDP outbox. + with self._outbox_dirty_lock: + self._outbox_dirty.update(str(s) for s in records_to_add) + with self._buffer_lock: logger.debug(f"Enqueued {len(records_to_add)} records to buffer. Buffer size is now {len(self._buffer)}.") should_flush = len(self._buffer) >= self._flush_max_rows or self.first_init # Check buffer size and trigger flush if needed @@ -1284,6 +1453,8 @@ def update_by_groups_bulk(self, origin: str, group_ids: List[Any], updates_list: if affected_ids: self.mark_dirty_batch(affected_ids) + with self._outbox_dirty_lock: + self._outbox_dirty.update(str(s) for s in affected_ids) def get_tainted_group_ids(self, group_ids: List[Any], origin: str) -> set: """Return the subset of group_ids where at least one member is discarded. @@ -1438,7 +1609,7 @@ def get_df_view(self, column: str = None, limit: int = -1, copy: bool = False, v subset = self._df[column] else: subset = self._df - if limit > 0: + if limit is not None and limit > 0: subset = subset.head(limit) return subset.copy() if copy else subset @@ -1778,6 +1949,9 @@ def _apply_buffer_records_nonblocking(self, records: List[Dict[str, Any]]): finally: self._lock.release() + # Array signals (prediction/prediction_raw/target) stored RAW — no bbox->segmap + # rasterize here (meaningless for detection, and it decoded an image per flush + # just to read H,W). Masks are produced on demand via get_prediction_mask(). # Det→seg conversion / array normalization over all written rows. if applied_index is not None and len(applied_index) > 0: if applied_index.has_duplicates: diff --git a/weightslab/data/h5_dataframe_store.py b/weightslab/data/h5_dataframe_store.py index 882468ec..3b6bec28 100644 --- a/weightslab/data/h5_dataframe_store.py +++ b/weightslab/data/h5_dataframe_store.py @@ -671,6 +671,38 @@ def upsert(self, origin: str, df: pd.DataFrame) -> int: key = self._key(origin) self._ensure_parent() + # DELTA / append-only fast path (gated by WL_H5_APPEND_ONLY): write ONLY the + # changed rows instead of read-all -> merge -> rewrite-all. O(delta) per flush + # vs O(total). Appends when the table schema matches; on first write or schema + # change, falls through to the rewrite path. Duplicate indices accumulate and + # are resolved keep-last on read (serving/UI dedup is a follow-up). + if os.environ.get("WL_H5_APPEND_ONLY", "0").lower() in ("1", "true", "yes", "on"): + try: + with self._local_lock: + with _InterProcessFileLock(self._lock_path, timeout=self._lock_timeout, poll_interval=self._poll_interval): + with pd.HDFStore(str(self._path), mode="a") as store: + if key in store: + head = store.select(key, start=0, stop=0) + ecols = list(head.columns) + if set(ecols) == set(df_norm.columns): + df2 = df_norm[ecols].copy() + for col in ecols: + edt = head[col].dtype + if str(edt) == "category": + df2[col] = pd.Categorical(df2[col], categories=head[col].cat.categories) + else: + try: + df2[col] = df2[col].astype(edt) + except Exception: + pass + store.append(key, df2, format="table", data_columns=True) + store.flush() + return len(df_norm) + except Exception as exc: + import sys as _sys; print(f"[DELTA] fell back: {type(exc).__name__}: {exc}", file=_sys.stderr, flush=True) + logger.warning(f"[H5DataFrameStore] append-only fast path fell back: {exc}") + # else fall through to the read-merge-rewrite path below + # Create backup BEFORE any writes backup_path = self._create_backup() diff --git a/weightslab/examples/PyTorch/ws-detection/README.md b/weightslab/examples/PyTorch/ws-detection/README.md deleted file mode 100644 index ba8bddc6..00000000 --- a/weightslab/examples/PyTorch/ws-detection/README.md +++ /dev/null @@ -1,102 +0,0 @@ -# WeightsLab — Object Detection (pure PyTorch) - -A small, fully-runnable **object detection** example wired into WeightsLab. It -trains a compact single-shot detector on the **Penn-Fudan Pedestrian** dataset -(~170 real photos, one class: `person`) and streams per-sample / per-instance -losses, IoU, and predicted bounding boxes to the WeightsLab UI. - -Everything here is plain PyTorch + torchvision — no detection framework -(no Ultralytics/Detectron). The only pretrained piece is an ImageNet backbone. - -## Quick start - -From a WeightsLab install, the one-liner (installs this example's -`requirements.txt`, then trains + serves until `Ctrl+C`): - -```bash -weightslab start example --det -``` - -Or run it directly: - -```bash -cd weightslab/examples/PyTorch/ws-detection -pip install -r requirements.txt -python main.py -``` - -The **first run downloads** the Penn-Fudan dataset (~50 MB, into `./data/`) and -the MobileNetV3-Small ImageNet weights (~10 MB, cached by torch). Then open the -UI (e.g. `http://localhost:5173`) to watch training. - -## What you'll see in the UI - -| Signal | Meaning | -| ----------------------- | ---------------------------------------------------- | -| `train_loss/sample` | Per-image training loss (the value being optimized) | -| `test_loss/sample` | Per-image validation loss | -| `train_iou/sample` | Mean IoU per training image | -| `test_iou/sample` | Mean IoU per validation image | -| `train_iou/instance` | IoU per **ground-truth box** `(sample_id, annotation_id)` | -| `test_iou/instance` | Same, on validation | - -Ground-truth and predicted **bounding boxes** are rendered as overlays on each -sample (the dataset and model declare `task_type = "detection"`). - -## How it works - -``` -utils/data.py PennFudanDetectionDataset — downloads Penn-Fudan, derives one - bbox per pedestrian from the instance masks, returns the WL - detection target [N, 6] = [x1, y1, x2, y2, class_id, conf] - normalized to [0, 1]. ImageNet-normalized model inputs. - `det_collate` keeps the variable box count as a per-sample list. - -utils/model.py SmallDetector — ImageNet-pretrained MobileNetV3-Small backbone - (frozen by default) + a small head that predicts ONE box per - cell on an S x S grid: (objectness, tx, ty, tw, th, class...). - `decode_grid` turns raw logits into xyxy boxes. - -utils/criterions.py PerSampleDetectionLoss — YOLO-style objectness + coordinate + - class loss, one differentiable scalar per sample (what WL - backprops). PerSampleIoU / PerInstanceIoU — IoU metrics. - decode_predictions — top-confidence boxes for the UI overlay. - -main.py Wires it all to WeightsLab: watch_or_edit(...) for the logger, - hyperparameters, data loaders, model, optimizer and the - loss/metric signals; serve(); start_training(); train/test loop. -``` - -The detector is genuinely learnable: on a small subset, mean IoU rises from -~0.39 to ~0.83 within ~60 steps. - -## Configuration (`config.yaml`) - -| Key | Default | Notes | -| ---------------------- | ------- | ----------------------------------------------------------- | -| `num_classes` | `1` | Penn-Fudan has one class (`person`). | -| `image_size` | `256` | Square model input (UI shows the original image). | -| `grid_size` | `8` | Detector predicts on an `8 x 8` cell grid. | -| `conf_thresh` | `0.3` | `objectness * class` threshold for displayed predictions. | -| `pretrained_backbone` | `true` | Load ImageNet weights for the MobileNetV3 backbone. | -| `freeze_backbone` | `true` | Train only the head (fast, less data-hungry). Set `false` to fine-tune the whole backbone once the head has warmed up. | -| `data.*.batch_size` | `8` | Per-loader batch size. | -| `data.*.max_samples` | `null` | Cap a split for quick runs (`null` = full split). | - -## Using your own dataset (e.g. traffic lights) - -The model, loss, metrics, `main.py`, and UI rendering are **dataset-agnostic** — -only `utils/data.py` and a couple of config values change: - -1. Write a `Dataset` whose `get_items(idx, ...)` returns - `(image_tensor, uid, target, metadata)`, where `target` is an - `[N, 6]` float array `[x1, y1, x2, y2, class_id, confidence]` **normalized to - `[0, 1]`** (ground-truth confidence = `1.0`). Set `self.task_type = "detection"`, - `self.num_classes`, `self.class_names`, and expose `self.images` (a list of - image paths) so the UI can show the raw image. -2. Reuse `det_collate` unchanged. -3. In `config.yaml`, set `num_classes` to your class count (e.g. `3` for - `red / yellow / green`) and update `class_names` in the dataset / model. - -That's it — multi-class works out of the box (the classification head is already -in the grid prediction; it's just trivial when `num_classes == 1`). diff --git a/weightslab/integrations/ultralytics/signals.py b/weightslab/integrations/ultralytics/signals.py index e6a733f7..c472f944 100644 --- a/weightslab/integrations/ultralytics/signals.py +++ b/weightslab/integrations/ultralytics/signals.py @@ -369,9 +369,16 @@ def overlay_p(batch): # ─── top-level API (back-compat with the existing trainer.py calls) ──── +def _unwrap_ddp(model): + """Under UL native DDP the model is a DistributedDataParallel wrapper; signal + hooks need the underlying module (criterion/init_criterion/args/modules).""" + return model.module if isinstance(model, th.nn.parallel.DistributedDataParallel) else model + + def install_per_sample_signals(model, signals_cfg: dict = {}): """Default train pipeline. Equivalent to: install_train_pipeline(model, default_train_signals(model))""" + model = _unwrap_ddp(model) install_train_pipeline(model, default_train_signals(model, signals_cfg=signals_cfg)) diff --git a/weightslab/src.py b/weightslab/src.py index a4e93cb4..4f96ad4c 100644 --- a/weightslab/src.py +++ b/weightslab/src.py @@ -31,6 +31,7 @@ from weightslab.utils.logs import set_log_directory from weightslab.utils.tools import detach_to_cpu from weightslab.backend.logger import LoggerQueue +from weightslab.utils.tools import ddp_info, is_main_process from weightslab.backend.cli import cli_serve from weightslab.backend import ledgers from weightslab.backend.ledgers import register_signal @@ -470,7 +471,6 @@ def wrappered_fwd(original_forward, kwargs, reg_name, *a, **kw): # User parameters batch_ids = wl_kw.get('batch_ids') - group_ids = wl_kw.get('group_id') batch_scalar = wl_kw.get('signals') preds = wl_kw.get('preds') targets = wl_kw.get('targets') if 'targets' in wl_kw else None @@ -481,24 +481,14 @@ def wrappered_fwd(original_forward, kwargs, reg_name, *a, **kw): # Original forward of the signal out = original_forward(*a, **kw) - # discarded samples/tainted groups from the loss tensor. - origin = kw.get('origin') or kwargs.get('origin') or get_active_origin() - - if origin and batch_ids is not None and hasattr(out, 'device') and out.ndim > 0: - try: - # Multi-sample Group Masking - if group_ids is not None: - mask = get_active_group_mask(group_ids, origin).to(out.device) - if len(mask) == len(out): - out = out * mask - - # Per-sample Individual Masking - else: - mask = get_active_sample_mask(batch_ids, origin).to(out.device) - if len(mask) == len(out): - out = out * mask - except Exception as e: - logger.debug(f"Automatic backend discard masking failed: {e}") + # NO post-hoc discard masking here. Discard is a TRANSACTION-BOUNDARY event + # (see docs/ddp_design.md): batch membership is fixed when the batch is built + # from the current deny-list (sampler filter), and a discard takes effect at + # the NEXT boundary — the prefetch queue is flushed by iterator invalidation + # so no since-discarded sample reaches a later batch. So whatever is in this + # batch was valid at its boundary; zeroing its loss after the fact (the old + # get_active_sample_mask / get_active_group_mask) was a symptom fix that + # actually violated the transactional model. Removed. # Per-instance handling: extract instance values + batch_idx mapping # and save per-annotation to dataframe. `out` may be a dict @@ -1080,7 +1070,17 @@ def serve(serving_cli: bool = False, serving_grpc: bool = False, **kwargs) -> No serving_cli: Start the interactive CLI server. serving_grpc: Start the gRPC server. **kwargs: Extra server options passed to underlying backends. + + Under DDP this is a no-op on non-zero ranks: the gRPC backend binds one + fixed port over this process's global ledger, so a second backend on a + child rank would only collide on the port and serve that rank's shard. The + studio talks to rank 0. Single-process (world_size == 1) is rank 0 -> serves + as before, so callers can write ``wl.serve()`` unconditionally. """ + rank, world_size = ddp_info() + if not is_main_process(): + logger.info("[serve] rank %d/%d is not the main process; skipping serve (rank 0 owns the UI backend).", rank, world_size) + return if serving_grpc: grpc_serve(**kwargs) @@ -1097,7 +1097,17 @@ def keep_serving(timeout: int = None, release_gpu: bool = True) -> None: until interrupted. release_gpu: If ``True``, move tracked torch objects to CPU and release CUDA cached memory before entering the wait loop. + + Under DDP this is a no-op on non-zero ranks: only rank 0 holds the process + alive to serve the UI. A child must NOT release its GPU / idle here (it is + still a live training replica), so it returns immediately and proceeds to + its own clean shutdown. """ + if not is_main_process(): + rank, world_size = ddp_info() + logger.info("[keep_serving] rank %d/%d is not the main process; returning (only rank 0 holds the UI backend alive).", rank, world_size) + return + if release_gpu: _release_gpu_resources() logger.info("WeightsLab switched to CPU idle mode for serving.") @@ -1689,7 +1699,7 @@ def normalize(x): if x is None: return None if isinstance(x, list) and isinstance(x[0], list): - return [np.max(np.array([to_numpy(t) for t in row]), axis=0) for row in x] + return [ (np.max(np.array([to_numpy(t) for t in row]), axis=0) if len(row) else np.zeros((0,), dtype=np.uint16)) for row in x] elif isinstance(x, list): return [to_numpy(t) for t in x] if isinstance(x, th.Tensor): @@ -1741,6 +1751,12 @@ def expand_dim(x): ) +# NOTE: get_active_group_mask / get_active_sample_mask were removed. They +# post-hoc zeroed the loss of discarded samples/tainted groups still present in a +# batch — a symptom fix that contradicted the transactional-discard model (a +# discard takes effect at the next batch boundary via the sampler filter + +# prefetch-flush, so no since-discarded sample reaches a later batch, and nothing +# needs masking). See docs/ddp_design.md and wrappered_fwd. def save_instance_signals( signals: dict, batch_ids: th.Tensor | np.ndarray | list, @@ -1930,6 +1946,17 @@ def _coerce_sid(x): if not losses_data: return + # Move per-instance targets OFF the GPU at enqueue time (gated by WL_INSTANCE_TARGETS_CPU): + # otherwise raw target tensors (e.g. seg [H,W] masks) sit in the pending-records buffer + # on-GPU until flush, so VRAM grows with flush_max (the disproportionate per-rank VRAM). + if targets is not None and os.environ.get("WL_INSTANCE_TARGETS_CPU", "0").lower() in ("1","true","yes","on"): + def _to_cpu(t): + return t.detach().cpu() if hasattr(t, "detach") else t + if isinstance(targets, (list, tuple)): + targets = [[_to_cpu(t) for t in s] if isinstance(s, (list, tuple)) else _to_cpu(s) for s in targets] + else: + targets = _to_cpu(targets) + # origin is intentionally NOT forwarded: instance rows (annotation_id >= 1) don't # carry an origin; the flush derives it from the sample row (annotation_id 0). DATAFRAME_M.enqueue_instance_batch( diff --git a/weightslab/tests/components/test_checkpoint_workflow.py b/weightslab/tests/components/test_checkpoint_workflow.py index c7f003c7..aab2cfdd 100644 --- a/weightslab/tests/components/test_checkpoint_workflow.py +++ b/weightslab/tests/components/test_checkpoint_workflow.py @@ -862,8 +862,20 @@ def test_06_reload_before_model_change(self): uids_A_original = self.state['uids_a'] # Before model change print(f"Reloading state A (before model change) for verification: {hash_A_original[:16]}...") + # Regression guard for same-arch restore: load_state must preserve the + # WRAPPED model object's identity. A trainer / DataLoaderInterface that + # captured `model = trainer.model` at startup holds a direct reference + # to the wrapped object — replacing it would orphan that reference. + # (The Proxy returned by ledgers.get_model() is kept stable by design; + # the bug was in what the Proxy wrapped, so we resolve via `.get()`.) + def _wrapped(p): + return p.get() if hasattr(p, "get") and callable(getattr(p, "get")) else p + pre_restore_wrapped = _wrapped(ledgers.get_model()) success = self.chkpt_manager.load_state(exp_hash=hash_A_original) self.assertTrue(success, "State A should load successfully") + self.assertIs(pre_restore_wrapped, _wrapped(ledgers.get_model()), + "load_state must preserve wrapped model identity on same-arch restore " + "(see feedback_restore_identity_preserving memory)") # Verify HP and data are from checkpoint A hp_reloaded = ledgers.get_hyperparams() diff --git a/weightslab/tests/integrations/ultralytics/__init__.py b/weightslab/tests/integrations/ultralytics/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/weightslab/tests/integrations/ultralytics/ddp/.gitignore b/weightslab/tests/integrations/ultralytics/ddp/.gitignore new file mode 100644 index 00000000..c9436d3b --- /dev/null +++ b/weightslab/tests/integrations/ultralytics/ddp/.gitignore @@ -0,0 +1,8 @@ +# run artifacts — never commit +reports/ +ddp_run/ +ddp_run.*/ +*.folded +*.perf.data +*.perfstat +__pycache__/ diff --git a/weightslab/tests/integrations/ultralytics/ddp/README.md b/weightslab/tests/integrations/ultralytics/ddp/README.md new file mode 100644 index 00000000..87961ca3 --- /dev/null +++ b/weightslab/tests/integrations/ultralytics/ddp/README.md @@ -0,0 +1,45 @@ +# WeightsLab × Ultralytics — DDP integration suite + +Locally-run integration + performance harness for the DDP-on-ultralytics detection usecase +in `examples/PyTorch/ws-detection/`. **Run explicitly** — it needs a GPU and the usecase +dataset, so it is *not* a CI unit test. The scripts drive the usecase via a path-bootstrap +to `../../../../examples/PyTorch/ws-detection/src` (config.yaml, data, `yolo_pipeline`, +`utils.*` all resolve there). + +## Entry point: `run_ddp_report.sh` + +One driver, several MODES (`PHASES`), one report under `reports/report_/`: + +| phase | what | via | +|---|---|---| +| `info` | host / GPU / torch / ultralytics / git snapshot | — | +| `scenarios` | functional suite — pass/fail + per-scenario time + MaxRSS | `ddp_test_suite.py` | +| `ablation` | WL internal tax: `ulmanual` (hand-rolled per-sample logger) vs `wl`; per-section time/RSS/IO/bytes + the `wl − ulmanual` delta | `ddp_ablation.py` | +| `profile` | py-spy Python-frame ownership (% wall in WL SDK) + perf native hotspots + `perf stat` HW counters + /proc peak RSS/threads | `aggregate_wl_ownership.py` + perf | + +```bash +./run_ddp_report.sh # all phases +PHASES="ablation profile" ABLATE_STEPS=256 ./run_ddp_report.sh +PHASES=scenarios ./run_ddp_report.sh +``` +`profile` needs `sudo` for perf + py-spy (this host: `perf_event_paranoid=4`, `ptrace_scope=1`). + +## Pieces (also runnable standalone) + +- **`ddp_test_suite.py`** — scenarios simulating UI-driven DDP curation (discard / rebalance / + pause / checkpoint / resample / …), each on a fresh 2-rank server. + `WL_DDP_ONLY=` runs one; `WL_DDP_SKIP=a,b` excludes (resume a killed run). +- **`ddp_ablation.py`** — `WL_ABLATE=ulmanual|wl` per-step cost decomposition. + `WL_ABLATE_STEPS=N`. The honest baseline is `ulmanual` (anyone logging per-sample signals + pays decode + per-sample loss); `wl − ulmanual` is WL's true machinery tax. +- **`aggregate_wl_ownership.py`** — classifies a py-spy folded profile into WL-SDK vs + goal (decode / per-sample loss) vs model / torch / data. + +Common env: `WL_DDP_BATCH`, `WL_DDP_WORKERS`, `WL_DDP_CUDA`, `WL_DDP_IMGSZ`. + +## Findings (see memory `project_wl_ddp_sdk_overhead`) + +WL's tax is a **fixed ~80 ms/step collective floor** (the anchor / gloo round-trip), +independent of image and batch size; `save_signals` is free. It amortizes to **≤5% at +batch ≥ 16**. Absolute cost grows only with dataset size (merge ∝ df rows) and flush +frequency — not pixels, not batch. diff --git a/weightslab/tests/integrations/ultralytics/ddp/__init__.py b/weightslab/tests/integrations/ultralytics/ddp/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/weightslab/tests/integrations/ultralytics/ddp/aggregate_wl_ownership.py b/weightslab/tests/integrations/ultralytics/ddp/aggregate_wl_ownership.py new file mode 100644 index 00000000..4ca17db0 --- /dev/null +++ b/weightslab/tests/integrations/ultralytics/ddp/aggregate_wl_ownership.py @@ -0,0 +1,127 @@ +"""Aggregate a py-spy folded profile into where the instruction pointer actually is, +by OWNERSHIP (not A/B section deltas). Beyond "is it WL SDK", it carves out the +GREY ZONE — work that isn't WL SDK code but only exists BECAUSE of WL: + + WL-SDK leaf in weightslab/ (save_signals, reconcile/flush, merge, dataframe, + the wrapper's own code) — EXCL weightslab/baseline_models; OR a + pandas/numpy/h5py leaf whose nearest non-lib caller is weightslab/. + decode-for-log `_decode_preds_to_6col` on the stack — NMS run ONLY to log predictions + (pure WL-motivated overhead; the UL baseline never decodes). [GREY] + loss:per-sample `criterions.py` on the stack — the per-sample loss wrapper AND the + ultralytics loss.py/tal.py it drives. Per-sample BECAUSE WL wants + per-sample signals; upper bound on WL-induced loss cost. [GREY] + model:forward ultralytics/nn leaf, no criterions/decode — the network itself. + torch:bwd/sync torch leaf — backward + grad all_reduce + optimizer. + loss:compute ultralytics loss/tal/ops leaf with NO criterions wrapper (UL baseline). + usecase/data, dataloader/collate, the driver loop, idle. + harness, other + + "WL-attributable" = WL-SDK + the grey zone. Run: + FOLDED=/tmp/wl_ablation.folded python aggregate_wl_ownership.py +""" +import os, re, collections + +PATH = os.environ.get("FOLDED", "/tmp/wl_ablation.folded") +_LIB_DATA = ("pandas", "numpy", "h5py", "pyarrow") +_GREY = {"decode-for-log", "loss:per-sample"} + + +def _file_of(frame): + m = re.search(r"\(([^)]*)\)\s*$", frame) + return m.group(1) if m else frame + + +def _is_wl(f): + return "weightslab/" in f and "baseline_models" not in f + + +_IMG = ("patches.py", "ultralytics/data/", "/cv2/", "albumentations", "PIL/", "imgaug") +_LOSS = ("ultralytics/utils/tal.py", "ultralytics/utils/loss.py", "ultralytics/utils/metrics.py") + + +def classify(frames): + files = [_file_of(f) for f in frames] + leaf = files[-1] + blob = ";".join(frames) # func names + files across the whole stack + # 1. WL SDK = WL owns the leaf (its own code running now) + if _is_wl(leaf): + return "WL-SDK" + # pandas/numpy/h5py leaf — attribute to the first non-lib caller (WL-induced?) + if any(l in leaf for l in _LIB_DATA): + for f in reversed(files[:-1]): + if any(l in f for l in _LIB_DATA): + continue + if _is_wl(f): + return "WL-SDK" + if any(s in f for s in _IMG): # numpy under imread = decode + return "data:img-decode" + break + # 2. grey zone (WL-motivated) — most specific intent wins + if "_decode_preds_to_6col" in blob: + return "decode-for-log" + # 3. image decode / augment (usecase; you'd load images regardless) + if any(s in leaf for s in _IMG): + return "data:img-decode" + # 4. loss — per-sample wrapper, or the ultralytics loss/assigner/metrics it drives + if "criterions.py" in blob: + return "loss:per-sample" + if any(s in leaf for s in _LOSS) or any(s in blob for s in _LOSS): + return "loss:compute" + # 5. model forward (network); torch leaf disambiguated by stack context + if "ultralytics/nn/" in leaf or "ultralytics/nn/" in blob: + return "model:forward" + if "torch/" in leaf: + return "torch:bwd/sync" + if any(s in leaf for s in ("yolo_pipeline", "/utils/data", "/data.py")): + return "data:img-decode" + if "ddp_ablation" in leaf or "ddp_test_suite" in leaf: + return "harness" + return "other/idle" + + +def main(): + buckets = collections.Counter() + wl_frames = collections.Counter() + total = 0 + with open(PATH) as fh: + for line in fh: + m = re.match(r"^(.*)\s+(\d+)$", line.rstrip("\n")) + if not m: + continue + stack, cnt = m.group(1), int(m.group(2)) + # drop py-spy's "process N:..." subprocess-header pseudo-frames + frames = [f for f in stack.split(";") if not f.startswith("process ")] + if not frames: + continue + total += cnt + buckets[classify(frames)] += cnt + seen = set() + for f in frames: + ff = _file_of(f) + if _is_wl(ff): + key = ff.split("/weightslab/")[-1] + if key not in seen: + wl_frames[key] += cnt + seen.add(key) + if not total: + print(f"no samples in {PATH}"); return + + print(f"TOTAL SAMPLES: {total} (~{total/200:.0f}s @ 200Hz)\n") + print("OWNERSHIP PARTITION (where is the instruction pointer):") + for b, c in buckets.most_common(): + tag = " <- GREY (WL-motivated)" if b in _GREY else "" + print(f" {b:18s} {c:8d} {100*c/total:5.1f}%{tag}") + + wl = buckets.get("WL-SDK", 0) + grey = sum(buckets.get(g, 0) for g in _GREY) + print(f"\n WL-SDK code = {100*wl/total:5.1f}% (decode/loss/bridge EXCLUDED)") + print(f" + grey zone (decode+per-sample loss) = {100*grey/total:5.1f}%") + print(f" = WL-ATTRIBUTABLE = {100*(wl+grey)/total:5.1f}% (SDK + only-because-of-WL work)") + + print("\nTOP WL-SDK files (inclusive — on the call path, not necessarily the leaf):") + for f, c in wl_frames.most_common(12): + print(f" {c:7d} {100*c/total:5.1f}% {f[:70]}") + + +if __name__ == "__main__": + main() diff --git a/weightslab/tests/integrations/ultralytics/ddp/run_ddp_report.sh b/weightslab/tests/integrations/ultralytics/ddp/run_ddp_report.sh new file mode 100755 index 00000000..f958448d --- /dev/null +++ b/weightslab/tests/integrations/ultralytics/ddp/run_ddp_report.sh @@ -0,0 +1,195 @@ +#!/usr/bin/env bash +# God-script for the WeightsLab-on-ultralytics DDP integration suite. ONE driver, +# several MODES (phases), emits a single report. Runs explicitly LOCALLY (needs a GPU + +# the usecase dataset) — it is NOT a CI unit test. Drives the ws-detection usecase under +# examples/. Runs unattended (ONE sudo prompt, kept alive). Output lands under +# reports/report_/ beside this script. +# +# Phases (override with PHASES="..."): +# info host/GPU/versions/git snapshot +# scenarios full functional suite (ddp_test_suite.py) — pass/fail + per-scn time + MaxRSS +# ablation WL internal tax: ulmanual (hand-rolled logger) vs wl (ddp_ablation.py) — +# per-section time/RSS/IO/bytes + the wl-ulmanual delta +# profile py-spy (Python-frame OWNERSHIP: % wall in WL SDK) + perf (native hotspots, +# perf stat HW counters) + /proc peak RSS/threads, on the wl ablation AND on +# PROFILE_SCN scenarios. This is the "use py-spy & perf as much as possible" part. +# +# Knobs (all optional): +# PHASES="info scenarios ablation profile" BATCH=16 WORKERS=2 CUDA=1 +# ABLATE_STEPS=256 SCN_ONLY= SCN_SKIP=a,b PROFILE_SCN="curate_lifecycle progressive_resample" +# SAMPLE_DUR=60 SAMPLE_WARM=20 OUT= +# +# Examples: +# ./run_ddp_report.sh # the works +# PHASES="ablation profile" ABLATE_STEPS=256 ./run_ddp_report.sh +# PHASES=profile PROFILE_SCN="curate_lifecycle" ./run_ddp_report.sh +set -uo pipefail # NOT -e: every phase runs even if one fails +cd "$(dirname "$0")" + +PY=${PY:-/home/rotaru/anaconda3/envs/wl_15_nl/bin/python} +PYSPY=${PYSPY:-/home/rotaru/anaconda3/bin/py-spy} +PERF=${PERF:-/usr/bin/perf} +GTIME=${GTIME:-/usr/bin/time} + +PHASES=${PHASES:-"info scenarios ablation profile"} +BATCH=${BATCH:-16}; WORKERS=${WORKERS:-2}; CUDA=${CUDA:-1} +ABLATE_STEPS=${ABLATE_STEPS:-256} +SAMPLE_DUR=${SAMPLE_DUR:-60}; SAMPLE_WARM=${SAMPLE_WARM:-20} +PROFILE_SCN=${PROFILE_SCN:-"curate_lifecycle progressive_resample"} +RD=${OUT:-reports/report_$(date +%Y%m%d_%H%M%S)} +mkdir -p "$RD"; RD=$(cd "$RD" && pwd) +REPORT="$RD/REPORT.md" + +# shared env for every child run +export WL_DDP_BATCH=$BATCH WL_DDP_WORKERS=$WORKERS WL_DDP_CUDA=$CUDA +export WEIGHTSLAB_SKIP_SECURE_INIT=true GRPC_TLS_ENABLED=0 WEIGHTSLAB_LOG_LEVEL=ERROR + +say(){ echo -e "$*" | tee -a "$REPORT"; } +hr(){ printf '\n%s\n' "================================================================" | tee -a "$REPORT"; } + +# ---- sudo: prompt once, keep alive for the whole run (perf+py-spy need it: paranoid=4, ptrace_scope=1) +HAVE_SUDO=1 +echo ">> caching sudo (perf + py-spy need root). One prompt now:" +if sudo -v 2>/dev/null; then + ( while true; do sudo -n true 2>/dev/null; sleep 50; done ) & KEEPALIVE=$! + trap 'kill $KEEPALIVE 2>/dev/null' EXIT +else + HAVE_SUDO=0; echo "!! no sudo — 'profile' phase (perf/py-spy) will be SKIPPED" +fi + +want(){ [[ " $PHASES " == *" $1 "* ]]; } + +say "# WL-on-ultralytics DDP report" +say "_$(date)_ • batch=$BATCH workers=$WORKERS cuda=$CUDA • dir: \`$RD\`" + +# ============================================================ INFO +if want info; then + hr; say "## host / versions" + say '```' + { echo "host: $(uname -srm) cores=$(nproc)" + echo "mem: $(free -h | awk '/Mem:/{print $2" total, "$7" avail"}')" + echo "gpu: $($PY - <<'P' 2>/dev/null +import torch +print(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU-only", + "| torch", torch.__version__) +P +)" + echo "ultra: $($PY -c 'import ultralytics;print(ultralytics.__version__)' 2>/dev/null)" + echo "python: $($PY --version 2>&1)" + echo "git: $(git rev-parse --short HEAD 2>/dev/null) $(git rev-parse --abbrev-ref HEAD 2>/dev/null)" + echo "perf_event_paranoid=$(cat /proc/sys/kernel/perf_event_paranoid) ptrace_scope=$(cat /proc/sys/kernel/yama/ptrace_scope)" + } | tee -a "$REPORT" + say '```' +fi + +# ============================================================ SCENARIOS +if want scenarios; then + hr; say "## functional suite (ddp_test_suite.py)" + log="$RD/scenarios.log" + env WL_DDP_SCN_TIMING=1 \ + ${SCN_ONLY:+WL_DDP_ONLY=$SCN_ONLY} ${SCN_SKIP:+WL_DDP_SKIP=$SCN_SKIP} \ + "$GTIME" -v "$PY" ddp_test_suite.py >"$log" 2>&1 + rc=$? + say "exit=$rc (full log: \`scenarios.log\`)" + say '```' + grep -E '^ scenario_|RESULT:' "$log" | tee -a "$REPORT" + grep -E 'took [0-9.]+s' "$log" | sed 's/^/ /' | tail -20 | tee -a "$REPORT" + grep -E 'Maximum resident|Elapsed \(wall|context switches' "$log" | sed 's/^\s*/ /' | tee -a "$REPORT" + say '```' +fi + +# ============================================================ ABLATION +if want ablation; then + hr; say "## ablation — WL internal tax vs hand-rolled logging (steps=$ABLATE_STEPS)" + for M in ulmanual wl; do + env WL_ABLATE=$M WL_ABLATE_STEPS=$ABLATE_STEPS \ + "$GTIME" -v "$PY" ddp_ablation.py >"$RD/ablation_$M.log" 2>&1 + say "### mode=$M"; say '```' + sed -n '/^=====/,/^=====/p' "$RD/ablation_$M.log" | tee -a "$REPORT" + grep -E '^\[mode=' "$RD/ablation_$M.log" | tee -a "$REPORT" + grep -E 'Maximum resident' "$RD/ablation_$M.log" | sed 's/^\s*/ /' | tee -a "$REPORT" + say '```' + done + # wl - ulmanual per-section delta = WL's internal machinery above hand-rolled + say "### wl − ulmanual per-section delta (WL internal tax; decode+loss cancel)"; say '```' + "$PY" - "$RD/ablation_ulmanual.log" "$RD/ablation_wl.log" <<'P' 2>/dev/null | tee -a "$REPORT" +import re,sys +def parse(p): + d={} + for ln in open(p): + m=re.match(r"\s+(\S.*?)\s+([\d.]+) ms/step",ln) + if m: d[m.group(1).strip()]=float(m.group(2)) + return d +man,wl=parse(sys.argv[1]),parse(sys.argv[2]) +for k in wl: + if k in man: print(f" {k:18s} {wl[k]-man[k]:+8.1f} ms (ulmanual {man[k]:6.1f} -> wl {wl[k]:6.1f})") +P + say '```' +fi + +# ============================================================ PROFILE (py-spy + perf) +# _profile