Skip to content

feat(losses): split batch-invariant frequency-minimality out of imp-min#899

Open
ocg-goodfire wants to merge 1 commit into
feature/jax-full32L-portfrom
feature/stable-imp-min
Open

feat(losses): split batch-invariant frequency-minimality out of imp-min#899
ocg-goodfire wants to merge 1 commit into
feature/jax-full32L-portfrom
feature/stable-imp-min

Conversation

@ocg-goodfire

Copy link
Copy Markdown
Collaborator

Description

Ports #543 (torch) to the JAX trainer: splits the frequency-weighted log2 term out of the rolled importance-minimality loss into a standalone, batch-invariant frequency penalty. Per component c, with f_c the per-token firing frequency over the global batch (f_c = Σ_{B·T tokens}(ci_c+eps)^p / B·T):

  • importance-minimality → the bare mean: L_imp = Σ_c f_c
  • frequency-minimality (split out) → L_freq = Σ_c f_c · log2(1 + a'·f_c), where a' = reference_token_count

The old rolled form was Σ_c f_c·(1 + beta·log2(1 + B·T·f_c)) — the B·T was implicit inside the log2, so the penalty's curvature drifted with batch size. Making the normalizer the explicit a' knob decouples the two sparsity pressures and removes the batch coupling. total = imp_coeff·L_imp + freq_coeff·L_freq with independent coefficients.

This is the deferred "token-count reparameterization" of imp-min (see the imp-min scaling notes): the old beta/entropy term's log2(N) coupling was a per-token-batch artifact; reference_token_count removes it.

Tied form (single config, two outputs)

Unlike the torch PR's two separate Metric classes, the JAX trainer computes both terms from one (c+eps)^p pass. So beta is replaced by a nested, optional frequency: (coeff, reference_token_count) | None on both imp-min configs (ImportanceMinimalityLoss + SmoothL0ImportanceMinimalityLoss), reusing the shared per-component sums. A separate flat LossTerm was rejected: a standalone frequency term has no penalty ψ of its own and would either refetch the sums (a second full pass over CI at 8B) or reach into the imp term's intermediates, breaking the self-contained-term model that replaced LossSpec.

- type: ImportanceMinimalityLoss     # (or SmoothL0ImportanceMinimalityLoss)
  coeff: 5.0e-06
  pnorm: 2.0
  p_anneal_final_p: 0.4
  eps: 1.0e-06
  frequency:                         # replaces `beta: 0.2`; omit entirely for old beta:0
    coeff: 1.0e-06                    # = old imp.coeff * beta
    reference_token_count: 65536      # global B*T  → reproduces the old behavior exactly

Coefficient translation (for migrated configs)

freq.coeff = old imp.coeff · old beta, imp.coeff unchanged, reference_token_count = global B·T = pd.batch_size · max_seq_len. beta: 0 configs just drop beta (no frequency block).

Motivation and Context

The implicit-B·T coupling meant frequency-penalty strength silently rescaled with batch size. Splitting + normalizing makes the two sparsity pressures independently tunable and batch-invariant.

How Has This Been Tested?

  • make type clean.
  • Full non-slow suite: 449 passed, 5 skipped, 11 xfailed, 0 failures.
  • New param_decomp/tests/test_frequency_minimality.py: closed-form, batch-invariance, f=0 → 0, a'=B·T reproduces the old rolled value, lp independent of a', dispatch (frequency Nonefreq=0).
  • Equivalence goldens preserved WITHOUT regen: a' = n_positions reproduces the old entropy bit-for-bit, so test_imp_min_bf16_seam / test_imp_min_global_reduction / the tests/equivalence/ harness just pass reference_token_count = n_positions. The smooth-L0 + global-reduction tests also pass under the 4-device sim.
  • All 23 in-repo YAMLs validated through the pydantic config union.

Code changes

  • Core: losses.py (_imp_min_terms → (lp, freq) with reference_token_count; importance_minimality_terms / smooth_l0_importance_minimality_terms / imp_min_terms signatures), configs.py (FrequencyMinimalityConfig, beta dropped from both imp configs), train.py (imp_coeff·lp + freq_coeff·freq; new freq metric), run.py (train/loss/FrequencyMinimalityLoss).
  • SPEC.md: amended S7/S8 + new S8′, rewrote the imp-min pseudocode (imp_min_terms), the config example, and the L_freq rationale. (Per the repo's "one rule," the SPEC change wants @ocg-goodfire's deliberate sign-off.)
  • Migrations: all in-repo run/experiment YAMLs + Python test/experiment config constructions. invariance_check keeps frequency active (it guards the global-reduction / Jensen path under the device-count invariance check); structural smoke tests (checkpoint, llama8b, simple-mlp, io, no-bake) migrate to frequency=None.

Notes / decisions made under ambiguity

  • Base branch: targets feature/jax-full32L-port (the merge of feature/jax into the full-32L work, which carries the flat-LossTerm refactor + the smooth-L0 penalty this builds on). Not feature/jax, so the diff stays imp-min-only.
  • Duplicated frequency field on both imp configs mirrors how they already duplicated beta; a shared base would dedupe it but the existing code chose duplication — left consistent.
  • An independent pnorm/gamma/eps for the frequency term is intentionally not offered: the two terms share one pass, so it's optionality that doesn't exist in practice. A future second-pass variant could add it if ever wanted.

Does this PR introduce a breaking change?

Yes — ImportanceMinimalityLoss / SmoothL0ImportanceMinimalityLoss no longer accept beta. No legacy fallback (house style); all in-repo configs are migrated here. Stored/snapshotted configs of running runs are unaffected.

🤖 Generated with Claude Code

Ports #543 to the JAX trainer. The rolled imp-min `lp + beta·entropy` becomes two
independently-coefficiented terms computed from one shared `(c+eps)^p` pass:

  imp  = Σ_c f_c                       (imp.coeff)
  freq = Σ_c f_c · log2(1 + a'·f_c)     (freq.coeff), a' = reference_token_count

The frequency normalizer is now EXPLICIT (`reference_token_count`) instead of the
implicit global `B·T`, so the penalty's curvature is invariant to batch size at a
fixed firing rate — batch and frequency-penalty strength become independently
tunable. Setting `a' = B·T` reproduces the old behavior exactly; coefficients
transfer as `freq.coeff = old imp.coeff · beta`.

Tied form: nested `frequency: (coeff, reference_token_count) | None` on BOTH imp
configs (`ImportanceMinimalityLoss` + `SmoothL0`), reusing the shared per-component
sums in one pass — not a separate flat `LossTerm` (which would refetch the sums and
break the self-contained-term model). `beta` is removed.

SPEC S7/S8 amended + new S8'. Migrations: all in-repo YAMLs + test/experiment
configs (`freq.coeff = imp.coeff·beta`, `reference_token_count = global B·T`;
`beta:0` → no freq block). Equivalence goldens preserved WITHOUT regen
(`a' = n_positions` reproduces the old `entropy` bit-for-bit).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_01JDynbo3BeS7AEwu8iyje6F
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant