Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 37 additions & 27 deletions benchmarks/bench_la_decode_vs_fla.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
Usage:
python benchmarks/bench_la_decode_vs_fla.py
python benchmarks/bench_la_decode_vs_fla.py --heads 64 --head-dim 128
python benchmarks/bench_la_decode_vs_fla.py --heads 32 --num-v-heads 64
python benchmarks/bench_la_decode_vs_fla.py --batch-sizes 1 8 64 256
"""

Expand All @@ -63,28 +64,32 @@
# ─────────────────────────────────────────────────────────────────────────────
# Core benchmark for one configuration
# ─────────────────────────────────────────────────────────────────────────────
def run_config(B, H, K, V, layer_idx, num_layers):
def run_config(B, H, HV, K, V, layer_idx, num_layers):
assert HV >= H and HV % H == 0, f"HV ({HV}) must be >= H ({H}) and divisible by H"
device = "cuda"
dtype = torch.bfloat16
scale = K**-0.5
group_size = HV // H

# Per-head log decay (Lightning Attention formula)
g_gamma = -(8 / H * (1 - layer_idx / num_layers)) * torch.arange(H, device=device, dtype=torch.float32)
g_gamma = -(8 / HV * (1 - layer_idx / num_layers)) * torch.arange(HV, device=device, dtype=torch.float32)
decay_scales = -g_gamma # la_decode convention

# ── Random inputs ──────────────────────────────────────────────────────
torch.manual_seed(42)
q_4d = torch.randn(B, 1, H, K, device=device, dtype=dtype)
k_4d = torch.randn(B, 1, H, K, device=device, dtype=dtype)
v_4d = torch.randn(B, 1, H, V, device=device, dtype=dtype)
state_init = torch.randn(B, H, K, V, device=device, dtype=torch.float32) * 0.01
v_4d = torch.randn(B, 1, HV, V, device=device, dtype=dtype)
state_init = torch.randn(B, HV, K, V, device=device, dtype=torch.float32) * 0.01
q_fla_4d = q_4d.repeat_interleave(group_size, dim=2)
k_fla_4d = k_4d.repeat_interleave(group_size, dim=2)

# ── fla reference output ───────────────────────────────────────────────
state_fla = state_init.clone()
with torch.no_grad():
o_fla_fp32, ht_fla = fused_recurrent_fwd(
q_4d,
k_4d,
q_fla_4d,
k_fla_4d,
v_4d,
g_gamma=g_gamma,
scale=scale,
Expand All @@ -94,11 +99,11 @@ def run_config(B, H, K, V, layer_idx, num_layers):
o_fla = o_fla_fp32.to(dtype)

# ── la_decode output ───────────────────────────────────────────────────
state_cute = state_init.clone().permute(0, 1, 3, 2).reshape(B * H, V, K).contiguous()
state_cute = state_init.clone().permute(0, 1, 3, 2).reshape(B * HV, V, K).contiguous()
q_3d = q_4d.squeeze(1)
k_3d = k_4d.squeeze(1)
v_3d = v_4d.squeeze(1)
out_cute = torch.zeros(B, H, V, device=device, dtype=dtype)
out_cute = torch.zeros(B, HV, V, device=device, dtype=dtype)
s_offsets = torch.arange(B, device=device, dtype=torch.int32)

with torch.no_grad():
Expand Down Expand Up @@ -128,7 +133,7 @@ def run_config(B, H, K, V, layer_idx, num_layers):
max_ref = torch.abs(o_fla_cmp).max().item()
rel_maxdiff = torch.abs(o_cute_cmp - o_fla_cmp).max().item() / (max_ref + 1e-8)

state_cute_back = state_cute.reshape(B, H, V, K).permute(0, 1, 3, 2).contiguous()
state_cute_back = state_cute.reshape(B, HV, V, K).permute(0, 1, 3, 2).contiguous()
state_relative_rms_error = relative_rms_error(ht_fla, state_cute_back)

# ==================================================================
Expand All @@ -140,16 +145,16 @@ def run_config(B, H, K, V, layer_idx, num_layers):
BV_fla = min(triton.next_power_of_2(V), 64)
NK = triton.cdiv(K, BK_fla)
NV = triton.cdiv(V, BV_fla)
fla_o_buf = torch.empty(NK, B, 1, H, V, device=device, dtype=torch.float32)
fla_ht_buf = torch.empty(B, H, K, V, device=device, dtype=torch.float32)
fla_o_sum = torch.empty(B, 1, H, V, device=device, dtype=torch.float32)
fla_o_buf = torch.empty(NK, B, 1, HV, V, device=device, dtype=torch.float32)
fla_ht_buf = torch.empty(B, HV, K, V, device=device, dtype=torch.float32)
fla_state_k = state_init.clone()
grid_fla = (NV, NK, B * H)
fla_o_sum = torch.empty(B, 1, HV, V, device=device, dtype=torch.float32)
grid_fla = (NV, NK, B * HV)

def kernel_fla():
fused_recurrent_fwd_kernel[grid_fla](
q=q_4d,
k=k_4d,
q=q_fla_4d,
k=k_fla_4d,
v=v_4d,
g=None,
g_gamma=g_gamma,
Expand All @@ -162,7 +167,7 @@ def kernel_fla():
scale=scale,
B=B,
T=1,
H=H,
H=HV,
K=K,
V=V,
BK=BK_fla,
Expand All @@ -176,9 +181,9 @@ def kernel_fla():
torch.sum(fla_o_buf, dim=0, out=fla_o_sum)

# cute kernel: pre-create compiled + stream handle
cute_state_k = state_init.clone().permute(0, 1, 3, 2).reshape(B * H, V, K).contiguous()
out_cute_k = torch.empty(B, H, V, device=device, dtype=dtype)
cache = _get_compiled_kernel(B, 1, H, K, V, cute_state_k.shape[0], scale, USE_FAST_MATH)
cute_state_k = state_init.clone().permute(0, 1, 3, 2).reshape(B * HV, V, K).contiguous()
out_cute_k = torch.empty(B, HV, V, device=device, dtype=dtype)
cache = _get_compiled_kernel(B, 1, H, HV, K, V, cute_state_k.shape[0], scale, USE_FAST_MATH)
compiled_cute = cache["compiled"]
stream_handle = cuda_drv.CUstream(torch.cuda.current_stream().cuda_stream)

Expand All @@ -193,13 +198,13 @@ def kernel_cute():
# Mode 2: WRAPPER (full call path as used in production)
# ==================================================================
wrap_fla_state = state_init.clone()
wrap_cute_state = state_init.clone().permute(0, 1, 3, 2).reshape(B * H, V, K).contiguous()
wrap_cute_out = torch.empty(B, H, V, device=device, dtype=dtype)
wrap_cute_state = state_init.clone().permute(0, 1, 3, 2).reshape(B * HV, V, K).contiguous()
wrap_cute_out = torch.empty(B, HV, V, device=device, dtype=dtype)

def wrapper_fla():
fused_recurrent_fwd(
q_4d,
k_4d,
q_fla_4d,
k_fla_4d,
v_4d,
g_gamma=g_gamma,
scale=scale,
Expand Down Expand Up @@ -233,6 +238,8 @@ def wrapper_cute():

return {
"B": B,
"H": H,
"HV": HV,
"kernel_fla_ms": kernel_fla_ms,
"kernel_cute_ms": kernel_cute_ms,
"kernel_speedup": kernel_fla_ms / kernel_cute_ms,
Expand All @@ -257,16 +264,19 @@ def main():
default=[1, 2, 4, 8, 16, 32, 64, 128, 256],
)
parser.add_argument("--heads", type=int, default=32)
parser.add_argument("--num-v-heads", type=int, default=None, help="Number of value heads (default: same as --heads)")
parser.add_argument("--head-dim", type=int, default=128)
parser.add_argument("--layer-idx", type=int, default=12)
parser.add_argument("--num-layers", type=int, default=24)
args = parser.parse_args()

H, K, V = args.heads, args.head_dim, args.head_dim
H, HV, K, V = args.heads, args.num_v_heads or args.heads, args.head_dim, args.head_dim
if HV < H or HV % H != 0:
raise ValueError(f"num_v_heads ({HV}) must be >= heads ({H}) and divisible by heads")

print("Lightning Attention Decode Benchmark")
print(" la_decode (CuTe DSL) vs fla fused_recurrent_fwd (Triton)")
print(f" H={H}, K={K}, V={V}, layer={args.layer_idx}/{args.num_layers}")
print(f" H={H}, HV={HV}, K={K}, V={V}, layer={args.layer_idx}/{args.num_layers}")
print(" dtype=bf16, state=fp32, T=1")

# ── Kernel-only comparison ──────────────────────────────────────────
Expand All @@ -282,7 +292,7 @@ def main():

results = []
for B in args.batch_sizes:
r = run_config(B, H, K, V, args.layer_idx, args.num_layers)
r = run_config(B, H, HV, K, V, args.layer_idx, args.num_layers)
results.append(r)
print(
f"{r['B']:>5} | {r['kernel_fla_ms']:>10.4f} | {r['kernel_cute_ms']:>10.4f} | "
Expand All @@ -293,7 +303,7 @@ def main():
# ── Wrapper comparison ──────────────────────────────────────────────
print(f"\n{'=' * 100}")
print(" Mode 2: WRAPPER (fused_recurrent_fwd vs linear_attention_decode, full call path)")
print(" fla: alloc o[NK,B,1,H,V]+ht[B,H,K,V] + kernel + sum(0); cute: cache lookup + CUstream + kernel")
print(" fla: alloc o[NK,B,1,HV,V]+ht[B,HV,K,V] + kernel + sum(0); cute: cache lookup + CUstream + kernel")
print(f"{'=' * 100}")
print(f"{'B':>5} | {'fla (ms)':>10} | {'cute (ms)':>10} | {'speedup':>8}")
print("─" * 50)
Expand Down
Loading