From 6b6872f02b60b2a8de10ecc5be4dfaaf59aa3902 Mon Sep 17 00:00:00 2001 From: AllenFarcas Date: Mon, 8 Jun 2026 14:40:35 +0000 Subject: [PATCH 1/7] [Fix] Optimized RMSNorm for Qwen3 shapes --- .../pytorch/triton_kernels/norms_common.py | 80 ++++++++- .../pytorch/triton_kernels/rmsnorm.py | 152 ++++++++++++------ .../pytorch/triton_kernels/utils.py | 18 ++- 3 files changed, 192 insertions(+), 58 deletions(-) diff --git a/transformer_engine/pytorch/triton_kernels/norms_common.py b/transformer_engine/pytorch/triton_kernels/norms_common.py index 3c770ec8f..ecd71cc0c 100644 --- a/transformer_engine/pytorch/triton_kernels/norms_common.py +++ b/transformer_engine/pytorch/triton_kernels/norms_common.py @@ -15,7 +15,7 @@ te_dtype_to_triton_dtype, ) from ..quantized_tensor import Quantizer -from .utils import num_programs, block_size, use_blocked, make_ln_out +from .utils import num_programs, block_size, use_blocked, make_ln_out, get_num_sms from .common import get_fp8_max from .rmsnorm import ( _rmsnorm_fwd_triton, @@ -122,6 +122,42 @@ def _get_fp8_transpose_configs(): True: _rmsnorm_bwd_dg_reduce_triton, False: _rmsnorm_bwd_dg_reduce_triton_impl, } + + +_TARGET_PRGMS_PER_CU = 8 +_MAX_PRGMS_PER_CU = 16 +_I32_OFFSET_LIMIT = 1 << 31 + +_FWD_CAP_LADDER = ((128, 16), (256, 4), (512, 2)) +_BWD_CAP_LADDER = ((128, 4), (512, 2)) + + +def _rows_per_pid(rows, hidden, num_sms, cap_ladder): + """Largest power-of-two rows-per-program for a narrow-H RMSNorm launch. + + Picks the biggest ``rpp`` that (a) does not exceed the hidden-size cap, + (b) evenly divides ``rows`` -- so the kernel needs no row-tail mask -- and + (c) still leaves at least ``_TARGET_PRGMS_PER_CU * num_sms`` programs. If + the program-count target cannot be met we still pack a little (4 or 2) to + amortise the gamma load, provided divisibility holds. + """ + target = _TARGET_PRGMS_PER_CU * num_sms + cap = 1 + for threshold, c in cap_ladder: + if hidden <= threshold: + cap = c + break + rpp = 1 + while rpp * 2 <= cap and rows % (rpp * 2) == 0 and rows // (rpp * 2) >= target: + rpp *= 2 + if rpp == 1: + if hidden <= 128 and rows % 4 == 0: + rpp = 4 + elif hidden <= 512 and rows % 2 == 0: + rpp = 2 + return rpp + + # triton drop-in replacement for transformer_engine::pytorch::rmsnorm_fwd def te_rmsnorm_fwd_triton( input: torch.Tensor, @@ -208,7 +244,22 @@ def _te_norm_fwd_triton( IS_FP8_CURRENT_SCALING = isinstance(quantizer, Float8CurrentScalingQuantizer) BLOCK_SIZE = block_size(input_tensor, norm=kernel) USE_BLOCKED = use_blocked(input_tensor) - NUM_PRGMS = N if kernel=='layer' else num_programs(input_tensor, sm_margin) + + # Row-packing (RMSNorm, non-blocked, non-FP8 only) + ROWS_PER_PID_FWD = 1 + if kernel == 'rms' and not USE_BLOCKED and not IS_FP8: + ROWS_PER_PID_FWD = _rows_per_pid(N, H, get_num_sms(sm_margin), _FWD_CAP_LADDER) + + if kernel == 'layer': + # LayerNorm requires one program per row on both paths. + NUM_PRGMS = N + elif kernel == 'rms' and not USE_BLOCKED: + if IS_FP8: + NUM_PRGMS = num_programs(input_tensor, sm_margin) + else: + NUM_PRGMS = N // ROWS_PER_PID_FWD + else: + NUM_PRGMS = num_programs(input_tensor, sm_margin) MAKE_TRANSPOSE = False APPLY_ATOMIC = N < 512 or kernel == 'rms' ATOMIC_REDUCTION_BLOCK_SIZE=256 @@ -301,6 +352,9 @@ def _te_norm_fwd_triton( out_ptr.data_ptr() % 16 == 0 and output_row_stride * getattr(out_ptr.dtype, 'itemsize', 1) % 16 == 0 ) + kwargs["ROWS_PER_PID"] = ROWS_PER_PID_FWD + max_off = max(input_row_stride, output_row_stride) * N + H + kwargs["NEEDS_I64_OFFSETS"] = max_off >= _I32_OFFSET_LIMIT kernel_func[grid_fwd](**kwargs) @@ -358,9 +412,23 @@ def te_rmsnorm_bwd_triton(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma, M, N = x_.shape blk_size = block_size(x_) USE_BLOCKED = use_blocked(x_) - NUM_PRGMS = num_programs(x_, sm_margin) + + # Multi-row-per-program for narrow N + ROWS_PER_PID_BWD = 1 + if not USE_BLOCKED: + ROWS_PER_PID_BWD = _rows_per_pid(M, N, get_num_sms(sm_margin), _BWD_CAP_LADDER) + + base_prgms = num_programs(x_, sm_margin) + if USE_BLOCKED: + NUM_PRGMS = base_prgms + dg_tmp_rows = M + elif ROWS_PER_PID_BWD > 1: + NUM_PRGMS = min(M // ROWS_PER_PID_BWD, _MAX_PRGMS_PER_CU * base_prgms) + dg_tmp_rows = NUM_PRGMS + else: + NUM_PRGMS = base_prgms + dg_tmp_rows = NUM_PRGMS need_reduction = NUM_PRGMS > 1 - dg_tmp_rows = M if USE_BLOCKED else NUM_PRGMS dg_tmp = torch.empty(dg_tmp_rows, N, device=x.device, dtype=torch.float32, requires_grad=False) if need_reduction else None input_aligned_16 = (x_.data_ptr() % 16 == 0) and (x_.stride(0) * x_.dtype.itemsize % 16 == 0) @@ -370,6 +438,8 @@ def te_rmsnorm_bwd_triton(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma, dg_aligned_16 = (dg_target.data_ptr() % 16 == 0) and (dg_target.stride(0) * dg_target.dtype.itemsize % 16 == 0) grid_bwd = lambda meta: (NUM_PRGMS, ) + # See forward: gate i64 offsets on actual overflow to stay on the fast i32 path. + needs_i64_offsets_bwd = max(x_.stride(0), dz_.stride(0), dx.stride(0)) * M + N >= _I32_OFFSET_LIMIT bwd_kernel = _rmsnorm_bwd_kernels[autotune] bwd_kwargs = dict( n_rows=M, n_cols=N, @@ -380,6 +450,8 @@ def te_rmsnorm_bwd_triton(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma, GRAD_OUTPUT_ALIGNED_16=grad_output_aligned_16, DX_ALIGNED_16=dx_aligned_16, DG_ALIGNED_16=dg_aligned_16, + ROWS_PER_PID=ROWS_PER_PID_BWD, + NEEDS_I64_OFFSETS=needs_i64_offsets_bwd, ) if not autotune: bwd_kwargs["num_warps"] = 8 diff --git a/transformer_engine/pytorch/triton_kernels/rmsnorm.py b/transformer_engine/pytorch/triton_kernels/rmsnorm.py index fc82b7f9d..7bb351120 100644 --- a/transformer_engine/pytorch/triton_kernels/rmsnorm.py +++ b/transformer_engine/pytorch/triton_kernels/rmsnorm.py @@ -7,7 +7,31 @@ from itertools import product def get_autotune_config(): - return [triton.Config({'waves_per_eu': we}, num_warps=nw) for (we, nw) in product([0, 1, 2, 4], [4, 8, 16])] + return [triton.Config({'waves_per_eu': we}, num_warps=nw) + for (we, nw) in product([0, 1, 2, 4], [1, 2, 4, 8, 16])] + + +def _prune_rms_configs(configs, named_args, **kwargs): + """Prune autotune configs whose ``num_warps`` is a poor fit for ``n_cols``. + + Bounds threads-per-row to ``[n_cols/16, n_cols*2]`` so the autotuner never + wastes a measurement on a config that is obviously too serial (too few + lanes) or mostly idle (too many lanes). Shared by the fwd and bwd kernels. + """ + n_cols = named_args.get('n_cols') + if n_cols is None: + n_cols = kwargs.get('n_cols') + if n_cols is None: + return configs + out = [] + for cfg in configs: + threads = cfg.num_warps * 64 # AMD wavefront = 64 + if threads < n_cols / 16: # too serial + continue + if threads > n_cols * 2: # too many idle lanes + continue + out.append(cfg) + return out if out else configs # TODO(micky774) Implement fused MXFP8 quantization within the kernel @@ -32,6 +56,8 @@ def _rmsnorm_fwd_triton_impl( FP8_MAX: tl.constexpr, INPUT_ALIGNED_16: tl.constexpr, OUTPUT_ALIGNED_16: tl.constexpr, + ROWS_PER_PID: tl.constexpr = 1, + NEEDS_I64_OFFSETS: tl.constexpr = False, ): row_start = tl.program_id(0) @@ -128,35 +154,45 @@ def _rmsnorm_fwd_triton_impl( else: mask = col_offsets < n_cols - # gamma is invariant across rows -- load + ZERO_CENTERED adjustment once per program. - g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32) - if (ZERO_CENTERED_GAMMA): - g += 1 inv_n_cols = 1.0 / n_cols - for row_idx in tl.range(row_start, n_rows, NUM_PRGMS, num_stages=2): - input_ptrs = input_ptr + row_idx * input_row_stride + col_offsets - if INPUT_ALIGNED_16: - input_ptrs = tl.multiple_of(input_ptrs, (16, )) - row = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg").to(tl.float32) - row_norm = row * row - row_norm = tl.sum(row_norm, axis=-1) - norm_factor = tl.math.rsqrt(row_norm * inv_n_cols + epsilon) + if BLOCK_SIZE <= 512: + g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32) + if (ZERO_CENTERED_GAMMA): + g += 1 + n_chunks = tl.cdiv(n_rows, ROWS_PER_PID) + for chunk_idx in tl.range(row_start, n_chunks, NUM_PRGMS, num_stages=2): + base_row = chunk_idx * ROWS_PER_PID + for i in tl.static_range(ROWS_PER_PID): + row_idx = base_row + i + if NEEDS_I64_OFFSETS: + row_off = row_idx.to(tl.int64) + else: + row_off = row_idx + input_ptrs = input_ptr + row_off * input_row_stride + col_offsets + if INPUT_ALIGNED_16: + input_ptrs = tl.multiple_of(input_ptrs, (16, )) + row = tl.load(input_ptrs, mask=mask, other=0.0).to(tl.float32) + row_norm = tl.sum(row * row, axis=-1) + norm_factor = tl.math.rsqrt(row_norm * inv_n_cols + epsilon) - # Store rsigma (norm_factor) - rsigma_output_ptr = rsigma_ptr + row_idx - tl.store(rsigma_output_ptr, norm_factor) + # Store rsigma (norm_factor) + tl.store(rsigma_ptr + row_idx, norm_factor) - rms_norm = row * norm_factor * g + if BLOCK_SIZE > 512: + g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32) + if (ZERO_CENTERED_GAMMA): + g += 1 + rms_norm = row * norm_factor * g - output_ptrs = output_ptr + row_idx * output_row_stride + col_offsets - if OUTPUT_ALIGNED_16: - output_ptrs = tl.multiple_of(output_ptrs, (16, )) - if IS_FP8: - amax_temp = tl.max(tl.abs(rms_norm), axis=-1) - amax = tl.maximum(amax, amax_temp) - rms_norm = rms_norm * scale - rms_norm = tl.clamp(rms_norm, -FP8_MAX, FP8_MAX) - tl.store(output_ptrs, rms_norm.to(output_type), mask=mask) + output_ptrs = output_ptr + row_off * output_row_stride + col_offsets + if OUTPUT_ALIGNED_16: + output_ptrs = tl.multiple_of(output_ptrs, (16, )) + if IS_FP8: + amax_temp = tl.max(tl.abs(rms_norm), axis=-1) + amax = tl.maximum(amax, amax_temp) + rms_norm = rms_norm * scale + rms_norm = tl.clamp(rms_norm, -FP8_MAX, FP8_MAX) + tl.store(output_ptrs, rms_norm.to(output_type), mask=mask) if IS_FP8: tl.atomic_max(q_amax_ptr, amax, sem="relaxed") if row_start == 0: @@ -164,7 +200,12 @@ def _rmsnorm_fwd_triton_impl( scale_inv = tl.fdiv(1.0, scale) tl.store(scale_inv_ptr, scale_inv) -autotune_dec = triton.autotune(configs=get_autotune_config(), key=['n_rows', 'n_cols'], use_cuda_graph=True) +autotune_dec = triton.autotune( + configs=get_autotune_config(), + key=['n_rows', 'n_cols'], + prune_configs_by={'early_config_prune': _prune_rms_configs}, + use_cuda_graph=True, +) _rmsnorm_fwd_triton = autotune_dec(_rmsnorm_fwd_triton_impl) @triton.jit @@ -172,7 +213,9 @@ def _rmsnorm_bwd_triton_impl(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_p n_rows, n_cols, ZERO_CENTERED_GAMMA: tl.constexpr, BLOCK_SIZE: tl.constexpr, USE_BLOCKED: tl.constexpr, NUM_PRGMS: tl.constexpr, INPUT_ALIGNED_16: tl.constexpr, GRAD_OUTPUT_ALIGNED_16: tl.constexpr, - DX_ALIGNED_16: tl.constexpr, DG_ALIGNED_16: tl.constexpr): + DX_ALIGNED_16: tl.constexpr, DG_ALIGNED_16: tl.constexpr, + ROWS_PER_PID: tl.constexpr = 1, + NEEDS_I64_OFFSETS: tl.constexpr = False): row_start = tl.program_id(0) col_offsets = tl.arange(0, BLOCK_SIZE) inv_n_cols = 1.0 / n_cols @@ -290,39 +333,50 @@ def _rmsnorm_bwd_triton_impl(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_p if (ZERO_CENTERED_GAMMA): g += 1. - for row_idx in tl.range(row_start, n_rows, NUM_PRGMS, num_stages=2): - input_ptrs = input_ptr + row_idx * input_row_stride + col_offsets - grad_output_ptrs = grad_output_ptr + row_idx * output_row_stride + col_offsets - dx_ptrs = dx_ptr + row_idx * input_row_stride + col_offsets + # Persistent outer loop with multi-row-per-chunk. + n_chunks = tl.cdiv(n_rows, ROWS_PER_PID) + for chunk_idx in tl.range(row_start, n_chunks, NUM_PRGMS, num_stages=2): + base_row = chunk_idx * ROWS_PER_PID + for i in tl.static_range(ROWS_PER_PID): + row_idx = base_row + i + if NEEDS_I64_OFFSETS: + row_off = row_idx.to(tl.int64) + else: + row_off = row_idx + input_ptrs = input_ptr + row_off * input_row_stride + col_offsets + grad_output_ptrs = grad_output_ptr + row_off * output_row_stride + col_offsets + dx_ptrs = dx_ptr + row_off * input_row_stride + col_offsets - if INPUT_ALIGNED_16: - input_ptrs = tl.multiple_of(input_ptrs, (16, )) - if GRAD_OUTPUT_ALIGNED_16: - grad_output_ptrs = tl.multiple_of(grad_output_ptrs, (16, )) - if DX_ALIGNED_16: - dx_ptrs = tl.multiple_of(dx_ptrs, (16, )) + if INPUT_ALIGNED_16: + input_ptrs = tl.multiple_of(input_ptrs, (16, )) + if GRAD_OUTPUT_ALIGNED_16: + grad_output_ptrs = tl.multiple_of(grad_output_ptrs, (16, )) + if DX_ALIGNED_16: + dx_ptrs = tl.multiple_of(dx_ptrs, (16, )) - x = tl.load(input_ptrs, mask=mask, other=0.0).to(tl.float32) - grad_output = tl.load(grad_output_ptrs, mask=mask, other=0.0).to(tl.float32) + x = tl.load(input_ptrs, mask=mask, other=0.0).to(tl.float32) + grad_output = tl.load(grad_output_ptrs, mask=mask, other=0.0).to(tl.float32) - norm_factor = tl.load(rsigma_ptr + row_idx).to(tl.float32) - grad_sum = tl.sum(grad_output * x * g, axis=0) - c_scalar = norm_factor * norm_factor * grad_sum * inv_n_cols + norm_factor = tl.load(rsigma_ptr + row_idx).to(tl.float32) + grad_sum = tl.sum(grad_output * x * g, axis=0) + c_scalar = norm_factor * norm_factor * grad_sum * inv_n_cols - grad_input = norm_factor * (grad_output * g - c_scalar * x) - tl.store(dx_ptrs, grad_input.to(dx_ptr.type.element_ty), mask=mask) + grad_input = norm_factor * (grad_output * g - c_scalar * x) + tl.store(dx_ptrs, grad_input.to(dx_ptr.type.element_ty), mask=mask) - dg = grad_output * x * norm_factor - dg_col_redux += dg.to(tl.float32) + dg = grad_output * x * norm_factor + dg_col_redux += dg.to(tl.float32) + # Each program owns exactly one dg_tmp partial row (index == row_start). tl.store(dg_ptr + row_start * n_cols + col_offsets, dg_col_redux, mask=mask) -# Autotune wrapper. Mirrors the fwd autotune layout so callers can toggle -# autotune via the same flag. +# Autotune wrapper. Mirrors the fwd autotune layout (same config set + prune) +# so callers can toggle autotune via the same flag. _rmsnorm_bwd_triton = triton.autotune( configs=get_autotune_config(), key=['n_rows', 'n_cols'], + prune_configs_by={'early_config_prune': _prune_rms_configs}, use_cuda_graph=True, )(_rmsnorm_bwd_triton_impl) diff --git a/transformer_engine/pytorch/triton_kernels/utils.py b/transformer_engine/pytorch/triton_kernels/utils.py index 884fab5e3..3a610907f 100644 --- a/transformer_engine/pytorch/triton_kernels/utils.py +++ b/transformer_engine/pytorch/triton_kernels/utils.py @@ -33,13 +33,21 @@ def get_inf_ln_sm_margin(): return get_ln_sm_margin("INF") +# Per-device SM count cache. get_num_sms is called on every RMSNorm dispatch +# (forward and backward), and get_device_properties is a driver round-trip; +# the active device's SM count is constant, so cache it. +_NUM_SMS_CACHE: "dict[int, int]" = {} + + def get_num_sms(sm_margin=None): - num_sms = torch.cuda.get_device_properties( - torch.cuda.current_device() - ).multi_processor_count + dev = torch.cuda.current_device() + n = _NUM_SMS_CACHE.get(dev) + if n is None: + n = torch.cuda.get_device_properties(dev).multi_processor_count + _NUM_SMS_CACHE[dev] = n if sm_margin is not None and sm_margin > 0: - num_sms = max(num_sms - int(sm_margin), 1) - return num_sms + n = max(n - int(sm_margin), 1) + return n def num_programs(x, sm_margin=None): From 3a2b7b1efa9bc982fd0bf48165340c890e415f5b Mon Sep 17 00:00:00 2001 From: AllenFarcas Date: Wed, 10 Jun 2026 04:03:26 +0000 Subject: [PATCH 2/7] [Fix] Improved regressions and overall performance --- .../pytorch/triton_kernels/norms_common.py | 14 +++++++++++--- .../pytorch/triton_kernels/rmsnorm.py | 14 ++++++++------ transformer_engine/pytorch/triton_kernels/utils.py | 13 ++----------- 3 files changed, 21 insertions(+), 20 deletions(-) diff --git a/transformer_engine/pytorch/triton_kernels/norms_common.py b/transformer_engine/pytorch/triton_kernels/norms_common.py index ecd71cc0c..380e45204 100644 --- a/transformer_engine/pytorch/triton_kernels/norms_common.py +++ b/transformer_engine/pytorch/triton_kernels/norms_common.py @@ -128,6 +128,8 @@ def _get_fp8_transpose_configs(): _MAX_PRGMS_PER_CU = 16 _I32_OFFSET_LIMIT = 1 << 31 +_FWD_FLAT_GRID_MAX_H = 2048 + _FWD_CAP_LADDER = ((128, 16), (256, 4), (512, 2)) _BWD_CAP_LADDER = ((128, 4), (512, 2)) @@ -250,14 +252,19 @@ def _te_norm_fwd_triton( if kernel == 'rms' and not USE_BLOCKED and not IS_FP8: ROWS_PER_PID_FWD = _rows_per_pid(N, H, get_num_sms(sm_margin), _FWD_CAP_LADDER) + # HOIST_GAMMA: load gamma once per program when the program processes more than one row + HOIST_GAMMA_FWD = True if kernel == 'layer': # LayerNorm requires one program per row on both paths. NUM_PRGMS = N elif kernel == 'rms' and not USE_BLOCKED: - if IS_FP8: - NUM_PRGMS = num_programs(input_tensor, sm_margin) - else: + use_flat = (not IS_FP8) and (H <= _FWD_FLAT_GRID_MAX_H) + if use_flat: NUM_PRGMS = N // ROWS_PER_PID_FWD + HOIST_GAMMA_FWD = ROWS_PER_PID_FWD > 1 + else: + NUM_PRGMS = num_programs(input_tensor, sm_margin) + HOIST_GAMMA_FWD = True # persistent loop reuses gamma across rows else: NUM_PRGMS = num_programs(input_tensor, sm_margin) MAKE_TRANSPOSE = False @@ -353,6 +360,7 @@ def _te_norm_fwd_triton( output_row_stride * getattr(out_ptr.dtype, 'itemsize', 1) % 16 == 0 ) kwargs["ROWS_PER_PID"] = ROWS_PER_PID_FWD + kwargs["HOIST_GAMMA"] = HOIST_GAMMA_FWD max_off = max(input_row_stride, output_row_stride) * N + H kwargs["NEEDS_I64_OFFSETS"] = max_off >= _I32_OFFSET_LIMIT diff --git a/transformer_engine/pytorch/triton_kernels/rmsnorm.py b/transformer_engine/pytorch/triton_kernels/rmsnorm.py index 7bb351120..cdf2a8858 100644 --- a/transformer_engine/pytorch/triton_kernels/rmsnorm.py +++ b/transformer_engine/pytorch/triton_kernels/rmsnorm.py @@ -4,11 +4,9 @@ import torch import triton import triton.language as tl -from itertools import product def get_autotune_config(): - return [triton.Config({'waves_per_eu': we}, num_warps=nw) - for (we, nw) in product([0, 1, 2, 4], [1, 2, 4, 8, 16])] + return [triton.Config({'waves_per_eu': 0}, num_warps=nw) for nw in (1, 4, 8, 16)] def _prune_rms_configs(configs, named_args, **kwargs): @@ -58,6 +56,7 @@ def _rmsnorm_fwd_triton_impl( OUTPUT_ALIGNED_16: tl.constexpr, ROWS_PER_PID: tl.constexpr = 1, NEEDS_I64_OFFSETS: tl.constexpr = False, + HOIST_GAMMA: tl.constexpr = True, ): row_start = tl.program_id(0) @@ -155,7 +154,7 @@ def _rmsnorm_fwd_triton_impl( else: mask = col_offsets < n_cols inv_n_cols = 1.0 / n_cols - if BLOCK_SIZE <= 512: + if HOIST_GAMMA: g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32) if (ZERO_CENTERED_GAMMA): g += 1 @@ -171,14 +170,17 @@ def _rmsnorm_fwd_triton_impl( input_ptrs = input_ptr + row_off * input_row_stride + col_offsets if INPUT_ALIGNED_16: input_ptrs = tl.multiple_of(input_ptrs, (16, )) - row = tl.load(input_ptrs, mask=mask, other=0.0).to(tl.float32) + if HOIST_GAMMA and ROWS_PER_PID == 1: + row = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg").to(tl.float32) + else: + row = tl.load(input_ptrs, mask=mask, other=0.0).to(tl.float32) row_norm = tl.sum(row * row, axis=-1) norm_factor = tl.math.rsqrt(row_norm * inv_n_cols + epsilon) # Store rsigma (norm_factor) tl.store(rsigma_ptr + row_idx, norm_factor) - if BLOCK_SIZE > 512: + if not HOIST_GAMMA: g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32) if (ZERO_CENTERED_GAMMA): g += 1 diff --git a/transformer_engine/pytorch/triton_kernels/utils.py b/transformer_engine/pytorch/triton_kernels/utils.py index 3a610907f..d4e6f3fdf 100644 --- a/transformer_engine/pytorch/triton_kernels/utils.py +++ b/transformer_engine/pytorch/triton_kernels/utils.py @@ -7,6 +7,7 @@ from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8CurrentScalingQuantizer from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer +from transformer_engine.pytorch.utils import get_sm_count from .common import te_dtype_to_torch_dtype def get_ln_sm_margin(sm_margin_type): @@ -33,18 +34,8 @@ def get_inf_ln_sm_margin(): return get_ln_sm_margin("INF") -# Per-device SM count cache. get_num_sms is called on every RMSNorm dispatch -# (forward and backward), and get_device_properties is a driver round-trip; -# the active device's SM count is constant, so cache it. -_NUM_SMS_CACHE: "dict[int, int]" = {} - - def get_num_sms(sm_margin=None): - dev = torch.cuda.current_device() - n = _NUM_SMS_CACHE.get(dev) - if n is None: - n = torch.cuda.get_device_properties(dev).multi_processor_count - _NUM_SMS_CACHE[dev] = n + n = get_sm_count() if sm_margin is not None and sm_margin > 0: n = max(n - int(sm_margin), 1) return n From f4f58662b0ac16294e4db883301affdb6afb8561 Mon Sep 17 00:00:00 2001 From: AllenFarcas Date: Wed, 10 Jun 2026 04:09:54 +0000 Subject: [PATCH 3/7] [Nit] Removed trailing spaces --- transformer_engine/pytorch/triton_kernels/norms_common.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/triton_kernels/norms_common.py b/transformer_engine/pytorch/triton_kernels/norms_common.py index 380e45204..c62bc5f13 100644 --- a/transformer_engine/pytorch/triton_kernels/norms_common.py +++ b/transformer_engine/pytorch/triton_kernels/norms_common.py @@ -124,9 +124,9 @@ def _get_fp8_transpose_configs(): } -_TARGET_PRGMS_PER_CU = 8 -_MAX_PRGMS_PER_CU = 16 -_I32_OFFSET_LIMIT = 1 << 31 +_TARGET_PRGMS_PER_CU = 8 +_MAX_PRGMS_PER_CU = 16 +_I32_OFFSET_LIMIT = 1 << 31 _FWD_FLAT_GRID_MAX_H = 2048 From 8ab5b5bb38329a6b9f6408ceb19b5c095af8d0b0 Mon Sep 17 00:00:00 2001 From: AllenFarcas Date: Fri, 12 Jun 2026 17:24:01 +0000 Subject: [PATCH 4/7] [Fix] Addressed ROWS_PER_PID divisibility --- transformer_engine/pytorch/triton_kernels/norms_common.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/transformer_engine/pytorch/triton_kernels/norms_common.py b/transformer_engine/pytorch/triton_kernels/norms_common.py index c62bc5f13..d6aa5047e 100644 --- a/transformer_engine/pytorch/triton_kernels/norms_common.py +++ b/transformer_engine/pytorch/triton_kernels/norms_common.py @@ -157,6 +157,8 @@ def _rows_per_pid(rows, hidden, num_sms, cap_ladder): rpp = 4 elif hidden <= 512 and rows % 2 == 0: rpp = 2 + # The kernel packs rpp rows per program with no row-tail mask, so rpp must divide rows + assert rows % rpp == 0, f"rows_per_pid={rpp} does not divide rows={rows}" return rpp From 36cd5693be1522ee6612938075346335a7281495 Mon Sep 17 00:00:00 2001 From: AllenFarcas Date: Fri, 12 Jun 2026 19:14:15 +0000 Subject: [PATCH 5/7] [Fix] Added mask --- .../pytorch/triton_kernels/norms_common.py | 11 +++++------ .../pytorch/triton_kernels/rmsnorm.py | 14 ++++++++++---- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/transformer_engine/pytorch/triton_kernels/norms_common.py b/transformer_engine/pytorch/triton_kernels/norms_common.py index d6aa5047e..4ec41b512 100644 --- a/transformer_engine/pytorch/triton_kernels/norms_common.py +++ b/transformer_engine/pytorch/triton_kernels/norms_common.py @@ -138,10 +138,11 @@ def _rows_per_pid(rows, hidden, num_sms, cap_ladder): """Largest power-of-two rows-per-program for a narrow-H RMSNorm launch. Picks the biggest ``rpp`` that (a) does not exceed the hidden-size cap, - (b) evenly divides ``rows`` -- so the kernel needs no row-tail mask -- and - (c) still leaves at least ``_TARGET_PRGMS_PER_CU * num_sms`` programs. If - the program-count target cannot be met we still pack a little (4 or 2) to - amortise the gamma load, provided divisibility holds. + (b) evenly divides ``rows`` so the packed kernel wastes no masked tail + iterations, and (c) still leaves at least ``_TARGET_PRGMS_PER_CU * + num_sms`` programs. If the target cannot be met we still pack a little + (4 or 2) to amortise the gamma load. Divisibility is only a performance + choice: the kernel masks any partial tail, so any ``rpp`` stays in bounds. """ target = _TARGET_PRGMS_PER_CU * num_sms cap = 1 @@ -157,8 +158,6 @@ def _rows_per_pid(rows, hidden, num_sms, cap_ladder): rpp = 4 elif hidden <= 512 and rows % 2 == 0: rpp = 2 - # The kernel packs rpp rows per program with no row-tail mask, so rpp must divide rows - assert rows % rpp == 0, f"rows_per_pid={rpp} does not divide rows={rows}" return rpp diff --git a/transformer_engine/pytorch/triton_kernels/rmsnorm.py b/transformer_engine/pytorch/triton_kernels/rmsnorm.py index cdf2a8858..cd1b8e6c9 100644 --- a/transformer_engine/pytorch/triton_kernels/rmsnorm.py +++ b/transformer_engine/pytorch/triton_kernels/rmsnorm.py @@ -163,6 +163,9 @@ def _rmsnorm_fwd_triton_impl( base_row = chunk_idx * ROWS_PER_PID for i in tl.static_range(ROWS_PER_PID): row_idx = base_row + i + # Mask rows past n_rows: ROWS_PER_PID need not divide n_rows. + row_valid = row_idx < n_rows + row_mask = mask & row_valid if NEEDS_I64_OFFSETS: row_off = row_idx.to(tl.int64) else: @@ -171,14 +174,14 @@ def _rmsnorm_fwd_triton_impl( if INPUT_ALIGNED_16: input_ptrs = tl.multiple_of(input_ptrs, (16, )) if HOIST_GAMMA and ROWS_PER_PID == 1: - row = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg").to(tl.float32) + row = tl.load(input_ptrs, mask=row_mask, other=0.0, cache_modifier=".cg").to(tl.float32) else: - row = tl.load(input_ptrs, mask=mask, other=0.0).to(tl.float32) + row = tl.load(input_ptrs, mask=row_mask, other=0.0).to(tl.float32) row_norm = tl.sum(row * row, axis=-1) norm_factor = tl.math.rsqrt(row_norm * inv_n_cols + epsilon) # Store rsigma (norm_factor) - tl.store(rsigma_ptr + row_idx, norm_factor) + tl.store(rsigma_ptr + row_idx, norm_factor, mask=row_valid) if not HOIST_GAMMA: g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32) @@ -194,7 +197,7 @@ def _rmsnorm_fwd_triton_impl( amax = tl.maximum(amax, amax_temp) rms_norm = rms_norm * scale rms_norm = tl.clamp(rms_norm, -FP8_MAX, FP8_MAX) - tl.store(output_ptrs, rms_norm.to(output_type), mask=mask) + tl.store(output_ptrs, rms_norm.to(output_type), mask=row_mask) if IS_FP8: tl.atomic_max(q_amax_ptr, amax, sem="relaxed") if row_start == 0: @@ -341,6 +344,9 @@ def _rmsnorm_bwd_triton_impl(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_p base_row = chunk_idx * ROWS_PER_PID for i in tl.static_range(ROWS_PER_PID): row_idx = base_row + i + # Mask rows past n_rows: ROWS_PER_PID need not divide n_rows. + row_valid = row_idx < n_rows + row_mask = mask & row_valid if NEEDS_I64_OFFSETS: row_off = row_idx.to(tl.int64) else: From d2e5d6044dbe706e639dcb4bd9b85196f853c85d Mon Sep 17 00:00:00 2001 From: AllenFarcas Date: Fri, 12 Jun 2026 19:21:23 +0000 Subject: [PATCH 6/7] [Fix] Added mask for bwd --- transformer_engine/pytorch/triton_kernels/rmsnorm.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/triton_kernels/rmsnorm.py b/transformer_engine/pytorch/triton_kernels/rmsnorm.py index cd1b8e6c9..c87116447 100644 --- a/transformer_engine/pytorch/triton_kernels/rmsnorm.py +++ b/transformer_engine/pytorch/triton_kernels/rmsnorm.py @@ -362,15 +362,15 @@ def _rmsnorm_bwd_triton_impl(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_p if DX_ALIGNED_16: dx_ptrs = tl.multiple_of(dx_ptrs, (16, )) - x = tl.load(input_ptrs, mask=mask, other=0.0).to(tl.float32) - grad_output = tl.load(grad_output_ptrs, mask=mask, other=0.0).to(tl.float32) + x = tl.load(input_ptrs, mask=row_mask, other=0.0).to(tl.float32) + grad_output = tl.load(grad_output_ptrs, mask=row_mask, other=0.0).to(tl.float32) - norm_factor = tl.load(rsigma_ptr + row_idx).to(tl.float32) + norm_factor = tl.load(rsigma_ptr + row_idx, mask=row_valid, other=1.0).to(tl.float32) grad_sum = tl.sum(grad_output * x * g, axis=0) c_scalar = norm_factor * norm_factor * grad_sum * inv_n_cols grad_input = norm_factor * (grad_output * g - c_scalar * x) - tl.store(dx_ptrs, grad_input.to(dx_ptr.type.element_ty), mask=mask) + tl.store(dx_ptrs, grad_input.to(dx_ptr.type.element_ty), mask=row_mask) dg = grad_output * x * norm_factor dg_col_redux += dg.to(tl.float32) From 700dd9459ea77581a17a7089dfb975eb192d14d6 Mon Sep 17 00:00:00 2001 From: AllenFarcas Date: Fri, 12 Jun 2026 19:49:08 +0000 Subject: [PATCH 7/7] [Fix] Added mask for bwd --- .../pytorch/triton_kernels/rmsnorm.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/transformer_engine/pytorch/triton_kernels/rmsnorm.py b/transformer_engine/pytorch/triton_kernels/rmsnorm.py index c87116447..0a64d60a3 100644 --- a/transformer_engine/pytorch/triton_kernels/rmsnorm.py +++ b/transformer_engine/pytorch/triton_kernels/rmsnorm.py @@ -165,7 +165,6 @@ def _rmsnorm_fwd_triton_impl( row_idx = base_row + i # Mask rows past n_rows: ROWS_PER_PID need not divide n_rows. row_valid = row_idx < n_rows - row_mask = mask & row_valid if NEEDS_I64_OFFSETS: row_off = row_idx.to(tl.int64) else: @@ -174,9 +173,9 @@ def _rmsnorm_fwd_triton_impl( if INPUT_ALIGNED_16: input_ptrs = tl.multiple_of(input_ptrs, (16, )) if HOIST_GAMMA and ROWS_PER_PID == 1: - row = tl.load(input_ptrs, mask=row_mask, other=0.0, cache_modifier=".cg").to(tl.float32) + row = tl.load(input_ptrs, mask=mask & row_valid, other=0.0, cache_modifier=".cg").to(tl.float32) else: - row = tl.load(input_ptrs, mask=row_mask, other=0.0).to(tl.float32) + row = tl.load(input_ptrs, mask=mask & row_valid, other=0.0).to(tl.float32) row_norm = tl.sum(row * row, axis=-1) norm_factor = tl.math.rsqrt(row_norm * inv_n_cols + epsilon) @@ -197,7 +196,7 @@ def _rmsnorm_fwd_triton_impl( amax = tl.maximum(amax, amax_temp) rms_norm = rms_norm * scale rms_norm = tl.clamp(rms_norm, -FP8_MAX, FP8_MAX) - tl.store(output_ptrs, rms_norm.to(output_type), mask=row_mask) + tl.store(output_ptrs, rms_norm.to(output_type), mask=mask & row_valid) if IS_FP8: tl.atomic_max(q_amax_ptr, amax, sem="relaxed") if row_start == 0: @@ -346,7 +345,6 @@ def _rmsnorm_bwd_triton_impl(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_p row_idx = base_row + i # Mask rows past n_rows: ROWS_PER_PID need not divide n_rows. row_valid = row_idx < n_rows - row_mask = mask & row_valid if NEEDS_I64_OFFSETS: row_off = row_idx.to(tl.int64) else: @@ -362,15 +360,15 @@ def _rmsnorm_bwd_triton_impl(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_p if DX_ALIGNED_16: dx_ptrs = tl.multiple_of(dx_ptrs, (16, )) - x = tl.load(input_ptrs, mask=row_mask, other=0.0).to(tl.float32) - grad_output = tl.load(grad_output_ptrs, mask=row_mask, other=0.0).to(tl.float32) + x = tl.load(input_ptrs, mask=mask & row_valid, other=0.0).to(tl.float32) + grad_output = tl.load(grad_output_ptrs, mask=mask & row_valid, other=0.0).to(tl.float32) norm_factor = tl.load(rsigma_ptr + row_idx, mask=row_valid, other=1.0).to(tl.float32) grad_sum = tl.sum(grad_output * x * g, axis=0) c_scalar = norm_factor * norm_factor * grad_sum * inv_n_cols grad_input = norm_factor * (grad_output * g - c_scalar * x) - tl.store(dx_ptrs, grad_input.to(dx_ptr.type.element_ty), mask=row_mask) + tl.store(dx_ptrs, grad_input.to(dx_ptr.type.element_ty), mask=mask & row_valid) dg = grad_output * x * norm_factor dg_col_redux += dg.to(tl.float32)