[Common/PyTorch] bugfix: Token-linear fused RoPE impl. for THD tensors.#3057
[Common/PyTorch] bugfix: Token-linear fused RoPE impl. for THD tensors.#3057plugyawn wants to merge 3 commits into
Conversation
Greptile SummaryThis PR fixes a launch-scaling bug in the THD fused RoPE path where the original kernel launched freqs_len × n_seqs CUDA blocks (most of which were dead blocks that early-exited), instead of only the total_tokens actually present in the packed batch. Two new CUDA kernels dispatch exactly one block per packed local token row and locate the owning sequence via binary search over cu_seqlens, reusing the existing fused_rope_block_forward/backward device helpers so the per-token math is unchanged.
Confidence Score: 5/5The change is safe to merge. The new kernels reuse existing device helpers without altering per-token math, the binary search correctly maps every packed token to its owning sequence (including zero-length spans), and the parity test verifies bitwise equality between old and new paths across all parametrize combinations. The core CUDA change is tightly scoped: both new kernels delegate to the same fused_rope_block_forward/backward device functions with identical s_id and s_id_for_freqs computations as the original THD kernel. The dispatch heuristic is conservative and fully overridable. The parity test is comprehensive across dtypes, cp_size, interleaved, start_positions, and zero-length spans. transformer_engine/common/fused_rope/fused_rope.cu deserves the closest look for the new kernel entry-points and the heuristic dispatcher. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[fused_rope_forward / fused_rope_backward] --> B{qkv_format == THD?}
B -- No --> C[SBHD/BSHD path\ndim3 blocks = s,b]
B -- Yes --> D[read total_tokens\nfrom input.shape 0]
D --> E[fused_rope_thd_use_token_linear\nb, s, total_tokens]
E --> F{env override?}
F -- 0 forced --> G[Old kernel\ndim3 blocks = s,b\nmany dead blocks]
F -- 1 forced --> H[New token-linear kernel\ndim3 blocks = total_tokens]
F -- unset --> I{b >= 64 AND\ns*b >= 8*total_tokens?}
I -- No --> G
I -- Yes --> H
H --> J[fused_rope_thd_find_seq_id\nbinary search over cu_seqlens]
J --> K[compute s_id, cur_seqlens,\ns_id_for_freqs + CP offset]
K --> L[fused_rope_block_forward\nor fused_rope_block_backward]
G --> M[original THD path:\nif t_id >= end return\nthen fused_rope_block_forward/backward]
Reviews (2): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
| int t_id = blockIdx.x; | ||
| int b_id = fused_rope_thd_find_seq_id(cu_seqlens, nseq, t_id, cp_size); |
There was a problem hiding this comment.
Redundant binary search across all threads in the block
Every thread in the block calls fused_rope_thd_find_seq_id with the same arguments (t_id = blockIdx.x, nseq, cp_size) and produces an identical result. With warps_per_block = 8, that's 256 threads each doing O(log nseq) global-memory reads of cu_seqlens that could be performed once. For nseq=2401 (~12 iterations x 256 threads), each block reads ~3,072 redundant entries from cu_seqlens. Performing the search once in thread 0 and broadcasting the result via shared memory would eliminate that overhead.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| int t_id = blockIdx.x; | ||
| int b_id = fused_rope_thd_find_seq_id(cu_seqlens, nseq, t_id, cp_size); | ||
| int start = cu_seqlens[b_id] / cp_size; | ||
| int end = cu_seqlens[b_id + 1] / cp_size; | ||
| int s_id = t_id - start; | ||
| int cur_seqlens = end - start; |
There was a problem hiding this comment.
No guard for
t_id exceeding valid cu_seqlens range
The old kernel explicitly filters dead blocks with if (t_id >= end) return; before any computation. The new kernel does not: it trusts that blockIdx.x < cu_seqlens[nseq]/cp_size because total_tokens is read from input.data.shape[0]. If a caller passes a tensor with shape[0] larger than cu_seqlens[-1]/cp_size, the binary search lands on b_id = nseq-1, computes s_id = t_id - start >= cur_seqlens, and fused_rope_block_forward indexes freqs at an out-of-range s_id_for_freqs. Adding if (t_id >= (int)(cu_seqlens[nseq] / cp_size)) return; after the binary search would restore the safety property the old kernel had.
|
@plugyawn Hi, could you sign your commits? See https://github.com/NVIDIA/TransformerEngine/blob/main/CONTRIBUTING.rst#sign-your-work @sudhakarsingh27 Could you take a look? |
Signed-off-by: plugyawn <progyan.das@iitgn.ac.in>
Signed-off-by: plugyawn <progyan.das@iitgn.ac.in>
for more information, see https://pre-commit.ci Signed-off-by: plugyawn <progyan.das@iitgn.ac.in>
331a3a0 to
6c46696
Compare
|
Thanks! Signed! fwiw I think the binary search overhead on normal cases can be reduced also, I'll probably add some improvements. |
Description
Adds a token-linear implementation of the existing THD fused RoPE path to remove a launch-scaling bug.
Addresses #2866, which finds an interesting case with RoPE scales by freqs_len × n_spans, which is pathological; it should scale by total tokens. I reproduced the issue and found that it's causing a noticeable drops on even plausibly routine shapes. For eg: the [128/512] and [512/128] cases here.
The new kernel reuses the existing
fused_rope_block_forwardandfused_rope_block_backwarddevice helpers, so the math doesn't change. All we need to do is add a THD-only path that launches one bloc/packed token.This is mostly pathological, however, so I've added a condition on the dispatch to avoid the unnecessary binary search overhead, although the overhead appears to be not-that-relevant. The condition is: token-linear only when
b >= 64and the old launch would issue ≥ 8× as many blocks as there are tokens. I'm not sure if this the usual shape of TE updates, so I could remove it!Some more relevant tests:
Microbenchmark on H100 (bf16,
h=32,d=d2=128,freqs_len=T_local=65536, single GPU):Fixes: #2866.
Type of change
Changes
Please list the changes introduced in this PR:
NVTE_FUSED_ROPE_THD_TOKEN_LINEAR=0|1.fused_rope_block_forwardandfused_rope_block_backwarddevice helpers.Checklist: