Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions AlphaBrain/training/continual_learning/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""Continual Learning module.

Sub-packages:
algorithms/ — CL algorithms (ER / MIR / EWC / …) grouped by mechanism
(rehearsal_based, regularization_based, dynamic_based)
plus the :class:`CLAlgorithm` base and :class:`CLContext`.
algorithms/ — CL algorithms (ER / MIR / …) grouped by mechanism
(rehearsal_based, dynamic_based) plus the
:class:`CLAlgorithm` base and :class:`CLContext`.
datasets/ — Task sequences and per-task dataset filtering.

Top-level entry:
Expand All @@ -13,7 +13,6 @@
Re-exports for convenience (fully-qualified paths also work):
`ER` ← `algorithms.rehearsal_based.er.ER`
`MIR` ← `algorithms.rehearsal_based.mir.MIR`
`EWC` ← `algorithms.regularization_based.ewc.EWC`
`CLAlgorithm` ← `algorithms.base.CLAlgorithm`
`CLContext` ← `algorithms.base.CLContext`
`build_cl_algorithm` ← `algorithms.build_cl_algorithm`
Expand All @@ -22,7 +21,6 @@
CLAlgorithm,
CLContext,
ER,
EWC,
MIR,
build_cl_algorithm,
)
Expand All @@ -31,7 +29,6 @@
"CLAlgorithm",
"CLContext",
"ER",
"EWC",
"MIR",
"build_cl_algorithm",
]
45 changes: 11 additions & 34 deletions AlphaBrain/training/continual_learning/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
"""Continual-learning algorithms.

Algorithms are grouped by their mechanism into three sub-packages:
Algorithms are grouped by their mechanism into sub-packages:

- :mod:`.rehearsal_based` — replay past-task samples
(ER, MIR, planned DER / A-GEM)
- :mod:`.regularization_based` — penalise movement away from important
parameters (EWC, planned SI / LwF)
- :mod:`.dynamic_based` — adapt the model architecture across tasks
(planned DWE / Weight Merge / PackNet)

Expand All @@ -19,19 +17,16 @@
----------
The concrete classes are re-exported at this package level for convenience::

from AlphaBrain.training.continual_learning.algorithms import EWC, MIR, ER
from AlphaBrain.training.continual_learning.algorithms import MIR, ER

Fully-qualified paths also work, e.g.
``algorithms.regularization_based.ewc.EWC`` or
``algorithms.rehearsal_based.er.ER``.
Fully-qualified paths also work, e.g. ``algorithms.rehearsal_based.er.ER``.
"""
from typing import Optional

from AlphaBrain.training.continual_learning.algorithms.base import (
CLAlgorithm,
CLContext,
)
from AlphaBrain.training.continual_learning.algorithms.regularization_based import EWC
from AlphaBrain.training.continual_learning.algorithms.rehearsal_based import ER, MIR


Expand All @@ -50,17 +45,17 @@ def build_cl_algorithm(cfg, seed: int = 42) -> Optional[CLAlgorithm]:
replay_batch_ratio: 0.3
balanced_sampling: false

2. Generic algorithm (EWC / RETAIN / DER / DWE / …)::
2. Generic algorithm (ER / MIR / …)::

continual_learning:
algorithm:
name: ewc
lambda: 1.0e4
gamma: 1.0
lora_only: true
fisher_num_batches: 50
fisher_clip: 1.0e4
grad_clip_per_sample: 100.0
name: mir
buffer_size_per_task: 500
replay_batch_ratio: 0.5
mir_refresh_interval: 50
mir_candidate_size: 16
mir_top_k: 8
mir_lora_only: true

