Preshuffle fp8 4wave main#520
Open
coderfeli wants to merge 4 commits into
Open
Conversation
Adds opt-in ``a_preshuffled`` / ``b_preshuffled`` compile flags to ``compile_fp8_gemm``. When enabled, the matching operand is read from a K-major outermost DRAM layout (``(K//128, rows//16, 16, 128)``) where every K-atom slab of ``rows * 128`` bytes is contiguous; the 8 row_parts that one wave- step touches span only 16 KB instead of the row-major ``8 * K`` bytes. LDS layout, MFMA fragment shape, and wave assignment are unchanged. Host side: * ``preshuffle_a(A)`` / ``preshuffle_b(B_T)`` torch helpers perform the offline permute via ``T.reshape(rows//16, 16, K//128, 128).permute(2, 0, 1, 3).contiguous()``. Kernel side: * ``_compute_global_swizzle`` takes a ``preshuffled`` flag; when set it emits the K-major slab linearisation ``(r // 16) * 2048 + (r % 16) * 128 + c`` instead of ``r * K + c``. * Tile-level base becomes ``tile_row * BLOCK_K`` (the in-slab offset) and per-K-iter step becomes ``rows * BLOCK_K`` (next K slab), both fed through the unchanged ``soffset`` path. Test side: * ``test_fp8_gemm_4wave`` is parametrised across ``[rowmajor, preshuffle]``, doubling the matrix. * CLI gets ``--preshuffle_a`` / ``--preshuffle_b`` flags; the default iteration count is bumped to 10 warmup + 100 iters so the BLOCK=256 shapes reach steady state for stable perf. Perf (3 trials, 10 warmup + 100 iters): 512 x 2112 x 7168 (BLOCK=64): 549 -> 588 TFLOPS (+7.1%) 5120 x 5120 x 8320: 1937 -> 1967 TFLOPS (+1.5%) 8192 x 8192 x 8192: 2387 ~ 2378 TFLOPS (95.1% peak) 9728 x 8192 x 8320: 2362 ~ 2357 TFLOPS (94.3% peak) Preshuffle is strictly non-regressing across all parametrised shapes; gains accumulate on small / memory-bound shapes, large compute-bound shapes are already at ~95% MI355X peak with the row-major layout so no further DRAM-side headroom is left. Co-authored-by: Cursor <cursoragent@cursor.com>
Empirical sweep on all four parametrised shapes showed A-side
preshuffle adding zero measurable benefit beyond B-side:
8192 x 8192 x 8192: baseline 2372 / B-only 2364 / A-only 2381 /
A+B 2374 TFLOPS -- all within ~0.5% std
That matches the physics: 8k^3 sits at ~95% MI355X fp8 peak, with
the remaining gap coming from LDS-side barriers, not DRAM access.
Smaller shapes (512 / 5120) already saw the full gain from B-only
(+7.1% / +1.5%), so the second symmetric path only added kernel
parameter and host helper surface for no return.
Removed:
* ``a_preshuffled`` flag on ``compile_fp8_gemm``
* ``preshuffle_a`` host helper
* ``--preshuffle_a`` CLI flag on the test
* The "_preshuffle_2d shared" plumbing -- B uses an inline torch
reshape/permute now
The remaining surface is the minimal one: ``b_preshuffled=True``
plus ``preshuffle_b(B_T)``. Pytest is parametrised as
``[rowmajor, preshuffle_b]`` (4 shapes x 2 = 8 cases).
Co-authored-by: Cursor <cursoragent@cursor.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Preshuffle fp8 4wave main, no perf regression on 8k^3. Better than original preshuffle_gemm on 256x256x256 tile.