From f9aa522443a0dfe7b10ca41bec5af642937677be Mon Sep 17 00:00:00 2001 From: coderfeli Date: Wed, 13 May 2026 10:58:24 +0000 Subject: [PATCH 1/8] feat(fp8_gemm_4wave): support K-major A/B preshuffle layouts Adds opt-in ``a_preshuffled`` / ``b_preshuffled`` compile flags to ``compile_fp8_gemm``. When enabled, the matching operand is read from a K-major outermost DRAM layout (``(K//128, rows//16, 16, 128)``) where every K-atom slab of ``rows * 128`` bytes is contiguous; the 8 row_parts that one wave- step touches span only 16 KB instead of the row-major ``8 * K`` bytes. LDS layout, MFMA fragment shape, and wave assignment are unchanged. Host side: * ``preshuffle_a(A)`` / ``preshuffle_b(B_T)`` torch helpers perform the offline permute via ``T.reshape(rows//16, 16, K//128, 128).permute(2, 0, 1, 3).contiguous()``. Kernel side: * ``_compute_global_swizzle`` takes a ``preshuffled`` flag; when set it emits the K-major slab linearisation ``(r // 16) * 2048 + (r % 16) * 128 + c`` instead of ``r * K + c``. * Tile-level base becomes ``tile_row * BLOCK_K`` (the in-slab offset) and per-K-iter step becomes ``rows * BLOCK_K`` (next K slab), both fed through the unchanged ``soffset`` path. Test side: * ``test_fp8_gemm_4wave`` is parametrised across ``[rowmajor, preshuffle]``, doubling the matrix. * CLI gets ``--preshuffle_a`` / ``--preshuffle_b`` flags; the default iteration count is bumped to 10 warmup + 100 iters so the BLOCK=256 shapes reach steady state for stable perf. Perf (3 trials, 10 warmup + 100 iters): 512 x 2112 x 7168 (BLOCK=64): 549 -> 588 TFLOPS (+7.1%) 5120 x 5120 x 8320: 1937 -> 1967 TFLOPS (+1.5%) 8192 x 8192 x 8192: 2387 ~ 2378 TFLOPS (95.1% peak) 9728 x 8192 x 8320: 2362 ~ 2357 TFLOPS (94.3% peak) Preshuffle is strictly non-regressing across all parametrised shapes; gains accumulate on small / memory-bound shapes, large compute-bound shapes are already at ~95% MI355X peak with the row-major layout so no further DRAM-side headroom is left. Co-authored-by: Cursor --- kernels/fp8_gemm_4wave.py | 113 +++++++++++++++++++++------ tests/kernels/test_fp8_gemm_4wave.py | 60 +++++++++++--- 2 files changed, 140 insertions(+), 33 deletions(-) diff --git a/kernels/fp8_gemm_4wave.py b/kernels/fp8_gemm_4wave.py index d9f4fcd0..65e7685b 100644 --- a/kernels/fp8_gemm_4wave.py +++ b/kernels/fp8_gemm_4wave.py @@ -12,6 +12,14 @@ the chained Vec(4, f32) accumulator stays on AGPR. The XOR swizzle and the 8-buffer LDS pipeline ping-pong are kept as direct arithmetic to preserve the original kernel's interleaved-cluster scheduling. + +Optional K-major operand preshuffle. ``compile_fp8_gemm(..., a_preshuffled=True)`` +and ``b_preshuffled=True`` switch the per-thread DRAM access for that +operand to a K-major outermost layout (each K atom slab of +``rows * 128`` bytes contiguous). The host-side helpers +:func:`preshuffle_a` / :func:`preshuffle_b` perform the offline +permute. LDS layout, MFMA fragment shape, and wave assignment are +unchanged. """ import flydsl.compiler as flyc @@ -27,6 +35,31 @@ from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr +def _preshuffle_2d(t): + """K-major outermost preshuffle for a row-major ``(rows, K)`` tensor. + + Output shape ``(K//128, rows//16, 16, 128)`` is K-major outermost so + every 128-K-col atom slab of ``rows * 128`` bytes lives contiguously. + """ + rows, k = t.shape[-2:] + assert rows % 16 == 0 and k % 128 == 0, f"need rows%16==0 and K%128==0, got rows={rows} K={k}" + return t.reshape(rows // 16, 16, k // 128, 128).permute(2, 0, 1, 3).contiguous() + + +def preshuffle_a(a): + """Permute row-major ``A`` ``(M, K)`` into the layout consumed when + ``compile_fp8_gemm(..., a_preshuffled=True)``. K-major outermost + (each K atom slab = ``M * 128`` byte contig).""" + return _preshuffle_2d(a) + + +def preshuffle_b(b_t): + """Permute row-major ``B_T`` ``(N, K)`` into the layout consumed when + ``compile_fp8_gemm(..., b_preshuffled=True)``. K-major outermost + (each K atom slab = ``N * 128`` byte contig).""" + return _preshuffle_2d(b_t) + + def _divmod(a, b): return (a // b, a % b) @@ -59,7 +92,17 @@ def _xcd_swizzle(num_pid_m, num_pid_n): return (pid_m, pid_n) -def compile_fp8_gemm(*, M: int, N: int, K: int, BLOCK_M: int = 256, BLOCK_N: int = 256, use_xcd_remap: bool = True): +def compile_fp8_gemm( + *, + M: int, + N: int, + K: int, + BLOCK_M: int = 256, + BLOCK_N: int = 256, + use_xcd_remap: bool = True, + a_preshuffled: bool = False, + b_preshuffled: bool = False, +): # MFMA atom is 16x16x128; 4 waves in a 2x2 config require BLOCK >= 64. BLOCK_K = 128 LDS_BLOCK_M = BLOCK_M // 2 @@ -134,10 +177,25 @@ def kernel_gemm( wave_i = wave_id // 2 wave_j = wave_id % 2 - A0_gl_offset = (tile_i * BLOCK_M) * K - A1_gl_offset = (tile_i * BLOCK_M + LDS_BLOCK_M) * K - B0_gl_offset = (tile_j * BLOCK_N) * K - B1_gl_offset = (tile_j * BLOCK_N + LDS_BLOCK_N) * K + # K-major preshuffle: tile_row's slot in k_outer=0 slab is + # ``(tile_row // 16) * 2048 = tile_row * BLOCK_K``; per-K-iter + # step jumps to next ``rows * BLOCK_K`` byte slab. + if const_expr(a_preshuffled): + A0_gl_offset = (tile_i * BLOCK_M) * BLOCK_K + A1_gl_offset = (tile_i * BLOCK_M + LDS_BLOCK_M) * BLOCK_K + A_K_STEP = M * BLOCK_K + else: + A0_gl_offset = (tile_i * BLOCK_M) * K + A1_gl_offset = (tile_i * BLOCK_M + LDS_BLOCK_M) * K + A_K_STEP = BLOCK_K + if const_expr(b_preshuffled): + B0_gl_offset = (tile_j * BLOCK_N) * BLOCK_K + B1_gl_offset = (tile_j * BLOCK_N + LDS_BLOCK_N) * BLOCK_K + B_K_STEP = N * BLOCK_K + else: + B0_gl_offset = (tile_j * BLOCK_N) * K + B1_gl_offset = (tile_j * BLOCK_N + LDS_BLOCK_N) * K + B_K_STEP = BLOCK_K # A/B come in as torch.int8 (PyTorch fp8 view restriction); recast # the buffer-desc pointer's element type to fp8 so typed copy @@ -171,13 +229,19 @@ def _swizzle_128(row, col): swizzled = offset ^ swz return swizzled // BLOCK_K, swizzled % BLOCK_K - def _compute_global_swizzle(): + def _compute_global_swizzle(preshuffled): + # Row-major: ``r * K + c``. Preshuffled: same swizzled (r, c) + # but linearised with the K-major slab layout + # ``(r // 16) * 2048 + (r % 16) * 128 + c``. offsets = [] for round in range_constexpr(max(N_TILES_A, N_TILES_B)): row = lane_id // 8 + wave_id * 8 + round * 32 col = (lane_id % 8) * 16 r, c = _swizzle_128(row, col) - offsets.append(r * K + c) + if const_expr(preshuffled): + offsets.append((r // 16) * (16 * BLOCK_K) + (r % 16) * BLOCK_K + c) + else: + offsets.append(r * K + c) return offsets def _compute_lds_swizzle(wave_idx, n_tiles): @@ -396,18 +460,19 @@ def _compute_block(lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_t c10_frag = [RT_C_i] * N_ACCUMS c11_frag = [RT_C_i] * N_ACCUMS - global_offsets = _compute_global_swizzle() + gl_off_a = _compute_global_swizzle(a_preshuffled) + gl_off_b = _compute_global_swizzle(b_preshuffled) # Prologue: 8-buffer LDS pipeline pre-fill. - _load_lds(ga_div, a_cur0, A0_gl_offset + 0 * BLOCK_K, global_offsets, N_TILES_A) - _load_lds(gb_div, b_cur0, B0_gl_offset + 0 * BLOCK_K, global_offsets, N_TILES_B) - _load_lds(gb_div, b_cur1, B1_gl_offset + 0 * BLOCK_K, global_offsets, N_TILES_B) - _load_lds(ga_div, a_cur1, A1_gl_offset + 0 * BLOCK_K, global_offsets, N_TILES_A) + _load_lds(ga_div, a_cur0, A0_gl_offset + 0 * A_K_STEP, gl_off_a, N_TILES_A) + _load_lds(gb_div, b_cur0, B0_gl_offset + 0 * B_K_STEP, gl_off_b, N_TILES_B) + _load_lds(gb_div, b_cur1, B1_gl_offset + 0 * B_K_STEP, gl_off_b, N_TILES_B) + _load_lds(ga_div, a_cur1, A1_gl_offset + 0 * A_K_STEP, gl_off_a, N_TILES_A) - _load_lds(ga_div, a_next0, A0_gl_offset + 1 * BLOCK_K, global_offsets, N_TILES_A) - _load_lds(gb_div, b_next0, B0_gl_offset + 1 * BLOCK_K, global_offsets, N_TILES_B) - _load_lds(gb_div, b_next1, B1_gl_offset + 1 * BLOCK_K, global_offsets, N_TILES_B) - _load_lds(ga_div, a_next1, A1_gl_offset + 1 * BLOCK_K, global_offsets, N_TILES_A) + _load_lds(ga_div, a_next0, A0_gl_offset + 1 * A_K_STEP, gl_off_a, N_TILES_A) + _load_lds(gb_div, b_next0, B0_gl_offset + 1 * B_K_STEP, gl_off_b, N_TILES_B) + _load_lds(gb_div, b_next1, B1_gl_offset + 1 * B_K_STEP, gl_off_b, N_TILES_B) + _load_lds(ga_div, a_next1, A1_gl_offset + 1 * A_K_STEP, gl_off_a, N_TILES_A) _wait_barrier((3 * N_TILES_A) + (4 * N_TILES_B)) @@ -423,8 +488,8 @@ def _compute_block(lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_t c00_frag, b1_frag = _compute_block( a_cur0, ga_div, - A0_gl_offset + (k + 2) * BLOCK_K, - global_offsets, + A0_gl_offset + (k + 2) * A_K_STEP, + gl_off_a, wave_j, b_cur1, N_TILES_A, @@ -437,8 +502,8 @@ def _compute_block(lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_t c01_frag, a1_frag = _compute_block( b_cur0, gb_div, - B0_gl_offset + (k + 2) * BLOCK_K, - global_offsets, + B0_gl_offset + (k + 2) * B_K_STEP, + gl_off_b, wave_i, a_cur1, N_TILES_B, @@ -453,8 +518,8 @@ def _compute_block(lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_t c10_frag, a0_frag = _compute_block( b_cur1, gb_div, - B1_gl_offset + (k + 2) * BLOCK_K, - global_offsets, + B1_gl_offset + (k + 2) * B_K_STEP, + gl_off_b, wave_i, a_next0, N_TILES_B, @@ -467,8 +532,8 @@ def _compute_block(lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_t c11_frag, b0_frag = _compute_block( a_cur1, ga_div, - A1_gl_offset + (k + 2) * BLOCK_K, - global_offsets, + A1_gl_offset + (k + 2) * A_K_STEP, + gl_off_a, wave_j, b_next0, N_TILES_A, diff --git a/tests/kernels/test_fp8_gemm_4wave.py b/tests/kernels/test_fp8_gemm_4wave.py index d8cedae0..c7ac9713 100644 --- a/tests/kernels/test_fp8_gemm_4wave.py +++ b/tests/kernels/test_fp8_gemm_4wave.py @@ -23,7 +23,7 @@ sys.path.insert(0, _REPO_ROOT) from flydsl.runtime.device import get_rocm_arch # noqa: E402 -from kernels.fp8_gemm_4wave import compile_fp8_gemm # noqa: E402 +from kernels.fp8_gemm_4wave import compile_fp8_gemm, preshuffle_a, preshuffle_b # noqa: E402 from tests.test_common import run_perftest, verify_output # noqa: E402 from tests.utils import pertoken_quant # noqa: E402 @@ -72,6 +72,8 @@ def _bench_fp8_gemm_4wave( num_warmups: int = 2, num_iters: int = 10, vs_torch: bool = False, + a_preshuffled: bool = False, + b_preshuffled: bool = False, ): """Run + verify a single (M, N, K, tile) configuration. Returns TFLOPS.""" if "gfx95" not in ARCH: @@ -92,6 +94,9 @@ def _bench_fp8_gemm_4wave( c_ref = _run_torch(a_q, b_q, scale_a, scale_b) + a_kernel = preshuffle_a(a_q) if a_preshuffled else a_q + b_kernel = preshuffle_b(b_q) if b_preshuffled else b_q + launch_fn = compile_fp8_gemm( M=M, N=N, @@ -99,9 +104,13 @@ def _bench_fp8_gemm_4wave( BLOCK_M=tile_m, BLOCK_N=tile_n, use_xcd_remap=not disable_xcd_remap, + a_preshuffled=a_preshuffled, + b_preshuffled=b_preshuffled, ) print( - f"\n[fp8_gemm_4wave] M={M} N={N} K={K} " f"BLOCK_M={tile_m} BLOCK_N={tile_n} xcd_remap={not disable_xcd_remap}" + f"\n[fp8_gemm_4wave] M={M} N={N} K={K} " + f"BLOCK_M={tile_m} BLOCK_N={tile_n} xcd_remap={not disable_xcd_remap} " + f"preshuffle_a={a_preshuffled} preshuffle_b={b_preshuffled}" ) def _args(c, a, b, sa, sb): @@ -114,7 +123,7 @@ def _args(c, a, b, sa, sb): torch.cuda.current_stream(), ) - compiled = flyc.compile(launch_fn, *_args(c_out_raw, a_q, b_q, scale_a, scale_b)) + compiled = flyc.compile(launch_fn, *_args(c_out_raw, a_kernel, b_kernel, scale_a, scale_b)) def _launch(c, a, b, sa, sb): compiled(*_args(c, a, b, sa, sb)) @@ -123,8 +132,8 @@ def _launch(c, a, b, sa, sb): _, us = run_perftest( _launch, c_out_raw, - a_q, - b_q, + a_kernel, + b_kernel, scale_a, scale_b, num_iters=num_iters, @@ -175,8 +184,17 @@ def _launch(c, a, b, sa, sb): pytest.param(9728, 8192, 8320, 256, 256, marks=pytest.mark.large_shape, id="9728x8192x8320"), ], ) -def test_fp8_gemm_4wave(M, N, K, tile_m, tile_n): - _bench_fp8_gemm_4wave(M=M, N=N, K=K, tile_m=tile_m, tile_n=tile_n) +@pytest.mark.parametrize("preshuffle", [False, True], ids=["rowmajor", "preshuffle"]) +def test_fp8_gemm_4wave(M, N, K, tile_m, tile_n, preshuffle): + _bench_fp8_gemm_4wave( + M=M, + N=N, + K=K, + tile_m=tile_m, + tile_n=tile_n, + a_preshuffled=preshuffle, + b_preshuffled=preshuffle, + ) if __name__ == "__main__": @@ -189,14 +207,36 @@ def test_fp8_gemm_4wave(M, N, K, tile_m, tile_n): parser.add_argument("--tile_m", type=int, default=256) parser.add_argument("--tile_n", type=int, default=256) parser.add_argument("--disable_xcd_remap", action="store_true", default=False) - parser.add_argument("--num_iters", type=int, default=10) - parser.add_argument("--num_warmups", type=int, default=2) + parser.add_argument( + "--num_iters", + type=int, + default=100, + help="Use 100+ for stable steady-state perf measurement.", + ) + parser.add_argument( + "--num_warmups", + type=int, + default=10, + help=">=10 needed to reach steady state on the large BLOCK=256 shapes.", + ) parser.add_argument( "--vs_torch", action="store_true", default=False, help="Also run torch._scaled_mm with the same input + harness for perf comparison.", ) + parser.add_argument( + "--preshuffle_a", + action="store_true", + default=False, + help="Preshuffle A into K-major slabs (each K atom = M*128 byte contig).", + ) + parser.add_argument( + "--preshuffle_b", + action="store_true", + default=False, + help="Preshuffle B into K-major slabs (each K atom = N*128 byte contig).", + ) args = parser.parse_args() torch.set_default_device("cuda") @@ -212,6 +252,8 @@ def test_fp8_gemm_4wave(M, N, K, tile_m, tile_n): num_warmups=args.num_warmups, num_iters=args.num_iters, vs_torch=args.vs_torch, + a_preshuffled=args.preshuffle_a, + b_preshuffled=args.preshuffle_b, ) except pytest.skip.Exception as e: print(f"Skipped: {e}") From 3ff4b779d2573a047a45230f8b465d5390c426cb Mon Sep 17 00:00:00 2001 From: coderfeli Date: Wed, 13 May 2026 12:32:39 +0000 Subject: [PATCH 2/8] refactor(fp8_gemm_4wave): drop A preshuffle, keep B-only Empirical sweep on all four parametrised shapes showed A-side preshuffle adding zero measurable benefit beyond B-side: 8192 x 8192 x 8192: baseline 2372 / B-only 2364 / A-only 2381 / A+B 2374 TFLOPS -- all within ~0.5% std That matches the physics: 8k^3 sits at ~95% MI355X fp8 peak, with the remaining gap coming from LDS-side barriers, not DRAM access. Smaller shapes (512 / 5120) already saw the full gain from B-only (+7.1% / +1.5%), so the second symmetric path only added kernel parameter and host helper surface for no return. Removed: * ``a_preshuffled`` flag on ``compile_fp8_gemm`` * ``preshuffle_a`` host helper * ``--preshuffle_a`` CLI flag on the test * The "_preshuffle_2d shared" plumbing -- B uses an inline torch reshape/permute now The remaining surface is the minimal one: ``b_preshuffled=True`` plus ``preshuffle_b(B_T)``. Pytest is parametrised as ``[rowmajor, preshuffle_b]`` (4 shapes x 2 = 8 cases). Co-authored-by: Cursor --- kernels/fp8_gemm_4wave.py | 63 ++++++++++------------------ tests/kernels/test_fp8_gemm_4wave.py | 25 ++++------- 2 files changed, 29 insertions(+), 59 deletions(-) diff --git a/kernels/fp8_gemm_4wave.py b/kernels/fp8_gemm_4wave.py index 65e7685b..36866d9f 100644 --- a/kernels/fp8_gemm_4wave.py +++ b/kernels/fp8_gemm_4wave.py @@ -13,13 +13,15 @@ the 8-buffer LDS pipeline ping-pong are kept as direct arithmetic to preserve the original kernel's interleaved-cluster scheduling. -Optional K-major operand preshuffle. ``compile_fp8_gemm(..., a_preshuffled=True)`` -and ``b_preshuffled=True`` switch the per-thread DRAM access for that -operand to a K-major outermost layout (each K atom slab of -``rows * 128`` bytes contiguous). The host-side helpers -:func:`preshuffle_a` / :func:`preshuffle_b` perform the offline -permute. LDS layout, MFMA fragment shape, and wave assignment are -unchanged. +Optional K-major preshuffle for B. ``compile_fp8_gemm(..., b_preshuffled=True)`` +switches the per-thread DRAM access for B to a K-major outermost +layout (each K atom slab of ``N * 128`` bytes contiguous). The host +helper :func:`preshuffle_b` performs the offline permute. LDS +layout, MFMA fragment shape, and wave assignment are unchanged. + +A stays row-major; it didn't gain measurably from the same +treatment on the parametrised shapes (8192^3 was already at ~95% +peak compute-bound, smaller shapes gained ~7% from B alone). """ import flydsl.compiler as flyc @@ -35,29 +37,14 @@ from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr -def _preshuffle_2d(t): - """K-major outermost preshuffle for a row-major ``(rows, K)`` tensor. - - Output shape ``(K//128, rows//16, 16, 128)`` is K-major outermost so - every 128-K-col atom slab of ``rows * 128`` bytes lives contiguously. - """ - rows, k = t.shape[-2:] - assert rows % 16 == 0 and k % 128 == 0, f"need rows%16==0 and K%128==0, got rows={rows} K={k}" - return t.reshape(rows // 16, 16, k // 128, 128).permute(2, 0, 1, 3).contiguous() - - -def preshuffle_a(a): - """Permute row-major ``A`` ``(M, K)`` into the layout consumed when - ``compile_fp8_gemm(..., a_preshuffled=True)``. K-major outermost - (each K atom slab = ``M * 128`` byte contig).""" - return _preshuffle_2d(a) - - def preshuffle_b(b_t): """Permute row-major ``B_T`` ``(N, K)`` into the layout consumed when - ``compile_fp8_gemm(..., b_preshuffled=True)``. K-major outermost - (each K atom slab = ``N * 128`` byte contig).""" - return _preshuffle_2d(b_t) + ``compile_fp8_gemm(..., b_preshuffled=True)``: K-major outermost + ``(K//128, N//16, 16, 128)`` so each 128-K-col atom slab of + ``N * 128`` bytes lives contiguously.""" + n, k = b_t.shape[-2:] + assert n % 16 == 0 and k % 128 == 0, f"need N%16==0 and K%128==0, got N={n} K={k}" + return b_t.reshape(n // 16, 16, k // 128, 128).permute(2, 0, 1, 3).contiguous() def _divmod(a, b): @@ -100,7 +87,6 @@ def compile_fp8_gemm( BLOCK_M: int = 256, BLOCK_N: int = 256, use_xcd_remap: bool = True, - a_preshuffled: bool = False, b_preshuffled: bool = False, ): # MFMA atom is 16x16x128; 4 waves in a 2x2 config require BLOCK >= 64. @@ -177,18 +163,13 @@ def kernel_gemm( wave_i = wave_id // 2 wave_j = wave_id % 2 - # K-major preshuffle: tile_row's slot in k_outer=0 slab is - # ``(tile_row // 16) * 2048 = tile_row * BLOCK_K``; per-K-iter - # step jumps to next ``rows * BLOCK_K`` byte slab. - if const_expr(a_preshuffled): - A0_gl_offset = (tile_i * BLOCK_M) * BLOCK_K - A1_gl_offset = (tile_i * BLOCK_M + LDS_BLOCK_M) * BLOCK_K - A_K_STEP = M * BLOCK_K - else: - A0_gl_offset = (tile_i * BLOCK_M) * K - A1_gl_offset = (tile_i * BLOCK_M + LDS_BLOCK_M) * K - A_K_STEP = BLOCK_K + A0_gl_offset = (tile_i * BLOCK_M) * K + A1_gl_offset = (tile_i * BLOCK_M + LDS_BLOCK_M) * K + A_K_STEP = BLOCK_K if const_expr(b_preshuffled): + # K-major preshuffle: tile_j's slot in k_outer=0 slab is + # ``(tile_j*BLOCK_N // 16) * 2048 = tile_j*BLOCK_N * BLOCK_K``; + # per-K-iter step jumps to next ``N * BLOCK_K`` byte slab. B0_gl_offset = (tile_j * BLOCK_N) * BLOCK_K B1_gl_offset = (tile_j * BLOCK_N + LDS_BLOCK_N) * BLOCK_K B_K_STEP = N * BLOCK_K @@ -460,7 +441,7 @@ def _compute_block(lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_t c10_frag = [RT_C_i] * N_ACCUMS c11_frag = [RT_C_i] * N_ACCUMS - gl_off_a = _compute_global_swizzle(a_preshuffled) + gl_off_a = _compute_global_swizzle(preshuffled=False) gl_off_b = _compute_global_swizzle(b_preshuffled) # Prologue: 8-buffer LDS pipeline pre-fill. diff --git a/tests/kernels/test_fp8_gemm_4wave.py b/tests/kernels/test_fp8_gemm_4wave.py index c7ac9713..a984a0a8 100644 --- a/tests/kernels/test_fp8_gemm_4wave.py +++ b/tests/kernels/test_fp8_gemm_4wave.py @@ -23,7 +23,7 @@ sys.path.insert(0, _REPO_ROOT) from flydsl.runtime.device import get_rocm_arch # noqa: E402 -from kernels.fp8_gemm_4wave import compile_fp8_gemm, preshuffle_a, preshuffle_b # noqa: E402 +from kernels.fp8_gemm_4wave import compile_fp8_gemm, preshuffle_b # noqa: E402 from tests.test_common import run_perftest, verify_output # noqa: E402 from tests.utils import pertoken_quant # noqa: E402 @@ -72,7 +72,6 @@ def _bench_fp8_gemm_4wave( num_warmups: int = 2, num_iters: int = 10, vs_torch: bool = False, - a_preshuffled: bool = False, b_preshuffled: bool = False, ): """Run + verify a single (M, N, K, tile) configuration. Returns TFLOPS.""" @@ -94,7 +93,6 @@ def _bench_fp8_gemm_4wave( c_ref = _run_torch(a_q, b_q, scale_a, scale_b) - a_kernel = preshuffle_a(a_q) if a_preshuffled else a_q b_kernel = preshuffle_b(b_q) if b_preshuffled else b_q launch_fn = compile_fp8_gemm( @@ -104,13 +102,12 @@ def _bench_fp8_gemm_4wave( BLOCK_M=tile_m, BLOCK_N=tile_n, use_xcd_remap=not disable_xcd_remap, - a_preshuffled=a_preshuffled, b_preshuffled=b_preshuffled, ) print( f"\n[fp8_gemm_4wave] M={M} N={N} K={K} " f"BLOCK_M={tile_m} BLOCK_N={tile_n} xcd_remap={not disable_xcd_remap} " - f"preshuffle_a={a_preshuffled} preshuffle_b={b_preshuffled}" + f"preshuffle_b={b_preshuffled}" ) def _args(c, a, b, sa, sb): @@ -123,7 +120,7 @@ def _args(c, a, b, sa, sb): torch.cuda.current_stream(), ) - compiled = flyc.compile(launch_fn, *_args(c_out_raw, a_kernel, b_kernel, scale_a, scale_b)) + compiled = flyc.compile(launch_fn, *_args(c_out_raw, a_q, b_kernel, scale_a, scale_b)) def _launch(c, a, b, sa, sb): compiled(*_args(c, a, b, sa, sb)) @@ -132,7 +129,7 @@ def _launch(c, a, b, sa, sb): _, us = run_perftest( _launch, c_out_raw, - a_kernel, + a_q, b_kernel, scale_a, scale_b, @@ -184,16 +181,15 @@ def _launch(c, a, b, sa, sb): pytest.param(9728, 8192, 8320, 256, 256, marks=pytest.mark.large_shape, id="9728x8192x8320"), ], ) -@pytest.mark.parametrize("preshuffle", [False, True], ids=["rowmajor", "preshuffle"]) -def test_fp8_gemm_4wave(M, N, K, tile_m, tile_n, preshuffle): +@pytest.mark.parametrize("preshuffle_b", [False, True], ids=["rowmajor", "preshuffle_b"]) +def test_fp8_gemm_4wave(M, N, K, tile_m, tile_n, preshuffle_b): _bench_fp8_gemm_4wave( M=M, N=N, K=K, tile_m=tile_m, tile_n=tile_n, - a_preshuffled=preshuffle, - b_preshuffled=preshuffle, + b_preshuffled=preshuffle_b, ) @@ -225,12 +221,6 @@ def test_fp8_gemm_4wave(M, N, K, tile_m, tile_n, preshuffle): default=False, help="Also run torch._scaled_mm with the same input + harness for perf comparison.", ) - parser.add_argument( - "--preshuffle_a", - action="store_true", - default=False, - help="Preshuffle A into K-major slabs (each K atom = M*128 byte contig).", - ) parser.add_argument( "--preshuffle_b", action="store_true", @@ -252,7 +242,6 @@ def test_fp8_gemm_4wave(M, N, K, tile_m, tile_n, preshuffle): num_warmups=args.num_warmups, num_iters=args.num_iters, vs_torch=args.vs_torch, - a_preshuffled=args.preshuffle_a, b_preshuffled=args.preshuffle_b, ) except pytest.skip.Exception as e: From 6838aed9fd7f5d02b0be072a1d22a9d09710262d Mon Sep 17 00:00:00 2001 From: coderfeli Date: Wed, 13 May 2026 13:17:10 +0000 Subject: [PATCH 3/8] layout change --- kernels/fp8_gemm_4wave.py | 144 +++++++++++++++++++++++++------------- 1 file changed, 95 insertions(+), 49 deletions(-) diff --git a/kernels/fp8_gemm_4wave.py b/kernels/fp8_gemm_4wave.py index 36866d9f..bc5d6525 100644 --- a/kernels/fp8_gemm_4wave.py +++ b/kernels/fp8_gemm_4wave.py @@ -13,15 +13,12 @@ the 8-buffer LDS pipeline ping-pong are kept as direct arithmetic to preserve the original kernel's interleaved-cluster scheduling. -Optional K-major preshuffle for B. ``compile_fp8_gemm(..., b_preshuffled=True)`` -switches the per-thread DRAM access for B to a K-major outermost -layout (each K atom slab of ``N * 128`` bytes contiguous). The host -helper :func:`preshuffle_b` performs the offline permute. LDS -layout, MFMA fragment shape, and wave assignment are unchanged. - -A stays row-major; it didn't gain measurably from the same -treatment on the parametrised shapes (8192^3 was already at ~95% -peak compute-bound, smaller shapes gained ~7% from B alone). +Optional preshuffle for B. ``compile_fp8_gemm(..., b_preshuffled=True)`` +consumes B in the same on-disk layout as +``preshuffle_gemm_v2``/``shuffle_weight((16, 16))`` so a single +offline-shuffled weight tensor can feed either kernel. The host +helper :func:`preshuffle_b` performs the permute. LDS layout, MFMA +fragment shape, and wave assignment are unchanged. """ import flydsl.compiler as flyc @@ -39,12 +36,20 @@ def preshuffle_b(b_t): """Permute row-major ``B_T`` ``(N, K)`` into the layout consumed when - ``compile_fp8_gemm(..., b_preshuffled=True)``: K-major outermost - ``(K//128, N//16, 16, 128)`` so each 128-K-col atom slab of - ``N * 128`` bytes lives contiguously.""" + ``compile_fp8_gemm(..., b_preshuffled=True)``. + + On-disk format matches ``preshuffle_gemm.shuffle_weight((16, 16))`` + / ``preshuffle_gemm_v2``: N-major outermost + ``((nlane=16, n_outer=N/16), (kpack=16, klane=4, k0=K/64))`` + with strides ``((16, K*16), (1, 256, 1024))``. Each 64-K-col block + of 16 N rows = 1 KB is contiguous; one MFMA_Scale_128 K-tile = 2 + consecutive k0 blocks = 2 KB. + + Equivalent to ``b.reshape(N/16, 16, K/64, 4, 16).permute(0, 2, 3, 1, 4)``. + """ n, k = b_t.shape[-2:] - assert n % 16 == 0 and k % 128 == 0, f"need N%16==0 and K%128==0, got N={n} K={k}" - return b_t.reshape(n // 16, 16, k // 128, 128).permute(2, 0, 1, 3).contiguous() + assert n % 16 == 0 and k % 64 == 0, f"need N%16==0 and K%64==0, got N={n} K={k}" + return b_t.reshape(n // 16, 16, k // 64, 4, 16).permute(0, 2, 3, 1, 4).contiguous() def _divmod(a, b): @@ -166,17 +171,14 @@ def kernel_gemm( A0_gl_offset = (tile_i * BLOCK_M) * K A1_gl_offset = (tile_i * BLOCK_M + LDS_BLOCK_M) * K A_K_STEP = BLOCK_K - if const_expr(b_preshuffled): - # K-major preshuffle: tile_j's slot in k_outer=0 slab is - # ``(tile_j*BLOCK_N // 16) * 2048 = tile_j*BLOCK_N * BLOCK_K``; - # per-K-iter step jumps to next ``N * BLOCK_K`` byte slab. - B0_gl_offset = (tile_j * BLOCK_N) * BLOCK_K - B1_gl_offset = (tile_j * BLOCK_N + LDS_BLOCK_N) * BLOCK_K - B_K_STEP = N * BLOCK_K - else: - B0_gl_offset = (tile_j * BLOCK_N) * K - B1_gl_offset = (tile_j * BLOCK_N + LDS_BLOCK_N) * K - B_K_STEP = BLOCK_K + # N-major preshuffle matches preshuffle_gemm_v2 on-disk format: + # ``((nlane=16, n_outer=N/16), (kpack=16, klane=4, k0=K/64))`` + # with strides ``((16, K*16), (1, 256, 1024))``. Tile base is + # unchanged (``(tile_j*BLOCK_N // 16) * K*16 = tile_j*BLOCK_N*K``); + # one K-tile of 128 K-cols = 2 k0 entries forward = 2 KB. + B0_gl_offset = (tile_j * BLOCK_N) * K + B1_gl_offset = (tile_j * BLOCK_N + LDS_BLOCK_N) * K + B_K_STEP = (2 * 1024) if b_preshuffled else BLOCK_K # A/B come in as torch.int8 (PyTorch fp8 view restriction); recast # the buffer-desc pointer's element type to fp8 so typed copy @@ -211,29 +213,58 @@ def _swizzle_128(row, col): return swizzled // BLOCK_K, swizzled % BLOCK_K def _compute_global_swizzle(preshuffled): - # Row-major: ``r * K + c``. Preshuffled: same swizzled (r, c) - # but linearised with the K-major slab layout - # ``(r // 16) * 2048 + (r % 16) * 128 + c``. + # Two thread→tile mappings: + # * Non-preshuffled (row-major B / A): ``lane//8 → row``, + # ``lane%8 → col_chunk``. 8 lanes share a row at adjacent + # 16-byte col chunks → contiguous 128 B per lane-group in DRAM. + # LDS write at ``lane*16`` produces row-major-within-wave LDS. + # XOR swizzle reduces ds_read bank conflicts. + # * Preshuffled (preshuffle_gemm_v2 N-major layout for B, + # ``((nlane=16, n_outer=N/16), (kpack=16, klane=4, k0=K/64))`` + # with strides ``((16, K*16), (1, 256, 1024))``): swap to + # ``lane%8 → row``, ``lane//8 → col_chunk``. 8 lanes now span + # NLANE=0..7 at the same (n_outer, klane, k0) → 128 B + # contiguous in DRAM. LDS write at ``lane*16`` produces + # col-major-within-wave LDS where each 128-byte LDS row gets + # exactly 8 lanes (32 banks fully utilised) → no bank conflict + # without any additional swizzle. offsets = [] for round in range_constexpr(max(N_TILES_A, N_TILES_B)): - row = lane_id // 8 + wave_id * 8 + round * 32 - col = (lane_id % 8) * 16 - r, c = _swizzle_128(row, col) if const_expr(preshuffled): - offsets.append((r // 16) * (16 * BLOCK_K) + (r % 16) * BLOCK_K + c) + row = lane_id % 8 + wave_id * 8 + round * 32 + col = (lane_id // 8) * 16 + offsets.append( + (row // 16) * (K * 16) + + (row % 16) * 16 + + (col // 64) * 1024 + + ((col % 64) // 16) * 256 + + (col % 16) + ) else: + row = lane_id // 8 + wave_id * 8 + round * 32 + col = (lane_id % 8) * 16 + r, c = _swizzle_128(row, col) offsets.append(r * K + c) return offsets - def _compute_lds_swizzle(wave_idx, n_tiles): + def _compute_lds_swizzle(wave_idx, n_tiles, preshuffled=False): + # LDS read byte for (R, C) within one sub-buffer: + # * Non-preshuffled: ``byte = R*BLOCK_K + C_swizzled`` (XOR + # swizzle on col by R-bits to avoid bank conflicts). + # * Preshuffled: ``byte = (R//8)*1024 + (R%8)*16 + + # (C//16)*128 + (C%16)``. col-major-within-wave-portion; + # 8 lanes per 128-byte LDS row utilise all 32 banks. lds_swz = [] for row_offset in range_constexpr(n_tiles): row = wave_idx * (n_tiles * 16) + row_offset * 16 + lane_id % 16 swz = [] for i in range_constexpr(2): col = (lane_id // 16) * 16 + i * 64 - r, c = _swizzle_128(row, col) - swz.append(r * BLOCK_K + c) + if const_expr(preshuffled): + swz.append((row // 8) * 1024 + (row % 8) * 16 + (col // 16) * 128 + (col % 16)) + else: + r, c = _swizzle_128(row, col) + swz.append(r * BLOCK_K + c) lds_swz.append(swz) return lds_swz @@ -269,15 +300,19 @@ def _load_one_lds(gl_src_div, lds_dst_mem, k_offset, gl_offsets, tile_idx): def _pack_i32x4_i32x8(lo, hi): return lo.shuffle(hi, list(range(8))) - def _load_rt(lds_src, wave_idx, n_tiles): + def _load_rt(lds_src, wave_idx, n_tiles, preshuffled=False): frag = [] for i in range_constexpr(n_tiles): row = wave_idx * (n_tiles * 16) + i * 16 + lane_id % 16 halves = [] for step in range_constexpr(2): col = (lane_id // 16) * 16 + step * 64 - r, c = _swizzle_128(row, col) - v = Vec.load(Vec16_t, lds_src, [fx.Index(r * BLOCK_K + c)]) + if const_expr(preshuffled): + byte = (row // 8) * 1024 + (row % 8) * 16 + (col // 16) * 128 + (col % 16) + else: + r, c = _swizzle_128(row, col) + byte = r * BLOCK_K + c + v = Vec.load(Vec16_t, lds_src, [fx.Index(byte)]) halves.append(v.bitcast(fx.Int32)) frag.append(_pack_i32x4_i32x8(halves[0], halves[1])) return frag @@ -360,7 +395,10 @@ def _mfma_ABt_one(a, b, c, m, n): c[_c_idx(m, n)] = _mfma(a[m], b[n], c[_c_idx(m, n)]) return c - def _interleaved_cluster(lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_tiles_lds, a, b, c): + def _interleaved_cluster( + lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_tiles_lds, a, b, c, + lds_src_preshuffled=False, + ): # 64x64 output via 4x4 MFMAs, with per-tile G→LDS and LDS→reg # loads interleaved between MFMAs to hide latency. rt_dst = [] @@ -368,7 +406,7 @@ def _interleaved_cluster(lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_sr c = _mfma_ABt_one(a, b, c, 0, 0) c = _mfma_ABt_one(a, b, c, 0, 1) - lds_swz = _compute_lds_swizzle(wave_idx, n_tiles_lds) + lds_swz = _compute_lds_swizzle(wave_idx, n_tiles_lds, preshuffled=lds_src_preshuffled) _load_one_lds(gl_src, lds_dst, k_offset, gl_offsets, 0) rt_dst_0 = _load_one_rt(lds_src, lds_swz, 0, 0) @@ -418,21 +456,27 @@ def _interleaved_cluster(lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_sr return c, rt_dst def _compute_cluster( - lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_tiles_lds, n_tiles_rt, a, b, c + lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_tiles_lds, n_tiles_rt, a, b, c, + lds_src_preshuffled=False, ): _load_lds(gl_src, lds_dst, k_offset, gl_offsets, n_tiles_lds) - rt_dst = _load_rt(lds_src, wave_idx, n_tiles_rt) + rt_dst = _load_rt(lds_src, wave_idx, n_tiles_rt, preshuffled=lds_src_preshuffled) c = _mfma_ABt_all(a, b, c) return c, rt_dst - def _compute_block(lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_tiles_lds, n_tiles_rt, a, b, c): + def _compute_block( + lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_tiles_lds, n_tiles_rt, a, b, c, + lds_src_preshuffled=False, + ): if const_expr(_use_interleaved_block): return _interleaved_cluster( - lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_tiles_lds, a, b, c + lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_tiles_lds, a, b, c, + lds_src_preshuffled=lds_src_preshuffled, ) else: return _compute_cluster( - lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_tiles_lds, n_tiles_rt, a, b, c + lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_tiles_lds, n_tiles_rt, a, b, c, + lds_src_preshuffled=lds_src_preshuffled, ) # Each wave handles 2x2 64x64 sub-tiles of the output. @@ -461,7 +505,7 @@ def _compute_block(lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_t _wait_barrier((3 * N_TILES_A) + (3 * N_TILES_B)) - b0_frag = _load_rt(b_cur0, wave_j, N_TILES_B) + b0_frag = _load_rt(b_cur0, wave_j, N_TILES_B, preshuffled=b_preshuffled) for k in range_constexpr(K_ITERS - 2): _wait_barrier((2 * N_TILES_A) + (2 * N_TILES_B)) @@ -478,6 +522,7 @@ def _compute_block(lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_t a0_frag, b0_frag, c00_frag, + lds_src_preshuffled=b_preshuffled, ) c01_frag, a1_frag = _compute_block( @@ -522,6 +567,7 @@ def _compute_block(lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_t a1_frag, b1_frag, c11_frag, + lds_src_preshuffled=b_preshuffled, ) a_cur0, a_next0 = a_next0, a_cur0 @@ -531,14 +577,14 @@ def _compute_block(lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_t # Tail step k_iters - 2. _wait_barrier((2 * N_TILES_A) + (2 * N_TILES_B)) - b1_frag = _load_rt(b_cur1, wave_j, N_TILES_B) + b1_frag = _load_rt(b_cur1, wave_j, N_TILES_B, preshuffled=b_preshuffled) c00_frag = _mfma_ABt_all(a0_frag, b0_frag, c00_frag) a1_frag = _load_rt(a_cur1, wave_i, N_TILES_A) c01_frag = _mfma_ABt_all(a0_frag, b1_frag, c01_frag) _wait_barrier((1 * N_TILES_A) + (1 * N_TILES_B)) a0_frag = _load_rt(a_next0, wave_i, N_TILES_A) c10_frag = _mfma_ABt_all(a1_frag, b0_frag, c10_frag) - b0_frag = _load_rt(b_next0, wave_j, N_TILES_B) + b0_frag = _load_rt(b_next0, wave_j, N_TILES_B, preshuffled=b_preshuffled) c11_frag = _mfma_ABt_all(a1_frag, b1_frag, c11_frag) a_cur0, a_next0 = a_next0, a_cur0 @@ -550,7 +596,7 @@ def _compute_block(lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_t base_row = tile_i * BLOCK_M + wave_i * (N_TILES_A * 16) base_col = tile_j * BLOCK_N + wave_j * (N_TILES_B * 16) _wait_barrier(0) - b1_frag = _load_rt(b_cur1, wave_j, N_TILES_B) + b1_frag = _load_rt(b_cur1, wave_j, N_TILES_B, preshuffled=b_preshuffled) a1_frag = _load_rt(a_cur1, wave_i, N_TILES_A) c00_frag = _mfma_ABt_all(a0_frag, b0_frag, c00_frag) c01_frag = _mfma_ABt_all(a0_frag, b1_frag, c01_frag) From 68f3e5af6dbc68ff05cebc7ccc8049522daf2843 Mon Sep 17 00:00:00 2001 From: coderfeli Date: Wed, 13 May 2026 13:41:52 +0000 Subject: [PATCH 4/8] change --- kernels/fp8_gemm_4wave.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/kernels/fp8_gemm_4wave.py b/kernels/fp8_gemm_4wave.py index bc5d6525..5dcc0988 100644 --- a/kernels/fp8_gemm_4wave.py +++ b/kernels/fp8_gemm_4wave.py @@ -251,9 +251,11 @@ def _compute_lds_swizzle(wave_idx, n_tiles, preshuffled=False): # LDS read byte for (R, C) within one sub-buffer: # * Non-preshuffled: ``byte = R*BLOCK_K + C_swizzled`` (XOR # swizzle on col by R-bits to avoid bank conflicts). - # * Preshuffled: ``byte = (R//8)*1024 + (R%8)*16 + - # (C//16)*128 + (C%16)``. col-major-within-wave-portion; - # 8 lanes per 128-byte LDS row utilise all 32 banks. + # * Preshuffled: ``byte = (R//8)*1024 + (R%8)*16 + (C//16)*128``. + # col-major-within-wave-portion; ``C`` is always a multiple of + # 16 here (``C = (lane//16)*16 + step*64``) so ``C%16 == 0`` + # drops out. 8 lanes per 128-byte LDS row utilise all 32 banks, + # producing the MI350-optimal 2-way ds_read_b128 bank pattern. lds_swz = [] for row_offset in range_constexpr(n_tiles): row = wave_idx * (n_tiles * 16) + row_offset * 16 + lane_id % 16 @@ -261,7 +263,7 @@ def _compute_lds_swizzle(wave_idx, n_tiles, preshuffled=False): for i in range_constexpr(2): col = (lane_id // 16) * 16 + i * 64 if const_expr(preshuffled): - swz.append((row // 8) * 1024 + (row % 8) * 16 + (col // 16) * 128 + (col % 16)) + swz.append((row // 8) * 1024 + (row % 8) * 16 + (col // 16) * 128) else: r, c = _swizzle_128(row, col) swz.append(r * BLOCK_K + c) @@ -308,7 +310,8 @@ def _load_rt(lds_src, wave_idx, n_tiles, preshuffled=False): for step in range_constexpr(2): col = (lane_id // 16) * 16 + step * 64 if const_expr(preshuffled): - byte = (row // 8) * 1024 + (row % 8) * 16 + (col // 16) * 128 + (col % 16) + # C%16 == 0 by construction; drop it. + byte = (row // 8) * 1024 + (row % 8) * 16 + (col // 16) * 128 else: r, c = _swizzle_128(row, col) byte = r * BLOCK_K + c From c75eaa4e81fd39c014efc4cf1359396b690988c4 Mon Sep 17 00:00:00 2001 From: coderfeli Date: Thu, 14 May 2026 09:55:16 +0000 Subject: [PATCH 5/8] fix: format fp8 4wave preshuffle kernel Co-authored-by: Cursor --- kernels/fp8_gemm_4wave.py | 58 +++++++++++++++++++++++++++++++++++---- 1 file changed, 53 insertions(+), 5 deletions(-) diff --git a/kernels/fp8_gemm_4wave.py b/kernels/fp8_gemm_4wave.py index 5dcc0988..cee2f8f1 100644 --- a/kernels/fp8_gemm_4wave.py +++ b/kernels/fp8_gemm_4wave.py @@ -399,7 +399,16 @@ def _mfma_ABt_one(a, b, c, m, n): return c def _interleaved_cluster( - lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_tiles_lds, a, b, c, + lds_dst, + gl_src, + k_offset, + gl_offsets, + wave_idx, + lds_src, + n_tiles_lds, + a, + b, + c, lds_src_preshuffled=False, ): # 64x64 output via 4x4 MFMAs, with per-tile G→LDS and LDS→reg @@ -459,7 +468,17 @@ def _interleaved_cluster( return c, rt_dst def _compute_cluster( - lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_tiles_lds, n_tiles_rt, a, b, c, + lds_dst, + gl_src, + k_offset, + gl_offsets, + wave_idx, + lds_src, + n_tiles_lds, + n_tiles_rt, + a, + b, + c, lds_src_preshuffled=False, ): _load_lds(gl_src, lds_dst, k_offset, gl_offsets, n_tiles_lds) @@ -468,17 +487,46 @@ def _compute_cluster( return c, rt_dst def _compute_block( - lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_tiles_lds, n_tiles_rt, a, b, c, + lds_dst, + gl_src, + k_offset, + gl_offsets, + wave_idx, + lds_src, + n_tiles_lds, + n_tiles_rt, + a, + b, + c, lds_src_preshuffled=False, ): if const_expr(_use_interleaved_block): return _interleaved_cluster( - lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_tiles_lds, a, b, c, + lds_dst, + gl_src, + k_offset, + gl_offsets, + wave_idx, + lds_src, + n_tiles_lds, + a, + b, + c, lds_src_preshuffled=lds_src_preshuffled, ) else: return _compute_cluster( - lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_tiles_lds, n_tiles_rt, a, b, c, + lds_dst, + gl_src, + k_offset, + gl_offsets, + wave_idx, + lds_src, + n_tiles_lds, + n_tiles_rt, + a, + b, + c, lds_src_preshuffled=lds_src_preshuffled, ) From 954cb167467b089587ea96993bd774feb3a1ae54 Mon Sep 17 00:00:00 2001 From: coderfeli Date: Thu, 14 May 2026 15:04:23 +0000 Subject: [PATCH 6/8] ci: report benchmark ratios against main Co-authored-by: Cursor --- .github/workflows/flydsl.yaml | 115 ++++++++++++++++++++++++++++-- scripts/compare_benchmark.py | 127 ++++++++++++++++++++++++++++++++++ scripts/run_benchmark.sh | 32 +++++++++ 3 files changed, 267 insertions(+), 7 deletions(-) create mode 100644 scripts/compare_benchmark.py diff --git a/.github/workflows/flydsl.yaml b/.github/workflows/flydsl.yaml index 20f59b43..ed7adeb3 100644 --- a/.github/workflows/flydsl.yaml +++ b/.github/workflows/flydsl.yaml @@ -179,19 +179,120 @@ jobs: fi - name: Run benchmarks + id: benchmarks run: | - docker exec flydsl_test bash -c "export PYTHONPATH=/tmp/aiter:\${PYTHONPATH:-} && export AITER_REPO=/tmp/aiter && cd /flydsl-test && bash scripts/run_benchmark.sh" + docker exec flydsl_test bash -c "export PYTHONPATH=/tmp/aiter:\${PYTHONPATH:-} && export AITER_REPO=/tmp/aiter && cd /flydsl-test && BENCH_LOG_DIR=/tmp/flydsl_bench_current bash scripts/run_benchmark.sh --output_csv /tmp/bench_current.csv" + + - name: Run benchmark baselines + id: bench-baselines + timeout-minutes: 180 + continue-on-error: true + run: | + docker exec flydsl_test bash <<'BASH' + set -u + cd /flydsl-test + rm -rf /tmp/flydsl-bench-main-* + rm -f /tmp/bench_main.csv /tmp/bench_prev_main.csv /tmp/bench_main_label /tmp/bench_prev_main_label + git worktree prune || true + git fetch origin main --depth=8 + + found=0 + idx=0 + for commit in $(git rev-list --max-count=8 FETCH_HEAD); do + label="main" + if [ "${idx}" -gt 0 ]; then + label="main~${idx}" + fi + worktree="/tmp/flydsl-bench-main-${idx}" + csv="/tmp/bench_candidate_${idx}.csv" + log_dir="/tmp/flydsl_bench_${label}" + echo "Trying benchmark baseline ${label} (${commit})" + + if ! git worktree add --detach "${worktree}" "${commit}"; then + echo "Failed to create worktree for ${label}; trying older main commit." + idx=$((idx + 1)) + continue + fi + + cp /flydsl-test/scripts/run_benchmark.sh "${worktree}/scripts/run_benchmark.sh" + ( + set -e -o pipefail + cd "${worktree}" + export MLIR_PATH=/llvm-project/mlir_install + python3 -m pip install -e . --use-pep517 2>&1 | tail -5 + export PYTHONPATH=/tmp/aiter:${PYTHONPATH:-} + export AITER_REPO=/tmp/aiter + BENCH_LOG_DIR="${log_dir}" bash scripts/run_benchmark.sh --output_csv "${csv}" + ) + status=$? + if [ "${status}" -eq 0 ] && [ -s "${csv}" ]; then + if [ "${found}" -eq 0 ]; then + cp "${csv}" /tmp/bench_main.csv + echo "${label}" >/tmp/bench_main_label + else + cp "${csv}" /tmp/bench_prev_main.csv + echo "${label}" >/tmp/bench_prev_main_label + fi + found=$((found + 1)) + if [ "${found}" -ge 2 ]; then + break + fi + else + echo "Benchmark baseline ${label} failed; trying older main commit." + fi + idx=$((idx + 1)) + done + + if [ "${found}" -eq 0 ]; then + echo "No usable main benchmark baseline found." + exit 0 + fi + if [ "${found}" -eq 1 ]; then + echo "Only one usable main benchmark baseline found." + fi + BASH + + - name: Check benchmark performance (current vs main) + if: steps.bench-baselines.outcome != 'skipped' + timeout-minutes: 5 + run: | + docker exec flydsl_test bash -c " + if [ ! -f /tmp/bench_main.csv ]; then + echo 'No usable main benchmark baseline found; skipping main comparison.' + exit 0 + fi + cd /flydsl-test + main_label=\$(cat /tmp/bench_main_label 2>/dev/null || echo main) + python3 scripts/compare_benchmark.py /tmp/bench_main.csv /tmp/bench_current.csv \ + --baseline-label \"\${main_label}\" --current-label current + " + + - name: Check benchmark performance (current vs main~1) + if: steps.bench-baselines.outcome != 'skipped' + timeout-minutes: 5 + run: | + docker exec flydsl_test bash -c " + if [ ! -f /tmp/bench_prev_main.csv ]; then + echo 'No second usable main benchmark baseline found; skipping previous-main comparison.' + exit 0 + fi + cd /flydsl-test + prev_main_label=\$(cat /tmp/bench_prev_main_label 2>/dev/null || echo main~1) + python3 scripts/compare_benchmark.py /tmp/bench_prev_main.csv /tmp/bench_current.csv \ + --baseline-label \"\${prev_main_label}\" --current-label current + " - name: Show benchmarks logs if: failure() run: | - docker exec flydsl_test bash -c 'cd /tmp/flydsl_bench && tar czf /tmp/flydsl_bench/logs.tgz *.log 2>/dev/null || echo "no logs"' - docker cp flydsl_test:/tmp/flydsl_bench/logs.tgz . || true - if [ -f logs.tgz ]; then - tar -xzf logs.tgz || true - cat *.log || true + docker exec flydsl_test bash -c 'cd /tmp && tar czf /tmp/flydsl_bench_logs.tgz flydsl_bench* bench_*.csv 2>/dev/null || echo "no logs"' + docker cp flydsl_test:/tmp/flydsl_bench_logs.tgz . || true + if [ -f flydsl_bench_logs.tgz ]; then + tar -xzf flydsl_bench_logs.tgz || true + cat flydsl_bench*/*.log || true + cat bench_*.csv || true else - echo "logs.tgz not found; skipping log extraction" + echo "flydsl_bench_logs.tgz not found; skipping log extraction" fi - name: Clean up diff --git a/scripts/compare_benchmark.py b/scripts/compare_benchmark.py new file mode 100644 index 00000000..137bdd69 --- /dev/null +++ b/scripts/compare_benchmark.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors +"""Report performance ratios between two run_benchmark.sh CSV outputs.""" + +from __future__ import annotations + +import argparse +import csv +import sys +from dataclasses import dataclass +from pathlib import Path + + +@dataclass(frozen=True) +class BenchmarkRow: + op: str + shape: str + dtype: str + metric_name: str + metric_value: float | None + status: str + + +def _parse_float(value: str) -> float | None: + if value in {"", "-", "skip"}: + return None + try: + return float(value) + except ValueError: + return None + + +def _read_csv(path: Path) -> dict[tuple[str, str, str], BenchmarkRow]: + rows: dict[tuple[str, str, str], BenchmarkRow] = {} + with path.open(newline="") as f: + reader = csv.DictReader(f) + required = {"op", "shape", "dtype", "tbps", "tflops", "status"} + missing = required.difference(reader.fieldnames or []) + if missing: + raise SystemExit(f"{path} is missing columns: {', '.join(sorted(missing))}") + + for raw in reader: + op = raw["op"] + shape = raw["shape"] + dtype = raw["dtype"] + tflops = _parse_float(raw["tflops"]) + tbps = _parse_float(raw["tbps"]) + if tflops is not None: + metric_name = "TFLOPS" + metric_value = tflops + else: + metric_name = "TB/s" + metric_value = tbps + rows[(op, shape, dtype)] = BenchmarkRow( + op=op, + shape=shape, + dtype=dtype, + metric_name=metric_name, + metric_value=metric_value, + status=raw["status"], + ) + return rows + + +def _format_key(key: tuple[str, str, str]) -> str: + op, shape, dtype = key + return f"{op:>18s} {shape:>34s} {dtype:>8s}" + + +def main() -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("baseline_csv", type=Path) + parser.add_argument("current_csv", type=Path) + parser.add_argument("--baseline-label", default="baseline") + parser.add_argument("--current-label", default="current") + args = parser.parse_args() + + baseline = _read_csv(args.baseline_csv) + current = _read_csv(args.current_csv) + + print(f"=== Benchmark: {args.current_label} vs {args.baseline_label} ===") + + compared = 0 + + for key in sorted(current.keys() & baseline.keys()): + base = baseline[key] + curr = current[key] + if base.metric_value is None: + continue + + if curr.metric_value is None: + print(f" {_format_key(key)} {args.current_label}=missing [SKIP]") + continue + if curr.metric_name != base.metric_name: + print( + f" {_format_key(key)} metric mismatch: " + f"{args.baseline_label}={base.metric_name}, " + f"{args.current_label}={curr.metric_name} [SKIP]" + ) + continue + + compared += 1 + delta = curr.metric_value - base.metric_value + delta_pct = (delta / base.metric_value) * 100.0 if base.metric_value else 0.0 + ratio = curr.metric_value / base.metric_value if base.metric_value else 0.0 + + print( + f" {_format_key(key)} " + f"{args.baseline_label}={base.metric_value:9.3f} {base.metric_name:<6s} " + f"{args.current_label}={curr.metric_value:9.3f} {curr.metric_name:<6s} " + f"ratio={ratio:6.3f}x delta={delta:+9.3f} ({delta_pct:+6.1f}%)" + ) + + skipped_new = len(set(current) - set(baseline)) + if skipped_new: + print(f"\nSkipped {skipped_new} new current-only benchmark row(s).") + + if compared == 0: + print("No comparable benchmark rows found.") + + print("\nBenchmark comparison report completed.") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/run_benchmark.sh b/scripts/run_benchmark.sh index 10325084..a4cd818c 100755 --- a/scripts/run_benchmark.sh +++ b/scripts/run_benchmark.sh @@ -26,6 +26,7 @@ fi BENCH_LOG_DIR="${BENCH_LOG_DIR:-/tmp/flydsl_bench}" mkdir -p "${BENCH_LOG_DIR}" +BENCH_OUTPUT_CSV="${BENCH_OUTPUT_CSV:-}" # Auto-select GPU with the most free VRAM (skip if HIP_VISIBLE_DEVICES is already set). if [ -z "${HIP_VISIBLE_DEVICES:-}" ] && command -v python3 >/dev/null 2>&1; then @@ -161,6 +162,7 @@ Usage: bash scripts/run_benchmark.sh softmax # run only softmax bash scripts/run_benchmark.sh layernorm moe # run only selected benchmarks bash scripts/run_benchmark.sh --only softmax,moe + bash scripts/run_benchmark.sh --output_csv /tmp/bench.csv bash scripts/run_benchmark.sh --list Supported ops: @@ -211,6 +213,21 @@ _fmt_table_header() { _emit_row() { op="$1"; shape="$2"; dtype="$3"; tbps="$4"; tflops="$5" printf "%-22.22s %-34.34s %-10.10s %10s %10s\n" "${op}" "${shape}" "${dtype}" "${tbps}" "${tflops}" + if [ -n "${BENCH_OUTPUT_CSV:-}" ]; then + status="ok" + if [ "${tbps}" = "skip" ] || [ "${tflops}" = "skip" ]; then + status="skip" + elif [ "${tbps}" = "-" ] && [ "${tflops}" = "-" ]; then + status="missing" + fi + python3 - "${BENCH_OUTPUT_CSV}" "${op}" "${shape}" "${dtype}" "${tbps}" "${tflops}" "${status}" <<'PY' +import csv +import sys + +with open(sys.argv[1], "a", newline="") as f: + csv.writer(f).writerow(sys.argv[2:]) +PY + fi } _normalize_op() { @@ -306,6 +323,16 @@ if [ "$#" -gt 0 ]; then _add_selected_op "$op" done ;; + --output_csv|--output-csv) + flag="$1" + shift + [ "$#" -gt 0 ] || _die "${flag} requires a CSV path" + BENCH_OUTPUT_CSV="$1" + ;; + --output_csv=*|--output-csv=*) + BENCH_OUTPUT_CSV="${1#*=}" + [ -n "${BENCH_OUTPUT_CSV}" ] || _die "$1 requires a CSV path" + ;; --*) _die "unknown flag '$1'" ;; @@ -321,6 +348,11 @@ if [ "$#" -gt 0 ]; then fi fi +if [ -n "${BENCH_OUTPUT_CSV}" ]; then + mkdir -p "$(dirname "${BENCH_OUTPUT_CSV}")" + printf "op,shape,dtype,tbps,tflops,status\n" >"${BENCH_OUTPUT_CSV}" +fi + _py_parse_and_emit() { # Args: op shape dtype log_path [M N] python3 - "$@" <<'PY' From 4d6f91cb24a2cb363ef97246aa49c006d11c4ca0 Mon Sep 17 00:00:00 2001 From: coderfeli Date: Thu, 14 May 2026 15:08:00 +0000 Subject: [PATCH 7/8] Revert "ci: report benchmark ratios against main" This reverts commit 954cb167467b089587ea96993bd774feb3a1ae54. --- .github/workflows/flydsl.yaml | 115 ++---------------------------- scripts/compare_benchmark.py | 127 ---------------------------------- scripts/run_benchmark.sh | 32 --------- 3 files changed, 7 insertions(+), 267 deletions(-) delete mode 100644 scripts/compare_benchmark.py diff --git a/.github/workflows/flydsl.yaml b/.github/workflows/flydsl.yaml index ed7adeb3..20f59b43 100644 --- a/.github/workflows/flydsl.yaml +++ b/.github/workflows/flydsl.yaml @@ -179,120 +179,19 @@ jobs: fi - name: Run benchmarks - id: benchmarks run: | - docker exec flydsl_test bash -c "export PYTHONPATH=/tmp/aiter:\${PYTHONPATH:-} && export AITER_REPO=/tmp/aiter && cd /flydsl-test && BENCH_LOG_DIR=/tmp/flydsl_bench_current bash scripts/run_benchmark.sh --output_csv /tmp/bench_current.csv" - - - name: Run benchmark baselines - id: bench-baselines - timeout-minutes: 180 - continue-on-error: true - run: | - docker exec flydsl_test bash <<'BASH' - set -u - cd /flydsl-test - rm -rf /tmp/flydsl-bench-main-* - rm -f /tmp/bench_main.csv /tmp/bench_prev_main.csv /tmp/bench_main_label /tmp/bench_prev_main_label - git worktree prune || true - git fetch origin main --depth=8 - - found=0 - idx=0 - for commit in $(git rev-list --max-count=8 FETCH_HEAD); do - label="main" - if [ "${idx}" -gt 0 ]; then - label="main~${idx}" - fi - worktree="/tmp/flydsl-bench-main-${idx}" - csv="/tmp/bench_candidate_${idx}.csv" - log_dir="/tmp/flydsl_bench_${label}" - echo "Trying benchmark baseline ${label} (${commit})" - - if ! git worktree add --detach "${worktree}" "${commit}"; then - echo "Failed to create worktree for ${label}; trying older main commit." - idx=$((idx + 1)) - continue - fi - - cp /flydsl-test/scripts/run_benchmark.sh "${worktree}/scripts/run_benchmark.sh" - ( - set -e -o pipefail - cd "${worktree}" - export MLIR_PATH=/llvm-project/mlir_install - python3 -m pip install -e . --use-pep517 2>&1 | tail -5 - export PYTHONPATH=/tmp/aiter:${PYTHONPATH:-} - export AITER_REPO=/tmp/aiter - BENCH_LOG_DIR="${log_dir}" bash scripts/run_benchmark.sh --output_csv "${csv}" - ) - status=$? - if [ "${status}" -eq 0 ] && [ -s "${csv}" ]; then - if [ "${found}" -eq 0 ]; then - cp "${csv}" /tmp/bench_main.csv - echo "${label}" >/tmp/bench_main_label - else - cp "${csv}" /tmp/bench_prev_main.csv - echo "${label}" >/tmp/bench_prev_main_label - fi - found=$((found + 1)) - if [ "${found}" -ge 2 ]; then - break - fi - else - echo "Benchmark baseline ${label} failed; trying older main commit." - fi - idx=$((idx + 1)) - done - - if [ "${found}" -eq 0 ]; then - echo "No usable main benchmark baseline found." - exit 0 - fi - if [ "${found}" -eq 1 ]; then - echo "Only one usable main benchmark baseline found." - fi - BASH - - - name: Check benchmark performance (current vs main) - if: steps.bench-baselines.outcome != 'skipped' - timeout-minutes: 5 - run: | - docker exec flydsl_test bash -c " - if [ ! -f /tmp/bench_main.csv ]; then - echo 'No usable main benchmark baseline found; skipping main comparison.' - exit 0 - fi - cd /flydsl-test - main_label=\$(cat /tmp/bench_main_label 2>/dev/null || echo main) - python3 scripts/compare_benchmark.py /tmp/bench_main.csv /tmp/bench_current.csv \ - --baseline-label \"\${main_label}\" --current-label current - " - - - name: Check benchmark performance (current vs main~1) - if: steps.bench-baselines.outcome != 'skipped' - timeout-minutes: 5 - run: | - docker exec flydsl_test bash -c " - if [ ! -f /tmp/bench_prev_main.csv ]; then - echo 'No second usable main benchmark baseline found; skipping previous-main comparison.' - exit 0 - fi - cd /flydsl-test - prev_main_label=\$(cat /tmp/bench_prev_main_label 2>/dev/null || echo main~1) - python3 scripts/compare_benchmark.py /tmp/bench_prev_main.csv /tmp/bench_current.csv \ - --baseline-label \"\${prev_main_label}\" --current-label current - " + docker exec flydsl_test bash -c "export PYTHONPATH=/tmp/aiter:\${PYTHONPATH:-} && export AITER_REPO=/tmp/aiter && cd /flydsl-test && bash scripts/run_benchmark.sh" - name: Show benchmarks logs if: failure() run: | - docker exec flydsl_test bash -c 'cd /tmp && tar czf /tmp/flydsl_bench_logs.tgz flydsl_bench* bench_*.csv 2>/dev/null || echo "no logs"' - docker cp flydsl_test:/tmp/flydsl_bench_logs.tgz . || true - if [ -f flydsl_bench_logs.tgz ]; then - tar -xzf flydsl_bench_logs.tgz || true - cat flydsl_bench*/*.log || true - cat bench_*.csv || true + docker exec flydsl_test bash -c 'cd /tmp/flydsl_bench && tar czf /tmp/flydsl_bench/logs.tgz *.log 2>/dev/null || echo "no logs"' + docker cp flydsl_test:/tmp/flydsl_bench/logs.tgz . || true + if [ -f logs.tgz ]; then + tar -xzf logs.tgz || true + cat *.log || true else - echo "flydsl_bench_logs.tgz not found; skipping log extraction" + echo "logs.tgz not found; skipping log extraction" fi - name: Clean up diff --git a/scripts/compare_benchmark.py b/scripts/compare_benchmark.py deleted file mode 100644 index 137bdd69..00000000 --- a/scripts/compare_benchmark.py +++ /dev/null @@ -1,127 +0,0 @@ -#!/usr/bin/env python3 -# SPDX-License-Identifier: Apache-2.0 -# Copyright (c) 2025 FlyDSL Project Contributors -"""Report performance ratios between two run_benchmark.sh CSV outputs.""" - -from __future__ import annotations - -import argparse -import csv -import sys -from dataclasses import dataclass -from pathlib import Path - - -@dataclass(frozen=True) -class BenchmarkRow: - op: str - shape: str - dtype: str - metric_name: str - metric_value: float | None - status: str - - -def _parse_float(value: str) -> float | None: - if value in {"", "-", "skip"}: - return None - try: - return float(value) - except ValueError: - return None - - -def _read_csv(path: Path) -> dict[tuple[str, str, str], BenchmarkRow]: - rows: dict[tuple[str, str, str], BenchmarkRow] = {} - with path.open(newline="") as f: - reader = csv.DictReader(f) - required = {"op", "shape", "dtype", "tbps", "tflops", "status"} - missing = required.difference(reader.fieldnames or []) - if missing: - raise SystemExit(f"{path} is missing columns: {', '.join(sorted(missing))}") - - for raw in reader: - op = raw["op"] - shape = raw["shape"] - dtype = raw["dtype"] - tflops = _parse_float(raw["tflops"]) - tbps = _parse_float(raw["tbps"]) - if tflops is not None: - metric_name = "TFLOPS" - metric_value = tflops - else: - metric_name = "TB/s" - metric_value = tbps - rows[(op, shape, dtype)] = BenchmarkRow( - op=op, - shape=shape, - dtype=dtype, - metric_name=metric_name, - metric_value=metric_value, - status=raw["status"], - ) - return rows - - -def _format_key(key: tuple[str, str, str]) -> str: - op, shape, dtype = key - return f"{op:>18s} {shape:>34s} {dtype:>8s}" - - -def main() -> int: - parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument("baseline_csv", type=Path) - parser.add_argument("current_csv", type=Path) - parser.add_argument("--baseline-label", default="baseline") - parser.add_argument("--current-label", default="current") - args = parser.parse_args() - - baseline = _read_csv(args.baseline_csv) - current = _read_csv(args.current_csv) - - print(f"=== Benchmark: {args.current_label} vs {args.baseline_label} ===") - - compared = 0 - - for key in sorted(current.keys() & baseline.keys()): - base = baseline[key] - curr = current[key] - if base.metric_value is None: - continue - - if curr.metric_value is None: - print(f" {_format_key(key)} {args.current_label}=missing [SKIP]") - continue - if curr.metric_name != base.metric_name: - print( - f" {_format_key(key)} metric mismatch: " - f"{args.baseline_label}={base.metric_name}, " - f"{args.current_label}={curr.metric_name} [SKIP]" - ) - continue - - compared += 1 - delta = curr.metric_value - base.metric_value - delta_pct = (delta / base.metric_value) * 100.0 if base.metric_value else 0.0 - ratio = curr.metric_value / base.metric_value if base.metric_value else 0.0 - - print( - f" {_format_key(key)} " - f"{args.baseline_label}={base.metric_value:9.3f} {base.metric_name:<6s} " - f"{args.current_label}={curr.metric_value:9.3f} {curr.metric_name:<6s} " - f"ratio={ratio:6.3f}x delta={delta:+9.3f} ({delta_pct:+6.1f}%)" - ) - - skipped_new = len(set(current) - set(baseline)) - if skipped_new: - print(f"\nSkipped {skipped_new} new current-only benchmark row(s).") - - if compared == 0: - print("No comparable benchmark rows found.") - - print("\nBenchmark comparison report completed.") - return 0 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/scripts/run_benchmark.sh b/scripts/run_benchmark.sh index a4cd818c..10325084 100755 --- a/scripts/run_benchmark.sh +++ b/scripts/run_benchmark.sh @@ -26,7 +26,6 @@ fi BENCH_LOG_DIR="${BENCH_LOG_DIR:-/tmp/flydsl_bench}" mkdir -p "${BENCH_LOG_DIR}" -BENCH_OUTPUT_CSV="${BENCH_OUTPUT_CSV:-}" # Auto-select GPU with the most free VRAM (skip if HIP_VISIBLE_DEVICES is already set). if [ -z "${HIP_VISIBLE_DEVICES:-}" ] && command -v python3 >/dev/null 2>&1; then @@ -162,7 +161,6 @@ Usage: bash scripts/run_benchmark.sh softmax # run only softmax bash scripts/run_benchmark.sh layernorm moe # run only selected benchmarks bash scripts/run_benchmark.sh --only softmax,moe - bash scripts/run_benchmark.sh --output_csv /tmp/bench.csv bash scripts/run_benchmark.sh --list Supported ops: @@ -213,21 +211,6 @@ _fmt_table_header() { _emit_row() { op="$1"; shape="$2"; dtype="$3"; tbps="$4"; tflops="$5" printf "%-22.22s %-34.34s %-10.10s %10s %10s\n" "${op}" "${shape}" "${dtype}" "${tbps}" "${tflops}" - if [ -n "${BENCH_OUTPUT_CSV:-}" ]; then - status="ok" - if [ "${tbps}" = "skip" ] || [ "${tflops}" = "skip" ]; then - status="skip" - elif [ "${tbps}" = "-" ] && [ "${tflops}" = "-" ]; then - status="missing" - fi - python3 - "${BENCH_OUTPUT_CSV}" "${op}" "${shape}" "${dtype}" "${tbps}" "${tflops}" "${status}" <<'PY' -import csv -import sys - -with open(sys.argv[1], "a", newline="") as f: - csv.writer(f).writerow(sys.argv[2:]) -PY - fi } _normalize_op() { @@ -323,16 +306,6 @@ if [ "$#" -gt 0 ]; then _add_selected_op "$op" done ;; - --output_csv|--output-csv) - flag="$1" - shift - [ "$#" -gt 0 ] || _die "${flag} requires a CSV path" - BENCH_OUTPUT_CSV="$1" - ;; - --output_csv=*|--output-csv=*) - BENCH_OUTPUT_CSV="${1#*=}" - [ -n "${BENCH_OUTPUT_CSV}" ] || _die "$1 requires a CSV path" - ;; --*) _die "unknown flag '$1'" ;; @@ -348,11 +321,6 @@ if [ "$#" -gt 0 ]; then fi fi -if [ -n "${BENCH_OUTPUT_CSV}" ]; then - mkdir -p "$(dirname "${BENCH_OUTPUT_CSV}")" - printf "op,shape,dtype,tbps,tflops,status\n" >"${BENCH_OUTPUT_CSV}" -fi - _py_parse_and_emit() { # Args: op shape dtype log_path [M N] python3 - "$@" <<'PY' From 263b9f1458db46ae54f0198facd6e162c9ed39a8 Mon Sep 17 00:00:00 2001 From: coderfeli Date: Fri, 15 May 2026 00:50:45 +0000 Subject: [PATCH 8/8] fix: trim fp8 4wave comments Co-authored-by: Cursor --- kernels/fp8_gemm_4wave.py | 52 ++-------------------------- tests/kernels/test_fp8_gemm_4wave.py | 6 ++-- 2 files changed, 6 insertions(+), 52 deletions(-) diff --git a/kernels/fp8_gemm_4wave.py b/kernels/fp8_gemm_4wave.py index cee2f8f1..67f26707 100644 --- a/kernels/fp8_gemm_4wave.py +++ b/kernels/fp8_gemm_4wave.py @@ -13,12 +13,8 @@ the 8-buffer LDS pipeline ping-pong are kept as direct arithmetic to preserve the original kernel's interleaved-cluster scheduling. -Optional preshuffle for B. ``compile_fp8_gemm(..., b_preshuffled=True)`` -consumes B in the same on-disk layout as -``preshuffle_gemm_v2``/``shuffle_weight((16, 16))`` so a single -offline-shuffled weight tensor can feed either kernel. The host -helper :func:`preshuffle_b` performs the permute. LDS layout, MFMA -fragment shape, and wave assignment are unchanged. +Optional B preshuffle uses the same on-disk layout as +``preshuffle_gemm_v2`` / ``shuffle_weight((16, 16))``. """ import flydsl.compiler as flyc @@ -35,18 +31,7 @@ def preshuffle_b(b_t): - """Permute row-major ``B_T`` ``(N, K)`` into the layout consumed when - ``compile_fp8_gemm(..., b_preshuffled=True)``. - - On-disk format matches ``preshuffle_gemm.shuffle_weight((16, 16))`` - / ``preshuffle_gemm_v2``: N-major outermost - ``((nlane=16, n_outer=N/16), (kpack=16, klane=4, k0=K/64))`` - with strides ``((16, K*16), (1, 256, 1024))``. Each 64-K-col block - of 16 N rows = 1 KB is contiguous; one MFMA_Scale_128 K-tile = 2 - consecutive k0 blocks = 2 KB. - - Equivalent to ``b.reshape(N/16, 16, K/64, 4, 16).permute(0, 2, 3, 1, 4)``. - """ + """Permute row-major ``B_T`` ``(N, K)`` for ``b_preshuffled=True``.""" n, k = b_t.shape[-2:] assert n % 16 == 0 and k % 64 == 0, f"need N%16==0 and K%64==0, got N={n} K={k}" return b_t.reshape(n // 16, 16, k // 64, 4, 16).permute(0, 2, 3, 1, 4).contiguous() @@ -171,11 +156,6 @@ def kernel_gemm( A0_gl_offset = (tile_i * BLOCK_M) * K A1_gl_offset = (tile_i * BLOCK_M + LDS_BLOCK_M) * K A_K_STEP = BLOCK_K - # N-major preshuffle matches preshuffle_gemm_v2 on-disk format: - # ``((nlane=16, n_outer=N/16), (kpack=16, klane=4, k0=K/64))`` - # with strides ``((16, K*16), (1, 256, 1024))``. Tile base is - # unchanged (``(tile_j*BLOCK_N // 16) * K*16 = tile_j*BLOCK_N*K``); - # one K-tile of 128 K-cols = 2 k0 entries forward = 2 KB. B0_gl_offset = (tile_j * BLOCK_N) * K B1_gl_offset = (tile_j * BLOCK_N + LDS_BLOCK_N) * K B_K_STEP = (2 * 1024) if b_preshuffled else BLOCK_K @@ -213,21 +193,6 @@ def _swizzle_128(row, col): return swizzled // BLOCK_K, swizzled % BLOCK_K def _compute_global_swizzle(preshuffled): - # Two thread→tile mappings: - # * Non-preshuffled (row-major B / A): ``lane//8 → row``, - # ``lane%8 → col_chunk``. 8 lanes share a row at adjacent - # 16-byte col chunks → contiguous 128 B per lane-group in DRAM. - # LDS write at ``lane*16`` produces row-major-within-wave LDS. - # XOR swizzle reduces ds_read bank conflicts. - # * Preshuffled (preshuffle_gemm_v2 N-major layout for B, - # ``((nlane=16, n_outer=N/16), (kpack=16, klane=4, k0=K/64))`` - # with strides ``((16, K*16), (1, 256, 1024))``): swap to - # ``lane%8 → row``, ``lane//8 → col_chunk``. 8 lanes now span - # NLANE=0..7 at the same (n_outer, klane, k0) → 128 B - # contiguous in DRAM. LDS write at ``lane*16`` produces - # col-major-within-wave LDS where each 128-byte LDS row gets - # exactly 8 lanes (32 banks fully utilised) → no bank conflict - # without any additional swizzle. offsets = [] for round in range_constexpr(max(N_TILES_A, N_TILES_B)): if const_expr(preshuffled): @@ -248,14 +213,6 @@ def _compute_global_swizzle(preshuffled): return offsets def _compute_lds_swizzle(wave_idx, n_tiles, preshuffled=False): - # LDS read byte for (R, C) within one sub-buffer: - # * Non-preshuffled: ``byte = R*BLOCK_K + C_swizzled`` (XOR - # swizzle on col by R-bits to avoid bank conflicts). - # * Preshuffled: ``byte = (R//8)*1024 + (R%8)*16 + (C//16)*128``. - # col-major-within-wave-portion; ``C`` is always a multiple of - # 16 here (``C = (lane//16)*16 + step*64``) so ``C%16 == 0`` - # drops out. 8 lanes per 128-byte LDS row utilise all 32 banks, - # producing the MI350-optimal 2-way ds_read_b128 bank pattern. lds_swz = [] for row_offset in range_constexpr(n_tiles): row = wave_idx * (n_tiles * 16) + row_offset * 16 + lane_id % 16 @@ -310,7 +267,6 @@ def _load_rt(lds_src, wave_idx, n_tiles, preshuffled=False): for step in range_constexpr(2): col = (lane_id // 16) * 16 + step * 64 if const_expr(preshuffled): - # C%16 == 0 by construction; drop it. byte = (row // 8) * 1024 + (row % 8) * 16 + (col // 16) * 128 else: r, c = _swizzle_128(row, col) @@ -411,8 +367,6 @@ def _interleaved_cluster( c, lds_src_preshuffled=False, ): - # 64x64 output via 4x4 MFMAs, with per-tile G→LDS and LDS→reg - # loads interleaved between MFMAs to hide latency. rt_dst = [] c = _mfma_ABt_one(a, b, c, 0, 0) diff --git a/tests/kernels/test_fp8_gemm_4wave.py b/tests/kernels/test_fp8_gemm_4wave.py index a984a0a8..7c8f1fab 100644 --- a/tests/kernels/test_fp8_gemm_4wave.py +++ b/tests/kernels/test_fp8_gemm_4wave.py @@ -207,13 +207,13 @@ def test_fp8_gemm_4wave(M, N, K, tile_m, tile_n, preshuffle_b): "--num_iters", type=int, default=100, - help="Use 100+ for stable steady-state perf measurement.", + help="Benchmark iterations.", ) parser.add_argument( "--num_warmups", type=int, default=10, - help=">=10 needed to reach steady state on the large BLOCK=256 shapes.", + help="Warmup iterations.", ) parser.add_argument( "--vs_torch", @@ -225,7 +225,7 @@ def test_fp8_gemm_4wave(M, N, K, tile_m, tile_n, preshuffle_b): "--preshuffle_b", action="store_true", default=False, - help="Preshuffle B into K-major slabs (each K atom = N*128 byte contig).", + help="Use preshuffled B layout.", ) args = parser.parse_args()