Returns ``None`` when neither section is configured — the trainer then
runs a plain sequential baseline without CL interventions.
Expand Down Expand Up @@ -94,23 +89,6 @@ def build_cl_algorithm(cfg, seed: int = 42) -> Optional[CLAlgorithm]:
balanced_sampling=algo_cfg.get("balanced_sampling", False),
seed=seed,
)
if key == "ewc":
# `lambda` is a Python keyword; accept either `lambda` or
# `ewc_lambda` (OmegaConf tolerates the former as a dict key).
ewc_lambda = algo_cfg.get("lambda", None)
if ewc_lambda is None:
ewc_lambda = algo_cfg.get("ewc_lambda", 1.0e4)
excl = algo_cfg.get("exclude_name_substrings", None)
return EWC(
ewc_lambda=ewc_lambda,
gamma=algo_cfg.get("gamma", 1.0),
lora_only=algo_cfg.get("lora_only", True),
fisher_num_batches=algo_cfg.get("fisher_num_batches", 50),
fisher_clip=algo_cfg.get("fisher_clip", 1.0e4),
grad_clip_per_sample=algo_cfg.get("grad_clip_per_sample", 100.0),
fisher_save_dir=algo_cfg.get("fisher_save_dir", None),
exclude_name_substrings=list(excl) if excl is not None else None,
)
if key == "mir":
return MIR(
buffer_size_per_task=algo_cfg.get("buffer_size_per_task", 500),
Expand All @@ -134,7 +112,6 @@ def build_cl_algorithm(cfg, seed: int = 42) -> Optional[CLAlgorithm]:
"CLAlgorithm",
"CLContext",
"ER",
"EWC",
"MIR",
"build_cl_algorithm",
]
16 changes: 6 additions & 10 deletions AlphaBrain/training/continual_learning/algorithms/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Abstract base class for continual-learning algorithms.

All CL algorithms (Experience Replay / EWC / DER / RETAIN / DWE / ...) implement
All CL algorithms (Experience Replay / DER / RETAIN / DWE / ...) implement
this interface. The continual trainer (`AlphaBrain.training.continual_learning.train`)
only talks to algorithms through this protocol, so new methods can be plugged
in without touching the training loop.
Expand All @@ -10,11 +10,9 @@
- `ER` (algorithms.rehearsal_based.er) — experience replay with reservoir
sampling (uniform or per-task balanced).
- `MIR` (algorithms.rehearsal_based.mir) — interference-aware replay.
- `EWC` (algorithms.regularization_based.ewc) — Fisher-weighted L2 penalty.

Planned
-------
- `EWC` Elastic Weight Consolidation (Kirkpatrick et al. 2017)
- `DER` Dark Experience Replay (Buzzega et al. 2020)
- `RETAIN` Weight Merging / model souping (Wortsman et al. 2022, variants)
- `DWE` Dynamic Weight Expansion (per-task adapters)
Expand All @@ -31,14 +29,13 @@
modify_batch(batch, task_id) — return the batch the model will consume
(ER / DER inject replay samples here).
compute_penalty(model) — return a scalar added to the task loss
(EWC / SI regularizers).
(regularizer-based methods).

Task-level hooks (bracket each CL task):
on_task_start(context) — before training starts on a new task
(DWE expands the model here).
on_task_end(context) — after a task finishes training
(ER populates buffer, EWC computes Fisher,
RETAIN merges weights).
(ER populates buffer, RETAIN merges weights).

On resume, the trainer replays `on_task_end` for each completed task so
algorithms can rebuild state not captured in `state_dict`.
Expand Down Expand Up @@ -109,9 +106,9 @@ def modify_batch(self, batch: Any, task_id: int) -> Any:
def compute_penalty(self, model: Any) -> Optional[torch.Tensor]:
"""Return a scalar penalty to add to the task loss, or None.

Typical uses:
* EWC: λ · Σ F_i · (θ_i − θ*_i)²
* SI: λ · Σ Ω_i · (θ_i − θ*_i)²
Typical uses (regularization-based methods):
* Fisher-weighted L2: λ · Σ F_i · (θ_i − θ*_i)²
* SI path integral: λ · Σ Ω_i · (θ_i − θ*_i)²

Default: None (no extra penalty).
"""
Expand Down Expand Up @@ -151,7 +148,6 @@ def on_task_end(self, context: CLContext) -> None:

Typical uses:
* ER: populate the buffer from `context.task_dataset`.
* EWC: compute Fisher on `context.task_dataloader`, snapshot θ*.
* RETAIN: merge the newly-trained weights into the running model.

Default: no-op.
Expand Down

This file was deleted.

Loading