diff --git a/benchmarks/profile_indexer.py b/benchmarks/profile_indexer.py new file mode 100644 index 000000000..43f824139 --- /dev/null +++ b/benchmarks/profile_indexer.py @@ -0,0 +1,144 @@ +"""Profile the low-rank lightning-indexer at realistic shapes (bf16). + +Measures wall time and effective TFLOPS for the einsum baseline vs the +fused Triton kernel. + +Run inside the container: + docker exec zain-w2 sh -c 'cd /workspace && python benchmarks/profile_indexer.py' +""" + +import time + +import jax +import jax.numpy as jnp + +from transformer_engine.jax.sparse_attention.indexer import indexer + +try: + from transformer_engine.jax.triton_extensions.indexer import score_reduce_triton # noqa: F401 + _HAVE_HYBRID = True +except Exception as _e: # noqa: BLE001 + _HAVE_HYBRID = False + _HYBRID_IMPORT_ERROR = _e + + +# --- Inputs / FLOP accounting ---------------------------------------------------- + +def make_inputs(B, oH, T, S, d, d_c, H, d_i, dtype, seed=0): + keys = jax.random.split(jax.random.PRNGKey(seed), 6) + Q = jax.random.normal(keys[0], (B, oH, T, d), dtype=dtype) + K = jax.random.normal(keys[1], (B, oH, S, d), dtype=dtype) + W_uq = jax.random.normal(keys[2], (H, d_c, d_i), dtype=dtype) + W_dq = jax.random.normal(keys[3], (d, d_c), dtype=dtype) + W_k = jax.random.normal(keys[4], (d, d_i), dtype=dtype) + W_w = jax.random.normal(keys[5], (d, H), dtype=dtype) + return Q, K, W_uq, W_dq, W_k, W_w + + +def theoretical_flops(B, oH, T, S, d, d_c, H, d_i): + # 2 flops per multiply-add. Counts the contractions in the low-rank + # indexer with learnable output-weight projection: + # C_q = Q @ W_dq : 2 * B*oH * T * d_c * d + # H_q = einsum(C_q, W_uq) : 2 * B*oH * T * H * d_i * d_c + # H_k = K @ W_k : 2 * B*oH * S * d_i * d + # scores = relu(H_q @ H_k^T) : 2 * B*oH * T * H * S * d_i + # W_o = Q @ W_w : 2 * B*oH * T * d * H + # O = sum_h scores * W_o : 2 * B*oH * T * S * H + n = B * oH + return 2 * ( + n * T * d_c * d + + n * T * H * d_i * d_c + + n * S * d_i * d + + n * T * H * S * d_i + + n * T * d * H + + n * T * S * H + ) + + +def time_fn(fn, args, n_warmup=15, n_iter=50): + for _ in range(n_warmup): + out = fn(*args) + jax.block_until_ready(out) + t0 = time.perf_counter() + for _ in range(n_iter): + out = fn(*args) + jax.block_until_ready(out) + return (time.perf_counter() - t0) / n_iter + + +# --- Driver --------------------------------------------------------------------- + +CONFIGS = [ + #(B, oH, T, S, d, d_c, H, d_i) + ( 2, 64, 4096, 4096, 512, 1024, 64, 128), +] + + +def _build_impl(backend): + @jax.jit + def fn(Q, K, W_uq, W_dq, W_k, W_w): + return indexer(Q, K, W_uq, W_dq, W_k, W_w, backend=backend) + return fn + + +def _dump_autotuner_winner(): + """Print the autotuner-selected config(s) for _score_reduce_kernel.""" + if not _HAVE_HYBRID: + return + try: + from transformer_engine.jax.triton_extensions.indexer import ( + _score_reduce_kernel, + ) + except ImportError: + return + cache = getattr(_score_reduce_kernel, "cache", None) + if not cache: + print(" [autotune] no cache entries") + return + for key, cfg in cache.items(): + print(f" [autotune] key={key} -> {cfg}") + + +if not _HAVE_HYBRID: + print(f"[profile_indexer] Hybrid backend unavailable: {_HYBRID_IMPORT_ERROR}") + + +def main(): + print(f"jax devices: {jax.devices()}\n") + for B, oH, T, S, d, d_c, H, d_i in CONFIGS: + Q, K, W_uq, W_dq, W_k, W_w = make_inputs( + B, oH, T, S, d, d_c, H, d_i, jnp.bfloat16 + ) + args = (Q, K, W_uq, W_dq, W_k, W_w) + flops = theoretical_flops(B, oH, T, S, d, d_c, H, d_i) + + print(f"--- B={B} oH={oH} T={T} S={S} d={d} d_c={d_c} H={H} d_i={d_i} bfloat16 ---") + print(f" theoretical work = {flops/1e9:.2f} GFLOPs/call") + + # impls = [("baseline", _build_impl("reference"))] + impls = [] + if _HAVE_HYBRID: + impls.append(("hybrid", _build_impl("hybrid"))) + + baseline_ms = None + for name, fn in impls: + try: + sec = time_fn(fn, args) + tflops = flops / sec / 1e12 + ms = sec * 1e3 + if name == "baseline": + baseline_ms = ms + speed = "" + elif baseline_ms is not None: + speed = f" ({baseline_ms/ms:.2f}x baseline)" + else: + speed = "" + print(f" {name:<10} {ms:8.3f} ms {tflops:6.2f} TFLOP/s{speed}") + except Exception as e: # noqa: BLE001 + print(f" {name:<10} FAILED: {type(e).__name__}: {str(e).splitlines()[0]}") + _dump_autotuner_winner() + print() + + +if __name__ == "__main__": + main() diff --git a/benchmarks/profile_indexer_bwd.py b/benchmarks/profile_indexer_bwd.py new file mode 100644 index 000000000..4e2688078 --- /dev/null +++ b/benchmarks/profile_indexer_bwd.py @@ -0,0 +1,182 @@ +"""Profile lightning-indexer backward pass throughput (bf16). + +Measures wall time and effective TFLOPS for forward, backward, and +value_and_grad. Uses the standard "backward = 2x forward FLOPs" convention, +so value_and_grad total work = 3x forward FLOPs. + +Run inside the container: + docker exec zain-w2 sh -c 'cd /workspace && python benchmarks/profile_indexer_bwd.py' + +Select backends and passes via flags: + --backends reference hybrid + --passes fwd bwd vag +""" + +import argparse +import time + +import jax +import jax.numpy as jnp + +from transformer_engine.jax.sparse_attention.indexer import indexer + +try: + from transformer_engine.jax.triton_extensions.indexer import score_reduce_triton # noqa: F401 + _HAVE_HYBRID = True +except Exception as _e: # noqa: BLE001 + _HAVE_HYBRID = False + _HYBRID_IMPORT_ERROR = _e + + +ALL_BACKENDS = ["reference", "hybrid"] +ALL_PASSES = ["fwd", "bwd", "vag"] + + +def make_inputs(B, oH, T, S, d, d_c, H, d_i, dtype, seed=0): + keys = jax.random.split(jax.random.PRNGKey(seed), 6) + Q = jax.random.normal(keys[0], (B, oH, T, d), dtype=dtype) + K = jax.random.normal(keys[1], (B, oH, S, d), dtype=dtype) + W_uq = jax.random.normal(keys[2], (H, d_c, d_i), dtype=dtype) + W_dq = jax.random.normal(keys[3], (d, d_c), dtype=dtype) + W_k = jax.random.normal(keys[4], (d, d_i), dtype=dtype) + W_w = jax.random.normal(keys[5], (d, H), dtype=dtype) + return Q, K, W_uq, W_dq, W_k, W_w + + +def theoretical_fwd_flops(B, oH, T, S, d, d_c, H, d_i): + n = B * oH + return 2 * ( + n * T * d_c * d + + n * T * H * d_i * d_c + + n * S * d_i * d + + n * T * H * S * d_i + + n * T * d * H + + n * T * S * H + ) + + +def time_fn(fn, args, n_warmup=10, n_iter=30): + for _ in range(n_warmup): + out = fn(*args) + jax.tree_util.tree_map(lambda x: x.block_until_ready(), out) + t0 = time.perf_counter() + for _ in range(n_iter): + out = fn(*args) + jax.tree_util.tree_map(lambda x: x.block_until_ready(), out) + return (time.perf_counter() - t0) / n_iter + + +CONFIGS = [ + #(B, oH, T, S, d, d_c, H, d_i) + ( 2, 64, 1024, 1024, 512, 1024, 64, 128), +] + + +def _build_fwd(backend): + @jax.jit + def fn(Q, K, W_uq, W_dq, W_k, W_w): + O = indexer(Q, K, W_uq, W_dq, W_k, W_w, backend=backend) + return jnp.sum(O.astype(jnp.float32)) + return fn + + +def _build_bwd(backend): + """Backward only: returns gradients.""" + fwd = _build_fwd(backend) + return jax.jit(jax.grad(fwd, argnums=(0, 1, 2, 3, 4, 5))) + + +def _build_value_and_grad(backend): + fwd = _build_fwd(backend) + return jax.jit(jax.value_and_grad(fwd, argnums=(0, 1, 2, 3, 4, 5))) + + +PASS_SPECS = { + "fwd": ("forward", _build_fwd, 1), + "bwd": ("backward", _build_bwd, 2), + "vag": ("value_and_grad", _build_value_and_grad, 3), +} + + +def parse_args(): + p = argparse.ArgumentParser(description=__doc__) + p.add_argument( + "--backends", + nargs="+", + choices=ALL_BACKENDS, + default=None, + help=( + "Backends to benchmark. Default: 'reference' plus 'hybrid' if importable." + ), + ) + p.add_argument( + "--passes", + nargs="+", + choices=ALL_PASSES, + default=ALL_PASSES, + help="Which passes to run: fwd, bwd, vag. Default: all three.", + ) + return p.parse_args() + + +def resolve_backends(requested): + if requested is None: + backends = ["reference"] + if _HAVE_HYBRID: + backends.append("hybrid") + return backends + if "hybrid" in requested and not _HAVE_HYBRID: + print( + f"WARNING: 'hybrid' backend requested but unavailable " + f"({type(_HYBRID_IMPORT_ERROR).__name__}: {_HYBRID_IMPORT_ERROR}). " + "Running it anyway — expect failure." + ) + return requested + + +def main(): + args = parse_args() + backends = resolve_backends(args.backends) + passes = args.passes + + print(f"jax devices: {jax.devices()}") + print(f"backends: {backends}") + print(f"passes: {passes}\n") + + for B, oH, T, S, d, d_c, H, d_i in CONFIGS: + Q, K, W_uq, W_dq, W_k, W_w = make_inputs( + B, oH, T, S, d, d_c, H, d_i, jnp.bfloat16 + ) + fn_args = (Q, K, W_uq, W_dq, W_k, W_w) + fwd_flops = theoretical_fwd_flops(B, oH, T, S, d, d_c, H, d_i) + + print(f"--- B={B} oH={oH} T={T} S={S} d={d} d_c={d_c} H={H} d_i={d_i} bfloat16 ---") + print(f" forward GFLOPs/call: {fwd_flops/1e9:.2f}") + if "bwd" in passes: + print(f" bwd GFLOPs/call (~2x): {2*fwd_flops/1e9:.2f}") + if "vag" in passes: + print(f" f+b GFLOPs/call (~3x): {3*fwd_flops/1e9:.2f}") + print() + + print(f" {'backend':<10s} {'pass':<14s} {'ms':>8s} {'TFLOP/s':>8s}") + + for backend in backends: + for pass_key in passes: + label, builder, flop_mult = PASS_SPECS[pass_key] + try: + fn = builder(backend) + sec = time_fn(fn, fn_args) + ms = sec * 1e3 + tflops = flop_mult * fwd_flops / sec / 1e12 + print(f" {backend:<10s} {label:<14s} {ms:8.3f} {tflops:8.2f}") + except Exception as e: # noqa: BLE001 + msg = str(e).splitlines()[0] if str(e) else "" + print( + f" {backend:<10s} {label:<14s} FAILED: " + f"{type(e).__name__}: {msg}" + ) + print() + + +if __name__ == "__main__": + main() diff --git a/benchmarks/profile_indexer_topk.py b/benchmarks/profile_indexer_topk.py new file mode 100644 index 000000000..68b9c86a5 --- /dev/null +++ b/benchmarks/profile_indexer_topk.py @@ -0,0 +1,146 @@ +"""Profile indexer + per-row top-k along T_s (bf16). + +Same canonical backends as ``profile_indexer.py`` (reference einsum vs +hybrid einsum+Triton score-reduce), with ``jax.lax.top_k`` applied to the +score matrix. Reports wall time and effective TFLOPS for the indexer +compute (top-k is comparison-only and counted as 0 FLOP). + +Run inside the container: + docker exec zain-w2 sh -c 'cd /workspace && python benchmarks/profile_indexer_topk.py' +""" + +import time + +import jax +import jax.numpy as jnp + +from transformer_engine.jax.sparse_attention.indexer import indexer, indexer_topk + +try: + from transformer_engine.jax.triton_extensions.indexer import score_reduce_triton # noqa: F401 + _HAVE_HYBRID = True +except Exception as _e: # noqa: BLE001 + _HAVE_HYBRID = False + _HYBRID_IMPORT_ERROR = _e + + +# --- Inputs / FLOP accounting --------------------------------------------------- +# Mirrors profile_indexer.py — keeping the two profilers in lockstep. + +def make_inputs(B, oH, T, S, d, d_c, H, d_i, dtype, seed=0): + keys = jax.random.split(jax.random.PRNGKey(seed), 6) + Q = jax.random.normal(keys[0], (B, oH, T, d), dtype=dtype) + K = jax.random.normal(keys[1], (B, oH, S, d), dtype=dtype) + W_uq = jax.random.normal(keys[2], (H, d_c, d_i), dtype=dtype) + W_dq = jax.random.normal(keys[3], (d, d_c), dtype=dtype) + W_k = jax.random.normal(keys[4], (d, d_i), dtype=dtype) + W_w = jax.random.normal(keys[5], (d, H), dtype=dtype) + return Q, K, W_uq, W_dq, W_k, W_w + + +def theoretical_flops(B, oH, T, S, d, d_c, H, d_i): + # 2 flops per multiply-add. top-k is comparison-only, counted as 0 FLOP. + n = B * oH + return 2 * ( + n * T * d_c * d + + n * T * H * d_i * d_c + + n * S * d_i * d + + n * T * H * S * d_i + + n * T * d * H + + n * T * S * H + ) + + +def time_fn(fn, args, n_warmup=15, n_iter=50): + for _ in range(n_warmup): + out = fn(*args) + jax.tree_util.tree_map(lambda x: x.block_until_ready(), out) + t0 = time.perf_counter() + for _ in range(n_iter): + out = fn(*args) + jax.tree_util.tree_map(lambda x: x.block_until_ready(), out) + return (time.perf_counter() - t0) / n_iter + + +# --- Driver --------------------------------------------------------------------- + +CONFIGS = [ + #(B, oH, T, S, d, d_c, H, d_i) + ( 2, 64, 1024, 1024, 512, 1024, 64, 128), +] + +K_TOPK = 512 + + +def _build_topk(backend, k): + @jax.jit + def fn(Q, K, W_uq, W_dq, W_k, W_w): + scores = indexer(Q, K, W_uq, W_dq, W_k, W_w, backend=backend) + return jax.lax.top_k(scores, k) + return fn + + +def _build_fused_topk(k): + @jax.jit + def fn(Q, K, W_uq, W_dq, W_k, W_w): + return indexer_topk(Q, K, W_uq, W_dq, W_k, W_w, k=k) + return fn + + +@jax.jit +def _topk_only(scores): + return jax.lax.top_k(scores, K_TOPK) + + +if not _HAVE_HYBRID: + print(f"[profile_indexer_topk] Hybrid backend unavailable: {_HYBRID_IMPORT_ERROR}") + + +def main(): + print(f"jax devices: {jax.devices()}\nk = {K_TOPK}\n") + for B, oH, T, S, d, d_c, H, d_i in CONFIGS: + Q, K, W_uq, W_dq, W_k, W_w = make_inputs( + B, oH, T, S, d, d_c, H, d_i, jnp.bfloat16 + ) + args = (Q, K, W_uq, W_dq, W_k, W_w) + flops = theoretical_flops(B, oH, T, S, d, d_c, H, d_i) + + print(f"--- B={B} oH={oH} T={T} S={S} d={d} d_c={d_c} H={H} d_i={d_i} bfloat16 ---") + print(f" theoretical work = {flops/1e9:.2f} GFLOPs/call (top-k = 0 FLOP)") + + # impls = [("baseline+topk", _build_topk("reference", K_TOPK))] + impls = [] + if _HAVE_HYBRID: + impls.append(("hybrid+topk", _build_topk("hybrid", K_TOPK))) + impls.append(("hybrid_fused_topk", _build_fused_topk(K_TOPK))) + + baseline_ms = None + for name, fn in impls: + try: + sec = time_fn(fn, args) + ms = sec * 1e3 + tflops = flops / sec / 1e12 + if name == "baseline+topk": + baseline_ms = ms + speed = "" + elif baseline_ms is not None: + speed = f" ({baseline_ms/ms:.2f}x baseline)" + else: + speed = "" + print(f" {name:<18} {ms:8.3f} ms {tflops:6.2f} TFLOP/s{speed}") + except Exception as e: # noqa: BLE001 + print(f" {name:<18} FAILED: {type(e).__name__}: {str(e).splitlines()[0]}") + + # Time top_k alone on a precomputed (reference) score matrix to + # isolate the top-k cost from the indexer compute. + try: + scores_mat = indexer(*args, backend="reference") + sec = time_fn(_topk_only, (scores_mat,)) + print(f" {'(top_k alone)':<18} {sec*1e3:8.3f} ms") + except Exception as e: # noqa: BLE001 + print(f" {'(top_k alone) FAILED':<18} {type(e).__name__}") + print() + + +if __name__ == "__main__": + main() diff --git a/tests/jax/test_indexer.py b/tests/jax/test_indexer.py new file mode 100644 index 000000000..0f3597777 --- /dev/null +++ b/tests/jax/test_indexer.py @@ -0,0 +1,123 @@ +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +"""Correctness tests for the lightning-indexer JAX ops. + +Ported from the in-module ``__main__`` smoke tests of +``transformer_engine.jax.sparse_attention.indexer``. The hybrid and top-k backends require +rank-4 ``(B, oH, T, d)`` inputs, so every leading shape here is length-2. +""" + +import jax +import jax.numpy as jnp +import pytest + +from transformer_engine.jax.sparse_attention.indexer import ( + LightningIndexer, + indexer, + indexer_topk, +) + + +def _indexer_inputs(B, oH, T_t, T_s, d, d_c, H, d_i, seed): + keys = jax.random.split(jax.random.PRNGKey(seed), 6) + Q = jax.random.normal(keys[0], (B, oH, T_t, d), dtype=jnp.bfloat16) + K = jax.random.normal(keys[1], (B, oH, T_s, d), dtype=jnp.bfloat16) + W_uq = jax.random.normal(keys[2], (H, d_c, d_i), dtype=jnp.bfloat16) + W_dq = jax.random.normal(keys[3], (d, d_c), dtype=jnp.bfloat16) + W_k = jax.random.normal(keys[4], (d, d_i), dtype=jnp.bfloat16) + W_w = jax.random.normal(keys[5], (d, H), dtype=jnp.bfloat16) + return Q, K, W_uq, W_dq, W_k, W_w + + +def _rel_err(actual, ref): + actual = actual.astype(jnp.float32) + ref = ref.astype(jnp.float32) + return float(jnp.linalg.norm(actual - ref) / (jnp.linalg.norm(ref) + 1e-30)) + + +@pytest.mark.parametrize("B,oH", [(2, 3), (1, 1), (1, 4)]) +def test_hybrid_matches_reference(B, oH): + """Hybrid Triton score-reduce matches the pure-einsum reference forward.""" + args = _indexer_inputs(B, oH, T_t=64, T_s=64, d=32, d_c=32, H=8, d_i=32, seed=100) + o_ref = indexer(*args, backend="reference") + o_hyb = indexer(*args, backend="hybrid") + assert o_hyb.shape == o_ref.shape + assert _rel_err(o_hyb, o_ref) < 5e-3 + + +@pytest.mark.parametrize("k", [32]) +def test_topk_matches_reference(k): + """Fused top-k selects the same scores as reference + ``jax.lax.top_k``. + + Index set-equality is too strict (backends break ties differently), so the + check is on the *scores* at the fused-selected indices. ``k`` is kept in the + top quartile of ``T_s``: a cutoff in the dense middle of the distribution + makes boundary scores closely spaced, so the kernel's fp32 ranking and the + bf16-rounded reference grid resolve near-ties differently (a test-grid + sensitivity, not a kernel error). + """ + args = _indexer_inputs(2, 3, T_t=64, T_s=128, d=32, d_c=32, H=16, d_i=32, seed=200) + o_ref = indexer(*args, backend="reference").astype(jnp.float32) + topk_idx = indexer_topk(*args, k=k) + assert topk_idx.shape == (2, 3, 64, k) + + ref_vals = jax.lax.top_k(o_ref, k=k)[0] + picked = jnp.take_along_axis(o_ref, topk_idx, axis=-1) + picked_sorted = jnp.sort(picked, axis=-1)[..., ::-1] + max_rel = float((jnp.abs(ref_vals - picked_sorted) / (jnp.abs(ref_vals) + 1e-6)).max()) + assert max_rel < 1e-2 + + +@pytest.mark.parametrize("B,oH", [(2, 3), (1, 2)]) +def test_hybrid_backward_matches_reference_grad(B, oH): + """``jax.grad`` through the hybrid backend matches grad through reference. + + Tolerance is 5e-2 (bf16 projections + Triton score recompute) — looser than + the 5e-3 forward tolerance; tighten once per-grad error is characterized + on-device. + """ + args = _indexer_inputs(B, oH, T_t=32, T_s=32, d=32, d_c=32, H=8, d_i=32, seed=300) + + def _loss(backend): + def inner(*a): + return jnp.sum(indexer(*a, backend=backend).astype(jnp.float32)) + return inner + + argnums = (0, 1, 2, 3, 4, 5) + grads_ref = jax.grad(_loss("reference"), argnums=argnums)(*args) + grads_hyb = jax.grad(_loss("hybrid"), argnums=argnums)(*args) + for gr, gh in zip(grads_ref, grads_hyb): + assert _rel_err(gh, gr) < 5e-2 + + +def test_lightning_indexer_module_matches_functional(): + """``LightningIndexer`` (Flax module) reproduces the functional ``indexer`` + when fed the module's own initialized weights.""" + B, oH, T_t, T_s, d, d_c, H, d_i = 2, 3, 64, 64, 32, 32, 8, 32 + keys = jax.random.split(jax.random.PRNGKey(7), 3) + Q = jax.random.normal(keys[0], (B, oH, T_t, d), dtype=jnp.bfloat16) + K = jax.random.normal(keys[1], (B, oH, T_s, d), dtype=jnp.bfloat16) + + mod = LightningIndexer(num_heads=H, d_c=d_c, d_i=d_i, backend="reference") + variables = mod.init(keys[2], Q, K) + o_mod = mod.apply(variables, Q, K) + assert o_mod.shape == (B, oH, T_t, T_s) + + p = variables["params"] + o_fn = indexer(Q, K, p["W_uq"], p["W_dq"], p["W_k"], p["W_w"], backend="reference") + assert _rel_err(o_mod, o_fn) < 1e-5 + + +def test_lightning_indexer_topk_mode(): + """``LightningIndexer(topk=k)`` returns fused top-k indices of shape (..., T, k).""" + B, oH, T_t, T_s, d, d_c, H, d_i, k = 2, 3, 64, 128, 32, 32, 16, 32, 32 + keys = jax.random.split(jax.random.PRNGKey(9), 2) + Q = jax.random.normal(keys[0], (B, oH, T_t, d), dtype=jnp.bfloat16) + K = jax.random.normal(keys[1], (B, oH, T_s, d), dtype=jnp.bfloat16) + + mod = LightningIndexer(num_heads=H, d_c=d_c, d_i=d_i, topk=k) + variables = mod.init(jax.random.PRNGKey(0), Q, K) + idx = mod.apply(variables, Q, K) + assert idx.shape == (B, oH, T_t, k) + assert idx.dtype == jnp.int32 diff --git a/tests/jax/test_sparse_attention.py b/tests/jax/test_sparse_attention.py new file mode 100644 index 000000000..a911b6049 --- /dev/null +++ b/tests/jax/test_sparse_attention.py @@ -0,0 +1,371 @@ +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +"""Tests for Deep Sparse Attention (DSA) composition + HCA / fused scaffold contracts.""" + +import jax +import jax.numpy as jnp +import pytest +from flax import linen as nn + +from transformer_engine.jax.sparse_attention import ( + DeepSparseAttention, + deep_sparse_attention_core, + _causal_keep_mask, + _topk_indices_to_attn_mask, +) +from transformer_engine.jax.sparse_attention.compressed_attention import ( + HeavilyCompressedAttention, + heavily_compressed_attention, +) +from transformer_engine.jax.sparse_attention.indexer import indexer +from transformer_engine.jax.triton_extensions import fused_sparse_attention_triton + + +@pytest.fixture(autouse=True) +def _force_unfused_attn(monkeypatch): + """Override conftest's enable_fused_attn_after_hopper for this module. + + The DSA composition path uses an arbitrary topk-derived attention mask. The + fused-attention backends on some platforms restrict mask semantics (padding- + style only). Force the unfused softmax path so reference comparisons hold. + Production callers can still set NVTE_FUSED_ATTN=1 — these tests are only + asserting the composition math, not the fused-path's mask handling. + """ + monkeypatch.setenv("NVTE_FUSED_ATTN", "0") + yield + + +# ----------------------------------------------------------------------------- +# Helpers +# ----------------------------------------------------------------------------- + + +def _make_dsa_module(*, oH=4, D=8, iH=2, idc=16, idi=16, k=4, + backend="composition", indexer_backend="hybrid"): + return DeepSparseAttention( + head_dim=D, + num_attention_heads=oH, + indexer_num_heads=iH, + indexer_d_c=idc, + indexer_d_i=idi, + topk=k, + backend=backend, + indexer_backend=indexer_backend, + dtype=jnp.bfloat16, + ) + + +def _make_inputs(B=1, oH=4, T=16, hidden=32, dtype=jnp.bfloat16, seed=0): + """Rank-4 inputs [B, oH, T, hidden].""" + return jax.random.normal(jax.random.PRNGKey(seed), (B, oH, T, hidden), dtype=dtype) + + +def _ref_dense_softmax_per_head(query, key, value, mask_out, scale): + """Per-head dense softmax attention with an arbitrary mask (no DPA). + + query/key/value: [B, oH, T, head_dim]; mask_out: [B, oH, T_t, T_s] uint8 + (1 = mask out). Returns [B, oH, T_t, head_dim]. + """ + logits = jnp.einsum("bhtd,bhsd->bhts", query, key) * scale + logits = logits.astype(jnp.float32) + logits = jnp.where( + mask_out.astype(jnp.bool_), + jnp.asarray(-jnp.inf, jnp.float32), + logits, + ) + weights = jax.nn.softmax(logits, axis=-1) + return jnp.einsum("bhts,bhsd->bhtd", weights.astype(value.dtype), value) + + +def _ref_dsa_jax( + inputs_q, inputs_kv, + W_q_kernel, W_k_kernel, W_v_kernel, + W_uq, W_dq, W_k_idx, W_w, + *, + head_dim, k, causal, +): + """Pure-JAX reference matching ``deep_sparse_attention_core``.""" + T_t = inputs_q.shape[2] + T_s = inputs_kv.shape[2] + + q = jnp.einsum("bhtd,dk->bhtk", inputs_q, W_q_kernel) + kk = jnp.einsum("bhsd,dk->bhsk", inputs_kv, W_k_kernel) + v = jnp.einsum("bhsd,dk->bhsk", inputs_kv, W_v_kernel) + + scores = indexer( + inputs_q, inputs_kv, W_uq, W_dq, W_k_idx, W_w, + backend="reference", out_dtype=jnp.float32, + ) + if causal: + ckeep = _causal_keep_mask(T_t, T_s)[None, None, :, :] + scores = jnp.where(ckeep, scores, jnp.asarray(-jnp.inf, jnp.float32)) + _, topk_idx = jax.lax.top_k(scores, min(k, T_s)) + mask_out = _topk_indices_to_attn_mask(topk_idx, T_s, causal=causal) + return _ref_dense_softmax_per_head( + q, kk, v, mask_out, scale=1.0 / jnp.sqrt(head_dim).astype(q.dtype), + ) + + +# ----------------------------------------------------------------------------- +# Mask helpers +# ----------------------------------------------------------------------------- + + +def test_causal_keep_mask_self_attention(): + """T_t == T_s: standard lower-triangular keep mask.""" + m = _causal_keep_mask(4, 4) + expected = jnp.tril(jnp.ones((4, 4), dtype=jnp.bool_)) + assert jnp.array_equal(m, expected) + + +def test_causal_keep_mask_cross_attention_with_prefix(): + """T_t < T_s: causal cutoff aligned to bottom-right (prefix context allowed).""" + m = _causal_keep_mask(2, 5) # T_t=2, T_s=5 → prefix of 3 always visible + expected = jnp.array( + [[True, True, True, True, False], + [True, True, True, True, True]], + dtype=jnp.bool_, + ) + assert jnp.array_equal(m, expected) + + +def test_topk_indices_to_attn_mask_basic(): + # B=1, oH=1, T_t=2, k=2 + indices = jnp.array([[[[0, 2], [1, 3]]]], dtype=jnp.int32) # [1, 1, 2, 2] + mask_out = _topk_indices_to_attn_mask(indices, T_s=4, causal=False) + expected = jnp.array( + [[[[0, 1, 0, 1], + [1, 0, 1, 0]]]], + dtype=jnp.uint8, + ) + assert mask_out.shape == (1, 1, 2, 4) + assert mask_out.dtype == jnp.uint8 + assert jnp.array_equal(mask_out, expected) + + +def test_topk_indices_to_attn_mask_per_head_diverges(): + """Different oH heads pick different topk → different per-head masks.""" + # B=1, oH=2, T_t=1, k=2 + indices = jnp.array([[[[0, 1]], [[2, 3]]]], dtype=jnp.int32) + mask_out = _topk_indices_to_attn_mask(indices, T_s=4, causal=False) + # Head 0 keeps {0,1} → mask [0,0,1,1]; head 1 keeps {2,3} → mask [1,1,0,0]. + expected = jnp.array( + [[[[0, 0, 1, 1]], + [[1, 1, 0, 0]]]], + dtype=jnp.uint8, + ) + assert mask_out.shape == (1, 2, 1, 4) + assert jnp.array_equal(mask_out, expected) + + +def test_topk_indices_to_attn_mask_causal_intersect(): + """Causal AND topk in self-attention: query t cannot keep positions > t.""" + # B=1, oH=1, T_t=T_s=4, k=2 + indices = jnp.array([[[[2, 3], [0, 1], [1, 2], [2, 3]]]], dtype=jnp.int32) + mask_out = _topk_indices_to_attn_mask(indices, T_s=4, causal=True) + # q=0: picks {2,3}, causal {0} → intersect {} → all-1 row. + assert bool((mask_out[0, 0, 0, :] == 1).all()) + # q=2: picks {1,2}, causal {0,1,2} → intersect {1,2}. + assert mask_out[0, 0, 2, 0] == 1 + assert mask_out[0, 0, 2, 1] == 0 + assert mask_out[0, 0, 2, 2] == 0 + assert mask_out[0, 0, 2, 3] == 1 + + +# ----------------------------------------------------------------------------- +# DSA composition correctness +# ----------------------------------------------------------------------------- + + +@pytest.mark.parametrize("B,oH,T,hidden,D,iH,idc,idi,k", [ + (1, 4, 16, 32, 8, 2, 16, 16, 4), + (2, 4, 32, 64, 16, 2, 32, 32, 8), + (1, 2, 8, 16, 8, 1, 8, 8, 2), +]) +def test_dsa_composition_vs_pure_jax_reference(B, oH, T, hidden, D, iH, idc, idi, k): + """DSA module output (composition + hybrid indexer) matches pure-JAX reference.""" + inputs = _make_inputs(B=B, oH=oH, T=T, hidden=hidden) + keys = jax.random.split(jax.random.PRNGKey(123), 2) + module = _make_dsa_module(oH=oH, D=D, iH=iH, idc=idc, idi=idi, k=k) + params = module.init(keys[0], inputs, inputs, deterministic=True) + out = module.apply(params, inputs, inputs, deterministic=True) + assert out.shape == (B, oH, T, D) + + p = nn.meta.unbox(params)["params"] + out_ref = _ref_dsa_jax( + inputs, inputs, + p["query"]["kernel"], p["key"]["kernel"], p["value"]["kernel"], + p["indexer_W_uq"], p["indexer_W_dq"], p["indexer_W_k"], p["indexer_W_w"], + head_dim=D, k=k, causal=True, + ) + + diff = (out.astype(jnp.float32) - out_ref.astype(jnp.float32)) + rel = float( + jnp.linalg.norm(diff) + / (jnp.linalg.norm(out_ref.astype(jnp.float32)) + 1e-30) + ) + assert rel < 5e-2, f"DSA output diverges from reference: rel.err={rel:.3e}" + + +def test_dsa_composition_reference_indexer_matches_hybrid(): + """Same correctness check using indexer_backend='reference' (pure einsum).""" + B, oH, T, hidden, D, iH, idc, idi, k = 1, 2, 8, 16, 8, 1, 8, 8, 2 + inputs = _make_inputs(B=B, oH=oH, T=T, hidden=hidden) + keys = jax.random.split(jax.random.PRNGKey(7), 2) + module = _make_dsa_module(oH=oH, D=D, iH=iH, idc=idc, idi=idi, k=k, + indexer_backend="reference") + params = module.init(keys[0], inputs, inputs, deterministic=True) + out = module.apply(params, inputs, inputs, deterministic=True) + assert out.shape == (B, oH, T, D) + + +@pytest.mark.parametrize("T_t,T_s,k", [(8, 8, 4), (8, 8, 2), (16, 16, 8)]) +def test_dsa_topk_count_equals_kept_count_under_causal(T_t, T_s, k): + """For each query t, the number of unmasked key positions equals min(k, t+1).""" + B, oH, hidden = 1, 2, 16 + inputs = _make_inputs(B=B, oH=oH, T=T_t, hidden=hidden, seed=7) + keys = jax.random.split(jax.random.PRNGKey(7), 2) + module = _make_dsa_module(oH=oH, D=8, iH=1, idc=8, idi=8, k=k) + params = module.init(keys[0], inputs, inputs, deterministic=True) + + p = nn.meta.unbox(params)["params"] + scores = indexer( + inputs, inputs, + p["indexer_W_uq"], p["indexer_W_dq"], p["indexer_W_k"], p["indexer_W_w"], + backend="reference", out_dtype=jnp.float32, + ) # [B, oH, T_t, T_s] + ckeep = _causal_keep_mask(T_t, T_s)[None, None, :, :] + scores_masked = jnp.where(ckeep, scores, -jnp.inf) + _, topk_idx = jax.lax.top_k(scores_masked, min(k, T_s)) + mask_out = _topk_indices_to_attn_mask(topk_idx, T_s, causal=True) + # Each (b, h, t) row should have exactly min(k, t+1) zeros. + for h in range(oH): + kept_per_q = (mask_out[0, h] == 0).sum(axis=-1) # [T_t] + for t in range(T_t): + expected = min(k, t + 1) + assert int(kept_per_q[t]) == expected, ( + f"oH={h}, t={t}: kept {int(kept_per_q[t])} keys, expected {expected}" + ) + + +# ----------------------------------------------------------------------------- +# Backward shape sanity +# ----------------------------------------------------------------------------- + + +def test_dsa_backward_runs_without_shape_errors(): + inputs = _make_inputs(B=1, oH=2, T=8, hidden=16) + keys = jax.random.split(jax.random.PRNGKey(5), 2) + module = _make_dsa_module(oH=2, D=8, iH=1, idc=8, idi=8, k=2) + params = module.init(keys[0], inputs, inputs, deterministic=True) + + def loss(p, x): + out = module.apply(p, x, x, deterministic=True) + return jnp.sum(out.astype(jnp.float32)) + + grads = jax.grad(loss)(params, inputs) + leaves = jax.tree_util.tree_leaves(grads) + assert all(bool(jnp.isfinite(leaf).all()) for leaf in leaves), \ + "DSA backward produced NaN/Inf gradients" + + +# ----------------------------------------------------------------------------- +# Scaffold contracts +# ----------------------------------------------------------------------------- + + +def test_dsa_fused_backend_raises_not_implemented(): + inputs = _make_inputs(B=1, oH=2, T=8, hidden=16) + keys = jax.random.split(jax.random.PRNGKey(0), 2) + module = _make_dsa_module(oH=2, D=8, iH=1, idc=8, idi=8, k=2, backend="fused") + # Flax materializes the call during init, so NotImplementedError fires there. + with pytest.raises(NotImplementedError, match="phase-2 scaffold"): + module.init(keys[0], inputs, inputs, deterministic=True) + + +def test_fused_sparse_attention_triton_direct_raises(): + """Calling the primitive directly also raises (locked contract).""" + q = jnp.zeros((1, 2, 4, 8), dtype=jnp.bfloat16) # [B, T, H, D] + kk = jnp.zeros((1, 2, 4, 8), dtype=jnp.bfloat16) + v = jnp.zeros((1, 2, 4, 8), dtype=jnp.bfloat16) + iq = jnp.zeros((1, 2, 4, 8), dtype=jnp.bfloat16) + ik = jnp.zeros((1, 2, 8), dtype=jnp.bfloat16) + iw = jnp.zeros((1, 2, 2), dtype=jnp.bfloat16) + with pytest.raises(NotImplementedError, match="phase-2 scaffold"): + jax.jit( + lambda *args: fused_sparse_attention_triton(*args, k=2) + )(q, kk, v, iq, ik, iw) + + +def test_hca_module_raises_not_implemented(): + module = HeavilyCompressedAttention( + head_dim=8, num_attention_heads=4, + q_lora_rank=16, kv_lora_rank=16, + qk_nope_head_dim=4, qk_rope_head_dim=4, v_head_dim=8, + ) + inputs = jax.random.normal(jax.random.PRNGKey(0), (1, 4, 32), dtype=jnp.bfloat16) + keys = jax.random.split(jax.random.PRNGKey(0), 2) + with pytest.raises(NotImplementedError, match="design.*deferred|DESIGN DEFERRED|scaffold"): + module.init(keys[0], inputs, inputs, deterministic=True) + + +def test_hca_functional_raises_not_implemented(): + inputs = jax.random.normal(jax.random.PRNGKey(0), (1, 4, 32), dtype=jnp.bfloat16) + with pytest.raises(NotImplementedError): + heavily_compressed_attention( + inputs, inputs, + head_dim=8, num_attention_heads=4, + q_lora_rank=16, kv_lora_rank=16, + qk_nope_head_dim=4, qk_rope_head_dim=4, v_head_dim=8, + ) + + +# ----------------------------------------------------------------------------- +# Functional API surface +# ----------------------------------------------------------------------------- + + +def test_deep_sparse_attention_core_invalid_backend_raises(): + q = jnp.zeros((1, 2, 4, 8)) # rank-4 + iq = jnp.zeros((1, 2, 4, 16)) + W = jnp.zeros((16, 8)) + Wuq = jnp.zeros((1, 8, 8)) + with pytest.raises(ValueError, match="unknown backend"): + deep_sparse_attention_core( + q, q, q, iq, iq, Wuq, W, W[:, :8], W[:, :1], + k=2, backend="bogus", + ) + + +def test_deep_sparse_attention_core_unsupported_mask_type_raises(): + q = jnp.zeros((1, 2, 4, 8)) + iq = jnp.zeros((1, 2, 4, 16)) + W = jnp.zeros((16, 8)) + Wuq = jnp.zeros((1, 8, 8)) + with pytest.raises(NotImplementedError, match="attn_mask_type"): + deep_sparse_attention_core( + q, q, q, iq, iq, Wuq, W, W[:, :8], W[:, :1], + k=2, attn_mask_type="padding", + ) + + +def test_deep_sparse_attention_core_rejects_rank3_inputs(): + """Rank-3 inputs (missing oH) should be rejected with a clear error.""" + q3 = jnp.zeros((1, 4, 8)) # rank-3 + iq4 = jnp.zeros((1, 2, 4, 16)) + W = jnp.zeros((16, 8)) + Wuq = jnp.zeros((1, 8, 8)) + with pytest.raises(ValueError, match="rank-4"): + deep_sparse_attention_core( + q3, q3, q3, iq4, iq4, Wuq, W, W[:, :8], W[:, :1], + k=2, + ) + + +def test_dsa_module_rejects_oh_mismatch(): + """Module asserts num_attention_heads matches inputs.shape[1].""" + inputs = _make_inputs(B=1, oH=3, T=8, hidden=16) # oH=3 in input + module = _make_dsa_module(oH=4, D=8, iH=1, idc=8, idi=8, k=2) # oH=4 in module + with pytest.raises(ValueError, match="must equal num_attention_heads"): + module.init(jax.random.PRNGKey(0), inputs, inputs, deterministic=True) diff --git a/transformer_engine/jax/__init__.py b/transformer_engine/jax/__init__.py index d0afc1ff2..9172a9867 100644 --- a/transformer_engine/jax/__init__.py +++ b/transformer_engine/jax/__init__.py @@ -34,6 +34,9 @@ from . import flax from . import quantize +# AMD lightning-indexer / sparse-attention staging module. +from . import sparse_attention + from .quantize import autocast, fp8_autocast, update_collections from .quantize import NVTE_FP8_COLLECTION_NAME @@ -51,4 +54,5 @@ "MeshResource", "flax", "quantize", + "sparse_attention", ] diff --git a/transformer_engine/jax/sparse_attention/__init__.py b/transformer_engine/jax/sparse_attention/__init__.py new file mode 100644 index 000000000..8a66d05d4 --- /dev/null +++ b/transformer_engine/jax/sparse_attention/__init__.py @@ -0,0 +1,45 @@ +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +"""Deep Sparse Attention (DSA) family. + +Bundles the lightning indexer and the attention modules built on top of it: + + * :mod:`~transformer_engine.jax.sparse_attention.indexer` — the lightning + indexer op (``indexer`` / ``indexer_topk``). + * :mod:`~transformer_engine.jax.sparse_attention.dsa` — Deep Sparse + Attention, which composes the indexer with dense attention. + * :mod:`~transformer_engine.jax.sparse_attention.compressed_attention` — + Heavily Compressed Attention (MLA-style scaffold, design deferred). + +The Triton kernel backends live in +:mod:`transformer_engine.jax.triton_extensions` alongside the other Triton +kernels. +""" + +from . import indexer +from . import dsa +from . import compressed_attention + +from .indexer import LightningIndexer +from .dsa import ( + DeepSparseAttention, + deep_sparse_attention_core, + _causal_keep_mask, + _topk_indices_to_attn_mask, +) +from .compressed_attention import ( + HeavilyCompressedAttention, + heavily_compressed_attention, +) + +__all__ = [ + "indexer", + "dsa", + "compressed_attention", + "LightningIndexer", + "DeepSparseAttention", + "deep_sparse_attention_core", + "HeavilyCompressedAttention", + "heavily_compressed_attention", +] diff --git a/transformer_engine/jax/sparse_attention/compressed_attention.py b/transformer_engine/jax/sparse_attention/compressed_attention.py new file mode 100644 index 000000000..3d7aee172 --- /dev/null +++ b/transformer_engine/jax/sparse_attention/compressed_attention.py @@ -0,0 +1,128 @@ +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +"""Heavily Compressed Attention (HCA) — design-deferred scaffold. + +This module stakes out the API surface for a future MLA-style (DeepSeek-V2/V3 +Multi-head Latent Attention) implementation. The Flax module and functional +entry point both raise :class:`NotImplementedError` so downstream code can +write against the eventual signature today while the design is finalized. + +The intended math (for context, not yet implemented):: + + C_q = LayerNorm(X) @ W_dq # (..., T, q_lora_rank) + Q = C_q @ W_uq # (..., T, H, qk_nope_head_dim + qk_rope_head_dim) + C_kv = X @ W_dkv # (..., S, kv_lora_rank) <-- KV cache stores this + K = C_kv @ W_uk # (..., S, H, qk_nope_head_dim + qk_rope_head_dim) + V = C_kv @ W_uv # (..., S, H, v_head_dim) + K_rope, K_nope = split(K, ...) + apply RoPE to (Q_rope, K_rope) + O = softmax(Q @ K^T / sqrt(d)) @ V + +See ``transformer_engine.jax.sparse_attention.dsa`` for the sibling DSA module +that is implemented today. +""" + +from typing import Optional + +from flax import linen as nn + +from . import indexer as _indexer # noqa: F401 — surface to assert package layout + + +_HCA_DEFER_MESSAGE = ( + "HeavilyCompressedAttention is a phase-1 scaffold (design deferred).\n" + "Open design questions to resolve before implementing:\n" + " 1. RoPE applied on compressed (C_q/C_kv) or decompressed (Q/K) tensors?\n" + " - DeepSeek-V2 applies RoPE on a separate sub-head; we should match.\n" + " 2. KV cache layout: latent-only (memory-optimal) vs latent+RoPE-sub-head?\n" + " 3. Backward through decompression: recompute (memory) vs store (bandwidth)?\n" + " 4. Should this share projection plumbing with MultiHeadAttention's " + "LayerNormDenseGeneral, or use bespoke low-rank projections?\n" + " 5. Interaction with TE's existing fused-attn backends — does any of " + "CK/AITER/cuDNN support split (RoPE/no-RoPE) head dims natively?\n" + "Pin these before filling in. See " + "transformer_engine.jax.sparse_attention.dsa for the working DSA module." +) + + +class HeavilyCompressedAttention(nn.Module): # pylint: disable=too-few-public-methods + """MLA-style heavily compressed attention — **DESIGN DEFERRED**. + + Parameters + ---------- + head_dim : int + Per-head dimension of the dense (decompressed) attention. + num_attention_heads : int + Number of attention heads. + q_lora_rank : int + Rank of the query low-rank compression (``d_c`` in indexer notation). + kv_lora_rank : int + Rank of the key/value low-rank compression. The KV cache stores + only this latent (``kv_lora_rank``-dimensional) representation. + qk_nope_head_dim : int + Per-head dimension for the non-RoPE component of Q/K. + qk_rope_head_dim : int + Per-head dimension for the RoPE component of Q/K. Total Q/K head + dim is ``qk_nope_head_dim + qk_rope_head_dim``. + v_head_dim : int + Per-head dimension of V (may differ from Q/K head dim). + attn_mask_type : str, default = ``"causal"`` + Mask type. Plumbed to the eventual dense attention call. + attention_dropout : float, default = ``0.0`` + qkv_layout : str, default = ``"bshd_bshd_bshd"`` + scale_factor : Optional[float], default = ``None`` + Defaults to ``1/sqrt(qk_nope_head_dim + qk_rope_head_dim)`` when implemented. + """ + + head_dim: int + num_attention_heads: int + q_lora_rank: int + kv_lora_rank: int + qk_nope_head_dim: int + qk_rope_head_dim: int + v_head_dim: int + attn_mask_type: str = "causal" + attention_dropout: float = 0.0 + qkv_layout: str = "bshd_bshd_bshd" + scale_factor: Optional[float] = None + + @nn.compact + def __call__(self, inputs_q, inputs_kv, *, deterministic: bool = False): # noqa: D401 + del inputs_q, inputs_kv, deterministic + raise NotImplementedError(_HCA_DEFER_MESSAGE) + + +def heavily_compressed_attention( + inputs_q, + inputs_kv, + *, + head_dim: int, + num_attention_heads: int, + q_lora_rank: int, + kv_lora_rank: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + attn_mask_type: str = "causal", + scale_factor: Optional[float] = None, +): + """Functional HCA — **DESIGN DEFERRED** (raises NotImplementedError). + + Mirrors the planned :class:`HeavilyCompressedAttention` surface as a + stateless function for callers that prefer functional composition. + """ + del ( + inputs_q, + inputs_kv, + head_dim, + num_attention_heads, + q_lora_rank, + kv_lora_rank, + qk_nope_head_dim, + qk_rope_head_dim, + v_head_dim, + attn_mask_type, + scale_factor, + ) + raise NotImplementedError(_HCA_DEFER_MESSAGE) diff --git a/transformer_engine/jax/sparse_attention/dsa.py b/transformer_engine/jax/sparse_attention/dsa.py new file mode 100644 index 000000000..e3ee8d741 --- /dev/null +++ b/transformer_engine/jax/sparse_attention/dsa.py @@ -0,0 +1,416 @@ +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +"""Deep Sparse Attention (DSA) — composes the lightning indexer with dense attention. + +Phase 1 (this file, working) composes the existing pieces: + + 1. Per-attention-head Q/K/V projection (DenseGeneral) + 2. Lightning-indexer scoring via the hybrid Triton backend + 3. Causal mask + jax.lax.top_k on each per-head score row + 4. Scatter top-k indices into a per-head sparse attention mask + 5. Call transformer_engine.jax.flax.DotProductAttention with that mask + +Phase 2 will dispatch the entire stack to a single fused Triton kernel +``transformer_engine.jax.triton_extensions.fused_sparse_attention_triton``; +the dispatch site lives in :func:`deep_sparse_attention_core` under +``backend="fused"`` and currently raises NotImplementedError (the scaffold +holds the signature stable so the kernel can land without API churn). + +**Shape contract — all DSA tensors are rank-4 with the outer-head dim +explicit:** + + inputs_q : [B, oH, T_t, hidden] + inputs_kv : [B, oH, T_s, hidden] + output : [B, oH, T_t, head_dim] + +``oH ≡ num_attention_heads``. Each attention head has its own indexer +score row, its own top-k pattern, and its own attention output. The +indexer projection *weights* are shared across attention heads — the +per-head divergence comes from the per-head input slice (the caller is +expected to have already produced per-head hidden states upstream). + +This shape contract aligns with the lightning-indexer benchmark's +``[B, oH, T, d]`` convention (see ``benchmarks/profile_indexer_topk.py``) +and lets us call the Triton hybrid backend directly without rank +adjustment. + +Zero modifications are made to upstream-tracked TE files; DSA composes +:class:`DotProductAttention` from the outside via its public ``mask=`` +argument. +""" + +from typing import Literal, Optional + +import jax +import jax.numpy as jnp +from flax import linen as nn + +from transformer_engine.jax.flax.module import DenseGeneral +from transformer_engine.jax.flax.transformer import DotProductAttention +from .indexer import indexer as _indexer_fn, _indexer_projections + + +# Backends supported by deep_sparse_attention_core. +_BACKENDS = ("composition", "fused") + + +# ----------------------------------------------------------------------------- +# Mask construction helpers +# ----------------------------------------------------------------------------- + + +def _causal_keep_mask(T_t: int, T_s: int, dtype=jnp.bool_): + """Lower-triangular keep mask aligned to the bottom-right corner. + + For self-attention (T_t == T_s) this is the standard ``jnp.tril(ones)``. + For cross-attention with T_t < T_s, query position ``t`` attends to key + positions ``[0, T_s - T_t + t]``. This matches the convention used by + causal cross-attention with prefix context. + """ + q_pos = jnp.arange(T_t)[:, None] # [T_t, 1] + k_pos = jnp.arange(T_s)[None, :] # [1, T_s] + keep = k_pos <= (q_pos + (T_s - T_t)) # [T_t, T_s] + return keep.astype(dtype) + + +def _topk_indices_to_attn_mask( + indices: jax.Array, + T_s: int, + *, + causal: bool, +) -> jax.Array: + """Convert per-(B, oH, T_t) top-k indices into a DPA-style mask. + + Args: + indices: ``[B, oH, T_t, k]`` int32 — top-k key positions per (B, oH, T_t). + T_s: number of key positions. + causal: if True, AND the keep-mask with a causal keep-mask before + inverting. + + Returns: + ``[B, oH, T_t, T_s]`` uint8 — ``1`` means *mask out*. The caller + reshapes to ``[B*oH, 1, T_t, T_s]`` for DPA dispatch. + """ + B, oH, T_t, _k = indices.shape + + # Scatter True at every (b, h, t, indices[b, h, t, :]) position. + keep = jnp.zeros((B, oH, T_t, T_s), dtype=jnp.bool_) + b_idx = jnp.arange(B)[:, None, None, None] # [B, 1, 1, 1] + h_idx = jnp.arange(oH)[None, :, None, None] # [1, oH, 1, 1] + t_idx = jnp.arange(T_t)[None, None, :, None] # [1, 1, T_t, 1] + # Duplicates from .at[].set(True) are idempotent — safe when k > finite scores. + keep = keep.at[b_idx, h_idx, t_idx, indices].set(True) # [B, oH, T_t, T_s] + + if causal: + keep = keep & _causal_keep_mask(T_t, T_s)[None, None, :, :] + + mask_out = jnp.logical_not(keep) + # TE's ScaledMaskedSoftmax expects uint8 mask (cpp_extensions/softmax.py:483). + return mask_out.astype(jnp.uint8) # [B, oH, T_t, T_s] + + +# ----------------------------------------------------------------------------- +# Functional API +# ----------------------------------------------------------------------------- + + +def deep_sparse_attention_core( + query: jax.Array, + key: jax.Array, + value: jax.Array, + indexer_inputs_q: jax.Array, + indexer_inputs_kv: jax.Array, + indexer_W_uq: jax.Array, + indexer_W_dq: jax.Array, + indexer_W_k: jax.Array, + indexer_W_w: jax.Array, + *, + k: int, + attn_mask_type: str = "causal", + scale_factor: Optional[float] = None, + attention_dropout: float = 0.0, + deterministic: bool = True, + backend: Literal["composition", "fused"] = "composition", + dropout_rng_name: str = "dropout", + indexer_backend: str = "hybrid", +) -> jax.Array: + """Functional DSA: indexer-top-k + per-head sparse attention. + + Args: + query, key, value: ``[B, oH, T, head_dim]`` — post-projection per-head + attention tensors. ``oH ≡ num_attention_heads``; each outer-head + slice owns a single attention head of dimension ``head_dim``. + indexer_inputs_q: ``[B, oH, T_t, hidden]`` — per-head hidden states + fed to the indexer's query side. + indexer_inputs_kv: ``[B, oH, T_s, hidden]`` — per-head hidden states + fed to the indexer's key side. + indexer_W_uq: ``[H_idx, d_c, d_i]`` indexer up-projection (shared). + indexer_W_dq: ``[hidden, d_c]`` indexer down-projection (shared). + indexer_W_k: ``[hidden, d_i]`` indexer key projection (shared). + indexer_W_w: ``[hidden, H_idx]`` indexer output-weight projection (shared). + k: number of top key positions to retain per (B, oH, T_t). + attn_mask_type: ``"causal"`` or ``"no_mask"`` (phase 1 only). + scale_factor: passed through to DPA. ``None`` → ``1/sqrt(head_dim)``. + attention_dropout, deterministic, dropout_rng_name: passed through to DPA. + backend: ``"composition"`` (working) or ``"fused"`` (phase-2 scaffold). + indexer_backend: which indexer implementation to use when + ``backend == "composition"``. ``"hybrid"`` (default, fast Triton) or + ``"reference"`` (pure einsum). + + Returns: + Attention output of the same shape as ``query``: ``[B, oH, T_t, head_dim]``. + """ + if backend not in _BACKENDS: + raise ValueError(f"unknown backend {backend!r}; expected one of {_BACKENDS}") + + if attn_mask_type not in ("causal", "no_mask"): + raise NotImplementedError( + f"deep_sparse_attention_core: attn_mask_type={attn_mask_type!r} " + "not supported in phase 1. Supported: 'causal', 'no_mask'. " + "(Padding / segment-id mask types are tracked as a follow-up.)" + ) + + if query.ndim != 4 or key.ndim != 4 or value.ndim != 4: + raise ValueError( + f"DSA expects rank-4 query/key/value [B, oH, T, head_dim]; got " + f"shapes query={query.shape} key={key.shape} value={value.shape}" + ) + if indexer_inputs_q.ndim != 4 or indexer_inputs_kv.ndim != 4: + raise ValueError( + f"DSA expects rank-4 indexer inputs [B, oH, T, hidden]; got " + f"shapes indexer_inputs_q={indexer_inputs_q.shape} " + f"indexer_inputs_kv={indexer_inputs_kv.shape}" + ) + + if backend == "fused": + from transformer_engine.jax.triton_extensions import ( + fused_sparse_attention_triton, + ) + # Project the indexer side so the fused primitive sees Hq/Hk/W_o tensors. + # (Scaffold lowering raises; the projections are computed for shape only.) + Hq, Hk, W_o = _indexer_projections( + indexer_inputs_q, indexer_inputs_kv, + indexer_W_uq, indexer_W_dq, indexer_W_k, indexer_W_w, + ) + return fused_sparse_attention_triton( + query, key, value, Hq, Hk, W_o, k=k, + ) + + # ---- composition backend ---- + B, oH, T_t, head_dim = query.shape + T_s = key.shape[2] + if key.shape != (B, oH, T_s, head_dim) or value.shape != (B, oH, T_s, head_dim): + raise ValueError( + f"DSA shape mismatch: query={query.shape} key={key.shape} value={value.shape}" + ) + + # 1. Indexer produces a per-head score row [B, oH, T_t, T_s]. + scores = _indexer_fn( + indexer_inputs_q, + indexer_inputs_kv, + indexer_W_uq, + indexer_W_dq, + indexer_W_k, + indexer_W_w, + backend=indexer_backend, + out_dtype=jnp.float32, + ) # [B, oH, T_t, T_s] fp32 + + # 2. Causal mask BEFORE top-k so non-causal positions are excluded. + causal = (attn_mask_type == "causal") + if causal: + ckeep = _causal_keep_mask(T_t, T_s)[None, None, :, :] # [1, 1, T_t, T_s] + scores = jnp.where(ckeep, scores, jnp.asarray(-jnp.inf, dtype=scores.dtype)) + + # 3. Per-(B, oH, T_t) top-k. + k_eff = min(k, T_s) + _, topk_idx = jax.lax.top_k(scores, k_eff) # [B, oH, T_t, k_eff] + + # 4. Scatter into [B, oH, T_t, T_s] uint8 DPA mask (1 = mask out). + sparse_mask = _topk_indices_to_attn_mask( + topk_idx, T_s, causal=causal, + ) # [B, oH, T_t, T_s] uint8 + + # 5. Dense attention with the sparse mask. We collapse (B, oH) into the + # batch dim of DPA so each attention head gets its own mask. attn_mask_type + # 'padding' tells DPA to honor the provided mask as-is (causal is baked in). + BH = B * oH + q_r = query.reshape(BH, T_t, 1, head_dim) # [BH, T_t, 1, D] + k_r = key.reshape(BH, T_s, 1, head_dim) + v_r = value.reshape(BH, T_s, 1, head_dim) + mask_r = sparse_mask.reshape(BH, 1, T_t, T_s) # [BH, 1, T_t, T_s] + + dpa = DotProductAttention( + head_dim=head_dim, + num_attention_heads=1, + num_gqa_groups=1, # one head per oH slice; must be int (probe rejects None) + attention_dropout=attention_dropout, + attn_mask_type="padding", + qkv_layout="bshd_bshd_bshd", + scale_factor=scale_factor, + dropout_rng_name=dropout_rng_name, + ) + out = dpa( + q_r, k_r, v_r, + sequence_descriptor=mask_r, + deterministic=deterministic, + ) # [BH, T_t, head_dim] (flattened H=1) + # DPA flattens the H=1 axis on output. Reshape back to [B, oH, T_t, head_dim]. + return out.reshape(B, oH, T_t, head_dim) + + +# ----------------------------------------------------------------------------- +# Flax module +# ----------------------------------------------------------------------------- + + +class DeepSparseAttention(nn.Module): # pylint: disable=too-few-public-methods + """Deep Sparse Attention (DSA) Flax module — rank-4, per-attention-head. + + Composes the lightning indexer with TE's :class:`DotProductAttention`. + Each attention head (``oH``) has its own indexer score row, top-k + pattern, and dense-attention output. Indexer projection weights are + shared across heads. + + Parameters + ---------- + head_dim : int + Per-attention-head dimension. + num_attention_heads : int + Number of attention heads (``oH``). + indexer_num_heads : int + Number of indexer-internal heads (``H`` in the indexer notation). + indexer_d_c : int + Indexer down-projection rank (``d_c``). + indexer_d_i : int + Indexer inner head dimension (``d_i``). + topk : int + Number of top key positions to retain per query. + attn_mask_type : str, default ``"causal"`` + ``"causal"`` or ``"no_mask"`` (phase 1). + attention_dropout : float, default ``0.0`` + scale_factor : Optional[float] + Defaults to ``1/sqrt(head_dim)`` inside DPA. + backend : str, default ``"composition"`` + ``"composition"`` (working) or ``"fused"`` (phase-2 scaffold). + indexer_backend : str, default ``"hybrid"`` + ``"hybrid"`` (fast Triton) or ``"reference"`` (pure einsum). Only used + when ``backend == "composition"``. + dtype : Optional[jnp.dtype] + Parameter dtype. Defaults to the input dtype. + """ + + head_dim: int + num_attention_heads: int + indexer_num_heads: int + indexer_d_c: int + indexer_d_i: int + topk: int + attn_mask_type: str = "causal" + attention_dropout: float = 0.0 + scale_factor: Optional[float] = None + backend: str = "composition" + indexer_backend: str = "hybrid" + dtype: Optional[jnp.dtype] = None + + @nn.compact + def __call__( + self, + inputs_q: jax.Array, + inputs_kv: jax.Array, + *, + deterministic: bool = True, + ) -> jax.Array: + """Run DSA on rank-4 per-head inputs. + + Args: + inputs_q: ``[B, oH, T_t, hidden]`` — per-head query-side hidden state. + inputs_kv: ``[B, oH, T_s, hidden]`` — per-head key-side hidden state. + deterministic: forwarded to DPA. + + Returns: + ``[B, oH, T_t, head_dim]`` — per-head attention output. + """ + if inputs_q.ndim != 4 or inputs_kv.ndim != 4: + raise ValueError( + f"DeepSparseAttention expects rank-4 inputs [B, oH, T, hidden]; " + f"got inputs_q.shape={inputs_q.shape}, inputs_kv.shape={inputs_kv.shape}" + ) + B, oH, T_t, hidden = inputs_q.shape + if oH != self.num_attention_heads: + raise ValueError( + f"DeepSparseAttention: inputs_q.shape[1]={oH} must equal " + f"num_attention_heads={self.num_attention_heads}" + ) + if inputs_kv.shape[0] != B or inputs_kv.shape[1] != oH or inputs_kv.shape[3] != hidden: + raise ValueError( + f"DeepSparseAttention: inputs_kv.shape={inputs_kv.shape} must match " + f"(B={B}, oH={oH}, T_s, hidden={hidden})" + ) + + param_dtype = self.dtype if self.dtype is not None else inputs_q.dtype + + # ---- per-head Q/K/V projections ---- + # DenseGeneral with features=head_dim and axis=-1 maps [..., hidden] → + # [..., head_dim], preserving the (B, oH, T) leading dims. Each attention + # head (oH slice) shares the projection kernel — divergence comes from the + # per-head input slice the caller provides. + query = DenseGeneral( + features=self.head_dim, + use_bias=False, + dtype=param_dtype, + name="query", + )(inputs_q) # [B, oH, T_t, head_dim] + key = DenseGeneral( + features=self.head_dim, + use_bias=False, + dtype=param_dtype, + name="key", + )(inputs_kv) # [B, oH, T_s, head_dim] + value = DenseGeneral( + features=self.head_dim, + use_bias=False, + dtype=param_dtype, + name="value", + )(inputs_kv) # [B, oH, T_s, head_dim] + + # ---- indexer projections (shared across oH) ---- + # Shapes mirror transformer_engine.jax.sparse_attention.indexer:31-48. + W_dq = self.param( + "indexer_W_dq", + nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal"), + (hidden, self.indexer_d_c), + param_dtype, + ) + W_uq = self.param( + "indexer_W_uq", + nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal"), + (self.indexer_num_heads, self.indexer_d_c, self.indexer_d_i), + param_dtype, + ) + W_k_idx = self.param( + "indexer_W_k", + nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal"), + (hidden, self.indexer_d_i), + param_dtype, + ) + W_w = self.param( + "indexer_W_w", + nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal"), + (hidden, self.indexer_num_heads), + param_dtype, + ) + + return deep_sparse_attention_core( + query, key, value, + inputs_q, inputs_kv, + W_uq, W_dq, W_k_idx, W_w, + k=self.topk, + attn_mask_type=self.attn_mask_type, + scale_factor=self.scale_factor, + attention_dropout=self.attention_dropout, + deterministic=deterministic, + backend=self.backend, + indexer_backend=self.indexer_backend, + ) # [B, oH, T_t, head_dim] diff --git a/transformer_engine/jax/sparse_attention/indexer.py b/transformer_engine/jax/sparse_attention/indexer.py new file mode 100644 index 000000000..98e93f5fa --- /dev/null +++ b/transformer_engine/jax/sparse_attention/indexer.py @@ -0,0 +1,200 @@ +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +"""Indexer op (forward only), bf16 inputs. + +Two canonical backends: + * ``"reference"`` — pure ``jnp.einsum``. Materializes the + (B, oH, T, H, S) pre-relu score tensor in HBM via hipBLASLt. + * ``"hybrid"`` — same einsum projections (C_q, H_q, H_k, W_o) followed + by a fused Triton kernel that does score+relu+H-reduction in + registers. Avoids the score-tensor HBM round-trip that dominates the + reference path. + +Functional entry point: ``indexer(Q, K, W_uq, W_dq, W_k, W_w, *, backend=...)``. +User-facing Flax module: :class:`LightningIndexer`, which owns the projection +weights and delegates to ``indexer`` / ``indexer_topk``. + +Math (low-rank form: Q is hidden state; query heads are produced by a +down-projection (d -> d_c) followed by an up-projection (d_c -> H * d_i); +output weights are produced from Q via a learnable d -> H projection): + + C_q = Q @ W_dq # (..., T, d_c) + H_q = einsum("...tc,hci->...thi", C_q, W_uq) # (..., T, H, d_i) + H_k = K @ W_k # (..., S, d_i) + W_o = Q @ W_w # (..., T, H) + H = relu(einsum("...thi,...si->...ths", H_q, H_k)) # (..., T, H, S) + O = einsum("...ths,...th->...ts", H, W_o) # (..., T, S) +""" + +import functools +from typing import Optional + +import jax +import jax.numpy as jnp +from flax import linen as nn + + +def _indexer_projections(Q, K, W_uq, W_dq, W_k, W_w): + """Low-rank indexer projections shared by every backend. + + Returns (H_q, H_k, W_o) with shapes + (..., T, H, d_i), (..., S, d_i), (..., T, H). + """ + C_q = jnp.einsum("...td,dc->...tc", Q, W_dq) + H_q = jnp.einsum("...tc,hci->...thi", C_q, W_uq) + H_k = jnp.einsum("...sd,di->...si", K, W_k) + W_o = jnp.einsum("...td,dh->...th", Q, W_w) + return H_q, H_k, W_o + + +def _indexer_impl_reference(Q, K, W_uq, W_dq, W_k, W_w, out_dtype=None): + """ + Q [..., T, d] + K [..., S, d] + W_dq [d, d_c] + W_uq [H, d_c, d_i] + W_k [d, d_i] + W_w [..., d, H] # leading dims must match Q's + """ + H_q, H_k, W_o = _indexer_projections(Q, K, W_uq, W_dq, W_k, W_w) + H = jax.nn.relu(jnp.einsum("...thi,...si->...ths", H_q, H_k)) # (..., T, H, S) + O = jnp.einsum("...ths,...th->...ts", H, W_o) # (..., T, S) + if out_dtype is not None: + O = O.astype(out_dtype) + return O + + +def _indexer_impl_hybrid(Q, K, W_uq, W_dq, W_k, W_w, out_dtype=None): + """Einsum projections + Triton score-relu-reduce. + + Mirrors ``_indexer_impl_reference`` for the four projections (which + lower to hipBLASLt bf16 GEMMs), then hands Hq / Hk / W_o to a fused + Triton kernel that does score+relu+H-reduction in registers — + eliminating the (B, oH, T, H, S) pre-relu-score HBM round-trip the + pure-einsum path pays. + """ + from transformer_engine.jax.triton_extensions.indexer import score_reduce_triton + + H_q, H_k, W_o = _indexer_projections(Q, K, W_uq, W_dq, W_k, W_w) + return score_reduce_triton(H_q, H_k, W_o, + out_dtype=out_dtype if out_dtype else Q.dtype) + + +@functools.partial(jax.jit, static_argnames=("k",)) +def indexer_topk(Q, K, W_uq, W_dq, W_k, weights, *, k): + """Lightning-indexer + top-k (fused). + + Same projections as ``indexer()`` (reference math), then a single Triton + kernel that computes the score row, ReLU, weighted H-reduction, and + streaming top-k all in one pass — the (B, oH, T_t, T_s) score matrix is + never materialized. + + Args: + Q, K, W_uq, W_dq, W_k, weights: same as ``indexer()``. + k: number of top scores to return per (B, oH, T_t) row. + Must be a power of 2 and <= S. + + Returns: + Topk_idx: (..., T_t, k) int32 — top-k indices into the S axis, + in descending score order. + """ + from transformer_engine.jax.triton_extensions.indexer import score_topk_triton + H_q, H_k, W_o = _indexer_projections(Q, K, W_uq, W_dq, W_k, weights) + return score_topk_triton(H_q, H_k, W_o, k=k) + + +@functools.partial(jax.jit, static_argnames=("backend", "out_dtype")) +def indexer(Q, K, W_uq, W_dq, W_k, weights, *, out_dtype=None, backend="reference"): + """Low-rank lightning-indexer (bf16). + + Args: + Q: (..., T, d) hidden state (per token) + K: (..., S, d) key hidden state + W_uq: (H, d_c, d_i) up-projection: d_c -> d_i (per head) + W_dq: (d, d_c) down-projection: d -> d_c + W_k: (d, d_i) key projection + weights: (d, H) learnable output-weight projection + (W_o = Q @ weights inside the impl) + out_dtype: output dtype override (defaults to Q.dtype). + backend: "reference" (pure einsum) or "hybrid" (einsum projections + + Triton score-relu-reduce kernel). + + Returns: + O of shape (..., T, S). + """ + if backend == "reference": + return _indexer_impl_reference(Q, K, W_uq, W_dq, W_k, weights, out_dtype=out_dtype) + if backend == "hybrid": + return _indexer_impl_hybrid(Q, K, W_uq, W_dq, W_k, weights, out_dtype=out_dtype) + raise ValueError( + f"unknown backend {backend!r}; expected 'reference' or 'hybrid'" + ) + + +class LightningIndexer(nn.Module): # pylint: disable=too-few-public-methods + """Lightning-indexer Flax module — the user-facing indexer API. + + Owns the low-rank indexer projection weights (``W_dq``, ``W_uq``, ``W_k``, + ``W_w``) and delegates to the functional :func:`indexer` / :func:`indexer_topk` + ops. Weight shapes mirror :func:`indexer`'s ``Args`` and are inferred from the + trailing hidden dimension ``d`` of ``Q`` at call time. + + Parameters + ---------- + num_heads : int + Number of indexer-internal heads (``H``). + d_c : int + Down-projection rank (``d -> d_c``). + d_i : int + Inner head dimension (``d_i``). + topk : Optional[int], default ``None`` + If set, :meth:`__call__` returns the fused top-``k`` indices + (``(..., T, k)`` int32) via :func:`indexer_topk`, and ``backend`` / + ``out_dtype`` are ignored (top-k always uses the fused Triton kernel). + If ``None``, :meth:`__call__` returns the full score tensor + ``(..., T, S)``. + backend : str, default ``"reference"`` + ``"reference"`` (pure einsum) or ``"hybrid"`` (Triton score-relu-reduce). + Only used when ``topk is None``. + out_dtype : Optional[jnp.dtype] + Output dtype override; defaults to ``Q.dtype``. Unused when ``topk`` is set. + dtype : Optional[jnp.dtype] + Parameter dtype. Defaults to the input dtype. + """ + + num_heads: int + d_c: int + d_i: int + topk: Optional[int] = None + backend: str = "reference" + out_dtype: Optional[jnp.dtype] = None + dtype: Optional[jnp.dtype] = None + + @nn.compact + def __call__(self, Q: jax.Array, K: jax.Array) -> jax.Array: + """Run the indexer on ``Q`` / ``K``. + + Args: + Q: ``(..., T, d)`` query-side hidden state. + K: ``(..., S, d)`` key-side hidden state. + + Returns: + ``(..., T, S)`` scores if ``topk is None``, else ``(..., T, k)`` + int32 top-k indices (in descending score order). + """ + d = Q.shape[-1] + param_dtype = self.dtype if self.dtype is not None else Q.dtype + init = nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal") + + W_dq = self.param("W_dq", init, (d, self.d_c), param_dtype) + W_uq = self.param("W_uq", init, (self.num_heads, self.d_c, self.d_i), param_dtype) + W_k = self.param("W_k", init, (d, self.d_i), param_dtype) + W_w = self.param("W_w", init, (d, self.num_heads), param_dtype) + + if self.topk is not None: + return indexer_topk(Q, K, W_uq, W_dq, W_k, W_w, k=self.topk) + return indexer( + Q, K, W_uq, W_dq, W_k, W_w, + out_dtype=self.out_dtype, backend=self.backend, + ) diff --git a/transformer_engine/jax/triton_extensions/__init__.py b/transformer_engine/jax/triton_extensions/__init__.py index d9708fde9..79ccc0f73 100644 --- a/transformer_engine/jax/triton_extensions/__init__.py +++ b/transformer_engine/jax/triton_extensions/__init__.py @@ -58,3 +58,5 @@ def lowering(ctx, x, **kwargs): from .utils import * from .permutation import * +from .indexer import score_reduce_triton, score_topk_triton +from .sparse_attention import fused_sparse_attention_triton diff --git a/transformer_engine/jax/triton_extensions/indexer.py b/transformer_engine/jax/triton_extensions/indexer.py new file mode 100644 index 000000000..e6477d6ae --- /dev/null +++ b/transformer_engine/jax/triton_extensions/indexer.py @@ -0,0 +1,828 @@ +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +"""Triton score-relu-reduce kernel for the lightning-indexer hybrid backend. + +The hybrid backend computes the four projections (C_q, H_q, H_k, W_o) via +``jnp.einsum`` (which lowers to hipBLASLt bf16 GEMMs) and then hands the +results to this kernel for the score matmul + ReLU + per-(t, h) weighted +H-reduction: + + scores = relu(einsum("...thi,...si->...ths", H_q, H_k)) # never written + O = einsum("...ths,...th->...ts", scores, W_o) + +The kernel keeps each per-head score tile in registers, avoiding the +(B, oH, T, H, S) HBM round-trip that an einsum-only implementation pays +on the pre-relu score tensor. +""" + +import functools + +import jax +import jax.numpy as jnp +import triton +import triton.language as tl + +from jax import core +from jax.extend import core as extend_core +from jax.interpreters import mlir, xla + +from .utils import triton_call_lowering + + +def _score_reduce_autotune_configs(): + # The kernel is dominated by Hq reads (one (BLOCK_T, d_i) load per H + # iteration). Bigger BLOCK_T ⇒ fewer T tiles ⇒ less total Hq traffic. + # Bigger BLOCK_S ⇒ more Hk reuse but bigger per-CTA footprint. + # + # BLOCK_T=512 was tried and consistently failed to launch on MI355X + # (resource exhaustion — VGPR/LDS budget for 64-iter H-loop with that + # large an accumulator). Capped at 256. + cfgs = [] + for bt in (64, 128, 256): + for bs in (32, 64, 128): + for num_warps in (4, 8): + for num_stages in (1, 2): + cfgs.append(triton.Config( + {"BLOCK_T": bt, "BLOCK_S": bs}, + num_warps=num_warps, num_stages=num_stages, + )) + # A few skinny / fat shapes the regular grid above won't hit. + cfgs += [ + triton.Config({"BLOCK_T": 32, "BLOCK_S": 128}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_T": 32, "BLOCK_S": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_T": 256, "BLOCK_S": 32}, num_warps=8, num_stages=2), + ] + return cfgs + + +@triton.autotune(configs=_score_reduce_autotune_configs(), key=["H", "d_i"]) +@triton.jit +def _score_reduce_kernel( + Hq_ptr, # (B, oH, T_t, H, d_i) — produced by einsum("...tc,hci->...thi") + Hk_ptr, # (B, oH, T_s, d_i) + W_o_ptr, # (B, oH, T_t, H) + O_ptr, # (B, oH, T_t, T_s) + B: tl.constexpr, + oH: tl.constexpr, + T_t: tl.constexpr, + T_s: tl.constexpr, + H: tl.constexpr, + d_i: tl.constexpr, + BLOCK_T: tl.constexpr, + BLOCK_S: tl.constexpr, +): + """Compute one (BLOCK_T, BLOCK_S) tile of O for one (b, h_outer) slice. + + Grid order: (cdiv(T_s, BLOCK_S), cdiv(T_t, BLOCK_T), B * oH). + + S is the fastest-dispatching axis so consecutive CTAs share (B*oH, T) + and vary only in S — they all read the same per-head Hq slab, hitting + L2 instead of HBM. Hq layout is the natural einsum output + (..., T, H, d_i); per-head loads are strided in T (stride H*d_i). + """ + pid_s = tl.program_id(0) + pid_t = tl.program_id(1) + pid_bh = tl.program_id(2) + + # int64 indexing — Hq alone has B*oH*T*H*d_i = 4.3 B elements at T=S=4096, + # exceeds int32 range. + b = (pid_bh // oH).to(tl.int64) + h_outer = (pid_bh % oH).to(tl.int64) + + rt = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) + rs = pid_s * BLOCK_S + tl.arange(0, BLOCK_S) + rdi = tl.arange(0, d_i) + + rt_mask = rt < T_t + rs_mask = rs < T_s + + hq_base = b * (oH * T_t * H * d_i) + h_outer * (T_t * H * d_i) + hk_base = b * (oH * T_s * d_i) + h_outer * (T_s * d_i) + wo_base = b * (oH * T_t * H) + h_outer * (T_t * H) + o_base = b * (oH * T_t * T_s) + h_outer * (T_t * T_s) + + # Load the (BLOCK_S, d_i) Hk slab once — it is loop-invariant over H. + hk_ptrs = Hk_ptr + hk_base + rs[:, None] * d_i + rdi[None, :] + Hk_tile = tl.load(hk_ptrs, mask=rs_mask[:, None], other=0.0) + Hk_T = tl.trans(Hk_tile) # (d_i, BLOCK_S) + + acc = tl.zeros((BLOCK_T, BLOCK_S), dtype=tl.float32) + + for h in range(H): + hq_ptrs = (Hq_ptr + hq_base + + rt[:, None] * (H * d_i) + h * d_i + rdi[None, :]) + Hq_h = tl.load(hq_ptrs, mask=rt_mask[:, None], other=0.0) + + wo_ptrs = W_o_ptr + wo_base + rt * H + h + w_h = tl.load(wo_ptrs, mask=rt_mask, other=0.0) + + score = tl.dot(Hq_h, Hk_T) + score = tl.maximum(score, 0.0) + acc += score * w_h[:, None].to(tl.float32) + + o_ptrs = O_ptr + o_base + rt[:, None] * T_s + rs[None, :] + tl.store(o_ptrs, acc.to(O_ptr.dtype.element_ty), + mask=rt_mask[:, None] & rs_mask[None, :]) + + +_score_reduce_p = extend_core.Primitive("te_indexer_score_reduce_triton") +_score_reduce_p.multiple_results = True + + +@_score_reduce_p.def_abstract_eval +def _score_reduce_abstract(Hq, Hk, W_o, *, out_dtype): + del W_o + # Hq layout: (B, oH, T_t, H, d_i) + B, oH, T_t, _H, _d_i = Hq.shape + T_s = Hk.shape[2] + return [core.ShapedArray((B, oH, T_t, T_s), out_dtype)] + + +_score_reduce_p.def_impl(functools.partial(xla.apply_primitive, _score_reduce_p)) + + +def _score_reduce_lowering(ctx, Hq, Hk, W_o, *, out_dtype): + del out_dtype + Hq_aval = ctx.avals_in[0] + Hk_aval = ctx.avals_in[1] + B, oH, T_t, H, d_i = Hq_aval.shape + T_s = Hk_aval.shape[2] + + def grid_fn(merged_kwargs): + bt = merged_kwargs.get("BLOCK_T", 64) + bs = merged_kwargs.get("BLOCK_S", 64) + # S as grid_x (fastest-dispatching) so per-(B*oH, T-tile) S workgroups + # cluster in time and hit L2 on the shared Hq slab. + return (triton.cdiv(T_s, bs), triton.cdiv(T_t, bt), B * oH) + + return triton_call_lowering( + ctx, + _score_reduce_kernel, + Hq, Hk, W_o, + grid=grid_fn, + num_warps=4, + num_stages=2, + constexprs={ + "B": B, + "oH": oH, + "T_t": T_t, + "T_s": T_s, + "H": H, + "d_i": d_i, + }, + ) + + +mlir.register_lowering(_score_reduce_p, _score_reduce_lowering, platform="rocm") +mlir.register_lowering(_score_reduce_p, _score_reduce_lowering, platform="cuda") + + +# --- Chunked score-tile kernel for hybrid bwd -------------------------------- +# +# Produces dscores_chunk[B, oH, T, H_CHUNK, T_s] and dW_o_chunk[B, oH, T, H_CHUNK] +# for ONE h-chunk. Caller loops over H/H_CHUNK chunks and feeds dscores_chunk +# to hipBLASLt einsums for dHq/dHk reductions. Bounds peak materialization to +# H/H_CHUNK fraction of the full (B, oH, T, H, T_s) score tensor. +# +# Fuses score recompute + relu + mask + dO*W_o broadcast in registers -- +# nothing of size (B, oH, T, H, T_s) ever lands in HBM at full size. dW_o is +# reduced inline (sum_s of h_relu * dO) so h_relu also never materializes. + + +_HBWD_BLOCK_T = 64 +_HBWD_BLOCK_S = 64 + + +@triton.jit +def _score_dscores_chunk_kernel( + Hq_chunk_ptr, # input (B, oH, T, H_CHUNK, d_i) bf16 + Hk_ptr, # input (B, oH, T_s, d_i) bf16 + W_o_chunk_ptr, # input (B, oH, T, H_CHUNK) bf16 + dO_ptr, # input (B, oH, T, T_s) fp32 + dscores_chunk_ptr, # output (B, oH, T, H_CHUNK, T_s) bf16 + dWo_chunk_ptr, # output (B, oH, T, H_CHUNK) bf16 + B: tl.constexpr, + oH: tl.constexpr, + T: tl.constexpr, + T_s: tl.constexpr, + H_CHUNK: tl.constexpr, + d_i: tl.constexpr, + BLOCK_T: tl.constexpr, + BLOCK_S: tl.constexpr, +): + """One CTA handles (T_tile, h_in) for one (b, h_outer). Loops over s_chunks. + + Each CTA writes its T_tile rows of (dscores_chunk[..., h_in, :], + dW_o_chunk[..., h_in]). dW_o is reduced in registers (sum over s) so + h_relu never lands in HBM -- we compute it on-the-fly and consume it. + """ + pid_t = tl.program_id(0) + pid_h_bh = tl.program_id(1) + h_in = pid_h_bh % H_CHUNK + pid_bh = pid_h_bh // H_CHUNK + b = (pid_bh // oH).to(tl.int64) + h_outer = (pid_bh % oH).to(tl.int64) + h_in_64 = h_in.to(tl.int64) + + rt = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) + rdi = tl.arange(0, d_i) + rt_mask = rt < T + + hq_base = b * (oH * T * H_CHUNK * d_i) + h_outer * (T * H_CHUNK * d_i) + hk_base = b * (oH * T_s * d_i) + h_outer * (T_s * d_i) + wo_base = b * (oH * T * H_CHUNK) + h_outer * (T * H_CHUNK) + do_base = b * (oH * T * T_s) + h_outer * (T * T_s) + ds_base = b * (oH * T * H_CHUNK * T_s) + h_outer * (T * H_CHUNK * T_s) + + # Load Hq[..., t_tile, h_in, :] -> [BLOCK_T, d_i] once per CTA + hq_ptrs = (Hq_chunk_ptr + hq_base + + rt[:, None] * (H_CHUNK * d_i) + + h_in_64 * d_i + + rdi[None, :]) + Hq_h = tl.load(hq_ptrs, mask=rt_mask[:, None], other=0.0) + + # Load W_o[..., t_tile, h_in] -> [BLOCK_T] once per CTA + wo_ptrs = W_o_chunk_ptr + wo_base + rt * H_CHUNK + h_in_64 + w_h = tl.load(wo_ptrs, mask=rt_mask, other=0.0).to(tl.float32) + + # dW_o accumulator: sum_s (h_relu * dO) -- reduced in regs + dWo_acc = tl.zeros((BLOCK_T,), dtype=tl.float32) + + for s_start in range(0, T_s, BLOCK_S): + rs = s_start + tl.arange(0, BLOCK_S) + rs_mask = rs < T_s + + # Load Hk[..., s_chunk, :] and dO[..., t_tile, s_chunk] + hk_ptrs = Hk_ptr + hk_base + rs[:, None] * d_i + rdi[None, :] + Hk_chunk = tl.load(hk_ptrs, mask=rs_mask[:, None], other=0.0) + + do_ptrs = dO_ptr + do_base + rt[:, None] * T_s + rs[None, :] + dO_chunk = tl.load( + do_ptrs, + mask=rt_mask[:, None] & rs_mask[None, :], + other=0.0, + ) + + # scores tile in registers (never lands in HBM at full size) + scores = tl.dot(Hq_h, tl.trans(Hk_chunk)) + relu_mask = scores > 0 + h_relu = tl.where(relu_mask, scores, 0.0) + + # dW_o contribution: sum_s (h_relu * dO) + dWo_acc += tl.sum(h_relu * dO_chunk, axis=1) + + # dscores tile = relu_mask * (dO * W_o) + dscores = tl.where(relu_mask, dO_chunk * w_h[:, None], 0.0) + + # Store dscores tile to HBM (bf16). Total dscores_chunk size is + # H_CHUNK x smaller than the full (B,oH,T,H,T_s) tensor. + ds_ptrs = (dscores_chunk_ptr + ds_base + + rt[:, None] * (H_CHUNK * T_s) + + h_in_64 * T_s + + rs[None, :]) + tl.store( + ds_ptrs, + dscores.to(dscores_chunk_ptr.dtype.element_ty), + mask=rt_mask[:, None] & rs_mask[None, :], + ) + + # Store dW_o[..., t_tile, h_in] + dwo_out_ptrs = dWo_chunk_ptr + wo_base + rt * H_CHUNK + h_in_64 + tl.store( + dwo_out_ptrs, + dWo_acc.to(dWo_chunk_ptr.dtype.element_ty), + mask=rt_mask, + ) + + +_score_dscores_chunk_p = extend_core.Primitive("te_indexer_score_dscores_chunk") +_score_dscores_chunk_p.multiple_results = True + + +@_score_dscores_chunk_p.def_abstract_eval +def _score_dscores_chunk_abstract(Hq_chunk, Hk, W_o_chunk, dO): + del Hk, W_o_chunk + B, oH, T, H_CHUNK, _ = Hq_chunk.shape + T_s = dO.shape[-1] + return [ + core.ShapedArray((B, oH, T, H_CHUNK, T_s), Hq_chunk.dtype), # dscores + core.ShapedArray((B, oH, T, H_CHUNK), Hq_chunk.dtype), # dW_o + ] + + +_score_dscores_chunk_p.def_impl( + functools.partial(xla.apply_primitive, _score_dscores_chunk_p) +) + + +def _score_dscores_chunk_lowering(ctx, Hq_chunk, Hk, W_o_chunk, dO): + Hq_aval = ctx.avals_in[0] + dO_aval = ctx.avals_in[3] + B, oH, T, H_CHUNK, d_i = Hq_aval.shape + T_s = dO_aval.shape[-1] + BLOCK_T = _HBWD_BLOCK_T if T >= _HBWD_BLOCK_T else T + BLOCK_S = _HBWD_BLOCK_S if T_s >= _HBWD_BLOCK_S else T_s + n_t_tiles = (T + BLOCK_T - 1) // BLOCK_T + + return triton_call_lowering( + ctx, + _score_dscores_chunk_kernel, + Hq_chunk, Hk, W_o_chunk, dO, + grid=(n_t_tiles, B * oH * H_CHUNK), + num_warps=4, + num_stages=2, + constexprs={ + "B": B, "oH": oH, "T": T, "T_s": T_s, + "H_CHUNK": H_CHUNK, "d_i": d_i, + "BLOCK_T": BLOCK_T, "BLOCK_S": BLOCK_S, + }, + ) + + +mlir.register_lowering(_score_dscores_chunk_p, _score_dscores_chunk_lowering, platform="rocm") +mlir.register_lowering(_score_dscores_chunk_p, _score_dscores_chunk_lowering, platform="cuda") + + +# --- Public score_reduce_triton with custom_vjp ------------------------------ + + +@functools.partial(jax.custom_vjp, nondiff_argnums=(3,)) +def _score_reduce_with_vjp(Hq, Hk, W_o, out_dtype): + return _score_reduce_p.bind(Hq, Hk, W_o, out_dtype=out_dtype)[0] + + +def _score_reduce_fwd(Hq, Hk, W_o, out_dtype): + out = _score_reduce_p.bind(Hq, Hk, W_o, out_dtype=out_dtype)[0] + return out, (Hq, Hk, W_o) + + +_BWD_H_CHUNK = 8 # peak (B, oH, T, H_CHUNK, T_s) tile -- bounds materialization + + +def _score_reduce_bwd(out_dtype, residuals, dO): + del out_dtype + Hq, Hk, W_o = residuals + B, oH, T, H, d_i = Hq.shape + + # Hybrid scheme with bounded materialization: + # For each h-chunk of size H_CHUNK (driven by lax.scan, NOT Python + # unroll, so intermediates are freed between iterations): + # 1. Triton kernel fuses (score recompute + relu + mask + dO*W_o + # broadcast) and writes dscores_chunk[B,oH,T,H_CHUNK,T_s] to HBM. + # h_relu is consumed in-register to also produce dWo_chunk + # without ever materializing the (B,oH,T,H,T_s) h_relu tensor. + # 2. hipBLASLt einsums on dscores_chunk give dHq_chunk and a partial + # dHk contribution. + # Peak HBM intermediate stays at H_CHUNK/H fraction of the full score. + # + # The fully-fused Triton bwd variants (v2/v3/v4) remain in this file for + # reference -- they don't materialize the score tensor either but are + # slower than the hipBLASLt-based reductions used here (~2x at 4096^2). + if H % _BWD_H_CHUNK == 0: + H_CHUNK = _BWD_H_CHUNK + else: + H_CHUNK = 1 + for c in (4, 2): + if H % c == 0: + H_CHUNK = c + break + n_chunks = H // H_CHUNK + + Hq_r = Hq.reshape(B, oH, T, n_chunks, H_CHUNK, d_i) + Wo_r = W_o.reshape(B, oH, T, n_chunks, H_CHUNK) + # Move chunk axis to leading for scan over axis 0. + Hq_s = jnp.moveaxis(Hq_r, -3, 0) # (n_chunks, B, oH, T, H_CHUNK, d_i) + Wo_s = jnp.moveaxis(Wo_r, -2, 0) # (n_chunks, B, oH, T, H_CHUNK) + + def step(dHk_acc, chunk): + Hq_c, Wo_c = chunk + # Triton: dscores_chunk + dWo_chunk; no full (B,oH,T,H,T_s) tensor + # ever exists in HBM. + dscores_c, dWo_c = _score_dscores_chunk_p.bind(Hq_c, Hk, Wo_c, dO) + dHq_c = jnp.einsum("...ths,...si->...thi", dscores_c, Hk) + dHk_c = jnp.einsum("...ths,...thi->...si", dscores_c, Hq_c) + new_dHk_acc = dHk_acc + dHk_c.astype(jnp.float32) + return new_dHk_acc, (dHq_c, dWo_c) + + init = jnp.zeros(Hk.shape, dtype=jnp.float32) + dHk_acc, (dHq_chunks, dWo_chunks) = jax.lax.scan( + step, init, (Hq_s, Wo_s), + ) + # dHq_chunks: (n_chunks, B, oH, T, H_CHUNK, d_i) + # dWo_chunks: (n_chunks, B, oH, T, H_CHUNK) + dHq = jnp.moveaxis(dHq_chunks, 0, -3).reshape(B, oH, T, H, d_i) + dWo = jnp.moveaxis(dWo_chunks, 0, -2).reshape(B, oH, T, H) + dHk = dHk_acc.astype(Hk.dtype) + + return dHq.astype(Hq.dtype), dHk, dWo.astype(W_o.dtype) + + +_score_reduce_with_vjp.defvjp(_score_reduce_fwd, _score_reduce_bwd) + + +def score_reduce_triton(Hq, Hk, W_o, *, out_dtype=None): + """Triton fused score-matmul + relu + per-(t, h) weighted H-reduction. + + Replaces the pattern: + + scores = relu(jnp.einsum("...thi,...si->...ths", Hq, Hk)) # never write + O = jnp.einsum("...ths,...th->...ts", scores, W_o) + + with a single kernel that holds the per-head score tile in registers, + avoiding the (B, oH, T, H, S) HBM round-trip an einsum+XLA chain pays. + + Differentiable via two backward kernels (FlashAttention-style: residuals + are just (Hq, Hk, W_o); the (T, H, S) score tensor is recomputed inside + backward, never materialized). + + Args: + Hq: (B, oH, T_t, H, d_i) + Hk: (B, oH, T_s, d_i) + W_o: (B, oH, T_t, H) + out_dtype: defaults to Hq.dtype. + + Returns: + O: (B, oH, T_t, T_s) + """ + if Hq.ndim != 5: + raise ValueError( + f"Hq must be rank-5 (B, oH, T_t, H, d_i); got shape {Hq.shape}" + ) + if Hk.ndim != 4: + raise ValueError( + f"Hk must be rank-4 (B, oH, T_s, d_i); got shape {Hk.shape}" + ) + if W_o.ndim != 4: + raise ValueError( + f"W_o must be rank-4 (B, oH, T_t, H); got shape {W_o.shape}" + ) + + B, oH, T_t, H, d_i = Hq.shape + Bk, oHk, T_s, d_i_k = Hk.shape + Bw, oHw, T_t_w, H_w = W_o.shape + if (Bk, oHk) != (B, oH): + raise ValueError( + f"(B, oH) mismatch: Hq has {(B, oH)}, Hk has {(Bk, oHk)}" + ) + if d_i != d_i_k: + raise ValueError(f"d_i mismatch: Hq has {d_i}, Hk has {d_i_k}") + if (Bw, oHw, T_t_w, H_w) != (B, oH, T_t, H): + raise ValueError( + f"W_o shape {W_o.shape} does not match expected " + f"(B={B}, oH={oH}, T_t={T_t}, H={H})" + ) + + if out_dtype is None: + out_dtype = Hq.dtype + + return _score_reduce_with_vjp(Hq, Hk, W_o, jnp.dtype(out_dtype)) + + +# --- Streaming top-k variant ---------------------------------------------------- +# +# Same einsum-projected (Hq, Hk, W_o) inputs, but fuses top-k indices into the +# kernel: one CTA per (B, oH, T_t) query token, score row never materialized. +# +# Algorithm (mirrors TileLang dsa_sparse_finetune/indexer_topk_reducesum): +# - Maintain a 2K-sized buffer of (score_bits, index) packed uint64 +# - Stream over T_s in BLOCK_S chunks; each chunk computes BLOCK_S new scores +# - Place chunk into buffer[K:K+BLOCK_S], zero buffer[K+BLOCK_S:2K] +# - tl.sort descending; top half is the running top-K +# - After all chunks: buffer[:K] is the answer +# +# tl.sort returns values only, so we pack (score_bits << 32) | index into uint64. +# Post-ReLU scores are >= 0, so fp32 bit pattern is monotone in value. + + +# Autotune sweep for _score_topk_kernel. +# +# BLOCK_T: number of query tokens per CTA. BLOCK_T>1 amortizes the Hk_chunk +# load across BLOCK_T queries — the single biggest lever at large T_s. At +# BLOCK_T=1 (original), each CTA reloads all of Hk for its (b, oH) slab, +# causing L2 thrash. BLOCK_T=2 halves Hk HBM traffic; BLOCK_T=4 quarters it, +# but grows per-CTA register pressure (Hq_token, top_packed, logits all +# scale with BLOCK_T). +# +# BLOCK_S knobs the inner-chunk size; bigger BLOCK_S = better matmul +# arithmetic intensity, but bigger per-CTA transient footprint +# (logits[BLOCK_S, BLOCK_T*H] fp32 + Hk_chunk[BLOCK_S, d_i] bf16). +# +# Constraint: BLOCK_S must divide K (so INNER = K // BLOCK_S is an integer +# >= 1). Configs whose BLOCK_S exceeds K or doesn't divide K are filtered +# out at lowering time — otherwise jaxlib's autotuner would time them as +# zero-work (fast) and pick a bogus winner that returns all-zero indices. +_SCORE_TOPK_CONFIGS = [ + triton.Config({"BLOCK_S": bs, "BLOCK_T": bt}, num_warps=nw, num_stages=ns) + for bt in (1, 2) + for bs in (32, 64, 128, 256) + for nw in (4, 8) + for ns in (1, 2) +] + [ + # BLOCK_T=4 only at smaller BLOCK_S — at BLOCK_S=256 the logits + # intermediate [256, 4*H=256] fp32 = 256 KB overflows reliably. + triton.Config({"BLOCK_S": bs, "BLOCK_T": 4}, num_warps=nw, num_stages=ns) + for bs in (32, 64, 128) + for nw in (4, 8) + for ns in (1, 2) +] + + +@triton.jit +def _score_topk_kernel( + Hq_ptr, # (B, oH, T_t, H, d_i) bf16 + Hk_ptr, # (B, oH, T_s, d_i) bf16 + W_o_ptr, # (B, oH, T_t, H) bf16 + Topk_idx_ptr, # (B, oH, T_t, K) int32 OUTPUT + B: tl.constexpr, + oH: tl.constexpr, + T_t: tl.constexpr, + T_s: tl.constexpr, + H: tl.constexpr, + d_i: tl.constexpr, + K: tl.constexpr, + S_PAD: tl.constexpr, + BLOCK_S: tl.constexpr, + BLOCK_T: tl.constexpr, +): + """Per-CTA: BLOCK_T consecutive query tokens, all sharing Hk loads. + + Grid: (cdiv(T_t, BLOCK_T), B * oH). Each CTA does: + - Pre-load Hq[..., rt, :, :] for BLOCK_T contiguous query tokens + - For each S chunk: load Hk_chunk ONCE, do one [BLOCK_S, d_i] @ + [d_i, BLOCK_T*H] matmul, weighted-H-reduce per T + - Maintain a single 1D top buffer of size BLOCK_T*2K, with T encoded + in the top 8 bits of each packed entry. After global sort desc, + per-T entries stay grouped together so per-T top-K can be sliced + from fixed offsets. + + Note on layout (1D vs 2D top buffer): + A 2D [BLOCK_T, 2K] top buffer with per-row sort is the natural + design, but `tl.gather + tl.sort(dim=1)` on uint64 2D tensors trips + `TritonGPUOptimizeThreadLocality` on the AMD backend (gfx950, Triton + 3.4.0). The 1D-with-encoded-T workaround sidesteps this — it pays a + ~1.5x sort-cost penalty (one sort of BLOCK_T*2K vs BLOCK_T sorts of + 2K) for BLOCK_T=2, but unblocks Hk-load amortization across queries. + """ + pid_t = tl.program_id(0) + pid_bh = tl.program_id(1) + # int64 indexing — Hq has B*oH*T*H*d_i = 4.3 B elements at T=S=4096. + b = (pid_bh // oH).to(tl.int64) + h_outer = (pid_bh % oH).to(tl.int64) + + rh = tl.arange(0, H) + rdi = tl.arange(0, d_i) + rs_chunk = tl.arange(0, BLOCK_S) + rk = tl.arange(0, K) + rt_local = tl.arange(0, BLOCK_T) + + rt = pid_t * BLOCK_T + rt_local + rt_64 = rt.to(tl.int64) + rt_mask = rt < T_t + + # Load Hq[b, h_outer, rt, :, :] -> [BLOCK_T, H, d_i]. + hq_base = b * (oH * T_t * H * d_i) + h_outer * (T_t * H * d_i) + Hq_token = tl.load( + Hq_ptr + hq_base + + rt_64[:, None, None] * (H * d_i) + + rh[None, :, None] * d_i + + rdi[None, None, :], + mask=rt_mask[:, None, None], + other=0.0, + ) + + # Load w_o[b, h_outer, rt, :] -> [BLOCK_T, H] + wo_base = b * (oH * T_t * H) + h_outer * (T_t * H) + w_o = tl.load( + W_o_ptr + wo_base + rt_64[:, None] * H + rh[None, :], + mask=rt_mask[:, None], + other=0.0, + ).to(tl.float32) + + # Flatten Hq for one big matmul per Hk_chunk: [BLOCK_T * H, d_i] -> trans + Hq_flat = tl.reshape(Hq_token, (BLOCK_T * H, d_i)) + Hq_T = tl.trans(Hq_flat) # [d_i, BLOCK_T * H] + w_o_flat = tl.reshape(w_o, (BLOCK_T * H,)) + + hk_base = b * (oH * T_s * d_i) + h_outer * (T_s * d_i) + + TOP_BUF: tl.constexpr = 2 * K + INNER: tl.constexpr = K // BLOCK_S # chunks per sort + N_OUTER: tl.constexpr = S_PAD // K # number of sorts per CTA + BIG_BUF: tl.constexpr = BLOCK_T * TOP_BUF + + # Initialize 1D top buffer with t-encoding pre-applied so per-T regions + # stay grouped after global sort. Each slot at position rb gets: + # t_pos = rb // TOP_BUF -> which T this slot belongs to + # t_enc = BLOCK_T - t_pos -> 1..BLOCK_T (never 0 → never collides with + # reserved init pattern) + # packed = (t_enc << 56) | 0 -> score=0 (sortable=0), index=0 + # Real candidates also get tagged with their t_enc; after global sort + # desc, all entries with t_enc=BLOCK_T (i.e. t=0) come first, then + # t_enc=BLOCK_T-1, etc. Within each t group, ordered by score then index. + rb = tl.arange(0, BIG_BUF) + rb_t = rb // TOP_BUF # [BIG_BUF] in [0, BLOCK_T) + rb_pos = rb % TOP_BUF # [BIG_BUF] in [0, TOP_BUF) + t_enc_per_slot = (BLOCK_T - rb_t).to(tl.uint64) + top_packed = t_enc_per_slot << 56 + + # Pre-compute the per-slot (t, pos)-to-flat-chunk-index map used in + # scatter: for each rb, identify the (t, j) in chunk_packed_flat to pull + # from. j depends on `chunk_offset` (varies per inner iter), so the + # gather index is recomputed each iter. + + for o in tl.static_range(N_OUTER): + for i in tl.static_range(INNER): + c = o * INNER + i + s_start = c * BLOCK_S + rs = s_start + rs_chunk # [BLOCK_S] + rs_mask = rs < T_s + + # Load Hk_chunk[BLOCK_S, d_i] ONCE — shared across BLOCK_T queries. + hk_ptrs = Hk_ptr + hk_base + rs[:, None] * d_i + rdi[None, :] + Hk_chunk = tl.load(hk_ptrs, mask=rs_mask[:, None], other=0.0) + + # One big matmul: [BLOCK_S, d_i] @ [d_i, BLOCK_T*H] -> [BLOCK_S, BLOCK_T*H] + logits = tl.dot(Hk_chunk, Hq_T) + logits = tl.maximum(logits, 0.0) + + # Weighted reduce over H per (s, t): + # chunk_scores[s, t] = sum_h logits[s, t*H + h] * w_o[t, h] + weighted = logits * w_o_flat[None, :] + weighted_3d = tl.reshape(weighted, (BLOCK_S, BLOCK_T, H)) + chunk_scores = tl.sum(weighted_3d, axis=2) # [BLOCK_S, BLOCK_T] + chunk_scores_T = tl.trans(chunk_scores) # [BLOCK_T, BLOCK_S] + + # Radix-flip: fp32 bit pattern -> sortable uint32 across full sign + # range (positives: flip sign bit; negatives: flip all bits). + # See https://stereopsis.com/radix.html + bits = chunk_scores_T.to(tl.uint32, bitcast=True) + sign = bits >> 31 + flip_mask = (0 - sign.to(tl.int32)).to(tl.uint32) | 0x80000000 + sortable = bits ^ flip_mask + sortable = tl.where(rs_mask[None, :], sortable, 0) + + # Pack: (t_enc<<56) | (sortable<<24) | (index in low 24 bits). + # 24-bit index supports T_s up to 16M, far above our regime. + t_enc_chunk = (BLOCK_T - rt_local).to(tl.uint64) # [BLOCK_T] + rs_2d = tl.broadcast_to(rs[None, :], (BLOCK_T, BLOCK_S)) + chunk_packed_2d = ( + (t_enc_chunk[:, None] << 56) + | (sortable.to(tl.uint64) << 24) + | rs_2d.to(tl.uint64) + ) # [BLOCK_T, BLOCK_S] + # Flatten to 1D for the scatter (1D gather + 1D sort sidesteps + # the AMD-backend bug with 2D gather+sort combos). + chunk_packed_flat = tl.reshape(chunk_packed_2d, (BLOCK_T * BLOCK_S,)) + + # Scatter into top_packed[t*TOP_BUF + K+i*BLOCK_S : ...] for each t. + # For each rb in [0, BIG_BUF): + # t = rb // TOP_BUF + # pos = rb % TOP_BUF + # in_slot = (pos >= K + i*BLOCK_S) & (pos < K + (i+1)*BLOCK_S) + # flat_idx = t * BLOCK_S + (pos - (K + i*BLOCK_S)) + chunk_offset = K + i * BLOCK_S + in_slot = (rb_pos >= chunk_offset) & (rb_pos < chunk_offset + BLOCK_S) + j = rb_pos - chunk_offset + flat_idx = tl.where(in_slot, rb_t * BLOCK_S + j, 0).to(tl.int32) + gathered = tl.gather(chunk_packed_flat, flat_idx, axis=0) + top_packed = tl.where(in_slot, gathered, top_packed) + + # 1D sort of the entire buffer. Per-T regions stay grouped via t_enc. + top_packed = tl.sort(top_packed, descending=True) + + # Extract per-T top K. After sort desc, t=0's top K is at positions + # [0, K), t=1's at [TOP_BUF, TOP_BUF+K), etc. — i.e. base = t*TOP_BUF. + out_idx = rt_local[:, None] * TOP_BUF + rk[None, :] # [BLOCK_T, K] + out_idx_flat = tl.reshape(out_idx, (BLOCK_T * K,)).to(tl.int32) + top_k_packed_flat = tl.gather(top_packed, out_idx_flat, axis=0) + top_k_packed = tl.reshape(top_k_packed_flat, (BLOCK_T, K)) + # Strip the t_enc and sortable bits, keep low 24 bits (index). + top_k_idx = (top_k_packed & 0xFFFFFF).to(tl.int32) + + out_base = b * (oH * T_t * K) + h_outer * (T_t * K) + out_ptrs = Topk_idx_ptr + out_base + rt_64[:, None] * K + rk[None, :] + tl.store(out_ptrs, top_k_idx, mask=rt_mask[:, None]) + + +_score_topk_p = extend_core.Primitive("te_indexer_score_topk_triton") +_score_topk_p.multiple_results = True + + +def _next_pow2(n): + p = 1 + while p < n: + p *= 2 + return p + + +@_score_topk_p.def_abstract_eval +def _score_topk_abstract(Hq, Hk, W_o, *, k): + del Hk, W_o + B, oH, T_t, _H, _d_i = Hq.shape + return [core.ShapedArray((B, oH, T_t, k), jnp.int32)] + + +_score_topk_p.def_impl(functools.partial(xla.apply_primitive, _score_topk_p)) + + +def _score_topk_lowering(ctx, Hq, Hk, W_o, *, k): + Hq_aval = ctx.avals_in[0] + Hk_aval = ctx.avals_in[1] + B, oH, T_t, H, d_i = Hq_aval.shape + T_s = Hk_aval.shape[2] + S_PAD = _next_pow2(T_s) + + # Build a K-filtered autotuner around the plain JIT kernel. We do this at + # lowering time (rather than decorating the kernel at definition) because + # configs with BLOCK_S > k or BLOCK_S that doesn't divide k would compile + # to a kernel where INNER = k // BLOCK_S = 0 — i.e. a no-op that's fastest + # in the autotune timing race. Filtering ensures the runtime picker only + # sees configs that actually do the work. + # + # Also filter BLOCK_T configs that don't evenly divide T_t — we mask the + # tail but unnecessary padding hurts L1/L2 efficiency. + valid_configs = [ + c for c in _SCORE_TOPK_CONFIGS + if c.kwargs["BLOCK_S"] <= k + and k % c.kwargs["BLOCK_S"] == 0 + and T_t % c.kwargs["BLOCK_T"] == 0 + ] + if not valid_configs: + raise ValueError( + f"No valid BLOCK_S/BLOCK_T config for k={k}, T_t={T_t}" + ) + + autotuned_kernel = triton.autotune( + configs=valid_configs, + key=["H", "d_i", "T_s", "K"], + )(_score_topk_kernel) + + def grid_fn(merged_kwargs): + bt = merged_kwargs.get("BLOCK_T", 1) + return (triton.cdiv(T_t, bt), B * oH) + + return triton_call_lowering( + ctx, + autotuned_kernel, + Hq, Hk, W_o, + grid=grid_fn, + constexprs={ + "B": B, "oH": oH, "T_t": T_t, "T_s": T_s, + "H": H, "d_i": d_i, + "K": k, "S_PAD": S_PAD, + }, + ) + + +mlir.register_lowering(_score_topk_p, _score_topk_lowering, platform="rocm") +mlir.register_lowering(_score_topk_p, _score_topk_lowering, platform="cuda") + + +def score_topk_triton(Hq, Hk, W_o, *, k): + """Fused score-relu-reduce + streaming top-k. + + Computes the same scores as ``score_reduce_triton`` but never materializes the + (B, oH, T_t, T_s) score matrix — instead, returns the top-k indices into the + T_s axis directly. + + Args: + Hq: (B, oH, T_t, H, d_i) + Hk: (B, oH, T_s, d_i) + W_o: (B, oH, T_t, H) + k: number of top scores to return per (b, oH, T_t) row. Must be a + power of 2 and <= T_s. + + Returns: + Topk_idx: (B, oH, T_t, k) int32 — top-k indices into T_s axis, in + descending score order. + + Notes: + Streaming: maintains a 2K candidate buffer and bitonic-sorts on each + chunk. For k >> S/8 (e.g., k=S/2), this is algorithmically slower than a + single full-row sort but matches the TileLang reference structure and + generalizes to large S without per-CTA registers scaling with S. + """ + if Hq.ndim != 5: + raise ValueError(f"Hq must be rank-5; got shape {Hq.shape}") + if Hk.ndim != 4: + raise ValueError(f"Hk must be rank-4; got shape {Hk.shape}") + if W_o.ndim != 4: + raise ValueError(f"W_o must be rank-4; got shape {W_o.shape}") + + B, oH, T_t, H, d_i = Hq.shape + Bk, oHk, T_s, d_i_k = Hk.shape + Bw, oHw, T_t_w, H_w = W_o.shape + if (Bk, oHk) != (B, oH): + raise ValueError(f"(B, oH) mismatch: Hq has {(B, oH)}, Hk has {(Bk, oHk)}") + if d_i != d_i_k: + raise ValueError(f"d_i mismatch: Hq has {d_i}, Hk has {d_i_k}") + if (Bw, oHw, T_t_w, H_w) != (B, oH, T_t, H): + raise ValueError(f"W_o shape {W_o.shape} != expected (B, oH, T_t, H)") + + if k <= 0 or (k & (k - 1)) != 0: + raise ValueError(f"k must be a positive power of 2; got {k}") + if k > T_s: + raise ValueError(f"k={k} must be <= T_s={T_s}") + + return _score_topk_p.bind(Hq, Hk, W_o, k=k)[0] diff --git a/transformer_engine/jax/triton_extensions/sparse_attention.py b/transformer_engine/jax/triton_extensions/sparse_attention.py new file mode 100644 index 000000000..28b784ac1 --- /dev/null +++ b/transformer_engine/jax/triton_extensions/sparse_attention.py @@ -0,0 +1,108 @@ +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +"""Phase-2 scaffold for the fused sparse-attention Triton primitive. + +The functional API `fused_sparse_attention_triton(query, key, value, +indexer_query, indexer_key, indexer_weights, *, k, ...)` is declared and +registered as a JAX primitive with abstract evaluation, but the kernel +body and MLIR lowering both raise NotImplementedError. + +Purpose: lock the call signature so the DSA Flax module can dispatch to +this primitive via ``backend="fused"`` today, and the real kernel can +land later without any caller-side changes. + +The composition path in ``transformer_engine.jax.sparse_attention`` +(indexer + sparse mask + DotProductAttention) is the supported phase-1 +implementation. +""" + +import functools + +import jax.numpy as jnp + +from jax import core +from jax.extend import core as extend_core +from jax.interpreters import mlir, xla + + +_fused_sparse_attention_p = extend_core.Primitive("te_fused_sparse_attention_triton") +_fused_sparse_attention_p.multiple_results = False + + +@_fused_sparse_attention_p.def_abstract_eval +def _fused_sparse_attention_abstract( + query, key, value, indexer_query, indexer_key, indexer_weights, *, k +): + """Output has the same shape/dtype as ``query`` (BSHD layout assumed).""" + del key, value, indexer_query, indexer_key, indexer_weights, k + return core.ShapedArray(query.shape, query.dtype) + + +_fused_sparse_attention_p.def_impl( + functools.partial(xla.apply_primitive, _fused_sparse_attention_p) +) + + +def _fused_sparse_attention_lowering_unavailable(ctx, *args, **kwargs): + raise NotImplementedError( + "fused_sparse_attention_triton is a phase-2 scaffold: the Triton kernel " + "has not been implemented yet. Use backend='composition' in " + "transformer_engine.jax.sparse_attention.deep_sparse_attention_core(...) " + "for the working composition path." + ) + + +mlir.register_lowering( + _fused_sparse_attention_p, + _fused_sparse_attention_lowering_unavailable, + platform="rocm", +) +mlir.register_lowering( + _fused_sparse_attention_p, + _fused_sparse_attention_lowering_unavailable, + platform="cuda", +) + + +def fused_sparse_attention_triton( + query, + key, + value, + indexer_query, + indexer_key, + indexer_weights, + *, + k: int, +): + """Fused indexer + sparse attention (phase-2 scaffold — raises NotImplementedError). + + Intended contract for the future fused kernel: + + Args: + query: (B, T_t, H, D) attention queries (BSHD) + key: (B, T_s, H_kv, D) attention keys + value: (B, T_s, H_kv, D) attention values + indexer_query: (B, T_t, H_idx, d_i) post-projection indexer Hq + indexer_key: (B, T_s, d_i) post-projection indexer Hk + indexer_weights: (B, T_t, H_idx) post-projection indexer W_o + k: number of top-k key positions per query token + + Returns: + Output of shape (B, T_t, H, D) — sparse attention output where each + query attends only to its indexer-selected top-k key positions + (intersected with the causal mask). + + The signature is intentionally minimal so phase-2 has room to grow it + (e.g. window_size, attn_bias). Add kwargs only via the function + signature — abstract_eval and lowering both already accept ``**kwargs``. + """ + return _fused_sparse_attention_p.bind( + query, + key, + value, + indexer_query, + indexer_key, + indexer_weights, + k=k, + ) diff --git a/transformer_engine/jax/triton_extensions/utils.py b/transformer_engine/jax/triton_extensions/utils.py index 6ea4092cb..1c3baf2c4 100644 --- a/transformer_engine/jax/triton_extensions/utils.py +++ b/transformer_engine/jax/triton_extensions/utils.py @@ -32,6 +32,7 @@ import hashlib import os +import tempfile import warnings from typing import Any, Callable, Mapping import zlib @@ -39,6 +40,7 @@ from jax import core import jax import jax.numpy as jnp +from transformer_engine.jax.util import is_hip_extension # Placeholder package version on PyPI that should never be used @@ -161,6 +163,12 @@ def _check_triton_compatibility(): ) from e +# AMD/HIP backend imports are additive: the NVIDIA path above is left untouched. +if is_hip_extension(): + from triton.backends.amd import compiler as cb_hip # noqa: E402 + from triton.backends.compiler import GPUTarget as _TritonGPUTarget # noqa: E402 + + __all__ = ["triton_call_lowering", "get_triton_info"] # Triton kernel cache (module-level, shared across all kernels) @@ -212,6 +220,9 @@ def get_triton_dtype(aval): jnp.dtype("float16"): "fp16", jnp.dtype("float8_e4m3fn"): "fp8e4nv", jnp.dtype("float8_e5m2"): "fp8e5", + # AMD gfx942 "FNUZ" variants — Triton calls these fp8e4b8/fp8e5b16. + jnp.dtype("float8_e4m3fnuz"): "fp8e4b8", + jnp.dtype("float8_e5m2fnuz"): "fp8e5b16", jnp.dtype("int64"): "i64", jnp.dtype("int32"): "i32", jnp.dtype("int16"): "i16", @@ -273,6 +284,22 @@ def compile_triton( if cache_key in _TRITON_KERNEL_CACHE: return _TRITON_KERNEL_CACHE[cache_key] + # AMD/HIP uses a separate compilation path; the NVIDIA path below is the + # unchanged upstream implementation. + if is_hip_extension(): + kernel = _compile_triton_hip( + kernel_fn, + signature, + constants, + num_warps, + num_stages, + num_ctas, + compute_capability, + enable_fp_fusion, + ) + _TRITON_KERNEL_CACHE[cache_key] = kernel + return kernel + # Compile kernel options = cb.CUDAOptions( num_warps=num_warps, @@ -332,6 +359,81 @@ def compile_triton( return kernel +# Track HSACO temp files for the lifetime of the process so the kernel paths +# we hand to jaxlib don't get garbage-collected. +_HSACO_TEMP_FILES: list[str] = [] + + +def _compile_triton_hip( + kernel_fn, + signature, + constants, + num_warps, + num_stages, + num_ctas, + compute_capability, + enable_fp_fusion, +): + # AMD/HIP returns an arch string like "gfx950:sramecc+:xnack-"; strip the + # target-feature suffix -> "gfx950". + arch = gpu_triton.get_arch_details(0).split(":", 1)[0] + # Mirror what triton's parse_options would do per-arch: the default + # HIPOptions.supported_fp8_dtypes is just ("fp8e5",), and constructing + # HIPOptions directly bypasses the per-arch augmentation. Set it + # explicitly so FP8 e4m3 kernels compile on gfx942/gfx950. + if arch == "gfx942": + fp8_dtypes = ("fp8e4b8", "fp8e4nv", "fp8e5", "fp8e5b16") + elif arch == "gfx950" or arch.startswith("gfx12"): + fp8_dtypes = ("fp8e4nv", "fp8e5") + else: + fp8_dtypes = ("fp8e5",) + options = cb_hip.HIPOptions( + num_warps=num_warps, + num_stages=num_stages, + num_ctas=num_ctas, + cluster_dims=(1, 1, 1), + debug=False, + enable_fp_fusion=enable_fp_fusion, + arch=arch, + supported_fp8_dtypes=fp8_dtypes, + ) + + # Mark constants as constexpr in signature (mirrors the NVIDIA path). + signature_with_constexpr = dict(signature) + for const_name in constants.keys(): + if const_name in signature_with_constexpr: + signature_with_constexpr[const_name] = "constexpr" + + src = tc.ASTSource( + fn=kernel_fn, + constexprs=constants, + signature=signature_with_constexpr, + ) + compiled = tc.compile( + src, + target=_TritonGPUTarget("hip", arch, warp_size=64), + options=options.__dict__, + ) + + # jaxlib's HIP TritonKernel ctor takes a path to an HSACO blob, not bytes. + fd, hsaco_path = tempfile.mkstemp(suffix=".hsaco", prefix=f"te_{compiled.name}_") + with os.fdopen(fd, "wb") as f: + f.write(compiled.asm["hsaco"]) + _HSACO_TEMP_FILES.append(hsaco_path) + + return gpu_triton.TritonKernel( + compiled.name, + num_warps, + compiled.metadata.shared, + hsaco_path, + str(compiled.asm.get("ttir", "")), + compute_capability, + 1, + 1, + 1, + ) + + def triton_call_lowering( ctx, kernel_fn: Callable, @@ -339,6 +441,9 @@ def triton_call_lowering( grid, input_output_aliases: Mapping[int, int] = None, constexprs: Mapping[str, Any] = None, + num_warps: int = None, + num_stages: int = None, + num_ctas: int = None, ): """Helper for MLIR lowering that calls a Triton kernel. @@ -348,7 +453,12 @@ def triton_call_lowering( ctx: MLIR lowering context kernel_fn: Triton kernel function *array_args: Input arrays (from ctx) - grid: Grid dimensions (int or tuple) + grid: Grid dimensions. Either: + * an int / 1-3 element tuple (fixed grid), OR + * a callable ``(merged_kwargs) -> tuple`` for autotuned kernels + whose grid depends on the autotune-selected meta-args + (e.g. BLOCK_T/BLOCK_S). ``merged_kwargs`` is the union of + ``constexprs`` and the per-config ``Config.kwargs``. input_output_aliases: Mapping of input to output aliases constexprs: Compile-time constants for the kernel. This includes both tl.constexpr arguments AND scalar runtime arguments (like @@ -389,23 +499,35 @@ def lowering(ctx, x, *, block_size): tensor_arg_names = [n for n in arg_names if n not in constexpr_names] signature = {n: get_triton_dtype(a) for n, a in zip(tensor_arg_names, all_avals)} - # Normalize grid to 3D - if isinstance(grid, int): - grid_tuple = (grid, 1, 1) - elif len(grid) == 1: - grid_tuple = (grid[0], 1, 1) - elif len(grid) == 2: - grid_tuple = (grid[0], grid[1], 1) + # Normalize grid to 3D. `grid` may be a callable for autotuned kernels + # whose grid depends on the per-config meta-args (BLOCK_T/BLOCK_S etc.). + grid_fn = grid if callable(grid) else None + + def _normalize_grid(g): + if isinstance(g, int): + return (g, 1, 1) + if len(g) == 1: + return (g[0], 1, 1) + if len(g) == 2: + return (g[0], g[1], 1) + return g[:3] + + if grid_fn is None: + grid_tuple = _normalize_grid(grid) else: - grid_tuple = grid[:3] + # For non-autotune fallback, evaluate with just the user constexprs. + grid_tuple = _normalize_grid(grid_fn(constexprs or {})) - # Default values for the kernel + # Default values for the kernel (used unless the caller overrides them). actual_kernel_fn = kernel_fn - num_warps = 32 - num_stages = ( - 1 # TODO(Phuong): consider if it is beneficial to expose num_warps, num_stages, num_ctas - ) - num_ctas = 1 + if num_warps is None: + num_warps = 32 + if num_stages is None: + num_stages = ( + 1 # TODO(Phuong): consider if it is beneficial to expose num_warps, num_stages, num_ctas + ) + if num_ctas is None: + num_ctas = 1 kernel_constexprs = constexprs if constexprs is not None else {} # Handle autotuned kernels - compile all configs @@ -424,6 +546,12 @@ def lowering(ctx, x, *, block_size): # Merge config kwargs with user constexprs config_constexprs = {**config.kwargs, **(constexprs if constexprs else {})} + # Per-config grid: re-evaluate grid_fn with this config's merged + # kwargs so configs that vary BLOCK_T/BLOCK_S launch at the right + # cdiv(T_t, BLOCK_T) etc. (grid_tuple is otherwise the fixed grid.) + if grid_fn is not None: + grid_tuple = _normalize_grid(grid_fn(config_constexprs)) + # Compile this config config_kernel = compile_triton( actual_kernel_fn,