Skip to content

Preshuffle fp8 4wave main#520

Open
coderfeli wants to merge 4 commits into
mainfrom
preshuffle-fp8-4wave-main
Open

Preshuffle fp8 4wave main#520
coderfeli wants to merge 4 commits into
mainfrom
preshuffle-fp8-4wave-main

Conversation

@coderfeli
Copy link
Copy Markdown
Collaborator

Preshuffle fp8 4wave main, no perf regression on 8k^3. Better than original preshuffle_gemm on 256x256x256 tile.

coderfeli and others added 4 commits May 13, 2026 10:58
Adds opt-in ``a_preshuffled`` / ``b_preshuffled`` compile flags to
``compile_fp8_gemm``. When enabled, the matching operand is read
from a K-major outermost DRAM layout
(``(K//128, rows//16, 16, 128)``) where every K-atom slab of
``rows * 128`` bytes is contiguous; the 8 row_parts that one wave-
step touches span only 16 KB instead of the row-major ``8 * K``
bytes. LDS layout, MFMA fragment shape, and wave assignment are
unchanged.

Host side:
* ``preshuffle_a(A)`` / ``preshuffle_b(B_T)`` torch helpers perform
  the offline permute via
  ``T.reshape(rows//16, 16, K//128, 128).permute(2, 0, 1, 3).contiguous()``.

Kernel side:
* ``_compute_global_swizzle`` takes a ``preshuffled`` flag; when set
  it emits the K-major slab linearisation
  ``(r // 16) * 2048 + (r % 16) * 128 + c`` instead of ``r * K + c``.
* Tile-level base becomes ``tile_row * BLOCK_K`` (the in-slab
  offset) and per-K-iter step becomes ``rows * BLOCK_K`` (next K
  slab), both fed through the unchanged ``soffset`` path.

Test side:
* ``test_fp8_gemm_4wave`` is parametrised across
  ``[rowmajor, preshuffle]``, doubling the matrix.
* CLI gets ``--preshuffle_a`` / ``--preshuffle_b`` flags; the
  default iteration count is bumped to 10 warmup + 100 iters so
  the BLOCK=256 shapes reach steady state for stable perf.

Perf (3 trials, 10 warmup + 100 iters):
  512  x 2112 x 7168 (BLOCK=64):  549  -> 588  TFLOPS  (+7.1%)
  5120 x 5120 x 8320:            1937 -> 1967 TFLOPS  (+1.5%)
  8192 x 8192 x 8192:            2387 ~  2378 TFLOPS  (95.1% peak)
  9728 x 8192 x 8320:            2362 ~  2357 TFLOPS  (94.3% peak)

Preshuffle is strictly non-regressing across all parametrised
shapes; gains accumulate on small / memory-bound shapes, large
compute-bound shapes are already at ~95% MI355X peak with the
row-major layout so no further DRAM-side headroom is left.

Co-authored-by: Cursor <cursoragent@cursor.com>
Empirical sweep on all four parametrised shapes showed A-side
preshuffle adding zero measurable benefit beyond B-side:

  8192 x 8192 x 8192:  baseline 2372 / B-only 2364 / A-only 2381 /
                       A+B 2374 TFLOPS -- all within ~0.5% std

That matches the physics: 8k^3 sits at ~95% MI355X fp8 peak, with
the remaining gap coming from LDS-side barriers, not DRAM access.
Smaller shapes (512 / 5120) already saw the full gain from B-only
(+7.1% / +1.5%), so the second symmetric path only added kernel
parameter and host helper surface for no return.

Removed:
  * ``a_preshuffled`` flag on ``compile_fp8_gemm``
  * ``preshuffle_a`` host helper
  * ``--preshuffle_a`` CLI flag on the test
  * The "_preshuffle_2d shared" plumbing -- B uses an inline torch
    reshape/permute now

The remaining surface is the minimal one: ``b_preshuffled=True``
plus ``preshuffle_b(B_T)``. Pytest is parametrised as
``[rowmajor, preshuffle_b]`` (4 shapes x 2 = 8 cases).

Co-authored-by: Cursor <cursoragent@cursor.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant