From 74a6357ea7745a4af6c27dd9bfe4c187c2a50abc Mon Sep 17 00:00:00 2001 From: sunnyxyli Date: Thu, 28 May 2026 17:49:43 +0800 Subject: [PATCH 1/2] support gva --- benchmarks/bench_lightning_attn.py | 128 +++++++++++------ cula/ops/lightning_attn_sm100.py | 215 ++++++++++++++++++----------- tests/test_lightning_attn.py | 160 +++++++++++++-------- 3 files changed, 316 insertions(+), 187 deletions(-) diff --git a/benchmarks/bench_lightning_attn.py b/benchmarks/bench_lightning_attn.py index 1700144..cbd9615 100644 --- a/benchmarks/bench_lightning_attn.py +++ b/benchmarks/bench_lightning_attn.py @@ -33,6 +33,9 @@ # Custom varlen workloads python benchmarks/bench_lightning_attn.py --modes varlen --num-heads 32 64 --iterations 50 + + # GVA (value heads > Q/K heads) + python benchmarks/bench_lightning_attn.py --modes no_state h0_ht varlen --num-heads 32 --num-v-heads 64 """ import argparse @@ -80,6 +83,17 @@ def compute_decay(H, layer_idx=12, num_layers=24): return (8 / H * (1 - layer_idx / num_layers)) * torch.arange(H, dtype=torch.float32, device=DEVICE) +def expand_qk_to_value_heads(Q, K, V): + """Expand Q/K from H heads to HV heads for FLA and naive references.""" + H = Q.shape[2] + HV = V.shape[2] + assert HV >= H and HV % H == 0, f"HV ({HV}) must be >= H ({H}) and divisible by H" + if HV == H: + return Q, K + group_size = HV // H + return Q.repeat_interleave(group_size, dim=2), K.repeat_interleave(group_size, dim=2) + + @torch.no_grad() def torch_naive_lightning_attn(Q, K, V, decay, scale=1.0, initial_state=None, output_final_state=False): """Recurrent FP32 reference for lightning attention (simple_gla). @@ -87,15 +101,17 @@ def torch_naive_lightning_attn(Q, K, V, decay, scale=1.0, initial_state=None, ou O(B*T*H*D^2) — exact ground truth, all computation in FP32. """ B, T, H, D = Q.shape + HV = V.shape[2] + Q, K = expand_qk_to_value_heads(Q, K, V) q, k, v = Q.float(), K.float(), V.float() - decay_factor = torch.exp(-decay.float()) # [H] + decay_factor = torch.exp(-decay.float()) # [HV] S = ( initial_state.float().clone() if initial_state is not None - else torch.zeros(B, H, D, D, dtype=torch.float32, device=Q.device) + else torch.zeros(B, HV, D, D, dtype=torch.float32, device=Q.device) ) - O = torch.zeros(B, T, H, D, dtype=torch.float32, device=Q.device) + O = torch.zeros(B, T, HV, D, dtype=torch.float32, device=Q.device) for t in range(T): S = S * decay_factor[None, :, None, None] @@ -111,13 +127,14 @@ def torch_naive_lightning_attn(Q, K, V, decay, scale=1.0, initial_state=None, ou # ============================================================================= def run_fla(Q, K, V, decay, initial_state, output_final_state, warmup, iters): """Run FLA chunk_simple_gla_fwd (standard, non-varlen).""" + Q_fla, K_fla = expand_qk_to_value_heads(Q, K, V) g_gamma = -decay scale = 1.0 def fn(): return chunk_simple_gla_fwd( - q=Q, - k=K, + q=Q_fla, + k=K_fla, v=V, g_gamma=g_gamma, scale=scale, @@ -183,13 +200,14 @@ def fn(): def run_fla_varlen(Q, K, V, decay, cu_seqlens, warmup, iters): """Run FLA native varlen (single launch via cu_seqlens). FAIR baseline.""" + Q_fla, K_fla = expand_qk_to_value_heads(Q, K, V) g_gamma = -decay cu_long = cu_seqlens.to(torch.long) def fn(): return chunk_simple_gla_fwd( - q=Q, - k=K, + q=Q_fla, + k=K_fla, v=V, g_gamma=g_gamma, scale=1.0, @@ -207,25 +225,26 @@ def fn(): # ============================================================================= # Standard (non-varlen) benchmark # ============================================================================= -def benchmark_standard_config(B, T, H, D, layer_idx, num_layers, mode, warmup, iters): +def benchmark_standard_config(B, T, H, HV, D, layer_idx, num_layers, mode, warmup, iters): """Benchmark a single standard (non-varlen) config. mode: "no_state" — no initial/final state "h0_ht" — provide random h0 and output ht """ + assert HV >= H and HV % H == 0, f"HV ({HV}) must be >= H ({H}) and divisible by H" torch.manual_seed(42) Q = torch.randn(B, T, H, D, dtype=DTYPE, device=DEVICE) K = torch.randn(B, T, H, D, dtype=DTYPE, device=DEVICE) - V = torch.randn(B, T, H, D, dtype=DTYPE, device=DEVICE) - decay = compute_decay(H, layer_idx, num_layers) + V = torch.randn(B, T, HV, D, dtype=DTYPE, device=DEVICE) + decay = compute_decay(HV, layer_idx, num_layers) has_h0 = mode == "h0_ht" output_ht = mode == "h0_ht" - h0 = torch.randn(B, H, D, D, dtype=torch.float32, device=DEVICE) * 0.01 if has_h0 else None + h0 = torch.randn(B, HV, D, D, dtype=torch.float32, device=DEVICE) * 0.01 if has_h0 else None h0_fla = h0.clone() if h0 is not None else None h0_cute = h0.transpose(-1, -2).contiguous() if h0 is not None else None # BHVK for CuTe - result = {"B": B, "T": T, "H": H, "D": D, "mode": mode} + result = {"B": B, "T": T, "H": H, "HV": HV, "D": D, "mode": mode} ht_fla = None ht_cute = None @@ -289,20 +308,22 @@ def benchmark_standard_config(B, T, H, D, layer_idx, num_layers, mode, warmup, i # ============================================================================= # Varlen benchmark # ============================================================================= -def benchmark_varlen_config(N, seq_lens, H, D, warmup, iters, dist=""): +def benchmark_varlen_config(N, seq_lens, H, HV, D, warmup, iters, dist=""): """Benchmark a varlen config: persistent vs non-persistent vs FLA varlen.""" + assert HV >= H and HV % H == 0, f"HV ({HV}) must be >= H ({H}) and divisible by H" T = sum(seq_lens) torch.manual_seed(42) Q = torch.randn(1, T, H, D, dtype=DTYPE, device=DEVICE) K = torch.randn(1, T, H, D, dtype=DTYPE, device=DEVICE) - V = torch.randn(1, T, H, D, dtype=DTYPE, device=DEVICE) + V = torch.randn(1, T, HV, D, dtype=DTYPE, device=DEVICE) cu = torch.tensor([0] + list(np.cumsum(seq_lens)), dtype=torch.int32, device=DEVICE) - decay = compute_decay(H) + decay = compute_decay(HV) result = { "B": N, "T": T, "H": H, + "HV": HV, "D": D, "mode": "varlen", "seq_lens": seq_lens, @@ -394,7 +415,8 @@ def print_standard_header(): def print_standard_result(r): - cfg = f"B={r['B']},T={r['T']},H={r['H']}" + hv = r.get("HV", r["H"]) + cfg = f"B={r['B']},T={r['T']},H={r['H']},HV={hv}" fla = f"{r['fla_ms']:.3f}" if _valid(r.get("fla_ms", float("nan"))) else "ERR" dsl = f"{r['cutedsl_ms']:.3f}" if _valid(r.get("cutedsl_ms", float("nan"))) else "ERR" @@ -437,7 +459,8 @@ def print_varlen_header(): def print_varlen_result(r): - cfg = f"N={r['B']},T={r['T']},H={r['H']}" + hv = r.get("HV", r["H"]) + cfg = f"N={r['B']},T={r['T']},H={r['H']},HV={hv}" dist = r.get("dist", "") p_ms = f"{r['persistent_ms']:.3f}" if _valid(r.get("persistent_ms", float("nan"))) else "ERR" @@ -482,6 +505,7 @@ def run_benchmark_suite(args): warmup = args.warmup iters = args.iterations modes = args.modes + num_v_heads = getattr(args, "num_v_heads", None) print("\n" + "=" * 100) print("Lightning Attention Benchmark: CuteDSL vs FLA") @@ -490,6 +514,7 @@ def run_benchmark_suite(args): print(f" Batch sizes: {args.batch_sizes}") print(f" Seq lengths: {args.seq_lens}") print(f" Num heads: {args.num_heads}") + print(f" Num V heads: {num_v_heads or args.num_heads}") print(f" Head dim: {D}") print(f" Layer: {layer_idx}/{num_layers}") print(f" Warmup/Iters: {warmup}/{iters}") @@ -505,14 +530,18 @@ def run_benchmark_suite(args): for B in args.batch_sizes: for T in args.seq_lens: for H in args.num_heads: - total = B * T * H * D - if total > 2_147_483_648: - continue - if T > 4096 and B > 2: - continue - r = benchmark_standard_config(B, T, H, D, layer_idx, num_layers, mode, warmup, iters) - all_results.append(r) - print_standard_result(r) + for HV in num_v_heads or [H]: + if HV < H or HV % H != 0: + print(f"Skipping invalid GVA config H={H}, HV={HV}") + continue + total = B * T * HV * D + if total > 2_147_483_648: + continue + if T > 4096 and B > 2: + continue + r = benchmark_standard_config(B, T, H, HV, D, layer_idx, num_layers, mode, warmup, iters) + all_results.append(r) + print_standard_result(r) # ===================== Varlen mode ===================== if "varlen" in modes: @@ -542,22 +571,26 @@ def run_benchmark_suite(args): workloads = unique for H in args.num_heads: - print(f"\n --- H={H}, D={D} ---") - print_varlen_header() - - for N, T_total, dist in workloads: - if dist == "uniform": - seq_lens = gen_uniform(N, T_total) - elif dist == "skewed": - seq_lens = gen_skewed(N, T_total) - elif dist == "random": - seq_lens = gen_random(N, T_total) - else: - raise ValueError(f"Unknown dist: {dist}") + for HV in num_v_heads or [H]: + if HV < H or HV % H != 0: + print(f"Skipping invalid GVA config H={H}, HV={HV}") + continue + print(f"\n --- H={H}, HV={HV}, D={D} ---") + print_varlen_header() + + for N, T_total, dist in workloads: + if dist == "uniform": + seq_lens = gen_uniform(N, T_total) + elif dist == "skewed": + seq_lens = gen_skewed(N, T_total) + elif dist == "random": + seq_lens = gen_random(N, T_total) + else: + raise ValueError(f"Unknown dist: {dist}") - r = benchmark_varlen_config(N, seq_lens, H, D, warmup, iters, dist=dist) - all_results.append(r) - print_varlen_result(r) + r = benchmark_varlen_config(N, seq_lens, H, HV, D, warmup, iters, dist=dist) + all_results.append(r) + print_varlen_result(r) # ===================== Summary ===================== print(f"\n{'=' * 100}") @@ -662,7 +695,7 @@ def plot_results(all_results, modes): if not hr: ax.set_title("varlen (no data)") continue - labels = [f"N{r['B']}T{r['T']}\n{r.get('dist', '')[:3]}" for r in hr] + labels = [f"N{r['B']}T{r['T']}H{r['H']}HV{r.get('HV', r['H'])}\n{r.get('dist', '')[:3]}" for r in hr] p_ms = [r["persistent_ms"] for r in hr] np_ms = [r["nonpersistent_ms"] for r in hr] fla_ms = [r["fla_varlen_ms"] if _valid(r.get("fla_varlen_ms", float("nan"))) else 0 for r in hr] @@ -676,7 +709,7 @@ def plot_results(all_results, modes): if not hr: ax.set_title(f"{mode} (no data)") continue - labels = [f"B{r['B']}T{r['T']}H{r['H']}" for r in hr] + labels = [f"B{r['B']}T{r['T']}H{r['H']}HV{r.get('HV', r['H'])}" for r in hr] fla = [r["fla_ms"] for r in hr] dsl = [r["cutedsl_ms"] for r in hr] x = np.arange(len(labels)) @@ -703,6 +736,7 @@ def plot_results(all_results, modes): def generate_report(all_results, modes, args): from datetime import datetime + num_v_heads = getattr(args, "num_v_heads", None) path = os.path.join(os.path.dirname(__file__), "benchmark_report.md") with open(path, "w") as f: f.write("# Lightning Attention Benchmark Report\n\n") @@ -712,6 +746,7 @@ def generate_report(all_results, modes, args): f.write(f"- Batch sizes: {args.batch_sizes}\n") f.write(f"- Seq lengths: {args.seq_lens}\n") f.write(f"- Num heads: {args.num_heads}\n") + f.write(f"- Num V heads: {num_v_heads or args.num_heads}\n") f.write(f"- Head dim: {args.head_dim}\n") f.write(f"- Layer: {args.layer_idx}/{args.num_layers}\n") f.write(f"- Warmup/Iters: {args.warmup}/{args.iterations}\n\n") @@ -723,8 +758,8 @@ def generate_report(all_results, modes, args): f.write(f"## Mode: {mode}\n\n") if mode == "varlen": - f.write("| N | T | Dist | Persist(ms) | NonPer(ms) | FLA_vl(ms) | P/NP | P/FLAvl | O diff | ht diff |\n") - f.write("|---|---|------|-------------|------------|------------|------|---------|--------|--------|\n") + f.write("| N | T | H | HV | Dist | Persist(ms) | NonPer(ms) | FLA_vl(ms) | P/NP | P/FLAvl | O diff | ht diff |\n") + f.write("|---|---|---|----|------|-------------|------------|------------|------|---------|--------|--------|\n") for r in mr: p = f"{r['persistent_ms']:.3f}" if _valid(r.get("persistent_ms", float("nan"))) else "-" np_ = f"{r['nonpersistent_ms']:.3f}" if _valid(r.get("nonpersistent_ms", float("nan"))) else "-" @@ -736,7 +771,7 @@ def generate_report(all_results, modes, args): od = f"{r['p_vs_np_O_diff']:.1e}" if not np.isnan(r.get("p_vs_np_O_diff", float("nan"))) else "-" hd = f"{r['p_vs_np_ht_diff']:.1e}" if not np.isnan(r.get("p_vs_np_ht_diff", float("nan"))) else "-" f.write( - f"| {r['B']} | {r['T']} | {r.get('dist', '')} | {p} | {np_} | {fla_vl} | {pvnp} | {pvfla_vl} | {od} | {hd} |\n" + f"| {r['B']} | {r['T']} | {r['H']} | {r.get('HV', r['H'])} | {r.get('dist', '')} | {p} | {np_} | {fla_vl} | {pvnp} | {pvfla_vl} | {od} | {hd} |\n" ) else: has_ht = mode == "h0_ht" @@ -753,7 +788,7 @@ def generate_report(all_results, modes, args): "|--------|---------|-------------|---------|---------------------------|----------------------------|\n" ) for r in mr: - cfg = f"B={r['B']},T={r['T']},H={r['H']}" + cfg = f"B={r['B']},T={r['T']},H={r['H']},HV={r.get('HV', r['H'])}" sp = f"{r['speedup']:.2f}x" if _valid(r.get("speedup", float("nan"))) else "-" fla = f"{r['fla_ms']:.3f}" if _valid(r.get("fla_ms", float("nan"))) else "-" dsl = f"{r['cutedsl_ms']:.3f}" if _valid(r.get("cutedsl_ms", float("nan"))) else "-" @@ -838,6 +873,7 @@ def parse_args(): "--seq-lens", type=int, nargs="+", default=[256, 1024, 4096, 8192, 32768], help="Sequence lengths for standard modes" ) p.add_argument("--num-heads", type=int, nargs="+", default=[64], help="Number of heads to test") + p.add_argument("--num-v-heads", type=int, nargs="+", default=None, help="Number of value heads to test (default: same as H)") p.add_argument("--head-dim", type=int, default=128) p.add_argument("--layer-idx", type=int, default=12) p.add_argument("--num-layers", type=int, default=24) diff --git a/cula/ops/lightning_attn_sm100.py b/cula/ops/lightning_attn_sm100.py index 8a6b204..3e516f9 100644 --- a/cula/ops/lightning_attn_sm100.py +++ b/cula/ops/lightning_attn_sm100.py @@ -107,7 +107,8 @@ class LinearAttentionChunkwiseDecay: chunk_size: Size of each attention chunk (default: 64) acc_dtype: Accumulator data type for all MMA computations (default: Float32) io_dtype: Input/output data type (default: BFloat16) - H: Number of attention heads + H: Number of Q/K heads + HV: Number of V/O heads. HV > H enables GVA. K: Key head dimension (must be 128) V: Value head dimension (must be 128) scale: Scaling factor for queries @@ -121,6 +122,7 @@ def __init__( has_initial_state: bool = False, output_final_state: bool = False, H: int = 64, + HV: int | None = None, K: int = 128, V: int = 128, scale: float = 1.0, @@ -129,6 +131,8 @@ def __init__( use_fast_math: bool = True, ): assert K == 128 and V == 128, f"K and V must both be 128, got K={K}, V={V}" + HV = H if HV is None else HV + assert HV >= H and HV % H == 0, f"HV ({HV}) must be >= H ({H}) and divisible by H" assert_blackwell() self.use_fast_math = use_fast_math self.chunk_size = chunk_size @@ -145,6 +149,7 @@ def __init__( self.has_initial_state = has_initial_state self.output_final_state = output_final_state self.H = H + self.HV = HV self.K = K self.V = V self.D = K # Internal shorthand: K == V == D @@ -349,16 +354,16 @@ def __call__( (zero-copy C-level dlpack). Pass None for initial_state_in / final_state_in when has_initial_state / output_final_state is False. - scale, H, D are compile-time constants stored in self.__init__. + scale, H, HV, D are compile-time constants stored in self.__init__. Args: q_in: Query tensor [B, S, H, D] or [1, T, H, D] for varlen k_in: Key tensor [B, S, H, D] or [1, T, H, D] for varlen - v_in: Value tensor [B, S, H, D] or [1, T, H, D] for varlen - o_in: Output tensor [B, S, H, D] or [1, T, H, D] for varlen - decay_in: Per-head decay tensor [H] (FP32) - initial_state_in: Initial state [B, H, D, D] or state pool [pool, H, D, D] (FP32) - final_state_in: Final state [B, H, D, D] (FP32) or None (varlen uses INPLACE_UPDATE) + v_in: Value tensor [B, S, HV, D] or [1, T, HV, D] for varlen + o_in: Output tensor [B, S, HV, D] or [1, T, HV, D] for varlen + decay_in: Per-value-head decay tensor [HV] (FP32) + initial_state_in: Initial state [B, HV, D, D] or state pool [pool, HV, D, D] (FP32) + final_state_in: Final state [B, HV, D, D] (FP32) or None (varlen uses INPLACE_UPDATE) cu_seqlens_in: [N+1] int32 cumulative sequence lengths (varlen only) initial_state_indices_in: [N] int32 indices into state pool (varlen only) problem_size: (N, T) for varlen or (B, S) dynamic problem dimensions @@ -366,6 +371,7 @@ def __call__( """ B, S = problem_size H = self.H + HV = self.HV D = self.D # Setup attributes @@ -373,16 +379,16 @@ def __call__( self.cta_group = tcgen05.CtaGroup.ONE - # It's ok since torch tensor is row major, hence we've layout=(B,S,H,D):(DHS, DH, D, 1). + # It's ok since torch tensor is row major, hence we've layout=(B,S,H/HV,D):(D*heads, DH, D, 1). # Below are just permutation tricks to ease the later processing. - # For varlen: input is [1, T, H, D] → view as (T, D, H) with stride (D*H, 1, D) - # For non-varlen: input is [B, S, H, D] → view as (S, D, (H,B)) + # For varlen: Q/K are [1,T,H,D] -> (T,D,H); V/O are [1,T,HV,D] -> (D,T,HV) + # For non-varlen: Q/K use (S,D,(H,B)); V/O use (D,S,(HV,B)). if cutlass.const_expr(self.is_varlen): # Varlen: B=N (num_seqs), S=T (total_tokens), no batch stride q_layout = cute.make_layout((S, D, H), stride=(D * H, 1, D)) k_layout = cute.make_layout((S, D, H), stride=(D * H, 1, D)) - v_layout = cute.make_layout((D, S, H), stride=(1, D * H, D)) - o_layout = cute.make_layout((D, S, H), stride=(1, D * H, D)) + v_layout = cute.make_layout((D, S, HV), stride=(1, D * HV, D)) + o_layout = cute.make_layout((D, S, HV), stride=(1, D * HV, D)) else: q_layout = cute.make_layout( (S, D, (H, B)), @@ -393,27 +399,27 @@ def __call__( stride=(D * H, 1, (D, D * H * S)), ) v_layout = cute.make_layout( - (D, S, (H, B)), - stride=(1, D * H, (D, D * H * S)), + (D, S, (HV, B)), + stride=(1, D * HV, (D, D * HV * S)), ) o_layout = cute.make_layout( - (D, S, (H, B)), - stride=(1, D * H, (D, D * H * S)), + (D, S, (HV, B)), + stride=(1, D * HV, (D, D * HV * S)), ) q = cute.make_tensor(q_in.iterator, q_layout) k = cute.make_tensor(k_in.iterator, k_layout) v = cute.make_tensor(v_in.iterator, v_layout) o = cute.make_tensor(o_in.iterator, o_layout) - # Initial state / final state: [B, H, D, D] in BHVK layout (K-contiguous) - # CuTe shape (V, K, (H, B)) with strides (D, 1, ...) for K-contiguous access. + # Initial state / final state: [B, HV, D, D] in BHVK layout (K-contiguous) + # CuTe shape (V, K, (HV, B)) with strides (D, 1, ...) for K-contiguous access. # When has_initial_state / output_final_state is False, None is passed # and the parameter is eliminated at compile time via const_expr guards. - # For varlen: state pool is [pool_size, H, D, D]. We use B (=N) as the + # For varlen: state pool is [pool_size, HV, D, D]. We use B (=N) as the # pool dimension — strides are correct regardless of actual pool_size. fstate_layout = cute.make_layout( - (D, D, (H, B)), - stride=(D, 1, (D * D, D * D * H)), + (D, D, (HV, B)), + stride=(D, 1, (D * D, D * D * HV)), ) if cutlass.const_expr(self.has_initial_state): initial_state = cute.make_tensor(initial_state_in.iterator, fstate_layout) @@ -739,8 +745,8 @@ class SharedStorage: sm_count = _torch.cuda.get_device_properties(0).multi_processor_count self.grid = (sm_count, 1, 1) elif cutlass.const_expr(self.is_varlen): - # Varlen grid: (1, H, N) where B = N = num_sequences - self.grid = (1, H, B) + # Varlen grid: (1, HV, N) where B = N = num_sequences + self.grid = (1, HV, B) else: self.grid = self._compute_grid( o_shape=cute.shape(o), @@ -1001,26 +1007,30 @@ def kernel( B, S = problem_size H = self.H + HV = self.HV D = self.D C = self.chunk_size scale = cutlass.Float32(self.scale) + qk_group_size = HV // H # ===================== Block indices ===================== if cutlass.const_expr(self.is_varlen): if cutlass.const_expr(self.persistent): # 1D grid work decode: persistent (grid=SM_count) - total_work_units = H * B + total_work_units = HV * B num_iters = Int32(0) # not used, while loop controls iteration # Pre-initialize variables reassigned inside persistent loop (CuTe DSL requirement) hidx = Int32(0) + i_h = Int32(0) bidx = Int32(0) bos = Int32(0) eos = Int32(0) seq_len = Int32(0) state_idx = Int32(0) else: - # Non-persistent varlen: 3D grid (1, H, N) + # Non-persistent varlen: 3D grid (1, HV, N) (_, hidx, bidx) = cute.arch.block_idx() + i_h = hidx // qk_group_size bos = cu_seqlens[bidx] eos = cu_seqlens[bidx + 1] seq_len = eos - bos @@ -1028,6 +1038,7 @@ def kernel( num_iters = Int32(1) else: (_, hidx, bidx) = cute.arch.block_idx() + i_h = hidx // qk_group_size seq_len = S state_idx = bidx num_iters = Int32(1) @@ -1051,7 +1062,7 @@ def kernel( block_decay = Float32(0.0) else: # Non-varlen and non-persistent varlen: hidx known at CTA start - decay_tensor = cute.make_tensor(decay, cute.make_layout(H)) + decay_tensor = cute.make_tensor(decay, cute.make_layout(HV)) decay_s = decay_tensor[hidx] # Block-level decay: λ^C for inter-chunk state accumulation block_decay = cute.exp(-decay_s * cutlass.Float32(C), fastmath=self.use_fast_math) @@ -1252,8 +1263,9 @@ def kernel( while should_continue: # --- Work decode (persistent only) --- if cutlass.const_expr(self.is_varlen and self.persistent): - hidx = work_idx % H - bidx = work_idx // H + hidx = work_idx % HV + i_h = hidx // qk_group_size + bidx = work_idx // HV bos = cu_seqlens[bidx] eos = cu_seqlens[bidx + 1] seq_len = eos - bos @@ -1266,7 +1278,7 @@ def kernel( tma_tensor_q_use = cute.domain_offset((bos, 0, 0), tma_tensor_q) # K: (S, D, H) → offset S (mode 0) by bos tma_tensor_k_use = cute.domain_offset((bos, 0, 0), tma_tensor_k) - # V: (D, S, H) → offset S (mode 1) by bos + # V: (D, S, HV) → offset S (mode 1) by bos tma_tensor_v_use = cute.domain_offset((0, bos, 0), tma_tensor_v) else: tma_tensor_q_use = tma_tensor_q @@ -1282,7 +1294,7 @@ def kernel( self.qk_mma_tiler, qk_tiled_mma, operand_mode="A", - hidx=hidx, + hidx=i_h, bidx=bidx, debug_name="Q", ) @@ -1294,7 +1306,7 @@ def kernel( self.qk_mma_tiler, qk_tiled_mma, operand_mode="B", - hidx=hidx, + hidx=i_h, bidx=bidx, debug_name="K", ) @@ -1384,7 +1396,7 @@ def kernel( while should_continue: # --- Work decode (MMA only needs seq_len) --- if cutlass.const_expr(self.is_varlen and self.persistent): - bidx_mma = work_idx // H + bidx_mma = work_idx // HV seq_len = cu_seqlens[bidx_mma + 1] - cu_seqlens[bidx_mma] for chunk_start in cutlass.range(0, seq_len, C, unroll=0): @@ -1699,8 +1711,9 @@ def kernel( while should_continue: # --- Work decode (persistent only) --- if cutlass.const_expr(self.is_varlen and self.persistent): - hidx = work_idx % H - bidx = work_idx // H + hidx = work_idx % HV + i_h = hidx // qk_group_size + bidx = work_idx // HV bos = cu_seqlens[bidx] eos = cu_seqlens[bidx + 1] seq_len = eos - bos @@ -1708,7 +1721,7 @@ def kernel( # Load per-head decay parameter to register (s_h > 0) # For persistent: hidx was decoded above; for non-persistent: hidx from block_idx - decay_tensor_cuda = cute.make_tensor(decay, cute.make_layout(H)) + decay_tensor_cuda = cute.make_tensor(decay, cute.make_layout(HV)) decay_s_cuda = decay_tensor_cuda[hidx] block_decay = cute.exp(-decay_s_cuda * cutlass.Float32(C), fastmath=self.use_fast_math) @@ -2032,8 +2045,8 @@ def kernel( while should_continue: # --- Work decode (persistent only) --- if cutlass.const_expr(self.is_varlen and self.persistent): - hidx = work_idx % H - bidx = work_idx // H + hidx = work_idx % HV + bidx = work_idx // HV bos = cu_seqlens[bidx] eos = cu_seqlens[bidx + 1] seq_len = eos - bos @@ -2081,14 +2094,14 @@ def kernel( tOrO = cute.make_fragment_like(tOsO, self.io_dtype) cute.autovec_copy(tOsO, tOrO) - o_chunk_raw = o_tensor.iterator + (bos + chunk_start) * D * H + hidx * D + o_chunk_raw = o_tensor.iterator + (bos + chunk_start) * D * HV + hidx * D o_chunk_ptr = cute.make_ptr( self.io_dtype, o_chunk_raw.toint(), cute.AddressSpace.gmem, assumed_align=16, ) - o_stride_c = D * H + o_stride_c = D * HV gO_chunk = cute.make_tensor( o_chunk_ptr, cute.make_layout( @@ -2803,11 +2816,22 @@ def make_thread_cooperative_group(size: int): # Compile cache + TVM-FFI API # --------------------------------------------------------------------------- -# Internal cache: maps (has_initial_state, output_final_state, H, D, scale, chunk_size) → compiled_fn +# Internal cache: maps (has_initial_state, output_final_state, H, HV, D, scale, chunk_size) → compiled_fn _kernel_cache: dict = {} -def _compile_single_variant(has_initial_state, output_final_state, H, D, scale, chunk_size): +def _normalize_gva_decay(decay: torch.Tensor, H: int, HV: int) -> torch.Tensor: + """Return a contiguous [HV] decay tensor, accepting [H] decay for grouped cases.""" + if decay.ndim != 1: + raise ValueError(f"decay must be a 1D tensor, got shape {tuple(decay.shape)}") + if decay.shape[0] == HV: + return decay.contiguous() + if decay.shape[0] == H and HV != H: + return decay.repeat_interleave(HV // H).contiguous() + raise ValueError(f"decay must have shape ({HV},) or ({H},), got {tuple(decay.shape)}") + + +def _compile_single_variant(has_initial_state, output_final_state, H, HV, D, scale, chunk_size): """Compile one kernel variant. Returns the compiled TVM-FFI callable. Uses make_fake_compact_tensor and make_fake_stream for compilation with @@ -2822,6 +2846,7 @@ def _compile_single_variant(has_initial_state, output_final_state, H, D, scale, has_initial_state=has_initial_state, output_final_state=output_final_state, H=H, + HV=HV, K=D, V=D, scale=scale, @@ -2831,7 +2856,7 @@ def _compile_single_variant(has_initial_state, output_final_state, H, D, scale, sym_b = cute.sym_int() sym_s = cute.sym_int() - # Q, K, V, O: (B, S, H, D) row-major bf16 + # Q/K: (B, S, H, D); V/O: (B, S, HV, D) row-major bf16 q_fake = make_fake_compact_tensor( cutlass.BFloat16, (sym_b, sym_s, H, D), @@ -2846,29 +2871,29 @@ def _compile_single_variant(has_initial_state, output_final_state, H, D, scale, ) v_fake = make_fake_compact_tensor( cutlass.BFloat16, - (sym_b, sym_s, H, D), + (sym_b, sym_s, HV, D), stride_order=(3, 2, 1, 0), assumed_align=128, ) o_fake = make_fake_compact_tensor( cutlass.BFloat16, - (sym_b, sym_s, H, D), + (sym_b, sym_s, HV, D), stride_order=(3, 2, 1, 0), assumed_align=128, ) - # decay: (H,) float32 + # decay: (HV,) float32 decay_fake = make_fake_compact_tensor( cutlass.Float32, - (H,), + (HV,), assumed_align=128, ) - # initial_state / final_state: (B, H, D, D) float32 or None + # initial_state / final_state: (B, HV, D, D) float32 or None h0_fake = ( make_fake_compact_tensor( cutlass.Float32, - (sym_b, H, D, D), + (sym_b, HV, D, D), stride_order=(3, 2, 1, 0), assumed_align=128, ) @@ -2878,7 +2903,7 @@ def _compile_single_variant(has_initial_state, output_final_state, H, D, scale, ht_fake = ( make_fake_compact_tensor( cutlass.Float32, - (sym_b, H, D, D), + (sym_b, HV, D, D), stride_order=(3, 2, 1, 0), assumed_align=128, ) @@ -2929,7 +2954,7 @@ def _compile_single_variant(has_initial_state, output_final_state, H, D, scale, return compiled_fn -def _get_compiled_kernel(has_initial_state, output_final_state, H, D, scale, chunk_size): +def _get_compiled_kernel(has_initial_state, output_final_state, H, HV, D, scale, chunk_size): """Get a compiled kernel with on-demand (lazy) compilation. Each variant is compiled exactly once and cached. Compilation is deferred @@ -2938,14 +2963,15 @@ def _get_compiled_kernel(has_initial_state, output_final_state, H, D, scale, chu where a subsequent cute.compile can invalidate previously compiled but not-yet-executed functions. - Cache key: (has_initial_state, output_final_state, H, D, scale, chunk_size, USE_FAST_MATH) + Cache key: (has_initial_state, output_final_state, H, HV, D, scale, chunk_size, USE_FAST_MATH) """ - key = (has_initial_state, output_final_state, H, D, scale, chunk_size, USE_FAST_MATH) + key = (has_initial_state, output_final_state, H, HV, D, scale, chunk_size, USE_FAST_MATH) if key not in _kernel_cache: _kernel_cache[key] = _compile_single_variant( has_initial_state, output_final_state, H, + HV, D, scale, chunk_size, @@ -2971,23 +2997,34 @@ def lightning_attn_fwd( sym_int() is used for B and S so a single compilation handles all batch-size / sequence-length combinations. - Cache key: (has_initial_state, output_final_state, H, D, scale, chunk_size) + Cache key: (has_initial_state, output_final_state, H, HV, D, scale, chunk_size) Args: Q: (B, S, H, D) bf16 query K: (B, S, H, D) bf16 key - V: (B, S, H, D) bf16 value - decay: (H,) f32 per-head decay coefficients + V: (B, S, HV, D) bf16 value. HV > H enables GVA. + decay: (HV,) f32 per-value-head decay coefficients; (H,) is accepted and expanded scale: attention scale factor (default: 1.0) - initial_state: (B, H, D, D) f32 initial state in BHVK layout, or None + initial_state: (B, HV, D, D) f32 initial state in BHVK layout, or None output_final_state: whether to output final state chunk_size: chunk size (default: 64) Returns: - (O, ht): output tensor (B,S,H,D) bf16, final state (B,H,D,D) f32 in BHVK layout or None + (O, ht): output tensor (B,S,HV,D) bf16, final state (B,HV,D,D) f32 in BHVK layout or None """ B, S, H, D = Q.shape - O = torch.zeros_like(Q) + if K.shape != Q.shape: + raise ValueError(f"K must have the same shape as Q, got K={tuple(K.shape)}, Q={tuple(Q.shape)}") + if V.ndim != 4 or V.shape[0] != B or V.shape[1] != S or V.shape[3] != D: + raise ValueError(f"V must have shape (B, S, HV, D), got {tuple(V.shape)}") + HV = V.shape[2] + if HV < H or HV % H != 0: + raise ValueError(f"HV ({HV}) must be >= H ({H}) and divisible by H") + decay = _normalize_gva_decay(decay, H, HV) + if initial_state is not None and initial_state.shape != (B, HV, D, D): + raise ValueError(f"initial_state must have shape {(B, HV, D, D)}, got {tuple(initial_state.shape)}") + + O = torch.zeros_like(V) has_initial_state = initial_state is not None @@ -2995,13 +3032,14 @@ def lightning_attn_fwd( has_initial_state, output_final_state, H, + HV, D, scale, chunk_size, ) if output_final_state: - ht = torch.zeros(B, H, D, D, dtype=torch.float32, device=Q.device) + ht = torch.zeros(B, HV, D, D, dtype=torch.float32, device=Q.device) else: ht = None @@ -3036,7 +3074,7 @@ def lightning_attn_fwd( _varlen_kernel_cache: dict = {} -def _compile_single_variant_varlen(H, D, scale, chunk_size, persistent=True): +def _compile_single_variant_varlen(H, HV, D, scale, chunk_size, persistent=True): """Compile one varlen kernel variant. Returns the compiled TVM-FFI callable. Varlen kernel always has initial state and output_final_state (INPLACE_UPDATE). @@ -3048,6 +3086,7 @@ def _compile_single_variant_varlen(H, D, scale, chunk_size, persistent=True): has_initial_state=True, output_final_state=True, H=H, + HV=HV, K=D, V=D, scale=scale, @@ -3059,7 +3098,7 @@ def _compile_single_variant_varlen(H, D, scale, chunk_size, persistent=True): sym_n = cute.sym_int() # N: number of sequences sym_t = cute.sym_int() # T: total packed tokens - # Q, K, V, O: [1, T, H, D] row-major bf16 + # Q/K: [1, T, H, D]; V/O: [1, T, HV, D] row-major bf16 # For varlen, B=1 in the physical tensor but we view as (T, D, H) q_fake = make_fake_compact_tensor( cutlass.BFloat16, @@ -3075,29 +3114,29 @@ def _compile_single_variant_varlen(H, D, scale, chunk_size, persistent=True): ) v_fake = make_fake_compact_tensor( cutlass.BFloat16, - (1, sym_t, H, D), + (1, sym_t, HV, D), stride_order=(3, 2, 1, 0), assumed_align=128, ) o_fake = make_fake_compact_tensor( cutlass.BFloat16, - (1, sym_t, H, D), + (1, sym_t, HV, D), stride_order=(3, 2, 1, 0), assumed_align=128, ) - # decay: (H,) float32 + # decay: (HV,) float32 decay_fake = make_fake_compact_tensor( cutlass.Float32, - (H,), + (HV,), assumed_align=128, ) - # State pool: [pool_size, H, D, D] float32 — always present for varlen + # State pool: [pool_size, HV, D, D] float32 — always present for varlen # Use sym_n as pool dimension (actual pool may be larger, strides are correct) h0_fake = make_fake_compact_tensor( cutlass.Float32, - (sym_n, H, D, D), + (sym_n, HV, D, D), stride_order=(3, 2, 1, 0), assumed_align=128, ) @@ -3149,15 +3188,16 @@ def _compile_single_variant_varlen(H, D, scale, chunk_size, persistent=True): return compiled_fn -def _get_compiled_kernel_varlen(H, D, scale, chunk_size, persistent=True): +def _get_compiled_kernel_varlen(H, HV, D, scale, chunk_size, persistent=True): """Get a compiled varlen kernel with on-demand compilation. - Cache key: (H, D, scale, chunk_size, persistent, USE_FAST_MATH) + Cache key: (H, HV, D, scale, chunk_size, persistent, USE_FAST_MATH) """ - key = (H, D, scale, chunk_size, persistent, USE_FAST_MATH) + key = (H, HV, D, scale, chunk_size, persistent, USE_FAST_MATH) if key not in _varlen_kernel_cache: _varlen_kernel_cache[key] = _compile_single_variant_varlen( H, + HV, D, scale, chunk_size, @@ -3191,11 +3231,11 @@ def lightning_attn_fwd_varlen( Args: Q: (1, T, H, D) bf16 query — packed tokens from all sequences K: (1, T, H, D) bf16 key - V: (1, T, H, D) bf16 value - decay: (H,) f32 per-head decay coefficients + V: (1, T, HV, D) bf16 value. HV > H enables GVA. + decay: (HV,) f32 per-value-head decay coefficients; (H,) is accepted and expanded cu_seqlens: (N+1,) int32 cumulative sequence lengths scale: attention scale factor (default: 1.0) - state_pool: (pool_size, H, D, D) f32 state pool in BHVK layout, or None + state_pool: (pool_size, HV, D, D) f32 state pool in BHVK layout, or None If None, a zero state pool is allocated with pool_size=N. States are updated in-place (INPLACE_UPDATE). initial_state_indices: (N,) int32 indices into state_pool per sequence. @@ -3203,15 +3243,25 @@ def lightning_attn_fwd_varlen( chunk_size: chunk size (default: 64) Returns: - (O, state_pool): output tensor (1,T,H,D) bf16, updated state pool (pool_size,H,D,D) f32 + (O, state_pool): output tensor (1,T,HV,D) bf16, updated state pool (pool_size,HV,D,D) f32 """ _, T, H, D = Q.shape + if K.shape != Q.shape: + raise ValueError(f"K must have the same shape as Q, got K={tuple(K.shape)}, Q={tuple(Q.shape)}") + if V.ndim != 4 or V.shape[0] != 1 or V.shape[1] != T or V.shape[3] != D: + raise ValueError(f"V must have shape (1, T, HV, D), got {tuple(V.shape)}") + HV = V.shape[2] + if HV < H or HV % H != 0: + raise ValueError(f"HV ({HV}) must be >= H ({H}) and divisible by H") + decay = _normalize_gva_decay(decay, H, HV) N = cu_seqlens.shape[0] - 1 - O = torch.zeros_like(Q) + O = torch.zeros_like(V) # Allocate state pool if not provided if state_pool is None: - state_pool = torch.zeros(N, H, D, D, dtype=torch.float32, device=Q.device) + state_pool = torch.zeros(N, HV, D, D, dtype=torch.float32, device=Q.device) + elif state_pool.ndim != 4 or state_pool.shape[1:] != (HV, D, D): + raise ValueError(f"state_pool must have shape (pool_size, {HV}, {D}, {D}), got {tuple(state_pool.shape)}") # Default indices: identity mapping if initial_state_indices is None: @@ -3221,7 +3271,7 @@ def lightning_attn_fwd_varlen( cu_seqlens = cu_seqlens.to(torch.int32) initial_state_indices = initial_state_indices.to(torch.int32) - compiled_fn = _get_compiled_kernel_varlen(H, D, scale, chunk_size, persistent=persistent) + compiled_fn = _get_compiled_kernel_varlen(H, HV, D, scale, chunk_size, persistent=persistent) # Workspace for persistent kernel atomic counter (zeroed before each call) workspace = torch.zeros(1, dtype=torch.int32, device=Q.device) @@ -3253,6 +3303,7 @@ def main(): parser.add_argument("--batch_size", type=int, default=2, help="Batch size") parser.add_argument("--seq_len", type=int, default=4096, help="Sequence length") parser.add_argument("--num_heads", type=int, default=64, help="Number of heads") + parser.add_argument("--num_v_heads", type=int, default=None, help="Number of value heads (default: num_heads)") parser.add_argument("--head_dim", type=int, default=128, help="Head dimension") parser.add_argument("--chunk_size", type=int, default=64, help="Chunk size") parser.add_argument("--decay", type=float, default=0.95, help="Decay factor") @@ -3267,6 +3318,7 @@ def main(): print(f" Batch size: {args.batch_size}") print(f" Sequence length: {args.seq_len}") print(f" Number of heads: {args.num_heads}") + print(f" Number of value heads: {args.num_v_heads or args.num_heads}") print(f" Head dimension: {args.head_dim}") print(f" Chunk size: {args.chunk_size}") print(f" Decay factor: {args.decay}") @@ -3281,14 +3333,15 @@ def main(): # Create inputs B, S, H, D = args.batch_size, args.seq_len, args.num_heads, args.head_dim + HV = args.num_v_heads if args.num_v_heads is not None else H # Input tensors in format [B, S, H, D] Q = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16) K = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16) - V = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16) + V = torch.randn(B, S, HV, D, device="cuda", dtype=torch.bfloat16) - # Per-head decay coefficients [H] - decay = torch.full((H,), args.decay, device="cuda", dtype=torch.float32) + # Per-value-head decay coefficients [HV] + decay = torch.full((HV,), args.decay, device="cuda", dtype=torch.float32) scale = 1.0 / (D**0.5) @@ -3307,7 +3360,7 @@ def main(): compilation_time = time.time() - start_time print(f"Compilation + first run time: {compilation_time:.4f} seconds") - print(f"B, S, H, D: {(B, S, H, D)}") + print(f"B, S, H, HV, D: {(B, S, H, HV, D)}") # Warmup (uses cached kernel — no recompilation) for _ in range(args.warmup_iterations): diff --git a/tests/test_lightning_attn.py b/tests/test_lightning_attn.py index 5e52f86..73f141b 100644 --- a/tests/test_lightning_attn.py +++ b/tests/test_lightning_attn.py @@ -66,16 +66,16 @@ def run_cute_kernel( Uses TVM-FFI compile cache: first call per config compiles, subsequent reuse. Args: - Q, K, V: (B, S, H, D) bfloat16 tensors on CUDA - decay: (H,) float32 per-head decay parameter s (s > 0) + Q, K: (B, S, H, D), V: (B, S, HV, D) bfloat16 tensors on CUDA + decay: (HV,) or (H,) float32 per-head decay parameter s (s > 0) scale: attention scale factor chunk_size: chunk size C - initial_state: (B, H, D, D) float32 or None + initial_state: (B, HV, D, D) float32 or None output_final_state: whether to allocate and return final state Returns: - O: (B, S, H, D) bfloat16 output - ht: (B, H, D, D) float32 final state (or None) + O: (B, S, HV, D) bfloat16 output + ht: (B, HV, D, D) float32 final state (or None) """ O, ht = lightning_attn_fwd( Q, @@ -105,15 +105,15 @@ def run_cute_kernel_varlen( """Run the CuTeDSL varlen kernel. Args: - Q, K, V: (1, T, H, D) bfloat16 — packed sequences - decay: (H,) float32 + Q, K: (1, T, H, D), V: (1, T, HV, D) bfloat16 — packed sequences + decay: (HV,) or (H,) float32 cu_seqlens: (N+1,) int32 - state_pool: (pool_size, H, D, D) float32 or None + state_pool: (pool_size, HV, D, D) float32 or None initial_state_indices: (N,) int32 or None Returns: - O: (1, T, H, D) bfloat16 - state_pool: (pool_size, H, D, D) float32 + O: (1, T, HV, D) bfloat16 + state_pool: (pool_size, HV, D, D) float32 """ O, sp = lightning_attn_fwd_varlen( Q, @@ -135,30 +135,52 @@ def run_cute_kernel_varlen( # --------------------------------------------------------------------------- +def _expand_qk_to_value_heads(Q, K, V): + """Expand Q/K from H heads to HV heads for GVA references.""" + H = Q.shape[2] + HV = V.shape[2] + assert HV >= H and HV % H == 0, f"HV ({HV}) must be >= H ({H}) and divisible by H" + if HV == H: + return Q, K + group_size = HV // H + return Q.repeat_interleave(group_size, dim=2), K.repeat_interleave(group_size, dim=2) + + +def _normalize_decay_to_value_heads(decay, H, HV): + if decay.shape[0] == HV: + return decay + if decay.shape[0] == H and HV != H: + return decay.repeat_interleave(HV // H) + raise ValueError(f"decay must have shape ({HV},) or ({H},), got {tuple(decay.shape)}") + + def pytorch_reference(Q, K, V, decay, chunk_size=64, scale=1.0, initial_state=None, output_final_state=False): """PyTorch reference for chunkwise linear attention with exponential decay. Args: - Q, K, V: (B, T, H, D) — any dtype, computed in float32 - decay: (H,) float32 per-head s (s >= 0) + Q, K: (B, T, H, D), V: (B, T, HV, D) — any dtype, computed in float32 + decay: (HV,) or (H,) float32 per-head s (s >= 0) chunk_size: C scale: scalar multiplier applied to final output - initial_state: (B, H, D, D) float32 or None + initial_state: (B, HV, D, D) float32 or None output_final_state: bool Returns: - O: (B, T, H, D) float32 - final_state: (B, H, D, D) float32 or None + O: (B, T, HV, D) float32 + final_state: (B, HV, D, D) float32 or None """ B, T, H, D = Q.shape + HV = V.shape[2] C = chunk_size + Q, K = _expand_qk_to_value_heads(Q, K, V) + decay = _normalize_decay_to_value_heads(decay, H, HV) Q, K, V = Q.float(), K.float(), V.float() - O = torch.zeros(B, T, H, D, device=Q.device, dtype=torch.float32) + O = torch.zeros(B, T, HV, D, device=Q.device, dtype=torch.float32) state = ( initial_state.clone().float() if initial_state is not None - else torch.zeros(B, H, D, D, device=Q.device, dtype=torch.float32) + else torch.zeros(B, HV, D, D, device=Q.device, dtype=torch.float32) ) num_chunks = (T + C - 1) // C @@ -176,21 +198,21 @@ def pytorch_reference(Q, K, V, decay, chunk_size=64, scale=1.0, initial_state=No pos_k = torch.arange(cl, device=Q.device).view(1, cl) dist = pos_q - pos_k # (cl, cl) - s = decay.view(1, H, 1, 1) + s = decay.view(1, HV, 1, 1) mask = torch.exp(-s * dist.unsqueeze(0).unsqueeze(0).float()) mask = mask * (pos_q >= pos_k).unsqueeze(0).unsqueeze(0).float() O_intra = torch.einsum("bhts,bshd->bthd", QK * mask, Vc) # --- inter-chunk: Q @ state with per-position decay --- pos_in = torch.arange(cl, device=Q.device).float() - per_pos = torch.exp(-decay.view(1, 1, H, 1) * (pos_in.view(1, -1, 1, 1) + 1.0)) + per_pos = torch.exp(-decay.view(1, 1, HV, 1) * (pos_in.view(1, -1, 1, 1) + 1.0)) O_inter = torch.einsum("bthd,bhde->bthe", Qc, state) * per_pos O[:, cs:ce] = (O_intra + O_inter) * scale # --- state update --- - block_decay = torch.exp(-decay.view(1, H, 1, 1) * C) - pos_w = torch.exp(-decay.view(1, 1, H, 1) * (C - 1 - pos_in.view(1, -1, 1, 1))) + block_decay = torch.exp(-decay.view(1, HV, 1, 1) * C) + pos_w = torch.exp(-decay.view(1, 1, HV, 1) * (C - 1 - pos_in.view(1, -1, 1, 1))) state = state * block_decay + torch.einsum("bthd,bthe->bhde", Kc * pos_w, Vc) return O, (state if output_final_state else None) @@ -270,16 +292,17 @@ def test_different_decay_values(): return False -def test_against_reference(B=1, S=128, H=4, D=128, C=64, decay_val=0.1, atol=5e-3, rtol=5e-2, verbose=True): +def test_against_reference(B=1, S=128, H=4, HV=None, D=128, C=64, decay_val=0.1, atol=5e-3, rtol=5e-2, verbose=True): """Compare against PyTorch reference (exact match).""" + HV = H if HV is None else HV if verbose: - print(f"\nRef: B={B}, S={S}, H={H}, D={D}, C={C}, decay={decay_val}") + print(f"\nRef: B={B}, S={S}, H={H}, HV={HV}, D={D}, C={C}, decay={decay_val}") torch.manual_seed(42) Q = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16) * 0.1 K = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16) * 0.1 - V = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16) * 0.1 - decay = torch.full((H,), decay_val, device="cuda", dtype=torch.float32) + V = torch.randn(B, S, HV, D, device="cuda", dtype=torch.bfloat16) * 0.1 + decay = torch.full((HV,), decay_val, device="cuda", dtype=torch.float32) O_ref, _ = pytorch_reference(Q, K, V, decay, chunk_size=C) O_ref_bf16 = O_ref.to(torch.bfloat16) @@ -291,22 +314,23 @@ def test_against_reference(B=1, S=128, H=4, D=128, C=64, decay_val=0.1, atol=5e- return passed -def test_initial_and_final_state(B=1, S=128, H=4, D=128, C=64, decay_val=0.1, atol=5e-3, rtol=5e-2, verbose=True): +def test_initial_and_final_state(B=1, S=128, H=4, HV=None, D=128, C=64, decay_val=0.1, atol=5e-3, rtol=5e-2, verbose=True): """Test h0/ht against PyTorch reference. NOTE: This test is placed BEFORE FLA tests so that the (has_initial_state=True, output_final_state=True) kernel variant is compiled before any Triton/FLA code runs. Running Triton corrupts state needed by cute.compile. """ + HV = H if HV is None else HV if verbose: - print(f"\nh0/ht: B={B}, S={S}, H={H}, D={D}, C={C}, decay={decay_val}") + print(f"\nh0/ht: B={B}, S={S}, H={H}, HV={HV}, D={D}, C={C}, decay={decay_val}") torch.manual_seed(42) Q = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16) * 0.1 K = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16) * 0.1 - V = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16) * 0.1 - decay = torch.full((H,), decay_val, device="cuda", dtype=torch.float32) - h0 = torch.randn(B, H, D, D, device="cuda", dtype=torch.float32) * 0.01 + V = torch.randn(B, S, HV, D, device="cuda", dtype=torch.bfloat16) * 0.1 + decay = torch.full((HV,), decay_val, device="cuda", dtype=torch.float32) + h0 = torch.randn(B, HV, D, D, device="cuda", dtype=torch.float32) * 0.01 h0_vk = h0.transpose(-1, -2).contiguous() # BHVK for CuTe kernel O_ref, ht_ref = pytorch_reference( @@ -340,7 +364,7 @@ def test_initial_and_final_state(B=1, S=128, H=4, D=128, C=64, decay_val=0.1, at return passed -def test_against_fla(B=1, S=128, H=4, D=128, C=64, decay_val=0.1, atol=5e-3, rtol=5e-2, verbose=True): +def test_against_fla(B=1, S=128, H=4, HV=None, D=128, C=64, decay_val=0.1, atol=5e-3, rtol=5e-2, verbose=True): """Compare against FLA chunk_simple_gla using g_gamma = -s. FLA's g_gamma is the per-head log-decay (negative). Our decay parameter s @@ -350,20 +374,22 @@ def test_against_fla(B=1, S=128, H=4, D=128, C=64, decay_val=0.1, atol=5e-3, rto print("\n ⊘ SKIPPED: fla library not available") return True + HV = H if HV is None else HV if verbose: - print(f"\nFLA: B={B}, S={S}, H={H}, D={D}, C={C}, decay={decay_val}") + print(f"\nFLA: B={B}, S={S}, H={H}, HV={HV}, D={D}, C={C}, decay={decay_val}") torch.manual_seed(42) Q = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16) * 0.1 K = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16) * 0.1 - V = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16) * 0.1 + V = torch.randn(B, S, HV, D, device="cuda", dtype=torch.bfloat16) * 0.1 + Q_fla, K_fla = _expand_qk_to_value_heads(Q, K, V) # Our decay s -> FLA g_gamma = -s - decay = torch.full((H,), decay_val, device="cuda", dtype=torch.float32) + decay = torch.full((HV,), decay_val, device="cuda", dtype=torch.float32) g_gamma = -decay # FLA reference (scale=1.0 to match our kernel) - O_fla, _ = chunk_simple_gla(Q, K, V, g_gamma=g_gamma, scale=1.0) + O_fla, _ = chunk_simple_gla(Q_fla, K_fla, V, g_gamma=g_gamma, scale=1.0) # Our kernel O_cute, _ = run_cute_kernel(Q, K, V, decay, scale=1.0, chunk_size=C) @@ -383,29 +409,31 @@ def test_against_fla(B=1, S=128, H=4, D=128, C=64, decay_val=0.1, atol=5e-3, rto return passed -def test_against_fla_with_state(B=1, S=128, H=4, D=128, C=64, decay_val=0.1, atol=5e-3, rtol=5e-2, verbose=True): +def test_against_fla_with_state(B=1, S=128, H=4, HV=None, D=128, C=64, decay_val=0.1, atol=5e-3, rtol=5e-2, verbose=True): """Compare h0/ht against FLA chunk_simple_gla.""" if not HAS_FLA: print("\n ⊘ SKIPPED: fla library not available") return True + HV = H if HV is None else HV if verbose: - print(f"\nFLA h0/ht: B={B}, S={S}, H={H}, D={D}, C={C}, decay={decay_val}") + print(f"\nFLA h0/ht: B={B}, S={S}, H={H}, HV={HV}, D={D}, C={C}, decay={decay_val}") torch.manual_seed(42) Q = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16) * 0.1 K = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16) * 0.1 - V = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16) * 0.1 - h0 = torch.randn(B, H, D, D, device="cuda", dtype=torch.float32) * 0.01 + V = torch.randn(B, S, HV, D, device="cuda", dtype=torch.bfloat16) * 0.1 + Q_fla, K_fla = _expand_qk_to_value_heads(Q, K, V) + h0 = torch.randn(B, HV, D, D, device="cuda", dtype=torch.float32) * 0.01 h0_vk = h0.transpose(-1, -2).contiguous() # BHVK for CuTe kernel - decay = torch.full((H,), decay_val, device="cuda", dtype=torch.float32) + decay = torch.full((HV,), decay_val, device="cuda", dtype=torch.float32) g_gamma = -decay # FLA (expects BHKV state) O_fla, ht_fla = chunk_simple_gla( - Q, - K, + Q_fla, + K_fla, V, g_gamma=g_gamma, scale=1.0, @@ -440,16 +468,17 @@ def test_against_fla_with_state(B=1, S=128, H=4, D=128, C=64, decay_val=0.1, ato # =========================================================================== -def test_varlen_single_seq(H=4, S=128, D=128, C=64, decay_val=0.1, atol=5e-3, rtol=5e-2, verbose=True) -> bool: +def test_varlen_single_seq(H=4, HV=None, S=128, D=128, C=64, decay_val=0.1, atol=5e-3, rtol=5e-2, verbose=True) -> bool: """Varlen with a single sequence vs non-varlen reference.""" + HV = H if HV is None else HV if verbose: - print(f"\nVarlen single: S={S}, H={H}, D={D}, C={C}, decay={decay_val}") + print(f"\nVarlen single: S={S}, H={H}, HV={HV}, D={D}, C={C}, decay={decay_val}") torch.manual_seed(42) Q = torch.randn(1, S, H, D, device="cuda", dtype=torch.bfloat16) * 0.1 K = torch.randn(1, S, H, D, device="cuda", dtype=torch.bfloat16) * 0.1 - V = torch.randn(1, S, H, D, device="cuda", dtype=torch.bfloat16) * 0.1 - decay = torch.full((H,), decay_val, device="cuda", dtype=torch.float32) + V = torch.randn(1, S, HV, D, device="cuda", dtype=torch.bfloat16) * 0.1 + decay = torch.full((HV,), decay_val, device="cuda", dtype=torch.float32) # Non-varlen reference O_ref, ht_ref = run_cute_kernel(Q, K, V, decay, chunk_size=C, output_final_state=True) @@ -465,12 +494,13 @@ def test_varlen_single_seq(H=4, S=128, D=128, C=64, decay_val=0.1, atol=5e-3, rt return passed -def test_varlen_multi_seq(seq_lens=None, H=4, D=128, C=64, decay_val=0.1, atol=5e-3, rtol=5e-2, verbose=True) -> bool: +def test_varlen_multi_seq(seq_lens=None, H=4, HV=None, D=128, C=64, decay_val=0.1, atol=5e-3, rtol=5e-2, verbose=True) -> bool: """Varlen with multiple packed sequences vs per-sequence non-varlen reference.""" + HV = H if HV is None else HV if seq_lens is None: seq_lens = [128, 64, 192] # all multiples of C if verbose: - print(f"\nVarlen multi: seqs={seq_lens}, H={H}, D={D}, C={C}, decay={decay_val}") + print(f"\nVarlen multi: seqs={seq_lens}, H={H}, HV={HV}, D={D}, C={C}, decay={decay_val}") torch.manual_seed(42) T = sum(seq_lens) @@ -482,8 +512,8 @@ def test_varlen_multi_seq(seq_lens=None, H=4, D=128, C=64, decay_val=0.1, atol=5 Q = torch.randn(1, T, H, D, device="cuda", dtype=torch.bfloat16) * 0.1 K = torch.randn(1, T, H, D, device="cuda", dtype=torch.bfloat16) * 0.1 - V = torch.randn(1, T, H, D, device="cuda", dtype=torch.bfloat16) * 0.1 - decay = torch.full((H,), decay_val, device="cuda", dtype=torch.float32) + V = torch.randn(1, T, HV, D, device="cuda", dtype=torch.bfloat16) * 0.1 + decay = torch.full((HV,), decay_val, device="cuda", dtype=torch.float32) O_var, sp = run_cute_kernel_varlen(Q, K, V, decay, cu_seqlens, chunk_size=C) @@ -504,12 +534,13 @@ def test_varlen_multi_seq(seq_lens=None, H=4, D=128, C=64, decay_val=0.1, atol=5 return all_pass -def test_varlen_with_initial_state(seq_lens=None, H=4, D=128, C=64, decay_val=0.1, atol=5e-3, rtol=5e-2, verbose=True) -> bool: +def test_varlen_with_initial_state(seq_lens=None, H=4, HV=None, D=128, C=64, decay_val=0.1, atol=5e-3, rtol=5e-2, verbose=True) -> bool: """Varlen with initial state from state pool (non-contiguous indices).""" + HV = H if HV is None else HV if seq_lens is None: seq_lens = [128, 64] if verbose: - print(f"\nVarlen h0: seqs={seq_lens}, H={H}, D={D}, C={C}, decay={decay_val}") + print(f"\nVarlen h0: seqs={seq_lens}, H={H}, HV={HV}, D={D}, C={C}, decay={decay_val}") torch.manual_seed(42) T = sum(seq_lens) @@ -521,12 +552,12 @@ def test_varlen_with_initial_state(seq_lens=None, H=4, D=128, C=64, decay_val=0. Q = torch.randn(1, T, H, D, device="cuda", dtype=torch.bfloat16) * 0.1 K = torch.randn(1, T, H, D, device="cuda", dtype=torch.bfloat16) * 0.1 - V = torch.randn(1, T, H, D, device="cuda", dtype=torch.bfloat16) * 0.1 - decay = torch.full((H,), decay_val, device="cuda", dtype=torch.float32) + V = torch.randn(1, T, HV, D, device="cuda", dtype=torch.bfloat16) * 0.1 + decay = torch.full((HV,), decay_val, device="cuda", dtype=torch.float32) # State pool with 3 slots, use indices [2, 0] — BHVK layout for CuTe pool_size = 3 - state_pool = torch.randn(pool_size, H, D, D, dtype=torch.float32, device="cuda").transpose(-1, -2).contiguous() * 0.01 + state_pool = torch.randn(pool_size, HV, D, D, dtype=torch.float32, device="cuda").transpose(-1, -2).contiguous() * 0.01 indices = torch.tensor([2, 0], dtype=torch.int32, device="cuda") O_var, sp = run_cute_kernel_varlen( @@ -569,13 +600,14 @@ def test_varlen_with_initial_state(seq_lens=None, H=4, D=128, C=64, decay_val=0. def test_varlen_against_pytorch_ref( - seq_lens=None, H=4, D=128, C=64, decay_val=0.1, atol=5e-3, rtol=5e-2, verbose=True + seq_lens=None, H=4, HV=None, D=128, C=64, decay_val=0.1, atol=5e-3, rtol=5e-2, verbose=True ) -> bool: """Varlen against the PyTorch reference with initial state.""" + HV = H if HV is None else HV if seq_lens is None: seq_lens = [128, 192] if verbose: - print(f"\nVarlen vs ref: seqs={seq_lens}, H={H}, D={D}, C={C}, decay={decay_val}") + print(f"\nVarlen vs ref: seqs={seq_lens}, H={H}, HV={HV}, D={D}, C={C}, decay={decay_val}") torch.manual_seed(42) T = sum(seq_lens) @@ -588,10 +620,10 @@ def test_varlen_against_pytorch_ref( Q = torch.randn(1, T, H, D, device="cuda", dtype=torch.bfloat16) * 0.1 K = torch.randn(1, T, H, D, device="cuda", dtype=torch.bfloat16) * 0.1 - V = torch.randn(1, T, H, D, device="cuda", dtype=torch.bfloat16) * 0.1 - decay = torch.full((H,), decay_val, device="cuda", dtype=torch.float32) + V = torch.randn(1, T, HV, D, device="cuda", dtype=torch.bfloat16) * 0.1 + decay = torch.full((HV,), decay_val, device="cuda", dtype=torch.float32) - state_pool = torch.randn(N, H, D, D, dtype=torch.float32, device="cuda") * 0.01 + state_pool = torch.randn(N, HV, D, D, dtype=torch.float32, device="cuda") * 0.01 state_pool_vk = state_pool.transpose(-1, -2).contiguous() # BHVK for CuTe O_var, sp = run_cute_kernel_varlen( @@ -675,6 +707,7 @@ def main(): ("Decay 0.2", dict(B=1, S=128, H=4, D=128, C=64, decay_val=0.2)), ("Decay 0.5", dict(B=1, S=128, H=4, D=128, C=64, decay_val=0.5)), ("Batch", dict(B=2, S=128, H=4, D=128, C=64, decay_val=0.1)), + ("GVA H2-HV4", dict(B=1, S=128, H=2, HV=4, D=128, C=64, decay_val=0.1)), ]: results.append((f"Ref {tag}", test_against_reference(**kw, verbose=args.verbose))) @@ -694,6 +727,7 @@ def main(): ("Decay 0.2", dict(B=1, S=128, H=4, D=128, C=64, decay_val=0.2)), ("Decay 0.5", dict(B=1, S=128, H=4, D=128, C=64, decay_val=0.5)), ("Batch", dict(B=2, S=128, H=4, D=128, C=64, decay_val=0.1)), + ("GVA H2-HV4", dict(B=1, S=128, H=2, HV=4, D=128, C=64, decay_val=0.1)), ]: results.append((f"FLA {tag}", test_against_fla(**kw, verbose=args.verbose))) @@ -706,6 +740,7 @@ def main(): ("Small", dict(B=1, S=64, H=2, D=128, C=64, decay_val=0.1)), ("Multi-chunk", dict(B=1, S=256, H=4, D=128, C=64, decay_val=0.1)), ("Batch", dict(B=2, S=128, H=4, D=128, C=64, decay_val=0.2)), + ("GVA H2-HV4", dict(B=1, S=128, H=2, HV=4, D=128, C=64, decay_val=0.1)), ]: results.append((f"h0/ht {tag}", test_initial_and_final_state(**kw, verbose=args.verbose))) @@ -719,6 +754,7 @@ def main(): ("Small", dict(B=1, S=64, H=2, D=128, C=64, decay_val=0.1)), ("Multi-chunk", dict(B=1, S=256, H=4, D=128, C=64, decay_val=0.1)), ("Batch", dict(B=2, S=128, H=4, D=128, C=64, decay_val=0.2)), + ("GVA H2-HV4", dict(B=1, S=128, H=2, HV=4, D=128, C=64, decay_val=0.1)), ]: results.append((f"FLA h0/ht {tag}", test_against_fla_with_state(**kw, verbose=args.verbose))) @@ -731,6 +767,7 @@ def main(): ("Single seq", dict(H=4, S=128, D=128, C=64, decay_val=0.1)), ("Single long", dict(H=4, S=256, D=128, C=64, decay_val=0.1)), ("Decay 0.5", dict(H=4, S=128, D=128, C=64, decay_val=0.5)), + ("GVA single", dict(H=2, HV=4, S=128, D=128, C=64, decay_val=0.1)), ]: results.append((f"Varlen {tag}", test_varlen_single_seq(**kw, verbose=args.verbose))) @@ -738,12 +775,14 @@ def main(): ("Multi 3-seq", dict(seq_lens=[128, 64, 192], H=4, D=128, C=64, decay_val=0.1)), ("Multi 2-seq", dict(seq_lens=[256, 128], H=4, D=128, C=64, decay_val=0.1)), ("Multi decay", dict(seq_lens=[128, 128], H=4, D=128, C=64, decay_val=0.5)), + ("GVA multi", dict(seq_lens=[128, 64], H=2, HV=4, D=128, C=64, decay_val=0.1)), ]: results.append((f"Varlen {tag}", test_varlen_multi_seq(**kw, verbose=args.verbose))) for tag, kw in [ ("h0 indirect", dict(seq_lens=[128, 64], H=4, D=128, C=64, decay_val=0.1)), ("h0 decay 0.2", dict(seq_lens=[128, 64], H=4, D=128, C=64, decay_val=0.2)), + ("GVA h0", dict(seq_lens=[128, 64], H=2, HV=4, D=128, C=64, decay_val=0.1)), ]: results.append((f"Varlen {tag}", test_varlen_with_initial_state(**kw, verbose=args.verbose))) @@ -753,6 +792,7 @@ def main(): for tag, kw in [ ("vs ref 2-seq", dict(seq_lens=[128, 192], H=4, D=128, C=64, decay_val=0.1)), ("vs ref decay", dict(seq_lens=[64, 128], H=4, D=128, C=64, decay_val=0.5)), + ("GVA vs ref", dict(seq_lens=[64, 128], H=2, HV=4, D=128, C=64, decay_val=0.1)), ]: results.append((f"Varlen {tag}", test_varlen_against_pytorch_ref(**kw, verbose=args.verbose))) From cf56cf40147dc7d7cab23f3dac13fdb622a64c8f Mon Sep 17 00:00:00 2001 From: sunnyxyli Date: Thu, 28 May 2026 18:48:03 +0800 Subject: [PATCH 2/2] support gva --- benchmarks/bench_la_decode_vs_fla.py | 64 ++++++++------- cula/ops/la_decode.py | 90 +++++++++++++-------- tests/test_la_decode.py | 113 +++++++++++++++++++++------ tests/test_la_decode_pool.py | 53 +++++++++++-- 4 files changed, 231 insertions(+), 89 deletions(-) diff --git a/benchmarks/bench_la_decode_vs_fla.py b/benchmarks/bench_la_decode_vs_fla.py index 27eafbe..ff1e20e 100644 --- a/benchmarks/bench_la_decode_vs_fla.py +++ b/benchmarks/bench_la_decode_vs_fla.py @@ -38,6 +38,7 @@ Usage: python benchmarks/bench_la_decode_vs_fla.py python benchmarks/bench_la_decode_vs_fla.py --heads 64 --head-dim 128 + python benchmarks/bench_la_decode_vs_fla.py --heads 32 --num-v-heads 64 python benchmarks/bench_la_decode_vs_fla.py --batch-sizes 1 8 64 256 """ @@ -63,28 +64,32 @@ # ───────────────────────────────────────────────────────────────────────────── # Core benchmark for one configuration # ───────────────────────────────────────────────────────────────────────────── -def run_config(B, H, K, V, layer_idx, num_layers): +def run_config(B, H, HV, K, V, layer_idx, num_layers): + assert HV >= H and HV % H == 0, f"HV ({HV}) must be >= H ({H}) and divisible by H" device = "cuda" dtype = torch.bfloat16 scale = K**-0.5 + group_size = HV // H # Per-head log decay (Lightning Attention formula) - g_gamma = -(8 / H * (1 - layer_idx / num_layers)) * torch.arange(H, device=device, dtype=torch.float32) + g_gamma = -(8 / HV * (1 - layer_idx / num_layers)) * torch.arange(HV, device=device, dtype=torch.float32) decay_scales = -g_gamma # la_decode convention # ── Random inputs ────────────────────────────────────────────────────── torch.manual_seed(42) q_4d = torch.randn(B, 1, H, K, device=device, dtype=dtype) k_4d = torch.randn(B, 1, H, K, device=device, dtype=dtype) - v_4d = torch.randn(B, 1, H, V, device=device, dtype=dtype) - state_init = torch.randn(B, H, K, V, device=device, dtype=torch.float32) * 0.01 + v_4d = torch.randn(B, 1, HV, V, device=device, dtype=dtype) + state_init = torch.randn(B, HV, K, V, device=device, dtype=torch.float32) * 0.01 + q_fla_4d = q_4d.repeat_interleave(group_size, dim=2) + k_fla_4d = k_4d.repeat_interleave(group_size, dim=2) # ── fla reference output ─────────────────────────────────────────────── state_fla = state_init.clone() with torch.no_grad(): o_fla_fp32, ht_fla = fused_recurrent_fwd( - q_4d, - k_4d, + q_fla_4d, + k_fla_4d, v_4d, g_gamma=g_gamma, scale=scale, @@ -94,11 +99,11 @@ def run_config(B, H, K, V, layer_idx, num_layers): o_fla = o_fla_fp32.to(dtype) # ── la_decode output ─────────────────────────────────────────────────── - state_cute = state_init.clone().permute(0, 1, 3, 2).reshape(B * H, V, K).contiguous() + state_cute = state_init.clone().permute(0, 1, 3, 2).reshape(B * HV, V, K).contiguous() q_3d = q_4d.squeeze(1) k_3d = k_4d.squeeze(1) v_3d = v_4d.squeeze(1) - out_cute = torch.zeros(B, H, V, device=device, dtype=dtype) + out_cute = torch.zeros(B, HV, V, device=device, dtype=dtype) s_offsets = torch.arange(B, device=device, dtype=torch.int32) with torch.no_grad(): @@ -128,7 +133,7 @@ def run_config(B, H, K, V, layer_idx, num_layers): max_ref = torch.abs(o_fla_cmp).max().item() rel_maxdiff = torch.abs(o_cute_cmp - o_fla_cmp).max().item() / (max_ref + 1e-8) - state_cute_back = state_cute.reshape(B, H, V, K).permute(0, 1, 3, 2).contiguous() + state_cute_back = state_cute.reshape(B, HV, V, K).permute(0, 1, 3, 2).contiguous() state_relative_rms_error = relative_rms_error(ht_fla, state_cute_back) # ================================================================== @@ -140,16 +145,16 @@ def run_config(B, H, K, V, layer_idx, num_layers): BV_fla = min(triton.next_power_of_2(V), 64) NK = triton.cdiv(K, BK_fla) NV = triton.cdiv(V, BV_fla) - fla_o_buf = torch.empty(NK, B, 1, H, V, device=device, dtype=torch.float32) - fla_ht_buf = torch.empty(B, H, K, V, device=device, dtype=torch.float32) - fla_o_sum = torch.empty(B, 1, H, V, device=device, dtype=torch.float32) + fla_o_buf = torch.empty(NK, B, 1, HV, V, device=device, dtype=torch.float32) + fla_ht_buf = torch.empty(B, HV, K, V, device=device, dtype=torch.float32) fla_state_k = state_init.clone() - grid_fla = (NV, NK, B * H) + fla_o_sum = torch.empty(B, 1, HV, V, device=device, dtype=torch.float32) + grid_fla = (NV, NK, B * HV) def kernel_fla(): fused_recurrent_fwd_kernel[grid_fla]( - q=q_4d, - k=k_4d, + q=q_fla_4d, + k=k_fla_4d, v=v_4d, g=None, g_gamma=g_gamma, @@ -162,7 +167,7 @@ def kernel_fla(): scale=scale, B=B, T=1, - H=H, + H=HV, K=K, V=V, BK=BK_fla, @@ -176,9 +181,9 @@ def kernel_fla(): torch.sum(fla_o_buf, dim=0, out=fla_o_sum) # cute kernel: pre-create compiled + stream handle - cute_state_k = state_init.clone().permute(0, 1, 3, 2).reshape(B * H, V, K).contiguous() - out_cute_k = torch.empty(B, H, V, device=device, dtype=dtype) - cache = _get_compiled_kernel(B, 1, H, K, V, cute_state_k.shape[0], scale, USE_FAST_MATH) + cute_state_k = state_init.clone().permute(0, 1, 3, 2).reshape(B * HV, V, K).contiguous() + out_cute_k = torch.empty(B, HV, V, device=device, dtype=dtype) + cache = _get_compiled_kernel(B, 1, H, HV, K, V, cute_state_k.shape[0], scale, USE_FAST_MATH) compiled_cute = cache["compiled"] stream_handle = cuda_drv.CUstream(torch.cuda.current_stream().cuda_stream) @@ -193,13 +198,13 @@ def kernel_cute(): # Mode 2: WRAPPER (full call path as used in production) # ================================================================== wrap_fla_state = state_init.clone() - wrap_cute_state = state_init.clone().permute(0, 1, 3, 2).reshape(B * H, V, K).contiguous() - wrap_cute_out = torch.empty(B, H, V, device=device, dtype=dtype) + wrap_cute_state = state_init.clone().permute(0, 1, 3, 2).reshape(B * HV, V, K).contiguous() + wrap_cute_out = torch.empty(B, HV, V, device=device, dtype=dtype) def wrapper_fla(): fused_recurrent_fwd( - q_4d, - k_4d, + q_fla_4d, + k_fla_4d, v_4d, g_gamma=g_gamma, scale=scale, @@ -233,6 +238,8 @@ def wrapper_cute(): return { "B": B, + "H": H, + "HV": HV, "kernel_fla_ms": kernel_fla_ms, "kernel_cute_ms": kernel_cute_ms, "kernel_speedup": kernel_fla_ms / kernel_cute_ms, @@ -257,16 +264,19 @@ def main(): default=[1, 2, 4, 8, 16, 32, 64, 128, 256], ) parser.add_argument("--heads", type=int, default=32) + parser.add_argument("--num-v-heads", type=int, default=None, help="Number of value heads (default: same as --heads)") parser.add_argument("--head-dim", type=int, default=128) parser.add_argument("--layer-idx", type=int, default=12) parser.add_argument("--num-layers", type=int, default=24) args = parser.parse_args() - H, K, V = args.heads, args.head_dim, args.head_dim + H, HV, K, V = args.heads, args.num_v_heads or args.heads, args.head_dim, args.head_dim + if HV < H or HV % H != 0: + raise ValueError(f"num_v_heads ({HV}) must be >= heads ({H}) and divisible by heads") print("Lightning Attention Decode Benchmark") print(" la_decode (CuTe DSL) vs fla fused_recurrent_fwd (Triton)") - print(f" H={H}, K={K}, V={V}, layer={args.layer_idx}/{args.num_layers}") + print(f" H={H}, HV={HV}, K={K}, V={V}, layer={args.layer_idx}/{args.num_layers}") print(" dtype=bf16, state=fp32, T=1") # ── Kernel-only comparison ────────────────────────────────────────── @@ -282,7 +292,7 @@ def main(): results = [] for B in args.batch_sizes: - r = run_config(B, H, K, V, args.layer_idx, args.num_layers) + r = run_config(B, H, HV, K, V, args.layer_idx, args.num_layers) results.append(r) print( f"{r['B']:>5} | {r['kernel_fla_ms']:>10.4f} | {r['kernel_cute_ms']:>10.4f} | " @@ -293,7 +303,7 @@ def main(): # ── Wrapper comparison ────────────────────────────────────────────── print(f"\n{'=' * 100}") print(" Mode 2: WRAPPER (fused_recurrent_fwd vs linear_attention_decode, full call path)") - print(" fla: alloc o[NK,B,1,H,V]+ht[B,H,K,V] + kernel + sum(0); cute: cache lookup + CUstream + kernel") + print(" fla: alloc o[NK,B,1,HV,V]+ht[B,HV,K,V] + kernel + sum(0); cute: cache lookup + CUstream + kernel") print(f"{'=' * 100}") print(f"{'B':>5} | {'fla (ms)':>10} | {'cute (ms)':>10} | {'speedup':>8}") print("─" * 50) diff --git a/cula/ops/la_decode.py b/cula/ops/la_decode.py index 08831df..6ab1b46 100644 --- a/cula/ops/la_decode.py +++ b/cula/ops/la_decode.py @@ -76,7 +76,7 @@ def la_decode_kernel_small_batch_pretranspose( smem_layout_staged: cute.Layout, vec_size: cutlass.Constexpr[int], num_v_tiles: cutlass.Constexpr[int], - decay_scales: cute.Tensor, # [H] + decay_scales: cute.Tensor, # [HV] q: cute.Tensor, # [B, T, H, K] k: cute.Tensor, # [B, T, H, K] v: cute.Tensor, # [B, T, HV, V] @@ -86,6 +86,7 @@ def la_decode_kernel_small_batch_pretranspose( B: cutlass.Constexpr[int], T: cutlass.Constexpr[int], H: cutlass.Constexpr[int], + HV: cutlass.Constexpr[int], K: cutlass.Constexpr[int], V: cutlass.Constexpr[int], NUM_WARPS: cutlass.Constexpr[int] = 4, @@ -94,7 +95,6 @@ def la_decode_kernel_small_batch_pretranspose( ): """Each block uses pipeline to load one batch and vectorized writeback""" - HV = H tidx, _, _ = cute.arch.thread_idx() lane_id = tidx % 32 warp_idx = cute.arch.warp_idx() @@ -121,7 +121,7 @@ def la_decode_kernel_small_batch_pretranspose( r_q = cute.make_rmem_tensor(cute.make_layout((vec_size,), stride=(1,)), cutlass.Float32) r_v = cute.make_rmem_tensor(cute.make_layout((vec_size,), stride=(1,)), cutlass.Float32) r_h = cute.make_rmem_tensor(cute.make_layout((vec_size,), stride=(1,)), cutlass.Float32) - r_decay_scale = -cutlass.Float32(decay_scales[i_h]) + r_decay_scale = -cutlass.Float32(decay_scales[i_hv]) r_decay = cute.exp(r_decay_scale, fastmath=USE_FAST_MATH) cute.arch.barrier() @@ -244,7 +244,7 @@ def la_decode_kernel_big_batch_pretranspose( smem_layout_staged: cute.Layout, vec_size: cutlass.Constexpr[int], num_v_tiles: cutlass.Constexpr[int], - decay_scales: cute.Tensor, # [H] + decay_scales: cute.Tensor, # [HV] q: cute.Tensor, # [B, T, H, K] k: cute.Tensor, # [B, T, H, K] v: cute.Tensor, # [B, T, HV, V] @@ -254,6 +254,7 @@ def la_decode_kernel_big_batch_pretranspose( B: cutlass.Constexpr[int], T: cutlass.Constexpr[int], H: cutlass.Constexpr[int], + HV: cutlass.Constexpr[int], K: cutlass.Constexpr[int], V: cutlass.Constexpr[int], NUM_WARPS: cutlass.Constexpr[int] = 4, @@ -262,7 +263,6 @@ def la_decode_kernel_big_batch_pretranspose( ): """Each block uses pipeline to load one batch and vectorized writeback""" - HV = H tidx, _, _ = cute.arch.thread_idx() lane_id = tidx % 32 warp_idx = cute.arch.warp_idx() @@ -330,7 +330,7 @@ def la_decode_kernel_big_batch_pretranspose( for i in cutlass.range_constexpr(vec_size): r_q[i] = r_q[i] * scale - r_g = cute.exp(-cutlass.Float32(decay_scales[i_h]), fastmath=USE_FAST_MATH) + r_g = cute.exp(-cutlass.Float32(decay_scales[i_hv]), fastmath=USE_FAST_MATH) # =================================================================== # Mainloop: All threads participate @@ -404,8 +404,8 @@ def la_decode_kernel_big_batch_pretranspose( @cute.jit def run_la_decode_kernel_big_batch_pretranspose( - h0_source: cute.Tensor, # [B*H, V, K] - decay_scales: cute.Tensor, # [H] + h0_source: cute.Tensor, # [pool_size*HV, V, K] + decay_scales: cute.Tensor, # [HV] q: cute.Tensor, k: cute.Tensor, v: cute.Tensor, @@ -413,13 +413,14 @@ def run_la_decode_kernel_big_batch_pretranspose( h0_indices: cute.Tensor, softmax_scale: cutlass.Constexpr[float], H: cutlass.Constexpr[int], + HV: cutlass.Constexpr[int], B: cutlass.Constexpr[int], T: cutlass.Constexpr[int], K: cutlass.Constexpr[int], V: cutlass.Constexpr[int], stream: cuda.CUstream, ): - # h0_source: (B*HV, V, K) + # h0_source: (pool_size*HV, V, K) _pool_dim0, v_dim, _k_dim = ( h0_source.layout.shape[0], h0_source.layout.shape[1], @@ -473,13 +474,14 @@ def run_la_decode_kernel_big_batch_pretranspose( B, T, H, + HV, K, V, NUM_WARPS_BIG, TILE_V_BIG, NUM_STAGES_BIG, ).launch( - grid=(B * H, 1, 1), + grid=(B * HV, 1, 1), block=[NUM_THREADS_BIG, 1, 1], smem=smem_bytes, stream=stream, @@ -488,8 +490,8 @@ def run_la_decode_kernel_big_batch_pretranspose( @cute.jit def run_la_decode_kernel_small_batch_pretranspose( - h0_source: cute.Tensor, # [B*H, V, K] - decay_scales: cute.Tensor, # [H] + h0_source: cute.Tensor, # [pool_size*HV, V, K] + decay_scales: cute.Tensor, # [HV] q: cute.Tensor, k: cute.Tensor, v: cute.Tensor, @@ -497,13 +499,14 @@ def run_la_decode_kernel_small_batch_pretranspose( h0_indices: cute.Tensor, softmax_scale: cutlass.Constexpr[float], H: cutlass.Constexpr[int], + HV: cutlass.Constexpr[int], B: cutlass.Constexpr[int], T: cutlass.Constexpr[int], K: cutlass.Constexpr[int], V: cutlass.Constexpr[int], stream: cuda.CUstream, ): - # h0_source: (B*H, V, K) + # h0_source: (pool_size*HV, V, K) _pool_dim0, v_dim, _k_dim = ( h0_source.layout.shape[0], h0_source.layout.shape[1], @@ -557,13 +560,14 @@ def run_la_decode_kernel_small_batch_pretranspose( B, T, H, + HV, K, V, NUM_WARPS_SMALL, TILE_V_SMALL, NUM_STAGES_SMALL, ).launch( - grid=(B * H * NUM_BLOCKS_PER_STATE, 1, 1), + grid=(B * HV * NUM_BLOCKS_PER_STATE, 1, 1), block=[NUM_THREADS_SMALL, 1, 1], smem=smem_bytes, stream=stream, @@ -572,18 +576,18 @@ def run_la_decode_kernel_small_batch_pretranspose( @functools.cache def _get_compiled_kernel( - B: int, T: int, H: int, K: int, V: int, pool_dim0: int, softmax_scale: float, use_fast_math: bool = True + B: int, T: int, H: int, HV: int, K: int, V: int, pool_dim0: int, softmax_scale: float, use_fast_math: bool = True ): """Get or create compiled kernel cache.""" return {} def linear_attention_decode( - q: torch.Tensor, # [B, 1, H, HEAD_DIM], same as [B, 1, H, K] - k: torch.Tensor, # [B, 1, H, HEAD_DIM], same as [B, 1, H, K] - v: torch.Tensor, # [B, 1, H, HEAD_DIM], same as [B, 1, H, V] - s: torch.Tensor, # [pool_size, heads, V, K] - out: torch.Tensor, # [B, 1, H, HEAD_DIM] + q: torch.Tensor, # [B, H, HEAD_DIM], same as [B, H, K] + k: torch.Tensor, # [B, H, HEAD_DIM], same as [B, H, K] + v: torch.Tensor, # [B, HV, HEAD_DIM], same as [B, HV, V] + s: torch.Tensor, # [pool_size * HV, V, K] + out: torch.Tensor, # [B, HV, HEAD_DIM] softmax_scale: float, stride_q: int, stride_k: int, @@ -591,7 +595,7 @@ def linear_attention_decode( stride_s: int, stride_o: int, s_offsets: torch.Tensor, # [B] - state pool indices - decay_scales: torch.Tensor, # [H] + decay_scales: torch.Tensor, # [HV] or [H] HEAD_DIM: int, K_SPLIT_DIM: int, V_SPLIT_DIM: int, @@ -603,9 +607,9 @@ def linear_attention_decode( Args: q: Query tensor [B, H, HEAD_DIM] k: Key tensor [B, H, HEAD_DIM] - v: Value tensor [B, H, HEAD_DIM] - s: State pool tensor [pool_size, heads, K*V] - out: Output tensor [k_dim_block, length, heads, HEAD_DIM] + v: Value tensor [B, HV, HEAD_DIM] + s: State pool tensor [pool_size * HV, V, K] in BHVK layout + out: Output tensor [B, HV, HEAD_DIM] softmax_scale: Softmax scale factor stride_q: Stride of q tensor stride_k: Stride of k tensor @@ -613,7 +617,7 @@ def linear_attention_decode( stride_s: Stride of s tensor stride_o: Stride of out tensor s_offsets: State pool indices [B] - decay_scales: Decay scales per head [H] + decay_scales: Decay scales per value head [HV]. A [H] tensor is accepted and expanded. HEAD_DIM: Head dimension K_SPLIT_DIM: K split dimension (must be HEAD_DIM for no split) V_SPLIT_DIM: V split dimension (must be HEAD_DIM for no split) @@ -621,8 +625,27 @@ def linear_attention_decode( Returns: None (modifies out and s in-place) """ + if q.ndim != 3 or q.shape[2] != HEAD_DIM: + raise ValueError(f"q must have shape (B, H, HEAD_DIM), got {tuple(q.shape)}") + if k.shape != q.shape: + raise ValueError(f"k must have the same shape as q, got k={tuple(k.shape)}, q={tuple(q.shape)}") B = q.shape[0] H = q.shape[1] + if v.ndim != 3 or v.shape[0] != B or v.shape[2] != HEAD_DIM: + raise ValueError(f"v must have shape (B, HV, HEAD_DIM), got {tuple(v.shape)}") + HV = v.shape[1] + if out.shape != (B, HV, HEAD_DIM): + raise ValueError(f"out must have shape {(B, HV, HEAD_DIM)}, got {tuple(out.shape)}") + if HV < H or HV % H != 0: + raise ValueError(f"HV ({HV}) must be >= H ({H}) and divisible by H") + if decay_scales.ndim != 1: + raise ValueError(f"decay_scales must be 1D, got {tuple(decay_scales.shape)}") + if decay_scales.shape[0] == H and HV != H: + decay_scales = decay_scales.repeat_interleave(HV // H).contiguous() + elif decay_scales.shape[0] == HV: + decay_scales = decay_scales.contiguous() + else: + raise ValueError(f"decay_scales must have shape ({HV},) or ({H},), got {tuple(decay_scales.shape)}") k_dim_block = HEAD_DIM // K_SPLIT_DIM if k_dim_block > 1: @@ -630,13 +653,13 @@ def linear_attention_decode( # Get compiled kernel (cached) pool_dim0 = s.shape[0] - cache_key = (B, 1, H, HEAD_DIM, HEAD_DIM, pool_dim0, softmax_scale, USE_FAST_MATH) + cache_key = (B, 1, H, HV, HEAD_DIM, HEAD_DIM, pool_dim0, softmax_scale, USE_FAST_MATH) cache = _get_compiled_kernel(*cache_key) h0_source = s # Validate state pool dimensions - assert s.shape[0] % H == 0, f"s.shape[0] must be divisible by H={H}, got {s.shape[0]}" + assert s.shape[0] % HV == 0, f"s.shape[0] must be divisible by HV={HV}, got {s.shape[0]}" # First-time compilation if "compiled" not in cache: stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) @@ -674,6 +697,7 @@ def linear_attention_decode( h0_idx_tensor, softmax_scale=softmax_scale, H=H, + HV=HV, B=B, T=1, K=HEAD_DIM, @@ -690,11 +714,11 @@ def linear_attention_decode( def seg_la_d_kernel_cute( - q: torch.Tensor, # [B, 1, heads, HEAD_DIM] - k: torch.Tensor, # [B, 1, heads, HEAD_DIM] - v: torch.Tensor, # [B, 1, heads, HEAD_DIM] - s: torch.Tensor, # [pool_size, heads, K*V] - out: torch.Tensor, # [B, 1, heads, HEAD_DIM] + q: torch.Tensor, # [B, H, HEAD_DIM] + k: torch.Tensor, # [B, H, HEAD_DIM] + v: torch.Tensor, # [B, HV, HEAD_DIM] + s: torch.Tensor, # [pool_size * HV, V, K] + out: torch.Tensor, # [B, HV, HEAD_DIM] softmax_scale: float, stride_q: int, stride_k: int, @@ -702,7 +726,7 @@ def seg_la_d_kernel_cute( stride_s: int, stride_o: int, s_offsets: torch.Tensor, # [B] - state pool indices - decay_scales: torch.Tensor, # [H] + decay_scales: torch.Tensor, # [HV] or [H] HEAD_DIM: int, K_SPLIT_DIM: int, V_SPLIT_DIM: int, diff --git a/tests/test_la_decode.py b/tests/test_la_decode.py index 5b57ac5..5a22a6b 100644 --- a/tests/test_la_decode.py +++ b/tests/test_la_decode.py @@ -48,45 +48,59 @@ def torch_la_decode_ref(q, k, v, state, decay_scales, scale): Pure PyTorch reference for single-token linear attention decode. Args: - q, k, v: [B, H, D] bf16 - state: [B, H, D, D] fp32 (K x V layout) - decay_scales: [H] fp32 (positive values; kernel does exp(-decay)) + q, k: [B, H, D] bf16 + v: [B, HV, D] bf16 + state: [B, HV, D, D] fp32 (K x V layout) + decay_scales: [HV] or [H] fp32 (positive values; kernel does exp(-decay)) scale: float Returns: - o: [B, H, D] bf16 - state_new: [B, H, D, D] fp32 + o: [B, HV, D] bf16 + state_new: [B, HV, D, D] fp32 """ B, H, D = q.shape + HV = v.shape[1] + assert HV >= H and HV % H == 0, f"HV ({HV}) must be >= H ({H}) and divisible by H" + group_size = HV // H + if group_size > 1: + q = q.repeat_interleave(group_size, dim=1) + k = k.repeat_interleave(group_size, dim=1) + if decay_scales.shape[0] == H and HV != H: + decay_scales = decay_scales.repeat_interleave(group_size) + q_f = q.float() * scale k_f = k.float() v_f = v.float() - decay = torch.exp(-decay_scales).view(1, H, 1, 1) # [1, H, 1, 1] - state_new = state * decay + k_f.unsqueeze(-1) * v_f.unsqueeze(-2) # [B,H,D,D] - o = torch.einsum("bhk,bhkv->bhv", q_f, state_new) # [B,H,D] + decay = torch.exp(-decay_scales).view(1, HV, 1, 1) # [1, HV, 1, 1] + state_new = state * decay + k_f.unsqueeze(-1) * v_f.unsqueeze(-2) # [B,HV,D,D] + o = torch.einsum("bhk,bhkv->bhv", q_f, state_new) # [B,HV,D] return o.to(torch.bfloat16), state_new # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- -def make_inputs(B, H, D, device="cuda", seed=42): +def make_inputs(B, H, D, HV=None, device="cuda", seed=42): torch.manual_seed(seed) + HV = H if HV is None else HV q = torch.randn(B, H, D, device=device, dtype=torch.bfloat16) k = torch.randn(B, H, D, device=device, dtype=torch.bfloat16) - v = torch.randn(B, H, D, device=device, dtype=torch.bfloat16) - state = torch.randn(B, H, D, D, device=device, dtype=torch.float32) * 0.01 + v = torch.randn(B, HV, D, device=device, dtype=torch.bfloat16) + state = torch.randn(B, HV, D, D, device=device, dtype=torch.float32) * 0.01 return q, k, v, state def run_la_decode(q, k, v, state_4d, decay_scales, scale): """Run la_decode with proper state layout conversion.""" - B, H, D, _ = state_4d.shape - # la_decode kernel expects BHVK layout: [B*H, V, K] - # Reference/test state is BHKV: [B, H, K, V] → transpose to BHVK - state_cute = state_4d.clone().transpose(-1, -2).contiguous().reshape(B * H, D, D) - out = torch.zeros(B, H, D, device=q.device, dtype=torch.bfloat16) + B, H, D = q.shape + HV = v.shape[1] + assert state_4d.shape == (B, HV, D, D) + + # la_decode kernel expects BHVK layout: [B*HV, V, K] + # Reference/test state is BHKV: [B, HV, K, V] -> transpose to BHVK + state_cute = state_4d.clone().transpose(-1, -2).contiguous().reshape(B * HV, D, D) + out = torch.zeros(B, HV, D, device=q.device, dtype=torch.bfloat16) s_offsets = torch.arange(B, device=q.device, dtype=torch.int32) linear_attention_decode( @@ -108,7 +122,7 @@ def run_la_decode(q, k, v, state_4d, decay_scales, scale): V_SPLIT_DIM=D, ) # Convert output state back from BHVK to BHKV for comparison - state_out = state_cute.reshape(B, H, D, D).transpose(-1, -2).contiguous() + state_out = state_cute.reshape(B, HV, D, D).transpose(-1, -2).contiguous() return out, state_out @@ -158,6 +172,27 @@ def test_different_heads(H): assert state_rmse / (state_max + 1e-8) < 0.001, f"H={H}: state mismatch" +@pytest.mark.parametrize("B", [2, 33]) +@pytest.mark.parametrize("decay_head_space", ["qk", "value"]) +def test_gva_output_vs_torch_ref(B, decay_head_space): + H, HV, D = 8, 16, 128 + scale = D**-0.5 + decay_hv = 0.5 * torch.arange(HV, device="cuda", dtype=torch.float32) / HV + decay_scales = decay_hv[:: HV // H] if decay_head_space == "qk" else decay_hv + + q, k, v, state = make_inputs(B, H, D, HV=HV) + o_ref, state_ref = torch_la_decode_ref(q, k, v, state, decay_scales, scale) + o_cute, state_cute = run_la_decode(q, k, v, state, decay_scales, scale) + + rmse = torch.sqrt(torch.mean((o_cute.float() - o_ref.float()) ** 2)).item() + max_ref = torch.abs(o_ref.float()).max().item() + assert rmse / (max_ref + 1e-8) < 0.01, f"GVA B={B}: output mismatch" + + state_rmse = torch.sqrt(torch.mean((state_cute - state_ref) ** 2)).item() + state_max = torch.abs(state_ref).max().item() + assert state_rmse / (state_max + 1e-8) < 0.001, f"GVA B={B}: state mismatch" + + def test_zero_decay(): """With decay=0, state_new = state_old + k⊗v (no decay applied).""" B, H, D = 2, 32, 128 @@ -229,20 +264,54 @@ def test_vs_fla(B): # --------------------------------------------------------------------------- +@pytest.mark.skipif(not HAS_FLA, reason="fla not available") +@pytest.mark.parametrize("B", [2, 33]) +def test_gva_vs_fla(B): + H, HV, D = 8, 16, 128 + scale = D**-0.5 + g_gamma = -(8 / HV * 0.5) * torch.arange(HV, device="cuda", dtype=torch.float32) + decay_scales = -g_gamma + group_size = HV // H + + q, k, v, state = make_inputs(B, H, D, HV=HV) + + q_4d = q.repeat_interleave(group_size, dim=1).unsqueeze(1) + k_4d = k.repeat_interleave(group_size, dim=1).unsqueeze(1) + v_4d = v.unsqueeze(1) + with torch.no_grad(): + o_fla, _ = fused_recurrent_fwd( + q_4d, + k_4d, + v_4d, + g_gamma=g_gamma, + scale=scale, + initial_state=state.clone(), + output_final_state=True, + ) + o_fla = o_fla.squeeze(1).to(torch.bfloat16) + + o_cute, _ = run_la_decode(q, k, v, state, decay_scales, scale) + + rmse = torch.sqrt(torch.mean((o_cute.float() - o_fla.float()) ** 2)).item() + max_ref = torch.abs(o_fla.float()).max().item() + assert rmse / (max_ref + 1e-8) < 0.005, f"GVA B={B}: vs fla mismatch, rel_rmse={rmse / (max_ref + 1e-8):.6f}" + + # End-to-End Prefill -> Decode Test # --------------------------------------------------------------------------- -def test_prefill_decode_e2e(): +@pytest.mark.parametrize("H, HV", [(8, 8), (4, 8)]) +def test_prefill_decode_e2e(H, HV): """Verify prefill output state passes directly into decode without transpose.""" from cula.ops.lightning_attn_sm100 import lightning_attn_fwd - B, S, H, D = 2, 64, 8, 128 + B, S, D = 2, 64, 128 scale = D**-0.5 - decay_scales = 0.5 * torch.arange(H, device="cuda", dtype=torch.float32) / H + decay_scales = 0.5 * torch.arange(HV, device="cuda", dtype=torch.float32) / HV # Dummy prefill tokens q_pre = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16) k_pre = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16) - v_pre = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16) + v_pre = torch.randn(B, S, HV, D, device="cuda", dtype=torch.bfloat16) # 1. Run Prefill (Generates BHVK ht) _, ht = lightning_attn_fwd(q_pre, k_pre, v_pre, decay_scales, scale=scale, output_final_state=True) @@ -253,7 +322,7 @@ def test_prefill_decode_e2e(): # Dummy decode tokens q_dec = torch.randn(B, H, D, device="cuda", dtype=torch.bfloat16) k_dec = torch.randn(B, H, D, device="cuda", dtype=torch.bfloat16) - v_dec = torch.randn(B, H, D, device="cuda", dtype=torch.bfloat16) + v_dec = torch.randn(B, HV, D, device="cuda", dtype=torch.bfloat16) # 2. Run Decode (run_la_decode handles BHKV→BHVK internally) out_dec, state_new = run_la_decode(q_dec, k_dec, v_dec, ht_kv, decay_scales, scale) diff --git a/tests/test_la_decode_pool.py b/tests/test_la_decode_pool.py index c9a7579..88765fd 100644 --- a/tests/test_la_decode_pool.py +++ b/tests/test_la_decode_pool.py @@ -32,12 +32,21 @@ def torch_la_decode_ref(q, k, v, state, decay_scales, scale): - """Pure PyTorch reference — state is [B, H, K, V] (BHKV).""" + """Pure PyTorch reference; state is [B, HV, K, V] (BHKV).""" B, H, D = q.shape + HV = v.shape[1] + assert HV >= H and HV % H == 0, f"HV ({HV}) must be >= H ({H}) and divisible by H" + group_size = HV // H + if group_size > 1: + q = q.repeat_interleave(group_size, dim=1) + k = k.repeat_interleave(group_size, dim=1) + if decay_scales.shape[0] == H and HV != H: + decay_scales = decay_scales.repeat_interleave(group_size) + q_f = q.float() * scale k_f = k.float() v_f = v.float() - decay = torch.exp(-decay_scales).view(1, H, 1, 1) + decay = torch.exp(-decay_scales).view(1, HV, 1, 1) state_new = state * decay + k_f.unsqueeze(-1) * v_f.unsqueeze(-2) o = torch.einsum("bhk,bhkv->bhv", q_f, state_new) return o.to(torch.bfloat16), state_new @@ -47,15 +56,16 @@ def run_la_decode_with_pool(q, k, v, state_pool_4d, s_offsets, decay_scales, sca """ Run la_decode with a state pool and arbitrary offsets. - state_pool_4d: [pool_size, H, K, V] — the full pool (BHKV layout) + state_pool_4d: [pool_size, HV, K, V] — the full pool (BHKV layout) s_offsets: [B] — which pool slot each batch element uses """ B, H, D = q.shape + HV = v.shape[1] pool_size = state_pool_4d.shape[0] - # la_decode expects BHVK layout: [pool_size*H, V, K] - state_cute = state_pool_4d.clone().transpose(-1, -2).contiguous().reshape(pool_size * H, D, D) - out = torch.zeros(B, H, D, device=q.device, dtype=torch.bfloat16) + # la_decode expects BHVK layout: [pool_size*HV, V, K] + state_cute = state_pool_4d.clone().transpose(-1, -2).contiguous().reshape(pool_size * HV, D, D) + out = torch.zeros(B, HV, D, device=q.device, dtype=torch.bfloat16) linear_attention_decode( q, @@ -76,7 +86,7 @@ def run_la_decode_with_pool(q, k, v, state_pool_4d, s_offsets, decay_scales, sca V_SPLIT_DIM=D, ) - state_out = state_cute.reshape(pool_size, H, D, D).transpose(-1, -2).contiguous() + state_out = state_cute.reshape(pool_size, HV, D, D).transpose(-1, -2).contiguous() return out, state_out @@ -145,6 +155,35 @@ def test_non_identity_offsets(): assert rel_err < 0.01, f"Non-identity offsets {offsets}: rel_err={rel_err:.6f}" +def test_gva_non_identity_offsets(): + """GVA with offsets: q/k use H heads while v/state/out use HV heads.""" + B = 4 + POOL_SIZE = 6 + H, HV, D = 4, 8, 128 + scale = D**-0.5 + decay_scales = 0.3 * torch.arange(HV, device="cuda", dtype=torch.float32) / HV + + torch.manual_seed(42) + q = torch.randn(B, H, D, device="cuda", dtype=torch.bfloat16) + k = torch.randn(B, H, D, device="cuda", dtype=torch.bfloat16) + v = torch.randn(B, HV, D, device="cuda", dtype=torch.bfloat16) + state_pool = torch.randn(POOL_SIZE, HV, D, D, device="cuda", dtype=torch.float32) * 0.1 + + offsets = [2, 0, 5, 1] + s_offsets = torch.tensor(offsets, device="cuda", dtype=torch.int32) + + out, _ = run_la_decode_with_pool(q, k, v, state_pool, s_offsets, decay_scales, scale) + + state_selected = state_pool[s_offsets.long()] + o_ref, _ = torch_la_decode_ref(q, k, v, state_selected, decay_scales, scale) + + rmse = torch.sqrt(torch.mean((out.float() - o_ref.float()) ** 2)).item() + max_ref = torch.abs(o_ref.float()).max().item() + rel_err = rmse / (max_ref + 1e-8) + + assert rel_err < 0.01, f"GVA non-identity offsets {offsets}: rel_err={rel_err:.6f}" + + # --------------------------------------------------------------------------- # Test 3: Reversed offsets (another non-identity pattern) # ---------------------------------------------------------------------------