-
Notifications
You must be signed in to change notification settings - Fork 31
[Fix] TE RMSNorm Triton Kernel Optimization #615
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,10 +4,32 @@ | |
| import torch | ||
| import triton | ||
| import triton.language as tl | ||
| from itertools import product | ||
|
|
||
| def get_autotune_config(): | ||
| return [triton.Config({'waves_per_eu': we}, num_warps=nw) for (we, nw) in product([0, 1, 2, 4], [4, 8, 16])] | ||
| return [triton.Config({'waves_per_eu': 0}, num_warps=nw) for nw in (1, 4, 8, 16)] | ||
|
|
||
|
|
||
| def _prune_rms_configs(configs, named_args, **kwargs): | ||
| """Prune autotune configs whose ``num_warps`` is a poor fit for ``n_cols``. | ||
|
|
||
| Bounds threads-per-row to ``[n_cols/16, n_cols*2]`` so the autotuner never | ||
| wastes a measurement on a config that is obviously too serial (too few | ||
| lanes) or mostly idle (too many lanes). Shared by the fwd and bwd kernels. | ||
| """ | ||
| n_cols = named_args.get('n_cols') | ||
| if n_cols is None: | ||
| n_cols = kwargs.get('n_cols') | ||
| if n_cols is None: | ||
| return configs | ||
| out = [] | ||
| for cfg in configs: | ||
| threads = cfg.num_warps * 64 # AMD wavefront = 64 | ||
| if threads < n_cols / 16: # too serial | ||
| continue | ||
| if threads > n_cols * 2: # too many idle lanes | ||
| continue | ||
| out.append(cfg) | ||
| return out if out else configs | ||
|
|
||
|
|
||
| # TODO(micky774) Implement fused MXFP8 quantization within the kernel | ||
|
|
@@ -32,6 +54,9 @@ def _rmsnorm_fwd_triton_impl( | |
| FP8_MAX: tl.constexpr, | ||
| INPUT_ALIGNED_16: tl.constexpr, | ||
| OUTPUT_ALIGNED_16: tl.constexpr, | ||
| ROWS_PER_PID: tl.constexpr = 1, | ||
| NEEDS_I64_OFFSETS: tl.constexpr = False, | ||
| HOIST_GAMMA: tl.constexpr = True, | ||
| ): | ||
|
|
||
| row_start = tl.program_id(0) | ||
|
|
@@ -128,51 +153,71 @@ def _rmsnorm_fwd_triton_impl( | |
|
|
||
| else: | ||
| mask = col_offsets < n_cols | ||
| # gamma is invariant across rows -- load + ZERO_CENTERED adjustment once per program. | ||
| g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32) | ||
| if (ZERO_CENTERED_GAMMA): | ||
| g += 1 | ||
| inv_n_cols = 1.0 / n_cols | ||
| for row_idx in tl.range(row_start, n_rows, NUM_PRGMS, num_stages=2): | ||
| input_ptrs = input_ptr + row_idx * input_row_stride + col_offsets | ||
| if INPUT_ALIGNED_16: | ||
| input_ptrs = tl.multiple_of(input_ptrs, (16, )) | ||
| row = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg").to(tl.float32) | ||
| row_norm = row * row | ||
| row_norm = tl.sum(row_norm, axis=-1) | ||
| norm_factor = tl.math.rsqrt(row_norm * inv_n_cols + epsilon) | ||
|
|
||
| # Store rsigma (norm_factor) | ||
| rsigma_output_ptr = rsigma_ptr + row_idx | ||
| tl.store(rsigma_output_ptr, norm_factor) | ||
|
|
||
| rms_norm = row * norm_factor * g | ||
|
|
||
| output_ptrs = output_ptr + row_idx * output_row_stride + col_offsets | ||
| if OUTPUT_ALIGNED_16: | ||
| output_ptrs = tl.multiple_of(output_ptrs, (16, )) | ||
| if IS_FP8: | ||
| amax_temp = tl.max(tl.abs(rms_norm), axis=-1) | ||
| amax = tl.maximum(amax, amax_temp) | ||
| rms_norm = rms_norm * scale | ||
| rms_norm = tl.clamp(rms_norm, -FP8_MAX, FP8_MAX) | ||
| tl.store(output_ptrs, rms_norm.to(output_type), mask=mask) | ||
| if HOIST_GAMMA: | ||
| g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32) | ||
| if (ZERO_CENTERED_GAMMA): | ||
| g += 1 | ||
| n_chunks = tl.cdiv(n_rows, ROWS_PER_PID) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: Since the kernel default is |
||
| for chunk_idx in tl.range(row_start, n_chunks, NUM_PRGMS, num_stages=2): | ||
| base_row = chunk_idx * ROWS_PER_PID | ||
| for i in tl.static_range(ROWS_PER_PID): | ||
| row_idx = base_row + i | ||
| if NEEDS_I64_OFFSETS: | ||
| row_off = row_idx.to(tl.int64) | ||
| else: | ||
| row_off = row_idx | ||
| input_ptrs = input_ptr + row_off * input_row_stride + col_offsets | ||
| if INPUT_ALIGNED_16: | ||
| input_ptrs = tl.multiple_of(input_ptrs, (16, )) | ||
| if HOIST_GAMMA and ROWS_PER_PID == 1: | ||
| row = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg").to(tl.float32) | ||
| else: | ||
| row = tl.load(input_ptrs, mask=mask, other=0.0).to(tl.float32) | ||
| row_norm = tl.sum(row * row, axis=-1) | ||
| norm_factor = tl.math.rsqrt(row_norm * inv_n_cols + epsilon) | ||
|
|
||
| # Store rsigma (norm_factor) | ||
| tl.store(rsigma_ptr + row_idx, norm_factor) | ||
|
|
||
| if not HOIST_GAMMA: | ||
| g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32) | ||
| if (ZERO_CENTERED_GAMMA): | ||
| g += 1 | ||
| rms_norm = row * norm_factor * g | ||
|
|
||
| output_ptrs = output_ptr + row_off * output_row_stride + col_offsets | ||
| if OUTPUT_ALIGNED_16: | ||
| output_ptrs = tl.multiple_of(output_ptrs, (16, )) | ||
| if IS_FP8: | ||
| amax_temp = tl.max(tl.abs(rms_norm), axis=-1) | ||
| amax = tl.maximum(amax, amax_temp) | ||
| rms_norm = rms_norm * scale | ||
| rms_norm = tl.clamp(rms_norm, -FP8_MAX, FP8_MAX) | ||
| tl.store(output_ptrs, rms_norm.to(output_type), mask=mask) | ||
| if IS_FP8: | ||
| tl.atomic_max(q_amax_ptr, amax, sem="relaxed") | ||
| if row_start == 0: | ||
| scale = tl.load(q_scale_ptr) | ||
| scale_inv = tl.fdiv(1.0, scale) | ||
| tl.store(scale_inv_ptr, scale_inv) | ||
|
|
||
| autotune_dec = triton.autotune(configs=get_autotune_config(), key=['n_rows', 'n_cols'], use_cuda_graph=True) | ||
| autotune_dec = triton.autotune( | ||
| configs=get_autotune_config(), | ||
| key=['n_rows', 'n_cols'], | ||
| prune_configs_by={'early_config_prune': _prune_rms_configs}, | ||
| use_cuda_graph=True, | ||
| ) | ||
| _rmsnorm_fwd_triton = autotune_dec(_rmsnorm_fwd_triton_impl) | ||
|
|
||
| @triton.jit | ||
| def _rmsnorm_bwd_triton_impl(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, dg_ptr, input_row_stride, output_row_stride, | ||
| n_rows, n_cols, ZERO_CENTERED_GAMMA: tl.constexpr, BLOCK_SIZE: tl.constexpr, | ||
| USE_BLOCKED: tl.constexpr, NUM_PRGMS: tl.constexpr, | ||
| INPUT_ALIGNED_16: tl.constexpr, GRAD_OUTPUT_ALIGNED_16: tl.constexpr, | ||
| DX_ALIGNED_16: tl.constexpr, DG_ALIGNED_16: tl.constexpr): | ||
| DX_ALIGNED_16: tl.constexpr, DG_ALIGNED_16: tl.constexpr, | ||
| ROWS_PER_PID: tl.constexpr = 1, | ||
| NEEDS_I64_OFFSETS: tl.constexpr = False): | ||
| row_start = tl.program_id(0) | ||
| col_offsets = tl.arange(0, BLOCK_SIZE) | ||
| inv_n_cols = 1.0 / n_cols | ||
|
|
@@ -290,39 +335,50 @@ def _rmsnorm_bwd_triton_impl(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_p | |
| if (ZERO_CENTERED_GAMMA): | ||
| g += 1. | ||
|
|
||
| for row_idx in tl.range(row_start, n_rows, NUM_PRGMS, num_stages=2): | ||
| input_ptrs = input_ptr + row_idx * input_row_stride + col_offsets | ||
| grad_output_ptrs = grad_output_ptr + row_idx * output_row_stride + col_offsets | ||
| dx_ptrs = dx_ptr + row_idx * input_row_stride + col_offsets | ||
| # Persistent outer loop with multi-row-per-chunk. | ||
| n_chunks = tl.cdiv(n_rows, ROWS_PER_PID) | ||
| for chunk_idx in tl.range(row_start, n_chunks, NUM_PRGMS, num_stages=2): | ||
| base_row = chunk_idx * ROWS_PER_PID | ||
| for i in tl.static_range(ROWS_PER_PID): | ||
| row_idx = base_row + i | ||
| if NEEDS_I64_OFFSETS: | ||
| row_off = row_idx.to(tl.int64) | ||
| else: | ||
| row_off = row_idx | ||
| input_ptrs = input_ptr + row_off * input_row_stride + col_offsets | ||
| grad_output_ptrs = grad_output_ptr + row_off * output_row_stride + col_offsets | ||
| dx_ptrs = dx_ptr + row_off * input_row_stride + col_offsets | ||
|
|
||
| if INPUT_ALIGNED_16: | ||
| input_ptrs = tl.multiple_of(input_ptrs, (16, )) | ||
| if GRAD_OUTPUT_ALIGNED_16: | ||
| grad_output_ptrs = tl.multiple_of(grad_output_ptrs, (16, )) | ||
| if DX_ALIGNED_16: | ||
| dx_ptrs = tl.multiple_of(dx_ptrs, (16, )) | ||
| if INPUT_ALIGNED_16: | ||
| input_ptrs = tl.multiple_of(input_ptrs, (16, )) | ||
| if GRAD_OUTPUT_ALIGNED_16: | ||
| grad_output_ptrs = tl.multiple_of(grad_output_ptrs, (16, )) | ||
| if DX_ALIGNED_16: | ||
| dx_ptrs = tl.multiple_of(dx_ptrs, (16, )) | ||
|
|
||
| x = tl.load(input_ptrs, mask=mask, other=0.0).to(tl.float32) | ||
| grad_output = tl.load(grad_output_ptrs, mask=mask, other=0.0).to(tl.float32) | ||
| x = tl.load(input_ptrs, mask=mask, other=0.0).to(tl.float32) | ||
| grad_output = tl.load(grad_output_ptrs, mask=mask, other=0.0).to(tl.float32) | ||
|
|
||
| norm_factor = tl.load(rsigma_ptr + row_idx).to(tl.float32) | ||
| grad_sum = tl.sum(grad_output * x * g, axis=0) | ||
| c_scalar = norm_factor * norm_factor * grad_sum * inv_n_cols | ||
| norm_factor = tl.load(rsigma_ptr + row_idx).to(tl.float32) | ||
| grad_sum = tl.sum(grad_output * x * g, axis=0) | ||
| c_scalar = norm_factor * norm_factor * grad_sum * inv_n_cols | ||
|
|
||
| grad_input = norm_factor * (grad_output * g - c_scalar * x) | ||
| tl.store(dx_ptrs, grad_input.to(dx_ptr.type.element_ty), mask=mask) | ||
| grad_input = norm_factor * (grad_output * g - c_scalar * x) | ||
| tl.store(dx_ptrs, grad_input.to(dx_ptr.type.element_ty), mask=mask) | ||
|
|
||
| dg = grad_output * x * norm_factor | ||
| dg_col_redux += dg.to(tl.float32) | ||
| dg = grad_output * x * norm_factor | ||
| dg_col_redux += dg.to(tl.float32) | ||
|
|
||
| # Each program owns exactly one dg_tmp partial row (index == row_start). | ||
| tl.store(dg_ptr + row_start * n_cols + col_offsets, dg_col_redux, mask=mask) | ||
|
|
||
|
|
||
| # Autotune wrapper. Mirrors the fwd autotune layout so callers can toggle | ||
| # autotune via the same flag. | ||
| # Autotune wrapper. Mirrors the fwd autotune layout (same config set + prune) | ||
| # so callers can toggle autotune via the same flag. | ||
| _rmsnorm_bwd_triton = triton.autotune( | ||
| configs=get_autotune_config(), | ||
| key=['n_rows', 'n_cols'], | ||
| prune_configs_by={'early_config_prune': _prune_rms_configs}, | ||
| use_cuda_graph=True, | ||
| )(_rmsnorm_bwd_triton_impl) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,6 +7,7 @@ | |
| from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8CurrentScalingQuantizer | ||
| from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer | ||
| from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer | ||
| from transformer_engine.pytorch.utils import get_sm_count | ||
| from .common import te_dtype_to_torch_dtype | ||
|
|
||
| def get_ln_sm_margin(sm_margin_type): | ||
|
|
@@ -34,12 +35,10 @@ def get_inf_ln_sm_margin(): | |
|
|
||
|
|
||
| def get_num_sms(sm_margin=None): | ||
| num_sms = torch.cuda.get_device_properties( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use get_sm_count from transformer_engine.pytorch.utils |
||
| torch.cuda.current_device() | ||
| ).multi_processor_count | ||
| n = get_sm_count() | ||
| if sm_margin is not None and sm_margin > 0: | ||
| num_sms = max(num_sms - int(sm_margin), 1) | ||
| return num_sms | ||
| n = max(n - int(sm_margin), 1) | ||
| return n | ||
|
|
||
|
|
||
| def num_programs(x, sm_margin=None): | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.