diff --git a/AlphaBrain/training/continual_learning/__init__.py b/AlphaBrain/training/continual_learning/__init__.py index ff7f12d..1d10a61 100644 --- a/AlphaBrain/training/continual_learning/__init__.py +++ b/AlphaBrain/training/continual_learning/__init__.py @@ -1,9 +1,9 @@ """Continual Learning module. Sub-packages: - algorithms/ — CL algorithms (ER / MIR / EWC / …) grouped by mechanism - (rehearsal_based, regularization_based, dynamic_based) - plus the :class:`CLAlgorithm` base and :class:`CLContext`. + algorithms/ — CL algorithms (ER / MIR / …) grouped by mechanism + (rehearsal_based, dynamic_based) plus the + :class:`CLAlgorithm` base and :class:`CLContext`. datasets/ — Task sequences and per-task dataset filtering. Top-level entry: @@ -13,7 +13,6 @@ Re-exports for convenience (fully-qualified paths also work): `ER` ← `algorithms.rehearsal_based.er.ER` `MIR` ← `algorithms.rehearsal_based.mir.MIR` - `EWC` ← `algorithms.regularization_based.ewc.EWC` `CLAlgorithm` ← `algorithms.base.CLAlgorithm` `CLContext` ← `algorithms.base.CLContext` `build_cl_algorithm` ← `algorithms.build_cl_algorithm` @@ -22,7 +21,6 @@ CLAlgorithm, CLContext, ER, - EWC, MIR, build_cl_algorithm, ) @@ -31,7 +29,6 @@ "CLAlgorithm", "CLContext", "ER", - "EWC", "MIR", "build_cl_algorithm", ] diff --git a/AlphaBrain/training/continual_learning/algorithms/__init__.py b/AlphaBrain/training/continual_learning/algorithms/__init__.py index cf81bbc..8a51448 100644 --- a/AlphaBrain/training/continual_learning/algorithms/__init__.py +++ b/AlphaBrain/training/continual_learning/algorithms/__init__.py @@ -1,11 +1,9 @@ """Continual-learning algorithms. -Algorithms are grouped by their mechanism into three sub-packages: +Algorithms are grouped by their mechanism into sub-packages: - :mod:`.rehearsal_based` — replay past-task samples (ER, MIR, planned DER / A-GEM) -- :mod:`.regularization_based` — penalise movement away from important - parameters (EWC, planned SI / LwF) - :mod:`.dynamic_based` — adapt the model architecture across tasks (planned DWE / Weight Merge / PackNet) @@ -19,11 +17,9 @@ ---------- The concrete classes are re-exported at this package level for convenience:: - from AlphaBrain.training.continual_learning.algorithms import EWC, MIR, ER + from AlphaBrain.training.continual_learning.algorithms import MIR, ER -Fully-qualified paths also work, e.g. -``algorithms.regularization_based.ewc.EWC`` or -``algorithms.rehearsal_based.er.ER``. +Fully-qualified paths also work, e.g. ``algorithms.rehearsal_based.er.ER``. """ from typing import Optional @@ -31,7 +27,6 @@ CLAlgorithm, CLContext, ) -from AlphaBrain.training.continual_learning.algorithms.regularization_based import EWC from AlphaBrain.training.continual_learning.algorithms.rehearsal_based import ER, MIR @@ -50,17 +45,17 @@ def build_cl_algorithm(cfg, seed: int = 42) -> Optional[CLAlgorithm]: replay_batch_ratio: 0.3 balanced_sampling: false - 2. Generic algorithm (EWC / RETAIN / DER / DWE / …):: + 2. Generic algorithm (ER / MIR / …):: continual_learning: algorithm: - name: ewc - lambda: 1.0e4 - gamma: 1.0 - lora_only: true - fisher_num_batches: 50 - fisher_clip: 1.0e4 - grad_clip_per_sample: 100.0 + name: mir + buffer_size_per_task: 500 + replay_batch_ratio: 0.5 + mir_refresh_interval: 50 + mir_candidate_size: 16 + mir_top_k: 8 + mir_lora_only: true Returns ``None`` when neither section is configured — the trainer then runs a plain sequential baseline without CL interventions. @@ -94,23 +89,6 @@ def build_cl_algorithm(cfg, seed: int = 42) -> Optional[CLAlgorithm]: balanced_sampling=algo_cfg.get("balanced_sampling", False), seed=seed, ) - if key == "ewc": - # `lambda` is a Python keyword; accept either `lambda` or - # `ewc_lambda` (OmegaConf tolerates the former as a dict key). - ewc_lambda = algo_cfg.get("lambda", None) - if ewc_lambda is None: - ewc_lambda = algo_cfg.get("ewc_lambda", 1.0e4) - excl = algo_cfg.get("exclude_name_substrings", None) - return EWC( - ewc_lambda=ewc_lambda, - gamma=algo_cfg.get("gamma", 1.0), - lora_only=algo_cfg.get("lora_only", True), - fisher_num_batches=algo_cfg.get("fisher_num_batches", 50), - fisher_clip=algo_cfg.get("fisher_clip", 1.0e4), - grad_clip_per_sample=algo_cfg.get("grad_clip_per_sample", 100.0), - fisher_save_dir=algo_cfg.get("fisher_save_dir", None), - exclude_name_substrings=list(excl) if excl is not None else None, - ) if key == "mir": return MIR( buffer_size_per_task=algo_cfg.get("buffer_size_per_task", 500), @@ -134,7 +112,6 @@ def build_cl_algorithm(cfg, seed: int = 42) -> Optional[CLAlgorithm]: "CLAlgorithm", "CLContext", "ER", - "EWC", "MIR", "build_cl_algorithm", ] diff --git a/AlphaBrain/training/continual_learning/algorithms/base.py b/AlphaBrain/training/continual_learning/algorithms/base.py index d784f13..96ccece 100644 --- a/AlphaBrain/training/continual_learning/algorithms/base.py +++ b/AlphaBrain/training/continual_learning/algorithms/base.py @@ -1,6 +1,6 @@ """Abstract base class for continual-learning algorithms. -All CL algorithms (Experience Replay / EWC / DER / RETAIN / DWE / ...) implement +All CL algorithms (Experience Replay / DER / RETAIN / DWE / ...) implement this interface. The continual trainer (`AlphaBrain.training.continual_learning.train`) only talks to algorithms through this protocol, so new methods can be plugged in without touching the training loop. @@ -10,11 +10,9 @@ - `ER` (algorithms.rehearsal_based.er) — experience replay with reservoir sampling (uniform or per-task balanced). - `MIR` (algorithms.rehearsal_based.mir) — interference-aware replay. -- `EWC` (algorithms.regularization_based.ewc) — Fisher-weighted L2 penalty. Planned ------- -- `EWC` Elastic Weight Consolidation (Kirkpatrick et al. 2017) - `DER` Dark Experience Replay (Buzzega et al. 2020) - `RETAIN` Weight Merging / model souping (Wortsman et al. 2022, variants) - `DWE` Dynamic Weight Expansion (per-task adapters) @@ -31,14 +29,13 @@ modify_batch(batch, task_id) — return the batch the model will consume (ER / DER inject replay samples here). compute_penalty(model) — return a scalar added to the task loss - (EWC / SI regularizers). + (regularizer-based methods). Task-level hooks (bracket each CL task): on_task_start(context) — before training starts on a new task (DWE expands the model here). on_task_end(context) — after a task finishes training - (ER populates buffer, EWC computes Fisher, - RETAIN merges weights). + (ER populates buffer, RETAIN merges weights). On resume, the trainer replays `on_task_end` for each completed task so algorithms can rebuild state not captured in `state_dict`. @@ -109,9 +106,9 @@ def modify_batch(self, batch: Any, task_id: int) -> Any: def compute_penalty(self, model: Any) -> Optional[torch.Tensor]: """Return a scalar penalty to add to the task loss, or None. - Typical uses: - * EWC: λ · Σ F_i · (θ_i − θ*_i)² - * SI: λ · Σ Ω_i · (θ_i − θ*_i)² + Typical uses (regularization-based methods): + * Fisher-weighted L2: λ · Σ F_i · (θ_i − θ*_i)² + * SI path integral: λ · Σ Ω_i · (θ_i − θ*_i)² Default: None (no extra penalty). """ @@ -151,7 +148,6 @@ def on_task_end(self, context: CLContext) -> None: Typical uses: * ER: populate the buffer from `context.task_dataset`. - * EWC: compute Fisher on `context.task_dataloader`, snapshot θ*. * RETAIN: merge the newly-trained weights into the running model. Default: no-op. diff --git a/AlphaBrain/training/continual_learning/algorithms/regularization_based/__init__.py b/AlphaBrain/training/continual_learning/algorithms/regularization_based/__init__.py deleted file mode 100644 index d9d8614..0000000 --- a/AlphaBrain/training/continual_learning/algorithms/regularization_based/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -"""Regularization-based continual learning algorithms. - -These methods add a penalty term to the task loss that discourages moving -parameters away from values that mattered for previous tasks. The penalty -is typically weighted by an importance estimate — Fisher information, path -integrals, etc. No explicit memory of past-task samples is required at -training time (though Fisher estimation does re-read the task data once -at task-end). - -- :class:`EWC` — Elastic Weight Consolidation (Kirkpatrick et al. 2017): - λ · Σ F_i · (θ_i − θ*_i)² with diagonal Fisher computed at task end. - -Planned (not yet implemented): -- **SI** — Synaptic Intelligence: path-integral importance accumulated - online (no extra task-end pass). -- **LwF** — Learning without Forgetting: teacher-logit distillation on - the current-task batch using a frozen snapshot of the previous model. -""" -from AlphaBrain.training.continual_learning.algorithms.regularization_based.ewc import EWC - -__all__ = ["EWC"] diff --git a/AlphaBrain/training/continual_learning/algorithms/regularization_based/ewc.py b/AlphaBrain/training/continual_learning/algorithms/regularization_based/ewc.py deleted file mode 100644 index ec18d66..0000000 --- a/AlphaBrain/training/continual_learning/algorithms/regularization_based/ewc.py +++ /dev/null @@ -1,505 +0,0 @@ -"""Elastic Weight Consolidation (EWC) for continual learning. - -Reference ---------- -Kirkpatrick et al. 2017, "Overcoming catastrophic forgetting in neural -networks" (https://arxiv.org/abs/1612.00796). - -Idea ----- -After each task ends we estimate the diagonal of the Fisher information -matrix on that task's data and snapshot the current parameters θ*. -During training on the next task the loss gains a regularizer - - L_EWC = λ · Σ_i F_i · (θ_i − θ*_i)² - -that penalises moves away from θ* in directions the previous tasks were -"confident" about. Diagonal Fisher is the standard approximation (full -Fisher is O(N²) in parameter count). - -VLA / LoRA specialisation -------------------------- -For VLA models with LoRA adapters the default is to track Fisher only -over LoRA parameters (`'lora' in name.lower()`), which keeps memory -around 1–3 % of the full-model cost. Mirrors the approach in -UT-Austin-RobIn/continual-vla-rl (rlinf/algorithms/ewc.py). - -Numerical stability tricks borrowed from the same reference: -- optional per-sample gradient clip **before** squaring (reduces blow-ups - when a rare batch produces a huge gradient); -- a ceiling on Fisher values (`fisher_clip`) post-aggregation; -- Fisher + θ* kept on CPU in fp32; a small per-device GPU cache is built - lazily on first `compute_penalty` call to avoid H2D every step. - -Multi-task accumulation ------------------------ -Across tasks we blend new and old Fisher with a decay factor γ: - - F ← γ · F_old + F_new - -γ = 1.0 gives pure additive EWC (each task contributes equally). -γ < 1.0 gives "online EWC" — the exponentially-weighted variant. - -Resume semantics ----------------- -We do NOT serialize Fisher/θ* tensors (they can be hundreds of MB). On -resume the trainer replays `on_task_end` for each completed task, and -each replay recomputes Fisher from the saved model state — this mirrors -how ER repopulates its samples on resume. This is an -approximation (Fisher for early tasks gets recomputed against the final -saved model rather than the θ* that existed when that task finished) but -it keeps the trainer's resume path uniform across algorithms and is -standard practice in online-EWC implementations. -""" - -from __future__ import annotations - -import logging -import os -from typing import Any, Dict, Iterator, List, Optional, Tuple - -import torch -from torch import nn - -from AlphaBrain.training.continual_learning.algorithms.base import ( - CLAlgorithm, - CLContext, -) - -logger = logging.getLogger(__name__) - - -def _unwrap(model: Any) -> nn.Module: - """Strip DeepSpeed / DDP / Accelerate wrappers to expose the user module.""" - base = model - while hasattr(base, "module"): - base = base.module - return base - - -def _zero_grad(model: Any) -> None: - """Clear gradients portably across bare nn.Module / DDP / DeepSpeed. - - ``DeepSpeedEngine.zero_grad()`` doesn't accept ``set_to_none`` so we - call it without kwargs. On a bare ``nn.Module`` this is equivalent to - ``set_to_none=True`` on torch ≥ 2.0 (the project's floor). - """ - try: - model.zero_grad() - except TypeError: - # Extremely old torch signatures — fall back to the explicit flag. - model.zero_grad(set_to_none=True) - - -def _is_main_rank() -> bool: - """Return True iff this process is rank 0 (or non-distributed).""" - try: - import torch.distributed as dist - if dist.is_available() and dist.is_initialized(): - return dist.get_rank() == 0 - except Exception: - pass - return True - - -class EWC(CLAlgorithm): - """Elastic Weight Consolidation with diagonal Fisher approximation. - - Hyperparameters - --------------- - ewc_lambda : regularization strength λ. In continual-vla-rl the default is - 1e6 for full-param Fisher; for LoRA-only Fisher a smaller value - (1e3–1e5) usually works. Tune per-dataset. - gamma : decay applied to the *old* Fisher when blending with a new one. - 1.0 = pure additive (standard EWC). <1.0 = online EWC. - lora_only : only compute Fisher over parameters whose name contains 'lora'. - Essential for large VLA backbones; set to False for models without LoRA. - fisher_num_batches : how many minibatches of the task dataloader to use - for Fisher estimation. Small values (20–100) are typical. - fisher_clip : post-aggregation clamp on Fisher entries (protects against - isolated gradient spikes). Set to None to disable. - grad_clip_per_sample : clip |grad| element-wise *before* squaring it into - Fisher. Set to None to disable. Guards against bf16 overflow. - """ - - def __init__( - self, - ewc_lambda: float = 1e4, - gamma: float = 1.0, - lora_only: bool = True, - fisher_num_batches: int = 50, - fisher_clip: Optional[float] = 1e4, - grad_clip_per_sample: Optional[float] = 100.0, - fisher_save_dir: Optional[str] = None, - exclude_name_substrings: Optional[List[str]] = None, - ): - self.ewc_lambda = float(ewc_lambda) - self.gamma = float(gamma) - self.lora_only = bool(lora_only) - self.fisher_num_batches = int(fisher_num_batches) - self.fisher_clip = None if fisher_clip is None else float(fisher_clip) - self.grad_clip_per_sample = ( - None if grad_clip_per_sample is None else float(grad_clip_per_sample) - ) - self.exclude_name_substrings = tuple( - s.lower() for s in (exclude_name_substrings or []) - ) - # When set, after every task we dump the accumulated Fisher + θ* - # tensors to `/fisher_task_.pt` for offline - # analysis. Rank 0 writes; other ranks are silent. - self.fisher_save_dir = ( - None if fisher_save_dir is None else str(fisher_save_dir) - ) - - # CPU-resident master state (fp32, keyed by full parameter name) - self._fisher: Optional[Dict[str, torch.Tensor]] = None - self._old_params: Optional[Dict[str, torch.Tensor]] = None - - # Per-device GPU cache (lazy, invalidated on every on_task_end) - self._device_cache: Optional[torch.device] = None - self._fisher_gpu: Optional[Dict[str, torch.Tensor]] = None - self._old_params_gpu: Optional[Dict[str, torch.Tensor]] = None - - # Bookkeeping - self._completed_tasks: List[int] = [] - # Most recent raw (unmerged) per-task Fisher — for `metrics()` - # and to help diagnose whether this task's estimate was sane. - self._last_task_fisher_stats: Dict[str, float] = {} - - # ------------------------------------------------------------------ - # Parameter iteration (respects lora_only filter) - # ------------------------------------------------------------------ - def _iter_params(self, model: Any) -> Iterator[Tuple[str, torch.Tensor]]: - base = _unwrap(model) - for name, param in base.named_parameters(): - if not param.requires_grad: - continue - lower = name.lower() - if self.lora_only and "lora" not in lower: - continue - if any(tok in lower for tok in self.exclude_name_substrings): - continue - yield name, param - - def _snapshot_params(self, model: Any) -> Dict[str, torch.Tensor]: - snap: Dict[str, torch.Tensor] = {} - for name, param in self._iter_params(model): - snap[name] = ( - param.detach().to(dtype=torch.float32, device="cpu").clone() - ) - return snap - - # ------------------------------------------------------------------ - # Fisher computation - # ------------------------------------------------------------------ - def _compute_fisher( - self, - model: Any, - dataloader: Any, - accelerator: Optional[Any] = None, - ) -> Dict[str, torch.Tensor]: - """Estimate diagonal Fisher by averaging squared per-batch gradients. - - The dataloader is iterated for up to ``fisher_num_batches`` batches; - each batch forwards once, backprops once, and contributes its squared - gradients (clipped if configured) to the running Fisher accumulator. - - DeepSpeed ZeRO-2 notes - ---------------------- - Under DeepSpeed ZeRO-2 with ``contiguous_gradients: true``, ``param.grad`` - is cleared once autograd finishes (gradients live in DeepSpeed's - contiguous bucket). Reading ``param.grad`` directly after - ``accelerator.backward`` therefore yields ``None`` and silently gives - an all-zero Fisher. To avoid this we register a per-parameter - backward hook that fires *during* autograd — hooks run before - DeepSpeed moves the gradient away, so we capture the live - (pre-reduction, per-rank-local) gradient there. For MIR-style - heuristic use the per-rank local value is fine. - """ - fisher: Dict[str, torch.Tensor] = {} - for name, param in self._iter_params(model): - fisher[name] = torch.zeros_like( - param.data, dtype=torch.float32, device="cpu" - ) - - if not fisher: - logger.warning( - "EWC._compute_fisher: no trainable parameters matched " - "(lora_only=%s). Returning empty Fisher.", - self.lora_only, - ) - return fisher - - base = _unwrap(model) - was_training = base.training - model.eval() - - # Install autograd hooks on every tracked LoRA param so we capture - # gradients during backward (before DeepSpeed can clear `.grad`). - captured: Dict[str, torch.Tensor] = {} - hook_handles: List[Any] = [] - for name, param in self._iter_params(model): - def _make_hook(pname: str): - def _hook(grad: torch.Tensor) -> torch.Tensor: - # Clone so later autograd ops don't mutate our copy; - # cast to fp32 for numerically stable accumulation. - captured[pname] = grad.detach().to(torch.float32).clone() - return grad # don't modify the gradient - return _hook - hook_handles.append(param.register_hook(_make_hook(name))) - - count = 0 - try: - for batch in dataloader: - if count >= self.fisher_num_batches: - break - captured.clear() - _zero_grad(model) - - with torch.autocast("cuda", dtype=torch.bfloat16): - output = model.forward(batch) - loss = output["action_loss"] - - if accelerator is not None: - accelerator.backward(loss) - else: - loss.backward() - - # After backward, `captured` holds per-param gradients - # regardless of whether DeepSpeed cleared .grad. - for name, g in captured.items(): - if self.grad_clip_per_sample is not None: - g = g.clamp( - -self.grad_clip_per_sample, self.grad_clip_per_sample - ) - fisher[name] += g.pow(2).cpu() - count += 1 - finally: - for h in hook_handles: - try: - h.remove() - except Exception: - pass - _zero_grad(model) - if was_training: - base.train() - - if count > 0: - for name in fisher: - fisher[name] = fisher[name] / float(count) - if self.fisher_clip is not None: - for name in fisher: - fisher[name] = fisher[name].clamp_(0.0, self.fisher_clip) - - # Cheap diagnostic summary — use WARNING level so it bypasses - # accelerate's default INFO filter and actually shows up in logs. - sum_all = float(sum(float(t.sum()) for t in fisher.values())) - max_all = float(max(float(t.max()) for t in fisher.values())) - nonzero_entries = int( - sum(int((t > 0).sum()) for t in fisher.values()) - ) - total_entries = int(sum(t.numel() for t in fisher.values())) - pct_nonzero = ( - 100.0 * nonzero_entries / total_entries if total_entries else 0.0 - ) - self._last_task_fisher_stats = { - "fisher_sum": sum_all, - "fisher_max": max_all, - "fisher_pct_nonzero": pct_nonzero, - "fisher_num_params": float(len(fisher)), - "fisher_num_batches_used": float(count), - } - logger.warning( - "EWC: Fisher estimated over %d batches across %d params " - "(lora_only=%s). sum=%.4e max=%.4e nonzero=%.2f%%", - count, len(fisher), self.lora_only, - sum_all, max_all, pct_nonzero, - ) - return fisher - - # ------------------------------------------------------------------ - # CLAlgorithm hooks - # ------------------------------------------------------------------ - def on_task_end(self, context: CLContext) -> None: - """Recompute Fisher on the finished task's data and snapshot θ*.""" - if context.model is None or context.task_dataloader is None: - logger.warning( - "EWC.on_task_end: model or task_dataloader missing from context; " - "skipping Fisher computation." - ) - return - - new_fisher = self._compute_fisher( - model=context.model, - dataloader=context.task_dataloader, - accelerator=context.accelerator, - ) - new_snapshot = self._snapshot_params(context.model) - - if self._fisher is None: - self._fisher = new_fisher - else: - # Merge Fisher across tasks: F ← γ · F_old + F_new - merged: Dict[str, torch.Tensor] = {} - for name in set(self._fisher.keys()) | set(new_fisher.keys()): - old = self._fisher.get(name) - incoming = new_fisher.get(name) - if old is None: - merged[name] = incoming - elif incoming is None: - merged[name] = self.gamma * old - else: - merged[name] = self.gamma * old + incoming - self._fisher = merged - - self._old_params = new_snapshot - self._completed_tasks.append(int(context.task_id)) - - # Invalidate the GPU cache — next penalty() call will rebuild it. - self._device_cache = None - self._fisher_gpu = None - self._old_params_gpu = None - - # Use WARNING level so the message survives accelerate's INFO filter. - logger.warning( - "EWC.on_task_end: consolidated task %s " - "(total tasks seen = %d, Fisher entries = %d)", - context.task_id, - len(self._completed_tasks), - len(self._fisher or {}), - ) - - # Optionally persist the Fisher + θ* snapshot to disk for offline - # analysis. Only rank 0 writes (tensors are per-rank local but we - # pick one deterministic view). - if self.fisher_save_dir is not None and _is_main_rank(): - try: - os.makedirs(self.fisher_save_dir, exist_ok=True) - save_path = os.path.join( - self.fisher_save_dir, - f"fisher_task_{int(context.task_id)}.pt", - ) - torch.save( - { - "task_id": int(context.task_id), - "completed_tasks": list(self._completed_tasks), - "fisher": self._fisher, - "old_params": self._old_params, - "last_task_fisher_stats": dict(self._last_task_fisher_stats), - "config": { - "ewc_lambda": self.ewc_lambda, - "gamma": self.gamma, - "lora_only": self.lora_only, - "fisher_num_batches": self.fisher_num_batches, - "fisher_clip": self.fisher_clip, - "grad_clip_per_sample": self.grad_clip_per_sample, - }, - }, - save_path, - ) - logger.warning("EWC: saved Fisher snapshot -> %s", save_path) - except Exception as e: - logger.warning( - "EWC: failed to save Fisher snapshot for task %s: %s", - context.task_id, e, - ) - - def compute_penalty(self, model: Any) -> Optional[torch.Tensor]: - """Return λ · Σ F · (θ − θ*)² or None if no tasks have been seen.""" - if ( - self._fisher is None - or self._old_params is None - or self.ewc_lambda == 0.0 - ): - return None - - # Resolve the device of the live parameters (first trainable param). - device: Optional[torch.device] = None - for _, p in self._iter_params(model): - device = p.device - break - if device is None: - return None - - # Lazy per-device cache: only pay H2D once per (device, task_end) epoch. - if self._device_cache != device: - self._fisher_gpu = {n: f.to(device) for n, f in self._fisher.items()} - self._old_params_gpu = {n: p.to(device) for n, p in self._old_params.items()} - self._device_cache = device - - penalty: Optional[torch.Tensor] = None - for name, param in self._iter_params(model): - if name not in self._fisher_gpu: - continue - f = self._fisher_gpu[name] - old = self._old_params_gpu[name] - diff = param.to(torch.float32) - old - term = (f * diff.pow(2)).sum() - penalty = term if penalty is None else penalty + term - - if penalty is None: - return None - return self.ewc_lambda * penalty - - # ------------------------------------------------------------------ - # Reporting / serialization - # ------------------------------------------------------------------ - def describe(self) -> Dict[str, Any]: - return { - "algorithm": self.name, - "ewc_lambda": self.ewc_lambda, - "gamma": self.gamma, - "lora_only": self.lora_only, - "fisher_num_batches": self.fisher_num_batches, - "fisher_clip": self.fisher_clip, - "grad_clip_per_sample": self.grad_clip_per_sample, - "num_tasks_consolidated": len(self._completed_tasks), - "fisher_entries": 0 if self._fisher is None else len(self._fisher), - } - - def metrics(self) -> Dict[str, float]: - m: Dict[str, float] = { - "ewc_num_tasks_consolidated": float(len(self._completed_tasks)), - } - # Surface last task's Fisher stats so training logs show whether - # Fisher estimation actually captured signal (vs silently zeroing). - for k, v in self._last_task_fisher_stats.items(): - m[f"ewc_{k}"] = float(v) - return m - - def state_dict(self) -> Dict[str, Any]: - """Return only metadata — Fisher/θ* tensors are not serialized. - - On resume the trainer replays `on_task_end` for each completed task, - which rebuilds Fisher from the saved model state. This mirrors - ER's behaviour and keeps the trainer's resume path uniform. - """ - return { - "algorithm": self.name, - "ewc_lambda": self.ewc_lambda, - "gamma": self.gamma, - "lora_only": self.lora_only, - "fisher_num_batches": self.fisher_num_batches, - "fisher_clip": self.fisher_clip, - "grad_clip_per_sample": self.grad_clip_per_sample, - "completed_tasks": list(self._completed_tasks), - } - - def load_state_dict(self, state: Dict[str, Any]) -> None: - self.ewc_lambda = float(state.get("ewc_lambda", self.ewc_lambda)) - self.gamma = float(state.get("gamma", self.gamma)) - self.lora_only = bool(state.get("lora_only", self.lora_only)) - self.fisher_num_batches = int( - state.get("fisher_num_batches", self.fisher_num_batches) - ) - fc = state.get("fisher_clip", self.fisher_clip) - self.fisher_clip = None if fc is None else float(fc) - gc = state.get("grad_clip_per_sample", self.grad_clip_per_sample) - self.grad_clip_per_sample = None if gc is None else float(gc) - # Tensor state is reset; replay of on_task_end will rebuild it. - self._fisher = None - self._old_params = None - self._device_cache = None - self._fisher_gpu = None - self._old_params_gpu = None - self._completed_tasks = [] diff --git a/AlphaBrain/training/continual_learning/algorithms/rehearsal_based/mir.py b/AlphaBrain/training/continual_learning/algorithms/rehearsal_based/mir.py index 327df20..d78cd6a 100644 --- a/AlphaBrain/training/continual_learning/algorithms/rehearsal_based/mir.py +++ b/AlphaBrain/training/continual_learning/algorithms/rehearsal_based/mir.py @@ -30,9 +30,8 @@ LoRA-only virtual step ---------------------- Only LoRA parameters (``'lora' in name.lower()``) participate in the -virtual step. This matches EWC's filter and is essential on 3B-scale -backbones — otherwise the Δ direction is swamped by frozen backbone -noise. +virtual step. This is essential on 3B-scale backbones — otherwise the +Δ direction is swamped by frozen backbone noise. Model access ------------ diff --git a/AlphaBrain/training/continual_learning/train.py b/AlphaBrain/training/continual_learning/train.py index a18136a..bb8a99b 100644 --- a/AlphaBrain/training/continual_learning/train.py +++ b/AlphaBrain/training/continual_learning/train.py @@ -2,7 +2,7 @@ Continual Learning Trainer for AlphaBrain. Trains a VLA model sequentially on a stream of tasks, delegating the CL -strategy (Experience Replay / EWC / RETAIN / DWE / …) to a pluggable +strategy (Experience Replay / DER / RETAIN / DWE / …) to a pluggable ``CLAlgorithm`` instance built by :func:`AlphaBrain.training.continual_learning.algorithms.build_cl_algorithm`. @@ -148,7 +148,7 @@ def __init__(self, cfg, model, optimizer, lr_scheduler, accelerator): from AlphaBrain.training.trainer_utils.peft import is_lora_enabled self.use_lora = is_lora_enabled(cfg) - # Continual-learning algorithm (ER / MIR / EWC / ...). + # Continual-learning algorithm (ER / MIR / ...). # None means "plain sequential baseline — no CL intervention". self.cl_algorithm = build_cl_algorithm(cfg, seed=cfg.get("seed", 42)) @@ -204,8 +204,8 @@ def train(self, full_dataset, episode_task_map): f"skipping {start_task_idx} completed tasks" ) # Rebuild CL algorithm state by replaying on_task_end for each - # completed task (ER re-populates its buffer, EWC re-computes - # Fisher from snapshots, RETAIN re-applies merges, etc.). + # completed task (ER re-populates its buffer, RETAIN re-applies + # merges, etc.). if self.cl_algorithm is not None: for skip_idx in range(start_task_idx): skip_task_id = task_order[skip_idx] @@ -298,7 +298,7 @@ def train(self, full_dataset, episode_task_map): steps_per_task=steps_per_task, ) - # Post-task CL hook (ER populates, EWC computes Fisher, RETAIN merges, …) + # Post-task CL hook (ER populates, RETAIN merges, …) if self.cl_algorithm is not None: logger.info( f"Running {self.cl_algorithm.name}.on_task_end for task {task_id}..." @@ -388,7 +388,7 @@ def _train_step(self, batch, lr_scheduler): """Execute a single training step. The task loss is the model's own `action_loss` plus an optional - algorithm-provided penalty (EWC / SI regularizer). Algorithms that + algorithm-provided penalty (regularization-based methods). Algorithms that don't need a penalty return ``None`` from :meth:`compute_penalty`. """ with self.accelerator.accumulate(self.model): @@ -522,7 +522,7 @@ def _save_task_checkpoint(self, task_id: int, task_idx_in_seq: int): else: torch.save(state_dict, checkpoint_path + "_pytorch_model.pt") - # Save CL algorithm state (ER buffer metadata, EWC Fisher stats, …) + # Save CL algorithm state (ER buffer metadata, …) if self.cl_algorithm is not None: cl_state = self.cl_algorithm.state_dict() with open(checkpoint_path + "_cl_state.json", "w") as f: diff --git a/docs/quickstart/continual_learning.md b/docs/quickstart/continual_learning.md index eb04744..9bd638a 100644 --- a/docs/quickstart/continual_learning.md +++ b/docs/quickstart/continual_learning.md @@ -1,7 +1,7 @@ # Continual Learning Train a single VLA backbone sequentially over the four LIBERO task -suites with a pluggable family of CL algorithms (ER / MIR / EWC) +suites with a pluggable family of CL algorithms (ER / MIR) selectable from YAML. Supports 4 architectures × LoRA / full-parameter. --- diff --git a/scripts/run_continual_learning_scripts/README.md b/scripts/run_continual_learning_scripts/README.md index 7f6f3ce..07e968d 100644 --- a/scripts/run_continual_learning_scripts/README.md +++ b/scripts/run_continual_learning_scripts/README.md @@ -52,8 +52,6 @@ AlphaBrain/training/continual_learning/algorithms/ ├── rehearsal_based/ # replay methods │ ├── er.py # class ER │ └── mir.py # class MIR(ER) -├── regularization_based/ # loss-penalty methods -│ └── ewc.py # class EWC └── dynamic_based/ # per-task architecture changes (planned) ``` @@ -430,7 +428,6 @@ AlphaBrain/training/continual_learning/train.py (trainer) ├── algorithms/ │ ├── base.py CLAlgorithm protocol + CLContext │ ├── rehearsal_based/ ER, MIR - │ ├── regularization_based/ EWC │ └── dynamic_based/ (planned: DWE / Weight Merge / PackNet) │ ├── datasets/ TaskFilteredDataset + task_sequences @@ -448,10 +445,10 @@ hooks run in the inner loop; task-level hooks bracket each CL task: |:--------------------------------|:----------------------------------------------|:----------------------------------------------------| | `observe(batch, task_id)` | Per step, before forward | Online bookkeeping (SI, streaming reservoir) | | `modify_batch(batch, task_id)` | Per step, before forward | ER / MIR inject replay samples | -| `compute_penalty(model)` | Per step, inside autocast block | EWC / SI return λ · regularizer tensor | +| `compute_penalty(model)` | Per step, inside autocast block | Regularization-based methods return λ · penalty | | `after_backward(model)` | Per step, after `accelerator.backward()` | MIR snapshots gradients (DeepSpeed-safe) | | `on_task_start(ctx)` | Before each task's inner loop begins | DWE expands model; MIR installs grad hooks | -| `on_task_end(ctx)` | After each task's inner loop finishes | ER populates buffer; EWC computes Fisher | +| `on_task_end(ctx)` | After each task's inner loop finishes | ER populates buffer; RETAIN merges weights | `ctx` is a `CLContext` dataclass carrying `task_id`, `model`, `task_dataset`, `task_dataloader`, and `accelerator` — the algorithm