diff --git a/kernels/fp8_gemm_4wave.py b/kernels/fp8_gemm_4wave.py index d9f4fcd0..67f26707 100644 --- a/kernels/fp8_gemm_4wave.py +++ b/kernels/fp8_gemm_4wave.py @@ -12,6 +12,9 @@ the chained Vec(4, f32) accumulator stays on AGPR. The XOR swizzle and the 8-buffer LDS pipeline ping-pong are kept as direct arithmetic to preserve the original kernel's interleaved-cluster scheduling. + +Optional B preshuffle uses the same on-disk layout as +``preshuffle_gemm_v2`` / ``shuffle_weight((16, 16))``. """ import flydsl.compiler as flyc @@ -27,6 +30,13 @@ from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr +def preshuffle_b(b_t): + """Permute row-major ``B_T`` ``(N, K)`` for ``b_preshuffled=True``.""" + n, k = b_t.shape[-2:] + assert n % 16 == 0 and k % 64 == 0, f"need N%16==0 and K%64==0, got N={n} K={k}" + return b_t.reshape(n // 16, 16, k // 64, 4, 16).permute(0, 2, 3, 1, 4).contiguous() + + def _divmod(a, b): return (a // b, a % b) @@ -59,7 +69,16 @@ def _xcd_swizzle(num_pid_m, num_pid_n): return (pid_m, pid_n) -def compile_fp8_gemm(*, M: int, N: int, K: int, BLOCK_M: int = 256, BLOCK_N: int = 256, use_xcd_remap: bool = True): +def compile_fp8_gemm( + *, + M: int, + N: int, + K: int, + BLOCK_M: int = 256, + BLOCK_N: int = 256, + use_xcd_remap: bool = True, + b_preshuffled: bool = False, +): # MFMA atom is 16x16x128; 4 waves in a 2x2 config require BLOCK >= 64. BLOCK_K = 128 LDS_BLOCK_M = BLOCK_M // 2 @@ -136,8 +155,10 @@ def kernel_gemm( wave_j = wave_id % 2 A0_gl_offset = (tile_i * BLOCK_M) * K A1_gl_offset = (tile_i * BLOCK_M + LDS_BLOCK_M) * K + A_K_STEP = BLOCK_K B0_gl_offset = (tile_j * BLOCK_N) * K B1_gl_offset = (tile_j * BLOCK_N + LDS_BLOCK_N) * K + B_K_STEP = (2 * 1024) if b_preshuffled else BLOCK_K # A/B come in as torch.int8 (PyTorch fp8 view restriction); recast # the buffer-desc pointer's element type to fp8 so typed copy @@ -171,24 +192,38 @@ def _swizzle_128(row, col): swizzled = offset ^ swz return swizzled // BLOCK_K, swizzled % BLOCK_K - def _compute_global_swizzle(): + def _compute_global_swizzle(preshuffled): offsets = [] for round in range_constexpr(max(N_TILES_A, N_TILES_B)): - row = lane_id // 8 + wave_id * 8 + round * 32 - col = (lane_id % 8) * 16 - r, c = _swizzle_128(row, col) - offsets.append(r * K + c) + if const_expr(preshuffled): + row = lane_id % 8 + wave_id * 8 + round * 32 + col = (lane_id // 8) * 16 + offsets.append( + (row // 16) * (K * 16) + + (row % 16) * 16 + + (col // 64) * 1024 + + ((col % 64) // 16) * 256 + + (col % 16) + ) + else: + row = lane_id // 8 + wave_id * 8 + round * 32 + col = (lane_id % 8) * 16 + r, c = _swizzle_128(row, col) + offsets.append(r * K + c) return offsets - def _compute_lds_swizzle(wave_idx, n_tiles): + def _compute_lds_swizzle(wave_idx, n_tiles, preshuffled=False): lds_swz = [] for row_offset in range_constexpr(n_tiles): row = wave_idx * (n_tiles * 16) + row_offset * 16 + lane_id % 16 swz = [] for i in range_constexpr(2): col = (lane_id // 16) * 16 + i * 64 - r, c = _swizzle_128(row, col) - swz.append(r * BLOCK_K + c) + if const_expr(preshuffled): + swz.append((row // 8) * 1024 + (row % 8) * 16 + (col // 16) * 128) + else: + r, c = _swizzle_128(row, col) + swz.append(r * BLOCK_K + c) lds_swz.append(swz) return lds_swz @@ -224,15 +259,19 @@ def _load_one_lds(gl_src_div, lds_dst_mem, k_offset, gl_offsets, tile_idx): def _pack_i32x4_i32x8(lo, hi): return lo.shuffle(hi, list(range(8))) - def _load_rt(lds_src, wave_idx, n_tiles): + def _load_rt(lds_src, wave_idx, n_tiles, preshuffled=False): frag = [] for i in range_constexpr(n_tiles): row = wave_idx * (n_tiles * 16) + i * 16 + lane_id % 16 halves = [] for step in range_constexpr(2): col = (lane_id // 16) * 16 + step * 64 - r, c = _swizzle_128(row, col) - v = Vec.load(Vec16_t, lds_src, [fx.Index(r * BLOCK_K + c)]) + if const_expr(preshuffled): + byte = (row // 8) * 1024 + (row % 8) * 16 + (col // 16) * 128 + else: + r, c = _swizzle_128(row, col) + byte = r * BLOCK_K + c + v = Vec.load(Vec16_t, lds_src, [fx.Index(byte)]) halves.append(v.bitcast(fx.Int32)) frag.append(_pack_i32x4_i32x8(halves[0], halves[1])) return frag @@ -315,15 +354,25 @@ def _mfma_ABt_one(a, b, c, m, n): c[_c_idx(m, n)] = _mfma(a[m], b[n], c[_c_idx(m, n)]) return c - def _interleaved_cluster(lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_tiles_lds, a, b, c): - # 64x64 output via 4x4 MFMAs, with per-tile G→LDS and LDS→reg - # loads interleaved between MFMAs to hide latency. + def _interleaved_cluster( + lds_dst, + gl_src, + k_offset, + gl_offsets, + wave_idx, + lds_src, + n_tiles_lds, + a, + b, + c, + lds_src_preshuffled=False, + ): rt_dst = [] c = _mfma_ABt_one(a, b, c, 0, 0) c = _mfma_ABt_one(a, b, c, 0, 1) - lds_swz = _compute_lds_swizzle(wave_idx, n_tiles_lds) + lds_swz = _compute_lds_swizzle(wave_idx, n_tiles_lds, preshuffled=lds_src_preshuffled) _load_one_lds(gl_src, lds_dst, k_offset, gl_offsets, 0) rt_dst_0 = _load_one_rt(lds_src, lds_swz, 0, 0) @@ -373,21 +422,66 @@ def _interleaved_cluster(lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_sr return c, rt_dst def _compute_cluster( - lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_tiles_lds, n_tiles_rt, a, b, c + lds_dst, + gl_src, + k_offset, + gl_offsets, + wave_idx, + lds_src, + n_tiles_lds, + n_tiles_rt, + a, + b, + c, + lds_src_preshuffled=False, ): _load_lds(gl_src, lds_dst, k_offset, gl_offsets, n_tiles_lds) - rt_dst = _load_rt(lds_src, wave_idx, n_tiles_rt) + rt_dst = _load_rt(lds_src, wave_idx, n_tiles_rt, preshuffled=lds_src_preshuffled) c = _mfma_ABt_all(a, b, c) return c, rt_dst - def _compute_block(lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_tiles_lds, n_tiles_rt, a, b, c): + def _compute_block( + lds_dst, + gl_src, + k_offset, + gl_offsets, + wave_idx, + lds_src, + n_tiles_lds, + n_tiles_rt, + a, + b, + c, + lds_src_preshuffled=False, + ): if const_expr(_use_interleaved_block): return _interleaved_cluster( - lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_tiles_lds, a, b, c + lds_dst, + gl_src, + k_offset, + gl_offsets, + wave_idx, + lds_src, + n_tiles_lds, + a, + b, + c, + lds_src_preshuffled=lds_src_preshuffled, ) else: return _compute_cluster( - lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_tiles_lds, n_tiles_rt, a, b, c + lds_dst, + gl_src, + k_offset, + gl_offsets, + wave_idx, + lds_src, + n_tiles_lds, + n_tiles_rt, + a, + b, + c, + lds_src_preshuffled=lds_src_preshuffled, ) # Each wave handles 2x2 64x64 sub-tiles of the output. @@ -396,18 +490,19 @@ def _compute_block(lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_t c10_frag = [RT_C_i] * N_ACCUMS c11_frag = [RT_C_i] * N_ACCUMS - global_offsets = _compute_global_swizzle() + gl_off_a = _compute_global_swizzle(preshuffled=False) + gl_off_b = _compute_global_swizzle(b_preshuffled) # Prologue: 8-buffer LDS pipeline pre-fill. - _load_lds(ga_div, a_cur0, A0_gl_offset + 0 * BLOCK_K, global_offsets, N_TILES_A) - _load_lds(gb_div, b_cur0, B0_gl_offset + 0 * BLOCK_K, global_offsets, N_TILES_B) - _load_lds(gb_div, b_cur1, B1_gl_offset + 0 * BLOCK_K, global_offsets, N_TILES_B) - _load_lds(ga_div, a_cur1, A1_gl_offset + 0 * BLOCK_K, global_offsets, N_TILES_A) + _load_lds(ga_div, a_cur0, A0_gl_offset + 0 * A_K_STEP, gl_off_a, N_TILES_A) + _load_lds(gb_div, b_cur0, B0_gl_offset + 0 * B_K_STEP, gl_off_b, N_TILES_B) + _load_lds(gb_div, b_cur1, B1_gl_offset + 0 * B_K_STEP, gl_off_b, N_TILES_B) + _load_lds(ga_div, a_cur1, A1_gl_offset + 0 * A_K_STEP, gl_off_a, N_TILES_A) - _load_lds(ga_div, a_next0, A0_gl_offset + 1 * BLOCK_K, global_offsets, N_TILES_A) - _load_lds(gb_div, b_next0, B0_gl_offset + 1 * BLOCK_K, global_offsets, N_TILES_B) - _load_lds(gb_div, b_next1, B1_gl_offset + 1 * BLOCK_K, global_offsets, N_TILES_B) - _load_lds(ga_div, a_next1, A1_gl_offset + 1 * BLOCK_K, global_offsets, N_TILES_A) + _load_lds(ga_div, a_next0, A0_gl_offset + 1 * A_K_STEP, gl_off_a, N_TILES_A) + _load_lds(gb_div, b_next0, B0_gl_offset + 1 * B_K_STEP, gl_off_b, N_TILES_B) + _load_lds(gb_div, b_next1, B1_gl_offset + 1 * B_K_STEP, gl_off_b, N_TILES_B) + _load_lds(ga_div, a_next1, A1_gl_offset + 1 * A_K_STEP, gl_off_a, N_TILES_A) _wait_barrier((3 * N_TILES_A) + (4 * N_TILES_B)) @@ -415,7 +510,7 @@ def _compute_block(lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_t _wait_barrier((3 * N_TILES_A) + (3 * N_TILES_B)) - b0_frag = _load_rt(b_cur0, wave_j, N_TILES_B) + b0_frag = _load_rt(b_cur0, wave_j, N_TILES_B, preshuffled=b_preshuffled) for k in range_constexpr(K_ITERS - 2): _wait_barrier((2 * N_TILES_A) + (2 * N_TILES_B)) @@ -423,8 +518,8 @@ def _compute_block(lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_t c00_frag, b1_frag = _compute_block( a_cur0, ga_div, - A0_gl_offset + (k + 2) * BLOCK_K, - global_offsets, + A0_gl_offset + (k + 2) * A_K_STEP, + gl_off_a, wave_j, b_cur1, N_TILES_A, @@ -432,13 +527,14 @@ def _compute_block(lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_t a0_frag, b0_frag, c00_frag, + lds_src_preshuffled=b_preshuffled, ) c01_frag, a1_frag = _compute_block( b_cur0, gb_div, - B0_gl_offset + (k + 2) * BLOCK_K, - global_offsets, + B0_gl_offset + (k + 2) * B_K_STEP, + gl_off_b, wave_i, a_cur1, N_TILES_B, @@ -453,8 +549,8 @@ def _compute_block(lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_t c10_frag, a0_frag = _compute_block( b_cur1, gb_div, - B1_gl_offset + (k + 2) * BLOCK_K, - global_offsets, + B1_gl_offset + (k + 2) * B_K_STEP, + gl_off_b, wave_i, a_next0, N_TILES_B, @@ -467,8 +563,8 @@ def _compute_block(lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_t c11_frag, b0_frag = _compute_block( a_cur1, ga_div, - A1_gl_offset + (k + 2) * BLOCK_K, - global_offsets, + A1_gl_offset + (k + 2) * A_K_STEP, + gl_off_a, wave_j, b_next0, N_TILES_A, @@ -476,6 +572,7 @@ def _compute_block(lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_t a1_frag, b1_frag, c11_frag, + lds_src_preshuffled=b_preshuffled, ) a_cur0, a_next0 = a_next0, a_cur0 @@ -485,14 +582,14 @@ def _compute_block(lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_t # Tail step k_iters - 2. _wait_barrier((2 * N_TILES_A) + (2 * N_TILES_B)) - b1_frag = _load_rt(b_cur1, wave_j, N_TILES_B) + b1_frag = _load_rt(b_cur1, wave_j, N_TILES_B, preshuffled=b_preshuffled) c00_frag = _mfma_ABt_all(a0_frag, b0_frag, c00_frag) a1_frag = _load_rt(a_cur1, wave_i, N_TILES_A) c01_frag = _mfma_ABt_all(a0_frag, b1_frag, c01_frag) _wait_barrier((1 * N_TILES_A) + (1 * N_TILES_B)) a0_frag = _load_rt(a_next0, wave_i, N_TILES_A) c10_frag = _mfma_ABt_all(a1_frag, b0_frag, c10_frag) - b0_frag = _load_rt(b_next0, wave_j, N_TILES_B) + b0_frag = _load_rt(b_next0, wave_j, N_TILES_B, preshuffled=b_preshuffled) c11_frag = _mfma_ABt_all(a1_frag, b1_frag, c11_frag) a_cur0, a_next0 = a_next0, a_cur0 @@ -504,7 +601,7 @@ def _compute_block(lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_t base_row = tile_i * BLOCK_M + wave_i * (N_TILES_A * 16) base_col = tile_j * BLOCK_N + wave_j * (N_TILES_B * 16) _wait_barrier(0) - b1_frag = _load_rt(b_cur1, wave_j, N_TILES_B) + b1_frag = _load_rt(b_cur1, wave_j, N_TILES_B, preshuffled=b_preshuffled) a1_frag = _load_rt(a_cur1, wave_i, N_TILES_A) c00_frag = _mfma_ABt_all(a0_frag, b0_frag, c00_frag) c01_frag = _mfma_ABt_all(a0_frag, b1_frag, c01_frag) diff --git a/tests/kernels/test_fp8_gemm_4wave.py b/tests/kernels/test_fp8_gemm_4wave.py index d8cedae0..7c8f1fab 100644 --- a/tests/kernels/test_fp8_gemm_4wave.py +++ b/tests/kernels/test_fp8_gemm_4wave.py @@ -23,7 +23,7 @@ sys.path.insert(0, _REPO_ROOT) from flydsl.runtime.device import get_rocm_arch # noqa: E402 -from kernels.fp8_gemm_4wave import compile_fp8_gemm # noqa: E402 +from kernels.fp8_gemm_4wave import compile_fp8_gemm, preshuffle_b # noqa: E402 from tests.test_common import run_perftest, verify_output # noqa: E402 from tests.utils import pertoken_quant # noqa: E402 @@ -72,6 +72,7 @@ def _bench_fp8_gemm_4wave( num_warmups: int = 2, num_iters: int = 10, vs_torch: bool = False, + b_preshuffled: bool = False, ): """Run + verify a single (M, N, K, tile) configuration. Returns TFLOPS.""" if "gfx95" not in ARCH: @@ -92,6 +93,8 @@ def _bench_fp8_gemm_4wave( c_ref = _run_torch(a_q, b_q, scale_a, scale_b) + b_kernel = preshuffle_b(b_q) if b_preshuffled else b_q + launch_fn = compile_fp8_gemm( M=M, N=N, @@ -99,9 +102,12 @@ def _bench_fp8_gemm_4wave( BLOCK_M=tile_m, BLOCK_N=tile_n, use_xcd_remap=not disable_xcd_remap, + b_preshuffled=b_preshuffled, ) print( - f"\n[fp8_gemm_4wave] M={M} N={N} K={K} " f"BLOCK_M={tile_m} BLOCK_N={tile_n} xcd_remap={not disable_xcd_remap}" + f"\n[fp8_gemm_4wave] M={M} N={N} K={K} " + f"BLOCK_M={tile_m} BLOCK_N={tile_n} xcd_remap={not disable_xcd_remap} " + f"preshuffle_b={b_preshuffled}" ) def _args(c, a, b, sa, sb): @@ -114,7 +120,7 @@ def _args(c, a, b, sa, sb): torch.cuda.current_stream(), ) - compiled = flyc.compile(launch_fn, *_args(c_out_raw, a_q, b_q, scale_a, scale_b)) + compiled = flyc.compile(launch_fn, *_args(c_out_raw, a_q, b_kernel, scale_a, scale_b)) def _launch(c, a, b, sa, sb): compiled(*_args(c, a, b, sa, sb)) @@ -124,7 +130,7 @@ def _launch(c, a, b, sa, sb): _launch, c_out_raw, a_q, - b_q, + b_kernel, scale_a, scale_b, num_iters=num_iters, @@ -175,8 +181,16 @@ def _launch(c, a, b, sa, sb): pytest.param(9728, 8192, 8320, 256, 256, marks=pytest.mark.large_shape, id="9728x8192x8320"), ], ) -def test_fp8_gemm_4wave(M, N, K, tile_m, tile_n): - _bench_fp8_gemm_4wave(M=M, N=N, K=K, tile_m=tile_m, tile_n=tile_n) +@pytest.mark.parametrize("preshuffle_b", [False, True], ids=["rowmajor", "preshuffle_b"]) +def test_fp8_gemm_4wave(M, N, K, tile_m, tile_n, preshuffle_b): + _bench_fp8_gemm_4wave( + M=M, + N=N, + K=K, + tile_m=tile_m, + tile_n=tile_n, + b_preshuffled=preshuffle_b, + ) if __name__ == "__main__": @@ -189,14 +203,30 @@ def test_fp8_gemm_4wave(M, N, K, tile_m, tile_n): parser.add_argument("--tile_m", type=int, default=256) parser.add_argument("--tile_n", type=int, default=256) parser.add_argument("--disable_xcd_remap", action="store_true", default=False) - parser.add_argument("--num_iters", type=int, default=10) - parser.add_argument("--num_warmups", type=int, default=2) + parser.add_argument( + "--num_iters", + type=int, + default=100, + help="Benchmark iterations.", + ) + parser.add_argument( + "--num_warmups", + type=int, + default=10, + help="Warmup iterations.", + ) parser.add_argument( "--vs_torch", action="store_true", default=False, help="Also run torch._scaled_mm with the same input + harness for perf comparison.", ) + parser.add_argument( + "--preshuffle_b", + action="store_true", + default=False, + help="Use preshuffled B layout.", + ) args = parser.parse_args() torch.set_default_device("cuda") @@ -212,6 +242,7 @@ def test_fp8_gemm_4wave(M, N, K, tile_m, tile_n): num_warmups=args.num_warmups, num_iters=args.num_iters, vs_torch=args.vs_torch, + b_preshuffled=args.preshuffle_b, ) except pytest.skip.Exception as e: print(f"Skipped: {e}")