diff --git a/transformer_engine/pytorch/triton_kernels/norms_common.py b/transformer_engine/pytorch/triton_kernels/norms_common.py index 3c770ec8f..d6aa5047e 100644 --- a/transformer_engine/pytorch/triton_kernels/norms_common.py +++ b/transformer_engine/pytorch/triton_kernels/norms_common.py @@ -15,7 +15,7 @@ te_dtype_to_triton_dtype, ) from ..quantized_tensor import Quantizer -from .utils import num_programs, block_size, use_blocked, make_ln_out +from .utils import num_programs, block_size, use_blocked, make_ln_out, get_num_sms from .common import get_fp8_max from .rmsnorm import ( _rmsnorm_fwd_triton, @@ -122,6 +122,46 @@ def _get_fp8_transpose_configs(): True: _rmsnorm_bwd_dg_reduce_triton, False: _rmsnorm_bwd_dg_reduce_triton_impl, } + + +_TARGET_PRGMS_PER_CU = 8 +_MAX_PRGMS_PER_CU = 16 +_I32_OFFSET_LIMIT = 1 << 31 + +_FWD_FLAT_GRID_MAX_H = 2048 + +_FWD_CAP_LADDER = ((128, 16), (256, 4), (512, 2)) +_BWD_CAP_LADDER = ((128, 4), (512, 2)) + + +def _rows_per_pid(rows, hidden, num_sms, cap_ladder): + """Largest power-of-two rows-per-program for a narrow-H RMSNorm launch. + + Picks the biggest ``rpp`` that (a) does not exceed the hidden-size cap, + (b) evenly divides ``rows`` -- so the kernel needs no row-tail mask -- and + (c) still leaves at least ``_TARGET_PRGMS_PER_CU * num_sms`` programs. If + the program-count target cannot be met we still pack a little (4 or 2) to + amortise the gamma load, provided divisibility holds. + """ + target = _TARGET_PRGMS_PER_CU * num_sms + cap = 1 + for threshold, c in cap_ladder: + if hidden <= threshold: + cap = c + break + rpp = 1 + while rpp * 2 <= cap and rows % (rpp * 2) == 0 and rows // (rpp * 2) >= target: + rpp *= 2 + if rpp == 1: + if hidden <= 128 and rows % 4 == 0: + rpp = 4 + elif hidden <= 512 and rows % 2 == 0: + rpp = 2 + # The kernel packs rpp rows per program with no row-tail mask, so rpp must divide rows + assert rows % rpp == 0, f"rows_per_pid={rpp} does not divide rows={rows}" + return rpp + + # triton drop-in replacement for transformer_engine::pytorch::rmsnorm_fwd def te_rmsnorm_fwd_triton( input: torch.Tensor, @@ -208,7 +248,27 @@ def _te_norm_fwd_triton( IS_FP8_CURRENT_SCALING = isinstance(quantizer, Float8CurrentScalingQuantizer) BLOCK_SIZE = block_size(input_tensor, norm=kernel) USE_BLOCKED = use_blocked(input_tensor) - NUM_PRGMS = N if kernel=='layer' else num_programs(input_tensor, sm_margin) + + # Row-packing (RMSNorm, non-blocked, non-FP8 only) + ROWS_PER_PID_FWD = 1 + if kernel == 'rms' and not USE_BLOCKED and not IS_FP8: + ROWS_PER_PID_FWD = _rows_per_pid(N, H, get_num_sms(sm_margin), _FWD_CAP_LADDER) + + # HOIST_GAMMA: load gamma once per program when the program processes more than one row + HOIST_GAMMA_FWD = True + if kernel == 'layer': + # LayerNorm requires one program per row on both paths. + NUM_PRGMS = N + elif kernel == 'rms' and not USE_BLOCKED: + use_flat = (not IS_FP8) and (H <= _FWD_FLAT_GRID_MAX_H) + if use_flat: + NUM_PRGMS = N // ROWS_PER_PID_FWD + HOIST_GAMMA_FWD = ROWS_PER_PID_FWD > 1 + else: + NUM_PRGMS = num_programs(input_tensor, sm_margin) + HOIST_GAMMA_FWD = True # persistent loop reuses gamma across rows + else: + NUM_PRGMS = num_programs(input_tensor, sm_margin) MAKE_TRANSPOSE = False APPLY_ATOMIC = N < 512 or kernel == 'rms' ATOMIC_REDUCTION_BLOCK_SIZE=256 @@ -301,6 +361,10 @@ def _te_norm_fwd_triton( out_ptr.data_ptr() % 16 == 0 and output_row_stride * getattr(out_ptr.dtype, 'itemsize', 1) % 16 == 0 ) + kwargs["ROWS_PER_PID"] = ROWS_PER_PID_FWD + kwargs["HOIST_GAMMA"] = HOIST_GAMMA_FWD + max_off = max(input_row_stride, output_row_stride) * N + H + kwargs["NEEDS_I64_OFFSETS"] = max_off >= _I32_OFFSET_LIMIT kernel_func[grid_fwd](**kwargs) @@ -358,9 +422,23 @@ def te_rmsnorm_bwd_triton(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma, M, N = x_.shape blk_size = block_size(x_) USE_BLOCKED = use_blocked(x_) - NUM_PRGMS = num_programs(x_, sm_margin) + + # Multi-row-per-program for narrow N + ROWS_PER_PID_BWD = 1 + if not USE_BLOCKED: + ROWS_PER_PID_BWD = _rows_per_pid(M, N, get_num_sms(sm_margin), _BWD_CAP_LADDER) + + base_prgms = num_programs(x_, sm_margin) + if USE_BLOCKED: + NUM_PRGMS = base_prgms + dg_tmp_rows = M + elif ROWS_PER_PID_BWD > 1: + NUM_PRGMS = min(M // ROWS_PER_PID_BWD, _MAX_PRGMS_PER_CU * base_prgms) + dg_tmp_rows = NUM_PRGMS + else: + NUM_PRGMS = base_prgms + dg_tmp_rows = NUM_PRGMS need_reduction = NUM_PRGMS > 1 - dg_tmp_rows = M if USE_BLOCKED else NUM_PRGMS dg_tmp = torch.empty(dg_tmp_rows, N, device=x.device, dtype=torch.float32, requires_grad=False) if need_reduction else None input_aligned_16 = (x_.data_ptr() % 16 == 0) and (x_.stride(0) * x_.dtype.itemsize % 16 == 0) @@ -370,6 +448,8 @@ def te_rmsnorm_bwd_triton(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma, dg_aligned_16 = (dg_target.data_ptr() % 16 == 0) and (dg_target.stride(0) * dg_target.dtype.itemsize % 16 == 0) grid_bwd = lambda meta: (NUM_PRGMS, ) + # See forward: gate i64 offsets on actual overflow to stay on the fast i32 path. + needs_i64_offsets_bwd = max(x_.stride(0), dz_.stride(0), dx.stride(0)) * M + N >= _I32_OFFSET_LIMIT bwd_kernel = _rmsnorm_bwd_kernels[autotune] bwd_kwargs = dict( n_rows=M, n_cols=N, @@ -380,6 +460,8 @@ def te_rmsnorm_bwd_triton(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma, GRAD_OUTPUT_ALIGNED_16=grad_output_aligned_16, DX_ALIGNED_16=dx_aligned_16, DG_ALIGNED_16=dg_aligned_16, + ROWS_PER_PID=ROWS_PER_PID_BWD, + NEEDS_I64_OFFSETS=needs_i64_offsets_bwd, ) if not autotune: bwd_kwargs["num_warps"] = 8 diff --git a/transformer_engine/pytorch/triton_kernels/rmsnorm.py b/transformer_engine/pytorch/triton_kernels/rmsnorm.py index fc82b7f9d..cdf2a8858 100644 --- a/transformer_engine/pytorch/triton_kernels/rmsnorm.py +++ b/transformer_engine/pytorch/triton_kernels/rmsnorm.py @@ -4,10 +4,32 @@ import torch import triton import triton.language as tl -from itertools import product def get_autotune_config(): - return [triton.Config({'waves_per_eu': we}, num_warps=nw) for (we, nw) in product([0, 1, 2, 4], [4, 8, 16])] + return [triton.Config({'waves_per_eu': 0}, num_warps=nw) for nw in (1, 4, 8, 16)] + + +def _prune_rms_configs(configs, named_args, **kwargs): + """Prune autotune configs whose ``num_warps`` is a poor fit for ``n_cols``. + + Bounds threads-per-row to ``[n_cols/16, n_cols*2]`` so the autotuner never + wastes a measurement on a config that is obviously too serial (too few + lanes) or mostly idle (too many lanes). Shared by the fwd and bwd kernels. + """ + n_cols = named_args.get('n_cols') + if n_cols is None: + n_cols = kwargs.get('n_cols') + if n_cols is None: + return configs + out = [] + for cfg in configs: + threads = cfg.num_warps * 64 # AMD wavefront = 64 + if threads < n_cols / 16: # too serial + continue + if threads > n_cols * 2: # too many idle lanes + continue + out.append(cfg) + return out if out else configs # TODO(micky774) Implement fused MXFP8 quantization within the kernel @@ -32,6 +54,9 @@ def _rmsnorm_fwd_triton_impl( FP8_MAX: tl.constexpr, INPUT_ALIGNED_16: tl.constexpr, OUTPUT_ALIGNED_16: tl.constexpr, + ROWS_PER_PID: tl.constexpr = 1, + NEEDS_I64_OFFSETS: tl.constexpr = False, + HOIST_GAMMA: tl.constexpr = True, ): row_start = tl.program_id(0) @@ -128,35 +153,48 @@ def _rmsnorm_fwd_triton_impl( else: mask = col_offsets < n_cols - # gamma is invariant across rows -- load + ZERO_CENTERED adjustment once per program. - g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32) - if (ZERO_CENTERED_GAMMA): - g += 1 inv_n_cols = 1.0 / n_cols - for row_idx in tl.range(row_start, n_rows, NUM_PRGMS, num_stages=2): - input_ptrs = input_ptr + row_idx * input_row_stride + col_offsets - if INPUT_ALIGNED_16: - input_ptrs = tl.multiple_of(input_ptrs, (16, )) - row = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg").to(tl.float32) - row_norm = row * row - row_norm = tl.sum(row_norm, axis=-1) - norm_factor = tl.math.rsqrt(row_norm * inv_n_cols + epsilon) - - # Store rsigma (norm_factor) - rsigma_output_ptr = rsigma_ptr + row_idx - tl.store(rsigma_output_ptr, norm_factor) - - rms_norm = row * norm_factor * g - - output_ptrs = output_ptr + row_idx * output_row_stride + col_offsets - if OUTPUT_ALIGNED_16: - output_ptrs = tl.multiple_of(output_ptrs, (16, )) - if IS_FP8: - amax_temp = tl.max(tl.abs(rms_norm), axis=-1) - amax = tl.maximum(amax, amax_temp) - rms_norm = rms_norm * scale - rms_norm = tl.clamp(rms_norm, -FP8_MAX, FP8_MAX) - tl.store(output_ptrs, rms_norm.to(output_type), mask=mask) + if HOIST_GAMMA: + g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32) + if (ZERO_CENTERED_GAMMA): + g += 1 + n_chunks = tl.cdiv(n_rows, ROWS_PER_PID) + for chunk_idx in tl.range(row_start, n_chunks, NUM_PRGMS, num_stages=2): + base_row = chunk_idx * ROWS_PER_PID + for i in tl.static_range(ROWS_PER_PID): + row_idx = base_row + i + if NEEDS_I64_OFFSETS: + row_off = row_idx.to(tl.int64) + else: + row_off = row_idx + input_ptrs = input_ptr + row_off * input_row_stride + col_offsets + if INPUT_ALIGNED_16: + input_ptrs = tl.multiple_of(input_ptrs, (16, )) + if HOIST_GAMMA and ROWS_PER_PID == 1: + row = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg").to(tl.float32) + else: + row = tl.load(input_ptrs, mask=mask, other=0.0).to(tl.float32) + row_norm = tl.sum(row * row, axis=-1) + norm_factor = tl.math.rsqrt(row_norm * inv_n_cols + epsilon) + + # Store rsigma (norm_factor) + tl.store(rsigma_ptr + row_idx, norm_factor) + + if not HOIST_GAMMA: + g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32) + if (ZERO_CENTERED_GAMMA): + g += 1 + rms_norm = row * norm_factor * g + + output_ptrs = output_ptr + row_off * output_row_stride + col_offsets + if OUTPUT_ALIGNED_16: + output_ptrs = tl.multiple_of(output_ptrs, (16, )) + if IS_FP8: + amax_temp = tl.max(tl.abs(rms_norm), axis=-1) + amax = tl.maximum(amax, amax_temp) + rms_norm = rms_norm * scale + rms_norm = tl.clamp(rms_norm, -FP8_MAX, FP8_MAX) + tl.store(output_ptrs, rms_norm.to(output_type), mask=mask) if IS_FP8: tl.atomic_max(q_amax_ptr, amax, sem="relaxed") if row_start == 0: @@ -164,7 +202,12 @@ def _rmsnorm_fwd_triton_impl( scale_inv = tl.fdiv(1.0, scale) tl.store(scale_inv_ptr, scale_inv) -autotune_dec = triton.autotune(configs=get_autotune_config(), key=['n_rows', 'n_cols'], use_cuda_graph=True) +autotune_dec = triton.autotune( + configs=get_autotune_config(), + key=['n_rows', 'n_cols'], + prune_configs_by={'early_config_prune': _prune_rms_configs}, + use_cuda_graph=True, +) _rmsnorm_fwd_triton = autotune_dec(_rmsnorm_fwd_triton_impl) @triton.jit @@ -172,7 +215,9 @@ def _rmsnorm_bwd_triton_impl(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_p n_rows, n_cols, ZERO_CENTERED_GAMMA: tl.constexpr, BLOCK_SIZE: tl.constexpr, USE_BLOCKED: tl.constexpr, NUM_PRGMS: tl.constexpr, INPUT_ALIGNED_16: tl.constexpr, GRAD_OUTPUT_ALIGNED_16: tl.constexpr, - DX_ALIGNED_16: tl.constexpr, DG_ALIGNED_16: tl.constexpr): + DX_ALIGNED_16: tl.constexpr, DG_ALIGNED_16: tl.constexpr, + ROWS_PER_PID: tl.constexpr = 1, + NEEDS_I64_OFFSETS: tl.constexpr = False): row_start = tl.program_id(0) col_offsets = tl.arange(0, BLOCK_SIZE) inv_n_cols = 1.0 / n_cols @@ -290,39 +335,50 @@ def _rmsnorm_bwd_triton_impl(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_p if (ZERO_CENTERED_GAMMA): g += 1. - for row_idx in tl.range(row_start, n_rows, NUM_PRGMS, num_stages=2): - input_ptrs = input_ptr + row_idx * input_row_stride + col_offsets - grad_output_ptrs = grad_output_ptr + row_idx * output_row_stride + col_offsets - dx_ptrs = dx_ptr + row_idx * input_row_stride + col_offsets + # Persistent outer loop with multi-row-per-chunk. + n_chunks = tl.cdiv(n_rows, ROWS_PER_PID) + for chunk_idx in tl.range(row_start, n_chunks, NUM_PRGMS, num_stages=2): + base_row = chunk_idx * ROWS_PER_PID + for i in tl.static_range(ROWS_PER_PID): + row_idx = base_row + i + if NEEDS_I64_OFFSETS: + row_off = row_idx.to(tl.int64) + else: + row_off = row_idx + input_ptrs = input_ptr + row_off * input_row_stride + col_offsets + grad_output_ptrs = grad_output_ptr + row_off * output_row_stride + col_offsets + dx_ptrs = dx_ptr + row_off * input_row_stride + col_offsets - if INPUT_ALIGNED_16: - input_ptrs = tl.multiple_of(input_ptrs, (16, )) - if GRAD_OUTPUT_ALIGNED_16: - grad_output_ptrs = tl.multiple_of(grad_output_ptrs, (16, )) - if DX_ALIGNED_16: - dx_ptrs = tl.multiple_of(dx_ptrs, (16, )) + if INPUT_ALIGNED_16: + input_ptrs = tl.multiple_of(input_ptrs, (16, )) + if GRAD_OUTPUT_ALIGNED_16: + grad_output_ptrs = tl.multiple_of(grad_output_ptrs, (16, )) + if DX_ALIGNED_16: + dx_ptrs = tl.multiple_of(dx_ptrs, (16, )) - x = tl.load(input_ptrs, mask=mask, other=0.0).to(tl.float32) - grad_output = tl.load(grad_output_ptrs, mask=mask, other=0.0).to(tl.float32) + x = tl.load(input_ptrs, mask=mask, other=0.0).to(tl.float32) + grad_output = tl.load(grad_output_ptrs, mask=mask, other=0.0).to(tl.float32) - norm_factor = tl.load(rsigma_ptr + row_idx).to(tl.float32) - grad_sum = tl.sum(grad_output * x * g, axis=0) - c_scalar = norm_factor * norm_factor * grad_sum * inv_n_cols + norm_factor = tl.load(rsigma_ptr + row_idx).to(tl.float32) + grad_sum = tl.sum(grad_output * x * g, axis=0) + c_scalar = norm_factor * norm_factor * grad_sum * inv_n_cols - grad_input = norm_factor * (grad_output * g - c_scalar * x) - tl.store(dx_ptrs, grad_input.to(dx_ptr.type.element_ty), mask=mask) + grad_input = norm_factor * (grad_output * g - c_scalar * x) + tl.store(dx_ptrs, grad_input.to(dx_ptr.type.element_ty), mask=mask) - dg = grad_output * x * norm_factor - dg_col_redux += dg.to(tl.float32) + dg = grad_output * x * norm_factor + dg_col_redux += dg.to(tl.float32) + # Each program owns exactly one dg_tmp partial row (index == row_start). tl.store(dg_ptr + row_start * n_cols + col_offsets, dg_col_redux, mask=mask) -# Autotune wrapper. Mirrors the fwd autotune layout so callers can toggle -# autotune via the same flag. +# Autotune wrapper. Mirrors the fwd autotune layout (same config set + prune) +# so callers can toggle autotune via the same flag. _rmsnorm_bwd_triton = triton.autotune( configs=get_autotune_config(), key=['n_rows', 'n_cols'], + prune_configs_by={'early_config_prune': _prune_rms_configs}, use_cuda_graph=True, )(_rmsnorm_bwd_triton_impl) diff --git a/transformer_engine/pytorch/triton_kernels/utils.py b/transformer_engine/pytorch/triton_kernels/utils.py index 884fab5e3..d4e6f3fdf 100644 --- a/transformer_engine/pytorch/triton_kernels/utils.py +++ b/transformer_engine/pytorch/triton_kernels/utils.py @@ -7,6 +7,7 @@ from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8CurrentScalingQuantizer from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer +from transformer_engine.pytorch.utils import get_sm_count from .common import te_dtype_to_torch_dtype def get_ln_sm_margin(sm_margin_type): @@ -34,12 +35,10 @@ def get_inf_ln_sm_margin(): def get_num_sms(sm_margin=None): - num_sms = torch.cuda.get_device_properties( - torch.cuda.current_device() - ).multi_processor_count + n = get_sm_count() if sm_margin is not None and sm_margin > 0: - num_sms = max(num_sms - int(sm_margin), 1) - return num_sms + n = max(n - int(sm_margin), 1) + return n def num_programs(x, sm_margin=None):