Skip to content

[Fix] TE RMSNorm Triton Kernel Optimization#615

Open
AllenFarcas wants to merge 3 commits into
devfrom
alfarcas/rmsnorm-optim
Open

[Fix] TE RMSNorm Triton Kernel Optimization#615
AllenFarcas wants to merge 3 commits into
devfrom
alfarcas/rmsnorm-optim

Conversation

@AllenFarcas

@AllenFarcas AllenFarcas commented Jun 8, 2026

Copy link
Copy Markdown
Contributor

Description

Optimize RMSNorm triton kernel for Qwen3 shapes.

Fixes # https://github.com/ROCm/frameworks-internal/issues/16614

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Pack multiple rows per program (ROWS_PER_PID, up to 16 for H≤128) instead of one program per row — slashes program-launch count for narrow-H shapes.
  • Apply the same packing to the backward kernel (capped at 4 for H≤128), targeting the Qwen3 32768×128 QK-norm gradient pass.
  • Eliminate redundant gamma reloads and launch overhead on narrow rows, the root cause of the 0.1× cliff on Qwen3 QK shapes.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@Micky774

Micky774 commented Jun 8, 2026

Copy link
Copy Markdown
Contributor

Based on the shared benchmarks it looks like there's actually a regression in the triton kernels across some shapes e.g.

Qwen3-235B/hidden, 8192, 4096, bfloat16, fwd: 4211.3 GB/s --> 3573.2 GB/s similarly for the bwd of the same config. Can you run a general benchmark comparing the updated triton implementation to the current one on dev so we can see what the tradeoffs are? Maybe we can optimize heuristics a bit.



