[Fix] TE RMSNorm Triton Kernel Optimization#615
Conversation
|
Based on the shared benchmarks it looks like there's actually a regression in the triton kernels across some shapes e.g.
|
|
|
||
|
|
||
| def get_num_sms(sm_margin=None): | ||
| num_sms = torch.cuda.get_device_properties( |
There was a problem hiding this comment.
Use get_sm_count from transformer_engine.pytorch.utils
Claude WalkthroughIntent. Restores RMSNorm Triton throughput on narrow-hidden Qwen3 shapes (notably 32768×128 QK-norm) where the previous one-program-per-row scheme bottlenecked on launch overhead and redundant gamma reloads, causing a ~0.1× perf cliff. The fix packs multiple rows into each program and tightens the autotune search space. Key changes.
Walkthrough.
Testing. No new tests are added in this PR. The author notes that existing unit tests pass locally; correctness for the packed path relies on the heuristic only enabling packing when Notes for reviewers.
Generated by Claude. To request a code review, comment `/claude review`. |
| g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32) | ||
| if (ZERO_CENTERED_GAMMA): | ||
| g += 1 | ||
| n_chunks = tl.cdiv(n_rows, ROWS_PER_PID) |
There was a problem hiding this comment.
Nit: n_chunks = tl.cdiv(n_rows, ROWS_PER_PID) rounds up, and the inner tl.static_range(ROWS_PER_PID) writes row_idx = base_row + i to rsigma_ptr + row_idx, output_ptrs, etc. with only a column mask. Correctness depends entirely on the dispatcher in norms_common.py returning ROWS_PER_PID > 1 only when n_rows % ROWS_PER_PID == 0 (which _rows_per_pid does enforce today).
Since the kernel default is ROWS_PER_PID: tl.constexpr = 1, a future caller passing a non-divisor value would silently corrupt memory. Worth a short comment near n_chunks documenting the divisibility invariant (or, more defensively, a row_idx < n_rows mask on the inner stores).
|
Reviewed the RMSNorm Triton optimization for Qwen3 QK-norm shapes. The row-packing approach is reasonable for narrow-H, but the fwd dispatch in Copyright headers: OK (all 3 files carry AMD copyrights ending in 2026). |
Description
Optimize RMSNorm triton kernel for Qwen3 shapes.
Fixes # https://github.com/ROCm/frameworks-internal/issues/16614
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: