feat(losses): split batch-invariant frequency-minimality out of imp-min#899
Open
ocg-goodfire wants to merge 1 commit into
Open
feat(losses): split batch-invariant frequency-minimality out of imp-min#899ocg-goodfire wants to merge 1 commit into
ocg-goodfire wants to merge 1 commit into
Conversation
Ports #543 to the JAX trainer. The rolled imp-min `lp + beta·entropy` becomes two independently-coefficiented terms computed from one shared `(c+eps)^p` pass: imp = Σ_c f_c (imp.coeff) freq = Σ_c f_c · log2(1 + a'·f_c) (freq.coeff), a' = reference_token_count The frequency normalizer is now EXPLICIT (`reference_token_count`) instead of the implicit global `B·T`, so the penalty's curvature is invariant to batch size at a fixed firing rate — batch and frequency-penalty strength become independently tunable. Setting `a' = B·T` reproduces the old behavior exactly; coefficients transfer as `freq.coeff = old imp.coeff · beta`. Tied form: nested `frequency: (coeff, reference_token_count) | None` on BOTH imp configs (`ImportanceMinimalityLoss` + `SmoothL0`), reusing the shared per-component sums in one pass — not a separate flat `LossTerm` (which would refetch the sums and break the self-contained-term model). `beta` is removed. SPEC S7/S8 amended + new S8'. Migrations: all in-repo YAMLs + test/experiment configs (`freq.coeff = imp.coeff·beta`, `reference_token_count = global B·T`; `beta:0` → no freq block). Equivalence goldens preserved WITHOUT regen (`a' = n_positions` reproduces the old `entropy` bit-for-bit). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_01JDynbo3BeS7AEwu8iyje6F
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Ports #543 (torch) to the JAX trainer: splits the frequency-weighted
log2term out of the rolled importance-minimality loss into a standalone, batch-invariant frequency penalty. Per componentc, withf_cthe per-token firing frequency over the global batch (f_c = Σ_{B·T tokens}(ci_c+eps)^p / B·T):L_imp = Σ_c f_cL_freq = Σ_c f_c · log2(1 + a'·f_c), wherea' = reference_token_countThe old rolled form was
Σ_c f_c·(1 + beta·log2(1 + B·T·f_c))— theB·Twas implicit inside thelog2, so the penalty's curvature drifted with batch size. Making the normalizer the explicita'knob decouples the two sparsity pressures and removes the batch coupling.total = imp_coeff·L_imp + freq_coeff·L_freqwith independent coefficients.This is the deferred "token-count reparameterization" of imp-min (see the imp-min scaling notes): the old
beta/entropy term'slog2(N)coupling was a per-token-batch artifact;reference_token_countremoves it.Tied form (single config, two outputs)
Unlike the torch PR's two separate
Metricclasses, the JAX trainer computes both terms from one(c+eps)^ppass. Sobetais replaced by a nested, optionalfrequency: (coeff, reference_token_count) | Noneon both imp-min configs (ImportanceMinimalityLoss+SmoothL0ImportanceMinimalityLoss), reusing the shared per-component sums. A separate flatLossTermwas rejected: a standalone frequency term has no penaltyψof its own and would either refetch the sums (a second full pass over CI at 8B) or reach into the imp term's intermediates, breaking the self-contained-term model that replacedLossSpec.Coefficient translation (for migrated configs)
freq.coeff = old imp.coeff · old beta,imp.coeffunchanged,reference_token_count = global B·T = pd.batch_size · max_seq_len.beta: 0configs just dropbeta(nofrequencyblock).Motivation and Context
The implicit-
B·Tcoupling meant frequency-penalty strength silently rescaled with batch size. Splitting + normalizing makes the two sparsity pressures independently tunable and batch-invariant.How Has This Been Tested?
make typeclean.param_decomp/tests/test_frequency_minimality.py: closed-form, batch-invariance,f=0 → 0,a'=B·Treproduces the old rolled value,lpindependent ofa', dispatch (frequencyNone→freq=0).a' = n_positionsreproduces the oldentropybit-for-bit, sotest_imp_min_bf16_seam/test_imp_min_global_reduction/ thetests/equivalence/harness just passreference_token_count = n_positions. The smooth-L0 + global-reduction tests also pass under the 4-device sim.Code changes
losses.py(_imp_min_terms → (lp, freq)withreference_token_count;importance_minimality_terms/smooth_l0_importance_minimality_terms/imp_min_termssignatures),configs.py(FrequencyMinimalityConfig,betadropped from both imp configs),train.py(imp_coeff·lp + freq_coeff·freq; newfreqmetric),run.py(train/loss/FrequencyMinimalityLoss).imp_min_terms), the config example, and theL_freqrationale. (Per the repo's "one rule," the SPEC change wants @ocg-goodfire's deliberate sign-off.)invariance_checkkeepsfrequencyactive (it guards the global-reduction / Jensen path under the device-count invariance check); structural smoke tests (checkpoint, llama8b, simple-mlp, io, no-bake) migrate tofrequency=None.Notes / decisions made under ambiguity
feature/jax-full32L-port(the merge offeature/jaxinto the full-32L work, which carries the flat-LossTermrefactor + the smooth-L0 penalty this builds on). Notfeature/jax, so the diff stays imp-min-only.frequencyfield on both imp configs mirrors how they already duplicatedbeta; a shared base would dedupe it but the existing code chose duplication — left consistent.pnorm/gamma/epsfor the frequency term is intentionally not offered: the two terms share one pass, so it's optionality that doesn't exist in practice. A future second-pass variant could add it if ever wanted.Does this PR introduce a breaking change?
Yes —
ImportanceMinimalityLoss/SmoothL0ImportanceMinimalityLossno longer acceptbeta. No legacy fallback (house style); all in-repo configs are migrated here. Stored/snapshotted configs of running runs are unaffected.🤖 Generated with Claude Code