Skip to content

[Common/PyTorch] bugfix: Token-linear fused RoPE impl. for THD tensors.#3057

Open
plugyawn wants to merge 3 commits into
NVIDIA:mainfrom
plugyawn:rope-thd-token-linear
Open

[Common/PyTorch] bugfix: Token-linear fused RoPE impl. for THD tensors.#3057
plugyawn wants to merge 3 commits into
NVIDIA:mainfrom
plugyawn:rope-thd-token-linear

Conversation

@plugyawn
Copy link
Copy Markdown

@plugyawn plugyawn commented May 28, 2026

Description

Adds a token-linear implementation of the existing THD fused RoPE path to remove a launch-scaling bug.

Addresses #2866, which finds an interesting case with RoPE scales by freqs_len × n_spans, which is pathological; it should scale by total tokens. I reproduced the issue and found that it's causing a noticeable drops on even plausibly routine shapes. For eg: the [128/512] and [512/128] cases here.

The new kernel reuses the existing fused_rope_block_forward and fused_rope_block_backward device helpers, so the math doesn't change. All we need to do is add a THD-only path that launches one bloc/packed token.

n_seqs max span old layer fwd+bwd (ms) new layer fwd+bwd (ms) layer speedup old paired-RoPE share new paired-RoPE share
128 512 41.8151 23.0284 1.816x 49.12% 6.14%
512 128 102.1047 23.0167 4.436x 79.38% 6.59%
1024 64 182.9933 23.3783 7.827x 88.36% 6.77%
2401 28 401.0516 24.5668 16.325x 94.40% 6.41%

This is mostly pathological, however, so I've added a condition on the dispatch to avoid the unnecessary binary search overhead, although the overhead appears to be not-that-relevant. The condition is: token-linear only when b >= 64 and the old launch would issue ≥ 8× as many blocks as there are tokens. I'm not sure if this the usual shape of TE updates, so I could remove it!

Some more relevant tests:
Microbenchmark on H100 (bf16, h=32, d=d2=128, freqs_len=T_local=65536, single GPU):

n_seqs old fwd+bwd (ms) new fwd+bwd (ms) speedup
1 1.2746 1.2734 1.001x
8 1.8860 1.3827 1.364x
32 3.9359 1.4462 2.722x
128 12.1849 1.5024 8.110x
512 44.9411 1.5600 28.808x
1024 89.1110 1.5919 55.977x
2401 208.4182 1.6373 127.296x

Fixes: #2866.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Add token-linear THD fused RoPE forward/backward kernels that launch one CUDA block per packed local token row.
  • Add NVTE_FUSED_ROPE_THD_TOKEN_LINEAR=0|1.
  • Reuses existing fused_rope_block_forward and fused_rope_block_backward device helpers.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation <<(none?)>>
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label May 28, 2026
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 28, 2026

Greptile Summary

This PR fixes a launch-scaling bug in the THD fused RoPE path where the original kernel launched freqs_len × n_seqs CUDA blocks (most of which were dead blocks that early-exited), instead of only the total_tokens actually present in the packed batch. Two new CUDA kernels dispatch exactly one block per packed local token row and locate the owning sequence via binary search over cu_seqlens, reusing the existing fused_rope_block_forward/backward device helpers so the per-token math is unchanged.

  • New token-linear kernels: fused_rope_thd_token_forward_kernel and fused_rope_thd_token_backward_kernel with a fused_rope_thd_find_seq_id binary-search helper that maps each absolute token index back to its sequence.
  • Heuristic dispatch: fused_rope_thd_use_token_linear activates the new path when b >= 64 and the old grid would be 8x larger than total_tokens; an env var overrides the heuristic for testing.
  • Tests and benchmarks: A parity test forcing both paths asserts bitwise equality of forward outputs and input gradients; two benchmark scripts measure speedup across n_seqs sweeps.

Confidence Score: 5/5

The change is safe to merge. The new kernels reuse existing device helpers without altering per-token math, the binary search correctly maps every packed token to its owning sequence (including zero-length spans), and the parity test verifies bitwise equality between old and new paths across all parametrize combinations.

The core CUDA change is tightly scoped: both new kernels delegate to the same fused_rope_block_forward/backward device functions with identical s_id and s_id_for_freqs computations as the original THD kernel. The dispatch heuristic is conservative and fully overridable. The parity test is comprehensive across dtypes, cp_size, interleaved, start_positions, and zero-length spans.

transformer_engine/common/fused_rope/fused_rope.cu deserves the closest look for the new kernel entry-points and the heuristic dispatcher.

Important Files Changed

