Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 139 additions & 42 deletions kernels/fp8_gemm_4wave.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
the chained Vec(4, f32) accumulator stays on AGPR. The XOR swizzle and
the 8-buffer LDS pipeline ping-pong are kept as direct arithmetic to
preserve the original kernel's interleaved-cluster scheduling.

Optional B preshuffle uses the same on-disk layout as
``preshuffle_gemm_v2`` / ``shuffle_weight((16, 16))``.
"""

import flydsl.compiler as flyc
Expand All @@ -27,6 +30,13 @@
from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr


def preshuffle_b(b_t):
"""Permute row-major ``B_T`` ``(N, K)`` for ``b_preshuffled=True``."""
n, k = b_t.shape[-2:]
assert n % 16 == 0 and k % 64 == 0, f"need N%16==0 and K%64==0, got N={n} K={k}"
return b_t.reshape(n // 16, 16, k // 64, 4, 16).permute(0, 2, 3, 1, 4).contiguous()


def _divmod(a, b):
return (a // b, a % b)

Expand Down Expand Up @@ -59,7 +69,16 @@ def _xcd_swizzle(num_pid_m, num_pid_n):
return (pid_m, pid_n)


def compile_fp8_gemm(*, M: int, N: int, K: int, BLOCK_M: int = 256, BLOCK_N: int = 256, use_xcd_remap: bool = True):
def compile_fp8_gemm(
*,
M: int,
N: int,
K: int,
BLOCK_M: int = 256,
BLOCK_N: int = 256,
use_xcd_remap: bool = True,
b_preshuffled: bool = False,
):
# MFMA atom is 16x16x128; 4 waves in a 2x2 config require BLOCK >= 64.
BLOCK_K = 128
LDS_BLOCK_M = BLOCK_M // 2
Expand Down Expand Up @@ -136,8 +155,10 @@ def kernel_gemm(
wave_j = wave_id % 2
A0_gl_offset = (tile_i * BLOCK_M) * K
A1_gl_offset = (tile_i * BLOCK_M + LDS_BLOCK_M) * K
A_K_STEP = BLOCK_K
B0_gl_offset = (tile_j * BLOCK_N) * K
B1_gl_offset = (tile_j * BLOCK_N + LDS_BLOCK_N) * K
B_K_STEP = (2 * 1024) if b_preshuffled else BLOCK_K

# A/B come in as torch.int8 (PyTorch fp8 view restriction); recast
# the buffer-desc pointer's element type to fp8 so typed copy
Expand Down Expand Up @@ -171,24 +192,38 @@ def _swizzle_128(row, col):
swizzled = offset ^ swz
return swizzled // BLOCK_K, swizzled % BLOCK_K

def _compute_global_swizzle():
def _compute_global_swizzle(preshuffled):
offsets = []
for round in range_constexpr(max(N_TILES_A, N_TILES_B)):
row = lane_id // 8 + wave_id * 8 + round * 32
col = (lane_id % 8) * 16
r, c = _swizzle_128(row, col)
offsets.append(r * K + c)
if const_expr(preshuffled):
row = lane_id % 8 + wave_id * 8 + round * 32
col = (lane_id // 8) * 16
offsets.append(
(row // 16) * (K * 16)
+ (row % 16) * 16
+ (col // 64) * 1024
+ ((col % 64) // 16) * 256
+ (col % 16)
)
else:
row = lane_id // 8 + wave_id * 8 + round * 32
col = (lane_id % 8) * 16
r, c = _swizzle_128(row, col)
offsets.append(r * K + c)
return offsets

def _compute_lds_swizzle(wave_idx, n_tiles):
def _compute_lds_swizzle(wave_idx, n_tiles, preshuffled=False):
lds_swz = []
for row_offset in range_constexpr(n_tiles):
row = wave_idx * (n_tiles * 16) + row_offset * 16 + lane_id % 16
swz = []
for i in range_constexpr(2):
col = (lane_id // 16) * 16 + i * 64
r, c = _swizzle_128(row, col)
swz.append(r * BLOCK_K + c)
if const_expr(preshuffled):
swz.append((row // 8) * 1024 + (row % 8) * 16 + (col // 16) * 128)
else:
r, c = _swizzle_128(row, col)
swz.append(r * BLOCK_K + c)
lds_swz.append(swz)
return lds_swz

Expand Down Expand Up @@ -224,15 +259,19 @@ def _load_one_lds(gl_src_div, lds_dst_mem, k_offset, gl_offsets, tile_idx):
def _pack_i32x4_i32x8(lo, hi):
return lo.shuffle(hi, list(range(8)))

def _load_rt(lds_src, wave_idx, n_tiles):
def _load_rt(lds_src, wave_idx, n_tiles, preshuffled=False):
frag = []
for i in range_constexpr(n_tiles):
row = wave_idx * (n_tiles * 16) + i * 16 + lane_id % 16
halves = []
for step in range_constexpr(2):
col = (lane_id // 16) * 16 + step * 64
r, c = _swizzle_128(row, col)
v = Vec.load(Vec16_t, lds_src, [fx.Index(r * BLOCK_K + c)])
if const_expr(preshuffled):
byte = (row // 8) * 1024 + (row % 8) * 16 + (col // 16) * 128
else:
r, c = _swizzle_128(row, col)
byte = r * BLOCK_K + c
v = Vec.load(Vec16_t, lds_src, [fx.Index(byte)])
halves.append(v.bitcast(fx.Int32))
frag.append(_pack_i32x4_i32x8(halves[0], halves[1]))
return frag
Expand Down Expand Up @@ -315,15 +354,25 @@ def _mfma_ABt_one(a, b, c, m, n):
c[_c_idx(m, n)] = _mfma(a[m], b[n], c[_c_idx(m, n)])
return c

def _interleaved_cluster(lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_tiles_lds, a, b, c):
# 64x64 output via 4x4 MFMAs, with per-tile G→LDS and LDS→reg
# loads interleaved between MFMAs to hide latency.
def _interleaved_cluster(
lds_dst,
gl_src,
k_offset,
gl_offsets,
wave_idx,
lds_src,
n_tiles_lds,
a,
b,
c,
lds_src_preshuffled=False,
):
rt_dst = []

c = _mfma_ABt_one(a, b, c, 0, 0)
c = _mfma_ABt_one(a, b, c, 0, 1)

lds_swz = _compute_lds_swizzle(wave_idx, n_tiles_lds)
lds_swz = _compute_lds_swizzle(wave_idx, n_tiles_lds, preshuffled=lds_src_preshuffled)
_load_one_lds(gl_src, lds_dst, k_offset, gl_offsets, 0)
rt_dst_0 = _load_one_rt(lds_src, lds_swz, 0, 0)

Expand Down Expand Up @@ -373,21 +422,66 @@ def _interleaved_cluster(lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_sr
return c, rt_dst

def _compute_cluster(
lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_tiles_lds, n_tiles_rt, a, b, c
lds_dst,
gl_src,
k_offset,
gl_offsets,
wave_idx,
lds_src,
n_tiles_lds,
n_tiles_rt,
a,
b,
c,
lds_src_preshuffled=False,
):
_load_lds(gl_src, lds_dst, k_offset, gl_offsets, n_tiles_lds)
rt_dst = _load_rt(lds_src, wave_idx, n_tiles_rt)
rt_dst = _load_rt(lds_src, wave_idx, n_tiles_rt, preshuffled=lds_src_preshuffled)
c = _mfma_ABt_all(a, b, c)
return c, rt_dst

def _compute_block(lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_tiles_lds, n_tiles_rt, a, b, c):
def _compute_block(
lds_dst,
gl_src,
k_offset,
gl_offsets,
wave_idx,
lds_src,
n_tiles_lds,
n_tiles_rt,
a,
b,
c,
lds_src_preshuffled=False,
):
if const_expr(_use_interleaved_block):
return _interleaved_cluster(
lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_tiles_lds, a, b, c
lds_dst,
gl_src,
k_offset,
gl_offsets,
wave_idx,
lds_src,
n_tiles_lds,
a,
b,
c,
lds_src_preshuffled=lds_src_preshuffled,
)
else:
return _compute_cluster(
lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_tiles_lds, n_tiles_rt, a, b, c
lds_dst,
gl_src,
k_offset,
gl_offsets,
wave_idx,
lds_src,
n_tiles_lds,
n_tiles_rt,
a,
b,
c,
lds_src_preshuffled=lds_src_preshuffled,
)

# Each wave handles 2x2 64x64 sub-tiles of the output.
Expand All @@ -396,49 +490,51 @@ def _compute_block(lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_t
c10_frag = [RT_C_i] * N_ACCUMS
c11_frag = [RT_C_i] * N_ACCUMS

global_offsets = _compute_global_swizzle()
gl_off_a = _compute_global_swizzle(preshuffled=False)
gl_off_b = _compute_global_swizzle(b_preshuffled)

# Prologue: 8-buffer LDS pipeline pre-fill.
_load_lds(ga_div, a_cur0, A0_gl_offset + 0 * BLOCK_K, global_offsets, N_TILES_A)
_load_lds(gb_div, b_cur0, B0_gl_offset + 0 * BLOCK_K, global_offsets, N_TILES_B)
_load_lds(gb_div, b_cur1, B1_gl_offset + 0 * BLOCK_K, global_offsets, N_TILES_B)
_load_lds(ga_div, a_cur1, A1_gl_offset + 0 * BLOCK_K, global_offsets, N_TILES_A)
_load_lds(ga_div, a_cur0, A0_gl_offset + 0 * A_K_STEP, gl_off_a, N_TILES_A)
_load_lds(gb_div, b_cur0, B0_gl_offset + 0 * B_K_STEP, gl_off_b, N_TILES_B)
_load_lds(gb_div, b_cur1, B1_gl_offset + 0 * B_K_STEP, gl_off_b, N_TILES_B)
_load_lds(ga_div, a_cur1, A1_gl_offset + 0 * A_K_STEP, gl_off_a, N_TILES_A)

_load_lds(ga_div, a_next0, A0_gl_offset + 1 * BLOCK_K, global_offsets, N_TILES_A)
_load_lds(gb_div, b_next0, B0_gl_offset + 1 * BLOCK_K, global_offsets, N_TILES_B)
_load_lds(gb_div, b_next1, B1_gl_offset + 1 * BLOCK_K, global_offsets, N_TILES_B)
_load_lds(ga_div, a_next1, A1_gl_offset + 1 * BLOCK_K, global_offsets, N_TILES_A)
_load_lds(ga_div, a_next0, A0_gl_offset + 1 * A_K_STEP, gl_off_a, N_TILES_A)
_load_lds(gb_div, b_next0, B0_gl_offset + 1 * B_K_STEP, gl_off_b, N_TILES_B)
_load_lds(gb_div, b_next1, B1_gl_offset + 1 * B_K_STEP, gl_off_b, N_TILES_B)
_load_lds(ga_div, a_next1, A1_gl_offset + 1 * A_K_STEP, gl_off_a, N_TILES_A)

_wait_barrier((3 * N_TILES_A) + (4 * N_TILES_B))

a0_frag = _load_rt(a_cur0, wave_i, N_TILES_A)

_wait_barrier((3 * N_TILES_A) + (3 * N_TILES_B))

b0_frag = _load_rt(b_cur0, wave_j, N_TILES_B)
b0_frag = _load_rt(b_cur0, wave_j, N_TILES_B, preshuffled=b_preshuffled)

for k in range_constexpr(K_ITERS - 2):
_wait_barrier((2 * N_TILES_A) + (2 * N_TILES_B))

c00_frag, b1_frag = _compute_block(
a_cur0,
ga_div,
A0_gl_offset + (k + 2) * BLOCK_K,
global_offsets,
A0_gl_offset + (k + 2) * A_K_STEP,
gl_off_a,
wave_j,
b_cur1,
N_TILES_A,
N_TILES_B,
a0_frag,
b0_frag,
c00_frag,
lds_src_preshuffled=b_preshuffled,
)

c01_frag, a1_frag = _compute_block(
b_cur0,
gb_div,
B0_gl_offset + (k + 2) * BLOCK_K,
global_offsets,
B0_gl_offset + (k + 2) * B_K_STEP,
gl_off_b,
wave_i,
a_cur1,
N_TILES_B,
Expand All @@ -453,8 +549,8 @@ def _compute_block(lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_t
c10_frag, a0_frag = _compute_block(
b_cur1,
gb_div,
B1_gl_offset + (k + 2) * BLOCK_K,
global_offsets,
B1_gl_offset + (k + 2) * B_K_STEP,
gl_off_b,
wave_i,
a_next0,
N_TILES_B,
Expand All @@ -467,15 +563,16 @@ def _compute_block(lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_t
c11_frag, b0_frag = _compute_block(
a_cur1,
ga_div,
A1_gl_offset + (k + 2) * BLOCK_K,
global_offsets,
A1_gl_offset + (k + 2) * A_K_STEP,
gl_off_a,
wave_j,
b_next0,
N_TILES_A,
N_TILES_B,
a1_frag,
b1_frag,
c11_frag,
lds_src_preshuffled=b_preshuffled,
)

a_cur0, a_next0 = a_next0, a_cur0
Expand All @@ -485,14 +582,14 @@ def _compute_block(lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_t

# Tail step k_iters - 2.
_wait_barrier((2 * N_TILES_A) + (2 * N_TILES_B))
b1_frag = _load_rt(b_cur1, wave_j, N_TILES_B)
b1_frag = _load_rt(b_cur1, wave_j, N_TILES_B, preshuffled=b_preshuffled)
c00_frag = _mfma_ABt_all(a0_frag, b0_frag, c00_frag)
a1_frag = _load_rt(a_cur1, wave_i, N_TILES_A)
c01_frag = _mfma_ABt_all(a0_frag, b1_frag, c01_frag)
_wait_barrier((1 * N_TILES_A) + (1 * N_TILES_B))
a0_frag = _load_rt(a_next0, wave_i, N_TILES_A)
c10_frag = _mfma_ABt_all(a1_frag, b0_frag, c10_frag)
b0_frag = _load_rt(b_next0, wave_j, N_TILES_B)
b0_frag = _load_rt(b_next0, wave_j, N_TILES_B, preshuffled=b_preshuffled)
c11_frag = _mfma_ABt_all(a1_frag, b1_frag, c11_frag)

a_cur0, a_next0 = a_next0, a_cur0
Expand All @@ -504,7 +601,7 @@ def _compute_block(lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_t
base_row = tile_i * BLOCK_M + wave_i * (N_TILES_A * 16)
base_col = tile_j * BLOCK_N + wave_j * (N_TILES_B * 16)
_wait_barrier(0)
b1_frag = _load_rt(b_cur1, wave_j, N_TILES_B)
b1_frag = _load_rt(b_cur1, wave_j, N_TILES_B, preshuffled=b_preshuffled)
a1_frag = _load_rt(a_cur1, wave_i, N_TILES_A)
c00_frag = _mfma_ABt_all(a0_frag, b0_frag, c00_frag)
c01_frag = _mfma_ABt_all(a0_frag, b1_frag, c01_frag)
Expand Down
Loading
Loading