From 8c81119b8570cfd050a9aeea9db926a62c63a28c Mon Sep 17 00:00:00 2001 From: plugyawn Date: Fri, 29 May 2026 00:18:14 +0530 Subject: [PATCH 1/3] Add token-linear THD fused RoPE path Signed-off-by: plugyawn --- .../benchmark_rope_thd_token_linear.py | 272 ++++++++++++++++++ tests/pytorch/test_fused_rope.py | 175 ++++++++++- .../common/fused_rope/fused_rope.cu | 156 +++++++++- 3 files changed, 585 insertions(+), 18 deletions(-) create mode 100644 benchmarks/attention/benchmark_rope_thd_token_linear.py diff --git a/benchmarks/attention/benchmark_rope_thd_token_linear.py b/benchmarks/attention/benchmark_rope_thd_token_linear.py new file mode 100644 index 0000000000..2c591f352b --- /dev/null +++ b/benchmarks/attention/benchmark_rope_thd_token_linear.py @@ -0,0 +1,272 @@ +"""Microbenchmark for the token-linear THD fused RoPE path. + +Holds the local packed-token count fixed and sweeps the number of packed +sequences. For each point, measures forward and backward latency of the fused +RoPE kernel under three regimes: + + * forced-old: ``NVTE_FUSED_ROPE_THD_TOKEN_LINEAR=0`` + * forced-new: ``NVTE_FUSED_ROPE_THD_TOKEN_LINEAR=1`` + * heuristic: variable unset + +Outputs a CSV and a PNG. Intended to be run on a single GPU; not distributed. +""" + +from __future__ import annotations + +import argparse +import csv +import os +from contextlib import contextmanager +from pathlib import Path +from typing import Iterable + +import torch + +from transformer_engine.pytorch.attention.rope import ( + RotaryPositionEmbedding, + apply_rotary_pos_emb, +) + + +@contextmanager +def env(name: str, value: str | None): + prev = os.environ.get(name) + if value is None: + os.environ.pop(name, None) + else: + os.environ[name] = value + try: + yield + finally: + if prev is None: + os.environ.pop(name, None) + else: + os.environ[name] = prev + + +def build_cu_seqlens(total_tokens: int, n_seqs: int, cp_size: int = 1) -> torch.Tensor: + """Build a cu_seqlens whose local packed length equals ``total_tokens``. + + Per-sequence lengths are equal to ``total_tokens / n_seqs`` rounded down to + a multiple of ``2 * cp_size``; any leftover tokens are tacked onto the last + span so that the local total is exact. + """ + pad = 2 * cp_size + per = (total_tokens // n_seqs // pad) * pad + if per <= 0: + raise ValueError( + f"n_seqs={n_seqs} is too large for total_tokens={total_tokens} with cp_size={cp_size}" + ) + lengths = [per] * n_seqs + deficit = total_tokens - per * n_seqs + lengths[-1] += (deficit // pad) * pad + cu = [0] + for length in lengths: + cu.append(cu[-1] + length) + return torch.tensor(cu, dtype=torch.int32) + + +def time_fwd_bwd( + fn, + iters: int, + warmup: int, +) -> tuple[float, float]: + """Return (fwd_ms, fwd_plus_bwd_ms) averaged across ``iters`` iterations.""" + torch.cuda.synchronize() + for _ in range(warmup): + out = fn() + out.sum().backward() + torch.cuda.synchronize() + + start_fwd = torch.cuda.Event(enable_timing=True) + end_fwd = torch.cuda.Event(enable_timing=True) + end_bwd = torch.cuda.Event(enable_timing=True) + + fwd_total = 0.0 + full_total = 0.0 + for _ in range(iters): + start_fwd.record() + out = fn() + end_fwd.record() + out.sum().backward() + end_bwd.record() + torch.cuda.synchronize() + fwd_total += start_fwd.elapsed_time(end_fwd) + full_total += start_fwd.elapsed_time(end_bwd) + return fwd_total / iters, full_total / iters + + +def main(argv: Iterable[str] | None = None) -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--total-tokens", type=int, default=65536) + parser.add_argument("--freqs-len", type=int, default=65536) + parser.add_argument("--head-num", type=int, default=32) + parser.add_argument("--hidden", type=int, default=128) + parser.add_argument("--dtype", choices=["bf16", "fp16", "fp32"], default="bf16") + parser.add_argument("--rotary-percent", type=float, default=1.0) + parser.add_argument("--interleaved", action="store_true") + parser.add_argument("--cp-size", type=int, default=1) + parser.add_argument( + "--n-seqs", + type=int, + nargs="+", + default=[1, 8, 32, 128, 512, 1024, 2401], + ) + parser.add_argument("--warmup", type=int, default=5) + parser.add_argument("--iters", type=int, default=20) + parser.add_argument("--out-dir", type=Path, default=Path("rope_thd_bench")) + args = parser.parse_args(argv) + + if not torch.cuda.is_available(): + raise SystemExit("CUDA is required to run this benchmark") + + dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[ + args.dtype + ] + device = torch.device("cuda:0") + + rotary = RotaryPositionEmbedding( + args.hidden, args.rotary_percent, interleaved=args.interleaved + ) + freqs = rotary(args.freqs_len).to(device) + + args.out_dir.mkdir(parents=True, exist_ok=True) + csv_path = args.out_dir / "rope_thd_bench.csv" + fields = [ + "n_seqs", + "regime", + "fwd_ms", + "fwd_bwd_ms", + "bwd_ms", + "blocks_old", + "blocks_new", + "speedup_fwd_bwd_vs_old", + ] + rows = [] + + print( + f"# total_tokens={args.total_tokens} freqs_len={args.freqs_len} h={args.head_num} " + f"d={args.hidden} dtype={args.dtype} cp={args.cp_size}" + ) + print("# n_seqs regime fwd_ms fwd_bwd_ms bwd_ms blocks_old blocks_new speedup") + + by_nseq_old: dict[int, float] = {} + + for n_seqs in args.n_seqs: + cu = build_cu_seqlens(args.total_tokens, n_seqs, cp_size=args.cp_size).to( + device + ) + actual_total = int(cu[-1].item()) + t = torch.rand( + (actual_total // args.cp_size, args.head_num, args.hidden), + dtype=dtype, + device=device, + requires_grad=True, + ) + + def runner() -> torch.Tensor: + # Reset grad in-place to keep the autograd graph fresh. + if t.grad is not None: + t.grad = None + return apply_rotary_pos_emb( + t, + freqs, + tensor_format="thd", + fused=True, + cu_seqlens=cu, + interleaved=args.interleaved, + cp_size=args.cp_size, + cp_rank=0, + ) + + blocks_old = args.freqs_len * n_seqs + blocks_new = actual_total // args.cp_size + + for regime, value in [("old", "0"), ("new", "1"), ("heuristic", None)]: + with env("NVTE_FUSED_ROPE_THD_TOKEN_LINEAR", value): + fwd_ms, full_ms = time_fwd_bwd( + runner, iters=args.iters, warmup=args.warmup + ) + bwd_ms = full_ms - fwd_ms + if regime == "old": + by_nseq_old[n_seqs] = full_ms + speedup = (by_nseq_old[n_seqs] / full_ms) if regime != "old" else 1.0 + rows.append( + { + "n_seqs": n_seqs, + "regime": regime, + "fwd_ms": f"{fwd_ms:.4f}", + "fwd_bwd_ms": f"{full_ms:.4f}", + "bwd_ms": f"{bwd_ms:.4f}", + "blocks_old": blocks_old, + "blocks_new": blocks_new, + "speedup_fwd_bwd_vs_old": f"{speedup:.3f}", + } + ) + print( + f"{n_seqs:>6} {regime:>10} fwd={fwd_ms:7.3f} fwd_bwd={full_ms:7.3f} " + f"bwd={bwd_ms:7.3f} blocks_old={blocks_old} blocks_new={blocks_new} " + f"speedup={speedup:.2f}x" + ) + + with csv_path.open("w", newline="") as fh: + writer = csv.DictWriter(fh, fieldnames=fields) + writer.writeheader() + writer.writerows(rows) + print(f"\nWrote {csv_path}") + + try: + import matplotlib + + matplotlib.use("Agg") + import matplotlib.pyplot as plt + except ImportError: + print("matplotlib not installed; skipping plot") + return + + # Aggregate per regime. + nseqs = sorted({int(r["n_seqs"]) for r in rows}) + by_regime = {regime: [] for regime in ("old", "new", "heuristic")} + for n in nseqs: + for regime in by_regime: + for r in rows: + if int(r["n_seqs"]) == n and r["regime"] == regime: + by_regime[regime].append(float(r["fwd_bwd_ms"])) + break + + fig, axes = plt.subplots(1, 2, figsize=(13, 5)) + + ax = axes[0] + for regime in ("old", "new", "heuristic"): + ax.plot(nseqs, by_regime[regime], marker="o", label=regime) + ax.set_xscale("log") + ax.set_yscale("log") + ax.set_xlabel("n_seqs (packed spans, log)") + ax.set_ylabel("fwd + bwd latency (ms, log)") + ax.set_title( + f"Fused THD RoPE latency vs n_seqs\nT_local={args.total_tokens}, " + f"freqs_len={args.freqs_len}, h={args.head_num}, d={args.hidden}, " + f"{args.dtype}, cp={args.cp_size}" + ) + ax.grid(True, which="both", alpha=0.3) + ax.legend() + + ax = axes[1] + speedup_new = [by_regime["old"][i] / by_regime["new"][i] for i in range(len(nseqs))] + ax.plot(nseqs, speedup_new, marker="o", color="tab:green", label="new vs old") + ax.axhline(1.0, color="gray", linestyle="--", alpha=0.5) + ax.set_xscale("log") + ax.set_xlabel("n_seqs (log)") + ax.set_ylabel("speedup (old / new)") + ax.set_title("Token-linear speedup over (s × b) launch") + ax.grid(True, which="both", alpha=0.3) + ax.legend() + + fig.tight_layout() + png_path = args.out_dir / "rope_thd_bench.png" + fig.savefig(png_path, dpi=120) + print(f"Wrote {png_path}") + + +if __name__ == "__main__": + main() diff --git a/tests/pytorch/test_fused_rope.py b/tests/pytorch/test_fused_rope.py index 50624df9e0..3eeaa18ee5 100644 --- a/tests/pytorch/test_fused_rope.py +++ b/tests/pytorch/test_fused_rope.py @@ -21,7 +21,9 @@ def _overlapping_grad(output: Union[List[torch.Tensor], torch.Tensor]) -> torch. # Gradient is a full tensor -def _non_overlapping_grad(output: Union[List[torch.Tensor], torch.Tensor]) -> torch.Tensor: +def _non_overlapping_grad( + output: Union[List[torch.Tensor], torch.Tensor] +) -> torch.Tensor: if isinstance(output, List): return sum(torch.sum(t * torch.ones_like(t)) for t in output) else: @@ -79,7 +81,9 @@ def test_fused_rope( t = t.transpose(*transpose).contiguous().transpose(*transpose) t.requires_grad = True - rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved) + rotary_pos_emb = RotaryPositionEmbedding( + hidden_size, rotary_percent, interleaved=interleaved + ) emb = rotary_pos_emb(seq_length * cp_size) assert emb.is_contiguous() @@ -150,7 +154,9 @@ def test_fused_rope_thd( # Get arbitrary offsets to be used with RoPE for all the sequences start_positions = ( - torch.randint(0, margin, (len(cu_seqlens) - 1,), dtype=torch.int32, device=device) + torch.randint( + 0, margin, (len(cu_seqlens) - 1,), dtype=torch.int32, device=device + ) if start_positions else None ) @@ -160,7 +166,8 @@ def test_fused_rope_thd( for i in range(1, len(cu_seqlens)): cu_seqlens_padded.append( cu_seqlens_padded[i - 1] - + math.ceil((cu_seqlens[i] - cu_seqlens[i - 1]) / (cp_size * 2)) * (cp_size * 2) + + math.ceil((cu_seqlens[i] - cu_seqlens[i - 1]) / (cp_size * 2)) + * (cp_size * 2) ) else: cu_seqlens_padded = cu_seqlens @@ -178,7 +185,9 @@ def test_fused_rope_thd( t = t.transpose(*transpose).contiguous().transpose(*transpose) t.requires_grad = True - rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved) + rotary_pos_emb = RotaryPositionEmbedding( + hidden_size, rotary_percent, interleaved=interleaved + ) emb = rotary_pos_emb(cu_seqlens_padded[-1]) assert emb.is_contiguous() @@ -252,9 +261,9 @@ def test_unfused_rope_thd_vs_bshd( # that causes unexpected issues. seq_lens = torch.tensor([seqlen for _ in range(batch_size)], dtype=torch.int32) - cu_seqlens = torch.cumsum(torch.cat([torch.zeros(1, dtype=torch.int32), seq_lens]), dim=0).to( - device=device, dtype=torch.int32 - ) + cu_seqlens = torch.cumsum( + torch.cat([torch.zeros(1, dtype=torch.int32), seq_lens]), dim=0 + ).to(device=device, dtype=torch.int32) # Create a tensor in THD format thd = torch.rand( @@ -274,7 +283,9 @@ def test_unfused_rope_thd_vs_bshd( sbhd = sbhd.to(dtype=dtype, device=device) sbhd.requires_grad = True - rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved) + rotary_pos_emb = RotaryPositionEmbedding( + hidden_size, rotary_percent, interleaved=interleaved + ) emb = rotary_pos_emb(max_seqlen) assert emb.is_contiguous() @@ -345,7 +356,8 @@ def test_unfused_rope_thd_vs_bshd( grad_unfused_bshd.reshape(*grad_unfused_thd.shape), grad_unfused_thd ) torch.testing.assert_close( - grad_unfused_sbhd.transpose(1, 0).reshape(*grad_unfused_thd.shape), grad_unfused_thd + grad_unfused_sbhd.transpose(1, 0).reshape(*grad_unfused_thd.shape), + grad_unfused_thd, ) assert output_unfused_thd.is_contiguous() @@ -407,9 +419,13 @@ def test_fused_qkv_rope( t = t.transpose(0, 1).contiguous() t.requires_grad = True - rotary_pos_emb_q = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved) + rotary_pos_emb_q = RotaryPositionEmbedding( + hidden_size, rotary_percent, interleaved=interleaved + ) emb_q = rotary_pos_emb_q(seq_length * cp_size) - rotary_pos_emb_k = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved) + rotary_pos_emb_k = RotaryPositionEmbedding( + hidden_size, rotary_percent, interleaved=interleaved + ) emb_k = rotary_pos_emb_k(seq_length * cp_size) for cp_rank in range(cp_size): @@ -495,3 +511,138 @@ def test_rotary_position_embedding_forward_with_autocast_gives_same_result_as_wi atol=1e-8, rtol=1e-8, ) + + +def _make_packed_thd_cu_seqlens( + n_seqs: int, + mean_len: int, + cp_size: int, + rng: torch.Generator, + include_zero_length: bool = False, +) -> torch.Tensor: + """Build a cu_seqlens tensor for a packed THD batch. + + Each per-sequence length is padded to a multiple of ``2 * cp_size`` so the + integer divisions inside the kernel are exact (matching how Megatron-style + callers pad cu_seqlens for context parallel). Optionally injects zero-length + spans to exercise the upper-bound search. + """ + lengths = torch.randint( + low=1, + high=max(2, 2 * mean_len), + size=(n_seqs,), + generator=rng, + dtype=torch.int64, + ) + if include_zero_length and n_seqs >= 4: + # Sprinkle a handful of zero-length spans, including back-to-back ones + # and one at the front, to exercise boundary cases. + zero_idx = [0, n_seqs // 3, n_seqs // 3 + 1, n_seqs - 2] + for idx in zero_idx: + if 0 <= idx < n_seqs: + lengths[idx] = 0 + pad = 2 * cp_size + lengths = ((lengths + pad - 1) // pad) * pad + # Restore zero-length spans after padding (pad rounds 0 to 0 already). + cu = torch.zeros(n_seqs + 1, dtype=torch.int32) + cu[1:] = torch.cumsum(lengths, dim=0).to(torch.int32) + return cu + + +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("hidden_size", [128]) +@pytest.mark.parametrize("rotary_percent", [0.5, 1.0]) +@pytest.mark.parametrize("interleaved", [False, True]) +@pytest.mark.parametrize("cp_size", [1, 2]) +@pytest.mark.parametrize( + "n_seqs,mean_len,include_zero_length", + [ + (1, 2048, False), + (8, 256, False), + (64, 64, False), + (513, 16, False), + (2401, 8, False), + (128, 32, True), + ], +) +@pytest.mark.parametrize("start_positions", [False, True]) +def test_fused_rope_thd_token_linear_parity( + monkeypatch: pytest.MonkeyPatch, + dtype: torch.dtype, + hidden_size: int, + rotary_percent: float, + interleaved: bool, + cp_size: int, + n_seqs: int, + mean_len: int, + include_zero_length: bool, + start_positions: bool, +) -> None: + """Forces the old and the new THD fused kernel back-to-back and asserts + bitwise equality on both the forward output and the input gradient. The new + kernel must enumerate exactly the same useful blocks as the old one, with + identical per-token math, so equality must hold without tolerance. + """ + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + device = torch.device("cuda:0") + head_num = 16 + + rng = torch.Generator(device="cpu") + rng.manual_seed(0xC0FFEE + n_seqs * 13 + (1 if include_zero_length else 0)) + + cu_seqlens = _make_packed_thd_cu_seqlens( + n_seqs, mean_len, cp_size, rng, include_zero_length=include_zero_length + ).to(device) + total_local = int(cu_seqlens[-1].item()) // cp_size + if total_local == 0: + pytest.skip("empty packed batch after padding") + + start_positions_t = ( + torch.randint(0, 4, (n_seqs,), dtype=torch.int32, device=device) + if start_positions + else None + ) + + t = torch.rand( + (total_local, head_num, hidden_size), dtype=dtype, device=device, generator=None + ) + t.requires_grad = True + + rotary_pos_emb = RotaryPositionEmbedding( + hidden_size, rotary_percent, interleaved=interleaved + ) + # `freqs` must cover (max span length per CP rank + start_positions offset + + # CP dual-chunk offset). Use the global cu_seqlens[-1] length as an upper + # bound, matching how callers size the freqs tensor in practice. + emb = rotary_pos_emb(int(cu_seqlens[-1].item()) + 32) + + def run(force_path: str) -> Tuple[torch.Tensor, torch.Tensor]: + monkeypatch.setenv("NVTE_FUSED_ROPE_THD_TOKEN_LINEAR", force_path) + cp_rank = 0 + out = apply_rotary_pos_emb( + t, + emb, + start_positions=start_positions_t, + interleaved=interleaved, + fused=True, + tensor_format="thd", + cu_seqlens=cu_seqlens, + cp_size=cp_size, + cp_rank=cp_rank, + ) + loss = _overlapping_grad(out) + loss.backward() + grad = t.grad.detach().clone() + t.grad = None + return out.detach().clone(), grad + + out_old, grad_old = run("0") + out_new, grad_new = run("1") + + # Both paths call the same per-token device function with the same + # arguments and write disjoint output rows. Bitwise equality is the right + # bar. + torch.testing.assert_close(out_new, out_old, rtol=0.0, atol=0.0) + torch.testing.assert_close(grad_new, grad_old, rtol=0.0, atol=0.0) diff --git a/transformer_engine/common/fused_rope/fused_rope.cu b/transformer_engine/common/fused_rope/fused_rope.cu index 27dc11ab43..3bcfe57888 100644 --- a/transformer_engine/common/fused_rope/fused_rope.cu +++ b/transformer_engine/common/fused_rope/fused_rope.cu @@ -10,10 +10,32 @@ #include "../common.h" #include "../util/logging.h" +#include "../util/system.h" #include "../utils.cuh" namespace transformer_engine { +// Returns the largest sequence index `b` such that +// `cu_seqlens[b] / cp_size <= t_id`. Used by the token-linear THD kernels to +// locate the sequence span that owns local packed token `t_id`. Uses the same +// integer-division semantics as the existing THD kernels so that the +// per-sequence boundaries agree exactly. +__device__ __forceinline__ int fused_rope_thd_find_seq_id(const int *cu_seqlens, const int nseq, + const int t_id, const int cp_size) { + int lo = 0; + int hi = nseq; + while (lo + 1 < hi) { + int mid = (lo + hi) >> 1; + int mid_start = cu_seqlens[mid] / cp_size; + if (mid_start <= t_id) { + lo = mid; + } else { + hi = mid; + } + } + return lo; +} + template __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs, scalar_t *dst, const bool interleaved, const int s_id, @@ -215,6 +237,98 @@ __global__ void fused_rope_backward_kernel( offset_block_dst, h, d, d2, stride_h, stride_d, o_stride_h, o_stride_d); } +// Token-linear THD forward kernel. Each block handles exactly one packed local +// token row. The block locates its owning sequence via binary search over the +// divided cumulative sequence boundaries, then defers to the same +// `fused_rope_block_forward` device function as the original kernel. +template +__global__ void fused_rope_thd_token_forward_kernel( + const scalar_t *src, const int *cu_seqlens, const float *freqs, const int *start_positions, + scalar_t *dst, const bool interleaved, const int cp_size, const int cp_rank, const int nseq, + const int h, const int d, const int d2, const int stride_t, const int stride_h, + const int stride_d, const int o_stride_t, const int o_stride_h, const int o_stride_d) { + int t_id = blockIdx.x; + int b_id = fused_rope_thd_find_seq_id(cu_seqlens, nseq, t_id, cp_size); + int start = cu_seqlens[b_id] / cp_size; + int end = cu_seqlens[b_id + 1] / cp_size; + int s_id = t_id - start; + int cur_seqlens = end - start; + + int offset_block = t_id * stride_t; + int offset_block_dst = t_id * o_stride_t; + + int begin_offset = (start_positions == nullptr) ? 0 : start_positions[b_id]; + int s_id_for_freqs = s_id + begin_offset; + + if (cp_size > 1) { + assert(cur_seqlens % 2 == 0); + if (s_id < cur_seqlens / 2) { + s_id_for_freqs += cp_rank * cur_seqlens / 2; + } else { + s_id_for_freqs += cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 - cur_seqlens / 2; + } + } + + fused_rope_block_forward(src, freqs, dst, interleaved, s_id_for_freqs, offset_block, + offset_block_dst, h, d, d2, stride_h, stride_d, o_stride_h, o_stride_d); +} + +// Token-linear THD backward kernel. Mirrors the forward variant and dispatches +// to `fused_rope_block_backward`. +template +__global__ void fused_rope_thd_token_backward_kernel( + const scalar_t *src, const int *cu_seqlens, const float *freqs, const int *start_positions, + scalar_t *dst, const bool interleaved, const int cp_size, const int cp_rank, const int nseq, + const int h, const int d, const int d2, const int stride_t, const int stride_h, + const int stride_d, const int o_stride_t, const int o_stride_h, const int o_stride_d) { + int t_id = blockIdx.x; + int b_id = fused_rope_thd_find_seq_id(cu_seqlens, nseq, t_id, cp_size); + int start = cu_seqlens[b_id] / cp_size; + int end = cu_seqlens[b_id + 1] / cp_size; + int s_id = t_id - start; + int cur_seqlens = end - start; + + int offset_block = t_id * stride_t; + int offset_block_dst = t_id * o_stride_t; + + int begin_offset = (start_positions == nullptr) ? 0 : start_positions[b_id]; + int s_id_for_freqs = s_id + begin_offset; + + if (cp_size > 1) { + assert(cur_seqlens % 2 == 0); + if (s_id < cur_seqlens / 2) { + s_id_for_freqs += cp_rank * cur_seqlens / 2; + } else { + s_id_for_freqs += cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 - cur_seqlens / 2; + } + } + + fused_rope_block_backward(src, freqs, dst, interleaved, s_id_for_freqs, offset_block, + offset_block_dst, h, d, d2, stride_h, stride_d, o_stride_h, o_stride_d); +} + +// Host-side dispatcher. Selects the token-linear THD path when it would +// eliminate a meaningful number of dead blocks. The environment variable +// NVTE_FUSED_ROPE_THD_TOKEN_LINEAR overrides the heuristic for testing and +// benchmarking: "0" forces the old kernel, "1" forces the new one. Read on +// every call so tests can toggle it inside a single process. +inline bool fused_rope_thd_use_token_linear(const NVTE_QKV_Format qkv_format, const int b, + const int s, const int64_t total_tokens) { + if (qkv_format != NVTE_QKV_Format::NVTE_THD) return false; + if (total_tokens <= 0) return false; + + const int env_override = transformer_engine::getenv("NVTE_FUSED_ROPE_THD_TOKEN_LINEAR", -1); + if (env_override == 0) return false; + if (env_override == 1) return true; + + // Heuristic: only worth it when the old launch would issue at least 8x as + // many blocks as there are useful tokens, and when there are enough + // sequences for binary-search overhead to be amortized. + if (b < 64) return false; + if (static_cast(s) * static_cast(b) < 8 * total_tokens) return false; + return true; +} + template __device__ void fused_qkv_rope_block_forward(const scalar_t *src, const float *freqs, scalar_t *out, const bool interleaved, const int s_id, @@ -467,9 +581,8 @@ void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, c const int cp_size, const int cp_rank, const int s, const int b, const int h, const int d, const int d2, const int stride_s_or_t, const int stride_b, const int stride_h, const int stride_d, - cudaStream_t stream) { + const int64_t total_tokens, cudaStream_t stream) { int warps_per_block = h < 16 ? 4 : 8; - dim3 blocks(s, b); dim3 threads(THREADS_PER_WARP, warps_per_block); const int shared_mem_size = 2 * d2 * sizeof(float); // cos, sin int o_stride_s_or_t, o_stride_b; @@ -487,6 +600,16 @@ void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, c const int o_stride_h = d; const int o_stride_d = 1; + if (fused_rope_thd_use_token_linear(qkv_format, b, s, total_tokens)) { + dim3 blocks(static_cast(total_tokens)); + fused_rope_thd_token_forward_kernel<<>>( + input, cu_seqlens, freqs, start_positions, output, interleaved, cp_size, cp_rank, b, h, d, + d2, stride_s_or_t, stride_h, stride_d, o_stride_s_or_t, o_stride_h, o_stride_d); + NVTE_CHECK_CUDA(cudaGetLastError()); + return; + } + + dim3 blocks(s, b); fused_rope_forward_kernel<<>>( input, cu_seqlens, freqs, start_positions, output, interleaved, cp_size, cp_rank, s, h, d, d2, stride_s_or_t, stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b, o_stride_h, @@ -501,9 +624,9 @@ void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_se const bool interleaved, const int cp_size, const int cp_rank, const int s, const int b, const int h, const int d, const int d2, const int stride_s_or_t, const int stride_b, const int stride_h, - const int stride_d, cudaStream_t stream) { + const int stride_d, const int64_t total_tokens, + cudaStream_t stream) { int warps_per_block = h < 16 ? 4 : 8; - dim3 blocks(s, b); dim3 threads(THREADS_PER_WARP, warps_per_block); const int shared_mem_size = 2 * d2 * sizeof(float); // cos, sin int o_stride_s_or_t, o_stride_b; @@ -521,6 +644,17 @@ void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_se const int o_stride_h = d; const int o_stride_d = 1; + if (fused_rope_thd_use_token_linear(qkv_format, b, s, total_tokens)) { + dim3 blocks(static_cast(total_tokens)); + fused_rope_thd_token_backward_kernel<<>>( + output_grads, cu_seqlens, freqs, start_positions, input_grads, interleaved, cp_size, + cp_rank, b, h, d, d2, stride_s_or_t, stride_h, stride_d, o_stride_s_or_t, o_stride_h, + o_stride_d); + NVTE_CHECK_CUDA(cudaGetLastError()); + return; + } + + dim3 blocks(s, b); fused_rope_backward_kernel<<>>( output_grads, cu_seqlens, freqs, start_positions, input_grads, interleaved, cp_size, cp_rank, s, h, d, d2, stride_s_or_t, stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b, @@ -579,6 +713,12 @@ void fused_rope_forward(const Tensor &input, const Tensor &cu_seqlens, const Ten const int cp_rank, const int s, const int b, const int h, const int d, const int d2, const int stride_s_or_t, const int stride_b, const int stride_h, const int stride_d, cudaStream_t stream) { + // For THD the packed local token count is the first dimension of the input + // tensor. SBHD/BSHD ignore this value. + const int64_t total_tokens = + (qkv_format == NVTE_QKV_Format::NVTE_THD && !input.data.shape.empty()) + ? static_cast(input.data.shape[0]) + : 0; TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( input.data.dtype, scalar_t, fused_rope_forward_launcher(reinterpret_cast(input.data.dptr), @@ -587,7 +727,7 @@ void fused_rope_forward(const Tensor &input, const Tensor &cu_seqlens, const Ten reinterpret_cast(start_positions.data.dptr), reinterpret_cast(output->data.dptr), qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t, - stride_b, stride_h, stride_d, stream);); + stride_b, stride_h, stride_d, total_tokens, stream);); } void fused_rope_backward(const Tensor &output_grads, const Tensor &cu_seqlens, const Tensor &freqs, @@ -597,6 +737,10 @@ void fused_rope_backward(const Tensor &output_grads, const Tensor &cu_seqlens, c const int h, const int d, const int d2, const int stride_s_or_t, const int stride_b, const int stride_h, const int stride_d, cudaStream_t stream) { + const int64_t total_tokens = + (qkv_format == NVTE_QKV_Format::NVTE_THD && !output_grads.data.shape.empty()) + ? static_cast(output_grads.data.shape[0]) + : 0; TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( output_grads.data.dtype, scalar_t, fused_rope_backward_launcher(reinterpret_cast(output_grads.data.dptr), @@ -605,7 +749,7 @@ void fused_rope_backward(const Tensor &output_grads, const Tensor &cu_seqlens, c reinterpret_cast(start_positions.data.dptr), reinterpret_cast(input_grads->data.dptr), qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t, - stride_b, stride_h, stride_d, stream);); + stride_b, stride_h, stride_d, total_tokens, stream);); } void fused_qkv_rope_forward(const Tensor &qkv_input, const Tensor &q_freqs, const Tensor &k_freqs, From 059a2e299a28d863f98fa1d84825875236b2cb29 Mon Sep 17 00:00:00 2001 From: plugyawn Date: Fri, 29 May 2026 01:03:04 +0530 Subject: [PATCH 2/3] Add THD RoPE full-layer benchmark Signed-off-by: plugyawn --- .../benchmark_rope_thd_full_layer.py | 373 ++++++++++++++++++ 1 file changed, 373 insertions(+) create mode 100644 benchmarks/attention/benchmark_rope_thd_full_layer.py diff --git a/benchmarks/attention/benchmark_rope_thd_full_layer.py b/benchmarks/attention/benchmark_rope_thd_full_layer.py new file mode 100644 index 0000000000..67caf4e0ed --- /dev/null +++ b/benchmarks/attention/benchmark_rope_thd_full_layer.py @@ -0,0 +1,373 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Full TransformerLayer benchmark for token-linear THD fused RoPE. + +This benchmark keeps the local packed-token count and RoPE table length fixed +while varying the number of packed THD spans. It compares the old fused RoPE +launch, the new token-linear launch, and the heuristic path on a TE +TransformerLayer using THD input and rotary embeddings. It also measures a +paired RoPE-only operation with the same tensor shape, so the output table can +report both end-to-end layer speedup and the fraction of layer time attributable +to fused RoPE. +""" + +from __future__ import annotations + +import argparse +import csv +import os +from contextlib import contextmanager +from pathlib import Path +from typing import Callable, Iterable + +import torch + +import transformer_engine.pytorch as te +from transformer_engine.pytorch.attention.rope import ( + RotaryPositionEmbedding, + apply_rotary_pos_emb, +) + + +@contextmanager +def env(name: str, value: str | None): + prev = os.environ.get(name) + if value is None: + os.environ.pop(name, None) + else: + os.environ[name] = value + try: + yield + finally: + if prev is None: + os.environ.pop(name, None) + else: + os.environ[name] = prev + + +def build_cu_seqlens(total_tokens: int, n_seqs: int) -> tuple[torch.Tensor, int]: + """Build balanced packed THD cu_seqlens with an exact total token count.""" + per = total_tokens // n_seqs + if per <= 0: + raise ValueError( + f"n_seqs={n_seqs} is too large for total_tokens={total_tokens}" + ) + rem = total_tokens - per * n_seqs + lengths = [per + (1 if i < rem else 0) for i in range(n_seqs)] + cu = [0] + max_seqlen = 0 + for length in lengths: + cu.append(cu[-1] + length) + max_seqlen = max(max_seqlen, length) + return torch.tensor(cu, dtype=torch.int32), max_seqlen + + +def zero_grads(params: Iterable[torch.Tensor], x: torch.Tensor) -> None: + if x.grad is not None: + x.grad = None + for p in params: + if p.grad is not None: + p.grad = None + + +def time_fwd_bwd( + fn: Callable[[], torch.Tensor], warmup: int, iters: int +) -> tuple[float, float]: + torch.cuda.synchronize() + for _ in range(warmup): + out = fn() + out.sum().backward() + torch.cuda.synchronize() + + start = torch.cuda.Event(enable_timing=True) + fwd_end = torch.cuda.Event(enable_timing=True) + bwd_end = torch.cuda.Event(enable_timing=True) + fwd_total = 0.0 + full_total = 0.0 + for _ in range(iters): + start.record() + out = fn() + fwd_end.record() + out.sum().backward() + bwd_end.record() + torch.cuda.synchronize() + fwd_total += start.elapsed_time(fwd_end) + full_total += start.elapsed_time(bwd_end) + return fwd_total / iters, full_total / iters + + +def make_layer(args: argparse.Namespace, dtype: torch.dtype) -> te.TransformerLayer: + sigma = 0.02 + + def init_method(tensor: torch.Tensor) -> torch.Tensor: + return torch.nn.init.normal_(tensor, mean=0.0, std=sigma) + + return te.TransformerLayer( + args.hidden_size, + args.ffn_hidden_size, + args.num_heads, + layernorm_epsilon=1e-5, + hidden_dropout=0.0, + attention_dropout=0.0, + init_method=init_method, + output_layer_init_method=init_method, + layer_number=1, + kv_channels=args.head_dim, + self_attn_mask_type="padding_causal", + tp_group=None, + tp_size=1, + params_dtype=dtype, + get_rng_state_tracker=None, + fuse_wgrad_accumulation=False, + seq_length=args.freqs_len, + micro_batch_size=1, + sequence_parallel=False, + apply_residual_connection_post_layernorm=False, + output_layernorm=False, + layer_type="encoder", + set_parallel_mode=True, + fuse_qkv_params=True, + zero_centered_gamma=False, + qkv_weight_interleaved=True, + bias=True, + attn_input_format="thd", + rotary_pos_interleaved=args.interleaved, + device="cuda", + ).to(dtype=dtype, device="cuda") + + +def main(argv: Iterable[str] | None = None) -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--total-tokens", type=int, default=65536) + parser.add_argument("--freqs-len", type=int, default=65536) + parser.add_argument("--hidden-size", type=int, default=1536) + parser.add_argument("--ffn-hidden-size", type=int, default=6144) + parser.add_argument("--num-heads", type=int, default=12) + parser.add_argument("--dtype", choices=["bf16", "fp16"], default="bf16") + parser.add_argument("--interleaved", action="store_true") + parser.add_argument("--warmup", type=int, default=2) + parser.add_argument("--iters", type=int, default=5) + # n_seqs=50 is intentionally omitted from the default sweep because the + # balanced-span shape has max_seqlen~=1311 and can hit a cuDNN fused-attn + # execution failure unrelated to RoPE on the tested H100 stack. The high-span + # cases below are the issue-relevant regime where RoPE launch waste dominates. + parser.add_argument("--n-seqs", type=int, nargs="+", default=[128, 512, 1024, 2401]) + parser.add_argument( + "--out-dir", type=Path, default=Path("rope_thd_full_layer_bench") + ) + args = parser.parse_args(argv) + + if not torch.cuda.is_available(): + raise SystemExit("CUDA is required") + if args.hidden_size % args.num_heads != 0: + raise SystemExit("--hidden-size must be divisible by --num-heads") + args.head_dim = args.hidden_size // args.num_heads + if args.freqs_len < args.total_tokens: + raise SystemExit( + "--freqs-len should be >= --total-tokens for this long-context benchmark" + ) + + torch.manual_seed(1234) + dtype = {"bf16": torch.bfloat16, "fp16": torch.float16}[args.dtype] + device = torch.device("cuda") + + rotary = RotaryPositionEmbedding(args.head_dim, interleaved=args.interleaved) + freqs = rotary(args.freqs_len).to(device=device) + + args.out_dir.mkdir(parents=True, exist_ok=True) + csv_path = args.out_dir / "rope_thd_full_layer_bench.csv" + fields = [ + "n_seqs", + "regime", + "max_seqlen", + "layer_fwd_ms", + "layer_fwd_bwd_ms", + "layer_bwd_ms", + "rope_pair_fwd_ms", + "rope_pair_fwd_bwd_ms", + "rope_pair_bwd_ms", + "rope_pair_pct_layer", + "layer_speedup_vs_old", + "rope_pair_speedup_vs_old", + ] + rows: list[dict[str, str | int]] = [] + + print( + "# full-layer THD RoPE benchmark: " + f"T={args.total_tokens} freqs_len={args.freqs_len} hidden={args.hidden_size} " + f"ffn={args.ffn_hidden_size} heads={args.num_heads} dtype={args.dtype}", + flush=True, + ) + print( + "# n_seqs regime max_seqlen layer_fwd layer_fwd_bwd rope_pair_fwd_bwd " + "rope_pct layer_speedup", + flush=True, + ) + + for n_seqs in args.n_seqs: + cu_cpu, max_seqlen = build_cu_seqlens(args.total_tokens, n_seqs) + cu = cu_cpu.to(device=device) + x = torch.randn( + args.total_tokens, + args.hidden_size, + dtype=dtype, + device=device, + requires_grad=True, + ) + q = torch.randn( + args.total_tokens, + args.num_heads, + args.head_dim, + dtype=dtype, + device=device, + requires_grad=True, + ) + k = torch.randn_like(q, requires_grad=True) + layer = make_layer(args, dtype) + layer.train() + params = tuple(layer.parameters()) + + layer_old = None + rope_old = None + + for regime, override in (("old", "0"), ("new", "1"), ("heuristic", None)): + with env("NVTE_FUSED_ROPE_THD_TOKEN_LINEAR", override): + + def layer_fn() -> torch.Tensor: + zero_grads(params, x) + return layer( + x, + rotary_pos_emb=freqs, + cu_seqlens_q=cu, + cu_seqlens_kv=cu, + max_seqlen_q=max_seqlen, + max_seqlen_kv=max_seqlen, + ) + + def rope_pair_fn() -> torch.Tensor: + if q.grad is not None: + q.grad = None + if k.grad is not None: + k.grad = None + q_out = apply_rotary_pos_emb( + q, + freqs, + tensor_format="thd", + fused=True, + cu_seqlens=cu, + interleaved=args.interleaved, + ) + k_out = apply_rotary_pos_emb( + k, + freqs, + tensor_format="thd", + fused=True, + cu_seqlens=cu, + interleaved=args.interleaved, + ) + return q_out + k_out + + layer_fwd, layer_full = time_fwd_bwd(layer_fn, args.warmup, args.iters) + rope_fwd, rope_full = time_fwd_bwd( + rope_pair_fn, args.warmup, args.iters + ) + + if regime == "old": + layer_old = layer_full + rope_old = rope_full + assert layer_old is not None and rope_old is not None + layer_speedup = layer_old / layer_full + rope_speedup = rope_old / rope_full + rope_pct = 100.0 * rope_full / layer_full + rows.append( + { + "n_seqs": n_seqs, + "regime": regime, + "max_seqlen": max_seqlen, + "layer_fwd_ms": f"{layer_fwd:.4f}", + "layer_fwd_bwd_ms": f"{layer_full:.4f}", + "layer_bwd_ms": f"{layer_full - layer_fwd:.4f}", + "rope_pair_fwd_ms": f"{rope_fwd:.4f}", + "rope_pair_fwd_bwd_ms": f"{rope_full:.4f}", + "rope_pair_bwd_ms": f"{rope_full - rope_fwd:.4f}", + "rope_pair_pct_layer": f"{rope_pct:.2f}", + "layer_speedup_vs_old": f"{layer_speedup:.3f}", + "rope_pair_speedup_vs_old": f"{rope_speedup:.3f}", + } + ) + print( + f"{n_seqs:>6} {regime:>10} {max_seqlen:>10} " + f"layer_fwd={layer_fwd:8.3f} layer_fwd_bwd={layer_full:8.3f} " + f"rope_pair_fwd_bwd={rope_full:8.3f} rope_pct={rope_pct:6.2f}% " + f"layer_speedup={layer_speedup:6.3f}x", + flush=True, + ) + + with csv_path.open("w", newline="") as fh: + writer = csv.DictWriter(fh, fieldnames=fields) + writer.writeheader() + writer.writerows(rows) + print(f"\nWrote {csv_path}") + + try: + import matplotlib + + matplotlib.use("Agg") + import matplotlib.pyplot as plt + except ImportError: + print("matplotlib not installed; skipping plot") + return + + nseqs = sorted({int(r["n_seqs"]) for r in rows}) + by_regime = {regime: [] for regime in ("old", "new", "heuristic")} + pct_by_regime = {regime: [] for regime in ("old", "new", "heuristic")} + for n in nseqs: + for regime in by_regime: + row = next( + r for r in rows if int(r["n_seqs"]) == n and r["regime"] == regime + ) + by_regime[regime].append(float(row["layer_fwd_bwd_ms"])) + pct_by_regime[regime].append(float(row["rope_pair_pct_layer"])) + + fig, axes = plt.subplots(1, 3, figsize=(17, 5)) + ax = axes[0] + for regime, values in by_regime.items(): + ax.plot(nseqs, values, marker="o", label=regime) + ax.set_xscale("log") + ax.set_yscale("log") + ax.set_xlabel("n_seqs") + ax.set_ylabel("TransformerLayer fwd+bwd (ms)") + ax.set_title("Full THD TransformerLayer") + ax.grid(True, which="both", alpha=0.3) + ax.legend() + + ax = axes[1] + speedups = [by_regime["old"][i] / by_regime["new"][i] for i in range(len(nseqs))] + ax.plot(nseqs, speedups, marker="o", color="tab:green") + ax.axhline(1.0, color="gray", linestyle="--", alpha=0.5) + ax.set_xscale("log") + ax.set_xlabel("n_seqs") + ax.set_ylabel("Layer speedup (old / new)") + ax.set_title("End-to-end layer speedup") + ax.grid(True, which="both", alpha=0.3) + + ax = axes[2] + for regime, values in pct_by_regime.items(): + ax.plot(nseqs, values, marker="o", label=regime) + ax.set_xscale("log") + ax.set_xlabel("n_seqs") + ax.set_ylabel("paired RoPE fwd+bwd / layer fwd+bwd (%)") + ax.set_title("RoPE share estimate") + ax.grid(True, which="both", alpha=0.3) + ax.legend() + + fig.tight_layout() + png_path = args.out_dir / "rope_thd_full_layer_bench.png" + fig.savefig(png_path, dpi=120) + print(f"Wrote {png_path}") + + +if __name__ == "__main__": + main() From 6c46696f1b71edb9013202e20091b3c3dacd8569 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 28 May 2026 20:53:32 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: plugyawn --- .../benchmark_rope_thd_full_layer.py | 24 +++------- .../benchmark_rope_thd_token_linear.py | 16 ++----- tests/pytorch/test_fused_rope.py | 45 ++++++------------- 3 files changed, 23 insertions(+), 62 deletions(-) diff --git a/benchmarks/attention/benchmark_rope_thd_full_layer.py b/benchmarks/attention/benchmark_rope_thd_full_layer.py index 67caf4e0ed..2533dad3ed 100644 --- a/benchmarks/attention/benchmark_rope_thd_full_layer.py +++ b/benchmarks/attention/benchmark_rope_thd_full_layer.py @@ -51,9 +51,7 @@ def build_cu_seqlens(total_tokens: int, n_seqs: int) -> tuple[torch.Tensor, int] """Build balanced packed THD cu_seqlens with an exact total token count.""" per = total_tokens // n_seqs if per <= 0: - raise ValueError( - f"n_seqs={n_seqs} is too large for total_tokens={total_tokens}" - ) + raise ValueError(f"n_seqs={n_seqs} is too large for total_tokens={total_tokens}") rem = total_tokens - per * n_seqs lengths = [per + (1 if i < rem else 0) for i in range(n_seqs)] cu = [0] @@ -72,9 +70,7 @@ def zero_grads(params: Iterable[torch.Tensor], x: torch.Tensor) -> None: p.grad = None -def time_fwd_bwd( - fn: Callable[[], torch.Tensor], warmup: int, iters: int -) -> tuple[float, float]: +def time_fwd_bwd(fn: Callable[[], torch.Tensor], warmup: int, iters: int) -> tuple[float, float]: torch.cuda.synchronize() for _ in range(warmup): out = fn() @@ -154,9 +150,7 @@ def main(argv: Iterable[str] | None = None) -> None: # execution failure unrelated to RoPE on the tested H100 stack. The high-span # cases below are the issue-relevant regime where RoPE launch waste dominates. parser.add_argument("--n-seqs", type=int, nargs="+", default=[128, 512, 1024, 2401]) - parser.add_argument( - "--out-dir", type=Path, default=Path("rope_thd_full_layer_bench") - ) + parser.add_argument("--out-dir", type=Path, default=Path("rope_thd_full_layer_bench")) args = parser.parse_args(argv) if not torch.cuda.is_available(): @@ -165,9 +159,7 @@ def main(argv: Iterable[str] | None = None) -> None: raise SystemExit("--hidden-size must be divisible by --num-heads") args.head_dim = args.hidden_size // args.num_heads if args.freqs_len < args.total_tokens: - raise SystemExit( - "--freqs-len should be >= --total-tokens for this long-context benchmark" - ) + raise SystemExit("--freqs-len should be >= --total-tokens for this long-context benchmark") torch.manual_seed(1234) dtype = {"bf16": torch.bfloat16, "fp16": torch.float16}[args.dtype] @@ -270,9 +262,7 @@ def rope_pair_fn() -> torch.Tensor: return q_out + k_out layer_fwd, layer_full = time_fwd_bwd(layer_fn, args.warmup, args.iters) - rope_fwd, rope_full = time_fwd_bwd( - rope_pair_fn, args.warmup, args.iters - ) + rope_fwd, rope_full = time_fwd_bwd(rope_pair_fn, args.warmup, args.iters) if regime == "old": layer_old = layer_full @@ -325,9 +315,7 @@ def rope_pair_fn() -> torch.Tensor: pct_by_regime = {regime: [] for regime in ("old", "new", "heuristic")} for n in nseqs: for regime in by_regime: - row = next( - r for r in rows if int(r["n_seqs"]) == n and r["regime"] == regime - ) + row = next(r for r in rows if int(r["n_seqs"]) == n and r["regime"] == regime) by_regime[regime].append(float(row["layer_fwd_bwd_ms"])) pct_by_regime[regime].append(float(row["rope_pair_pct_layer"])) diff --git a/benchmarks/attention/benchmark_rope_thd_token_linear.py b/benchmarks/attention/benchmark_rope_thd_token_linear.py index 2c591f352b..734c7289be 100644 --- a/benchmarks/attention/benchmark_rope_thd_token_linear.py +++ b/benchmarks/attention/benchmark_rope_thd_token_linear.py @@ -120,14 +120,10 @@ def main(argv: Iterable[str] | None = None) -> None: if not torch.cuda.is_available(): raise SystemExit("CUDA is required to run this benchmark") - dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[ - args.dtype - ] + dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[args.dtype] device = torch.device("cuda:0") - rotary = RotaryPositionEmbedding( - args.hidden, args.rotary_percent, interleaved=args.interleaved - ) + rotary = RotaryPositionEmbedding(args.hidden, args.rotary_percent, interleaved=args.interleaved) freqs = rotary(args.freqs_len).to(device) args.out_dir.mkdir(parents=True, exist_ok=True) @@ -153,9 +149,7 @@ def main(argv: Iterable[str] | None = None) -> None: by_nseq_old: dict[int, float] = {} for n_seqs in args.n_seqs: - cu = build_cu_seqlens(args.total_tokens, n_seqs, cp_size=args.cp_size).to( - device - ) + cu = build_cu_seqlens(args.total_tokens, n_seqs, cp_size=args.cp_size).to(device) actual_total = int(cu[-1].item()) t = torch.rand( (actual_total // args.cp_size, args.head_num, args.hidden), @@ -184,9 +178,7 @@ def runner() -> torch.Tensor: for regime, value in [("old", "0"), ("new", "1"), ("heuristic", None)]: with env("NVTE_FUSED_ROPE_THD_TOKEN_LINEAR", value): - fwd_ms, full_ms = time_fwd_bwd( - runner, iters=args.iters, warmup=args.warmup - ) + fwd_ms, full_ms = time_fwd_bwd(runner, iters=args.iters, warmup=args.warmup) bwd_ms = full_ms - fwd_ms if regime == "old": by_nseq_old[n_seqs] = full_ms diff --git a/tests/pytorch/test_fused_rope.py b/tests/pytorch/test_fused_rope.py index 3eeaa18ee5..332e8ff6ec 100644 --- a/tests/pytorch/test_fused_rope.py +++ b/tests/pytorch/test_fused_rope.py @@ -21,9 +21,7 @@ def _overlapping_grad(output: Union[List[torch.Tensor], torch.Tensor]) -> torch. # Gradient is a full tensor -def _non_overlapping_grad( - output: Union[List[torch.Tensor], torch.Tensor] -) -> torch.Tensor: +def _non_overlapping_grad(output: Union[List[torch.Tensor], torch.Tensor]) -> torch.Tensor: if isinstance(output, List): return sum(torch.sum(t * torch.ones_like(t)) for t in output) else: @@ -81,9 +79,7 @@ def test_fused_rope( t = t.transpose(*transpose).contiguous().transpose(*transpose) t.requires_grad = True - rotary_pos_emb = RotaryPositionEmbedding( - hidden_size, rotary_percent, interleaved=interleaved - ) + rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved) emb = rotary_pos_emb(seq_length * cp_size) assert emb.is_contiguous() @@ -154,9 +150,7 @@ def test_fused_rope_thd( # Get arbitrary offsets to be used with RoPE for all the sequences start_positions = ( - torch.randint( - 0, margin, (len(cu_seqlens) - 1,), dtype=torch.int32, device=device - ) + torch.randint(0, margin, (len(cu_seqlens) - 1,), dtype=torch.int32, device=device) if start_positions else None ) @@ -166,8 +160,7 @@ def test_fused_rope_thd( for i in range(1, len(cu_seqlens)): cu_seqlens_padded.append( cu_seqlens_padded[i - 1] - + math.ceil((cu_seqlens[i] - cu_seqlens[i - 1]) / (cp_size * 2)) - * (cp_size * 2) + + math.ceil((cu_seqlens[i] - cu_seqlens[i - 1]) / (cp_size * 2)) * (cp_size * 2) ) else: cu_seqlens_padded = cu_seqlens @@ -185,9 +178,7 @@ def test_fused_rope_thd( t = t.transpose(*transpose).contiguous().transpose(*transpose) t.requires_grad = True - rotary_pos_emb = RotaryPositionEmbedding( - hidden_size, rotary_percent, interleaved=interleaved - ) + rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved) emb = rotary_pos_emb(cu_seqlens_padded[-1]) assert emb.is_contiguous() @@ -261,9 +252,9 @@ def test_unfused_rope_thd_vs_bshd( # that causes unexpected issues. seq_lens = torch.tensor([seqlen for _ in range(batch_size)], dtype=torch.int32) - cu_seqlens = torch.cumsum( - torch.cat([torch.zeros(1, dtype=torch.int32), seq_lens]), dim=0 - ).to(device=device, dtype=torch.int32) + cu_seqlens = torch.cumsum(torch.cat([torch.zeros(1, dtype=torch.int32), seq_lens]), dim=0).to( + device=device, dtype=torch.int32 + ) # Create a tensor in THD format thd = torch.rand( @@ -283,9 +274,7 @@ def test_unfused_rope_thd_vs_bshd( sbhd = sbhd.to(dtype=dtype, device=device) sbhd.requires_grad = True - rotary_pos_emb = RotaryPositionEmbedding( - hidden_size, rotary_percent, interleaved=interleaved - ) + rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved) emb = rotary_pos_emb(max_seqlen) assert emb.is_contiguous() @@ -419,13 +408,9 @@ def test_fused_qkv_rope( t = t.transpose(0, 1).contiguous() t.requires_grad = True - rotary_pos_emb_q = RotaryPositionEmbedding( - hidden_size, rotary_percent, interleaved=interleaved - ) + rotary_pos_emb_q = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved) emb_q = rotary_pos_emb_q(seq_length * cp_size) - rotary_pos_emb_k = RotaryPositionEmbedding( - hidden_size, rotary_percent, interleaved=interleaved - ) + rotary_pos_emb_k = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved) emb_k = rotary_pos_emb_k(seq_length * cp_size) for cp_rank in range(cp_size): @@ -605,14 +590,10 @@ def test_fused_rope_thd_token_linear_parity( else None ) - t = torch.rand( - (total_local, head_num, hidden_size), dtype=dtype, device=device, generator=None - ) + t = torch.rand((total_local, head_num, hidden_size), dtype=dtype, device=device, generator=None) t.requires_grad = True - rotary_pos_emb = RotaryPositionEmbedding( - hidden_size, rotary_percent, interleaved=interleaved - ) + rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved) # `freqs` must cover (max span length per CP rank + start_positions offset + # CP dual-chunk offset). Use the global cu_seqlens[-1] length as an upper # bound, matching how callers size the freqs tensor in practice.