Filename Overview
transformer_engine/common/fused_rope/fused_rope.cu Core CUDA change: adds token-linear forward/backward kernels with a binary-search helper and a heuristic dispatcher; math is identical to the existing THD path but launch topology is fixed.
tests/pytorch/test_fused_rope.py Adds test_fused_rope_thd_token_linear_parity which bitwise-compares old and new forced paths; parametrization covers cp_size 1/2, interleaved/non, start_positions, and zero-length spans. Heuristic auto-select path is not covered by a regression test.
benchmarks/attention/benchmark_rope_thd_token_linear.py New microbenchmark comparing old/new/heuristic paths across n_seqs sweep; lengths correctly aligned to 2×cp_size multiples and env var correctly scoped with a context manager.
benchmarks/attention/benchmark_rope_thd_full_layer.py New full-layer benchmark (TransformerLayer fwd+bwd) measuring speedup and RoPE share under old/new/heuristic regimes; well-structured, gracefully skips matplotlib when not installed.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[fused_rope_forward / fused_rope_backward] --> B{qkv_format == THD?}
    B -- No --> C[SBHD/BSHD path\ndim3 blocks = s,b]
    B -- Yes --> D[read total_tokens\nfrom input.shape 0]
    D --> E[fused_rope_thd_use_token_linear\nb, s, total_tokens]
    E --> F{env override?}
    F -- 0 forced --> G[Old kernel\ndim3 blocks = s,b\nmany dead blocks]
    F -- 1 forced --> H[New token-linear kernel\ndim3 blocks = total_tokens]
    F -- unset --> I{b >= 64 AND\ns*b >= 8*total_tokens?}
    I -- No --> G
    I -- Yes --> H
    H --> J[fused_rope_thd_find_seq_id\nbinary search over cu_seqlens]
    J --> K[compute s_id, cur_seqlens,\ns_id_for_freqs + CP offset]
    K --> L[fused_rope_block_forward\nor fused_rope_block_backward]
    G --> M[original THD path:\nif t_id >= end return\nthen fused_rope_block_forward/backward]
Loading

Reviews (2): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines +250 to +251
int t_id = blockIdx.x;
int b_id = fused_rope_thd_find_seq_id(cu_seqlens, nseq, t_id, cp_size);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Redundant binary search across all threads in the block

Every thread in the block calls fused_rope_thd_find_seq_id with the same arguments (t_id = blockIdx.x, nseq, cp_size) and produces an identical result. With warps_per_block = 8, that's 256 threads each doing O(log nseq) global-memory reads of cu_seqlens that could be performed once. For nseq=2401 (~12 iterations x 256 threads), each block reads ~3,072 redundant entries from cu_seqlens. Performing the search once in thread 0 and broadcasting the result via shared memory would eliminate that overhead.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Smart bot!

Comment on lines +250 to +255
int t_id = blockIdx.x;
int b_id = fused_rope_thd_find_seq_id(cu_seqlens, nseq, t_id, cp_size);
int start = cu_seqlens[b_id] / cp_size;
int end = cu_seqlens[b_id + 1] / cp_size;
int s_id = t_id - start;
int cur_seqlens = end - start;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 No guard for t_id exceeding valid cu_seqlens range

The old kernel explicitly filters dead blocks with if (t_id >= end) return; before any computation. The new kernel does not: it trusts that blockIdx.x < cu_seqlens[nseq]/cp_size because total_tokens is read from input.data.shape[0]. If a caller passes a tensor with shape[0] larger than cu_seqlens[-1]/cp_size, the binary search lands on b_id = nseq-1, computes s_id = t_id - start >= cur_seqlens, and fused_rope_block_forward indexes freqs at an out-of-range s_id_for_freqs. Adding if (t_id >= (int)(cu_seqlens[nseq] / cp_size)) return; after the binary search would restore the safety property the old kernel had.

@ptrendx
Copy link
Copy Markdown
Member

ptrendx commented May 28, 2026

@plugyawn Hi, could you sign your commits? See https://github.com/NVIDIA/TransformerEngine/blob/main/CONTRIBUTING.rst#sign-your-work
Nice improvement :-).

@sudhakarsingh27 Could you take a look?

plugyawn and others added 3 commits May 29, 2026 03:23
Signed-off-by: plugyawn <progyan.das@iitgn.ac.in>
Signed-off-by: plugyawn <progyan.das@iitgn.ac.in>
for more information, see https://pre-commit.ci

Signed-off-by: plugyawn <progyan.das@iitgn.ac.in>
@plugyawn plugyawn force-pushed the rope-thd-token-linear branch from 331a3a0 to 6c46696 Compare May 28, 2026 21:55
@plugyawn
Copy link
Copy Markdown
Author

plugyawn commented May 28, 2026

Thanks! Signed!

fwiw I think the binary search overhead on normal cases can be reduced also, I'll probably add some improvements.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Performance] Fused RoPE THD kernel becomes dominant bottleneck in long-context training with many packed sequences

3 participants