def get_num_sms(sm_margin=None):
num_sms = torch.cuda.get_device_properties(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use get_sm_count from transformer_engine.pytorch.utils

@github-actions

github-actions Bot commented Jun 8, 2026

Copy link
Copy Markdown

Claude Walkthrough

Intent. Restores RMSNorm Triton throughput on narrow-hidden Qwen3 shapes (notably 32768×128 QK-norm) where the previous one-program-per-row scheme bottlenecked on launch overhead and redundant gamma reloads, causing a ~0.1× perf cliff. The fix packs multiple rows into each program and tightens the autotune search space.

Key changes.

  • Introduces a _rows_per_pid heuristic that picks the largest power-of-two row-packing factor that fits the hidden-size cap, divides the row count evenly, and still leaves enough work for the SMs (transformer_engine/pytorch/triton_kernels/norms_common.py:128).
  • Plumbs ROWS_PER_PID and a NEEDS_I64_OFFSETS flag into both forward and backward dispatch, with separate fwd/bwd cap ladders (fwd up to 16 at H≤128, bwd up to 4) — fwd at norms_common.py:247, bwd at norms_common.py:413.
  • Rewrites the RMSNorm forward and backward kernel bodies into a chunked outer loop with a tl.static_range(ROWS_PER_PID) inner loop, hoists the gamma load only when BLOCK_SIZE <= 512, and switches row offsets to int64 only when total offset would overflow int32 (transformer_engine/pytorch/triton_kernels/rmsnorm.py:154, rmsnorm.py:336).
  • Expands the autotune grid (num_warps now includes 1 and 2) and adds an _prune_rms_configs early pruner that bounds threads-per-row to [n_cols/16, n_cols*2]; the backward kernel now uses the same prune as forward (rmsnorm.py:9, rmsnorm.py:18, rmsnorm.py:382).
  • Caches per-device SM count in get_num_sms to avoid a driver round-trip on every dispatch (transformer_engine/pytorch/triton_kernels/utils.py:36).

Walkthrough.

norms_common.py — Adds _rows_per_pid plus two cap ladders (_FWD_CAP_LADDER, _BWD_CAP_LADDER) that encode the empirical observation that very narrow rows (H≤128) benefit from aggressive packing while wider rows do not. The heuristic targets _TARGET_PRGMS_PER_CU * num_sms programs and only packs when packing divides rows evenly, so the kernel needs no row-tail mask. In _te_norm_fwd_triton, packing is gated to RMSNorm + non-blocked + non-FP8 — LayerNorm and FP8 paths keep the original N-program or num_programs() launch. Grid size becomes N // ROWS_PER_PID_FWD. The new NEEDS_I64_OFFSETS kwarg is computed from max(stride) * N + H against 1<<31, so the kernel stays on the fast int32 path unless it would actually overflow. te_rmsnorm_bwd_triton mirrors this, additionally clamping NUM_PRGMS to _MAX_PRGMS_PER_CU * base_prgms and sizing the dg_tmp partial-reduction buffer to NUM_PRGMS rather than M, which falls out naturally from one-partial-row-per-program.

rmsnorm.py — Both kernel impls gain ROWS_PER_PID and NEEDS_I64_OFFSETS constexpr params (defaulted, so existing call sites stay valid). The persistent outer loop now iterates n_chunks = cdiv(n_rows, ROWS_PER_PID) strides and uses tl.static_range over the packed rows so Triton fully unrolls the inner work — that's where the gamma reuse and launch-overhead savings come from. Gamma hoisting is conditional: at BLOCK_SIZE > 512 the register pressure makes it better to reload inside the loop, so the old behavior is preserved there. Row offsets are cast to int64 only when the precomputed flag says they must. The autotune config product gains num_warps ∈ {1, 2} for the narrow shapes where 4+ warps leave lanes idle, and _prune_rms_configs discards any config whose num_warps * 64 falls outside [n_cols/16, n_cols*2] before the autotuner ever measures it — keeping tuning time bounded after the grid expansion. The backward autotuner is updated to use the same pruner.

utils.pyget_num_sms is now called on every RMSNorm dispatch (twice per forward/backward); the SM count is constant per device, so the value is memoized in _NUM_SMS_CACHE keyed by device index, with sm_margin still applied per call.

Testing. No new tests are added in this PR. The author notes that existing unit tests pass locally; correctness for the packed path relies on the heuristic only enabling packing when rows % ROWS_PER_PID == 0, which removes the need for a row-tail mask. Reviewers may want to confirm existing RMSNorm tests cover the new paths (packed fwd, packed bwd, BLOCK_SIZE > 512 gamma-reload branch, i64-offset branch).

Notes for reviewers.

  • Packing is intentionally disabled for LayerNorm, blocked-norm, and FP8 RMSNorm. The FP8 amax / scale-inv store in the forward kernel still hinges on row_start == 0, which is preserved.
  • The NEEDS_I64_OFFSETS gate is computed from strides on the host; if a caller passed a tensor whose strides understated the addressing extent, the int32 path could silently wrap. The current dispatch derives it from the actual tensors being launched, so this should be safe in-tree.
  • dg_tmp_rows = NUM_PRGMS (instead of M) is a memory win on tall-and-narrow shapes but depends on each program writing exactly one partial row at offset row_start; that invariant is held by the single tl.store(dg_ptr + row_start * n_cols + ...) after the chunk loop.
  • Autotune grid expansion (12 → 20 configs) is partly offset by the early pruner, but first-time autotune cost for shapes outside the cache may go up.
  • get_num_sms cache is keyed by device index and never invalidated; fine in practice but worth noting if GPUs are hot-swapped in a process.

Generated by Claude. To request a code review, comment `/claude review`.

Comment thread transformer_engine/pytorch/triton_kernels/norms_common.py
g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
if (ZERO_CENTERED_GAMMA):
g += 1
n_chunks = tl.cdiv(n_rows, ROWS_PER_PID)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: n_chunks = tl.cdiv(n_rows, ROWS_PER_PID) rounds up, and the inner tl.static_range(ROWS_PER_PID) writes row_idx = base_row + i to rsigma_ptr + row_idx, output_ptrs, etc. with only a column mask. Correctness depends entirely on the dispatcher in norms_common.py returning ROWS_PER_PID > 1 only when n_rows % ROWS_PER_PID == 0 (which _rows_per_pid does enforce today).

Since the kernel default is ROWS_PER_PID: tl.constexpr = 1, a future caller passing a non-divisor value would silently corrupt memory. Worth a short comment near n_chunks documenting the divisibility invariant (or, more defensively, a row_idx < n_rows mask on the inner stores).

Comment thread transformer_engine/pytorch/triton_kernels/norms_common.py Outdated
@github-actions

github-actions Bot commented Jun 8, 2026

Copy link
Copy Markdown

Reviewed the RMSNorm Triton optimization for Qwen3 QK-norm shapes. The row-packing approach is reasonable for narrow-H, but the fwd dispatch in norms_common.py may be the source of the wide-H regression already flagged: when ROWS_PER_PID_FWD == 1 (any H > 512, non-FP8), NUM_PRGMS is set to N (one program per row), dropping the prior persistent-grid behavior. The bwd path correctly guards on > 1; fwd should mirror it. Other findings are nits (divisibility invariant should be commented in the kernel; trailing whitespace on three constants). The _NUM_SMS_CACHE duplication is already covered by @ipanfilo's inline comment, not duplicated here.

Copyright headers: OK (all 3 files carry AMD copyrights ending in 2026).

@AllenFarcas AllenFarcas requested a review from ipanfilo June 10, 2026 04:10
@AllenFarcas

Copy link
Copy Markdown
Contributor Author

@Micky774 I updated the benchmarks run and the results obtained here and solved all regressions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants