Record: Fused MLP (Triton+CUTLASS EVT) + MLP 3.5× + Mixed int5/int6 + Brotli — 1.1125 BPB (3-seed mean)#1105
Open
abaybektursun wants to merge 3 commits intoopenai:mainfrom
Conversation
ba665dd to
64ce201
Compare
9b27cf4 to
c27131c
Compare
… Brotli — val_bpb 1.1125 (3-seed mean) Seed 314: 1.1123 BPB / 1.87802 nats, 14.52 MB, 6844 steps, 87.7ms/step Seed 999: 1.1124 BPB / 1.87821 nats, 14.52 MB, 6846 steps, 87.7ms/step Seed 1337: 1.1129 BPB / 1.87910 nats, 14.53 MB, 6828 steps, 87.7ms/step Delta vs merged SOTA (our PR 1019): -0.00215 nats (-0.0013 BPB). Delta vs prior leaderboard (our PR 549): -0.01158 nats. Welch's t = -17.63, p < 0.01. Changes from PR 1019: 1. Fused Triton TMA forward + CUTLASS EVT backward MLP kernels 2. Pre-computed activation gradient (branch-free backward) 3. MLP 3.5x (1792 hidden dim, motivated by SVD analysis) 4. Hessian-based mixed int5/int6 quantization (motivated by quant sensitivity) 5. Brotli-11 compression (-581KB vs LZMA-9) 6. LR floor 0.05 7. Memmap multi-shard data pipeline (PR 726) Negative: Turbo-Muon +0.0018 BPB worse at scale, reverted to NS5. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
c27131c to
0df40cc
Compare
|
@abaybektursun - this is a fantastic write-up! Congrats on the SLOT improvement. If you need to free up even more room, you should check out the shrink.py script I used in PR 1089. I was able to shrink the train_gpt.py file by ~100KB. That might let you reduce pruning and/or promote one more group to int6. |
Contributor
Author
|
Ohhh I think with newer Pytroch performance and speed will be even better! I will try it when I can get my hands around 8xH100s |
6 tasks
…ng runs - Replace train_gpt.py with version containing SLOT eval-time adaptation (forward_hidden + compute_logits + per-batch delta optimization) - Fix hyperparameter defaults: MLP_MULT 3.0->3.5, WARMDOWN_ITERS 3500->4000, BIGRAM_VOCAB_SIZE 2048->3072, BIGRAM_DIM 128->112, LR_FLOOR 0.0->0.05, SLOT_ENABLED 0->1 - Update submission.json: 1.1125->1.1088 BPB, 1.8784->1.8722 nats (SLOT) - Replace logs with SLOT run logs (3-seed: 314/999/1337) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
| @@ -0,0 +1,110 @@ | |||
| # Record: Fused MLP (Triton+CUTLASS EVT) + MLP 3.5× + Mixed int5/int6 + Brotli | |||
SLOT optimizes a shared delta on all positions then scores those same tokens — position t's prediction is influenced by future tokens through the broadcast delta. Reverted to clean non-SLOT sliding-window eval. Results: 1.1125 BPB (3-seed mean), 1.8784 nats. Code: train_gpt_mlp35_mixed.py with fixed defaults. SLOT results (1.1088 BPB) kept in PR description for reference only. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Results: val_bpb 1.1125 (3-seed mean) | 1.8784 nats | 8×H100 SXM | 600s | ~14.52 MB
Mixed quantization: 10 layers int6, 56 layers int5, no pruning needed.
Our merged PR 1019 (current SOTA): 1.88059 nats (1.1138 BPB). Delta: −0.00215 nats (−0.0013 BPB). Our PR 549 (prior SOTA): 1.89002 nats (1.1194 BPB). Delta vs PR 549: −0.01158 nats, Welch's t = −17.63, p < 0.01.
SLOT study (removed from submission — causality violation)
SLOT (Selective Logit Offset Tuning) optimizes a 512-dim delta vector at the last hidden layer using AdamW (lr=0.003, 5 steps) per sliding-window batch. It gave −0.0037 BPB (1.1125 → 1.1088), but violates causality: the delta has shape
[1,1,512]and is optimized using targets at all positions, then applied to all positions — so position t's prediction is influenced by future tokens through the shared delta. Removed from submission code; results below are for reference only.Credit: PR 609 (saml212).
Prior results: fused kernels + Brotli only (val_bpb 1.1138, 3-seed)
Delta vs PR 549: −0.00943 nats. Welch's t = −10.26, df ≈ 3.78, p < 0.01.
Throughput recovery
Our PR 1019 (now merged as SOTA) traded throughput for quality — full Hessian GPTQ and BigramHash 3072×112 added 3.3ms/step. Fused MLP kernels recover that regression. Mechanistic analysis of that model identified MLP as the capacity bottleneck, leading to MLP 3.5× (enabled by mixed quantization + Brotli headroom).
Changes vs our PR 1019
1. Fused MLP Kernels: Triton TMA Forward + CUTLASS EVT Backward
Forward (Triton TMA): Fuses
F.linear(x, up_w) → LeakyReLU(0.5) → squareinto a single kernel. The 302MB intermediate never touches HBM.Backward (CUTLASS EVT): Fuses
(go @ down_w.T) * act_gradinto a single CUTLASS 3.x kernel via Epilogue Visitor Tree. The elementwise multiply runs in the GEMM epilogue while tiles are still in registers — eliminating one 302MB write + read per layer.Key design insight — pre-computed activation gradient: We store the activation gradient in the forward pass instead of the pre-activation:
The identity
post = 0.5 · act_grad · preholds for both signs because:This eliminates all branching from the backward, reducing the CUTLASS EVT epilogue to a trivial 3-node tree:
Sm90EVT<multiplies, AccFetch, AuxLoad>. No conditionals in the kernel.CUTLASS EVT is a hard dependency — no silent fallback.
Kernel benchmarks + incremental deltas (2×H100)
Per-layer kernel timing:
CUTLASS vs Triton: +0.032 ms/layer, +0.347 ms/step kernel-level.
End-to-end training (35 steps, seed=42):
Kernel-level 0.347ms translates to 0.43ms end-to-end (cache/scheduling interactions).
8×H100: 86.7ms (our PR 1019, unfused) → 83.5ms (this PR) = −3.2ms/step (−3.7%).
Step-time profile — where all 313ms goes (2×H100, Nsight)
Why surgical fusion, not full-MLP autograd.Function: The 21.6% from torch.compile's cross-layer fusions (RMSNorm backward, residual adds, RoPE backward) only exists because these ops are visible to the compiler. Wrapping the full MLP backward in
autograd.Functionmakes it opaque to Inductor — all backward GEMMs plus cross-layer fusion run in eager mode, 2.7× slower net (identified in our PR 670). We fuse only forward and one backward GEMM+pointwise, preserving the compiler's scope.Top individual kernels:
Wall-clock breakdown: forward+backward compute ~94%, NCCL ~1.6%, CPU overhead ~4.1%.
2. Brotli-11 Compression (replaces LZMA-9)
−581 KB (−5.9%) vs LZMA-9. Independently discovered; PR 1089 (mikeapedia) also uses Brotli.
3. Memmap Multi-Shard Data Pipeline + GPU Prefetch
Coprime-stride sampling, daemon thread, CUDA stream prefetch. Credit: DeepReinforce (PR 726).
4. MLP 3.5× (1536 → 1792 hidden dim)
Motivated by mechanistic analysis: SVD analysis of our PR 1019 model showed MLP at 94.4% rank utilization (fully packed) while attention Q sat at 72.6% (spare capacity). The model was parameter-starved in MLP, not attention — so we made MLP wider.
Increases hidden dim from 3.0 × 512 = 1536 to 3.5 × 512 = 1792. Model goes from 27.07M to 29.95M params (+2.88M). At uniform int6, the 29.95M model compresses to 17.36 MB — 1.36 MB over the 16 MB limit. This is what makes mixed quantization (change 5) necessary.
Impact: −0.003 BPB from capacity, +13ms/step on 2×H100 (bigger GEMMs). Credit: PR 185 (dttdrv), PR 344 (aryanbhosale).
5. Mixed int5/int6 Quantization (Hessian-based)
Motivated by mechanistic analysis: Per-matrix quantization sensitivity showed MLP accounts for 80% of int6 quantization damage (MLP_down: +0.0039 BPB total, all Q matrices: +0.0003 BPB total — a 13× gap). Giving more bits to MLP is the optimal allocation.
Instead of uniform int6 for all layers, use int5 as default and promote the top 10 most sensitive layers to int6 based on Hessian trace ranking. Sensitivity = trace(H) where H = X^TX collected during GPTQ calibration. MLP projection layers in early blocks are most sensitive — they get int6; the remaining 56 layers get int5.
Uniform int5 loses ~0.019 BPB (catastrophic). Targeted Hessian-based allocation keeps quality loss under ~0.003 BPB while saving ~1.5 MB — exactly the headroom MLP 3.5× needs to fit under 16 MB. The wider MLP also made the model 3.6× less sensitive to quantization overall — information distributed across more dimensions means no single weight is load-bearing.
Credit: mixed quant concept PR 76 (Will DePue), gradient-guided PR 332 (saml212), Hessian-based PR 1089 (mikeapedia).
6. LR Floor (0.05)
During warmdown, learning rate normally decays to 0. With
lr_floor=0.05, it stops at 5% of peak instead. Prevents the optimizer from stalling, which helps with quantization-sensitive weight distributions still being refined at end of training.Impact: ~0.001 BPB. Credit: PR 130 (mohosy).
Negative Results
Architecture
Calibration legality: AR self-generated (64 seqs × 2048 tokens, temp=0.8). No val data, no train data accessed during quantization. Same method as our PR 1019.
Setup & Reproduction
🤖 Generated with Claude Code