One-command wrappers for AlphaBrain's Continual Learning (CL) pipeline: train a single Vision-Language-Action (VLA) backbone sequentially over a stream of manipulation tasks, then evaluate the final checkpoint on the full task matrix.
CL algorithms plug into a single CLAlgorithm hook interface and are
selected entirely via YAML — switching methods needs no code edits.
The four LIBERO suites (Goal / Spatial / Object / Long) ship in-tree as
first-class CL streams; additional LeRobot-format mixtures can be
plugged in via the Custom task streams section.
| Addition | Where |
|---|---|
| MIR (Maximally Interfered Retrieval) — replay | algorithms/rehearsal_based/mir.py, see §MIR |
| LIBERO-Long (libero_10) — 10 long-horizon tasks | qwengr00t_er_lora_libero_long.yaml |
compute_cl_matrix_metrics.py — ASR / BWT / F |
computes Lopez-Paz BWT + Chaudhry forgetting from a 10×10 eval matrix |
compute_dataset_statistics.py — generic ds_stats |
recomputes / repairs dataset_statistics.json for any LeRobot mixture |
Four VLA architectures are supported, each with full-parameter and Low-Rank Adaptation (LoRA) variants:
| Architecture | Parameters | Backbone |
|---|---|---|
| QwenGR00T | ~3.8 B | Qwen2.5-VL-3B + Flow-Matching DiT head |
| NeuroVLA | ~3.0 B | Qwen2.5-VL-3B + Q-Former + SNN head |
| LlamaOFT | ~11 B | Llama-3.2-11B-Vision + MLP head |
| PaliGemmaOFT | ~3.0 B | PaliGemma-3B + MLP head |
Selectable via continual_learning.algorithm.name in YAML.
| Algorithm | Category | One-line idea |
|---|---|---|
| ER | rehearsal_based | Reservoir replay buffer; mixes past-task samples into each batch |
| MIR | rehearsal_based | Replays the top-k samples that a virtual SGD step would hurt the most |
All algorithms share the single :class:CLAlgorithm hook interface
(observe, modify_batch, compute_penalty, after_backward,
on_task_start, on_task_end) — a new method only needs to override the
hooks it cares about. Code lives under:
AlphaBrain/training/continual_learning/algorithms/
├── base.py # CLAlgorithm + CLContext
├── rehearsal_based/ # replay methods
│ ├── er.py # class ER
│ └── mir.py # class MIR(ER)
├── regularization_based/ # loss-penalty methods
│ └── ewc.py # class EWC
└── dynamic_based/ # per-task architecture changes (planned)
Reference. Chaudhry et al. 2019 (baseline rehearsal); Ratcliff 1990 (original). Mechanism. A fixed-size per-task reservoir buffer. At each training step a configurable fraction of the current batch is replaced with samples drawn uniformly (or balanced per task) from the buffer. Buffer is populated once at each task-end.
Core hyperparameters:
| Key | Default | Effect |
|---|---|---|
buffer_size_per_task |
500 | Max samples per task; total buffer grows linearly with # tasks |
replay_batch_ratio |
0.3 | Fraction of each batch replaced by replay samples (≥ 2 enforced) |
balanced_sampling |
false | True → draw equally from every stored task; False → uniform over all samples |
YAML (generic-algorithm style):
continual_learning:
algorithm:
name: er
buffer_size_per_task: 500
replay_batch_ratio: 0.3
balanced_sampling: falseOr legacy replay-style (still supported):
continual_learning:
replay:
enabled: true
method: experience_replay
buffer_size_per_task: 500
replay_batch_ratio: 0.3
balanced_sampling: falseReference. Aljundi et al., NeurIPS 2019 (arXiv:1908.04742).
Mechanism. Same storage as ER, but the sampling policy is interference-aware.
Every mir_refresh_interval steps, MIR:
- Draws
|C|candidates uniformly from the buffer. - Forwards each candidate at the current parameters →
L(θ, c). - Takes a virtual SGD step θ_v ← θ − η · g_curr (g_curr is piggy-backed
from the trainer's own backward via an
after_backwardhook — no extra reduction, DeepSpeed-safe). - Forwards each candidate at θ_v →
L(θ_v, c). - Restores θ, ranks candidates by Δ_c = L(θ_v, c) − L(θ, c), caches the top-k. Subsequent steps replay from that cached top-k until the next refresh.
This amortized variant runs ≈ 10 % slower than plain ER at defaults — the naive "every-step MIR" would be 50-100× slower on 3 B-parameter VLAs.
Gradient capture uses torch.nn.Parameter.register_hook, which fires
during autograd (before DeepSpeed ZeRO-2 moves gradients into its
contiguous buffer and clears param.grad). This is why MIR works under
ZeRO-2 with contiguous_gradients: true in the DeepSpeed config.
Hyperparameters:
| Key | Default | Effect |
|---|---|---|
buffer_size_per_task |
500 | Same as ER |
replay_batch_ratio |
0.3 | Same as ER |
balanced_sampling |
false | Fallback uniform sampling while cache is empty |
mir_refresh_interval |
200 | Lower → fresher cache, higher overhead |
mir_candidate_size |
16 | How many candidates scored per refresh ( |
mir_top_k |
8 | Cache size; must be ≤ mir_candidate_size |
mir_virtual_lr |
null | Virtual SGD lr; null → MIR.DEFAULT_VIRTUAL_LR (1e-4) |
mir_lora_only |
true | Virtual step only on LoRA params (essential on 3B+ models) |
YAML:
continual_learning:
algorithm:
name: mir
buffer_size_per_task: 500
replay_batch_ratio: 0.3
balanced_sampling: false
mir_refresh_interval: 200
mir_candidate_size: 16
mir_top_k: 8
mir_virtual_lr: null
mir_lora_only: trueBackbone QwenGR00T-3B + LoRA (r=32); 50 rollouts × 10 tasks; final-checkpoint matrix eval.
| Method | LIBERO-Goal | LIBERO-Long |
|---|---|---|
| ASR / BWT (pp) | ASR | |
| Sequential FT | 9.8 / — | — |
| ER (aligned, buf=1000) | ~52 / — | 16.0 |
| MIR (refresh50 recipe) | 77.0 / −7.8 | 30.0 |
— rows are in-flight or out of current scope. BWT requires a full
T×T matrix eval (drop --last-only); the LIBERO-Long row is a 50-trial
× 10-task final-checkpoint single-row eval.
Cross-architecture reference (LIBERO-Goal):
| Backbone | Method | ASR |
|---|---|---|
| QwenGR00T (full) | ER | 51.6 % |
# 1. Train (~17 h on 4× A800 80 GB)
bash scripts/run_continual_learning_scripts/run_cl_train.sh \
--model qwengr00t --algo mir --dataset libero_goal \
--run-id qwengr00t_mir_libero_goal --gpus 0,1,2,3
# 2. Eval (50 trials × 10 tasks; final ckpt only)
bash scripts/run_continual_learning_scripts/run_cl_eval.sh \
--run-id qwengr00t_mir_libero_goal \
--model qwengr00t \
--gpus 0,1 --trials 50 --last-onlyFor the BWT/F numbers a full 10×10 matrix eval is required — drop
--last-only to evaluate every per-task checkpoint, then run
compute_cl_matrix_metrics.py results/eval_cl/<run_id>.
ASR = mean success rate across the 10 tasks at the final checkpoint.
BWT = 1/(N−1) · Σ_{i<N} (a_{N,i} − a_{i,i}) in pp (Lopez-Paz &
Ranzato 2017); 0 = no forgetting, negative = forgetting. Recompute
both for any run via
compute_cl_matrix_metrics.py.
All examples assume a fresh clone and conda activate alphabrain. Each
block cds into the repo root so commands can be pasted verbatim.
run_cl_train.sh uses a three-axis interface: pick --model, --algo,
and --dataset; the script composes cl_base.yaml with the right model
overlay automatically and wires up the algorithm-specific OmegaConf overrides.
cd /path/to/AlphaBrain-CL
# Default — QwenGR00T LoRA + ER on LIBERO-Goal (~15 h on 2× A800)
bash scripts/run_continual_learning_scripts/run_cl_train.sh
# Pipeline smoke test — 5 steps × 10 tasks, ~3 min
bash scripts/run_continual_learning_scripts/run_cl_train.sh --smoke
# MIR 77 % LIBERO-Goal recipe (refresh=50 frozen recipe)
bash scripts/run_continual_learning_scripts/run_cl_train.sh \
--model qwengr00t --algo mir --dataset libero_goal --gpus 0,1,2,3
# MIR on LIBERO-Long (same refresh=50 recipe)
bash scripts/run_continual_learning_scripts/run_cl_train.sh \
--model qwengr00t --algo mir --dataset libero_long --gpus 0,1,2,3
# NeuroVLA + ER
bash scripts/run_continual_learning_scripts/run_cl_train.sh \
--model neurovla --algo er --dataset libero_goal \
--run-id neurovla_er_libero_goal
# LlamaOFT + ER
bash scripts/run_continual_learning_scripts/run_cl_train.sh \
--model llamaoft --algo er --dataset libero_goal
# QwenGR00T full-param ER (disable LoRA via passthrough override)
bash scripts/run_continual_learning_scripts/run_cl_train.sh \
--model qwengr00t --algo er -- --lora.enabled=false
# Pin GPUs + custom step budget
bash scripts/run_continual_learning_scripts/run_cl_train.sh \
--model qwengr00t --algo er --gpus 1,2 -- \
--continual_learning.steps_per_task=20000Checkpoints are written to results/Checkpoints/<run_id>/checkpoints/:
| Variant | Artifacts per task |
|---|---|
| LoRA | task_<k>_id<k>_steps_<N>_{lora_adapter/, action_model.pt, _cl_state.json} |
| Full-parameter | task_<k>_id<k>_steps_<N>_pytorch_model.pt (+ _cl_state.json) |
The _cl_state.json sidecar stores per-algorithm metadata (e.g. ER
buffer sizes). Tensor-heavy state (replay samples) is not serialized
— it is rebuilt on resume by replaying on_task_end for each completed
task.
# LIBERO — final ckpt × 50 trials per task
bash scripts/run_continual_learning_scripts/run_cl_eval.sh \
--run-id qwengr00t_mir_libero_goal \
--model qwengr00t \
--gpus 0,1 --trials 50 --last-onlyPer-task results land in results/eval_cl/<run_id>/. Drop
--last-only to evaluate every per-task checkpoint and produce the
full T×T matrix needed for BWT/F.
dataset_statistics.jsonrepair. If a run dir is missingdataset_statistics.json(legacy runs) or the file came from a different mixture, eval will mis-denormalize actions and SR collapses to ~0%. Recompute on the fly:python scripts/run_continual_learning_scripts/compute_dataset_statistics.py \ --config configs/continual_learning/<your_config>.yaml \ --out results/Checkpoints/<run_id>/dataset_statistics.jsonThe script auto-detects binary-gripper axes (
q01==0 ∧ q99==1, the LIBERO convention) and forcesmask=Falseon those, matching the eval-sideunnormalize_actionsbinarize-then-passthrough branch.
The same run_cl_train.sh handles task streams beyond the LIBERO
suites. Define a stream inline in a YAML config — no Python changes
required:
continual_learning:
task_sequence:
base_data_mix: my_custom_mix # must exist in DATASET_NAMED_MIXTURES
num_tasks: 5
task_order: [2, 0, 4, 1, 3] # optional visit order
task_stream_mode: by_dataset # | by_task_index | auto
steps_per_task: 5000Partitioning modes:
task_stream_mode |
Semantics |
|---|---|
by_task_index |
LIBERO default: partition one multi-task parquet by its task_index column. |
by_dataset |
Each sub-dataset in the mixture is one CL task. |
auto |
Try by_task_index; fall back to by_dataset if it yields < 2 tasks. |
Implementation notes and guidance for adding new benchmarks are
collected in README_custom_streams.md.
conda activate alphabrain
cp .env.example .envEdit .env with your local paths. Required:
| Variable | Purpose |
|---|---|
PRETRAINED_MODELS_DIR |
Parent directory holding Qwen2.5-VL-3B-Instruct/, Llama-3.2-11B-Vision-Instruct/, etc. |
LEROBOT_LIBERO_DATA_DIR |
LeRobot-format LIBERO data root. |
LIBERO_PYTHON |
Python from a separate conda env containing robosuite and libero (eval-only). |
LIBERO_HOME |
LIBERO project root (for simulator configuration paths). |
Optional env vars are documented inline in .env.example for any
custom streams beyond LIBERO.
| Flag | Description | Default |
|---|---|---|
--model MODEL |
VLA backbone. Choices: qwengr00t | neurovla | llamaoft | paligemma. |
qwengr00t |
--algo ALGO |
CL algorithm. Choices: er | mir. |
er |
--dataset DS |
Task stream. Choices: libero_goal | libero_long. |
libero_goal |
--run-id ID |
Override auto-generated run_id ({model}_{algo}_{dataset}). |
auto |
--gpus SPEC |
Count ("2") or id-list ("1,2,3"). A list pins CUDA_VISIBLE_DEVICES. |
auto-detect |
--port N |
accelerate main process port. |
auto-pick |
--smoke |
5 steps/task × batch 4 — end-to-end pipeline check. | off |
--yaml PATH ... |
Advanced: one or more explicit yaml paths merged left-to-right. Skips three-axis logic. | — |
-- |
Pass-through OmegaConf overrides (e.g. --lora.rank=64). |
— |
-h, --help |
Full help text. | — |
Config layout (under configs/continual_learning/):
cl_base.yaml ← shared trainer / CL / dataset defaults (ER + libero_goal)
models/
qwengr00t.yaml ← QwenGR00T-3B framework block + model-specific LR keys
neurovla.yaml ← NeuroVLA framework block + model-specific overrides
llamaoft.yaml ← LlamaOFT framework block + model-specific overrides
paligemma.yaml ← PaliGemmaOFT framework block + model-specific overrides
The script auto-composes cl_base.yaml + models/<model>.yaml and injects
algorithm / dataset dotlist overrides. For custom configs, pass explicit paths:
bash scripts/run_continual_learning_scripts/run_cl_train.sh \
--yaml configs/continual_learning/cl_base.yaml \
configs/continual_learning/models/qwengr00t.yaml \
-- --continual_learning.steps_per_task=5000Algorithm selection is via --algo flag (script generates overrides automatically):
--algo |
What the script injects |
|---|---|
er |
Base defaults (replay block already enabled in cl_base.yaml) |
mir |
Disables replay block; sets algorithm.name=mir with refresh=50 frozen recipe |
| Flag | Description | Default |
|---|---|---|
--run-id ID |
Required. Run directory under results/Checkpoints/. |
— |
--model NAME |
Required for LoRA runs — model name (qwengr00t, neurovla, llamaoft, paligemma). Auto-expands to cl_base.yaml + models/<name>.yaml. |
— |
--gpus LIST |
Comma-separated GPU id list; determines parallelism. | 0 |
--suite NAME |
libero_goal, libero_spatial, libero_object, or libero_10. |
libero_goal |
--trials N |
Rollouts per task (production = 50). | 10 |
--port-base N |
Starting port (each parallel worker gets +i). |
5694 |
--output-base PATH |
Eval results root. | results/eval_cl/<run_id> |
--last-only |
Evaluate only the final task checkpoint. | off |
The evaluator automatically:
- Discovers every
task_*_lora_adapter/ortask_*_pytorch_model.ptunder<run_id>/checkpoints/. - Detects LoRA runs and merges adapters into full checkpoints on demand (cached as
*_merged.pt). - Parallelises across
--gpus— each worker owns a dedicated policy server + port. - Emits per-checkpoint
eval.log+server.logunder the output base.
scripts/run_continual_learning_scripts/run_cl_train.sh (self-contained wrapper)
│
│ composes cl_base.yaml + models/<model>.yaml,
│ injects algo / dataset overrides, loads .env,
│ exec accelerate launch
▼
AlphaBrain/training/continual_learning/train.py (trainer)
│
│ outer loop: tasks 0..N-1
│ inner loop: standard VLA training on each task
│ dispatches through CLAlgorithm hooks
│
├── algorithms/
│ ├── base.py CLAlgorithm protocol + CLContext
│ ├── rehearsal_based/ ER, MIR
│ ├── regularization_based/ EWC
│ └── dynamic_based/ (planned: DWE / Weight Merge / PackNet)
│
├── datasets/ TaskFilteredDataset + task_sequences
└── trainer_utils/peft/
apply_lora() · save_lora_checkpoint() · load_and_merge()
merge_lora_checkpoint (CLI for post-hoc adapter merging)
Every algorithm subclass overrides only the hooks it needs. Per-step hooks run in the inner loop; task-level hooks bracket each CL task:
| Hook | When | Typical override |
|---|---|---|
observe(batch, task_id) |
Per step, before forward | Online bookkeeping (SI, streaming reservoir) |
modify_batch(batch, task_id) |
Per step, before forward | ER / MIR inject replay samples |
compute_penalty(model) |
Per step, inside autocast block | EWC / SI return λ · regularizer tensor |
after_backward(model) |
Per step, after accelerator.backward() |
MIR snapshots gradients (DeepSpeed-safe) |
on_task_start(ctx) |
Before each task's inner loop begins | DWE expands model; MIR installs grad hooks |
on_task_end(ctx) |
After each task's inner loop finishes | ER populates buffer; EWC computes Fisher |
ctx is a CLContext dataclass carrying task_id, model,
task_dataset, task_dataloader, and accelerator — the algorithm
picks the handles it needs.
| Component | Path |
|---|---|
| CL trainer | AlphaBrain/training/continual_learning/train.py |
| Custom-stream extensions (add-only) | AlphaBrain/training/continual_learning/{train_custom,datasets/custom_streams}.py |
| CL algorithms — base / rehearsal / regularization / dynamic | AlphaBrain/training/continual_learning/algorithms/ |
Task sequences + TaskFilteredDataset |
AlphaBrain/training/continual_learning/datasets/task_sequences.py |
| LoRA helpers (inject / save / load & merge) | AlphaBrain/training/trainer_utils/peft/ |
| YAML configurations | configs/continual_learning/ |
| Documentation hub | docs/continual_learning/ |
| mkdocs quickstart | docs/quickstart/continual_learning.md |
- flash-attn ABI mismatch. If the active environment has
torch ≥ 2.6butflash-attnwas built againsttorch 2.2, the defaultattn_implementation: flash_attention_2crashes at model load. Workaround — override to SDPA:or reinstall:bash run_cl_train.sh -- --framework.qwenvl.attn_implementation=sdpa
pip install flash-attn --no-build-isolation --force-reinstall. - Eval uses a separate conda env.
LIBERO_PYTHONin.envmust point to an interpreter withrobosuiteinstalled (distinct from the training env). The wrapper auto-detects and falls back to a conventionalvlacl_engine_evalinstall when the configured path lacks the dependency. - LoRA evaluation caches merged checkpoints. The first evaluation
call merges each LoRA adapter into a
*_merged.ptfile (~7 GB per task). Subsequent calls reuse the cache; remove the file to force a re-merge. - DeepSpeed ZeRO-2 +
contiguous_gradients: true. Afteraccelerator.backward()the parameter.gradattribute is None — gradients live in DeepSpeed's internal buffer. Algorithms that need per-step gradient access (currently: MIR) usetorch.nn.Parameter.register_hookto capture gradients during autograd instead. If you add a new algorithm that needs gradients, follow MIR's pattern. - Default run takes ~15 h on 2× A800 80 GB. Use
--smoketo verify the pipeline in three minutes before committing a full run.
- Full experiment record and 10×10 matrices:
docs/continual_learning/EXPERIMENTS.md - Annotated experiment matrix:
docs/continual_learning/EXPERIMENT_MATRIX.md - Source-layout index:
docs/continual_learning/CODE_LAYOUT.md - Hosted quickstart (mkdocs):
docs/quickstart/continual_learning.md - Custom-stream implementation notes:
README_custom_streams.md